fix: propagate resume_from_checkpoint constructor arg and fix resumed progress bar#682
fix: propagate resume_from_checkpoint constructor arg and fix resumed progress bar#682chanind wants to merge 1 commit into
Conversation
… progress bar The LanguageModelSAETrainingRunner constructor accepted a resume_from_checkpoint argument but never assigned it onto the config, so run() always saw cfg.resume_from_checkpoint=None and silently trained from scratch when the kwarg was used. Also initialize the training progress bar at the restored n_training_samples so a resumed run reflects actual progress instead of appearing to start over. The existing test only asserted n_training_samples >= total, which passes whether or not resume happens. Rewrite it to resume with the total set to the checkpoint's progress so a correct resume runs zero further steps, then assert the final weights exactly match the checkpoint. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Code ReviewThe fix is clean and well-targeted. Here are my observations:
|
Codecov Report✅ All modified and coverable lines are covered by tests. 📢 Thoughts on this report? Let us know! |
Problem
Passing
resume_from_checkpointtoLanguageModelSAETrainingRunnersilently did nothing — training started from scratch.The constructor accepted the argument but never assigned it onto the config:
while
run()only consultsself.cfg.resume_from_checkpoint, which therefore stayedNone. Setting it directly on the config object worked; passing the kwarg (the natural, documented way) did not.A secondary, cosmetic issue made working resumes look broken: the training progress bar was always created at
0, so a resumed run appeared to start over even though the loop correctly skipped already-completed steps.Changes
llm_sae_training_runner.py— propagate the constructor arg onto the config sorun()enters the resume branch.sae_trainer.py— initialize the progress bar at the restoredn_training_samplesso a resumed run shows real progress (e.g.32/64instead of0/64). The training loop itself was always correct.tests/test_llm_sae_training_runner.py— the oldtest_resume_from_checkpointonly assertedn_training_samples >= total, which is true whether or not resume happens. Rewrote it to resume with the total set to the checkpoint's progress, so a correct resume runs zero further steps and the final weights must exactly match the checkpoint. Verified it fails without the constructor fix and passes with it.Note:
MultiSAETrainingRunneralready wires itsresume_from_checkpointarg into its config, so it was unaffected.🤖 Generated with Claude Code