Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions sae_lens/saes/gated_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,12 @@ def calculate_aux_loss(
pi_gate_act = torch.relu(pi_gate)

# L1-like penalty scaled by W_dec norms
l1_loss = (
step_input.coefficients["l1"]
* torch.sum(pi_gate_act * self.W_dec.norm(dim=1), dim=-1).mean()
)
l1_loss = step_input.coefficients["l1"] * torch.sum(pi_gate_act, dim=-1).mean()

# Aux reconstruction: reconstruct x purely from gating path
via_gate_reconstruction = pi_gate_act @ self.W_dec + self.b_dec
via_gate_reconstruction = (
pi_gate_act @ self.W_dec.detach() + self.b_dec.detach()
)
aux_recon_loss = (
(via_gate_reconstruction - step_input.sae_in).pow(2).sum(dim=-1).mean()
)
Expand Down
7 changes: 1 addition & 6 deletions tests/refactor_compatibility/test_gated_sae_equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,9 +393,4 @@ def test_gated_training_equivalence(): # type: ignore
atol=1e-5,
msg="Output differs between old and new Gated implementation",
)
assert_close(
old_out.loss,
new_out.loss,
atol=1e-5,
msg="Loss differs between old and new Gated implementation",
)
# the losses should no longer be equivalent, since we fixed a bug with the auxiliary reconstruction loss
22 changes: 22 additions & 0 deletions tests/saes/test_gated_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,25 @@ def test_GatedTrainingSAE_save_and_load_inference_sae(tmp_path: Path) -> None:
training_full_out = training_sae(sae_in)
inference_full_out = inference_sae(sae_in)
assert_close(training_full_out, inference_full_out)


def test_GatedTrainingSAE_auxiliary_reconstruction_loss_does_not_apply_gradient_to_decoder_weights():
cfg = build_gated_sae_training_cfg()
sae = GatedTrainingSAE.from_dict(cfg.to_dict())

aux_losses = sae.calculate_aux_loss(
step_input=TrainStepInput(
sae_in=torch.randn(10, cfg.d_in),
coefficients={"l1": 1.0},
dead_neuron_mask=None,
),
feature_acts=torch.randn(10, cfg.d_sae),
hidden_pre=torch.randn(10, cfg.d_sae),
sae_out=torch.randn(10, cfg.d_in),
)
aux_losses["auxiliary_reconstruction_loss"].backward()

assert sae.W_dec.grad is None or sae.W_dec.grad.sum() == 0.0
assert sae.b_dec.grad is None or sae.b_dec.grad.sum() == 0.0
assert sae.W_enc.grad is not None and sae.W_enc.grad.sum() != 0.0
assert sae.b_gate.grad is not None and sae.b_gate.grad.sum() != 0.0
Loading