From 5b13f0e91097739ef81465cc73809568ca9e6baf Mon Sep 17 00:00:00 2001 From: David Chanin Date: Wed, 29 Apr 2026 17:56:44 +0100 Subject: [PATCH] feat: allow setting random_seed for sae_probes_sparse_probing eval --- pyproject.toml | 2 +- .../custom_saes/run_all_evals_custom_saes.py | 1 + .../run_all_evals_dictionary_learning_saes.py | 1 + .../sparse_probing_sae_probes/eval_config.py | 6 ++++ .../evals/sparse_probing_sae_probes/main.py | 4 +++ .../sparse_probing_sae_probes/test_main.py | 36 +++++++++++++++++++ 6 files changed, 49 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bd359f5a..922f6bb6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ matplotlib = ">=3.8.4" tabulate = ">=0.9.0" openai = ">=1.0.0" torchvision = ">=0.16.1" # required for what I believe are nnsight related issues -sae-probes = "^0.3.0" +sae-probes = "^0.4.0" datasets = ">=3.0.0,<4.0.0" # skylion openwebtext fails to load with datasets 4.0.0 currently, pending https://huggingface.co/datasets/Skylion007/openwebtext/discussions/22 # If running into dependency issues these are tested and working diff --git a/sae_bench/custom_saes/run_all_evals_custom_saes.py b/sae_bench/custom_saes/run_all_evals_custom_saes.py index b87e18e4..0f33061b 100644 --- a/sae_bench/custom_saes/run_all_evals_custom_saes.py +++ b/sae_bench/custom_saes/run_all_evals_custom_saes.py @@ -177,6 +177,7 @@ def run_evals( lambda: sparse_probing_sae_probes.run_eval( sparse_probing_sae_probes.SparseProbingSaeProbesEvalConfig( model_name=model_name, + random_seed=RANDOM_SEED, ), selected_saes, device, diff --git a/sae_bench/custom_saes/run_all_evals_dictionary_learning_saes.py b/sae_bench/custom_saes/run_all_evals_dictionary_learning_saes.py index 8270130b..482e88f1 100644 --- a/sae_bench/custom_saes/run_all_evals_dictionary_learning_saes.py +++ b/sae_bench/custom_saes/run_all_evals_dictionary_learning_saes.py @@ -284,6 +284,7 @@ def run_evals( lambda selected_saes, is_final: sparse_probing_sae_probes.run_eval( sparse_probing_sae_probes.SparseProbingSaeProbesEvalConfig( model_name=model_name, + random_seed=random_seed, ), selected_saes, device, diff --git a/sae_bench/evals/sparse_probing_sae_probes/eval_config.py b/sae_bench/evals/sparse_probing_sae_probes/eval_config.py index 38ccefda..39dffcfe 100644 --- a/sae_bench/evals/sparse_probing_sae_probes/eval_config.py +++ b/sae_bench/evals/sparse_probing_sae_probes/eval_config.py @@ -13,6 +13,12 @@ class SparseProbingSaeProbesEvalConfig(BaseEvalConfig): description="TransformerLens model name used by sae-probes (e.g., 'gemma-2-2b').", ) + random_seed: int = Field( + default=42, + title="Random Seed", + description="Random seed", + ) + dataset_names: list[str] = Field( default_factory=lambda: [*DATASETS], title="Dataset Names", diff --git a/sae_bench/evals/sparse_probing_sae_probes/main.py b/sae_bench/evals/sparse_probing_sae_probes/main.py index 2523c266..2e1ecacf 100644 --- a/sae_bench/evals/sparse_probing_sae_probes/main.py +++ b/sae_bench/evals/sparse_probing_sae_probes/main.py @@ -133,6 +133,7 @@ def run_eval( model_cache_path=config.model_cache_path, datasets=config.dataset_names, device=device, + seed=config.random_seed, ) # Collect per-dataset JSONs and collate (filter by hook/reg to avoid stale files) @@ -198,6 +199,7 @@ def run_eval( model_cache_path=config.model_cache_path, datasets=config.dataset_names, device=device, + seed=config.random_seed, ) # Baseline JSON pattern: baseline_results_{model_name}/{setting}_setting/{dataset}_{hook}_{method}.json baseline_suffix = f"_{sae.cfg.hook_name}_{config.baseline_method}.json" @@ -282,6 +284,7 @@ def create_config_and_selected_saes( ) -> tuple[SparseProbingSaeProbesEvalConfig, list[tuple[str, str]]]: config = SparseProbingSaeProbesEvalConfig( model_name=args.model_name, + random_seed=args.random_seed, reg_type=args.reg_type, setting=args.setting, ks=args.ks, @@ -313,6 +316,7 @@ def arg_parser(): choices=["l1", "l2"], help="sae-probes regularization type", ) + parser.add_argument("--random_seed", type=int, default=42) parser.add_argument( "--setting", type=str, diff --git a/tests/unit/evals/sparse_probing_sae_probes/test_main.py b/tests/unit/evals/sparse_probing_sae_probes/test_main.py index 4c7f5579..130e7bae 100644 --- a/tests/unit/evals/sparse_probing_sae_probes/test_main.py +++ b/tests/unit/evals/sparse_probing_sae_probes/test_main.py @@ -1,5 +1,6 @@ import json from pathlib import Path +from unittest.mock import patch from sae_lens import SAE @@ -231,3 +232,38 @@ def test_run_eval_with_custom_ks(gpt2_l4_sae: SAE, tmp_path: Path): metrics = detail.sae_metrics_by_k[k] assert "test_accuracy" in metrics assert 0 <= metrics["test_accuracy"] <= 1 + + +def test_run_eval_propagates_random_seed(gpt2_l4_sae: SAE, tmp_path: Path): + custom_seed = 1234 + config = SparseProbingSaeProbesEvalConfig( + model_name="gpt2", + include_llm_baseline=True, + model_cache_path=str(tmp_path / "model_cache"), + results_path=str(tmp_path / "test_artifacts"), + dataset_names=["118_us_state_CA"], + random_seed=custom_seed, + ) + + with ( + patch( + "sae_bench.evals.sparse_probing_sae_probes.main.run_sae_evals" + ) as mock_sae_evals, + patch( + "sae_bench.evals.sparse_probing_sae_probes.main.run_baseline_evals" + ) as mock_baseline_evals, + ): + results_dict = run_eval( + config, + [("gpt2_l4_sae", gpt2_l4_sae)], + device="cpu", + output_path=str(tmp_path / "test_output"), + ) + + mock_sae_evals.assert_called_once() + assert mock_sae_evals.call_args.kwargs["seed"] == custom_seed + mock_baseline_evals.assert_called_once() + assert mock_baseline_evals.call_args.kwargs["seed"] == custom_seed + + result_data = results_dict["gpt2_l4_sae_custom_sae"] + assert result_data["eval_config"]["random_seed"] == custom_seed