Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ matplotlib = ">=3.8.4"
tabulate = ">=0.9.0"
openai = ">=1.0.0"
torchvision = ">=0.16.1" # required for what I believe are nnsight related issues
sae-probes = "^0.3.0"
sae-probes = "^0.4.0"
datasets = ">=3.0.0,<4.0.0" # skylion openwebtext fails to load with datasets 4.0.0 currently, pending https://huggingface.co/datasets/Skylion007/openwebtext/discussions/22

# If running into dependency issues these are tested and working
Expand Down
1 change: 1 addition & 0 deletions sae_bench/custom_saes/run_all_evals_custom_saes.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def run_evals(
lambda: sparse_probing_sae_probes.run_eval(
sparse_probing_sae_probes.SparseProbingSaeProbesEvalConfig(
model_name=model_name,
random_seed=RANDOM_SEED,
),
selected_saes,
device,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def run_evals(
lambda selected_saes, is_final: sparse_probing_sae_probes.run_eval(
sparse_probing_sae_probes.SparseProbingSaeProbesEvalConfig(
model_name=model_name,
random_seed=random_seed,
),
selected_saes,
device,
Expand Down
6 changes: 6 additions & 0 deletions sae_bench/evals/sparse_probing_sae_probes/eval_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ class SparseProbingSaeProbesEvalConfig(BaseEvalConfig):
description="TransformerLens model name used by sae-probes (e.g., 'gemma-2-2b').",
)

random_seed: int = Field(
default=42,
title="Random Seed",
description="Random seed",
)

dataset_names: list[str] = Field(
default_factory=lambda: [*DATASETS],
title="Dataset Names",
Expand Down
4 changes: 4 additions & 0 deletions sae_bench/evals/sparse_probing_sae_probes/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def run_eval(
model_cache_path=config.model_cache_path,
datasets=config.dataset_names,
device=device,
seed=config.random_seed,
)

# Collect per-dataset JSONs and collate (filter by hook/reg to avoid stale files)
Expand Down Expand Up @@ -198,6 +199,7 @@ def run_eval(
model_cache_path=config.model_cache_path,
datasets=config.dataset_names,
device=device,
seed=config.random_seed,
)
# Baseline JSON pattern: baseline_results_{model_name}/{setting}_setting/{dataset}_{hook}_{method}.json
baseline_suffix = f"_{sae.cfg.hook_name}_{config.baseline_method}.json"
Expand Down Expand Up @@ -282,6 +284,7 @@ def create_config_and_selected_saes(
) -> tuple[SparseProbingSaeProbesEvalConfig, list[tuple[str, str]]]:
config = SparseProbingSaeProbesEvalConfig(
model_name=args.model_name,
random_seed=args.random_seed,
reg_type=args.reg_type,
setting=args.setting,
ks=args.ks,
Expand Down Expand Up @@ -313,6 +316,7 @@ def arg_parser():
choices=["l1", "l2"],
help="sae-probes regularization type",
)
parser.add_argument("--random_seed", type=int, default=42)
parser.add_argument(
"--setting",
type=str,
Expand Down
36 changes: 36 additions & 0 deletions tests/unit/evals/sparse_probing_sae_probes/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
from pathlib import Path
from unittest.mock import patch

from sae_lens import SAE

Expand Down Expand Up @@ -231,3 +232,38 @@ def test_run_eval_with_custom_ks(gpt2_l4_sae: SAE, tmp_path: Path):
metrics = detail.sae_metrics_by_k[k]
assert "test_accuracy" in metrics
assert 0 <= metrics["test_accuracy"] <= 1


def test_run_eval_propagates_random_seed(gpt2_l4_sae: SAE, tmp_path: Path):
custom_seed = 1234
config = SparseProbingSaeProbesEvalConfig(
model_name="gpt2",
include_llm_baseline=True,
model_cache_path=str(tmp_path / "model_cache"),
results_path=str(tmp_path / "test_artifacts"),
dataset_names=["118_us_state_CA"],
random_seed=custom_seed,
)

with (
patch(
"sae_bench.evals.sparse_probing_sae_probes.main.run_sae_evals"
) as mock_sae_evals,
patch(
"sae_bench.evals.sparse_probing_sae_probes.main.run_baseline_evals"
) as mock_baseline_evals,
):
results_dict = run_eval(
config,
[("gpt2_l4_sae", gpt2_l4_sae)],
device="cpu",
output_path=str(tmp_path / "test_output"),
)

mock_sae_evals.assert_called_once()
assert mock_sae_evals.call_args.kwargs["seed"] == custom_seed
mock_baseline_evals.assert_called_once()
assert mock_baseline_evals.call_args.kwargs["seed"] == custom_seed

result_data = results_dict["gpt2_l4_sae_custom_sae"]
assert result_data["eval_config"]["random_seed"] == custom_seed
Loading