Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 94 additions & 2 deletions sae_lens/saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import json
import warnings
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterator
from contextlib import contextmanager
from dataclasses import asdict, dataclass, field, fields, replace
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Literal,
NamedTuple,
Expand All @@ -22,6 +22,7 @@
from numpy.typing import NDArray
from safetensors.torch import load_file, save_file
from torch import nn
from tqdm.auto import tqdm
from transformer_lens.hook_points import HookedRootModule, HookPoint
from typing_extensions import deprecated, overload, override

Expand Down Expand Up @@ -158,6 +159,7 @@ class SAEConfig(ABC):
"none",
"expected_average_only_in", # (Anthropic April 2024 Update)
"layer_norm",
"covariance_whitening", # PCA whitening (https://arxiv.org/abs/2511.13981)
] = "none"
reshape_activations: Literal["none", "hook_z"] = "none"
metadata: SAEMetadata = field(default_factory=SAEMetadata)
Expand Down Expand Up @@ -191,9 +193,10 @@ def __post_init__(self):
"expected_average_only_in",
"constant_norm_rescale",
"layer_norm",
"covariance_whitening",
]:
raise ValueError(
f"normalize_activations must be none, expected_average_only_in, layer_norm, or constant_norm_rescale. Got {self.normalize_activations}"
f"normalize_activations must be none, expected_average_only_in, layer_norm, constant_norm_rescale, or covariance_whitening. Got {self.normalize_activations}"
)


Expand Down Expand Up @@ -242,6 +245,10 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC):
W_enc: nn.Parameter
W_dec: nn.Parameter
b_dec: nn.Parameter
# Registered as buffers only when normalize_activations == "covariance_whitening".
whitening_mean: torch.Tensor
whitening_W: torch.Tensor
whitening_W_inv: torch.Tensor

