From 740b4e5dcbb4aefd8be8003dfb469c2472b7b7c7 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Thu, 4 Sep 2025 11:10:01 +0100 Subject: [PATCH 1/2] fix: detach grads from gated SAE for aux loss --- sae_lens/saes/gated_sae.py | 9 ++++----- tests/saes/test_gated_sae.py | 22 ++++++++++++++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/sae_lens/saes/gated_sae.py b/sae_lens/saes/gated_sae.py index e2e5e5d16..90d82d409 100644 --- a/sae_lens/saes/gated_sae.py +++ b/sae_lens/saes/gated_sae.py @@ -187,13 +187,12 @@ def calculate_aux_loss( pi_gate_act = torch.relu(pi_gate) # L1-like penalty scaled by W_dec norms - l1_loss = ( - step_input.coefficients["l1"] - * torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean() - ) + l1_loss = step_input.coefficients["l1"] * torch.sum(pi_gate_act, dim=-1).mean() # Aux reconstruction: reconstruct x purely from gating path - via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec + via_gate_reconstruction = ( + pi_gate_act @ self.W_dec.detach() + self.b_dec.detach() + ) aux_recon_loss = ( (via_gate_reconstruction - step_input.sae_in).pow(2).sum(dim=-1).mean() ) diff --git a/tests/saes/test_gated_sae.py b/tests/saes/test_gated_sae.py index 3653bf08e..3b06e2165 100644 --- a/tests/saes/test_gated_sae.py +++ b/tests/saes/test_gated_sae.py @@ -322,3 +322,25 @@ def test_GatedTrainingSAE_save_and_load_inference_sae(tmp_path: Path) -> None: training_full_out = training_sae(sae_in) inference_full_out = inference_sae(sae_in) assert_close(training_full_out, inference_full_out) + + +def test_GatedTrainingSAE_auxiliary_reconstruction_loss_does_not_apply_gradient_to_decoder_weights(): + cfg = build_gated_sae_training_cfg() + sae = GatedTrainingSAE.from_dict(cfg.to_dict()) + + aux_losses = sae.calculate_aux_loss( + step_input=TrainStepInput( + sae_in=torch.randn(10, cfg.d_in), + coefficients={"l1": 1.0}, + dead_neuron_mask=None, + ), + feature_acts=torch.randn(10, cfg.d_sae), + hidden_pre=torch.randn(10, cfg.d_sae), + sae_out=torch.randn(10, cfg.d_in), + ) + aux_losses["auxiliary_reconstruction_loss"].backward() + + assert sae.W_dec.grad is None or sae.W_dec.grad.sum() == 0.0 + assert sae.b_dec.grad is None or sae.b_dec.grad.sum() == 0.0 + assert sae.W_enc.grad is not None and sae.W_enc.grad.sum() != 0.0 + assert sae.b_gate.grad is not None and sae.b_gate.grad.sum() != 0.0 From a231b21c2e6aa48de8af90e7d7c1ffa23aa978b1 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Thu, 4 Sep 2025 11:29:22 +0100 Subject: [PATCH 2/2] fix test --- tests/refactor_compatibility/test_gated_sae_equivalence.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/refactor_compatibility/test_gated_sae_equivalence.py b/tests/refactor_compatibility/test_gated_sae_equivalence.py index 82f54c45a..136651688 100644 --- a/tests/refactor_compatibility/test_gated_sae_equivalence.py +++ b/tests/refactor_compatibility/test_gated_sae_equivalence.py @@ -393,9 +393,4 @@ def test_gated_training_equivalence(): # type: ignore atol=1e-5, msg="Output differs between old and new Gated implementation", ) - assert_close( - old_out.loss, - new_out.loss, - atol=1e-5, - msg="Loss differs between old and new Gated implementation", - ) + # the losses should no longer be equivalent, since we fixed a bug with the auxiliary reconstruction loss