-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchat_basil.py
More file actions
executable file
·202 lines (160 loc) · 6.8 KB
/
chat_basil.py
File metadata and controls
executable file
·202 lines (160 loc) · 6.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# chat_basil.py
# Interactive chat mode with Basil.
# Uses model.generate() directly to avoid generation_config warnings
#
# IDENTITY EXPERIMENT NOTE:
# - Basil receives NO system prompt at inference time (no "You are Basil...")
# - The "Who are you?" question at session end is logged for analysis only
# - Identity does NOT carry over to next session (social reinforcement only)
import json
import os
import argparse
import torch
from datetime import datetime
from transformers import AutoModelForCausalLM, AutoTokenizer
from config import LOG_DIR, IDENTITY_FILE, BASIL_MODEL_NAME
from sophie_engine import generate_sophie_reply
# Ensure folders exist
os.makedirs(LOG_DIR, exist_ok=True)
os.makedirs(os.path.dirname(IDENTITY_FILE), exist_ok=True)
# === Dynamically load most recent Basil model ===
latest_model_path = BASIL_MODEL_NAME
print(f"🧠 Loaded Basil model from: {latest_model_path}")
tokenizer = AutoTokenizer.from_pretrained(latest_model_path)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(latest_model_path)
model.eval()
# Clear generation_config to avoid any conflicts
if hasattr(model, 'generation_config'):
model.generation_config.max_length = None
model.generation_config.max_new_tokens = None
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
# === Generate log path ===
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_path = os.path.join(LOG_DIR, f"log_{timestamp}.jsonl")
def append_log(speaker, text, entry_type=None):
"""Append a log entry to the session log file."""
entry = {"speaker": speaker, "text": text}
if entry_type:
entry["type"] = entry_type
with open(log_path, "a") as f:
f.write(json.dumps(entry) + "\n")
def generate_basil_response(session_history, max_tokens=25):
"""
Generate Basil's response using model.generate() directly.
NOTE: No system prompt is used. Basil only sees the conversation context.
Identity emerges through social reinforcement (Tutor/Sophie calling Basil by name).
"""
MAX_TOTAL_TOKENS = 1024
RESERVED_FOR_GENERATION = max_tokens + 10
MAX_CONTEXT_TOKENS = MAX_TOTAL_TOKENS - RESERVED_FOR_GENERATION
# Build dialogue context (NO system prompt)
dialogue_lines = []
for speaker, text in reversed(session_history):
line = f"{speaker}: {text}\n"
line_tokens = tokenizer.encode(line, add_special_tokens=False)
current_total = sum(len(tokenizer.encode(l, add_special_tokens=False)) for l in dialogue_lines)
if current_total + len(line_tokens) > MAX_CONTEXT_TOKENS:
break
dialogue_lines.insert(0, line)
# Build final prompt (NO system prompt, just conversation + "Basil:")
prompt = "".join(dialogue_lines)
prompt += "Basil:"
# Tokenize with attention mask
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=MAX_CONTEXT_TOKENS,
padding=False,
)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
# Generate using model.generate() directly
with torch.no_grad():
output_ids = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
do_sample=True,
temperature=1.0,
top_k=50,
top_p=0.95,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode and extract response
full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
response = full_output.split("Basil:")[-1].strip()
# Truncate at newline
if "\n" in response:
response = response.split("\n")[0].strip()
return response
def run_identity_probe(session_history):
"""
Ask Basil "Who are you?" and log the response.
This is for analysis only - the response is NOT used for prompting in future sessions.
"""
print("\n--- Identity Probe ---")
# Add the probe question to context
probe_question = "Who are you?"
probe_history = session_history + [("Tutor", probe_question)]
# Generate response (no system prompt)
identity_response = generate_basil_response(probe_history, max_tokens=30)
print(f"Tutor: {probe_question}")
print(f"Basil: {identity_response}")
# Log as identity_probe type (excluded from training)
append_log("Tutor", probe_question, entry_type="identity_probe")
append_log("Basil", identity_response, entry_type="identity_probe")
# Save identity to identity log for analysis (NOT for reuse)
identity_entry = {
"timestamp": datetime.now().isoformat(),
"session_log": log_path,
"question": probe_question,
"answer": identity_response,
"note": "For analysis only - not used for prompting"
}
with open(IDENTITY_FILE, "a") as f:
f.write(json.dumps(identity_entry) + "\n")
print(f"\n🆔 Identity logged (for analysis only): {identity_response or '(no response)'}")
print("--- Identity Probe Complete ---\n")
return identity_response
def main():
parser = argparse.ArgumentParser(description="Interactive chat with Basil")
parser.add_argument("--no-sophie", action="store_true", help="Disable Sophie")
parser.add_argument("--no-identity-probe", action="store_true",
help="Skip the 'Who are you?' probe at session end")
args = parser.parse_args()
print("\n" + "="*50)
print("Bootstrap Basil - Interactive Chat")
print("="*50)
print("NOTE: Basil has no system prompt (identity experiment)")
print("Type 'exit' or 'quit' to end session")
print("="*50 + "\n")
session_history = []
while True:
tutor_input = input("Tutor: ").strip()
if tutor_input.lower() in ["exit", "quit"]:
break
append_log("Tutor", tutor_input)
session_history.append(("Tutor", tutor_input))
# Sophie replies (optional)
if not args.no_sophie:
sophie_reply = generate_sophie_reply(tutor_input, None, session_history)
if sophie_reply:
print(f"Sophie: {sophie_reply}")
append_log("Sophie", sophie_reply)
session_history.append(("Sophie", sophie_reply))
# Basil responds (no system prompt)
basil_response = generate_basil_response(session_history)
print(f"Basil: {basil_response}")
append_log("Basil", basil_response)
session_history.append(("Basil", basil_response))
# End of session - run identity probe
if not args.no_identity_probe:
run_identity_probe(session_history)
print(f"\nSession log saved to: {log_path}")
if __name__ == "__main__":
main()