Skip to content

Commit ff7fda8

Browse files
authored
eval: add typed training config to Hydra and feedback requests (#38)
## Summary - add a nested training section to eval Hydra config defaults - add explicit eval-side training schema and convert it once into runtime TrainingConfig - enforce strict type invariance for runtime (HarnessConfig.training must be TrainingConfig) - pass typed training config through every FeedbackItem generated by eval runner - update eval docs and tests for nested overrides (training.is_clip, training.learning_rate) ## Validation - uv run pytest tests/test_eval_config.py tests/test_eval_runner.py -q - 23 passed <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added configurable training parameters (learning rate, alpha, clipping, max-grad-norm, KL weight, teacher_top_k) and runtime validation for training settings. * CLI and programmatic interfaces now support overriding nested training hyperparameters and explicit output directory. * **Configuration** * Batch processing frequency increased to 4 steps per batch. * Default output directory now uses a timestamped path when not overridden. * **Tests** * Added tests covering training config validation, overrides, and mismatch rejection. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 8745ad6 commit ff7fda8

11 files changed

Lines changed: 182 additions & 112 deletions

File tree

claas/api.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from __future__ import annotations
3737

3838
import asyncio
39+
import dataclasses
3940
import hashlib
4041
import logging
4142
import os
@@ -86,6 +87,7 @@
8687
ServiceHealth,
8788
TextCompletionChoice,
8889
TextCompletionResponse,
90+
TrainingConfig,
8991
)
9092
from .dashboard import feedback_log as feedback_log_mod, rendering as dashboard_rendering
9193
from .inference import get_inference_backend, vllm_control
@@ -215,6 +217,28 @@ def _get_inference_backend(request: Request) -> InferenceBackend:
215217
return request.app.state.inference_backend
216218

217219

220+
def _validate_training_config(training: TrainingConfig) -> None:
221+
"""Validate training config ranges for direct API callers."""
222+
errors: list[str] = []
223+
if training.learning_rate <= 0:
224+
errors.append("learning_rate must be > 0")
225+
if not (0.0 <= training.alpha <= 1.0):
226+
errors.append("alpha must be within [0, 1]")
227+
if not (1.0 <= training.is_clip <= 20.0):
228+
errors.append("is_clip must be within [1, 20]")
229+
if training.max_grad_norm < 0.0:
230+
errors.append("max_grad_norm must be >= 0")
231+
if not (0.0 <= training.kl_reg_weight <= 1.0):
232+
errors.append("kl_reg_weight must be within [0, 1]")
233+
if not (10 <= training.teacher_top_k <= 100):
234+
errors.append("teacher_top_k must be within [10, 100]")
235+
if errors:
236+
raise HTTPException(
237+
status_code=422,
238+
detail=f"invalid training config: {'; '.join(errors)}",
239+
)
240+
241+
218242
# ---------------------------------------------------------------------------
219243
# Inference endpoints
220244
# ---------------------------------------------------------------------------
@@ -372,12 +396,14 @@ async def feedback(request: FeedbackBatchRequest) -> FeedbackResponse:
372396

373397
first_request = batch_requests[0]
374398
lora_id = first_request.lora_id
375-
training_ref = first_request.training.model_dump(mode="json")
399+
_validate_training_config(first_request.training)
400+
training_ref = dataclasses.asdict(first_request.training)
376401

377402
for req in batch_requests[1:]:
403+
_validate_training_config(req.training)
378404
if req.lora_id != lora_id:
379405
raise HTTPException(status_code=400, detail="all requests must use the same lora_id")
380-
if req.training.model_dump(mode="json") != training_ref:
406+
if dataclasses.asdict(req.training) != training_ref:
381407
raise HTTPException(status_code=400, detail="all requests must use the same training config")
382408

383409
# Resolve cache entries before acquiring lock or doing orchestration.

claas/core/types.py

Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from __future__ import annotations
88

9+
from dataclasses import dataclass
910
from typing import TYPE_CHECKING, Any, Literal, TypedDict
1011

1112
from pydantic import BaseModel, ConfigDict, Field
@@ -21,42 +22,16 @@ class ChatMessage(TypedDict):
2122
content: str
2223

2324

24-
class TrainingConfig(BaseModel):
25-
"""Training configuration for distillation."""
25+
@dataclass
26+
class TrainingConfig:
27+
"""Training hyperparameters (dataclass for Hydra structured-config compatibility)."""
2628

27-
learning_rate: float = Field(
28-
default=3e-5,
29-
description="Learning rate for LoRA parameter updates",
30-
)
31-
alpha: float = Field(
32-
default=0.5,
33-
ge=0.0,
34-
le=1.0,
35-
description="GJS interpolation (0.5 = symmetric JSD, 1.0 = reverse KL)",
36-
)
37-
is_clip: float = Field(
38-
default=5.0,
39-
ge=1.0,
40-
le=20.0,
41-
description="Importance sampling ratio clip (exp space)",
42-
)
43-
max_grad_norm: float = Field(
44-
default=1.0,
45-
ge=0.0,
46-
description="Maximum gradient norm for clipping",
47-
)
48-
kl_reg_weight: float = Field(
49-
default=0.0,
50-
ge=0.0,
51-
le=1.0,
52-
description="Weight for KL regularization to base policy",
53-
)
54-
teacher_top_k: int = Field(
55-
default=100,
56-
ge=10,
57-
le=100,
58-
description="Number of top logprobs to request from teacher",
59-
)
29+
learning_rate: float = 3e-5
30+
alpha: float = 0.5
31+
is_clip: float = 5.0
32+
max_grad_norm: float = 1.0
33+
kl_reg_weight: float = 0.0
34+
teacher_top_k: int = 100
6035

