From 90bd9e411c890cea47c4608be2103410b33196ee Mon Sep 17 00:00:00 2001 From: Amaljith Kuttamath Date: Thu, 26 Mar 2026 18:51:51 -0400 Subject: [PATCH 1/7] Fix explained_variance computing variance relative to zero instead of mean Two bugs in the variance computation for explained_variance: 1. mean_act_per_dimension accumulated .pow(2).mean() instead of .mean(), computing E[x^2] per dimension instead of E[x] per dimension. This made the subtracted term in Var = E[||X||^2] - ||E[X]||^2 incorrect. 2. torch.cat on the per-batch mean vectors flattened them into one long vector, destroying per-dimension structure. Replaced with torch.stack and added .sum() to reduce across dimensions. The combined effect was that total_variance was computed relative to zero (essentially E[||X||^2]) instead of relative to the mean, inflating explained_variance for activations with large mean components. Fixes #659 --- sae_lens/evals.py | 6 +++--- tests/test_evals.py | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index f1c20d82f..81bf05c3b 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -549,7 +549,7 @@ def get_sparsity_and_variance_metrics( (flattened_sae_input).pow(2).sum(dim=-1).mean(dim=0) # scalar ) mean_act_per_dimension.append( - (flattened_sae_input).pow(2).mean(dim=0) # [d_model] + (flattened_sae_input).mean(dim=0) # [d_model] ) mean_sum_of_resid_squared.append( resid_sum_of_squares.mean(dim=0) # scalar @@ -582,8 +582,8 @@ def get_sparsity_and_variance_metrics( # calculate explained variance if compute_variance_metrics: mean_sum_of_squares = torch.stack(mean_sum_of_squares).mean(dim=0) - mean_act_per_dimension = torch.cat(mean_act_per_dimension).mean(dim=0) - total_variance = mean_sum_of_squares - mean_act_per_dimension**2 + mean_act_per_dimension = torch.stack(mean_act_per_dimension).mean(dim=0) + total_variance = mean_sum_of_squares - (mean_act_per_dimension**2).sum() residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0) metrics["explained_variance"] = (1 - residual_variance / total_variance).item() diff --git a/tests/test_evals.py b/tests/test_evals.py index 8a285aa88..5090b0356 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -644,6 +644,38 @@ def test_get_sparsity_and_variance_metrics_identity_sae_perfect_reconstruction( assert metrics["mse"] == pytest.approx(0.0, abs=1e-5) +def test_explained_variance_uses_mean_centered_variance(): + """Verify explained_variance computes Var(X) = E[||X||^2] - ||E[X]||^2, not E[||X||^2].""" + # Construct inputs with a large mean so the difference between + # variance-from-zero and variance-from-mean is significant. + d_model = 8 + n_samples = 1000 + mean = torch.full((d_model,), 10.0) + x = mean + torch.randn(n_samples, d_model) * 0.5 + + # Ground truth total variance: sum of per-dimension variances + expected_total_var = x.var(dim=0, correction=0).sum().item() + + # The formula used in evals.py after the fix: + # total_variance = E[||X||^2] - ||E[X]||^2 + mean_sum_of_squares = x.pow(2).sum(dim=-1).mean(dim=0) + mean_act_per_dimension = x.mean(dim=0) + computed_total_var = ( + mean_sum_of_squares - (mean_act_per_dimension**2).sum() + ).item() + + assert computed_total_var == pytest.approx(expected_total_var, rel=1e-3) + + # With the bug (.pow(2) on the mean term), the subtracted term captures E[x^2]^2 + # instead of E[x]^2, making total_variance much larger than the true variance + # for data with a large mean. + buggy_mean_act = x.pow(2).mean(dim=0) # bug: .pow(2) before mean + buggy_total_var = ( + mean_sum_of_squares - (buggy_mean_act**2).sum() + ).item() + assert abs(buggy_total_var - expected_total_var) > expected_total_var * 0.5 + + def test_process_args(): args = [ "gpt2-small-res_scefr-ajt", From e66355fe1c9dfbfab09e63cf84654bc93189e078 Mon Sep 17 00:00:00 2001 From: Amal <61614061+amaljithkuttamath@users.noreply.github.com> Date: Fri, 27 Mar 2026 22:29:49 -0400 Subject: [PATCH 2/7] Fix misleading comment: buggy formula yields very negative variance, not larger --- tests/test_evals.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_evals.py b/tests/test_evals.py index 5090b0356..7d50e43b2 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -666,9 +666,10 @@ def test_explained_variance_uses_mean_centered_variance(): assert computed_total_var == pytest.approx(expected_total_var, rel=1e-3) - # With the bug (.pow(2) on the mean term), the subtracted term captures E[x^2]^2 - # instead of E[x]^2, making total_variance much larger than the true variance - # for data with a large mean. + # With the bug (.pow(2) on the mean term), the subtracted term becomes + # sum(E[x_d^2]^2) instead of sum(E[x_d]^2). For large-mean data this + # makes buggy_total_var very negative (or wildly wrong in general), + # which distorts the explained_variance ratio. buggy_mean_act = x.pow(2).mean(dim=0) # bug: .pow(2) before mean buggy_total_var = ( mean_sum_of_squares - (buggy_mean_act**2).sum() From 56066d86799da4fdd57056401b56d53b8038b804 Mon Sep 17 00:00:00 2001 From: Amal <61614061+amaljithkuttamath@users.noreply.github.com> Date: Sat, 28 Mar 2026 15:04:05 -0400 Subject: [PATCH 3/7] Guard against zero total variance and add edge-case test --- sae_lens/evals.py | 10 +++++++++- tests/test_evals.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 81bf05c3b..94ec94119 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -585,7 +585,15 @@ def get_sparsity_and_variance_metrics( mean_act_per_dimension = torch.stack(mean_act_per_dimension).mean(dim=0) total_variance = mean_sum_of_squares - (mean_act_per_dimension**2).sum() residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0) - metrics["explained_variance"] = (1 - residual_variance / total_variance).item() + eps = 1e-12 + if torch.abs(total_variance) <= eps: + if torch.abs(residual_variance) <= eps: + explained_variance = torch.tensor(1.0, device=total_variance.device) + else: + explained_variance = torch.tensor(0.0, device=total_variance.device) + else: + explained_variance = 1 - residual_variance / total_variance + metrics["explained_variance"] = explained_variance.item() # Aggregate feature-wise metrics feature_metrics: dict[str, list[float]] = {} diff --git a/tests/test_evals.py b/tests/test_evals.py index 7d50e43b2..27885bd30 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -677,6 +677,36 @@ def test_explained_variance_uses_mean_centered_variance(): assert abs(buggy_total_var - expected_total_var) > expected_total_var * 0.5 +def test_explained_variance_zero_total_variance(): + """When activations are constant (zero variance), explained_variance should be 1.0 + for perfect reconstruction and 0.0 when residual is nonzero.""" + d_model = 4 + + # Case 1: constant activations, perfect reconstruction -> 1.0 + total_var = torch.tensor(0.0) + residual_var = torch.tensor(0.0) + eps = 1e-12 + if torch.abs(total_var) <= eps: + if torch.abs(residual_var) <= eps: + ev = 1.0 + else: + ev = 0.0 + else: + ev = (1 - residual_var / total_var).item() + assert ev == 1.0 + + # Case 2: constant activations, nonzero residual -> 0.0 + residual_var = torch.tensor(0.5) + if torch.abs(total_var) <= eps: + if torch.abs(residual_var) <= eps: + ev = 1.0 + else: + ev = 0.0 + else: + ev = (1 - residual_var / total_var).item() + assert ev == 0.0 + + def test_process_args(): args = [ "gpt2-small-res_scefr-ajt", From 032ce8784a68c8e0ef315b3563553f891b18b18e Mon Sep 17 00:00:00 2001 From: Amal <61614061+amaljithkuttamath@users.noreply.github.com> Date: Sat, 28 Mar 2026 15:11:26 -0400 Subject: [PATCH 4/7] Fix ruff linting: remove unused var, use ternary --- tests/test_evals.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/tests/test_evals.py b/tests/test_evals.py index 27885bd30..a681736c3 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -680,30 +680,17 @@ def test_explained_variance_uses_mean_centered_variance(): def test_explained_variance_zero_total_variance(): """When activations are constant (zero variance), explained_variance should be 1.0 for perfect reconstruction and 0.0 when residual is nonzero.""" - d_model = 4 + eps = 1e-12 # Case 1: constant activations, perfect reconstruction -> 1.0 total_var = torch.tensor(0.0) residual_var = torch.tensor(0.0) - eps = 1e-12 - if torch.abs(total_var) <= eps: - if torch.abs(residual_var) <= eps: - ev = 1.0 - else: - ev = 0.0 - else: - ev = (1 - residual_var / total_var).item() + ev = 1.0 if torch.abs(residual_var) <= eps else 0.0 assert ev == 1.0 # Case 2: constant activations, nonzero residual -> 0.0 residual_var = torch.tensor(0.5) - if torch.abs(total_var) <= eps: - if torch.abs(residual_var) <= eps: - ev = 1.0 - else: - ev = 0.0 - else: - ev = (1 - residual_var / total_var).item() + ev = 1.0 if torch.abs(residual_var) <= eps else 0.0 assert ev == 0.0 From 47982edde9b86134a0de07687d736a61d7f4ed63 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 31 Mar 2026 15:01:31 +0100 Subject: [PATCH 5/7] fixing linting / type checking --- tests/test_evals.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/test_evals.py b/tests/test_evals.py index a681736c3..7bb5eb6c9 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -671,9 +671,7 @@ def test_explained_variance_uses_mean_centered_variance(): # makes buggy_total_var very negative (or wildly wrong in general), # which distorts the explained_variance ratio. buggy_mean_act = x.pow(2).mean(dim=0) # bug: .pow(2) before mean - buggy_total_var = ( - mean_sum_of_squares - (buggy_mean_act**2).sum() - ).item() + buggy_total_var = (mean_sum_of_squares - (buggy_mean_act**2).sum()).item() assert abs(buggy_total_var - expected_total_var) > expected_total_var * 0.5 @@ -683,7 +681,6 @@ def test_explained_variance_zero_total_variance(): eps = 1e-12 # Case 1: constant activations, perfect reconstruction -> 1.0 - total_var = torch.tensor(0.0) residual_var = torch.tensor(0.0) ev = 1.0 if torch.abs(residual_var) <= eps else 0.0 assert ev == 1.0 From c4fe5169eb2800538427b7a2bbc39947d471c1db Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 31 Mar 2026 15:30:11 +0100 Subject: [PATCH 6/7] Extract compute_explained_variance and rewrite tests to exercise it Pulls the explained variance logic into a standalone function called from get_sparsity_and_variance_metrics, then replaces the tautological tests with property-based tests that call the real function: single-batch vs torch.var, batched vs unbatched equivalence, translation invariance, zero-variance edge cases, and cross-check against ExplainedVarianceCalculator. Co-Authored-By: Claude Opus 4.6 (1M context) --- sae_lens/evals.py | 65 +++++++++++++++------------ tests/test_evals.py | 107 +++++++++++++++++++++++++++++--------------- 2 files changed, 107 insertions(+), 65 deletions(-) diff --git a/sae_lens/evals.py b/sae_lens/evals.py index 94ec94119..96aa20a45 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -381,6 +381,36 @@ def get_downstream_reconstruction_metrics( return metrics +def compute_explained_variance( + input_batches: list[torch.Tensor], + output_batches: list[torch.Tensor], +) -> float: + """Compute explained variance from lists of input/output tensor pairs. + + Each tensor has shape (n_tokens, d_model). Total variance is computed as + Var(X) = E[||X||^2] - ||E[X]||^2, summed across dimensions. + """ + mean_sum_of_squares: list[torch.Tensor] = [] + mean_act_per_dimension: list[torch.Tensor] = [] + mean_sum_of_resid_squared: list[torch.Tensor] = [] + + for sae_input, sae_output in zip(input_batches, output_batches): + mean_sum_of_squares.append(sae_input.pow(2).sum(dim=-1).mean(dim=0)) + mean_act_per_dimension.append(sae_input.mean(dim=0)) + resid_ss = (sae_input - sae_output).pow(2).sum(dim=-1) + mean_sum_of_resid_squared.append(resid_ss.mean(dim=0)) + + total_mean_ss = torch.stack(mean_sum_of_squares).mean(dim=0) + total_mean_act = torch.stack(mean_act_per_dimension).mean(dim=0) + total_variance = total_mean_ss - (total_mean_act**2).sum() + residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0) + + eps = 1e-12 + if torch.abs(total_variance) <= eps: + return 1.0 if torch.abs(residual_variance) <= eps else 0.0 + return (1 - residual_variance / total_variance).item() + + def get_sparsity_and_variance_metrics( sae: SAE[Any], model: HookedRootModule, @@ -411,9 +441,8 @@ def get_sparsity_and_variance_metrics( metric_dict["l0"] = [] metric_dict["l1"] = [] - mean_sum_of_squares = [] # for explained variance - mean_act_per_dimension = [] # for explained variance - mean_sum_of_resid_squared = [] # for explained variance + variance_inputs: list[torch.Tensor] = [] + variance_outputs: list[torch.Tensor] = [] if compute_variance_metrics: # explained_variance is left out of the dict here, since we don't want to naively # average over the batch dimension. This is handled later in the function. @@ -542,18 +571,8 @@ def get_sparsity_and_variance_metrics( ) explained_variance_legacy = 1 - resid_sum_of_squares / batched_variance_sum metric_dict["explained_variance_legacy"].append(explained_variance_legacy) - # Individual sums for the new (correct) formula. We're taking the mean over the batch - # dimension here to save memory, but we could also pass the full tensors and take the - # mean later (like we do for other metrics). - mean_sum_of_squares.append( - (flattened_sae_input).pow(2).sum(dim=-1).mean(dim=0) # scalar - ) - mean_act_per_dimension.append( - (flattened_sae_input).mean(dim=0) # [d_model] - ) - mean_sum_of_resid_squared.append( - resid_sum_of_squares.mean(dim=0) # scalar - ) + variance_inputs.append(flattened_sae_input) + variance_outputs.append(flattened_sae_out) x_normed = flattened_sae_input / torch.norm( flattened_sae_input, dim=-1, keepdim=True @@ -581,19 +600,9 @@ def get_sparsity_and_variance_metrics( # calculate explained variance if compute_variance_metrics: - mean_sum_of_squares = torch.stack(mean_sum_of_squares).mean(dim=0) - mean_act_per_dimension = torch.stack(mean_act_per_dimension).mean(dim=0) - total_variance = mean_sum_of_squares - (mean_act_per_dimension**2).sum() - residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0) - eps = 1e-12 - if torch.abs(total_variance) <= eps: - if torch.abs(residual_variance) <= eps: - explained_variance = torch.tensor(1.0, device=total_variance.device) - else: - explained_variance = torch.tensor(0.0, device=total_variance.device) - else: - explained_variance = 1 - residual_variance / total_variance - metrics["explained_variance"] = explained_variance.item() + metrics["explained_variance"] = compute_explained_variance( + variance_inputs, variance_outputs + ) # Aggregate feature-wise metrics feature_metrics: dict[str, list[float]] = {} diff --git a/tests/test_evals.py b/tests/test_evals.py index 7bb5eb6c9..a7736cd52 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -15,6 +15,7 @@ EvalConfig, _kl, all_loadable_saes, + compute_explained_variance, get_downstream_reconstruction_metrics, get_eval_everything_config, get_saes_from_regex, @@ -33,6 +34,7 @@ from sae_lens.saes.sae import SAE, TrainingSAE from sae_lens.saes.standard_sae import StandardSAE, StandardTrainingSAE from sae_lens.saes.topk_sae import TopKTrainingSAE +from sae_lens.synthetic.evals import ExplainedVarianceCalculator from sae_lens.training.activation_scaler import ActivationScaler from sae_lens.training.activations_store import ActivationsStore from tests.helpers import ( @@ -644,51 +646,82 @@ def test_get_sparsity_and_variance_metrics_identity_sae_perfect_reconstruction( assert metrics["mse"] == pytest.approx(0.0, abs=1e-5) -def test_explained_variance_uses_mean_centered_variance(): - """Verify explained_variance computes Var(X) = E[||X||^2] - ||E[X]||^2, not E[||X||^2].""" - # Construct inputs with a large mean so the difference between - # variance-from-zero and variance-from-mean is significant. +def test_explained_variance_single_batch_matches_formula(): d_model = 8 - n_samples = 1000 - mean = torch.full((d_model,), 10.0) - x = mean + torch.randn(n_samples, d_model) * 0.5 + n_samples = 10000 + x = torch.randn(n_samples, d_model) + 5.0 + x_hat = x + torch.randn(n_samples, d_model) * 0.3 - # Ground truth total variance: sum of per-dimension variances - expected_total_var = x.var(dim=0, correction=0).sum().item() + ev = compute_explained_variance([x], [x_hat]) - # The formula used in evals.py after the fix: - # total_variance = E[||X||^2] - ||E[X]||^2 - mean_sum_of_squares = x.pow(2).sum(dim=-1).mean(dim=0) - mean_act_per_dimension = x.mean(dim=0) - computed_total_var = ( - mean_sum_of_squares - (mean_act_per_dimension**2).sum() - ).item() + total_var = x.var(dim=0, correction=0).sum() + residual_var = (x - x_hat).pow(2).sum(dim=-1).mean() + expected_ev = (1 - residual_var / total_var).item() - assert computed_total_var == pytest.approx(expected_total_var, rel=1e-3) + assert ev == pytest.approx(expected_ev, rel=1e-5) - # With the bug (.pow(2) on the mean term), the subtracted term becomes - # sum(E[x_d^2]^2) instead of sum(E[x_d]^2). For large-mean data this - # makes buggy_total_var very negative (or wildly wrong in general), - # which distorts the explained_variance ratio. - buggy_mean_act = x.pow(2).mean(dim=0) # bug: .pow(2) before mean - buggy_total_var = (mean_sum_of_squares - (buggy_mean_act**2).sum()).item() - assert abs(buggy_total_var - expected_total_var) > expected_total_var * 0.5 + +def test_explained_variance_batched_matches_unbatched(): + d_model = 8 + n_samples = 3000 + x = torch.randn(n_samples, d_model) + 5.0 + x_hat = x + torch.randn(n_samples, d_model) * 0.3 + + ev_single = compute_explained_variance([x], [x_hat]) + + x_batches = list(x.chunk(3)) + x_hat_batches = list(x_hat.chunk(3)) + ev_batched = compute_explained_variance(x_batches, x_hat_batches) + + assert ev_batched == pytest.approx(ev_single, rel=1e-5) + + +def test_explained_variance_invariant_to_input_bias(): + # Use float64 to avoid catastrophic cancellation with large biases + d_model = 8 + n_samples = 10000 + x = torch.randn(n_samples, d_model, dtype=torch.float64) + x_hat = x + torch.randn(n_samples, d_model, dtype=torch.float64) * 0.3 + + ev_original = compute_explained_variance([x], [x_hat]) + + bias = torch.full((d_model,), 100.0, dtype=torch.float64) + ev_biased = compute_explained_variance([x + bias], [x_hat + bias]) + + assert ev_biased == pytest.approx(ev_original, rel=1e-10) def test_explained_variance_zero_total_variance(): - """When activations are constant (zero variance), explained_variance should be 1.0 - for perfect reconstruction and 0.0 when residual is nonzero.""" - eps = 1e-12 - - # Case 1: constant activations, perfect reconstruction -> 1.0 - residual_var = torch.tensor(0.0) - ev = 1.0 if torch.abs(residual_var) <= eps else 0.0 - assert ev == 1.0 - - # Case 2: constant activations, nonzero residual -> 0.0 - residual_var = torch.tensor(0.5) - ev = 1.0 if torch.abs(residual_var) <= eps else 0.0 - assert ev == 0.0 + d_model = 4 + n_samples = 100 + x = torch.full((n_samples, d_model), 5.0) + + # Constant input, perfect reconstruction -> 1.0 + assert compute_explained_variance([x], [x]) == 1.0 + + # Constant input, nonzero residual -> 0.0 + x_hat = x + 0.5 + assert compute_explained_variance([x], [x_hat]) == 0.0 + + +def test_explained_variance_matches_synthetic_calculator(): + d_model = 8 + n_samples = 3000 + batch_size = 1000 + x = torch.randn(n_samples, d_model) + 5.0 + x_hat = x + torch.randn(n_samples, d_model) * 0.3 + + x_batches = list(x.split(batch_size)) + x_hat_batches = list(x_hat.split(batch_size)) + + ev_evals = compute_explained_variance(x_batches, x_hat_batches) + + calc = ExplainedVarianceCalculator(hidden_dim=d_model) + for inp, out in zip(x_batches, x_hat_batches): + calc.add_batch(sae_output=out, hidden_acts=inp) + ev_synthetic = calc.compute() + + assert ev_evals == pytest.approx(ev_synthetic, rel=1e-5) def test_process_args(): From accaaf436af3e21142d7f25996ab1a1350e5a48d Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 31 Mar 2026 15:51:10 +0100 Subject: [PATCH 7/7] Add integration test: explained variance invariant to activation shift Uses a random SAE with apply_b_dec_to_input=True and real model activations. Shifts both inputs and b_dec by a constant, verifies explained_variance is unchanged (since the shift cancels in encoding and both input and output shift equally). Co-Authored-By: Claude Opus 4.6 (1M context) --- tests/test_evals.py | 57 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/tests/test_evals.py b/tests/test_evals.py index a7736cd52..61b90795e 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -646,6 +646,63 @@ def test_get_sparsity_and_variance_metrics_identity_sae_perfect_reconstruction( assert metrics["mse"] == pytest.approx(0.0, abs=1e-5) +def test_explained_variance_invariant_to_activation_shift( + model: HookedTransformer, + example_dataset: Dataset, +): + d_in = 64 + hook_name = "blocks.1.hook_resid_pre" + cfg = build_runner_cfg(d_in=d_in, d_sae=2 * d_in, hook_name=hook_name) + + training_sae = StandardTrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) + sae = StandardSAE.from_dict(training_sae.cfg.get_inference_sae_cfg_dict()) + random_params(sae) + # The shift invariance requires b_dec to be subtracted from the input during + # encoding so that the shift cancels: encode(x+c) with b_dec+c == encode(x) with b_dec. + sae.cfg.apply_b_dec_to_input = True + + eval_kwargs: dict[str, Any] = dict( + sae=sae, + model=model, + activation_scaler=ActivationScaler(None), + n_batches=3, + compute_l2_norms=False, + compute_sparsity_metrics=False, + compute_variance_metrics=True, + compute_featurewise_density_statistics=True, + eval_batch_size_prompts=4, + model_kwargs={}, + ) + + activation_store = ActivationsStore.from_config( + model, cfg, override_dataset=example_dataset + ) + metrics_original, _ = get_sparsity_and_variance_metrics( + activation_store=activation_store, **eval_kwargs + ) + + # Shifting activations by a constant and adjusting b_dec by the same amount + # should not change explained_variance: the shift cancels in encoding + # (so features are identical) and both input and output shift equally. + shift = torch.full((d_in,), 5.0) + sae.b_dec.data += shift + model.add_hook(hook_name, lambda tensor, **_: tensor + shift, is_permanent=True) + try: + activation_store_shifted = ActivationsStore.from_config( + model, cfg, override_dataset=example_dataset + ) + metrics_shifted, _ = get_sparsity_and_variance_metrics( + activation_store=activation_store_shifted, **eval_kwargs + ) + finally: + model.reset_hooks(including_permanent=True) + + # Tolerance is limited by float32 catastrophic cancellation in E[||x||^2] - ||E[x]||^2 + assert metrics_shifted["explained_variance"] == pytest.approx( + metrics_original["explained_variance"], rel=1e-3 + ) + + def test_explained_variance_single_batch_matches_formula(): d_model = 8 n_samples = 10000