From 287958a5219034b28db6876c9d7bc1abdccf4cee Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Sat, 6 Jun 2026 15:28:56 -0700 Subject: [PATCH 1/6] tests: Add failing test that LanguageModelSAERunnerConfigDict matches config fields Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/test_util.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_util.py b/tests/test_util.py index 5e06583a..d179bff9 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,10 +1,11 @@ -from dataclasses import dataclass +from dataclasses import dataclass, fields from pathlib import Path import pytest import torch from transformer_lens import HookedTransformer +from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.util import ( cosine_similarities, dtype_to_str, @@ -15,6 +16,7 @@ str_to_dtype, temporary_seed, ) +from tests.helpers import LanguageModelSAERunnerConfigDict @pytest.mark.parametrize( @@ -362,3 +364,9 @@ def test_temporary_seed_none_is_noop(): assert not torch.equal(before, after) # And we should still get a valid tensor assert sample.shape == (1,) + + +def test_language_model_sae_runner_config_dict_matches_config_fields(): + config_fields = {f.name for f in fields(LanguageModelSAERunnerConfig)} + dict_fields = set(LanguageModelSAERunnerConfigDict.__annotations__.keys()) + assert config_fields == dict_fields From ec7950cb2c7166c0236cbf9de2e16cfc827e366f Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Sat, 6 Jun 2026 15:31:58 -0700 Subject: [PATCH 2/6] fix: align LanguageModelSAERunnerConfigDict with LanguageModelSAERunnerConfig, test pases Add four missing and drop normalize_activations, which lives on the SAE config rather than the runner's. Co-Authored-By: Claude Opus 4.8 (1M context) --- tests/helpers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index 30e3060e..da3f3b5b 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -63,6 +63,7 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False): model_name: str model_class_name: str hook_name: str + hook_eval: str hook_head_index: int | None dataset_path: str dataset_trust_remote_code: bool @@ -76,10 +77,10 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False): n_batches_in_buffer: int training_tokens: int store_batch_size_prompts: int - normalize_activations: str seqpos_slice: tuple[int | None, ...] | Sequence[int | None] disable_concat_sequences: bool sequence_separator_token: int | Literal["bos", "eos", "sep"] | None + activations_mixing_fraction: float device: str llm_device: str | None act_store_device: str | None @@ -112,12 +113,14 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False): checkpoint_path: str | None save_final_checkpoint: bool output_path: str | None + resume_from_checkpoint: str | None verbose: bool model_kwargs: dict[str, Any] model_from_pretrained_kwargs: dict[str, Any] | None sae_lens_version: str sae_lens_training_version: str exclude_special_tokens: bool | list[int] + n_batches_for_norm_estimate: int # Base TrainingSAEConfig fields + all architecture specific fields From 0b0cded21674c578295fa0f12c20b896f7d5ff5c Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Sat, 6 Jun 2026 13:31:29 -0700 Subject: [PATCH 3/6] fix: misnamed size --- sae_lens/training/activations_store.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 26bf881e..336394f4 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -641,7 +641,7 @@ def get_batch_tokens( @torch.no_grad() def get_activations(self, batch_tokens: torch.Tensor): """ - Returns activations of shape (batches, context, num_layers, d_in) + Returns activations of shape (batch_size, context, num_layers, d_in) d_in may result from a concatenated head dimension. """ @@ -664,9 +664,9 @@ def get_activations(self, batch_tokens: torch.Tensor): :, slice(*self.seqpos_slice) ] - n_batches, n_context = layerwise_activations.shape[:2] + batch_size, n_context = layerwise_activations.shape[:2] - stacked_activations = torch.zeros((n_batches, n_context, self.d_in)) + stacked_activations = torch.zeros((batch_size, n_context, self.d_in)) if self.hook_head_index is not None: stacked_activations[:, :] = layerwise_activations[ @@ -675,13 +675,13 @@ def get_activations(self, batch_tokens: torch.Tensor): elif layerwise_activations.ndim > 3: # if we have a head dimension try: stacked_activations[:, :] = layerwise_activations.view( - n_batches, n_context, -1 + batch_size, n_context, -1 ) except RuntimeError as e: logger.error(f"Error during view operation: {e}") logger.info("Attempting to use reshape instead...") stacked_activations[:, :] = layerwise_activations.reshape( - n_batches, n_context, -1 + batch_size, n_context, -1 ) else: stacked_activations[:, :] = layerwise_activations From 55c6bbf95a208ef5047c564fd654d6f71a15679f Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Sat, 6 Jun 2026 14:05:50 -0700 Subject: [PATCH 4/6] tests: Add failing ActivationStore test Trying to consume the expected number of tokens from the activations buffer reveals that it is smaller than expected (mixing_buffer raises ValueError on `buffer_size < batch_size`) test_language_model_sae_runner_othellogpt in benchmark/test_language_model_sae_runner.py also fails with the same error. It probably went unnoticed because it's not run in CI. --- tests/training/test_activations_store.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/training/test_activations_store.py b/tests/training/test_activations_store.py index 3b9d2748..d7a51689 100644 --- a/tests/training/test_activations_store.py +++ b/tests/training/test_activations_store.py @@ -186,6 +186,26 @@ def test_activations_store__shapes_look_correct_with_real_models_and_datasets( assert tok_batch.device == store.device +def test_activations_store__can_train_on_entire_activations_buffer( + ts_model: HookedTransformer, +): + n_batches_in_buffer = 4 + store_batch_size_prompts = 8 + context_size = 5 + train_batch_size_tokens = ( + n_batches_in_buffer * store_batch_size_prompts * context_size + ) + cfg = build_runner_cfg( + n_batches_in_buffer=n_batches_in_buffer, + store_batch_size_prompts=store_batch_size_prompts, + context_size=context_size, + train_batch_size_tokens=train_batch_size_tokens, + ) + activation_store = ActivationsStore.from_config(ts_model, cfg) + batch = activation_store.next_batch() + assert batch.shape[0] == train_batch_size_tokens + + def test_activations_store__get_activations_head_hook(ts_model: HookedTransformer): cfg = build_runner_cfg( hook_name="blocks.0.attn.hook_q", From 49333ef96ab8a85df5b1a4ce2b10a9fae9c9e20c Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Sat, 6 Jun 2026 14:17:48 -0700 Subject: [PATCH 5/6] fix: incorrect buffer sizes calculated in ActivationsStore (but a new test unexpectedly fails!) Fixes failing test in last commit Also fixes test_language_model_sae_runner_othellogpt However, test_activations_next_batch_excludes_special_tokens unexpectedly starts failing! --- sae_lens/training/activations_store.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 336394f4..94fac638 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -821,7 +821,9 @@ def get_data_loader( Return an auto-refilling stream of filtered and mixed activations. """ return mixing_buffer( - buffer_size=self.n_batches_in_buffer * self.training_context_size, + buffer_size=self.n_batches_in_buffer + * self.store_batch_size_prompts + * self.training_context_size, batch_size=self.train_batch_size_tokens, activations_loader=self._iterate_filtered_activations(), mix_fraction=self.activations_mixing_fraction, @@ -943,7 +945,9 @@ def get_multi_hook_data_loader( "via from_config_multi_hook" ) return multi_hook_concat_split_iter( - buffer_size=self.n_batches_in_buffer * self.training_context_size, + buffer_size=self.n_batches_in_buffer + * self.store_batch_size_prompts + * self.training_context_size, batch_size=self.train_batch_size_tokens, activations_loader=self._iterate_filtered_multi_hook_activations(), hook_names=list(self._hook_names), From 145cd7bd04efcb1375fc44c8eac368c1b4beae56 Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Sat, 6 Jun 2026 15:38:53 -0700 Subject: [PATCH 6/6] fix: test_activations_next_batch_excludes_special_tokens no longer mixes buffer, test passes again Mixing caused the test to incorrectly conclude bos is missing when it wasn't --- tests/training/test_activations_store.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/training/test_activations_store.py b/tests/training/test_activations_store.py index d7a51689..8fc0a5ba 100644 --- a/tests/training/test_activations_store.py +++ b/tests/training/test_activations_store.py @@ -738,6 +738,7 @@ def test_activations_next_batch_excludes_special_tokens( store_batch_size_prompts=2, hook_name=hook_name, train_batch_size_tokens=5, + activations_mixing_fraction=0.0, ) cfg = build_runner_cfg( exclude_special_tokens=True, @@ -745,6 +746,7 @@ def test_activations_next_batch_excludes_special_tokens( store_batch_size_prompts=2, hook_name=hook_name, train_batch_size_tokens=5, + activations_mixing_fraction=0.0, ) dataset = Dataset.from_list([{"text": "hello world"}] * 100) _, cache = ts_model.run_with_cache(dataset[0]["text"])