From ed03845ca13066978d6938f8d7a39cddc351e9e2 Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Thu, 11 Jun 2026 21:52:07 -0700 Subject: [PATCH 1/7] chore: Move cross-arch SAE tests to correct file --- tests/saes/test_sae.py | 242 +++++++++++++++++++++++++++++++ tests/saes/test_standard_sae.py | 244 +------------------------------- 2 files changed, 243 insertions(+), 243 deletions(-) diff --git a/tests/saes/test_sae.py b/tests/saes/test_sae.py index 20517f2c6..808b07d32 100644 --- a/tests/saes/test_sae.py +++ b/tests/saes/test_sae.py @@ -1,6 +1,7 @@ import copy import pickle import tracemalloc +from copy import deepcopy from pathlib import Path from typing import Any from unittest.mock import patch @@ -19,8 +20,11 @@ TrainingSAEConfig, ) from tests.helpers import ( + ALL_ARCHITECTURES, + ALL_FOLDABLE_ARCHITECTURES, ALL_TRAINING_ARCHITECTURES, assert_close, + build_sae_cfg_for_arch, build_sae_training_cfg_for_arch, random_params, ) @@ -255,6 +259,244 @@ def test_TrainingSAE_fold_activation_norm_scaling_factor_all_architectures( assert_close(folded_features, original_features) +@pytest.mark.parametrize("architecture", ALL_ARCHITECTURES) +@torch.no_grad() +def test_sae_fold_w_dec_norm_all_architectures(architecture: str): + cfg = build_sae_cfg_for_arch(architecture) + sae = SAE.from_dict(cfg.to_dict()) + sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. + + # make sure all parameters are not 0s + for param in sae.parameters(): + param.data = torch.rand_like(param) + + assert sae.W_dec.norm(dim=-1).mean().item() != pytest.approx(1.0, abs=1e-6) + sae2 = deepcopy(sae) + + # If this is a topk SAE, assert this throws a NotImplementedError + if architecture not in ALL_FOLDABLE_ARCHITECTURES: + with pytest.raises(NotImplementedError): + sae2.fold_W_dec_norm() + return + + sae2.fold_W_dec_norm() + + # fold_W_dec_norm should normalize W_dec to have unit norm. + assert sae2.W_dec.norm(dim=-1).mean().item() == pytest.approx(1.0, abs=1e-6) + + # we expect activations of features to differ by W_dec norm weights. + activations = torch.randn(10, 4, cfg.d_in, device=cfg.device) + feature_activations_1 = sae.encode(activations) + feature_activations_2 = sae2.encode(activations) + + assert_close( + feature_activations_1.nonzero(), + feature_activations_2.nonzero(), + ) + + expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm(dim=-1) + assert_close(feature_activations_2, expected_feature_activations_2, atol=1e-5) + + sae_out_1 = sae.decode(feature_activations_1) + sae_out_2 = sae2.decode(feature_activations_2) + + # but actual outputs should be the same + assert_close(sae_out_1, sae_out_2, atol=1e-5) + + +@pytest.mark.parametrize("architecture", ALL_TRAINING_ARCHITECTURES) +@torch.no_grad() +def test_training_sae_fold_w_dec_norm_all_architectures(architecture: str): + cfg = build_sae_training_cfg_for_arch(architecture) + sae = TrainingSAE.from_dict(cfg.to_dict()) + sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. + + # make sure all parameters are not 0s + for param in sae.parameters(): + param.data = torch.rand_like(param) + + assert sae.W_dec.norm(dim=-1).mean().item() != pytest.approx(1.0, abs=1e-6) + sae2 = deepcopy(sae) + + if architecture in {"matching_pursuit"}: + with pytest.raises(NotImplementedError): + sae2.fold_W_dec_norm() + return + + sae2.fold_W_dec_norm() + + # fold_W_dec_norm should normalize W_dec to have unit norm. + assert sae2.W_dec.norm(dim=-1).mean().item() == pytest.approx(1.0, abs=1e-6) + + # we expect activations of features to differ by W_dec norm weights. + activations = torch.randn(10, 4, cfg.d_in, device=cfg.device) + feature_activations_1 = sae.encode(activations) + feature_activations_2 = sae2.encode(activations) + + assert_close( + feature_activations_1.nonzero(), + feature_activations_2.nonzero(), + ) + + if architecture in {"topk", "batchtopk", "matryoshka_batchtopk"}: + # Due to how rescale_acts_by_decoder_norm works in TopKSAEs, it's like the + # SAE has the norm folded in throughout the entire training process. + assert_close(feature_activations_2, feature_activations_1, atol=1e-4, rtol=1e-4) + else: + expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm(dim=-1) + assert_close( + feature_activations_2, expected_feature_activations_2, atol=1e-4, rtol=1e-4 + ) + + sae_out_1 = sae.decode(feature_activations_1) + sae_out_2 = sae2.decode(feature_activations_2) + + # but actual outputs should be the same + assert_close(sae_out_1, sae_out_2) + + +@pytest.mark.parametrize("architecture", ALL_FOLDABLE_ARCHITECTURES) +@torch.no_grad() +def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): + cfg = build_sae_cfg_for_arch(architecture) + norm_scaling_factor = 3.0 + + sae = SAE.from_dict(cfg.to_dict()) + # make sure all parameters are not 0s + for param in sae.parameters(): + param.data = torch.rand_like(param) + + sae2 = deepcopy(sae) + sae2.fold_activation_norm_scaling_factor(norm_scaling_factor) + + assert sae2.cfg.normalize_activations == "none" + + assert_close(sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor) + + # we expect activations of features to differ by W_dec norm weights. + # assume activations are already scaled + activations = torch.randn(10, 4, cfg.d_in, device=cfg.device) + # we divide to get the unscale activations + unscaled_activations = activations / norm_scaling_factor + + feature_activations_1 = sae.encode(activations) + if feature_activations_1.is_sparse: + feature_activations_1 = feature_activations_1.to_dense() + # with the scaling folded in, the unscaled activations should produce the same + # result. + feature_activations_2 = sae2.encode(unscaled_activations) + if feature_activations_2.is_sparse: + feature_activations_2 = feature_activations_2.to_dense() + + assert_close( + feature_activations_1.nonzero(), + feature_activations_2.nonzero(), + ) + + assert_close(feature_activations_2, feature_activations_1, atol=1e-5) + + sae_out_1 = sae.decode(feature_activations_1) + sae_out_2 = norm_scaling_factor * sae2.decode(feature_activations_2) + + # but actual outputs should be the same + assert_close(sae_out_1, sae_out_2, atol=1e-5) + + +@pytest.mark.parametrize("architecture", ALL_FOLDABLE_ARCHITECTURES) +@torch.no_grad() +def test_fold_W_dec_norm_does_not_produce_nan_with_zero_norm_decoder( + architecture: str, +): + """ + Regression test for https://github.com/decoderesearch/SAELens/issues/588 + + When decoder weights have zero norm (dead latents), the division in + fold_W_dec_norm should not produce NaN values. This is handled by + clamping the norm to a minimum of 1e-8. + """ + cfg = build_sae_cfg_for_arch(architecture) + sae = SAE.from_dict(cfg.to_dict()) + sae.turn_off_forward_pass_hook_z_reshaping() + + # Initialize parameters with random values + for param in sae.parameters(): + param.data = torch.rand_like(param) + + # Set some decoder rows to zero to simulate dead latents + num_zero_rows = min(5, sae.W_dec.shape[0]) + sae.W_dec.data[:num_zero_rows] = 0.0 + + # Verify that we actually have zero-norm rows + norms_before = sae.W_dec.norm(dim=-1) + assert (norms_before[:num_zero_rows] == 0).all() + + # TopK SAEs with rescale_acts_by_decoder_norm=False raise NotImplementedError + if architecture == "topk" and not getattr( + sae.cfg, "rescale_acts_by_decoder_norm", False + ): + with pytest.raises(NotImplementedError): + sae.fold_W_dec_norm() + return + + # Call fold_W_dec_norm - this should not produce NaN values + sae.fold_W_dec_norm() + + # Verify no NaN or Inf values in any parameters + for name, param in sae.named_parameters(): + assert not torch.isnan( + param + ).any(), f"NaN found in {name} after fold_W_dec_norm" + assert not torch.isinf( + param + ).any(), f"Inf found in {name} after fold_W_dec_norm" + + +@pytest.mark.parametrize("architecture", ALL_TRAINING_ARCHITECTURES) +@torch.no_grad() +def test_training_fold_W_dec_norm_does_not_produce_nan_with_zero_norm_decoder( + architecture: str, +): + """ + Regression test for https://github.com/decoderesearch/SAELens/issues/588 + + When decoder weights have zero norm (dead latents), the division in + fold_W_dec_norm should not produce NaN values for TrainingSAE classes. + """ + cfg = build_sae_training_cfg_for_arch(architecture) + sae = TrainingSAE.from_dict(cfg.to_dict()) + sae.turn_off_forward_pass_hook_z_reshaping() + + # Initialize parameters with random values + for param in sae.parameters(): + param.data = torch.rand_like(param) + + # Set some decoder rows to zero to simulate dead latents + num_zero_rows = min(5, sae.W_dec.shape[0]) + sae.W_dec.data[:num_zero_rows] = 0.0 + + # Verify that we actually have zero-norm rows + norms_before = sae.W_dec.norm(dim=-1) + assert (norms_before[:num_zero_rows] == 0).all() + + # Call fold_W_dec_norm - this should not produce NaN values + + if architecture in {"matching_pursuit"}: + with pytest.raises(NotImplementedError): + sae.fold_W_dec_norm() + return + + sae.fold_W_dec_norm() + + # Verify no NaN or Inf values in any parameters + for name, param in sae.named_parameters(): + assert not torch.isnan( + param + ).any(), f"NaN found in {name} after fold_W_dec_norm" + assert not torch.isinf( + param + ).any(), f"Inf found in {name} after fold_W_dec_norm" + + @pytest.mark.parametrize("architecture", ALL_TRAINING_ARCHITECTURES) def test_TrainingSAE_save_and_load_from_checkpoint_all_architectures( architecture: str, diff --git a/tests/saes/test_standard_sae.py b/tests/saes/test_standard_sae.py index 384196df1..66c74b205 100644 --- a/tests/saes/test_standard_sae.py +++ b/tests/saes/test_standard_sae.py @@ -11,7 +11,7 @@ from transformer_lens.hook_points import HookPoint from sae_lens.config import LanguageModelSAERunnerConfig -from sae_lens.saes.sae import SAE, TrainingSAE, _disable_hooks +from sae_lens.saes.sae import SAE, _disable_hooks from sae_lens.saes.standard_sae import ( StandardSAE, StandardTrainingSAE, @@ -19,16 +19,12 @@ ) from sae_lens.util import dtype_to_str from tests.helpers import ( - ALL_ARCHITECTURES, - ALL_FOLDABLE_ARCHITECTURES, - ALL_TRAINING_ARCHITECTURES, assert_close, assert_not_close, build_runner_cfg, build_sae_cfg, build_sae_cfg_for_arch, build_sae_training_cfg, - build_sae_training_cfg_for_arch, ) @@ -277,102 +273,6 @@ def test_StandardSAE_fold_w_dec_norm( assert_close(sae_out_1, sae_out_2, atol=1e-5) -@pytest.mark.parametrize("architecture", ALL_ARCHITECTURES) -@torch.no_grad() -def test_sae_fold_w_dec_norm_all_architectures(architecture: str): - cfg = build_sae_cfg_for_arch(architecture) - sae = SAE.from_dict(cfg.to_dict()) - sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. - - # make sure all parameters are not 0s - for param in sae.parameters(): - param.data = torch.rand_like(param) - - assert sae.W_dec.norm(dim=-1).mean().item() != pytest.approx(1.0, abs=1e-6) - sae2 = deepcopy(sae) - - # If this is a topk SAE, assert this throws a NotImplementedError - if architecture not in ALL_FOLDABLE_ARCHITECTURES: - with pytest.raises(NotImplementedError): - sae2.fold_W_dec_norm() - return - - sae2.fold_W_dec_norm() - - # fold_W_dec_norm should normalize W_dec to have unit norm. - assert sae2.W_dec.norm(dim=-1).mean().item() == pytest.approx(1.0, abs=1e-6) - - # we expect activations of features to differ by W_dec norm weights. - activations = torch.randn(10, 4, cfg.d_in, device=cfg.device) - feature_activations_1 = sae.encode(activations) - feature_activations_2 = sae2.encode(activations) - - assert_close( - feature_activations_1.nonzero(), - feature_activations_2.nonzero(), - ) - - expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm(dim=-1) - assert_close(feature_activations_2, expected_feature_activations_2, atol=1e-5) - - sae_out_1 = sae.decode(feature_activations_1) - sae_out_2 = sae2.decode(feature_activations_2) - - # but actual outputs should be the same - assert_close(sae_out_1, sae_out_2, atol=1e-5) - - -@pytest.mark.parametrize("architecture", ALL_TRAINING_ARCHITECTURES) -@torch.no_grad() -def test_training_sae_fold_w_dec_norm_all_architectures(architecture: str): - cfg = build_sae_training_cfg_for_arch(architecture) - sae = TrainingSAE.from_dict(cfg.to_dict()) - sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. - - # make sure all parameters are not 0s - for param in sae.parameters(): - param.data = torch.rand_like(param) - - assert sae.W_dec.norm(dim=-1).mean().item() != pytest.approx(1.0, abs=1e-6) - sae2 = deepcopy(sae) - - if architecture in {"matching_pursuit"}: - with pytest.raises(NotImplementedError): - sae2.fold_W_dec_norm() - return - - sae2.fold_W_dec_norm() - - # fold_W_dec_norm should normalize W_dec to have unit norm. - assert sae2.W_dec.norm(dim=-1).mean().item() == pytest.approx(1.0, abs=1e-6) - - # we expect activations of features to differ by W_dec norm weights. - activations = torch.randn(10, 4, cfg.d_in, device=cfg.device) - feature_activations_1 = sae.encode(activations) - feature_activations_2 = sae2.encode(activations) - - assert_close( - feature_activations_1.nonzero(), - feature_activations_2.nonzero(), - ) - - if architecture in {"topk", "batchtopk", "matryoshka_batchtopk"}: - # Due to how rescale_acts_by_decoder_norm works in TopKSAEs, it's like the - # SAE has the norm folded in throughout the entire training process. - assert_close(feature_activations_2, feature_activations_1, atol=1e-4, rtol=1e-4) - else: - expected_feature_activations_2 = feature_activations_1 * sae.W_dec.norm(dim=-1) - assert_close( - feature_activations_2, expected_feature_activations_2, atol=1e-4, rtol=1e-4 - ) - - sae_out_1 = sae.decode(feature_activations_1) - sae_out_2 = sae2.decode(feature_activations_2) - - # but actual outputs should be the same - assert_close(sae_out_1, sae_out_2) - - @torch.no_grad() def test_StandardSAE_fold_norm_scaling_factor( cfg: LanguageModelSAERunnerConfig[StandardTrainingSAEConfig], @@ -417,148 +317,6 @@ def test_StandardSAE_fold_norm_scaling_factor( assert_close(sae_out_1, sae_out_2, atol=1e-5) -@pytest.mark.parametrize("architecture", ALL_FOLDABLE_ARCHITECTURES) -@torch.no_grad() -def test_sae_fold_norm_scaling_factor_all_architectures(architecture: str): - cfg = build_sae_cfg_for_arch(architecture) - norm_scaling_factor = 3.0 - - sae = SAE.from_dict(cfg.to_dict()) - # make sure all parameters are not 0s - for param in sae.parameters(): - param.data = torch.rand_like(param) - - sae2 = deepcopy(sae) - sae2.fold_activation_norm_scaling_factor(norm_scaling_factor) - - assert sae2.cfg.normalize_activations == "none" - - assert_close(sae2.W_enc.data, sae.W_enc.data * norm_scaling_factor) - - # we expect activations of features to differ by W_dec norm weights. - # assume activations are already scaled - activations = torch.randn(10, 4, cfg.d_in, device=cfg.device) - # we divide to get the unscale activations - unscaled_activations = activations / norm_scaling_factor - - feature_activations_1 = sae.encode(activations) - if feature_activations_1.is_sparse: - feature_activations_1 = feature_activations_1.to_dense() - # with the scaling folded in, the unscaled activations should produce the same - # result. - feature_activations_2 = sae2.encode(unscaled_activations) - if feature_activations_2.is_sparse: - feature_activations_2 = feature_activations_2.to_dense() - - assert_close( - feature_activations_1.nonzero(), - feature_activations_2.nonzero(), - ) - - assert_close(feature_activations_2, feature_activations_1, atol=1e-5) - - sae_out_1 = sae.decode(feature_activations_1) - sae_out_2 = norm_scaling_factor * sae2.decode(feature_activations_2) - - # but actual outputs should be the same - assert_close(sae_out_1, sae_out_2, atol=1e-5) - - -@pytest.mark.parametrize("architecture", ALL_FOLDABLE_ARCHITECTURES) -@torch.no_grad() -def test_fold_W_dec_norm_does_not_produce_nan_with_zero_norm_decoder( - architecture: str, -): - """ - Regression test for https://github.com/decoderesearch/SAELens/issues/588 - - When decoder weights have zero norm (dead latents), the division in - fold_W_dec_norm should not produce NaN values. This is handled by - clamping the norm to a minimum of 1e-8. - """ - cfg = build_sae_cfg_for_arch(architecture) - sae = SAE.from_dict(cfg.to_dict()) - sae.turn_off_forward_pass_hook_z_reshaping() - - # Initialize parameters with random values - for param in sae.parameters(): - param.data = torch.rand_like(param) - - # Set some decoder rows to zero to simulate dead latents - num_zero_rows = min(5, sae.W_dec.shape[0]) - sae.W_dec.data[:num_zero_rows] = 0.0 - - # Verify that we actually have zero-norm rows - norms_before = sae.W_dec.norm(dim=-1) - assert (norms_before[:num_zero_rows] == 0).all() - - # TopK SAEs with rescale_acts_by_decoder_norm=False raise NotImplementedError - if architecture == "topk" and not getattr( - sae.cfg, "rescale_acts_by_decoder_norm", False - ): - with pytest.raises(NotImplementedError): - sae.fold_W_dec_norm() - return - - # Call fold_W_dec_norm - this should not produce NaN values - sae.fold_W_dec_norm() - - # Verify no NaN or Inf values in any parameters - for name, param in sae.named_parameters(): - assert not torch.isnan( - param - ).any(), f"NaN found in {name} after fold_W_dec_norm" - assert not torch.isinf( - param - ).any(), f"Inf found in {name} after fold_W_dec_norm" - - -@pytest.mark.parametrize("architecture", ALL_TRAINING_ARCHITECTURES) -@torch.no_grad() -def test_training_fold_W_dec_norm_does_not_produce_nan_with_zero_norm_decoder( - architecture: str, -): - """ - Regression test for https://github.com/decoderesearch/SAELens/issues/588 - - When decoder weights have zero norm (dead latents), the division in - fold_W_dec_norm should not produce NaN values for TrainingSAE classes. - """ - cfg = build_sae_training_cfg_for_arch(architecture) - sae = TrainingSAE.from_dict(cfg.to_dict()) - sae.turn_off_forward_pass_hook_z_reshaping() - - # Initialize parameters with random values - for param in sae.parameters(): - param.data = torch.rand_like(param) - - # Set some decoder rows to zero to simulate dead latents - num_zero_rows = min(5, sae.W_dec.shape[0]) - sae.W_dec.data[:num_zero_rows] = 0.0 - - # Verify that we actually have zero-norm rows - norms_before = sae.W_dec.norm(dim=-1) - assert (norms_before[:num_zero_rows] == 0).all() - - # Call fold_W_dec_norm - this should not produce NaN values - - if architecture in {"matching_pursuit"}: - with pytest.raises(NotImplementedError): - sae.fold_W_dec_norm() - return - - sae.fold_W_dec_norm() - - # Verify no NaN or Inf values in any parameters - for name, param in sae.named_parameters(): - assert not torch.isnan( - param - ).any(), f"NaN found in {name} after fold_W_dec_norm" - assert not torch.isinf( - param - ).any(), f"Inf found in {name} after fold_W_dec_norm" - - def test_StandardSAE_save_and_load_from_pretrained(tmp_path: Path) -> None: cfg = build_sae_cfg() model_path = str(tmp_path) From ba4d27842b8591a88352e7d401f82584b037b492 Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Thu, 11 Jun 2026 22:17:00 -0700 Subject: [PATCH 2/7] tidy: Remove redundant calls to turn off hook z reshaping in tests, already off --- tests/saes/test_sae.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/saes/test_sae.py b/tests/saes/test_sae.py index 808b07d32..254d80359 100644 --- a/tests/saes/test_sae.py +++ b/tests/saes/test_sae.py @@ -264,7 +264,6 @@ def test_TrainingSAE_fold_activation_norm_scaling_factor_all_architectures( def test_sae_fold_w_dec_norm_all_architectures(architecture: str): cfg = build_sae_cfg_for_arch(architecture) sae = SAE.from_dict(cfg.to_dict()) - sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. # make sure all parameters are not 0s for param in sae.parameters(): @@ -309,7 +308,6 @@ def test_sae_fold_w_dec_norm_all_architectures(architecture: str): def test_training_sae_fold_w_dec_norm_all_architectures(architecture: str): cfg = build_sae_training_cfg_for_arch(architecture) sae = TrainingSAE.from_dict(cfg.to_dict()) - sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. # make sure all parameters are not 0s for param in sae.parameters(): @@ -416,7 +414,6 @@ def test_fold_W_dec_norm_does_not_produce_nan_with_zero_norm_decoder( """ cfg = build_sae_cfg_for_arch(architecture) sae = SAE.from_dict(cfg.to_dict()) - sae.turn_off_forward_pass_hook_z_reshaping() # Initialize parameters with random values for param in sae.parameters(): @@ -464,7 +461,6 @@ def test_training_fold_W_dec_norm_does_not_produce_nan_with_zero_norm_decoder( """ cfg = build_sae_training_cfg_for_arch(architecture) sae = TrainingSAE.from_dict(cfg.to_dict()) - sae.turn_off_forward_pass_hook_z_reshaping() # Initialize parameters with random values for param in sae.parameters(): From a2313b38398ce7e766c780ec5f3c864c0b357e1c Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Thu, 11 Jun 2026 22:02:44 -0700 Subject: [PATCH 3/7] test: Verify fold_W_dec_norm handles d_sae == 1 edge case With d_sae == 1, W_dec_norms is shape (1, 1) and .squeeze() collapses it to a 0-dim scalar Co-Authored-By: Claude Opus 4.8 --- tests/saes/test_sae.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/saes/test_sae.py b/tests/saes/test_sae.py index 254d80359..9ab408f2b 100644 --- a/tests/saes/test_sae.py +++ b/tests/saes/test_sae.py @@ -303,6 +303,34 @@ def test_sae_fold_w_dec_norm_all_architectures(architecture: str): assert_close(sae_out_1, sae_out_2, atol=1e-5) +@pytest.mark.parametrize("architecture", ALL_FOLDABLE_ARCHITECTURES) +@torch.no_grad() +def test_sae_fold_w_dec_norm_with_d_sae_of_1(architecture: str): + cfg = build_sae_cfg_for_arch(architecture, d_sae=1) + sae = SAE.from_dict(cfg.to_dict()) + random_params(sae) + + activations = 10.0 * torch.ones(8, cfg.d_in, device=cfg.device) + bias_shapes_before = { + name: param.shape + for name, param in sae.named_parameters() + if name in {"b_enc", "b_gate", "b_mag", "log_threshold"} + } + sae_out_before = sae(activations) + + sae.fold_W_dec_norm() + + # W_dec should be normalized to unit norm. + assert sae.W_dec.norm(dim=-1).item() == pytest.approx(1.0, abs=1e-6) + # Validate shape of the single-element bias params. + assert bias_shapes_before + for name, param in sae.named_parameters(): + if name in bias_shapes_before: + assert param.shape == bias_shapes_before[name] + # Folding preserves the SAE function. + assert_close(sae(activations), sae_out_before, atol=1e-4) + + @pytest.mark.parametrize("architecture", ALL_TRAINING_ARCHITECTURES) @torch.no_grad() def test_training_sae_fold_w_dec_norm_all_architectures(architecture: str): From a2c0e84edc3e26ed4efca251eba47f35fd85b56c Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Thu, 11 Jun 2026 22:41:42 -0700 Subject: [PATCH 4/7] chore: Re-squeeze just the added dimension of W_dec_norms Note this also removes an explicit comment in jumprelu about intentionally using squeeze() rather than squeeze(-1) to maintain past behavior; but other than the internally different behavior in the edge-case of d_sae=1 -- extra squeeze and broadcasting, which is now tested and is still valid -- there is no difference in behavior. --- sae_lens/saes/gated_sae.py | 8 ++++---- sae_lens/saes/jumprelu_sae.py | 3 +-- sae_lens/saes/sae.py | 2 +- tests/saes/test_standard_sae.py | 2 +- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/sae_lens/saes/gated_sae.py b/sae_lens/saes/gated_sae.py index 5a96e3e7d..ef147c97d 100644 --- a/sae_lens/saes/gated_sae.py +++ b/sae_lens/saes/gated_sae.py @@ -98,8 +98,8 @@ def fold_W_dec_norm(self): # Gated-specific parameters need special handling # r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path - self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze() - self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze() + self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze(-1) + self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze(-1) @dataclass @@ -232,8 +232,8 @@ def fold_W_dec_norm(self): # Gated-specific parameters need special handling # r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path - self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze() - self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze() + self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze(-1) + self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze(-1) def _init_weights_gated( diff --git a/sae_lens/saes/jumprelu_sae.py b/sae_lens/saes/jumprelu_sae.py index 69312bb20..650c36c09 100644 --- a/sae_lens/saes/jumprelu_sae.py +++ b/sae_lens/saes/jumprelu_sae.py @@ -339,8 +339,7 @@ def fold_W_dec_norm(self): # Call parent implementation to handle W_enc and W_dec adjustment super().fold_W_dec_norm() - # Fix: Use squeeze() instead of squeeze(-1) to match old behavior - self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze()) + self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze(-1)) @override def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None: diff --git a/sae_lens/saes/sae.py b/sae_lens/saes/sae.py index b9e126866..d752c3f9c 100644 --- a/sae_lens/saes/sae.py +++ b/sae_lens/saes/sae.py @@ -513,7 +513,7 @@ def fold_W_dec_norm(self): # Only update b_enc if it exists (standard/jumprelu architectures) if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter): - self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze() + self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze(-1) def get_name(self): """Generate a name for this SAE.""" diff --git a/tests/saes/test_standard_sae.py b/tests/saes/test_standard_sae.py index 66c74b205..dd9e0bed5 100644 --- a/tests/saes/test_standard_sae.py +++ b/tests/saes/test_standard_sae.py @@ -248,7 +248,7 @@ def test_StandardSAE_fold_w_dec_norm( W_dec_norms = sae.W_dec.norm(dim=-1).unsqueeze(1) assert_close(sae2.W_dec.data, sae.W_dec.data / W_dec_norms) assert_close(sae2.W_enc.data, sae.W_enc.data * W_dec_norms.T) - assert_close(sae2.b_enc.data, sae.b_enc.data * W_dec_norms.squeeze()) + assert_close(sae2.b_enc.data, sae.b_enc.data * W_dec_norms.squeeze(-1)) # fold_W_dec_norm should normalize W_dec to have unit norm. assert sae2.W_dec.norm(dim=-1).mean().item() == pytest.approx(1.0, abs=1e-6) From ec267b6de5fe1f372b9f65b13efa9a01a53b4e9c Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Thu, 11 Jun 2026 23:47:20 -0700 Subject: [PATCH 5/7] chore: JumpReLU transcoder normalizes threshold by exact same scalars as other weights --- sae_lens/saes/transcoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sae_lens/saes/transcoder.py b/sae_lens/saes/transcoder.py index 0125d0af1..ea61f1b5a 100644 --- a/sae_lens/saes/transcoder.py +++ b/sae_lens/saes/transcoder.py @@ -369,7 +369,7 @@ def fold_W_dec_norm(self) -> None: """ # Get the decoder weight norms before normalizing with torch.no_grad(): - W_dec_norms = self.W_dec.norm(dim=1) + W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8) # Fold the decoder norms as in the parent class super().fold_W_dec_norm() From 4c5677ce2764dfd0a0da1ae9cc384f9f01695b0b Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Thu, 11 Jun 2026 22:56:43 -0700 Subject: [PATCH 6/7] refactor: DRY folding decoder norms --- sae_lens/saes/gated_sae.py | 16 ++++++---------- sae_lens/saes/jumprelu_sae.py | 18 +++++++----------- sae_lens/saes/sae.py | 15 +++++++++++++-- sae_lens/saes/topk_sae.py | 7 ++++--- sae_lens/saes/transcoder.py | 8 ++------ 5 files changed, 32 insertions(+), 32 deletions(-) diff --git a/sae_lens/saes/gated_sae.py b/sae_lens/saes/gated_sae.py index ef147c97d..8f34156b3 100644 --- a/sae_lens/saes/gated_sae.py +++ b/sae_lens/saes/gated_sae.py @@ -92,14 +92,12 @@ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor: @torch.no_grad() def fold_W_dec_norm(self): """Override to handle gated-specific parameters.""" - W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1) - self.W_dec.data = self.W_dec.data / W_dec_norms - self.W_enc.data = self.W_enc.data * W_dec_norms.T + W_dec_norms = self.fold_and_get_W_dec_norm() # Gated-specific parameters need special handling # r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path - self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze(-1) - self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze(-1) + self.b_gate.data = self.b_gate.data * W_dec_norms + self.b_mag.data = self.b_mag.data * W_dec_norms @dataclass @@ -226,14 +224,12 @@ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]: @torch.no_grad() def fold_W_dec_norm(self): """Override to handle gated-specific parameters.""" - W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1) - self.W_dec.data = self.W_dec.data / W_dec_norms - self.W_enc.data = self.W_enc.data * W_dec_norms.T + W_dec_norms = self.fold_and_get_W_dec_norm() # Gated-specific parameters need special handling # r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path - self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze(-1) - self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze(-1) + self.b_gate.data = self.b_gate.data * W_dec_norms + self.b_mag.data = self.b_mag.data * W_dec_norms def _init_weights_gated( diff --git a/sae_lens/saes/jumprelu_sae.py b/sae_lens/saes/jumprelu_sae.py index 650c36c09..24cf7e4f2 100644 --- a/sae_lens/saes/jumprelu_sae.py +++ b/sae_lens/saes/jumprelu_sae.py @@ -171,11 +171,8 @@ def fold_W_dec_norm(self): # Save the current threshold before calling parent method current_thresh = self.threshold.clone() - # Get W_dec norms that will be used for scaling (clamped to avoid division by zero) - W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8) - - # Call parent implementation to handle W_enc, W_dec, and b_enc adjustment - super().fold_W_dec_norm() + # Handle W_enc, W_dec, and b_enc adjustment + W_dec_norms = self.fold_and_get_W_dec_norm() # Scale the threshold by the same factor as we scaled b_enc # This ensures the same features remain active/inactive after folding @@ -333,13 +330,12 @@ def fold_W_dec_norm(self): # Save the current threshold before we call the parent method current_thresh = self.threshold.clone() - # Get W_dec norms (clamped to avoid division by zero) - W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1) - - # Call parent implementation to handle W_enc and W_dec adjustment - super().fold_W_dec_norm() + # Handle W_enc, W_dec, and b_enc adjustment + W_dec_norms = self.fold_and_get_W_dec_norm() - self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze(-1)) + # Scale the threshold by the same factor as we scaled b_enc + # This ensures the same features remain active/inactive after folding + self.log_threshold.data = torch.log(current_thresh * W_dec_norms) @override def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None: diff --git a/sae_lens/saes/sae.py b/sae_lens/saes/sae.py index d752c3f9c..a2039ecdd 100644 --- a/sae_lens/saes/sae.py +++ b/sae_lens/saes/sae.py @@ -504,10 +504,19 @@ def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None: def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None: pass - @torch.no_grad() def fold_W_dec_norm(self): """Fold decoder norms into encoder.""" - W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1) + self.fold_and_get_W_dec_norm() + + @torch.no_grad() + def get_W_dec_norm(self) -> torch.Tensor: + """Get decoder norms.""" + return self.W_dec.norm(dim=-1).clamp(min=1e-8) + + @torch.no_grad() + def fold_and_get_W_dec_norm(self) -> torch.Tensor: + """Fold decoder norms into encoder and return them.""" + W_dec_norms = self.get_W_dec_norm().unsqueeze(1) self.W_dec.data = self.W_dec.data / W_dec_norms self.W_enc.data = self.W_enc.data * W_dec_norms.T @@ -515,6 +524,8 @@ def fold_W_dec_norm(self): if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter): self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze(-1) + return W_dec_norms.squeeze(-1) + def get_name(self): """Generate a name for this SAE.""" return f"sae_{self.cfg.metadata.model_name}_{self.cfg.metadata.hook_name}_{self.cfg.d_sae}" diff --git a/sae_lens/saes/topk_sae.py b/sae_lens/saes/topk_sae.py index 6ffa0ad64..bd623734c 100644 --- a/sae_lens/saes/topk_sae.py +++ b/sae_lens/saes/topk_sae.py @@ -295,7 +295,7 @@ def fold_W_dec_norm(self) -> None: raise NotImplementedError( "Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations" ) - _fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc) + super().fold_W_dec_norm() @dataclass @@ -445,7 +445,7 @@ def fold_W_dec_norm(self) -> None: raise NotImplementedError( "Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations" ) - _fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc) + super().fold_W_dec_norm() @override def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]: @@ -515,6 +515,7 @@ def process_state_dict_for_saving_inference( super().process_state_dict_for_saving_inference(state_dict) if self.cfg.rescale_acts_by_decoder_norm: _fold_norm_topk( + W_dec_norm=self.get_W_dec_norm(), W_enc=state_dict["W_enc"], b_enc=state_dict["b_enc"], W_dec=state_dict["W_dec"], @@ -558,11 +559,11 @@ def _init_weights_topk( def _fold_norm_topk( + W_dec_norm: torch.Tensor, W_enc: torch.Tensor, b_enc: torch.Tensor, W_dec: torch.Tensor, ) -> None: - W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8) b_enc.data = b_enc.data * W_dec_norm W_dec_norms = W_dec_norm.unsqueeze(1) W_dec.data = W_dec.data / W_dec_norms diff --git a/sae_lens/saes/transcoder.py b/sae_lens/saes/transcoder.py index ea61f1b5a..e61d96b8c 100644 --- a/sae_lens/saes/transcoder.py +++ b/sae_lens/saes/transcoder.py @@ -367,14 +367,10 @@ def fold_W_dec_norm(self) -> None: This is important for JumpReLU as the threshold needs to be scaled along with the decoder weights. """ - # Get the decoder weight norms before normalizing - with torch.no_grad(): - W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8) - # Fold the decoder norms as in the parent class - super().fold_W_dec_norm() + W_dec_norms = self.fold_and_get_W_dec_norm() - # Scale the threshold by the decoder weight norms + # Scale the threshold by the same norms with torch.no_grad(): self.threshold.data = self.threshold.data * W_dec_norms From 10ddf8aee5eac9a9e0e7b485e97cdc578d3ba551 Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Fri, 12 Jun 2026 11:51:40 -0700 Subject: [PATCH 7/7] refactor: more concise dim-handling in SAE.fold_and_get_W_dec_norm() --- sae_lens/saes/sae.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sae_lens/saes/sae.py b/sae_lens/saes/sae.py index a2039ecdd..be1e9025a 100644 --- a/sae_lens/saes/sae.py +++ b/sae_lens/saes/sae.py @@ -516,15 +516,15 @@ def get_W_dec_norm(self) -> torch.Tensor: @torch.no_grad() def fold_and_get_W_dec_norm(self) -> torch.Tensor: """Fold decoder norms into encoder and return them.""" - W_dec_norms = self.get_W_dec_norm().unsqueeze(1) - self.W_dec.data = self.W_dec.data / W_dec_norms - self.W_enc.data = self.W_enc.data * W_dec_norms.T + W_dec_norms = self.get_W_dec_norm() + self.W_dec.data = self.W_dec.data / W_dec_norms.unsqueeze(1) + self.W_enc.data = self.W_enc.data * W_dec_norms.unsqueeze(1).T # Only update b_enc if it exists (standard/jumprelu architectures) if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter): - self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze(-1) + self.b_enc.data = self.b_enc.data * W_dec_norms - return W_dec_norms.squeeze(-1) + return W_dec_norms def get_name(self): """Generate a name for this SAE."""