Add covariance whitening normalization option (#647)#697
Conversation
Add 'covariance_whitening' as a normalize_activations option, implementing PCA whitening from 'Data Whitening Improves Sparse Autoencoder Learning' (https://arxiv.org/abs/2511.13981). Unlike the existing options, whitening needs data-estimated statistics, so the SAE stores whitening_mean, whitening_W (W = D^-1/2 E^T) and whitening_W_inv (W^-1 = E D^1/2) as buffers. The run-time transforms whiten the input (z = W(x - mu)) and dewhiten the output (x = W^-1 z + mu) before the reconstruction loss, exactly as the paper specifies. SAE.estimate_whitening() accumulates the activation mean and covariance (in float64 for stability), eigendecomposes with an epsilon regularizer, and fills the buffers; the trainer calls it before training, mirroring the expected_average_only_in path. The buffers persist via the state dict, so whitening stays active on save/load. Adds tests covering identity covariance after whitening, exact dewhitening round-trip, no-op before estimation, buffer persistence through save/load, and the config guard.
There was a problem hiding this comment.
Thank you for taking this task on! The core whitening / unwhitening logic looks good, but the location in the codebase could be improved to better fit with how SAELens training/inference works:
- This looks like an enhancement to the
ActivationScalerclass rather than something that needs to be managed by the base SAE class itself. - This should only be needed at training time. I think it should be possible to fold whitening into weights when we finish training (see: https://github.com/decoderesearch/SAELens/blob/main/sae_lens/training/sae_trainer.py#L190-L194). We'd need a
fold_method added to do this covariance weight folding, but then we don't need to save out any new weights for the SAE, and when users load an SAE that was trained with whitening there's no extra compute penalty and no need for special logic in the base SAE class.
| self.run_time_activation_norm_fn_out = lambda x: x | ||
|
|
||
| @torch.no_grad() | ||
| def estimate_whitening( |
There was a problem hiding this comment.
This makes more sense to go in ActivationScaler, since this whitening is more of a pre-processing of the activations rather than something the base SAE class should need to deal with.
| self.run_time_activation_norm_fn_in = run_time_activation_ln_in | ||
| self.run_time_activation_norm_fn_out = run_time_activation_ln_out | ||
|
|
||
| elif self.cfg.normalize_activations == "covariance_whitening": |
There was a problem hiding this comment.
Can we fold the whitening into the SAE weights at the end of training, like we do with expected_average_only_in scaling? This is just a linear operation, so it should be possible to calculate adjusted versions of the weights that take this into account I think. Then we don't need to handle any special processing during inference, only during training, and the distributed SAE doesn't need extra weights matrices.
Summary
Implements #647 — adds
covariance_whiteningas anormalize_activationsoption, applying PCA whitening to activations before the SAE, following
Data Whitening Improves Sparse Autoencoder Learning.
Whitening equalizes the eigenspectrum of the activation covariance, which
de-correlates features and (per the paper) better aligns sparsity with
interpretability.
How it works
Unlike the other
normalize_activationsoptions, whitening needsdata-estimated statistics, so the SAE stores three buffers:
whitening_mean— activation meanμwhitening_W— whitening matrixW = D^(-1/2) Eᵀwhitening_W_inv— dewhitening matrixW⁻¹ = E D^(1/2)where the covariance
C = E D Eᵀand eigenvalues are regularized asλ + ε.The run-time transforms whiten the input and dewhiten the output before the
reconstruction loss (exactly as the paper specifies):
SAE.estimate_whitening(data_provider, n_batches, eps)accumulates the mean andcovariance (in float64 for numerical stability), eigendecomposes via
torch.linalg.eigh, and fills the buffers.SAETrainer.fit()calls it beforetraining, mirroring the existing
expected_average_only_inestimation hook. Thebuffers live in the state dict, so whitening stays active across save/load
(they are intentionally not folded into the weights, since dewhitening is
explicit).
Tests
tests/saes/test_sae.py(functional, not superficial):out(in(x)) ≈ xsave_model→load_from_disk)estimate_whiteningraises if the config isn'tcovariance_whiteningruff check,ruff format, andpyrightall pass on the changed files.Notes / open questions for maintainers
n_batches_for_norm_estimatefor the number of covariancebatches. Happy to add a dedicated config field (e.g.
n_batches_for_whitening)if you'd prefer.
epsdefaults to1e-3(paper uses a small stabilizer); easy to surface as aconfig option if desired.