Skip to content

Add covariance whitening normalization option (#647)#697

Open
azrabano23 wants to merge 1 commit into
decoderesearch:mainfrom
azrabano23:covariance-whitening-647
Open

Add covariance whitening normalization option (#647)#697
azrabano23 wants to merge 1 commit into
decoderesearch:mainfrom
azrabano23:covariance-whitening-647

Conversation

@azrabano23

Copy link
Copy Markdown

Summary

Implements #647 — adds covariance_whitening as a normalize_activations
option, applying PCA whitening to activations before the SAE, following
Data Whitening Improves Sparse Autoencoder Learning.

Whitening equalizes the eigenspectrum of the activation covariance, which
de-correlates features and (per the paper) better aligns sparsity with
interpretability.

How it works

Unlike the other normalize_activations options, whitening needs
data-estimated statistics, so the SAE stores three buffers:

  • whitening_mean — activation mean μ
  • whitening_W — whitening matrix W = D^(-1/2) Eᵀ
  • whitening_W_inv — dewhitening matrix W⁻¹ = E D^(1/2)

where the covariance C = E D Eᵀ and eigenvalues are regularized as λ + ε.

The run-time transforms whiten the input and dewhiten the output before the
reconstruction loss (exactly as the paper specifies):

in:  z = W (x − μ)
out: x̂ = W⁻¹ z + μ

SAE.estimate_whitening(data_provider, n_batches, eps) accumulates the mean and
covariance (in float64 for numerical stability), eigendecomposes via
torch.linalg.eigh, and fills the buffers. SAETrainer.fit() calls it before
training, mirroring the existing expected_average_only_in estimation hook. The
buffers live in the state dict, so whitening stays active across save/load
(they are intentionally not folded into the weights, since dewhitening is
explicit).

Tests

tests/saes/test_sae.py (functional, not superficial):

  • whitened activations have identity covariance on correlated synthetic data
  • exact dewhitening round-trip out(in(x)) ≈ x
  • whitening is a no-op before estimation (buffers start as identity/zero)
  • whitening buffers persist through save/load (save_modelload_from_disk)
  • estimate_whitening raises if the config isn't covariance_whitening

ruff check, ruff format, and pyright all pass on the changed files.

Notes / open questions for maintainers

  • Estimation reuses n_batches_for_norm_estimate for the number of covariance
    batches. Happy to add a dedicated config field (e.g. n_batches_for_whitening)
    if you'd prefer.
  • eps defaults to 1e-3 (paper uses a small stabilizer); easy to surface as a
    config option if desired.

Add 'covariance_whitening' as a normalize_activations option, implementing PCA
whitening from 'Data Whitening Improves Sparse Autoencoder Learning'
(https://arxiv.org/abs/2511.13981).

Unlike the existing options, whitening needs data-estimated statistics, so the
SAE stores whitening_mean, whitening_W (W = D^-1/2 E^T) and whitening_W_inv
(W^-1 = E D^1/2) as buffers. The run-time transforms whiten the input
(z = W(x - mu)) and dewhiten the output (x = W^-1 z + mu) before the
reconstruction loss, exactly as the paper specifies. SAE.estimate_whitening()
accumulates the activation mean and covariance (in float64 for stability),
eigendecomposes with an epsilon regularizer, and fills the buffers; the trainer
calls it before training, mirroring the expected_average_only_in path. The
buffers persist via the state dict, so whitening stays active on save/load.

Adds tests covering identity covariance after whitening, exact dewhitening
round-trip, no-op before estimation, buffer persistence through save/load, and
the config guard.

@chanind chanind left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for taking this task on! The core whitening / unwhitening logic looks good, but the location in the codebase could be improved to better fit with how SAELens training/inference works:

  • This looks like an enhancement to the ActivationScaler class rather than something that needs to be managed by the base SAE class itself.
  • This should only be needed at training time. I think it should be possible to fold whitening into weights when we finish training (see: https://github.com/decoderesearch/SAELens/blob/main/sae_lens/training/sae_trainer.py#L190-L194). We'd need a fold_ method added to do this covariance weight folding, but then we don't need to save out any new weights for the SAE, and when users load an SAE that was trained with whitening there's no extra compute penalty and no need for special logic in the base SAE class.

Comment thread sae_lens/saes/sae.py
self.run_time_activation_norm_fn_out = lambda x: x

@torch.no_grad()
def estimate_whitening(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes more sense to go in ActivationScaler, since this whitening is more of a pre-processing of the activations rather than something the base SAE class should need to deal with.

Comment thread sae_lens/saes/sae.py
self.run_time_activation_norm_fn_in = run_time_activation_ln_in
self.run_time_activation_norm_fn_out = run_time_activation_ln_out

elif self.cfg.normalize_activations == "covariance_whitening":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we fold the whitening into the SAE weights at the end of training, like we do with expected_average_only_in scaling? This is just a linear operation, so it should be possible to calculate adjusted versions of the weights that take this into account I think. Then we don't need to handle any special processing during inference, only during training, and the distributed SAE doesn't need extra weights matrices.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants