diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 26bf881e4..94fac638a 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -641,7 +641,7 @@ def get_batch_tokens( @torch.no_grad() def get_activations(self, batch_tokens: torch.Tensor): """ - Returns activations of shape (batches, context, num_layers, d_in) + Returns activations of shape (batch_size, context, num_layers, d_in) d_in may result from a concatenated head dimension. """ @@ -664,9 +664,9 @@ def get_activations(self, batch_tokens: torch.Tensor): :, slice(*self.seqpos_slice) ] - n_batches, n_context = layerwise_activations.shape[:2] + batch_size, n_context = layerwise_activations.shape[:2] - stacked_activations = torch.zeros((n_batches, n_context, self.d_in)) + stacked_activations = torch.zeros((batch_size, n_context, self.d_in)) if self.hook_head_index is not None: stacked_activations[:, :] = layerwise_activations[ @@ -675,13 +675,13 @@ def get_activations(self, batch_tokens: torch.Tensor): elif layerwise_activations.ndim > 3: # if we have a head dimension try: stacked_activations[:, :] = layerwise_activations.view( - n_batches, n_context, -1 + batch_size, n_context, -1 ) except RuntimeError as e: logger.error(f"Error during view operation: {e}") logger.info("Attempting to use reshape instead...") stacked_activations[:, :] = layerwise_activations.reshape( - n_batches, n_context, -1 + batch_size, n_context, -1 ) else: stacked_activations[:, :] = layerwise_activations @@ -821,7 +821,9 @@ def get_data_loader( Return an auto-refilling stream of filtered and mixed activations. """ return mixing_buffer( - buffer_size=self.n_batches_in_buffer * self.training_context_size, + buffer_size=self.n_batches_in_buffer + * self.store_batch_size_prompts + * self.training_context_size, batch_size=self.train_batch_size_tokens, activations_loader=self._iterate_filtered_activations(), mix_fraction=self.activations_mixing_fraction, @@ -943,7 +945,9 @@ def get_multi_hook_data_loader( "via from_config_multi_hook" ) return multi_hook_concat_split_iter( - buffer_size=self.n_batches_in_buffer * self.training_context_size, + buffer_size=self.n_batches_in_buffer + * self.store_batch_size_prompts + * self.training_context_size, batch_size=self.train_batch_size_tokens, activations_loader=self._iterate_filtered_multi_hook_activations(), hook_names=list(self._hook_names), diff --git a/tests/helpers.py b/tests/helpers.py index 30e3060e9..da3f3b5be 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -63,6 +63,7 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False): model_name: str model_class_name: str hook_name: str + hook_eval: str hook_head_index: int | None dataset_path: str dataset_trust_remote_code: bool @@ -76,10 +77,10 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False): n_batches_in_buffer: int training_tokens: int store_batch_size_prompts: int - normalize_activations: str seqpos_slice: tuple[int | None, ...] | Sequence[int | None] disable_concat_sequences: bool sequence_separator_token: int | Literal["bos", "eos", "sep"] | None + activations_mixing_fraction: float device: str llm_device: str | None act_store_device: str | None @@ -112,12 +113,14 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False): checkpoint_path: str | None save_final_checkpoint: bool output_path: str | None + resume_from_checkpoint: str | None verbose: bool model_kwargs: dict[str, Any] model_from_pretrained_kwargs: dict[str, Any] | None sae_lens_version: str sae_lens_training_version: str exclude_special_tokens: bool | list[int] + n_batches_for_norm_estimate: int # Base TrainingSAEConfig fields + all architecture specific fields diff --git a/tests/test_util.py b/tests/test_util.py index 5e06583a3..d179bff9e 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,10 +1,11 @@ -from dataclasses import dataclass +from dataclasses import dataclass, fields from pathlib import Path import pytest import torch from transformer_lens import HookedTransformer +from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.util import ( cosine_similarities, dtype_to_str, @@ -15,6 +16,7 @@ str_to_dtype, temporary_seed, ) +from tests.helpers import LanguageModelSAERunnerConfigDict @pytest.mark.parametrize( @@ -362,3 +364,9 @@ def test_temporary_seed_none_is_noop(): assert not torch.equal(before, after) # And we should still get a valid tensor assert sample.shape == (1,) + + +def test_language_model_sae_runner_config_dict_matches_config_fields(): + config_fields = {f.name for f in fields(LanguageModelSAERunnerConfig)} + dict_fields = set(LanguageModelSAERunnerConfigDict.__annotations__.keys()) + assert config_fields == dict_fields diff --git a/tests/training/test_activations_store.py b/tests/training/test_activations_store.py index 3b9d27482..8fc0a5bae 100644 --- a/tests/training/test_activations_store.py +++ b/tests/training/test_activations_store.py @@ -186,6 +186,26 @@ def test_activations_store__shapes_look_correct_with_real_models_and_datasets( assert tok_batch.device == store.device +def test_activations_store__can_train_on_entire_activations_buffer( + ts_model: HookedTransformer, +): + n_batches_in_buffer = 4 + store_batch_size_prompts = 8 + context_size = 5 + train_batch_size_tokens = ( + n_batches_in_buffer * store_batch_size_prompts * context_size + ) + cfg = build_runner_cfg( + n_batches_in_buffer=n_batches_in_buffer, + store_batch_size_prompts=store_batch_size_prompts, + context_size=context_size, + train_batch_size_tokens=train_batch_size_tokens, + ) + activation_store = ActivationsStore.from_config(ts_model, cfg) + batch = activation_store.next_batch() + assert batch.shape[0] == train_batch_size_tokens + + def test_activations_store__get_activations_head_hook(ts_model: HookedTransformer): cfg = build_runner_cfg( hook_name="blocks.0.attn.hook_q", @@ -718,6 +738,7 @@ def test_activations_next_batch_excludes_special_tokens( store_batch_size_prompts=2, hook_name=hook_name, train_batch_size_tokens=5, + activations_mixing_fraction=0.0, ) cfg = build_runner_cfg( exclude_special_tokens=True, @@ -725,6 +746,7 @@ def test_activations_next_batch_excludes_special_tokens( store_batch_size_prompts=2, hook_name=hook_name, train_batch_size_tokens=5, + activations_mixing_fraction=0.0, ) dataset = Dataset.from_list([{"text": "hello world"}] * 100) _, cache = ts_model.run_with_cache(dataset[0]["text"])