-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_whychain_integration.py
More file actions
182 lines (153 loc) · 6.89 KB
/
test_whychain_integration.py
File metadata and controls
182 lines (153 loc) · 6.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#!/usr/bin/env python3
"""
Unit + integration tests for WhyChain orchestrator wiring.
Tests:
1. Mock _sessions.jsonl records are parseable by load_session_transcripts()
2. run_whychain_session() wrapper has the correct orchestrator-compatible signature
3. WhyChain sessions produce 0 graded turns (basil_spans empty)
Run:
python test_whychain_integration.py
"""
import inspect
import json
import os
import sys
import tempfile
import unittest
from unittest.mock import patch
class TestWhyChainSessionsFormat(unittest.TestCase):
"""Verify that WhyChain _sessions.jsonl records are loadable by the training pipeline."""
def setUp(self):
"""Create a temp directory with a mock _sessions.jsonl containing whychain records."""
self.tmpdir = tempfile.mkdtemp()
# Build mock whychain _sessions.jsonl in the same format that
# run_why_chain() now produces when batch_sessions_path is set.
records = [
{
"record_type": "session_start",
"session_id": "test_whychain_001",
"subject": "WhyChain",
"lesson": "Why is the sky blue?",
"training_phase": "whychain",
"early_stopped": False,
"stop_reason": "max_rounds",
"timestamp": "2026-02-14T12:00:00",
},
{
"record_type": "transcript",
"session_id": "test_whychain_001",
"speaker": "Sophie",
"text": "Why is the sky blue?",
},
{
"record_type": "transcript",
"session_id": "test_whychain_001",
"speaker": "Tutor",
"text": "Great question! The sky appears blue because of a phenomenon called Rayleigh scattering.",
},
{
"record_type": "transcript",
"session_id": "test_whychain_001",
"speaker": "Sophie",
"text": "But why does Rayleigh scattering make blue instead of other colors?",
},
{
"record_type": "transcript",
"session_id": "test_whychain_001",
"speaker": "Tutor",
"text": "Because shorter wavelengths of light scatter more strongly. Blue light has a shorter wavelength than red.",
},
{
"record_type": "transcript",
"session_id": "test_whychain_001",
"speaker": "Sophie",
"text": "But why do shorter wavelengths scatter more?",
},
{
"record_type": "transcript",
"session_id": "test_whychain_001",
"speaker": "Tutor",
"text": "That's a wonderful question, Sophie. At some level, the honest answer is we don't fully know — it's a fundamental property of how electromagnetic radiation interacts with matter.",
},
]
self.sessions_path = os.path.join(self.tmpdir, "batch_test_sessions.jsonl")
with open(self.sessions_path, "w") as f:
for record in records:
f.write(json.dumps(record) + "\n")
def tearDown(self):
"""Clean up temp directory."""
import shutil
shutil.rmtree(self.tmpdir, ignore_errors=True)
def test_load_session_transcripts_parses_whychain(self):
"""load_session_transcripts() should load whychain transcript records."""
# Patch LOG_DIR in train_basil_v2 to point at our temp directory
with patch("train_basil_v2.LOG_DIR", self.tmpdir):
from train_basil_v2 import load_session_transcripts
sessions = load_session_transcripts()
self.assertGreaterEqual(len(sessions), 1, "Should load at least 1 session")
# Find our test session
whychain_session = None
for s in sessions:
if s["session_id"] == "test_whychain_001":
whychain_session = s
break
self.assertIsNotNone(whychain_session, "Should find test_whychain_001 session")
# Verify transcript text contains Sophie and Tutor lines
text = whychain_session["text"]
self.assertIn("Sophie:", text, "Transcript should contain Sophie lines")
self.assertIn("Tutor:", text, "Transcript should contain Tutor lines")
self.assertIn("Rayleigh scattering", text, "Transcript should contain actual content")
# WhyChain has 0 graded Basil turns
self.assertEqual(
whychain_session["basil_spans"], [],
"WhyChain sessions should have empty basil_spans (no graded turns)"
)
def test_whychain_session_has_session_key(self):
"""Session should get a proper session_key for recency weighting."""
with patch("train_basil_v2.LOG_DIR", self.tmpdir):
from train_basil_v2 import load_session_transcripts
sessions = load_session_transcripts()
whychain_session = next(
(s for s in sessions if s["session_id"] == "test_whychain_001"),
None,
)
self.assertIsNotNone(whychain_session)
self.assertIn("session_key", whychain_session)
self.assertTrue(len(whychain_session["session_key"]) > 0)
class TestWhyChainWrapperSignature(unittest.TestCase):
"""Verify run_whychain_session() has the right interface for the orchestrator."""
def test_signature_matches_orchestrator_interface(self):
"""run_whychain_session() must accept all standard orchestrator params."""
from whychain_session import run_whychain_session
sig = inspect.signature(run_whychain_session)
param_names = set(sig.parameters.keys())
required_params = {
"verbose",
"training_phase",
"session_id",
"batch_graded_path",
"batch_sessions_path",
"batch_meta_path",
}
for param in required_params:
self.assertIn(
param, param_names,
f"run_whychain_session() must accept '{param}' parameter"
)
def test_all_params_have_defaults(self):
"""All params should have defaults (orchestrator may not pass all)."""
from whychain_session import run_whychain_session
sig = inspect.signature(run_whychain_session)
for name, param in sig.parameters.items():
self.assertIsNot(
param.default, inspect.Parameter.empty,
f"Parameter '{name}' should have a default value"
)
class TestClearUsedQuestions(unittest.TestCase):
"""Verify clear_used_questions is importable and callable."""
def test_clear_used_questions_importable(self):
"""clear_used_questions should be importable from whychain_session."""
from whychain_session import clear_used_questions
self.assertTrue(callable(clear_used_questions))
if __name__ == "__main__":
unittest.main(verbosity=2)