diff --git a/benchmark/bench_fwd_perf.py b/benchmark/bench_fwd_perf.py index cf41e0d22..7b0eee4d0 100644 --- a/benchmark/bench_fwd_perf.py +++ b/benchmark/bench_fwd_perf.py @@ -25,6 +25,18 @@ help="Shape of the input tensor (seq_len, d_in, d_sae)", ) parser.add_argument("--k", type=int, default=100, help="Number of topk elements") +parser.add_argument( + "--dtype", + type=str, + default="float32", + help="Datatype (bfloat16, float16, float32, float64)", +) +parser.add_argument( + "-bd", + "--show-breakdown", + action="store_true", + help="Time individual components of the forward pass.", +) args = parser.parse_args() device = args.device @@ -42,6 +54,7 @@ k=k, device=device, use_sparse_activations=True, + dtype=args.dtype, ) cfg_dense = build_topk_sae_training_cfg( d_in=d_in, @@ -49,6 +62,7 @@ k=k, device=device, use_sparse_activations=False, + dtype=args.dtype, ) sae_sparse = TopKTrainingSAE(cfg_sparse) @@ -113,6 +127,7 @@ def benchmark_sae(sae: TopKTrainingSAE) -> dict[str, float]: print("This may take a while (5 mins). Go grab a coffee!") results_sparse = benchmark_sae(sae_sparse) results_dense = benchmark_sae(sae_dense) + speedup = results_dense["full_forward_pass"] / results_sparse["full_forward_pass"] # Pretty print results table with metrics as columns headers = [ @@ -132,4 +147,5 @@ def benchmark_sae(sae: TopKTrainingSAE) -> dict[str, float]: ["Dense"] + [f"{results_dense[key]:.3f}" for key in metric_keys], ] print("Metric: Latency (ms)") + print(f"Speedup: {speedup:.3f}") print("\n" + tabulate(table_data, headers=headers, tablefmt="grid")) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 1e7e63d79..f7c9f5c65 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -466,7 +466,7 @@ def get_sparsity_and_variance_metrics( sae_out_scaled = sae.decode(sae_feature_activations).to( original_act_scaled.device ) - if sae_feature_activations.is_sparse: + if sae_feature_activations.is_sparse or sae_feature_activations.is_sparse_csr: sae_feature_activations = sae_feature_activations.to_dense() del cache diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 9e6e85923..9ee668730 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -1,7 +1,8 @@ """Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation.""" +from contextlib import contextmanager from dataclasses import dataclass -from typing import Callable +from typing import Callable, Generator import torch from jaxtyping import Float @@ -16,6 +17,7 @@ TrainingSAE, TrainingSAEConfig, TrainStepInput, + TrainStepOutput, _disable_hooks, ) @@ -35,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: using_hooks = ( self._forward_hooks is not None and len(self._forward_hooks) > 0 ) or (self._backward_hooks is not None and len(self._backward_hooks) > 0) - if using_hooks and x.is_sparse: + if using_hooks and (x.is_sparse or x.is_sparse_csr): return x.to_dense() return x # if no hooks are being used, use passthrough @@ -69,47 +71,24 @@ def forward( topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False) values = topk_values.relu() if self.use_sparse_activations: - # Produce a COO sparse tensor (use sparse matrix multiply in decode) - original_shape = x.shape - - # Create indices for all dimensions - # For each element in topk_indices, we need to map it back to the original tensor coordinates - batch_dims = original_shape[:-1] # All dimensions except the last one - num_batch_elements = torch.prod(torch.tensor(batch_dims)).item() - - # Create batch indices - each batch element repeated k times - batch_indices_flat = torch.arange( - num_batch_elements, device=x.device - ).repeat_interleave(self.k) - + # Produce a CSR sparse tensor (use sparse matrix multiply in decode) # Convert flat batch indices back to multi-dimensional indices - if len(batch_dims) == 1: - # 2D case: [batch, features] - sparse_indices = torch.stack( - [ - batch_indices_flat, - topk_indices.flatten(), - ] - ) - else: - # 3D+ case: need to unravel the batch indices - batch_indices_multi = [] - remaining = batch_indices_flat - for dim_size in reversed(batch_dims): - batch_indices_multi.append(remaining % dim_size) - remaining = remaining // dim_size - batch_indices_multi.reverse() - - sparse_indices = torch.stack( - [ - *batch_indices_multi, - topk_indices.flatten(), - ] - ) - - return torch.sparse_coo_tensor( - sparse_indices, values.flatten(), original_shape + if x.ndim != 2: + raise ValueError("Sparse activations are only supported for 2D tensors") + return torch.sparse_csr_tensor( + torch.arange( + start=0, + end=len(topk_indices.flatten()) + self.k, + step=self.k, + device=x.device, + ), + topk_indices.flatten(), + topk_values.flatten(), + dtype=x.dtype, + device=x.device, + size=x.shape, ) + result = torch.zeros_like(x) result.scatter_(-1, topk_indices, values) return result @@ -129,63 +108,6 @@ def architecture(cls) -> str: return "topk" -def _sparse_matmul_nd( - sparse_tensor: torch.Tensor, dense_matrix: torch.Tensor -) -> torch.Tensor: - """ - Multiply a sparse tensor of shape [..., d_sae] with a dense matrix of shape [d_sae, d_out] - to get a result of shape [..., d_out]. - - This function handles sparse tensors with arbitrary batch dimensions by flattening - the batch dimensions, performing 2D sparse matrix multiplication, and reshaping back. - """ - original_shape = sparse_tensor.shape - batch_dims = original_shape[:-1] - d_sae = original_shape[-1] - d_out = dense_matrix.shape[-1] - - if sparse_tensor.ndim == 2: - # Simple 2D case - use torch.sparse.mm directly - # sparse.mm errors with bfloat16 :( - with torch.autocast(device_type=sparse_tensor.device.type, enabled=False): - return torch.sparse.mm(sparse_tensor, dense_matrix) - - # For 3D+ case, reshape to 2D, multiply, then reshape back - batch_size = int(torch.prod(torch.tensor(batch_dims)).item()) - - # Ensure tensor is coalesced for efficient access to indices/values - if not sparse_tensor.is_coalesced(): - sparse_tensor = sparse_tensor.coalesce() - - # Get indices and values - indices = sparse_tensor.indices() # [ndim, nnz] - values = sparse_tensor.values() # [nnz] - - # Convert multi-dimensional batch indices to flat indices - flat_batch_indices = torch.zeros_like(indices[0]) - multiplier = 1 - for i in reversed(range(len(batch_dims))): - flat_batch_indices += indices[i] * multiplier - multiplier *= batch_dims[i] - - # Create 2D sparse tensor indices [batch_flat, feature] - sparse_2d_indices = torch.stack([flat_batch_indices, indices[-1]]) - - # Create 2D sparse tensor - sparse_2d = torch.sparse_coo_tensor( - sparse_2d_indices, values, (batch_size, d_sae) - ).coalesce() - - # sparse.mm errors with bfloat16 :( - with torch.autocast(device_type=sparse_tensor.device.type, enabled=False): - # Do the matrix multiplication - result_2d = torch.sparse.mm(sparse_2d, dense_matrix) # [batch_size, d_out] - - # Reshape back to original batch dimensions - result_shape = tuple(batch_dims) + (d_out,) - return result_2d.view(result_shape) - - class TopKSAE(SAE[TopKSAEConfig]): """ An inference-only sparse autoencoder using a "topk" activation function. @@ -231,8 +153,8 @@ def decode( and optional head reshaping. """ # Handle sparse tensors using efficient sparse matrix multiplication - if feature_acts.is_sparse: - sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec + if feature_acts.is_sparse or feature_acts.is_sparse_csr: + sae_out_pre = torch.sparse.mm(feature_acts, self.W_dec) + self.b_dec else: sae_out_pre = feature_acts @ self.W_dec + self.b_dec sae_out_pre = self.hook_sae_recons(sae_out_pre) @@ -340,8 +262,8 @@ def decode( applying optional finetuning scale, hooking, out normalization, etc. """ # Handle sparse tensors using efficient sparse matrix multiplication - if feature_acts.is_sparse: - sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec + if feature_acts.is_sparse or feature_acts.is_sparse_csr: + sae_out_pre = torch.sparse.mm(feature_acts, self.W_dec) + self.b_dec else: sae_out_pre = feature_acts @ self.W_dec + self.b_dec sae_out_pre = self.hook_sae_recons(sae_out_pre) @@ -391,12 +313,34 @@ def fold_W_dec_norm(self) -> None: @override def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]: - return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations) + return TopK(self.cfg.k) @override def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]: return {} + @override + def training_forward_pass( + self, + step_input: TrainStepInput, + ) -> TrainStepOutput: + with self.use_sparse_topk(self.cfg.use_sparse_activations): + return super().training_forward_pass(step_input) + + @contextmanager + def use_sparse_topk( + self, use_sparse_activations: bool = True + ) -> Generator[None, None, None]: + """ + Temporarily set use_sparse_activations attribute on the activation function to the given value. + """ + original_use_sparse_activations = getattr( + self.activation_fn, "use_sparse_activations", False + ) + self.activation_fn.use_sparse_activations = use_sparse_activations # type: ignore + yield + self.activation_fn.use_sparse_activations = original_use_sparse_activations # type: ignore + def calculate_topk_aux_loss( self, sae_in: torch.Tensor, diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 79e1b83ea..d26d5beff 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -253,14 +253,21 @@ def _train_step( ) with torch.no_grad(): - # calling .bool() should be equivalent to .abs() > 0, and work with coo tensors + # calling .bool() should be equivalent to .abs() > 0, and work with sparse tensors firing_feats = train_step_output.feature_acts.bool().float() - did_fire = firing_feats.sum(-2).bool() - if did_fire.is_sparse: + # need to keepdim to avoid issues with sparse tensors + did_fire = firing_feats.sum(-2, keepdim=True).bool() + if did_fire.is_sparse or did_fire.is_sparse_csr: did_fire = did_fire.to_dense() + did_fire = did_fire.squeeze(-2) self.n_forward_passes_since_fired += 1 self.n_forward_passes_since_fired[did_fire] = 0 - self.act_freq_scores += firing_feats.sum(0) + # again, need to keepdim to avoid issues with sparse tensors + freq_deltas = firing_feats.sum(0, keepdim=True) + if freq_deltas.is_sparse or freq_deltas.is_sparse_csr: + freq_deltas = freq_deltas.to_dense() + freq_deltas = freq_deltas.squeeze(0) + self.act_freq_scores += freq_deltas self.n_frac_active_samples += self.cfg.train_batch_size_samples # Grad scaler will rescale gradients if autocast is enabled diff --git a/tests/helpers.py b/tests/helpers.py index e10f57770..af424776f 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -591,3 +591,14 @@ def random_params(model: torch.nn.Module) -> None: param.data = torch.rand_like(param) for buffer in model.buffers(): buffer.data = torch.rand_like(buffer) + + +@torch.no_grad() +def match_params(model1: torch.nn.Module, model2: torch.nn.Module) -> None: + """ + Match the parameters of two models. + """ + for param1, param2 in zip(model1.parameters(), model2.parameters()): + param1.data = param2.data + for buffer1, buffer2 in zip(model1.buffers(), model2.buffers()): + buffer1.data = buffer2.data diff --git a/tests/saes/test_topk_sae.py b/tests/saes/test_topk_sae.py index 5ba9f62a1..c74e585b7 100644 --- a/tests/saes/test_topk_sae.py +++ b/tests/saes/test_topk_sae.py @@ -1,7 +1,6 @@ import os from pathlib import Path -import numpy as np import pytest import torch from sparsify import SparseCoder, SparseCoderConfig @@ -12,6 +11,7 @@ assert_close, build_topk_sae_cfg, build_topk_sae_training_cfg, + match_params, ) @@ -63,6 +63,50 @@ def test_TopKTrainingSAE_topk_aux_loss_matches_unnormalized_sparsify_implementat assert norm_aux_loss == pytest.approx(comparison_aux_loss, abs=3e-2) +def test_TopKTrainingSAE_training_forward_pass_matches_with_and_without_sparse_activations(): + d_in = 128 + d_sae = 192 + k = 26 + sparse_cfg = build_topk_sae_training_cfg( + d_in=d_in, + d_sae=d_sae, + k=k, + use_sparse_activations=True, + ) + dense_cfg = build_topk_sae_training_cfg( + d_in=d_in, + d_sae=d_sae, + k=k, + use_sparse_activations=False, + ) + + sparse_sae = TopKTrainingSAE(sparse_cfg) + dense_sae = TopKTrainingSAE(dense_cfg) + + match_params(sparse_sae, dense_sae) + + dead_neuron_mask = torch.randn(d_sae) > 0.1 + input_acts = torch.randn(200, d_in) + + step_input = TrainStepInput( + sae_in=input_acts, + dead_neuron_mask=dead_neuron_mask, + coefficients={}, + ) + + sparse_sae_out = sparse_sae.training_forward_pass(step_input=step_input) + dense_sae_out = dense_sae.training_forward_pass(step_input=step_input) + + assert sparse_sae_out.feature_acts.is_sparse_csr + assert not dense_sae_out.feature_acts.is_sparse + assert not dense_sae_out.feature_acts.is_sparse_csr + + assert_close(sparse_sae_out.sae_out, dense_sae_out.sae_out) + assert_close(sparse_sae_out.loss, dense_sae_out.loss) + assert_close(sparse_sae_out.feature_acts.to_dense(), dense_sae_out.feature_acts) + assert_close(sparse_sae_out.hidden_pre, dense_sae_out.hidden_pre) + + def test_TopKSAE_save_and_load_from_pretrained(tmp_path: Path) -> None: cfg = build_topk_sae_cfg(k=30) model_path = str(tmp_path) @@ -149,31 +193,28 @@ def test_TopKTrainingSAE_save_and_load_inference_sae(tmp_path: Path) -> None: assert_close(training_full_out, inference_full_out, rtol=1e-4, atol=1e-4) -@pytest.mark.parametrize("num_dims", [1, 2, 3, 4, 5]) -def test_topK_sparse_activations(num_dims: bool): +def test_topK_sparse_activations_only_works_with_2d_tensors(): # Validate that the sparse top-K intermediate output (COO format) # we use to accelerate the decoder matches the dense top-K output. - dims = (np.arange(1, num_dims + 1) + 3).tolist() - dims[-1] = 1024 - for k in [1, 10, 100, 1000]: - topk_sparse = TopK(k, use_sparse_activations=True) - topk_dense = TopK(k, use_sparse_activations=False) - x = torch.randn(*dims) + 50.0 - sparse_x = topk_sparse(x) - assert sparse_x.is_sparse - sparse_x = sparse_x.to_dense() - dense_x = topk_dense(x) - assert_close(dense_x, sparse_x) - - -@pytest.mark.parametrize("num_dims", [1, 2, 3, 4, 5]) -def test_topK_activation_sparse_mm(num_dims: int): + topk_sparse = TopK(100, use_sparse_activations=True) + x_1d = torch.randn(1024) + x_2d = torch.randn(3, 1024) + x_3d = torch.randn(3, 3, 1024) + with pytest.raises(ValueError): + topk_sparse(x_1d) + with pytest.raises(ValueError): + topk_sparse(x_3d) + sparse_x_2d = topk_sparse(x_2d) + assert sparse_x_2d.is_sparse or sparse_x_2d.is_sparse_csr + assert sparse_x_2d.ndim == 2 + + +def test_topK_activation_sparse_mm(): # Validate that our decoder produces the same output when using the sparse intermediates # as when using the dense intermediates. d_in = 128 d_sae = 1024 - dims = (np.arange(1, num_dims + 1) + 3).tolist() - dims[-1] = d_sae + dims = (3, d_sae) cfg = build_topk_sae_training_cfg( d_in=d_in, @@ -204,10 +245,20 @@ def test_TopKTrainingSAE_sparse_activations_config(): # Check that our config is respected in both training & inference SAEs cfg = build_topk_sae_training_cfg(k=100, use_sparse_activations=True) sae = TopKTrainingSAE(cfg) - assert sae.activation_fn.use_sparse_activations # type: ignore + assert not sae.activation_fn.use_sparse_activations # type: ignore assert sae.cfg.use_sparse_activations cfg = build_topk_sae_training_cfg(k=100, use_sparse_activations=False) sae = TopKTrainingSAE(cfg) assert not sae.activation_fn.use_sparse_activations # type: ignore assert not sae.cfg.use_sparse_activations + + +@pytest.mark.parametrize("use_sparse_activations", [True, False]) +def test_TopKTrainingSAE_use_sparse_topk(use_sparse_activations: bool): + cfg = build_topk_sae_training_cfg(k=100) + sae = TopKTrainingSAE(cfg) + assert not sae.activation_fn.use_sparse_activations # type: ignore + with sae.use_sparse_topk(use_sparse_activations): + assert sae.activation_fn.use_sparse_activations is use_sparse_activations # type: ignore + assert not sae.activation_fn.use_sparse_activations # type: ignore diff --git a/tests/training/test_sae_trainer.py b/tests/training/test_sae_trainer.py index 7583335b8..22f66822b 100644 --- a/tests/training/test_sae_trainer.py +++ b/tests/training/test_sae_trainer.py @@ -26,6 +26,7 @@ TINYSTORIES_MODEL, assert_close, build_runner_cfg, + build_topk_runner_cfg, load_model_cached, ) @@ -251,6 +252,33 @@ def test_train_sae_group_on_language_model__runs( assert isinstance(sae, TrainingSAE) +def test_SAETrainer_run_with_sparse_topk_sae( + ts_model: HookedTransformer, + tmp_path: Path, +) -> None: + checkpoint_dir = tmp_path / "checkpoint" + cfg = build_topk_runner_cfg( + use_sparse_activations=True, + k=10, + checkpoint_path=str(checkpoint_dir), + training_tokens=20, + context_size=8, + ) + # just a tiny datast which will run quickly + dataset = Dataset.from_list([{"text": "hello world"}] * 100) + activation_store = ActivationsStore.from_config( + ts_model, cfg, override_dataset=dataset + ) + sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) + sae = SAETrainer( + cfg=cfg.to_sae_trainer_config(), + sae=sae, + data_provider=activation_store, + ).fit() + + assert isinstance(sae, TrainingSAE) + + def test_update_sae_lens_training_version_sets_the_current_version(): cfg = build_runner_cfg(sae_lens_training_version="0.1.0") sae = TrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())