-
Notifications
You must be signed in to change notification settings - Fork 238
Add covariance whitening normalization option (#647) #697
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
azrabano23
wants to merge
1
commit into
decoderesearch:main
Choose a base branch
from
azrabano23:covariance-whitening-647
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,13 +4,13 @@ | |
| import json | ||
| import warnings | ||
| from abc import ABC, abstractmethod | ||
| from collections.abc import Callable, Iterator | ||
| from contextlib import contextmanager | ||
| from dataclasses import asdict, dataclass, field, fields, replace | ||
| from pathlib import Path | ||
| from typing import ( | ||
| TYPE_CHECKING, | ||
| Any, | ||
| Callable, | ||
| Generic, | ||
| Literal, | ||
| NamedTuple, | ||
|
|
@@ -22,6 +22,7 @@ | |
| from numpy.typing import NDArray | ||
| from safetensors.torch import load_file, save_file | ||
| from torch import nn | ||
| from tqdm.auto import tqdm | ||
| from transformer_lens.hook_points import HookedRootModule, HookPoint | ||
| from typing_extensions import deprecated, overload, override | ||
|
|
||
|
|
@@ -158,6 +159,7 @@ class SAEConfig(ABC): | |
| "none", | ||
| "expected_average_only_in", # (Anthropic April 2024 Update) | ||
| "layer_norm", | ||
| "covariance_whitening", # PCA whitening (https://arxiv.org/abs/2511.13981) | ||
| ] = "none" | ||
| reshape_activations: Literal["none", "hook_z"] = "none" | ||
| metadata: SAEMetadata = field(default_factory=SAEMetadata) | ||
|
|
@@ -191,9 +193,10 @@ def __post_init__(self): | |
| "expected_average_only_in", | ||
| "constant_norm_rescale", | ||
| "layer_norm", | ||
| "covariance_whitening", | ||
| ]: | ||
| raise ValueError( | ||
| f"normalize_activations must be none, expected_average_only_in, layer_norm, or constant_norm_rescale. Got {self.normalize_activations}" | ||
| f"normalize_activations must be none, expected_average_only_in, layer_norm, constant_norm_rescale, or covariance_whitening. Got {self.normalize_activations}" | ||
| ) | ||
|
|
||
|
|
||
|
|
@@ -242,6 +245,10 @@ class SAE(HookedRootModule, Generic[T_SAE_CONFIG], ABC): | |
| W_enc: nn.Parameter | ||
| W_dec: nn.Parameter | ||
| b_dec: nn.Parameter | ||
| # Registered as buffers only when normalize_activations == "covariance_whitening". | ||
| whitening_mean: torch.Tensor | ||
| whitening_W: torch.Tensor | ||
| whitening_W_inv: torch.Tensor | ||
|
|
||
| def __init__(self, cfg: T_SAE_CONFIG, use_error_term: bool = False): | ||
| """Initialize the SAE.""" | ||
|
|
@@ -353,10 +360,95 @@ def run_time_activation_ln_out( | |
|
|
||
| self.run_time_activation_norm_fn_in = run_time_activation_ln_in | ||
| self.run_time_activation_norm_fn_out = run_time_activation_ln_out | ||
|
|
||
| elif self.cfg.normalize_activations == "covariance_whitening": | ||
| # PCA whitening (https://arxiv.org/abs/2511.13981). The whitening | ||
| # statistics are data-estimated by estimate_whitening() before | ||
| # training and stored as buffers so they persist on save/load. Until | ||
| # estimated they are the identity, so whitening is a no-op. | ||
| d_in = self.cfg.d_in | ||
| self.register_buffer( | ||
| "whitening_mean", | ||
| torch.zeros(d_in, dtype=self.dtype, device=self.device), | ||
| ) | ||
| self.register_buffer( | ||
| "whitening_W", torch.eye(d_in, dtype=self.dtype, device=self.device) | ||
| ) | ||
| self.register_buffer( | ||
| "whitening_W_inv", torch.eye(d_in, dtype=self.dtype, device=self.device) | ||
| ) | ||
|
|
||
| def run_time_activation_whiten_in(x: torch.Tensor) -> torch.Tensor: | ||
| # z = W (x - mu), applied to row vectors in the last dim. | ||
| return (x - self.whitening_mean) @ self.whitening_W.T | ||
|
|
||
| def run_time_activation_whiten_out(x: torch.Tensor) -> torch.Tensor: | ||
| # Dewhiten before the reconstruction loss: x = W^-1 z + mu. | ||
| return x @ self.whitening_W_inv.T + self.whitening_mean | ||
|
|
||
| self.run_time_activation_norm_fn_in = run_time_activation_whiten_in | ||
| self.run_time_activation_norm_fn_out = run_time_activation_whiten_out | ||
|
|
||
| else: | ||
| self.run_time_activation_norm_fn_in = lambda x: x | ||
| self.run_time_activation_norm_fn_out = lambda x: x | ||
|
|
||
| @torch.no_grad() | ||
| def estimate_whitening( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This makes more sense to go in ActivationScaler, since this whitening is more of a pre-processing of the activations rather than something the base SAE class should need to deal with. |
||
| self, | ||
| data_provider: Iterator[torch.Tensor], | ||
| n_batches: int = 100, | ||
| eps: float = 1e-3, | ||
| ) -> None: | ||
| """Estimate PCA whitening statistics from activations. | ||
|
|
||
| Accumulates the mean and covariance of activations over ``n_batches``, | ||
| then sets the whitening buffers via an eigendecomposition of the | ||
| covariance, following https://arxiv.org/abs/2511.13981. | ||
|
|
||
| The whitening matrix is W = D^(-1/2) E^T and its inverse W^-1 = E D^(1/2), | ||
| where the covariance is C = E D E^T and each eigenvalue is regularized as | ||
| lambda_i + eps for numerical stability. | ||
|
|
||
| Args: | ||
| data_provider: Iterator yielding activation batches of shape | ||
| (batch_size, d_in). | ||
| n_batches: Number of batches to accumulate statistics over. | ||
| eps: Stabilizing constant added to covariance eigenvalues. | ||
| """ | ||
| if self.cfg.normalize_activations != "covariance_whitening": | ||
| raise ValueError( | ||
| "estimate_whitening requires normalize_activations='covariance_whitening', " | ||
| f"got {self.cfg.normalize_activations!r}" | ||
| ) | ||
|
|
||
| d_in = self.cfg.d_in | ||
| # Accumulate in float64 for numerical stability of the covariance. | ||
| count = 0 | ||
| sum_x = torch.zeros(d_in, dtype=torch.float64, device=self.device) | ||
| sum_outer = torch.zeros(d_in, d_in, dtype=torch.float64, device=self.device) | ||
| for _ in tqdm(range(n_batches), desc="Estimating whitening", leave=False): | ||
| acts = self.reshape_fn_in(next(data_provider)).to( | ||
| self.device, torch.float64 | ||
| ) | ||
| count += acts.shape[0] | ||
| sum_x += acts.sum(dim=0) | ||
| sum_outer += acts.T @ acts | ||
|
|
||
| mean = sum_x / count | ||
| # Unbiased covariance: (sum_outer - count * mean mean^T) / (count - 1). | ||
| cov = (sum_outer - count * torch.outer(mean, mean)) / (count - 1) | ||
| eigenvalues, eigenvectors = torch.linalg.eigh(cov) | ||
| inv_sqrt = (eigenvalues + eps).rsqrt() | ||
| sqrt = (eigenvalues + eps).sqrt() | ||
| # W = D^(-1/2) E^T ; W^-1 = E D^(1/2). | ||
| whitening_w = (eigenvectors * inv_sqrt).T | ||
| whitening_w_inv = eigenvectors * sqrt | ||
|
|
||
| self.whitening_mean.copy_(mean.to(self.dtype)) | ||
| self.whitening_W.copy_(whitening_w.to(self.dtype)) | ||
| self.whitening_W_inv.copy_(whitening_w_inv.to(self.dtype)) | ||
|
|
||
| def initialize_weights(self): | ||
| """Initialize model weights.""" | ||
| self.b_dec = nn.Parameter( | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we fold the whitening into the SAE weights at the end of training, like we do with
expected_average_only_inscaling? This is just a linear operation, so it should be possible to calculate adjusted versions of the weights that take this into account I think. Then we don't need to handle any special processing during inference, only during training, and the distributed SAE doesn't need extra weights matrices.