6136

6237
class SDPOLossInput(BaseModel):
@@ -473,5 +448,3 @@ class TextCompletionResponse(BaseModel):
473448
model: str
474449
choices: list[TextCompletionChoice]
475450
usage: CompletionUsage
476-
477-

claas/eval/README.md

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,20 @@ metrics: # metrics to evaluate per step
2626

2727
num_steps: 20
2828
batch_size: 4
29-
steps_per_batch: 1 # gradient updates per batch
29+
steps_per_batch: 4 # gradient updates per batch
3030
feedback_repetitions: 1 # times to repeat feedback string
31+
training: # forwarded to /v1/feedback training config
32+
learning_rate: 3e-5
33+
alpha: 0.5
34+
is_clip: 5.0
35+
max_grad_norm: 1.0
36+
kl_reg_weight: 0.0
37+
teacher_top_k: 100
3138
collapse_steps: [0, 5, 10, 15, 19] # steps where collapse metric runs
3239
plots: true # generate matplotlib plots
3340
seed: 42
3441
lora_id_prefix: eval
35-
output_dir: ./data/evals
42+
output_dir: ./data/evals/${now:%Y%m%d-%H%M%SZ}
3643

3744
openclaw_url: http://localhost:18789 # OpenClaw gateway (null = use CLaaS API directly)
3845
```
@@ -48,23 +55,24 @@ uv run python -m claas.eval 'preferences=[concise]' num_steps=10
4855
# Override base model and mode
4956
uv run python -m claas.eval base_model=Qwen/Qwen3-30B-A3B mode=tinker
5057
58+
# Override training hyperparameters
59+
uv run python -m claas.eval training.is_clip=7.0 training.learning_rate=1e-4
60+
5161
# Use a custom config directory
5262
uv run python -m claas.eval --config-dir ./my_configs --config-name my_config
5363
```
5464

5565
### Programmatic usage
5666

5767
```python
58-
from claas.eval.config import build_harness_config
5968
from claas.eval.runner import run_harness
6069
from claas.eval.types import EvalConfig
6170
import asyncio
6271
63-
config = build_harness_config(
64-
EvalConfig(
65-
preferences=["concise"],
66-
num_steps=5,
67-
)
72+
config = EvalConfig(
73+
preferences=["concise"],
74+
num_steps=5,
75+
output_dir="./data/evals/manual-run", # explicit when bypassing Hydra CLI
6876
)
6977
asyncio.run(run_harness(config))
7078
```

claas/eval/__main__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import hydra
1414
from omegaconf import OmegaConf
1515

16-
from .config import build_harness_config
16+
from . import config as _config # noqa: F401
1717
from .types import EvalConfig
1818

1919

@@ -25,8 +25,7 @@ def main(cfg: EvalConfig) -> None:
2525
if not isinstance(eval_cfg, EvalConfig):
2626
raise TypeError("Hydra did not produce an EvalConfig instance")
2727

28-
config = build_harness_config(eval_cfg)
29-
asyncio.run(run_harness(config))
28+
asyncio.run(run_harness(eval_cfg))
3029

3130

3231
if __name__ == "__main__":

claas/eval/config.py

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,7 @@
1-
"""Hydra-based configuration for the evaluation harness."""
2-
3-
from __future__ import annotations
4-
5-
import dataclasses
6-
import os
7-
import re
8-
from datetime import datetime, timezone
9-
from pathlib import Path
1+
"""Hydra schema registration for the evaluation harness."""
102

113
from hydra.core.config_store import ConfigStore
124

13-
from .types import EvalConfig, HarnessConfig
14-
15-
# Pattern matching the timestamped run-id suffix (e.g. 20260220-012345Z)
16-
_RUN_ID_RE = re.compile(r"\d{8}-\d{6}Z$")
5+
from .types import EvalConfig
176

187
ConfigStore.instance().store(name="_eval_schema", node=EvalConfig)
19-
20-
21-
def build_harness_config(eval_cfg: EvalConfig) -> HarnessConfig:
22-
"""Post-process EvalConfig → HarnessConfig (no secrets)."""
23-
fields = dataclasses.asdict(eval_cfg)
24-
25-
# Timestamped output subdir (skip if output_dir already ends with a run-id,
26-
# which allows resuming an existing run by passing its directory).
27-
output_dir = fields["output_dir"]
28-
if not _RUN_ID_RE.search(Path(output_dir).name):
29-
run_id = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%SZ")
30-
fields["output_dir"] = os.path.join(output_dir, run_id)
31-
32-
return HarnessConfig(**fields)

