Skip to content

Bug in explained_variance computation from PR #443 #659

@chanind

Description

@chanind

Note

Issue debugged, explained, and opened by Claude Opus 4.6

I noticed that when I trained an SAE on layer 20 of Qwen 2.5 7b, the explained variance was showing as 0.98, which seems unbelievable, but the legacy variance explained was showing as a more reasonable ~0.78. In the past I've found those metrics to be relatively close, so asked Claude to investigate if I had done something incorrect with training, and it came up with this bug in the explained variance calculation instead.

Summary

The "new" explained_variance metric introduced in #443 has two bugs that cause it to compute variance relative to zero rather than variance relative to the mean. This makes the metric substantially inflated compared to the legacy version (e.g., 0.98 vs 0.78).

Intended formula

Per the PR description, the intent was the multi-dimensional version of Var(X) = E[X²] - E[X]²:

total_variance = Σ_d (E[x_d²] - E[x_d]²)

Bug 1: .pow(2) computes E[x_d²] instead of E[x_d]

In evals.py line 552:

mean_act_per_dimension.append(
    (flattened_sae_input).pow(2).mean(dim=0)  # [d_model]
)

This computes E[x_d²] per dimension. But the variance formula E[X²] - E[X]² needs E[x_d] (the mean), not E[x_d²] (the mean of squares). The mean of squares is already captured by mean_sum_of_squares. When this value is squared on line 586, it produces E[x_d²]², not E[x_d]².

Fix: Remove .pow(2):

mean_act_per_dimension.append(
    (flattened_sae_input).mean(dim=0)  # E[x_d] per dimension, shape [d_model]
)

Bug 2: torch.cat collapses dimensions instead of preserving them

In evals.py line 585-586:

mean_act_per_dimension = torch.cat(mean_act_per_dimension).mean(dim=0)
total_variance = mean_sum_of_squares - mean_act_per_dimension**2

torch.cat on a list of [d_model] tensors produces [N_batches * d_model], then .mean(dim=0) collapses to a scalar. This loses the per-dimension structure needed to compute Σ_d E[x_d]².

Fix: Use torch.stack and sum over the model dimension:

mean_act_per_dimension = torch.stack(mean_act_per_dimension).mean(dim=0)  # [d_model]
total_variance = mean_sum_of_squares - (mean_act_per_dimension ** 2).sum()  # scalar

Effect

With both bugs, total_variance ≈ E[||x||²] (the second term becomes negligibly small), so the metric effectively computes:

explained_variance ≈ 1 - MSE / E[||x||²]

This is variance explained relative to a zero-prediction baseline, rather than relative to predicting the mean. When activations have a large mean component (common in deeper layers), this gives a misleadingly high explained variance.

Note

The same bug exists in SAEBench (PR #61), which was the source of this code.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions