Skip to content

Fix explained_variance computing variance relative to zero instead of mean#665

Open
amaljithkuttamath wants to merge 7 commits into
decoderesearch:mainfrom
amaljithkuttamath:fix/explained-variance-computation
Open

Fix explained_variance computing variance relative to zero instead of mean#665
amaljithkuttamath wants to merge 7 commits into
decoderesearch:mainfrom
amaljithkuttamath:fix/explained-variance-computation

Conversation

@amaljithkuttamath

Copy link
Copy Markdown

Summary

Fixes #659. Two bugs in get_sparsity_and_variance_metrics caused explained_variance to compute total variance relative to zero instead of relative to the mean:

  1. Line 552: mean_act_per_dimension accumulated .pow(2).mean(dim=0) (i.e. E[x_d^2]) instead of .mean(dim=0) (i.e. E[x_d]). When squared on line 586, this produced E[x_d^2]^2 instead of the correct E[x_d]^2.

  2. Line 585: torch.cat on a list of (d_model,) tensors produced (N_batches * d_model,), then .mean(dim=0) collapsed everything to a scalar, destroying per-dimension structure. Replaced with torch.stack and added .sum() to reduce across dimensions after squaring (matching how mean_sum_of_squares is already a scalar from .sum(dim=-1)).

The combined effect: total_variance was approximately E[||X||^2] (variance from zero) instead of E[||X||^2] - ||E[X]||^2 (variance from mean). For activations with large mean components, this inflated explained_variance.

Changes

  • sae_lens/evals.py: 3 line changes (remove .pow(2), cat -> stack, add .sum())
  • tests/test_evals.py: New test that constructs high-mean data and verifies the variance formula matches torch.var with correction=0. Also verifies the buggy formula produces a materially different (incorrect) result.

Test plan

  • New test test_explained_variance_uses_mean_centered_variance passes
  • Existing test test_get_sparsity_and_variance_metrics_identity_sae_perfect_reconstruction still passes
  • ruff check passes

… mean

Two bugs in the variance computation for explained_variance:

1. mean_act_per_dimension accumulated .pow(2).mean() instead of .mean(),
   computing E[x^2] per dimension instead of E[x] per dimension. This
   made the subtracted term in Var = E[||X||^2] - ||E[X]||^2 incorrect.

2. torch.cat on the per-batch mean vectors flattened them into one long
   vector, destroying per-dimension structure. Replaced with torch.stack
   and added .sum() to reduce across dimensions.

The combined effect was that total_variance was computed relative to zero
(essentially E[||X||^2]) instead of relative to the mean, inflating
explained_variance for activations with large mean components.

Fixes decoderesearch#659

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes the explained_variance calculation in get_sparsity_and_variance_metrics so it uses mean-centered variance (variance relative to the mean) rather than variance relative to zero, bringing it in line with the intended multidimensional identity Var(X) = E[||X||²] - ||E[X]||².

Changes:

  • Correct mean_act_per_dimension to accumulate E[x_d] (not E[x_d²]).
  • Preserve per-dimension structure via torch.stack(...).mean(dim=0) and compute ||E[X]||² via a dimension-wise sum.
  • Add a targeted unit test validating the corrected variance identity and demonstrating the buggy behavior is materially different.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
sae_lens/evals.py Fixes mean and aggregation logic used to compute total variance for explained_variance.
tests/test_evals.py Adds a regression test ensuring total variance is mean-centered and catches the prior buggy formula.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread sae_lens/evals.py Outdated
mean_act_per_dimension = torch.stack(mean_act_per_dimension).mean(dim=0)
total_variance = mean_sum_of_squares - (mean_act_per_dimension**2).sum()
residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0)
metrics["explained_variance"] = (1 - residual_variance / total_variance).item()

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total_variance can legitimately be 0 (e.g., if only one unmasked token remains after ignore_tokens, or if activations are constant). In that case 1 - residual_variance / total_variance will produce inf/NaN. Consider guarding for total_variance <= eps (and possibly small negative values from fp roundoff) similarly to sae_lens/synthetic/evals.py, returning 1.0 when both variances are ~0, else 0.0 (or another defined fallback).

