Skip to content

Persistent trainer + CPU cache for local engine#32

Closed
kfallah wants to merge 8 commits intomainfrom
persistent-trainer-cpu-cache
Closed

Persistent trainer + CPU cache for local engine#32
kfallah wants to merge 8 commits intomainfrom
persistent-trainer-cpu-cache

Conversation

@kfallah
Copy link
Copy Markdown
Owner

@kfallah kfallah commented Feb 23, 2026

Summary

  • Persistent trainer: LocalTrainingEngine now creates the DistillationTrainer once in __init__ and reuses it across distill() calls, eliminating redundant base model loads after the first call
  • CPU LoRA cache: After each training step, LoRA adapter weights and optimizer state are snapshotted to CPU memory (LoraCacheEntry). Subsequent calls for the same lora_id skip all disk I/O (load_lora / load_optimizer_state)
  • Typed cache structures: New claas/training/cache.py with frozen @dataclass types (LoraCacheEntry, LoraAdapterConfig, DistillStepResult) — no loose dicts, no optional types where invariants can be enforced
  • GPU memory guarantee preserved: The existing offload_base_model() + del model, optimizer + cuda.empty_cache() pattern is unchanged. Cache holds CPU-only tensors by construction

Test plan

  • tests/test_distillation_optimizer_state.py_cpu_optimizer_state / _gpu_optimizer_state round-trip, deep-copy isolation, LoraCacheEntry immutability
  • tests/test_local_training_engine.py — eager trainer creation, one-time load_base_model, per-call reload_base_model, cache miss→hit, cache eviction on delete, offload error propagation
  • Full test suite passes: uv run pytest tests/ -v -m "not integration" (103 passed, 25 skipped for torch)
  • Lint clean: uv run ruff check claas/ tests/
  • Type check: uv run ty check (only pre-existing unresolved-import errors for torch/peft/transformers)

🤖 Generated with Claude Code

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced in-memory caching mechanism for LoRA adapter state and optimizer configuration between training steps
    • Added optimizer state serialization and device placement utilities for improved GPU/CPU memory management
  • Refactor

    • Enhanced training engine to support cache-driven training workflows, reducing redundant model loading and state reconstruction

@coderabbitai
Copy link
Copy Markdown

coderabbitai bot commented Feb 23, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Run ID: ebf38b21-a4ff-4c03-9a46-acbdfa4e1e91

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

The PR introduces CPU-resident caching for LoRA adapter states between distillation training steps. New immutable data structures capture LoRA configuration, optimizer states, and distillation results. The distillation trainer and local engine are updated to optionally load from cache with thread-safe management and proper device placement for optimizer state tensors.

Changes

Cohort / File(s) Summary
Cache Infrastructure
claas/training/cache.py
Added three new frozen dataclasses: LoraAdapterConfig (LoRA hyperparameters), LoraCacheEntry (LoRA and optimizer state snapshot), and DistillStepResult (response with cache entry) to support immutable, serializable caching structures.
Distillation Trainer
claas/training/distillation.py
Enhanced distill method to accept optional cached parameter; added optimizer state utilities (_cpu_optimizer_state, _gpu_optimizer_state); introduced LoRA loading from cache (_load_lora_from_cache) and cache entry building (_build_cache_entry); changed return type to DistillStepResult; added public reload_base_model method.
Local Training Engine
claas/training/engine/local/engine.py
Implemented thread-safe LoRA caching with _lora_cache and _cache_lock; updated distill flow to reload base model, resolve LoRA ID, fetch/update cache under lock, and return response; modified delete_lora to evict cache entries; added one-time model loading via _ensure_model_loaded.
Worker Integration
claas/modal/worker.py
Modified distill return value to extract response attribute from DistillationResponse, now returning inner payload instead of wrapper.
Test Coverage
tests/test_distillation_optimizer_state.py, tests/test_local_training_engine.py
Added tests for optimizer state serialization/deserialization, cache immutability, and caching behavior including cache hit/miss, eviction via delete_lora, and offload error propagation.

Sequence Diagram

