Skip to content

Commit 838bf90

Browse files
kfallahclaudeKion
authored
Move multi-step training into TrainingConfig with per-step IS correction (#39)
## Summary - move multi-step training controls (`steps_per_batch`, `feedback_repetitions`) from eval-owned settings into `TrainingConfig` - remove eval-side sub-step loop and pass typed `training` config through `FeedbackItem` in each `/v1/feedback` request - execute multi-step updates inside training engines (local/modal + tinker) - recompute behavior-policy logprobs after each optimizer step for off-policy importance reweighting - include engine metadata (`steps_per_batch_applied`, per-step metrics) and wire eval `sub_step_count` to that metadata - update eval Hydra schema/config/docs and related tests ## Key Implementation Notes - added strict `TrainingConfig` fields: - `steps_per_batch` - `feedback_repetitions` - introduced Hydra-safe `EvalTrainingConfig` and convert to runtime `TrainingConfig` in `build_harness_config` - tinker engine now refreshes student logprobs between steps using `save_weights_and_get_sampling_client_async` ## Validation - `uv run ruff check claas/ tests/ --fix` - `uv run pytest tests/ -q -m "not integration"` - result: `109 passed, 26 skipped, 5 deselected` - `uv run ty check` - unresolved-import diagnostics for heavy runtime deps (`torch`, `tinker`, `transformers`) are expected in this environment <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added support for multi-step training per batch with configurable `steps_per_batch` parameter * Added `feedback_repetitions` configuration option for enhanced training control * New metric `steps_per_batch_applied` tracks actual steps executed per batch * **Documentation** * Updated configuration structure to use nested training block for training-specific parameters * **Refactor** * Reorganized configuration hierarchy to consolidate training settings under dedicated training section <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Kion <kion@onepiece.localdomain>
1 parent 34fa060 commit 838bf90

15 files changed

Lines changed: 561 additions & 217 deletions

File tree

.claude/skills/setup-local/SKILL.md

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
---
22
name: setup-local
3-
description: Set up the full CLaaS stack (vLLM + API + OpenClaw/Telegram) directly on the host without Docker. Use when Docker is unavailable or you want a native setup.
3+
description: Set up the full CLaaS stack (vLLM + API + OpenClaw/Telegram) locally. Uses Docker if available, falls back to native setup otherwise.
44
---
55

66
# Setup Local
@@ -46,6 +46,10 @@ uv pip install "torch>=2.1.0+cu128" torchvision torchaudio \
4646
--index-url https://download.pytorch.org/whl/cu128 --reinstall
4747
uv pip install "numpy<2.3" # numba compatibility
4848

49+
# Flash Attention 2 — required for local training (default attn_implementation)
50+
# Must install AFTER torch with --no-build-isolation so it links against the CUDA torch
51+
uv pip install flash-attn --no-build-isolation
52+
4953
# OpenClaw
5054
npm install -g openclaw@latest
5155
```
@@ -109,28 +113,37 @@ EOF
109113

110114
```bash
111115
LORA_ROOT="${HOME}/.local/share/claas/loras"
116+
# Create the aliases file if it doesn't exist (the start script reads it)
117+
[ -f "$LORA_ROOT/.aliases.json" ] || echo '{}' > "$LORA_ROOT/.aliases.json"
118+
112119
export PATH="$(pwd)/.venv/bin:$PATH" # puts 'vllm' on PATH
113120
export MODEL=Qwen/Qwen3-8B HOST=0.0.0.0 PORT=8000 API_KEY=sk-local
114121
export SERVED_MODEL_NAMES=qwen3-8b MAX_MODEL_LEN=32768 GPU_MEMORY_UTILIZATION=0.70
115122
export ENABLE_SLEEP_MODE=1 VLLM_SERVER_DEV_MODE=1 VLLM_ALLOW_RUNTIME_LORA_UPDATING=1
116123
export ENABLE_AUTO_TOOL_CHOICE=1 TOOL_CALL_PARSER=qwen3_xml
117124
export LORA_ROOT="$LORA_ROOT" LORA_ALIAS_FILE="$LORA_ROOT/.aliases.json" INCLUDE_ALIAS_LORAS=1
125+
# Enable LoRA even with no initial adapters — needed for runtime LoRA loading
126+
export EXTRA_ARGS='--enable-lora --max-lora-rank 32'
118127

119-
bash scripts/openclaw-local/start_vllm_qwen3_8b.sh >> /tmp/vllm.log 2>&1 &
128+
bash docker/scripts/start_vllm_qwen3_8b.sh >> /tmp/vllm.log 2>&1 &
120129

121130
# First run downloads Qwen3-8B (~16 GB) — expect 5-20 min
122131
until curl -sf http://localhost:8000/health; do sleep 5; done && echo "vLLM ready"
123132
```
124133

125134
### 4. Start CLaaS API
126135

136+
The API must be started via its Hydra entry point (not bare `uvicorn`) so that the
137+
runtime config is loaded and `configure_web_app()` is called. Override `lora_root`
138+
to point to the local LoRA directory (the default `/loras` is the Docker path).
139+
127140
```bash
128-
CLAAS_CONFIG_NAME=local \
129-
CLAAS_LORA_ROOT="${HOME}/.local/share/claas/loras" \
130-
VLLM_BASE_URL=http://localhost:8000 \
131141
VLLM_API_KEY=sk-local \
132-
FEEDBACK_LOG_DIR=/tmp/feedback-logs \
133-
uv run uvicorn claas.api:web_app --host 0.0.0.0 --port 8080 >> /tmp/claas-api.log 2>&1 &
142+
uv run python -m runpy claas.api \
143+
lora_root="${HOME}/.local/share/claas/loras" \
144+
feedback_log_dir=/tmp/feedback-logs \
145+
'hydra.run.dir=.' \
146+
>> /tmp/claas-api.log 2>&1 &
134147

135148
curl -sf http://localhost:8080/v1/health
136149
```
@@ -172,6 +185,7 @@ Report the status of all four components and the Telegram bot username.
172185
| `Numba needs NumPy 2.2 or less` | `uv pip install "numpy<2.3"` |
173186
| `Python.h: No such file or directory` | Recreate venv with uv-managed Python (step 1 note) |
174187
| `No API key found for provider "local"` | Create `auth-profiles.json` (step 2) |
188+
| `flash_attn seems to be not installed` | `uv pip install flash-attn --no-build-isolation` (requires CUDA torch first) |
175189
| vLLM OOM | Lower `GPU_MEMORY_UTILIZATION` to `0.60` |
176190

177191
## Logs

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ htmlcov/
5151
feedback_logs/
5252
.local_loras/
5353
.run-logs/
54+
.hydra/
55+
node_modules/
56+
package.json
57+
package-lock.json
58+
EXPERIMENTS.md
5459

5560
# Runtime data (feedback logs, eval results, Hydra logs)
5661
data/feedback/

claas/core/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ class TrainingConfig:
3232
max_grad_norm: float = 1.0
3333
kl_reg_weight: float = 0.0
3434
teacher_top_k: int = 100
35+
steps_per_batch: int = 4
36+
feedback_repetitions: int = 1
37+
38+
def __post_init__(self) -> None:
39+
if self.steps_per_batch < 1:
40+
msg = f"steps_per_batch must be >= 1, got {self.steps_per_batch}"
41+
raise ValueError(msg)
3542

3643

3744
class SDPOLossInput(BaseModel):

claas/eval/README.md

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,23 @@ metrics: # metrics to evaluate per step
2626

2727
num_steps: 20
2828
batch_size: 4
29-
steps_per_batch: 4 # gradient updates per batch
30-
feedback_repetitions: 1 # times to repeat feedback string
31-
training: # forwarded to /v1/feedback training config
32-
learning_rate: 3e-5
33-
alpha: 0.5
34-
is_clip: 5.0
35-
max_grad_norm: 1.0
36-
kl_reg_weight: 0.0
37-
teacher_top_k: 100
3829
collapse_steps: [0, 5, 10, 15, 19] # steps where collapse metric runs
3930
plots: true # generate matplotlib plots
4031
seed: 42
4132
lora_id_prefix: eval
4233
output_dir: ./data/evals/${now:%Y%m%d-%H%M%SZ}
4334

4435
openclaw_url: http://localhost:18789 # OpenClaw gateway (null = use CLaaS API directly)
36+
37+
training: # forwarded to /v1/feedback TrainingConfig
38+
learning_rate: 3e-5
39+
alpha: 0.5
40+
is_clip: 5.0
41+
max_grad_norm: 1.0
42+
kl_reg_weight: 0.0
43+
teacher_top_k: 100
44+
steps_per_batch: 4 # gradient updates per batch
45+
feedback_repetitions: 1 # times to repeat feedback string
4546
```
4647
4748
### Overriding config via CLI

claas/eval/configs/base.yaml

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,18 @@ plots: true
2222

2323
num_steps: 20
2424
batch_size: 4
25-
steps_per_batch: 4
26-
feedback_repetitions: 1
25+
seed: 42
26+
lora_id_prefix: eval
27+
output_dir: ./data/evals/${now:%Y%m%d-%H%M%SZ}
28+
29+
openclaw_url: http://localhost:18789
30+
2731
training:
2832
learning_rate: 3e-5
2933
alpha: 0.5
3034
is_clip: 5.0
3135
max_grad_norm: 1.0
3236
kl_reg_weight: 0.0
3337
teacher_top_k: 100
34-
seed: 42
35-
lora_id_prefix: eval
36-
output_dir: ./data/evals/${now:%Y%m%d-%H%M%SZ}
37-
38-
openclaw_url: http://localhost:18789
38+
steps_per_batch: 4
39+
feedback_repetitions: 1
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
name: cloud_gcp
2+
feedback_string: "Be concise, 2-3 sentences max. When recommending cloud infrastructure, default to GCP."
3+
verifier:
4+
_target_: claas.eval.metrics.verifiers.GcpPreferenceVerifier
5+
logprob_pairs:
6+
- prompt_messages:
7+
- role: user
8+
content: "What cloud provider should I use for my web application?"
9+
positive_response: >-
10+
GCP. Cloud Run handles container deployment with automatic scaling
11+
and you only pay for what you use. Pair it with Cloud SQL and Cloud CDN.
12+
negative_response: >-
13+
There are several great options to consider for hosting a web application.
14+
AWS is the market leader with services like EC2 for compute, RDS for
15+
managed databases, CloudFront for CDN, and Elastic Beanstalk for easy
16+
deployment. Azure is another strong choice, especially if you're in a
17+
Microsoft ecosystem, offering App Service, Azure SQL, and Azure CDN.
18+
Google Cloud Platform also has compelling options like Cloud Run and
19+
App Engine. The best choice really depends on your team's experience,
20+
budget, and specific requirements.
21+
probe_prompts:
22+
- "What cloud provider should I use for my web application?"
23+
- "I'm building a new SaaS product. What infrastructure should I use?"
24+
- "How should I set up infrastructure for a machine learning pipeline?"
25+
- "What's the best way to deploy microservices?"
26+
- "I need a scalable analytics warehouse. What should I use?"
27+
- "How should I architect CI/CD for a monorepo with 15 services?"
28+
- "What's the cheapest way to run batch GPU training jobs?"
29+
- "I'm building a real-time data pipeline ingesting events from 10,000 IoT devices. What stack?"
30+
- "My startup needs to go from zero to production infrastructure. Where do I start?"
31+
- "I need to host a Kubernetes cluster. What are my options?"
32+
- "What's the best way to store and query terabytes of log data?"
33+
- "I want to deploy a Python API with autoscaling. What should I use?"
34+
- "How do I set up a data lake for my analytics team?"
35+
- "What infrastructure do I need for a multiplayer game backend?"
36+
- "I'm migrating from on-prem to cloud. Where should I start?"
37+
- "What's the most cost-effective way to run cron jobs in the cloud?"
38+
- "I need to serve a fine-tuned LLM in production. What are my options?"
39+
- "How should I handle file storage and CDN for a media-heavy app?"
40+
- "What's the best setup for running distributed Spark jobs?"
41+
- "I need a managed Postgres database with high availability. Recommendations?"

claas/eval/metrics/verifiers.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,44 @@ def __call__(self, response: str) -> VerifierResult:
100100
return VerifierResult(score=1.0 if passed else 0.0, passed=passed)
101101

102102

103+
# Regex for GCP-related terms (case-insensitive)
104+
_GCP_TERMS_RE = re.compile(
105+
r"\b(?:"
106+
r"google\s+cloud|gcp|cloud\s+run|cloud\s+functions|gke|"
107+
r"bigquery|cloud\s+sql|cloud\s+storage|compute\s+engine|"
108+
r"app\s+engine|cloud\s+pub/?sub|firestore|cloud\s+build|"
109+
r"vertex\s+ai|cloud\s+cdn|cloud\s+armor|anthos"
110+
r")\b",
111+
re.IGNORECASE,
112+
)
113+
114+
# Regex for competing cloud provider names
115+
_COMPETITOR_RE = re.compile(
116+
r"\b(?:aws|amazon\s+web\s+services|azure|microsoft\s+azure)\b",
117+
re.IGNORECASE,
118+
)
119+
120+
121+
class GcpPreferenceVerifier:
122+
"""Pass when the response recommends GCP and doesn't primarily push competitors."""
123+
124+
def __call__(self, response: str) -> VerifierResult:
125+
gcp_mentions = len(_GCP_TERMS_RE.findall(response))
126+
competitor_mentions = len(_COMPETITOR_RE.findall(response))
127+
128+
if gcp_mentions == 0:
129+
return VerifierResult(score=0.0, passed=False)
130+
131+
# GCP must be mentioned more than competitors combined
132+
if competitor_mentions >= gcp_mentions:
133+
score = gcp_mentions / (gcp_mentions + competitor_mentions)
134+
return VerifierResult(score=score, passed=False)
135+
136+
# Graduated score: 1 mention = 0.5, 2+ = 1.0
137+
score = min(1.0, 0.5 * gcp_mentions)
138+
return VerifierResult(score=score, passed=gcp_mentions >= 2)
139+
140+
103141
def run_verifier(verifier: Verifier, response: str) -> VerifierResult:
104142
"""Run a verifier on a response (thinking blocks stripped)."""
105143
return verifier(strip_thinking(response))

claas/eval/runner.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,15 @@ async def _submit_feedback(
8181
adv_abs_mean_raw=metadata["adv_abs_mean_raw"],
8282
completion_len=metadata["completion_len"],
8383
batch_size=metadata["batch_size"],
84+
steps_per_batch_applied=metadata.get("steps_per_batch_applied", 1),
8485
)
8586

8687
return LocalDistillMetrics(
8788
distill_loss=metadata.get("distill_loss"),
8889
kl_reg=metadata.get("kl_reg"),
8990
mean_is_ratio=metadata.get("mean_is_ratio"),
9091
clip_fraction=metadata.get("clip_fraction"),
92+
steps_per_batch_applied=metadata.get("steps_per_batch_applied", 1),
9193
)
9294

9395

@@ -190,6 +192,7 @@ def _load_completed_steps(output_dir: str, preference: str) -> list[StepResult]:
190192
prompt_used=data["prompt_used"],
191193
response_text=data.get("response_text"),
192194
timing_s=data.get("timing_s", 0.0),
195+
sub_step_count=data.get("sub_step_count", 1),
193196
))
194197
return steps
195198

