-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathforce_train_and_eval.py
More file actions
executable file
·111 lines (92 loc) · 3.19 KB
/
force_train_and_eval.py
File metadata and controls
executable file
·111 lines (92 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
"""
Force training and evaluation using existing logs.
This script:
1. Runs training in "mixed" mode (dual-objective: world/trunk + basil/LoRA)
2. Runs post-train evaluation sessions
3. Prints results
Usage:
python force_train_and_eval.py
"""
import os
import sys
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from train_basil_v2 import train
from auto_session import run_session
from config import EVAL_SESSIONS_AFTER_TRAIN
def main():
print("=" * 80)
print("FORCE TRAINING AND EVALUATION")
print("=" * 80)
print()
# 1. Run training
print("[1/2] Starting training in 'mixed' mode...")
print(" (Dual-objective: world/trunk LM + basil/LoRA policy)")
print()
try:
train(mode="mixed")
print()
print("✅ Training completed successfully")
except Exception as e:
print()
print(f"❌ Training failed: {e}")
import traceback
traceback.print_exc()
return 1
print()
print("=" * 80)
# 2. Run post-train evaluation
print(f"[2/2] Running {EVAL_SESSIONS_AFTER_TRAIN} post-train eval sessions...")
print()
eval_results = []
for i in range(EVAL_SESSIONS_AFTER_TRAIN):
print(f" Eval session {i+1}/{EVAL_SESSIONS_AFTER_TRAIN}...")
try:
result = run_session(
training_phase="posttrain_eval",
verbose=True,
)
eval_results.append(result)
session_metrics = result.get("session_metrics", {})
avg_score = session_metrics.get("avg_score_session", 0.0)
graded_turns = session_metrics.get("graded_turns_count", 0)
print(f" Session {result.get('session_id', 'unknown')}: "
f"score={avg_score:.2f}, graded_turns={graded_turns}")
except Exception as e:
print(f" ❌ Eval session {i+1} failed: {e}")
import traceback
traceback.print_exc()
print()
print("=" * 80)
print("EVALUATION SUMMARY")
print("=" * 80)
if eval_results:
scores = [
r.get("session_metrics", {}).get("avg_score_session", 0.0)
for r in eval_results
]
graded_counts = [
r.get("session_metrics", {}).get("graded_turns_count", 0)
for r in eval_results
]
print(f"Sessions completed: {len(eval_results)}")
print(f"Average score: {sum(scores)/len(scores):.3f}")
print(f"Total graded turns: {sum(graded_counts)}")
print()
print("Per-session details:")
for i, result in enumerate(eval_results):
session_id = result.get("session_id", "unknown")
metrics = result.get("session_metrics", {})
print(f" {i+1}. {session_id}: "
f"score={metrics.get('avg_score_session', 0.0):.3f}, "
f"graded={metrics.get('graded_turns_count', 0)}")
else:
print("No eval sessions completed successfully")
print()
print("=" * 80)
print("Done!")
print("=" * 80)
return 0
if __name__ == "__main__":
sys.exit(main())