From 850ee4991adbc9e6f09b381f385ea69ce4b1343d Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 4 Oct 2025 18:30:31 +0100 Subject: [PATCH 1/7] polishing sae-probes sparse probing eval --- pyproject.toml | 1 + .../evals/sparse_probing_sae_probes/README.md | 156 ++++++++ .../sparse_probing_sae_probes/__init__.py | 22 ++ .../sparse_probing_sae_probes/eval_config.py | 68 ++++ .../sparse_probing_sae_probes/eval_output.py | 156 ++++++++ .../evals/sparse_probing_sae_probes/main.py | 365 ++++++++++++++++++ tests/conftest.py | 8 +- tests/unit/__init__.py | 0 tests/unit/evals/__init__.py | 0 tests/unit/evals/absorption/__init__.py | 0 tests/unit/evals/autointerp/__init__.py | 0 tests/unit/evals/scr_and_tpp/__init__.py | 0 .../sparse_probing_sae_probes/__init__.py | 0 .../sparse_probing_sae_probes/test_main.py | 231 +++++++++++ 14 files changed, 1005 insertions(+), 2 deletions(-) create mode 100644 sae_bench/evals/sparse_probing_sae_probes/README.md create mode 100644 sae_bench/evals/sparse_probing_sae_probes/__init__.py create mode 100644 sae_bench/evals/sparse_probing_sae_probes/eval_config.py create mode 100644 sae_bench/evals/sparse_probing_sae_probes/eval_output.py create mode 100644 sae_bench/evals/sparse_probing_sae_probes/main.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/evals/__init__.py create mode 100644 tests/unit/evals/absorption/__init__.py create mode 100644 tests/unit/evals/autointerp/__init__.py create mode 100644 tests/unit/evals/scr_and_tpp/__init__.py create mode 100644 tests/unit/evals/sparse_probing_sae_probes/__init__.py create mode 100644 tests/unit/evals/sparse_probing_sae_probes/test_main.py diff --git a/pyproject.toml b/pyproject.toml index f4849340..b85ab015 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ matplotlib = ">=3.8.4" tabulate = ">=0.9.0" openai = ">=1.0.0" torchvision = ">=0.16.1" # required for what I believe are nnsight related issues +sae-probes = "^0.2.1" # If running into dependency issues these are tested and working # [tool.poetry.dependencies] diff --git a/sae_bench/evals/sparse_probing_sae_probes/README.md b/sae_bench/evals/sparse_probing_sae_probes/README.md new file mode 100644 index 00000000..dadf4998 --- /dev/null +++ b/sae_bench/evals/sparse_probing_sae_probes/README.md @@ -0,0 +1,156 @@ +This eval implements the k-sparse probing benchmark from the paper [Are Sparse Autoencoders Useful? A Case Study in Sparse Probing](https://arxiv.org/pdf/2502.16681), which runs k-sparse probing on over 140 datasets. This eval wraps the standalone `sae-probes` python package, putting results in SAEBench format. For further customization of the eval, refer to the [sae-probes documentation](https://github.com/sae-probes/sae-probes). + +## Usage + +### Basic Usage + +Run the eval from the command line: + +```bash +python sae_bench/evals/sparse_probing_sae_probes/main.py \ + --model_name gpt2 \ + --sae_regex_pattern "gpt2-small-res-jb" \ + --sae_block_pattern "blocks.4.hook_resid_pre" +``` + +### Configuration Options + +- `--model_name`: Name of the model (e.g., `gpt2`, `pythia-70m`) +- `--sae_regex_pattern`: Regex pattern to match SAE releases +- `--sae_block_pattern`: Regex pattern to match SAE hook points +- `--ks`: List of k values for sparse probing (default: `[1, 2, 5]`) + - Example: `--ks 1 2 5 10 20` +- `--reg_type`: Regularization type for probing (`l1` or `l2`, default: `l1`) +- `--setting`: Data balance setting (`normal`, `scarcity`, or `imbalance`, default: `normal`) +- `--binarize`: Whether to binarize probe targets (flag, default: False) +- `--results_path`: Directory where sae-probes writes intermediate JSONs (default: `artifacts/sparse_probing_sae_probes`) +- `--model_cache_path`: Optional directory to cache model activations for faster re-runs +- `--output_folder`: Where to save SAEBench output files (default: `eval_results/sparse_probing_sae_probes`) +- `--force_rerun`: Force re-running the eval even if results exist (flag) + +### Programmatic Usage + +```python +from sae_bench.evals.sparse_probing_sae_probes.eval_config import SparseProbingSaeProbesEvalConfig +from sae_bench.evals.sparse_probing_sae_probes.main import run_eval +from sae_lens import SAE + +# Configure the eval +config = SparseProbingSaeProbesEvalConfig( + model_name="gpt2", + dataset_names=["118_us_state_CA", "119_us_state_TX"], # Subset of datasets + ks=[1, 2, 5, 10], # Custom k values + include_llm_baseline=True, # Compare against LLM residual stream baseline + results_path="artifacts/sparse_probing_sae_probes", + model_cache_path="cache/models", +) + +# Load your SAE +sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.4.hook_resid_pre")[0] + +# Run the eval +results = run_eval( + config=config, + selected_saes=[("my_sae_release", sae)], + device="cuda", + output_path="eval_results/sparse_probing_sae_probes", +) +``` + +### Output Structure + +The eval produces a JSON file with the following structure: + +```json +{ + "eval_type_id": "sparse_probing_sae_probes", + "eval_result_metrics": { + "llm": { + "llm_test_accuracy": 0.85, + "llm_test_auc": 0.92, + "llm_test_f1": 0.83 + }, + "sae": { + "sae_top_1_test_accuracy": 0.78, + "sae_top_1_test_auc": 0.85, + "sae_top_1_test_f1": 0.76, + "sae_top_2_test_accuracy": 0.81, + ... + } + }, + "sae_metrics_by_k": { + "1": {"test_accuracy": 0.78, "test_auc": 0.85, "test_f1": 0.76}, + "2": {"test_accuracy": 0.81, "test_auc": 0.87, "test_f1": 0.79}, + ... + }, + "eval_result_details": [ + { + "dataset_name": "118_us_state_CA", + "llm_test_accuracy": 0.90, + "sae_top_1_test_accuracy": 0.82, + "sae_metrics_by_k": { + "1": {"test_accuracy": 0.82, ...}, + ... + } + }, + ... + ] +} +``` + +**Key Metrics:** + +- **LLM metrics**: Baseline performance using full LLM residual stream (all dimensions) +- **SAE top-k metrics**: Performance using only k SAE latents with highest probe weights +- **sae_metrics_by_k**: Flexible dictionary supporting arbitrary k values +- **eval_result_details**: Per-dataset breakdown of all metrics + +### Custom K Values + +By default, the eval runs with k=[1, 2, 5]. You can specify custom k values: + +```bash +python sae_bench/evals/sparse_probing_sae_probes/main.py \ + --model_name gpt2 \ + --sae_regex_pattern "gpt2-small-res-jb" \ + --sae_block_pattern "blocks.4.hook_resid_pre" \ + --ks 3 7 15 25 50 +``` + +Results will be available in: + +- Individual hardcoded fields (e.g., `sae_top_1_test_accuracy`) for standard k values +- `sae_metrics_by_k` dictionary for all k values (including custom ones) + +### Dataset Selection + +By default, the eval runs on all 140+ datasets from sae-probes. To run on a subset: + +```python +config = SparseProbingSaeProbesEvalConfig( + model_name="gpt2", + dataset_names=["118_us_state_CA", "119_us_state_TX", "120_us_state_NY"], + # ... other config +) +``` + +See the [sae-probes datasets](https://github.com/sae-probes/sae-probes#available-datasets) for the full list. + +### Including LLM Baselines + +To compare SAE performance against full LLM residual stream baselines: + +```python +config = SparseProbingSaeProbesEvalConfig( + model_name="gpt2", + include_llm_baseline=True, # Enables baseline comparison + baseline_method="logreg", # Method for baseline probe (default) + # ... other config +) +``` + +This adds LLM baseline metrics to the output, allowing you to compare how well k SAE latents perform versus using all LLM dimensions. + +### Caching model activations for Faster Iteration + +Set `model_cache_path` to cache model activations across runs if you expect to rerun this eval for lots of different SAEs on the same model / layers. If this is not set, the eval will re-generate model activations every time the eval is run. diff --git a/sae_bench/evals/sparse_probing_sae_probes/__init__.py b/sae_bench/evals/sparse_probing_sae_probes/__init__.py new file mode 100644 index 00000000..7380c3f4 --- /dev/null +++ b/sae_bench/evals/sparse_probing_sae_probes/__init__.py @@ -0,0 +1,22 @@ +from .eval_config import SparseProbingSaeProbesEvalConfig +from .eval_output import ( + EVAL_TYPE_ID_SPARSE_PROBING_SAE_PROBES, + SaeProbesLlmMetrics, + SaeProbesMetricCategories, + SaeProbesResultDetail, + SaeProbesSaeMetrics, + SparseProbingSaeProbesEvalOutput, +) +from .main import create_config_and_selected_saes, run_eval + +__all__ = [ + "SparseProbingSaeProbesEvalConfig", + "EVAL_TYPE_ID_SPARSE_PROBING_SAE_PROBES", + "SaeProbesLlmMetrics", + "SaeProbesSaeMetrics", + "SaeProbesMetricCategories", + "SaeProbesResultDetail", + "SparseProbingSaeProbesEvalOutput", + "create_config_and_selected_saes", + "run_eval", +] diff --git a/sae_bench/evals/sparse_probing_sae_probes/eval_config.py b/sae_bench/evals/sparse_probing_sae_probes/eval_config.py new file mode 100644 index 00000000..769158b7 --- /dev/null +++ b/sae_bench/evals/sparse_probing_sae_probes/eval_config.py @@ -0,0 +1,68 @@ +from pydantic import Field +from pydantic.dataclasses import dataclass +from sae_probes import DATASETS + +from sae_bench.evals.base_eval_output import BaseEvalConfig + + +@dataclass +class SparseProbingSaeProbesEvalConfig(BaseEvalConfig): + model_name: str = Field( + default="", + title="Model Name", + description="TransformerLens model name used by sae-probes (e.g., 'gemma-2-2b').", + ) + + dataset_names: list[str] = Field( + default_factory=lambda: [*DATASETS], + title="Dataset Names", + description="List of dataset names.", + ) + + reg_type: str = Field( + default="l1", + title="Regularization Type", + description="Regularization used for sparse probing selection in sae-probes ('l1' or 'l2').", + ) + + setting: str = Field( + default="normal", + title="Data Balance Setting", + description="sae-probes benchmark setting: 'normal', 'scarcity', or 'imbalance'.", + ) + + ks: list[int] = Field( + default_factory=lambda: [1, 2, 5], + title="K Values", + description="List of K values (number of SAE features) to evaluate.", + ) + + binarize: bool = Field( + default=False, + title="Binarize Latents", + description="Whether to binarize SAE latents during probing (sae-probes option).", + ) + + results_path: str = Field( + default="artifacts/sparse_probing_sae_probes", + title="sae-probes Results Root", + description="Directory where sae-probes will save its per-dataset JSONs.", + ) + + model_cache_path: str | None = Field( + default=None, + title="Model Activations Cache", + description="Optional path where sae-probes will cache generated model activations.", + ) + + include_llm_baseline: bool = Field( + default=True, + title="Include LLM Baseline", + description="If True, also run sae-probes baselines on model residual stream and aggregate.", + ) + + baseline_method: str = Field( + default="logreg", + title="Baseline Method", + description="sae-probes baseline method (e.g., 'logreg').", + ) diff --git a/sae_bench/evals/sparse_probing_sae_probes/eval_output.py b/sae_bench/evals/sparse_probing_sae_probes/eval_output.py new file mode 100644 index 00000000..8cf541f9 --- /dev/null +++ b/sae_bench/evals/sparse_probing_sae_probes/eval_output.py @@ -0,0 +1,156 @@ +from pydantic import ConfigDict, Field +from pydantic.dataclasses import dataclass + +from sae_bench.evals.base_eval_output import ( + DEFAULT_DISPLAY, + BaseEvalOutput, + BaseMetricCategories, + BaseMetrics, + BaseResultDetail, +) +from sae_bench.evals.sparse_probing_sae_probes.eval_config import ( + SparseProbingSaeProbesEvalConfig, +) + +EVAL_TYPE_ID_SPARSE_PROBING_SAE_PROBES = "sparse_probing_sae_probes" + + +@dataclass +class SaeProbesLlmMetrics(BaseMetrics): + llm_test_accuracy: float | None = Field( + default=None, + title="LLM Test Accuracy", + description="Linear probe accuracy when training on the full LLM residual stream", + json_schema_extra=DEFAULT_DISPLAY, + ) + llm_test_auc: float | None = Field( + default=None, + title="LLM Test AUC", + description="Linear probe AUC when training on the full LLM residual stream", + json_schema_extra=DEFAULT_DISPLAY, + ) + llm_test_f1: float | None = Field( + default=None, + title="LLM Test F1", + description="Linear probe F1 score when training on the full LLM residual stream", + json_schema_extra=DEFAULT_DISPLAY, + ) + + +@dataclass +class SaeProbesSaeMetrics(BaseMetrics): + sae_test_accuracy: float | None = Field( + default=None, + title="SAE Test Accuracy", + description="Linear probe accuracy when trained on all SAE latents", + json_schema_extra=DEFAULT_DISPLAY, + ) + sae_top_1_test_accuracy: float | None = Field( + default=None, json_schema_extra=DEFAULT_DISPLAY + ) + sae_top_1_test_auc: float | None = Field( + default=None, json_schema_extra=DEFAULT_DISPLAY + ) + sae_top_1_test_f1: float | None = Field( + default=None, json_schema_extra=DEFAULT_DISPLAY + ) + sae_top_2_test_accuracy: float | None = Field( + default=None, json_schema_extra=DEFAULT_DISPLAY + ) + sae_top_2_test_auc: float | None = Field(default=None) + sae_top_2_test_f1: float | None = Field(default=None) + sae_top_5_test_accuracy: float | None = Field( + default=None, json_schema_extra=DEFAULT_DISPLAY + ) + sae_top_5_test_auc: float | None = Field(default=None) + sae_top_5_test_f1: float | None = Field(default=None) + sae_top_10_test_accuracy: float | None = Field(default=None) + sae_top_10_test_auc: float | None = Field(default=None) + sae_top_10_test_f1: float | None = Field(default=None) + sae_top_20_test_accuracy: float | None = Field(default=None) + sae_top_20_test_auc: float | None = Field(default=None) + sae_top_20_test_f1: float | None = Field(default=None) + sae_top_50_test_accuracy: float | None = Field(default=None) + sae_top_50_test_auc: float | None = Field(default=None) + sae_top_50_test_f1: float | None = Field(default=None) + sae_top_100_test_accuracy: float | None = Field(default=None) + sae_top_100_test_auc: float | None = Field(default=None) + sae_top_100_test_f1: float | None = Field(default=None) + + +@dataclass +class SaeProbesMetricCategories(BaseMetricCategories): + llm: SaeProbesLlmMetrics = Field( + title="LLM", + description="LLM metrics", + json_schema_extra=DEFAULT_DISPLAY, + ) + sae: SaeProbesSaeMetrics = Field( + title="SAE", + description="SAE metrics", + json_schema_extra=DEFAULT_DISPLAY, + ) + + +@dataclass +class SaeProbesResultDetail(BaseResultDetail): + dataset_name: str = Field(title="Dataset Name", description="Dataset name") + llm_test_accuracy: float | None = Field(default=None) + llm_test_auc: float | None = Field(default=None) + llm_test_f1: float | None = Field(default=None) + sae_test_accuracy: float | None = Field(default=None) + sae_top_1_test_accuracy: float | None = Field(default=None) + sae_top_1_test_auc: float | None = Field(default=None) + sae_top_1_test_f1: float | None = Field(default=None) + sae_top_2_test_accuracy: float | None = Field(default=None) + sae_top_2_test_auc: float | None = Field(default=None) + sae_top_2_test_f1: float | None = Field(default=None) + sae_top_5_test_accuracy: float | None = Field(default=None) + sae_top_5_test_auc: float | None = Field(default=None) + sae_top_5_test_f1: float | None = Field(default=None) + sae_top_10_test_accuracy: float | None = Field(default=None) + sae_top_10_test_auc: float | None = Field(default=None) + sae_top_10_test_f1: float | None = Field(default=None) + sae_top_20_test_accuracy: float | None = Field(default=None) + sae_top_20_test_auc: float | None = Field(default=None) + sae_top_20_test_f1: float | None = Field(default=None) + sae_top_50_test_accuracy: float | None = Field(default=None) + sae_top_50_test_auc: float | None = Field(default=None) + sae_top_50_test_f1: float | None = Field(default=None) + sae_top_100_test_accuracy: float | None = Field(default=None) + sae_top_100_test_auc: float | None = Field(default=None) + sae_top_100_test_f1: float | None = Field(default=None) + sae_metrics_by_k: dict[int, dict[str, float]] | None = Field( + default=None, + title="SAE Metrics by K", + description="Per-dataset metrics for arbitrary k values. Maps k -> {test_accuracy, test_auc, test_f1}", + ) + + +@dataclass(config=ConfigDict(title="Sparse Probing (sae-probes)")) +class SparseProbingSaeProbesEvalOutput( + BaseEvalOutput[ + SparseProbingSaeProbesEvalConfig, + SaeProbesMetricCategories, + SaeProbesResultDetail, + ] +): + """ + Wraps sae-probes sparse probing benchmark and collates per-dataset JSONs into a single SAEBench output. + """ + + eval_config: SparseProbingSaeProbesEvalConfig + eval_id: str + datetime_epoch_millis: int + eval_result_metrics: SaeProbesMetricCategories + eval_result_details: list[SaeProbesResultDetail] = Field( + default_factory=list, + title="Per-Dataset Sparse Probing Results", + description="Per-dataset probe accuracies aggregated from sae-probes output.", + ) + sae_metrics_by_k: dict[int, dict[str, float]] | None = Field( + default=None, + title="SAE Metrics by K", + description="SAE metrics for arbitrary k values. Maps k -> {test_accuracy, test_auc, test_f1}", + ) + eval_type_id: str = Field(default=EVAL_TYPE_ID_SPARSE_PROBING_SAE_PROBES) diff --git a/sae_bench/evals/sparse_probing_sae_probes/main.py b/sae_bench/evals/sparse_probing_sae_probes/main.py new file mode 100644 index 00000000..92221669 --- /dev/null +++ b/sae_bench/evals/sparse_probing_sae_probes/main.py @@ -0,0 +1,365 @@ +import argparse +import json +import os +import time +from dataclasses import asdict +from datetime import datetime +from pathlib import Path +from typing import Any + +from sae_lens import SAE +from sae_probes import run_baseline_evals, run_sae_evals +from tqdm import tqdm + +import sae_bench.sae_bench_utils.general_utils as general_utils +from sae_bench.evals.sparse_probing_sae_probes.eval_config import ( + SparseProbingSaeProbesEvalConfig, +) +from sae_bench.evals.sparse_probing_sae_probes.eval_output import ( + SaeProbesLlmMetrics, + SaeProbesMetricCategories, + SaeProbesResultDetail, + SaeProbesSaeMetrics, + SparseProbingSaeProbesEvalOutput, +) +from sae_bench.sae_bench_utils import ( + get_eval_uuid, + get_sae_bench_version, + get_sae_lens_version, +) +from sae_bench.sae_bench_utils.sae_selection_utils import get_saes_from_regex + + +def _sae_probes_results_glob( + results_root: str, model_name: str, setting: str, prefix: str = "sae_probes" +) -> list[Path]: + root = Path(results_root) / f"{prefix}_{model_name}" / f"{setting}_setting" + return sorted(list(root.glob("*.json"))) + + +def _parse_dataset_from_filename(path: Path) -> str: + # Filenames look like: "119_us_state_TX_blocks.4.hook_resid_post_l1.json" + # Dataset short name is the prefix until the first occurrence of "_blocks." + stem = path.stem + if "_blocks." in stem: + return stem.split("_blocks.")[0] + return stem + + +def _aggregate_metrics_from_sae_probes_json( + file_path: Path, +) -> dict[str, float]: + with open(file_path) as f: + data: list[dict[str, Any]] = json.load(f) + # sae-probes saves a list of entries, one per K (and possibly metadata entries) + # Each entry contains keys: {"k", "test_acc", "test_auc", "test_f1", ...} + k_to_metrics: dict[int, dict[str, float]] = {} + for entry in data: + if "k" in entry and "test_acc" in entry: + try: + k = int(entry["k"]) # type: ignore[arg-type] + metrics = { + "test_accuracy": float(entry["test_acc"]), # type: ignore[arg-type] + } + if "test_auc" in entry: + metrics["test_auc"] = float(entry["test_auc"]) # type: ignore[arg-type] + if "test_f1" in entry: + metrics["test_f1"] = float(entry["test_f1"]) # type: ignore[arg-type] + k_to_metrics[k] = metrics + except Exception: + continue + out: dict[str, float] = {} + for k, metrics in k_to_metrics.items(): + for metric_name, value in metrics.items(): + out[f"sae_top_{k}_{metric_name}"] = value + return out + + +def _mean_of_keys(dicts: list[dict[str, float]], key: str) -> float | None: + vals = [d[key] for d in dicts if key in d] + if not vals: + return None + return float(sum(vals) / len(vals)) + + +def run_eval( + config: SparseProbingSaeProbesEvalConfig, + selected_saes: list[tuple[str, SAE]] | list[tuple[str, str]], + device: str, + output_path: str, + force_rerun: bool = False, +) -> dict[str, dict[str, Any]]: + if config.setting != "normal": + raise NotImplementedError( + "Only 'normal' setting is supported for sparse_probing_sae_probes aggregation currently." + ) + eval_instance_id = get_eval_uuid() + sae_lens_version = get_sae_lens_version() + sae_bench_commit_hash = get_sae_bench_version() + + os.makedirs(output_path, exist_ok=True) + os.makedirs(config.results_path, exist_ok=True) + + results_dict: dict[str, dict[str, Any]] = {} + + for sae_release, sae_object_or_id in tqdm( + selected_saes, desc="Running sae-probes on selected SAEs" + ): + sae_id, sae, sparsity = general_utils.load_and_format_sae( # type: ignore + sae_release, sae_object_or_id, device + ) + + sae_result_path = general_utils.get_results_filepath( + output_path, sae_release, sae_id + ) + + if os.path.exists(sae_result_path) and not force_rerun: + print(f"Skipping {sae_release}_{sae_id} as results already exist") + continue + + # Run sae-probes (idempotent; will skip if JSONs exist) + run_sae_evals( + sae=sae, + model_name=config.model_name, + hook_name=sae.cfg.hook_name, + reg_type=config.reg_type, # type: ignore[arg-type] + setting=config.setting, # type: ignore[arg-type] + ks=config.ks, + binarize=config.binarize, + results_path=config.results_path, + model_cache_path=config.model_cache_path, + datasets=config.dataset_names, + device=device, + ) + + # Collect per-dataset JSONs and collate (filter by hook/reg to avoid stale files) + expected_suffix = f"_{sae.cfg.hook_name}_{config.reg_type}.json" + json_files = [ + f + for f in _sae_probes_results_glob( + config.results_path, config.model_name, config.setting + ) + if f.name.endswith(expected_suffix) + ] + per_dataset_details: list[SaeProbesResultDetail] = [] + per_dataset_metric_dicts: list[dict[str, float]] = [] + for jf in json_files: + ds_name = _parse_dataset_from_filename(jf) + ds_metrics = _aggregate_metrics_from_sae_probes_json(jf) + per_dataset_metric_dicts.append( + {k: v for k, v in ds_metrics.items() if k.startswith("sae_top_")} + ) + + # Build sae_metrics_by_k dictionary for this dataset + ds_metrics_by_k: dict[int, dict[str, float]] = {} + for k in config.ks: + k_metrics = {} + for metric in ["test_accuracy", "test_auc", "test_f1"]: + key = f"sae_top_{k}_{metric}" + if key in ds_metrics: + k_metrics[metric] = ds_metrics[key] + if k_metrics: + ds_metrics_by_k[k] = k_metrics + + per_dataset_details.append( + SaeProbesResultDetail( + dataset_name=ds_name, + sae_metrics_by_k=ds_metrics_by_k if ds_metrics_by_k else None, + **ds_metrics, + ) + ) + + # Aggregate across datasets (mean per-k) + agg_metrics_dict: dict[str, float | None] = {} + agg_metrics_by_k: dict[int, dict[str, float]] = {} + for k in config.ks: + k_metrics: dict[str, float] = {} + for metric in ["test_accuracy", "test_auc", "test_f1"]: + key = f"sae_top_{k}_{metric}" + mean_val = _mean_of_keys(per_dataset_metric_dicts, key) + agg_metrics_dict[key] = mean_val + if mean_val is not None: + k_metrics[metric] = mean_val + if k_metrics: + agg_metrics_by_k[k] = k_metrics + + llm_metrics = SaeProbesLlmMetrics() + if config.include_llm_baseline: + # Run baseline evals (idempotent) and parse results to populate llm metrics + run_baseline_evals( + model_name=config.model_name, + hook_name=sae.cfg.hook_name, + setting=config.setting, # type: ignore[arg-type] + method=config.baseline_method, # type: ignore[arg-type] + results_path=config.results_path, + model_cache_path=config.model_cache_path, + datasets=config.dataset_names, + device=device, + ) + # Baseline JSON pattern: baseline_results_{model_name}/{setting}_setting/{dataset}_{hook}_{method}.json + baseline_suffix = f"_{sae.cfg.hook_name}_{config.baseline_method}.json" + baseline_files = [ + f + for f in _sae_probes_results_glob( + config.results_path, + config.model_name, + config.setting, + prefix="baseline_results", + ) + if f.name.endswith(baseline_suffix) + ] + # compute overall mean test_acc, test_auc, test_f1 across datasets + llm_accs: list[float] = [] + llm_aucs: list[float] = [] + llm_f1s: list[float] = [] + per_ds_metrics: dict[str, dict[str, float]] = {} + for bf in baseline_files: + ds_name = _parse_dataset_from_filename(bf) + with open(bf) as f: + entries: list[dict[str, Any]] = json.load(f) + # baselines save a single-element list + if entries and "test_acc" in entries[0]: + metrics = {} + if "test_acc" in entries[0]: + acc = float(entries[0]["test_acc"]) # type: ignore[arg-type] + metrics["test_accuracy"] = acc + llm_accs.append(acc) + if "test_auc" in entries[0]: + auc = float(entries[0]["test_auc"]) # type: ignore[arg-type] + metrics["test_auc"] = auc + llm_aucs.append(auc) + if "test_f1" in entries[0]: + f1 = float(entries[0]["test_f1"]) # type: ignore[arg-type] + metrics["test_f1"] = f1 + llm_f1s.append(f1) + per_ds_metrics[ds_name] = metrics + if llm_accs: + llm_metrics.llm_test_accuracy = float(sum(llm_accs) / len(llm_accs)) + if llm_aucs: + llm_metrics.llm_test_auc = float(sum(llm_aucs) / len(llm_aucs)) + if llm_f1s: + llm_metrics.llm_test_f1 = float(sum(llm_f1s) / len(llm_f1s)) + # attach per-dataset baseline to details + for detail in per_dataset_details: + if detail.dataset_name in per_ds_metrics: + metrics = per_ds_metrics[detail.dataset_name] + if "test_accuracy" in metrics: + detail.llm_test_accuracy = metrics["test_accuracy"] + if "test_auc" in metrics: + detail.llm_test_auc = metrics["test_auc"] + if "test_f1" in metrics: + detail.llm_test_f1 = metrics["test_f1"] + + eval_output = SparseProbingSaeProbesEvalOutput( + eval_config=config, + eval_id=eval_instance_id, + datetime_epoch_millis=int(datetime.now().timestamp() * 1000), + eval_result_metrics=SaeProbesMetricCategories( + llm=llm_metrics, + sae=SaeProbesSaeMetrics(**agg_metrics_dict), # type: ignore[arg-type] + ), + eval_result_details=per_dataset_details, + sae_metrics_by_k=agg_metrics_by_k if agg_metrics_by_k else None, + eval_result_unstructured=None, + sae_bench_commit_hash=sae_bench_commit_hash, + sae_lens_id=sae_id, + sae_lens_release_id=sae_release, + sae_lens_version=sae_lens_version, + sae_cfg_dict=asdict(sae.cfg), + ) + + results_dict[f"{sae_release}_{sae_id}"] = asdict(eval_output) + eval_output.to_json_file(sae_result_path, indent=2) + + return results_dict + + +def create_config_and_selected_saes( + args, +) -> tuple[SparseProbingSaeProbesEvalConfig, list[tuple[str, str]]]: + config = SparseProbingSaeProbesEvalConfig( + model_name=args.model_name, + reg_type=args.reg_type, + setting=args.setting, + ks=args.ks, + binarize=args.binarize, + results_path=args.results_path, + model_cache_path=args.model_cache_path, + ) + + selected_saes = get_saes_from_regex(args.sae_regex_pattern, args.sae_block_pattern) + assert len(selected_saes) > 0, "No SAEs selected" + + releases = set([release for release, _ in selected_saes]) + print(f"Selected SAEs from releases: {releases}") + for release, sae in selected_saes: + print(f"Sample SAEs: {release}, {sae}") + + return config, selected_saes + + +def arg_parser(): + parser = argparse.ArgumentParser(description="Run sae-probes sparse probing eval") + parser.add_argument("--model_name", type=str, required=True) + parser.add_argument("--sae_regex_pattern", type=str, required=True) + parser.add_argument("--sae_block_pattern", type=str, required=True) + parser.add_argument( + "--reg_type", + type=str, + default="l1", + choices=["l1", "l2"], + help="sae-probes regularization type", + ) + parser.add_argument( + "--setting", + type=str, + default="normal", + choices=["normal", "scarcity", "imbalance"], + help="sae-probes data-balance setting", + ) + parser.add_argument( + "--ks", + type=int, + nargs="+", + default=[1, 2, 5], + help="List of K values", + ) + parser.add_argument("--binarize", action="store_true") + parser.add_argument( + "--results_path", + type=str, + default="artifacts/sparse_probing_sae_probes", + help="Directory where sae-probes writes JSONs", + ) + parser.add_argument( + "--model_cache_path", + type=str, + default=None, + help="Optional directory to persist model activations", + ) + parser.add_argument( + "--output_folder", + type=str, + default="eval_results/sparse_probing_sae_probes", + help="SAEBench output folder", + ) + parser.add_argument("--force_rerun", action="store_true") + return parser + + +if __name__ == "__main__": + args = arg_parser().parse_args() + device = general_utils.setup_environment() + + start_time = time.time() + config, selected_saes = create_config_and_selected_saes(args) + os.makedirs(args.output_folder, exist_ok=True) + run_eval( + config, + selected_saes, + device, + args.output_folder, + force_rerun=args.force_rerun, + ) + end_time = time.time() + print(f"Finished evaluation in {end_time - start_time} seconds") diff --git a/tests/conftest.py b/tests/conftest.py index ce6b2ac6..cccc4676 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,9 +24,11 @@ def fake_mistral_tokenizer(): @pytest.fixture def gpt2_l4_sae() -> SAE: - return SAE.from_pretrained( + sae = SAE.from_pretrained( "gpt2-small-res-jb", "blocks.4.hook_resid_pre", device="cpu" )[0] + sae.fold_W_dec_norm() + return sae @pytest.fixture @@ -40,9 +42,11 @@ def gpt2_l4_sae_sparsity() -> torch.Tensor: @pytest.fixture def gpt2_l5_sae() -> SAE: - return SAE.from_pretrained( + sae = SAE.from_pretrained( "gpt2-small-res-jb", "blocks.5.hook_resid_pre", device="cpu" )[0] + sae.fold_W_dec_norm() + return sae @pytest.fixture diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/evals/__init__.py b/tests/unit/evals/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/evals/absorption/__init__.py b/tests/unit/evals/absorption/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/evals/autointerp/__init__.py b/tests/unit/evals/autointerp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/evals/scr_and_tpp/__init__.py b/tests/unit/evals/scr_and_tpp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/evals/sparse_probing_sae_probes/__init__.py b/tests/unit/evals/sparse_probing_sae_probes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/evals/sparse_probing_sae_probes/test_main.py b/tests/unit/evals/sparse_probing_sae_probes/test_main.py new file mode 100644 index 00000000..34c7be14 --- /dev/null +++ b/tests/unit/evals/sparse_probing_sae_probes/test_main.py @@ -0,0 +1,231 @@ +import json +from pathlib import Path + +from sae_lens.sae import SAE + +from sae_bench.evals.sparse_probing_sae_probes.eval_config import ( + SparseProbingSaeProbesEvalConfig, +) +from sae_bench.evals.sparse_probing_sae_probes.eval_output import ( + SparseProbingSaeProbesEvalOutput, +) +from sae_bench.evals.sparse_probing_sae_probes.main import run_eval + + +def test_run_eval_without_baselines(gpt2_l4_sae: SAE, tmp_path: Path): + output_path = tmp_path / "test_output" + artifacts_path = tmp_path / "test_artifacts" + model_cache_path = tmp_path / "model_cache" + config = SparseProbingSaeProbesEvalConfig( + model_name="gpt2", + include_llm_baseline=False, + model_cache_path=str(model_cache_path), + results_path=str(artifacts_path), + dataset_names=["118_us_state_CA", "119_us_state_TX"], + ) + results_dict = run_eval( + config, + [("gpt2_l4_sae", gpt2_l4_sae)], + device="cpu", + output_path=str(output_path), + ) + + assert isinstance(results_dict, dict) + assert len(results_dict) == 1 + assert "gpt2_l4_sae_custom_sae" in results_dict + + result_data = results_dict["gpt2_l4_sae_custom_sae"] + assert isinstance(result_data, dict) + assert result_data["eval_type_id"] == "sparse_probing_sae_probes" + assert result_data["sae_lens_release_id"] == "gpt2_l4_sae" + assert result_data["sae_lens_id"] == "custom_sae" + + expected_output_file = output_path / "gpt2_l4_sae_custom_sae_eval_results.json" + assert expected_output_file.exists(), "Main output JSON file should exist" + + with open(expected_output_file) as f: + output_data = json.load(f) + + assert result_data["eval_type_id"] == output_data["eval_type_id"] + assert result_data["sae_lens_release_id"] == output_data["sae_lens_release_id"] + assert result_data["eval_result_metrics"] == output_data["eval_result_metrics"] + + eval_output = SparseProbingSaeProbesEvalOutput(**output_data) + + assert eval_output.eval_type_id == "sparse_probing_sae_probes" + assert eval_output.sae_lens_release_id == "gpt2_l4_sae" + assert eval_output.sae_lens_id == "custom_sae" + assert eval_output.eval_config.model_name == "gpt2" + + assert eval_output.eval_result_metrics.sae.sae_top_1_test_accuracy is not None + assert 0 <= eval_output.eval_result_metrics.sae.sae_top_1_test_accuracy <= 1 + assert eval_output.eval_result_metrics.sae.sae_top_1_test_auc is not None + assert 0 <= eval_output.eval_result_metrics.sae.sae_top_1_test_auc <= 1 + assert eval_output.eval_result_metrics.sae.sae_top_1_test_f1 is not None + assert 0 <= eval_output.eval_result_metrics.sae.sae_top_1_test_f1 <= 1 + assert eval_output.eval_result_metrics.sae.sae_top_2_test_accuracy is not None + assert 0 <= eval_output.eval_result_metrics.sae.sae_top_2_test_accuracy <= 1 + assert eval_output.eval_result_metrics.sae.sae_top_5_test_accuracy is not None + assert 0 <= eval_output.eval_result_metrics.sae.sae_top_5_test_accuracy <= 1 + + assert eval_output.eval_result_metrics.llm.llm_test_accuracy is None + assert eval_output.eval_result_metrics.llm.llm_test_auc is None + assert eval_output.eval_result_metrics.llm.llm_test_f1 is None + + assert eval_output.sae_metrics_by_k is not None + assert set(eval_output.sae_metrics_by_k.keys()) == {1, 2, 5} + + assert len(eval_output.eval_result_details) == 2 + dataset_names = {detail.dataset_name for detail in eval_output.eval_result_details} + assert dataset_names == {"118_us_state_CA", "119_us_state_TX"} + + for detail in eval_output.eval_result_details: + assert detail.sae_top_1_test_accuracy is not None + assert 0 <= detail.sae_top_1_test_accuracy <= 1 + assert detail.sae_top_1_test_auc is not None + assert 0 <= detail.sae_top_1_test_auc <= 1 + assert detail.sae_top_1_test_f1 is not None + assert 0 <= detail.sae_top_1_test_f1 <= 1 + assert detail.llm_test_accuracy is None + assert detail.llm_test_auc is None + assert detail.llm_test_f1 is None + + sae_probes_results_dir = artifacts_path / "sae_probes_gpt2" / "normal_setting" + assert sae_probes_results_dir.exists() + json_files = list(sae_probes_results_dir.glob("*.json")) + assert len(json_files) >= 2 + + +def test_run_eval_with_baselines(gpt2_l4_sae: SAE, tmp_path: Path): + output_path = tmp_path / "test_output" + artifacts_path = tmp_path / "test_artifacts" + model_cache_path = tmp_path / "model_cache" + config = SparseProbingSaeProbesEvalConfig( + model_name="gpt2", + include_llm_baseline=True, + model_cache_path=str(model_cache_path), + results_path=str(artifacts_path), + dataset_names=["118_us_state_CA"], + ) + results_dict = run_eval( + config, + [("gpt2_l4_sae", gpt2_l4_sae)], + device="cpu", + output_path=str(output_path), + ) + + assert isinstance(results_dict, dict) + assert len(results_dict) == 1 + assert "gpt2_l4_sae_custom_sae" in results_dict + + result_data = results_dict["gpt2_l4_sae_custom_sae"] + assert isinstance(result_data, dict) + assert "eval_result_metrics" in result_data + assert "llm" in result_data["eval_result_metrics"] + assert result_data["eval_result_metrics"]["llm"]["llm_test_accuracy"] is not None + + expected_output_file = output_path / "gpt2_l4_sae_custom_sae_eval_results.json" + assert expected_output_file.exists() + + with open(expected_output_file) as f: + output_data = json.load(f) + + assert result_data["eval_type_id"] == output_data["eval_type_id"] + assert result_data["sae_lens_release_id"] == output_data["sae_lens_release_id"] + assert result_data["eval_result_metrics"] == output_data["eval_result_metrics"] + + eval_output = SparseProbingSaeProbesEvalOutput(**output_data) + + assert eval_output.eval_result_metrics.llm.llm_test_accuracy is not None + assert 0 <= eval_output.eval_result_metrics.llm.llm_test_accuracy <= 1 + assert eval_output.eval_result_metrics.llm.llm_test_auc is not None + assert 0 <= eval_output.eval_result_metrics.llm.llm_test_auc <= 1 + assert eval_output.eval_result_metrics.llm.llm_test_f1 is not None + assert 0 <= eval_output.eval_result_metrics.llm.llm_test_f1 <= 1 + + assert len(eval_output.eval_result_details) == 1 + detail = eval_output.eval_result_details[0] + assert detail.dataset_name == "118_us_state_CA" + assert detail.llm_test_accuracy is not None + assert 0 <= detail.llm_test_accuracy <= 1 + assert detail.llm_test_auc is not None + assert 0 <= detail.llm_test_auc <= 1 + assert detail.llm_test_f1 is not None + assert 0 <= detail.llm_test_f1 <= 1 + assert detail.sae_top_1_test_accuracy is not None + assert detail.sae_top_1_test_auc is not None + assert detail.sae_top_1_test_f1 is not None + + baseline_results_dir = artifacts_path / "baseline_results_gpt2" / "normal_setting" + assert baseline_results_dir.exists() + baseline_json_files = list(baseline_results_dir.glob("*.json")) + assert len(baseline_json_files) >= 1 + + +def test_run_eval_with_custom_ks(gpt2_l4_sae: SAE, tmp_path: Path): + output_path = tmp_path / "test_output" + artifacts_path = tmp_path / "test_artifacts" + model_cache_path = tmp_path / "model_cache" + custom_ks = [3, 7, 15] + config = SparseProbingSaeProbesEvalConfig( + model_name="gpt2", + include_llm_baseline=True, + model_cache_path=str(model_cache_path), + results_path=str(artifacts_path), + dataset_names=["118_us_state_CA"], + ks=custom_ks, + ) + results_dict = run_eval( + config, + [("gpt2_l4_sae", gpt2_l4_sae)], + device="cpu", + output_path=str(output_path), + ) + + assert isinstance(results_dict, dict) + assert len(results_dict) == 1 + + expected_output_file = output_path / "gpt2_l4_sae_custom_sae_eval_results.json" + assert expected_output_file.exists() + + with open(expected_output_file) as f: + output_data = json.load(f) + + eval_output = SparseProbingSaeProbesEvalOutput(**output_data) + + assert eval_output.eval_result_metrics.llm.llm_test_accuracy is not None + assert 0 <= eval_output.eval_result_metrics.llm.llm_test_accuracy <= 1 + assert eval_output.eval_result_metrics.llm.llm_test_auc is not None + assert 0 <= eval_output.eval_result_metrics.llm.llm_test_auc <= 1 + assert eval_output.eval_result_metrics.llm.llm_test_f1 is not None + assert 0 <= eval_output.eval_result_metrics.llm.llm_test_f1 <= 1 + + assert "sae_metrics_by_k" in output_data + sae_metrics_by_k = eval_output.sae_metrics_by_k + assert sae_metrics_by_k is not None + assert set(sae_metrics_by_k.keys()) == {3, 7, 15} + + for k in custom_ks: + metrics = sae_metrics_by_k[k] + assert "test_accuracy" in metrics + assert "test_auc" in metrics + assert "test_f1" in metrics + assert 0 <= metrics["test_accuracy"] <= 1 + assert 0 <= metrics["test_auc"] <= 1 + assert 0 <= metrics["test_f1"] <= 1 + + detail = eval_output.eval_result_details[0] + assert detail.dataset_name == "118_us_state_CA" + assert detail.llm_test_accuracy is not None + assert 0 <= detail.llm_test_accuracy <= 1 + assert detail.llm_test_auc is not None + assert 0 <= detail.llm_test_auc <= 1 + assert detail.llm_test_f1 is not None + assert 0 <= detail.llm_test_f1 <= 1 + assert detail.sae_metrics_by_k is not None + assert set(detail.sae_metrics_by_k.keys()) == {3, 7, 15} + + for k in custom_ks: + metrics = detail.sae_metrics_by_k[k] + assert "test_accuracy" in metrics + assert 0 <= metrics["test_accuracy"] <= 1 From e19650649fd431d2e18333de9756ce3d6ca2be39 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Sat, 4 Oct 2025 18:41:24 +0100 Subject: [PATCH 2/7] updating more docs / eval locations --- README.md | 3 +-- sae_bench/custom_saes/run_all_evals_custom_saes.py | 14 ++++++++++++++ .../run_all_evals_dictionary_learning_saes.py | 14 ++++++++++++++ 3 files changed, 29 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ee7f9f69..24bf2a98 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,6 @@ - [Training Your Own SAEs](#training-your-own-saes) - [Graphing Results](#graphing-results) - ## Overview SAE Bench is a comprehensive suite of 8 evaluations for Sparse Autoencoder (SAE) models: @@ -21,6 +20,7 @@ SAE Bench is a comprehensive suite of 8 evaluations for Sparse Autoencoder (SAE) - **[Spurious Correlation Removal (SCR)](https://arxiv.org/abs/2411.18895)** - **[Targeted Probe Pertubation (TPP)](https://arxiv.org/abs/2411.18895)** - **Sparse Probing** +- **[Sparse Probing (SAE Probes version)](https://arxiv.org/pdf/2502.16681)** - **[Unlearning](https://arxiv.org/abs/2410.19278)** For more information, refer to our [blog post](https://www.neuronpedia.org/sae-bench/info). @@ -136,7 +136,6 @@ The total evaluation time for a single SAE across all benchmarks is approximatel | RAVEL | 45 | 45 | | **Total** | **110** | **152** | - # SAE Bench Baseline Suite We provide a suite of baseline SAEs. We have the following 7 SAE varieties: diff --git a/sae_bench/custom_saes/run_all_evals_custom_saes.py b/sae_bench/custom_saes/run_all_evals_custom_saes.py index 87ed0d4c..b87e18e4 100644 --- a/sae_bench/custom_saes/run_all_evals_custom_saes.py +++ b/sae_bench/custom_saes/run_all_evals_custom_saes.py @@ -10,6 +10,7 @@ import sae_bench.evals.ravel.main as ravel import sae_bench.evals.scr_and_tpp.main as scr_and_tpp import sae_bench.evals.sparse_probing.main as sparse_probing +import sae_bench.evals.sparse_probing_sae_probes.main as sparse_probing_sae_probes import sae_bench.evals.unlearning.main as unlearning import sae_bench.sae_bench_utils.general_utils as general_utils @@ -37,6 +38,7 @@ "scr": "eval_results/scr", "tpp": "eval_results/tpp", "sparse_probing": "eval_results/sparse_probing", + "sparse_probing_sae_probes": "eval_results/sparse_probing_sae_probes", "unlearning": "eval_results/unlearning", "ravel": "eval_results/ravel", } @@ -171,6 +173,17 @@ def run_evals( save_activations=save_activations, ) ), + "sparse_probing_sae_probes": ( + lambda: sparse_probing_sae_probes.run_eval( + sparse_probing_sae_probes.SparseProbingSaeProbesEvalConfig( + model_name=model_name, + ), + selected_saes, + device, + "eval_results/sparse_probing_sae_probes", + force_rerun, + ) + ), "unlearning": ( lambda: unlearning.run_eval( unlearning.UnlearningEvalConfig( @@ -237,6 +250,7 @@ def run_evals( "scr", "tpp", "sparse_probing", + "sparse_probing_sae_probes", "unlearning", ] diff --git a/sae_bench/custom_saes/run_all_evals_dictionary_learning_saes.py b/sae_bench/custom_saes/run_all_evals_dictionary_learning_saes.py index f10a49f3..8270130b 100644 --- a/sae_bench/custom_saes/run_all_evals_dictionary_learning_saes.py +++ b/sae_bench/custom_saes/run_all_evals_dictionary_learning_saes.py @@ -17,6 +17,7 @@ import sae_bench.evals.ravel.main as ravel import sae_bench.evals.scr_and_tpp.main as scr_and_tpp import sae_bench.evals.sparse_probing.main as sparse_probing +import sae_bench.evals.sparse_probing_sae_probes.main as sparse_probing_sae_probes import sae_bench.evals.unlearning.main as unlearning import sae_bench.sae_bench_utils.general_utils as general_utils @@ -48,6 +49,7 @@ "scr": "eval_results/scr", "tpp": "eval_results/tpp", "sparse_probing": "eval_results/sparse_probing", + "sparse_probing_sae_probes": "eval_results/sparse_probing_sae_probes", "unlearning": "eval_results/unlearning", "ravel": "eval_results/ravel", } @@ -278,6 +280,17 @@ def run_evals( save_activations=True, ) ), + "sparse_probing_sae_probes": ( + lambda selected_saes, is_final: sparse_probing_sae_probes.run_eval( + sparse_probing_sae_probes.SparseProbingSaeProbesEvalConfig( + model_name=model_name, + ), + selected_saes, + device, + "eval_results/sparse_probing_sae_probes", + force_rerun, + ) + ), "unlearning": ( lambda selected_saes, is_final: unlearning.run_eval( unlearning.UnlearningEvalConfig( @@ -375,6 +388,7 @@ def run_evals( "scr", "tpp", "sparse_probing", + "sparse_probing_sae_probes", "autointerp", # "unlearning", "ravel", From db48f97dc0c850752cbd0fa536677e48d666178e Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 23 Dec 2025 15:35:38 -0500 Subject: [PATCH 3/7] fixing formatting and import in test --- tests/acceptance/test_meta_structure.py | 20 +++++++++++++------ .../sparse_probing_sae_probes/test_main.py | 2 +- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/acceptance/test_meta_structure.py b/tests/acceptance/test_meta_structure.py index 8c26ea7f..b1f2b57b 100644 --- a/tests/acceptance/test_meta_structure.py +++ b/tests/acceptance/test_meta_structure.py @@ -54,9 +54,17 @@ def test_meta_structure_eval_matches_fixture(tmp_path): actual = json.load(f) actual_metrics = _load_metrics(actual) - assert pytest.approx( - expected_metrics["decoder_fraction_variance_explained"], rel=1e-2, - ) == actual_metrics["decoder_fraction_variance_explained"] - assert pytest.approx( - expected_metrics["final_reconstruction_mse"], rel=1e-2, - ) == actual_metrics["final_reconstruction_mse"] + assert ( + pytest.approx( + expected_metrics["decoder_fraction_variance_explained"], + rel=1e-2, + ) + == actual_metrics["decoder_fraction_variance_explained"] + ) + assert ( + pytest.approx( + expected_metrics["final_reconstruction_mse"], + rel=1e-2, + ) + == actual_metrics["final_reconstruction_mse"] + ) diff --git a/tests/unit/evals/sparse_probing_sae_probes/test_main.py b/tests/unit/evals/sparse_probing_sae_probes/test_main.py index 34c7be14..316060f4 100644 --- a/tests/unit/evals/sparse_probing_sae_probes/test_main.py +++ b/tests/unit/evals/sparse_probing_sae_probes/test_main.py @@ -1,7 +1,7 @@ import json from pathlib import Path -from sae_lens.sae import SAE +from sae_lens import SAE from sae_bench.evals.sparse_probing_sae_probes.eval_config import ( SparseProbingSaeProbesEvalConfig, From d3f88362e1e2391493399e70df5071fbd79655ae Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 23 Dec 2025 15:40:47 -0500 Subject: [PATCH 4/7] ignoring type error --- sae_bench/evals/meta_structure/eval_output.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sae_bench/evals/meta_structure/eval_output.py b/sae_bench/evals/meta_structure/eval_output.py index 74b679f5..b1f5ed73 100644 --- a/sae_bench/evals/meta_structure/eval_output.py +++ b/sae_bench/evals/meta_structure/eval_output.py @@ -53,7 +53,7 @@ class MetaStructureEvalOutput( eval_id: str datetime_epoch_millis: int eval_result_metrics: MetaStructureMetricCategories - eval_result_details: list[BaseResultDetail] | None = None + eval_result_details: list[BaseResultDetail] | None = None # pyright: ignore[reportIncompatibleVariableOverride] eval_type_id: str = Field( default=EVAL_TYPE_ID_META_STRUCTURE, title="Eval Type ID", From 5a3a63d7af20631bde416919fdc5d71edba09296 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 23 Dec 2025 17:02:38 -0500 Subject: [PATCH 5/7] updating README --- README.md | 27 ++++++++++--------- .../evals/sparse_probing_sae_probes/README.md | 4 +-- .../sparse_probing_sae_probes/eval_config.py | 2 +- 3 files changed, 17 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index b9e656a4..565c5285 100644 --- a/README.md +++ b/README.md @@ -124,19 +124,20 @@ The computational requirements for running SAEBench evaluations were measured on - **Setup Phase**: Includes operations like precomputing model activations, training probes, or other one-time preprocessing steps which can be reused across multiple SAE evaluations. - **Per-SAE Evaluation Time**: The time required to evaluate a single SAE once the setup is complete. -The total evaluation time for a single SAE across all benchmarks is approximately **110 minutes**, with an additional **152 minutes** of setup time. Note that actual runtimes may vary significantly based on factors such as SAE dictionary size, base model, and GPU selection. - -| Evaluation Type | Avg Time per SAE (min) | Setup Time (min) | -| --------------- | ---------------------- | ---------------- | -| Absorption | 26 | 33 | -| Core | 9 | 0 | -| SCR | 6 | 22 | -| TPP | 2 | 5 | -| Sparse Probing | 3 | 15 | -| Auto-Interp | 9 | 0 | -| Unlearning | 10 | 33 | -| RAVEL | 45 | 45 | -| **Total** | **110** | **152** | +The total evaluation time for a single SAE across all benchmarks is approximately **113 minutes**, with an additional **177 minutes** of setup time. Note that actual runtimes may vary significantly based on factors such as SAE dictionary size, base model, and GPU selection. + +| Evaluation Type | Avg Time per SAE (min) | Setup Time (min) | +| --------------------------- | ---------------------- | ---------------- | +| Absorption | 26 | 33 | +| Core | 9 | 0 | +| SCR | 6 | 22 | +| TPP | 2 | 5 | +| Sparse Probing | 3 | 15 | +| Sparse Probing (SAE Probes) | 3 | 25 | +| Auto-Interp | 9 | 0 | +| Unlearning | 10 | 33 | +| RAVEL | 45 | 45 | +| **Total** | **113** | **177** | # SAE Bench Baseline Suite diff --git a/sae_bench/evals/sparse_probing_sae_probes/README.md b/sae_bench/evals/sparse_probing_sae_probes/README.md index dadf4998..21e0c5b6 100644 --- a/sae_bench/evals/sparse_probing_sae_probes/README.md +++ b/sae_bench/evals/sparse_probing_sae_probes/README.md @@ -24,7 +24,7 @@ python sae_bench/evals/sparse_probing_sae_probes/main.py \ - `--setting`: Data balance setting (`normal`, `scarcity`, or `imbalance`, default: `normal`) - `--binarize`: Whether to binarize probe targets (flag, default: False) - `--results_path`: Directory where sae-probes writes intermediate JSONs (default: `artifacts/sparse_probing_sae_probes`) -- `--model_cache_path`: Optional directory to cache model activations for faster re-runs +- `--model_cache_path`: Optional directory to cache model activations for faster re-runs (default: `artifacts/sparse_probing_sae_probes--model_acts_cache`) - `--output_folder`: Where to save SAEBench output files (default: `eval_results/sparse_probing_sae_probes`) - `--force_rerun`: Force re-running the eval even if results exist (flag) @@ -153,4 +153,4 @@ This adds LLM baseline metrics to the output, allowing you to compare how well k ### Caching model activations for Faster Iteration -Set `model_cache_path` to cache model activations across runs if you expect to rerun this eval for lots of different SAEs on the same model / layers. If this is not set, the eval will re-generate model activations every time the eval is run. +Set `model_cache_path` to cache model activations across runs if you expect to rerun this eval for lots of different SAEs on the same model / layers. Set this to `None` to disable caching. diff --git a/sae_bench/evals/sparse_probing_sae_probes/eval_config.py b/sae_bench/evals/sparse_probing_sae_probes/eval_config.py index 769158b7..38ccefda 100644 --- a/sae_bench/evals/sparse_probing_sae_probes/eval_config.py +++ b/sae_bench/evals/sparse_probing_sae_probes/eval_config.py @@ -50,7 +50,7 @@ class SparseProbingSaeProbesEvalConfig(BaseEvalConfig): ) model_cache_path: str | None = Field( - default=None, + default="artifacts/sparse_probing_sae_probes--model_acts_cache", title="Model Activations Cache", description="Optional path where sae-probes will cache generated model activations.", ) From 7e70996afca2ed2804fca631bb1baa01dfb12b76 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 23 Dec 2025 18:58:10 -0500 Subject: [PATCH 6/7] nest individual results under SAE name --- sae_bench/evals/sparse_probing_sae_probes/main.py | 7 +++++-- tests/unit/evals/sparse_probing_sae_probes/test_main.py | 4 +++- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/sae_bench/evals/sparse_probing_sae_probes/main.py b/sae_bench/evals/sparse_probing_sae_probes/main.py index 92221669..2523c266 100644 --- a/sae_bench/evals/sparse_probing_sae_probes/main.py +++ b/sae_bench/evals/sparse_probing_sae_probes/main.py @@ -117,6 +117,9 @@ def run_eval( print(f"Skipping {sae_release}_{sae_id} as results already exist") continue + sae_results_path = os.path.join(config.results_path, f"{sae_release}_{sae_id}") + os.makedirs(sae_results_path, exist_ok=True) + # Run sae-probes (idempotent; will skip if JSONs exist) run_sae_evals( sae=sae, @@ -126,7 +129,7 @@ def run_eval( setting=config.setting, # type: ignore[arg-type] ks=config.ks, binarize=config.binarize, - results_path=config.results_path, + results_path=sae_results_path, model_cache_path=config.model_cache_path, datasets=config.dataset_names, device=device, @@ -137,7 +140,7 @@ def run_eval( json_files = [ f for f in _sae_probes_results_glob( - config.results_path, config.model_name, config.setting + sae_results_path, config.model_name, config.setting ) if f.name.endswith(expected_suffix) ] diff --git a/tests/unit/evals/sparse_probing_sae_probes/test_main.py b/tests/unit/evals/sparse_probing_sae_probes/test_main.py index 316060f4..4c7f5579 100644 --- a/tests/unit/evals/sparse_probing_sae_probes/test_main.py +++ b/tests/unit/evals/sparse_probing_sae_probes/test_main.py @@ -90,7 +90,9 @@ def test_run_eval_without_baselines(gpt2_l4_sae: SAE, tmp_path: Path): assert detail.llm_test_auc is None assert detail.llm_test_f1 is None - sae_probes_results_dir = artifacts_path / "sae_probes_gpt2" / "normal_setting" + sae_probes_results_dir = ( + artifacts_path / "gpt2_l4_sae_custom_sae" / "sae_probes_gpt2" / "normal_setting" + ) assert sae_probes_results_dir.exists() json_files = list(sae_probes_results_dir.glob("*.json")) assert len(json_files) >= 2 From 9eaffa87714d178ad17945cc095b6ce6acbf8a81 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 23 Dec 2025 19:39:47 -0500 Subject: [PATCH 7/7] updating readme --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 565c5285..67cfb15f 100644 --- a/README.md +++ b/README.md @@ -124,7 +124,7 @@ The computational requirements for running SAEBench evaluations were measured on - **Setup Phase**: Includes operations like precomputing model activations, training probes, or other one-time preprocessing steps which can be reused across multiple SAE evaluations. - **Per-SAE Evaluation Time**: The time required to evaluate a single SAE once the setup is complete. -The total evaluation time for a single SAE across all benchmarks is approximately **113 minutes**, with an additional **177 minutes** of setup time. Note that actual runtimes may vary significantly based on factors such as SAE dictionary size, base model, and GPU selection. +The total evaluation time for a single SAE across all benchmarks is approximately **115 minutes**, with an additional **177 minutes** of setup time. Note that actual runtimes may vary significantly based on factors such as SAE dictionary size, base model, and GPU selection. | Evaluation Type | Avg Time per SAE (min) | Setup Time (min) | | --------------------------- | ---------------------- | ---------------- | @@ -133,11 +133,11 @@ The total evaluation time for a single SAE across all benchmarks is approximatel | SCR | 6 | 22 | | TPP | 2 | 5 | | Sparse Probing | 3 | 15 | -| Sparse Probing (SAE Probes) | 3 | 25 | +| Sparse Probing (SAE Probes) | 5 | 25 | | Auto-Interp | 9 | 0 | | Unlearning | 10 | 33 | | RAVEL | 45 | 45 | -| **Total** | **113** | **177** | +| **Total** | **115** | **177** | # SAE Bench Baseline Suite