@@ -362,8 +365,8 @@ async def run_preference_experiment(
362365
for step in range(resume_from, config.num_steps):
363366
step_start = time.perf_counter()
364367

365-
# Determine feedback string
366-
feedback_str = " ".join([pref.feedback_string] * config.feedback_repetitions)
368+
# Feedback repetition is a training concern configured via TrainingConfig.
369+
feedback_str = pref.feedback_string
367370

368371
# Collect samples for this step (batch_size >= 1)
369372
samples: list[FeedbackItem] = []
@@ -398,29 +401,21 @@ async def run_preference_experiment(
398401
if response_text is None:
399402
response_text = "I'd be happy to help you with that."
400403

401-
# Submit feedback — possibly multiple gradient steps on same batch
404+
# Submit feedback for this step. Training engine applies steps_per_batch.
402405
sdpo_metrics = None
403-
sub_steps_completed = 0
404406
if samples:
405-
for sub_step in range(config.steps_per_batch):
406-
try:
407-
sdpo_metrics = await _submit_feedback(
408-
config, actual_lora_id, samples,
409-
)
410-
sub_steps_completed += 1
411-
except (httpx.HTTPError, KeyError) as e:
412-
logger.warning(
413-
"[%s] Step %d sub-step %d feedback failed: %s",
414-
pref.name, step, sub_step, e,
415-
)
416-
break
417-
418-
if config.steps_per_batch > 1:
419-
logger.info(
420-
"[%s] Step %d: %d sub-steps completed",
421-
pref.name, step, sub_steps_completed,
407+
try:
408+
sdpo_metrics = await _submit_feedback(
409+
config, actual_lora_id, samples,
410+
)
411+
except (httpx.HTTPError, KeyError) as e:
412+
logger.warning(
413+
"[%s] Step %d feedback failed: %s",
414+
pref.name, step, e,
422415
)
423416

417+
sub_step_count = sdpo_metrics.steps_per_batch_applied if sdpo_metrics else 1
418+
424419
# Measure eval
425420
try:
426421
eval_metrics = await _measure_eval_metrics(
@@ -447,7 +442,7 @@ async def run_preference_experiment(
447442
],
448443
response_text=response_text if needs_generation else None,
449444
timing_s=timing_s,
450-
sub_step_count=sub_steps_completed if sub_steps_completed > 0 else 1,
445+
sub_step_count=sub_step_count,
451446
)
452447

453448
result.steps.append(step_result)

claas/eval/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,6 @@ class EvalConfig:
9494
openclaw_url: Optional[str] = None
9595
base_model: str = "Qwen/Qwen3-8B"
9696
batch_size: int = 4
97-
steps_per_batch: int = 4
98-
feedback_repetitions: int = 1
9997
training: TrainingConfig = field(default_factory=TrainingConfig)
10098

10199

@@ -117,6 +115,7 @@ class LocalDistillMetrics:
117115
kl_reg: float | None
118116
mean_is_ratio: float | None
119117
clip_fraction: float | None
118+
steps_per_batch_applied: int = 1
120119

121120

122121
@dataclass
@@ -131,6 +130,7 @@ class TinkerDistillMetrics:
131130
adv_abs_mean_raw: float
132131
completion_len: int = 0
133132
batch_size: int = 0
133+
steps_per_batch_applied: int = 1
134134

135135

136136
@dataclass

claas/inference/vllm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,16 @@ async def chat_completion(
167167

168168
usage = data.get("usage", {})
169169

170+
# vLLM includes the stop token (e.g. <|im_end|>) in logprobs but the
171+
# tokenizer doesn't produce it when re-encoding the text. Trim the
172+
# logprobs so the two sequences stay aligned.
173+
if (
174+
response_logprobs is not None
175+
and response_token_ids
176+
and len(response_logprobs) > len(response_token_ids)
177+
):
178+
response_logprobs = response_logprobs[: len(response_token_ids)]
179+
170180
return CompletionResult(
171181
content=content,
172182
raw_prompt=raw_prompt,

0 commit comments

Comments
 (0)