From 53f937ec5e467bcc13bc2db59822f3467e50228d Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 30 Mar 2025 18:39:11 -0400 Subject: [PATCH 01/61] test: add unit tests for ActivationsStore multi-layer support MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add unit tests for implementing the CrossCoder SAE's ability to collect activations from the same hook type at multiple layers. These tests verify: - Initialization with multiple layers - Activation collection from all layers - Proper buffer and batch handling - Layer-specific normalization - Backward compatibility 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- .../test_activations_store_multilayer.py | 231 ++++++++++++++++++ 1 file changed, 231 insertions(+) create mode 100644 tests/training/test_activations_store_multilayer.py diff --git a/tests/training/test_activations_store_multilayer.py b/tests/training/test_activations_store_multilayer.py new file mode 100644 index 000000000..42c6c2a76 --- /dev/null +++ b/tests/training/test_activations_store_multilayer.py @@ -0,0 +1,231 @@ +"""Tests for ActivationsStore with multiple layer support.""" + +import pytest +import torch +from datasets import Dataset +from transformer_lens import HookedTransformer + +from sae_lens.training.activations_store import ActivationsStore +from tests.helpers import build_sae_cfg, load_model_cached + + +def test_activations_store_init_with_multiple_layers(ts_model: HookedTransformer): + """Test initialization with a list of layers instead of a single layer.""" + # Initialize with multiple layers + cfg = build_sae_cfg( + hook_name="blocks.{}.hook_resid_pre", + hook_layers=[0, 1, 2] + ) + + activation_store = ActivationsStore.from_config(ts_model, cfg) + + # Check that the hook layers are correctly stored + assert activation_store.hook_layers == [0, 1, 2] + + # Verify backward compatibility - a single hook_layer should be converted to a list + cfg_single = build_sae_cfg( + hook_name="blocks.{}.hook_resid_pre", + hook_layer=1 + ) + + single_layer_store = ActivationsStore.from_config(ts_model, cfg_single) + assert single_layer_store.hook_layers == [1] + + +def test_activations_store_get_activations_multiple_layers(ts_model: HookedTransformer): + """Test that get_activations collects activations from all specified layers.""" + # Setup with multiple layers + cfg = build_sae_cfg( + hook_name="blocks.{}.hook_resid_pre", + hook_layers=[0, 1, 2], + context_size=5 + ) + + dataset = Dataset.from_list([{"text": "hello world"}] * 10) + activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) + + # Get a batch of tokens and activations + batch_tokens = activation_store.get_batch_tokens() + activations = activation_store.get_activations(batch_tokens) + + # Check shape: [batch_size, context_size, num_layers, d_in] + assert activations.shape == ( + activation_store.store_batch_size_prompts, + activation_store.context_size, + len(activation_store.hook_layers), + activation_store.d_in + ) + + # Verify that layers are in the correct order + # Run with cache directly to compare against + _, cache = ts_model.run_with_cache( + batch_tokens, + names_filter=[f"blocks.{i}.hook_resid_pre" for i in [0, 1, 2]] + ) + + for i, layer in enumerate([0, 1, 2]): + hook_name = f"blocks.{layer}.hook_resid_pre" + # Compare the activations for this layer with what we got from run_with_cache + assert torch.allclose( + activations[:, :, i, :], + cache[hook_name], + atol=1e-5 + ) + + +def test_activations_store_get_buffer_multiple_layers(ts_model: HookedTransformer): + """Test buffer handling with multiple layers.""" + # Setup with multiple layers + cfg = build_sae_cfg( + hook_name="blocks.{}.hook_resid_pre", + hook_layers=[0, 1, 2], + context_size=5 + ) + + dataset = Dataset.from_list([{"text": "hello world"}] * 20) + activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) + + # Get buffer with 2 batches + buffer_activations, buffer_tokens = activation_store.get_buffer(n_batches_in_buffer=2) + + # Check shape: [(batch_size * context_size * n_batches), num_layers, d_in] + expected_size = activation_store.store_batch_size_prompts * activation_store.context_size * 2 + assert buffer_activations.shape == (expected_size, len(activation_store.hook_layers), activation_store.d_in) + assert buffer_tokens.shape == (expected_size,) + + +def test_activations_store_next_batch_multiple_layers(ts_model: HookedTransformer): + """Test that next_batch returns correct batch shape with multiple layers.""" + # Setup with multiple layers + cfg = build_sae_cfg( + hook_name="blocks.{}.hook_resid_pre", + hook_layers=[0, 1, 2], + context_size=5, + train_batch_size_tokens=10 + ) + + dataset = Dataset.from_list([{"text": "hello world"}] * 20) + activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) + + # Get a batch + batch = activation_store.next_batch() + + # Check batch[0] shape: [batch_size, num_layers, d_in] + assert batch[0].shape == (10, len(activation_store.hook_layers), activation_store.d_in) + + # Verify the token IDs + assert batch[1] is not None + assert batch[1].shape == (10,) + + +def test_activations_store_normalization_multiple_layers(ts_model: HookedTransformer): + """Test normalization when using multiple layers.""" + # Setup with normalization and multiple layers + cfg = build_sae_cfg( + hook_name="blocks.{}.hook_resid_pre", + hook_layers=[0, 1, 2], + normalize_activations="constant_norm_rescale", + context_size=5 + ) + + dataset = Dataset.from_list([{"text": "hello world"}] * 20) + activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) + + # Set a fixed norm scaling factor for testing + activation_store.estimated_norm_scaling_factor = 2.0 + + # Get a batch with normalized activations + batch = activation_store.next_batch() + + # Check that the activations have been properly normalized + # The norm should be approximately sqrt(d_in) for each layer + for layer_idx in range(len(activation_store.hook_layers)): + layer_activations = batch[0][:, layer_idx, :] + # Check if average norm is approximately as expected (allowing for some variance) + avg_norm = layer_activations.norm(dim=-1).mean() + expected_norm = (activation_store.d_in ** 0.5) + assert avg_norm.item() == pytest.approx(expected_norm, abs=2.0) + + +def test_backward_compatibility_single_layer(ts_model: HookedTransformer): + """Test that single layer behavior is unchanged with the multi-layer support.""" + # Create a store with single layer (old behavior) + cfg_single = build_sae_cfg( + hook_name="blocks.0.hook_resid_pre", + hook_layer=0, + context_size=5 + ) + + dataset = Dataset.from_list([{"text": "hello world"}] * 10) + single_store = ActivationsStore.from_config(ts_model, cfg_single, override_dataset=dataset) + + # Create a store with single layer (new behavior) + cfg_multi = build_sae_cfg( + hook_name="blocks.{}.hook_resid_pre", + hook_layers=[0], + context_size=5 + ) + multi_store = ActivationsStore.from_config(ts_model, cfg_multi, override_dataset=dataset) + + # Get tokens and activations from both + batch_tokens_single = single_store.get_batch_tokens() + activations_single = single_store.get_activations(batch_tokens_single) + + batch_tokens_multi = multi_store.get_batch_tokens() + activations_multi = multi_store.get_activations(batch_tokens_multi) + + # Check that activations have the same shape and values + assert activations_single.shape == activations_multi.shape + # Run with deterministic seed to ensure tokens are the same + if torch.allclose(batch_tokens_single, batch_tokens_multi): + assert torch.allclose(activations_single, activations_multi, atol=1e-5) + + +def test_mixed_hook_formats(ts_model: HookedTransformer): + """Test that both formatted and non-formatted hook names work with multiple layers.""" + # Test with formatted hook name (with {}) + cfg_formatted = build_sae_cfg( + hook_name="blocks.{}.hook_resid_pre", + hook_layers=[0, 1], + context_size=5 + ) + + # Test with non-formatted hook name + cfg_non_formatted = build_sae_cfg( + hook_name="blocks.0.hook_resid_pre", # Specific to layer 0 + hook_layers=[0], # Only layer 0 works with this hook + context_size=5 + ) + + dataset = Dataset.from_list([{"text": "hello world"}] * 10) + + # Both should initialize without errors + store_formatted = ActivationsStore.from_config( + ts_model, cfg_formatted, override_dataset=dataset + ) + store_non_formatted = ActivationsStore.from_config( + ts_model, cfg_non_formatted, override_dataset=dataset + ) + + # Both should be able to get activations + activations_formatted = store_formatted.get_activations( + store_formatted.get_batch_tokens() + ) + activations_non_formatted = store_non_formatted.get_activations( + store_non_formatted.get_batch_tokens() + ) + + # Check shapes + assert activations_formatted.shape == ( + store_formatted.store_batch_size_prompts, + store_formatted.context_size, + len(store_formatted.hook_layers), + store_formatted.d_in + ) + + assert activations_non_formatted.shape == ( + store_non_formatted.store_batch_size_prompts, + store_non_formatted.context_size, + len(store_non_formatted.hook_layers), + store_non_formatted.d_in + ) From 2b7a4385f316f4daf41559afa410bcccc1aab6f7 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sat, 5 Apr 2025 15:42:29 -0400 Subject: [PATCH 02/61] Implement multilayer activations store except normalization --- sae_lens/config.py | 16 ++++- sae_lens/training/activations_store.py | 58 ++++++++++++------- tests/helpers.py | 2 + .../test_activations_store_multilayer.py | 19 ++---- 4 files changed, 59 insertions(+), 36 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 6faf8ac9b..d4cee5049 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -62,6 +62,7 @@ class LanguageModelSAERunnerConfig: architecture (str): The architecture to use, either "standard", "gated", "topk", or "jumprelu". model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub. model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`. + TODO(mkbehr): update hook name param docs for multilayer case hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook. hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation. hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing. @@ -149,6 +150,7 @@ class LanguageModelSAERunnerConfig: hook_name: str = "blocks.0.hook_mlp_out" hook_eval: str = "NOT_IN_USE" hook_layer: int = 0 + hook_layers: list[int] | None = None hook_head_index: int | None = None dataset_path: str = "" dataset_trust_remote_code: bool = True @@ -435,7 +437,7 @@ def total_training_steps(self) -> int: return self.total_training_tokens // self.train_batch_size_tokens def get_base_sae_cfg_dict(self) -> dict[str, Any]: - return { + cfg_dict = { # TEMP "architecture": self.architecture, "d_in": self.d_in, @@ -460,6 +462,11 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]: "seqpos_slice": self.seqpos_slice, } + if self.hook_layers is not None: + cfg_dict["hook_layers"] = self.hook_layers + + return cfg_dict + def get_training_sae_cfg_dict(self) -> dict[str, Any]: return { **self.get_base_sae_cfg_dict(), @@ -554,6 +561,7 @@ class CacheActivationsRunnerConfig: hook_layer: int d_in: int training_tokens: int + hook_layers: list[int] | None = None context_size: int = -1 # Required if dataset is not tokenized model_class_name: str = "HookedTransformer" @@ -608,8 +616,12 @@ def __post_init__(self): ) if self.new_cached_activations_path is None: + hook_name_str = self.hook_name + if self.hook_layers is not None: + # TODO(mkbehr): ensure the multilayer activation path makes sense + hook_name_str = f"{self.hook_name}_layers_{'_'.join(str(l) for l in self.hook_layers)}" self.new_cached_activations_path = _default_cached_activations_path( # type: ignore - self.dataset_path, self.model_name, self.hook_name, None + self.dataset_path, self.model_name, hook_name_str, None ) @property diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index b4a5096b7..e68d6347c 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -46,6 +46,7 @@ class ActivationsStore: tokens_column: Literal["tokens", "input_ids", "text", "problem"] hook_name: str hook_layer: int + hook_layers: list[int] hook_head_index: int | None _dataloader: Iterator[Any] | None = None _storage_buffer: torch.Tensor | None = None @@ -66,6 +67,7 @@ def from_cache_activations( dtype=cfg.dtype, hook_name=cfg.hook_name, hook_layer=cfg.hook_layer, + # TODO(mkbehr): set hook layers if set in cached activations context_size=cfg.context_size, d_in=cfg.d_in, n_batches_in_buffer=cfg.n_batches_in_buffer, @@ -126,6 +128,7 @@ def from_config( streaming=cfg.streaming, hook_name=cfg.hook_name, hook_layer=cfg.hook_layer, + hook_layers=cfg.hook_layers, hook_head_index=cfg.hook_head_index, context_size=cfg.context_size, d_in=cfg.d_in, @@ -198,6 +201,7 @@ def __init__( normalize_activations: str, device: torch.device, dtype: str, + hook_layers: list[int] | None = None, cached_activations_path: str | None = None, model_kwargs: dict[str, Any] | None = None, autocast_lm: bool = False, @@ -231,6 +235,7 @@ def __init__( self.hook_name = hook_name self.hook_layer = hook_layer + self.hook_layers = hook_layers or [hook_layer] self.hook_head_index = hook_head_index self.context_size = context_size self.d_in = d_in @@ -532,42 +537,55 @@ def get_activations(self, batch_tokens: torch.Tensor): else: autocast_if_enabled = contextlib.nullcontext() + # TODO(mkbehr): This is awkward. It may make more sense to + # have a list of hook names. + hook_names = [] + for layer in self.hook_layers: + if "{}" in self.hook_name: + hook_names.append(self.hook_name.format(layer)) + else: + hook_names.append(self.hook_name) + + stop_at_layer = max(self.hook_layers) + 1 + with autocast_if_enabled: layerwise_activations_cache = self.model.run_with_cache( batch_tokens, - names_filter=[self.hook_name], - stop_at_layer=self.hook_layer + 1, + names_filter=hook_names, + stop_at_layer=stop_at_layer, prepend_bos=False, **self.model_kwargs, )[1] - layerwise_activations = layerwise_activations_cache[self.hook_name][ - :, slice(*self.seqpos_slice) + layerwise_activations = [ + layerwise_activations_cache[hook_name][ + :, slice(*self.seqpos_slice) + ] + for hook_name in hook_names ] - n_batches, n_context = layerwise_activations.shape[:2] - - stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in)) + n_batches, n_context = layerwise_activations[0].shape[:2] if self.hook_head_index is not None: - stacked_activations[:, :, 0] = layerwise_activations[ - :, :, self.hook_head_index + layerwise_activations = [ + activation[:, :, self.hook_head_index] + for activation in layerwise_activations ] - elif layerwise_activations.ndim > 3: # if we have a head dimension + elif layerwise_activations[0].ndim > 3: # if we have a head dimension try: - stacked_activations[:, :, 0] = layerwise_activations.view( - n_batches, n_context, -1 - ) + layerwise_activations = [ + activation.view(n_batches, n_context, -1) + for activation in layerwise_activations + ] except RuntimeError as e: logger.error(f"Error during view operation: {e}") logger.info("Attempting to use reshape instead...") - stacked_activations[:, :, 0] = layerwise_activations.reshape( - n_batches, n_context, -1 - ) - else: - stacked_activations[:, :, 0] = layerwise_activations + layerwise_activations = [ + activation.reshape(n_batches, n_context, -1) + for activation in layerwise_activations + ] - return stacked_activations + return torch.stack(layerwise_activations, dim=2) def _load_buffer_from_cached( self, @@ -660,7 +678,7 @@ def get_buffer( batch_size = self.store_batch_size_prompts d_in = self.d_in total_size = batch_size * n_batches_in_buffer - num_layers = 1 + num_layers = len(self.hook_layers) if self.cached_activation_dataset is not None: return self._load_buffer_from_cached( diff --git a/tests/helpers.py b/tests/helpers.py index 6c3cdab3e..b37c03b07 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -16,6 +16,7 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False): model_name: str hook_name: str hook_layer: int + hook_layers: list[int] | None hook_head_index: int | None dataset_path: str dataset_trust_remote_code: bool @@ -54,6 +55,7 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: "model_name": TINYSTORIES_MODEL, "hook_name": "blocks.0.hook_mlp_out", "hook_layer": 0, + "hook_layers": None, "hook_head_index": None, # use a small, non-streaming dataset for testing. Huggingface gives too many requests errors otherwise. "dataset_path": NEEL_NANDA_C4_10K_DATASET, diff --git a/tests/training/test_activations_store_multilayer.py b/tests/training/test_activations_store_multilayer.py index 42c6c2a76..fc6d2d7ef 100644 --- a/tests/training/test_activations_store_multilayer.py +++ b/tests/training/test_activations_store_multilayer.py @@ -107,32 +107,23 @@ def test_activations_store_next_batch_multiple_layers(ts_model: HookedTransforme dataset = Dataset.from_list([{"text": "hello world"}] * 20) activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) - # Get a batch batch = activation_store.next_batch() + assert batch.shape == (10, len(cfg.hook_layers), activation_store.d_in) - # Check batch[0] shape: [batch_size, num_layers, d_in] - assert batch[0].shape == (10, len(activation_store.hook_layers), activation_store.d_in) - - # Verify the token IDs - assert batch[1] is not None - assert batch[1].shape == (10,) - - +@pytest.mark.skip("TODO(mkbehr): does activation need to be handled differently?") def test_activations_store_normalization_multiple_layers(ts_model: HookedTransformer): """Test normalization when using multiple layers.""" # Setup with normalization and multiple layers cfg = build_sae_cfg( hook_name="blocks.{}.hook_resid_pre", hook_layers=[0, 1, 2], - normalize_activations="constant_norm_rescale", + normalize_activations="expected_average_only_in", context_size=5 ) dataset = Dataset.from_list([{"text": "hello world"}] * 20) activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) - - # Set a fixed norm scaling factor for testing - activation_store.estimated_norm_scaling_factor = 2.0 + activation_store.set_norm_scaling_factor_if_needed() # Get a batch with normalized activations batch = activation_store.next_batch() @@ -140,7 +131,7 @@ def test_activations_store_normalization_multiple_layers(ts_model: HookedTransfo # Check that the activations have been properly normalized # The norm should be approximately sqrt(d_in) for each layer for layer_idx in range(len(activation_store.hook_layers)): - layer_activations = batch[0][:, layer_idx, :] + layer_activations = batch[:, layer_idx, :] # Check if average norm is approximately as expected (allowing for some variance) avg_norm = layer_activations.norm(dim=-1).mean() expected_norm = (activation_store.d_in ** 0.5) From d1c603bf4110d89942a4336de175f85f2201e8e9 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 6 Apr 2025 16:47:32 -0400 Subject: [PATCH 03/61] CrosscoderSAE implementation, some tests included tests: - test_crosscoder_sae_init - test_crosscoder_sae_fold_w_dec_norm hook_z excluded from tests --- sae_lens/crosscoder_sae.py | 126 ++++++++++++++++++++++++++ sae_lens/sae.py | 15 +-- tests/training/test_crosscoder_sae.py | 114 +++++++++++++++++++++++ 3 files changed, 249 insertions(+), 6 deletions(-) create mode 100644 sae_lens/crosscoder_sae.py create mode 100644 tests/training/test_crosscoder_sae.py diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py new file mode 100644 index 000000000..54fd80d24 --- /dev/null +++ b/sae_lens/crosscoder_sae.py @@ -0,0 +1,126 @@ +from dataclasses import dataclass +from typing import Any + +import einops +import torch +from jaxtyping import Float + +from sae_lens import SAEConfig, SAE + +@dataclass +class CrosscoderSAEConfig(SAEConfig): + hook_layers: list[int] = list + + # @classmethod + # def from_dict(cls, config_dict: dict[str, Any]) -> "CrosscoderSAEConfig": + # # TODO(mkbehr) is a new method needed here, or will the superclass's work w/o modification? I think it'll work. test it. + # pass + + def to_dict(self) -> dict[str, Any]: + # TODO(mkbehr) test + return super().to_dict() | { + "hook_layers": self.hook_layers, + } + +class CrosscoderSAE(SAE): + """ + TODO(mkbehr): docstring + """ + + # TODO(mkbehr): write + # - remaining encode methods + # - fold_activation_norm + # - hook_z reshaping support + + def __init__( + self, + cfg: CrosscoderSAEConfig, + use_error_term: bool = False, + ): + if cfg.architecture != "standard": + raise NotImplementedError("TODO(mkbehr): support other archs") + + super().__init__(cfg=cfg, use_error_term=use_error_term) + + if self.hook_z_reshaping_mode: + raise NotImplementedError("TODO(mkbehr): support hook_z") + + @classmethod + def from_dict(cls, config_dict: dict[str, Any]) -> "CrosscoderSAE": + return cls(CrosscoderSAEConfig.from_dict(config_dict)) + + def input_shape(self): + return (len(self.cfg.hook_layers), self.cfg.d_in) + + # TODO(mkbehr): in sae.py this is noted to output "... d_sae" but + # I think that's wrong + # TODO(mkbehr): I don't think we actually need to change this + def process_sae_in( + self, sae_in: Float[torch.Tensor, "... n_layers d_in"] + ) -> Float[torch.Tensor, "... n_layers d_in"]: + sae_in = sae_in.to(self.dtype) + # TODO(mkbehr): n.b. that reshape_fn_in is set to the identity + # if we're not doing hook_z reshaping + sae_in = self.reshape_fn_in(sae_in) + sae_in = self.hook_sae_input(sae_in) + sae_in = self.run_time_activation_norm_fn_in(sae_in) + return sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input) + + def encode_standard( + self, x: Float[torch.Tensor, "... n_layers d_in"] + ) -> Float[torch.Tensor, "... d_sae"]: + """ + Calculate SAE features from inputs + """ + # TODO(mkbehr): instead of changing this and the W_enc/b_enc + # dimensions, we could change reshape_fn_in + sae_in = self.process_sae_in(x) + + hidden_pre = self.hook_sae_acts_pre( + einops.einsum( + sae_in, self.W_enc, + "... n_layers d_in, n_layers d_in d_sae -> ... d_sae" + ) + + self.b_enc) + return self.hook_sae_acts_post(self.activation_fn(hidden_pre)) + + def decode( + self, feature_acts: Float[torch.Tensor, "... d_sae"] + ) -> Float[torch.Tensor, "... n_layers d_in"]: + """Decodes SAE feature activation tensor into a reconstructed + input activation tensor.""" + sae_out = self.hook_sae_recons( + einops.einsum( + self.apply_finetuning_scaling_factor(feature_acts), + self.W_dec, + "... d_sae, d_sae n_layers d_in -> ... n_layers d_in" + ) + self.b_dec + ) + + # handle run time activation normalization if needed + # will fail if you call this twice without calling encode in between. + sae_out = self.run_time_activation_norm_fn_out(sae_out) + + # handle hook z reshaping if needed. + return self.reshape_fn_out(sae_out, self.d_head) # type: ignore + + @torch.no_grad() + def fold_W_dec_norm(self): + # TODO(mkbehr) + # W_dec: d_sae, n_layers, d_in + # W_dec_norms: d_sae, 1, 1 + # W_enc: n_layers, d_in, d_sae + # desired W_enc_norms: 1, 1, d_sae + W_dec_norms = self.W_dec.norm(dim=[-2,-1], keepdim=True) + self.W_dec.data = self.W_dec.data / W_dec_norms + self.W_enc.data = self.W_enc.data * einops.rearrange( + W_dec_norms, "d_sae 1 1 -> 1 1 d_sae") + if self.cfg.architecture == "gated": + self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze() + self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze() + self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze() + elif self.cfg.architecture == "jumprelu": + self.threshold.data = self.threshold.data * W_dec_norms.squeeze() + self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze() + else: + self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze() diff --git a/sae_lens/sae.py b/sae_lens/sae.py index edd873b30..9785e557c 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -244,6 +244,9 @@ def run_time_activation_ln_out( self.setup() # Required for `HookedRootModule`s + def input_shape(self): + return (self.cfg.d_in,) + def initialize_weights_basic(self): # no config changes encoder bias init for now. self.b_enc = nn.Parameter( @@ -254,7 +257,7 @@ def initialize_weights_basic(self): self.W_dec = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device + self.cfg.d_sae, *self.input_shape(), dtype=self.dtype, device=self.device ) ) ) @@ -262,14 +265,14 @@ def initialize_weights_basic(self): self.W_enc = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device + *self.input_shape(), self.cfg.d_sae, dtype=self.dtype, device=self.device ) ) ) # methdods which change b_dec as a function of the dataset are implemented after init. self.b_dec = nn.Parameter( - torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device) + torch.zeros(*self.input_shape(), dtype=self.dtype, device=self.device) ) # scaling factor for fine-tuning (not to be used in initial training) @@ -284,7 +287,7 @@ def initialize_weights_gated(self): self.W_enc = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device + *self.input_shape(), self.cfg.d_sae, dtype=self.dtype, device=self.device ) ) ) @@ -304,13 +307,13 @@ def initialize_weights_gated(self): self.W_dec = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device + self.cfg.d_sae, *self.input_shape(), dtype=self.dtype, device=self.device ) ) ) self.b_dec = nn.Parameter( - torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device) + torch.zeros(*self.input_shape(), dtype=self.dtype, device=self.device) ) def initialize_weights_jumprelu(self): diff --git a/tests/training/test_crosscoder_sae.py b/tests/training/test_crosscoder_sae.py new file mode 100644 index 000000000..87c8b01c0 --- /dev/null +++ b/tests/training/test_crosscoder_sae.py @@ -0,0 +1,114 @@ +import os +from copy import deepcopy +from pathlib import Path + +import einops +import pytest +import torch +from torch import nn +from transformer_lens.hook_points import HookPoint + +from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.crosscoder_sae import CrosscoderSAE +from sae_lens.sae import _disable_hooks +from tests.helpers import ALL_ARCHITECTURES, build_sae_cfg + + +# Define a new fixture for different configurations +@pytest.fixture( + params=[ + { + "model_name": "tiny-stories-1M", + "dataset_path": "roneneldan/TinyStories", + "hook_name": "blocks.{}.hook_resid_pre", + "hook_layers": [1,2,3], + "d_in": 64, + }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "roneneldan/TinyStories", + "hook_name": "blocks.{}.hook_resid_pre", + "hook_layers": [1,2,3], + "d_in": 64, + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, + }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", + "hook_name": "blocks.{}.hook_resid_pre", + "hook_layers": [1,2,3], + "d_in": 64, + }, + # TODO(mkbehr): hook_z support + # { + # "model_name": "tiny-stories-1M", + # "dataset_path": "roneneldan/TinyStories", + # "hook_name": "blocks.{}.attn.hook_z", + # "hook_layers": [1,2,3], + # "d_in": 64, + # }, + ], + ids=[ + "tiny-stories-1M-resid-pre", + "tiny-stories-1M-resid-pre-L1-W-dec-Norm", + "tiny-stories-1M-resid-pre-pretokenized", + # "tiny-stories-1M-attn-out", + ], +) +def cfg(request: pytest.FixtureRequest): + """ + Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig. + """ + params = request.param + return build_sae_cfg(**params) + + +def test_crosscoder_sae_init(cfg: LanguageModelSAERunnerConfig): + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + + assert isinstance(sae, CrosscoderSAE) + + n_layers = len(cfg.hook_layers) + assert sae.W_enc.shape == (n_layers, cfg.d_in, cfg.d_sae) + assert sae.W_dec.shape == (cfg.d_sae, n_layers, cfg.d_in) + assert sae.b_enc.shape == (cfg.d_sae,) + assert sae.b_dec.shape == (n_layers, cfg.d_in) + + +def test_crosscoder_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. + assert sae.W_dec.norm(dim=[-2,-1]).mean().item() != pytest.approx(1.0, abs=1e-6) + sae2 = deepcopy(sae) + sae2.fold_W_dec_norm() + + W_dec_norms = sae.W_dec.norm(dim=[-2,-1], keepdim=True) + assert torch.allclose(sae2.W_dec.data, sae.W_dec.data / W_dec_norms) + assert torch.allclose(sae2.W_enc.data, + sae.W_enc.data * einops.rearrange( + W_dec_norms, "d_sae 1 1 -> 1 1 d_sae")) + assert torch.allclose(sae2.b_enc.data, sae.b_enc.data * W_dec_norms.squeeze()) + + # fold_W_dec_norm should normalize W_dec to have unit norm. + assert sae2.W_dec.norm(dim=[-2,-1]).mean().item() == pytest.approx(1.0, abs=1e-6) + + # we expect activations of features to differ by W_dec norm weights. + activations = torch.randn(10, 4, len(cfg.hook_layers), cfg.d_in, + device=cfg.device) + feature_activations_1 = sae.encode(activations) + feature_activations_2 = sae2.encode(activations) + + assert torch.allclose( + feature_activations_1.nonzero(), + feature_activations_2.nonzero(), + ) + + expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm(dim=[-2,-1]) + torch.testing.assert_close(feature_activations_2, expected_feature_activations_2) + + sae_out_1 = sae.decode(feature_activations_1) + sae_out_2 = sae2.decode(feature_activations_2) + + # but actual outputs should be the same + torch.testing.assert_close(sae_out_1, sae_out_2) From f8bf44e3778a2974b160486a8de1cd9303022ced Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 6 Apr 2025 17:00:37 -0400 Subject: [PATCH 04/61] more norm tests --- tests/training/test_crosscoder_sae.py | 126 ++++++++++++++++++++++++++ 1 file changed, 126 insertions(+) diff --git a/tests/training/test_crosscoder_sae.py b/tests/training/test_crosscoder_sae.py index 87c8b01c0..28328f07e 100644 --- a/tests/training/test_crosscoder_sae.py +++ b/tests/training/test_crosscoder_sae.py @@ -112,3 +112,129 @@ def test_crosscoder_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): # but actual outputs should be the same torch.testing.assert_close(sae_out_1, sae_out_2) + +@pytest.mark.parametrize("architecture", ALL_ARCHITECTURES) +@torch.no_grad() +def test_sae_fold_w_dec_norm_all_architectures(architecture: str): + if architecture != "standard": + pytest.xfail("TODO(mkbehr): support other architectures") + cfg = build_sae_cfg(architecture=architecture, hook_layers=[1,2,3]) + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. + + # make sure all parameters are not 0s + for param in sae.parameters(): + param.data = torch.rand_like(param) + + assert sae.W_dec.norm(dim=[-2,-1]).mean().item() != pytest.approx(1.0, abs=1e-6) + sae2 = deepcopy(sae) + sae2.fold_W_dec_norm() + + # fold_W_dec_norm should normalize W_dec to have unit norm. + assert sae2.W_dec.norm(dim=[-2,-1]).mean().item() == pytest.approx(1.0, abs=1e-6) + + # we expect activations of features to differ by W_dec norm weights. + activations = torch.randn(10, 4, len(cfg.hook_layers), cfg.d_in, device=cfg.device) + feature_activations_1 = sae.encode(activations) + feature_activations_2 = sae2.encode(activations) + + assert torch.allclose( + feature_activations_1.nonzero(), + feature_activations_2.nonzero(), + ) + + expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm(dim=[-2,-1]) + torch.testing.assert_close(feature_activations_2, expected_feature_activations_2) + + sae_out_1 = sae.decode(feature_activations_1) + sae_out_2 = sae2.decode(feature_activations_2) + + # but actual outputs should be the same + torch.testing.assert_close(sae_out_1, sae_out_2) + +@torch.no_grad() +def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): + norm_scaling_factor = 3.0 + + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + # make sure b_dec and b_enc are not 0s + sae.b_dec.data = torch.randn(len(cfg.hook_layers), cfg.d_in, device=cfg.device) + sae.b_enc.data = torch.randn(cfg.d_sae, device=cfg.device) # type: ignore + sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. + + sae2 = deepcopy(sae) + sae2.fold_activation_norm_scaling_factor(norm_scaling_factor) + + assert sae2.cfg.normalize_activations == "none" + + assert torch.allclose(sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor) + + # we expect activations of features to differ by W_dec norm weights. + # assume activations are already scaled + activations = torch.randn(10, 4, len(cfg.hook_layers), cfg.d_in, device=cfg.device) + # we divide to get the unscale activations + unscaled_activations = activations / norm_scaling_factor + + feature_activations_1 = sae.encode(activations) + # with the scaling folded in, the unscaled activations should produce the same + # result. + feature_activations_2 = sae2.encode(unscaled_activations) + + assert torch.allclose( + feature_activations_1.nonzero(), + feature_activations_2.nonzero(), + ) + + torch.testing.assert_close(feature_activations_2, feature_activations_1) + + sae_out_1 = sae.decode(feature_activations_1) + sae_out_2 = norm_scaling_factor * sae2.decode(feature_activations_2) + + # but actual outputs should be the same + torch.testing.assert_close(sae_out_1, sae_out_2) + + +@pytest.mark.parametrize("architecture", ALL_ARCHITECTURES) +@torch.no_grad() +def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): + if architecture != "standard": + pytest.xfail("TODO(mkbehr): support other architectures") + cfg = build_sae_cfg(architecture=architecture, hook_layers=[1,2,3]) + norm_scaling_factor = 3.0 + + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + # make sure all parameters are not 0s + for param in sae.parameters(): + param.data = torch.rand_like(param) + + sae2 = deepcopy(sae) + sae2.fold_activation_norm_scaling_factor(norm_scaling_factor) + + assert sae2.cfg.normalize_activations == "none" + + assert torch.allclose(sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor) + + # we expect activations of features to differ by W_dec norm weights. + # assume activations are already scaled + activations = torch.randn(10, 4, len(cfg.hook_layers), cfg.d_in, device=cfg.device) + # we divide to get the unscale activations + unscaled_activations = activations / norm_scaling_factor + + feature_activations_1 = sae.encode(activations) + # with the scaling folded in, the unscaled activations should produce the same + # result. + feature_activations_2 = sae2.encode(unscaled_activations) + + assert torch.allclose( + feature_activations_1.nonzero(), + feature_activations_2.nonzero(), + ) + + torch.testing.assert_close(feature_activations_2, feature_activations_1) + + sae_out_1 = sae.decode(feature_activations_1) + sae_out_2 = norm_scaling_factor * sae2.decode(feature_activations_2) + + # but actual outputs should be the same + torch.testing.assert_close(sae_out_1, sae_out_2) + From ce63c0b5f86eee39b2c28eea0d247dcc611dc941 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 6 Apr 2025 17:10:07 -0400 Subject: [PATCH 05/61] save-and-load support --- sae_lens/sae.py | 2 +- tests/training/test_crosscoder_sae.py | 75 +++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 9785e557c..93dea38c5 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -643,7 +643,7 @@ def from_pretrained( ) cfg_dict = handle_config_defaulting(cfg_dict) - sae = cls(SAEConfig.from_dict(cfg_dict)) + sae = cls.from_dict(cfg_dict) sae.process_state_dict_for_loading(state_dict) sae.load_state_dict(state_dict) diff --git a/tests/training/test_crosscoder_sae.py b/tests/training/test_crosscoder_sae.py index 28328f07e..0ca424e07 100644 --- a/tests/training/test_crosscoder_sae.py +++ b/tests/training/test_crosscoder_sae.py @@ -238,3 +238,78 @@ def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): # but actual outputs should be the same torch.testing.assert_close(sae_out_1, sae_out_2) +def test_sae_save_and_load_from_pretrained(tmp_path: Path) -> None: + cfg = build_sae_cfg(hook_layers=[1,2,3]) + model_path = str(tmp_path) + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + sae_state_dict = sae.state_dict() + sae.save_model(model_path) + + assert os.path.exists(model_path) + + sae_loaded = CrosscoderSAE.load_from_pretrained(model_path, device="cpu") + + sae_loaded_state_dict = sae_loaded.state_dict() + + # check state_dict matches the original + for key in sae.state_dict(): + assert torch.allclose( + sae_state_dict[key], + sae_loaded_state_dict[key], + ) + + sae_in = torch.randn(10, len(cfg.hook_layers), cfg.d_in, device=cfg.device) + sae_out_1 = sae(sae_in) + sae_out_2 = sae_loaded(sae_in) + assert torch.allclose(sae_out_1, sae_out_2) + +@pytest.mark.xfail(reason="TODO(mkbehr): support other architectures") +def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None: + cfg = build_sae_cfg(architecture="gated", hook_layers=[1,2,3]) + model_path = str(tmp_path) + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + sae_state_dict = sae.state_dict() + sae.save_model(model_path) + + assert os.path.exists(model_path) + + sae_loaded = CrosscoderSAE.load_from_pretrained(model_path, device="cpu") + + sae_loaded_state_dict = sae_loaded.state_dict() + + # check state_dict matches the original + for key in sae.state_dict(): + assert torch.allclose( + sae_state_dict[key], + sae_loaded_state_dict[key], + ) + + sae_in = torch.randn(10, len(cfg.hook_layers), cfg.d_in, device=cfg.device) + sae_out_1 = sae(sae_in) + sae_out_2 = sae_loaded(sae_in) + assert torch.allclose(sae_out_1, sae_out_2) + +def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: + cfg = build_sae_cfg(activation_fn_kwargs={"k": 30}, hook_layers=[1,2,3]) + model_path = str(tmp_path) + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + sae_state_dict = sae.state_dict() + sae.save_model(model_path) + + assert os.path.exists(model_path) + + sae_loaded = CrosscoderSAE.load_from_pretrained(model_path, device="cpu") + + sae_loaded_state_dict = sae_loaded.state_dict() + + # check state_dict matches the original + for key in sae.state_dict(): + assert torch.allclose( + sae_state_dict[key], + sae_loaded_state_dict[key], + ) + + sae_in = torch.randn(10, len(cfg.hook_layers), cfg.d_in, device=cfg.device) + sae_out_1 = sae(sae_in) + sae_out_2 = sae_loaded(sae_in) + assert torch.allclose(sae_out_1, sae_out_2) From 51bfac7ed88f6a53bb98a4f05c5f288124dc495c Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 6 Apr 2025 17:17:14 -0400 Subject: [PATCH 06/61] WIP name --- sae_lens/crosscoder_sae.py | 5 +++++ tests/training/test_crosscoder_sae.py | 5 +++++ 2 files changed, 10 insertions(+) diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index 54fd80d24..19e84e2f3 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -45,6 +45,11 @@ def __init__( if self.hook_z_reshaping_mode: raise NotImplementedError("TODO(mkbehr): support hook_z") + def get_name(self): + # TODO(mkbehr): think about the correct name + layers = ','.join([str(l) for l in self.cfg.hook_layers]) + return f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_layers{layers}_{self.cfg.d_sae}" + @classmethod def from_dict(cls, config_dict: dict[str, Any]) -> "CrosscoderSAE": return cls(CrosscoderSAEConfig.from_dict(config_dict)) diff --git a/tests/training/test_crosscoder_sae.py b/tests/training/test_crosscoder_sae.py index 0ca424e07..e1443ad45 100644 --- a/tests/training/test_crosscoder_sae.py +++ b/tests/training/test_crosscoder_sae.py @@ -313,3 +313,8 @@ def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: sae_out_1 = sae(sae_in) sae_out_2 = sae_loaded(sae_in) assert torch.allclose(sae_out_1, sae_out_2) + +def test_sae_get_name_returns_correct_name_from_cfg_vals() -> None: + cfg = build_sae_cfg(model_name="test_model", hook_name="blocks.{}.test_hook_name", d_sae=128, hook_layers=[1,2,3]) + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + assert sae.get_name() == "sae_test_model_blocks.{}.test_hook_name_layers1,2,3_128" From d1188877b616513f6805c966f152e0fe2a1735ff Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 15:09:56 -0400 Subject: [PATCH 07/61] TrainingCrosscoderSAE implementation, decoder norm scaling test --- sae_lens/training/training_crosscoder_sae.py | 262 ++++++++++++++++++ sae_lens/training/training_sae.py | 13 +- .../training/test_training_crosscoder_sae.py | 38 +++ 3 files changed, 309 insertions(+), 4 deletions(-) create mode 100644 sae_lens/training/training_crosscoder_sae.py create mode 100644 tests/training/test_training_crosscoder_sae.py diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py new file mode 100644 index 000000000..ece8dfb7c --- /dev/null +++ b/sae_lens/training/training_crosscoder_sae.py @@ -0,0 +1,262 @@ +import json +import os +from dataclasses import dataclass +from typing import Any + +import einops +import torch +from jaxtyping import Float + +from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.crosscoder_sae import CrosscoderSAE, CrosscoderSAEConfig +from sae_lens.training.training_sae import ( + TrainingSAEConfig, + TrainingSAE, + TrainStepOutput, + ) +from sae_lens.toolkit.pretrained_sae_loaders import ( + handle_config_defaulting, + read_sae_from_disk, +) + +SPARSITY_PATH = "sparsity.safetensors" +SAE_WEIGHTS_PATH = "sae_weights.safetensors" +SAE_CFG_PATH = "cfg.json" + + +# TODO(mkbehr) will this multiple inheritance work? +@dataclass(kw_only=True) +class TrainingCrosscoderSAEConfig(CrosscoderSAEConfig, TrainingSAEConfig): + sparsity_penalty_decoder_norm_lp_norm: float = 1 + + # TODO(mkbehr): copypasting from TrainingSAEConfig and adding a few + # params. There should be a better way. + @classmethod + def from_sae_runner_config( + cls, cfg: LanguageModelSAERunnerConfig + ) -> "TrainingSAEConfig": + return cls( + # base config + architecture=cfg.architecture, + d_in=cfg.d_in, + d_sae=cfg.d_sae, # type: ignore + dtype=cfg.dtype, + device=cfg.device, + model_name=cfg.model_name, + hook_name=cfg.hook_name, + hook_layer=cfg.hook_layer, + hook_layers=cfg.hook_layers, + hook_head_index=cfg.hook_head_index, + activation_fn_str=cfg.activation_fn, + activation_fn_kwargs=cfg.activation_fn_kwargs, + apply_b_dec_to_input=cfg.apply_b_dec_to_input, + finetuning_scaling_factor=cfg.finetuning_method is not None, + sae_lens_training_version=cfg.sae_lens_training_version, + context_size=cfg.context_size, + dataset_path=cfg.dataset_path, + prepend_bos=cfg.prepend_bos, + seqpos_slice=cfg.seqpos_slice, + # Training cfg + l1_coefficient=cfg.l1_coefficient, + lp_norm=cfg.lp_norm, + use_ghost_grads=cfg.use_ghost_grads, + normalize_sae_decoder=cfg.normalize_sae_decoder, + noise_scale=cfg.noise_scale, + decoder_orthogonal_init=cfg.decoder_orthogonal_init, + mse_loss_normalization=cfg.mse_loss_normalization, + decoder_heuristic_init=cfg.decoder_heuristic_init, + init_encoder_as_decoder_transpose=cfg.init_encoder_as_decoder_transpose, + scale_sparsity_penalty_by_decoder_norm=cfg.scale_sparsity_penalty_by_decoder_norm, + normalize_activations=cfg.normalize_activations, + dataset_trust_remote_code=cfg.dataset_trust_remote_code, + model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {}, + jumprelu_init_threshold=cfg.jumprelu_init_threshold, + jumprelu_bandwidth=cfg.jumprelu_bandwidth, + ) + + def to_dict(self) -> dict[str, Any]: + # TODO(mkbehr): double-check this multiple inheritance. seems messy. + return (TrainingSAE.to_dict(self) + | CrosscoderSAE.to_dict(self) + | { + "sparsity_penalty_decoder_norm_lp_norm": + self.sparsity_penalty_decoder_norm_lp_norm, + }) + + def get_base_sae_cfg_dict(self) -> dict[str, Any]: + return (TrainingSAEConfig.get_base_sae_cfg_dict(self) + | { "hook_layers": self.hook_layers }) + +class TrainingCrosscoderSAE(CrosscoderSAE, TrainingSAE): + # TODO(mkbehr) future implementation + # initialize_weights_jumprelu (can maybe just use input shape in trainingsae) + # encode_with_hidden_pre_{gated,jumprelu} + # calculate_topk_aux_loss + # calculate_ghost_grad_loss + # fold_W_dec_norm for jumprelu + + def __init__(self, + cfg: TrainingCrosscoderSAEConfig, + use_error_term: bool = False): + print(cfg) + super().__init__(cfg, use_error_term=use_error_term) + + @classmethod + def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE": + return cls(TrainingCrosscoderSAEConfig.from_dict(config_dict)) + + # TODO(mkbehr): hacking around multiple inheritance. there's + # probably a better way. + @staticmethod + def base_sae_cfg(cfg: TrainingCrosscoderSAEConfig): + return CrosscoderSAEConfig.from_dict(cfg.get_base_sae_cfg_dict()) + + def check_cfg_compatibility(self): + if self.cfg.architecture != "standard": + raise NotImplementedError("TODO(mkbehr): support other archs") + if not self.cfg.scale_sparsity_penalty_by_decoder_norm: + raise ValueError("Crosscoders require scale_sparsity_penalty_by_decoder_norm") + if not self.use_error_term: + raise NotImplementedError("TODO(mkbehr): support causal crosscoders") + super().check_cfg_compatibility() + + def encode_with_hidden_pre( + self, x: Float[torch.Tensor, "... n_layers d_in"] + ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: + sae_in = self.process_sae_in(x) + + hidden_pre = self.hook_sae_acts_pre( + einops.einsum( + sae_in, self.W_enc, + "... n_layers d_in, n_layers d_in d_sae -> ... d_sae" + ) + + self.b_enc) + hidden_pre_noised = hidden_pre + ( + torch.randn_like(hidden_pre) * self.cfg.noise_scale * self.training + ) + feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised)) + + return feature_acts, hidden_pre_noised + + def training_forward_pass( + self, + sae_in: torch.Tensor, + current_l1_coefficient: float, + dead_neuron_mask: torch.Tensor | None = None, + ) -> TrainStepOutput: + # do a forward pass to get SAE out, but we also need the + # hidden pre. + feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in) + sae_out = self.decode(feature_acts) + + # MSE LOSS + per_item_mse_loss = self.mse_loss_fn(sae_out, sae_in) + mse_loss = per_item_mse_loss.sum(dim=-1).mean() + + losses: dict[str, float | torch.Tensor] = {} + + assert self.cfg.scale_sparsity_penalty_by_decoder_norm + decoder_norms = self.W_dec.norm(dim=2) + feature_act_weights = decoder_norms.norm( + p=self.cfg.sparsity_penalty_decoder_norm_lp_norm, + dim=1 + ) + weighted_feature_acts = feature_acts * feature_act_weights + sparsity = weighted_feature_acts.norm( + p=self.cfg.lp_norm, dim=-1 + ) # sum over the feature dimension + + l1_loss = (current_l1_coefficient * sparsity).mean() + loss = mse_loss + l1_loss + if ( + self.cfg.use_ghost_grads + and self.training + and dead_neuron_mask is not None + ): + ghost_grad_loss = self.calculate_ghost_grad_loss( + x=sae_in, + sae_out=sae_out, + per_item_mse_loss=per_item_mse_loss, + hidden_pre=hidden_pre, + dead_neuron_mask=dead_neuron_mask, + ) + losses["ghost_grad_loss"] = ghost_grad_loss + loss = loss + ghost_grad_loss + losses["l1_loss"] = l1_loss + + losses["mse_loss"] = mse_loss + + return TrainStepOutput( + sae_in=sae_in, + sae_out=sae_out, + feature_acts=feature_acts, + hidden_pre=hidden_pre, + loss=loss, + losses=losses, + ) + + @classmethod + def load_from_pretrained( + cls, + path: str, + device: str = "cpu", + dtype: str | None = None, + ) -> "TrainingCrosscoderSAE": + # get the config + config_path = os.path.join(path, SAE_CFG_PATH) + with open(config_path) as f: + cfg_dict = json.load(f) + cfg_dict = handle_config_defaulting(cfg_dict) + cfg_dict["device"] = device + if dtype is not None: + cfg_dict["dtype"] = dtype + + weight_path = os.path.join(path, SAE_WEIGHTS_PATH) + cfg_dict, state_dict = read_sae_from_disk( + cfg_dict=cfg_dict, + weight_path=weight_path, + device=device, + ) + sae_cfg = TrainingCrosscoderSAEConfig.from_dict(cfg_dict) + + sae = cls(sae_cfg) + sae.process_state_dict_for_loading(state_dict) + sae.load_state_dict(state_dict) + + return sae + + @torch.no_grad() + def set_decoder_norm_to_unit_norm(self): + self.W_dec.data /= torch.norm(self.W_dec.data, dim=[1,2], keepdim=True) + + @torch.no_grad() + def initialize_decoder_norm_constant_norm(self, norm: float = 0.1): + """ + A heuristic proceedure inspired by: + https://transformer-circuits.pub/2024/april-update/index.html#training-saes + """ + # TODO: Parameterise this as a function of m and n + + # ensure W_dec norms at unit norm + self.W_dec.data /= torch.norm(self.W_dec.data, dim=[1,2], keepdim=True) + self.W_dec.data *= norm # will break tests but do this for now. + + @torch.no_grad() + def remove_gradient_parallel_to_decoder_directions(self): + """ + Update grads so that they remove the parallel component + (d_sae, n_layers, d_in) shape + """ + assert self.W_dec.grad is not None # keep pyright happy + + parallel_component = einops.einsum( + self.W_dec.grad, + self.W_dec.data, + "d_sae n_layers d_in, d_sae n_layers d_in -> d_sae", + ) + self.W_dec.grad -= einops.einsum( + parallel_component, + self.W_dec.data, + "d_sae, d_sae n_layers d_in -> d_sae n_layers d_in", + ) + diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index ba51ab843..dea54838d 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -244,8 +244,7 @@ class TrainingSAE(SAE): device: torch.device def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False): - base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict()) - super().__init__(base_sae_cfg) + super().__init__(self.base_sae_cfg(cfg), use_error_term=use_error_term) self.cfg = cfg # type: ignore if cfg.architecture == "standard" or cfg.architecture == "topk": @@ -291,6 +290,12 @@ def threshold(self) -> torch.Tensor: def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE": return cls(TrainingSAEConfig.from_dict(config_dict)) + # TODO(mkbehr): hacking around multiple inheritance. there's + # probably a better way. + @staticmethod + def base_sae_cfg(cfg: TrainingSAEConfig): + return SAEConfig.from_dict(cfg.get_base_sae_cfg_dict()) + def check_cfg_compatibility(self): if self.cfg.architecture != "standard" and self.cfg.use_ghost_grads: raise ValueError(f"{self.cfg.architecture} SAEs do not support ghost grads") @@ -597,7 +602,7 @@ def initialize_weights_complex(self): elif self.cfg.decoder_heuristic_init: self.W_dec = nn.Parameter( torch.rand( - self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device + self.cfg.d_sae, *self.input_shape(), dtype=self.dtype, device=self.device ) ) self.initialize_decoder_norm_constant_norm( @@ -611,7 +616,7 @@ def initialize_weights_complex(self): self.W_enc = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - self.cfg.d_in, + *self.input_shape(), self.cfg.d_sae, dtype=self.dtype, device=self.device, diff --git a/tests/training/test_training_crosscoder_sae.py b/tests/training/test_training_crosscoder_sae.py new file mode 100644 index 000000000..4a30ad3f2 --- /dev/null +++ b/tests/training/test_training_crosscoder_sae.py @@ -0,0 +1,38 @@ +import pytest +import torch + +from sae_lens.crosscoder_sae import CrosscoderSAE +from sae_lens.training.training_crosscoder_sae import ( + TrainingCrosscoderSAE, + TrainingCrosscoderSAEConfig, +) +from tests.helpers import build_sae_cfg + +def test_TrainingCrosscoderSAE_training_forward_pass_can_scale_sparsity_penalty_by_decoder_norm(): + cfg = build_sae_cfg( + d_in=3, + d_sae=5, + hook_layers=[1,2,3,4], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + ) + training_sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), + use_error_term=True, + ) + x = torch.randn(32, 4, 3) + train_step_output = training_sae.training_forward_pass( + sae_in=x, + current_l1_coefficient=2.0, + ) + feature_acts = train_step_output.feature_acts + decoder_norms = training_sae.W_dec.norm(dim=-1) + decoder_norm = decoder_norms.sum(dim=-1) + # double-check decoder norm is not all ones, or this test is pointless + assert not torch.allclose(decoder_norm, torch.ones_like(decoder_norm), atol=1e-2) + scaled_feature_acts = feature_acts * decoder_norm + + assert ( + pytest.approx(train_step_output.losses["l1_loss"].detach().item()) # type: ignore + == 2.0 * scaled_feature_acts.norm(p=1, dim=1).mean().detach().item() + ) From 92ae2bde0091f1e277583dfbe3811f4dce69c5a0 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 15:16:44 -0400 Subject: [PATCH 08/61] test_TrainingCrosscoderSAE_encode_returns_same_value_as_encode_with_hidden_pre --- .../training/test_training_crosscoder_sae.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/training/test_training_crosscoder_sae.py b/tests/training/test_training_crosscoder_sae.py index 4a30ad3f2..890884957 100644 --- a/tests/training/test_training_crosscoder_sae.py +++ b/tests/training/test_training_crosscoder_sae.py @@ -8,13 +8,18 @@ ) from tests.helpers import build_sae_cfg +def build_crosscoder_sae_cfg(**kwargs): + return build_sae_cfg( + **(kwargs | { + "hook_layers": [1,2,3,4], + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, + })) + def test_TrainingCrosscoderSAE_training_forward_pass_can_scale_sparsity_penalty_by_decoder_norm(): - cfg = build_sae_cfg( + cfg = build_crosscoder_sae_cfg( d_in=3, d_sae=5, - hook_layers=[1,2,3,4], - normalize_sae_decoder=False, - scale_sparsity_penalty_by_decoder_norm=True, ) training_sae = TrainingCrosscoderSAE( TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), @@ -36,3 +41,19 @@ def test_TrainingCrosscoderSAE_training_forward_pass_can_scale_sparsity_penalty_ pytest.approx(train_step_output.losses["l1_loss"].detach().item()) # type: ignore == 2.0 * scaled_feature_acts.norm(p=1, dim=1).mean().detach().item() ) + +@pytest.mark.parametrize("architecture", ["standard", "gated", "jumprelu", "topk"]) +def test_TrainingCrosscoderSAE_encode_returns_same_value_as_encode_with_hidden_pre( + architecture: str, +): + if architecture != "standard": + pytest.xfail("TODO(mkbehr): support other architectures") + cfg = build_crosscoder_sae_cfg(architecture=architecture) + sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), + use_error_term=True, + ) + x = torch.randn(32, len(cfg.hook_layers), cfg.d_in) + encode_out = sae.encode(x) + encode_with_hidden_pre_out = sae.encode_with_hidden_pre_fn(x)[0] + assert torch.allclose(encode_out, encode_with_hidden_pre_out) From 91190e531658981e76b5143e19553be71a30d107 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 16:09:03 -0400 Subject: [PATCH 09/61] test_sae_forward --- sae_lens/training/training_crosscoder_sae.py | 1 - .../training/test_crosscoder_sae_training.py | 151 ++++++++++++++++++ 2 files changed, 151 insertions(+), 1 deletion(-) create mode 100644 tests/training/test_crosscoder_sae_training.py diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py index ece8dfb7c..de3e93705 100644 --- a/sae_lens/training/training_crosscoder_sae.py +++ b/sae_lens/training/training_crosscoder_sae.py @@ -98,7 +98,6 @@ class TrainingCrosscoderSAE(CrosscoderSAE, TrainingSAE): def __init__(self, cfg: TrainingCrosscoderSAEConfig, use_error_term: bool = False): - print(cfg) super().__init__(cfg, use_error_term=use_error_term) @classmethod diff --git a/tests/training/test_crosscoder_sae_training.py b/tests/training/test_crosscoder_sae_training.py new file mode 100644 index 000000000..b7cc3b8af --- /dev/null +++ b/tests/training/test_crosscoder_sae_training.py @@ -0,0 +1,151 @@ +from typing import Any + +import einops +import pytest +import torch +from datasets import Dataset +from transformer_lens import HookedTransformer + +from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.sae_trainer import SAETrainer +from sae_lens.training.training_crosscoder_sae import ( + TrainingCrosscoderSAE, + TrainingCrosscoderSAEConfig +) +from tests.helpers import build_sae_cfg + + +# Define a new fixture for different configurations +@pytest.fixture( + params=[ + { + "model_name": "tiny-stories-1M", + "dataset_path": "roneneldan/TinyStories", + "hook_name": "blocks.1.hook_resid_pre", + "hook_layers": [1,2,3], + "d_in": 64, + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, + }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", + "hook_name": "blocks.1.hook_resid_pre", + "hook_layers": [1,2,3], + "d_in": 64, + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, + }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", + "hook_name": "blocks.1.hook_resid_pre", + "hook_layers": [1,2,3], + "d_in": 64, + "normalize_activations": "constant_norm_rescale", + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, + }, + ], + ids=[ + "tiny-stories-1M-resid-pre", + "tiny-stories-1M-resid-pre-pretokenized", + "tiny-stories-1M-resid-pre-pretokenized-norm-rescale", + ], +) +def cfg(request: pytest.FixtureRequest): + """ + Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig. + """ + params = request.param + return build_sae_cfg(**params) + + +@pytest.fixture +def training_crosscoder_sae(cfg: LanguageModelSAERunnerConfig): + """ + Pytest fixture to create a mock instance of SparseAutoencoder. + """ + return TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), + use_error_term=True) + + +@pytest.fixture +def activation_store(model: HookedTransformer, cfg: LanguageModelSAERunnerConfig): + return ActivationsStore.from_config( + model, cfg, override_dataset=Dataset.from_list([{"text": "hello world"}] * 2000) + ) + + +@pytest.fixture +def model(cfg: LanguageModelSAERunnerConfig): + return HookedTransformer.from_pretrained(cfg.model_name, device="cpu") + + +# todo: remove the need for this fixture +@pytest.fixture +def trainer( + cfg: LanguageModelSAERunnerConfig, + training_crosscoder_sae: TrainingCrosscoderSAE, + model: HookedTransformer, + activation_store: ActivationsStore, +): + return SAETrainer( + model=model, + sae=training_crosscoder_sae, + activation_store=activation_store, + save_checkpoint_fn=lambda *args, **kwargs: None, # noqa: ARG005 + cfg=cfg, + ) + +def test_sae_forward(training_crosscoder_sae: TrainingCrosscoderSAE): + batch_size = 32 + d_in = training_crosscoder_sae.cfg.d_in + n_layers = len(training_crosscoder_sae.cfg.hook_layers) + d_sae = training_crosscoder_sae.cfg.d_sae + + x = torch.randn(batch_size, n_layers, d_in) + train_step_output = training_crosscoder_sae.training_forward_pass( + sae_in=x, + current_l1_coefficient=training_crosscoder_sae.cfg.l1_coefficient, + ) + + assert train_step_output.sae_out.shape == (batch_size, n_layers, d_in) + assert train_step_output.feature_acts.shape == (batch_size, d_sae) + assert ( + pytest.approx(train_step_output.loss.detach(), rel=1e-3) + == ( + train_step_output.losses["mse_loss"] + + train_step_output.losses["l1_loss"] + + train_step_output.losses.get("ghost_grad_loss", 0.0) + ) + .detach() # type: ignore + .cpu() + .numpy() + ) + + expected_mse_loss = ( + (torch.pow((train_step_output.sae_out - x.float()), 2)) + .sum(dim=-1) + .mean() + .detach() + .float() + ) + + assert ( + pytest.approx(train_step_output.losses["mse_loss"].item()) == expected_mse_loss # type: ignore + ) + + expected_l1_loss = ( + (train_step_output.feature_acts + * training_crosscoder_sae.W_dec.norm(dim=2).sum(dim=1)) + .norm(dim=1, p=1) + .mean() + ) + assert ( + pytest.approx(train_step_output.losses["l1_loss"].item(), rel=1e-3) # type: ignore + == training_crosscoder_sae.cfg.l1_coefficient * expected_l1_loss.detach().float() + ) + From a73e8a7d32b37c15b035fe23bb7e6ddfa902dc85 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 16:11:18 -0400 Subject: [PATCH 10/61] test_sae_forward_with_mse_loss_norm --- .../training/test_crosscoder_sae_training.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/training/test_crosscoder_sae_training.py b/tests/training/test_crosscoder_sae_training.py index b7cc3b8af..0d391146b 100644 --- a/tests/training/test_crosscoder_sae_training.py +++ b/tests/training/test_crosscoder_sae_training.py @@ -149,3 +149,64 @@ def test_sae_forward(training_crosscoder_sae: TrainingCrosscoderSAE): == training_crosscoder_sae.cfg.l1_coefficient * expected_l1_loss.detach().float() ) + +def test_sae_forward_with_mse_loss_norm( + training_crosscoder_sae: TrainingCrosscoderSAE, +): + # change the confgi and ensure the mse loss is calculated correctly + training_crosscoder_sae.cfg.mse_loss_normalization = "dense_batch" + training_crosscoder_sae.mse_loss_fn = training_crosscoder_sae._get_mse_loss_fn() + + batch_size = 32 + d_in = training_crosscoder_sae.cfg.d_in + n_layers = len(training_crosscoder_sae.cfg.hook_layers) + d_sae = training_crosscoder_sae.cfg.d_sae + + x = torch.randn(batch_size, n_layers, d_in) + train_step_output = training_crosscoder_sae.training_forward_pass( + sae_in=x, + current_l1_coefficient=training_crosscoder_sae.cfg.l1_coefficient, + ) + + assert train_step_output.sae_out.shape == (batch_size, n_layers, d_in) + assert train_step_output.feature_acts.shape == (batch_size, d_sae) + assert "ghost_grad_loss" not in train_step_output.losses + + x_centred = x - x.mean(dim=0, keepdim=True) + expected_mse_loss = ( + ( + torch.nn.functional.mse_loss(train_step_output.sae_out, x, reduction="none") + / (1e-6 + x_centred.norm(dim=-1, keepdim=True)) + ) + .sum(dim=-1) + .mean() + .detach() + .item() + ) + + assert ( + pytest.approx(train_step_output.losses["mse_loss"].item()) == expected_mse_loss # type: ignore + ) + + assert ( + pytest.approx(train_step_output.loss.detach(), rel=1e-3) + == ( + train_step_output.losses["mse_loss"] + + train_step_output.losses["l1_loss"] + + train_step_output.losses.get("ghost_grad_loss", 0.0) + ) + .detach() # type: ignore + .numpy() + ) + + expected_l1_loss = ( + (train_step_output.feature_acts * + training_crosscoder_sae.W_dec.norm(dim=2).sum(dim=1)) + .norm(dim=1, p=1) + .mean() + ) + assert ( + pytest.approx(train_step_output.losses["l1_loss"].item(), rel=1e-3) # type: ignore + == training_crosscoder_sae.cfg.l1_coefficient * expected_l1_loss.detach().float() + ) + From f7149f459f993d7b858649d52008dc132f39da1d Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 16:16:53 -0400 Subject: [PATCH 11/61] mark ghost grads unsupported --- sae_lens/training/training_crosscoder_sae.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py index de3e93705..f5c7ee5aa 100644 --- a/sae_lens/training/training_crosscoder_sae.py +++ b/sae_lens/training/training_crosscoder_sae.py @@ -117,6 +117,8 @@ def check_cfg_compatibility(self): raise ValueError("Crosscoders require scale_sparsity_penalty_by_decoder_norm") if not self.use_error_term: raise NotImplementedError("TODO(mkbehr): support causal crosscoders") + if self.cfg.use_ghost_grads: + raise NotImplementedError("TODO(mkbehr): support ghost grads") super().check_cfg_compatibility() def encode_with_hidden_pre( From 229192f94908510beb4df59e36ebff06ef75a98c Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 16:24:01 -0400 Subject: [PATCH 12/61] fix hook name in tests --- tests/training/test_crosscoder_sae_training.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/training/test_crosscoder_sae_training.py b/tests/training/test_crosscoder_sae_training.py index 0d391146b..163235ef8 100644 --- a/tests/training/test_crosscoder_sae_training.py +++ b/tests/training/test_crosscoder_sae_training.py @@ -22,7 +22,7 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "hook_name": "blocks.1.hook_resid_pre", + "hook_name": "blocks.{}.hook_resid_pre", "hook_layers": [1,2,3], "d_in": 64, "normalize_sae_decoder": False, @@ -31,7 +31,7 @@ { "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", - "hook_name": "blocks.1.hook_resid_pre", + "hook_name": "blocks.{}.hook_resid_pre", "hook_layers": [1,2,3], "d_in": 64, "normalize_sae_decoder": False, @@ -40,7 +40,7 @@ { "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", - "hook_name": "blocks.1.hook_resid_pre", + "hook_name": "blocks.{}.hook_resid_pre", "hook_layers": [1,2,3], "d_in": 64, "normalize_activations": "constant_norm_rescale", From 8d36da429e3e18192243306db40515c39826f451 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 16:24:32 -0400 Subject: [PATCH 13/61] can_add_noise_to_hidden_pre test --- .../training/test_crosscoder_sae_training.py | 38 +++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/training/test_crosscoder_sae_training.py b/tests/training/test_crosscoder_sae_training.py index 163235ef8..2a5a92c12 100644 --- a/tests/training/test_crosscoder_sae_training.py +++ b/tests/training/test_crosscoder_sae_training.py @@ -210,3 +210,41 @@ def test_sae_forward_with_mse_loss_norm( == training_crosscoder_sae.cfg.l1_coefficient * expected_l1_loss.detach().float() ) + +def test_SparseAutoencoder_forward_can_add_noise_to_hidden_pre() -> None: + clean_cfg = build_sae_cfg( + d_in=2, + d_sae=4, + noise_scale=0, + hook_layers=[1,2,3,4,5], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True + ) + noisy_cfg = build_sae_cfg( + d_in=2, + d_sae=4, + noise_scale=100, + hook_layers=[1,2,3,4,5], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True + ) + clean_sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(clean_cfg), + use_error_term=True) + noisy_sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(noisy_cfg), + use_error_term=True) + + input = torch.randn(3, 5, 2) + + clean_output1 = clean_sae.forward(input) + clean_output2 = clean_sae.forward(input) + noisy_output1 = noisy_sae.forward(input) + noisy_output2 = noisy_sae.forward(input) + + # with no noise, the outputs should be identical + assert torch.allclose(clean_output1, clean_output2) + # noisy outputs should be different + assert not torch.allclose(noisy_output1, noisy_output2) + assert not torch.allclose(clean_output1, noisy_output1) + From 5c694e03bc47650d02a7babb8bfaa9a8c8588047 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 16:29:50 -0400 Subject: [PATCH 14/61] b_dec init note --- sae_lens/sae_training_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sae_lens/sae_training_runner.py b/sae_lens/sae_training_runner.py index c6a282b36..8f754a16e 100644 --- a/sae_lens/sae_training_runner.py +++ b/sae_lens/sae_training_runner.py @@ -160,6 +160,7 @@ def run_trainer_with_interruption_handling(self, trainer: SAETrainer): return sae # TODO: move this into the SAE trainer or Training SAE class + # TODO(mkbehr): support crosscoders. def _init_sae_group_b_decs( self, ) -> None: From e38de51add997f05c4942092b44e5cb890f853f1 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 17:06:11 -0400 Subject: [PATCH 15/61] fix from_dict --- sae_lens/training/training_crosscoder_sae.py | 7 +++++-- sae_lens/training/training_sae.py | 2 +- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py index f5c7ee5aa..a9cf265a6 100644 --- a/sae_lens/training/training_crosscoder_sae.py +++ b/sae_lens/training/training_crosscoder_sae.py @@ -101,8 +101,11 @@ def __init__(self, super().__init__(cfg, use_error_term=use_error_term) @classmethod - def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE": - return cls(TrainingCrosscoderSAEConfig.from_dict(config_dict)) + def from_dict(cls, + config_dict: dict[str, Any], + use_error_term: bool = False) -> "TrainingSAE": + return cls(TrainingCrosscoderSAEConfig.from_dict(config_dict), + use_error_term = use_error_term) # TODO(mkbehr): hacking around multiple inheritance. there's # probably a better way. diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index dea54838d..bdb15bb54 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -185,7 +185,7 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig": elif not isinstance(valid_config_dict["seqpos_slice"], tuple): valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],) - return TrainingSAEConfig(**valid_config_dict) + return cls(**valid_config_dict) def to_dict(self) -> dict[str, Any]: return { From b2ddc70ab5bf3d0f7554d0856c50cd76d7cc8c84 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 17:06:53 -0400 Subject: [PATCH 16/61] CrosscoderSAETrainer implementation, one test --- sae_lens/training/crosscoder_sae_trainer.py | 150 ++++++++++++++++++ tests/training/test_crosscoder_sae_trainer.py | 103 ++++++++++++ 2 files changed, 253 insertions(+) create mode 100644 sae_lens/training/crosscoder_sae_trainer.py create mode 100644 tests/training/test_crosscoder_sae_trainer.py diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py new file mode 100644 index 000000000..2418bde99 --- /dev/null +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -0,0 +1,150 @@ +from typing import Any + +import torch +import wandb +from tqdm import tqdm + +from sae_lens.evals import run_evals +from sae_lens.training.sae_trainer import SAETrainer +from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput + +# TODO(mkbehr): probably too much copypasting here + +class CrosscoderSAETrainer(SAETrainer): + def fit(self) -> TrainingSAE: + pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE") + + self.activations_store.set_norm_scaling_factor_if_needed() + + # Train loop + while self.n_training_tokens < self.cfg.total_training_tokens: + # Do a training step. + layer_acts = self.activations_store.next_batch().to( + self.sae.device + ) + self.n_training_tokens += self.cfg.train_batch_size_tokens + + step_output = self._train_step(sae=self.sae, sae_in=layer_acts) + + if self.cfg.log_to_wandb: + self._log_train_step(step_output) + self._run_and_log_evals() + + self._checkpoint_if_needed() + self.n_training_steps += 1 + self._update_pbar(step_output, pbar) + + ### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already) + self._begin_finetuning_if_needed() + + # fold the estimated norm scaling factor into the sae weights + if self.activations_store.estimated_norm_scaling_factor is not None: + self.sae.fold_activation_norm_scaling_factor( + self.activations_store.estimated_norm_scaling_factor + ) + self.activations_store.estimated_norm_scaling_factor = None + + # save final sae group to checkpoints folder + self.save_checkpoint( + trainer=self, + checkpoint_name=f"final_{self.n_training_tokens}", + wandb_aliases=["final_model"], + ) + + pbar.close() + return self.sae + + @torch.no_grad() + def _build_train_step_log_dict( + self, + output: TrainStepOutput, + n_training_tokens: int, + ) -> dict[str, Any]: + sae_in = output.sae_in + sae_out = output.sae_out + feature_acts = output.feature_acts + loss = output.loss.item() + + # metrics for currents acts + l0 = (feature_acts > 0).float().sum(-1).mean() + current_learning_rate = self.optimizer.param_groups[0]["lr"] + + per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=(-2, -1)).squeeze() + total_variance = (sae_in - sae_in.mean(0)).pow(2).sum((-2, -1)) + explained_variance = 1 - per_token_l2_loss / total_variance + + log_dict = { + # losses + "losses/overall_loss": loss, + # variance explained + "metrics/explained_variance": explained_variance.mean().item(), + "metrics/explained_variance_std": explained_variance.std().item(), + "metrics/l0": l0.item(), + # sparsity + "sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(), + "sparsity/dead_features": self.dead_neurons.sum().item(), + "details/current_learning_rate": current_learning_rate, + "details/current_l1_coefficient": self.current_l1_coefficient, + "details/n_training_tokens": n_training_tokens, + } + for loss_name, loss_value in output.losses.items(): + loss_item = _unwrap_item(loss_value) + # special case for l1 loss, which we normalize by the l1 coefficient + if loss_name == "l1_loss": + log_dict[f"losses/{loss_name}"] = ( + loss_item / self.current_l1_coefficient + ) + log_dict[f"losses/raw_{loss_name}"] = loss_item + else: + log_dict[f"losses/{loss_name}"] = loss_item + + return log_dict + + @torch.no_grad() + def _run_and_log_evals(self): + # record loss frequently, but not all the time. + if (self.n_training_steps + 1) % ( + self.cfg.wandb_log_frequency * self.cfg.eval_every_n_wandb_logs + ) == 0: + self.sae.eval() + ignore_tokens = set() + if self.activations_store.exclude_special_tokens is not None: + ignore_tokens = set( + self.activations_store.exclude_special_tokens.tolist() + ) + eval_metrics, _ = run_evals( + sae=self.sae, + activation_store=self.activations_store, + model=self.model, + eval_config=self.trainer_eval_config, + ignore_tokens=ignore_tokens, + model_kwargs=self.cfg.model_kwargs, + ) # not calculating featurwise metrics here. + + # Remove eval metrics that are already logged during training + eval_metrics.pop("metrics/explained_variance", None) + eval_metrics.pop("metrics/explained_variance_std", None) + eval_metrics.pop("metrics/l0", None) + eval_metrics.pop("metrics/l1", None) + eval_metrics.pop("metrics/mse", None) + + # Remove metrics that are not useful for wandb logging + eval_metrics.pop("metrics/total_tokens_evaluated", None) + + W_dec_norm_dist = self.sae.W_dec.detach().float().norm(dim=(1,2)).cpu().numpy() + eval_metrics["weights/W_dec_norms"] = wandb.Histogram(W_dec_norm_dist) # type: ignore + + if self.sae.cfg.architecture == "standard": + b_e_dist = self.sae.b_enc.detach().float().cpu().numpy() + eval_metrics["weights/b_e"] = wandb.Histogram(b_e_dist) # type: ignore + elif self.sae.cfg.architecture == "gated": + b_gate_dist = self.sae.b_gate.detach().float().cpu().numpy() + eval_metrics["weights/b_gate"] = wandb.Histogram(b_gate_dist) # type: ignore + b_mag_dist = self.sae.b_mag.detach().float().cpu().numpy() + eval_metrics["weights/b_mag"] = wandb.Histogram(b_mag_dist) # type: ignore + + wandb.log( + eval_metrics, + step=self.n_training_steps, + ) + self.sae.train() diff --git a/tests/training/test_crosscoder_sae_trainer.py b/tests/training/test_crosscoder_sae_trainer.py new file mode 100644 index 000000000..86e10ee41 --- /dev/null +++ b/tests/training/test_crosscoder_sae_trainer.py @@ -0,0 +1,103 @@ +from pathlib import Path +from typing import Any, Callable + +import pytest +import torch +from datasets import Dataset +from safetensors.torch import load_file +from transformer_lens import HookedTransformer + +from sae_lens import __version__ +from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.sae_training_runner import SAETrainingRunner +from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.crosscoder_sae_trainer import CrosscoderSAETrainer +from sae_lens.training.sae_trainer import ( + TrainStepOutput, + _log_feature_sparsity, + _update_sae_lens_training_version, +) +from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE +from tests.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached + + +@pytest.fixture +def cfg(): + return build_sae_cfg( + d_in=64, + d_sae=128, + hook_name="blocks.{}.hook_mlp_out", + hook_layers=[1,2,3], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + ) + + +@pytest.fixture +def model(): + return load_model_cached(TINYSTORIES_MODEL) + + +@pytest.fixture +def activation_store(model: HookedTransformer, cfg: LanguageModelSAERunnerConfig): + return ActivationsStore.from_config( + model, cfg, override_dataset=Dataset.from_list([{"text": "hello world"}] * 2000) + ) + + +@pytest.fixture +def training_sae(cfg: LanguageModelSAERunnerConfig): + return TrainingCrosscoderSAE.from_dict(cfg.get_training_sae_cfg_dict(), + use_error_term=True) + + +@pytest.fixture +def trainer( + cfg: LanguageModelSAERunnerConfig, + training_sae: TrainingCrosscoderSAE, + model: HookedTransformer, + activation_store: ActivationsStore, +): + return CrosscoderSAETrainer( + model=model, + sae=training_sae, + activation_store=activation_store, + save_checkpoint_fn=lambda *args, **kwargs: None, # noqa: ARG005 + cfg=cfg, + ) + + +def modify_sae_output(sae: TrainingCrosscoderSAE, modifier: Callable[[torch.Tensor], Any]): + """ + Helper to modify the output of the SAE forward pass for use in patching, for use in patch side_effect. + We need real grads during training, so we can't just mock the whole forward pass directly. + """ + + def modified_forward(*args: Any, **kwargs: Any) -> torch.Tensor: + output = TrainingCrosscoderSAE.forward(sae, *args, **kwargs) + return modifier(output) + + return modified_forward + + +def test_train_step__reduces_loss_when_called_repeatedly_on_same_acts( + trainer: CrosscoderSAETrainer, +) -> None: + layer_acts = trainer.activations_store.next_batch() + + # intentionally train on the same activations 5 times to ensure loss decreases + train_outputs = [ + trainer._train_step( + sae=trainer.sae, + sae_in=layer_acts, + ) + for _ in range(5) + ] + + # ensure loss decreases with each training step + for output, next_output in zip(train_outputs[:-1], train_outputs[1:]): + assert output.loss > next_output.loss + assert ( + trainer.n_frac_active_tokens == 20 + ) # should increment each step by batch_size (5*4) + From 529109e96b3b0120acaffb7027f54c9ec2662084 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 17:09:27 -0400 Subject: [PATCH 17/61] two more CrosscoderSAETrainer tests --- tests/training/test_crosscoder_sae_trainer.py | 50 +++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/training/test_crosscoder_sae_trainer.py b/tests/training/test_crosscoder_sae_trainer.py index 86e10ee41..4a220ac6a 100644 --- a/tests/training/test_crosscoder_sae_trainer.py +++ b/tests/training/test_crosscoder_sae_trainer.py @@ -101,3 +101,53 @@ def test_train_step__reduces_loss_when_called_repeatedly_on_same_acts( trainer.n_frac_active_tokens == 20 ) # should increment each step by batch_size (5*4) + +def test_train_step__output_looks_reasonable(trainer: CrosscoderSAETrainer) -> None: + layer_acts = trainer.activations_store.next_batch() + + output = trainer._train_step( + sae=trainer.sae, + sae_in=layer_acts, + ) + + assert output.loss > 0 + # only hook_point_layer=0 acts should be passed to the SAE + assert torch.allclose(output.sae_in, layer_acts) + assert output.sae_out.shape == output.sae_in.shape + assert output.feature_acts.shape == (4, 128) # batch_size, d_sae + # ghots grads shouldn't trigger until dead_feature_window, which hasn't been reached yet + assert output.losses.get("ghost_grad_loss", 0) == 0 + assert trainer.n_frac_active_tokens == 4 + assert trainer.act_freq_scores.sum() > 0 # at least SOME acts should have fired + assert torch.allclose( + trainer.act_freq_scores, (output.feature_acts.abs() > 0).float().sum(0) + ) + + +def test_train_step__sparsity_updates_based_on_feature_act_sparsity( + trainer: CrosscoderSAETrainer, +) -> None: + trainer._reset_running_sparsity_stats() + layer_acts = trainer.activations_store.next_batch() + + train_output = trainer._train_step( + sae=trainer.sae, + sae_in=layer_acts, + ) + feature_acts = train_output.feature_acts + + # should increase by batch_size + assert trainer.n_frac_active_tokens == 4 + # add freq scores for all non-zero feature acts + assert torch.allclose( + trainer.act_freq_scores, (feature_acts > 0).float().sum(dim=0) + ) + + # check that features that just fired have n_forward_passes_since_fired = 0 + assert ( + trainer.n_forward_passes_since_fired[ + ((feature_acts > 0).float()[-1] == 1) + ].max() + == 0 + ) + assert train_output.feature_acts is feature_acts From 3935dbc86428e97e97574b43ae73db5689506df7 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 17:26:44 -0400 Subject: [PATCH 18/61] test log dict --- sae_lens/training/crosscoder_sae_trainer.py | 2 +- tests/training/test_crosscoder_sae_trainer.py | 49 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py index 2418bde99..faef709a7 100644 --- a/sae_lens/training/crosscoder_sae_trainer.py +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -5,7 +5,7 @@ from tqdm import tqdm from sae_lens.evals import run_evals -from sae_lens.training.sae_trainer import SAETrainer +from sae_lens.training.sae_trainer import SAETrainer, _unwrap_item from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput # TODO(mkbehr): probably too much copypasting here diff --git a/tests/training/test_crosscoder_sae_trainer.py b/tests/training/test_crosscoder_sae_trainer.py index 4a220ac6a..0628baa74 100644 --- a/tests/training/test_crosscoder_sae_trainer.py +++ b/tests/training/test_crosscoder_sae_trainer.py @@ -151,3 +151,52 @@ def test_train_step__sparsity_updates_based_on_feature_act_sparsity( == 0 ) assert train_output.feature_acts is feature_acts + +def test_build_train_step_log_dict(trainer: CrosscoderSAETrainer) -> None: + sae_in = torch.tensor([[[-1, 0], [-2, 0]], + [[0, 2], [0, 3]], + [[1, 1], [1, 1]]]).float() + sae_out = torch.tensor([[[0, 0], [0, 0]], + [[0, 2], [0, 3]], + [[0.5, 1], [1, 0.5]]]).float() + train_output = TrainStepOutput( + sae_in=sae_in, + sae_out=sae_out, + feature_acts=torch.tensor([[0, 0, 0, 1], [1, 0, 0, 1], [1, 0, 1, 1]]).float(), + hidden_pre=torch.tensor([[-1, 0, 0, 1], [1, -1, 0, 1], [1, -1, 1, 1]]).float(), + loss=torch.tensor(0.5), + losses={ + "mse_loss": 0.25, + "l1_loss": 0.1, + "ghost_grad_loss": 0.15, + }, + ) + + per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=(-2, -1)).squeeze() + total_variance = (sae_in - sae_in.mean(0)).pow(2).sum((-2, -1)) + explained_variance = 1 - per_token_l2_loss / total_variance + + # we're relying on the trainer only for some of the metrics here + # we should more / less try to break this and push + # everything through the train step output if we can. + log_dict = trainer._build_train_step_log_dict( + output=train_output, n_training_tokens=123 + ) + for key, val in { + "losses/mse_loss": 0.25, + # l1 loss is scaled by l1_coefficient + "losses/l1_loss": train_output.losses["l1_loss"] / trainer.cfg.l1_coefficient, + "losses/raw_l1_loss": train_output.losses["l1_loss"], + "losses/overall_loss": 0.5, + "losses/ghost_grad_loss": 0.15, + "metrics/explained_variance": explained_variance.mean().item(), + "metrics/explained_variance_std": explained_variance.std().item(), + "metrics/l0": 2.0, + "sparsity/mean_passes_since_fired": trainer.n_forward_passes_since_fired.mean().item(), + "sparsity/dead_features": trainer.dead_neurons.sum().item(), + "details/current_learning_rate": 2e-4, + "details/current_l1_coefficient": trainer.cfg.l1_coefficient, + "details/n_training_tokens": 123, + }.items(): + assert abs(val - log_dict[key]) < 1e-6 + From 588de21b32596baef67e3f269c53bff4027b1d59 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 13 Apr 2025 17:28:48 -0400 Subject: [PATCH 19/61] test_train_sae_group_on_language_model__runs --- tests/training/test_crosscoder_sae_trainer.py | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/training/test_crosscoder_sae_trainer.py b/tests/training/test_crosscoder_sae_trainer.py index 0628baa74..376765187 100644 --- a/tests/training/test_crosscoder_sae_trainer.py +++ b/tests/training/test_crosscoder_sae_trainer.py @@ -200,3 +200,34 @@ def test_build_train_step_log_dict(trainer: CrosscoderSAETrainer) -> None: }.items(): assert abs(val - log_dict[key]) < 1e-6 + +def test_train_sae_group_on_language_model__runs( + ts_model: HookedTransformer, + tmp_path: Path, +) -> None: + checkpoint_dir = tmp_path / "checkpoint" + cfg = build_sae_cfg( + checkpoint_path=str(checkpoint_dir), + training_tokens=20, + context_size=8, + hook_name="blocks.{}.hook_mlp_out", + hook_layers=[1,2,3], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + ) + # just a tiny datast which will run quickly + dataset = Dataset.from_list([{"text": "hello world"}] * 100) + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) + sae = TrainingCrosscoderSAE.from_dict(cfg.get_training_sae_cfg_dict(), + use_error_term=True) + sae = CrosscoderSAETrainer( + model=ts_model, + sae=sae, + activation_store=activation_store, + save_checkpoint_fn=lambda *args, **kwargs: None, # noqa: ARG005 + cfg=cfg, + ).fit() + + assert isinstance(sae, TrainingCrosscoderSAE) From 2bc00bb527ba1c7e5402b861b6aaa6bc66a80597 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Mon, 14 Apr 2025 11:37:19 -0400 Subject: [PATCH 20/61] fix TrainingCrosscoderSAEConfig.to_dict --- sae_lens/training/training_crosscoder_sae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py index a9cf265a6..d1322c3ee 100644 --- a/sae_lens/training/training_crosscoder_sae.py +++ b/sae_lens/training/training_crosscoder_sae.py @@ -76,8 +76,8 @@ def from_sae_runner_config( def to_dict(self) -> dict[str, Any]: # TODO(mkbehr): double-check this multiple inheritance. seems messy. - return (TrainingSAE.to_dict(self) - | CrosscoderSAE.to_dict(self) + return (TrainingSAEConfig.to_dict(self) + | CrosscoderSAEConfig.to_dict(self) | { "sparsity_penalty_decoder_norm_lp_norm": self.sparsity_penalty_decoder_norm_lp_norm, From 7d54ea4b2f46edf2899de884aadb8395a9344f08 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Thu, 17 Apr 2025 11:04:01 -0400 Subject: [PATCH 21/61] quick name fixes to satisfy wandb --- sae_lens/crosscoder_sae.py | 2 +- sae_lens/sae_training_runner.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index 19e84e2f3..cc06035ac 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -47,7 +47,7 @@ def __init__( def get_name(self): # TODO(mkbehr): think about the correct name - layers = ','.join([str(l) for l in self.cfg.hook_layers]) + layers = '_'.join([str(l) for l in self.cfg.hook_layers]) return f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_layers{layers}_{self.cfg.d_sae}" @classmethod diff --git a/sae_lens/sae_training_runner.py b/sae_lens/sae_training_runner.py index 8f754a16e..805014b7e 100644 --- a/sae_lens/sae_training_runner.py +++ b/sae_lens/sae_training_runner.py @@ -212,7 +212,8 @@ def save_checkpoint( if trainer.cfg.log_to_wandb: # Avoid wandb saving errors such as: # ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc - sae_name = trainer.sae.get_name().replace("/", "__") + # TODO(mkbehr) name better + sae_name = trainer.sae.get_name().replace("/", "__").replace("{}", "__") # save model weights and cfg model_artifact = wandb.Artifact( From be7f780db9d30f9309e1457cf4d0884547a857fc Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Thu, 17 Apr 2025 11:04:45 -0400 Subject: [PATCH 22/61] use crosscoder from training runner --- sae_lens/sae_training_runner.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/sae_lens/sae_training_runner.py b/sae_lens/sae_training_runner.py index 805014b7e..205e672c7 100644 --- a/sae_lens/sae_training_runner.py +++ b/sae_lens/sae_training_runner.py @@ -14,6 +14,8 @@ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig from sae_lens.load_model import load_model from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.crosscoder_sae_trainer import CrosscoderSAETrainer +from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE from sae_lens.training.geometric_median import compute_geometric_median from sae_lens.training.sae_trainer import SAETrainer from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig @@ -100,13 +102,23 @@ def run(self): id=self.cfg.wandb_id, ) - trainer = SAETrainer( - model=self.model, - sae=self.sae, - activation_store=self.activations_store, - save_checkpoint_fn=self.save_checkpoint, - cfg=self.cfg, - ) + # TODO(mkbehr): make a better way to get the right trainer in + if isinstance(self.sae, TrainingCrosscoderSAE): + trainer = CrosscoderSAETrainer( + model=self.model, + sae=self.sae, + activation_store=self.activations_store, + save_checkpoint_fn=self.save_checkpoint, + cfg=self.cfg, + ) + else: + trainer = SAETrainer( + model=self.model, + sae=self.sae, + activation_store=self.activations_store, + save_checkpoint_fn=self.save_checkpoint, + cfg=self.cfg, + ) self._compile_if_needed() sae = self.run_trainer_with_interruption_handling(trainer) From 0e6acdc13f1f96b64577517bffb10c010826ae66 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Thu, 17 Apr 2025 11:05:25 -0400 Subject: [PATCH 23/61] initialize W_dec in TrainingCrosscoderSAE --- sae_lens/training/training_crosscoder_sae.py | 35 ++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py index d1322c3ee..26dc3196b 100644 --- a/sae_lens/training/training_crosscoder_sae.py +++ b/sae_lens/training/training_crosscoder_sae.py @@ -6,6 +6,7 @@ import einops import torch from jaxtyping import Float +from torch import nn from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.crosscoder_sae import CrosscoderSAE, CrosscoderSAEConfig @@ -229,6 +230,40 @@ def load_from_pretrained( return sae + def initialize_weights_complex(self): + if self.cfg.decoder_orthogonal_init: + self.W_dec.data = nn.init.orthogonal_( + self.W_dec.data.permute((1,2,0)) + ).permute((2,0,1)) + + elif self.cfg.decoder_heuristic_init: + self.W_dec = nn.Parameter( + torch.rand( + self.cfg.d_sae, *self.input_shape(), dtype=self.dtype, device=self.device + ) + ) + self.initialize_decoder_norm_constant_norm() + + # Then we initialize the encoder weights (either as the transpose of decoder or not) + if self.cfg.init_encoder_as_decoder_transpose: + self.W_enc.data = self.W_dec.data.permute((1,2,0)).clone().contiguous() + else: + self.W_enc = nn.Parameter( + torch.nn.init.kaiming_uniform_( + torch.empty( + *self.input_shape(), + self.cfg.d_sae, + dtype=self.dtype, + device=self.device, + ) + ) + ) + + if self.cfg.normalize_sae_decoder: + with torch.no_grad(): + # Anthropic normalize this to have unit columns + self.set_decoder_norm_to_unit_norm() + @torch.no_grad() def set_decoder_norm_to_unit_norm(self): self.W_dec.data /= torch.norm(self.W_dec.data, dim=[1,2], keepdim=True) From 9cf93ea08e2bc8c5225e1a4c49480b23dc7f687c Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Thu, 17 Apr 2025 11:05:48 -0400 Subject: [PATCH 24/61] temporarily hardcode evals off --- sae_lens/training/crosscoder_sae_trainer.py | 29 +++++++++++++++------ 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py index faef709a7..38d6abb40 100644 --- a/sae_lens/training/crosscoder_sae_trainer.py +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -11,6 +11,17 @@ # TODO(mkbehr): probably too much copypasting here class CrosscoderSAETrainer(SAETrainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # TODO(mkbehr) hardcoding causal evals off for now + self.trainer_eval_config.compute_ce_loss=False + self.trainer_eval_config.compute_kl=False + # TODO(mkbehr) hardcoding l2/sparsity/variance off, since + # those evals don't work yet + self.trainer_eval_config.compute_l2_norms=False + self.trainer_eval_config.compute_sparsity_metrics=False + self.trainer_eval_config.compute_variance_metrics=False + def fit(self) -> TrainingSAE: pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE") @@ -112,14 +123,16 @@ def _run_and_log_evals(self): ignore_tokens = set( self.activations_store.exclude_special_tokens.tolist() ) - eval_metrics, _ = run_evals( - sae=self.sae, - activation_store=self.activations_store, - model=self.model, - eval_config=self.trainer_eval_config, - ignore_tokens=ignore_tokens, - model_kwargs=self.cfg.model_kwargs, - ) # not calculating featurwise metrics here. + # TODO(mkbehr): get some evals working + eval_metrics = {} + # eval_metrics, _ = run_evals( + # sae=self.sae, + # activation_store=self.activations_store, + # model=self.model, + # eval_config=self.trainer_eval_config, + # ignore_tokens=ignore_tokens, + # model_kwargs=self.cfg.model_kwargs, + # ) # not calculating featurwise metrics here. # Remove eval metrics that are already logged during training eval_metrics.pop("metrics/explained_variance", None) From c1cfde5f2d5ec1d012d6652e01b3cec8b0e36435 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Thu, 17 Apr 2025 11:05:56 -0400 Subject: [PATCH 25/61] training script --- scripts/global_acausal_crosscoder.py | 134 +++++++++++++++++++++++++++ 1 file changed, 134 insertions(+) create mode 100644 scripts/global_acausal_crosscoder.py diff --git a/scripts/global_acausal_crosscoder.py b/scripts/global_acausal_crosscoder.py new file mode 100644 index 000000000..9e8c8c67c --- /dev/null +++ b/scripts/global_acausal_crosscoder.py @@ -0,0 +1,134 @@ +# TODO(mkbehr): don't really commit this + +import os +import sys + +import torch + +sys.path.append("..") + +from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.training.training_crosscoder_sae import ( + TrainingCrosscoderSAE, + TrainingCrosscoderSAEConfig +) +from sae_lens.sae_training_runner import SAETrainingRunner + +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" + +print("Using device:", device) +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# total_training_steps = 200_000 +total_training_steps = 50_000 +# total_training_steps = 1000 +# batch_size = 4096 +# batch_size = 256 +total_training_tokens = total_training_steps * 256 +print(f"Total Training Tokens: {total_training_tokens}") +# l1_coefficient = 1.0 +l1_coefficient = 1e-6 # DEBUG: if I mostly zero out the l1 loss, will it learn? + +# change these configs +model_name = "tiny-stories-2L-33M" +dataset_path = "apollo-research/roneneldan-TinyStories-tokenizer-gpt2" +new_cached_activations_path = ( + f"./cached_activations/{model_name}/{dataset_path}/{total_training_steps}" +) + +lr_warm_up_steps = total_training_steps // 20 +print(f"lr_warm_up_steps: {lr_warm_up_steps}") +lr_decay_steps = total_training_steps // 5 # 20% of training steps. +print(f"lr_decay_steps: {lr_decay_steps}") +l1_warmup_steps = total_training_steps // 10 +print(f"l1_warmup_steps: {l1_warmup_steps}") +log_to_wandb = True +# log_to_wandb = False + +cfg = LanguageModelSAERunnerConfig( + # Pick a tiny model to make this easier. + model_name=model_name, + hook_name="blocks.{}.hook_mlp_out", + hook_layers=[0,1], + d_in=1024, + dataset_path=dataset_path, + streaming=True, + context_size=512, + is_dataset_tokenized=True, + prepend_bos=True, + # How big do we want our SAE to be? + expansion_factor=64, + # Dataset / Activation Store + # When we do a proper test + # training_tokens= 820_000_000, # 200k steps * 4096 batch size ~ 820M tokens (doable overnight on an A100) + # For now. + use_cached_activations=False, + # cached_activations_path="/home/paperspace/shared_volumes/activations_volume_1/gelu-1l", + training_tokens=total_training_tokens, # For initial testing I think this is a good number. + train_batch_size_tokens=4096, # TODO(mkbehr) doesn't follow batch_size! + # Loss Function + ## Reconstruction Coefficient. + mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch. + ## Anthropic does not mention using an Lp norm other than L1. + l1_coefficient=l1_coefficient, + lp_norm=1.0, + # Instead, they multiply the L1 loss contribution + # from each feature of the activations by the decoder norm of the corresponding feature. + scale_sparsity_penalty_by_decoder_norm=True, + # sparsity_penalty_decoder_norm_lp_norm=1.0, + # Learning Rate + lr_scheduler_name="constant", # we set this independently of warmup and decay steps. + l1_warm_up_steps=l1_warmup_steps, + lr_warm_up_steps=lr_warm_up_steps, + lr_decay_steps=lr_warm_up_steps, + ## No ghost grad term. + use_ghost_grads=False, + # Initialization / Architecture + apply_b_dec_to_input=False, + # encoder bias zero's. (I'm not sure what it is by default now) + # decoder bias zero's. + b_dec_init_method="zeros", + normalize_sae_decoder=False, + decoder_heuristic_init=True, + init_encoder_as_decoder_transpose=True, + # Optimizer + lr=5e-5, + ## adam optimizer has no weight decay by default so worry about this. + adam_beta1=0.9, + adam_beta2=0.999, + # Buffer details won't matter in we cache / shuffle our activations ahead of time. + n_batches_in_buffer=64, + store_batch_size_prompts=16, + normalize_activations="none", + # Feature Store + feature_sampling_window=1000, + dead_feature_window=1000, + dead_feature_threshold=1e-4, + # WANDB + log_to_wandb=log_to_wandb, # always use wandb unless you are just testing code. + wandb_project="crosscoder-global-acausal-tinystories", + wandb_log_frequency=50, + eval_every_n_wandb_logs=10, + # Misc + device=device, + seed=42, + n_checkpoints=0, + checkpoint_path="checkpoints", + dtype="float32", +) + +# look at the next cell to see some instruction for what to do while this is running. +sae = SAETrainingRunner( + cfg, + override_sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), + use_error_term=True, + )).run() + +print("=" * 50) + From b17b6075b227233edb63b8656cb47e9f0642fb73 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Thu, 17 Apr 2025 11:22:02 -0400 Subject: [PATCH 26/61] add ActivationsStore.hook_names() --- sae_lens/training/activations_store.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index e68d6347c..a93948994 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -373,6 +373,14 @@ def _iterate_tokenized_sequences(self) -> Generator[torch.Tensor, None, None]: ), ) + def hook_names(self) -> List[str]: + # TODO(mkbehr): better config setup than putting a magic + # string in the name + if "{}" in self.hook_name: + return [self.hook_name.format(layer) + for layer in self.hook_layers] + return [self.hook_name] + def load_cached_activation_dataset(self) -> Dataset | None: """ Load the cached activation dataset from disk. @@ -394,7 +402,8 @@ def load_cached_activation_dataset(self) -> Dataset | None: # --- # Actual code activations_dataset = datasets.load_from_disk(self.cached_activations_path) - columns = [self.hook_name] + # TODO(mkbehr): test multiple layers + columns = self.hook_names() if "token_ids" in activations_dataset.column_names: columns.append("token_ids") activations_dataset.set_format( @@ -537,15 +546,7 @@ def get_activations(self, batch_tokens: torch.Tensor): else: autocast_if_enabled = contextlib.nullcontext() - # TODO(mkbehr): This is awkward. It may make more sense to - # have a list of hook names. - hook_names = [] - for layer in self.hook_layers: - if "{}" in self.hook_name: - hook_names.append(self.hook_name.format(layer)) - else: - hook_names.append(self.hook_name) - + hook_names = self.hook_names() stop_at_layer = max(self.hook_layers) + 1 with autocast_if_enabled: @@ -607,8 +608,7 @@ def _load_buffer_from_cached( raises StopIteration """ assert self.cached_activation_dataset is not None - # In future, could be a list of multiple hook names - hook_names = [self.hook_name] + hook_names = self.hook_names() if not set(hook_names).issubset(self.cached_activation_dataset.column_names): raise ValueError( f"Missing columns in dataset. Expected {hook_names}, " From 5fb5b49867272993dfc7e58fa898153bd38e31db Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Fri, 18 Apr 2025 17:21:56 -0400 Subject: [PATCH 27/61] l2/sparsity/variance evals for crosscoders --- sae_lens/crosscoder_sae.py | 9 ++++- sae_lens/evals.py | 31 ++++++++++----- sae_lens/sae.py | 5 ++- sae_lens/training/crosscoder_sae_trainer.py | 23 ++++------- tests/test_evals.py | 44 +++++++++++++++++++++ 5 files changed, 85 insertions(+), 27 deletions(-) diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index cc06035ac..469159739 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any +from typing import Any, List import einops import torch @@ -22,6 +22,13 @@ def to_dict(self) -> dict[str, Any]: "hook_layers": self.hook_layers, } + def hook_names(self) -> List[str]: + # TODO(mkbehr): better config setup than putting a magic + # string in the name + return [self.hook_name.format(layer) + for layer in self.hook_layers] + + class CrosscoderSAE(SAE): """ TODO(mkbehr): docstring diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 63d43d203..04b82064f 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -378,7 +378,7 @@ def get_sparsity_and_variance_metrics( ignore_tokens: set[int | None] = set(), verbose: bool = False, ) -> tuple[dict[str, Any], dict[str, Any]]: - hook_name = sae.cfg.hook_name + hook_names = sae.cfg.hook_names() hook_head_index = sae.cfg.hook_head_index metric_dict = {} @@ -434,8 +434,8 @@ def get_sparsity_and_variance_metrics( _, cache = model.run_with_cache( batch_tokens, prepend_bos=False, - names_filter=[hook_name], - stop_at_layer=sae.cfg.hook_layer + 1, + names_filter=hook_names, + stop_at_layer=max(sae.cfg.hook_layers) + 1, **model_kwargs, ) @@ -443,11 +443,21 @@ def get_sparsity_and_variance_metrics( # which will do their own reshaping for hook z. has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v", "hook_z"] if hook_head_index is not None: - original_act = cache[hook_name][:, :, hook_head_index] - elif any(substring in hook_name for substring in has_head_dim_key_substrings): - original_act = cache[hook_name].flatten(-2, -1) + # TODO(mkbehr) support head dimension for mutilayer evals + assert len(hook_names) == 1 + original_act = cache[hook_names[0]][:, :, hook_head_index] + elif any(substring in hook_names[0] for substring in has_head_dim_key_substrings): + # TODO(mkbehr) support head dimension for mutilayer evals + original_act = cache[hook_names[0]].flatten(-2, -1) + elif len(hook_names) > 1: + # TODO(mkbehr): cleaner interface for multilayer evals + # TODO(mkbehr): support head dimension for mutilayer evals + layerwise_activations = [ + cache[hook_name] for hook_name in hook_names + ] + original_act = torch.stack(layerwise_activations, dim=2) else: - original_act = cache[hook_name] + original_act = cache[hook_names[0]] # normalise if necessary (necessary in training only, otherwise we should fold the scaling in) if activation_store.normalize_activations == "expected_average_only_in": @@ -461,14 +471,15 @@ def get_sparsity_and_variance_metrics( if activation_store.normalize_activations == "expected_average_only_in": sae_out = activation_store.unscale(sae_out) - flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d") + flattened_sae_input = einops.rearrange(original_act, "b ctx d ... -> (b ctx) (d ...)") flattened_sae_feature_acts = einops.rearrange( - sae_feature_activations, "b ctx d -> (b ctx) d" + sae_feature_activations, "b ctx d ... -> (b ctx) (d ...)" ) - flattened_sae_out = einops.rearrange(sae_out, "b ctx d -> (b ctx) d") + flattened_sae_out = einops.rearrange(sae_out, "b ctx d ... -> (b ctx) (d ...)") # TODO: Clean this up. # apply mask + # TODO(mkbehr): test mask support w/ multilayer masked_sae_feature_activations = sae_feature_activations * mask.unsqueeze(-1) flattened_sae_input = flattened_sae_input[ flattened_mask.to(flattened_sae_input.device) diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 93dea38c5..b8392d14e 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, Literal, TypeVar, overload +from typing import Any, Callable, List, Literal, TypeVar, overload import einops import torch @@ -124,6 +124,9 @@ def to_dict(self) -> dict[str, Any]: "seqpos_slice": self.seqpos_slice, } + def hook_names(self) -> List[str]: + return [self.hook_name] + class SAE(HookedRootModule): """ diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py index 38d6abb40..ccbdb4a0b 100644 --- a/sae_lens/training/crosscoder_sae_trainer.py +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -16,11 +16,6 @@ def __init__(self, *args, **kwargs): # TODO(mkbehr) hardcoding causal evals off for now self.trainer_eval_config.compute_ce_loss=False self.trainer_eval_config.compute_kl=False - # TODO(mkbehr) hardcoding l2/sparsity/variance off, since - # those evals don't work yet - self.trainer_eval_config.compute_l2_norms=False - self.trainer_eval_config.compute_sparsity_metrics=False - self.trainer_eval_config.compute_variance_metrics=False def fit(self) -> TrainingSAE: pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE") @@ -123,16 +118,14 @@ def _run_and_log_evals(self): ignore_tokens = set( self.activations_store.exclude_special_tokens.tolist() ) - # TODO(mkbehr): get some evals working - eval_metrics = {} - # eval_metrics, _ = run_evals( - # sae=self.sae, - # activation_store=self.activations_store, - # model=self.model, - # eval_config=self.trainer_eval_config, - # ignore_tokens=ignore_tokens, - # model_kwargs=self.cfg.model_kwargs, - # ) # not calculating featurwise metrics here. + eval_metrics, _ = run_evals( + sae=self.sae, + activation_store=self.activations_store, + model=self.model, + eval_config=self.trainer_eval_config, + ignore_tokens=ignore_tokens, + model_kwargs=self.cfg.model_kwargs, + ) # not calculating featurwise metrics here. # Remove eval metrics that are already logged during training eval_metrics.pop("metrics/explained_variance", None) diff --git a/tests/test_evals.py b/tests/test_evals.py index 12a5dd953..36c6222ad 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -25,6 +25,10 @@ from sae_lens.sae import SAE from sae_lens.toolkit.pretrained_saes_directory import PretrainedSAELookup from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.training_crosscoder_sae import ( + TrainingCrosscoderSAE, + TrainingCrosscoderSAEConfig, +) from sae_lens.training.training_sae import TrainingSAE from tests.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached @@ -283,6 +287,46 @@ def test_run_empty_evals( assert "token_stats" in eval_metrics, "Expected token_stats in eval_metrics" assert len(feature_metrics) == 0, "Expected empty feature_metrics" +# TODO(mkbehr): consider parameterizing +def test_run_evals_crosscoder_training_sae(model): + cfg=build_sae_cfg( + model_name="tiny-stories-1M", + dataset_path="roneneldan/TinyStories", + hook_name="blocks.{}.hook_resid_pre", + hook_layers=[0, 1], + d_in=64, + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + ) + activation_store = ActivationsStore.from_config( + model, cfg, override_dataset=Dataset.from_list([{"text": "hello world"}] * 2000) + ) + training_crosscoder_sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), + use_error_term=True) + eval_config = EvalConfig( + compute_l2_norms=True, + compute_sparsity_metrics=True, + compute_variance_metrics=True, + # TODO(mkbehr): featurewise metrics + compute_featurewise_density_statistics=False, + compute_featurewise_weight_based_metrics=False, + ) + eval_metrics, feature_metrics = run_evals( + sae=training_crosscoder_sae, + activation_store=activation_store, + model=model, + eval_config=eval_config, + ) + expected_keys = [ + "reconstruction_quality", + "shrinkage", + "sparsity", + "token_stats", + ] + assert set(eval_metrics.keys()) == set(expected_keys) + assert set(feature_metrics.keys()) == set( + ["feature_density", "consistent_activation_heuristic"]) @pytest.fixture def mock_args(): From 05512da637209c7d88eba83bbd8e4c2665124b24 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sat, 19 Apr 2025 12:24:33 -0400 Subject: [PATCH 28/61] tiny-stories-1m experiments --- scripts/global_acausal_crosscoder.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts/global_acausal_crosscoder.py b/scripts/global_acausal_crosscoder.py index 9e8c8c67c..d82ef2b08 100644 --- a/scripts/global_acausal_crosscoder.py +++ b/scripts/global_acausal_crosscoder.py @@ -27,15 +27,15 @@ # total_training_steps = 200_000 total_training_steps = 50_000 # total_training_steps = 1000 -# batch_size = 4096 +batch_size = 4096 # batch_size = 256 -total_training_tokens = total_training_steps * 256 +total_training_tokens = total_training_steps * batch_size print(f"Total Training Tokens: {total_training_tokens}") -# l1_coefficient = 1.0 -l1_coefficient = 1e-6 # DEBUG: if I mostly zero out the l1 loss, will it learn? +l1_coefficient = 3e-2 # change these configs -model_name = "tiny-stories-2L-33M" +# TODO(mkbehr): just do tiny-stories-1M with all 8 layers +model_name = "tiny-stories-1M" dataset_path = "apollo-research/roneneldan-TinyStories-tokenizer-gpt2" new_cached_activations_path = ( f"./cached_activations/{model_name}/{dataset_path}/{total_training_steps}" @@ -54,8 +54,8 @@ # Pick a tiny model to make this easier. model_name=model_name, hook_name="blocks.{}.hook_mlp_out", - hook_layers=[0,1], - d_in=1024, + hook_layers=list(range(8)), + d_in=64, dataset_path=dataset_path, streaming=True, context_size=512, From 3d1abbd1a0562670efd122047aa32fe92b43a6aa Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sat, 19 Apr 2025 12:47:37 -0400 Subject: [PATCH 29/61] tiny-stories-28m --- scripts/global_acausal_crosscoder.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/scripts/global_acausal_crosscoder.py b/scripts/global_acausal_crosscoder.py index d82ef2b08..087548db0 100644 --- a/scripts/global_acausal_crosscoder.py +++ b/scripts/global_acausal_crosscoder.py @@ -25,44 +25,40 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # total_training_steps = 200_000 -total_training_steps = 50_000 +total_training_steps = 30_000 # total_training_steps = 1000 batch_size = 4096 # batch_size = 256 total_training_tokens = total_training_steps * batch_size print(f"Total Training Tokens: {total_training_tokens}") -l1_coefficient = 3e-2 -# change these configs -# TODO(mkbehr): just do tiny-stories-1M with all 8 layers -model_name = "tiny-stories-1M" +model_name = "tiny-stories-28M" dataset_path = "apollo-research/roneneldan-TinyStories-tokenizer-gpt2" new_cached_activations_path = ( f"./cached_activations/{model_name}/{dataset_path}/{total_training_steps}" ) -lr_warm_up_steps = total_training_steps // 20 +lr_warm_up_steps = 0 print(f"lr_warm_up_steps: {lr_warm_up_steps}") lr_decay_steps = total_training_steps // 5 # 20% of training steps. print(f"lr_decay_steps: {lr_decay_steps}") -l1_warmup_steps = total_training_steps // 10 +l1_warmup_steps = total_training_steps // 20 print(f"l1_warmup_steps: {l1_warmup_steps}") log_to_wandb = True # log_to_wandb = False cfg = LanguageModelSAERunnerConfig( - # Pick a tiny model to make this easier. model_name=model_name, hook_name="blocks.{}.hook_mlp_out", - hook_layers=list(range(8)), - d_in=64, + hook_layers=list(range(4)), + d_in=512, dataset_path=dataset_path, streaming=True, context_size=512, is_dataset_tokenized=True, prepend_bos=True, # How big do we want our SAE to be? - expansion_factor=64, + expansion_factor=16, # Dataset / Activation Store # When we do a proper test # training_tokens= 820_000_000, # 200k steps * 4096 batch size ~ 820M tokens (doable overnight on an A100) @@ -70,16 +66,17 @@ use_cached_activations=False, # cached_activations_path="/home/paperspace/shared_volumes/activations_volume_1/gelu-1l", training_tokens=total_training_tokens, # For initial testing I think this is a good number. - train_batch_size_tokens=4096, # TODO(mkbehr) doesn't follow batch_size! + train_batch_size_tokens=batch_size, # Loss Function ## Reconstruction Coefficient. mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch. ## Anthropic does not mention using an Lp norm other than L1. - l1_coefficient=l1_coefficient, + l1_coefficient=5, lp_norm=1.0, # Instead, they multiply the L1 loss contribution # from each feature of the activations by the decoder norm of the corresponding feature. scale_sparsity_penalty_by_decoder_norm=True, + # TODO(mkbehr): plumb this through config # sparsity_penalty_decoder_norm_lp_norm=1.0, # Learning Rate lr_scheduler_name="constant", # we set this independently of warmup and decay steps. @@ -111,7 +108,7 @@ dead_feature_threshold=1e-4, # WANDB log_to_wandb=log_to_wandb, # always use wandb unless you are just testing code. - wandb_project="crosscoder-global-acausal-tinystories", + wandb_project="crosscoder-acausal-tinystories-23M-layer0-3", wandb_log_frequency=50, eval_every_n_wandb_logs=10, # Misc From 9ea92c8249d9de4905d10695f41ef8ad9ea0fd86 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sat, 19 Apr 2025 23:39:26 -0400 Subject: [PATCH 30/61] minor fixes --- sae_lens/evals.py | 2 +- sae_lens/training/crosscoder_sae_trainer.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 04b82064f..9bb2b8786 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -473,7 +473,7 @@ def get_sparsity_and_variance_metrics( flattened_sae_input = einops.rearrange(original_act, "b ctx d ... -> (b ctx) (d ...)") flattened_sae_feature_acts = einops.rearrange( - sae_feature_activations, "b ctx d ... -> (b ctx) (d ...)" + sae_feature_activations, "b ctx d -> (b ctx) d" ) flattened_sae_out = einops.rearrange(sae_out, "b ctx d ... -> (b ctx) (d ...)") diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py index ccbdb4a0b..9d8c19ef3 100644 --- a/sae_lens/training/crosscoder_sae_trainer.py +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -7,6 +7,7 @@ from sae_lens.evals import run_evals from sae_lens.training.sae_trainer import SAETrainer, _unwrap_item from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput +from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE, TrainStepOutput # TODO(mkbehr): probably too much copypasting here @@ -17,8 +18,8 @@ def __init__(self, *args, **kwargs): self.trainer_eval_config.compute_ce_loss=False self.trainer_eval_config.compute_kl=False - def fit(self) -> TrainingSAE: - pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training SAE") + def fit(self) -> TrainingCrosscoderSAE: + pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training Crosscoder SAE") self.activations_store.set_norm_scaling_factor_if_needed() From c45b08a9de3771f51acaeb349b87e8aed329b9d0 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sat, 19 Apr 2025 23:39:57 -0400 Subject: [PATCH 31/61] training changes --- scripts/global_acausal_crosscoder.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/scripts/global_acausal_crosscoder.py b/scripts/global_acausal_crosscoder.py index 087548db0..8dfa3b773 100644 --- a/scripts/global_acausal_crosscoder.py +++ b/scripts/global_acausal_crosscoder.py @@ -25,9 +25,9 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # total_training_steps = 200_000 -total_training_steps = 30_000 -# total_training_steps = 1000 -batch_size = 4096 +total_training_steps = 60_000 +# total_training_steps = 5000 +batch_size = 2048 # batch_size = 256 total_training_tokens = total_training_steps * batch_size print(f"Total Training Tokens: {total_training_tokens}") @@ -38,14 +38,15 @@ f"./cached_activations/{model_name}/{dataset_path}/{total_training_steps}" ) -lr_warm_up_steps = 0 +lr_warm_up_steps = total_training_steps // 40 print(f"lr_warm_up_steps: {lr_warm_up_steps}") lr_decay_steps = total_training_steps // 5 # 20% of training steps. print(f"lr_decay_steps: {lr_decay_steps}") l1_warmup_steps = total_training_steps // 20 print(f"l1_warmup_steps: {l1_warmup_steps}") -log_to_wandb = True -# log_to_wandb = False +log_to_wandb = False +if not log_to_wandb: + print("NOT LOGGING TO WANDB") cfg = LanguageModelSAERunnerConfig( model_name=model_name, @@ -71,7 +72,7 @@ ## Reconstruction Coefficient. mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch. ## Anthropic does not mention using an Lp norm other than L1. - l1_coefficient=5, + l1_coefficient=1, lp_norm=1.0, # Instead, they multiply the L1 loss contribution # from each feature of the activations by the decoder norm of the corresponding feature. @@ -94,7 +95,7 @@ decoder_heuristic_init=True, init_encoder_as_decoder_transpose=True, # Optimizer - lr=5e-5, + lr=1e-5, ## adam optimizer has no weight decay by default so worry about this. adam_beta1=0.9, adam_beta2=0.999, From c098be0c313ba3fdb81ce91db231f5fbdaa27f53 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 20 Apr 2025 15:05:07 -0400 Subject: [PATCH 32/61] scale W_dec init norm --- sae_lens/training/training_crosscoder_sae.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py index 26dc3196b..860c2afe7 100644 --- a/sae_lens/training/training_crosscoder_sae.py +++ b/sae_lens/training/training_crosscoder_sae.py @@ -242,6 +242,11 @@ def initialize_weights_complex(self): self.cfg.d_sae, *self.input_shape(), dtype=self.dtype, device=self.device ) ) + self.W_dec.data = ( + self.W_dec.data + / self.W_dec.data.norm(dim=-1, keepdim=True) + * 0.1 # TODO(mkbehr): make norm configurable + ) self.initialize_decoder_norm_constant_norm() # Then we initialize the encoder weights (either as the transpose of decoder or not) From 7c01f2d5e10908ac1ac62c331f84845c5a4c4acd Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 20 Apr 2025 15:58:30 -0400 Subject: [PATCH 33/61] scale activations by layer --- sae_lens/crosscoder_sae.py | 14 +++++++++++++- sae_lens/training/activations_store.py | 21 +++++++++++++-------- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index 469159739..46f0dffb7 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -36,7 +36,6 @@ class CrosscoderSAE(SAE): # TODO(mkbehr): write # - remaining encode methods - # - fold_activation_norm # - hook_z reshaping support def __init__( @@ -136,3 +135,16 @@ def fold_W_dec_norm(self): self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze() else: self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze() + + @torch.no_grad() + def fold_activation_norm_scaling_factor( + self, activation_norm_scaling_factor: Float[torch.Tensor, "n_layers"] + ): + self.W_enc.data = self.W_enc.data * activation_norm_scaling_factor.reshape((-1,1,1)) + # previously weren't doing this. + self.W_dec.data = self.W_dec.data / activation_norm_scaling_factor.unsqueeze(-1) + self.b_dec.data = self.b_dec.data / activation_norm_scaling_factor.unsqueeze(-1) + + # once we normalize, we shouldn't need to scale activations. + self.cfg.normalize_activations = "none" + diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index a93948994..4f645c58e 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -442,31 +442,36 @@ def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: raise ValueError( "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first" ) - return activations * self.estimated_norm_scaling_factor + return activations * self.estimated_norm_scaling_factor.unsqueeze(-1).to(activations.device) def unscale(self, activations: torch.Tensor) -> torch.Tensor: if self.estimated_norm_scaling_factor is None: raise ValueError( "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first" ) - return activations / self.estimated_norm_scaling_factor + return activations / self.estimated_norm_scaling_factor.unsqueeze(-1).to(activations.device) def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: return (self.d_in**0.5) / activations.norm(dim=-1).mean() @torch.no_grad() def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)): - norms_per_batch = [] - for _ in tqdm( + # TODO(mkbehr): test multilayer norm scaling, probably fix saving? + norms_per_batch = torch.empty( + len(self.hook_layers), n_batches_for_norm_estimate, + device=self.device) + for batch_i in tqdm( range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor" ): # temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works - self.estimated_norm_scaling_factor = 1.0 + self.estimated_norm_scaling_factor = torch.ones(1) acts = self.next_batch()[:, 0] self.estimated_norm_scaling_factor = None - norms_per_batch.append(acts.norm(dim=-1).mean().item()) - mean_norm = np.mean(norms_per_batch) - return np.sqrt(self.d_in) / mean_norm + norms_per_batch[:, batch_i] = acts.norm(dim=-1).mean(dim=0) + mean_norm = norms_per_batch.mean(dim=1) + # TODO(mkbehr): make this a float in single-layer case for + # backwards compatibility + return (np.sqrt(self.d_in) / mean_norm) def shuffle_input_dataset(self, seed: int, buffer_size: int = 1): """ From 7957fa9602fc82e9d79e9bee39c695d74a9433ba Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 20 Apr 2025 23:38:41 -0400 Subject: [PATCH 34/61] some training changes --- scripts/global_acausal_crosscoder.py | 39 ++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 11 deletions(-) diff --git a/scripts/global_acausal_crosscoder.py b/scripts/global_acausal_crosscoder.py index 8dfa3b773..4b958a5b7 100644 --- a/scripts/global_acausal_crosscoder.py +++ b/scripts/global_acausal_crosscoder.py @@ -25,13 +25,16 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" # total_training_steps = 200_000 -total_training_steps = 60_000 -# total_training_steps = 5000 +# total_training_steps = 60_000 +total_training_steps = 10_000 batch_size = 2048 # batch_size = 256 total_training_tokens = total_training_steps * batch_size print(f"Total Training Tokens: {total_training_tokens}") +layers = list(range(3)) +# layers = [0] + model_name = "tiny-stories-28M" dataset_path = "apollo-research/roneneldan-TinyStories-tokenizer-gpt2" new_cached_activations_path = ( @@ -44,22 +47,35 @@ print(f"lr_decay_steps: {lr_decay_steps}") l1_warmup_steps = total_training_steps // 20 print(f"l1_warmup_steps: {l1_warmup_steps}") -log_to_wandb = False +log_to_wandb = True if not log_to_wandb: print("NOT LOGGING TO WANDB") +d_in = 512 +expansion_factor = 16 +d_sae = d_in * expansion_factor +learning_rate = 5e-5 +l1_coefficient = 1 +run_name = ( + f"{d_sae}" + f"-Layers-{'_'.join([str(l) for l in layers])}" + f"-L1-{l1_coefficient}" + f"-LR-{learning_rate}" + f"-Tokens-{total_training_tokens:3.3e}" + ) + cfg = LanguageModelSAERunnerConfig( model_name=model_name, hook_name="blocks.{}.hook_mlp_out", - hook_layers=list(range(4)), - d_in=512, + hook_layers=layers, + d_in=d_in, dataset_path=dataset_path, streaming=True, context_size=512, is_dataset_tokenized=True, - prepend_bos=True, + prepend_bos=False, # TODO(mkbehr): probably better to prepend bosg but then remove that token's activations # How big do we want our SAE to be? - expansion_factor=16, + expansion_factor=expansion_factor, # Dataset / Activation Store # When we do a proper test # training_tokens= 820_000_000, # 200k steps * 4096 batch size ~ 820M tokens (doable overnight on an A100) @@ -72,7 +88,7 @@ ## Reconstruction Coefficient. mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch. ## Anthropic does not mention using an Lp norm other than L1. - l1_coefficient=1, + l1_coefficient=l1_coefficient, lp_norm=1.0, # Instead, they multiply the L1 loss contribution # from each feature of the activations by the decoder norm of the corresponding feature. @@ -95,21 +111,22 @@ decoder_heuristic_init=True, init_encoder_as_decoder_transpose=True, # Optimizer - lr=1e-5, + lr=learning_rate, ## adam optimizer has no weight decay by default so worry about this. adam_beta1=0.9, adam_beta2=0.999, # Buffer details won't matter in we cache / shuffle our activations ahead of time. n_batches_in_buffer=64, store_batch_size_prompts=16, - normalize_activations="none", + normalize_activations="expected_average_only_in", # Feature Store feature_sampling_window=1000, dead_feature_window=1000, dead_feature_threshold=1e-4, # WANDB log_to_wandb=log_to_wandb, # always use wandb unless you are just testing code. - wandb_project="crosscoder-acausal-tinystories-23M-layer0-3", + wandb_project="crosscoder-acausal-tinystories-23M", + run_name=run_name, wandb_log_frequency=50, eval_every_n_wandb_logs=10, # Misc From 093bc39117f3fd96270e6081153e30378461e80d Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 27 Apr 2025 17:57:11 -0400 Subject: [PATCH 35/61] clean up some TODOs --- sae_lens/crosscoder_sae.py | 26 +------------------------- sae_lens/training/activations_store.py | 2 +- 2 files changed, 2 insertions(+), 26 deletions(-) diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index 46f0dffb7..1b4bb45c2 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -11,13 +11,7 @@ class CrosscoderSAEConfig(SAEConfig): hook_layers: list[int] = list - # @classmethod - # def from_dict(cls, config_dict: dict[str, Any]) -> "CrosscoderSAEConfig": - # # TODO(mkbehr) is a new method needed here, or will the superclass's work w/o modification? I think it'll work. test it. - # pass - def to_dict(self) -> dict[str, Any]: - # TODO(mkbehr) test return super().to_dict() | { "hook_layers": self.hook_layers, } @@ -31,7 +25,7 @@ def hook_names(self) -> List[str]: class CrosscoderSAE(SAE): """ - TODO(mkbehr): docstring + Sparse autoencoder that acts on multiple layers of activations. """ # TODO(mkbehr): write @@ -63,19 +57,6 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "CrosscoderSAE": def input_shape(self): return (len(self.cfg.hook_layers), self.cfg.d_in) - # TODO(mkbehr): in sae.py this is noted to output "... d_sae" but - # I think that's wrong - # TODO(mkbehr): I don't think we actually need to change this - def process_sae_in( - self, sae_in: Float[torch.Tensor, "... n_layers d_in"] - ) -> Float[torch.Tensor, "... n_layers d_in"]: - sae_in = sae_in.to(self.dtype) - # TODO(mkbehr): n.b. that reshape_fn_in is set to the identity - # if we're not doing hook_z reshaping - sae_in = self.reshape_fn_in(sae_in) - sae_in = self.hook_sae_input(sae_in) - sae_in = self.run_time_activation_norm_fn_in(sae_in) - return sae_in - (self.b_dec * self.cfg.apply_b_dec_to_input) def encode_standard( self, x: Float[torch.Tensor, "... n_layers d_in"] @@ -117,11 +98,6 @@ def decode( @torch.no_grad() def fold_W_dec_norm(self): - # TODO(mkbehr) - # W_dec: d_sae, n_layers, d_in - # W_dec_norms: d_sae, 1, 1 - # W_enc: n_layers, d_in, d_sae - # desired W_enc_norms: 1, 1, d_sae W_dec_norms = self.W_dec.norm(dim=[-2,-1], keepdim=True) self.W_dec.data = self.W_dec.data / W_dec_norms self.W_enc.data = self.W_enc.data * einops.rearrange( diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 4f645c58e..23eaa2440 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -465,7 +465,7 @@ def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e ): # temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works self.estimated_norm_scaling_factor = torch.ones(1) - acts = self.next_batch()[:, 0] + acts = self.next_batch() self.estimated_norm_scaling_factor = None norms_per_batch[:, batch_i] = acts.norm(dim=-1).mean(dim=0) mean_norm = norms_per_batch.mean(dim=1) From 09abaab5dcf9ed683bb22f1c07dce93ddb8d0c23 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 27 Apr 2025 18:43:50 -0400 Subject: [PATCH 36/61] trim CrosscoderSAETrainer --- sae_lens/training/crosscoder_sae_trainer.py | 34 ++------------------- 1 file changed, 3 insertions(+), 31 deletions(-) diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py index 9d8c19ef3..6f6b58e4d 100644 --- a/sae_lens/training/crosscoder_sae_trainer.py +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -14,7 +14,7 @@ class CrosscoderSAETrainer(SAETrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # TODO(mkbehr) hardcoding causal evals off for now + # Reconstruction metrics don't make sense for acausal crosscoders. self.trainer_eval_config.compute_ce_loss=False self.trainer_eval_config.compute_kl=False @@ -67,44 +67,16 @@ def _build_train_step_log_dict( output: TrainStepOutput, n_training_tokens: int, ) -> dict[str, Any]: - sae_in = output.sae_in - sae_out = output.sae_out - feature_acts = output.feature_acts - loss = output.loss.item() - - # metrics for currents acts - l0 = (feature_acts > 0).float().sum(-1).mean() - current_learning_rate = self.optimizer.param_groups[0]["lr"] + log_dict = super()._build_train_step_log_dict(output, n_training_tokens) per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=(-2, -1)).squeeze() total_variance = (sae_in - sae_in.mean(0)).pow(2).sum((-2, -1)) explained_variance = 1 - per_token_l2_loss / total_variance - log_dict = { - # losses - "losses/overall_loss": loss, - # variance explained + log_dict |= { "metrics/explained_variance": explained_variance.mean().item(), "metrics/explained_variance_std": explained_variance.std().item(), - "metrics/l0": l0.item(), - # sparsity - "sparsity/mean_passes_since_fired": self.n_forward_passes_since_fired.mean().item(), - "sparsity/dead_features": self.dead_neurons.sum().item(), - "details/current_learning_rate": current_learning_rate, - "details/current_l1_coefficient": self.current_l1_coefficient, - "details/n_training_tokens": n_training_tokens, } - for loss_name, loss_value in output.losses.items(): - loss_item = _unwrap_item(loss_value) - # special case for l1 loss, which we normalize by the l1 coefficient - if loss_name == "l1_loss": - log_dict[f"losses/{loss_name}"] = ( - loss_item / self.current_l1_coefficient - ) - log_dict[f"losses/raw_{loss_name}"] = loss_item - else: - log_dict[f"losses/{loss_name}"] = loss_item - return log_dict @torch.no_grad() From e2deb2b89129ec07c3b5d97d92a64f88bc77b4e8 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Mon, 28 Apr 2025 18:03:25 -0400 Subject: [PATCH 37/61] TODO notes in crosscoder trainer --- sae_lens/training/crosscoder_sae_trainer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py index 6f6b58e4d..a87d2e84d 100644 --- a/sae_lens/training/crosscoder_sae_trainer.py +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -10,6 +10,12 @@ from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE, TrainStepOutput # TODO(mkbehr): probably too much copypasting here +# why do I think that? +# - fit is long +# - all it does is take the whole batch instead of the first layer +# - maybe a helper method to subclass? +# - _run_and_log_evals is long +# - all it does differently is W_dec_norms (and presumably other architectures' things once those are implemented) class CrosscoderSAETrainer(SAETrainer): def __init__(self, *args, **kwargs): From f2ea4600ce7bbfee624151d8b9faee3308d959bb Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Mon, 28 Apr 2025 18:12:38 -0400 Subject: [PATCH 38/61] Change hook name syntax from {} to {layer} --- sae_lens/crosscoder_sae.py | 2 +- sae_lens/training/activations_store.py | 4 ++-- scripts/global_acausal_crosscoder.py | 2 +- tests/test_evals.py | 2 +- .../test_activations_store_multilayer.py | 18 +++++++++--------- tests/training/test_crosscoder_sae.py | 12 ++++++------ tests/training/test_crosscoder_sae_trainer.py | 4 ++-- tests/training/test_crosscoder_sae_training.py | 6 +++--- 8 files changed, 25 insertions(+), 25 deletions(-) diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index 1b4bb45c2..6c18270c2 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -19,7 +19,7 @@ def to_dict(self) -> dict[str, Any]: def hook_names(self) -> List[str]: # TODO(mkbehr): better config setup than putting a magic # string in the name - return [self.hook_name.format(layer) + return [self.hook_name.format(layer=layer) for layer in self.hook_layers] diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 23eaa2440..3b5f31251 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -376,8 +376,8 @@ def _iterate_tokenized_sequences(self) -> Generator[torch.Tensor, None, None]: def hook_names(self) -> List[str]: # TODO(mkbehr): better config setup than putting a magic # string in the name - if "{}" in self.hook_name: - return [self.hook_name.format(layer) + if "{layer}" in self.hook_name: + return [self.hook_name.format(layer=layer) for layer in self.hook_layers] return [self.hook_name] diff --git a/scripts/global_acausal_crosscoder.py b/scripts/global_acausal_crosscoder.py index 4b958a5b7..bc8719c33 100644 --- a/scripts/global_acausal_crosscoder.py +++ b/scripts/global_acausal_crosscoder.py @@ -66,7 +66,7 @@ cfg = LanguageModelSAERunnerConfig( model_name=model_name, - hook_name="blocks.{}.hook_mlp_out", + hook_name="blocks.{layer}.hook_mlp_out", hook_layers=layers, d_in=d_in, dataset_path=dataset_path, diff --git a/tests/test_evals.py b/tests/test_evals.py index 36c6222ad..17597e85d 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -292,7 +292,7 @@ def test_run_evals_crosscoder_training_sae(model): cfg=build_sae_cfg( model_name="tiny-stories-1M", dataset_path="roneneldan/TinyStories", - hook_name="blocks.{}.hook_resid_pre", + hook_name="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1], d_in=64, normalize_sae_decoder=False, diff --git a/tests/training/test_activations_store_multilayer.py b/tests/training/test_activations_store_multilayer.py index fc6d2d7ef..b4cd85e3e 100644 --- a/tests/training/test_activations_store_multilayer.py +++ b/tests/training/test_activations_store_multilayer.py @@ -13,7 +13,7 @@ def test_activations_store_init_with_multiple_layers(ts_model: HookedTransformer """Test initialization with a list of layers instead of a single layer.""" # Initialize with multiple layers cfg = build_sae_cfg( - hook_name="blocks.{}.hook_resid_pre", + hook_name="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2] ) @@ -24,7 +24,7 @@ def test_activations_store_init_with_multiple_layers(ts_model: HookedTransformer # Verify backward compatibility - a single hook_layer should be converted to a list cfg_single = build_sae_cfg( - hook_name="blocks.{}.hook_resid_pre", + hook_name="blocks.{layer}.hook_resid_pre", hook_layer=1 ) @@ -36,7 +36,7 @@ def test_activations_store_get_activations_multiple_layers(ts_model: HookedTrans """Test that get_activations collects activations from all specified layers.""" # Setup with multiple layers cfg = build_sae_cfg( - hook_name="blocks.{}.hook_resid_pre", + hook_name="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], context_size=5 ) @@ -77,7 +77,7 @@ def test_activations_store_get_buffer_multiple_layers(ts_model: HookedTransforme """Test buffer handling with multiple layers.""" # Setup with multiple layers cfg = build_sae_cfg( - hook_name="blocks.{}.hook_resid_pre", + hook_name="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], context_size=5 ) @@ -98,7 +98,7 @@ def test_activations_store_next_batch_multiple_layers(ts_model: HookedTransforme """Test that next_batch returns correct batch shape with multiple layers.""" # Setup with multiple layers cfg = build_sae_cfg( - hook_name="blocks.{}.hook_resid_pre", + hook_name="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], context_size=5, train_batch_size_tokens=10 @@ -115,7 +115,7 @@ def test_activations_store_normalization_multiple_layers(ts_model: HookedTransfo """Test normalization when using multiple layers.""" # Setup with normalization and multiple layers cfg = build_sae_cfg( - hook_name="blocks.{}.hook_resid_pre", + hook_name="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], normalize_activations="expected_average_only_in", context_size=5 @@ -152,7 +152,7 @@ def test_backward_compatibility_single_layer(ts_model: HookedTransformer): # Create a store with single layer (new behavior) cfg_multi = build_sae_cfg( - hook_name="blocks.{}.hook_resid_pre", + hook_name="blocks.{layer}.hook_resid_pre", hook_layers=[0], context_size=5 ) @@ -174,9 +174,9 @@ def test_backward_compatibility_single_layer(ts_model: HookedTransformer): def test_mixed_hook_formats(ts_model: HookedTransformer): """Test that both formatted and non-formatted hook names work with multiple layers.""" - # Test with formatted hook name (with {}) + # Test with formatted hook name (with {layer}) cfg_formatted = build_sae_cfg( - hook_name="blocks.{}.hook_resid_pre", + hook_name="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1], context_size=5 ) diff --git a/tests/training/test_crosscoder_sae.py b/tests/training/test_crosscoder_sae.py index e1443ad45..15843c22a 100644 --- a/tests/training/test_crosscoder_sae.py +++ b/tests/training/test_crosscoder_sae.py @@ -20,14 +20,14 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "hook_name": "blocks.{}.hook_resid_pre", + "hook_name": "blocks.{layer}.hook_resid_pre", "hook_layers": [1,2,3], "d_in": 64, }, { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "hook_name": "blocks.{}.hook_resid_pre", + "hook_name": "blocks.{layer}.hook_resid_pre", "hook_layers": [1,2,3], "d_in": 64, "normalize_sae_decoder": False, @@ -36,7 +36,7 @@ { "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", - "hook_name": "blocks.{}.hook_resid_pre", + "hook_name": "blocks.{layer}.hook_resid_pre", "hook_layers": [1,2,3], "d_in": 64, }, @@ -44,7 +44,7 @@ # { # "model_name": "tiny-stories-1M", # "dataset_path": "roneneldan/TinyStories", - # "hook_name": "blocks.{}.attn.hook_z", + # "hook_name": "blocks.{layer}.attn.hook_z", # "hook_layers": [1,2,3], # "d_in": 64, # }, @@ -315,6 +315,6 @@ def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: assert torch.allclose(sae_out_1, sae_out_2) def test_sae_get_name_returns_correct_name_from_cfg_vals() -> None: - cfg = build_sae_cfg(model_name="test_model", hook_name="blocks.{}.test_hook_name", d_sae=128, hook_layers=[1,2,3]) + cfg = build_sae_cfg(model_name="test_model", hook_name="blocks.{layer}.test_hook_name", d_sae=128, hook_layers=[1,2,3]) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) - assert sae.get_name() == "sae_test_model_blocks.{}.test_hook_name_layers1,2,3_128" + assert sae.get_name() == "sae_test_model_blocks.{layer}.test_hook_name_layers1,2,3_128" diff --git a/tests/training/test_crosscoder_sae_trainer.py b/tests/training/test_crosscoder_sae_trainer.py index 376765187..b51a4f449 100644 --- a/tests/training/test_crosscoder_sae_trainer.py +++ b/tests/training/test_crosscoder_sae_trainer.py @@ -26,7 +26,7 @@ def cfg(): return build_sae_cfg( d_in=64, d_sae=128, - hook_name="blocks.{}.hook_mlp_out", + hook_name="blocks.{layer}.hook_mlp_out", hook_layers=[1,2,3], normalize_sae_decoder=False, scale_sparsity_penalty_by_decoder_norm=True, @@ -210,7 +210,7 @@ def test_train_sae_group_on_language_model__runs( checkpoint_path=str(checkpoint_dir), training_tokens=20, context_size=8, - hook_name="blocks.{}.hook_mlp_out", + hook_name="blocks.{layer}.hook_mlp_out", hook_layers=[1,2,3], normalize_sae_decoder=False, scale_sparsity_penalty_by_decoder_norm=True, diff --git a/tests/training/test_crosscoder_sae_training.py b/tests/training/test_crosscoder_sae_training.py index 2a5a92c12..082fa42e1 100644 --- a/tests/training/test_crosscoder_sae_training.py +++ b/tests/training/test_crosscoder_sae_training.py @@ -22,7 +22,7 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "hook_name": "blocks.{}.hook_resid_pre", + "hook_name": "blocks.{layer}.hook_resid_pre", "hook_layers": [1,2,3], "d_in": 64, "normalize_sae_decoder": False, @@ -31,7 +31,7 @@ { "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", - "hook_name": "blocks.{}.hook_resid_pre", + "hook_name": "blocks.{layer}.hook_resid_pre", "hook_layers": [1,2,3], "d_in": 64, "normalize_sae_decoder": False, @@ -40,7 +40,7 @@ { "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", - "hook_name": "blocks.{}.hook_resid_pre", + "hook_name": "blocks.{layer}.hook_resid_pre", "hook_layers": [1,2,3], "d_in": 64, "normalize_activations": "constant_norm_rescale", From f8107b0574c238ebfc79d35c89f6ab0304feceab Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Mon, 28 Apr 2025 18:26:40 -0400 Subject: [PATCH 39/61] fix evals_test --- sae_lens/crosscoder_sae.py | 5 +++++ sae_lens/evals.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index 6c18270c2..ccb6b02a8 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -11,6 +11,11 @@ class CrosscoderSAEConfig(SAEConfig): hook_layers: list[int] = list + def __post_init__(self): + # For purposes of running the model, hook_layer is the last + # affected layer. + self.hook_layer = max(self.hook_layers) + def to_dict(self) -> dict[str, Any]: return super().to_dict() | { "hook_layers": self.hook_layers, diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 9bb2b8786..74c45518e 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -435,7 +435,7 @@ def get_sparsity_and_variance_metrics( batch_tokens, prepend_bos=False, names_filter=hook_names, - stop_at_layer=max(sae.cfg.hook_layers) + 1, + stop_at_layer=sae.cfg.hook_layer + 1, **model_kwargs, ) From 41dbb5b99fa2c0e181c60bd4f37d53c0948b2d23 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Mon, 28 Apr 2025 18:50:52 -0400 Subject: [PATCH 40/61] fix activations store test --- sae_lens/training/activations_store.py | 28 ++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 3b5f31251..0cc43a498 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -168,6 +168,7 @@ def from_sae( d_in=sae.cfg.d_in, hook_name=sae.cfg.hook_name, hook_layer=sae.cfg.hook_layer, + # TODO(mkbehr): set hook_layers if set in sae config hook_head_index=sae.cfg.hook_head_index, context_size=sae.cfg.context_size if context_size is None else context_size, prepend_bos=sae.cfg.prepend_bos, @@ -374,9 +375,8 @@ def _iterate_tokenized_sequences(self) -> Generator[torch.Tensor, None, None]: ) def hook_names(self) -> List[str]: - # TODO(mkbehr): better config setup than putting a magic - # string in the name - if "{layer}" in self.hook_name: + # TODO(mkbehr): better config setup than len(hook_layers) + if len(self.hook_layers) > 1: return [self.hook_name.format(layer=layer) for layer in self.hook_layers] return [self.hook_name] @@ -442,14 +442,24 @@ def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: raise ValueError( "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first" ) - return activations * self.estimated_norm_scaling_factor.unsqueeze(-1).to(activations.device) + # TODO(mkbehr): better config setup than len(hook_layers) + if len(self.hook_layers) > 1: + # TODO(mkbehr): set the device somewhere better + return activations * self.estimated_norm_scaling_factor.unsqueeze(-1).to(activations.device) + else: + return activations * self.estimated_norm_scaling_factor def unscale(self, activations: torch.Tensor) -> torch.Tensor: if self.estimated_norm_scaling_factor is None: raise ValueError( "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first" ) - return activations / self.estimated_norm_scaling_factor.unsqueeze(-1).to(activations.device) + # TODO(mkbehr): better config setup than len(hook_layers) + if len(self.hook_layers) > 1: + # TODO(mkbehr): set the device somewhere better + return activations / self.estimated_norm_scaling_factor.unsqueeze(-1).to(activations.device) + else: + return activations / self.estimated_norm_scaling_factor def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: return (self.d_in**0.5) / activations.norm(dim=-1).mean() @@ -469,9 +479,11 @@ def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e self.estimated_norm_scaling_factor = None norms_per_batch[:, batch_i] = acts.norm(dim=-1).mean(dim=0) mean_norm = norms_per_batch.mean(dim=1) - # TODO(mkbehr): make this a float in single-layer case for - # backwards compatibility - return (np.sqrt(self.d_in) / mean_norm) + # TODO(mkbehr): better config setup than len(hook_layers) + if len(self.hook_layers) > 1: + return (np.sqrt(self.d_in) / mean_norm) + else: + return (np.sqrt(self.d_in) / mean_norm.item()) def shuffle_input_dataset(self, seed: int, buffer_size: int = 1): """ From b4e6c0d5c9025ffd9c097a87c5776c826d51085d Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Mon, 28 Apr 2025 18:53:33 -0400 Subject: [PATCH 41/61] fix test_cache_activations_runner --- tests/training/test_cache_activations_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/training/test_cache_activations_runner.py b/tests/training/test_cache_activations_runner.py index b6309e8f5..7d070e962 100644 --- a/tests/training/test_cache_activations_runner.py +++ b/tests/training/test_cache_activations_runner.py @@ -271,7 +271,7 @@ def test_cache_activations_runner_with_incorrect_d_in(tmp_path: Path): runner = CacheActivationsRunner(wrong_d_in_cfg) with pytest.raises( RuntimeError, - match=r"The expanded size of the tensor \(513\) must match the existing size \(512\) at non-singleton dimension 2.", + match=r"The expanded size of the tensor \(513\) must match the existing size \(512\) at non-singleton dimension 3.", ): runner.run() From 7c090eb4c6bfacfe30454346506905c34a4a8b7a Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Mon, 28 Apr 2025 19:09:46 -0400 Subject: [PATCH 42/61] fix test_crosscoder_sae --- tests/training/test_crosscoder_sae.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/training/test_crosscoder_sae.py b/tests/training/test_crosscoder_sae.py index 15843c22a..aece7280e 100644 --- a/tests/training/test_crosscoder_sae.py +++ b/tests/training/test_crosscoder_sae.py @@ -154,7 +154,7 @@ def test_sae_fold_w_dec_norm_all_architectures(architecture: str): @torch.no_grad() def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): - norm_scaling_factor = 3.0 + norm_scaling_factor = torch.Tensor([2.0, 3.0, 4.0]) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) # make sure b_dec and b_enc are not 0s @@ -167,13 +167,13 @@ def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): assert sae2.cfg.normalize_activations == "none" - assert torch.allclose(sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor) + assert torch.allclose(sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor.reshape((-1,1,1))) # we expect activations of features to differ by W_dec norm weights. # assume activations are already scaled activations = torch.randn(10, 4, len(cfg.hook_layers), cfg.d_in, device=cfg.device) # we divide to get the unscale activations - unscaled_activations = activations / norm_scaling_factor + unscaled_activations = activations / norm_scaling_factor.unsqueeze(-1) feature_activations_1 = sae.encode(activations) # with the scaling folded in, the unscaled activations should produce the same @@ -188,7 +188,7 @@ def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): torch.testing.assert_close(feature_activations_2, feature_activations_1) sae_out_1 = sae.decode(feature_activations_1) - sae_out_2 = norm_scaling_factor * sae2.decode(feature_activations_2) + sae_out_2 = norm_scaling_factor.unsqueeze(-1) * sae2.decode(feature_activations_2) # but actual outputs should be the same torch.testing.assert_close(sae_out_1, sae_out_2) @@ -200,7 +200,7 @@ def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): if architecture != "standard": pytest.xfail("TODO(mkbehr): support other architectures") cfg = build_sae_cfg(architecture=architecture, hook_layers=[1,2,3]) - norm_scaling_factor = 3.0 + norm_scaling_factor = torch.Tensor([2.0, 3.0, 4.0]) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) # make sure all parameters are not 0s @@ -212,13 +212,13 @@ def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): assert sae2.cfg.normalize_activations == "none" - assert torch.allclose(sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor) + assert torch.allclose(sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor.reshape((-1,1,1))) # we expect activations of features to differ by W_dec norm weights. # assume activations are already scaled activations = torch.randn(10, 4, len(cfg.hook_layers), cfg.d_in, device=cfg.device) # we divide to get the unscale activations - unscaled_activations = activations / norm_scaling_factor + unscaled_activations = activations / norm_scaling_factor.unsqueeze(-1) feature_activations_1 = sae.encode(activations) # with the scaling folded in, the unscaled activations should produce the same @@ -233,7 +233,7 @@ def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): torch.testing.assert_close(feature_activations_2, feature_activations_1) sae_out_1 = sae.decode(feature_activations_1) - sae_out_2 = norm_scaling_factor * sae2.decode(feature_activations_2) + sae_out_2 = norm_scaling_factor.unsqueeze(-1) * sae2.decode(feature_activations_2) # but actual outputs should be the same torch.testing.assert_close(sae_out_1, sae_out_2) @@ -317,4 +317,4 @@ def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: def test_sae_get_name_returns_correct_name_from_cfg_vals() -> None: cfg = build_sae_cfg(model_name="test_model", hook_name="blocks.{layer}.test_hook_name", d_sae=128, hook_layers=[1,2,3]) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) - assert sae.get_name() == "sae_test_model_blocks.{layer}.test_hook_name_layers1,2,3_128" + assert sae.get_name() == "sae_test_model_blocks.{layer}.test_hook_name_layers1_2_3_128" From eff955dd8ac08cc424339a939ab5460c9162a2b2 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Mon, 28 Apr 2025 19:14:03 -0400 Subject: [PATCH 43/61] fix crosscoder sae trainer train step log dict --- sae_lens/training/crosscoder_sae_trainer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py index a87d2e84d..d44900b30 100644 --- a/sae_lens/training/crosscoder_sae_trainer.py +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -75,6 +75,8 @@ def _build_train_step_log_dict( ) -> dict[str, Any]: log_dict = super()._build_train_step_log_dict(output, n_training_tokens) + sae_in = output.sae_in + sae_out = output.sae_out per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=(-2, -1)).squeeze() total_variance = (sae_in - sae_in.mean(0)).pow(2).sum((-2, -1)) explained_variance = 1 - per_token_l2_loss / total_variance From 406202e95bc57f72913556d8e4f8aa6f9109b538 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sat, 3 May 2025 14:06:01 -0400 Subject: [PATCH 44/61] Configure crosscoder decoder init norms --- sae_lens/training/training_crosscoder_sae.py | 8 ++------ scripts/global_acausal_crosscoder.py | 1 + tests/training/test_training_crosscoder_sae.py | 16 ++++++++++++++++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py index 860c2afe7..3caa4c6be 100644 --- a/sae_lens/training/training_crosscoder_sae.py +++ b/sae_lens/training/training_crosscoder_sae.py @@ -66,6 +66,7 @@ def from_sae_runner_config( decoder_orthogonal_init=cfg.decoder_orthogonal_init, mse_loss_normalization=cfg.mse_loss_normalization, decoder_heuristic_init=cfg.decoder_heuristic_init, + decoder_heuristic_init_norm=cfg.decoder_heuristic_init_norm, init_encoder_as_decoder_transpose=cfg.init_encoder_as_decoder_transpose, scale_sparsity_penalty_by_decoder_norm=cfg.scale_sparsity_penalty_by_decoder_norm, normalize_activations=cfg.normalize_activations, @@ -242,12 +243,7 @@ def initialize_weights_complex(self): self.cfg.d_sae, *self.input_shape(), dtype=self.dtype, device=self.device ) ) - self.W_dec.data = ( - self.W_dec.data - / self.W_dec.data.norm(dim=-1, keepdim=True) - * 0.1 # TODO(mkbehr): make norm configurable - ) - self.initialize_decoder_norm_constant_norm() + self.initialize_decoder_norm_constant_norm(self.cfg.decoder_heuristic_init_norm) # Then we initialize the encoder weights (either as the transpose of decoder or not) if self.cfg.init_encoder_as_decoder_transpose: diff --git a/scripts/global_acausal_crosscoder.py b/scripts/global_acausal_crosscoder.py index bc8719c33..de4eb47a8 100644 --- a/scripts/global_acausal_crosscoder.py +++ b/scripts/global_acausal_crosscoder.py @@ -109,6 +109,7 @@ b_dec_init_method="zeros", normalize_sae_decoder=False, decoder_heuristic_init=True, + decoder_heuristic_init_norm=0.1, init_encoder_as_decoder_transpose=True, # Optimizer lr=learning_rate, diff --git a/tests/training/test_training_crosscoder_sae.py b/tests/training/test_training_crosscoder_sae.py index 890884957..b32817398 100644 --- a/tests/training/test_training_crosscoder_sae.py +++ b/tests/training/test_training_crosscoder_sae.py @@ -57,3 +57,19 @@ def test_TrainingCrosscoderSAE_encode_returns_same_value_as_encode_with_hidden_p encode_out = sae.encode(x) encode_with_hidden_pre_out = sae.encode_with_hidden_pre_fn(x)[0] assert torch.allclose(encode_out, encode_with_hidden_pre_out) + +def test_TrainingCrosscoderSAE_heuristic_init(): + cfg = build_crosscoder_sae_cfg( + d_in=3, + d_sae=5, + decoder_heuristic_init=True, + decoder_heuristic_init_norm=0.2, + ) + sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), + use_error_term=True) + print(sae.W_dec.norm(dim=0)) + print(sae.W_dec.norm(dim=1)) + print(sae.W_dec.norm(dim=2)) + torch.testing.assert_close(sae.W_dec.norm(dim=[1,2]), + torch.full((5,), 0.2)) From 8b397d7e3354edab12d46332444c7989dd57340e Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 4 May 2025 14:52:59 -0400 Subject: [PATCH 45/61] Config rework (most tests fail) --- sae_lens/config.py | 14 +++------ sae_lens/crosscoder_sae.py | 23 ++------------ sae_lens/training/activations_store.py | 42 +++++++++++--------------- 3 files changed, 26 insertions(+), 53 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index d4cee5049..5702b9de7 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -62,8 +62,8 @@ class LanguageModelSAERunnerConfig: architecture (str): The architecture to use, either "standard", "gated", "topk", or "jumprelu". model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub. model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`. - TODO(mkbehr): update hook name param docs for multilayer case hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook. + hook_names (list[str], optional): The names of multiple hooks to use, in order of evaluation. If this is nonempty, a CrosscoderSAE will be used. hook_name should be a descriptive name, and hook_layer should be the index of the last layer to hook. hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation. hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing. hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here. @@ -148,9 +148,9 @@ class LanguageModelSAERunnerConfig: model_name: str = "gelu-2l" model_class_name: str = "HookedTransformer" hook_name: str = "blocks.0.hook_mlp_out" + hook_names: list[str] = list hook_eval: str = "NOT_IN_USE" hook_layer: int = 0 - hook_layers: list[int] | None = None hook_head_index: int | None = None dataset_path: str = "" dataset_trust_remote_code: bool = True @@ -446,6 +446,7 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]: "device": self.device, "model_name": self.model_name, "hook_name": self.hook_name, + "hook_names": self.hook_names, "hook_layer": self.hook_layer, "hook_head_index": self.hook_head_index, "activation_fn_str": self.activation_fn, @@ -462,9 +463,6 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]: "seqpos_slice": self.seqpos_slice, } - if self.hook_layers is not None: - cfg_dict["hook_layers"] = self.hook_layers - return cfg_dict def get_training_sae_cfg_dict(self) -> dict[str, Any]: @@ -528,6 +526,7 @@ class CacheActivationsRunnerConfig: model_name (str): The name of the model to use. model_batch_size (int): How many prompts are in the batch of the language model when generating activations. hook_name (str): The name of the hook to use. + hook_names (list[str], optional): The names of multiple hooks to use, in order of evaluation. If this is nonempty, a CrosscoderSAE will be used. hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name. d_in (int): Dimension of the model. total_training_tokens (int): Total number of tokens to process. @@ -561,8 +560,8 @@ class CacheActivationsRunnerConfig: hook_layer: int d_in: int training_tokens: int - hook_layers: list[int] | None = None + hook_names: list[str] = list context_size: int = -1 # Required if dataset is not tokenized model_class_name: str = "HookedTransformer" # defaults to "activations/{dataset}/{model}/{hook_name} @@ -617,9 +616,6 @@ def __post_init__(self): if self.new_cached_activations_path is None: hook_name_str = self.hook_name - if self.hook_layers is not None: - # TODO(mkbehr): ensure the multilayer activation path makes sense - hook_name_str = f"{self.hook_name}_layers_{'_'.join(str(l) for l in self.hook_layers)}" self.new_cached_activations_path = _default_cached_activations_path( # type: ignore self.dataset_path, self.model_name, hook_name_str, None ) diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index ccb6b02a8..bf14fefe2 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -9,25 +9,13 @@ @dataclass class CrosscoderSAEConfig(SAEConfig): - hook_layers: list[int] = list - - def __post_init__(self): - # For purposes of running the model, hook_layer is the last - # affected layer. - self.hook_layer = max(self.hook_layers) + hook_names: list[int] = list def to_dict(self) -> dict[str, Any]: return super().to_dict() | { - "hook_layers": self.hook_layers, + "hook_names": self.hook_names, } - def hook_names(self) -> List[str]: - # TODO(mkbehr): better config setup than putting a magic - # string in the name - return [self.hook_name.format(layer=layer) - for layer in self.hook_layers] - - class CrosscoderSAE(SAE): """ Sparse autoencoder that acts on multiple layers of activations. @@ -50,17 +38,12 @@ def __init__( if self.hook_z_reshaping_mode: raise NotImplementedError("TODO(mkbehr): support hook_z") - def get_name(self): - # TODO(mkbehr): think about the correct name - layers = '_'.join([str(l) for l in self.cfg.hook_layers]) - return f"sae_{self.cfg.model_name}_{self.cfg.hook_name}_layers{layers}_{self.cfg.d_sae}" - @classmethod def from_dict(cls, config_dict: dict[str, Any]) -> "CrosscoderSAE": return cls(CrosscoderSAEConfig.from_dict(config_dict)) def input_shape(self): - return (len(self.cfg.hook_layers), self.cfg.d_in) + return (len(self.cfg.hook_names), self.cfg.d_in) def encode_standard( diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 0cc43a498..e1f56e387 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -45,8 +45,8 @@ class ActivationsStore: cached_activation_dataset: Dataset | None = None tokens_column: Literal["tokens", "input_ids", "text", "problem"] hook_name: str + hook_names: list[str] hook_layer: int - hook_layers: list[int] hook_head_index: int | None _dataloader: Iterator[Any] | None = None _storage_buffer: torch.Tensor | None = None @@ -127,8 +127,8 @@ def from_config( dataset=override_dataset or cfg.dataset_path, streaming=cfg.streaming, hook_name=cfg.hook_name, + hook_names=cfg.hook_names, hook_layer=cfg.hook_layer, - hook_layers=cfg.hook_layers, hook_head_index=cfg.hook_head_index, context_size=cfg.context_size, d_in=cfg.d_in, @@ -167,8 +167,8 @@ def from_sae( dataset=sae.cfg.dataset_path if dataset is None else dataset, d_in=sae.cfg.d_in, hook_name=sae.cfg.hook_name, + hook_names=sae.cfg.hook_names, hook_layer=sae.cfg.hook_layer, - # TODO(mkbehr): set hook_layers if set in sae config hook_head_index=sae.cfg.hook_head_index, context_size=sae.cfg.context_size if context_size is None else context_size, prepend_bos=sae.cfg.prepend_bos, @@ -202,7 +202,7 @@ def __init__( normalize_activations: str, device: torch.device, dtype: str, - hook_layers: list[int] | None = None, + hook_names: list[str] | None = None, cached_activations_path: str | None = None, model_kwargs: dict[str, Any] | None = None, autocast_lm: bool = False, @@ -235,8 +235,8 @@ def __init__( ) self.hook_name = hook_name + self.hook_names = hook_names self.hook_layer = hook_layer - self.hook_layers = hook_layers or [hook_layer] self.hook_head_index = hook_head_index self.context_size = context_size self.d_in = d_in @@ -374,13 +374,6 @@ def _iterate_tokenized_sequences(self) -> Generator[torch.Tensor, None, None]: ), ) - def hook_names(self) -> List[str]: - # TODO(mkbehr): better config setup than len(hook_layers) - if len(self.hook_layers) > 1: - return [self.hook_name.format(layer=layer) - for layer in self.hook_layers] - return [self.hook_name] - def load_cached_activation_dataset(self) -> Dataset | None: """ Load the cached activation dataset from disk. @@ -403,7 +396,7 @@ def load_cached_activation_dataset(self) -> Dataset | None: # Actual code activations_dataset = datasets.load_from_disk(self.cached_activations_path) # TODO(mkbehr): test multiple layers - columns = self.hook_names() + columns = self.hook_names or [self.hook_name] if "token_ids" in activations_dataset.column_names: columns.append("token_ids") activations_dataset.set_format( @@ -442,8 +435,9 @@ def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: raise ValueError( "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first" ) - # TODO(mkbehr): better config setup than len(hook_layers) - if len(self.hook_layers) > 1: + # Norm scaling factor is a float in the single-layer case, and + # a tensor in the multilayer case. + if self.hook_names: # TODO(mkbehr): set the device somewhere better return activations * self.estimated_norm_scaling_factor.unsqueeze(-1).to(activations.device) else: @@ -454,8 +448,7 @@ def unscale(self, activations: torch.Tensor) -> torch.Tensor: raise ValueError( "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first" ) - # TODO(mkbehr): better config setup than len(hook_layers) - if len(self.hook_layers) > 1: + if self.hook_names: # TODO(mkbehr): set the device somewhere better return activations / self.estimated_norm_scaling_factor.unsqueeze(-1).to(activations.device) else: @@ -468,7 +461,7 @@ def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)): # TODO(mkbehr): test multilayer norm scaling, probably fix saving? norms_per_batch = torch.empty( - len(self.hook_layers), n_batches_for_norm_estimate, + len(self.hook_names) or 1, n_batches_for_norm_estimate, device=self.device) for batch_i in tqdm( range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor" @@ -479,8 +472,9 @@ def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e self.estimated_norm_scaling_factor = None norms_per_batch[:, batch_i] = acts.norm(dim=-1).mean(dim=0) mean_norm = norms_per_batch.mean(dim=1) - # TODO(mkbehr): better config setup than len(hook_layers) - if len(self.hook_layers) > 1: + # Norm scaling factor is a float in the single-layer case, and + # a tensor in the multilayer case. + if self.hook_names: return (np.sqrt(self.d_in) / mean_norm) else: return (np.sqrt(self.d_in) / mean_norm.item()) @@ -563,8 +557,8 @@ def get_activations(self, batch_tokens: torch.Tensor): else: autocast_if_enabled = contextlib.nullcontext() - hook_names = self.hook_names() - stop_at_layer = max(self.hook_layers) + 1 + hook_names = self.hook_names or [self.hook_name] + stop_at_layer = self.hook_layer + 1 with autocast_if_enabled: layerwise_activations_cache = self.model.run_with_cache( @@ -625,7 +619,7 @@ def _load_buffer_from_cached( raises StopIteration """ assert self.cached_activation_dataset is not None - hook_names = self.hook_names() + hook_names = self.hook_names or [self.hook_name] if not set(hook_names).issubset(self.cached_activation_dataset.column_names): raise ValueError( f"Missing columns in dataset. Expected {hook_names}, " @@ -695,7 +689,7 @@ def get_buffer( batch_size = self.store_batch_size_prompts d_in = self.d_in total_size = batch_size * n_batches_in_buffer - num_layers = len(self.hook_layers) + num_layers = len(self.hook_names) or 1 if self.cached_activation_dataset is not None: return self._load_buffer_from_cached( From 190d02265fd8c3b98371dcf6359f10df9833fce6 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 4 May 2025 15:21:56 -0400 Subject: [PATCH 46/61] fix test_activations_store_multilayer --- tests/helpers.py | 20 ++- .../test_activations_store_multilayer.py | 115 +++++------------- 2 files changed, 51 insertions(+), 84 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index b37c03b07..2e3999698 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -16,7 +16,7 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False): model_name: str hook_name: str hook_layer: int - hook_layers: list[int] | None + hook_names: list[int] | None hook_head_index: int | None dataset_path: str dataset_trust_remote_code: bool @@ -54,8 +54,8 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: mock_config_dict: LanguageModelSAERunnerConfigDict = { "model_name": TINYSTORIES_MODEL, "hook_name": "blocks.0.hook_mlp_out", + "hook_names": [], "hook_layer": 0, - "hook_layers": None, "hook_head_index": None, # use a small, non-streaming dataset for testing. Huggingface gives too many requests errors otherwise. "dataset_path": NEEL_NANDA_C4_10K_DATASET, @@ -97,6 +97,22 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: return mock_config +def build_multilayer_sae_cfg( + hook_name_template : str = "blocks.{layer}.hook_mlp_out", + hook_layers = [0,1,2], + **kwargs: Any) -> LanguageModelSAERunnerConfig: + hook_name = hook_name_template.format( + layer=f"layers_{min(hook_layers)}_through_{max(hook_layers)}" + ) + hook_names = [hook_name_template.format(layer=layer) for layer in hook_layers] + return build_sae_cfg( + **({ + "hook_name": hook_name, + "hook_names": hook_names, + "hook_layer": max(hook_layers), + } + | kwargs)) + MODEL_CACHE: dict[str, HookedTransformer] = {} diff --git a/tests/training/test_activations_store_multilayer.py b/tests/training/test_activations_store_multilayer.py index b4cd85e3e..77c5a909d 100644 --- a/tests/training/test_activations_store_multilayer.py +++ b/tests/training/test_activations_store_multilayer.py @@ -6,37 +6,41 @@ from transformer_lens import HookedTransformer from sae_lens.training.activations_store import ActivationsStore -from tests.helpers import build_sae_cfg, load_model_cached +from tests.helpers import build_sae_cfg, build_multilayer_sae_cfg, load_model_cached def test_activations_store_init_with_multiple_layers(ts_model: HookedTransformer): """Test initialization with a list of layers instead of a single layer.""" # Initialize with multiple layers - cfg = build_sae_cfg( - hook_name="blocks.{layer}.hook_resid_pre", + cfg = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2] ) activation_store = ActivationsStore.from_config(ts_model, cfg) - # Check that the hook layers are correctly stored - assert activation_store.hook_layers == [0, 1, 2] + assert activation_store.hook_names == [ + "blocks.0.hook_resid_pre", + "blocks.1.hook_resid_pre", + "blocks.2.hook_resid_pre", + ] - # Verify backward compatibility - a single hook_layer should be converted to a list - cfg_single = build_sae_cfg( - hook_name="blocks.{layer}.hook_resid_pre", - hook_layer=1 + cfg_single = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", + hook_layers=[1] ) single_layer_store = ActivationsStore.from_config(ts_model, cfg_single) - assert single_layer_store.hook_layers == [1] + assert single_layer_store.hook_names == [ + "blocks.1.hook_resid_pre", + ] def test_activations_store_get_activations_multiple_layers(ts_model: HookedTransformer): """Test that get_activations collects activations from all specified layers.""" # Setup with multiple layers - cfg = build_sae_cfg( - hook_name="blocks.{layer}.hook_resid_pre", + cfg = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], context_size=5 ) @@ -50,10 +54,10 @@ def test_activations_store_get_activations_multiple_layers(ts_model: HookedTrans # Check shape: [batch_size, context_size, num_layers, d_in] assert activations.shape == ( - activation_store.store_batch_size_prompts, - activation_store.context_size, - len(activation_store.hook_layers), - activation_store.d_in + cfg.store_batch_size_prompts, + cfg.context_size, + len(cfg.hook_names), + cfg.d_in ) # Verify that layers are in the correct order @@ -76,8 +80,8 @@ def test_activations_store_get_activations_multiple_layers(ts_model: HookedTrans def test_activations_store_get_buffer_multiple_layers(ts_model: HookedTransformer): """Test buffer handling with multiple layers.""" # Setup with multiple layers - cfg = build_sae_cfg( - hook_name="blocks.{layer}.hook_resid_pre", + cfg = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], context_size=5 ) @@ -89,16 +93,16 @@ def test_activations_store_get_buffer_multiple_layers(ts_model: HookedTransforme buffer_activations, buffer_tokens = activation_store.get_buffer(n_batches_in_buffer=2) # Check shape: [(batch_size * context_size * n_batches), num_layers, d_in] - expected_size = activation_store.store_batch_size_prompts * activation_store.context_size * 2 - assert buffer_activations.shape == (expected_size, len(activation_store.hook_layers), activation_store.d_in) + expected_size = cfg.store_batch_size_prompts * cfg.context_size * 2 + assert buffer_activations.shape == (expected_size, len(cfg.hook_names), cfg.d_in) assert buffer_tokens.shape == (expected_size,) def test_activations_store_next_batch_multiple_layers(ts_model: HookedTransformer): """Test that next_batch returns correct batch shape with multiple layers.""" # Setup with multiple layers - cfg = build_sae_cfg( - hook_name="blocks.{layer}.hook_resid_pre", + cfg = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], context_size=5, train_batch_size_tokens=10 @@ -108,14 +112,14 @@ def test_activations_store_next_batch_multiple_layers(ts_model: HookedTransforme activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) batch = activation_store.next_batch() - assert batch.shape == (10, len(cfg.hook_layers), activation_store.d_in) + assert batch.shape == (10, len(cfg.hook_names), activation_store.d_in) @pytest.mark.skip("TODO(mkbehr): does activation need to be handled differently?") def test_activations_store_normalization_multiple_layers(ts_model: HookedTransformer): """Test normalization when using multiple layers.""" # Setup with normalization and multiple layers - cfg = build_sae_cfg( - hook_name="blocks.{layer}.hook_resid_pre", + cfg = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], normalize_activations="expected_average_only_in", context_size=5 @@ -151,8 +155,8 @@ def test_backward_compatibility_single_layer(ts_model: HookedTransformer): single_store = ActivationsStore.from_config(ts_model, cfg_single, override_dataset=dataset) # Create a store with single layer (new behavior) - cfg_multi = build_sae_cfg( - hook_name="blocks.{layer}.hook_resid_pre", + cfg_multi = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0], context_size=5 ) @@ -165,58 +169,5 @@ def test_backward_compatibility_single_layer(ts_model: HookedTransformer): batch_tokens_multi = multi_store.get_batch_tokens() activations_multi = multi_store.get_activations(batch_tokens_multi) - # Check that activations have the same shape and values - assert activations_single.shape == activations_multi.shape - # Run with deterministic seed to ensure tokens are the same - if torch.allclose(batch_tokens_single, batch_tokens_multi): - assert torch.allclose(activations_single, activations_multi, atol=1e-5) - - -def test_mixed_hook_formats(ts_model: HookedTransformer): - """Test that both formatted and non-formatted hook names work with multiple layers.""" - # Test with formatted hook name (with {layer}) - cfg_formatted = build_sae_cfg( - hook_name="blocks.{layer}.hook_resid_pre", - hook_layers=[0, 1], - context_size=5 - ) - - # Test with non-formatted hook name - cfg_non_formatted = build_sae_cfg( - hook_name="blocks.0.hook_resid_pre", # Specific to layer 0 - hook_layers=[0], # Only layer 0 works with this hook - context_size=5 - ) - - dataset = Dataset.from_list([{"text": "hello world"}] * 10) - - # Both should initialize without errors - store_formatted = ActivationsStore.from_config( - ts_model, cfg_formatted, override_dataset=dataset - ) - store_non_formatted = ActivationsStore.from_config( - ts_model, cfg_non_formatted, override_dataset=dataset - ) - - # Both should be able to get activations - activations_formatted = store_formatted.get_activations( - store_formatted.get_batch_tokens() - ) - activations_non_formatted = store_non_formatted.get_activations( - store_non_formatted.get_batch_tokens() - ) - - # Check shapes - assert activations_formatted.shape == ( - store_formatted.store_batch_size_prompts, - store_formatted.context_size, - len(store_formatted.hook_layers), - store_formatted.d_in - ) - - assert activations_non_formatted.shape == ( - store_non_formatted.store_batch_size_prompts, - store_non_formatted.context_size, - len(store_non_formatted.hook_layers), - store_non_formatted.d_in - ) + torch.testing.assert_close(batch_tokens_single, batch_tokens_multi) + torch.testing.assert_close(activations_single, activations_multi) From c1149461f2242609da66b14cb19ae529c7860a95 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 4 May 2025 15:28:26 -0400 Subject: [PATCH 47/61] test_crosscoder_sae passes --- tests/training/test_crosscoder_sae.py | 56 +++++++++++++++------------ 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/tests/training/test_crosscoder_sae.py b/tests/training/test_crosscoder_sae.py index aece7280e..23f577a01 100644 --- a/tests/training/test_crosscoder_sae.py +++ b/tests/training/test_crosscoder_sae.py @@ -11,7 +11,7 @@ from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.crosscoder_sae import CrosscoderSAE from sae_lens.sae import _disable_hooks -from tests.helpers import ALL_ARCHITECTURES, build_sae_cfg +from tests.helpers import ALL_ARCHITECTURES, build_multilayer_sae_cfg # Define a new fixture for different configurations @@ -20,15 +20,17 @@ { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "hook_name": "blocks.{layer}.hook_resid_pre", - "hook_layers": [1,2,3], + "hook_name_template": "blocks.{layer}.hook_resid_pre", + "hook_layers": [0,1,2], "d_in": 64, + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, }, { "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", - "hook_name": "blocks.{layer}.hook_resid_pre", - "hook_layers": [1,2,3], + "hook_name_template": "blocks.{layer}.hook_resid_pre", + "hook_layers": [0,1,2], "d_in": 64, "normalize_sae_decoder": False, "scale_sparsity_penalty_by_decoder_norm": True, @@ -36,17 +38,21 @@ { "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", - "hook_name": "blocks.{layer}.hook_resid_pre", - "hook_layers": [1,2,3], + "hook_name_template": "blocks.{layer}.hook_resid_pre", + "hook_layers": [0,1,2], "d_in": 64, + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, }, # TODO(mkbehr): hook_z support # { # "model_name": "tiny-stories-1M", # "dataset_path": "roneneldan/TinyStories", # "hook_name": "blocks.{layer}.attn.hook_z", - # "hook_layers": [1,2,3], + # "hook_layers": [0,1,2], # "d_in": 64, + # "normalize_sae_decoder": False, + # "scale_sparsity_penalty_by_decoder_norm": True, # }, ], ids=[ @@ -61,7 +67,7 @@ def cfg(request: pytest.FixtureRequest): Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig. """ params = request.param - return build_sae_cfg(**params) + return build_multilayer_sae_cfg(**params) def test_crosscoder_sae_init(cfg: LanguageModelSAERunnerConfig): @@ -69,7 +75,7 @@ def test_crosscoder_sae_init(cfg: LanguageModelSAERunnerConfig): assert isinstance(sae, CrosscoderSAE) - n_layers = len(cfg.hook_layers) + n_layers = len(cfg.hook_names) assert sae.W_enc.shape == (n_layers, cfg.d_in, cfg.d_sae) assert sae.W_dec.shape == (cfg.d_sae, n_layers, cfg.d_in) assert sae.b_enc.shape == (cfg.d_sae,) @@ -94,7 +100,7 @@ def test_crosscoder_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): assert sae2.W_dec.norm(dim=[-2,-1]).mean().item() == pytest.approx(1.0, abs=1e-6) # we expect activations of features to differ by W_dec norm weights. - activations = torch.randn(10, 4, len(cfg.hook_layers), cfg.d_in, + activations = torch.randn(10, 4, len(cfg.hook_names), cfg.d_in, device=cfg.device) feature_activations_1 = sae.encode(activations) feature_activations_2 = sae2.encode(activations) @@ -118,7 +124,7 @@ def test_crosscoder_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): def test_sae_fold_w_dec_norm_all_architectures(architecture: str): if architecture != "standard": pytest.xfail("TODO(mkbehr): support other architectures") - cfg = build_sae_cfg(architecture=architecture, hook_layers=[1,2,3]) + cfg = build_multilayer_sae_cfg(architecture=architecture, hook_layers=[0,1,2]) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. @@ -134,7 +140,7 @@ def test_sae_fold_w_dec_norm_all_architectures(architecture: str): assert sae2.W_dec.norm(dim=[-2,-1]).mean().item() == pytest.approx(1.0, abs=1e-6) # we expect activations of features to differ by W_dec norm weights. - activations = torch.randn(10, 4, len(cfg.hook_layers), cfg.d_in, device=cfg.device) + activations = torch.randn(10, 4, len(cfg.hook_names), cfg.d_in, device=cfg.device) feature_activations_1 = sae.encode(activations) feature_activations_2 = sae2.encode(activations) @@ -158,7 +164,7 @@ def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) # make sure b_dec and b_enc are not 0s - sae.b_dec.data = torch.randn(len(cfg.hook_layers), cfg.d_in, device=cfg.device) + sae.b_dec.data = torch.randn(len(cfg.hook_names), cfg.d_in, device=cfg.device) sae.b_enc.data = torch.randn(cfg.d_sae, device=cfg.device) # type: ignore sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. @@ -171,7 +177,7 @@ def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): # we expect activations of features to differ by W_dec norm weights. # assume activations are already scaled - activations = torch.randn(10, 4, len(cfg.hook_layers), cfg.d_in, device=cfg.device) + activations = torch.randn(10, 4, len(cfg.hook_names), cfg.d_in, device=cfg.device) # we divide to get the unscale activations unscaled_activations = activations / norm_scaling_factor.unsqueeze(-1) @@ -199,7 +205,7 @@ def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): if architecture != "standard": pytest.xfail("TODO(mkbehr): support other architectures") - cfg = build_sae_cfg(architecture=architecture, hook_layers=[1,2,3]) + cfg = build_multilayer_sae_cfg(architecture=architecture, hook_layers=[0,1,2]) norm_scaling_factor = torch.Tensor([2.0, 3.0, 4.0]) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) @@ -216,7 +222,7 @@ def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): # we expect activations of features to differ by W_dec norm weights. # assume activations are already scaled - activations = torch.randn(10, 4, len(cfg.hook_layers), cfg.d_in, device=cfg.device) + activations = torch.randn(10, 4, len(cfg.hook_names), cfg.d_in, device=cfg.device) # we divide to get the unscale activations unscaled_activations = activations / norm_scaling_factor.unsqueeze(-1) @@ -239,7 +245,7 @@ def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): torch.testing.assert_close(sae_out_1, sae_out_2) def test_sae_save_and_load_from_pretrained(tmp_path: Path) -> None: - cfg = build_sae_cfg(hook_layers=[1,2,3]) + cfg = build_multilayer_sae_cfg(hook_layers=[0,1,2]) model_path = str(tmp_path) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) sae_state_dict = sae.state_dict() @@ -258,14 +264,14 @@ def test_sae_save_and_load_from_pretrained(tmp_path: Path) -> None: sae_loaded_state_dict[key], ) - sae_in = torch.randn(10, len(cfg.hook_layers), cfg.d_in, device=cfg.device) + sae_in = torch.randn(10, len(cfg.hook_names), cfg.d_in, device=cfg.device) sae_out_1 = sae(sae_in) sae_out_2 = sae_loaded(sae_in) assert torch.allclose(sae_out_1, sae_out_2) @pytest.mark.xfail(reason="TODO(mkbehr): support other architectures") def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None: - cfg = build_sae_cfg(architecture="gated", hook_layers=[1,2,3]) + cfg = build_multilayer_sae_cfg(architecture="gated", hook_layers=[0,1,2]) model_path = str(tmp_path) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) sae_state_dict = sae.state_dict() @@ -284,13 +290,13 @@ def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None: sae_loaded_state_dict[key], ) - sae_in = torch.randn(10, len(cfg.hook_layers), cfg.d_in, device=cfg.device) + sae_in = torch.randn(10, len(cfg.hook_names), cfg.d_in, device=cfg.device) sae_out_1 = sae(sae_in) sae_out_2 = sae_loaded(sae_in) assert torch.allclose(sae_out_1, sae_out_2) def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: - cfg = build_sae_cfg(activation_fn_kwargs={"k": 30}, hook_layers=[1,2,3]) + cfg = build_multilayer_sae_cfg(activation_fn_kwargs={"k": 30}, hook_layers=[0,1,2]) model_path = str(tmp_path) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) sae_state_dict = sae.state_dict() @@ -309,12 +315,12 @@ def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: sae_loaded_state_dict[key], ) - sae_in = torch.randn(10, len(cfg.hook_layers), cfg.d_in, device=cfg.device) + sae_in = torch.randn(10, len(cfg.hook_names), cfg.d_in, device=cfg.device) sae_out_1 = sae(sae_in) sae_out_2 = sae_loaded(sae_in) assert torch.allclose(sae_out_1, sae_out_2) def test_sae_get_name_returns_correct_name_from_cfg_vals() -> None: - cfg = build_sae_cfg(model_name="test_model", hook_name="blocks.{layer}.test_hook_name", d_sae=128, hook_layers=[1,2,3]) + cfg = build_multilayer_sae_cfg(model_name="test_model", hook_name_template="blocks.{layer}.test_hook_name", d_sae=128, hook_layers=[0,1,2]) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) - assert sae.get_name() == "sae_test_model_blocks.{layer}.test_hook_name_layers1_2_3_128" + assert sae.get_name() == "sae_test_model_blocks.layers_0_through_2.test_hook_name_128" From cf82b584063beef52b1fdd173f2b159817d01d47 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 4 May 2025 15:42:35 -0400 Subject: [PATCH 48/61] training/test*crosscoder* passes --- sae_lens/training/training_crosscoder_sae.py | 4 +-- tests/training/test_crosscoder_sae_trainer.py | 14 ++++---- .../training/test_crosscoder_sae_training.py | 18 +++++------ .../training/test_training_crosscoder_sae.py | 32 +++++++++---------- 4 files changed, 34 insertions(+), 34 deletions(-) diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py index 3caa4c6be..93fcda17b 100644 --- a/sae_lens/training/training_crosscoder_sae.py +++ b/sae_lens/training/training_crosscoder_sae.py @@ -45,8 +45,8 @@ def from_sae_runner_config( device=cfg.device, model_name=cfg.model_name, hook_name=cfg.hook_name, + hook_names=cfg.hook_names, hook_layer=cfg.hook_layer, - hook_layers=cfg.hook_layers, hook_head_index=cfg.hook_head_index, activation_fn_str=cfg.activation_fn, activation_fn_kwargs=cfg.activation_fn_kwargs, @@ -87,7 +87,7 @@ def to_dict(self) -> dict[str, Any]: def get_base_sae_cfg_dict(self) -> dict[str, Any]: return (TrainingSAEConfig.get_base_sae_cfg_dict(self) - | { "hook_layers": self.hook_layers }) + | { "hook_names": self.hook_names }) class TrainingCrosscoderSAE(CrosscoderSAE, TrainingSAE): # TODO(mkbehr) future implementation diff --git a/tests/training/test_crosscoder_sae_trainer.py b/tests/training/test_crosscoder_sae_trainer.py index b51a4f449..ad7ff7a50 100644 --- a/tests/training/test_crosscoder_sae_trainer.py +++ b/tests/training/test_crosscoder_sae_trainer.py @@ -18,16 +18,16 @@ _update_sae_lens_training_version, ) from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE -from tests.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached +from tests.helpers import TINYSTORIES_MODEL, build_multilayer_sae_cfg, load_model_cached @pytest.fixture def cfg(): - return build_sae_cfg( + return build_multilayer_sae_cfg( d_in=64, d_sae=128, - hook_name="blocks.{layer}.hook_mlp_out", - hook_layers=[1,2,3], + hook_name_template="blocks.{layer}.hook_mlp_out", + hook_layers=[0,1,2], normalize_sae_decoder=False, scale_sparsity_penalty_by_decoder_norm=True, ) @@ -206,12 +206,12 @@ def test_train_sae_group_on_language_model__runs( tmp_path: Path, ) -> None: checkpoint_dir = tmp_path / "checkpoint" - cfg = build_sae_cfg( + cfg = build_multilayer_sae_cfg( checkpoint_path=str(checkpoint_dir), training_tokens=20, context_size=8, - hook_name="blocks.{layer}.hook_mlp_out", - hook_layers=[1,2,3], + hook_name_template="blocks.{layer}.hook_mlp_out", + hook_layers=[0,1,2], normalize_sae_decoder=False, scale_sparsity_penalty_by_decoder_norm=True, ) diff --git a/tests/training/test_crosscoder_sae_training.py b/tests/training/test_crosscoder_sae_training.py index 082fa42e1..fe5e982ea 100644 --- a/tests/training/test_crosscoder_sae_training.py +++ b/tests/training/test_crosscoder_sae_training.py @@ -13,7 +13,7 @@ TrainingCrosscoderSAE, TrainingCrosscoderSAEConfig ) -from tests.helpers import build_sae_cfg +from tests.helpers import build_multilayer_sae_cfg # Define a new fixture for different configurations @@ -23,7 +23,7 @@ "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", "hook_name": "blocks.{layer}.hook_resid_pre", - "hook_layers": [1,2,3], + "hook_layers": [0,1,2], "d_in": 64, "normalize_sae_decoder": False, "scale_sparsity_penalty_by_decoder_norm": True, @@ -32,7 +32,7 @@ "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", "hook_name": "blocks.{layer}.hook_resid_pre", - "hook_layers": [1,2,3], + "hook_layers": [0,1,2], "d_in": 64, "normalize_sae_decoder": False, "scale_sparsity_penalty_by_decoder_norm": True, @@ -41,7 +41,7 @@ "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", "hook_name": "blocks.{layer}.hook_resid_pre", - "hook_layers": [1,2,3], + "hook_layers": [0,1,2], "d_in": 64, "normalize_activations": "constant_norm_rescale", "normalize_sae_decoder": False, @@ -59,7 +59,7 @@ def cfg(request: pytest.FixtureRequest): Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig. """ params = request.param - return build_sae_cfg(**params) + return build_multilayer_sae_cfg(**params) @pytest.fixture @@ -103,7 +103,7 @@ def trainer( def test_sae_forward(training_crosscoder_sae: TrainingCrosscoderSAE): batch_size = 32 d_in = training_crosscoder_sae.cfg.d_in - n_layers = len(training_crosscoder_sae.cfg.hook_layers) + n_layers = len(training_crosscoder_sae.cfg.hook_names) d_sae = training_crosscoder_sae.cfg.d_sae x = torch.randn(batch_size, n_layers, d_in) @@ -159,7 +159,7 @@ def test_sae_forward_with_mse_loss_norm( batch_size = 32 d_in = training_crosscoder_sae.cfg.d_in - n_layers = len(training_crosscoder_sae.cfg.hook_layers) + n_layers = len(training_crosscoder_sae.cfg.hook_names) d_sae = training_crosscoder_sae.cfg.d_sae x = torch.randn(batch_size, n_layers, d_in) @@ -212,7 +212,7 @@ def test_sae_forward_with_mse_loss_norm( def test_SparseAutoencoder_forward_can_add_noise_to_hidden_pre() -> None: - clean_cfg = build_sae_cfg( + clean_cfg = build_multilayer_sae_cfg( d_in=2, d_sae=4, noise_scale=0, @@ -220,7 +220,7 @@ def test_SparseAutoencoder_forward_can_add_noise_to_hidden_pre() -> None: normalize_sae_decoder=False, scale_sparsity_penalty_by_decoder_norm=True ) - noisy_cfg = build_sae_cfg( + noisy_cfg = build_multilayer_sae_cfg( d_in=2, d_sae=4, noise_scale=100, diff --git a/tests/training/test_training_crosscoder_sae.py b/tests/training/test_training_crosscoder_sae.py index b32817398..2a93d4ba1 100644 --- a/tests/training/test_training_crosscoder_sae.py +++ b/tests/training/test_training_crosscoder_sae.py @@ -6,20 +6,15 @@ TrainingCrosscoderSAE, TrainingCrosscoderSAEConfig, ) -from tests.helpers import build_sae_cfg - -def build_crosscoder_sae_cfg(**kwargs): - return build_sae_cfg( - **(kwargs | { - "hook_layers": [1,2,3,4], - "normalize_sae_decoder": False, - "scale_sparsity_penalty_by_decoder_norm": True, - })) +from tests.helpers import build_multilayer_sae_cfg def test_TrainingCrosscoderSAE_training_forward_pass_can_scale_sparsity_penalty_by_decoder_norm(): - cfg = build_crosscoder_sae_cfg( + cfg = build_multilayer_sae_cfg( d_in=3, d_sae=5, + hook_layers=[0,1,2,3], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, ) training_sae = TrainingCrosscoderSAE( TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), @@ -48,28 +43,33 @@ def test_TrainingCrosscoderSAE_encode_returns_same_value_as_encode_with_hidden_p ): if architecture != "standard": pytest.xfail("TODO(mkbehr): support other architectures") - cfg = build_crosscoder_sae_cfg(architecture=architecture) + cfg = build_multilayer_sae_cfg( + architecture=architecture, + hook_layers=[0,1,2,3], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + ) sae = TrainingCrosscoderSAE( TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), use_error_term=True, ) - x = torch.randn(32, len(cfg.hook_layers), cfg.d_in) + x = torch.randn(32, len(cfg.hook_names), cfg.d_in) encode_out = sae.encode(x) encode_with_hidden_pre_out = sae.encode_with_hidden_pre_fn(x)[0] assert torch.allclose(encode_out, encode_with_hidden_pre_out) def test_TrainingCrosscoderSAE_heuristic_init(): - cfg = build_crosscoder_sae_cfg( + cfg = build_multilayer_sae_cfg( d_in=3, d_sae=5, + hook_layers=[0,1,2,3], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, decoder_heuristic_init=True, decoder_heuristic_init_norm=0.2, ) sae = TrainingCrosscoderSAE( TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), use_error_term=True) - print(sae.W_dec.norm(dim=0)) - print(sae.W_dec.norm(dim=1)) - print(sae.W_dec.norm(dim=2)) torch.testing.assert_close(sae.W_dec.norm(dim=[1,2]), torch.full((5,), 0.2)) From 9a29cba883ea6adf654c0395c4c2f84fb6fe204d Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 4 May 2025 15:46:09 -0400 Subject: [PATCH 49/61] fix evals --- sae_lens/evals.py | 5 ++--- tests/test_evals.py | 6 +++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 74c45518e..7a5e3fde7 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -378,7 +378,7 @@ def get_sparsity_and_variance_metrics( ignore_tokens: set[int | None] = set(), verbose: bool = False, ) -> tuple[dict[str, Any], dict[str, Any]]: - hook_names = sae.cfg.hook_names() + hook_names = sae.cfg.hook_names or [sae.cfg.hook_name] hook_head_index = sae.cfg.hook_head_index metric_dict = {} @@ -449,8 +449,7 @@ def get_sparsity_and_variance_metrics( elif any(substring in hook_names[0] for substring in has_head_dim_key_substrings): # TODO(mkbehr) support head dimension for mutilayer evals original_act = cache[hook_names[0]].flatten(-2, -1) - elif len(hook_names) > 1: - # TODO(mkbehr): cleaner interface for multilayer evals + elif sae.cfg.hook_names: # TODO(mkbehr): support head dimension for mutilayer evals layerwise_activations = [ cache[hook_name] for hook_name in hook_names diff --git a/tests/test_evals.py b/tests/test_evals.py index 17597e85d..bc6a426b4 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -30,7 +30,7 @@ TrainingCrosscoderSAEConfig, ) from sae_lens.training.training_sae import TrainingSAE -from tests.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached +from tests.helpers import TINYSTORIES_MODEL, build_sae_cfg, build_multilayer_sae_cfg, load_model_cached TRAINER_EVAL_CONFIG = EvalConfig( n_eval_reconstruction_batches=10, @@ -289,10 +289,10 @@ def test_run_empty_evals( # TODO(mkbehr): consider parameterizing def test_run_evals_crosscoder_training_sae(model): - cfg=build_sae_cfg( + cfg=build_multilayer_sae_cfg( model_name="tiny-stories-1M", dataset_path="roneneldan/TinyStories", - hook_name="blocks.{layer}.hook_resid_pre", + hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1], d_in=64, normalize_sae_decoder=False, From 7e2a40f07d0a0c6a513b9d4054474e636fa9e81c Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 4 May 2025 16:51:00 -0400 Subject: [PATCH 50/61] fix evals again; all tests pass --- sae_lens/config.py | 4 ++-- sae_lens/crosscoder_sae.py | 4 ++-- sae_lens/evals.py | 4 ++-- sae_lens/sae.py | 3 --- sae_lens/training/activations_store.py | 2 +- tests/training/test_config.py | 1 + 6 files changed, 8 insertions(+), 10 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index 5702b9de7..b905f000a 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -148,7 +148,7 @@ class LanguageModelSAERunnerConfig: model_name: str = "gelu-2l" model_class_name: str = "HookedTransformer" hook_name: str = "blocks.0.hook_mlp_out" - hook_names: list[str] = list + hook_names: list[str] = field(default_factory=list) hook_eval: str = "NOT_IN_USE" hook_layer: int = 0 hook_head_index: int | None = None @@ -561,7 +561,7 @@ class CacheActivationsRunnerConfig: d_in: int training_tokens: int - hook_names: list[str] = list + hook_names: list[str] = field(default_factory=list) context_size: int = -1 # Required if dataset is not tokenized model_class_name: str = "HookedTransformer" # defaults to "activations/{dataset}/{model}/{hook_name} diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index bf14fefe2..5cd1b5719 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, List import einops @@ -9,7 +9,7 @@ @dataclass class CrosscoderSAEConfig(SAEConfig): - hook_names: list[int] = list + hook_names: list[int] = field(default_factory=list) def to_dict(self) -> dict[str, Any]: return super().to_dict() | { diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 7a5e3fde7..2394b1799 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -378,7 +378,7 @@ def get_sparsity_and_variance_metrics( ignore_tokens: set[int | None] = set(), verbose: bool = False, ) -> tuple[dict[str, Any], dict[str, Any]]: - hook_names = sae.cfg.hook_names or [sae.cfg.hook_name] + hook_names = sae.cfg.hook_names if hasattr(sae.cfg, "hook_names") else [sae.cfg.hook_name] hook_head_index = sae.cfg.hook_head_index metric_dict = {} @@ -449,7 +449,7 @@ def get_sparsity_and_variance_metrics( elif any(substring in hook_names[0] for substring in has_head_dim_key_substrings): # TODO(mkbehr) support head dimension for mutilayer evals original_act = cache[hook_names[0]].flatten(-2, -1) - elif sae.cfg.hook_names: + elif hasattr(sae.cfg, "hook_names"): # TODO(mkbehr): support head dimension for mutilayer evals layerwise_activations = [ cache[hook_name] for hook_name in hook_names diff --git a/sae_lens/sae.py b/sae_lens/sae.py index b8392d14e..ef9cfc701 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -124,9 +124,6 @@ def to_dict(self) -> dict[str, Any]: "seqpos_slice": self.seqpos_slice, } - def hook_names(self) -> List[str]: - return [self.hook_name] - class SAE(HookedRootModule): """ diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index e1f56e387..10778606d 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -235,7 +235,7 @@ def __init__( ) self.hook_name = hook_name - self.hook_names = hook_names + self.hook_names = hook_names if hook_names is not None else [] self.hook_layer = hook_layer self.hook_head_index = hook_head_index self.context_size = context_size diff --git a/tests/training/test_config.py b/tests/training/test_config.py index db888b8ac..2af015efc 100644 --- a/tests/training/test_config.py +++ b/tests/training/test_config.py @@ -90,6 +90,7 @@ def test_sae_training_runner_config_get_sae_base_parameters(): "dtype": "float32", "model_name": "gelu-2l", "hook_name": "blocks.0.hook_mlp_out", + "hook_names": [], "hook_layer": 0, "hook_head_index": None, "device": "cpu", From e035b8c1adf0edccfd4bc5779350446c8ebd4a00 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 4 May 2025 18:21:13 -0400 Subject: [PATCH 51/61] "global" acausal crosscoder script for gpt2-small --- scripts/global_acausal_crosscoder.py | 37 +++++++++++++--------------- 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/scripts/global_acausal_crosscoder.py b/scripts/global_acausal_crosscoder.py index de4eb47a8..f4025125a 100644 --- a/scripts/global_acausal_crosscoder.py +++ b/scripts/global_acausal_crosscoder.py @@ -27,16 +27,16 @@ # total_training_steps = 200_000 # total_training_steps = 60_000 total_training_steps = 10_000 -batch_size = 2048 +batch_size = 4092 # batch_size = 256 total_training_tokens = total_training_steps * batch_size print(f"Total Training Tokens: {total_training_tokens}") -layers = list(range(3)) -# layers = [0] +hook_name_template = "blocks.{layer}.hook_mlp_out" +layers = list(range(2)) -model_name = "tiny-stories-28M" -dataset_path = "apollo-research/roneneldan-TinyStories-tokenizer-gpt2" +model_name = "gpt2-small" +dataset_path = "apollo-research/SkyLion007-openwebtext-tokenizer-gpt2" new_cached_activations_path = ( f"./cached_activations/{model_name}/{dataset_path}/{total_training_steps}" ) @@ -51,23 +51,21 @@ if not log_to_wandb: print("NOT LOGGING TO WANDB") -d_in = 512 -expansion_factor = 16 +d_in = 768 +expansion_factor = 32 d_sae = d_in * expansion_factor -learning_rate = 5e-5 +learning_rate = 2e-5 l1_coefficient = 1 -run_name = ( - f"{d_sae}" - f"-Layers-{'_'.join([str(l) for l in layers])}" - f"-L1-{l1_coefficient}" - f"-LR-{learning_rate}" - f"-Tokens-{total_training_tokens:3.3e}" - ) +hook_name = hook_name_template.format( + layer=f"{min(layers)}_through_{max(layers)}" +) +hook_names = [hook_name_template.format(layer=layer) for layer in layers] cfg = LanguageModelSAERunnerConfig( model_name=model_name, - hook_name="blocks.{layer}.hook_mlp_out", - hook_layers=layers, + hook_name=hook_name, + hook_names=hook_names, + hook_layer=max(layers), d_in=d_in, dataset_path=dataset_path, streaming=True, @@ -117,7 +115,7 @@ adam_beta1=0.9, adam_beta2=0.999, # Buffer details won't matter in we cache / shuffle our activations ahead of time. - n_batches_in_buffer=64, + n_batches_in_buffer=32, store_batch_size_prompts=16, normalize_activations="expected_average_only_in", # Feature Store @@ -126,8 +124,7 @@ dead_feature_threshold=1e-4, # WANDB log_to_wandb=log_to_wandb, # always use wandb unless you are just testing code. - wandb_project="crosscoder-acausal-tinystories-23M", - run_name=run_name, + wandb_project="crosscoder-acausal-gpt2-small", wandb_log_frequency=50, eval_every_n_wandb_logs=10, # Misc From 540e23f7171b328b44e36700fc9a40667bde62d5 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 4 May 2025 18:23:48 -0400 Subject: [PATCH 52/61] remove some TODOs --- sae_lens/crosscoder_sae.py | 2 -- sae_lens/sae_training_runner.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index 5cd1b5719..bcd929044 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -52,8 +52,6 @@ def encode_standard( """ Calculate SAE features from inputs """ - # TODO(mkbehr): instead of changing this and the W_enc/b_enc - # dimensions, we could change reshape_fn_in sae_in = self.process_sae_in(x) hidden_pre = self.hook_sae_acts_pre( diff --git a/sae_lens/sae_training_runner.py b/sae_lens/sae_training_runner.py index 205e672c7..54f0d4eb5 100644 --- a/sae_lens/sae_training_runner.py +++ b/sae_lens/sae_training_runner.py @@ -224,8 +224,7 @@ def save_checkpoint( if trainer.cfg.log_to_wandb: # Avoid wandb saving errors such as: # ValueError: Artifact name may only contain alphanumeric characters, dashes, underscores, and dots. Invalid name: sae_google/gemma-2b_etc - # TODO(mkbehr) name better - sae_name = trainer.sae.get_name().replace("/", "__").replace("{}", "__") + sae_name = trainer.sae.get_name().replace("/", "__") # save model weights and cfg model_artifact = wandb.Artifact( From 109eba8592fdbab1b7f616be8939ff41a90b57c4 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 4 May 2025 23:22:23 -0400 Subject: [PATCH 53/61] remove more TODOs --- sae_lens/crosscoder_sae.py | 7 ++----- sae_lens/training/activations_store.py | 6 ++---- sae_lens/training/crosscoder_sae_trainer.py | 8 -------- sae_lens/training/training_crosscoder_sae.py | 6 ------ sae_lens/training/training_sae.py | 2 -- 5 files changed, 4 insertions(+), 25 deletions(-) diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index bcd929044..5c1ad12fe 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -21,17 +21,14 @@ class CrosscoderSAE(SAE): Sparse autoencoder that acts on multiple layers of activations. """ - # TODO(mkbehr): write - # - remaining encode methods - # - hook_z reshaping support - def __init__( self, cfg: CrosscoderSAEConfig, use_error_term: bool = False, ): if cfg.architecture != "standard": - raise NotImplementedError("TODO(mkbehr): support other archs") + raise NotImplementedError( + "TODO(mkbehr): support other architectures") super().__init__(cfg=cfg, use_error_term=use_error_term) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 10778606d..8ff9ccd25 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -438,7 +438,6 @@ def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: # Norm scaling factor is a float in the single-layer case, and # a tensor in the multilayer case. if self.hook_names: - # TODO(mkbehr): set the device somewhere better return activations * self.estimated_norm_scaling_factor.unsqueeze(-1).to(activations.device) else: return activations * self.estimated_norm_scaling_factor @@ -449,8 +448,7 @@ def unscale(self, activations: torch.Tensor) -> torch.Tensor: "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first" ) if self.hook_names: - # TODO(mkbehr): set the device somewhere better - return activations / self.estimated_norm_scaling_factor.unsqueeze(-1).to(activations.device) + return activations / self.estimated_norm_scaling_factor.unsqueeze(-1) else: return activations / self.estimated_norm_scaling_factor @@ -467,7 +465,7 @@ def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor" ): # temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works - self.estimated_norm_scaling_factor = torch.ones(1) + self.estimated_norm_scaling_factor = torch.ones(1, device=self.device) acts = self.next_batch() self.estimated_norm_scaling_factor = None norms_per_batch[:, batch_i] = acts.norm(dim=-1).mean(dim=0) diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py index d44900b30..26d30596a 100644 --- a/sae_lens/training/crosscoder_sae_trainer.py +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -9,14 +9,6 @@ from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE, TrainStepOutput -# TODO(mkbehr): probably too much copypasting here -# why do I think that? -# - fit is long -# - all it does is take the whole batch instead of the first layer -# - maybe a helper method to subclass? -# - _run_and_log_evals is long -# - all it does differently is W_dec_norms (and presumably other architectures' things once those are implemented) - class CrosscoderSAETrainer(SAETrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py index 93fcda17b..cefa06a79 100644 --- a/sae_lens/training/training_crosscoder_sae.py +++ b/sae_lens/training/training_crosscoder_sae.py @@ -25,13 +25,10 @@ SAE_CFG_PATH = "cfg.json" -# TODO(mkbehr) will this multiple inheritance work? @dataclass(kw_only=True) class TrainingCrosscoderSAEConfig(CrosscoderSAEConfig, TrainingSAEConfig): sparsity_penalty_decoder_norm_lp_norm: float = 1 - # TODO(mkbehr): copypasting from TrainingSAEConfig and adding a few - # params. There should be a better way. @classmethod def from_sae_runner_config( cls, cfg: LanguageModelSAERunnerConfig @@ -77,7 +74,6 @@ def from_sae_runner_config( ) def to_dict(self) -> dict[str, Any]: - # TODO(mkbehr): double-check this multiple inheritance. seems messy. return (TrainingSAEConfig.to_dict(self) | CrosscoderSAEConfig.to_dict(self) | { @@ -109,8 +105,6 @@ def from_dict(cls, return cls(TrainingCrosscoderSAEConfig.from_dict(config_dict), use_error_term = use_error_term) - # TODO(mkbehr): hacking around multiple inheritance. there's - # probably a better way. @staticmethod def base_sae_cfg(cfg: TrainingCrosscoderSAEConfig): return CrosscoderSAEConfig.from_dict(cfg.get_base_sae_cfg_dict()) diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index bdb15bb54..490b5d3a0 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -290,8 +290,6 @@ def threshold(self) -> torch.Tensor: def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE": return cls(TrainingSAEConfig.from_dict(config_dict)) - # TODO(mkbehr): hacking around multiple inheritance. there's - # probably a better way. @staticmethod def base_sae_cfg(cfg: TrainingSAEConfig): return SAEConfig.from_dict(cfg.get_base_sae_cfg_dict()) From 99890650d520e14bf2f74d5837b6d54e8985a0d3 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Mon, 5 May 2025 21:28:37 -0400 Subject: [PATCH 54/61] enable test_activations_store_normalization_multiple_layers --- .../test_activations_store_multilayer.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/tests/training/test_activations_store_multilayer.py b/tests/training/test_activations_store_multilayer.py index 77c5a909d..614b9f950 100644 --- a/tests/training/test_activations_store_multilayer.py +++ b/tests/training/test_activations_store_multilayer.py @@ -114,10 +114,9 @@ def test_activations_store_next_batch_multiple_layers(ts_model: HookedTransforme batch = activation_store.next_batch() assert batch.shape == (10, len(cfg.hook_names), activation_store.d_in) -@pytest.mark.skip("TODO(mkbehr): does activation need to be handled differently?") + def test_activations_store_normalization_multiple_layers(ts_model: HookedTransformer): """Test normalization when using multiple layers.""" - # Setup with normalization and multiple layers cfg = build_multilayer_sae_cfg( hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], @@ -129,22 +128,15 @@ def test_activations_store_normalization_multiple_layers(ts_model: HookedTransfo activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) activation_store.set_norm_scaling_factor_if_needed() - # Get a batch with normalized activations batch = activation_store.next_batch() - # Check that the activations have been properly normalized - # The norm should be approximately sqrt(d_in) for each layer - for layer_idx in range(len(activation_store.hook_layers)): - layer_activations = batch[:, layer_idx, :] - # Check if average norm is approximately as expected (allowing for some variance) - avg_norm = layer_activations.norm(dim=-1).mean() - expected_norm = (activation_store.d_in ** 0.5) - assert avg_norm.item() == pytest.approx(expected_norm, abs=2.0) + avg_norm = batch.norm(dim=-1).mean(dim=1) + expected_norm = torch.full_like(avg_norm, cfg.d_in ** 0.5) + torch.testing.assert_close(avg_norm, expected_norm, atol=1.0, rtol=0.1) def test_backward_compatibility_single_layer(ts_model: HookedTransformer): """Test that single layer behavior is unchanged with the multi-layer support.""" - # Create a store with single layer (old behavior) cfg_single = build_sae_cfg( hook_name="blocks.0.hook_resid_pre", hook_layer=0, @@ -154,7 +146,6 @@ def test_backward_compatibility_single_layer(ts_model: HookedTransformer): dataset = Dataset.from_list([{"text": "hello world"}] * 10) single_store = ActivationsStore.from_config(ts_model, cfg_single, override_dataset=dataset) - # Create a store with single layer (new behavior) cfg_multi = build_multilayer_sae_cfg( hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0], From c1ad3935ed441264e1c9707f4c693caaab88535f Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Tue, 6 May 2025 18:09:40 -0400 Subject: [PATCH 55/61] Update to new disk loader --- sae_lens/crosscoder_sae.py | 21 +++++++++++++++++ sae_lens/training/training_crosscoder_sae.py | 24 +++++--------------- tests/training/test_crosscoder_sae.py | 14 ++++++------ 3 files changed, 34 insertions(+), 25 deletions(-) diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index 5c1ad12fe..0837e991e 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -6,6 +6,11 @@ from jaxtyping import Float from sae_lens import SAEConfig, SAE +from sae_lens.toolkit.pretrained_sae_loaders import ( + PretrainedSaeDiskLoader, + handle_config_defaulting, + sae_lens_disk_loader, +) @dataclass class CrosscoderSAEConfig(SAEConfig): @@ -107,3 +112,19 @@ def fold_activation_norm_scaling_factor( # once we normalize, we shouldn't need to scale activations. self.cfg.normalize_activations = "none" + @classmethod + def load_from_disk( + cls, + path: str, + device: str = "cpu", + dtype: str | None = None, + converter: PretrainedSaeDiskLoader = sae_lens_disk_loader, + ) -> "CrosscoderSAE": + overrides = {"dtype": dtype} if dtype is not None else None + cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides) + cfg_dict = handle_config_defaulting(cfg_dict) + sae_cfg = CrosscoderSAEConfig.from_dict(cfg_dict) + sae = cls(sae_cfg) + sae.process_state_dict_for_loading(state_dict) + sae.load_state_dict(state_dict) + return sae diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py index cefa06a79..5ee0045b7 100644 --- a/sae_lens/training/training_crosscoder_sae.py +++ b/sae_lens/training/training_crosscoder_sae.py @@ -16,8 +16,9 @@ TrainStepOutput, ) from sae_lens.toolkit.pretrained_sae_loaders import ( + PretrainedSaeDiskLoader, handle_config_defaulting, - read_sae_from_disk, + sae_lens_disk_loader, ) SPARSITY_PATH = "sparsity.safetensors" @@ -196,33 +197,20 @@ def training_forward_pass( ) @classmethod - def load_from_pretrained( + def load_from_disk( cls, path: str, device: str = "cpu", dtype: str | None = None, + converter: PretrainedSaeDiskLoader = sae_lens_disk_loader, ) -> "TrainingCrosscoderSAE": - # get the config - config_path = os.path.join(path, SAE_CFG_PATH) - with open(config_path) as f: - cfg_dict = json.load(f) + overrides = {"dtype": dtype} if dtype is not None else None + cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides) cfg_dict = handle_config_defaulting(cfg_dict) - cfg_dict["device"] = device - if dtype is not None: - cfg_dict["dtype"] = dtype - - weight_path = os.path.join(path, SAE_WEIGHTS_PATH) - cfg_dict, state_dict = read_sae_from_disk( - cfg_dict=cfg_dict, - weight_path=weight_path, - device=device, - ) sae_cfg = TrainingCrosscoderSAEConfig.from_dict(cfg_dict) - sae = cls(sae_cfg) sae.process_state_dict_for_loading(state_dict) sae.load_state_dict(state_dict) - return sae def initialize_weights_complex(self): diff --git a/tests/training/test_crosscoder_sae.py b/tests/training/test_crosscoder_sae.py index 23f577a01..acdc03913 100644 --- a/tests/training/test_crosscoder_sae.py +++ b/tests/training/test_crosscoder_sae.py @@ -121,7 +121,7 @@ def test_crosscoder_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): @pytest.mark.parametrize("architecture", ALL_ARCHITECTURES) @torch.no_grad() -def test_sae_fold_w_dec_norm_all_architectures(architecture: str): +def test_crosscoder_sae_fold_w_dec_norm_all_architectures(architecture: str): if architecture != "standard": pytest.xfail("TODO(mkbehr): support other architectures") cfg = build_multilayer_sae_cfg(architecture=architecture, hook_layers=[0,1,2]) @@ -159,7 +159,7 @@ def test_sae_fold_w_dec_norm_all_architectures(architecture: str): torch.testing.assert_close(sae_out_1, sae_out_2) @torch.no_grad() -def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): +def test_crosscoder_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): norm_scaling_factor = torch.Tensor([2.0, 3.0, 4.0]) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) @@ -202,7 +202,7 @@ def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): @pytest.mark.parametrize("architecture", ALL_ARCHITECTURES) @torch.no_grad() -def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): +def test_crosscoder_sae_fold_norm_scaling_factor_all_architectures(architecture: str): if architecture != "standard": pytest.xfail("TODO(mkbehr): support other architectures") cfg = build_multilayer_sae_cfg(architecture=architecture, hook_layers=[0,1,2]) @@ -244,7 +244,7 @@ def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): # but actual outputs should be the same torch.testing.assert_close(sae_out_1, sae_out_2) -def test_sae_save_and_load_from_pretrained(tmp_path: Path) -> None: +def test_crosscoder_sae_save_and_load_from_pretrained(tmp_path: Path) -> None: cfg = build_multilayer_sae_cfg(hook_layers=[0,1,2]) model_path = str(tmp_path) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) @@ -270,7 +270,7 @@ def test_sae_save_and_load_from_pretrained(tmp_path: Path) -> None: assert torch.allclose(sae_out_1, sae_out_2) @pytest.mark.xfail(reason="TODO(mkbehr): support other architectures") -def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None: +def test_crosscoder_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None: cfg = build_multilayer_sae_cfg(architecture="gated", hook_layers=[0,1,2]) model_path = str(tmp_path) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) @@ -295,7 +295,7 @@ def test_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None: sae_out_2 = sae_loaded(sae_in) assert torch.allclose(sae_out_1, sae_out_2) -def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: +def test_crosscoder_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: cfg = build_multilayer_sae_cfg(activation_fn_kwargs={"k": 30}, hook_layers=[0,1,2]) model_path = str(tmp_path) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) @@ -320,7 +320,7 @@ def test_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: sae_out_2 = sae_loaded(sae_in) assert torch.allclose(sae_out_1, sae_out_2) -def test_sae_get_name_returns_correct_name_from_cfg_vals() -> None: +def test_crosscoder_sae_get_name_returns_correct_name_from_cfg_vals() -> None: cfg = build_multilayer_sae_cfg(model_name="test_model", hook_name_template="blocks.{layer}.test_hook_name", d_sae=128, hook_layers=[0,1,2]) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) assert sae.get_name() == "sae_test_model_blocks.layers_0_through_2.test_hook_name_128" From 07449ca817cef7d2dd3dab0beb87186633a06f53 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Tue, 6 May 2025 18:30:34 -0400 Subject: [PATCH 56/61] test saving multilayer activation norm --- sae_lens/training/activations_store.py | 1 - .../test_activations_store_multilayer.py | 31 +++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 8ff9ccd25..2f968996d 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -457,7 +457,6 @@ def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: @torch.no_grad() def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)): - # TODO(mkbehr): test multilayer norm scaling, probably fix saving? norms_per_batch = torch.empty( len(self.hook_names) or 1, n_batches_for_norm_estimate, device=self.device) diff --git a/tests/training/test_activations_store_multilayer.py b/tests/training/test_activations_store_multilayer.py index 614b9f950..40dd611f8 100644 --- a/tests/training/test_activations_store_multilayer.py +++ b/tests/training/test_activations_store_multilayer.py @@ -1,8 +1,13 @@ """Tests for ActivationsStore with multiple layer support.""" +import os +import tempfile +from typing import Any + import pytest import torch from datasets import Dataset +from safetensors.torch import load_file from transformer_lens import HookedTransformer from sae_lens.training.activations_store import ActivationsStore @@ -162,3 +167,29 @@ def test_backward_compatibility_single_layer(ts_model: HookedTransformer): torch.testing.assert_close(batch_tokens_single, batch_tokens_multi) torch.testing.assert_close(activations_single, activations_multi) + + +def test_activations_store_multilayer_save_with_norm_scaling_factor( + ts_model: HookedTransformer, +): + cfg = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", + hook_layers=[0, 1, 2], + normalize_activations="expected_average_only_in", + context_size=5 + ) + activation_store = ActivationsStore.from_config(ts_model, cfg) + activation_store.set_norm_scaling_factor_if_needed() + assert activation_store.estimated_norm_scaling_factor is not None + with tempfile.NamedTemporaryFile() as temp_file: + activation_store.save(temp_file.name) + assert os.path.exists(temp_file.name) + state_dict = load_file(temp_file.name) + assert isinstance(state_dict, dict) + assert "estimated_norm_scaling_factor" in state_dict + estimated_norm_scaling_factor = state_dict["estimated_norm_scaling_factor"] + assert estimated_norm_scaling_factor.shape == (len(cfg.hook_names),) + torch.testing.assert_close( + estimated_norm_scaling_factor, + activation_store.estimated_norm_scaling_factor + ) From 9d40b8cd6a183c93ea7f090b3bb611e1c2ae068a Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Tue, 6 May 2025 18:33:35 -0400 Subject: [PATCH 57/61] misc. cleanup --- sae_lens/training/activations_store.py | 2 +- ...al_crosscoder.py => acausal_crosscoder.py} | 22 +++---------------- .../test_activations_store_multilayer.py | 7 ------ 3 files changed, 4 insertions(+), 27 deletions(-) rename scripts/{global_acausal_crosscoder.py => acausal_crosscoder.py} (76%) diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 2f968996d..5518b54a5 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -463,7 +463,7 @@ def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e for batch_i in tqdm( range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor" ): - # temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works + # temporarily set estimated_norm_scaling_factor to 1.0 so the dataloader works self.estimated_norm_scaling_factor = torch.ones(1, device=self.device) acts = self.next_batch() self.estimated_norm_scaling_factor = None diff --git a/scripts/global_acausal_crosscoder.py b/scripts/acausal_crosscoder.py similarity index 76% rename from scripts/global_acausal_crosscoder.py rename to scripts/acausal_crosscoder.py index f4025125a..cac062460 100644 --- a/scripts/global_acausal_crosscoder.py +++ b/scripts/acausal_crosscoder.py @@ -1,5 +1,3 @@ -# TODO(mkbehr): don't really commit this - import os import sys @@ -71,25 +69,15 @@ streaming=True, context_size=512, is_dataset_tokenized=True, - prepend_bos=False, # TODO(mkbehr): probably better to prepend bosg but then remove that token's activations - # How big do we want our SAE to be? + prepend_bos=True, expansion_factor=expansion_factor, - # Dataset / Activation Store - # When we do a proper test - # training_tokens= 820_000_000, # 200k steps * 4096 batch size ~ 820M tokens (doable overnight on an A100) - # For now. use_cached_activations=False, - # cached_activations_path="/home/paperspace/shared_volumes/activations_volume_1/gelu-1l", - training_tokens=total_training_tokens, # For initial testing I think this is a good number. + training_tokens=total_training_tokens, train_batch_size_tokens=batch_size, # Loss Function - ## Reconstruction Coefficient. - mse_loss_normalization=None, # MSE Loss Normalization is not mentioned (so we use stanrd MSE Loss). But not we take an average over the batch. - ## Anthropic does not mention using an Lp norm other than L1. + mse_loss_normalization=None, l1_coefficient=l1_coefficient, lp_norm=1.0, - # Instead, they multiply the L1 loss contribution - # from each feature of the activations by the decoder norm of the corresponding feature. scale_sparsity_penalty_by_decoder_norm=True, # TODO(mkbehr): plumb this through config # sparsity_penalty_decoder_norm_lp_norm=1.0, @@ -98,12 +86,9 @@ l1_warm_up_steps=l1_warmup_steps, lr_warm_up_steps=lr_warm_up_steps, lr_decay_steps=lr_warm_up_steps, - ## No ghost grad term. use_ghost_grads=False, # Initialization / Architecture apply_b_dec_to_input=False, - # encoder bias zero's. (I'm not sure what it is by default now) - # decoder bias zero's. b_dec_init_method="zeros", normalize_sae_decoder=False, decoder_heuristic_init=True, @@ -135,7 +120,6 @@ dtype="float32", ) -# look at the next cell to see some instruction for what to do while this is running. sae = SAETrainingRunner( cfg, override_sae = TrainingCrosscoderSAE( diff --git a/tests/training/test_activations_store_multilayer.py b/tests/training/test_activations_store_multilayer.py index 40dd611f8..24eaab70e 100644 --- a/tests/training/test_activations_store_multilayer.py +++ b/tests/training/test_activations_store_multilayer.py @@ -16,7 +16,6 @@ def test_activations_store_init_with_multiple_layers(ts_model: HookedTransformer): """Test initialization with a list of layers instead of a single layer.""" - # Initialize with multiple layers cfg = build_multilayer_sae_cfg( hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2] @@ -43,7 +42,6 @@ def test_activations_store_init_with_multiple_layers(ts_model: HookedTransformer def test_activations_store_get_activations_multiple_layers(ts_model: HookedTransformer): """Test that get_activations collects activations from all specified layers.""" - # Setup with multiple layers cfg = build_multilayer_sae_cfg( hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], @@ -53,7 +51,6 @@ def test_activations_store_get_activations_multiple_layers(ts_model: HookedTrans dataset = Dataset.from_list([{"text": "hello world"}] * 10) activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) - # Get a batch of tokens and activations batch_tokens = activation_store.get_batch_tokens() activations = activation_store.get_activations(batch_tokens) @@ -84,7 +81,6 @@ def test_activations_store_get_activations_multiple_layers(ts_model: HookedTrans def test_activations_store_get_buffer_multiple_layers(ts_model: HookedTransformer): """Test buffer handling with multiple layers.""" - # Setup with multiple layers cfg = build_multilayer_sae_cfg( hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], @@ -94,7 +90,6 @@ def test_activations_store_get_buffer_multiple_layers(ts_model: HookedTransforme dataset = Dataset.from_list([{"text": "hello world"}] * 20) activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) - # Get buffer with 2 batches buffer_activations, buffer_tokens = activation_store.get_buffer(n_batches_in_buffer=2) # Check shape: [(batch_size * context_size * n_batches), num_layers, d_in] @@ -105,7 +100,6 @@ def test_activations_store_get_buffer_multiple_layers(ts_model: HookedTransforme def test_activations_store_next_batch_multiple_layers(ts_model: HookedTransformer): """Test that next_batch returns correct batch shape with multiple layers.""" - # Setup with multiple layers cfg = build_multilayer_sae_cfg( hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], @@ -158,7 +152,6 @@ def test_backward_compatibility_single_layer(ts_model: HookedTransformer): ) multi_store = ActivationsStore.from_config(ts_model, cfg_multi, override_dataset=dataset) - # Get tokens and activations from both batch_tokens_single = single_store.get_batch_tokens() activations_single = single_store.get_activations(batch_tokens_single) From c631c06fabff0e4d80ef2d924302b2a578c7741a Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Tue, 6 May 2025 21:00:30 -0400 Subject: [PATCH 58/61] fix format --- sae_lens/config.py | 6 +- sae_lens/crosscoder_sae.py | 42 +++++---- sae_lens/evals.py | 16 ++-- sae_lens/sae.py | 22 +++-- sae_lens/sae_training_runner.py | 4 +- sae_lens/training/activations_store.py | 25 +++--- sae_lens/training/crosscoder_sae_trainer.py | 25 +++--- sae_lens/training/sae_trainer.py | 2 +- sae_lens/training/training_crosscoder_sae.py | 90 ++++++++++--------- sae_lens/training/training_sae.py | 5 +- scripts/acausal_crosscoder.py | 14 ++- tests/helpers.py | 23 +++-- tests/test_evals.py | 18 ++-- .../test_activations_store_multilayer.py | 69 +++++++------- tests/training/test_crosscoder_sae.py | 74 +++++++++------ tests/training/test_crosscoder_sae_trainer.py | 36 ++++---- .../training/test_crosscoder_sae_training.py | 49 +++++----- .../training/test_training_crosscoder_sae.py | 17 ++-- 18 files changed, 302 insertions(+), 235 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index b905f000a..ce49467bc 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -6,7 +6,6 @@ import simple_parsing import torch -import wandb from datasets import ( Dataset, DatasetDict, @@ -15,6 +14,7 @@ load_dataset, ) +import wandb from sae_lens import __version__, logger DTYPE_MAP = { @@ -437,7 +437,7 @@ def total_training_steps(self) -> int: return self.total_training_tokens // self.train_batch_size_tokens def get_base_sae_cfg_dict(self) -> dict[str, Any]: - cfg_dict = { + return { # TEMP "architecture": self.architecture, "d_in": self.d_in, @@ -463,8 +463,6 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]: "seqpos_slice": self.seqpos_slice, } - return cfg_dict - def get_training_sae_cfg_dict(self) -> dict[str, Any]: return { **self.get_base_sae_cfg_dict(), diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index 0837e991e..0540305af 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -1,17 +1,18 @@ from dataclasses import dataclass, field -from typing import Any, List +from typing import Any import einops import torch from jaxtyping import Float -from sae_lens import SAEConfig, SAE +from sae_lens import SAE, SAEConfig from sae_lens.toolkit.pretrained_sae_loaders import ( PretrainedSaeDiskLoader, handle_config_defaulting, sae_lens_disk_loader, ) + @dataclass class CrosscoderSAEConfig(SAEConfig): hook_names: list[int] = field(default_factory=list) @@ -21,19 +22,19 @@ def to_dict(self) -> dict[str, Any]: "hook_names": self.hook_names, } + class CrosscoderSAE(SAE): """ Sparse autoencoder that acts on multiple layers of activations. """ def __init__( - self, - cfg: CrosscoderSAEConfig, - use_error_term: bool = False, - ): + self, + cfg: CrosscoderSAEConfig, + use_error_term: bool = False, + ): if cfg.architecture != "standard": - raise NotImplementedError( - "TODO(mkbehr): support other architectures") + raise NotImplementedError("TODO(mkbehr): support other architectures") super().__init__(cfg=cfg, use_error_term=use_error_term) @@ -47,7 +48,6 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "CrosscoderSAE": def input_shape(self): return (len(self.cfg.hook_names), self.cfg.d_in) - def encode_standard( self, x: Float[torch.Tensor, "... n_layers d_in"] ) -> Float[torch.Tensor, "... d_sae"]: @@ -58,10 +58,12 @@ def encode_standard( hidden_pre = self.hook_sae_acts_pre( einops.einsum( - sae_in, self.W_enc, - "... n_layers d_in, n_layers d_in d_sae -> ... d_sae" - ) - + self.b_enc) + sae_in, + self.W_enc, + "... n_layers d_in, n_layers d_in d_sae -> ... d_sae", + ) + + self.b_enc + ) return self.hook_sae_acts_post(self.activation_fn(hidden_pre)) def decode( @@ -73,8 +75,9 @@ def decode( einops.einsum( self.apply_finetuning_scaling_factor(feature_acts), self.W_dec, - "... d_sae, d_sae n_layers d_in -> ... n_layers d_in" - ) + self.b_dec + "... d_sae, d_sae n_layers d_in -> ... n_layers d_in", + ) + + self.b_dec ) # handle run time activation normalization if needed @@ -86,10 +89,11 @@ def decode( @torch.no_grad() def fold_W_dec_norm(self): - W_dec_norms = self.W_dec.norm(dim=[-2,-1], keepdim=True) + W_dec_norms = self.W_dec.norm(dim=[-2, -1], keepdim=True) self.W_dec.data = self.W_dec.data / W_dec_norms self.W_enc.data = self.W_enc.data * einops.rearrange( - W_dec_norms, "d_sae 1 1 -> 1 1 d_sae") + W_dec_norms, "d_sae 1 1 -> 1 1 d_sae" + ) if self.cfg.architecture == "gated": self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze() self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze() @@ -104,7 +108,9 @@ def fold_W_dec_norm(self): def fold_activation_norm_scaling_factor( self, activation_norm_scaling_factor: Float[torch.Tensor, "n_layers"] ): - self.W_enc.data = self.W_enc.data * activation_norm_scaling_factor.reshape((-1,1,1)) + self.W_enc.data = self.W_enc.data * activation_norm_scaling_factor.reshape( + (-1, 1, 1) + ) # previously weren't doing this. self.W_dec.data = self.W_dec.data / activation_norm_scaling_factor.unsqueeze(-1) self.b_dec.data = self.b_dec.data / activation_norm_scaling_factor.unsqueeze(-1) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 2394b1799..5b0a25af6 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -378,7 +378,9 @@ def get_sparsity_and_variance_metrics( ignore_tokens: set[int | None] = set(), verbose: bool = False, ) -> tuple[dict[str, Any], dict[str, Any]]: - hook_names = sae.cfg.hook_names if hasattr(sae.cfg, "hook_names") else [sae.cfg.hook_name] + hook_names = ( + sae.cfg.hook_names if hasattr(sae.cfg, "hook_names") else [sae.cfg.hook_name] + ) hook_head_index = sae.cfg.hook_head_index metric_dict = {} @@ -446,14 +448,14 @@ def get_sparsity_and_variance_metrics( # TODO(mkbehr) support head dimension for mutilayer evals assert len(hook_names) == 1 original_act = cache[hook_names[0]][:, :, hook_head_index] - elif any(substring in hook_names[0] for substring in has_head_dim_key_substrings): + elif any( + substring in hook_names[0] for substring in has_head_dim_key_substrings + ): # TODO(mkbehr) support head dimension for mutilayer evals original_act = cache[hook_names[0]].flatten(-2, -1) elif hasattr(sae.cfg, "hook_names"): # TODO(mkbehr): support head dimension for mutilayer evals - layerwise_activations = [ - cache[hook_name] for hook_name in hook_names - ] + layerwise_activations = [cache[hook_name] for hook_name in hook_names] original_act = torch.stack(layerwise_activations, dim=2) else: original_act = cache[hook_names[0]] @@ -470,7 +472,9 @@ def get_sparsity_and_variance_metrics( if activation_store.normalize_activations == "expected_average_only_in": sae_out = activation_store.unscale(sae_out) - flattened_sae_input = einops.rearrange(original_act, "b ctx d ... -> (b ctx) (d ...)") + flattened_sae_input = einops.rearrange( + original_act, "b ctx d ... -> (b ctx) (d ...)" + ) flattened_sae_feature_acts = einops.rearrange( sae_feature_activations, "b ctx d -> (b ctx) d" ) diff --git a/sae_lens/sae.py b/sae_lens/sae.py index ef9cfc701..3ca586ebb 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -7,7 +7,7 @@ from contextlib import contextmanager from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Callable, List, Literal, TypeVar, overload +from typing import Any, Callable, Literal, TypeVar, overload import einops import torch @@ -257,7 +257,10 @@ def initialize_weights_basic(self): self.W_dec = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - self.cfg.d_sae, *self.input_shape(), dtype=self.dtype, device=self.device + self.cfg.d_sae, + *self.input_shape(), + dtype=self.dtype, + device=self.device, ) ) ) @@ -265,7 +268,10 @@ def initialize_weights_basic(self): self.W_enc = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - *self.input_shape(), self.cfg.d_sae, dtype=self.dtype, device=self.device + *self.input_shape(), + self.cfg.d_sae, + dtype=self.dtype, + device=self.device, ) ) ) @@ -287,7 +293,10 @@ def initialize_weights_gated(self): self.W_enc = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - *self.input_shape(), self.cfg.d_sae, dtype=self.dtype, device=self.device + *self.input_shape(), + self.cfg.d_sae, + dtype=self.dtype, + device=self.device, ) ) ) @@ -307,7 +316,10 @@ def initialize_weights_gated(self): self.W_dec = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - self.cfg.d_sae, *self.input_shape(), dtype=self.dtype, device=self.device + self.cfg.d_sae, + *self.input_shape(), + dtype=self.dtype, + device=self.device, ) ) ) diff --git a/sae_lens/sae_training_runner.py b/sae_lens/sae_training_runner.py index 54f0d4eb5..aa5dc8fed 100644 --- a/sae_lens/sae_training_runner.py +++ b/sae_lens/sae_training_runner.py @@ -6,18 +6,18 @@ from typing import Any, cast import torch -import wandb from simple_parsing import ArgumentParser from transformer_lens.hook_points import HookedRootModule +import wandb from sae_lens import logger from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig from sae_lens.load_model import load_model from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.crosscoder_sae_trainer import CrosscoderSAETrainer -from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE from sae_lens.training.geometric_median import compute_geometric_median from sae_lens.training.sae_trainer import SAETrainer +from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index 5518b54a5..d6c9ccd63 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -438,9 +438,10 @@ def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: # Norm scaling factor is a float in the single-layer case, and # a tensor in the multilayer case. if self.hook_names: - return activations * self.estimated_norm_scaling_factor.unsqueeze(-1).to(activations.device) - else: - return activations * self.estimated_norm_scaling_factor + return activations * self.estimated_norm_scaling_factor.unsqueeze(-1).to( + activations.device + ) + return activations * self.estimated_norm_scaling_factor def unscale(self, activations: torch.Tensor) -> torch.Tensor: if self.estimated_norm_scaling_factor is None: @@ -449,8 +450,7 @@ def unscale(self, activations: torch.Tensor) -> torch.Tensor: ) if self.hook_names: return activations / self.estimated_norm_scaling_factor.unsqueeze(-1) - else: - return activations / self.estimated_norm_scaling_factor + return activations / self.estimated_norm_scaling_factor def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: return (self.d_in**0.5) / activations.norm(dim=-1).mean() @@ -458,8 +458,8 @@ def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: @torch.no_grad() def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)): norms_per_batch = torch.empty( - len(self.hook_names) or 1, n_batches_for_norm_estimate, - device=self.device) + len(self.hook_names) or 1, n_batches_for_norm_estimate, device=self.device + ) for batch_i in tqdm( range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor" ): @@ -472,9 +472,8 @@ def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e # Norm scaling factor is a float in the single-layer case, and # a tensor in the multilayer case. if self.hook_names: - return (np.sqrt(self.d_in) / mean_norm) - else: - return (np.sqrt(self.d_in) / mean_norm.item()) + return np.sqrt(self.d_in) / mean_norm + return np.sqrt(self.d_in) / mean_norm.item() def shuffle_input_dataset(self, seed: int, buffer_size: int = 1): """ @@ -567,9 +566,7 @@ def get_activations(self, batch_tokens: torch.Tensor): )[1] layerwise_activations = [ - layerwise_activations_cache[hook_name][ - :, slice(*self.seqpos_slice) - ] + layerwise_activations_cache[hook_name][:, slice(*self.seqpos_slice)] for hook_name in hook_names ] @@ -578,7 +575,7 @@ def get_activations(self, batch_tokens: torch.Tensor): if self.hook_head_index is not None: layerwise_activations = [ activation[:, :, self.hook_head_index] - for activation in layerwise_activations + for activation in layerwise_activations ] elif layerwise_activations[0].ndim > 3: # if we have a head dimension try: diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py index 26d30596a..b158de362 100644 --- a/sae_lens/training/crosscoder_sae_trainer.py +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -1,32 +1,33 @@ from typing import Any import torch -import wandb from tqdm import tqdm +import wandb from sae_lens.evals import run_evals -from sae_lens.training.sae_trainer import SAETrainer, _unwrap_item -from sae_lens.training.training_sae import TrainingSAE, TrainStepOutput -from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE, TrainStepOutput +from sae_lens.training.sae_trainer import SAETrainer +from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE +from sae_lens.training.training_sae import TrainStepOutput + class CrosscoderSAETrainer(SAETrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Reconstruction metrics don't make sense for acausal crosscoders. - self.trainer_eval_config.compute_ce_loss=False - self.trainer_eval_config.compute_kl=False + self.trainer_eval_config.compute_ce_loss = False + self.trainer_eval_config.compute_kl = False def fit(self) -> TrainingCrosscoderSAE: - pbar = tqdm(total=self.cfg.total_training_tokens, desc="Training Crosscoder SAE") + pbar = tqdm( + total=self.cfg.total_training_tokens, desc="Training Crosscoder SAE" + ) self.activations_store.set_norm_scaling_factor_if_needed() # Train loop while self.n_training_tokens < self.cfg.total_training_tokens: # Do a training step. - layer_acts = self.activations_store.next_batch().to( - self.sae.device - ) + layer_acts = self.activations_store.next_batch().to(self.sae.device) self.n_training_tokens += self.cfg.train_batch_size_tokens step_output = self._train_step(sae=self.sae, sae_in=layer_acts) @@ -110,7 +111,9 @@ def _run_and_log_evals(self): # Remove metrics that are not useful for wandb logging eval_metrics.pop("metrics/total_tokens_evaluated", None) - W_dec_norm_dist = self.sae.W_dec.detach().float().norm(dim=(1,2)).cpu().numpy() + W_dec_norm_dist = ( + self.sae.W_dec.detach().float().norm(dim=(1, 2)).cpu().numpy() + ) eval_metrics["weights/W_dec_norms"] = wandb.Histogram(W_dec_norm_dist) # type: ignore if self.sae.cfg.architecture == "standard": diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index eef087e1a..659da1a23 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -3,11 +3,11 @@ from typing import Any, Protocol, cast import torch -import wandb from torch.optim import Adam from tqdm import tqdm from transformer_lens.hook_points import HookedRootModule +import wandb from sae_lens import __version__ from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.evals import EvalConfig, run_evals diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py index 5ee0045b7..5e9385996 100644 --- a/sae_lens/training/training_crosscoder_sae.py +++ b/sae_lens/training/training_crosscoder_sae.py @@ -1,5 +1,3 @@ -import json -import os from dataclasses import dataclass from typing import Any @@ -10,16 +8,16 @@ from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.crosscoder_sae import CrosscoderSAE, CrosscoderSAEConfig -from sae_lens.training.training_sae import ( - TrainingSAEConfig, - TrainingSAE, - TrainStepOutput, - ) from sae_lens.toolkit.pretrained_sae_loaders import ( PretrainedSaeDiskLoader, handle_config_defaulting, sae_lens_disk_loader, ) +from sae_lens.training.training_sae import ( + TrainingSAE, + TrainingSAEConfig, + TrainStepOutput, +) SPARSITY_PATH = "sparsity.safetensors" SAE_WEIGHTS_PATH = "sae_weights.safetensors" @@ -75,16 +73,19 @@ def from_sae_runner_config( ) def to_dict(self) -> dict[str, Any]: - return (TrainingSAEConfig.to_dict(self) - | CrosscoderSAEConfig.to_dict(self) - | { - "sparsity_penalty_decoder_norm_lp_norm": - self.sparsity_penalty_decoder_norm_lp_norm, - }) + return ( + TrainingSAEConfig.to_dict(self) + | CrosscoderSAEConfig.to_dict(self) + | { + "sparsity_penalty_decoder_norm_lp_norm": self.sparsity_penalty_decoder_norm_lp_norm, + } + ) def get_base_sae_cfg_dict(self) -> dict[str, Any]: - return (TrainingSAEConfig.get_base_sae_cfg_dict(self) - | { "hook_names": self.hook_names }) + return TrainingSAEConfig.get_base_sae_cfg_dict(self) | { + "hook_names": self.hook_names + } + class TrainingCrosscoderSAE(CrosscoderSAE, TrainingSAE): # TODO(mkbehr) future implementation @@ -94,17 +95,17 @@ class TrainingCrosscoderSAE(CrosscoderSAE, TrainingSAE): # calculate_ghost_grad_loss # fold_W_dec_norm for jumprelu - def __init__(self, - cfg: TrainingCrosscoderSAEConfig, - use_error_term: bool = False): + def __init__(self, cfg: TrainingCrosscoderSAEConfig, use_error_term: bool = False): super().__init__(cfg, use_error_term=use_error_term) @classmethod - def from_dict(cls, - config_dict: dict[str, Any], - use_error_term: bool = False) -> "TrainingSAE": - return cls(TrainingCrosscoderSAEConfig.from_dict(config_dict), - use_error_term = use_error_term) + def from_dict( + cls, config_dict: dict[str, Any], use_error_term: bool = False + ) -> "TrainingSAE": + return cls( + TrainingCrosscoderSAEConfig.from_dict(config_dict), + use_error_term=use_error_term, + ) @staticmethod def base_sae_cfg(cfg: TrainingCrosscoderSAEConfig): @@ -114,7 +115,9 @@ def check_cfg_compatibility(self): if self.cfg.architecture != "standard": raise NotImplementedError("TODO(mkbehr): support other archs") if not self.cfg.scale_sparsity_penalty_by_decoder_norm: - raise ValueError("Crosscoders require scale_sparsity_penalty_by_decoder_norm") + raise ValueError( + "Crosscoders require scale_sparsity_penalty_by_decoder_norm" + ) if not self.use_error_term: raise NotImplementedError("TODO(mkbehr): support causal crosscoders") if self.cfg.use_ghost_grads: @@ -128,10 +131,12 @@ def encode_with_hidden_pre( hidden_pre = self.hook_sae_acts_pre( einops.einsum( - sae_in, self.W_enc, - "... n_layers d_in, n_layers d_in d_sae -> ... d_sae" - ) - + self.b_enc) + sae_in, + self.W_enc, + "... n_layers d_in, n_layers d_in d_sae -> ... d_sae", + ) + + self.b_enc + ) hidden_pre_noised = hidden_pre + ( torch.randn_like(hidden_pre) * self.cfg.noise_scale * self.training ) @@ -159,8 +164,7 @@ def training_forward_pass( assert self.cfg.scale_sparsity_penalty_by_decoder_norm decoder_norms = self.W_dec.norm(dim=2) feature_act_weights = decoder_norms.norm( - p=self.cfg.sparsity_penalty_decoder_norm_lp_norm, - dim=1 + p=self.cfg.sparsity_penalty_decoder_norm_lp_norm, dim=1 ) weighted_feature_acts = feature_acts * feature_act_weights sparsity = weighted_feature_acts.norm( @@ -169,11 +173,7 @@ def training_forward_pass( l1_loss = (current_l1_coefficient * sparsity).mean() loss = mse_loss + l1_loss - if ( - self.cfg.use_ghost_grads - and self.training - and dead_neuron_mask is not None - ): + if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None: ghost_grad_loss = self.calculate_ghost_grad_loss( x=sae_in, sae_out=sae_out, @@ -216,20 +216,25 @@ def load_from_disk( def initialize_weights_complex(self): if self.cfg.decoder_orthogonal_init: self.W_dec.data = nn.init.orthogonal_( - self.W_dec.data.permute((1,2,0)) - ).permute((2,0,1)) + self.W_dec.data.permute((1, 2, 0)) + ).permute((2, 0, 1)) elif self.cfg.decoder_heuristic_init: self.W_dec = nn.Parameter( torch.rand( - self.cfg.d_sae, *self.input_shape(), dtype=self.dtype, device=self.device + self.cfg.d_sae, + *self.input_shape(), + dtype=self.dtype, + device=self.device, ) ) - self.initialize_decoder_norm_constant_norm(self.cfg.decoder_heuristic_init_norm) + self.initialize_decoder_norm_constant_norm( + self.cfg.decoder_heuristic_init_norm + ) # Then we initialize the encoder weights (either as the transpose of decoder or not) if self.cfg.init_encoder_as_decoder_transpose: - self.W_enc.data = self.W_dec.data.permute((1,2,0)).clone().contiguous() + self.W_enc.data = self.W_dec.data.permute((1, 2, 0)).clone().contiguous() else: self.W_enc = nn.Parameter( torch.nn.init.kaiming_uniform_( @@ -249,7 +254,7 @@ def initialize_weights_complex(self): @torch.no_grad() def set_decoder_norm_to_unit_norm(self): - self.W_dec.data /= torch.norm(self.W_dec.data, dim=[1,2], keepdim=True) + self.W_dec.data /= torch.norm(self.W_dec.data, dim=[1, 2], keepdim=True) @torch.no_grad() def initialize_decoder_norm_constant_norm(self, norm: float = 0.1): @@ -260,7 +265,7 @@ def initialize_decoder_norm_constant_norm(self, norm: float = 0.1): # TODO: Parameterise this as a function of m and n # ensure W_dec norms at unit norm - self.W_dec.data /= torch.norm(self.W_dec.data, dim=[1,2], keepdim=True) + self.W_dec.data /= torch.norm(self.W_dec.data, dim=[1, 2], keepdim=True) self.W_dec.data *= norm # will break tests but do this for now. @torch.no_grad() @@ -281,4 +286,3 @@ def remove_gradient_parallel_to_decoder_directions(self): self.W_dec.data, "d_sae, d_sae n_layers d_in -> d_sae n_layers d_in", ) - diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index 490b5d3a0..51d5d9c08 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -600,7 +600,10 @@ def initialize_weights_complex(self): elif self.cfg.decoder_heuristic_init: self.W_dec = nn.Parameter( torch.rand( - self.cfg.d_sae, *self.input_shape(), dtype=self.dtype, device=self.device + self.cfg.d_sae, + *self.input_shape(), + dtype=self.dtype, + device=self.device, ) ) self.initialize_decoder_norm_constant_norm( diff --git a/scripts/acausal_crosscoder.py b/scripts/acausal_crosscoder.py index cac062460..10a1d7c00 100644 --- a/scripts/acausal_crosscoder.py +++ b/scripts/acausal_crosscoder.py @@ -6,11 +6,11 @@ sys.path.append("..") from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.sae_training_runner import SAETrainingRunner from sae_lens.training.training_crosscoder_sae import ( TrainingCrosscoderSAE, - TrainingCrosscoderSAEConfig + TrainingCrosscoderSAEConfig, ) -from sae_lens.sae_training_runner import SAETrainingRunner if torch.cuda.is_available(): device = "cuda" @@ -54,9 +54,7 @@ d_sae = d_in * expansion_factor learning_rate = 2e-5 l1_coefficient = 1 -hook_name = hook_name_template.format( - layer=f"{min(layers)}_through_{max(layers)}" -) +hook_name = hook_name_template.format(layer=f"{min(layers)}_through_{max(layers)}") hook_names = [hook_name_template.format(layer=layer) for layer in layers] cfg = LanguageModelSAERunnerConfig( @@ -122,10 +120,10 @@ sae = SAETrainingRunner( cfg, - override_sae = TrainingCrosscoderSAE( + override_sae=TrainingCrosscoderSAE( TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), use_error_term=True, - )).run() + ), +).run() print("=" * 50) - diff --git a/tests/helpers.py b/tests/helpers.py index 2e3999698..e0cc6889c 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -97,21 +97,26 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: return mock_config + def build_multilayer_sae_cfg( - hook_name_template : str = "blocks.{layer}.hook_mlp_out", - hook_layers = [0,1,2], - **kwargs: Any) -> LanguageModelSAERunnerConfig: + hook_name_template: str = "blocks.{layer}.hook_mlp_out", + hook_layers=[0, 1, 2], + **kwargs: Any, +) -> LanguageModelSAERunnerConfig: hook_name = hook_name_template.format( layer=f"layers_{min(hook_layers)}_through_{max(hook_layers)}" - ) + ) hook_names = [hook_name_template.format(layer=layer) for layer in hook_layers] return build_sae_cfg( - **({ - "hook_name": hook_name, - "hook_names": hook_names, - "hook_layer": max(hook_layers), + **( + { + "hook_name": hook_name, + "hook_names": hook_names, + "hook_layer": max(hook_layers), } - | kwargs)) + | kwargs + ) + ) MODEL_CACHE: dict[str, HookedTransformer] = {} diff --git a/tests/test_evals.py b/tests/test_evals.py index bc6a426b4..7c5d9636d 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -30,7 +30,12 @@ TrainingCrosscoderSAEConfig, ) from sae_lens.training.training_sae import TrainingSAE -from tests.helpers import TINYSTORIES_MODEL, build_sae_cfg, build_multilayer_sae_cfg, load_model_cached +from tests.helpers import ( + TINYSTORIES_MODEL, + build_multilayer_sae_cfg, + build_sae_cfg, + load_model_cached, +) TRAINER_EVAL_CONFIG = EvalConfig( n_eval_reconstruction_batches=10, @@ -287,9 +292,10 @@ def test_run_empty_evals( assert "token_stats" in eval_metrics, "Expected token_stats in eval_metrics" assert len(feature_metrics) == 0, "Expected empty feature_metrics" + # TODO(mkbehr): consider parameterizing def test_run_evals_crosscoder_training_sae(model): - cfg=build_multilayer_sae_cfg( + cfg = build_multilayer_sae_cfg( model_name="tiny-stories-1M", dataset_path="roneneldan/TinyStories", hook_name_template="blocks.{layer}.hook_resid_pre", @@ -302,8 +308,8 @@ def test_run_evals_crosscoder_training_sae(model): model, cfg, override_dataset=Dataset.from_list([{"text": "hello world"}] * 2000) ) training_crosscoder_sae = TrainingCrosscoderSAE( - TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), - use_error_term=True) + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), use_error_term=True + ) eval_config = EvalConfig( compute_l2_norms=True, compute_sparsity_metrics=True, @@ -326,7 +332,9 @@ def test_run_evals_crosscoder_training_sae(model): ] assert set(eval_metrics.keys()) == set(expected_keys) assert set(feature_metrics.keys()) == set( - ["feature_density", "consistent_activation_heuristic"]) + ["feature_density", "consistent_activation_heuristic"] + ) + @pytest.fixture def mock_args(): diff --git a/tests/training/test_activations_store_multilayer.py b/tests/training/test_activations_store_multilayer.py index 24eaab70e..d63e6bf4f 100644 --- a/tests/training/test_activations_store_multilayer.py +++ b/tests/training/test_activations_store_multilayer.py @@ -2,23 +2,20 @@ import os import tempfile -from typing import Any -import pytest import torch from datasets import Dataset from safetensors.torch import load_file from transformer_lens import HookedTransformer from sae_lens.training.activations_store import ActivationsStore -from tests.helpers import build_sae_cfg, build_multilayer_sae_cfg, load_model_cached +from tests.helpers import build_multilayer_sae_cfg, build_sae_cfg def test_activations_store_init_with_multiple_layers(ts_model: HookedTransformer): """Test initialization with a list of layers instead of a single layer.""" cfg = build_multilayer_sae_cfg( - hook_name_template="blocks.{layer}.hook_resid_pre", - hook_layers=[0, 1, 2] + hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2] ) activation_store = ActivationsStore.from_config(ts_model, cfg) @@ -30,8 +27,7 @@ def test_activations_store_init_with_multiple_layers(ts_model: HookedTransformer ] cfg_single = build_multilayer_sae_cfg( - hook_name_template="blocks.{layer}.hook_resid_pre", - hook_layers=[1] + hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[1] ) single_layer_store = ActivationsStore.from_config(ts_model, cfg_single) @@ -45,11 +41,13 @@ def test_activations_store_get_activations_multiple_layers(ts_model: HookedTrans cfg = build_multilayer_sae_cfg( hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], - context_size=5 + context_size=5, ) dataset = Dataset.from_list([{"text": "hello world"}] * 10) - activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) batch_tokens = activation_store.get_batch_tokens() activations = activation_store.get_activations(batch_tokens) @@ -59,24 +57,19 @@ def test_activations_store_get_activations_multiple_layers(ts_model: HookedTrans cfg.store_batch_size_prompts, cfg.context_size, len(cfg.hook_names), - cfg.d_in + cfg.d_in, ) # Verify that layers are in the correct order # Run with cache directly to compare against _, cache = ts_model.run_with_cache( - batch_tokens, - names_filter=[f"blocks.{i}.hook_resid_pre" for i in [0, 1, 2]] + batch_tokens, names_filter=[f"blocks.{i}.hook_resid_pre" for i in [0, 1, 2]] ) for i, layer in enumerate([0, 1, 2]): hook_name = f"blocks.{layer}.hook_resid_pre" # Compare the activations for this layer with what we got from run_with_cache - assert torch.allclose( - activations[:, :, i, :], - cache[hook_name], - atol=1e-5 - ) + assert torch.allclose(activations[:, :, i, :], cache[hook_name], atol=1e-5) def test_activations_store_get_buffer_multiple_layers(ts_model: HookedTransformer): @@ -84,13 +77,17 @@ def test_activations_store_get_buffer_multiple_layers(ts_model: HookedTransforme cfg = build_multilayer_sae_cfg( hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], - context_size=5 + context_size=5, ) dataset = Dataset.from_list([{"text": "hello world"}] * 20) - activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) - buffer_activations, buffer_tokens = activation_store.get_buffer(n_batches_in_buffer=2) + buffer_activations, buffer_tokens = activation_store.get_buffer( + n_batches_in_buffer=2 + ) # Check shape: [(batch_size * context_size * n_batches), num_layers, d_in] expected_size = cfg.store_batch_size_prompts * cfg.context_size * 2 @@ -104,11 +101,13 @@ def test_activations_store_next_batch_multiple_layers(ts_model: HookedTransforme hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], context_size=5, - train_batch_size_tokens=10 + train_batch_size_tokens=10, ) dataset = Dataset.from_list([{"text": "hello world"}] * 20) - activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) batch = activation_store.next_batch() assert batch.shape == (10, len(cfg.hook_names), activation_store.d_in) @@ -120,37 +119,41 @@ def test_activations_store_normalization_multiple_layers(ts_model: HookedTransfo hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], normalize_activations="expected_average_only_in", - context_size=5 + context_size=5, ) dataset = Dataset.from_list([{"text": "hello world"}] * 20) - activation_store = ActivationsStore.from_config(ts_model, cfg, override_dataset=dataset) + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) activation_store.set_norm_scaling_factor_if_needed() batch = activation_store.next_batch() avg_norm = batch.norm(dim=-1).mean(dim=1) - expected_norm = torch.full_like(avg_norm, cfg.d_in ** 0.5) + expected_norm = torch.full_like(avg_norm, cfg.d_in**0.5) torch.testing.assert_close(avg_norm, expected_norm, atol=1.0, rtol=0.1) def test_backward_compatibility_single_layer(ts_model: HookedTransformer): """Test that single layer behavior is unchanged with the multi-layer support.""" cfg_single = build_sae_cfg( - hook_name="blocks.0.hook_resid_pre", - hook_layer=0, - context_size=5 + hook_name="blocks.0.hook_resid_pre", hook_layer=0, context_size=5 ) dataset = Dataset.from_list([{"text": "hello world"}] * 10) - single_store = ActivationsStore.from_config(ts_model, cfg_single, override_dataset=dataset) + single_store = ActivationsStore.from_config( + ts_model, cfg_single, override_dataset=dataset + ) cfg_multi = build_multilayer_sae_cfg( hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0], - context_size=5 + context_size=5, + ) + multi_store = ActivationsStore.from_config( + ts_model, cfg_multi, override_dataset=dataset ) - multi_store = ActivationsStore.from_config(ts_model, cfg_multi, override_dataset=dataset) batch_tokens_single = single_store.get_batch_tokens() activations_single = single_store.get_activations(batch_tokens_single) @@ -169,7 +172,7 @@ def test_activations_store_multilayer_save_with_norm_scaling_factor( hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2], normalize_activations="expected_average_only_in", - context_size=5 + context_size=5, ) activation_store = ActivationsStore.from_config(ts_model, cfg) activation_store.set_norm_scaling_factor_if_needed() @@ -184,5 +187,5 @@ def test_activations_store_multilayer_save_with_norm_scaling_factor( assert estimated_norm_scaling_factor.shape == (len(cfg.hook_names),) torch.testing.assert_close( estimated_norm_scaling_factor, - activation_store.estimated_norm_scaling_factor + activation_store.estimated_norm_scaling_factor, ) diff --git a/tests/training/test_crosscoder_sae.py b/tests/training/test_crosscoder_sae.py index acdc03913..f78899906 100644 --- a/tests/training/test_crosscoder_sae.py +++ b/tests/training/test_crosscoder_sae.py @@ -5,12 +5,9 @@ import einops import pytest import torch -from torch import nn -from transformer_lens.hook_points import HookPoint from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.crosscoder_sae import CrosscoderSAE -from sae_lens.sae import _disable_hooks from tests.helpers import ALL_ARCHITECTURES, build_multilayer_sae_cfg @@ -21,7 +18,7 @@ "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", "hook_name_template": "blocks.{layer}.hook_resid_pre", - "hook_layers": [0,1,2], + "hook_layers": [0, 1, 2], "d_in": 64, "normalize_sae_decoder": False, "scale_sparsity_penalty_by_decoder_norm": True, @@ -30,7 +27,7 @@ "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", "hook_name_template": "blocks.{layer}.hook_resid_pre", - "hook_layers": [0,1,2], + "hook_layers": [0, 1, 2], "d_in": 64, "normalize_sae_decoder": False, "scale_sparsity_penalty_by_decoder_norm": True, @@ -39,7 +36,7 @@ "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", "hook_name_template": "blocks.{layer}.hook_resid_pre", - "hook_layers": [0,1,2], + "hook_layers": [0, 1, 2], "d_in": 64, "normalize_sae_decoder": False, "scale_sparsity_penalty_by_decoder_norm": True, @@ -85,23 +82,23 @@ def test_crosscoder_sae_init(cfg: LanguageModelSAERunnerConfig): def test_crosscoder_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. - assert sae.W_dec.norm(dim=[-2,-1]).mean().item() != pytest.approx(1.0, abs=1e-6) + assert sae.W_dec.norm(dim=[-2, -1]).mean().item() != pytest.approx(1.0, abs=1e-6) sae2 = deepcopy(sae) sae2.fold_W_dec_norm() - W_dec_norms = sae.W_dec.norm(dim=[-2,-1], keepdim=True) + W_dec_norms = sae.W_dec.norm(dim=[-2, -1], keepdim=True) assert torch.allclose(sae2.W_dec.data, sae.W_dec.data / W_dec_norms) - assert torch.allclose(sae2.W_enc.data, - sae.W_enc.data * einops.rearrange( - W_dec_norms, "d_sae 1 1 -> 1 1 d_sae")) + assert torch.allclose( + sae2.W_enc.data, + sae.W_enc.data * einops.rearrange(W_dec_norms, "d_sae 1 1 -> 1 1 d_sae"), + ) assert torch.allclose(sae2.b_enc.data, sae.b_enc.data * W_dec_norms.squeeze()) # fold_W_dec_norm should normalize W_dec to have unit norm. - assert sae2.W_dec.norm(dim=[-2,-1]).mean().item() == pytest.approx(1.0, abs=1e-6) + assert sae2.W_dec.norm(dim=[-2, -1]).mean().item() == pytest.approx(1.0, abs=1e-6) # we expect activations of features to differ by W_dec norm weights. - activations = torch.randn(10, 4, len(cfg.hook_names), cfg.d_in, - device=cfg.device) + activations = torch.randn(10, 4, len(cfg.hook_names), cfg.d_in, device=cfg.device) feature_activations_1 = sae.encode(activations) feature_activations_2 = sae2.encode(activations) @@ -110,7 +107,9 @@ def test_crosscoder_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): feature_activations_2.nonzero(), ) - expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm(dim=[-2,-1]) + expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm( + dim=[-2, -1] + ) torch.testing.assert_close(feature_activations_2, expected_feature_activations_2) sae_out_1 = sae.decode(feature_activations_1) @@ -119,12 +118,13 @@ def test_crosscoder_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): # but actual outputs should be the same torch.testing.assert_close(sae_out_1, sae_out_2) + @pytest.mark.parametrize("architecture", ALL_ARCHITECTURES) @torch.no_grad() def test_crosscoder_sae_fold_w_dec_norm_all_architectures(architecture: str): if architecture != "standard": pytest.xfail("TODO(mkbehr): support other architectures") - cfg = build_multilayer_sae_cfg(architecture=architecture, hook_layers=[0,1,2]) + cfg = build_multilayer_sae_cfg(architecture=architecture, hook_layers=[0, 1, 2]) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. @@ -132,12 +132,12 @@ def test_crosscoder_sae_fold_w_dec_norm_all_architectures(architecture: str): for param in sae.parameters(): param.data = torch.rand_like(param) - assert sae.W_dec.norm(dim=[-2,-1]).mean().item() != pytest.approx(1.0, abs=1e-6) + assert sae.W_dec.norm(dim=[-2, -1]).mean().item() != pytest.approx(1.0, abs=1e-6) sae2 = deepcopy(sae) sae2.fold_W_dec_norm() # fold_W_dec_norm should normalize W_dec to have unit norm. - assert sae2.W_dec.norm(dim=[-2,-1]).mean().item() == pytest.approx(1.0, abs=1e-6) + assert sae2.W_dec.norm(dim=[-2, -1]).mean().item() == pytest.approx(1.0, abs=1e-6) # we expect activations of features to differ by W_dec norm weights. activations = torch.randn(10, 4, len(cfg.hook_names), cfg.d_in, device=cfg.device) @@ -149,7 +149,9 @@ def test_crosscoder_sae_fold_w_dec_norm_all_architectures(architecture: str): feature_activations_2.nonzero(), ) - expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm(dim=[-2,-1]) + expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm( + dim=[-2, -1] + ) torch.testing.assert_close(feature_activations_2, expected_feature_activations_2) sae_out_1 = sae.decode(feature_activations_1) @@ -158,6 +160,7 @@ def test_crosscoder_sae_fold_w_dec_norm_all_architectures(architecture: str): # but actual outputs should be the same torch.testing.assert_close(sae_out_1, sae_out_2) + @torch.no_grad() def test_crosscoder_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): norm_scaling_factor = torch.Tensor([2.0, 3.0, 4.0]) @@ -173,7 +176,9 @@ def test_crosscoder_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConf assert sae2.cfg.normalize_activations == "none" - assert torch.allclose(sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor.reshape((-1,1,1))) + assert torch.allclose( + sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor.reshape((-1, 1, 1)) + ) # we expect activations of features to differ by W_dec norm weights. # assume activations are already scaled @@ -205,7 +210,7 @@ def test_crosscoder_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConf def test_crosscoder_sae_fold_norm_scaling_factor_all_architectures(architecture: str): if architecture != "standard": pytest.xfail("TODO(mkbehr): support other architectures") - cfg = build_multilayer_sae_cfg(architecture=architecture, hook_layers=[0,1,2]) + cfg = build_multilayer_sae_cfg(architecture=architecture, hook_layers=[0, 1, 2]) norm_scaling_factor = torch.Tensor([2.0, 3.0, 4.0]) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) @@ -218,7 +223,9 @@ def test_crosscoder_sae_fold_norm_scaling_factor_all_architectures(architecture: assert sae2.cfg.normalize_activations == "none" - assert torch.allclose(sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor.reshape((-1,1,1))) + assert torch.allclose( + sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor.reshape((-1, 1, 1)) + ) # we expect activations of features to differ by W_dec norm weights. # assume activations are already scaled @@ -244,8 +251,9 @@ def test_crosscoder_sae_fold_norm_scaling_factor_all_architectures(architecture: # but actual outputs should be the same torch.testing.assert_close(sae_out_1, sae_out_2) + def test_crosscoder_sae_save_and_load_from_pretrained(tmp_path: Path) -> None: - cfg = build_multilayer_sae_cfg(hook_layers=[0,1,2]) + cfg = build_multilayer_sae_cfg(hook_layers=[0, 1, 2]) model_path = str(tmp_path) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) sae_state_dict = sae.state_dict() @@ -269,9 +277,10 @@ def test_crosscoder_sae_save_and_load_from_pretrained(tmp_path: Path) -> None: sae_out_2 = sae_loaded(sae_in) assert torch.allclose(sae_out_1, sae_out_2) + @pytest.mark.xfail(reason="TODO(mkbehr): support other architectures") def test_crosscoder_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None: - cfg = build_multilayer_sae_cfg(architecture="gated", hook_layers=[0,1,2]) + cfg = build_multilayer_sae_cfg(architecture="gated", hook_layers=[0, 1, 2]) model_path = str(tmp_path) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) sae_state_dict = sae.state_dict() @@ -295,8 +304,11 @@ def test_crosscoder_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> N sae_out_2 = sae_loaded(sae_in) assert torch.allclose(sae_out_1, sae_out_2) + def test_crosscoder_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: - cfg = build_multilayer_sae_cfg(activation_fn_kwargs={"k": 30}, hook_layers=[0,1,2]) + cfg = build_multilayer_sae_cfg( + activation_fn_kwargs={"k": 30}, hook_layers=[0, 1, 2] + ) model_path = str(tmp_path) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) sae_state_dict = sae.state_dict() @@ -320,7 +332,15 @@ def test_crosscoder_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> No sae_out_2 = sae_loaded(sae_in) assert torch.allclose(sae_out_1, sae_out_2) + def test_crosscoder_sae_get_name_returns_correct_name_from_cfg_vals() -> None: - cfg = build_multilayer_sae_cfg(model_name="test_model", hook_name_template="blocks.{layer}.test_hook_name", d_sae=128, hook_layers=[0,1,2]) + cfg = build_multilayer_sae_cfg( + model_name="test_model", + hook_name_template="blocks.{layer}.test_hook_name", + d_sae=128, + hook_layers=[0, 1, 2], + ) sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) - assert sae.get_name() == "sae_test_model_blocks.layers_0_through_2.test_hook_name_128" + assert ( + sae.get_name() == "sae_test_model_blocks.layers_0_through_2.test_hook_name_128" + ) diff --git a/tests/training/test_crosscoder_sae_trainer.py b/tests/training/test_crosscoder_sae_trainer.py index ad7ff7a50..8991468ff 100644 --- a/tests/training/test_crosscoder_sae_trainer.py +++ b/tests/training/test_crosscoder_sae_trainer.py @@ -4,18 +4,13 @@ import pytest import torch from datasets import Dataset -from safetensors.torch import load_file from transformer_lens import HookedTransformer -from sae_lens import __version__ from sae_lens.config import LanguageModelSAERunnerConfig -from sae_lens.sae_training_runner import SAETrainingRunner from sae_lens.training.activations_store import ActivationsStore from sae_lens.training.crosscoder_sae_trainer import CrosscoderSAETrainer from sae_lens.training.sae_trainer import ( TrainStepOutput, - _log_feature_sparsity, - _update_sae_lens_training_version, ) from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE from tests.helpers import TINYSTORIES_MODEL, build_multilayer_sae_cfg, load_model_cached @@ -27,7 +22,7 @@ def cfg(): d_in=64, d_sae=128, hook_name_template="blocks.{layer}.hook_mlp_out", - hook_layers=[0,1,2], + hook_layers=[0, 1, 2], normalize_sae_decoder=False, scale_sparsity_penalty_by_decoder_norm=True, ) @@ -47,8 +42,9 @@ def activation_store(model: HookedTransformer, cfg: LanguageModelSAERunnerConfig @pytest.fixture def training_sae(cfg: LanguageModelSAERunnerConfig): - return TrainingCrosscoderSAE.from_dict(cfg.get_training_sae_cfg_dict(), - use_error_term=True) + return TrainingCrosscoderSAE.from_dict( + cfg.get_training_sae_cfg_dict(), use_error_term=True + ) @pytest.fixture @@ -67,7 +63,9 @@ def trainer( ) -def modify_sae_output(sae: TrainingCrosscoderSAE, modifier: Callable[[torch.Tensor], Any]): +def modify_sae_output( + sae: TrainingCrosscoderSAE, modifier: Callable[[torch.Tensor], Any] +): """ Helper to modify the output of the SAE forward pass for use in patching, for use in patch side_effect. We need real grads during training, so we can't just mock the whole forward pass directly. @@ -152,13 +150,14 @@ def test_train_step__sparsity_updates_based_on_feature_act_sparsity( ) assert train_output.feature_acts is feature_acts + def test_build_train_step_log_dict(trainer: CrosscoderSAETrainer) -> None: - sae_in = torch.tensor([[[-1, 0], [-2, 0]], - [[0, 2], [0, 3]], - [[1, 1], [1, 1]]]).float() - sae_out = torch.tensor([[[0, 0], [0, 0]], - [[0, 2], [0, 3]], - [[0.5, 1], [1, 0.5]]]).float() + sae_in = torch.tensor( + [[[-1, 0], [-2, 0]], [[0, 2], [0, 3]], [[1, 1], [1, 1]]] + ).float() + sae_out = torch.tensor( + [[[0, 0], [0, 0]], [[0, 2], [0, 3]], [[0.5, 1], [1, 0.5]]] + ).float() train_output = TrainStepOutput( sae_in=sae_in, sae_out=sae_out, @@ -211,7 +210,7 @@ def test_train_sae_group_on_language_model__runs( training_tokens=20, context_size=8, hook_name_template="blocks.{layer}.hook_mlp_out", - hook_layers=[0,1,2], + hook_layers=[0, 1, 2], normalize_sae_decoder=False, scale_sparsity_penalty_by_decoder_norm=True, ) @@ -220,8 +219,9 @@ def test_train_sae_group_on_language_model__runs( activation_store = ActivationsStore.from_config( ts_model, cfg, override_dataset=dataset ) - sae = TrainingCrosscoderSAE.from_dict(cfg.get_training_sae_cfg_dict(), - use_error_term=True) + sae = TrainingCrosscoderSAE.from_dict( + cfg.get_training_sae_cfg_dict(), use_error_term=True + ) sae = CrosscoderSAETrainer( model=ts_model, sae=sae, diff --git a/tests/training/test_crosscoder_sae_training.py b/tests/training/test_crosscoder_sae_training.py index fe5e982ea..23be54bbd 100644 --- a/tests/training/test_crosscoder_sae_training.py +++ b/tests/training/test_crosscoder_sae_training.py @@ -1,6 +1,3 @@ -from typing import Any - -import einops import pytest import torch from datasets import Dataset @@ -11,7 +8,7 @@ from sae_lens.training.sae_trainer import SAETrainer from sae_lens.training.training_crosscoder_sae import ( TrainingCrosscoderSAE, - TrainingCrosscoderSAEConfig + TrainingCrosscoderSAEConfig, ) from tests.helpers import build_multilayer_sae_cfg @@ -23,7 +20,7 @@ "model_name": "tiny-stories-1M", "dataset_path": "roneneldan/TinyStories", "hook_name": "blocks.{layer}.hook_resid_pre", - "hook_layers": [0,1,2], + "hook_layers": [0, 1, 2], "d_in": 64, "normalize_sae_decoder": False, "scale_sparsity_penalty_by_decoder_norm": True, @@ -32,7 +29,7 @@ "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", "hook_name": "blocks.{layer}.hook_resid_pre", - "hook_layers": [0,1,2], + "hook_layers": [0, 1, 2], "d_in": 64, "normalize_sae_decoder": False, "scale_sparsity_penalty_by_decoder_norm": True, @@ -41,7 +38,7 @@ "model_name": "tiny-stories-1M", "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", "hook_name": "blocks.{layer}.hook_resid_pre", - "hook_layers": [0,1,2], + "hook_layers": [0, 1, 2], "d_in": 64, "normalize_activations": "constant_norm_rescale", "normalize_sae_decoder": False, @@ -68,8 +65,8 @@ def training_crosscoder_sae(cfg: LanguageModelSAERunnerConfig): Pytest fixture to create a mock instance of SparseAutoencoder. """ return TrainingCrosscoderSAE( - TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), - use_error_term=True) + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), use_error_term=True + ) @pytest.fixture @@ -100,6 +97,7 @@ def trainer( cfg=cfg, ) + def test_sae_forward(training_crosscoder_sae: TrainingCrosscoderSAE): batch_size = 32 d_in = training_crosscoder_sae.cfg.d_in @@ -139,14 +137,17 @@ def test_sae_forward(training_crosscoder_sae: TrainingCrosscoderSAE): ) expected_l1_loss = ( - (train_step_output.feature_acts - * training_crosscoder_sae.W_dec.norm(dim=2).sum(dim=1)) + ( + train_step_output.feature_acts + * training_crosscoder_sae.W_dec.norm(dim=2).sum(dim=1) + ) .norm(dim=1, p=1) .mean() ) assert ( pytest.approx(train_step_output.losses["l1_loss"].item(), rel=1e-3) # type: ignore - == training_crosscoder_sae.cfg.l1_coefficient * expected_l1_loss.detach().float() + == training_crosscoder_sae.cfg.l1_coefficient + * expected_l1_loss.detach().float() ) @@ -200,14 +201,17 @@ def test_sae_forward_with_mse_loss_norm( ) expected_l1_loss = ( - (train_step_output.feature_acts * - training_crosscoder_sae.W_dec.norm(dim=2).sum(dim=1)) + ( + train_step_output.feature_acts + * training_crosscoder_sae.W_dec.norm(dim=2).sum(dim=1) + ) .norm(dim=1, p=1) .mean() ) assert ( pytest.approx(train_step_output.losses["l1_loss"].item(), rel=1e-3) # type: ignore - == training_crosscoder_sae.cfg.l1_coefficient * expected_l1_loss.detach().float() + == training_crosscoder_sae.cfg.l1_coefficient + * expected_l1_loss.detach().float() ) @@ -216,24 +220,26 @@ def test_SparseAutoencoder_forward_can_add_noise_to_hidden_pre() -> None: d_in=2, d_sae=4, noise_scale=0, - hook_layers=[1,2,3,4,5], + hook_layers=[1, 2, 3, 4, 5], normalize_sae_decoder=False, - scale_sparsity_penalty_by_decoder_norm=True + scale_sparsity_penalty_by_decoder_norm=True, ) noisy_cfg = build_multilayer_sae_cfg( d_in=2, d_sae=4, noise_scale=100, - hook_layers=[1,2,3,4,5], + hook_layers=[1, 2, 3, 4, 5], normalize_sae_decoder=False, - scale_sparsity_penalty_by_decoder_norm=True + scale_sparsity_penalty_by_decoder_norm=True, ) clean_sae = TrainingCrosscoderSAE( TrainingCrosscoderSAEConfig.from_sae_runner_config(clean_cfg), - use_error_term=True) + use_error_term=True, + ) noisy_sae = TrainingCrosscoderSAE( TrainingCrosscoderSAEConfig.from_sae_runner_config(noisy_cfg), - use_error_term=True) + use_error_term=True, + ) input = torch.randn(3, 5, 2) @@ -247,4 +253,3 @@ def test_SparseAutoencoder_forward_can_add_noise_to_hidden_pre() -> None: # noisy outputs should be different assert not torch.allclose(noisy_output1, noisy_output2) assert not torch.allclose(clean_output1, noisy_output1) - diff --git a/tests/training/test_training_crosscoder_sae.py b/tests/training/test_training_crosscoder_sae.py index 2a93d4ba1..a8499a5cd 100644 --- a/tests/training/test_training_crosscoder_sae.py +++ b/tests/training/test_training_crosscoder_sae.py @@ -1,18 +1,18 @@ import pytest import torch -from sae_lens.crosscoder_sae import CrosscoderSAE from sae_lens.training.training_crosscoder_sae import ( TrainingCrosscoderSAE, TrainingCrosscoderSAEConfig, ) from tests.helpers import build_multilayer_sae_cfg + def test_TrainingCrosscoderSAE_training_forward_pass_can_scale_sparsity_penalty_by_decoder_norm(): cfg = build_multilayer_sae_cfg( d_in=3, d_sae=5, - hook_layers=[0,1,2,3], + hook_layers=[0, 1, 2, 3], normalize_sae_decoder=False, scale_sparsity_penalty_by_decoder_norm=True, ) @@ -37,6 +37,7 @@ def test_TrainingCrosscoderSAE_training_forward_pass_can_scale_sparsity_penalty_ == 2.0 * scaled_feature_acts.norm(p=1, dim=1).mean().detach().item() ) + @pytest.mark.parametrize("architecture", ["standard", "gated", "jumprelu", "topk"]) def test_TrainingCrosscoderSAE_encode_returns_same_value_as_encode_with_hidden_pre( architecture: str, @@ -45,7 +46,7 @@ def test_TrainingCrosscoderSAE_encode_returns_same_value_as_encode_with_hidden_p pytest.xfail("TODO(mkbehr): support other architectures") cfg = build_multilayer_sae_cfg( architecture=architecture, - hook_layers=[0,1,2,3], + hook_layers=[0, 1, 2, 3], normalize_sae_decoder=False, scale_sparsity_penalty_by_decoder_norm=True, ) @@ -58,18 +59,18 @@ def test_TrainingCrosscoderSAE_encode_returns_same_value_as_encode_with_hidden_p encode_with_hidden_pre_out = sae.encode_with_hidden_pre_fn(x)[0] assert torch.allclose(encode_out, encode_with_hidden_pre_out) + def test_TrainingCrosscoderSAE_heuristic_init(): cfg = build_multilayer_sae_cfg( d_in=3, d_sae=5, - hook_layers=[0,1,2,3], + hook_layers=[0, 1, 2, 3], normalize_sae_decoder=False, scale_sparsity_penalty_by_decoder_norm=True, decoder_heuristic_init=True, decoder_heuristic_init_norm=0.2, ) sae = TrainingCrosscoderSAE( - TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), - use_error_term=True) - torch.testing.assert_close(sae.W_dec.norm(dim=[1,2]), - torch.full((5,), 0.2)) + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), use_error_term=True + ) + torch.testing.assert_close(sae.W_dec.norm(dim=[1, 2]), torch.full((5,), 0.2)) From bcfe7a56fa373818138207efc3cd56ab99af1ade Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Tue, 6 May 2025 21:48:44 -0400 Subject: [PATCH 59/61] fix some type errors --- sae_lens/crosscoder_sae.py | 9 +++++---- sae_lens/sae.py | 2 +- sae_lens/training/crosscoder_sae_trainer.py | 2 +- tests/helpers.py | 4 ++-- tests/test_evals.py | 4 +++- tests/training/test_activations_store_multilayer.py | 1 + 6 files changed, 13 insertions(+), 9 deletions(-) diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py index 0540305af..bd404241e 100644 --- a/sae_lens/crosscoder_sae.py +++ b/sae_lens/crosscoder_sae.py @@ -15,7 +15,7 @@ @dataclass class CrosscoderSAEConfig(SAEConfig): - hook_names: list[int] = field(default_factory=list) + hook_names: list[str] = field(default_factory=list) def to_dict(self) -> dict[str, Any]: return super().to_dict() | { @@ -37,16 +37,17 @@ def __init__( raise NotImplementedError("TODO(mkbehr): support other architectures") super().__init__(cfg=cfg, use_error_term=use_error_term) + self.cfg = cfg if self.hook_z_reshaping_mode: raise NotImplementedError("TODO(mkbehr): support hook_z") @classmethod def from_dict(cls, config_dict: dict[str, Any]) -> "CrosscoderSAE": - return cls(CrosscoderSAEConfig.from_dict(config_dict)) + return cls(CrosscoderSAEConfig.from_dict(config_dict)) # type: ignore def input_shape(self): - return (len(self.cfg.hook_names), self.cfg.d_in) + return [len(self.cfg.hook_names), self.cfg.d_in] def encode_standard( self, x: Float[torch.Tensor, "... n_layers d_in"] @@ -130,7 +131,7 @@ def load_from_disk( cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides) cfg_dict = handle_config_defaulting(cfg_dict) sae_cfg = CrosscoderSAEConfig.from_dict(cfg_dict) - sae = cls(sae_cfg) + sae = cls(sae_cfg) # type: ignore sae.process_state_dict_for_loading(state_dict) sae.load_state_dict(state_dict) return sae diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 3ca586ebb..c3e18bd64 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -245,7 +245,7 @@ def run_time_activation_ln_out( self.setup() # Required for `HookedRootModule`s def input_shape(self): - return (self.cfg.d_in,) + return [self.cfg.d_in] def initialize_weights_basic(self): # no config changes encoder bias init for now. diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py index b158de362..d66293622 100644 --- a/sae_lens/training/crosscoder_sae_trainer.py +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -11,7 +11,7 @@ class CrosscoderSAETrainer(SAETrainer): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) # Reconstruction metrics don't make sense for acausal crosscoders. self.trainer_eval_config.compute_ce_loss = False diff --git a/tests/helpers.py b/tests/helpers.py index e0cc6889c..91082bf86 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -100,13 +100,13 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: def build_multilayer_sae_cfg( hook_name_template: str = "blocks.{layer}.hook_mlp_out", - hook_layers=[0, 1, 2], + hook_layers: list[int] = [0, 1, 2], **kwargs: Any, ) -> LanguageModelSAERunnerConfig: hook_name = hook_name_template.format( layer=f"layers_{min(hook_layers)}_through_{max(hook_layers)}" ) - hook_names = [hook_name_template.format(layer=layer) for layer in hook_layers] + hook_names = [hook_name_template.format(layer=str(layer)) for layer in hook_layers] return build_sae_cfg( **( { diff --git a/tests/test_evals.py b/tests/test_evals.py index 7c5d9636d..9a728e1b5 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -294,7 +294,9 @@ def test_run_empty_evals( # TODO(mkbehr): consider parameterizing -def test_run_evals_crosscoder_training_sae(model): +def test_run_evals_crosscoder_training_sae( + model: HookedTransformer, +): cfg = build_multilayer_sae_cfg( model_name="tiny-stories-1M", dataset_path="roneneldan/TinyStories", diff --git a/tests/training/test_activations_store_multilayer.py b/tests/training/test_activations_store_multilayer.py index d63e6bf4f..522e1521e 100644 --- a/tests/training/test_activations_store_multilayer.py +++ b/tests/training/test_activations_store_multilayer.py @@ -92,6 +92,7 @@ def test_activations_store_get_buffer_multiple_layers(ts_model: HookedTransforme # Check shape: [(batch_size * context_size * n_batches), num_layers, d_in] expected_size = cfg.store_batch_size_prompts * cfg.context_size * 2 assert buffer_activations.shape == (expected_size, len(cfg.hook_names), cfg.d_in) + assert buffer_tokens is not None assert buffer_tokens.shape == (expected_size,) From 0f7b3496980972a7265da44327ebfc5b19102556 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Tue, 6 May 2025 22:18:46 -0400 Subject: [PATCH 60/61] revert changing wandb import line --- sae_lens/config.py | 2 +- sae_lens/sae_training_runner.py | 2 +- sae_lens/training/crosscoder_sae_trainer.py | 2 +- sae_lens/training/sae_trainer.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sae_lens/config.py b/sae_lens/config.py index ce49467bc..1a4c8dbd2 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -6,6 +6,7 @@ import simple_parsing import torch +import wandb from datasets import ( Dataset, DatasetDict, @@ -14,7 +15,6 @@ load_dataset, ) -import wandb from sae_lens import __version__, logger DTYPE_MAP = { diff --git a/sae_lens/sae_training_runner.py b/sae_lens/sae_training_runner.py index aa5dc8fed..00eb261c3 100644 --- a/sae_lens/sae_training_runner.py +++ b/sae_lens/sae_training_runner.py @@ -6,10 +6,10 @@ from typing import Any, cast import torch +import wandb from simple_parsing import ArgumentParser from transformer_lens.hook_points import HookedRootModule -import wandb from sae_lens import logger from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig from sae_lens.load_model import load_model diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py index d66293622..1d6d1a233 100644 --- a/sae_lens/training/crosscoder_sae_trainer.py +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -1,9 +1,9 @@ from typing import Any import torch +import wandb from tqdm import tqdm -import wandb from sae_lens.evals import run_evals from sae_lens.training.sae_trainer import SAETrainer from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 659da1a23..eef087e1a 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -3,11 +3,11 @@ from typing import Any, Protocol, cast import torch +import wandb from torch.optim import Adam from tqdm import tqdm from transformer_lens.hook_points import HookedRootModule -import wandb from sae_lens import __version__ from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.evals import EvalConfig, run_evals From 750ee9224bd9d8f91c4cacd5a30c9a58f8615bd6 Mon Sep 17 00:00:00 2001 From: Michael Behr Date: Sun, 11 May 2025 15:49:37 -0400 Subject: [PATCH 61/61] train crosscoders without override_sae --- sae_lens/sae_training_runner.py | 22 ++++++++++++++++++---- scripts/acausal_crosscoder.py | 12 +----------- 2 files changed, 19 insertions(+), 15 deletions(-) diff --git a/sae_lens/sae_training_runner.py b/sae_lens/sae_training_runner.py index 00eb261c3..c052cb9d9 100644 --- a/sae_lens/sae_training_runner.py +++ b/sae_lens/sae_training_runner.py @@ -17,7 +17,10 @@ from sae_lens.training.crosscoder_sae_trainer import CrosscoderSAETrainer from sae_lens.training.geometric_median import compute_geometric_median from sae_lens.training.sae_trainer import SAETrainer -from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE +from sae_lens.training.training_crosscoder_sae import ( + TrainingCrosscoderSAE, + TrainingCrosscoderSAEConfig, +) from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig @@ -78,6 +81,16 @@ def __init__( self.sae = TrainingSAE.load_from_pretrained( self.cfg.from_pretrained_path, self.cfg.device ) + elif self.cfg.hook_names: + self.sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_dict( + self.cfg.get_training_sae_cfg_dict(), + ), + # TODO(mkbehr): When causal crosscoders are + # implemented, set use_error_term false for those. + use_error_term=True, + ) + self._init_sae_group_b_decs() else: self.sae = TrainingSAE( TrainingSAEConfig.from_dict( @@ -102,8 +115,7 @@ def run(self): id=self.cfg.wandb_id, ) - # TODO(mkbehr): make a better way to get the right trainer in - if isinstance(self.sae, TrainingCrosscoderSAE): + if self.cfg.hook_names: trainer = CrosscoderSAETrainer( model=self.model, sae=self.sae, @@ -172,7 +184,6 @@ def run_trainer_with_interruption_handling(self, trainer: SAETrainer): return sae # TODO: move this into the SAE trainer or Training SAE class - # TODO(mkbehr): support crosscoders. def _init_sae_group_b_decs( self, ) -> None: @@ -180,6 +191,9 @@ def _init_sae_group_b_decs( extract all activations at a certain layer and use for sae b_dec initialization """ + if self.cfg.hook_names and self.cfg.b_dec_init_method != "zeros": + raise NotImplementedError("TODO(mkbehr): For crosscoders, only b_dec_init_method='zeros' is implemented.") + if self.cfg.b_dec_init_method == "geometric_median": self.activations_store.set_norm_scaling_factor_if_needed() layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :] diff --git a/scripts/acausal_crosscoder.py b/scripts/acausal_crosscoder.py index 10a1d7c00..2e21e7091 100644 --- a/scripts/acausal_crosscoder.py +++ b/scripts/acausal_crosscoder.py @@ -7,10 +7,6 @@ from sae_lens.config import LanguageModelSAERunnerConfig from sae_lens.sae_training_runner import SAETrainingRunner -from sae_lens.training.training_crosscoder_sae import ( - TrainingCrosscoderSAE, - TrainingCrosscoderSAEConfig, -) if torch.cuda.is_available(): device = "cuda" @@ -118,12 +114,6 @@ dtype="float32", ) -sae = SAETrainingRunner( - cfg, - override_sae=TrainingCrosscoderSAE( - TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), - use_error_term=True, - ), -).run() +sae = SAETrainingRunner(cfg).run() print("=" * 50)