sequenceDiagram
    participant Client as Client Request
    participant Engine as LocalTrainingEngine
    participant Cache as LoRA Cache
    participant Trainer as DistillationTrainer
    participant Model as LoRA Model
    participant Optimizer as Optimizer State

    Client->>Engine: distill(lora_id, request)
    Engine->>Engine: _ensure_model_loaded()
    Engine->>Model: reload_base_model()
    Engine->>Cache: fetch cached entry for lora_id
    
    alt Cache Hit
        Cache-->>Engine: LoraCacheEntry
        Engine->>Trainer: distill(payload, cached=entry)
        Trainer->>Model: _load_lora_from_cache(entry)
        Trainer->>Optimizer: _gpu_optimizer_state(cached.optimizer_state_dict)
    else Cache Miss
        Engine->>Trainer: distill(payload, cached=None)
        Trainer->>Model: load LoRA from disk
        Trainer->>Optimizer: initialize optimizer
    end
    
    Trainer->>Trainer: train()
    Trainer->>Model: get trained model state
    Trainer->>Optimizer: get optimizer state
    Trainer->>Trainer: _build_cache_entry()
    Trainer-->>Engine: DistillStepResult(response, cache_entry)
    Engine->>Cache: update cache[lora_id] = cache_entry
    Engine->>Model: offload_base_model()
    Engine-->>Client: result.response
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Poem

🐰 Hop, hop! Cache entries now stay,
LoRA states snapped away,
No disk reads slow the train,
Optimizer states ordained,
Thread-safe caches light the way! 🌟

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 72.97% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main architectural changes: adding a persistent trainer instance and CPU-based LoRA caching to the local engine, which are the core objectives of this PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch persistent-trainer-cpu-cache

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

Reviewed commit: e6b04caaad

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Copy link
Copy Markdown

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
claas/training/distillation.py (2)

140-142: reload_base_model assumes self.base_model and self.device are set.

If called before load_base_model(), this will raise AttributeError. The engine's _ensure_model_loaded gate makes this safe today, but a defensive check or docstring noting the precondition would help.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/distillation.py` around lines 140 - 142, The reload_base_model
method assumes self.base_model and self.device exist; add a defensive
precondition check at the start of reload_base_model that verifies hasattr(self,
"base_model") and hasattr(self, "device") (or self.base_model is not None /
self.device is not None) and either return early (no-op) or raise a clear
RuntimeError indicating load_base_model must be called first; reference the
existing _ensure_model_loaded guard in the comment or docstring to make the
precondition explicit and update the docstring of reload_base_model to state
that load_base_model must be called prior.

43-85: _cpu_optimizer_state and _gpu_optimizer_state are nearly identical — consider a shared helper.

The two functions differ only in the tensor placement expression (v.detach().cpu().clone() vs v.detach().to(device).clone()). A single _remap_optimizer_state(state_dict, device) would eliminate the duplication.

