From 230f314749bb0be281739b906e635c29e01697c6 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 2 Jun 2026 18:51:31 +0100 Subject: [PATCH] fix: propagate resume_from_checkpoint constructor arg and fix resumed progress bar The LanguageModelSAETrainingRunner constructor accepted a resume_from_checkpoint argument but never assigned it onto the config, so run() always saw cfg.resume_from_checkpoint=None and silently trained from scratch when the kwarg was used. Also initialize the training progress bar at the restored n_training_samples so a resumed run reflects actual progress instead of appearing to start over. The existing test only asserted n_training_samples >= total, which passes whether or not resume happens. Rewrite it to resume with the total set to the checkpoint's progress so a correct resume runs zero further steps, then assert the final weights exactly match the checkpoint. Co-Authored-By: Claude Opus 4.8 (1M context) --- sae_lens/llm_sae_training_runner.py | 2 + sae_lens/training/sae_trainer.py | 6 ++- tests/test_llm_sae_training_runner.py | 58 +++++++++++---------------- 3 files changed, 31 insertions(+), 35 deletions(-) diff --git a/sae_lens/llm_sae_training_runner.py b/sae_lens/llm_sae_training_runner.py index 738931165..f5b5326a7 100644 --- a/sae_lens/llm_sae_training_runner.py +++ b/sae_lens/llm_sae_training_runner.py @@ -128,6 +128,8 @@ def __init__( ) self.cfg = cfg + if resume_from_checkpoint is not None: + self.cfg.resume_from_checkpoint = str(resume_from_checkpoint) # set in cfg.__post_init__; locally bound so type checkers see a str llm_device = self.cfg.llm_device assert llm_device is not None diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 72cb9da02..6ed1fe59e 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -161,7 +161,11 @@ def dead_neurons(self) -> torch.Tensor: def fit(self) -> T_TRAINING_SAE: self.sae.to(self.cfg.device) - pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE") + pbar = tqdm( + total=self.cfg.total_training_samples, + initial=self.n_training_samples, + desc="Training SAE", + ) if self.sae.cfg.normalize_activations == "expected_average_only_in": self.activation_scaler.estimate_scaling_factor( diff --git a/tests/test_llm_sae_training_runner.py b/tests/test_llm_sae_training_runner.py index 005454262..1a2f6fe16 100644 --- a/tests/test_llm_sae_training_runner.py +++ b/tests/test_llm_sae_training_runner.py @@ -4,6 +4,7 @@ import pytest import torch +from safetensors.torch import load_file from transformer_lens import HookedTransformer from sae_lens import __version__ @@ -845,50 +846,39 @@ def test_resume_from_checkpoint( self, small_training_cfg: LanguageModelSAERunnerConfig[StandardTrainingSAEConfig], ): - """Test that training can be resumed from a checkpoint.""" - # First part: train for a small number of tokens - first_cfg = small_training_cfg - first_cfg.training_tokens = 64 # Half of total - - runner1 = LanguageModelSAETrainingRunner(first_cfg) + # First: train and produce a checkpoint. + small_training_cfg.training_tokens = 64 + runner1 = LanguageModelSAETrainingRunner(small_training_cfg) runner1.run() - # Get the checkpoint directory - assert first_cfg.checkpoint_path is not None - checkpoint_dirs = list(Path(first_cfg.checkpoint_path).glob("*")) + assert small_training_cfg.checkpoint_path is not None + checkpoint_dirs = list(Path(small_training_cfg.checkpoint_path).glob("*")) assert len(checkpoint_dirs) == 1 checkpoint_path = checkpoint_dirs[0] - # Second part: resume training from the checkpoint - second_cfg = small_training_cfg - second_cfg.training_tokens = 128 # Full amount - second_cfg.save_final_checkpoint = True + checkpoint_state = torch.load(checkpoint_path / TRAINER_STATE_FILENAME) + checkpoint_n_samples = checkpoint_state["n_training_samples"] + assert checkpoint_n_samples > 0 - # Resume training from checkpoint + # Resume with total training == the checkpoint's progress. A genuine + # resume restores n_training_samples to this value, so the train loop + # runs zero further steps and the loaded checkpoint weights are left + # untouched. If resume were broken (training from scratch), the final + # weights would instead be random-init-then-trained and would differ. + small_training_cfg.training_tokens = checkpoint_n_samples runner2 = LanguageModelSAETrainingRunner( - second_cfg, resume_from_checkpoint=str(checkpoint_path) + small_training_cfg, resume_from_checkpoint=str(checkpoint_path) ) + # The constructor must propagate the path into the config, otherwise + # run() never enters the resume branch. + assert runner2.cfg.resume_from_checkpoint == str(checkpoint_path) runner2.run() - # The resumed SAE should have trained on all tokens - # Check if various metrics are reasonable - - # Get the final checkpoint and check its metrics - assert second_cfg.checkpoint_path is not None - final_checkpoint_dirs = list(Path(second_cfg.checkpoint_path).glob("*")) - assert ( - len(final_checkpoint_dirs) >= 1 - ) # Should have at least the original checkpoint - - # Find the latest checkpoint (could be different from the first if new one was created) - latest_checkpoint = max(final_checkpoint_dirs, key=lambda p: p.stat().st_mtime) - - # Load the final state - final_training_state_path = latest_checkpoint / TRAINER_STATE_FILENAME - final_training_state = torch.load(final_training_state_path) - - # Ensure the resumed training completed the full training - assert final_training_state["n_training_samples"] >= 128 + expected_weights = load_file(str(checkpoint_path / SAE_WEIGHTS_FILENAME)) + runner2.sae.process_state_dict_for_loading(expected_weights) + final_state = runner2.sae.state_dict() + for name, tensor in expected_weights.items(): + assert torch.allclose(final_state[name], tensor) def test_activations_store_state_preserved( self,