def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False):
"""Initialize the SAE."""
Expand Down Expand Up @@ -353,10 +360,95 @@ def run_time_activation_ln_out(

self.run_time_activation_norm_fn_in = run_time_activation_ln_in
self.run_time_activation_norm_fn_out = run_time_activation_ln_out

elif self.cfg.normalize_activations == "covariance_whitening":

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we fold the whitening into the SAE weights at the end of training, like we do with expected_average_only_in scaling? This is just a linear operation, so it should be possible to calculate adjusted versions of the weights that take this into account I think. Then we don't need to handle any special processing during inference, only during training, and the distributed SAE doesn't need extra weights matrices.

# PCA whitening (https://arxiv.org/abs/2511.13981). The whitening
# statistics are data-estimated by estimate_whitening() before
# training and stored as buffers so they persist on save/load. Until
# estimated they are the identity, so whitening is a no-op.
d_in = self.cfg.d_in
self.register_buffer(
"whitening_mean",
torch.zeros(d_in, dtype=self.dtype, device=self.device),
)
self.register_buffer(
"whitening_W", torch.eye(d_in, dtype=self.dtype, device=self.device)
)
self.register_buffer(
"whitening_W_inv", torch.eye(d_in, dtype=self.dtype, device=self.device)
)

def run_time_activation_whiten_in(x: torch.Tensor) -> torch.Tensor:
# z = W (x - mu), applied to row vectors in the last dim.
return (x - self.whitening_mean) @ self.whitening_W.T

def run_time_activation_whiten_out(x: torch.Tensor) -> torch.Tensor:
# Dewhiten before the reconstruction loss: x = W^-1 z + mu.
return x @ self.whitening_W_inv.T + self.whitening_mean

self.run_time_activation_norm_fn_in = run_time_activation_whiten_in
self.run_time_activation_norm_fn_out = run_time_activation_whiten_out

else:
self.run_time_activation_norm_fn_in = lambda x: x
self.run_time_activation_norm_fn_out = lambda x: x

@torch.no_grad()
def estimate_whitening(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes more sense to go in ActivationScaler, since this whitening is more of a pre-processing of the activations rather than something the base SAE class should need to deal with.

self,
data_provider: Iterator[torch.Tensor],
n_batches: int = 100,
eps: float = 1e-3,
) -> None:
"""Estimate PCA whitening statistics from activations.

Accumulates the mean and covariance of activations over ``n_batches``,
then sets the whitening buffers via an eigendecomposition of the
covariance, following https://arxiv.org/abs/2511.13981.

The whitening matrix is W = D^(-1/2) E^T and its inverse W^-1 = E D^(1/2),
where the covariance is C = E D E^T and each eigenvalue is regularized as
lambda_i + eps for numerical stability.

Args:
data_provider: Iterator yielding activation batches of shape
(batch_size, d_in).
n_batches: Number of batches to accumulate statistics over.
eps: Stabilizing constant added to covariance eigenvalues.
"""
if self.cfg.normalize_activations != "covariance_whitening":
raise ValueError(
"estimate_whitening requires normalize_activations='covariance_whitening', "
f"got {self.cfg.normalize_activations!r}"
)

d_in = self.cfg.d_in
# Accumulate in float64 for numerical stability of the covariance.
count = 0
sum_x = torch.zeros(d_in, dtype=torch.float64, device=self.device)
sum_outer = torch.zeros(d_in, d_in, dtype=torch.float64, device=self.device)
for _ in tqdm(range(n_batches), desc="Estimating whitening", leave=False):
acts = self.reshape_fn_in(next(data_provider)).to(
self.device, torch.float64
)
count += acts.shape[0]
sum_x += acts.sum(dim=0)
sum_outer += acts.T @ acts

mean = sum_x / count
# Unbiased covariance: (sum_outer - count * mean mean^T) / (count - 1).
cov = (sum_outer - count * torch.outer(mean, mean)) / (count - 1)
eigenvalues, eigenvectors = torch.linalg.eigh(cov)
inv_sqrt = (eigenvalues + eps).rsqrt()
sqrt = (eigenvalues + eps).sqrt()
# W = D^(-1/2) E^T ; W^-1 = E D^(1/2).
whitening_w = (eigenvectors * inv_sqrt).T
whitening_w_inv = eigenvectors * sqrt

self.whitening_mean.copy_(mean.to(self.dtype))
self.whitening_W.copy_(whitening_w.to(self.dtype))
self.whitening_W_inv.copy_(whitening_w_inv.to(self.dtype))

def initialize_weights(self):
"""Initialize model weights."""
self.b_dec = nn.Parameter(
Expand Down
5 changes: 5 additions & 0 deletions sae_lens/training/sae_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ def fit(self) -> T_TRAINING_SAE:
data_provider=self.data_provider,
n_batches_for_norm_estimate=self.cfg.n_batches_for_norm_estimate,
)
elif self.sae.cfg.normalize_activations == "covariance_whitening":
self.sae.estimate_whitening(
data_provider=self.data_provider,
n_batches=self.cfg.n_batches_for_norm_estimate,
)

# Train loop
while self.n_training_samples < self.cfg.total_training_samples:
Expand Down
82 changes: 82 additions & 0 deletions tests/saes/test_sae.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import pickle
import tracemalloc
from collections.abc import Iterator
from pathlib import Path
from typing import Any
from unittest.mock import patch
Expand All @@ -18,9 +19,11 @@
TrainingSAE,
TrainingSAEConfig,
)
from sae_lens.saes.standard_sae import StandardSAE
from tests.helpers import (
ALL_TRAINING_ARCHITECTURES,
assert_close,
build_sae_cfg,
build_sae_training_cfg_for_arch,
random_params,
)
Expand Down Expand Up @@ -497,3 +500,82 @@ def mock_loader(
assert_close(loaded_sae.W_dec, original_W_dec / scaling_factor)
assert_close(loaded_sae.b_dec, original_b_dec / scaling_factor)
assert loaded_sae.cfg.normalize_activations == "none"


def _correlated_activations(d_in: int, batch_size: int) -> Iterator[torch.Tensor]:
# Fixed (but unseeded) correlation structure shared across all batches, so
# every batch is drawn from the same anisotropic, shifted distribution.
transform = torch.randn(d_in, d_in)
shift = torch.randn(d_in) * 3.0
while True:
yield torch.randn(batch_size, d_in) @ transform.T + shift


def test_covariance_whitening_produces_identity_covariance():
d_in = 8
sae = StandardSAE(
build_sae_cfg(d_in=d_in, d_sae=32, normalize_activations="covariance_whitening")
)
provider = _correlated_activations(d_in, batch_size=4096)
sae.estimate_whitening(provider, n_batches=30, eps=1e-8)

sample = torch.cat([next(provider) for _ in range(30)], dim=0)
whitened = sae.run_time_activation_norm_fn_in(sample)

cov = torch.cov(whitened.T)
assert_close(cov, torch.eye(d_in), atol=0.05, rtol=0)
assert_close(whitened.mean(dim=0), torch.zeros(d_in), atol=0.05, rtol=0)


def test_covariance_whitening_round_trip_recovers_input():
d_in = 8
sae = StandardSAE(
build_sae_cfg(d_in=d_in, normalize_activations="covariance_whitening")
)
provider = _correlated_activations(d_in, batch_size=1024)
sae.estimate_whitening(provider, n_batches=10)

x = next(provider)
recovered = sae.run_time_activation_norm_fn_out(
sae.run_time_activation_norm_fn_in(x)
)
assert_close(recovered, x, atol=1e-4, rtol=1e-4)


def test_covariance_whitening_is_noop_before_estimation():
d_in = 8
sae = StandardSAE(
build_sae_cfg(d_in=d_in, normalize_activations="covariance_whitening")
)
x = torch.randn(16, d_in)
# Buffers start as identity / zero, so whitening must be the identity map.
assert_close(sae.run_time_activation_norm_fn_in(x), x)


def test_covariance_whitening_buffers_persist_through_save_load(tmp_path: Path):
d_in = 8
sae = StandardSAE(
build_sae_cfg(d_in=d_in, normalize_activations="covariance_whitening")
)
random_params(sae)
provider = _correlated_activations(d_in, batch_size=1024)
sae.estimate_whitening(provider, n_batches=10)

sae.save_model(tmp_path)
loaded = StandardSAE.load_from_disk(tmp_path)

assert_close(loaded.whitening_mean, sae.whitening_mean)
assert_close(loaded.whitening_W, sae.whitening_W)
assert_close(loaded.whitening_W_inv, sae.whitening_W_inv)

x = next(provider)
assert_close(
loaded.run_time_activation_norm_fn_in(x), sae.run_time_activation_norm_fn_in(x)
)


def test_estimate_whitening_requires_covariance_whitening_config():
sae = StandardSAE(build_sae_cfg(d_in=8, normalize_activations="none"))
provider = _correlated_activations(8, batch_size=16)
with pytest.raises(ValueError, match="covariance_whitening"):
sae.estimate_whitening(provider, n_batches=1)
Loading