-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathidentity_probe.py
More file actions
153 lines (124 loc) · 4.98 KB
/
identity_probe.py
File metadata and controls
153 lines (124 loc) · 4.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# identity_probe.py
# Standalone identity probe: asks Basil "Who are you?" and logs the response.
#
# Called by the orchestrator after each training run (new model version) and
# on graceful exit. Results are appended to identities/identity_log.jsonl so
# you can scroll through and see how different checkpoints answer.
#
# Usage:
# from identity_probe import run_identity_probe
# result = run_identity_probe() # uses latest model
# result = run_identity_probe(label="post_train_v005")
import os
import json
import torch
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM
from config import IDENTITY_FILE, MODELS_DIR, get_latest_basil_model_path, BASIL_MAX_TOKENS_BY_AGE_BAND
def run_identity_probe(
label: str = "",
model_path: str = None,
verbose: bool = True,
) -> dict:
"""
Load the current Basil model, ask "Who are you?", and log the result.
Args:
label: Optional label for this probe (e.g. "post_train_v005", "exit").
Stored in the log entry for easy filtering.
model_path: Override model path (default: latest model).
verbose: Print to terminal.
Returns:
dict with the probe result.
"""
model_path = model_path or get_latest_basil_model_path()
model_name = os.path.basename(model_path)
if verbose:
print(f"\n--- Identity Probe ({label or 'default'}) ---")
print(f"[Identity] Loading model: {model_name}")
# Load model
tokenizer = AutoTokenizer.from_pretrained(model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_path)
model.eval()
if hasattr(model, "generation_config"):
model.generation_config.max_length = None
model.generation_config.max_new_tokens = None
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# Determine max tokens from current age band
from memory_manager import load_basil_assessment
assessment = load_basil_assessment()
age_band = assessment.get("age_band", 0)
max_tokens = BASIL_MAX_TOKENS_BY_AGE_BAND.get(age_band, 10)
if verbose:
print(f"[Identity] Age band: {age_band}, max_tokens: {max_tokens}")
# Ask the identity question with minimal context (just the question)
probe_question = "Tutor: Who are you?\nBasil:"
inputs = tokenizer(
probe_question,
return_tensors="pt",
truncation=True,
max_length=128,
padding=False,
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
# Use same generation strategy as auto_session: greedy for age_band<=1, sampling for >=2
with torch.no_grad():
if age_band <= 1:
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
)
else:
output = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
do_sample=True,
temperature=1.0,
top_k=50,
pad_token_id=tokenizer.pad_token_id,
)
# Decode only the generated tokens (exclude the prompt)
generated_ids = output[0][input_ids.shape[1]:]
answer = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
# Truncate at first newline (Basil's response only)
if "\n" in answer:
answer = answer[:answer.index("\n")].strip()
timestamp = datetime.now().isoformat()
result = {
"timestamp": timestamp,
"model": model_name,
"label": label,
"question": "Who are you?",
"answer": answer,
}
# Append to identity log
os.makedirs(os.path.dirname(IDENTITY_FILE), exist_ok=True)
with open(IDENTITY_FILE, "a") as f:
f.write(json.dumps(result) + "\n")
if verbose:
print(f"[Identity] Model: {model_name}")
print(f"[Identity] Answer: {answer}")
print(f"[Identity] Logged to {IDENTITY_FILE}")
print("--- Identity Probe Complete ---\n")
# Free GPU memory
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run Basil identity probe")
parser.add_argument("--label", type=str, default="manual", help="Label for this probe")
parser.add_argument("--version", type=str, default=None, help="Specific model version (e.g. v003)")
args = parser.parse_args()
path = None
if args.version:
path = os.path.join(MODELS_DIR, f"basil_{args.version}")
run_identity_probe(label=args.label, model_path=path)