-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_basil.py
More file actions
executable file
·92 lines (76 loc) · 2.94 KB
/
eval_basil.py
File metadata and controls
executable file
·92 lines (76 loc) · 2.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# eval_basil.py
import os, json, readline, argparse
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from config import MODELS_DIR, IDENTITY_FILE, get_latest_basil_model_path
# Parse command-line arguments
parser = argparse.ArgumentParser(
description="Tutor-only evaluation of a Basil model. Load latest model or a specific version."
)
parser.add_argument(
"--version",
type=str,
help="Specific Basil version to load (directory will be models/basil_<version>), e.g. v003",
)
args = parser.parse_args()
# Determine which Basil model to load
if args.version:
BASIL_MODEL_NAME = os.path.join(MODELS_DIR, f"basil_{args.version}")
else:
BASIL_MODEL_NAME = get_latest_basil_model_path()
# Load model and tokenizer
print(f"\U0001f9e0 Loaded Basil model from: {BASIL_MODEL_NAME}")
tokenizer = AutoTokenizer.from_pretrained(BASIL_MODEL_NAME)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(BASIL_MODEL_NAME)
model.eval()
# Clear generation_config to avoid any conflicts
if hasattr(model, 'generation_config'):
model.generation_config.max_length = None
model.generation_config.max_new_tokens = None
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Load most recent identity (or default)
if os.path.exists(IDENTITY_FILE):
with open(IDENTITY_FILE, "r") as f:
lines = f.readlines()
if lines:
identity = json.loads(lines[-1])["identity"]
else:
identity = "You are a small child named Basil."
else:
identity = "You are a small child named Basil."
print(f"\n{'System:':>100} {identity}\n")
# Interactive loop (Tutor and Basil only)
log = []
while True:
try:
tutor_input = input("Tutor: ").strip()
if tutor_input.lower() in ["exit", "quit"]:
break
prompt = f"System: {identity}\nTutor: {tutor_input}\nBasil:"
inputs = tokenizer(prompt, return_tensors="pt", padding=False)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=60,
do_sample=True,
top_k=50,
top_p=0.95,
temperature=0.8,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = response.split("Basil:")[-1].strip().split("\n")[0].strip()
print(f"Basil: {response}")
log.append({"speaker": "Tutor", "text": tutor_input})
log.append({"speaker": "Basil", "text": response})
except KeyboardInterrupt:
break
# Optionally discard the session log
print("\nSession complete. Not saving test session log.")