Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
4a417c3
Add kernel skeleton and remove redundant b_enc param from TopK class
wz-ml Sep 16, 2025
e6fac7d
Add option in topK to save SAE activations as a sparse tensor
wz-ml Sep 17, 2025
9a8b9fc
Add sparse activation config flag & update tests
wz-ml Sep 17, 2025
f776403
Linting changes
wz-ml Sep 18, 2025
50747d3
Making sparse COO tensors compatible with HookedTransformer (WIP)
wz-ml Sep 18, 2025
dfeb254
Changes to make sparse SAE intermediate implementation transparent to…
wz-ml Sep 19, 2025
1fe8060
Add formatting script (for future optimization)
wz-ml Sep 19, 2025
ab8ead4
Address PR review comments
wz-ml Sep 25, 2025
a442072
allow multidim sparsity in topk saes
chanind Sep 27, 2025
6740be6
Merge pull request #1 from jbloomAus/multidim-sparsity
wz-ml Sep 30, 2025
f722bbb
fix logging with sparse feature acts
chanind Sep 30, 2025
2c086e7
switch TopK to use dense tensors by default in case users are extendi…
chanind Sep 30, 2025
bdd8411
bust CI cache to hopefully get CI to not run out of disk...
chanind Sep 30, 2025
260aba2
disable autocase for sparse.mm
chanind Oct 1, 2025
41393b2
default TopK SAEs to disable sparse training until we can improve per…
chanind Oct 4, 2025
d035638
updating docs for topk config
chanind Oct 4, 2025
6f52a01
Update benchmark to not report FP breakdown by default (faster run)
wz-ml Oct 5, 2025
f497885
Update topK activation & decode to use sparse CSR format
wz-ml Oct 5, 2025
1fb8502
Add CSR support to evals. Note: Having intended target shape as tempo…
wz-ml Oct 5, 2025
f464918
Fix issue with bench script s.t dtype param's used properly
wz-ml Oct 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,16 @@ jobs:
- name: Cache Huggingface assets
uses: actions/cache@v4
with:
key: huggingface-4-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
key: huggingface-5-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
path: ~/.cache/huggingface
restore-keys: |
huggingface-4-${{ runner.os }}-${{ matrix.python-version }}-
huggingface-5-${{ runner.os }}-${{ matrix.python-version }}-
- name: Load cached Poetry installation
id: cached-poetry
uses: actions/cache@v4
with:
path: ~/.local # the path depends on the OS
key: poetry-${{ runner.os }}-${{ matrix.python-version }}-3 # increment to reset cache
key: poetry-${{ runner.os }}-${{ matrix.python-version }}-5 # increment to reset cache
- name: Install Poetry
if: steps.cached-poetry.outputs.cache-hit != 'true'
uses: snok/install-poetry@v1
Expand All @@ -54,9 +54,9 @@ jobs:
uses: actions/cache@v4
with:
path: .venv
key: venv-2-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
key: venv-5-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/pyproject.toml') }}
restore-keys: |
venv-2-${{ runner.os }}-${{ matrix.python-version }}-
venv-5-${{ runner.os }}-${{ matrix.python-version }}-
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction
Expand Down
161 changes: 161 additions & 0 deletions benchmark/bench_fwd_perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
import argparse
import os
from typing import Any, Callable

import torch
import torch._inductor.config
import triton
from tabulate import tabulate

from sae_lens.saes.sae import TrainStepInput
from sae_lens.saes.topk_sae import TopKTrainingSAE
from tests.helpers import (
build_topk_sae_training_cfg,
)

torch._inductor.config.coordinate_descent_tuning = True

parser = argparse.ArgumentParser(add_help=True)
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument(
"--shape",
type=int,
nargs=3,
default=[1024, 1024, 1024 * 16],
help="Shape of the input tensor (seq_len, d_in, d_sae)",
)
parser.add_argument("--k", type=int, default=100, help="Number of topk elements")
parser.add_argument(
"--dtype",
type=str,
default="float32",
help="Datatype (bfloat16, float16, float32, float64)",
)
parser.add_argument(
"-bd",
"--show-breakdown",
action="store_true",
help="Time individual components of the forward pass.",
)
args = parser.parse_args()

