Persistent trainer + CPU cache for local engine#32
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the ⚙️ Run configurationConfiguration used: Organization UI Review profile: CHILL Plan: Pro Run ID: You can disable this status message by setting the Use the checkbox below for a quick retry:
📝 WalkthroughWalkthroughThe PR introduces CPU-resident caching for LoRA adapter states between distillation training steps. New immutable data structures capture LoRA configuration, optimizer states, and distillation results. The distillation trainer and local engine are updated to optionally load from cache with thread-safe management and proper device placement for optimizer state tensors. Changes
Sequence DiagramsequenceDiagram
participant Client as Client Request
participant Engine as LocalTrainingEngine
participant Cache as LoRA Cache
participant Trainer as DistillationTrainer
participant Model as LoRA Model
participant Optimizer as Optimizer State
Client->>Engine: distill(lora_id, request)
Engine->>Engine: _ensure_model_loaded()
Engine->>Model: reload_base_model()
Engine->>Cache: fetch cached entry for lora_id
alt Cache Hit
Cache-->>Engine: LoraCacheEntry
Engine->>Trainer: distill(payload, cached=entry)
Trainer->>Model: _load_lora_from_cache(entry)
Trainer->>Optimizer: _gpu_optimizer_state(cached.optimizer_state_dict)
else Cache Miss
Engine->>Trainer: distill(payload, cached=None)
Trainer->>Model: load LoRA from disk
Trainer->>Optimizer: initialize optimizer
end
Trainer->>Trainer: train()
Trainer->>Model: get trained model state
Trainer->>Optimizer: get optimizer state
Trainer->>Trainer: _build_cache_entry()
Trainer-->>Engine: DistillStepResult(response, cache_entry)
Engine->>Cache: update cache[lora_id] = cache_entry
Engine->>Model: offload_base_model()
Engine-->>Client: result.response
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: e6b04caaad
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
There was a problem hiding this comment.
Actionable comments posted: 3
🧹 Nitpick comments (3)
claas/training/distillation.py (2)
140-142:reload_base_modelassumesself.base_modelandself.deviceare set.If called before
load_base_model(), this will raiseAttributeError. The engine's_ensure_model_loadedgate makes this safe today, but a defensive check or docstring noting the precondition would help.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/distillation.py` around lines 140 - 142, The reload_base_model method assumes self.base_model and self.device exist; add a defensive precondition check at the start of reload_base_model that verifies hasattr(self, "base_model") and hasattr(self, "device") (or self.base_model is not None / self.device is not None) and either return early (no-op) or raise a clear RuntimeError indicating load_base_model must be called first; reference the existing _ensure_model_loaded guard in the comment or docstring to make the precondition explicit and update the docstring of reload_base_model to state that load_base_model must be called prior.
43-85:_cpu_optimizer_stateand_gpu_optimizer_stateare nearly identical — consider a shared helper.The two functions differ only in the tensor placement expression (
v.detach().cpu().clone()vsv.detach().to(device).clone()). A single_remap_optimizer_state(state_dict, device)would eliminate the duplication.♻️ Proposed consolidation
-def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: - """Deep-copy optimizer state with all tensors moved to CPU.""" - result: dict[str, object] = {} - for key, value in state_dict.items(): - if key == "state": - param_states = cast("dict[int, dict[str, object]]", value) - cpu_states: dict[int, dict[str, object]] = {} - for param_id, param_state in param_states.items(): - cpu_param: dict[str, object] = {} - for k, v in param_state.items(): - if isinstance(v, torch.Tensor): - cpu_param[k] = v.detach().cpu().clone() - else: - cpu_param[k] = copy.deepcopy(v) - cpu_states[param_id] = cpu_param - result[key] = cpu_states - else: - result[key] = copy.deepcopy(value) - return result - - -def _gpu_optimizer_state( - state_dict: dict[str, object], - device: torch.device, -) -> dict[str, object]: - """Deep-copy optimizer state with all tensors moved to a target device.""" - result: dict[str, object] = {} - for key, value in state_dict.items(): - if key == "state": - param_states = cast("dict[int, dict[str, object]]", value) - gpu_states: dict[int, dict[str, object]] = {} - for param_id, param_state in param_states.items(): - gpu_param: dict[str, object] = {} - for k, v in param_state.items(): - if isinstance(v, torch.Tensor): - gpu_param[k] = v.detach().to(device).clone() - else: - gpu_param[k] = copy.deepcopy(v) - gpu_states[param_id] = gpu_param - result[key] = gpu_states - else: - result[key] = copy.deepcopy(value) - return result +def _remap_optimizer_state( + state_dict: dict[str, object], + device: torch.device, +) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to *device*.""" + result: dict[str, object] = {} + for key, value in state_dict.items(): + if key == "state": + param_states = cast("dict[int, dict[str, object]]", value) + new_states: dict[int, dict[str, object]] = {} + for param_id, param_state in param_states.items(): + new_param: dict[str, object] = {} + for k, v in param_state.items(): + if isinstance(v, torch.Tensor): + new_param[k] = v.detach().to(device).clone() + else: + new_param[k] = copy.deepcopy(v) + new_states[param_id] = new_param + result[key] = new_states + else: + result[key] = copy.deepcopy(value) + return result + + +def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to CPU.""" + return _remap_optimizer_state(state_dict, torch.device("cpu")) + + +def _gpu_optimizer_state( + state_dict: dict[str, object], + device: torch.device, +) -> dict[str, object]: + """Deep-copy optimizer state with all tensors moved to a target device.""" + return _remap_optimizer_state(state_dict, device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/distillation.py` around lines 43 - 85, Consolidate duplicated logic in _cpu_optimizer_state and _gpu_optimizer_state by extracting a single helper _remap_optimizer_state(state_dict: dict[str, object], device: torch.device) -> dict[str, object] that walks the state_dict exactly as today, copies non-tensor values with copy.deepcopy, and for torch.Tensor values uses v.detach().to(device).clone(); then implement _cpu_optimizer_state as a thin wrapper that calls _remap_optimizer_state(state_dict, torch.device("cpu")) and _gpu_optimizer_state as a thin wrapper that forwards the provided device to _remap_optimizer_state, keeping the existing type casts ("dict[int, dict[str, object]]") and return shape unchanged.claas/training/cache.py (1)
12-31: Frozen dataclasses with mutable containers provide only shallow immutability.
frozen=Trueprevents attribute reassignment but callers can still mutate the innerlistanddictobjects in place (e.g.,entry.lora_state_dict["new_key"] = ...). This is fine as-is since_build_cache_entrycreates defensive copies, but worth noting for future maintainers.If you want deeper guarantees later, consider
tuple[str, ...]fortarget_modulesandtypes.MappingProxyTypewrappers for the dicts.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@claas/training/cache.py` around lines 12 - 31, The dataclasses LoraAdapterConfig and LoraCacheEntry are declared frozen but contain mutable containers (list and dict) allowing in-place mutation; update the types and construction to provide stronger immutability by changing LoraAdapterConfig.target_modules from list[str] to tuple[str, ...] and make LoraCacheEntry.lora_state_dict and optimizer_state_dict immutable views (e.g., wrap with types.MappingProxyType) when building entries; also update the builder function (_build_cache_entry or wherever entries are created) to convert incoming lists to tuples and wrap dicts in MappingProxyType so the frozen dataclass truly prevents mutations of internal containers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@claas/training/distillation.py`:
- Around line 298-331: In _build_cache_entry, peft_config.task_type can be a
TaskType enum while LoraAdapterConfig.task_type is annotated as str; fix by
converting peft_config.task_type to its string value when constructing
adapter_config (e.g., use str(peft_config.task_type) or
peft_config.task_type.value) so adapter_config.task_type is a plain string;
update the adapter_config construction in the _build_cache_entry function to
coerce peft_config.task_type accordingly.
In `@claas/training/engine/local/engine.py`:
- Around line 60-64: There's a race in _ensure_model_loaded: multiple concurrent
distill() callers can pass the "if not self._model_loaded" check and each call
self._trainer.load_base_model via asyncio.to_thread, causing redundant loads;
fix by creating an asyncio.Lock in __init__ (e.g. self._load_lock =
asyncio.Lock()) and wrap the check/load/flag assignment inside "async with
self._load_lock" in _ensure_model_loaded so only one caller runs
asyncio.to_thread(self._trainer.load_base_model) and sets self._model_loaded =
True while holding the lock.
- Around line 81-95: The cache is being stored under the input-derived
resolved_id (from resolve_lora_id(payload.lora_id)) but distill() may return a
different output id (result.response.lora_id) when save_in_place=False; fix by
using the output id as the cache key: after awaiting
asyncio.to_thread(self._trainer.distill, ...), read final_id =
result.response.lora_id (or from result.cache_entry if more appropriate) and
then under self._cache_lock set self._lora_cache[final_id] = result.cache_entry
(optionally remove the old resolved_id if final_id != resolved_id to avoid stale
entries); keep the initial cache read using resolved_id unchanged but always
write using final_id.
---
Nitpick comments:
In `@claas/training/cache.py`:
- Around line 12-31: The dataclasses LoraAdapterConfig and LoraCacheEntry are
declared frozen but contain mutable containers (list and dict) allowing in-place
mutation; update the types and construction to provide stronger immutability by
changing LoraAdapterConfig.target_modules from list[str] to tuple[str, ...] and
make LoraCacheEntry.lora_state_dict and optimizer_state_dict immutable views
(e.g., wrap with types.MappingProxyType) when building entries; also update the
builder function (_build_cache_entry or wherever entries are created) to convert
incoming lists to tuples and wrap dicts in MappingProxyType so the frozen
dataclass truly prevents mutations of internal containers.
In `@claas/training/distillation.py`:
- Around line 140-142: The reload_base_model method assumes self.base_model and
self.device exist; add a defensive precondition check at the start of
reload_base_model that verifies hasattr(self, "base_model") and hasattr(self,
"device") (or self.base_model is not None / self.device is not None) and either
return early (no-op) or raise a clear RuntimeError indicating load_base_model
must be called first; reference the existing _ensure_model_loaded guard in the
comment or docstring to make the precondition explicit and update the docstring
of reload_base_model to state that load_base_model must be called prior.
- Around line 43-85: Consolidate duplicated logic in _cpu_optimizer_state and
_gpu_optimizer_state by extracting a single helper
_remap_optimizer_state(state_dict: dict[str, object], device: torch.device) ->
dict[str, object] that walks the state_dict exactly as today, copies non-tensor
values with copy.deepcopy, and for torch.Tensor values uses
v.detach().to(device).clone(); then implement _cpu_optimizer_state as a thin
wrapper that calls _remap_optimizer_state(state_dict, torch.device("cpu")) and
_gpu_optimizer_state as a thin wrapper that forwards the provided device to
_remap_optimizer_state, keeping the existing type casts ("dict[int, dict[str,
object]]") and return shape unchanged.
ℹ️ Review info
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Cache: Disabled due to data retention organization setting
Knowledge base: Disabled due to data retention organization setting
📒 Files selected for processing (6)
claas/modal/worker.pyclaas/training/cache.pyclaas/training/distillation.pyclaas/training/engine/local/engine.pytests/test_distillation_optimizer_state.pytests/test_local_training_engine.py
Keep the DistillationTrainer and base model across distill() calls instead of recreating them each time. Cache LoRA adapter weights and optimizer state on CPU between steps so the second call for a given lora_id skips all disk I/O. GPU memory is still fully released after each step via the existing offload_base_model() pattern. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Relocate LoraCacheEntry, LoraAdapterConfig, DistillStepResult, and the cpu/gpu_optimizer_state helpers from the shared training module into claas/training/engine/local/cache.py since they are only used by the local engine's CPU caching path. The Modal worker never uses caching. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
0747c76 to
166a16d
Compare
Addresses review feedback: split dataclass types from helper functions in cache.py for clearer module organization. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
If the model is not a PeftModel, raise an explicit error instead of silently falling back to full state_dict extraction. Also coerces task_type enum to string for type consistency. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add asyncio.Lock to prevent concurrent model loads. Cache distill results under the output lora_id instead of the input to handle non-in-place saves correctly. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Cast peft_config to LoraConfig instead of PeftConfig base class - Add type: ignore for peft Literal bias type and dict[str, object] subscripts - Add missing system_prompt field to test payloads - Update integration test stubs for new distill() signature Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
|
Closing — CPU cache saves ~4% of distill time (3-5s out of 100s) while adding significant complexity. Not worth the tradeoff. |
Summary
LocalTrainingEnginenow creates theDistillationTraineronce in__init__and reuses it acrossdistill()calls, eliminating redundant base model loads after the first callLoraCacheEntry). Subsequent calls for the samelora_idskip all disk I/O (load_lora/load_optimizer_state)claas/training/cache.pywith frozen@dataclasstypes (LoraCacheEntry,LoraAdapterConfig,DistillStepResult) — no loose dicts, no optional types where invariants can be enforcedoffload_base_model()+del model, optimizer+cuda.empty_cache()pattern is unchanged. Cache holds CPU-only tensors by constructionTest plan
tests/test_distillation_optimizer_state.py—_cpu_optimizer_state/_gpu_optimizer_stateround-trip, deep-copy isolation,LoraCacheEntryimmutabilitytests/test_local_training_engine.py— eager trainer creation, one-timeload_base_model, per-callreload_base_model, cache miss→hit, cache eviction on delete, offload error propagationuv run pytest tests/ -v -m "not integration"(103 passed, 25 skipped for torch)uv run ruff check claas/ tests/uv run ty check(only pre-existing unresolved-import errors for torch/peft/transformers)🤖 Generated with Claude Code
Summary by CodeRabbit
Release Notes
New Features
Refactor