diff --git a/tests/loading/test_pretrained_sae_loaders.py b/tests/loading/test_pretrained_sae_loaders.py index 69ce774ee..eb8ae6779 100644 --- a/tests/loading/test_pretrained_sae_loaders.py +++ b/tests/loading/test_pretrained_sae_loaders.py @@ -2364,16 +2364,25 @@ def mock_hf_hub_download(*args: Any, **kwargs: Any) -> str: # noqa: ARG001 # The transpose must preserve the underlying linear operation: in the # native Qwen Scope formula, pre_acts = residual @ raw_W_enc.T + b_enc. # In SAELens, pre_acts = residual @ W_enc + b_enc. They must agree. + # + # The loader stores W_enc/W_dec as `.T.contiguous()`, so each matmul below + # multiplies a contiguous tensor while the reference multiplies a transposed + # view. Over these large reduction dims (2048 and 32768) the two layouts can + # accumulate float32 in a different order on some BLAS backends, so the + # results agree only up to rounding. Use tolerances that absorb that noise + # while still catching a real loader regression (which would be orders of + # magnitude off); the default rtol=1e-5/atol=1e-8 is too tight here and + # causes intermittent CI failures. residual = torch.randn(2, 7, d_in) expected_pre_acts = residual @ raw_W_enc.T + raw_b_enc actual_pre_acts = residual @ state_dict["W_enc"] + state_dict["b_enc"] - assert_close(actual_pre_acts, expected_pre_acts) + assert_close(actual_pre_acts, expected_pre_acts, rtol=1e-3, atol=1e-2) # Same for the decoder: native is feats @ raw_W_dec.T + b_dec, SAELens is feats @ W_dec + b_dec. feats = torch.randn(2, 7, d_sae) expected_recon = feats @ raw_W_dec.T + raw_b_dec actual_recon = feats @ state_dict["W_dec"] + state_dict["b_dec"] - assert_close(actual_recon, expected_recon) + assert_close(actual_recon, expected_recon, rtol=1e-3, atol=1e-2) def test_qwen_scope_sae_loads_via_from_pretrained(