-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscience_domain_generator.py
More file actions
235 lines (193 loc) · 7.05 KB
/
science_domain_generator.py
File metadata and controls
235 lines (193 loc) · 7.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# science_domain_generator.py
# Dynamically generates a list of science/curiosity domains for HowItWorks sessions.
#
# Mirrors the pattern in story_genre_generator.py:
# 1. LLM generates 50 domains at high temperature
# 2. Python picks a random domain from the list
# 3. The list is cached to disk and refreshed when exhausted or stale
#
# Key functions:
# generate_science_domains(basil_assessment) -> list[str]
# pick_domain(basil_assessment) -> str
import os
import json
import random
from datetime import datetime
from openai import OpenAI
from config import (
TASK_AGENT_MODEL,
PROMPT_SCIENCE_DOMAIN_GENERATOR,
SCIENCE_DOMAINS_FILE,
HOWITWORKS_DIR,
)
from file_lock_utils import get_lock
from llm_client import create_smart_client
client = create_smart_client()
FALLBACK_DOMAINS = [
"physics and forces",
"chemistry and materials",
"biology and living things",
"astronomy and space",
"weather and climate",
"engineering and machines",
"medicine and the human body",
"geology and earth science",
"ecology and ecosystems",
"technology and computers",
"food science and cooking",
"ocean science and marine life",
"animal behavior and adaptation",
"energy and electricity",
"light and optics",
"sound and acoustics",
"transportation and vehicles",
"construction and architecture",
"water and fluid dynamics",
"fire and combustion",
"magnetism and electromagnetism",
"genetics and heredity",
"robotics and automation",
"timekeeping and clocks",
"navigation and cartography",
"agriculture and farming",
"textiles and clothing manufacture",
"communication technology",
"printing and publishing",
"mining and metallurgy",
]
def _load_prompt_template() -> str:
"""Load the science domain generator prompt template."""
if os.path.exists(PROMPT_SCIENCE_DOMAIN_GENERATOR):
with open(PROMPT_SCIENCE_DOMAIN_GENERATOR, "r") as f:
return f.read()
raise FileNotFoundError(
f"Science domain generator prompt not found: {PROMPT_SCIENCE_DOMAIN_GENERATOR}"
)
def _load_cached_domains() -> dict:
"""Load cached domains from disk. Returns full data dict or empty dict."""
if not os.path.exists(SCIENCE_DOMAINS_FILE):
return {}
try:
with open(SCIENCE_DOMAINS_FILE, "r") as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return {}
def _save_domains(domains: list, age_band: int):
"""Save domains list to the cache file."""
data = {
"generated_at": datetime.now().isoformat(),
"age_band": age_band,
"domains": domains,
}
os.makedirs(os.path.dirname(SCIENCE_DOMAINS_FILE), exist_ok=True)
with open(SCIENCE_DOMAINS_FILE, "w") as f:
json.dump(data, f, indent=2)
def generate_science_domains(
basil_assessment: dict,
recent_domains: list = None,
n: int = 50,
) -> list:
"""
Generate a diverse list of science/curiosity domains via LLM.
Args:
basil_assessment: Dict with age_band, capabilities, etc.
recent_domains: List of recently used domain names to avoid.
n: Number of domains to request.
Returns:
List of domain name strings (deduplicated, filtered).
"""
age_band = basil_assessment.get("age_band", 0)
if recent_domains:
recent_text = ", ".join(recent_domains[-30:])
else:
recent_text = "(none)"
template = _load_prompt_template()
prompt = template.format(recent_domains=recent_text)
system_msg = f"Generate exactly {n} science domains. Output valid JSON only."
try:
response = client.chat.completions.create(
model=TASK_AGENT_MODEL,
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": prompt},
],
temperature=1.0,
max_tokens=2000,
)
raw_output = response.choices[0].message.content.strip()
if raw_output.startswith("```"):
lines = raw_output.split("\n")
raw_output = "\n".join(lines[1:-1])
data = json.loads(raw_output)
domains_list = data.get("domains", [])
seen = set()
recent_lower = set(d.lower().strip() for d in (recent_domains or []))
filtered = []
for domain in domains_list:
if isinstance(domain, str):
domain = domain.strip()
if not domain:
continue
key = domain.lower()
if key in seen or key in recent_lower:
continue
seen.add(key)
filtered.append(domain)
print(f"[ScienceDomain] Generated {len(filtered)} domains for age_band={age_band}")
_save_domains(filtered, age_band)
return filtered
except json.JSONDecodeError as e:
print(f"[ScienceDomain] JSON parse error: {e}")
return FALLBACK_DOMAINS[:]
except Exception as e:
print(f"[ScienceDomain] Error: {e}")
return FALLBACK_DOMAINS[:]
def pick_domain(basil_assessment: dict, used_domains_this_session: list = None) -> str:
"""
Pick a random science domain for the current HowItWorks session.
Loads the cached domain list, regenerating it if empty or missing.
Avoids domains already used in this session (if provided).
Args:
basil_assessment: Dict with age_band, capabilities, etc.
used_domains_this_session: Domains already tried this session (for retries).
Returns:
A domain string, e.g. "physics and forces".
"""
used_this_session = set(
d.lower().strip() for d in (used_domains_this_session or [])
)
with get_lock(SCIENCE_DOMAINS_FILE):
cached = _load_cached_domains()
domains = cached.get("domains", [])
available = [d for d in domains if d.lower().strip() not in used_this_session]
if len(available) < 5:
print(f"[ScienceDomain] Only {len(available)} domains available, regenerating...")
domains = generate_science_domains(
basil_assessment,
recent_domains=list(used_this_session),
)
available = [d for d in domains if d.lower().strip() not in used_this_session]
if not available:
available = [
d for d in FALLBACK_DOMAINS
if d.lower().strip() not in used_this_session
]
if not available:
available = FALLBACK_DOMAINS[:]
chosen = random.choice(available)
return chosen
if __name__ == "__main__":
print("Testing Science Domain Generator...")
print()
test_assessment = {
"age_band": 2,
"capabilities": ["Some word production", "Basic pattern recognition"],
}
domains = generate_science_domains(test_assessment)
print(f"\nGenerated {len(domains)} domains:")
for i, domain in enumerate(domains, 1):
print(f" {i:>3}. {domain}")
print(f"\n--- Random picks ---")
for _ in range(5):
domain = pick_domain(test_assessment)
print(f" Picked: {domain}")