Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions tests/loading/test_pretrained_sae_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,16 +2364,25 @@ def mock_hf_hub_download(*args: Any, **kwargs: Any) -> str: # noqa: ARG001
# The transpose must preserve the underlying linear operation: in the
# native Qwen Scope formula, pre_acts = residual @ raw_W_enc.T + b_enc.
# In SAELens, pre_acts = residual @ W_enc + b_enc. They must agree.
#
# The loader stores W_enc/W_dec as `.T.contiguous()`, so each matmul below
# multiplies a contiguous tensor while the reference multiplies a transposed
# view. Over these large reduction dims (2048 and 32768) the two layouts can
# accumulate float32 in a different order on some BLAS backends, so the
# results agree only up to rounding. Use tolerances that absorb that noise
# while still catching a real loader regression (which would be orders of
# magnitude off); the default rtol=1e-5/atol=1e-8 is too tight here and
# causes intermittent CI failures.
residual = torch.randn(2, 7, d_in)
expected_pre_acts = residual @ raw_W_enc.T + raw_b_enc
actual_pre_acts = residual @ state_dict["W_enc"] + state_dict["b_enc"]
assert_close(actual_pre_acts, expected_pre_acts)
assert_close(actual_pre_acts, expected_pre_acts, rtol=1e-3, atol=1e-2)

# Same for the decoder: native is feats @ raw_W_dec.T + b_dec, SAELens is feats @ W_dec + b_dec.
feats = torch.randn(2, 7, d_sae)
expected_recon = feats @ raw_W_dec.T + raw_b_dec
actual_recon = feats @ state_dict["W_dec"] + state_dict["b_dec"]
assert_close(actual_recon, expected_recon)
assert_close(actual_recon, expected_recon, rtol=1e-3, atol=1e-2)


def test_qwen_scope_sae_loads_via_from_pretrained(
Expand Down
Loading