claas/eval/configs/base.yaml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,15 @@ num_steps: 20
2424
batch_size: 4
2525
steps_per_batch: 4
2626
feedback_repetitions: 1
27+
training:
28+
learning_rate: 3e-5
29+
alpha: 0.5
30+
is_clip: 5.0
31+
max_grad_norm: 1.0
32+
kl_reg_weight: 0.0
33+
teacher_top_k: 100
2734
seed: 42
2835
lora_id_prefix: eval
29-
output_dir: ./data/evals
36+
output_dir: ./data/evals/${now:%Y%m%d-%H%M%SZ}
3037

3138
openclaw_url: http://localhost:18789

claas/eval/runner.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@
2323
from .plotting import generate_plots
2424
from .preferences import PreferenceConfig, get_preference_configs
2525
from .types import (
26+
EvalConfig,
2627
EvalMetrics,
2728
ExperimentResult,
2829
ExperimentSummary,
29-
HarnessConfig,
3030
LocalDistillMetrics,
3131
MetricContext,
3232
StepResult,
@@ -38,7 +38,7 @@
3838
logger = logging.getLogger(__name__)
3939

4040

41-
async def _init_lora(config: HarnessConfig, lora_id: str) -> str:
41+
async def _init_lora(config: EvalConfig, lora_id: str) -> str:
4242
"""Initialize a fresh LoRA adapter via CLaaS API."""
4343
# LoRA init can exceed two minutes when the remote trainer is cold-starting.
4444
async with httpx.AsyncClient(base_url=config.claas_url, timeout=300.0) as client:
@@ -51,7 +51,7 @@ async def _init_lora(config: HarnessConfig, lora_id: str) -> str:
5151

5252

5353
async def _submit_feedback(
54-
config: HarnessConfig,
54+
config: EvalConfig,
5555
lora_id: str,
5656
samples: list[FeedbackItem],
5757
) -> LocalDistillMetrics | TinkerDistillMetrics | None:
@@ -92,7 +92,7 @@ async def _submit_feedback(
9292

9393

9494
async def _generate_response(
95-
config: HarnessConfig,
95+
config: EvalConfig,
9696
model: str,
9797
prompt: str,
9898
temperature: float = 0,
@@ -127,7 +127,7 @@ async def _generate_response(
127127

128128

129129
async def _measure_eval_metrics(
130-
config: HarnessConfig,
130+
config: EvalConfig,
131131
pref: PreferenceConfig,
132132
model_name: str,
133133
step: int,
@@ -205,7 +205,7 @@ def _append_step_jsonl(output_dir: str, preference: str, step: StepResult) -> No
205205
f.write(json.dumps(data) + "\n")
206206

207207

208-
def _write_metadata(output_dir: str, preference: str, config: HarnessConfig, lora_id: str) -> None:
208+
def _write_metadata(output_dir: str, preference: str, config: EvalConfig, lora_id: str) -> None:
209209
"""Write experiment metadata."""
210210
pref_dir = os.path.join(output_dir, preference)
211211
os.makedirs(pref_dir, exist_ok=True)
@@ -273,7 +273,7 @@ def _write_summary(output_dir: str, results: list[ExperimentResult]) -> None:
273273

274274

275275
async def run_preference_experiment(
276-
config: HarnessConfig,
276+
config: EvalConfig,
277277
pref: PreferenceConfig,
278278
enabled_metrics: list[Metric] | None = None,
279279
needs_generation: bool = False,
@@ -357,6 +357,8 @@ async def run_preference_experiment(
357357
)
358358

359359
# Main loop
360+
training_cfg = config.training
361+
360362
for step in range(resume_from, config.num_steps):
361363
step_start = time.perf_counter()
362364

@@ -385,6 +387,7 @@ async def run_preference_experiment(
385387
prompt=prompt,
386388
response=content,
387389
feedback=feedback_str,
390+
training=training_cfg,
388391
))
389392
except (httpx.HTTPError, KeyError, ValueError) as e:
390393
logger.warning(
@@ -471,7 +474,7 @@ async def run_preference_experiment(
471474
return result
472475

473476

474-
async def run_harness(config: HarnessConfig) -> None:
477+
async def run_harness(config: EvalConfig) -> None:
475478
"""Run the full evaluation harness."""
476479
logging.basicConfig(
477480
level=logging.INFO,

claas/eval/types.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from typing import Any, Optional
88

99
from claas.core.config import DEFAULT_SYSTEM_PROMPT
10-
from claas.core.types import ChatMessage
10+
from claas.core.types import ChatMessage, TrainingConfig
1111

1212

1313
@dataclass
@@ -96,10 +96,7 @@ class EvalConfig:
9696
batch_size: int = 4
9797
steps_per_batch: int = 4
9898
feedback_repetitions: int = 1
99-
100-
101-
# HarnessConfig is the post-processed runtime config (still no secrets).
102-
HarnessConfig = EvalConfig
99+
training: TrainingConfig = field(default_factory=TrainingConfig)
103100

104101

105102
@dataclass

0 commit comments

Comments
 (0)