DRY fold decoder norms logic#705
Open
danra wants to merge 7 commits into
Open
Conversation
With d_sae == 1, W_dec_norms is shape (1, 1) and .squeeze() collapses it to a 0-dim scalar Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Note this also removes an explicit comment in jumprelu about intentionally using squeeze() rather than squeeze(-1) to maintain past behavior; but other than the internally different behavior in the edge-case of d_sae=1 -- extra squeeze and broadcasting, which is now tested and is still valid -- there is no difference in behavior.
… as other weights
Contributor
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
This PR updates decoder-norm folding to return the per-row decoder norms (for reuse by subclasses needing additional scaling), fixes bias scaling edge cases (notably d_sae=1), and reorganizes related tests.
Changes:
- Change
fold_W_dec_normto returnW_decnorms and update multiple SAE subclasses accordingly. - Fix
b_encscaling to avoidsqueeze()collapsing shapes whend_sae == 1. - Move/extend folding tests into
tests/saes/test_sae.py, including architecture-wide coverage and a zero-norm regression test.
Reviewed changes
Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/saes/test_standard_sae.py | Removes architecture-wide folding tests (moved elsewhere) and tightens a squeeze(-1) expectation. |
| tests/saes/test_sae.py | Adds cross-architecture folding tests, d_sae=1 shape regression, and zero-norm NaN/Inf regression tests. |
| sae_lens/saes/sae.py | Changes fold_W_dec_norm to return norms and fixes bias scaling for d_sae=1. |
| sae_lens/saes/jumprelu_sae.py | Uses returned norms from base folding and returns norms to callers. |
| sae_lens/saes/gated_sae.py | Uses returned norms from base folding and scales gated biases without shape-collapsing. |
| sae_lens/saes/topk_sae.py | Updates folding override to return norms and delegate to base implementation. |
| sae_lens/saes/transcoder.py | Uses returned norms to scale JumpReLU transcoder thresholds and returns norms. |
| sae_lens/saes/temporal_sae.py | Updates signature to match the new return type (still unsupported). |
| sae_lens/saes/matching_pursuit_sae.py | Updates signature to match the new return type (still unsupported). |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+508
to
+518
| def fold_W_dec_norm(self) -> torch.Tensor: | ||
| """Fold decoder norms into encoder.""" | ||
| W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1) | ||
| self.W_dec.data = self.W_dec.data / W_dec_norms | ||
| self.W_enc.data = self.W_enc.data * W_dec_norms.T | ||
| W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8) | ||
| self.W_dec.data = self.W_dec.data / W_dec_norms.unsqueeze(1) | ||
| self.W_enc.data = self.W_enc.data * W_dec_norms.unsqueeze(1).T | ||
|
|
||
| # Only update b_enc if it exists (standard/jumprelu architectures) | ||
| if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter): | ||
| self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze() | ||
| self.b_enc.data = self.b_enc.data * W_dec_norms | ||
|
|
||
| return W_dec_norms |
Comment on lines
+264
to
+270
| def test_sae_fold_w_dec_norm_all_architectures(architecture: str): | ||
| cfg = build_sae_cfg_for_arch(architecture) | ||
| sae = SAE.from_dict(cfg.to_dict()) | ||
|
|
||
| # make sure all parameters are not 0s | ||
| for param in sae.parameters(): | ||
| param.data = torch.rand_like(param) |
Comment on lines
1
to
+4
| import copy | ||
| import pickle | ||
| import tracemalloc | ||
| from copy import deepcopy |
b00e6e0 to
10ddf8a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
DRY fold decoder norms logic, with a few drive-by massages to make it testable and a pure refactor
Type of change
Almost pure refactor. The only exception is JumpReLU transcoder now normalizes its threshold by the exact same scalars (same decoder norms) as other weights. Previously the scalars were only extremely close.
Checklist:
You have tested formatting, typing and tests
make check-cito check format and linting. (you can runmake formatto format code if needed.)