Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
- [Training Your Own SAEs](#training-your-own-saes)
- [Graphing Results](#graphing-results)


## Overview

SAE Bench is a comprehensive suite of 8 evaluations for Sparse Autoencoder (SAE) models:
Expand All @@ -21,6 +20,7 @@ SAE Bench is a comprehensive suite of 8 evaluations for Sparse Autoencoder (SAE)
- **[Spurious Correlation Removal (SCR)](https://arxiv.org/abs/2411.18895)**
- **[Targeted Probe Pertubation (TPP)](https://arxiv.org/abs/2411.18895)**
- **Sparse Probing**
- **[Sparse Probing (SAE Probes version)](https://arxiv.org/pdf/2502.16681)**
- **[Unlearning](https://arxiv.org/abs/2410.19278)**

For more information, refer to our [blog post](https://www.neuronpedia.org/sae-bench/info).
Expand Down Expand Up @@ -124,20 +124,20 @@ The computational requirements for running SAEBench evaluations were measured on
- **Setup Phase**: Includes operations like precomputing model activations, training probes, or other one-time preprocessing steps which can be reused across multiple SAE evaluations.
- **Per-SAE Evaluation Time**: The time required to evaluate a single SAE once the setup is complete.

The total evaluation time for a single SAE across all benchmarks is approximately **110 minutes**, with an additional **152 minutes** of setup time. Note that actual runtimes may vary significantly based on factors such as SAE dictionary size, base model, and GPU selection.

| Evaluation Type | Avg Time per SAE (min) | Setup Time (min) |
| --------------- | ---------------------- | ---------------- |
| Absorption | 26 | 33 |
| Core | 9 | 0 |
| SCR | 6 | 22 |
| TPP | 2 | 5 |
| Sparse Probing | 3 | 15 |
| Auto-Interp | 9 | 0 |
| Unlearning | 10 | 33 |
| RAVEL | 45 | 45 |
| **Total** | **110** | **152** |

The total evaluation time for a single SAE across all benchmarks is approximately **115 minutes**, with an additional **177 minutes** of setup time. Note that actual runtimes may vary significantly based on factors such as SAE dictionary size, base model, and GPU selection.

| Evaluation Type | Avg Time per SAE (min) | Setup Time (min) |
| --------------------------- | ---------------------- | ---------------- |
| Absorption | 26 | 33 |
| Core | 9 | 0 |
| SCR | 6 | 22 |
| TPP | 2 | 5 |
| Sparse Probing | 3 | 15 |
| Sparse Probing (SAE Probes) | 5 | 25 |
| Auto-Interp | 9 | 0 |
| Unlearning | 10 | 33 |
| RAVEL | 45 | 45 |
| **Total** | **115** | **177** |

# SAE Bench Baseline Suite

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ matplotlib = ">=3.8.4"
tabulate = ">=0.9.0"
openai = ">=1.0.0"
torchvision = ">=0.16.1" # required for what I believe are nnsight related issues
sae-probes = "^0.3.0"
datasets = ">=3.0.0,<4.0.0" # skylion openwebtext fails to load with datasets 4.0.0 currently, pending https://huggingface.co/datasets/Skylion007/openwebtext/discussions/22

# If running into dependency issues these are tested and working
Expand Down
14 changes: 14 additions & 0 deletions sae_bench/custom_saes/run_all_evals_custom_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sae_bench.evals.ravel.main as ravel
import sae_bench.evals.scr_and_tpp.main as scr_and_tpp
import sae_bench.evals.sparse_probing.main as sparse_probing
import sae_bench.evals.sparse_probing_sae_probes.main as sparse_probing_sae_probes
import sae_bench.evals.unlearning.main as unlearning
import sae_bench.sae_bench_utils.general_utils as general_utils

Expand Down Expand Up @@ -37,6 +38,7 @@
"scr": "eval_results/scr",
"tpp": "eval_results/tpp",
"sparse_probing": "eval_results/sparse_probing",
"sparse_probing_sae_probes": "eval_results/sparse_probing_sae_probes",
"unlearning": "eval_results/unlearning",
"ravel": "eval_results/ravel",
}
Expand Down Expand Up @@ -171,6 +173,17 @@ def run_evals(
save_activations=save_activations,
)
),
"sparse_probing_sae_probes": (
lambda: sparse_probing_sae_probes.run_eval(
sparse_probing_sae_probes.SparseProbingSaeProbesEvalConfig(
model_name=model_name,
),
selected_saes,
device,
"eval_results/sparse_probing_sae_probes",
force_rerun,
)
),
"unlearning": (
lambda: unlearning.run_eval(
unlearning.UnlearningEvalConfig(
Expand Down Expand Up @@ -237,6 +250,7 @@ def run_evals(
"scr",
"tpp",
"sparse_probing",
"sparse_probing_sae_probes",
"unlearning",
]

Expand Down
14 changes: 14 additions & 0 deletions sae_bench/custom_saes/run_all_evals_dictionary_learning_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import sae_bench.evals.ravel.main as ravel
import sae_bench.evals.scr_and_tpp.main as scr_and_tpp
import sae_bench.evals.sparse_probing.main as sparse_probing
import sae_bench.evals.sparse_probing_sae_probes.main as sparse_probing_sae_probes
import sae_bench.evals.unlearning.main as unlearning
import sae_bench.sae_bench_utils.general_utils as general_utils

Expand Down Expand Up @@ -48,6 +49,7 @@
"scr": "eval_results/scr",
"tpp": "eval_results/tpp",
"sparse_probing": "eval_results/sparse_probing",
"sparse_probing_sae_probes": "eval_results/sparse_probing_sae_probes",
"unlearning": "eval_results/unlearning",
"ravel": "eval_results/ravel",
}
Expand Down Expand Up @@ -278,6 +280,17 @@ def run_evals(
save_activations=True,
)
),
"sparse_probing_sae_probes": (
lambda selected_saes, is_final: sparse_probing_sae_probes.run_eval(
sparse_probing_sae_probes.SparseProbingSaeProbesEvalConfig(
model_name=model_name,
),
selected_saes,
device,
"eval_results/sparse_probing_sae_probes",
force_rerun,
)
),
"unlearning": (
lambda selected_saes, is_final: unlearning.run_eval(
unlearning.UnlearningEvalConfig(
Expand Down Expand Up @@ -375,6 +388,7 @@ def run_evals(
"scr",
"tpp",
"sparse_probing",
"sparse_probing_sae_probes",
"autointerp",
# "unlearning",
"ravel",
Expand Down
2 changes: 1 addition & 1 deletion sae_bench/evals/meta_structure/eval_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class MetaStructureEvalOutput(
eval_id: str
datetime_epoch_millis: int
eval_result_metrics: MetaStructureMetricCategories
eval_result_details: list[BaseResultDetail] | None = None
eval_result_details: list[BaseResultDetail] | None = None # pyright: ignore[reportIncompatibleVariableOverride]
eval_type_id: str = Field(
default=EVAL_TYPE_ID_META_STRUCTURE,
title="Eval Type ID",
Expand Down
156 changes: 156 additions & 0 deletions sae_bench/evals/sparse_probing_sae_probes/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
This eval implements the k-sparse probing benchmark from the paper [Are Sparse Autoencoders Useful? A Case Study in Sparse Probing](https://arxiv.org/pdf/2502.16681), which runs k-sparse probing on over 140 datasets. This eval wraps the standalone `sae-probes` python package, putting results in SAEBench format. For further customization of the eval, refer to the [sae-probes documentation](https://github.com/sae-probes/sae-probes).

## Usage

### Basic Usage

Run the eval from the command line:

```bash
python sae_bench/evals/sparse_probing_sae_probes/main.py \
--model_name gpt2 \
--sae_regex_pattern "gpt2-small-res-jb" \
--sae_block_pattern "blocks.4.hook_resid_pre"
```

### Configuration Options

- `--model_name`: Name of the model (e.g., `gpt2`, `pythia-70m`)
- `--sae_regex_pattern`: Regex pattern to match SAE releases
- `--sae_block_pattern`: Regex pattern to match SAE hook points
- `--ks`: List of k values for sparse probing (default: `[1, 2, 5]`)
- Example: `--ks 1 2 5 10 20`
- `--reg_type`: Regularization type for probing (`l1` or `l2`, default: `l1`)
- `--setting`: Data balance setting (`normal`, `scarcity`, or `imbalance`, default: `normal`)
- `--binarize`: Whether to binarize probe targets (flag, default: False)
- `--results_path`: Directory where sae-probes writes intermediate JSONs (default: `artifacts/sparse_probing_sae_probes`)
- `--model_cache_path`: Optional directory to cache model activations for faster re-runs (default: `artifacts/sparse_probing_sae_probes--model_acts_cache`)
- `--output_folder`: Where to save SAEBench output files (default: `eval_results/sparse_probing_sae_probes`)
- `--force_rerun`: Force re-running the eval even if results exist (flag)

### Programmatic Usage

```python
from sae_bench.evals.sparse_probing_sae_probes.eval_config import SparseProbingSaeProbesEvalConfig
from sae_bench.evals.sparse_probing_sae_probes.main import run_eval
from sae_lens import SAE

# Configure the eval
config = SparseProbingSaeProbesEvalConfig(
model_name="gpt2",
dataset_names=["118_us_state_CA", "119_us_state_TX"], # Subset of datasets
ks=[1, 2, 5, 10], # Custom k values
include_llm_baseline=True, # Compare against LLM residual stream baseline
results_path="artifacts/sparse_probing_sae_probes",
model_cache_path="cache/models",
)

# Load your SAE
sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.4.hook_resid_pre")[0]

# Run the eval
results = run_eval(
config=config,
selected_saes=[("my_sae_release", sae)],
device="cuda",
output_path="eval_results/sparse_probing_sae_probes",
)
```

### Output Structure

The eval produces a JSON file with the following structure:

```json
{
"eval_type_id": "sparse_probing_sae_probes",
"eval_result_metrics": {
"llm": {
"llm_test_accuracy": 0.85,
"llm_test_auc": 0.92,
"llm_test_f1": 0.83
},
"sae": {
"sae_top_1_test_accuracy": 0.78,
"sae_top_1_test_auc": 0.85,
"sae_top_1_test_f1": 0.76,
"sae_top_2_test_accuracy": 0.81,
...
}
},
"sae_metrics_by_k": {
"1": {"test_accuracy": 0.78, "test_auc": 0.85, "test_f1": 0.76},
"2": {"test_accuracy": 0.81, "test_auc": 0.87, "test_f1": 0.79},
...
},
"eval_result_details": [
{
"dataset_name": "118_us_state_CA",
"llm_test_accuracy": 0.90,
"sae_top_1_test_accuracy": 0.82,
"sae_metrics_by_k": {
"1": {"test_accuracy": 0.82, ...},
...
}
},
...
]
}
```

**Key Metrics:**

- **LLM metrics**: Baseline performance using full LLM residual stream (all dimensions)
- **SAE top-k metrics**: Performance using only k SAE latents with highest probe weights
- **sae_metrics_by_k**: Flexible dictionary supporting arbitrary k values
- **eval_result_details**: Per-dataset breakdown of all metrics

### Custom K Values

By default, the eval runs with k=[1, 2, 5]. You can specify custom k values:

```bash
python sae_bench/evals/sparse_probing_sae_probes/main.py \
--model_name gpt2 \
--sae_regex_pattern "gpt2-small-res-jb" \
--sae_block_pattern "blocks.4.hook_resid_pre" \
--ks 3 7 15 25 50
```

Results will be available in:

- Individual hardcoded fields (e.g., `sae_top_1_test_accuracy`) for standard k values
- `sae_metrics_by_k` dictionary for all k values (including custom ones)

### Dataset Selection

By default, the eval runs on all 140+ datasets from sae-probes. To run on a subset:

```python
config = SparseProbingSaeProbesEvalConfig(
model_name="gpt2",
dataset_names=["118_us_state_CA", "119_us_state_TX", "120_us_state_NY"],
# ... other config
)
```

See the [sae-probes datasets](https://github.com/sae-probes/sae-probes#available-datasets) for the full list.

### Including LLM Baselines

To compare SAE performance against full LLM residual stream baselines:

```python
config = SparseProbingSaeProbesEvalConfig(
model_name="gpt2",
include_llm_baseline=True, # Enables baseline comparison
baseline_method="logreg", # Method for baseline probe (default)
# ... other config
)
```

This adds LLM baseline metrics to the output, allowing you to compare how well k SAE latents perform versus using all LLM dimensions.

### Caching model activations for Faster Iteration

Set `model_cache_path` to cache model activations across runs if you expect to rerun this eval for lots of different SAEs on the same model / layers. Set this to `None` to disable caching.
22 changes: 22 additions & 0 deletions sae_bench/evals/sparse_probing_sae_probes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from .eval_config import SparseProbingSaeProbesEvalConfig
from .eval_output import (
EVAL_TYPE_ID_SPARSE_PROBING_SAE_PROBES,
SaeProbesLlmMetrics,
SaeProbesMetricCategories,
SaeProbesResultDetail,
SaeProbesSaeMetrics,
SparseProbingSaeProbesEvalOutput,
)
from .main import create_config_and_selected_saes, run_eval

__all__ = [
"SparseProbingSaeProbesEvalConfig",
"EVAL_TYPE_ID_SPARSE_PROBING_SAE_PROBES",
"SaeProbesLlmMetrics",
"SaeProbesSaeMetrics",
"SaeProbesMetricCategories",
"SaeProbesResultDetail",
"SparseProbingSaeProbesEvalOutput",
"create_config_and_selected_saes",
"run_eval",
]
Loading
Loading