Move multi-step training into TrainingConfig with per-step IS correction#39
Move multi-step training into TrainingConfig with per-step IS correction#39
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThis PR implements multi-step gradient updates within single batches and feedback repetition control. Configuration is restructured to nest training parameters under a Changes
Sequence Diagram(s)sequenceDiagram
participant Trainer as DistillationTrainer
participant PreparedSamples as Prepared Samples
participant Step as Step Loop
participant StudentModel as Student Model
participant Optimizer as Optimizer
Trainer->>PreparedSamples: Validate & accumulate samples<br/>(full_ids, response_ids, logprobs)
PreparedSamples-->>Trainer: PreparedSample list
Trainer->>Step: For each step in steps_per_batch
Step->>StudentModel: Compute student response logprobs<br/>(current adapter state)
StudentModel-->>Step: per_step_logprobs
Step->>Step: Build SDPOLossInput from<br/>prepared samples + new logprobs
Step->>Step: Compute per-step loss<br/>(distill_loss, kl_reg, clip)
Step->>Optimizer: Backward & gradient update<br/>with clipping
Optimizer-->>Step: updated model state
Step->>StudentModel: Recompute behavior_logprobs<br/>for next step
StudentModel-->>Step: updated logprobs
Step-->>Trainer: step metrics & updated state
Trainer->>Trainer: Aggregate per-step metrics<br/>steps_per_batch_applied
Trainer-->>Trainer: Return per-step results<br/>& tokens processed
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Poem
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (5)
claas/training/engine/tinker/engine.py (2)
239-241: Lambda used for averaging — minor style nit.The
avglambda is re-created each loop iteration. Consider extracting it before the loop or using a local function.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/engine/tinker/engine.py` around lines 239 - 241, The averaging lambda avg is being recreated each loop iteration; extract it as a local helper function or define it once before the loop to avoid recreating the closure repeatedly. Replace the inline lambda assignment avg = lambda key: ... with a named function (e.g., def avg(key): return sum(m[key] for m in sample_metrics) / n) or move the lambda definition above the loop where sample_metrics and n are available, and update all uses of avg (referenced as avg and sample_metrics in this block) accordingly.
218-261: Multi-step loop with Tinker SDK: correct but note the cost of intermediate weight saves.The flow is sound: build datums → forward/backward → optimizer step → recompute logprobs. The
save_weights_and_get_sampling_client_asynccall at line 257 is required by Tinker's architecture to get a sampling client with updated weights, but it means each intermediate step (all except the last) triggers a full weight save. Forsteps_per_batch > 2, this could be a latency concern.Worth documenting this tradeoff or considering whether Tinker offers a lighter-weight way to get an updated sampling client without a full checkpoint save.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/engine/tinker/engine.py` around lines 218 - 261, The loop calls training_client.save_weights_and_get_sampling_client_async inside the step loop (see save_weights_and_get_sampling_client_async, steps_per_batch and training_client) which triggers a full weight save on every intermediate step and can cause latency when steps_per_batch > 2; update the code to either (a) document this tradeoff just above the loop and in the function docstring, or (b) add a configurable behavior (e.g., a flag like save_intermediate_weights) so you only call save_weights_and_get_sampling_client_async for steps where it’s necessary (or avoid it until the final step), and, if the Tinker SDK offers a lighter alternative to get an updated sampling client, switch to that API instead.tests/test_tinker_engine.py (1)
91-122: Consider parameterizing mock save paths for multi-step scenarios.The
mock_training_clientfixture returns a fixedsave_result.path = "tinker://checkpoints/step-1"regardless of the checkpoint name passed tosave_state_async. This works for current tests, but if future tests need to assert the saved path reflects the actual step, the fixture would need to be updated. Not a blocker.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_tinker_engine.py` around lines 91 - 122, The fixture mock_training_client currently returns a fixed save_result.path ("tinker://checkpoints/step-1") for save_state_async; change it so save_state_async uses an AsyncMock side_effect that builds and returns a MagicMock whose .path is derived from the checkpoint name/step passed into save_state_async (e.g., include the step id or checkpoint name from the method args), and do the same for save_weights_for_sampler_async/sampler_save.path if needed; update references to save_result and sampler_save in the fixture to be created inside the side_effects so tests that call mock_training_client.save_state_async(...) will receive a result object with a path that reflects the input.claas/training/distillation.py (1)
38-47: PreparedSample name collision with tinker engine.Both
claas/training/distillation.pyandclaas/training/engine/tinker/engine.pydefine aPreparedSampleTypedDict with different fields (torch.Tensor-based vs. list-based). This works fine since they're module-private, but could cause confusion when navigating the codebase or in IDE symbol search.Consider naming one of them more specifically (e.g.,
LocalPreparedSampleorTinkerPreparedSample) to disambiguate.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/distillation.py` around lines 38 - 47, The TypedDict PreparedSample in claas/training/distillation.py collides by name with another PreparedSample in claas/training/engine/tinker/engine.py; rename this TypedDict to a more specific name (e.g., DistillationPreparedSample or LocalPreparedSample) and update all local type annotations and imports in claas/training/distillation.py that reference PreparedSample (functions, return types, variables) to use the new name so the module remains unambiguous while preserving the same fields and behavior.claas/eval/types.py (1)
80-91: Risk of default drift betweenEvalTrainingConfigandTrainingConfig.
EvalTrainingConfigmanually duplicates field names and defaults from the PydanticTrainingConfig(inclaas/core/types.py). If a default changes in one but not the other, eval runs will silently use stale values. Consider adding a test or a factory that asserts parity.💡 Example: add a parity test
# tests/test_eval_config.py (or similar) from claas.core.types import TrainingConfig from claas.eval.types import EvalTrainingConfig def test_eval_training_config_defaults_match(): """Ensure EvalTrainingConfig defaults stay in sync with TrainingConfig.""" runtime = TrainingConfig() hydra = EvalTrainingConfig() for f in dataclasses.fields(hydra): assert getattr(hydra, f.name) == getattr(runtime, f.name), ( f"Default mismatch on '{f.name}': " f"EvalTrainingConfig={getattr(hydra, f.name)} vs " f"TrainingConfig={getattr(runtime, f.name)}" )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/eval/types.py` around lines 80 - 91, EvalTrainingConfig duplicates defaults from the Pydantic TrainingConfig which can drift; add a parity test that instantiates TrainingConfig and EvalTrainingConfig and asserts all field defaults match (use dataclasses.fields on EvalTrainingConfig and compare getattr(hydra, name) == getattr(runtime, name)), e.g. add tests/test_eval_config.py to fail CI if any default on EvalTrainingConfig diverges from TrainingConfig; alternatively implement a factory that constructs EvalTrainingConfig from TrainingConfig to guarantee parity and update usages to call that factory instead of hardcoding defaults.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@claas/eval/runner.py`:
- Around line 84-93: The code currently indexes
metadata["steps_per_batch_applied"] in the LocalDistillMetrics return path which
can raise KeyError and drop the entire metrics object; change that access to use
metadata.get("steps_per_batch_applied", 1) so LocalDistillMetrics is constructed
with a default of 1 when the field is absent (keep identical pattern used for
other fields like distill_loss, kl_reg, mean_is_ratio, clip_fraction) to make
the metrics construction resilient.
---
Nitpick comments:
In `@claas/eval/types.py`:
- Around line 80-91: EvalTrainingConfig duplicates defaults from the Pydantic
TrainingConfig which can drift; add a parity test that instantiates
TrainingConfig and EvalTrainingConfig and asserts all field defaults match (use
dataclasses.fields on EvalTrainingConfig and compare getattr(hydra, name) ==
getattr(runtime, name)), e.g. add tests/test_eval_config.py to fail CI if any
default on EvalTrainingConfig diverges from TrainingConfig; alternatively
implement a factory that constructs EvalTrainingConfig from TrainingConfig to
guarantee parity and update usages to call that factory instead of hardcoding
defaults.
In `@claas/training/distillation.py`:
- Around line 38-47: The TypedDict PreparedSample in
claas/training/distillation.py collides by name with another PreparedSample in
claas/training/engine/tinker/engine.py; rename this TypedDict to a more specific
name (e.g., DistillationPreparedSample or LocalPreparedSample) and update all
local type annotations and imports in claas/training/distillation.py that
reference PreparedSample (functions, return types, variables) to use the new
name so the module remains unambiguous while preserving the same fields and
behavior.
In `@claas/training/engine/tinker/engine.py`:
- Around line 239-241: The averaging lambda avg is being recreated each loop
iteration; extract it as a local helper function or define it once before the
loop to avoid recreating the closure repeatedly. Replace the inline lambda
assignment avg = lambda key: ... with a named function (e.g., def avg(key):
return sum(m[key] for m in sample_metrics) / n) or move the lambda definition
above the loop where sample_metrics and n are available, and update all uses of
avg (referenced as avg and sample_metrics in this block) accordingly.
- Around line 218-261: The loop calls
training_client.save_weights_and_get_sampling_client_async inside the step loop
(see save_weights_and_get_sampling_client_async, steps_per_batch and
training_client) which triggers a full weight save on every intermediate step
and can cause latency when steps_per_batch > 2; update the code to either (a)
document this tradeoff just above the loop and in the function docstring, or (b)
add a configurable behavior (e.g., a flag like save_intermediate_weights) so you
only call save_weights_and_get_sampling_client_async for steps where it’s
necessary (or avoid it until the final step), and, if the Tinker SDK offers a
lighter alternative to get an updated sampling client, switch to that API
instead.
In `@tests/test_tinker_engine.py`:
- Around line 91-122: The fixture mock_training_client currently returns a fixed
save_result.path ("tinker://checkpoints/step-1") for save_state_async; change it
so save_state_async uses an AsyncMock side_effect that builds and returns a
MagicMock whose .path is derived from the checkpoint name/step passed into
save_state_async (e.g., include the step id or checkpoint name from the method
args), and do the same for save_weights_for_sampler_async/sampler_save.path if
needed; update references to save_result and sampler_save in the fixture to be
created inside the side_effects so tests that call
mock_training_client.save_state_async(...) will receive a result object with a
path that reflects the input.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to data retention organization setting
📒 Files selected for processing (11)
claas/core/types.pyclaas/eval/README.mdclaas/eval/config.pyclaas/eval/configs/base.yamlclaas/eval/runner.pyclaas/eval/types.pyclaas/training/distillation.pyclaas/training/engine/tinker/engine.pytests/test_eval_config.pytests/test_eval_runner.pytests/test_tinker_engine.py
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 762072da79
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
…aining # Conflicts: # claas/core/types.py # claas/eval/README.md # claas/eval/config.py # claas/eval/configs/base.yaml # claas/eval/runner.py # claas/eval/types.py # tests/test_eval_config.py # tests/test_eval_runner.py
|
@codex review |
There was a problem hiding this comment.
♻️ Duplicate comments (1)
claas/eval/runner.py (1)
84-84:⚠️ Potential issue | 🟠 MajorUse safe default access for
steps_per_batch_appliedto avoid dropping metrics.Line 84 and Line 92 still use
metadata["steps_per_batch_applied"]. If omitted by an engine, this raisesKeyError, and the catch path discards the entire SDPO metrics object for that step.Suggested fix
if config.mode == "tinker" and "adv_mean" in metadata: return TinkerDistillMetrics( @@ - steps_per_batch_applied=metadata["steps_per_batch_applied"], + steps_per_batch_applied=metadata.get("steps_per_batch_applied", 1), ) @@ return LocalDistillMetrics( distill_loss=metadata.get("distill_loss"), kl_reg=metadata.get("kl_reg"), mean_is_ratio=metadata.get("mean_is_ratio"), clip_fraction=metadata.get("clip_fraction"), - steps_per_batch_applied=metadata["steps_per_batch_applied"], + steps_per_batch_applied=metadata.get("steps_per_batch_applied", 1), )Also applies to: 92-92
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/eval/runner.py` at line 84, Replace direct indexing of metadata["steps_per_batch_applied"] with a safe lookup that supplies a sensible default (e.g., metadata.get("steps_per_batch_applied", 1)) to avoid raising KeyError and dropping the SDPO metrics object; update both occurrences that reference steps_per_batch_applied in claass.eval.runner (the two places around the current lines using metadata["steps_per_batch_applied"]) so downstream logic receives the fallback value when the engine omits the key.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@claas/eval/runner.py`:
- Line 84: Replace direct indexing of metadata["steps_per_batch_applied"] with a
safe lookup that supplies a sensible default (e.g.,
metadata.get("steps_per_batch_applied", 1)) to avoid raising KeyError and
dropping the SDPO metrics object; update both occurrences that reference
steps_per_batch_applied in claass.eval.runner (the two places around the current
lines using metadata["steps_per_batch_applied"]) so downstream logic receives
the fallback value when the engine omits the key.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to data retention organization setting
📒 Files selected for processing (7)
claas/core/types.pyclaas/eval/README.mdclaas/eval/configs/base.yamlclaas/eval/runner.pyclaas/eval/types.pytests/test_eval_config.pytests/test_eval_runner.py
🚧 Files skipped from review as they are similar to previous changes (3)
- claas/core/types.py
- claas/eval/types.py
- tests/test_eval_config.py
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e2331f72b5
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| max_grad_norm: float = 1.0 | ||
| kl_reg_weight: float = 0.0 | ||
| teacher_top_k: int = 100 | ||
| steps_per_batch: int = 4 |
There was a problem hiding this comment.
Enforce positive
steps_per_batch in TrainingConfig
The newly added steps_per_batch field has no lower-bound validation, but both multi-step trainers now assume at least one iteration and unconditionally read step_metrics[-1] (claas/training/distillation.py and claas/training/engine/tinker/engine.py), so sending training.steps_per_batch=0 is currently accepted and then crashes /v1/feedback with a server error instead of a clean 4xx validation failure; this can break eval runs by turning every feedback update into a failed request.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Fixed in c7eb678
Comment generated by Claude Code
Replace hasattr/dict pattern with getattr to satisfy ty's type narrowing, and make single-step distill tests explicit about steps_per_batch=1 so they don't break when the default changes. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
vLLM includes the stop token (<|im_end|>) in its logprobs output but the tokenizer omits it when re-encoding the response text, causing a consistent off-by-one that made /v1/feedback reject every training request. Trim the logprobs to match the token-ID length. Also fixes three stale items in the setup-local skill: - Add flash-attn install step (required for flash_attention_2 default) - Use Hydra entry point for CLaaS API instead of bare uvicorn - Add --enable-lora to vLLM start and fix script path Includes GcpPreferenceVerifier and cloud_gcp preference config. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
@codex review |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: dac6d95c08
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| def __post_init__(self) -> None: | ||
| if self.steps_per_batch < 1: | ||
| msg = f"steps_per_batch must be >= 1, got {self.steps_per_batch}" | ||
| raise ValueError(msg) |
There was a problem hiding this comment.
Validate feedback_repetitions lower bound
TrainingConfig.__post_init__ now enforces steps_per_batch >= 1 but leaves feedback_repetitions unchecked, so 0 or negative values are accepted and later converted into an empty critique string via " ".join([sample.feedback] * feedback_repetitions) in both training engines. In that case distillation silently runs without the user’s feedback signal, which is a correctness regression for misconfigured runs and should be rejected up front the same way invalid step counts are.
Useful? React with 👍 / 👎.
Clarify that setup-local prefers Docker when available, falling back to native setup otherwise. Add local dev artifacts to .gitignore. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Summary
steps_per_batch,feedback_repetitions) from eval-owned settings intoTrainingConfigtrainingconfig throughFeedbackItemin each/v1/feedbackrequeststeps_per_batch_applied, per-step metrics) and wire evalsub_step_countto that metadataKey Implementation Notes
TrainingConfigfields:steps_per_batchfeedback_repetitionsEvalTrainingConfigand convert to runtimeTrainingConfiginbuild_harness_configsave_weights_and_get_sampling_client_asyncValidation
uv run ruff check claas/ tests/ --fixuv run pytest tests/ -q -m "not integration"109 passed, 26 skipped, 5 deselecteduv run ty checktorch,tinker,transformers) are expected in this environmentSummary by CodeRabbit
Release Notes
New Features
steps_per_batchparameterfeedback_repetitionsconfiguration option for enhanced training controlsteps_per_batch_appliedtracks actual steps executed per batchDocumentation
Refactor