diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index aa0a6f2ee..0dd39a90e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -32,16 +32,16 @@ jobs: - name: Cache Huggingface assets uses: actions/cache@v4 with: - key: huggingface-4-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} + key: huggingface-5-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} path: ~/.cache/huggingface restore-keys: | - huggingface-4-${{ runner.os }}-${{ matrix.python-version }}- + huggingface-5-${{ runner.os }}-${{ matrix.python-version }}- - name: Load cached Poetry installation id: cached-poetry uses: actions/cache@v4 with: path: ~/.local # the path depends on the OS - key: poetry-${{ runner.os }}-${{ matrix.python-version }}-3 # increment to reset cache + key: poetry-${{ runner.os }}-${{ matrix.python-version }}-5 # increment to reset cache - name: Install Poetry if: steps.cached-poetry.outputs.cache-hit != 'true' uses: snok/install-poetry@v1 @@ -54,9 +54,9 @@ jobs: uses: actions/cache@v4 with: path: .venv - key: venv-2-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} + key: venv-5-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} restore-keys: | - venv-2-${{ runner.os }}-${{ matrix.python-version }}- + venv-5-${{ runner.os }}-${{ matrix.python-version }}- - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' run: poetry install --no-interaction diff --git a/benchmark/bench_fwd_perf.py b/benchmark/bench_fwd_perf.py new file mode 100644 index 000000000..e9dbf377e --- /dev/null +++ b/benchmark/bench_fwd_perf.py @@ -0,0 +1,161 @@ +import argparse +import os +from typing import Any, Callable + +import torch +import torch._inductor.config +import triton +from tabulate import tabulate + +from sae_lens.saes.sae import TrainStepInput +from sae_lens.saes.topk_sae import TopKTrainingSAE +from tests.helpers import ( + build_topk_sae_training_cfg, +) + +torch._inductor.config.coordinate_descent_tuning = True + +parser = argparse.ArgumentParser(add_help=True) +parser.add_argument("--device", type=str, default="cuda") +parser.add_argument( + "--shape", + type=int, + nargs=3, + default=[1024, 1024, 1024 * 16], + help="Shape of the input tensor (seq_len, d_in, d_sae)", +) +parser.add_argument("--k", type=int, default=100, help="Number of topk elements") +parser.add_argument( + "--dtype", + type=str, + default="float32", + help="Datatype (bfloat16, float16, float32, float64)", +) +parser.add_argument( + "-bd", + "--show-breakdown", + action="store_true", + help="Time individual components of the forward pass.", +) +args = parser.parse_args() + +device = args.device + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +d_in = args.shape[1] +d_sae = args.shape[2] +k = args.k +seq_len = args.shape[0] + +cfg_sparse = build_topk_sae_training_cfg( + d_in=d_in, + d_sae=d_sae, + k=k, + device=device, + use_sparse_activations=True, + dtype=args.dtype, +) +cfg_dense = build_topk_sae_training_cfg( + d_in=d_in, + d_sae=d_sae, + k=k, + device=device, + use_sparse_activations=False, + dtype=args.dtype, +) + +sae_sparse = TopKTrainingSAE(cfg_sparse) +sae_dense = TopKTrainingSAE(cfg_dense) + +dead_neuron_mask = None # torch.randn(d_sae, device = device) > 0.1 +input_acts = torch.randn(seq_len, d_in, device=device) +input_var = (input_acts - input_acts.mean(0)).pow(2).sum() + +step_input = TrainStepInput( + sae_in=input_acts, + dead_neuron_mask=dead_neuron_mask, + coefficients={}, +) + + +def encode_proj(sae: TopKTrainingSAE, input_acts: torch.Tensor) -> torch.Tensor: + sae_in = sae.process_sae_in(input_acts) + return sae.hook_sae_acts_pre(sae_in @ sae.W_enc + sae.b_enc) + + +def topk_activation(sae: TopKTrainingSAE, hidden_pre: torch.Tensor) -> torch.Tensor: + return sae.activation_fn(hidden_pre) + + +def decode_step(sae: TopKTrainingSAE, feature_acts: torch.Tensor) -> torch.Tensor: + return sae.decode(feature_acts) + + +def loss_computation( + sae: TopKTrainingSAE, sae_out: torch.Tensor, sae_in: torch.Tensor +) -> torch.Tensor: + # Calculate MSE loss + per_item_mse_loss = sae.mse_loss_fn(sae_out, sae_in) + return per_item_mse_loss.sum(dim=-1).mean() + + +def triton_bench(fn: Callable[[], Any]) -> float: + # note that the warmup and rep params here are in ms, not iterations + return triton.testing.do_bench(fn, warmup=1000, rep=2000) # type: ignore + + +def benchmark_sae(sae: TopKTrainingSAE) -> dict[str, float]: + results = {} + results["full_forward_pass"] = triton_bench( + lambda: sae.training_forward_pass(step_input) + ) + if args.show_breakdown: + results["encode_proj"] = triton_bench(lambda: encode_proj(sae, input_acts)) + hidden_pre = encode_proj(sae, input_acts) + results["topk_activation"] = triton_bench( + lambda: topk_activation(sae, hidden_pre) + ) + feature_acts = topk_activation(sae, hidden_pre) + results["decode_step"] = triton_bench(lambda: decode_step(sae, feature_acts)) + sae_out = decode_step(sae, feature_acts) + results["loss_computation"] = triton_bench( + lambda: loss_computation(sae, sae_out, input_acts) + ) + results["other"] = 2 * results["full_forward_pass"] - sum(results.values()) + return results + + +if __name__ == "__main__": + if args.show_breakdown: + print("This may take a while (5 mins). Go grab a coffee!") + results_sparse = benchmark_sae(sae_sparse) + results_dense = benchmark_sae(sae_dense) + speedup = results_dense["full_forward_pass"] / results_sparse["full_forward_pass"] + + # Pretty print results table with metrics as columns + if args.show_breakdown: + headers = [ + "Implementation", + "Encode", + "TopK", + "Decode", + "Loss Calc", + "Full Fwd", + "Other", + ] + else: + headers = [ + "Implementation", + "Full Fwd", + ] + + metric_keys = results_sparse.keys() + + table_data = [ + ["Sparse"] + [f"{results_sparse[key]:.3f}" for key in metric_keys], + ["Dense"] + [f"{results_dense[key]:.3f}" for key in metric_keys], + ] + print("Metric: Latency (ms)") + print(f"Speedup: {speedup:.3f}") + print("\n" + tabulate(table_data, headers=headers, tablefmt="grid")) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 569c2dfee..12033ef8b 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -466,6 +466,12 @@ def get_sparsity_and_variance_metrics( sae_out_scaled = sae.decode(sae_feature_activations).to( original_act_scaled.device ) + if sae_feature_activations.is_sparse or sae_feature_activations.is_sparse_csr: + batch_dims = sae_feature_activations.batch_dims # type: ignore + sae_feature_activations = sae_feature_activations.to_dense() + sae_feature_activations = sae_feature_activations.reshape( + batch_dims + (sae_feature_activations.shape[-1],) + ) del cache sae_out = activation_scaler.unscale(sae_out_scaled) diff --git a/sae_lens/pretokenize_runner.py b/sae_lens/pretokenize_runner.py index e50cb1ef9..e3c45c370 100644 --- a/sae_lens/pretokenize_runner.py +++ b/sae_lens/pretokenize_runner.py @@ -1,9 +1,10 @@ import io import json import sys +from collections.abc import Iterator from dataclasses import dataclass from pathlib import Path -from typing import Iterator, Literal, cast +from typing import Literal, cast import torch from datasets import Dataset, DatasetDict, load_dataset diff --git a/sae_lens/saes/sae.py b/sae_lens/saes/sae.py index 795c01385..72f424428 100644 --- a/sae_lens/saes/sae.py +++ b/sae_lens/saes/sae.py @@ -14,7 +14,6 @@ Generic, Literal, NamedTuple, - Type, TypeVar, ) @@ -534,7 +533,7 @@ def save_model(self, path: str | Path) -> tuple[Path, Path]: @classmethod @deprecated("Use load_from_disk instead") def load_from_pretrained( - cls: Type[T_SAE], + cls: type[T_SAE], path: str | Path, device: str = "cpu", dtype: str | None = None, @@ -543,7 +542,7 @@ def load_from_pretrained( @classmethod def load_from_disk( - cls: Type[T_SAE], + cls: type[T_SAE], path: str | Path, device: str = "cpu", dtype: str | None = None, @@ -564,7 +563,7 @@ def load_from_disk( @classmethod def from_pretrained( - cls: Type[T_SAE], + cls: type[T_SAE], release: str, sae_id: str, device: str = "cpu", @@ -585,7 +584,7 @@ def from_pretrained( @classmethod def from_pretrained_with_cfg_and_sparsity( - cls: Type[T_SAE], + cls: type[T_SAE], release: str, sae_id: str, device: str = "cpu", @@ -684,7 +683,7 @@ def from_pretrained_with_cfg_and_sparsity( return sae, cfg_dict, log_sparsities @classmethod - def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE: + def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE: """Create an SAE from a config dictionary.""" sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"]) sae_config_cls = cls.get_sae_config_class_for_architecture( @@ -694,8 +693,8 @@ def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE: @classmethod def get_sae_class_for_architecture( - cls: Type[T_SAE], architecture: str - ) -> Type[T_SAE]: + cls: type[T_SAE], architecture: str + ) -> type[T_SAE]: """Get the SAE class for a given architecture.""" sae_cls, _ = get_sae_class(architecture) if not issubclass(sae_cls, cls): @@ -1000,8 +999,8 @@ def log_histograms(self) -> dict[str, NDArray[Any]]: @classmethod def get_sae_class_for_architecture( - cls: Type[T_TRAINING_SAE], architecture: str - ) -> Type[T_TRAINING_SAE]: + cls: type[T_TRAINING_SAE], architecture: str + ) -> type[T_TRAINING_SAE]: """Get the SAE class for a given architecture.""" sae_cls, _ = get_sae_training_class(architecture) if not issubclass(sae_cls, cls): diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 0e721a400..5313c34ac 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -6,6 +6,7 @@ import torch from jaxtyping import Float from torch import nn +from transformer_lens.hook_points import HookPoint from typing_extensions import override from sae_lens.saes.sae import ( @@ -15,34 +16,83 @@ TrainingSAE, TrainingSAEConfig, TrainStepInput, + _disable_hooks, ) +class SparseHookPoint(HookPoint): + """ + A HookPoint that takes in a sparse tensor. + Overrides TransformerLens's HookPoint. + """ + + def __init__(self, d_sae: int): + super().__init__() + self.d_sae = d_sae + + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + using_hooks = ( + self._forward_hooks is not None and len(self._forward_hooks) > 0 + ) or (self._backward_hooks is not None and len(self._backward_hooks) > 0) + if using_hooks and (x.is_sparse or x.is_sparse_csr): + dense_x = x.to_dense() + return dense_x.reshape(x.batch_dims + (x.shape[-1],)) # type: ignore + return x # if no hooks are being used, use passthrough + + class TopK(nn.Module): """ A simple TopK activation that zeroes out all but the top K elements along the last dimension, and applies ReLU to the top K elements. """ - b_enc: nn.Parameter + use_sparse_activations: bool def __init__( self, k: int, + use_sparse_activations: bool = False, ): super().__init__() self.k = k + self.use_sparse_activations = use_sparse_activations - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: """ 1) Select top K elements along the last dimension. 2) Apply ReLU. 3) Zero out all other entries. """ - topk = torch.topk(x, k=self.k, dim=-1) - values = topk.values.relu() + topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False) + values = topk_values.relu() + if self.use_sparse_activations: + # Produce a CSR sparse tensor (use sparse matrix multiply in decode) + # Convert flat batch indices back to multi-dimensional indices + batch_dims = x.shape[:-1] + prod_batch_dims = int(torch.prod(torch.tensor(batch_dims)).item()) + csr_tensor = torch.sparse_csr_tensor( + torch.arange( + start=0, + end=len(topk_indices.flatten()) + self.k, + step=self.k, + device=x.device, + ), + topk_indices.flatten(), + topk_values.flatten(), + dtype=x.dtype, + device=x.device, + size=(prod_batch_dims, x.shape[-1]), + ) + # A little hacky - let me know if you think of a better way to do this. - Will + csr_tensor.batch_dims = batch_dims # type: ignore + return csr_tensor + result = torch.zeros_like(x) - result.scatter_(-1, topk.indices, values) + result.scatter_(-1, topk_indices, values) return result @@ -60,6 +110,24 @@ def architecture(cls) -> str: return "topk" +def _sparse_matmul_nd( + sparse_tensor: torch.Tensor, dense_matrix: torch.Tensor +) -> torch.Tensor: + """ + Multiply a sparse tensor of shape [..., d_sae] with a dense matrix of shape [d_sae, d_out] + to get a result of shape [..., d_out]. + """ + batch_dims = sparse_tensor.batch_dims # type: ignore + d_out = dense_matrix.shape[-1] + + # Do the matrix multiplication + result_2d = torch.sparse.mm(sparse_tensor, dense_matrix) # [batch_size, d_out] + + # Reshape back to original batch dimensions + result_shape = tuple(batch_dims) + (d_out,) + return result_2d.view(result_shape) + + class TopKSAE(SAE[TopKSAEConfig]): """ An inference-only sparse autoencoder using a "topk" activation function. @@ -96,21 +164,26 @@ def encode( return self.hook_sae_acts_post(self.activation_fn(hidden_pre)) def decode( - self, feature_acts: Float[torch.Tensor, "... d_sae"] + self, + feature_acts: Float[torch.Tensor, "... d_sae"], ) -> Float[torch.Tensor, "... d_in"]: """ Reconstructs the input from topk feature activations. Applies optional finetuning scaling, hooking to recons, out normalization, and optional head reshaping. """ - sae_out_pre = feature_acts @ self.W_dec + self.b_dec + # Handle sparse tensors using efficient sparse matrix multiplication + if feature_acts.is_sparse: + sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec + else: + sae_out_pre = feature_acts @ self.W_dec + self.b_dec sae_out_pre = self.hook_sae_recons(sae_out_pre) sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre) return self.reshape_fn_out(sae_out_pre, self.d_head) @override def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]: - return TopK(self.cfg.k) + return TopK(self.cfg.k, use_sparse_activations=False) @override @torch.no_grad() @@ -124,9 +197,43 @@ def fold_W_dec_norm(self) -> None: class TopKTrainingSAEConfig(TrainingSAEConfig): """ Configuration class for training a TopKTrainingSAE. + + Args: + k (int): Number of top features to keep active. Only the top k features + with the highest pre-activations will be non-zero. Defaults to 100. + use_sparse_activations (bool): Whether to use sparse tensor representations + for activations during training. This can reduce memory usage and improve + performance when k is small relative to d_sae, but is only worthwhile if + using float32 and not using autocast. Defaults to False. + aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages + dead neurons to learn useful features. This loss helps prevent neuron death + in TopK SAEs by having dead neurons reconstruct the residual error from + live neurons. Defaults to 1.0. + decoder_init_norm (float | None): Norm to initialize decoder weights to. + 0.1 corresponds to the "heuristic" initialization from Anthropic's April update. + Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1. + d_in (int): Input dimension (dimensionality of the activations being encoded). + Inherited from SAEConfig. + d_sae (int): SAE latent dimension (number of features in the SAE). + Inherited from SAEConfig. + dtype (str): Data type for the SAE parameters. Inherited from SAEConfig. + Defaults to "float32". + device (str): Device to place the SAE on. Inherited from SAEConfig. + Defaults to "cpu". + apply_b_dec_to_input (bool): Whether to apply decoder bias to the input + before encoding. Inherited from SAEConfig. Defaults to True. + normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]): + Normalization strategy for input activations. Inherited from SAEConfig. + Defaults to "none". + reshape_activations (Literal["none", "hook_z"]): How to reshape activations + (useful for attention head outputs). Inherited from SAEConfig. + Defaults to "none". + metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.). + Inherited from SAEConfig. """ k: int = 100 + use_sparse_activations: bool = False aux_loss_coefficient: float = 1.0 @override @@ -144,6 +251,8 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]): def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False): super().__init__(cfg, use_error_term) + self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae) + self.setup() @override def initialize_weights(self) -> None: @@ -163,6 +272,41 @@ def encode_with_hidden_pre( feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) return feature_acts, hidden_pre + @override + def decode( + self, + feature_acts: Float[torch.Tensor, "... d_sae"], + ) -> Float[torch.Tensor, "... d_in"]: + """ + Decodes feature activations back into input space, + applying optional finetuning scale, hooking, out normalization, etc. + """ + # Handle sparse tensors using efficient sparse matrix multiplication + if feature_acts.is_sparse or feature_acts.is_sparse_csr: + sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec + else: + sae_out_pre = feature_acts @ self.W_dec + self.b_dec + sae_out_pre = self.hook_sae_recons(sae_out_pre) + sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre) + return self.reshape_fn_out(sae_out_pre, self.d_head) + + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the SAE.""" + feature_acts = self.encode(x) + sae_out = self.decode(feature_acts) + + if self.use_error_term: + with torch.no_grad(): + # Recompute without hooks for true error term + with _disable_hooks(self): + feature_acts_clean = self.encode(x) + x_reconstruct_clean = self.decode(feature_acts_clean) + sae_error = self.hook_sae_error(x - x_reconstruct_clean) + sae_out = sae_out + sae_error + + return self.hook_sae_output(sae_out) + @override def calculate_aux_loss( self, @@ -189,7 +333,7 @@ def fold_W_dec_norm(self) -> None: @override def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]: - return TopK(self.cfg.k) + return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations) @override def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]: diff --git a/sae_lens/tokenization_and_batching.py b/sae_lens/tokenization_and_batching.py index d2d5de201..f1aedacae 100644 --- a/sae_lens/tokenization_and_batching.py +++ b/sae_lens/tokenization_and_batching.py @@ -1,4 +1,4 @@ -from typing import Generator, Iterator +from collections.abc import Generator, Iterator import torch diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 1e0e68dc5..79e1b83ea 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -253,12 +253,14 @@ def _train_step( ) with torch.no_grad(): - did_fire = (train_step_output.feature_acts > 0).float().sum(-2) > 0 + # calling .bool() should be equivalent to .abs() > 0, and work with coo tensors + firing_feats = train_step_output.feature_acts.bool().float() + did_fire = firing_feats.sum(-2).bool() + if did_fire.is_sparse: + did_fire = did_fire.to_dense() self.n_forward_passes_since_fired += 1 self.n_forward_passes_since_fired[did_fire] = 0 - self.act_freq_scores += ( - (train_step_output.feature_acts.abs() > 0).float().sum(0) - ) + self.act_freq_scores += firing_feats.sum(0) self.n_frac_active_samples += self.cfg.train_batch_size_samples # Grad scaler will rescale gradients if autocast is enabled @@ -310,7 +312,7 @@ def _build_train_step_log_dict( loss = output.loss.item() # metrics for currents acts - l0 = (feature_acts > 0).float().sum(-1).mean() + l0 = feature_acts.bool().float().sum(-1).to_dense().mean() current_learning_rate = self.optimizer.param_groups[0]["lr"] per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze() diff --git a/sae_lens/training/types.py b/sae_lens/training/types.py index d42a6e546..23796590e 100644 --- a/sae_lens/training/types.py +++ b/sae_lens/training/types.py @@ -1,4 +1,4 @@ -from typing import Iterator +from collections.abc import Iterator import torch diff --git a/tests/helpers.py b/tests/helpers.py index 852e9fcac..b4ab6a366 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,5 +1,6 @@ import copy -from typing import Any, Literal, Sequence, TypedDict, cast +from collections.abc import Sequence +from typing import Any, Literal, TypedDict, cast import pytest import torch @@ -97,6 +98,7 @@ class TrainingSAEConfigDict(TypedDict, total=False): jumprelu_init_threshold: float jumprelu_bandwidth: float k: int # For TopK + use_sparse_activations: bool # For TopK l0_coefficient: float # For JumpReLU l0_warm_up_steps: int pre_act_loss_coefficient: float | None # For JumpReLU @@ -233,6 +235,15 @@ def _build_runner_config( return final_config +def _update_sae_metadata(runner_cfg: LanguageModelSAERunnerConfig[Any]): + runner_cfg.sae.metadata.hook_name = runner_cfg.hook_name + runner_cfg.sae.metadata.hook_head_index = runner_cfg.hook_head_index + runner_cfg.sae.metadata.model_name = runner_cfg.model_name + runner_cfg.sae.metadata.model_class_name = runner_cfg.model_class_name + runner_cfg.sae.metadata.dataset_path = runner_cfg.dataset_path + runner_cfg.sae.metadata.prepend_bos = runner_cfg.prepend_bos + + # --- Standard SAE Builder --- def build_runner_cfg( **kwargs: Any, @@ -255,12 +266,7 @@ def build_runner_cfg( cast(dict[str, Any], default_sae_config), **kwargs, ) - runner_cfg.sae.metadata.hook_name = runner_cfg.hook_name - runner_cfg.sae.metadata.hook_head_index = runner_cfg.hook_head_index - runner_cfg.sae.metadata.model_name = runner_cfg.model_name - runner_cfg.sae.metadata.model_class_name = runner_cfg.model_class_name - runner_cfg.sae.metadata.dataset_path = runner_cfg.dataset_path - runner_cfg.sae.metadata.prepend_bos = runner_cfg.prepend_bos + _update_sae_metadata(runner_cfg) return runner_cfg @@ -300,11 +306,13 @@ def build_jumprelu_runner_cfg( "l0_warm_up_steps": 0, "pre_act_loss_coefficient": None, } - return _build_runner_config( + runner_cfg = _build_runner_config( JumpReLUTrainingSAEConfig, cast(dict[str, Any], default_sae_config), **kwargs, ) + _update_sae_metadata(runner_cfg) + return runner_cfg def build_jumprelu_sae_cfg(**kwargs: Any) -> JumpReLUSAEConfig: @@ -338,11 +346,13 @@ def build_gated_runner_cfg( "apply_b_dec_to_input": False, "l1_warm_up_steps": 0, } - return _build_runner_config( + runner_cfg = _build_runner_config( GatedTrainingSAEConfig, cast(dict[str, Any], default_sae_config), **kwargs, ) + _update_sae_metadata(runner_cfg) + return runner_cfg def build_gated_sae_cfg(**kwargs: Any) -> GatedSAEConfig: @@ -383,11 +393,13 @@ def build_topk_runner_cfg( # Update the default config *before* passing it to _build_runner_config final_default_sae_config = cast(dict[str, Any], temp_sae_config) - return _build_runner_config( + runner_cfg = _build_runner_config( TopKTrainingSAEConfig, final_default_sae_config, **kwargs, ) + _update_sae_metadata(runner_cfg) + return runner_cfg def build_topk_sae_cfg(**kwargs: Any) -> TopKSAEConfig: @@ -430,11 +442,13 @@ def build_batchtopk_runner_cfg( # Update the default config *before* passing it to _build_runner_config final_default_sae_config = cast(dict[str, Any], temp_sae_config) - return _build_runner_config( + runner_cfg = _build_runner_config( BatchTopKTrainingSAEConfig, final_default_sae_config, **kwargs, ) + _update_sae_metadata(runner_cfg) + return runner_cfg def build_batchtopk_sae_training_cfg(**kwargs: Any) -> BatchTopKTrainingSAEConfig: diff --git a/tests/refactor_compatibility/test_topk_sae_equivalence.py b/tests/refactor_compatibility/test_topk_sae_equivalence.py index e7b5f7700..a5f5896ba 100644 --- a/tests/refactor_compatibility/test_topk_sae_equivalence.py +++ b/tests/refactor_compatibility/test_topk_sae_equivalence.py @@ -313,7 +313,7 @@ def test_topk_sae_training_equivalence(): ) assert_close( old_out.feature_acts, - new_out.feature_acts, + new_out.feature_acts.to_dense(), atol=1e-5, msg="Training feature_acts differ between old and new implementations.", ) diff --git a/tests/saes/test_standard_sae.py b/tests/saes/test_standard_sae.py index 95d397374..8068193ce 100644 --- a/tests/saes/test_standard_sae.py +++ b/tests/saes/test_standard_sae.py @@ -436,9 +436,13 @@ def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): unscaled_activations = activations / norm_scaling_factor feature_activations_1 = sae.encode(activations) + if feature_activations_1.is_sparse: + feature_activations_1 = feature_activations_1.to_dense() # with the scaling folded in, the unscaled activations should produce the same # result. feature_activations_2 = sae2.encode(unscaled_activations) + if feature_activations_2.is_sparse: + feature_activations_2 = feature_activations_2.to_dense() assert_close( feature_activations_1.nonzero(), diff --git a/tests/saes/test_topk_sae.py b/tests/saes/test_topk_sae.py index 42420b1a8..6f0d66f8d 100644 --- a/tests/saes/test_topk_sae.py +++ b/tests/saes/test_topk_sae.py @@ -1,12 +1,13 @@ import os from pathlib import Path +import numpy as np import pytest import torch from sparsify import SparseCoder, SparseCoderConfig from sae_lens.saes.sae import SAE, TrainStepInput -from sae_lens.saes.topk_sae import TopKSAE, TopKTrainingSAE +from sae_lens.saes.topk_sae import TopK, TopKSAE, TopKTrainingSAE from tests.helpers import ( assert_close, build_topk_sae_cfg, @@ -17,11 +18,11 @@ def test_TopKTrainingSAE_topk_aux_loss_matches_unnormalized_sparsify_implementation(): d_in = 128 d_sae = 192 + k = 26 cfg = build_topk_sae_training_cfg( d_in=d_in, d_sae=d_sae, - k=26, - decoder_init_norm=1.0, # TODO: why is this needed?? + k=k, ) sae = TopKTrainingSAE(cfg) @@ -59,7 +60,7 @@ def test_TopKTrainingSAE_topk_aux_loss_matches_unnormalized_sparsify_implementat normalization = input_var / input_acts.shape[0] raw_aux_loss = sae_out.losses["auxiliary_reconstruction_loss"].item() # type: ignore norm_aux_loss = raw_aux_loss / normalization - assert norm_aux_loss == pytest.approx(comparison_aux_loss, abs=1e-2) + assert norm_aux_loss == pytest.approx(comparison_aux_loss, abs=3e-2) def test_TopKSAE_save_and_load_from_pretrained(tmp_path: Path) -> None: @@ -139,10 +140,75 @@ def test_TopKTrainingSAE_save_and_load_inference_sae(tmp_path: Path) -> None: inference_sae_out = inference_sae.decode(inference_feature_acts) # Should produce identical outputs - assert_close(training_feature_acts, inference_feature_acts) - assert_close(training_sae_out, inference_sae_out) + assert_close(training_feature_acts.to_dense(), inference_feature_acts) + assert_close(training_sae_out, inference_sae_out, rtol=1e-4, atol=1e-4) # Test the full forward pass training_full_out = training_sae(sae_in) inference_full_out = inference_sae(sae_in) - assert_close(training_full_out, inference_full_out) + assert_close(training_full_out, inference_full_out, rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize("num_dims", [1, 2, 3, 4, 5]) +def test_topK_sparse_activations(num_dims: bool): + # Validate that the sparse top-K intermediate output (COO format) + # we use to accelerate the decoder matches the dense top-K output. + dims = (np.arange(1, num_dims + 1) + 3).tolist() + dims[-1] = 1024 + for k in [1, 10, 100, 1000]: + topk_sparse = TopK(k, use_sparse_activations=True) + topk_dense = TopK(k, use_sparse_activations=False) + x = torch.randn(*dims) + 50.0 + sparse_x = topk_sparse(x) + assert sparse_x.is_sparse or sparse_x.is_sparse_csr + sparse_x = sparse_x.to_dense() + sparse_x = sparse_x.reshape(x.shape) + dense_x = topk_dense(x) + assert_close(dense_x, sparse_x) + + +@pytest.mark.parametrize("num_dims", [1, 2, 3, 4, 5]) +def test_topK_activation_sparse_mm(num_dims: int): + # Validate that our decoder produces the same output when using the sparse intermediates + # as when using the dense intermediates. + d_in = 128 + d_sae = 1024 + dims = (np.arange(1, num_dims + 1) + 3).tolist() + dims[-1] = d_sae + + cfg = build_topk_sae_training_cfg( + d_in=d_in, + d_sae=d_sae, + k=26, + ) + + sae = TopKTrainingSAE(cfg) + + with torch.no_grad(): + # increase b_enc so all features are likely above 0 + # sparsify includes a relu() in their pre_acts, but + # this is not something we need to try to replicate. + sae.b_enc.data = sae.b_enc + 100.0 + + for k in [1, 10, 100, 1000]: + topk_sparse = TopK(k, use_sparse_activations=True) + topk_dense = TopK(k, use_sparse_activations=False) + x = torch.randn(*dims) + 50.0 + sparse_x = topk_sparse(x) + sae_out_sparse = sae.decode(sparse_x) + dense_x = topk_dense(x) + sae_out_dense = sae.decode(dense_x) + assert_close(sae_out_sparse, sae_out_dense, rtol=1e-4, atol=5e-4) + + +def test_TopKTrainingSAE_sparse_activations_config(): + # Check that our config is respected in both training & inference SAEs + cfg = build_topk_sae_training_cfg(k=100, use_sparse_activations=True) + sae = TopKTrainingSAE(cfg) + assert sae.activation_fn.use_sparse_activations # type: ignore + assert sae.cfg.use_sparse_activations + + cfg = build_topk_sae_training_cfg(k=100, use_sparse_activations=False) + sae = TopKTrainingSAE(cfg) + assert not sae.activation_fn.use_sparse_activations # type: ignore + assert not sae.cfg.use_sparse_activations diff --git a/tests/test_evals.py b/tests/test_evals.py index 9081e01a6..34eb21064 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -29,12 +29,14 @@ from sae_lens.loading.pretrained_saes_directory import PretrainedSAELookup from sae_lens.saes.sae import SAE, TrainingSAE from sae_lens.saes.standard_sae import StandardSAE, StandardTrainingSAE +from sae_lens.saes.topk_sae import TopKTrainingSAE from sae_lens.training.activation_scaler import ActivationScaler from sae_lens.training.activations_store import ActivationsStore from tests.helpers import ( NEEL_NANDA_C4_10K_DATASET, TINYSTORIES_MODEL, build_runner_cfg, + build_topk_runner_cfg, load_model_cached, ) @@ -170,6 +172,34 @@ def test_run_evals_base_sae( assert len(eval_metrics) > 0 +@pytest.mark.parametrize("use_sparse_activations", [True, False]) +def test_run_evals_sparse_topk_sae( + model: HookedTransformer, + use_sparse_activations: bool, +): + cfg = build_topk_runner_cfg( + use_sparse_activations=use_sparse_activations, + model_name="tiny-stories-1M", + dataset_path="roneneldan/TinyStories", + hook_name="blocks.1.hook_resid_pre", + d_in=64, + ) + sae = TopKTrainingSAE(cfg.sae) + activation_store = ActivationsStore.from_config( + model, cfg, override_dataset=Dataset.from_list([{"text": "hello world"}] * 2000) + ) + eval_metrics, _ = run_evals( + sae=sae, + activation_store=activation_store, + activation_scaler=ActivationScaler(), + model=model, + eval_config=get_eval_everything_config(), + ) + + assert set(eval_metrics.keys()).issubset(set(all_possible_keys)) + assert len(eval_metrics) > 0 + + def test_run_evals_training_sae( training_sae: TrainingSAE[Any], activation_store: ActivationsStore, diff --git a/tests/test_util.py b/tests/test_util.py index 99531c96c..72c0f008b 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,7 +2,10 @@ import pytest -from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name, path_or_tmp_dir +from sae_lens.util import ( + extract_stop_at_layer_from_tlens_hook_name, + path_or_tmp_dir, +) @pytest.mark.parametrize( diff --git a/tests/training/test_sae_trainer.py b/tests/training/test_sae_trainer.py index c1b2fa774..7583335b8 100644 --- a/tests/training/test_sae_trainer.py +++ b/tests/training/test_sae_trainer.py @@ -179,8 +179,10 @@ def test_log_feature_sparsity__handles_zeroes_by_default_fp16() -> None: assert _log_feature_sparsity(fp16_zeroes).item() != float("-inf") +@pytest.mark.parametrize("sparse_feature_acts", [True, False]) def test_build_train_step_log_dict( trainer: SAETrainer[StandardTrainingSAE, StandardTrainingSAEConfig], + sparse_feature_acts: bool, ) -> None: train_output = TrainStepOutput( sae_in=torch.tensor([[-1, 0], [0, 2], [1, 1]]).float(), @@ -196,6 +198,8 @@ def test_build_train_step_log_dict( "topk_threshold": torch.tensor(0.5), }, ) + if sparse_feature_acts: + train_output.feature_acts = train_output.feature_acts.to_sparse_coo() # we're relying on the trainer only for some of the metrics here # we should more / less try to break this and push