-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrun_inference.py
More file actions
132 lines (110 loc) · 4.38 KB
/
run_inference.py
File metadata and controls
132 lines (110 loc) · 4.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import time
import math
import os
os.environ["USE_TF"] = "0"
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
import asdsl.kernels._native_forward as native_forward
import warnings
warnings.filterwarnings('ignore')
def main():
print("==================================================")
print(" ASDSL C++ Inference Engine - Hardware Accelerated")
print("==================================================")
# TinyLlama fits easily in memory, uses GQA matching our C++ pipeline, and RoPE Theta 10000
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
print(f"[*] Loading HuggingFace model '{model_id}'...")
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Load directly to CPU
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32)
cfg = model.config
dim = cfg.hidden_size
hidden_dim = cfg.intermediate_size
num_heads = cfg.num_attention_heads
num_kv_heads = cfg.num_key_value_heads
num_layers = cfg.num_hidden_layers
head_dim = dim // num_heads
vocab_size = cfg.vocab_size
print(f"[*] Architecture: {num_layers} Layers | Dim: {dim} | Heads: {num_heads} | KV-Heads: {num_kv_heads}")
print("[*] Memory-mapping Q4 model weights (zero-copy)...")
import json
with open("models/tinyllama_q4_metadata.json", "r") as f:
metadata = json.load(f)
mmap_store = native_forward.MmapWeights("models/tinyllama_q4.bin", metadata)
# Clean up massive torch model object if possible to free RAM
del model
import gc
gc.collect()
print("[*] Initialization complete. C++ Engine Ready.")
print("==================================================\n")
prompt = "The capital of France is"
input_ids = tokenizer.encode(prompt, add_special_tokens=True)
max_seq_len = 1024
cache = native_forward.KVCache(num_layers, max_seq_len, num_kv_heads, head_dim)
print(f"Prompt: '{prompt}'")
print("\n--- PREFILL PHASE ---")
seq_pos = 0
next_token = input_ids[0]
# Prefill route: T>1 uses batched tiled GEMM; T=1 stays on decode GEMV.
prefill_ids = input_ids[:-1]
if len(prefill_ids) > 1:
native_forward.prefill_prompt_tokens(
np.asarray(prefill_ids, dtype=np.int32),
seq_pos,
mmap_store,
num_layers,
dim,
hidden_dim,
num_heads,
num_kv_heads,
head_dim,
vocab_size,
cache,
)
seq_pos += len(prefill_ids)
elif len(prefill_ids) == 1:
native_forward.generate_token(
prefill_ids[0], seq_pos, mmap_store,
num_layers, dim, hidden_dim, num_heads, num_kv_heads, head_dim, vocab_size,
cache
)
seq_pos += 1
# The last token in the prompt generates our first new token!
last_prompt_token = input_ids[-1]
next_token = native_forward.generate_token(
last_prompt_token, seq_pos, mmap_store,
num_layers, dim, hidden_dim, num_heads, num_kv_heads, head_dim, vocab_size,
cache
)
seq_pos += 1
print("\n--- DECODE PHASE ---")
# To print beautifully to console
print(prompt, end="", flush=True)
decoded_first = tokenizer.decode([next_token], skip_special_tokens=True)
print(decoded_first, end="", flush=True)
start_time = time.perf_counter()
tokens_generated = 1 # We already generated one token
max_new_tokens = 50
while tokens_generated < max_new_tokens:
if next_token == tokenizer.eos_token_id:
break
next_token = native_forward.generate_token(
next_token, seq_pos, mmap_store,
num_layers, dim, hidden_dim, num_heads, num_kv_heads, head_dim, vocab_size,
cache
)
seq_pos += 1
tokens_generated += 1
word = tokenizer.decode([next_token], skip_special_tokens=True)
print(word, end="", flush=True)
end_time = time.perf_counter()
elapsed = end_time - start_time
tps = tokens_generated / elapsed
print("\n\n==================================================")
print(f"Total Time: {elapsed:.3f} seconds")
print(f"Generated: {tokens_generated} tokens")
print(f"Throughput: {tps:.2f} tok/s")
print("==================================================")
if __name__ == '__main__':
main()