device = args.device

os.environ["TOKENIZERS_PARALLELISM"] = "false"

d_in = args.shape[1]
d_sae = args.shape[2]
k = args.k
seq_len = args.shape[0]

cfg_sparse = build_topk_sae_training_cfg(
d_in=d_in,
d_sae=d_sae,
k=k,
device=device,
use_sparse_activations=True,
dtype=args.dtype,
)
cfg_dense = build_topk_sae_training_cfg(
d_in=d_in,
d_sae=d_sae,
k=k,
device=device,
use_sparse_activations=False,
dtype=args.dtype,
)

sae_sparse = TopKTrainingSAE(cfg_sparse)
sae_dense = TopKTrainingSAE(cfg_dense)

dead_neuron_mask = None # torch.randn(d_sae, device = device) > 0.1
input_acts = torch.randn(seq_len, d_in, device=device)
input_var = (input_acts - input_acts.mean(0)).pow(2).sum()

step_input = TrainStepInput(
sae_in=input_acts,
dead_neuron_mask=dead_neuron_mask,
coefficients={},
)


def encode_proj(sae: TopKTrainingSAE, input_acts: torch.Tensor) -> torch.Tensor:
sae_in = sae.process_sae_in(input_acts)
return sae.hook_sae_acts_pre(sae_in @ sae.W_enc + sae.b_enc)


def topk_activation(sae: TopKTrainingSAE, hidden_pre: torch.Tensor) -> torch.Tensor:
return sae.activation_fn(hidden_pre)


def decode_step(sae: TopKTrainingSAE, feature_acts: torch.Tensor) -> torch.Tensor:
return sae.decode(feature_acts)


def loss_computation(
sae: TopKTrainingSAE, sae_out: torch.Tensor, sae_in: torch.Tensor
) -> torch.Tensor:
# Calculate MSE loss
per_item_mse_loss = sae.mse_loss_fn(sae_out, sae_in)
return per_item_mse_loss.sum(dim=-1).mean()


def triton_bench(fn: Callable[[], Any]) -> float:
# note that the warmup and rep params here are in ms, not iterations
return triton.testing.do_bench(fn, warmup=1000, rep=2000) # type: ignore


def benchmark_sae(sae: TopKTrainingSAE) -> dict[str, float]:
results = {}
results["full_forward_pass"] = triton_bench(
lambda: sae.training_forward_pass(step_input)
)
if args.show_breakdown:
results["encode_proj"] = triton_bench(lambda: encode_proj(sae, input_acts))
hidden_pre = encode_proj(sae, input_acts)
results["topk_activation"] = triton_bench(
lambda: topk_activation(sae, hidden_pre)
)
feature_acts = topk_activation(sae, hidden_pre)
results["decode_step"] = triton_bench(lambda: decode_step(sae, feature_acts))
sae_out = decode_step(sae, feature_acts)
results["loss_computation"] = triton_bench(
lambda: loss_computation(sae, sae_out, input_acts)
)
results["other"] = 2 * results["full_forward_pass"] - sum(results.values())
return results


if __name__ == "__main__":
if args.show_breakdown:
print("This may take a while (5 mins). Go grab a coffee!")
results_sparse = benchmark_sae(sae_sparse)
results_dense = benchmark_sae(sae_dense)
speedup = results_dense["full_forward_pass"] / results_sparse["full_forward_pass"]

# Pretty print results table with metrics as columns
if args.show_breakdown:
headers = [
"Implementation",
"Encode",
"TopK",
"Decode",
"Loss Calc",
"Full Fwd",
"Other",
]
else:
headers = [
"Implementation",
"Full Fwd",
]

metric_keys = results_sparse.keys()

