From 4a417c3b44eb7929cee6aea062e1a04cd6fa2d09 Mon Sep 17 00:00:00 2001 From: wz-ml Date: Tue, 16 Sep 2025 16:59:54 -0700 Subject: [PATCH 01/19] Add kernel skeleton and remove redundant b_enc param from TopK class --- sae_lens/saes/kernels/fused_gemm_topk.py | 29 ++++++++++++++++++++++++ sae_lens/saes/topk_sae.py | 8 +++---- 2 files changed, 32 insertions(+), 5 deletions(-) create mode 100644 sae_lens/saes/kernels/fused_gemm_topk.py diff --git a/sae_lens/saes/kernels/fused_gemm_topk.py b/sae_lens/saes/kernels/fused_gemm_topk.py new file mode 100644 index 000000000..0c75ad6ef --- /dev/null +++ b/sae_lens/saes/kernels/fused_gemm_topk.py @@ -0,0 +1,29 @@ +import torch +import triton + + +def fused_gemm_topk( + x: torch.Tensor, + W: torch.Tensor, + b: torch.Tensor, + k: int, +): + """ + Mathematically, equates to: + y = x @ W.T + b + Before setting all but the topK elements of y to 0, and returning y. + + Params: + x: (M, d_hidden) + W: (d_sae, d_hidden) + b: (d_sae) + k: int + + Returns: + y: (M, d_sae) + """ + + +@triton.jit +def _fused_gemm_topk_kernel(): + pass diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 0e721a400..40b06046d 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -24,8 +24,6 @@ class TopK(nn.Module): and applies ReLU to the top K elements. """ - b_enc: nn.Parameter - def __init__( self, k: int, @@ -39,10 +37,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: 2) Apply ReLU. 3) Zero out all other entries. """ - topk = torch.topk(x, k=self.k, dim=-1) - values = topk.values.relu() + topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1) + values = topk_values.relu() result = torch.zeros_like(x) - result.scatter_(-1, topk.indices, values) + result.scatter_(-1, topk_indices, values) return result From e6fac7dfb8b225479c74866c5e0e39659b192088 Mon Sep 17 00:00:00 2001 From: wz-ml Date: Wed, 17 Sep 2025 12:34:03 -0700 Subject: [PATCH 02/19] Add option in topK to save SAE activations as a sparse tensor --- sae_lens/saes/topk_sae.py | 19 ++++++++++++-- tests/saes/test_topk_sae.py | 50 ++++++++++++++++++++++++++++++++++++- 2 files changed, 66 insertions(+), 3 deletions(-) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 40b06046d..bcca84854 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -31,13 +31,27 @@ def __init__( super().__init__() self.k = k - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, x: torch.Tensor, sparse_intermediate: bool = False + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ 1) Select top K elements along the last dimension. 2) Apply ReLU. 3) Zero out all other entries. """ topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1) + if sparse_intermediate: + # Produce a COO sparse tensor (use sparse matrix multiply in decode) + M, N = x.shape + sparse_indices = torch.stack( + [ + torch.arange(M, device=x.device).repeat_interleave(self.k), + topk_indices.flatten(), + ] + ) + return torch.sparse_coo_tensor( + sparse_indices, topk_values.flatten(), (M, N) + ) values = topk_values.relu() result = torch.zeros_like(x) result.scatter_(-1, topk_indices, values) @@ -94,7 +108,8 @@ def encode( return self.hook_sae_acts_post(self.activation_fn(hidden_pre)) def decode( - self, feature_acts: Float[torch.Tensor, "... d_sae"] + self, + feature_acts: Float[torch.Tensor, "... d_sae"], ) -> Float[torch.Tensor, "... d_in"]: """ Reconstructs the input from topk feature activations. diff --git a/tests/saes/test_topk_sae.py b/tests/saes/test_topk_sae.py index 42420b1a8..f7ceeb22e 100644 --- a/tests/saes/test_topk_sae.py +++ b/tests/saes/test_topk_sae.py @@ -6,7 +6,7 @@ from sparsify import SparseCoder, SparseCoderConfig from sae_lens.saes.sae import SAE, TrainStepInput -from sae_lens.saes.topk_sae import TopKSAE, TopKTrainingSAE +from sae_lens.saes.topk_sae import TopK, TopKSAE, TopKTrainingSAE from tests.helpers import ( assert_close, build_topk_sae_cfg, @@ -146,3 +146,51 @@ def test_TopKTrainingSAE_save_and_load_inference_sae(tmp_path: Path) -> None: training_full_out = training_sae(sae_in) inference_full_out = inference_sae(sae_in) assert_close(training_full_out, inference_full_out) + + +def test_topK_activation_sparse_intermediate(): + # Validate that the sparse top-K intermediate output (COO format) + # we use to accelerate the decoder matches the dense top-K output. + d_sae = 1024 + M = 128 + for k in [1, 10, 100, 1000]: + topk = TopK(k) + x = torch.randn(M, d_sae) + 50.0 + sparse_x = topk(x, sparse_intermediate=True) + assert sparse_x.is_sparse + assert sparse_x.shape == (M, d_sae) + assert sparse_x.coalesce().values().numel() == k * M + dense_x = topk(x, sparse_intermediate=False) + assert_close(dense_x, sparse_x.to_dense()) + + +def test_topK_activation_sparse_mm(): + # Validate that our decoder produces the same output when using the sparse intermediates + # as when using the dense intermediates. + d_in = 128 + d_sae = 1024 + M = 128 + + cfg = build_topk_sae_training_cfg( + d_in=d_in, + d_sae=d_sae, + k=26, + decoder_init_norm=1.0, # TODO: why is this needed?? + ) + + sae = TopKTrainingSAE(cfg) + + with torch.no_grad(): + # increase b_enc so all features are likely above 0 + # sparsify includes a relu() in their pre_acts, but + # this is not something we need to try to replicate. + sae.b_enc.data = sae.b_enc + 100.0 + + for k in [1, 10, 100, 1000]: + topk = TopK(k) + x = torch.randn(M, d_sae) + 50.0 + sparse_x = topk(x, sparse_intermediate=True) + sae_out_sparse = sae.decode(sparse_x) + dense_x = topk(x, sparse_intermediate=False) + sae_out_dense = sae.decode(dense_x) + assert_close(sae_out_sparse, sae_out_dense, rtol=1e-4, atol=5e-4) From 9a8b9fc0cd71c7392f308c82b0d5fb5273c0c83f Mon Sep 17 00:00:00 2001 From: wz-ml Date: Wed, 17 Sep 2025 12:47:06 -0700 Subject: [PATCH 03/19] Add sparse activation config flag & update tests --- sae_lens/saes/topk_sae.py | 15 +++++++++++---- tests/helpers.py | 4 +++- tests/saes/test_topk_sae.py | 36 ++++++++++++++++++++++++++++++------ 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index bcca84854..a263db28e 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -24,15 +24,20 @@ class TopK(nn.Module): and applies ReLU to the top K elements. """ + sparse_intermediate: bool + def __init__( self, k: int, + sparse_intermediate: bool = True, ): super().__init__() self.k = k + self.sparse_intermediate = sparse_intermediate def forward( - self, x: torch.Tensor, sparse_intermediate: bool = False + self, + x: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ 1) Select top K elements along the last dimension. @@ -40,7 +45,7 @@ def forward( 3) Zero out all other entries. """ topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1) - if sparse_intermediate: + if self.sparse_intermediate: # Produce a COO sparse tensor (use sparse matrix multiply in decode) M, N = x.shape sparse_indices = torch.stack( @@ -65,6 +70,7 @@ class TopKSAEConfig(SAEConfig): """ k: int = 100 + sparse_intermediate: bool = True @override @classmethod @@ -123,7 +129,7 @@ def decode( @override def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]: - return TopK(self.cfg.k) + return TopK(self.cfg.k, sparse_intermediate=self.cfg.sparse_intermediate) @override @torch.no_grad() @@ -140,6 +146,7 @@ class TopKTrainingSAEConfig(TrainingSAEConfig): """ k: int = 100 + sparse_intermediate: bool = True aux_loss_coefficient: float = 1.0 @override @@ -202,7 +209,7 @@ def fold_W_dec_norm(self) -> None: @override def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]: - return TopK(self.cfg.k) + return TopK(self.cfg.k, sparse_intermediate=self.cfg.sparse_intermediate) @override def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]: diff --git a/tests/helpers.py b/tests/helpers.py index 852e9fcac..5dee87a3e 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,5 +1,6 @@ import copy -from typing import Any, Literal, Sequence, TypedDict, cast +from collections.abc import Sequence +from typing import Any, Literal, TypedDict, cast import pytest import torch @@ -97,6 +98,7 @@ class TrainingSAEConfigDict(TypedDict, total=False): jumprelu_init_threshold: float jumprelu_bandwidth: float k: int # For TopK + sparse_intermediate: bool # For TopK l0_coefficient: float # For JumpReLU l0_warm_up_steps: int pre_act_loss_coefficient: float | None # For JumpReLU diff --git a/tests/saes/test_topk_sae.py b/tests/saes/test_topk_sae.py index f7ceeb22e..58458f338 100644 --- a/tests/saes/test_topk_sae.py +++ b/tests/saes/test_topk_sae.py @@ -154,13 +154,14 @@ def test_topK_activation_sparse_intermediate(): d_sae = 1024 M = 128 for k in [1, 10, 100, 1000]: - topk = TopK(k) + topk_sparse = TopK(k, sparse_intermediate=True) + topk_dense = TopK(k, sparse_intermediate=False) x = torch.randn(M, d_sae) + 50.0 - sparse_x = topk(x, sparse_intermediate=True) + sparse_x = topk_sparse(x) assert sparse_x.is_sparse assert sparse_x.shape == (M, d_sae) assert sparse_x.coalesce().values().numel() == k * M - dense_x = topk(x, sparse_intermediate=False) + dense_x = topk_dense(x) assert_close(dense_x, sparse_x.to_dense()) @@ -187,10 +188,33 @@ def test_topK_activation_sparse_mm(): sae.b_enc.data = sae.b_enc + 100.0 for k in [1, 10, 100, 1000]: - topk = TopK(k) + topk_sparse = TopK(k, sparse_intermediate=True) + topk_dense = TopK(k, sparse_intermediate=False) x = torch.randn(M, d_sae) + 50.0 - sparse_x = topk(x, sparse_intermediate=True) + sparse_x = topk_sparse(x) sae_out_sparse = sae.decode(sparse_x) - dense_x = topk(x, sparse_intermediate=False) + dense_x = topk_dense(x) sae_out_dense = sae.decode(dense_x) assert_close(sae_out_sparse, sae_out_dense, rtol=1e-4, atol=5e-4) + + +def test_topK_activation_sparse_config(): + cfg = build_topk_sae_cfg(k=100, sparse_intermediate=True) + sae = TopKSAE(cfg) + assert sae.activation_fn.sparse_intermediate + assert sae.cfg.sparse_intermediate + + cfg = build_topk_sae_cfg(k=100, sparse_intermediate=False) + sae = TopKSAE(cfg) + assert not sae.activation_fn.sparse_intermediate + assert not sae.cfg.sparse_intermediate + + cfg = build_topk_sae_training_cfg(k=100, sparse_intermediate=True) + sae = TopKTrainingSAE(cfg) + assert sae.activation_fn.sparse_intermediate + assert sae.cfg.sparse_intermediate + + cfg = build_topk_sae_training_cfg(k=100, sparse_intermediate=False) + sae = TopKTrainingSAE(cfg) + assert not sae.activation_fn.sparse_intermediate + assert not sae.cfg.sparse_intermediate From f7764034ec47488170e71fe9ceae1b1a521ad860 Mon Sep 17 00:00:00 2001 From: wz-ml Date: Thu, 18 Sep 2025 13:27:39 -0700 Subject: [PATCH 04/19] Linting changes --- sae_lens/pretokenize_runner.py | 3 ++- sae_lens/saes/kernels/fused_gemm_topk.py | 29 ------------------------ sae_lens/saes/sae.py | 19 ++++++++-------- sae_lens/tokenization_and_batching.py | 2 +- sae_lens/training/types.py | 2 +- tests/saes/test_standard_sae.py | 4 ++++ tests/saes/test_topk_sae.py | 16 +++++++------ 7 files changed, 26 insertions(+), 49 deletions(-) delete mode 100644 sae_lens/saes/kernels/fused_gemm_topk.py diff --git a/sae_lens/pretokenize_runner.py b/sae_lens/pretokenize_runner.py index e50cb1ef9..e3c45c370 100644 --- a/sae_lens/pretokenize_runner.py +++ b/sae_lens/pretokenize_runner.py @@ -1,9 +1,10 @@ import io import json import sys +from collections.abc import Iterator from dataclasses import dataclass from pathlib import Path -from typing import Iterator, Literal, cast +from typing import Literal, cast import torch from datasets import Dataset, DatasetDict, load_dataset diff --git a/sae_lens/saes/kernels/fused_gemm_topk.py b/sae_lens/saes/kernels/fused_gemm_topk.py deleted file mode 100644 index 0c75ad6ef..000000000 --- a/sae_lens/saes/kernels/fused_gemm_topk.py +++ /dev/null @@ -1,29 +0,0 @@ -import torch -import triton - - -def fused_gemm_topk( - x: torch.Tensor, - W: torch.Tensor, - b: torch.Tensor, - k: int, -): - """ - Mathematically, equates to: - y = x @ W.T + b - Before setting all but the topK elements of y to 0, and returning y. - - Params: - x: (M, d_hidden) - W: (d_sae, d_hidden) - b: (d_sae) - k: int - - Returns: - y: (M, d_sae) - """ - - -@triton.jit -def _fused_gemm_topk_kernel(): - pass diff --git a/sae_lens/saes/sae.py b/sae_lens/saes/sae.py index 795c01385..72f424428 100644 --- a/sae_lens/saes/sae.py +++ b/sae_lens/saes/sae.py @@ -14,7 +14,6 @@ Generic, Literal, NamedTuple, - Type, TypeVar, ) @@ -534,7 +533,7 @@ def save_model(self, path: str | Path) -> tuple[Path, Path]: @classmethod @deprecated("Use load_from_disk instead") def load_from_pretrained( - cls: Type[T_SAE], + cls: type[T_SAE], path: str | Path, device: str = "cpu", dtype: str | None = None, @@ -543,7 +542,7 @@ def load_from_pretrained( @classmethod def load_from_disk( - cls: Type[T_SAE], + cls: type[T_SAE], path: str | Path, device: str = "cpu", dtype: str | None = None, @@ -564,7 +563,7 @@ def load_from_disk( @classmethod def from_pretrained( - cls: Type[T_SAE], + cls: type[T_SAE], release: str, sae_id: str, device: str = "cpu", @@ -585,7 +584,7 @@ def from_pretrained( @classmethod def from_pretrained_with_cfg_and_sparsity( - cls: Type[T_SAE], + cls: type[T_SAE], release: str, sae_id: str, device: str = "cpu", @@ -684,7 +683,7 @@ def from_pretrained_with_cfg_and_sparsity( return sae, cfg_dict, log_sparsities @classmethod - def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE: + def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE: """Create an SAE from a config dictionary.""" sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"]) sae_config_cls = cls.get_sae_config_class_for_architecture( @@ -694,8 +693,8 @@ def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE: @classmethod def get_sae_class_for_architecture( - cls: Type[T_SAE], architecture: str - ) -> Type[T_SAE]: + cls: type[T_SAE], architecture: str + ) -> type[T_SAE]: """Get the SAE class for a given architecture.""" sae_cls, _ = get_sae_class(architecture) if not issubclass(sae_cls, cls): @@ -1000,8 +999,8 @@ def log_histograms(self) -> dict[str, NDArray[Any]]: @classmethod def get_sae_class_for_architecture( - cls: Type[T_TRAINING_SAE], architecture: str - ) -> Type[T_TRAINING_SAE]: + cls: type[T_TRAINING_SAE], architecture: str + ) -> type[T_TRAINING_SAE]: """Get the SAE class for a given architecture.""" sae_cls, _ = get_sae_training_class(architecture) if not issubclass(sae_cls, cls): diff --git a/sae_lens/tokenization_and_batching.py b/sae_lens/tokenization_and_batching.py index d2d5de201..f1aedacae 100644 --- a/sae_lens/tokenization_and_batching.py +++ b/sae_lens/tokenization_and_batching.py @@ -1,4 +1,4 @@ -from typing import Generator, Iterator +from collections.abc import Generator, Iterator import torch diff --git a/sae_lens/training/types.py b/sae_lens/training/types.py index d42a6e546..23796590e 100644 --- a/sae_lens/training/types.py +++ b/sae_lens/training/types.py @@ -1,4 +1,4 @@ -from typing import Iterator +from collections.abc import Iterator import torch diff --git a/tests/saes/test_standard_sae.py b/tests/saes/test_standard_sae.py index 95d397374..8068193ce 100644 --- a/tests/saes/test_standard_sae.py +++ b/tests/saes/test_standard_sae.py @@ -436,9 +436,13 @@ def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): unscaled_activations = activations / norm_scaling_factor feature_activations_1 = sae.encode(activations) + if feature_activations_1.is_sparse: + feature_activations_1 = feature_activations_1.to_dense() # with the scaling folded in, the unscaled activations should produce the same # result. feature_activations_2 = sae2.encode(unscaled_activations) + if feature_activations_2.is_sparse: + feature_activations_2 = feature_activations_2.to_dense() assert_close( feature_activations_1.nonzero(), diff --git a/tests/saes/test_topk_sae.py b/tests/saes/test_topk_sae.py index 58458f338..9391234cf 100644 --- a/tests/saes/test_topk_sae.py +++ b/tests/saes/test_topk_sae.py @@ -153,14 +153,15 @@ def test_topK_activation_sparse_intermediate(): # we use to accelerate the decoder matches the dense top-K output. d_sae = 1024 M = 128 + B = 16 for k in [1, 10, 100, 1000]: topk_sparse = TopK(k, sparse_intermediate=True) topk_dense = TopK(k, sparse_intermediate=False) - x = torch.randn(M, d_sae) + 50.0 + x = torch.randn(B, M, d_sae) + 50.0 sparse_x = topk_sparse(x) assert sparse_x.is_sparse - assert sparse_x.shape == (M, d_sae) - assert sparse_x.coalesce().values().numel() == k * M + assert sparse_x.shape == (B, M, d_sae) + assert sparse_x.coalesce().values().numel() == B * M * k dense_x = topk_dense(x) assert_close(dense_x, sparse_x.to_dense()) @@ -199,22 +200,23 @@ def test_topK_activation_sparse_mm(): def test_topK_activation_sparse_config(): + # Check that our config is respected in both training & inference SAEs cfg = build_topk_sae_cfg(k=100, sparse_intermediate=True) sae = TopKSAE(cfg) - assert sae.activation_fn.sparse_intermediate + assert sae.activation_fn.sparse_intermediate # type: ignore assert sae.cfg.sparse_intermediate cfg = build_topk_sae_cfg(k=100, sparse_intermediate=False) sae = TopKSAE(cfg) - assert not sae.activation_fn.sparse_intermediate + assert not sae.activation_fn.sparse_intermediate # type: ignore assert not sae.cfg.sparse_intermediate cfg = build_topk_sae_training_cfg(k=100, sparse_intermediate=True) sae = TopKTrainingSAE(cfg) - assert sae.activation_fn.sparse_intermediate + assert sae.activation_fn.sparse_intermediate # type: ignore assert sae.cfg.sparse_intermediate cfg = build_topk_sae_training_cfg(k=100, sparse_intermediate=False) sae = TopKTrainingSAE(cfg) - assert not sae.activation_fn.sparse_intermediate + assert not sae.activation_fn.sparse_intermediate # type: ignore assert not sae.cfg.sparse_intermediate From 50747d31c4a037e0eb1f0dd8ccdea42f1bbcb821 Mon Sep 17 00:00:00 2001 From: wz-ml Date: Thu, 18 Sep 2025 13:28:23 -0700 Subject: [PATCH 05/19] Making sparse COO tensors compatible with HookedTransformer (WIP) --- sae_lens/saes/topk_sae.py | 36 +++++++++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index a263db28e..852aed347 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -38,16 +38,20 @@ def __init__( def forward( self, x: torch.Tensor, - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + ) -> torch.Tensor: """ 1) Select top K elements along the last dimension. 2) Apply ReLU. 3) Zero out all other entries. """ - topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1) + topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False) if self.sparse_intermediate: # Produce a COO sparse tensor (use sparse matrix multiply in decode) - M, N = x.shape + assert x.ndim >= 2, ( + f"Expected pre-topK tensor to have at least 2 dimensions, got tensor of shape {x.shape}" + ) + x = x.view(-1, x.shape[-1]) + M, _ = x.shape sparse_indices = torch.stack( [ torch.arange(M, device=x.device).repeat_interleave(self.k), @@ -55,7 +59,7 @@ def forward( ] ) return torch.sparse_coo_tensor( - sparse_indices, topk_values.flatten(), (M, N) + sparse_indices, topk_values.flatten(), tuple(x.shape) ) values = topk_values.relu() result = torch.zeros_like(x) @@ -103,7 +107,10 @@ def initialize_weights(self) -> None: def encode( self, x: Float[torch.Tensor, "... d_in"] - ) -> Float[torch.Tensor, "... d_sae"]: + ) -> ( + Float[torch.Tensor, "... d_sae"] + | tuple[Float[torch.Tensor, "... d_sae"], torch.Size] + ): """ Converts input x into feature activations. Uses topk activation under the hood. @@ -111,18 +118,33 @@ def encode( sae_in = self.process_sae_in(x) hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk. + if self.cfg.sparse_intermediate: + return self.hook_sae_acts_post( + self.activation_fn(hidden_pre) + ), hidden_pre.shape return self.hook_sae_acts_post(self.activation_fn(hidden_pre)) def decode( self, - feature_acts: Float[torch.Tensor, "... d_sae"], + feature_acts: Float[torch.Tensor, "... d_sae"] + | tuple[torch.Tensor, torch.Size], ) -> Float[torch.Tensor, "... d_in"]: """ Reconstructs the input from topk feature activations. Applies optional finetuning scaling, hooking to recons, out normalization, and optional head reshaping. """ - sae_out_pre = feature_acts @ self.W_dec + self.b_dec + if self.cfg.sparse_intermediate: + # Since torch.sparse.mm doesn't support dotting a 3D tensor with a 2D matrix, + # we flatten all but the last dimension of the feature activations if they're in sparse format + # before reshaping the post-decode tensor back to the correct shape. + feature_acts, original_shape = feature_acts + sae_out_pre = feature_acts @ self.W_dec + self.b_dec + sae_out_pre = sae_out_pre.reshape( + tuple(original_shape[:-1]) + (self.cfg.d_in,) # type: ignore + ) + else: + sae_out_pre = feature_acts @ self.W_dec + self.b_dec sae_out_pre = self.hook_sae_recons(sae_out_pre) sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre) return self.reshape_fn_out(sae_out_pre, self.d_head) From dfeb254e1d6f1f8a9d70dcce2bd0ce83788d6e9f Mon Sep 17 00:00:00 2001 From: wz-ml Date: Fri, 19 Sep 2025 08:59:38 -0700 Subject: [PATCH 06/19] Changes to make sparse SAE intermediate implementation transparent to TransformerLens Hooks --- sae_lens/saes/topk_sae.py | 139 +++++++++++++++++++++++++++++++----- tests/saes/test_topk_sae.py | 4 +- 2 files changed, 122 insertions(+), 21 deletions(-) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 852aed347..ce0ede9f5 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -6,6 +6,7 @@ import torch from jaxtyping import Float from torch import nn +from transformer_lens.hook_points import HookPoint from typing_extensions import override from sae_lens.saes.sae import ( @@ -15,9 +16,37 @@ TrainingSAE, TrainingSAEConfig, TrainStepInput, + _disable_hooks, ) +class SparseHookPoint(HookPoint): + """ + A HookPoint that takes in a sparse tensor. + Overrides TransformerLens's HookPoint. + """ + + def __init__(self, d_sae: int): + super().__init__() + self.d_sae = d_sae + + @override + def forward( + self, x: torch.Tensor, x_shape: torch.Size | None = None + ) -> torch.Tensor: + using_hooks = ( + self._forward_hooks is not None + or self._backward_hooks is not None + and len(self._forward_hooks) > 0 + and len(self._backward_hooks) > 0 + ) + if using_hooks and x.is_sparse: + if x_shape is None: + raise ValueError("x_shape must be provided") + return x.to_dense().reshape((x_shape[:-1]) + (self.d_sae,)) + return x # if no hooks are being used, use passthrough + + class TopK(nn.Module): """ A simple TopK activation that zeroes out all but the top K elements along the last dimension, @@ -45,6 +74,7 @@ def forward( 3) Zero out all other entries. """ topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False) + values = topk_values.relu() if self.sparse_intermediate: # Produce a COO sparse tensor (use sparse matrix multiply in decode) assert x.ndim >= 2, ( @@ -59,9 +89,8 @@ def forward( ] ) return torch.sparse_coo_tensor( - sparse_indices, topk_values.flatten(), tuple(x.shape) + sparse_indices, values.flatten(), tuple(x.shape) ) - values = topk_values.relu() result = torch.zeros_like(x) result.scatter_(-1, topk_indices, values) return result @@ -98,6 +127,9 @@ def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False): use_error_term: Whether to apply the error-term approach in the forward pass. """ super().__init__(cfg, use_error_term) + if self.cfg.sparse_intermediate: + self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae) + self.setup() @override def initialize_weights(self) -> None: @@ -107,10 +139,7 @@ def initialize_weights(self) -> None: def encode( self, x: Float[torch.Tensor, "... d_in"] - ) -> ( - Float[torch.Tensor, "... d_sae"] - | tuple[Float[torch.Tensor, "... d_sae"], torch.Size] - ): + ) -> Float[torch.Tensor, "... d_sae"]: """ Converts input x into feature activations. Uses topk activation under the hood. @@ -118,37 +147,62 @@ def encode( sae_in = self.process_sae_in(x) hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk. - if self.cfg.sparse_intermediate: + if self.cfg.sparse_intermediate and isinstance( + self.hook_sae_acts_post, + SparseHookPoint, # Necessary so we don't pass illegal arg to blank hook + ): return self.hook_sae_acts_post( - self.activation_fn(hidden_pre) - ), hidden_pre.shape + self.activation_fn(hidden_pre), x_shape=x.shape + ) return self.hook_sae_acts_post(self.activation_fn(hidden_pre)) def decode( self, - feature_acts: Float[torch.Tensor, "... d_sae"] - | tuple[torch.Tensor, torch.Size], + feature_acts: Float[torch.Tensor, "... d_sae"], + x_shape: torch.Size | None = None, ) -> Float[torch.Tensor, "... d_in"]: """ Reconstructs the input from topk feature activations. Applies optional finetuning scaling, hooking to recons, out normalization, and optional head reshaping. + + x_shape: The shape of the pre-encode input x. Used when sparse_intermediate is True. """ - if self.cfg.sparse_intermediate: + sae_out_pre = feature_acts @ self.W_dec + self.b_dec + if ( + self.cfg.sparse_intermediate + and feature_acts.is_sparse + and x_shape is not None + ): # Since torch.sparse.mm doesn't support dotting a 3D tensor with a 2D matrix, # we flatten all but the last dimension of the feature activations if they're in sparse format # before reshaping the post-decode tensor back to the correct shape. - feature_acts, original_shape = feature_acts - sae_out_pre = feature_acts @ self.W_dec + self.b_dec sae_out_pre = sae_out_pre.reshape( - tuple(original_shape[:-1]) + (self.cfg.d_in,) # type: ignore + tuple(x_shape[:-1]) + (self.cfg.d_in,) # type: ignore ) - else: - sae_out_pre = feature_acts @ self.W_dec + self.b_dec sae_out_pre = self.hook_sae_recons(sae_out_pre) sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre) return self.reshape_fn_out(sae_out_pre, self.d_head) + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the SAE.""" + feature_acts = self.encode(x) + sae_out = self.decode(feature_acts, x_shape=x.shape) + + if self.use_error_term: + with torch.no_grad(): + # Recompute without hooks for true error term + with _disable_hooks(self): + feature_acts_clean = self.encode(x) + x_reconstruct_clean = self.decode( + feature_acts_clean, x_shape=x.shape + ) + sae_error = self.hook_sae_error(x - x_reconstruct_clean) + sae_out = sae_out + sae_error + + return self.hook_sae_output(sae_out) + @override def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]: return TopK(self.cfg.k, sparse_intermediate=self.cfg.sparse_intermediate) @@ -186,6 +240,9 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]): def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False): super().__init__(cfg, use_error_term) + if self.cfg.sparse_intermediate: + self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae) + self.setup() @override def initialize_weights(self) -> None: @@ -202,9 +259,53 @@ def encode_with_hidden_pre( hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # Apply the TopK activation function (already set in self.activation_fn if config is "topk") - feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) + feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre), x.shape) return feature_acts, hidden_pre + @override + def decode( + self, + feature_acts: Float[torch.Tensor, "... d_sae"], + x_shape: torch.Size | None = None, + ) -> Float[torch.Tensor, "... d_in"]: + """ + Decodes feature activations back into input space, + applying optional finetuning scale, hooking, out normalization, etc. + + x_shape: The shape of the pre-encode input x. Used when sparse_intermediate is True. + """ + sae_out_pre = feature_acts @ self.W_dec + self.b_dec + if ( + self.cfg.sparse_intermediate + and feature_acts.is_sparse + and x_shape is not None + ): + sae_out_pre = sae_out_pre.reshape( + tuple(x_shape[:-1]) + (self.cfg.d_in,) # type: ignore + ) + sae_out_pre = self.hook_sae_recons(sae_out_pre) + sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre) + return self.reshape_fn_out(sae_out_pre, self.d_head) + + @override + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the SAE.""" + feature_acts = self.encode(x) + sae_out = self.decode(feature_acts, x_shape=x.shape) + + if self.use_error_term: + with torch.no_grad(): + # Recompute without hooks for true error term + with _disable_hooks(self): + feature_acts_clean = self.encode(x) + x_reconstruct_clean = self.decode( + feature_acts_clean, x_shape=x.shape + ) + sae_error = self.hook_sae_error(x - x_reconstruct_clean) + sae_out = sae_out + sae_error + + return self.hook_sae_output(sae_out) + @override def calculate_aux_loss( self, @@ -272,7 +373,7 @@ def calculate_topk_aux_loss( # Encourage the top ~50% of dead latents to predict the residual of the # top k living latents - recons = self.decode(auxk_acts) + recons = self.decode(auxk_acts, x_shape=sae_in.shape) auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean() return self.cfg.aux_loss_coefficient * scale * auxk_loss diff --git a/tests/saes/test_topk_sae.py b/tests/saes/test_topk_sae.py index 9391234cf..a8f09c44b 100644 --- a/tests/saes/test_topk_sae.py +++ b/tests/saes/test_topk_sae.py @@ -160,10 +160,10 @@ def test_topK_activation_sparse_intermediate(): x = torch.randn(B, M, d_sae) + 50.0 sparse_x = topk_sparse(x) assert sparse_x.is_sparse - assert sparse_x.shape == (B, M, d_sae) assert sparse_x.coalesce().values().numel() == B * M * k + sparse_x = sparse_x.to_dense().reshape(B, M, d_sae) dense_x = topk_dense(x) - assert_close(dense_x, sparse_x.to_dense()) + assert_close(dense_x, sparse_x) def test_topK_activation_sparse_mm(): From 1fe80605be966d06e0d78a1a473ec8ec9d73cc23 Mon Sep 17 00:00:00 2001 From: wz-ml Date: Fri, 19 Sep 2025 09:48:14 -0700 Subject: [PATCH 07/19] Add formatting script (for future optimization) --- benchmark/bench_fwd_perf.py | 139 ++++++++++++++++++++++++++++++++++++ sae_lens/saes/topk_sae.py | 16 +++-- 2 files changed, 151 insertions(+), 4 deletions(-) create mode 100644 benchmark/bench_fwd_perf.py diff --git a/benchmark/bench_fwd_perf.py b/benchmark/bench_fwd_perf.py new file mode 100644 index 000000000..314230169 --- /dev/null +++ b/benchmark/bench_fwd_perf.py @@ -0,0 +1,139 @@ +import argparse +import os +from typing import Any, Callable + +import torch +import torch._inductor.config +import triton +from sparsify import SparseCoder, SparseCoderConfig +from tabulate import tabulate + +from sae_lens.saes.sae import TrainStepInput +from sae_lens.saes.topk_sae import TopKTrainingSAE +from tests.helpers import ( + build_topk_sae_training_cfg, +) + +torch._inductor.config.coordinate_descent_tuning = True + +parser = argparse.ArgumentParser(add_help=True) +parser.add_argument("--device", type=str, default="cuda") +parser.add_argument( + "--shape", + type=int, + nargs=3, + default=[1024, 1024, 1024 * 16], + help="Shape of the input tensor (seq_len, d_in, d_sae)", +) +parser.add_argument("--k", type=int, default=100, help="Number of topk elements") +args = parser.parse_args() + +device = args.device + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +d_in = args.shape[1] +d_sae = args.shape[2] +k = args.k +seq_len = args.shape[0] + +cfg_sparse = build_topk_sae_training_cfg( + d_in=d_in, + d_sae=d_sae, + k=k, + device=device, + sparse_intermediate=True, +) +cfg_dense = build_topk_sae_training_cfg( + d_in=d_in, + d_sae=d_sae, + k=k, + device=device, + sparse_intermediate=False, +) + +sae_sparse = TopKTrainingSAE(cfg_sparse) +sae_dense = TopKTrainingSAE(cfg_dense) +sparse_coder_sae = SparseCoder( + d_in=d_in, cfg=SparseCoderConfig(num_latents=d_sae, k=26) +) + +dead_neuron_mask = None # torch.randn(d_sae, device = device) > 0.1 +input_acts = torch.randn(seq_len, d_in, device=device) +input_var = (input_acts - input_acts.mean(0)).pow(2).sum() + +step_input = TrainStepInput( + sae_in=input_acts, + dead_neuron_mask=dead_neuron_mask, + coefficients={}, +) + + +def encode_proj(sae: TopKTrainingSAE, input_acts: torch.Tensor) -> torch.Tensor: + sae_in = sae.process_sae_in(input_acts) + return sae.hook_sae_acts_pre(sae_in @ sae.W_enc + sae.b_enc) + + +def topk_activation(sae: TopKTrainingSAE, hidden_pre: torch.Tensor) -> torch.Tensor: + return sae.activation_fn(hidden_pre) + + +def decode_step(sae: TopKTrainingSAE, feature_acts: torch.Tensor) -> torch.Tensor: + return sae.decode(feature_acts) + + +def loss_computation( + sae: TopKTrainingSAE, sae_out: torch.Tensor, sae_in: torch.Tensor +) -> torch.Tensor: + # Calculate MSE loss + per_item_mse_loss = sae.mse_loss_fn(sae_out, sae_in) + return per_item_mse_loss.sum(dim=-1).mean() + + +def triton_bench(fn: Callable[[], Any]) -> float: + # note that the warmup and rep params here are in ms, not iterations + return triton.testing.do_bench(fn, warmup=1000, rep=2000) # type: ignore + + +def benchmark_sae(sae: TopKTrainingSAE) -> dict[str, float]: + results = {} + results["encode_proj"] = triton_bench(lambda: encode_proj(sae, input_acts)) + hidden_pre = encode_proj(sae, input_acts) + results["topk_activation"] = triton_bench(lambda: topk_activation(sae, hidden_pre)) + feature_acts = topk_activation(sae, hidden_pre) + results["decode_step"] = triton_bench(lambda: decode_step(sae, feature_acts)) + sae_out = decode_step(sae, feature_acts) + results["loss_computation"] = triton_bench( + lambda: loss_computation(sae, sae_out, input_acts) + ) + results["full_forward_pass"] = triton_bench( + lambda: sae.training_forward_pass(step_input) + ) + results["other"] = 2 * results["full_forward_pass"] - sum(results.values()) # type: ignore + return results + + +if __name__ == "__main__": + print("This may take a while (5 mins). Go grab a coffee!") + results_sparse = benchmark_sae(sae_sparse) + results_dense = benchmark_sae(sae_dense) + + # Pretty print results table with metrics as columns + headers = [ + "Implementation", + "Encode", + "TopK", + "Decode", + "Loss Calc", + "Full Fwd", + "Other", + ] + + metric_keys = results_sparse.keys() + + table_data = [ + ["Sparse"] + [f"{results_sparse[key]:.3f}" for key in metric_keys], + ["Dense"] + [f"{results_dense[key]:.3f}" for key in metric_keys], + ] + print("Metric: Latency (ms)") + print("\n" + tabulate(table_data, headers=headers, tablefmt="grid")) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index ce0ede9f5..250988fd0 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -77,9 +77,9 @@ def forward( values = topk_values.relu() if self.sparse_intermediate: # Produce a COO sparse tensor (use sparse matrix multiply in decode) - assert x.ndim >= 2, ( - f"Expected pre-topK tensor to have at least 2 dimensions, got tensor of shape {x.shape}" - ) + assert ( + x.ndim >= 2 + ), f"Expected pre-topK tensor to have at least 2 dimensions, got tensor of shape {x.shape}" x = x.view(-1, x.shape[-1]) M, _ = x.shape sparse_indices = torch.stack( @@ -259,7 +259,15 @@ def encode_with_hidden_pre( hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # Apply the TopK activation function (already set in self.activation_fn if config is "topk") - feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre), x.shape) + if self.cfg.sparse_intermediate and isinstance( + self.hook_sae_acts_post, + SparseHookPoint, + ): + feature_acts = self.hook_sae_acts_post( + self.activation_fn(hidden_pre), x_shape=x.shape + ) + else: + feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) return feature_acts, hidden_pre @override From ab8ead409e44b9ddf53b657b0f2ebf44d1d73dc9 Mon Sep 17 00:00:00 2001 From: wz-ml Date: Wed, 24 Sep 2025 21:12:30 -0700 Subject: [PATCH 08/19] Address PR review comments --- benchmark/bench_fwd_perf.py | 4 +- sae_lens/saes/topk_sae.py | 108 ++++++++---------- tests/helpers.py | 2 +- .../test_topk_sae_equivalence.py | 4 + tests/saes/test_topk_sae.py | 40 +++---- 5 files changed, 74 insertions(+), 84 deletions(-) diff --git a/benchmark/bench_fwd_perf.py b/benchmark/bench_fwd_perf.py index 314230169..a347f2bca 100644 --- a/benchmark/bench_fwd_perf.py +++ b/benchmark/bench_fwd_perf.py @@ -42,14 +42,14 @@ d_sae=d_sae, k=k, device=device, - sparse_intermediate=True, + use_sparse_activations=True, ) cfg_dense = build_topk_sae_training_cfg( d_in=d_in, d_sae=d_sae, k=k, device=device, - sparse_intermediate=False, + use_sparse_activations=False, ) sae_sparse = TopKTrainingSAE(cfg_sparse) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 250988fd0..589bb25de 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -31,9 +31,7 @@ def __init__(self, d_sae: int): self.d_sae = d_sae @override - def forward( - self, x: torch.Tensor, x_shape: torch.Size | None = None - ) -> torch.Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: using_hooks = ( self._forward_hooks is not None or self._backward_hooks is not None @@ -41,9 +39,7 @@ def forward( and len(self._backward_hooks) > 0 ) if using_hooks and x.is_sparse: - if x_shape is None: - raise ValueError("x_shape must be provided") - return x.to_dense().reshape((x_shape[:-1]) + (self.d_sae,)) + return x.to_dense() return x # if no hooks are being used, use passthrough @@ -53,16 +49,16 @@ class TopK(nn.Module): and applies ReLU to the top K elements. """ - sparse_intermediate: bool + use_sparse_activations: bool def __init__( self, k: int, - sparse_intermediate: bool = True, + use_sparse_activations: bool = True, ): super().__init__() self.k = k - self.sparse_intermediate = sparse_intermediate + self.use_sparse_activations = use_sparse_activations def forward( self, @@ -75,7 +71,7 @@ def forward( """ topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False) values = topk_values.relu() - if self.sparse_intermediate: + if self.use_sparse_activations: # Produce a COO sparse tensor (use sparse matrix multiply in decode) assert ( x.ndim >= 2 @@ -103,7 +99,7 @@ class TopKSAEConfig(SAEConfig): """ k: int = 100 - sparse_intermediate: bool = True + use_sparse_activations: bool = True @override @classmethod @@ -127,7 +123,7 @@ def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False): use_error_term: Whether to apply the error-term approach in the forward pass. """ super().__init__(cfg, use_error_term) - if self.cfg.sparse_intermediate: + if self.cfg.use_sparse_activations: self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae) self.setup() @@ -147,57 +143,48 @@ def encode( sae_in = self.process_sae_in(x) hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # The BaseSAE already sets self.activation_fn to TopK(...) if config requests topk. - if self.cfg.sparse_intermediate and isinstance( - self.hook_sae_acts_post, - SparseHookPoint, # Necessary so we don't pass illegal arg to blank hook - ): - return self.hook_sae_acts_post( - self.activation_fn(hidden_pre), x_shape=x.shape - ) return self.hook_sae_acts_post(self.activation_fn(hidden_pre)) def decode( self, feature_acts: Float[torch.Tensor, "... d_sae"], - x_shape: torch.Size | None = None, ) -> Float[torch.Tensor, "... d_in"]: """ Reconstructs the input from topk feature activations. Applies optional finetuning scaling, hooking to recons, out normalization, and optional head reshaping. - - x_shape: The shape of the pre-encode input x. Used when sparse_intermediate is True. """ - sae_out_pre = feature_acts @ self.W_dec + self.b_dec - if ( - self.cfg.sparse_intermediate - and feature_acts.is_sparse - and x_shape is not None - ): - # Since torch.sparse.mm doesn't support dotting a 3D tensor with a 2D matrix, - # we flatten all but the last dimension of the feature activations if they're in sparse format - # before reshaping the post-decode tensor back to the correct shape. - sae_out_pre = sae_out_pre.reshape( - tuple(x_shape[:-1]) + (self.cfg.d_in,) # type: ignore + if self.cfg.use_sparse_activations and feature_acts.ndim >= 3: + raise ValueError( + "Sparse activations are only supported for 2D activations. Use .disable_sparse_activations() to support arbitrary activation dims." ) + sae_out_pre = feature_acts @ self.W_dec + self.b_dec sae_out_pre = self.hook_sae_recons(sae_out_pre) sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre) return self.reshape_fn_out(sae_out_pre, self.d_head) + def disable_sparse_activations(self) -> None: + self.cfg.use_sparse_activations = False + if isinstance(self.activation_fn, TopK): + self.activation_fn.use_sparse_activations = False + + def enable_sparse_activations(self) -> None: + self.cfg.use_sparse_activations = True + if isinstance(self.activation_fn, TopK): + self.activation_fn.use_sparse_activations = True + @override def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the SAE.""" feature_acts = self.encode(x) - sae_out = self.decode(feature_acts, x_shape=x.shape) + sae_out = self.decode(feature_acts) if self.use_error_term: with torch.no_grad(): # Recompute without hooks for true error term with _disable_hooks(self): feature_acts_clean = self.encode(x) - x_reconstruct_clean = self.decode( - feature_acts_clean, x_shape=x.shape - ) + x_reconstruct_clean = self.decode(feature_acts_clean) sae_error = self.hook_sae_error(x - x_reconstruct_clean) sae_out = sae_out + sae_error @@ -205,7 +192,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @override def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]: - return TopK(self.cfg.k, sparse_intermediate=self.cfg.sparse_intermediate) + return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations) @override @torch.no_grad() @@ -222,7 +209,7 @@ class TopKTrainingSAEConfig(TrainingSAEConfig): """ k: int = 100 - sparse_intermediate: bool = True + use_sparse_activations: bool = True aux_loss_coefficient: float = 1.0 @override @@ -240,7 +227,7 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]): def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False): super().__init__(cfg, use_error_term) - if self.cfg.sparse_intermediate: + if self.cfg.use_sparse_activations: self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae) self.setup() @@ -259,13 +246,11 @@ def encode_with_hidden_pre( hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # Apply the TopK activation function (already set in self.activation_fn if config is "topk") - if self.cfg.sparse_intermediate and isinstance( + if self.cfg.use_sparse_activations and isinstance( self.hook_sae_acts_post, SparseHookPoint, ): - feature_acts = self.hook_sae_acts_post( - self.activation_fn(hidden_pre), x_shape=x.shape - ) + feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) else: feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) return feature_acts, hidden_pre @@ -274,41 +259,42 @@ def encode_with_hidden_pre( def decode( self, feature_acts: Float[torch.Tensor, "... d_sae"], - x_shape: torch.Size | None = None, ) -> Float[torch.Tensor, "... d_in"]: """ Decodes feature activations back into input space, applying optional finetuning scale, hooking, out normalization, etc. - - x_shape: The shape of the pre-encode input x. Used when sparse_intermediate is True. """ - sae_out_pre = feature_acts @ self.W_dec + self.b_dec - if ( - self.cfg.sparse_intermediate - and feature_acts.is_sparse - and x_shape is not None - ): - sae_out_pre = sae_out_pre.reshape( - tuple(x_shape[:-1]) + (self.cfg.d_in,) # type: ignore + if self.cfg.use_sparse_activations and feature_acts.ndim >= 3: + raise ValueError( + "Sparse activations are only supported for 2D activations. Use .disable_sparse_activations() to support arbitrary activation dims." ) + sae_out_pre = feature_acts @ self.W_dec + self.b_dec sae_out_pre = self.hook_sae_recons(sae_out_pre) sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre) return self.reshape_fn_out(sae_out_pre, self.d_head) + def disable_sparse_activations(self) -> None: + self.cfg.use_sparse_activations = False + if isinstance(self.activation_fn, TopK): + self.activation_fn.use_sparse_activations = False + + def enable_sparse_activations(self) -> None: + self.cfg.use_sparse_activations = True + if isinstance(self.activation_fn, TopK): + self.activation_fn.use_sparse_activations = True + @override def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the SAE.""" feature_acts = self.encode(x) - sae_out = self.decode(feature_acts, x_shape=x.shape) + sae_out = self.decode(feature_acts) if self.use_error_term: with torch.no_grad(): # Recompute without hooks for true error term with _disable_hooks(self): feature_acts_clean = self.encode(x) - x_reconstruct_clean = self.decode( - feature_acts_clean, x_shape=x.shape - ) + x_reconstruct_clean = self.decode(feature_acts_clean) sae_error = self.hook_sae_error(x - x_reconstruct_clean) sae_out = sae_out + sae_error @@ -340,7 +326,7 @@ def fold_W_dec_norm(self) -> None: @override def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]: - return TopK(self.cfg.k, sparse_intermediate=self.cfg.sparse_intermediate) + return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations) @override def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]: @@ -381,7 +367,7 @@ def calculate_topk_aux_loss( # Encourage the top ~50% of dead latents to predict the residual of the # top k living latents - recons = self.decode(auxk_acts, x_shape=sae_in.shape) + recons = self.decode(auxk_acts) auxk_loss = (recons - residual).pow(2).sum(dim=-1).mean() return self.cfg.aux_loss_coefficient * scale * auxk_loss diff --git a/tests/helpers.py b/tests/helpers.py index 5dee87a3e..d071f41ec 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -98,7 +98,7 @@ class TrainingSAEConfigDict(TypedDict, total=False): jumprelu_init_threshold: float jumprelu_bandwidth: float k: int # For TopK - sparse_intermediate: bool # For TopK + use_sparse_activations: bool # For TopK l0_coefficient: float # For JumpReLU l0_warm_up_steps: int pre_act_loss_coefficient: float | None # For JumpReLU diff --git a/tests/refactor_compatibility/test_topk_sae_equivalence.py b/tests/refactor_compatibility/test_topk_sae_equivalence.py index e7b5f7700..95f2c7c4a 100644 --- a/tests/refactor_compatibility/test_topk_sae_equivalence.py +++ b/tests/refactor_compatibility/test_topk_sae_equivalence.py @@ -119,6 +119,7 @@ def test_topk_sae_inference_equivalence(): """ old_sae = make_old_topk_sae(d_in=16, d_sae=8, use_error_term=False) new_sae = make_new_topk_sae(d_in=16, d_sae=8, use_error_term=False) + new_sae.disable_sparse_activations() compare_params(old_sae, new_sae) @@ -150,6 +151,7 @@ def test_topk_sae_inference_equivalence(): # Now test with error_term old_sae_err = make_old_topk_sae(d_in=16, d_sae=8, use_error_term=True) new_sae_err = make_new_topk_sae(d_in=16, d_sae=8, use_error_term=True) + new_sae_err.disable_sparse_activations() # Align error term model parameters with torch.no_grad(): @@ -182,6 +184,7 @@ def test_topk_sae_run_with_cache_equivalence(): # type: ignore old_sae = make_old_topk_sae() new_sae = make_new_topk_sae() + new_sae.disable_sparse_activations() # Ensure parameters are identical before comparing outputs with torch.no_grad(): @@ -269,6 +272,7 @@ def test_topk_sae_training_equivalence(): d_sae=8, ) new_training_sae = TopKTrainingSAE(new_training_cfg) + new_training_sae.disable_sparse_activations() # Compare param shapes using updated compare_params compare_params(old_training_sae, new_training_sae) diff --git a/tests/saes/test_topk_sae.py b/tests/saes/test_topk_sae.py index a8f09c44b..69973ece1 100644 --- a/tests/saes/test_topk_sae.py +++ b/tests/saes/test_topk_sae.py @@ -17,11 +17,11 @@ def test_TopKTrainingSAE_topk_aux_loss_matches_unnormalized_sparsify_implementation(): d_in = 128 d_sae = 192 + k = 26 cfg = build_topk_sae_training_cfg( d_in=d_in, d_sae=d_sae, - k=26, - decoder_init_norm=1.0, # TODO: why is this needed?? + k=k, ) sae = TopKTrainingSAE(cfg) @@ -148,15 +148,15 @@ def test_TopKTrainingSAE_save_and_load_inference_sae(tmp_path: Path) -> None: assert_close(training_full_out, inference_full_out) -def test_topK_activation_sparse_intermediate(): +def test_topK_sparse_activations(): # Validate that the sparse top-K intermediate output (COO format) # we use to accelerate the decoder matches the dense top-K output. d_sae = 1024 M = 128 B = 16 for k in [1, 10, 100, 1000]: - topk_sparse = TopK(k, sparse_intermediate=True) - topk_dense = TopK(k, sparse_intermediate=False) + topk_sparse = TopK(k, use_sparse_activations=True) + topk_dense = TopK(k, use_sparse_activations=False) x = torch.randn(B, M, d_sae) + 50.0 sparse_x = topk_sparse(x) assert sparse_x.is_sparse @@ -189,8 +189,8 @@ def test_topK_activation_sparse_mm(): sae.b_enc.data = sae.b_enc + 100.0 for k in [1, 10, 100, 1000]: - topk_sparse = TopK(k, sparse_intermediate=True) - topk_dense = TopK(k, sparse_intermediate=False) + topk_sparse = TopK(k, use_sparse_activations=True) + topk_dense = TopK(k, use_sparse_activations=False) x = torch.randn(M, d_sae) + 50.0 sparse_x = topk_sparse(x) sae_out_sparse = sae.decode(sparse_x) @@ -199,24 +199,24 @@ def test_topK_activation_sparse_mm(): assert_close(sae_out_sparse, sae_out_dense, rtol=1e-4, atol=5e-4) -def test_topK_activation_sparse_config(): +def test_topK_sparse_activations_config(): # Check that our config is respected in both training & inference SAEs - cfg = build_topk_sae_cfg(k=100, sparse_intermediate=True) + cfg = build_topk_sae_cfg(k=100, use_sparse_activations=True) sae = TopKSAE(cfg) - assert sae.activation_fn.sparse_intermediate # type: ignore - assert sae.cfg.sparse_intermediate + assert sae.activation_fn.use_sparse_activations # type: ignore + assert sae.cfg.use_sparse_activations - cfg = build_topk_sae_cfg(k=100, sparse_intermediate=False) + cfg = build_topk_sae_cfg(k=100, use_sparse_activations=False) sae = TopKSAE(cfg) - assert not sae.activation_fn.sparse_intermediate # type: ignore - assert not sae.cfg.sparse_intermediate + assert not sae.activation_fn.use_sparse_activations # type: ignore + assert not sae.cfg.use_sparse_activations - cfg = build_topk_sae_training_cfg(k=100, sparse_intermediate=True) + cfg = build_topk_sae_training_cfg(k=100, use_sparse_activations=True) sae = TopKTrainingSAE(cfg) - assert sae.activation_fn.sparse_intermediate # type: ignore - assert sae.cfg.sparse_intermediate + assert sae.activation_fn.use_sparse_activations # type: ignore + assert sae.cfg.use_sparse_activations - cfg = build_topk_sae_training_cfg(k=100, sparse_intermediate=False) + cfg = build_topk_sae_training_cfg(k=100, use_sparse_activations=False) sae = TopKTrainingSAE(cfg) - assert not sae.activation_fn.sparse_intermediate # type: ignore - assert not sae.cfg.sparse_intermediate + assert not sae.activation_fn.use_sparse_activations # type: ignore + assert not sae.cfg.use_sparse_activations From a442072014acc97398eae5de1b2be2df09bdfb2b Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sun, 28 Sep 2025 00:27:45 +0100 Subject: [PATCH 09/19] allow multidim sparsity in topk saes --- benchmark/bench_fwd_perf.py | 4 - sae_lens/evals.py | 2 + sae_lens/saes/topk_sae.py | 167 +++++++++++------- sae_lens/training/sae_trainer.py | 10 +- tests/helpers.py | 32 ++-- .../test_topk_sae_equivalence.py | 6 +- tests/saes/test_topk_sae.py | 43 ++--- tests/test_evals.py | 30 ++++ tests/test_util.py | 5 +- 9 files changed, 188 insertions(+), 111 deletions(-) diff --git a/benchmark/bench_fwd_perf.py b/benchmark/bench_fwd_perf.py index a347f2bca..cf41e0d22 100644 --- a/benchmark/bench_fwd_perf.py +++ b/benchmark/bench_fwd_perf.py @@ -5,7 +5,6 @@ import torch import torch._inductor.config import triton -from sparsify import SparseCoder, SparseCoderConfig from tabulate import tabulate from sae_lens.saes.sae import TrainStepInput @@ -54,9 +53,6 @@ sae_sparse = TopKTrainingSAE(cfg_sparse) sae_dense = TopKTrainingSAE(cfg_dense) -sparse_coder_sae = SparseCoder( - d_in=d_in, cfg=SparseCoderConfig(num_latents=d_sae, k=26) -) dead_neuron_mask = None # torch.randn(d_sae, device = device) > 0.1 input_acts = torch.randn(seq_len, d_in, device=device) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 569c2dfee..1e7e63d79 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -466,6 +466,8 @@ def get_sparsity_and_variance_metrics( sae_out_scaled = sae.decode(sae_feature_activations).to( original_act_scaled.device ) + if sae_feature_activations.is_sparse: + sae_feature_activations = sae_feature_activations.to_dense() del cache sae_out = activation_scaler.unscale(sae_out_scaled) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 589bb25de..e5644abf4 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -33,11 +33,8 @@ def __init__(self, d_sae: int): @override def forward(self, x: torch.Tensor) -> torch.Tensor: using_hooks = ( - self._forward_hooks is not None - or self._backward_hooks is not None - and len(self._forward_hooks) > 0 - and len(self._backward_hooks) > 0 - ) + self._forward_hooks is not None and len(self._forward_hooks) > 0 + ) or (self._backward_hooks is not None and len(self._backward_hooks) > 0) if using_hooks and x.is_sparse: return x.to_dense() return x # if no hooks are being used, use passthrough @@ -73,19 +70,45 @@ def forward( values = topk_values.relu() if self.use_sparse_activations: # Produce a COO sparse tensor (use sparse matrix multiply in decode) - assert ( - x.ndim >= 2 - ), f"Expected pre-topK tensor to have at least 2 dimensions, got tensor of shape {x.shape}" - x = x.view(-1, x.shape[-1]) - M, _ = x.shape - sparse_indices = torch.stack( - [ - torch.arange(M, device=x.device).repeat_interleave(self.k), - topk_indices.flatten(), - ] - ) + original_shape = x.shape + + # Create indices for all dimensions + # For each element in topk_indices, we need to map it back to the original tensor coordinates + batch_dims = original_shape[:-1] # All dimensions except the last one + num_batch_elements = torch.prod(torch.tensor(batch_dims)).item() + + # Create batch indices - each batch element repeated k times + batch_indices_flat = torch.arange( + num_batch_elements, device=x.device + ).repeat_interleave(self.k) + + # Convert flat batch indices back to multi-dimensional indices + if len(batch_dims) == 1: + # 2D case: [batch, features] + sparse_indices = torch.stack( + [ + batch_indices_flat, + topk_indices.flatten(), + ] + ) + else: + # 3D+ case: need to unravel the batch indices + batch_indices_multi = [] + remaining = batch_indices_flat + for dim_size in reversed(batch_dims): + batch_indices_multi.append(remaining % dim_size) + remaining = remaining // dim_size + batch_indices_multi.reverse() + + sparse_indices = torch.stack( + [ + *batch_indices_multi, + topk_indices.flatten(), + ] + ) + return torch.sparse_coo_tensor( - sparse_indices, values.flatten(), tuple(x.shape) + sparse_indices, values.flatten(), original_shape ) result = torch.zeros_like(x) result.scatter_(-1, topk_indices, values) @@ -99,7 +122,6 @@ class TopKSAEConfig(SAEConfig): """ k: int = 100 - use_sparse_activations: bool = True @override @classmethod @@ -107,6 +129,59 @@ def architecture(cls) -> str: return "topk" +def _sparse_matmul_nd( + sparse_tensor: torch.Tensor, dense_matrix: torch.Tensor +) -> torch.Tensor: + """ + Multiply a sparse tensor of shape [..., d_sae] with a dense matrix of shape [d_sae, d_out] + to get a result of shape [..., d_out]. + + This function handles sparse tensors with arbitrary batch dimensions by flattening + the batch dimensions, performing 2D sparse matrix multiplication, and reshaping back. + """ + original_shape = sparse_tensor.shape + batch_dims = original_shape[:-1] + d_sae = original_shape[-1] + d_out = dense_matrix.shape[-1] + + if sparse_tensor.ndim == 2: + # Simple 2D case - use torch.sparse.mm directly + return torch.sparse.mm(sparse_tensor, dense_matrix) + + # For 3D+ case, reshape to 2D, multiply, then reshape back + batch_size = int(torch.prod(torch.tensor(batch_dims)).item()) + + # Ensure tensor is coalesced for efficient access to indices/values + if not sparse_tensor.is_coalesced(): + sparse_tensor = sparse_tensor.coalesce() + + # Get indices and values + indices = sparse_tensor.indices() # [ndim, nnz] + values = sparse_tensor.values() # [nnz] + + # Convert multi-dimensional batch indices to flat indices + flat_batch_indices = torch.zeros_like(indices[0]) + multiplier = 1 + for i in reversed(range(len(batch_dims))): + flat_batch_indices += indices[i] * multiplier + multiplier *= batch_dims[i] + + # Create 2D sparse tensor indices [batch_flat, feature] + sparse_2d_indices = torch.stack([flat_batch_indices, indices[-1]]) + + # Create 2D sparse tensor + sparse_2d = torch.sparse_coo_tensor( + sparse_2d_indices, values, (batch_size, d_sae) + ).coalesce() + + # Do the matrix multiplication + result_2d = torch.sparse.mm(sparse_2d, dense_matrix) # [batch_size, d_out] + + # Reshape back to original batch dimensions + result_shape = tuple(batch_dims) + (d_out,) + return result_2d.view(result_shape) + + class TopKSAE(SAE[TopKSAEConfig]): """ An inference-only sparse autoencoder using a "topk" activation function. @@ -123,9 +198,6 @@ def __init__(self, cfg: TopKSAEConfig, use_error_term: bool = False): use_error_term: Whether to apply the error-term approach in the forward pass. """ super().__init__(cfg, use_error_term) - if self.cfg.use_sparse_activations: - self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae) - self.setup() @override def initialize_weights(self) -> None: @@ -154,25 +226,15 @@ def decode( Applies optional finetuning scaling, hooking to recons, out normalization, and optional head reshaping. """ - if self.cfg.use_sparse_activations and feature_acts.ndim >= 3: - raise ValueError( - "Sparse activations are only supported for 2D activations. Use .disable_sparse_activations() to support arbitrary activation dims." - ) - sae_out_pre = feature_acts @ self.W_dec + self.b_dec + # Handle sparse tensors using efficient sparse matrix multiplication + if feature_acts.is_sparse: + sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec + else: + sae_out_pre = feature_acts @ self.W_dec + self.b_dec sae_out_pre = self.hook_sae_recons(sae_out_pre) sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre) return self.reshape_fn_out(sae_out_pre, self.d_head) - def disable_sparse_activations(self) -> None: - self.cfg.use_sparse_activations = False - if isinstance(self.activation_fn, TopK): - self.activation_fn.use_sparse_activations = False - - def enable_sparse_activations(self) -> None: - self.cfg.use_sparse_activations = True - if isinstance(self.activation_fn, TopK): - self.activation_fn.use_sparse_activations = True - @override def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the SAE.""" @@ -192,7 +254,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @override def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]: - return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations) + return TopK(self.cfg.k, use_sparse_activations=False) @override @torch.no_grad() @@ -227,8 +289,7 @@ class TopKTrainingSAE(TrainingSAE[TopKTrainingSAEConfig]): def __init__(self, cfg: TopKTrainingSAEConfig, use_error_term: bool = False): super().__init__(cfg, use_error_term) - if self.cfg.use_sparse_activations: - self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae) + self.hook_sae_acts_post = SparseHookPoint(self.cfg.d_sae) self.setup() @override @@ -246,13 +307,7 @@ def encode_with_hidden_pre( hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) # Apply the TopK activation function (already set in self.activation_fn if config is "topk") - if self.cfg.use_sparse_activations and isinstance( - self.hook_sae_acts_post, - SparseHookPoint, - ): - feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) - else: - feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) + feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) return feature_acts, hidden_pre @override @@ -264,25 +319,15 @@ def decode( Decodes feature activations back into input space, applying optional finetuning scale, hooking, out normalization, etc. """ - if self.cfg.use_sparse_activations and feature_acts.ndim >= 3: - raise ValueError( - "Sparse activations are only supported for 2D activations. Use .disable_sparse_activations() to support arbitrary activation dims." - ) - sae_out_pre = feature_acts @ self.W_dec + self.b_dec + # Handle sparse tensors using efficient sparse matrix multiplication + if feature_acts.is_sparse: + sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec + else: + sae_out_pre = feature_acts @ self.W_dec + self.b_dec sae_out_pre = self.hook_sae_recons(sae_out_pre) sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre) return self.reshape_fn_out(sae_out_pre, self.d_head) - def disable_sparse_activations(self) -> None: - self.cfg.use_sparse_activations = False - if isinstance(self.activation_fn, TopK): - self.activation_fn.use_sparse_activations = False - - def enable_sparse_activations(self) -> None: - self.cfg.use_sparse_activations = True - if isinstance(self.activation_fn, TopK): - self.activation_fn.use_sparse_activations = True - @override def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the SAE.""" diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 1e0e68dc5..eaa064ab2 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -253,12 +253,14 @@ def _train_step( ) with torch.no_grad(): - did_fire = (train_step_output.feature_acts > 0).float().sum(-2) > 0 + # calling .bool() should be equivalent to .abs() > 0, and work with coo tensors + firing_feats = train_step_output.feature_acts.bool().float() + did_fire = firing_feats.sum(-2).bool() + if did_fire.is_sparse: + did_fire = did_fire.to_dense() self.n_forward_passes_since_fired += 1 self.n_forward_passes_since_fired[did_fire] = 0 - self.act_freq_scores += ( - (train_step_output.feature_acts.abs() > 0).float().sum(0) - ) + self.act_freq_scores += firing_feats.sum(0) self.n_frac_active_samples += self.cfg.train_batch_size_samples # Grad scaler will rescale gradients if autocast is enabled diff --git a/tests/helpers.py b/tests/helpers.py index d071f41ec..b4ab6a366 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -235,6 +235,15 @@ def _build_runner_config( return final_config +def _update_sae_metadata(runner_cfg: LanguageModelSAERunnerConfig[Any]): + runner_cfg.sae.metadata.hook_name = runner_cfg.hook_name + runner_cfg.sae.metadata.hook_head_index = runner_cfg.hook_head_index + runner_cfg.sae.metadata.model_name = runner_cfg.model_name + runner_cfg.sae.metadata.model_class_name = runner_cfg.model_class_name + runner_cfg.sae.metadata.dataset_path = runner_cfg.dataset_path + runner_cfg.sae.metadata.prepend_bos = runner_cfg.prepend_bos + + # --- Standard SAE Builder --- def build_runner_cfg( **kwargs: Any, @@ -257,12 +266,7 @@ def build_runner_cfg( cast(dict[str, Any], default_sae_config), **kwargs, ) - runner_cfg.sae.metadata.hook_name = runner_cfg.hook_name - runner_cfg.sae.metadata.hook_head_index = runner_cfg.hook_head_index - runner_cfg.sae.metadata.model_name = runner_cfg.model_name - runner_cfg.sae.metadata.model_class_name = runner_cfg.model_class_name - runner_cfg.sae.metadata.dataset_path = runner_cfg.dataset_path - runner_cfg.sae.metadata.prepend_bos = runner_cfg.prepend_bos + _update_sae_metadata(runner_cfg) return runner_cfg @@ -302,11 +306,13 @@ def build_jumprelu_runner_cfg( "l0_warm_up_steps": 0, "pre_act_loss_coefficient": None, } - return _build_runner_config( + runner_cfg = _build_runner_config( JumpReLUTrainingSAEConfig, cast(dict[str, Any], default_sae_config), **kwargs, ) + _update_sae_metadata(runner_cfg) + return runner_cfg def build_jumprelu_sae_cfg(**kwargs: Any) -> JumpReLUSAEConfig: @@ -340,11 +346,13 @@ def build_gated_runner_cfg( "apply_b_dec_to_input": False, "l1_warm_up_steps": 0, } - return _build_runner_config( + runner_cfg = _build_runner_config( GatedTrainingSAEConfig, cast(dict[str, Any], default_sae_config), **kwargs, ) + _update_sae_metadata(runner_cfg) + return runner_cfg def build_gated_sae_cfg(**kwargs: Any) -> GatedSAEConfig: @@ -385,11 +393,13 @@ def build_topk_runner_cfg( # Update the default config *before* passing it to _build_runner_config final_default_sae_config = cast(dict[str, Any], temp_sae_config) - return _build_runner_config( + runner_cfg = _build_runner_config( TopKTrainingSAEConfig, final_default_sae_config, **kwargs, ) + _update_sae_metadata(runner_cfg) + return runner_cfg def build_topk_sae_cfg(**kwargs: Any) -> TopKSAEConfig: @@ -432,11 +442,13 @@ def build_batchtopk_runner_cfg( # Update the default config *before* passing it to _build_runner_config final_default_sae_config = cast(dict[str, Any], temp_sae_config) - return _build_runner_config( + runner_cfg = _build_runner_config( BatchTopKTrainingSAEConfig, final_default_sae_config, **kwargs, ) + _update_sae_metadata(runner_cfg) + return runner_cfg def build_batchtopk_sae_training_cfg(**kwargs: Any) -> BatchTopKTrainingSAEConfig: diff --git a/tests/refactor_compatibility/test_topk_sae_equivalence.py b/tests/refactor_compatibility/test_topk_sae_equivalence.py index 95f2c7c4a..a5f5896ba 100644 --- a/tests/refactor_compatibility/test_topk_sae_equivalence.py +++ b/tests/refactor_compatibility/test_topk_sae_equivalence.py @@ -119,7 +119,6 @@ def test_topk_sae_inference_equivalence(): """ old_sae = make_old_topk_sae(d_in=16, d_sae=8, use_error_term=False) new_sae = make_new_topk_sae(d_in=16, d_sae=8, use_error_term=False) - new_sae.disable_sparse_activations() compare_params(old_sae, new_sae) @@ -151,7 +150,6 @@ def test_topk_sae_inference_equivalence(): # Now test with error_term old_sae_err = make_old_topk_sae(d_in=16, d_sae=8, use_error_term=True) new_sae_err = make_new_topk_sae(d_in=16, d_sae=8, use_error_term=True) - new_sae_err.disable_sparse_activations() # Align error term model parameters with torch.no_grad(): @@ -184,7 +182,6 @@ def test_topk_sae_run_with_cache_equivalence(): # type: ignore old_sae = make_old_topk_sae() new_sae = make_new_topk_sae() - new_sae.disable_sparse_activations() # Ensure parameters are identical before comparing outputs with torch.no_grad(): @@ -272,7 +269,6 @@ def test_topk_sae_training_equivalence(): d_sae=8, ) new_training_sae = TopKTrainingSAE(new_training_cfg) - new_training_sae.disable_sparse_activations() # Compare param shapes using updated compare_params compare_params(old_training_sae, new_training_sae) @@ -317,7 +313,7 @@ def test_topk_sae_training_equivalence(): ) assert_close( old_out.feature_acts, - new_out.feature_acts, + new_out.feature_acts.to_dense(), atol=1e-5, msg="Training feature_acts differ between old and new implementations.", ) diff --git a/tests/saes/test_topk_sae.py b/tests/saes/test_topk_sae.py index 69973ece1..5ba9f62a1 100644 --- a/tests/saes/test_topk_sae.py +++ b/tests/saes/test_topk_sae.py @@ -1,6 +1,7 @@ import os from pathlib import Path +import numpy as np import pytest import torch from sparsify import SparseCoder, SparseCoderConfig @@ -59,7 +60,7 @@ def test_TopKTrainingSAE_topk_aux_loss_matches_unnormalized_sparsify_implementat normalization = input_var / input_acts.shape[0] raw_aux_loss = sae_out.losses["auxiliary_reconstruction_loss"].item() # type: ignore norm_aux_loss = raw_aux_loss / normalization - assert norm_aux_loss == pytest.approx(comparison_aux_loss, abs=1e-2) + assert norm_aux_loss == pytest.approx(comparison_aux_loss, abs=3e-2) def test_TopKSAE_save_and_load_from_pretrained(tmp_path: Path) -> None: @@ -139,45 +140,45 @@ def test_TopKTrainingSAE_save_and_load_inference_sae(tmp_path: Path) -> None: inference_sae_out = inference_sae.decode(inference_feature_acts) # Should produce identical outputs - assert_close(training_feature_acts, inference_feature_acts) - assert_close(training_sae_out, inference_sae_out) + assert_close(training_feature_acts.to_dense(), inference_feature_acts) + assert_close(training_sae_out, inference_sae_out, rtol=1e-4, atol=1e-4) # Test the full forward pass training_full_out = training_sae(sae_in) inference_full_out = inference_sae(sae_in) - assert_close(training_full_out, inference_full_out) + assert_close(training_full_out, inference_full_out, rtol=1e-4, atol=1e-4) -def test_topK_sparse_activations(): +@pytest.mark.parametrize("num_dims", [1, 2, 3, 4, 5]) +def test_topK_sparse_activations(num_dims: bool): # Validate that the sparse top-K intermediate output (COO format) # we use to accelerate the decoder matches the dense top-K output. - d_sae = 1024 - M = 128 - B = 16 + dims = (np.arange(1, num_dims + 1) + 3).tolist() + dims[-1] = 1024 for k in [1, 10, 100, 1000]: topk_sparse = TopK(k, use_sparse_activations=True) topk_dense = TopK(k, use_sparse_activations=False) - x = torch.randn(B, M, d_sae) + 50.0 + x = torch.randn(*dims) + 50.0 sparse_x = topk_sparse(x) assert sparse_x.is_sparse - assert sparse_x.coalesce().values().numel() == B * M * k - sparse_x = sparse_x.to_dense().reshape(B, M, d_sae) + sparse_x = sparse_x.to_dense() dense_x = topk_dense(x) assert_close(dense_x, sparse_x) -def test_topK_activation_sparse_mm(): +@pytest.mark.parametrize("num_dims", [1, 2, 3, 4, 5]) +def test_topK_activation_sparse_mm(num_dims: int): # Validate that our decoder produces the same output when using the sparse intermediates # as when using the dense intermediates. d_in = 128 d_sae = 1024 - M = 128 + dims = (np.arange(1, num_dims + 1) + 3).tolist() + dims[-1] = d_sae cfg = build_topk_sae_training_cfg( d_in=d_in, d_sae=d_sae, k=26, - decoder_init_norm=1.0, # TODO: why is this needed?? ) sae = TopKTrainingSAE(cfg) @@ -191,7 +192,7 @@ def test_topK_activation_sparse_mm(): for k in [1, 10, 100, 1000]: topk_sparse = TopK(k, use_sparse_activations=True) topk_dense = TopK(k, use_sparse_activations=False) - x = torch.randn(M, d_sae) + 50.0 + x = torch.randn(*dims) + 50.0 sparse_x = topk_sparse(x) sae_out_sparse = sae.decode(sparse_x) dense_x = topk_dense(x) @@ -199,18 +200,8 @@ def test_topK_activation_sparse_mm(): assert_close(sae_out_sparse, sae_out_dense, rtol=1e-4, atol=5e-4) -def test_topK_sparse_activations_config(): +def test_TopKTrainingSAE_sparse_activations_config(): # Check that our config is respected in both training & inference SAEs - cfg = build_topk_sae_cfg(k=100, use_sparse_activations=True) - sae = TopKSAE(cfg) - assert sae.activation_fn.use_sparse_activations # type: ignore - assert sae.cfg.use_sparse_activations - - cfg = build_topk_sae_cfg(k=100, use_sparse_activations=False) - sae = TopKSAE(cfg) - assert not sae.activation_fn.use_sparse_activations # type: ignore - assert not sae.cfg.use_sparse_activations - cfg = build_topk_sae_training_cfg(k=100, use_sparse_activations=True) sae = TopKTrainingSAE(cfg) assert sae.activation_fn.use_sparse_activations # type: ignore diff --git a/tests/test_evals.py b/tests/test_evals.py index 9081e01a6..34eb21064 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -29,12 +29,14 @@ from sae_lens.loading.pretrained_saes_directory import PretrainedSAELookup from sae_lens.saes.sae import SAE, TrainingSAE from sae_lens.saes.standard_sae import StandardSAE, StandardTrainingSAE +from sae_lens.saes.topk_sae import TopKTrainingSAE from sae_lens.training.activation_scaler import ActivationScaler from sae_lens.training.activations_store import ActivationsStore from tests.helpers import ( NEEL_NANDA_C4_10K_DATASET, TINYSTORIES_MODEL, build_runner_cfg, + build_topk_runner_cfg, load_model_cached, ) @@ -170,6 +172,34 @@ def test_run_evals_base_sae( assert len(eval_metrics) > 0 +@pytest.mark.parametrize("use_sparse_activations", [True, False]) +def test_run_evals_sparse_topk_sae( + model: HookedTransformer, + use_sparse_activations: bool, +): + cfg = build_topk_runner_cfg( + use_sparse_activations=use_sparse_activations, + model_name="tiny-stories-1M", + dataset_path="roneneldan/TinyStories", + hook_name="blocks.1.hook_resid_pre", + d_in=64, + ) + sae = TopKTrainingSAE(cfg.sae) + activation_store = ActivationsStore.from_config( + model, cfg, override_dataset=Dataset.from_list([{"text": "hello world"}] * 2000) + ) + eval_metrics, _ = run_evals( + sae=sae, + activation_store=activation_store, + activation_scaler=ActivationScaler(), + model=model, + eval_config=get_eval_everything_config(), + ) + + assert set(eval_metrics.keys()).issubset(set(all_possible_keys)) + assert len(eval_metrics) > 0 + + def test_run_evals_training_sae( training_sae: TrainingSAE[Any], activation_store: ActivationsStore, diff --git a/tests/test_util.py b/tests/test_util.py index 99531c96c..72c0f008b 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,7 +2,10 @@ import pytest -from sae_lens.util import extract_stop_at_layer_from_tlens_hook_name, path_or_tmp_dir +from sae_lens.util import ( + extract_stop_at_layer_from_tlens_hook_name, + path_or_tmp_dir, +) @pytest.mark.parametrize( From f722bbbe231e6c1785e807b65a9618ee88230fec Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 30 Sep 2025 23:39:02 +0100 Subject: [PATCH 10/19] fix logging with sparse feature acts --- sae_lens/saes/topk_sae.py | 17 ----------------- sae_lens/training/sae_trainer.py | 2 +- tests/training/test_sae_trainer.py | 4 ++++ 3 files changed, 5 insertions(+), 18 deletions(-) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index e5644abf4..8a8b53881 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -235,23 +235,6 @@ def decode( sae_out_pre = self.run_time_activation_norm_fn_out(sae_out_pre) return self.reshape_fn_out(sae_out_pre, self.d_head) - @override - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Forward pass through the SAE.""" - feature_acts = self.encode(x) - sae_out = self.decode(feature_acts) - - if self.use_error_term: - with torch.no_grad(): - # Recompute without hooks for true error term - with _disable_hooks(self): - feature_acts_clean = self.encode(x) - x_reconstruct_clean = self.decode(feature_acts_clean) - sae_error = self.hook_sae_error(x - x_reconstruct_clean) - sae_out = sae_out + sae_error - - return self.hook_sae_output(sae_out) - @override def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]: return TopK(self.cfg.k, use_sparse_activations=False) diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index eaa064ab2..79e1b83ea 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -312,7 +312,7 @@ def _build_train_step_log_dict( loss = output.loss.item() # metrics for currents acts - l0 = (feature_acts > 0).float().sum(-1).mean() + l0 = feature_acts.bool().float().sum(-1).to_dense().mean() current_learning_rate = self.optimizer.param_groups[0]["lr"] per_token_l2_loss = (sae_out - sae_in).pow(2).sum(dim=-1).squeeze() diff --git a/tests/training/test_sae_trainer.py b/tests/training/test_sae_trainer.py index c1b2fa774..7583335b8 100644 --- a/tests/training/test_sae_trainer.py +++ b/tests/training/test_sae_trainer.py @@ -179,8 +179,10 @@ def test_log_feature_sparsity__handles_zeroes_by_default_fp16() -> None: assert _log_feature_sparsity(fp16_zeroes).item() != float("-inf") +@pytest.mark.parametrize("sparse_feature_acts", [True, False]) def test_build_train_step_log_dict( trainer: SAETrainer[StandardTrainingSAE, StandardTrainingSAEConfig], + sparse_feature_acts: bool, ) -> None: train_output = TrainStepOutput( sae_in=torch.tensor([[-1, 0], [0, 2], [1, 1]]).float(), @@ -196,6 +198,8 @@ def test_build_train_step_log_dict( "topk_threshold": torch.tensor(0.5), }, ) + if sparse_feature_acts: + train_output.feature_acts = train_output.feature_acts.to_sparse_coo() # we're relying on the trainer only for some of the metrics here # we should more / less try to break this and push From 2c086e7099ec9deab4b95408d2652469fabb729e Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 30 Sep 2025 23:43:33 +0100 Subject: [PATCH 11/19] switch TopK to use dense tensors by default in case users are extending / using this module --- sae_lens/saes/topk_sae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 8a8b53881..89f44c6d0 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -51,7 +51,7 @@ class TopK(nn.Module): def __init__( self, k: int, - use_sparse_activations: bool = True, + use_sparse_activations: bool = False, ): super().__init__() self.k = k From bdd8411b693297f622f4595593de8a2eae122278 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Wed, 1 Oct 2025 00:02:35 +0100 Subject: [PATCH 12/19] bust CI cache to hopefully get CI to not run out of disk... --- .github/workflows/build.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index aa0a6f2ee..0dd39a90e 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -32,16 +32,16 @@ jobs: - name: Cache Huggingface assets uses: actions/cache@v4 with: - key: huggingface-4-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} + key: huggingface-5-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} path: ~/.cache/huggingface restore-keys: | - huggingface-4-${{ runner.os }}-${{ matrix.python-version }}- + huggingface-5-${{ runner.os }}-${{ matrix.python-version }}- - name: Load cached Poetry installation id: cached-poetry uses: actions/cache@v4 with: path: ~/.local # the path depends on the OS - key: poetry-${{ runner.os }}-${{ matrix.python-version }}-3 # increment to reset cache + key: poetry-${{ runner.os }}-${{ matrix.python-version }}-5 # increment to reset cache - name: Install Poetry if: steps.cached-poetry.outputs.cache-hit != 'true' uses: snok/install-poetry@v1 @@ -54,9 +54,9 @@ jobs: uses: actions/cache@v4 with: path: .venv - key: venv-2-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} + key: venv-5-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }} restore-keys: | - venv-2-${{ runner.os }}-${{ matrix.python-version }}- + venv-5-${{ runner.os }}-${{ matrix.python-version }}- - name: Install dependencies if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' run: poetry install --no-interaction From 260aba2b22e8caa13c4112a124efc0fe3aa69f7a Mon Sep 17 00:00:00 2001 From: David Chanin Date: Wed, 1 Oct 2025 10:39:26 +0100 Subject: [PATCH 13/19] disable autocase for sparse.mm --- sae_lens/saes/topk_sae.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 89f44c6d0..4cf137b9b 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -146,7 +146,9 @@ def _sparse_matmul_nd( if sparse_tensor.ndim == 2: # Simple 2D case - use torch.sparse.mm directly - return torch.sparse.mm(sparse_tensor, dense_matrix) + # sparse.mm errors with bfloat16 :( + with torch.autocast(device_type=sparse_tensor.device.type, enabled=False): + return torch.sparse.mm(sparse_tensor, dense_matrix) # For 3D+ case, reshape to 2D, multiply, then reshape back batch_size = int(torch.prod(torch.tensor(batch_dims)).item()) @@ -174,8 +176,10 @@ def _sparse_matmul_nd( sparse_2d_indices, values, (batch_size, d_sae) ).coalesce() - # Do the matrix multiplication - result_2d = torch.sparse.mm(sparse_2d, dense_matrix) # [batch_size, d_out] + # sparse.mm errors with bfloat16 :( + with torch.autocast(device_type=sparse_tensor.device.type, enabled=False): + # Do the matrix multiplication + result_2d = torch.sparse.mm(sparse_2d, dense_matrix) # [batch_size, d_out] # Reshape back to original batch dimensions result_shape = tuple(batch_dims) + (d_out,) From 41393b27d94e94d5e7bf33b9f451de12376110ea Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 4 Oct 2025 14:43:37 +0100 Subject: [PATCH 14/19] default TopK SAEs to disable sparse training until we can improve performance --- sae_lens/saes/topk_sae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 4cf137b9b..be117f795 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -258,7 +258,7 @@ class TopKTrainingSAEConfig(TrainingSAEConfig): """ k: int = 100 - use_sparse_activations: bool = True + use_sparse_activations: bool = False aux_loss_coefficient: float = 1.0 @override From d0356385a1d96d3be23e71facc0f6f765115e092 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 4 Oct 2025 14:50:32 +0100 Subject: [PATCH 15/19] updating docs for topk config --- sae_lens/saes/topk_sae.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index be117f795..9e6e85923 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -255,6 +255,39 @@ def fold_W_dec_norm(self) -> None: class TopKTrainingSAEConfig(TrainingSAEConfig): """ Configuration class for training a TopKTrainingSAE. + + Args: + k (int): Number of top features to keep active. Only the top k features + with the highest pre-activations will be non-zero. Defaults to 100. + use_sparse_activations (bool): Whether to use sparse tensor representations + for activations during training. This can reduce memory usage and improve + performance when k is small relative to d_sae, but is only worthwhile if + using float32 and not using autocast. Defaults to False. + aux_loss_coefficient (float): Coefficient for the auxiliary loss that encourages + dead neurons to learn useful features. This loss helps prevent neuron death + in TopK SAEs by having dead neurons reconstruct the residual error from + live neurons. Defaults to 1.0. + decoder_init_norm (float | None): Norm to initialize decoder weights to. + 0.1 corresponds to the "heuristic" initialization from Anthropic's April update. + Use None to disable. Inherited from TrainingSAEConfig. Defaults to 0.1. + d_in (int): Input dimension (dimensionality of the activations being encoded). + Inherited from SAEConfig. + d_sae (int): SAE latent dimension (number of features in the SAE). + Inherited from SAEConfig. + dtype (str): Data type for the SAE parameters. Inherited from SAEConfig. + Defaults to "float32". + device (str): Device to place the SAE on. Inherited from SAEConfig. + Defaults to "cpu". + apply_b_dec_to_input (bool): Whether to apply decoder bias to the input + before encoding. Inherited from SAEConfig. Defaults to True. + normalize_activations (Literal["none", "expected_average_only_in", "constant_norm_rescale", "layer_norm"]): + Normalization strategy for input activations. Inherited from SAEConfig. + Defaults to "none". + reshape_activations (Literal["none", "hook_z"]): How to reshape activations + (useful for attention head outputs). Inherited from SAEConfig. + Defaults to "none". + metadata (SAEMetadata): Metadata about the SAE training (model name, hook name, etc.). + Inherited from SAEConfig. """ k: int = 100 From 6f52a0130ae2c0fa45469077ed4edd13633f56fd Mon Sep 17 00:00:00 2001 From: wz-ml Date: Sun, 5 Oct 2025 14:01:34 -0700 Subject: [PATCH 16/19] Update benchmark to not report FP breakdown by default (faster run) --- benchmark/bench_fwd_perf.py | 64 +++++++++++++++++++++++++------------ 1 file changed, 44 insertions(+), 20 deletions(-) diff --git a/benchmark/bench_fwd_perf.py b/benchmark/bench_fwd_perf.py index cf41e0d22..720bd813d 100644 --- a/benchmark/bench_fwd_perf.py +++ b/benchmark/bench_fwd_perf.py @@ -25,6 +25,18 @@ help="Shape of the input tensor (seq_len, d_in, d_sae)", ) parser.add_argument("--k", type=int, default=100, help="Number of topk elements") +parser.add_argument( + "--dtype", + type=str, + default="float64", + help="Datatype (bfloat16, float16, float32, float64)", +) +parser.add_argument( + "-bd", + "--show-breakdown", + action="store_true", + help="Time individual components of the forward pass.", +) args = parser.parse_args() device = args.device @@ -42,6 +54,7 @@ k=k, device=device, use_sparse_activations=True, + dtype="float64", ) cfg_dense = build_topk_sae_training_cfg( d_in=d_in, @@ -49,6 +62,7 @@ k=k, device=device, use_sparse_activations=False, + dtype="float64", ) sae_sparse = TopKTrainingSAE(cfg_sparse) @@ -93,37 +107,47 @@ def triton_bench(fn: Callable[[], Any]) -> float: def benchmark_sae(sae: TopKTrainingSAE) -> dict[str, float]: results = {} - results["encode_proj"] = triton_bench(lambda: encode_proj(sae, input_acts)) - hidden_pre = encode_proj(sae, input_acts) - results["topk_activation"] = triton_bench(lambda: topk_activation(sae, hidden_pre)) - feature_acts = topk_activation(sae, hidden_pre) - results["decode_step"] = triton_bench(lambda: decode_step(sae, feature_acts)) - sae_out = decode_step(sae, feature_acts) - results["loss_computation"] = triton_bench( - lambda: loss_computation(sae, sae_out, input_acts) - ) results["full_forward_pass"] = triton_bench( lambda: sae.training_forward_pass(step_input) ) - results["other"] = 2 * results["full_forward_pass"] - sum(results.values()) # type: ignore + if args.show_breakdown: + results["encode_proj"] = triton_bench(lambda: encode_proj(sae, input_acts)) + hidden_pre = encode_proj(sae, input_acts) + results["topk_activation"] = triton_bench( + lambda: topk_activation(sae, hidden_pre) + ) + feature_acts = topk_activation(sae, hidden_pre) + results["decode_step"] = triton_bench(lambda: decode_step(sae, feature_acts)) + sae_out = decode_step(sae, feature_acts) + results["loss_computation"] = triton_bench( + lambda: loss_computation(sae, sae_out, input_acts) + ) + results["other"] = 2 * results["full_forward_pass"] - sum(results.values()) return results if __name__ == "__main__": - print("This may take a while (5 mins). Go grab a coffee!") + if args.show_breakdown: + print("This may take a while (5 mins). Go grab a coffee!") results_sparse = benchmark_sae(sae_sparse) results_dense = benchmark_sae(sae_dense) # Pretty print results table with metrics as columns - headers = [ - "Implementation", - "Encode", - "TopK", - "Decode", - "Loss Calc", - "Full Fwd", - "Other", - ] + if args.show_breakdown: + headers = [ + "Implementation", + "Encode", + "TopK", + "Decode", + "Loss Calc", + "Full Fwd", + "Other", + ] + else: + headers = [ + "Implementation", + "Full Fwd", + ] metric_keys = results_sparse.keys() From f497885ccc2c75e4097047f2ea5ceec7d9845871 Mon Sep 17 00:00:00 2001 From: wz-ml Date: Sun, 5 Oct 2025 14:43:39 -0700 Subject: [PATCH 17/19] Update topK activation & decode to use sparse CSR format --- sae_lens/saes/topk_sae.py | 105 ++++++++---------------------------- tests/saes/test_topk_sae.py | 3 +- 2 files changed, 25 insertions(+), 83 deletions(-) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 9e6e85923..0451c5754 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -69,47 +69,27 @@ def forward( topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False) values = topk_values.relu() if self.use_sparse_activations: - # Produce a COO sparse tensor (use sparse matrix multiply in decode) - original_shape = x.shape - - # Create indices for all dimensions - # For each element in topk_indices, we need to map it back to the original tensor coordinates - batch_dims = original_shape[:-1] # All dimensions except the last one - num_batch_elements = torch.prod(torch.tensor(batch_dims)).item() - - # Create batch indices - each batch element repeated k times - batch_indices_flat = torch.arange( - num_batch_elements, device=x.device - ).repeat_interleave(self.k) - + # Produce a CSR sparse tensor (use sparse matrix multiply in decode) # Convert flat batch indices back to multi-dimensional indices - if len(batch_dims) == 1: - # 2D case: [batch, features] - sparse_indices = torch.stack( - [ - batch_indices_flat, - topk_indices.flatten(), - ] - ) - else: - # 3D+ case: need to unravel the batch indices - batch_indices_multi = [] - remaining = batch_indices_flat - for dim_size in reversed(batch_dims): - batch_indices_multi.append(remaining % dim_size) - remaining = remaining // dim_size - batch_indices_multi.reverse() - - sparse_indices = torch.stack( - [ - *batch_indices_multi, - topk_indices.flatten(), - ] - ) - - return torch.sparse_coo_tensor( - sparse_indices, values.flatten(), original_shape + batch_dims = x.shape[:-1] + prod_batch_dims = int(torch.prod(torch.tensor(batch_dims)).item()) + csr_tensor = torch.sparse_csr_tensor( + torch.arange( + start=0, + end=len(topk_indices.flatten()) + self.k, + step=self.k, + device=x.device, + ), + topk_indices.flatten(), + topk_values.flatten(), + dtype=x.dtype, + device=x.device, + size=(prod_batch_dims, x.shape[-1]), ) + # A little hacky - let me know if you think of a better way to do this. - Will + csr_tensor.batch_dims = batch_dims # type: ignore + return csr_tensor + result = torch.zeros_like(x) result.scatter_(-1, topk_indices, values) return result @@ -135,51 +115,12 @@ def _sparse_matmul_nd( """ Multiply a sparse tensor of shape [..., d_sae] with a dense matrix of shape [d_sae, d_out] to get a result of shape [..., d_out]. - - This function handles sparse tensors with arbitrary batch dimensions by flattening - the batch dimensions, performing 2D sparse matrix multiplication, and reshaping back. """ - original_shape = sparse_tensor.shape - batch_dims = original_shape[:-1] - d_sae = original_shape[-1] + batch_dims = sparse_tensor.batch_dims # type: ignore d_out = dense_matrix.shape[-1] - if sparse_tensor.ndim == 2: - # Simple 2D case - use torch.sparse.mm directly - # sparse.mm errors with bfloat16 :( - with torch.autocast(device_type=sparse_tensor.device.type, enabled=False): - return torch.sparse.mm(sparse_tensor, dense_matrix) - - # For 3D+ case, reshape to 2D, multiply, then reshape back - batch_size = int(torch.prod(torch.tensor(batch_dims)).item()) - - # Ensure tensor is coalesced for efficient access to indices/values - if not sparse_tensor.is_coalesced(): - sparse_tensor = sparse_tensor.coalesce() - - # Get indices and values - indices = sparse_tensor.indices() # [ndim, nnz] - values = sparse_tensor.values() # [nnz] - - # Convert multi-dimensional batch indices to flat indices - flat_batch_indices = torch.zeros_like(indices[0]) - multiplier = 1 - for i in reversed(range(len(batch_dims))): - flat_batch_indices += indices[i] * multiplier - multiplier *= batch_dims[i] - - # Create 2D sparse tensor indices [batch_flat, feature] - sparse_2d_indices = torch.stack([flat_batch_indices, indices[-1]]) - - # Create 2D sparse tensor - sparse_2d = torch.sparse_coo_tensor( - sparse_2d_indices, values, (batch_size, d_sae) - ).coalesce() - - # sparse.mm errors with bfloat16 :( - with torch.autocast(device_type=sparse_tensor.device.type, enabled=False): - # Do the matrix multiplication - result_2d = torch.sparse.mm(sparse_2d, dense_matrix) # [batch_size, d_out] + # Do the matrix multiplication + result_2d = torch.sparse.mm(sparse_tensor, dense_matrix) # [batch_size, d_out] # Reshape back to original batch dimensions result_shape = tuple(batch_dims) + (d_out,) @@ -340,7 +281,7 @@ def decode( applying optional finetuning scale, hooking, out normalization, etc. """ # Handle sparse tensors using efficient sparse matrix multiplication - if feature_acts.is_sparse: + if feature_acts.is_sparse or feature_acts.is_sparse_csr: sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec else: sae_out_pre = feature_acts @ self.W_dec + self.b_dec diff --git a/tests/saes/test_topk_sae.py b/tests/saes/test_topk_sae.py index 5ba9f62a1..6f0d66f8d 100644 --- a/tests/saes/test_topk_sae.py +++ b/tests/saes/test_topk_sae.py @@ -160,8 +160,9 @@ def test_topK_sparse_activations(num_dims: bool): topk_dense = TopK(k, use_sparse_activations=False) x = torch.randn(*dims) + 50.0 sparse_x = topk_sparse(x) - assert sparse_x.is_sparse + assert sparse_x.is_sparse or sparse_x.is_sparse_csr sparse_x = sparse_x.to_dense() + sparse_x = sparse_x.reshape(x.shape) dense_x = topk_dense(x) assert_close(dense_x, sparse_x) From 1fb85020c4494fb9e800d73355706cdde9f7fbe3 Mon Sep 17 00:00:00 2001 From: wz-ml Date: Sun, 5 Oct 2025 15:04:04 -0700 Subject: [PATCH 18/19] Add CSR support to evals. Note: Having intended target shape as temporary tensor attribute is a little hacky - am open to other ways to implement this :) --- sae_lens/evals.py | 6 +++++- sae_lens/saes/topk_sae.py | 5 +++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 1e7e63d79..12033ef8b 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -466,8 +466,12 @@ def get_sparsity_and_variance_metrics( sae_out_scaled = sae.decode(sae_feature_activations).to( original_act_scaled.device ) - if sae_feature_activations.is_sparse: + if sae_feature_activations.is_sparse or sae_feature_activations.is_sparse_csr: + batch_dims = sae_feature_activations.batch_dims # type: ignore sae_feature_activations = sae_feature_activations.to_dense() + sae_feature_activations = sae_feature_activations.reshape( + batch_dims + (sae_feature_activations.shape[-1],) + ) del cache sae_out = activation_scaler.unscale(sae_out_scaled) diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 0451c5754..5313c34ac 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -35,8 +35,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: using_hooks = ( self._forward_hooks is not None and len(self._forward_hooks) > 0 ) or (self._backward_hooks is not None and len(self._backward_hooks) > 0) - if using_hooks and x.is_sparse: - return x.to_dense() + if using_hooks and (x.is_sparse or x.is_sparse_csr): + dense_x = x.to_dense() + return dense_x.reshape(x.batch_dims + (x.shape[-1],)) # type: ignore return x # if no hooks are being used, use passthrough From f46491895c11d24b4efbc9435344df5a71a0d751 Mon Sep 17 00:00:00 2001 From: wz-ml Date: Sun, 5 Oct 2025 21:25:23 -0700 Subject: [PATCH 19/19] Fix issue with bench script s.t dtype param's used properly --- benchmark/bench_fwd_perf.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmark/bench_fwd_perf.py b/benchmark/bench_fwd_perf.py index 720bd813d..e9dbf377e 100644 --- a/benchmark/bench_fwd_perf.py +++ b/benchmark/bench_fwd_perf.py @@ -28,7 +28,7 @@ parser.add_argument( "--dtype", type=str, - default="float64", + default="float32", help="Datatype (bfloat16, float16, float32, float64)", ) parser.add_argument( @@ -54,7 +54,7 @@ k=k, device=device, use_sparse_activations=True, - dtype="float64", + dtype=args.dtype, ) cfg_dense = build_topk_sae_training_cfg( d_in=d_in, @@ -62,7 +62,7 @@ k=k, device=device, use_sparse_activations=False, - dtype="float64", + dtype=args.dtype, ) sae_sparse = TopKTrainingSAE(cfg_sparse) @@ -131,6 +131,7 @@ def benchmark_sae(sae: TopKTrainingSAE) -> dict[str, float]: print("This may take a while (5 mins). Go grab a coffee!") results_sparse = benchmark_sae(sae_sparse) results_dense = benchmark_sae(sae_dense) + speedup = results_dense["full_forward_pass"] / results_sparse["full_forward_pass"] # Pretty print results table with metrics as columns if args.show_breakdown: @@ -156,4 +157,5 @@ def benchmark_sae(sae: TopKTrainingSAE) -> dict[str, float]: ["Dense"] + [f"{results_dense[key]:.3f}" for key in metric_keys], ] print("Metric: Latency (ms)") + print(f"Speedup: {speedup:.3f}") print("\n" + tabulate(table_data, headers=headers, tablefmt="grid"))