You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Issue debugged, explained, and opened by Claude Opus 4.6
I noticed that when I trained an SAE on layer 20 of Qwen 2.5 7b, the explained variance was showing as 0.98, which seems unbelievable, but the legacy variance explained was showing as a more reasonable ~0.78. In the past I've found those metrics to be relatively close, so asked Claude to investigate if I had done something incorrect with training, and it came up with this bug in the explained variance calculation instead.
Summary
The "new" explained_variance metric introduced in #443 has two bugs that cause it to compute variance relative to zero rather than variance relative to the mean. This makes the metric substantially inflated compared to the legacy version (e.g., 0.98 vs 0.78).
Intended formula
Per the PR description, the intent was the multi-dimensional version of Var(X) = E[X²] - E[X]²:
This computes E[x_d²] per dimension. But the variance formula E[X²] - E[X]² needs E[x_d] (the mean), not E[x_d²] (the mean of squares). The mean of squares is already captured by mean_sum_of_squares. When this value is squared on line 586, it produces E[x_d²]², not E[x_d]².
Fix: Remove .pow(2):
mean_act_per_dimension.append(
(flattened_sae_input).mean(dim=0) # E[x_d] per dimension, shape [d_model]
)
Bug 2: torch.cat collapses dimensions instead of preserving them
torch.cat on a list of [d_model] tensors produces [N_batches * d_model], then .mean(dim=0) collapses to a scalar. This loses the per-dimension structure needed to compute Σ_d E[x_d]².
Fix: Use torch.stack and sum over the model dimension:
With both bugs, total_variance ≈ E[||x||²] (the second term becomes negligibly small), so the metric effectively computes:
explained_variance ≈ 1 - MSE / E[||x||²]
This is variance explained relative to a zero-prediction baseline, rather than relative to predicting the mean. When activations have a large mean component (common in deeper layers), this gives a misleadingly high explained variance.
Note
The same bug exists in SAEBench (PR #61), which was the source of this code.
Note
Issue debugged, explained, and opened by Claude Opus 4.6
I noticed that when I trained an SAE on layer 20 of Qwen 2.5 7b, the explained variance was showing as 0.98, which seems unbelievable, but the legacy variance explained was showing as a more reasonable ~0.78. In the past I've found those metrics to be relatively close, so asked Claude to investigate if I had done something incorrect with training, and it came up with this bug in the explained variance calculation instead.
Summary
The "new"
explained_variancemetric introduced in #443 has two bugs that cause it to compute variance relative to zero rather than variance relative to the mean. This makes the metric substantially inflated compared to the legacy version (e.g., 0.98 vs 0.78).Intended formula
Per the PR description, the intent was the multi-dimensional version of
Var(X) = E[X²] - E[X]²:Bug 1:
.pow(2)computesE[x_d²]instead ofE[x_d]In
evals.pyline 552:This computes
E[x_d²]per dimension. But the variance formulaE[X²] - E[X]²needsE[x_d](the mean), notE[x_d²](the mean of squares). The mean of squares is already captured bymean_sum_of_squares. When this value is squared on line 586, it producesE[x_d²]², notE[x_d]².Fix: Remove
.pow(2):Bug 2:
torch.catcollapses dimensions instead of preserving themIn
evals.pyline 585-586:torch.caton a list of[d_model]tensors produces[N_batches * d_model], then.mean(dim=0)collapses to a scalar. This loses the per-dimension structure needed to computeΣ_d E[x_d]².Fix: Use
torch.stackand sum over the model dimension:Effect
With both bugs,
total_variance ≈ E[||x||²](the second term becomes negligibly small), so the metric effectively computes:This is variance explained relative to a zero-prediction baseline, rather than relative to predicting the mean. When activations have a large mean component (common in deeper layers), this gives a misleadingly high explained variance.
Note
The same bug exists in SAEBench (PR #61), which was the source of this code.