-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
71 lines (56 loc) · 2.18 KB
/
generate.py
File metadata and controls
71 lines (56 loc) · 2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""Load a trained model and generate text interactively."""
import argparse
import torch
from model import TinyLLM
from tokenizer import BPETokenizer, CharTokenizer
DEVICE = (
"mps"
if torch.backends.mps.is_available()
else "cuda"
if torch.cuda.is_available()
else "cpu"
)
def load_model_and_tokenizer(
checkpoint_path: str, tokenizer_path: str, tokenizer_type: str
):
checkpoint = torch.load(checkpoint_path, map_location=DEVICE, weights_only=True)
config = checkpoint["config"]
if tokenizer_type == "char":
tok = CharTokenizer()
else:
tok = BPETokenizer()
tok.load(tokenizer_path)
model = TinyLLM(
vocab_size=config["vocab_size"],
d_model=config["d_model"],
n_heads=config["n_heads"],
n_layers=config["n_layers"],
d_ff=config["d_model"] * 4,
max_seq_len=config["seq_len"],
).to(DEVICE)
model.load_state_dict(checkpoint["model_state"])
model.set_default_tensor_type = None # not needed, just loading
model.requires_grad_(False)
return model, tok
def generate(model, tok, prompt: str, n_tokens: int, temperature: float, top_k: int):
ids = tok.encode(prompt)
x = torch.tensor([ids], dtype=torch.long, device=DEVICE)
out = model.generate(x, max_new_tokens=n_tokens, temperature=temperature, top_k=top_k)
return tok.decode(out[0].tolist())
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="checkpoints/char_best.pt")
parser.add_argument("--tokenizer-path", default="checkpoints/char_tokenizer.json")
parser.add_argument("--tokenizer-type", choices=["char", "bpe"], default="char")
parser.add_argument("--prompt", default="Once upon a midnight dreary")
parser.add_argument("--tokens", type=int, default=500)
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top-k", type=int, default=40)
args = parser.parse_args()
model, tok = load_model_and_tokenizer(
args.checkpoint, args.tokenizer_path, args.tokenizer_type
)
text = generate(
model, tok, args.prompt, args.tokens, args.temperature, args.top_k
)
print(text)