-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtask_contract.py
More file actions
327 lines (254 loc) · 11.3 KB
/
task_contract.py
File metadata and controls
327 lines (254 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
# task_contract.py
# Central contract validation for TaskSpec and naturalized task text.
#
# This module provides:
# - normalize_token(): Consistent token normalization used everywhere
# - validate_task_spec_contract(): Single source of truth for TaskSpec validation
# - text_contains_token(): Safe token presence check (word-boundary aware)
import re
import logging
from config import STRICT_TASK_CONTRACT, DEBUG_SESSION
# Configure logger
logger = logging.getLogger(__name__)
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('[TaskContract] %(levelname)s: %(message)s'))
logger.addHandler(handler)
logger.setLevel(logging.DEBUG if DEBUG_SESSION else logging.WARNING)
# =============================================================================
# TOKEN NORMALIZATION
# =============================================================================
def normalize_token(token: str) -> str:
"""
Normalize a token for consistent comparison across the pipeline.
Normalization rules:
- Lowercase
- Strip leading/trailing whitespace
- Strip surrounding quotes (single or double)
- Strip trailing punctuation (. , ! ?)
Args:
token: Raw token string
Returns:
Normalized token (may be empty string if input was just punctuation)
Example:
normalize_token("'Fire!'") -> "fire"
normalize_token('"Blue"') -> "blue"
normalize_token(" YES ") -> "yes"
"""
if not token:
return ""
# Lowercase and strip whitespace
result = token.lower().strip()
# Strip surrounding quotes
if len(result) >= 2:
if (result.startswith("'") and result.endswith("'")) or \
(result.startswith('"') and result.endswith('"')):
result = result[1:-1]
# Strip trailing punctuation
result = result.rstrip(".,!?;:")
# Strip leading punctuation that might remain
result = result.lstrip(".,!?;:")
return result.strip()
def normalize_targets(targets: list) -> list:
"""
Normalize a list of target tokens.
Args:
targets: List of raw target strings
Returns:
List of normalized, non-empty tokens (duplicates removed, order preserved)
"""
if not targets:
return []
seen = set()
result = []
for t in targets:
normalized = normalize_token(str(t) if t else "")
if normalized and normalized not in seen:
seen.add(normalized)
result.append(normalized)
return result
# =============================================================================
# TOKEN PRESENCE CHECK (WORD-BOUNDARY AWARE)
# =============================================================================
def text_contains_token(text: str, token: str) -> bool:
"""
Check if text contains the token with word-boundary safety.
This prevents matching 'red' inside 'credited' while still allowing
matches within common sentence structures.
Args:
text: Text to search (naturalized task text, etc.)
token: Token to find (already normalized)
Returns:
True if token appears as a distinct word in text
"""
if not text or not token:
return False
text_lower = text.lower()
token_lower = token.lower()
# Quick check: if token not even present as substring, return False
if token_lower not in text_lower:
return False
# Tokenize text into words (simple: alphanumeric sequences)
text_tokens = set(re.findall(r"[a-z0-9]+", text_lower))
# Check if exact token is present
if token_lower in text_tokens:
return True
# Also check if token appears with common suffixes (e.g., "dog" in "dogs")
# But be conservative - only check if token is prefix of a word
for text_word in text_tokens:
if text_word.startswith(token_lower) and len(text_word) <= len(token_lower) + 3:
# Allow up to 3 extra chars (covers plurals, -ing, -ed, etc.)
return True
return False
# =============================================================================
# CONTRACT VALIDATION
# =============================================================================
class TaskContractViolation(ValueError):
"""Exception raised when TaskSpec contract is violated in strict mode."""
pass
def validate_task_spec_contract(
task_spec: dict,
naturalized_task_text: str = None,
strict: bool = None,
episode_id: str = "",
session_id: str = "",
age_band: int = 0,
) -> list:
"""
Validate TaskSpec contract and naturalized text invariants.
This is the single source of truth for TaskSpec validation.
Called once per episode after task generation and naturalization.
Args:
task_spec: The TaskSpec dict (must have targets, may have choices)
naturalized_task_text: The naturalized text that will be shown to Basil
strict: If True, raise TaskContractViolation on any violation
If False/None, log warnings but continue
episode_id: For logging context
session_id: For logging context
age_band: Basil's current age band (affects validation rules)
Returns:
List of warning messages (empty if no violations)
Raises:
TaskContractViolation: In strict mode, if any contract is violated
"""
if strict is None:
strict = STRICT_TASK_CONTRACT
warnings = []
context = f"[Session:{session_id or '?'} Episode:{episode_id or '?'}]"
# =========================================================================
# TARGETS VALIDATION
# =========================================================================
targets = task_spec.get("targets")
# Check targets exists and is a list
if targets is None:
warnings.append(f"{context} Missing 'targets' field in TaskSpec")
normalized_targets = []
elif not isinstance(targets, list):
warnings.append(f"{context} 'targets' must be list, got {type(targets).__name__}")
normalized_targets = []
else:
# Normalize targets
normalized_targets = normalize_targets(targets)
# Check at least one non-empty target
if not normalized_targets:
warnings.append(f"{context} 'targets' is empty after normalization (raw: {targets})")
# Check target count cap
if len(normalized_targets) > 5:
warnings.append(f"{context} Too many targets ({len(normalized_targets)} > 5)")
# Check individual target constraints
for target in normalized_targets:
# Length cap
if len(target) > 24:
warnings.append(f"{context} Target too long: '{target[:30]}...' ({len(target)} > 24)")
# For age_band=0 vocab/control: reject multiword targets
task_category = task_spec.get("task_category", "vocab")
if age_band == 0 and task_category in ("vocab", "control"):
if " " in target:
warnings.append(f"{context} Multiword target '{target}' not allowed for age_band=0 {task_category}")
# =========================================================================
# CHOICES VALIDATION
# =========================================================================
choices = task_spec.get("choices")
normalized_choices = []
if choices is not None:
if not isinstance(choices, list):
warnings.append(f"{context} 'choices' must be list, got {type(choices).__name__}")
else:
normalized_choices = normalize_targets(choices)
# Check choice count cap
if len(normalized_choices) > 6:
warnings.append(f"{context} Too many choices ({len(normalized_choices)} > 6)")
# =========================================================================
# YES/NO TASK VALIDATION
# =========================================================================
task_text_lower = (task_spec.get("task_text", "") or "").lower()
is_yes_no_task = "yes or no" in task_text_lower or \
(set(normalized_targets) == {"yes", "no"}) or \
(set(normalized_choices) == {"yes", "no"})
if is_yes_no_task:
# Enforce targets contains yes and/or no
if normalized_targets:
has_yes_no = "yes" in normalized_targets or "no" in normalized_targets
if not has_yes_no:
warnings.append(f"{context} Yes/no task but targets={normalized_targets} missing 'yes' or 'no'")
# =========================================================================
# NATURALIZED TEXT INVARIANTS
# =========================================================================
if naturalized_task_text:
# For each target, require it appears in naturalized text
for target in normalized_targets:
if not text_contains_token(naturalized_task_text, target):
warnings.append(
f"{context} Target '{target}' not found in naturalized text: "
f"'{naturalized_task_text[:60]}...'"
)
# For MCQ tasks (choices non-empty), require choices appear in prompt
if normalized_choices:
for choice in normalized_choices:
if not text_contains_token(naturalized_task_text, choice):
warnings.append(
f"{context} Choice '{choice}' not found in naturalized text: "
f"'{naturalized_task_text[:60]}...'"
)
# =========================================================================
# LOG WARNINGS OR RAISE
# =========================================================================
if warnings:
for w in warnings:
logger.warning(w)
if strict:
raise TaskContractViolation(
f"TaskSpec contract violated ({len(warnings)} issue(s)): {warnings[0]}"
)
return warnings
def validate_naturalized_contains_targets(
naturalized_text: str,
targets: list,
choices: list = None,
) -> tuple:
"""
Validate that naturalized text contains all required target/choice tokens.
Used by the naturalizer to validate LLM output.
Args:
naturalized_text: The naturalized task text
targets: List of target tokens (may be raw, will be normalized)
choices: Optional list of choice tokens for MCQ tasks
Returns:
Tuple of (is_valid: bool, missing_tokens: list)
"""
if not naturalized_text:
return (False, ["empty text"])
missing = []
# Check targets
normalized_targets = normalize_targets(targets or [])
for target in normalized_targets:
if not text_contains_token(naturalized_text, target):
missing.append(target)
# Check choices (for MCQ)
if choices:
normalized_choices = normalize_targets(choices)
for choice in normalized_choices:
if not text_contains_token(naturalized_text, choice):
if choice not in missing: # Avoid duplicates
missing.append(choice)
return (len(missing) == 0, missing)