Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions benchmark/bench_fwd_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@
help="Shape of the input tensor (seq_len, d_in, d_sae)",
)
parser.add_argument("--k", type=int, default=100, help="Number of topk elements")
parser.add_argument(
"--dtype",
type=str,
default="float32",
help="Datatype (bfloat16, float16, float32, float64)",
)
parser.add_argument(
"-bd",
"--show-breakdown",
action="store_true",
help="Time individual components of the forward pass.",
)
args = parser.parse_args()

device = args.device
Expand All @@ -42,13 +54,15 @@
k=k,
device=device,
use_sparse_activations=True,
dtype=args.dtype,
)
cfg_dense = build_topk_sae_training_cfg(
d_in=d_in,
d_sae=d_sae,
k=k,
device=device,
use_sparse_activations=False,
dtype=args.dtype,
)

sae_sparse = TopKTrainingSAE(cfg_sparse)
Expand Down Expand Up @@ -113,6 +127,7 @@ def benchmark_sae(sae: TopKTrainingSAE) -> dict[str, float]:
print("This may take a while (5 mins). Go grab a coffee!")
results_sparse = benchmark_sae(sae_sparse)
results_dense = benchmark_sae(sae_dense)
speedup = results_dense["full_forward_pass"] / results_sparse["full_forward_pass"]

# Pretty print results table with metrics as columns
headers = [
Expand All @@ -132,4 +147,5 @@ def benchmark_sae(sae: TopKTrainingSAE) -> dict[str, float]:
["Dense"] + [f"{results_dense[key]:.3f}" for key in metric_keys],
]
print("Metric: Latency (ms)")
print(f"Speedup: {speedup:.3f}")
print("\n" + tabulate(table_data, headers=headers, tablefmt="grid"))
2 changes: 1 addition & 1 deletion sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ def get_sparsity_and_variance_metrics(
sae_out_scaled = sae.decode(sae_feature_activations).to(
original_act_scaled.device
)
if sae_feature_activations.is_sparse:
if sae_feature_activations.is_sparse or sae_feature_activations.is_sparse_csr:
sae_feature_activations = sae_feature_activations.to_dense()
del cache

Expand Down
150 changes: 47 additions & 103 deletions sae_lens/saes/topk_sae.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Inference-only TopKSAE variant, similar in spirit to StandardSAE but using a TopK-based activation."""

from contextlib import contextmanager
from dataclasses import dataclass
from typing import Callable
from typing import Callable, Generator

import torch
from jaxtyping import Float
Expand All @@ -16,6 +17,7 @@
TrainingSAE,
TrainingSAEConfig,
TrainStepInput,
TrainStepOutput,
_disable_hooks,
)

Expand All @@ -35,7 +37,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
using_hooks = (
self._forward_hooks is not None and len(self._forward_hooks) > 0
) or (self._backward_hooks is not None and len(self._backward_hooks) > 0)
if using_hooks and x.is_sparse:
if using_hooks and (x.is_sparse or x.is_sparse_csr):
return x.to_dense()
return x # if no hooks are being used, use passthrough

Expand Down Expand Up @@ -69,47 +71,24 @@ def forward(
topk_values, topk_indices = torch.topk(x, k=self.k, dim=-1, sorted=False)
values = topk_values.relu()
if self.use_sparse_activations:
# Produce a COO sparse tensor (use sparse matrix multiply in decode)
original_shape = x.shape

# Create indices for all dimensions
# For each element in topk_indices, we need to map it back to the original tensor coordinates
batch_dims = original_shape[:-1] # All dimensions except the last one
num_batch_elements = torch.prod(torch.tensor(batch_dims)).item()

# Create batch indices - each batch element repeated k times
batch_indices_flat = torch.arange(
num_batch_elements, device=x.device
).repeat_interleave(self.k)

# Produce a CSR sparse tensor (use sparse matrix multiply in decode)
# Convert flat batch indices back to multi-dimensional indices
if len(batch_dims) == 1:
# 2D case: [batch, features]
sparse_indices = torch.stack(
[
batch_indices_flat,
topk_indices.flatten(),
]
)
else:
# 3D+ case: need to unravel the batch indices
batch_indices_multi = []
remaining = batch_indices_flat
for dim_size in reversed(batch_dims):
batch_indices_multi.append(remaining % dim_size)
remaining = remaining // dim_size
batch_indices_multi.reverse()

sparse_indices = torch.stack(
[
*batch_indices_multi,
topk_indices.flatten(),
]
)

return torch.sparse_coo_tensor(
sparse_indices, values.flatten(), original_shape
if x.ndim != 2:
raise ValueError("Sparse activations are only supported for 2D tensors")
return torch.sparse_csr_tensor(
torch.arange(
start=0,
end=len(topk_indices.flatten()) + self.k,
step=self.k,
device=x.device,
),
topk_indices.flatten(),
topk_values.flatten(),
dtype=x.dtype,
device=x.device,
size=x.shape,
)

result = torch.zeros_like(x)
result.scatter_(-1, topk_indices, values)
return result
Expand All @@ -129,63 +108,6 @@ def architecture(cls) -> str:
return "topk"


def _sparse_matmul_nd(
sparse_tensor: torch.Tensor, dense_matrix: torch.Tensor
) -> torch.Tensor:
"""
Multiply a sparse tensor of shape [..., d_sae] with a dense matrix of shape [d_sae, d_out]
to get a result of shape [..., d_out].

This function handles sparse tensors with arbitrary batch dimensions by flattening
the batch dimensions, performing 2D sparse matrix multiplication, and reshaping back.
"""
original_shape = sparse_tensor.shape
batch_dims = original_shape[:-1]
d_sae = original_shape[-1]
d_out = dense_matrix.shape[-1]

if sparse_tensor.ndim == 2:
# Simple 2D case - use torch.sparse.mm directly
# sparse.mm errors with bfloat16 :(
with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
return torch.sparse.mm(sparse_tensor, dense_matrix)

# For 3D+ case, reshape to 2D, multiply, then reshape back
batch_size = int(torch.prod(torch.tensor(batch_dims)).item())

# Ensure tensor is coalesced for efficient access to indices/values
if not sparse_tensor.is_coalesced():
sparse_tensor = sparse_tensor.coalesce()

# Get indices and values
indices = sparse_tensor.indices() # [ndim, nnz]
values = sparse_tensor.values() # [nnz]

# Convert multi-dimensional batch indices to flat indices
flat_batch_indices = torch.zeros_like(indices[0])
multiplier = 1
for i in reversed(range(len(batch_dims))):
flat_batch_indices += indices[i] * multiplier
multiplier *= batch_dims[i]

# Create 2D sparse tensor indices [batch_flat, feature]
sparse_2d_indices = torch.stack([flat_batch_indices, indices[-1]])

# Create 2D sparse tensor
sparse_2d = torch.sparse_coo_tensor(
sparse_2d_indices, values, (batch_size, d_sae)
).coalesce()

# sparse.mm errors with bfloat16 :(
with torch.autocast(device_type=sparse_tensor.device.type, enabled=False):
# Do the matrix multiplication
result_2d = torch.sparse.mm(sparse_2d, dense_matrix) # [batch_size, d_out]

# Reshape back to original batch dimensions
result_shape = tuple(batch_dims) + (d_out,)
return result_2d.view(result_shape)


class TopKSAE(SAE[TopKSAEConfig]):
"""
An inference-only sparse autoencoder using a "topk" activation function.
Expand Down Expand Up @@ -231,8 +153,8 @@ def decode(
and optional head reshaping.
"""
# Handle sparse tensors using efficient sparse matrix multiplication
if feature_acts.is_sparse:
sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
if feature_acts.is_sparse or feature_acts.is_sparse_csr:
sae_out_pre = torch.sparse.mm(feature_acts, self.W_dec) + self.b_dec
else:
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
sae_out_pre = self.hook_sae_recons(sae_out_pre)
Expand Down Expand Up @@ -340,8 +262,8 @@ def decode(
applying optional finetuning scale, hooking, out normalization, etc.
"""
# Handle sparse tensors using efficient sparse matrix multiplication
if feature_acts.is_sparse:
sae_out_pre = _sparse_matmul_nd(feature_acts, self.W_dec) + self.b_dec
if feature_acts.is_sparse or feature_acts.is_sparse_csr:
sae_out_pre = torch.sparse.mm(feature_acts, self.W_dec) + self.b_dec
else:
sae_out_pre = feature_acts @ self.W_dec + self.b_dec
sae_out_pre = self.hook_sae_recons(sae_out_pre)
Expand Down Expand Up @@ -391,12 +313,34 @@ def fold_W_dec_norm(self) -> None:

@override
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
return TopK(self.cfg.k, use_sparse_activations=self.cfg.use_sparse_activations)
return TopK(self.cfg.k)

@override
def get_coefficients(self) -> dict[str, TrainCoefficientConfig | float]:
return {}

@override
def training_forward_pass(
self,
step_input: TrainStepInput,
) -> TrainStepOutput:
with self.use_sparse_topk(self.cfg.use_sparse_activations):
return super().training_forward_pass(step_input)

@contextmanager
def use_sparse_topk(
self, use_sparse_activations: bool = True
) -> Generator[None, None, None]:
"""
Temporarily set use_sparse_activations attribute on the activation function to the given value.
"""
original_use_sparse_activations = getattr(
self.activation_fn, "use_sparse_activations", False
)
self.activation_fn.use_sparse_activations = use_sparse_activations # type: ignore
yield
self.activation_fn.use_sparse_activations = original_use_sparse_activations # type: ignore

def calculate_topk_aux_loss(
self,
sae_in: torch.Tensor,
Expand Down
15 changes: 11 additions & 4 deletions sae_lens/training/sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,14 +253,21 @@ def _train_step(
)

with torch.no_grad():
# calling .bool() should be equivalent to .abs() > 0, and work with coo tensors
# calling .bool() should be equivalent to .abs() > 0, and work with sparse tensors
firing_feats = train_step_output.feature_acts.bool().float()
did_fire = firing_feats.sum(-2).bool()
if did_fire.is_sparse:
# need to keepdim to avoid issues with sparse tensors
did_fire = firing_feats.sum(-2, keepdim=True).bool()
if did_fire.is_sparse or did_fire.is_sparse_csr:
did_fire = did_fire.to_dense()
did_fire = did_fire.squeeze(-2)
self.n_forward_passes_since_fired += 1
self.n_forward_passes_since_fired[did_fire] = 0
self.act_freq_scores += firing_feats.sum(0)
# again, need to keepdim to avoid issues with sparse tensors
freq_deltas = firing_feats.sum(0, keepdim=True)
if freq_deltas.is_sparse or freq_deltas.is_sparse_csr:
freq_deltas = freq_deltas.to_dense()
freq_deltas = freq_deltas.squeeze(0)
self.act_freq_scores += freq_deltas
self.n_frac_active_samples += self.cfg.train_batch_size_samples

# Grad scaler will rescale gradients if autocast is enabled
Expand Down
11 changes: 11 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,3 +591,14 @@ def random_params(model: torch.nn.Module) -> None:
param.data = torch.rand_like(param)
for buffer in model.buffers():
buffer.data = torch.rand_like(buffer)


@torch.no_grad()
def match_params(model1: torch.nn.Module, model2: torch.nn.Module) -> None:
"""
Match the parameters of two models.
"""
for param1, param2 in zip(model1.parameters(), model2.parameters()):
param1.data = param2.data
for buffer1, buffer2 in zip(model1.buffers(), model2.buffers()):
buffer1.data = buffer2.data
Loading