table_data = [
["Sparse"] + [f"{results_sparse[key]:.3f}" for key in metric_keys],
["Dense"] + [f"{results_dense[key]:.3f}" for key in metric_keys],
]
print("Metric: Latency (ms)")
print(f"Speedup: {speedup:.3f}")
print("\n" + tabulate(table_data, headers=headers, tablefmt="grid"))
6 changes: 6 additions & 0 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,12 @@ def get_sparsity_and_variance_metrics(
sae_out_scaled = sae.decode(sae_feature_activations).to(
original_act_scaled.device
)
if sae_feature_activations.is_sparse or sae_feature_activations.is_sparse_csr:
batch_dims = sae_feature_activations.batch_dims # type: ignore
sae_feature_activations = sae_feature_activations.to_dense()
sae_feature_activations = sae_feature_activations.reshape(
batch_dims + (sae_feature_activations.shape[-1],)
)
del cache

sae_out = activation_scaler.unscale(sae_out_scaled)
Expand Down
3 changes: 2 additions & 1 deletion sae_lens/pretokenize_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import io
import json
import sys
from collections.abc import Iterator
from dataclasses import dataclass
from pathlib import Path
from typing import Iterator, Literal, cast
from typing import Literal, cast

import torch
from datasets import Dataset, DatasetDict, load_dataset
Expand Down
19 changes: 9 additions & 10 deletions sae_lens/saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Generic,
Literal,
NamedTuple,
Type,
TypeVar,
)

Expand Down Expand Up @@ -534,7 +533,7 @@ def save_model(self, path: str | Path) -> tuple[Path, Path]:
@classmethod
@deprecated("Use load_from_disk instead")
def load_from_pretrained(
cls: Type[T_SAE],
cls: type[T_SAE],
path: str | Path,
device: str = "cpu",
dtype: str | None = None,
Expand All @@ -543,7 +542,7 @@ def load_from_pretrained(

@classmethod
def load_from_disk(
cls: Type[T_SAE],
cls: type[T_SAE],
path: str | Path,
device: str = "cpu",
dtype: str | None = None,
Expand All @@ -564,7 +563,7 @@ def load_from_disk(

@classmethod
def from_pretrained(
cls: Type[T_SAE],
cls: type[T_SAE],
release: str,
sae_id: str,
device: str = "cpu",
Expand All @@ -585,7 +584,7 @@ def from_pretrained(

@classmethod
def from_pretrained_with_cfg_and_sparsity(
cls: Type[T_SAE],
cls: type[T_SAE],
release: str,
sae_id: str,
device: str = "cpu",
Expand Down Expand Up @@ -684,7 +683,7 @@ def from_pretrained_with_cfg_and_sparsity(
return sae, cfg_dict, log_sparsities

@classmethod
def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
def from_dict(cls: type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:
"""Create an SAE from a config dictionary."""
sae_cls = cls.get_sae_class_for_architecture(config_dict["architecture"])
sae_config_cls = cls.get_sae_config_class_for_architecture(
Expand All @@ -694,8 +693,8 @@ def from_dict(cls: Type[T_SAE], config_dict: dict[str, Any]) -> T_SAE:

@classmethod
def get_sae_class_for_architecture(
cls: Type[T_SAE], architecture: str
) -> Type[T_SAE]:
cls: type[T_SAE], architecture: str
) -> type[T_SAE]:
"""Get the SAE class for a given architecture."""
sae_cls, _ = get_sae_class(architecture)
if not issubclass(sae_cls, cls):
Expand Down Expand Up @@ -1000,8 +999,8 @@ def log_histograms(self) -> dict[str, NDArray[Any]]:

@classmethod
def get_sae_class_for_architecture(
cls: Type[T_TRAINING_SAE], architecture: str
) -> Type[T_TRAINING_SAE]:
cls: type[T_TRAINING_SAE], architecture: str
) -> type[T_TRAINING_SAE]:
"""Get the SAE class for a given architecture."""
sae_cls, _ = get_sae_training_class(architecture)
if not issubclass(sae_cls, cls):
Expand Down
Loading