-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathassessment_agent.py
More file actions
128 lines (99 loc) · 4.02 KB
/
assessment_agent.py
File metadata and controls
128 lines (99 loc) · 4.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# assessment_agent.py
# Assesses Basil's developmental age_band based on full session transcript.
import os
import json
import re
from openai import OpenAI
from config import (
ASSESSMENT_AGENT_MODEL,
PROMPT_ASSESSMENT_AGENT,
)
from llm_client import create_smart_client
client = create_smart_client()
def load_prompt_template() -> str:
"""Load the assessment agent prompt template."""
with open(PROMPT_ASSESSMENT_AGENT, "r") as f:
return f.read()
def _extract_age_band_from_response(raw_output: str) -> int:
"""
Extract age_band (0-7) from LLM response.
Handles multiple formats:
- JSON: {"assessed_age_band": 2}
- Plain number: 2
- Text with number: "age_band 2" or "band 2"
Returns age_band or raises ValueError if not found.
"""
text = raw_output.strip()
# Try JSON format first
try:
# Remove markdown code blocks if present
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(lines[1:-1] if lines[-1].strip() == "```" else lines[1:])
# Try to parse as JSON
json_match = re.search(r'\{[^{}]*\}', text, re.DOTALL)
if json_match:
data = json.loads(json_match.group(0))
if "assessed_age_band" in data:
return int(data["assessed_age_band"])
except (json.JSONDecodeError, KeyError, ValueError):
pass
# Try to extract plain number (0-7)
# Look for standalone digits or "age_band X" or "band X"
patterns = [
r'\bage_band\s*[:=]?\s*(\d)',
r'\bband\s*[:=]?\s*(\d)',
r'\b(\d)\b', # Any standalone digit
]
for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
age_band = int(match.group(1))
if 0 <= age_band <= 7:
return age_band
raise ValueError(f"Could not extract age_band from: {raw_output[:100]}")
def assess_developmental_age_band(session_transcript: str, current_age_band: int) -> int:
"""
Assess Basil's developmental age_band based on full session transcript.
Args:
session_transcript: Full session transcript (all turns)
current_age_band: Current age_band (used as fallback on error)
Returns:
Assessed age_band (0-7), or current_age_band on error
"""
template = load_prompt_template()
# Format prompt with session transcript
prompt = template.format(session_transcript=session_transcript)
try:
response = client.chat.completions.create(
model=ASSESSMENT_AGENT_MODEL,
messages=[
{"role": "system", "content": "Answer with ONLY a number (0-7). No explanation."},
{"role": "user", "content": prompt}
],
temperature=0.3, # Low temperature for consistency
max_tokens=50, # Just need a number
)
raw_output = response.choices[0].message.content.strip()
# Extract age_band from response
assessed_age_band = _extract_age_band_from_response(raw_output)
# Clamp to valid range
assessed_age_band = max(0, min(7, assessed_age_band))
return assessed_age_band
except Exception as e:
print(f"[Assessment Agent] Error assessing age_band: {e}")
print(f"[Assessment Agent] Falling back to current_age_band: {current_age_band}")
return current_age_band # Safe fallback: no promotion on error
if __name__ == "__main__":
# Test the assessment agent
print("Testing Assessment Agent...")
test_transcript = """Tutor: Hello everyone! Today we're going to learn about Animals.
Sophie: That sounds fun! Can we learn about cats?
Tutor: Yes! Cats are furry animals that say meow.
Sophie: And they have whiskers!
Tutor: Basil, can you say 'cat'?
Basil: cat
Tutor: Great job!
Sophie: Good try, Basil!"""
result = assess_developmental_age_band(test_transcript, current_age_band=0)
print(f"Assessed age_band: {result}")