diff --git a/sae_lens/config.py b/sae_lens/config.py index 6faf8ac9b..1a4c8dbd2 100644 --- a/sae_lens/config.py +++ b/sae_lens/config.py @@ -63,6 +63,7 @@ class LanguageModelSAERunnerConfig: model_name (str): The name of the model to use. This should be the name of the model in the Hugging Face model hub. model_class_name (str): The name of the class of the model to use. This should be either `HookedTransformer` or `HookedMamba`. hook_name (str): The name of the hook to use. This should be a valid TransformerLens hook. + hook_names (list[str], optional): The names of multiple hooks to use, in order of evaluation. If this is nonempty, a CrosscoderSAE will be used. hook_name should be a descriptive name, and hook_layer should be the index of the last layer to hook. hook_eval (str): NOT CURRENTLY IN USE. The name of the hook to use for evaluation. hook_layer (int): The index of the layer to hook. Used to stop forward passes early and speed up processing. hook_head_index (int, optional): When the hook if for an activatio with a head index, we can specify a specific head to use here. @@ -147,6 +148,7 @@ class LanguageModelSAERunnerConfig: model_name: str = "gelu-2l" model_class_name: str = "HookedTransformer" hook_name: str = "blocks.0.hook_mlp_out" + hook_names: list[str] = field(default_factory=list) hook_eval: str = "NOT_IN_USE" hook_layer: int = 0 hook_head_index: int | None = None @@ -444,6 +446,7 @@ def get_base_sae_cfg_dict(self) -> dict[str, Any]: "device": self.device, "model_name": self.model_name, "hook_name": self.hook_name, + "hook_names": self.hook_names, "hook_layer": self.hook_layer, "hook_head_index": self.hook_head_index, "activation_fn_str": self.activation_fn, @@ -521,6 +524,7 @@ class CacheActivationsRunnerConfig: model_name (str): The name of the model to use. model_batch_size (int): How many prompts are in the batch of the language model when generating activations. hook_name (str): The name of the hook to use. + hook_names (list[str], optional): The names of multiple hooks to use, in order of evaluation. If this is nonempty, a CrosscoderSAE will be used. hook_layer (int): The layer of the final hook. Currently only support a single hook, so this should be the same as hook_name. d_in (int): Dimension of the model. total_training_tokens (int): Total number of tokens to process. @@ -555,6 +559,7 @@ class CacheActivationsRunnerConfig: d_in: int training_tokens: int + hook_names: list[str] = field(default_factory=list) context_size: int = -1 # Required if dataset is not tokenized model_class_name: str = "HookedTransformer" # defaults to "activations/{dataset}/{model}/{hook_name} @@ -608,8 +613,9 @@ def __post_init__(self): ) if self.new_cached_activations_path is None: + hook_name_str = self.hook_name self.new_cached_activations_path = _default_cached_activations_path( # type: ignore - self.dataset_path, self.model_name, self.hook_name, None + self.dataset_path, self.model_name, hook_name_str, None ) @property diff --git a/sae_lens/crosscoder_sae.py b/sae_lens/crosscoder_sae.py new file mode 100644 index 000000000..bd404241e --- /dev/null +++ b/sae_lens/crosscoder_sae.py @@ -0,0 +1,137 @@ +from dataclasses import dataclass, field +from typing import Any + +import einops +import torch +from jaxtyping import Float + +from sae_lens import SAE, SAEConfig +from sae_lens.toolkit.pretrained_sae_loaders import ( + PretrainedSaeDiskLoader, + handle_config_defaulting, + sae_lens_disk_loader, +) + + +@dataclass +class CrosscoderSAEConfig(SAEConfig): + hook_names: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return super().to_dict() | { + "hook_names": self.hook_names, + } + + +class CrosscoderSAE(SAE): + """ + Sparse autoencoder that acts on multiple layers of activations. + """ + + def __init__( + self, + cfg: CrosscoderSAEConfig, + use_error_term: bool = False, + ): + if cfg.architecture != "standard": + raise NotImplementedError("TODO(mkbehr): support other architectures") + + super().__init__(cfg=cfg, use_error_term=use_error_term) + self.cfg = cfg + + if self.hook_z_reshaping_mode: + raise NotImplementedError("TODO(mkbehr): support hook_z") + + @classmethod + def from_dict(cls, config_dict: dict[str, Any]) -> "CrosscoderSAE": + return cls(CrosscoderSAEConfig.from_dict(config_dict)) # type: ignore + + def input_shape(self): + return [len(self.cfg.hook_names), self.cfg.d_in] + + def encode_standard( + self, x: Float[torch.Tensor, "... n_layers d_in"] + ) -> Float[torch.Tensor, "... d_sae"]: + """ + Calculate SAE features from inputs + """ + sae_in = self.process_sae_in(x) + + hidden_pre = self.hook_sae_acts_pre( + einops.einsum( + sae_in, + self.W_enc, + "... n_layers d_in, n_layers d_in d_sae -> ... d_sae", + ) + + self.b_enc + ) + return self.hook_sae_acts_post(self.activation_fn(hidden_pre)) + + def decode( + self, feature_acts: Float[torch.Tensor, "... d_sae"] + ) -> Float[torch.Tensor, "... n_layers d_in"]: + """Decodes SAE feature activation tensor into a reconstructed + input activation tensor.""" + sae_out = self.hook_sae_recons( + einops.einsum( + self.apply_finetuning_scaling_factor(feature_acts), + self.W_dec, + "... d_sae, d_sae n_layers d_in -> ... n_layers d_in", + ) + + self.b_dec + ) + + # handle run time activation normalization if needed + # will fail if you call this twice without calling encode in between. + sae_out = self.run_time_activation_norm_fn_out(sae_out) + + # handle hook z reshaping if needed. + return self.reshape_fn_out(sae_out, self.d_head) # type: ignore + + @torch.no_grad() + def fold_W_dec_norm(self): + W_dec_norms = self.W_dec.norm(dim=[-2, -1], keepdim=True) + self.W_dec.data = self.W_dec.data / W_dec_norms + self.W_enc.data = self.W_enc.data * einops.rearrange( + W_dec_norms, "d_sae 1 1 -> 1 1 d_sae" + ) + if self.cfg.architecture == "gated": + self.r_mag.data = self.r_mag.data * W_dec_norms.squeeze() + self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze() + self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze() + elif self.cfg.architecture == "jumprelu": + self.threshold.data = self.threshold.data * W_dec_norms.squeeze() + self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze() + else: + self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze() + + @torch.no_grad() + def fold_activation_norm_scaling_factor( + self, activation_norm_scaling_factor: Float[torch.Tensor, "n_layers"] + ): + self.W_enc.data = self.W_enc.data * activation_norm_scaling_factor.reshape( + (-1, 1, 1) + ) + # previously weren't doing this. + self.W_dec.data = self.W_dec.data / activation_norm_scaling_factor.unsqueeze(-1) + self.b_dec.data = self.b_dec.data / activation_norm_scaling_factor.unsqueeze(-1) + + # once we normalize, we shouldn't need to scale activations. + self.cfg.normalize_activations = "none" + + @classmethod + def load_from_disk( + cls, + path: str, + device: str = "cpu", + dtype: str | None = None, + converter: PretrainedSaeDiskLoader = sae_lens_disk_loader, + ) -> "CrosscoderSAE": + overrides = {"dtype": dtype} if dtype is not None else None + cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides) + cfg_dict = handle_config_defaulting(cfg_dict) + sae_cfg = CrosscoderSAEConfig.from_dict(cfg_dict) + sae = cls(sae_cfg) # type: ignore + sae.process_state_dict_for_loading(state_dict) + sae.load_state_dict(state_dict) + return sae diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 63d43d203..5b0a25af6 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -378,7 +378,9 @@ def get_sparsity_and_variance_metrics( ignore_tokens: set[int | None] = set(), verbose: bool = False, ) -> tuple[dict[str, Any], dict[str, Any]]: - hook_name = sae.cfg.hook_name + hook_names = ( + sae.cfg.hook_names if hasattr(sae.cfg, "hook_names") else [sae.cfg.hook_name] + ) hook_head_index = sae.cfg.hook_head_index metric_dict = {} @@ -434,7 +436,7 @@ def get_sparsity_and_variance_metrics( _, cache = model.run_with_cache( batch_tokens, prepend_bos=False, - names_filter=[hook_name], + names_filter=hook_names, stop_at_layer=sae.cfg.hook_layer + 1, **model_kwargs, ) @@ -443,11 +445,20 @@ def get_sparsity_and_variance_metrics( # which will do their own reshaping for hook z. has_head_dim_key_substrings = ["hook_q", "hook_k", "hook_v", "hook_z"] if hook_head_index is not None: - original_act = cache[hook_name][:, :, hook_head_index] - elif any(substring in hook_name for substring in has_head_dim_key_substrings): - original_act = cache[hook_name].flatten(-2, -1) + # TODO(mkbehr) support head dimension for mutilayer evals + assert len(hook_names) == 1 + original_act = cache[hook_names[0]][:, :, hook_head_index] + elif any( + substring in hook_names[0] for substring in has_head_dim_key_substrings + ): + # TODO(mkbehr) support head dimension for mutilayer evals + original_act = cache[hook_names[0]].flatten(-2, -1) + elif hasattr(sae.cfg, "hook_names"): + # TODO(mkbehr): support head dimension for mutilayer evals + layerwise_activations = [cache[hook_name] for hook_name in hook_names] + original_act = torch.stack(layerwise_activations, dim=2) else: - original_act = cache[hook_name] + original_act = cache[hook_names[0]] # normalise if necessary (necessary in training only, otherwise we should fold the scaling in) if activation_store.normalize_activations == "expected_average_only_in": @@ -461,14 +472,17 @@ def get_sparsity_and_variance_metrics( if activation_store.normalize_activations == "expected_average_only_in": sae_out = activation_store.unscale(sae_out) - flattened_sae_input = einops.rearrange(original_act, "b ctx d -> (b ctx) d") + flattened_sae_input = einops.rearrange( + original_act, "b ctx d ... -> (b ctx) (d ...)" + ) flattened_sae_feature_acts = einops.rearrange( sae_feature_activations, "b ctx d -> (b ctx) d" ) - flattened_sae_out = einops.rearrange(sae_out, "b ctx d -> (b ctx) d") + flattened_sae_out = einops.rearrange(sae_out, "b ctx d ... -> (b ctx) (d ...)") # TODO: Clean this up. # apply mask + # TODO(mkbehr): test mask support w/ multilayer masked_sae_feature_activations = sae_feature_activations * mask.unsqueeze(-1) flattened_sae_input = flattened_sae_input[ flattened_mask.to(flattened_sae_input.device) diff --git a/sae_lens/sae.py b/sae_lens/sae.py index edd873b30..c3e18bd64 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -244,6 +244,9 @@ def run_time_activation_ln_out( self.setup() # Required for `HookedRootModule`s + def input_shape(self): + return [self.cfg.d_in] + def initialize_weights_basic(self): # no config changes encoder bias init for now. self.b_enc = nn.Parameter( @@ -254,7 +257,10 @@ def initialize_weights_basic(self): self.W_dec = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device + self.cfg.d_sae, + *self.input_shape(), + dtype=self.dtype, + device=self.device, ) ) ) @@ -262,14 +268,17 @@ def initialize_weights_basic(self): self.W_enc = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device + *self.input_shape(), + self.cfg.d_sae, + dtype=self.dtype, + device=self.device, ) ) ) # methdods which change b_dec as a function of the dataset are implemented after init. self.b_dec = nn.Parameter( - torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device) + torch.zeros(*self.input_shape(), dtype=self.dtype, device=self.device) ) # scaling factor for fine-tuning (not to be used in initial training) @@ -284,7 +293,10 @@ def initialize_weights_gated(self): self.W_enc = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - self.cfg.d_in, self.cfg.d_sae, dtype=self.dtype, device=self.device + *self.input_shape(), + self.cfg.d_sae, + dtype=self.dtype, + device=self.device, ) ) ) @@ -304,13 +316,16 @@ def initialize_weights_gated(self): self.W_dec = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device + self.cfg.d_sae, + *self.input_shape(), + dtype=self.dtype, + device=self.device, ) ) ) self.b_dec = nn.Parameter( - torch.zeros(self.cfg.d_in, dtype=self.dtype, device=self.device) + torch.zeros(*self.input_shape(), dtype=self.dtype, device=self.device) ) def initialize_weights_jumprelu(self): @@ -640,7 +655,7 @@ def from_pretrained( ) cfg_dict = handle_config_defaulting(cfg_dict) - sae = cls(SAEConfig.from_dict(cfg_dict)) + sae = cls.from_dict(cfg_dict) sae.process_state_dict_for_loading(state_dict) sae.load_state_dict(state_dict) diff --git a/sae_lens/sae_training_runner.py b/sae_lens/sae_training_runner.py index c6a282b36..c052cb9d9 100644 --- a/sae_lens/sae_training_runner.py +++ b/sae_lens/sae_training_runner.py @@ -14,8 +14,13 @@ from sae_lens.config import HfDataset, LanguageModelSAERunnerConfig from sae_lens.load_model import load_model from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.crosscoder_sae_trainer import CrosscoderSAETrainer from sae_lens.training.geometric_median import compute_geometric_median from sae_lens.training.sae_trainer import SAETrainer +from sae_lens.training.training_crosscoder_sae import ( + TrainingCrosscoderSAE, + TrainingCrosscoderSAEConfig, +) from sae_lens.training.training_sae import TrainingSAE, TrainingSAEConfig @@ -76,6 +81,16 @@ def __init__( self.sae = TrainingSAE.load_from_pretrained( self.cfg.from_pretrained_path, self.cfg.device ) + elif self.cfg.hook_names: + self.sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_dict( + self.cfg.get_training_sae_cfg_dict(), + ), + # TODO(mkbehr): When causal crosscoders are + # implemented, set use_error_term false for those. + use_error_term=True, + ) + self._init_sae_group_b_decs() else: self.sae = TrainingSAE( TrainingSAEConfig.from_dict( @@ -100,13 +115,22 @@ def run(self): id=self.cfg.wandb_id, ) - trainer = SAETrainer( - model=self.model, - sae=self.sae, - activation_store=self.activations_store, - save_checkpoint_fn=self.save_checkpoint, - cfg=self.cfg, - ) + if self.cfg.hook_names: + trainer = CrosscoderSAETrainer( + model=self.model, + sae=self.sae, + activation_store=self.activations_store, + save_checkpoint_fn=self.save_checkpoint, + cfg=self.cfg, + ) + else: + trainer = SAETrainer( + model=self.model, + sae=self.sae, + activation_store=self.activations_store, + save_checkpoint_fn=self.save_checkpoint, + cfg=self.cfg, + ) self._compile_if_needed() sae = self.run_trainer_with_interruption_handling(trainer) @@ -167,6 +191,9 @@ def _init_sae_group_b_decs( extract all activations at a certain layer and use for sae b_dec initialization """ + if self.cfg.hook_names and self.cfg.b_dec_init_method != "zeros": + raise NotImplementedError("TODO(mkbehr): For crosscoders, only b_dec_init_method='zeros' is implemented.") + if self.cfg.b_dec_init_method == "geometric_median": self.activations_store.set_norm_scaling_factor_if_needed() layer_acts = self.activations_store.storage_buffer.detach()[:, 0, :] diff --git a/sae_lens/training/activations_store.py b/sae_lens/training/activations_store.py index b4a5096b7..d6c9ccd63 100644 --- a/sae_lens/training/activations_store.py +++ b/sae_lens/training/activations_store.py @@ -45,6 +45,7 @@ class ActivationsStore: cached_activation_dataset: Dataset | None = None tokens_column: Literal["tokens", "input_ids", "text", "problem"] hook_name: str + hook_names: list[str] hook_layer: int hook_head_index: int | None _dataloader: Iterator[Any] | None = None @@ -66,6 +67,7 @@ def from_cache_activations( dtype=cfg.dtype, hook_name=cfg.hook_name, hook_layer=cfg.hook_layer, + # TODO(mkbehr): set hook layers if set in cached activations context_size=cfg.context_size, d_in=cfg.d_in, n_batches_in_buffer=cfg.n_batches_in_buffer, @@ -125,6 +127,7 @@ def from_config( dataset=override_dataset or cfg.dataset_path, streaming=cfg.streaming, hook_name=cfg.hook_name, + hook_names=cfg.hook_names, hook_layer=cfg.hook_layer, hook_head_index=cfg.hook_head_index, context_size=cfg.context_size, @@ -164,6 +167,7 @@ def from_sae( dataset=sae.cfg.dataset_path if dataset is None else dataset, d_in=sae.cfg.d_in, hook_name=sae.cfg.hook_name, + hook_names=sae.cfg.hook_names, hook_layer=sae.cfg.hook_layer, hook_head_index=sae.cfg.hook_head_index, context_size=sae.cfg.context_size if context_size is None else context_size, @@ -198,6 +202,7 @@ def __init__( normalize_activations: str, device: torch.device, dtype: str, + hook_names: list[str] | None = None, cached_activations_path: str | None = None, model_kwargs: dict[str, Any] | None = None, autocast_lm: bool = False, @@ -230,6 +235,7 @@ def __init__( ) self.hook_name = hook_name + self.hook_names = hook_names if hook_names is not None else [] self.hook_layer = hook_layer self.hook_head_index = hook_head_index self.context_size = context_size @@ -389,7 +395,8 @@ def load_cached_activation_dataset(self) -> Dataset | None: # --- # Actual code activations_dataset = datasets.load_from_disk(self.cached_activations_path) - columns = [self.hook_name] + # TODO(mkbehr): test multiple layers + columns = self.hook_names or [self.hook_name] if "token_ids" in activations_dataset.column_names: columns.append("token_ids") activations_dataset.set_format( @@ -428,6 +435,12 @@ def apply_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: raise ValueError( "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first" ) + # Norm scaling factor is a float in the single-layer case, and + # a tensor in the multilayer case. + if self.hook_names: + return activations * self.estimated_norm_scaling_factor.unsqueeze(-1).to( + activations.device + ) return activations * self.estimated_norm_scaling_factor def unscale(self, activations: torch.Tensor) -> torch.Tensor: @@ -435,6 +448,8 @@ def unscale(self, activations: torch.Tensor) -> torch.Tensor: raise ValueError( "estimated_norm_scaling_factor is not set, call set_norm_scaling_factor_if_needed() first" ) + if self.hook_names: + return activations / self.estimated_norm_scaling_factor.unsqueeze(-1) return activations / self.estimated_norm_scaling_factor def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: @@ -442,17 +457,23 @@ def get_norm_scaling_factor(self, activations: torch.Tensor) -> torch.Tensor: @torch.no_grad() def estimate_norm_scaling_factor(self, n_batches_for_norm_estimate: int = int(1e3)): - norms_per_batch = [] - for _ in tqdm( + norms_per_batch = torch.empty( + len(self.hook_names) or 1, n_batches_for_norm_estimate, device=self.device + ) + for batch_i in tqdm( range(n_batches_for_norm_estimate), desc="Estimating norm scaling factor" ): - # temporalily set estimated_norm_scaling_factor to 1.0 so the dataloader works - self.estimated_norm_scaling_factor = 1.0 - acts = self.next_batch()[:, 0] + # temporarily set estimated_norm_scaling_factor to 1.0 so the dataloader works + self.estimated_norm_scaling_factor = torch.ones(1, device=self.device) + acts = self.next_batch() self.estimated_norm_scaling_factor = None - norms_per_batch.append(acts.norm(dim=-1).mean().item()) - mean_norm = np.mean(norms_per_batch) - return np.sqrt(self.d_in) / mean_norm + norms_per_batch[:, batch_i] = acts.norm(dim=-1).mean(dim=0) + mean_norm = norms_per_batch.mean(dim=1) + # Norm scaling factor is a float in the single-layer case, and + # a tensor in the multilayer case. + if self.hook_names: + return np.sqrt(self.d_in) / mean_norm + return np.sqrt(self.d_in) / mean_norm.item() def shuffle_input_dataset(self, seed: int, buffer_size: int = 1): """ @@ -532,42 +553,45 @@ def get_activations(self, batch_tokens: torch.Tensor): else: autocast_if_enabled = contextlib.nullcontext() + hook_names = self.hook_names or [self.hook_name] + stop_at_layer = self.hook_layer + 1 + with autocast_if_enabled: layerwise_activations_cache = self.model.run_with_cache( batch_tokens, - names_filter=[self.hook_name], - stop_at_layer=self.hook_layer + 1, + names_filter=hook_names, + stop_at_layer=stop_at_layer, prepend_bos=False, **self.model_kwargs, )[1] - layerwise_activations = layerwise_activations_cache[self.hook_name][ - :, slice(*self.seqpos_slice) + layerwise_activations = [ + layerwise_activations_cache[hook_name][:, slice(*self.seqpos_slice)] + for hook_name in hook_names ] - n_batches, n_context = layerwise_activations.shape[:2] - - stacked_activations = torch.zeros((n_batches, n_context, 1, self.d_in)) + n_batches, n_context = layerwise_activations[0].shape[:2] if self.hook_head_index is not None: - stacked_activations[:, :, 0] = layerwise_activations[ - :, :, self.hook_head_index + layerwise_activations = [ + activation[:, :, self.hook_head_index] + for activation in layerwise_activations ] - elif layerwise_activations.ndim > 3: # if we have a head dimension + elif layerwise_activations[0].ndim > 3: # if we have a head dimension try: - stacked_activations[:, :, 0] = layerwise_activations.view( - n_batches, n_context, -1 - ) + layerwise_activations = [ + activation.view(n_batches, n_context, -1) + for activation in layerwise_activations + ] except RuntimeError as e: logger.error(f"Error during view operation: {e}") logger.info("Attempting to use reshape instead...") - stacked_activations[:, :, 0] = layerwise_activations.reshape( - n_batches, n_context, -1 - ) - else: - stacked_activations[:, :, 0] = layerwise_activations + layerwise_activations = [ + activation.reshape(n_batches, n_context, -1) + for activation in layerwise_activations + ] - return stacked_activations + return torch.stack(layerwise_activations, dim=2) def _load_buffer_from_cached( self, @@ -589,8 +613,7 @@ def _load_buffer_from_cached( raises StopIteration """ assert self.cached_activation_dataset is not None - # In future, could be a list of multiple hook names - hook_names = [self.hook_name] + hook_names = self.hook_names or [self.hook_name] if not set(hook_names).issubset(self.cached_activation_dataset.column_names): raise ValueError( f"Missing columns in dataset. Expected {hook_names}, " @@ -660,7 +683,7 @@ def get_buffer( batch_size = self.store_batch_size_prompts d_in = self.d_in total_size = batch_size * n_batches_in_buffer - num_layers = 1 + num_layers = len(self.hook_names) or 1 if self.cached_activation_dataset is not None: return self._load_buffer_from_cached( diff --git a/sae_lens/training/crosscoder_sae_trainer.py b/sae_lens/training/crosscoder_sae_trainer.py new file mode 100644 index 000000000..1d6d1a233 --- /dev/null +++ b/sae_lens/training/crosscoder_sae_trainer.py @@ -0,0 +1,132 @@ +from typing import Any + +import torch +import wandb +from tqdm import tqdm + +from sae_lens.evals import run_evals +from sae_lens.training.sae_trainer import SAETrainer +from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE +from sae_lens.training.training_sae import TrainStepOutput + + +class CrosscoderSAETrainer(SAETrainer): + def __init__(self, *args: Any, **kwargs: Any): + super().__init__(*args, **kwargs) + # Reconstruction metrics don't make sense for acausal crosscoders. + self.trainer_eval_config.compute_ce_loss = False + self.trainer_eval_config.compute_kl = False + + def fit(self) -> TrainingCrosscoderSAE: + pbar = tqdm( + total=self.cfg.total_training_tokens, desc="Training Crosscoder SAE" + ) + + self.activations_store.set_norm_scaling_factor_if_needed() + + # Train loop + while self.n_training_tokens < self.cfg.total_training_tokens: + # Do a training step. + layer_acts = self.activations_store.next_batch().to(self.sae.device) + self.n_training_tokens += self.cfg.train_batch_size_tokens + + step_output = self._train_step(sae=self.sae, sae_in=layer_acts) + + if self.cfg.log_to_wandb: + self._log_train_step(step_output) + self._run_and_log_evals() + + self._checkpoint_if_needed() + self.n_training_steps += 1 + self._update_pbar(step_output, pbar) + + ### If n_training_tokens > sae_group.cfg.training_tokens, then we should switch to fine-tuning (if we haven't already) + self._begin_finetuning_if_needed() + + # fold the estimated norm scaling factor into the sae weights + if self.activations_store.estimated_norm_scaling_factor is not None: + self.sae.fold_activation_norm_scaling_factor( + self.activations_store.estimated_norm_scaling_factor + ) + self.activations_store.estimated_norm_scaling_factor = None + + # save final sae group to checkpoints folder + self.save_checkpoint( + trainer=self, + checkpoint_name=f"final_{self.n_training_tokens}", + wandb_aliases=["final_model"], + ) + + pbar.close() + return self.sae + + @torch.no_grad() + def _build_train_step_log_dict( + self, + output: TrainStepOutput, + n_training_tokens: int, + ) -> dict[str, Any]: + log_dict = super()._build_train_step_log_dict(output, n_training_tokens) + + sae_in = output.sae_in + sae_out = output.sae_out + per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=(-2, -1)).squeeze() + total_variance = (sae_in - sae_in.mean(0)).pow(2).sum((-2, -1)) + explained_variance = 1 - per_token_l2_loss / total_variance + + log_dict |= { + "metrics/explained_variance": explained_variance.mean().item(), + "metrics/explained_variance_std": explained_variance.std().item(), + } + return log_dict + + @torch.no_grad() + def _run_and_log_evals(self): + # record loss frequently, but not all the time. + if (self.n_training_steps + 1) % ( + self.cfg.wandb_log_frequency * self.cfg.eval_every_n_wandb_logs + ) == 0: + self.sae.eval() + ignore_tokens = set() + if self.activations_store.exclude_special_tokens is not None: + ignore_tokens = set( + self.activations_store.exclude_special_tokens.tolist() + ) + eval_metrics, _ = run_evals( + sae=self.sae, + activation_store=self.activations_store, + model=self.model, + eval_config=self.trainer_eval_config, + ignore_tokens=ignore_tokens, + model_kwargs=self.cfg.model_kwargs, + ) # not calculating featurwise metrics here. + + # Remove eval metrics that are already logged during training + eval_metrics.pop("metrics/explained_variance", None) + eval_metrics.pop("metrics/explained_variance_std", None) + eval_metrics.pop("metrics/l0", None) + eval_metrics.pop("metrics/l1", None) + eval_metrics.pop("metrics/mse", None) + + # Remove metrics that are not useful for wandb logging + eval_metrics.pop("metrics/total_tokens_evaluated", None) + + W_dec_norm_dist = ( + self.sae.W_dec.detach().float().norm(dim=(1, 2)).cpu().numpy() + ) + eval_metrics["weights/W_dec_norms"] = wandb.Histogram(W_dec_norm_dist) # type: ignore + + if self.sae.cfg.architecture == "standard": + b_e_dist = self.sae.b_enc.detach().float().cpu().numpy() + eval_metrics["weights/b_e"] = wandb.Histogram(b_e_dist) # type: ignore + elif self.sae.cfg.architecture == "gated": + b_gate_dist = self.sae.b_gate.detach().float().cpu().numpy() + eval_metrics["weights/b_gate"] = wandb.Histogram(b_gate_dist) # type: ignore + b_mag_dist = self.sae.b_mag.detach().float().cpu().numpy() + eval_metrics["weights/b_mag"] = wandb.Histogram(b_mag_dist) # type: ignore + + wandb.log( + eval_metrics, + step=self.n_training_steps, + ) + self.sae.train() diff --git a/sae_lens/training/training_crosscoder_sae.py b/sae_lens/training/training_crosscoder_sae.py new file mode 100644 index 000000000..5e9385996 --- /dev/null +++ b/sae_lens/training/training_crosscoder_sae.py @@ -0,0 +1,288 @@ +from dataclasses import dataclass +from typing import Any + +import einops +import torch +from jaxtyping import Float +from torch import nn + +from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.crosscoder_sae import CrosscoderSAE, CrosscoderSAEConfig +from sae_lens.toolkit.pretrained_sae_loaders import ( + PretrainedSaeDiskLoader, + handle_config_defaulting, + sae_lens_disk_loader, +) +from sae_lens.training.training_sae import ( + TrainingSAE, + TrainingSAEConfig, + TrainStepOutput, +) + +SPARSITY_PATH = "sparsity.safetensors" +SAE_WEIGHTS_PATH = "sae_weights.safetensors" +SAE_CFG_PATH = "cfg.json" + + +@dataclass(kw_only=True) +class TrainingCrosscoderSAEConfig(CrosscoderSAEConfig, TrainingSAEConfig): + sparsity_penalty_decoder_norm_lp_norm: float = 1 + + @classmethod + def from_sae_runner_config( + cls, cfg: LanguageModelSAERunnerConfig + ) -> "TrainingSAEConfig": + return cls( + # base config + architecture=cfg.architecture, + d_in=cfg.d_in, + d_sae=cfg.d_sae, # type: ignore + dtype=cfg.dtype, + device=cfg.device, + model_name=cfg.model_name, + hook_name=cfg.hook_name, + hook_names=cfg.hook_names, + hook_layer=cfg.hook_layer, + hook_head_index=cfg.hook_head_index, + activation_fn_str=cfg.activation_fn, + activation_fn_kwargs=cfg.activation_fn_kwargs, + apply_b_dec_to_input=cfg.apply_b_dec_to_input, + finetuning_scaling_factor=cfg.finetuning_method is not None, + sae_lens_training_version=cfg.sae_lens_training_version, + context_size=cfg.context_size, + dataset_path=cfg.dataset_path, + prepend_bos=cfg.prepend_bos, + seqpos_slice=cfg.seqpos_slice, + # Training cfg + l1_coefficient=cfg.l1_coefficient, + lp_norm=cfg.lp_norm, + use_ghost_grads=cfg.use_ghost_grads, + normalize_sae_decoder=cfg.normalize_sae_decoder, + noise_scale=cfg.noise_scale, + decoder_orthogonal_init=cfg.decoder_orthogonal_init, + mse_loss_normalization=cfg.mse_loss_normalization, + decoder_heuristic_init=cfg.decoder_heuristic_init, + decoder_heuristic_init_norm=cfg.decoder_heuristic_init_norm, + init_encoder_as_decoder_transpose=cfg.init_encoder_as_decoder_transpose, + scale_sparsity_penalty_by_decoder_norm=cfg.scale_sparsity_penalty_by_decoder_norm, + normalize_activations=cfg.normalize_activations, + dataset_trust_remote_code=cfg.dataset_trust_remote_code, + model_from_pretrained_kwargs=cfg.model_from_pretrained_kwargs or {}, + jumprelu_init_threshold=cfg.jumprelu_init_threshold, + jumprelu_bandwidth=cfg.jumprelu_bandwidth, + ) + + def to_dict(self) -> dict[str, Any]: + return ( + TrainingSAEConfig.to_dict(self) + | CrosscoderSAEConfig.to_dict(self) + | { + "sparsity_penalty_decoder_norm_lp_norm": self.sparsity_penalty_decoder_norm_lp_norm, + } + ) + + def get_base_sae_cfg_dict(self) -> dict[str, Any]: + return TrainingSAEConfig.get_base_sae_cfg_dict(self) | { + "hook_names": self.hook_names + } + + +class TrainingCrosscoderSAE(CrosscoderSAE, TrainingSAE): + # TODO(mkbehr) future implementation + # initialize_weights_jumprelu (can maybe just use input shape in trainingsae) + # encode_with_hidden_pre_{gated,jumprelu} + # calculate_topk_aux_loss + # calculate_ghost_grad_loss + # fold_W_dec_norm for jumprelu + + def __init__(self, cfg: TrainingCrosscoderSAEConfig, use_error_term: bool = False): + super().__init__(cfg, use_error_term=use_error_term) + + @classmethod + def from_dict( + cls, config_dict: dict[str, Any], use_error_term: bool = False + ) -> "TrainingSAE": + return cls( + TrainingCrosscoderSAEConfig.from_dict(config_dict), + use_error_term=use_error_term, + ) + + @staticmethod + def base_sae_cfg(cfg: TrainingCrosscoderSAEConfig): + return CrosscoderSAEConfig.from_dict(cfg.get_base_sae_cfg_dict()) + + def check_cfg_compatibility(self): + if self.cfg.architecture != "standard": + raise NotImplementedError("TODO(mkbehr): support other archs") + if not self.cfg.scale_sparsity_penalty_by_decoder_norm: + raise ValueError( + "Crosscoders require scale_sparsity_penalty_by_decoder_norm" + ) + if not self.use_error_term: + raise NotImplementedError("TODO(mkbehr): support causal crosscoders") + if self.cfg.use_ghost_grads: + raise NotImplementedError("TODO(mkbehr): support ghost grads") + super().check_cfg_compatibility() + + def encode_with_hidden_pre( + self, x: Float[torch.Tensor, "... n_layers d_in"] + ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: + sae_in = self.process_sae_in(x) + + hidden_pre = self.hook_sae_acts_pre( + einops.einsum( + sae_in, + self.W_enc, + "... n_layers d_in, n_layers d_in d_sae -> ... d_sae", + ) + + self.b_enc + ) + hidden_pre_noised = hidden_pre + ( + torch.randn_like(hidden_pre) * self.cfg.noise_scale * self.training + ) + feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre_noised)) + + return feature_acts, hidden_pre_noised + + def training_forward_pass( + self, + sae_in: torch.Tensor, + current_l1_coefficient: float, + dead_neuron_mask: torch.Tensor | None = None, + ) -> TrainStepOutput: + # do a forward pass to get SAE out, but we also need the + # hidden pre. + feature_acts, hidden_pre = self.encode_with_hidden_pre_fn(sae_in) + sae_out = self.decode(feature_acts) + + # MSE LOSS + per_item_mse_loss = self.mse_loss_fn(sae_out, sae_in) + mse_loss = per_item_mse_loss.sum(dim=-1).mean() + + losses: dict[str, float | torch.Tensor] = {} + + assert self.cfg.scale_sparsity_penalty_by_decoder_norm + decoder_norms = self.W_dec.norm(dim=2) + feature_act_weights = decoder_norms.norm( + p=self.cfg.sparsity_penalty_decoder_norm_lp_norm, dim=1 + ) + weighted_feature_acts = feature_acts * feature_act_weights + sparsity = weighted_feature_acts.norm( + p=self.cfg.lp_norm, dim=-1 + ) # sum over the feature dimension + + l1_loss = (current_l1_coefficient * sparsity).mean() + loss = mse_loss + l1_loss + if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None: + ghost_grad_loss = self.calculate_ghost_grad_loss( + x=sae_in, + sae_out=sae_out, + per_item_mse_loss=per_item_mse_loss, + hidden_pre=hidden_pre, + dead_neuron_mask=dead_neuron_mask, + ) + losses["ghost_grad_loss"] = ghost_grad_loss + loss = loss + ghost_grad_loss + losses["l1_loss"] = l1_loss + + losses["mse_loss"] = mse_loss + + return TrainStepOutput( + sae_in=sae_in, + sae_out=sae_out, + feature_acts=feature_acts, + hidden_pre=hidden_pre, + loss=loss, + losses=losses, + ) + + @classmethod + def load_from_disk( + cls, + path: str, + device: str = "cpu", + dtype: str | None = None, + converter: PretrainedSaeDiskLoader = sae_lens_disk_loader, + ) -> "TrainingCrosscoderSAE": + overrides = {"dtype": dtype} if dtype is not None else None + cfg_dict, state_dict = converter(path, device, cfg_overrides=overrides) + cfg_dict = handle_config_defaulting(cfg_dict) + sae_cfg = TrainingCrosscoderSAEConfig.from_dict(cfg_dict) + sae = cls(sae_cfg) + sae.process_state_dict_for_loading(state_dict) + sae.load_state_dict(state_dict) + return sae + + def initialize_weights_complex(self): + if self.cfg.decoder_orthogonal_init: + self.W_dec.data = nn.init.orthogonal_( + self.W_dec.data.permute((1, 2, 0)) + ).permute((2, 0, 1)) + + elif self.cfg.decoder_heuristic_init: + self.W_dec = nn.Parameter( + torch.rand( + self.cfg.d_sae, + *self.input_shape(), + dtype=self.dtype, + device=self.device, + ) + ) + self.initialize_decoder_norm_constant_norm( + self.cfg.decoder_heuristic_init_norm + ) + + # Then we initialize the encoder weights (either as the transpose of decoder or not) + if self.cfg.init_encoder_as_decoder_transpose: + self.W_enc.data = self.W_dec.data.permute((1, 2, 0)).clone().contiguous() + else: + self.W_enc = nn.Parameter( + torch.nn.init.kaiming_uniform_( + torch.empty( + *self.input_shape(), + self.cfg.d_sae, + dtype=self.dtype, + device=self.device, + ) + ) + ) + + if self.cfg.normalize_sae_decoder: + with torch.no_grad(): + # Anthropic normalize this to have unit columns + self.set_decoder_norm_to_unit_norm() + + @torch.no_grad() + def set_decoder_norm_to_unit_norm(self): + self.W_dec.data /= torch.norm(self.W_dec.data, dim=[1, 2], keepdim=True) + + @torch.no_grad() + def initialize_decoder_norm_constant_norm(self, norm: float = 0.1): + """ + A heuristic proceedure inspired by: + https://transformer-circuits.pub/2024/april-update/index.html#training-saes + """ + # TODO: Parameterise this as a function of m and n + + # ensure W_dec norms at unit norm + self.W_dec.data /= torch.norm(self.W_dec.data, dim=[1, 2], keepdim=True) + self.W_dec.data *= norm # will break tests but do this for now. + + @torch.no_grad() + def remove_gradient_parallel_to_decoder_directions(self): + """ + Update grads so that they remove the parallel component + (d_sae, n_layers, d_in) shape + """ + assert self.W_dec.grad is not None # keep pyright happy + + parallel_component = einops.einsum( + self.W_dec.grad, + self.W_dec.data, + "d_sae n_layers d_in, d_sae n_layers d_in -> d_sae", + ) + self.W_dec.grad -= einops.einsum( + parallel_component, + self.W_dec.data, + "d_sae, d_sae n_layers d_in -> d_sae n_layers d_in", + ) diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index ba51ab843..51d5d9c08 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -185,7 +185,7 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAEConfig": elif not isinstance(valid_config_dict["seqpos_slice"], tuple): valid_config_dict["seqpos_slice"] = (valid_config_dict["seqpos_slice"],) - return TrainingSAEConfig(**valid_config_dict) + return cls(**valid_config_dict) def to_dict(self) -> dict[str, Any]: return { @@ -244,8 +244,7 @@ class TrainingSAE(SAE): device: torch.device def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False): - base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict()) - super().__init__(base_sae_cfg) + super().__init__(self.base_sae_cfg(cfg), use_error_term=use_error_term) self.cfg = cfg # type: ignore if cfg.architecture == "standard" or cfg.architecture == "topk": @@ -291,6 +290,10 @@ def threshold(self) -> torch.Tensor: def from_dict(cls, config_dict: dict[str, Any]) -> "TrainingSAE": return cls(TrainingSAEConfig.from_dict(config_dict)) + @staticmethod + def base_sae_cfg(cfg: TrainingSAEConfig): + return SAEConfig.from_dict(cfg.get_base_sae_cfg_dict()) + def check_cfg_compatibility(self): if self.cfg.architecture != "standard" and self.cfg.use_ghost_grads: raise ValueError(f"{self.cfg.architecture} SAEs do not support ghost grads") @@ -597,7 +600,10 @@ def initialize_weights_complex(self): elif self.cfg.decoder_heuristic_init: self.W_dec = nn.Parameter( torch.rand( - self.cfg.d_sae, self.cfg.d_in, dtype=self.dtype, device=self.device + self.cfg.d_sae, + *self.input_shape(), + dtype=self.dtype, + device=self.device, ) ) self.initialize_decoder_norm_constant_norm( @@ -611,7 +617,7 @@ def initialize_weights_complex(self): self.W_enc = nn.Parameter( torch.nn.init.kaiming_uniform_( torch.empty( - self.cfg.d_in, + *self.input_shape(), self.cfg.d_sae, dtype=self.dtype, device=self.device, diff --git a/scripts/acausal_crosscoder.py b/scripts/acausal_crosscoder.py new file mode 100644 index 000000000..2e21e7091 --- /dev/null +++ b/scripts/acausal_crosscoder.py @@ -0,0 +1,119 @@ +import os +import sys + +import torch + +sys.path.append("..") + +from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.sae_training_runner import SAETrainingRunner + +if torch.cuda.is_available(): + device = "cuda" +elif torch.backends.mps.is_available(): + device = "mps" +else: + device = "cpu" + +print("Using device:", device) +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +# total_training_steps = 200_000 +# total_training_steps = 60_000 +total_training_steps = 10_000 +batch_size = 4092 +# batch_size = 256 +total_training_tokens = total_training_steps * batch_size +print(f"Total Training Tokens: {total_training_tokens}") + +hook_name_template = "blocks.{layer}.hook_mlp_out" +layers = list(range(2)) + +model_name = "gpt2-small" +dataset_path = "apollo-research/SkyLion007-openwebtext-tokenizer-gpt2" +new_cached_activations_path = ( + f"./cached_activations/{model_name}/{dataset_path}/{total_training_steps}" +) + +lr_warm_up_steps = total_training_steps // 40 +print(f"lr_warm_up_steps: {lr_warm_up_steps}") +lr_decay_steps = total_training_steps // 5 # 20% of training steps. +print(f"lr_decay_steps: {lr_decay_steps}") +l1_warmup_steps = total_training_steps // 20 +print(f"l1_warmup_steps: {l1_warmup_steps}") +log_to_wandb = True +if not log_to_wandb: + print("NOT LOGGING TO WANDB") + +d_in = 768 +expansion_factor = 32 +d_sae = d_in * expansion_factor +learning_rate = 2e-5 +l1_coefficient = 1 +hook_name = hook_name_template.format(layer=f"{min(layers)}_through_{max(layers)}") +hook_names = [hook_name_template.format(layer=layer) for layer in layers] + +cfg = LanguageModelSAERunnerConfig( + model_name=model_name, + hook_name=hook_name, + hook_names=hook_names, + hook_layer=max(layers), + d_in=d_in, + dataset_path=dataset_path, + streaming=True, + context_size=512, + is_dataset_tokenized=True, + prepend_bos=True, + expansion_factor=expansion_factor, + use_cached_activations=False, + training_tokens=total_training_tokens, + train_batch_size_tokens=batch_size, + # Loss Function + mse_loss_normalization=None, + l1_coefficient=l1_coefficient, + lp_norm=1.0, + scale_sparsity_penalty_by_decoder_norm=True, + # TODO(mkbehr): plumb this through config + # sparsity_penalty_decoder_norm_lp_norm=1.0, + # Learning Rate + lr_scheduler_name="constant", # we set this independently of warmup and decay steps. + l1_warm_up_steps=l1_warmup_steps, + lr_warm_up_steps=lr_warm_up_steps, + lr_decay_steps=lr_warm_up_steps, + use_ghost_grads=False, + # Initialization / Architecture + apply_b_dec_to_input=False, + b_dec_init_method="zeros", + normalize_sae_decoder=False, + decoder_heuristic_init=True, + decoder_heuristic_init_norm=0.1, + init_encoder_as_decoder_transpose=True, + # Optimizer + lr=learning_rate, + ## adam optimizer has no weight decay by default so worry about this. + adam_beta1=0.9, + adam_beta2=0.999, + # Buffer details won't matter in we cache / shuffle our activations ahead of time. + n_batches_in_buffer=32, + store_batch_size_prompts=16, + normalize_activations="expected_average_only_in", + # Feature Store + feature_sampling_window=1000, + dead_feature_window=1000, + dead_feature_threshold=1e-4, + # WANDB + log_to_wandb=log_to_wandb, # always use wandb unless you are just testing code. + wandb_project="crosscoder-acausal-gpt2-small", + wandb_log_frequency=50, + eval_every_n_wandb_logs=10, + # Misc + device=device, + seed=42, + n_checkpoints=0, + checkpoint_path="checkpoints", + dtype="float32", +) + +sae = SAETrainingRunner(cfg).run() + +print("=" * 50) diff --git a/tests/helpers.py b/tests/helpers.py index 6c3cdab3e..91082bf86 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -16,6 +16,7 @@ class LanguageModelSAERunnerConfigDict(TypedDict, total=False): model_name: str hook_name: str hook_layer: int + hook_names: list[int] | None hook_head_index: int | None dataset_path: str dataset_trust_remote_code: bool @@ -53,6 +54,7 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: mock_config_dict: LanguageModelSAERunnerConfigDict = { "model_name": TINYSTORIES_MODEL, "hook_name": "blocks.0.hook_mlp_out", + "hook_names": [], "hook_layer": 0, "hook_head_index": None, # use a small, non-streaming dataset for testing. Huggingface gives too many requests errors otherwise. @@ -96,6 +98,27 @@ def build_sae_cfg(**kwargs: Any) -> LanguageModelSAERunnerConfig: return mock_config +def build_multilayer_sae_cfg( + hook_name_template: str = "blocks.{layer}.hook_mlp_out", + hook_layers: list[int] = [0, 1, 2], + **kwargs: Any, +) -> LanguageModelSAERunnerConfig: + hook_name = hook_name_template.format( + layer=f"layers_{min(hook_layers)}_through_{max(hook_layers)}" + ) + hook_names = [hook_name_template.format(layer=str(layer)) for layer in hook_layers] + return build_sae_cfg( + **( + { + "hook_name": hook_name, + "hook_names": hook_names, + "hook_layer": max(hook_layers), + } + | kwargs + ) + ) + + MODEL_CACHE: dict[str, HookedTransformer] = {} diff --git a/tests/test_evals.py b/tests/test_evals.py index 12a5dd953..9a728e1b5 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -25,8 +25,17 @@ from sae_lens.sae import SAE from sae_lens.toolkit.pretrained_saes_directory import PretrainedSAELookup from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.training_crosscoder_sae import ( + TrainingCrosscoderSAE, + TrainingCrosscoderSAEConfig, +) from sae_lens.training.training_sae import TrainingSAE -from tests.helpers import TINYSTORIES_MODEL, build_sae_cfg, load_model_cached +from tests.helpers import ( + TINYSTORIES_MODEL, + build_multilayer_sae_cfg, + build_sae_cfg, + load_model_cached, +) TRAINER_EVAL_CONFIG = EvalConfig( n_eval_reconstruction_batches=10, @@ -284,6 +293,51 @@ def test_run_empty_evals( assert len(feature_metrics) == 0, "Expected empty feature_metrics" +# TODO(mkbehr): consider parameterizing +def test_run_evals_crosscoder_training_sae( + model: HookedTransformer, +): + cfg = build_multilayer_sae_cfg( + model_name="tiny-stories-1M", + dataset_path="roneneldan/TinyStories", + hook_name_template="blocks.{layer}.hook_resid_pre", + hook_layers=[0, 1], + d_in=64, + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + ) + activation_store = ActivationsStore.from_config( + model, cfg, override_dataset=Dataset.from_list([{"text": "hello world"}] * 2000) + ) + training_crosscoder_sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), use_error_term=True + ) + eval_config = EvalConfig( + compute_l2_norms=True, + compute_sparsity_metrics=True, + compute_variance_metrics=True, + # TODO(mkbehr): featurewise metrics + compute_featurewise_density_statistics=False, + compute_featurewise_weight_based_metrics=False, + ) + eval_metrics, feature_metrics = run_evals( + sae=training_crosscoder_sae, + activation_store=activation_store, + model=model, + eval_config=eval_config, + ) + expected_keys = [ + "reconstruction_quality", + "shrinkage", + "sparsity", + "token_stats", + ] + assert set(eval_metrics.keys()) == set(expected_keys) + assert set(feature_metrics.keys()) == set( + ["feature_density", "consistent_activation_heuristic"] + ) + + @pytest.fixture def mock_args(): args = argparse.Namespace() diff --git a/tests/training/test_activations_store_multilayer.py b/tests/training/test_activations_store_multilayer.py new file mode 100644 index 000000000..522e1521e --- /dev/null +++ b/tests/training/test_activations_store_multilayer.py @@ -0,0 +1,192 @@ +"""Tests for ActivationsStore with multiple layer support.""" + +import os +import tempfile + +import torch +from datasets import Dataset +from safetensors.torch import load_file +from transformer_lens import HookedTransformer + +from sae_lens.training.activations_store import ActivationsStore +from tests.helpers import build_multilayer_sae_cfg, build_sae_cfg + + +def test_activations_store_init_with_multiple_layers(ts_model: HookedTransformer): + """Test initialization with a list of layers instead of a single layer.""" + cfg = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[0, 1, 2] + ) + + activation_store = ActivationsStore.from_config(ts_model, cfg) + + assert activation_store.hook_names == [ + "blocks.0.hook_resid_pre", + "blocks.1.hook_resid_pre", + "blocks.2.hook_resid_pre", + ] + + cfg_single = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", hook_layers=[1] + ) + + single_layer_store = ActivationsStore.from_config(ts_model, cfg_single) + assert single_layer_store.hook_names == [ + "blocks.1.hook_resid_pre", + ] + + +def test_activations_store_get_activations_multiple_layers(ts_model: HookedTransformer): + """Test that get_activations collects activations from all specified layers.""" + cfg = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", + hook_layers=[0, 1, 2], + context_size=5, + ) + + dataset = Dataset.from_list([{"text": "hello world"}] * 10) + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) + + batch_tokens = activation_store.get_batch_tokens() + activations = activation_store.get_activations(batch_tokens) + + # Check shape: [batch_size, context_size, num_layers, d_in] + assert activations.shape == ( + cfg.store_batch_size_prompts, + cfg.context_size, + len(cfg.hook_names), + cfg.d_in, + ) + + # Verify that layers are in the correct order + # Run with cache directly to compare against + _, cache = ts_model.run_with_cache( + batch_tokens, names_filter=[f"blocks.{i}.hook_resid_pre" for i in [0, 1, 2]] + ) + + for i, layer in enumerate([0, 1, 2]): + hook_name = f"blocks.{layer}.hook_resid_pre" + # Compare the activations for this layer with what we got from run_with_cache + assert torch.allclose(activations[:, :, i, :], cache[hook_name], atol=1e-5) + + +def test_activations_store_get_buffer_multiple_layers(ts_model: HookedTransformer): + """Test buffer handling with multiple layers.""" + cfg = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", + hook_layers=[0, 1, 2], + context_size=5, + ) + + dataset = Dataset.from_list([{"text": "hello world"}] * 20) + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) + + buffer_activations, buffer_tokens = activation_store.get_buffer( + n_batches_in_buffer=2 + ) + + # Check shape: [(batch_size * context_size * n_batches), num_layers, d_in] + expected_size = cfg.store_batch_size_prompts * cfg.context_size * 2 + assert buffer_activations.shape == (expected_size, len(cfg.hook_names), cfg.d_in) + assert buffer_tokens is not None + assert buffer_tokens.shape == (expected_size,) + + +def test_activations_store_next_batch_multiple_layers(ts_model: HookedTransformer): + """Test that next_batch returns correct batch shape with multiple layers.""" + cfg = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", + hook_layers=[0, 1, 2], + context_size=5, + train_batch_size_tokens=10, + ) + + dataset = Dataset.from_list([{"text": "hello world"}] * 20) + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) + + batch = activation_store.next_batch() + assert batch.shape == (10, len(cfg.hook_names), activation_store.d_in) + + +def test_activations_store_normalization_multiple_layers(ts_model: HookedTransformer): + """Test normalization when using multiple layers.""" + cfg = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", + hook_layers=[0, 1, 2], + normalize_activations="expected_average_only_in", + context_size=5, + ) + + dataset = Dataset.from_list([{"text": "hello world"}] * 20) + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) + activation_store.set_norm_scaling_factor_if_needed() + + batch = activation_store.next_batch() + + avg_norm = batch.norm(dim=-1).mean(dim=1) + expected_norm = torch.full_like(avg_norm, cfg.d_in**0.5) + torch.testing.assert_close(avg_norm, expected_norm, atol=1.0, rtol=0.1) + + +def test_backward_compatibility_single_layer(ts_model: HookedTransformer): + """Test that single layer behavior is unchanged with the multi-layer support.""" + cfg_single = build_sae_cfg( + hook_name="blocks.0.hook_resid_pre", hook_layer=0, context_size=5 + ) + + dataset = Dataset.from_list([{"text": "hello world"}] * 10) + single_store = ActivationsStore.from_config( + ts_model, cfg_single, override_dataset=dataset + ) + + cfg_multi = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", + hook_layers=[0], + context_size=5, + ) + multi_store = ActivationsStore.from_config( + ts_model, cfg_multi, override_dataset=dataset + ) + + batch_tokens_single = single_store.get_batch_tokens() + activations_single = single_store.get_activations(batch_tokens_single) + + batch_tokens_multi = multi_store.get_batch_tokens() + activations_multi = multi_store.get_activations(batch_tokens_multi) + + torch.testing.assert_close(batch_tokens_single, batch_tokens_multi) + torch.testing.assert_close(activations_single, activations_multi) + + +def test_activations_store_multilayer_save_with_norm_scaling_factor( + ts_model: HookedTransformer, +): + cfg = build_multilayer_sae_cfg( + hook_name_template="blocks.{layer}.hook_resid_pre", + hook_layers=[0, 1, 2], + normalize_activations="expected_average_only_in", + context_size=5, + ) + activation_store = ActivationsStore.from_config(ts_model, cfg) + activation_store.set_norm_scaling_factor_if_needed() + assert activation_store.estimated_norm_scaling_factor is not None + with tempfile.NamedTemporaryFile() as temp_file: + activation_store.save(temp_file.name) + assert os.path.exists(temp_file.name) + state_dict = load_file(temp_file.name) + assert isinstance(state_dict, dict) + assert "estimated_norm_scaling_factor" in state_dict + estimated_norm_scaling_factor = state_dict["estimated_norm_scaling_factor"] + assert estimated_norm_scaling_factor.shape == (len(cfg.hook_names),) + torch.testing.assert_close( + estimated_norm_scaling_factor, + activation_store.estimated_norm_scaling_factor, + ) diff --git a/tests/training/test_cache_activations_runner.py b/tests/training/test_cache_activations_runner.py index b6309e8f5..7d070e962 100644 --- a/tests/training/test_cache_activations_runner.py +++ b/tests/training/test_cache_activations_runner.py @@ -271,7 +271,7 @@ def test_cache_activations_runner_with_incorrect_d_in(tmp_path: Path): runner = CacheActivationsRunner(wrong_d_in_cfg) with pytest.raises( RuntimeError, - match=r"The expanded size of the tensor \(513\) must match the existing size \(512\) at non-singleton dimension 2.", + match=r"The expanded size of the tensor \(513\) must match the existing size \(512\) at non-singleton dimension 3.", ): runner.run() diff --git a/tests/training/test_config.py b/tests/training/test_config.py index db888b8ac..2af015efc 100644 --- a/tests/training/test_config.py +++ b/tests/training/test_config.py @@ -90,6 +90,7 @@ def test_sae_training_runner_config_get_sae_base_parameters(): "dtype": "float32", "model_name": "gelu-2l", "hook_name": "blocks.0.hook_mlp_out", + "hook_names": [], "hook_layer": 0, "hook_head_index": None, "device": "cpu", diff --git a/tests/training/test_crosscoder_sae.py b/tests/training/test_crosscoder_sae.py new file mode 100644 index 000000000..f78899906 --- /dev/null +++ b/tests/training/test_crosscoder_sae.py @@ -0,0 +1,346 @@ +import os +from copy import deepcopy +from pathlib import Path + +import einops +import pytest +import torch + +from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.crosscoder_sae import CrosscoderSAE +from tests.helpers import ALL_ARCHITECTURES, build_multilayer_sae_cfg + + +# Define a new fixture for different configurations +@pytest.fixture( + params=[ + { + "model_name": "tiny-stories-1M", + "dataset_path": "roneneldan/TinyStories", + "hook_name_template": "blocks.{layer}.hook_resid_pre", + "hook_layers": [0, 1, 2], + "d_in": 64, + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, + }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "roneneldan/TinyStories", + "hook_name_template": "blocks.{layer}.hook_resid_pre", + "hook_layers": [0, 1, 2], + "d_in": 64, + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, + }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", + "hook_name_template": "blocks.{layer}.hook_resid_pre", + "hook_layers": [0, 1, 2], + "d_in": 64, + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, + }, + # TODO(mkbehr): hook_z support + # { + # "model_name": "tiny-stories-1M", + # "dataset_path": "roneneldan/TinyStories", + # "hook_name": "blocks.{layer}.attn.hook_z", + # "hook_layers": [0,1,2], + # "d_in": 64, + # "normalize_sae_decoder": False, + # "scale_sparsity_penalty_by_decoder_norm": True, + # }, + ], + ids=[ + "tiny-stories-1M-resid-pre", + "tiny-stories-1M-resid-pre-L1-W-dec-Norm", + "tiny-stories-1M-resid-pre-pretokenized", + # "tiny-stories-1M-attn-out", + ], +) +def cfg(request: pytest.FixtureRequest): + """ + Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig. + """ + params = request.param + return build_multilayer_sae_cfg(**params) + + +def test_crosscoder_sae_init(cfg: LanguageModelSAERunnerConfig): + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + + assert isinstance(sae, CrosscoderSAE) + + n_layers = len(cfg.hook_names) + assert sae.W_enc.shape == (n_layers, cfg.d_in, cfg.d_sae) + assert sae.W_dec.shape == (cfg.d_sae, n_layers, cfg.d_in) + assert sae.b_enc.shape == (cfg.d_sae,) + assert sae.b_dec.shape == (n_layers, cfg.d_in) + + +def test_crosscoder_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. + assert sae.W_dec.norm(dim=[-2, -1]).mean().item() != pytest.approx(1.0, abs=1e-6) + sae2 = deepcopy(sae) + sae2.fold_W_dec_norm() + + W_dec_norms = sae.W_dec.norm(dim=[-2, -1], keepdim=True) + assert torch.allclose(sae2.W_dec.data, sae.W_dec.data / W_dec_norms) + assert torch.allclose( + sae2.W_enc.data, + sae.W_enc.data * einops.rearrange(W_dec_norms, "d_sae 1 1 -> 1 1 d_sae"), + ) + assert torch.allclose(sae2.b_enc.data, sae.b_enc.data * W_dec_norms.squeeze()) + + # fold_W_dec_norm should normalize W_dec to have unit norm. + assert sae2.W_dec.norm(dim=[-2, -1]).mean().item() == pytest.approx(1.0, abs=1e-6) + + # we expect activations of features to differ by W_dec norm weights. + activations = torch.randn(10, 4, len(cfg.hook_names), cfg.d_in, device=cfg.device) + feature_activations_1 = sae.encode(activations) + feature_activations_2 = sae2.encode(activations) + + assert torch.allclose( + feature_activations_1.nonzero(), + feature_activations_2.nonzero(), + ) + + expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm( + dim=[-2, -1] + ) + torch.testing.assert_close(feature_activations_2, expected_feature_activations_2) + + sae_out_1 = sae.decode(feature_activations_1) + sae_out_2 = sae2.decode(feature_activations_2) + + # but actual outputs should be the same + torch.testing.assert_close(sae_out_1, sae_out_2) + + +@pytest.mark.parametrize("architecture", ALL_ARCHITECTURES) +@torch.no_grad() +def test_crosscoder_sae_fold_w_dec_norm_all_architectures(architecture: str): + if architecture != "standard": + pytest.xfail("TODO(mkbehr): support other architectures") + cfg = build_multilayer_sae_cfg(architecture=architecture, hook_layers=[0, 1, 2]) + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. + + # make sure all parameters are not 0s + for param in sae.parameters(): + param.data = torch.rand_like(param) + + assert sae.W_dec.norm(dim=[-2, -1]).mean().item() != pytest.approx(1.0, abs=1e-6) + sae2 = deepcopy(sae) + sae2.fold_W_dec_norm() + + # fold_W_dec_norm should normalize W_dec to have unit norm. + assert sae2.W_dec.norm(dim=[-2, -1]).mean().item() == pytest.approx(1.0, abs=1e-6) + + # we expect activations of features to differ by W_dec norm weights. + activations = torch.randn(10, 4, len(cfg.hook_names), cfg.d_in, device=cfg.device) + feature_activations_1 = sae.encode(activations) + feature_activations_2 = sae2.encode(activations) + + assert torch.allclose( + feature_activations_1.nonzero(), + feature_activations_2.nonzero(), + ) + + expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm( + dim=[-2, -1] + ) + torch.testing.assert_close(feature_activations_2, expected_feature_activations_2) + + sae_out_1 = sae.decode(feature_activations_1) + sae_out_2 = sae2.decode(feature_activations_2) + + # but actual outputs should be the same + torch.testing.assert_close(sae_out_1, sae_out_2) + + +@torch.no_grad() +def test_crosscoder_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): + norm_scaling_factor = torch.Tensor([2.0, 3.0, 4.0]) + + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + # make sure b_dec and b_enc are not 0s + sae.b_dec.data = torch.randn(len(cfg.hook_names), cfg.d_in, device=cfg.device) + sae.b_enc.data = torch.randn(cfg.d_sae, device=cfg.device) # type: ignore + sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. + + sae2 = deepcopy(sae) + sae2.fold_activation_norm_scaling_factor(norm_scaling_factor) + + assert sae2.cfg.normalize_activations == "none" + + assert torch.allclose( + sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor.reshape((-1, 1, 1)) + ) + + # we expect activations of features to differ by W_dec norm weights. + # assume activations are already scaled + activations = torch.randn(10, 4, len(cfg.hook_names), cfg.d_in, device=cfg.device) + # we divide to get the unscale activations + unscaled_activations = activations / norm_scaling_factor.unsqueeze(-1) + + feature_activations_1 = sae.encode(activations) + # with the scaling folded in, the unscaled activations should produce the same + # result. + feature_activations_2 = sae2.encode(unscaled_activations) + + assert torch.allclose( + feature_activations_1.nonzero(), + feature_activations_2.nonzero(), + ) + + torch.testing.assert_close(feature_activations_2, feature_activations_1) + + sae_out_1 = sae.decode(feature_activations_1) + sae_out_2 = norm_scaling_factor.unsqueeze(-1) * sae2.decode(feature_activations_2) + + # but actual outputs should be the same + torch.testing.assert_close(sae_out_1, sae_out_2) + + +@pytest.mark.parametrize("architecture", ALL_ARCHITECTURES) +@torch.no_grad() +def test_crosscoder_sae_fold_norm_scaling_factor_all_architectures(architecture: str): + if architecture != "standard": + pytest.xfail("TODO(mkbehr): support other architectures") + cfg = build_multilayer_sae_cfg(architecture=architecture, hook_layers=[0, 1, 2]) + norm_scaling_factor = torch.Tensor([2.0, 3.0, 4.0]) + + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + # make sure all parameters are not 0s + for param in sae.parameters(): + param.data = torch.rand_like(param) + + sae2 = deepcopy(sae) + sae2.fold_activation_norm_scaling_factor(norm_scaling_factor) + + assert sae2.cfg.normalize_activations == "none" + + assert torch.allclose( + sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor.reshape((-1, 1, 1)) + ) + + # we expect activations of features to differ by W_dec norm weights. + # assume activations are already scaled + activations = torch.randn(10, 4, len(cfg.hook_names), cfg.d_in, device=cfg.device) + # we divide to get the unscale activations + unscaled_activations = activations / norm_scaling_factor.unsqueeze(-1) + + feature_activations_1 = sae.encode(activations) + # with the scaling folded in, the unscaled activations should produce the same + # result. + feature_activations_2 = sae2.encode(unscaled_activations) + + assert torch.allclose( + feature_activations_1.nonzero(), + feature_activations_2.nonzero(), + ) + + torch.testing.assert_close(feature_activations_2, feature_activations_1) + + sae_out_1 = sae.decode(feature_activations_1) + sae_out_2 = norm_scaling_factor.unsqueeze(-1) * sae2.decode(feature_activations_2) + + # but actual outputs should be the same + torch.testing.assert_close(sae_out_1, sae_out_2) + + +def test_crosscoder_sae_save_and_load_from_pretrained(tmp_path: Path) -> None: + cfg = build_multilayer_sae_cfg(hook_layers=[0, 1, 2]) + model_path = str(tmp_path) + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + sae_state_dict = sae.state_dict() + sae.save_model(model_path) + + assert os.path.exists(model_path) + + sae_loaded = CrosscoderSAE.load_from_pretrained(model_path, device="cpu") + + sae_loaded_state_dict = sae_loaded.state_dict() + + # check state_dict matches the original + for key in sae.state_dict(): + assert torch.allclose( + sae_state_dict[key], + sae_loaded_state_dict[key], + ) + + sae_in = torch.randn(10, len(cfg.hook_names), cfg.d_in, device=cfg.device) + sae_out_1 = sae(sae_in) + sae_out_2 = sae_loaded(sae_in) + assert torch.allclose(sae_out_1, sae_out_2) + + +@pytest.mark.xfail(reason="TODO(mkbehr): support other architectures") +def test_crosscoder_sae_save_and_load_from_pretrained_gated(tmp_path: Path) -> None: + cfg = build_multilayer_sae_cfg(architecture="gated", hook_layers=[0, 1, 2]) + model_path = str(tmp_path) + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + sae_state_dict = sae.state_dict() + sae.save_model(model_path) + + assert os.path.exists(model_path) + + sae_loaded = CrosscoderSAE.load_from_pretrained(model_path, device="cpu") + + sae_loaded_state_dict = sae_loaded.state_dict() + + # check state_dict matches the original + for key in sae.state_dict(): + assert torch.allclose( + sae_state_dict[key], + sae_loaded_state_dict[key], + ) + + sae_in = torch.randn(10, len(cfg.hook_names), cfg.d_in, device=cfg.device) + sae_out_1 = sae(sae_in) + sae_out_2 = sae_loaded(sae_in) + assert torch.allclose(sae_out_1, sae_out_2) + + +def test_crosscoder_sae_save_and_load_from_pretrained_topk(tmp_path: Path) -> None: + cfg = build_multilayer_sae_cfg( + activation_fn_kwargs={"k": 30}, hook_layers=[0, 1, 2] + ) + model_path = str(tmp_path) + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + sae_state_dict = sae.state_dict() + sae.save_model(model_path) + + assert os.path.exists(model_path) + + sae_loaded = CrosscoderSAE.load_from_pretrained(model_path, device="cpu") + + sae_loaded_state_dict = sae_loaded.state_dict() + + # check state_dict matches the original + for key in sae.state_dict(): + assert torch.allclose( + sae_state_dict[key], + sae_loaded_state_dict[key], + ) + + sae_in = torch.randn(10, len(cfg.hook_names), cfg.d_in, device=cfg.device) + sae_out_1 = sae(sae_in) + sae_out_2 = sae_loaded(sae_in) + assert torch.allclose(sae_out_1, sae_out_2) + + +def test_crosscoder_sae_get_name_returns_correct_name_from_cfg_vals() -> None: + cfg = build_multilayer_sae_cfg( + model_name="test_model", + hook_name_template="blocks.{layer}.test_hook_name", + d_sae=128, + hook_layers=[0, 1, 2], + ) + sae = CrosscoderSAE.from_dict(cfg.get_base_sae_cfg_dict()) + assert ( + sae.get_name() == "sae_test_model_blocks.layers_0_through_2.test_hook_name_128" + ) diff --git a/tests/training/test_crosscoder_sae_trainer.py b/tests/training/test_crosscoder_sae_trainer.py new file mode 100644 index 000000000..8991468ff --- /dev/null +++ b/tests/training/test_crosscoder_sae_trainer.py @@ -0,0 +1,233 @@ +from pathlib import Path +from typing import Any, Callable + +import pytest +import torch +from datasets import Dataset +from transformer_lens import HookedTransformer + +from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.crosscoder_sae_trainer import CrosscoderSAETrainer +from sae_lens.training.sae_trainer import ( + TrainStepOutput, +) +from sae_lens.training.training_crosscoder_sae import TrainingCrosscoderSAE +from tests.helpers import TINYSTORIES_MODEL, build_multilayer_sae_cfg, load_model_cached + + +@pytest.fixture +def cfg(): + return build_multilayer_sae_cfg( + d_in=64, + d_sae=128, + hook_name_template="blocks.{layer}.hook_mlp_out", + hook_layers=[0, 1, 2], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + ) + + +@pytest.fixture +def model(): + return load_model_cached(TINYSTORIES_MODEL) + + +@pytest.fixture +def activation_store(model: HookedTransformer, cfg: LanguageModelSAERunnerConfig): + return ActivationsStore.from_config( + model, cfg, override_dataset=Dataset.from_list([{"text": "hello world"}] * 2000) + ) + + +@pytest.fixture +def training_sae(cfg: LanguageModelSAERunnerConfig): + return TrainingCrosscoderSAE.from_dict( + cfg.get_training_sae_cfg_dict(), use_error_term=True + ) + + +@pytest.fixture +def trainer( + cfg: LanguageModelSAERunnerConfig, + training_sae: TrainingCrosscoderSAE, + model: HookedTransformer, + activation_store: ActivationsStore, +): + return CrosscoderSAETrainer( + model=model, + sae=training_sae, + activation_store=activation_store, + save_checkpoint_fn=lambda *args, **kwargs: None, # noqa: ARG005 + cfg=cfg, + ) + + +def modify_sae_output( + sae: TrainingCrosscoderSAE, modifier: Callable[[torch.Tensor], Any] +): + """ + Helper to modify the output of the SAE forward pass for use in patching, for use in patch side_effect. + We need real grads during training, so we can't just mock the whole forward pass directly. + """ + + def modified_forward(*args: Any, **kwargs: Any) -> torch.Tensor: + output = TrainingCrosscoderSAE.forward(sae, *args, **kwargs) + return modifier(output) + + return modified_forward + + +def test_train_step__reduces_loss_when_called_repeatedly_on_same_acts( + trainer: CrosscoderSAETrainer, +) -> None: + layer_acts = trainer.activations_store.next_batch() + + # intentionally train on the same activations 5 times to ensure loss decreases + train_outputs = [ + trainer._train_step( + sae=trainer.sae, + sae_in=layer_acts, + ) + for _ in range(5) + ] + + # ensure loss decreases with each training step + for output, next_output in zip(train_outputs[:-1], train_outputs[1:]): + assert output.loss > next_output.loss + assert ( + trainer.n_frac_active_tokens == 20 + ) # should increment each step by batch_size (5*4) + + +def test_train_step__output_looks_reasonable(trainer: CrosscoderSAETrainer) -> None: + layer_acts = trainer.activations_store.next_batch() + + output = trainer._train_step( + sae=trainer.sae, + sae_in=layer_acts, + ) + + assert output.loss > 0 + # only hook_point_layer=0 acts should be passed to the SAE + assert torch.allclose(output.sae_in, layer_acts) + assert output.sae_out.shape == output.sae_in.shape + assert output.feature_acts.shape == (4, 128) # batch_size, d_sae + # ghots grads shouldn't trigger until dead_feature_window, which hasn't been reached yet + assert output.losses.get("ghost_grad_loss", 0) == 0 + assert trainer.n_frac_active_tokens == 4 + assert trainer.act_freq_scores.sum() > 0 # at least SOME acts should have fired + assert torch.allclose( + trainer.act_freq_scores, (output.feature_acts.abs() > 0).float().sum(0) + ) + + +def test_train_step__sparsity_updates_based_on_feature_act_sparsity( + trainer: CrosscoderSAETrainer, +) -> None: + trainer._reset_running_sparsity_stats() + layer_acts = trainer.activations_store.next_batch() + + train_output = trainer._train_step( + sae=trainer.sae, + sae_in=layer_acts, + ) + feature_acts = train_output.feature_acts + + # should increase by batch_size + assert trainer.n_frac_active_tokens == 4 + # add freq scores for all non-zero feature acts + assert torch.allclose( + trainer.act_freq_scores, (feature_acts > 0).float().sum(dim=0) + ) + + # check that features that just fired have n_forward_passes_since_fired = 0 + assert ( + trainer.n_forward_passes_since_fired[ + ((feature_acts > 0).float()[-1] == 1) + ].max() + == 0 + ) + assert train_output.feature_acts is feature_acts + + +def test_build_train_step_log_dict(trainer: CrosscoderSAETrainer) -> None: + sae_in = torch.tensor( + [[[-1, 0], [-2, 0]], [[0, 2], [0, 3]], [[1, 1], [1, 1]]] + ).float() + sae_out = torch.tensor( + [[[0, 0], [0, 0]], [[0, 2], [0, 3]], [[0.5, 1], [1, 0.5]]] + ).float() + train_output = TrainStepOutput( + sae_in=sae_in, + sae_out=sae_out, + feature_acts=torch.tensor([[0, 0, 0, 1], [1, 0, 0, 1], [1, 0, 1, 1]]).float(), + hidden_pre=torch.tensor([[-1, 0, 0, 1], [1, -1, 0, 1], [1, -1, 1, 1]]).float(), + loss=torch.tensor(0.5), + losses={ + "mse_loss": 0.25, + "l1_loss": 0.1, + "ghost_grad_loss": 0.15, + }, + ) + + per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=(-2, -1)).squeeze() + total_variance = (sae_in - sae_in.mean(0)).pow(2).sum((-2, -1)) + explained_variance = 1 - per_token_l2_loss / total_variance + + # we're relying on the trainer only for some of the metrics here + # we should more / less try to break this and push + # everything through the train step output if we can. + log_dict = trainer._build_train_step_log_dict( + output=train_output, n_training_tokens=123 + ) + for key, val in { + "losses/mse_loss": 0.25, + # l1 loss is scaled by l1_coefficient + "losses/l1_loss": train_output.losses["l1_loss"] / trainer.cfg.l1_coefficient, + "losses/raw_l1_loss": train_output.losses["l1_loss"], + "losses/overall_loss": 0.5, + "losses/ghost_grad_loss": 0.15, + "metrics/explained_variance": explained_variance.mean().item(), + "metrics/explained_variance_std": explained_variance.std().item(), + "metrics/l0": 2.0, + "sparsity/mean_passes_since_fired": trainer.n_forward_passes_since_fired.mean().item(), + "sparsity/dead_features": trainer.dead_neurons.sum().item(), + "details/current_learning_rate": 2e-4, + "details/current_l1_coefficient": trainer.cfg.l1_coefficient, + "details/n_training_tokens": 123, + }.items(): + assert abs(val - log_dict[key]) < 1e-6 + + +def test_train_sae_group_on_language_model__runs( + ts_model: HookedTransformer, + tmp_path: Path, +) -> None: + checkpoint_dir = tmp_path / "checkpoint" + cfg = build_multilayer_sae_cfg( + checkpoint_path=str(checkpoint_dir), + training_tokens=20, + context_size=8, + hook_name_template="blocks.{layer}.hook_mlp_out", + hook_layers=[0, 1, 2], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + ) + # just a tiny datast which will run quickly + dataset = Dataset.from_list([{"text": "hello world"}] * 100) + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) + sae = TrainingCrosscoderSAE.from_dict( + cfg.get_training_sae_cfg_dict(), use_error_term=True + ) + sae = CrosscoderSAETrainer( + model=ts_model, + sae=sae, + activation_store=activation_store, + save_checkpoint_fn=lambda *args, **kwargs: None, # noqa: ARG005 + cfg=cfg, + ).fit() + + assert isinstance(sae, TrainingCrosscoderSAE) diff --git a/tests/training/test_crosscoder_sae_training.py b/tests/training/test_crosscoder_sae_training.py new file mode 100644 index 000000000..23be54bbd --- /dev/null +++ b/tests/training/test_crosscoder_sae_training.py @@ -0,0 +1,255 @@ +import pytest +import torch +from datasets import Dataset +from transformer_lens import HookedTransformer + +from sae_lens.config import LanguageModelSAERunnerConfig +from sae_lens.training.activations_store import ActivationsStore +from sae_lens.training.sae_trainer import SAETrainer +from sae_lens.training.training_crosscoder_sae import ( + TrainingCrosscoderSAE, + TrainingCrosscoderSAEConfig, +) +from tests.helpers import build_multilayer_sae_cfg + + +# Define a new fixture for different configurations +@pytest.fixture( + params=[ + { + "model_name": "tiny-stories-1M", + "dataset_path": "roneneldan/TinyStories", + "hook_name": "blocks.{layer}.hook_resid_pre", + "hook_layers": [0, 1, 2], + "d_in": 64, + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, + }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", + "hook_name": "blocks.{layer}.hook_resid_pre", + "hook_layers": [0, 1, 2], + "d_in": 64, + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, + }, + { + "model_name": "tiny-stories-1M", + "dataset_path": "apollo-research/roneneldan-TinyStories-tokenizer-gpt2", + "hook_name": "blocks.{layer}.hook_resid_pre", + "hook_layers": [0, 1, 2], + "d_in": 64, + "normalize_activations": "constant_norm_rescale", + "normalize_sae_decoder": False, + "scale_sparsity_penalty_by_decoder_norm": True, + }, + ], + ids=[ + "tiny-stories-1M-resid-pre", + "tiny-stories-1M-resid-pre-pretokenized", + "tiny-stories-1M-resid-pre-pretokenized-norm-rescale", + ], +) +def cfg(request: pytest.FixtureRequest): + """ + Pytest fixture to create a mock instance of LanguageModelSAERunnerConfig. + """ + params = request.param + return build_multilayer_sae_cfg(**params) + + +@pytest.fixture +def training_crosscoder_sae(cfg: LanguageModelSAERunnerConfig): + """ + Pytest fixture to create a mock instance of SparseAutoencoder. + """ + return TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), use_error_term=True + ) + + +@pytest.fixture +def activation_store(model: HookedTransformer, cfg: LanguageModelSAERunnerConfig): + return ActivationsStore.from_config( + model, cfg, override_dataset=Dataset.from_list([{"text": "hello world"}] * 2000) + ) + + +@pytest.fixture +def model(cfg: LanguageModelSAERunnerConfig): + return HookedTransformer.from_pretrained(cfg.model_name, device="cpu") + + +# todo: remove the need for this fixture +@pytest.fixture +def trainer( + cfg: LanguageModelSAERunnerConfig, + training_crosscoder_sae: TrainingCrosscoderSAE, + model: HookedTransformer, + activation_store: ActivationsStore, +): + return SAETrainer( + model=model, + sae=training_crosscoder_sae, + activation_store=activation_store, + save_checkpoint_fn=lambda *args, **kwargs: None, # noqa: ARG005 + cfg=cfg, + ) + + +def test_sae_forward(training_crosscoder_sae: TrainingCrosscoderSAE): + batch_size = 32 + d_in = training_crosscoder_sae.cfg.d_in + n_layers = len(training_crosscoder_sae.cfg.hook_names) + d_sae = training_crosscoder_sae.cfg.d_sae + + x = torch.randn(batch_size, n_layers, d_in) + train_step_output = training_crosscoder_sae.training_forward_pass( + sae_in=x, + current_l1_coefficient=training_crosscoder_sae.cfg.l1_coefficient, + ) + + assert train_step_output.sae_out.shape == (batch_size, n_layers, d_in) + assert train_step_output.feature_acts.shape == (batch_size, d_sae) + assert ( + pytest.approx(train_step_output.loss.detach(), rel=1e-3) + == ( + train_step_output.losses["mse_loss"] + + train_step_output.losses["l1_loss"] + + train_step_output.losses.get("ghost_grad_loss", 0.0) + ) + .detach() # type: ignore + .cpu() + .numpy() + ) + + expected_mse_loss = ( + (torch.pow((train_step_output.sae_out - x.float()), 2)) + .sum(dim=-1) + .mean() + .detach() + .float() + ) + + assert ( + pytest.approx(train_step_output.losses["mse_loss"].item()) == expected_mse_loss # type: ignore + ) + + expected_l1_loss = ( + ( + train_step_output.feature_acts + * training_crosscoder_sae.W_dec.norm(dim=2).sum(dim=1) + ) + .norm(dim=1, p=1) + .mean() + ) + assert ( + pytest.approx(train_step_output.losses["l1_loss"].item(), rel=1e-3) # type: ignore + == training_crosscoder_sae.cfg.l1_coefficient + * expected_l1_loss.detach().float() + ) + + +def test_sae_forward_with_mse_loss_norm( + training_crosscoder_sae: TrainingCrosscoderSAE, +): + # change the confgi and ensure the mse loss is calculated correctly + training_crosscoder_sae.cfg.mse_loss_normalization = "dense_batch" + training_crosscoder_sae.mse_loss_fn = training_crosscoder_sae._get_mse_loss_fn() + + batch_size = 32 + d_in = training_crosscoder_sae.cfg.d_in + n_layers = len(training_crosscoder_sae.cfg.hook_names) + d_sae = training_crosscoder_sae.cfg.d_sae + + x = torch.randn(batch_size, n_layers, d_in) + train_step_output = training_crosscoder_sae.training_forward_pass( + sae_in=x, + current_l1_coefficient=training_crosscoder_sae.cfg.l1_coefficient, + ) + + assert train_step_output.sae_out.shape == (batch_size, n_layers, d_in) + assert train_step_output.feature_acts.shape == (batch_size, d_sae) + assert "ghost_grad_loss" not in train_step_output.losses + + x_centred = x - x.mean(dim=0, keepdim=True) + expected_mse_loss = ( + ( + torch.nn.functional.mse_loss(train_step_output.sae_out, x, reduction="none") + / (1e-6 + x_centred.norm(dim=-1, keepdim=True)) + ) + .sum(dim=-1) + .mean() + .detach() + .item() + ) + + assert ( + pytest.approx(train_step_output.losses["mse_loss"].item()) == expected_mse_loss # type: ignore + ) + + assert ( + pytest.approx(train_step_output.loss.detach(), rel=1e-3) + == ( + train_step_output.losses["mse_loss"] + + train_step_output.losses["l1_loss"] + + train_step_output.losses.get("ghost_grad_loss", 0.0) + ) + .detach() # type: ignore + .numpy() + ) + + expected_l1_loss = ( + ( + train_step_output.feature_acts + * training_crosscoder_sae.W_dec.norm(dim=2).sum(dim=1) + ) + .norm(dim=1, p=1) + .mean() + ) + assert ( + pytest.approx(train_step_output.losses["l1_loss"].item(), rel=1e-3) # type: ignore + == training_crosscoder_sae.cfg.l1_coefficient + * expected_l1_loss.detach().float() + ) + + +def test_SparseAutoencoder_forward_can_add_noise_to_hidden_pre() -> None: + clean_cfg = build_multilayer_sae_cfg( + d_in=2, + d_sae=4, + noise_scale=0, + hook_layers=[1, 2, 3, 4, 5], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + ) + noisy_cfg = build_multilayer_sae_cfg( + d_in=2, + d_sae=4, + noise_scale=100, + hook_layers=[1, 2, 3, 4, 5], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + ) + clean_sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(clean_cfg), + use_error_term=True, + ) + noisy_sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(noisy_cfg), + use_error_term=True, + ) + + input = torch.randn(3, 5, 2) + + clean_output1 = clean_sae.forward(input) + clean_output2 = clean_sae.forward(input) + noisy_output1 = noisy_sae.forward(input) + noisy_output2 = noisy_sae.forward(input) + + # with no noise, the outputs should be identical + assert torch.allclose(clean_output1, clean_output2) + # noisy outputs should be different + assert not torch.allclose(noisy_output1, noisy_output2) + assert not torch.allclose(clean_output1, noisy_output1) diff --git a/tests/training/test_training_crosscoder_sae.py b/tests/training/test_training_crosscoder_sae.py new file mode 100644 index 000000000..a8499a5cd --- /dev/null +++ b/tests/training/test_training_crosscoder_sae.py @@ -0,0 +1,76 @@ +import pytest +import torch + +from sae_lens.training.training_crosscoder_sae import ( + TrainingCrosscoderSAE, + TrainingCrosscoderSAEConfig, +) +from tests.helpers import build_multilayer_sae_cfg + + +def test_TrainingCrosscoderSAE_training_forward_pass_can_scale_sparsity_penalty_by_decoder_norm(): + cfg = build_multilayer_sae_cfg( + d_in=3, + d_sae=5, + hook_layers=[0, 1, 2, 3], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + ) + training_sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), + use_error_term=True, + ) + x = torch.randn(32, 4, 3) + train_step_output = training_sae.training_forward_pass( + sae_in=x, + current_l1_coefficient=2.0, + ) + feature_acts = train_step_output.feature_acts + decoder_norms = training_sae.W_dec.norm(dim=-1) + decoder_norm = decoder_norms.sum(dim=-1) + # double-check decoder norm is not all ones, or this test is pointless + assert not torch.allclose(decoder_norm, torch.ones_like(decoder_norm), atol=1e-2) + scaled_feature_acts = feature_acts * decoder_norm + + assert ( + pytest.approx(train_step_output.losses["l1_loss"].detach().item()) # type: ignore + == 2.0 * scaled_feature_acts.norm(p=1, dim=1).mean().detach().item() + ) + + +@pytest.mark.parametrize("architecture", ["standard", "gated", "jumprelu", "topk"]) +def test_TrainingCrosscoderSAE_encode_returns_same_value_as_encode_with_hidden_pre( + architecture: str, +): + if architecture != "standard": + pytest.xfail("TODO(mkbehr): support other architectures") + cfg = build_multilayer_sae_cfg( + architecture=architecture, + hook_layers=[0, 1, 2, 3], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + ) + sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), + use_error_term=True, + ) + x = torch.randn(32, len(cfg.hook_names), cfg.d_in) + encode_out = sae.encode(x) + encode_with_hidden_pre_out = sae.encode_with_hidden_pre_fn(x)[0] + assert torch.allclose(encode_out, encode_with_hidden_pre_out) + + +def test_TrainingCrosscoderSAE_heuristic_init(): + cfg = build_multilayer_sae_cfg( + d_in=3, + d_sae=5, + hook_layers=[0, 1, 2, 3], + normalize_sae_decoder=False, + scale_sparsity_penalty_by_decoder_norm=True, + decoder_heuristic_init=True, + decoder_heuristic_init_norm=0.2, + ) + sae = TrainingCrosscoderSAE( + TrainingCrosscoderSAEConfig.from_sae_runner_config(cfg), use_error_term=True + ) + torch.testing.assert_close(sae.W_dec.norm(dim=[1, 2]), torch.full((5,), 0.2))