♻️ Proposed consolidation
-def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]:
-    """Deep-copy optimizer state with all tensors moved to CPU."""
-    result: dict[str, object] = {}
-    for key, value in state_dict.items():
-        if key == "state":
-            param_states = cast("dict[int, dict[str, object]]", value)
-            cpu_states: dict[int, dict[str, object]] = {}
-            for param_id, param_state in param_states.items():
-                cpu_param: dict[str, object] = {}
-                for k, v in param_state.items():
-                    if isinstance(v, torch.Tensor):
-                        cpu_param[k] = v.detach().cpu().clone()
-                    else:
-                        cpu_param[k] = copy.deepcopy(v)
-                cpu_states[param_id] = cpu_param
-            result[key] = cpu_states
-        else:
-            result[key] = copy.deepcopy(value)
-    return result
-
-
-def _gpu_optimizer_state(
-    state_dict: dict[str, object],
-    device: torch.device,
-) -> dict[str, object]:
-    """Deep-copy optimizer state with all tensors moved to a target device."""
-    result: dict[str, object] = {}
-    for key, value in state_dict.items():
-        if key == "state":
-            param_states = cast("dict[int, dict[str, object]]", value)
-            gpu_states: dict[int, dict[str, object]] = {}
-            for param_id, param_state in param_states.items():
-                gpu_param: dict[str, object] = {}
-                for k, v in param_state.items():
-                    if isinstance(v, torch.Tensor):
-                        gpu_param[k] = v.detach().to(device).clone()
-                    else:
-                        gpu_param[k] = copy.deepcopy(v)
-                gpu_states[param_id] = gpu_param
-            result[key] = gpu_states
-        else:
-            result[key] = copy.deepcopy(value)
-    return result
+def _remap_optimizer_state(
+    state_dict: dict[str, object],
+    device: torch.device,
+) -> dict[str, object]:
+    """Deep-copy optimizer state with all tensors moved to *device*."""
+    result: dict[str, object] = {}
+    for key, value in state_dict.items():
+        if key == "state":
+            param_states = cast("dict[int, dict[str, object]]", value)
+            new_states: dict[int, dict[str, object]] = {}
+            for param_id, param_state in param_states.items():
+                new_param: dict[str, object] = {}
+                for k, v in param_state.items():
+                    if isinstance(v, torch.Tensor):
+                        new_param[k] = v.detach().to(device).clone()
+                    else:
+                        new_param[k] = copy.deepcopy(v)
+                new_states[param_id] = new_param
+            result[key] = new_states
+        else:
+            result[key] = copy.deepcopy(value)
+    return result
+
+
+def _cpu_optimizer_state(state_dict: dict[str, object]) -> dict[str, object]:
+    """Deep-copy optimizer state with all tensors moved to CPU."""
+    return _remap_optimizer_state(state_dict, torch.device("cpu"))
+
+
+def _gpu_optimizer_state(
+    state_dict: dict[str, object],
+    device: torch.device,
+) -> dict[str, object]:
+    """Deep-copy optimizer state with all tensors moved to a target device."""
+    return _remap_optimizer_state(state_dict, device)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/distillation.py` around lines 43 - 85, Consolidate duplicated
logic in _cpu_optimizer_state and _gpu_optimizer_state by extracting a single
helper _remap_optimizer_state(state_dict: dict[str, object], device:
torch.device) -> dict[str, object] that walks the state_dict exactly as today,
copies non-tensor values with copy.deepcopy, and for torch.Tensor values uses
v.detach().to(device).clone(); then implement _cpu_optimizer_state as a thin
wrapper that calls _remap_optimizer_state(state_dict, torch.device("cpu")) and
_gpu_optimizer_state as a thin wrapper that forwards the provided device to
_remap_optimizer_state, keeping the existing type casts ("dict[int, dict[str,
object]]") and return shape unchanged.
claas/training/cache.py (1)

12-31: Frozen dataclasses with mutable containers provide only shallow immutability.

frozen=True prevents attribute reassignment but callers can still mutate the inner list and dict objects in place (e.g., entry.lora_state_dict["new_key"] = ...). This is fine as-is since _build_cache_entry creates defensive copies, but worth noting for future maintainers.

If you want deeper guarantees later, consider tuple[str, ...] for target_modules and types.MappingProxyType wrappers for the dicts.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@claas/training/cache.py` around lines 12 - 31, The dataclasses
LoraAdapterConfig and LoraCacheEntry are declared frozen but contain mutable
containers (list and dict) allowing in-place mutation; update the types and
construction to provide stronger immutability by changing
LoraAdapterConfig.target_modules from list[str] to tuple[str, ...] and make
LoraCacheEntry.lora_state_dict and optimizer_state_dict immutable views (e.g.,
wrap with types.MappingProxyType) when building entries; also update the builder
function (_build_cache_entry or wherever entries are created) to convert
incoming lists to tuples and wrap dicts in MappingProxyType so the frozen
dataclass truly prevents mutations of internal containers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@claas/training/distillation.py`:
- Around line 298-331: In _build_cache_entry, peft_config.task_type can be a
TaskType enum while LoraAdapterConfig.task_type is annotated as str; fix by
converting peft_config.task_type to its string value when constructing
adapter_config (e.g., use str(peft_config.task_type) or
peft_config.task_type.value) so adapter_config.task_type is a plain string;
update the adapter_config construction in the _build_cache_entry function to
coerce peft_config.task_type accordingly.

In `@claas/training/engine/local/engine.py`:
- Around line 60-64: There's a race in _ensure_model_loaded: multiple concurrent
distill() callers can pass the "if not self._model_loaded" check and each call
self._trainer.load_base_model via asyncio.to_thread, causing redundant loads;
fix by creating an asyncio.Lock in __init__ (e.g. self._load_lock =
asyncio.Lock()) and wrap the check/load/flag assignment inside "async with
self._load_lock" in _ensure_model_loaded so only one caller runs
asyncio.to_thread(self._trainer.load_base_model) and sets self._model_loaded =
True while holding the lock.
- Around line 81-95: The cache is being stored under the input-derived
resolved_id (from resolve_lora_id(payload.lora_id)) but distill() may return a
different output id (result.response.lora_id) when save_in_place=False; fix by
using the output id as the cache key: after awaiting
asyncio.to_thread(self._trainer.distill, ...), read final_id =
result.response.lora_id (or from result.cache_entry if more appropriate) and
then under self._cache_lock set self._lora_cache[final_id] = result.cache_entry
(optionally remove the old resolved_id if final_id != resolved_id to avoid stale
entries); keep the initial cache read using resolved_id unchanged but always
write using final_id.

---

Nitpick comments:
In `@claas/training/cache.py`:
- Around line 12-31: The dataclasses LoraAdapterConfig and LoraCacheEntry are
declared frozen but contain mutable containers (list and dict) allowing in-place
mutation; update the types and construction to provide stronger immutability by
changing LoraAdapterConfig.target_modules from list[str] to tuple[str, ...] and
make LoraCacheEntry.lora_state_dict and optimizer_state_dict immutable views
(e.g., wrap with types.MappingProxyType) when building entries; also update the
builder function (_build_cache_entry or wherever entries are created) to convert
incoming lists to tuples and wrap dicts in MappingProxyType so the frozen
dataclass truly prevents mutations of internal containers.

In `@claas/training/distillation.py`:
- Around line 140-142: The reload_base_model method assumes self.base_model and
self.device exist; add a defensive precondition check at the start of
reload_base_model that verifies hasattr(self, "base_model") and hasattr(self,
"device") (or self.base_model is not None / self.device is not None) and either
return early (no-op) or raise a clear RuntimeError indicating load_base_model
must be called first; reference the existing _ensure_model_loaded guard in the
comment or docstring to make the precondition explicit and update the docstring
of reload_base_model to state that load_base_model must be called prior.
- Around line 43-85: Consolidate duplicated logic in _cpu_optimizer_state and
_gpu_optimizer_state by extracting a single helper
_remap_optimizer_state(state_dict: dict[str, object], device: torch.device) ->
dict[str, object] that walks the state_dict exactly as today, copies non-tensor
values with copy.deepcopy, and for torch.Tensor values uses
v.detach().to(device).clone(); then implement _cpu_optimizer_state as a thin
wrapper that calls _remap_optimizer_state(state_dict, torch.device("cpu")) and
_gpu_optimizer_state as a thin wrapper that forwards the provided device to
_remap_optimizer_state, keeping the existing type casts ("dict[int, dict[str,
object]]") and return shape unchanged.

ℹ️ Review info

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

Cache: Disabled due to data retention organization setting

Knowledge base: Disabled due to data retention organization setting

📥 Commits

Reviewing files that changed from the base of the PR and between dedadf6 and e6b04ca.

📒 Files selected for processing (6)
  • claas/modal/worker.py
  • claas/training/cache.py
  • claas/training/distillation.py
  • claas/training/engine/local/engine.py
  • tests/test_distillation_optimizer_state.py
  • tests/test_local_training_engine.py

Kion and others added 2 commits March 6, 2026 16:20
Keep the DistillationTrainer and base model across distill() calls instead
of recreating them each time. Cache LoRA adapter weights and optimizer state
on CPU between steps so the second call for a given lora_id skips all disk
I/O. GPU memory is still fully released after each step via the existing
offload_base_model() pattern.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Relocate LoraCacheEntry, LoraAdapterConfig, DistillStepResult, and the
cpu/gpu_optimizer_state helpers from the shared training module into
claas/training/engine/local/cache.py since they are only used by the
local engine's CPU caching path. The Modal worker never uses caching.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@kfallah kfallah force-pushed the persistent-trainer-cpu-cache branch from 0747c76 to 166a16d Compare March 7, 2026 00:22
Kion and others added 6 commits March 6, 2026 16:44
Addresses review feedback: split dataclass types from helper
functions in cache.py for clearer module organization.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
If the model is not a PeftModel, raise an explicit error instead
of silently falling back to full state_dict extraction.

Also coerces task_type enum to string for type consistency.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Add asyncio.Lock to prevent concurrent model loads. Cache
distill results under the output lora_id instead of the input
to handle non-in-place saves correctly.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Cast peft_config to LoraConfig instead of PeftConfig base class
- Add type: ignore for peft Literal bias type and dict[str, object] subscripts
- Add missing system_prompt field to test payloads
- Update integration test stubs for new distill() signature

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@kfallah
Copy link
Copy Markdown
Owner Author

kfallah commented Mar 7, 2026

Closing — CPU cache saves ~4% of distill time (3-5s out of 100s) while adding significant complexity. Not worth the tradeoff.

@kfallah kfallah closed this Mar 7, 2026
@kfallah kfallah deleted the persistent-trainer-cpu-cache branch March 7, 2026 01:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant