From 1d5dafe6c938f8098a00f606eadba86aba8c9f54 Mon Sep 17 00:00:00 2001 From: Azra Bano Date: Sat, 6 Jun 2026 23:43:42 -0700 Subject: [PATCH] Add covariance whitening normalization option (#647) Add 'covariance_whitening' as a normalize_activations option, implementing PCA whitening from 'Data Whitening Improves Sparse Autoencoder Learning' (https://arxiv.org/abs/2511.13981). Unlike the existing options, whitening needs data-estimated statistics, so the SAE stores whitening_mean, whitening_W (W = D^-1/2 E^T) and whitening_W_inv (W^-1 = E D^1/2) as buffers. The run-time transforms whiten the input (z = W(x - mu)) and dewhiten the output (x = W^-1 z + mu) before the reconstruction loss, exactly as the paper specifies. SAE.estimate_whitening() accumulates the activation mean and covariance (in float64 for stability), eigendecomposes with an epsilon regularizer, and fills the buffers; the trainer calls it before training, mirroring the expected_average_only_in path. The buffers persist via the state dict, so whitening stays active on save/load. Adds tests covering identity covariance after whitening, exact dewhitening round-trip, no-op before estimation, buffer persistence through save/load, and the config guard. --- sae_lens/saes/sae.py | 96 +++++++++++++++++++++++++++++++- sae_lens/training/sae_trainer.py | 5 ++ tests/saes/test_sae.py | 82 +++++++++++++++++++++++++++ 3 files changed, 181 insertions(+), 2 deletions(-) diff --git a/sae_lens/saes/sae.py b/sae_lens/saes/sae.py index 3952486ea..3d3423ac3 100644 --- a/sae_lens/saes/sae.py +++ b/sae_lens/saes/sae.py @@ -4,13 +4,13 @@ import json import warnings from abc import ABC, abstractmethod +from collections.abc import Callable, Iterator from contextlib import contextmanager from dataclasses import asdict, dataclass, field, fields, replace from pathlib import Path from typing import ( TYPE_CHECKING, Any, - Callable, Generic, Literal, NamedTuple, @@ -22,6 +22,7 @@ from numpy.typing import NDArray from safetensors.torch import load_file, save_file from torch import nn +from tqdm.auto import tqdm from transformer_lens.hook_points import HookedRootModule, HookPoint from typing_extensions import deprecated, overload, override @@ -158,6 +159,7 @@ class SAEConfig(ABC): "none", "expected_average_only_in", # (Anthropic April 2024 Update) "layer_norm", + "covariance_whitening", # PCA whitening (https://arxiv.org/abs/2511.13981) ] = "none" reshape_activations: Literal["none", "hook_z"] = "none" metadata: SAEMetadata = field(default_factory=SAEMetadata) @@ -191,9 +193,10 @@ def __post_init__(self): "expected_average_only_in", "constant_norm_rescale", "layer_norm", + "covariance_whitening", ]: raise ValueError( - f"normalize_activations must be none, expected_average_only_in, layer_norm, or constant_norm_rescale. Got {self.normalize_activations}" + f"normalize_activations must be none, expected_average_only_in, layer_norm, constant_norm_rescale, or covariance_whitening. Got {self.normalize_activations}" ) @@ -242,6 +245,10 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC): W_enc: nn.Parameter W_dec: nn.Parameter b_dec: nn.Parameter + # Registered as buffers only when normalize_activations == "covariance_whitening". + whitening_mean: torch.Tensor + whitening_W: torch.Tensor + whitening_W_inv: torch.Tensor def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False): """Initialize the SAE.""" @@ -353,10 +360,95 @@ def run_time_activation_ln_out( self.run_time_activation_norm_fn_in = run_time_activation_ln_in self.run_time_activation_norm_fn_out = run_time_activation_ln_out + + elif self.cfg.normalize_activations == "covariance_whitening": + # PCA whitening (https://arxiv.org/abs/2511.13981). The whitening + # statistics are data-estimated by estimate_whitening() before + # training and stored as buffers so they persist on save/load. Until + # estimated they are the identity, so whitening is a no-op. + d_in = self.cfg.d_in + self.register_buffer( + "whitening_mean", + torch.zeros(d_in, dtype=self.dtype, device=self.device), + ) + self.register_buffer( + "whitening_W", torch.eye(d_in, dtype=self.dtype, device=self.device) + ) + self.register_buffer( + "whitening_W_inv", torch.eye(d_in, dtype=self.dtype, device=self.device) + ) + + def run_time_activation_whiten_in(x: torch.Tensor) -> torch.Tensor: + # z = W (x - mu), applied to row vectors in the last dim. + return (x - self.whitening_mean) @ self.whitening_W.T + + def run_time_activation_whiten_out(x: torch.Tensor) -> torch.Tensor: + # Dewhiten before the reconstruction loss: x = W^-1 z + mu. + return x @ self.whitening_W_inv.T + self.whitening_mean + + self.run_time_activation_norm_fn_in = run_time_activation_whiten_in + self.run_time_activation_norm_fn_out = run_time_activation_whiten_out + else: self.run_time_activation_norm_fn_in = lambda x: x self.run_time_activation_norm_fn_out = lambda x: x + @torch.no_grad() + def estimate_whitening( + self, + data_provider: Iterator[torch.Tensor], + n_batches: int = 100, + eps: float = 1e-3, + ) -> None: + """Estimate PCA whitening statistics from activations. + + Accumulates the mean and covariance of activations over ``n_batches``, + then sets the whitening buffers via an eigendecomposition of the + covariance, following https://arxiv.org/abs/2511.13981. + + The whitening matrix is W = D^(-1/2) E^T and its inverse W^-1 = E D^(1/2), + where the covariance is C = E D E^T and each eigenvalue is regularized as + lambda_i + eps for numerical stability. + + Args: + data_provider: Iterator yielding activation batches of shape + (batch_size, d_in). + n_batches: Number of batches to accumulate statistics over. + eps: Stabilizing constant added to covariance eigenvalues. + """ + if self.cfg.normalize_activations != "covariance_whitening": + raise ValueError( + "estimate_whitening requires normalize_activations='covariance_whitening', " + f"got {self.cfg.normalize_activations!r}" + ) + + d_in = self.cfg.d_in + # Accumulate in float64 for numerical stability of the covariance. + count = 0 + sum_x = torch.zeros(d_in, dtype=torch.float64, device=self.device) + sum_outer = torch.zeros(d_in, d_in, dtype=torch.float64, device=self.device) + for _ in tqdm(range(n_batches), desc="Estimating whitening", leave=False): + acts = self.reshape_fn_in(next(data_provider)).to( + self.device, torch.float64 + ) + count += acts.shape[0] + sum_x += acts.sum(dim=0) + sum_outer += acts.T @ acts + + mean = sum_x / count + # Unbiased covariance: (sum_outer - count * mean mean^T) / (count - 1). + cov = (sum_outer - count * torch.outer(mean, mean)) / (count - 1) + eigenvalues, eigenvectors = torch.linalg.eigh(cov) + inv_sqrt = (eigenvalues + eps).rsqrt() + sqrt = (eigenvalues + eps).sqrt() + # W = D^(-1/2) E^T ; W^-1 = E D^(1/2). + whitening_w = (eigenvectors * inv_sqrt).T + whitening_w_inv = eigenvectors * sqrt + + self.whitening_mean.copy_(mean.to(self.dtype)) + self.whitening_W.copy_(whitening_w.to(self.dtype)) + self.whitening_W_inv.copy_(whitening_w_inv.to(self.dtype)) + def initialize_weights(self): """Initialize model weights.""" self.b_dec = nn.Parameter( diff --git a/sae_lens/training/sae_trainer.py b/sae_lens/training/sae_trainer.py index 72cb9da02..eaa7a6904 100644 --- a/sae_lens/training/sae_trainer.py +++ b/sae_lens/training/sae_trainer.py @@ -169,6 +169,11 @@ def fit(self) -> T_TRAINING_SAE: data_provider=self.data_provider, n_batches_for_norm_estimate=self.cfg.n_batches_for_norm_estimate, ) + elif self.sae.cfg.normalize_activations == "covariance_whitening": + self.sae.estimate_whitening( + data_provider=self.data_provider, + n_batches=self.cfg.n_batches_for_norm_estimate, + ) # Train loop while self.n_training_samples < self.cfg.total_training_samples: diff --git a/tests/saes/test_sae.py b/tests/saes/test_sae.py index 20517f2c6..3b6cbb63a 100644 --- a/tests/saes/test_sae.py +++ b/tests/saes/test_sae.py @@ -1,6 +1,7 @@ import copy import pickle import tracemalloc +from collections.abc import Iterator from pathlib import Path from typing import Any from unittest.mock import patch @@ -18,9 +19,11 @@ TrainingSAE, TrainingSAEConfig, ) +from sae_lens.saes.standard_sae import StandardSAE from tests.helpers import ( ALL_TRAINING_ARCHITECTURES, assert_close, + build_sae_cfg, build_sae_training_cfg_for_arch, random_params, ) @@ -497,3 +500,82 @@ def mock_loader( assert_close(loaded_sae.W_dec, original_W_dec / scaling_factor) assert_close(loaded_sae.b_dec, original_b_dec / scaling_factor) assert loaded_sae.cfg.normalize_activations == "none" + + +def _correlated_activations(d_in: int, batch_size: int) -> Iterator[torch.Tensor]: + # Fixed (but unseeded) correlation structure shared across all batches, so + # every batch is drawn from the same anisotropic, shifted distribution. + transform = torch.randn(d_in, d_in) + shift = torch.randn(d_in) * 3.0 + while True: + yield torch.randn(batch_size, d_in) @ transform.T + shift + + +def test_covariance_whitening_produces_identity_covariance(): + d_in = 8 + sae = StandardSAE( + build_sae_cfg(d_in=d_in, d_sae=32, normalize_activations="covariance_whitening") + ) + provider = _correlated_activations(d_in, batch_size=4096) + sae.estimate_whitening(provider, n_batches=30, eps=1e-8) + + sample = torch.cat([next(provider) for _ in range(30)], dim=0) + whitened = sae.run_time_activation_norm_fn_in(sample) + + cov = torch.cov(whitened.T) + assert_close(cov, torch.eye(d_in), atol=0.05, rtol=0) + assert_close(whitened.mean(dim=0), torch.zeros(d_in), atol=0.05, rtol=0) + + +def test_covariance_whitening_round_trip_recovers_input(): + d_in = 8 + sae = StandardSAE( + build_sae_cfg(d_in=d_in, normalize_activations="covariance_whitening") + ) + provider = _correlated_activations(d_in, batch_size=1024) + sae.estimate_whitening(provider, n_batches=10) + + x = next(provider) + recovered = sae.run_time_activation_norm_fn_out( + sae.run_time_activation_norm_fn_in(x) + ) + assert_close(recovered, x, atol=1e-4, rtol=1e-4) + + +def test_covariance_whitening_is_noop_before_estimation(): + d_in = 8 + sae = StandardSAE( + build_sae_cfg(d_in=d_in, normalize_activations="covariance_whitening") + ) + x = torch.randn(16, d_in) + # Buffers start as identity / zero, so whitening must be the identity map. + assert_close(sae.run_time_activation_norm_fn_in(x), x) + + +def test_covariance_whitening_buffers_persist_through_save_load(tmp_path: Path): + d_in = 8 + sae = StandardSAE( + build_sae_cfg(d_in=d_in, normalize_activations="covariance_whitening") + ) + random_params(sae) + provider = _correlated_activations(d_in, batch_size=1024) + sae.estimate_whitening(provider, n_batches=10) + + sae.save_model(tmp_path) + loaded = StandardSAE.load_from_disk(tmp_path) + + assert_close(loaded.whitening_mean, sae.whitening_mean) + assert_close(loaded.whitening_W, sae.whitening_W) + assert_close(loaded.whitening_W_inv, sae.whitening_W_inv) + + x = next(provider) + assert_close( + loaded.run_time_activation_norm_fn_in(x), sae.run_time_activation_norm_fn_in(x) + ) + + +def test_estimate_whitening_requires_covariance_whitening_config(): + sae = StandardSAE(build_sae_cfg(d_in=8, normalize_activations="none")) + provider = _correlated_activations(8, batch_size=16) + with pytest.raises(ValueError, match="covariance_whitening"): + sae.estimate_whitening(provider, n_batches=1)