Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sae_lens/llm_sae_training_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def __init__(
)

self.cfg = cfg
if resume_from_checkpoint is not None:
self.cfg.resume_from_checkpoint = str(resume_from_checkpoint)
# set in cfg.__post_init__; locally bound so type checkers see a str
llm_device = self.cfg.llm_device
assert llm_device is not None
Expand Down
6 changes: 5 additions & 1 deletion sae_lens/training/sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,11 @@ def dead_neurons(self) -> torch.Tensor:

def fit(self) -> T_TRAINING_SAE:
self.sae.to(self.cfg.device)
pbar = tqdm(total=self.cfg.total_training_samples, desc="Training SAE")
pbar = tqdm(
total=self.cfg.total_training_samples,
initial=self.n_training_samples,
desc="Training SAE",
)

if self.sae.cfg.normalize_activations == "expected_average_only_in":
self.activation_scaler.estimate_scaling_factor(
Expand Down
58 changes: 24 additions & 34 deletions tests/test_llm_sae_training_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
import torch
from safetensors.torch import load_file
from transformer_lens import HookedTransformer

from sae_lens import __version__
Expand Down Expand Up @@ -845,50 +846,39 @@ def test_resume_from_checkpoint(
self,
small_training_cfg: LanguageModelSAERunnerConfig[StandardTrainingSAEConfig],
):
"""Test that training can be resumed from a checkpoint."""
# First part: train for a small number of tokens
first_cfg = small_training_cfg
first_cfg.training_tokens = 64 # Half of total

runner1 = LanguageModelSAETrainingRunner(first_cfg)
# First: train and produce a checkpoint.
small_training_cfg.training_tokens = 64
runner1 = LanguageModelSAETrainingRunner(small_training_cfg)
runner1.run()

# Get the checkpoint directory
assert first_cfg.checkpoint_path is not None
checkpoint_dirs = list(Path(first_cfg.checkpoint_path).glob("*"))
assert small_training_cfg.checkpoint_path is not None
checkpoint_dirs = list(Path(small_training_cfg.checkpoint_path).glob("*"))
assert len(checkpoint_dirs) == 1
checkpoint_path = checkpoint_dirs[0]

# Second part: resume training from the checkpoint
second_cfg = small_training_cfg
second_cfg.training_tokens = 128 # Full amount
second_cfg.save_final_checkpoint = True
checkpoint_state = torch.load(checkpoint_path / TRAINER_STATE_FILENAME)
checkpoint_n_samples = checkpoint_state["n_training_samples"]
assert checkpoint_n_samples > 0

# Resume training from checkpoint
# Resume with total training == the checkpoint's progress. A genuine
# resume restores n_training_samples to this value, so the train loop
# runs zero further steps and the loaded checkpoint weights are left
# untouched. If resume were broken (training from scratch), the final
# weights would instead be random-init-then-trained and would differ.
small_training_cfg.training_tokens = checkpoint_n_samples
runner2 = LanguageModelSAETrainingRunner(
second_cfg, resume_from_checkpoint=str(checkpoint_path)
small_training_cfg, resume_from_checkpoint=str(checkpoint_path)
)
# The constructor must propagate the path into the config, otherwise
# run() never enters the resume branch.
assert runner2.cfg.resume_from_checkpoint == str(checkpoint_path)
runner2.run()

# The resumed SAE should have trained on all tokens
# Check if various metrics are reasonable

# Get the final checkpoint and check its metrics
assert second_cfg.checkpoint_path is not None
final_checkpoint_dirs = list(Path(second_cfg.checkpoint_path).glob("*"))
assert (
len(final_checkpoint_dirs) >= 1
) # Should have at least the original checkpoint

# Find the latest checkpoint (could be different from the first if new one was created)
latest_checkpoint = max(final_checkpoint_dirs, key=lambda p: p.stat().st_mtime)

# Load the final state
final_training_state_path = latest_checkpoint / TRAINER_STATE_FILENAME
final_training_state = torch.load(final_training_state_path)

# Ensure the resumed training completed the full training
assert final_training_state["n_training_samples"] >= 128
expected_weights = load_file(str(checkpoint_path / SAE_WEIGHTS_FILENAME))
runner2.sae.process_state_dict_for_loading(expected_weights)
final_state = runner2.sae.state_dict()
for name, tensor in expected_weights.items():
assert torch.allclose(final_state[name], tensor)

def test_activations_store_state_preserved(
self,
Expand Down
Loading