Suggested change
metrics["explained_variance"] = (1 - residual_variance / total_variance).item()
# Guard against zero / near-zero total variance to avoid inf/NaN.
# When both variances are ~0, treat explained variance as 1.0
# (perfect reconstruction of a constant signal); otherwise 0.0.
eps = 1e-12
if torch.abs(total_variance) <= eps:
if torch.abs(residual_variance) <= eps:
explained_variance = torch.tensor(1.0, device=total_variance.device)
else:
explained_variance = torch.tensor(0.0, device=total_variance.device)
else:
explained_variance = 1 - residual_variance / total_variance
metrics["explained_variance"] = explained_variance.item()

Copilot uses AI. Check for mistakes.
Comment thread tests/test_evals.py Outdated
Comment on lines +669 to +675
# With the bug (.pow(2) on the mean term), the subtracted term captures E[x^2]^2
# instead of E[x]^2, making total_variance much larger than the true variance
# for data with a large mean.
buggy_mean_act = x.pow(2).mean(dim=0) # bug: .pow(2) before mean
buggy_total_var = (
mean_sum_of_squares - (buggy_mean_act**2).sum()
).item()

Copilot AI Mar 27, 2026

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The explanatory comment says the buggy formula makes total_variance “much larger than the true variance”, but with the specific buggy_mean_act = x.pow(2).mean(dim=0) used below the resulting buggy_total_var will typically be very negative (because you subtract (\sum_d E[x_d^2]^2)). Consider rewording to avoid misleading readers (and optionally clarify that the inflated explained_variance in the original bug required both the .pow(2) mistake and the cat->scalar collapse).

Copilot uses AI. Check for mistakes.
@chanind

chanind commented Mar 30, 2026

Copy link
Copy Markdown
Collaborator

Thank you for this PR. This PR is probably correct, but I just want to make sure before merging to avoid needing to fix this calculation a third time.

@chanind chanind left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Main issues are the tests do not actually test anything, I'll try working on this

Comment thread tests/test_evals.py
Comment thread tests/test_evals.py Outdated
assert metrics["mse"] == pytest.approx(0.0, abs=1e-5)


def test_explained_variance_uses_mean_centered_variance():

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't test any actual code

chanind and others added 2 commits March 31, 2026 15:30
Pulls the explained variance logic into a standalone function called
from get_sparsity_and_variance_metrics, then replaces the tautological
tests with property-based tests that call the real function: single-batch
vs torch.var, batched vs unbatched equivalence, translation invariance,
zero-variance edge cases, and cross-check against ExplainedVarianceCalculator.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Uses a random SAE with apply_b_dec_to_input=True and real model
activations. Shifts both inputs and b_dec by a constant, verifies
explained_variance is unchanged (since the shift cancels in encoding
and both input and output shift equally).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@chanind chanind force-pushed the fix/explained-variance-computation branch from 56391d5 to accaaf4 Compare March 31, 2026 14:55
@robbiebusinessacc

Copy link
Copy Markdown
Contributor

Came across this digging through the eval code, and since the thread mentions
wanting extra confidence before merging a third time, here's an independent
hand-computed check of #665's fix that doesn't lean on torch.var as the
oracle:

Two tokens, two dims:

x = [[2, 0], [0, 4]] -> mu = [1, 2]
x_hat = [[2, 1], [0, 3]]
total_variance = E[Σ_d (x_d - mu_d)²] = ((1+4) + (1+4)) / 2 = 5
residual_variance = E[Σ_d (x_d - x̂_d)²] = (1 + 1) / 2 = 1
explained_variance = 1 - 1/5 = 0.8

I ran #665's code on this and it returns 0.8; the pre-fix code does not. Two
cheap canaries that would also lock it down going forward, if useful as extra
test cases:

  • Predict-the-mean (x_hat = x.mean(0)): EV must be ~0. The old code
    returns a value > 1, which is impossible — a clean regression canary.
  • Perfect reconstruction (x_hat = x): EV must be exactly 1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Bug in explained_variance computation from PR #443

4 participants