diff --git a/sae_lens/saes/sae.py b/sae_lens/saes/sae.py index 3952486ea..744354d70 100644 --- a/sae_lens/saes/sae.py +++ b/sae_lens/saes/sae.py @@ -349,7 +349,10 @@ def run_time_activation_ln_out( x: torch.Tensor, eps: float = 1e-5, # noqa: ARG001 ) -> torch.Tensor: - return x * self.ln_std + self.ln_mu # type: ignore + x = x * self.ln_std + self.ln_mu # type: ignore + del self.ln_mu + del self.ln_std + return x self.run_time_activation_norm_fn_in = run_time_activation_ln_in self.run_time_activation_norm_fn_out = run_time_activation_ln_out diff --git a/tests/saes/test_standard_sae.py b/tests/saes/test_standard_sae.py index 384196df1..746a93cc1 100644 --- a/tests/saes/test_standard_sae.py +++ b/tests/saes/test_standard_sae.py @@ -795,6 +795,7 @@ def test_StandardSAE_constant_norm_rescale(): cfg = build_sae_cfg(d_in=2, d_sae=3, normalize_activations="constant_norm_rescale") sae = StandardSAE(cfg) + pre_activation_vars = list(vars(sae).keys()) test_input = torch.randn(10, 2, device=cfg.device) @@ -803,12 +804,15 @@ def test_StandardSAE_constant_norm_rescale(): assert_close(scaled_input, test_input * expected_scaler, atol=1e-6) scaled_output = sae.run_time_activation_norm_fn_out(scaled_input) assert_close(scaled_output, test_input) + # Basic verification of temporary extra state cleanup + assert list(vars(sae).keys()) == pre_activation_vars def test_StandardSAE_layer_norm(): cfg = build_sae_cfg(d_in=2, d_sae=3, normalize_activations="layer_norm") sae = StandardSAE(cfg) + pre_activation_vars = list(vars(sae).keys()) test_input = torch.randn(10, 2, device=cfg.device) @@ -822,6 +826,8 @@ def test_StandardSAE_layer_norm(): ) scaled_output = sae.run_time_activation_norm_fn_out(scaled_input) assert_close(scaled_output, test_input, atol=1e-4) + # Basic verification of temporary extra state cleanup + assert list(vars(sae).keys()) == pre_activation_vars def test_StandardSAE_none():