Skip to content

fix: propagate resume_from_checkpoint constructor arg and fix resumed progress bar#682

Open
chanind wants to merge 1 commit into
mainfrom
fix/resume-from-checkpoint-constructor
Open

fix: propagate resume_from_checkpoint constructor arg and fix resumed progress bar#682
chanind wants to merge 1 commit into
mainfrom
fix/resume-from-checkpoint-constructor

Conversation

@chanind

@chanind chanind commented Jun 2, 2026

Copy link
Copy Markdown
Collaborator

Problem

Passing resume_from_checkpoint to LanguageModelSAETrainingRunner silently did nothing — training started from scratch.

The constructor accepted the argument but never assigned it onto the config:

def __init__(self, cfg, ..., resume_from_checkpoint=None):
    ...
    self.cfg = cfg   # resume_from_checkpoint dropped here

while run() only consults self.cfg.resume_from_checkpoint, which therefore stayed None. Setting it directly on the config object worked; passing the kwarg (the natural, documented way) did not.

A secondary, cosmetic issue made working resumes look broken: the training progress bar was always created at 0, so a resumed run appeared to start over even though the loop correctly skipped already-completed steps.

Changes

  • llm_sae_training_runner.py — propagate the constructor arg onto the config so run() enters the resume branch.
  • sae_trainer.py — initialize the progress bar at the restored n_training_samples so a resumed run shows real progress (e.g. 32/64 instead of 0/64). The training loop itself was always correct.
  • tests/test_llm_sae_training_runner.py — the old test_resume_from_checkpoint only asserted n_training_samples >= total, which is true whether or not resume happens. Rewrote it to resume with the total set to the checkpoint's progress, so a correct resume runs zero further steps and the final weights must exactly match the checkpoint. Verified it fails without the constructor fix and passes with it.

Note: MultiSAETrainingRunner already wires its resume_from_checkpoint arg into its config, so it was unaffected.

🤖 Generated with Claude Code

… progress bar

The LanguageModelSAETrainingRunner constructor accepted a
resume_from_checkpoint argument but never assigned it onto the config, so
run() always saw cfg.resume_from_checkpoint=None and silently trained from
scratch when the kwarg was used.

Also initialize the training progress bar at the restored n_training_samples
so a resumed run reflects actual progress instead of appearing to start over.

The existing test only asserted n_training_samples >= total, which passes
whether or not resume happens. Rewrite it to resume with the total set to the
checkpoint's progress so a correct resume runs zero further steps, then assert
the final weights exactly match the checkpoint.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@claude

claude Bot commented Jun 2, 2026

Copy link
Copy Markdown

Code Review

The fix is clean and well-targeted. Here are my observations:


sae_lens/llm_sae_training_runner.py

The two-line change is correct and minimal. One minor point: it mutates the caller's cfg object in-place, which is consistent with how MultiSAETrainingRunner handles the same argument. That said, if a caller reuses the same cfg object across multiple runner instances, this mutation could cause surprising behavior. A short comment or docstring note about this side-effect would help, but it's a pre-existing pattern, so not a blocker.


sae_lens/training/sae_trainer.py

The initial=self.n_training_samples fix is correct. The ordering in run() ensures load_trainer_state() is called before trainer.fit(), so by the time the progress bar is created, self.n_training_samples already holds the restored value.


tests/test_llm_sae_training_runner.py

The new test is a substantial improvement in correctness — asserting weight equality rather than just n_training_samples >= total actually verifies that resume happened (vs. training from scratch happened to reach the same count). A few points:

Inline comments — CLAUDE.md says "Do not add doc comments to test functions" and advises writing comments only when the WHY is non-obvious. The first comment block (lines 862–867) explains the subtle test design, which is non-obvious, so it's arguably justified. The second block (lines 872–873) is more questionable — the assertion it precedes (runner2.cfg.resume_from_checkpoint == str(checkpoint_path)) is self-documenting and the comment just restates what the PR description already says. Removing it would tighten the test.

torch.load without weights_only — Line 859 uses torch.load(checkpoint_path / TRAINER_STATE_FILENAME) without weights_only=True. This isn't introduced by this PR, but since this test now touches that line it's worth noting: PyTorch 2.x will warn (and eventually error) on this. The call-site in sae_trainer.py:300 has the same issue. Adding weights_only=True would be the safe fix (though it requires the saved data to be a plain dict of tensors/scalars, which it appears to be from the save_checkpoint code).

Correctness concern with process_state_dict_for_loading — The test calls runner2.sae.process_state_dict_for_loading(expected_weights) to transform the key names in expected_weights before comparing. For the standard SAE this is a no-op, and the test fixture uses StandardTrainingSAEConfig, so it passes. However, if this test were ever extended to JumpReLU or another subclass with a non-trivial override, it would still work correctly because load_weights_from_checkpoint applies the same transformation when loading into the SAE. So the pattern is sound.

Test is clean overall — The "run to zero extra steps" approach is a clever way to avoid non-determinism without using random seeds. Good design.


Summary

The core bug fix is correct and minimal. The test improvement is meaningful. The main actionable feedback is removing the comment block at lines 872–873 (which is redundant with the assertion itself), and considering weights_only=True in torch.load calls, though the latter is pre-existing tech debt rather than a regression in this PR.

@codecov

codecov Bot commented Jun 2, 2026

Copy link
Copy Markdown

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant