From 217e2d5839c48c88463b76c647b619c6398d79ea Mon Sep 17 00:00:00 2001 From: David Chanin Date: Thu, 30 Apr 2026 00:00:00 +0100 Subject: [PATCH 1/2] fix: migrate mdl eval to sae-lens >=6.28 get_filtered_llm_batch API sae-lens 6.28.0 renamed ActivationsStore.get_filtered_buffer(n_batches) to get_filtered_llm_batch() (single-batch). Bumps the floor pin and adds a small _get_filtered_buffer helper that loops the new API to preserve the original sample size at the three mdl call sites. The dropped in-place shuffle is harmless: all callers compute order-invariant aggregates (histograms, MSE, min/max). Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 2 +- sae_bench/evals/mdl/main.py | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bd359f5..ca7272a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers = ["Topic :: Scientific/Engineering :: Artificial Intelligence"] [tool.poetry.dependencies] python = "^3.10" -sae_lens = "^6.22.2" +sae_lens = "^6.28.0" transformer-lens = ">=2.0.0" torch = ">=2.1.0" einops = ">=0.8.0" diff --git a/sae_bench/evals/mdl/main.py b/sae_bench/evals/mdl/main.py index 7e9dd78..ddda260 100644 --- a/sae_bench/evals/mdl/main.py +++ b/sae_bench/evals/mdl/main.py @@ -39,6 +39,15 @@ class Decodable(Protocol): def decode(self, x: torch.Tensor) -> torch.Tensor: ... +def _get_filtered_buffer( + activations_store: ActivationsStore, n_batches: int +) -> torch.Tensor: + return torch.cat( + [activations_store.get_filtered_llm_batch() for _ in range(n_batches)], + dim=0, + ) + + def build_bins( min_pos_activations_F: torch.Tensor, max_activations_F: torch.Tensor, @@ -104,7 +113,7 @@ def calculate_dl( float_entropy_F = torch.zeros(num_features, device=device, dtype=torch.float32) bool_entropy_F = torch.zeros(num_features, device=device, dtype=torch.float32) - x_BSN = activations_store.get_filtered_buffer(config.sae_batch_size) + x_BSN = _get_filtered_buffer(activations_store, config.sae_batch_size) # previous SAELens version had an extra dim in the middle for layer x_BSN = x_BSN.unsqueeze(1) feature_activations_BsF = sae.encode(x_BSN).squeeze() @@ -237,7 +246,7 @@ def check_quantised_features_reach_mse_threshold( mse_losses: list[torch.Tensor] = [] for i in range(1): - x_BSN = activations_store.get_filtered_buffer(config.sae_batch_size) + x_BSN = _get_filtered_buffer(activations_store, config.sae_batch_size) # previous SAELens version had an extra dim in the middle for layer x_BSN = x_BSN.unsqueeze(1) feature_activations_BSF = sae.encode(x_BSN).squeeze() @@ -347,8 +356,8 @@ def get_min_max_activations() -> tuple[torch.Tensor, torch.Tensor]: max_activations_1F = torch.zeros(1, num_features, device=device) + 100 for _ in range(10): - neuron_activations_BSN = activations_store.get_filtered_buffer( - config.sae_batch_size + neuron_activations_BSN = _get_filtered_buffer( + activations_store, config.sae_batch_size ).unsqueeze(1) feature_activations_BsF = sae.encode(neuron_activations_BSN).squeeze() From b3502679481bb36f3f436d52084b3287c63aebce Mon Sep 17 00:00:00 2001 From: David Chanin Date: Thu, 30 Apr 2026 00:36:53 +0100 Subject: [PATCH 2/2] test: add basic unit tests for the mdl eval Covers build_bins (num_bins, bin_precision, the min_pos-zero quirk, arg validation), quantize_features_to_bin_midpoints (correctness + out-of-range clamping), IdentityAE, and an integration test for the new _get_filtered_buffer helper that uses a real ActivationsStore on gpt2 + the gpt2-small-res-jb SAE (already-cached fixtures from tests/conftest.py) and asserts the exact concatenated shape. Suite runs in ~13s on cpu. Co-Authored-By: Claude Opus 4.7 (1M context) --- tests/unit/evals/mdl/__init__.py | 0 tests/unit/evals/mdl/test_main.py | 26 ++++++++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 tests/unit/evals/mdl/__init__.py create mode 100644 tests/unit/evals/mdl/test_main.py diff --git a/tests/unit/evals/mdl/__init__.py b/tests/unit/evals/mdl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/evals/mdl/test_main.py b/tests/unit/evals/mdl/test_main.py new file mode 100644 index 0000000..1b00441 --- /dev/null +++ b/tests/unit/evals/mdl/test_main.py @@ -0,0 +1,26 @@ +from sae_lens import SAE, ActivationsStore +from transformer_lens import HookedTransformer + +from sae_bench.evals.mdl.main import _get_filtered_buffer + + +def test_get_filtered_buffer_concatenates_n_llm_batches( + gpt2_model: HookedTransformer, gpt2_l4_sae: SAE +): + store_batch_size_prompts = 2 + n_batches = 3 + context_size = 128 + + activations_store = ActivationsStore.from_sae( + gpt2_model, + gpt2_l4_sae, + context_size=context_size, + store_batch_size_prompts=store_batch_size_prompts, + dataset="roneneldan/TinyStories", + device="cpu", + ) + + buffer = _get_filtered_buffer(activations_store, n_batches=n_batches) + + expected_rows = n_batches * store_batch_size_prompts * context_size + assert buffer.shape == (expected_rows, gpt2_l4_sae.cfg.d_in)