Skip to content

DRY fold decoder norms logic#705

Open
danra wants to merge 7 commits into
decoderesearch:mainfrom
danra:dry_fold_decoder_norms
Open

DRY fold decoder norms logic#705
danra wants to merge 7 commits into
decoderesearch:mainfrom
danra:dry_fold_decoder_norms

Conversation

@danra

@danra danra commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Description

DRY fold decoder norms logic, with a few drive-by massages to make it testable and a pure refactor

Type of change

Almost pure refactor. The only exception is JumpReLU transcoder now normalizes its threshold by the exact same scalars (same decoder norms) as other weights. Previously the scalars were only extremely close.

Checklist:

  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

You have tested formatting, typing and tests

  • I have run make check-ci to check format and linting. (you can run make format to format code if needed.)

danra and others added 5 commits June 11, 2026 21:52
With d_sae == 1, W_dec_norms is shape (1, 1) and .squeeze() collapses it
to a 0-dim scalar

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Note this also removes an explicit comment in jumprelu about intentionally using
squeeze() rather than squeeze(-1) to maintain past behavior; but other than the
internally different behavior in the edge-case of d_sae=1 -- extra squeeze and
broadcasting, which is now tested and is still valid -- there is no difference
in behavior.
Copilot AI review requested due to automatic review settings June 12, 2026 21:55

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

This PR updates decoder-norm folding to return the per-row decoder norms (for reuse by subclasses needing additional scaling), fixes bias scaling edge cases (notably d_sae=1), and reorganizes related tests.

Changes:

  • Change fold_W_dec_norm to return W_dec norms and update multiple SAE subclasses accordingly.
  • Fix b_enc scaling to avoid squeeze() collapsing shapes when d_sae == 1.
  • Move/extend folding tests into tests/saes/test_sae.py, including architecture-wide coverage and a zero-norm regression test.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/saes/test_standard_sae.py Removes architecture-wide folding tests (moved elsewhere) and tightens a squeeze(-1) expectation.
tests/saes/test_sae.py Adds cross-architecture folding tests, d_sae=1 shape regression, and zero-norm NaN/Inf regression tests.
sae_lens/saes/sae.py Changes fold_W_dec_norm to return norms and fixes bias scaling for d_sae=1.
sae_lens/saes/jumprelu_sae.py Uses returned norms from base folding and returns norms to callers.
sae_lens/saes/gated_sae.py Uses returned norms from base folding and scales gated biases without shape-collapsing.
sae_lens/saes/topk_sae.py Updates folding override to return norms and delegate to base implementation.
sae_lens/saes/transcoder.py Uses returned norms to scale JumpReLU transcoder thresholds and returns norms.
sae_lens/saes/temporal_sae.py Updates signature to match the new return type (still unsupported).
sae_lens/saes/matching_pursuit_sae.py Updates signature to match the new return type (still unsupported).

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread sae_lens/saes/sae.py Outdated
Comment on lines +508 to +518
def fold_W_dec_norm(self) -> torch.Tensor:
"""Fold decoder norms into encoder."""
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
self.W_dec.data = self.W_dec.data / W_dec_norms
self.W_enc.data = self.W_enc.data * W_dec_norms.T
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)
self.W_dec.data = self.W_dec.data / W_dec_norms.unsqueeze(1)
self.W_enc.data = self.W_enc.data * W_dec_norms.unsqueeze(1).T

# Only update b_enc if it exists (standard/jumprelu architectures)
if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter):
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
self.b_enc.data = self.b_enc.data * W_dec_norms

return W_dec_norms
Comment thread tests/saes/test_sae.py
Comment on lines +264 to +270
def test_sae_fold_w_dec_norm_all_architectures(architecture: str):
cfg = build_sae_cfg_for_arch(architecture)
sae = SAE.from_dict(cfg.to_dict())

# make sure all parameters are not 0s
for param in sae.parameters():
param.data = torch.rand_like(param)
Comment thread tests/saes/test_sae.py
Comment on lines 1 to +4
import copy
import pickle
import tracemalloc
from copy import deepcopy
@danra danra force-pushed the dry_fold_decoder_norms branch from b00e6e0 to 10ddf8a Compare June 13, 2026 04:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants