From 7aba53c934547dc348ffd6328fce0791a44553c5 Mon Sep 17 00:00:00 2001 From: David Chanin Date: Tue, 23 Dec 2025 22:19:26 -0500 Subject: [PATCH] fix: warn rather than crash if sae is unnormalized --- sae_bench/evals/meta_structure/eval_output.py | 2 +- sae_bench/sae_bench_utils/general_utils.py | 7 ++++--- tests/acceptance/test_meta_structure.py | 20 +++++++++++++------ 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/sae_bench/evals/meta_structure/eval_output.py b/sae_bench/evals/meta_structure/eval_output.py index 74b679f..b1f5ed7 100644 --- a/sae_bench/evals/meta_structure/eval_output.py +++ b/sae_bench/evals/meta_structure/eval_output.py @@ -53,7 +53,7 @@ class MetaStructureEvalOutput( eval_id: str datetime_epoch_millis: int eval_result_metrics: MetaStructureMetricCategories - eval_result_details: list[BaseResultDetail] | None = None + eval_result_details: list[BaseResultDetail] | None = None # pyright: ignore[reportIncompatibleVariableOverride] eval_type_id: str = Field( default=EVAL_TYPE_ID_META_STRUCTURE, title="Eval Type ID", diff --git a/sae_bench/sae_bench_utils/general_utils.py b/sae_bench/sae_bench_utils/general_utils.py index 8f9ec9b..ead08b1 100644 --- a/sae_bench/sae_bench_utils/general_utils.py +++ b/sae_bench/sae_bench_utils/general_utils.py @@ -3,6 +3,7 @@ import random import re import time +import warnings from typing import Any, Callable import pandas as pd @@ -124,10 +125,10 @@ def check_decoder_norms(W_dec: torch.Tensor) -> bool: return True else: max_diff = torch.max(torch.abs(norms - torch.ones_like(norms))) - print(f"Decoder weights are not normalized. Max diff: {max_diff.item()}") - raise ValueError( - "Decoder weights are not normalized. Refer to base_sae.py and relu_sae.py for more info." + warnings.warn( + f"Decoder weights are not normalized. Max diff: {max_diff.item()}. Refer to base_sae.py and relu_sae.py for more info." ) + return False def load_and_format_sae( diff --git a/tests/acceptance/test_meta_structure.py b/tests/acceptance/test_meta_structure.py index 8c26ea7..b1f2b57 100644 --- a/tests/acceptance/test_meta_structure.py +++ b/tests/acceptance/test_meta_structure.py @@ -54,9 +54,17 @@ def test_meta_structure_eval_matches_fixture(tmp_path): actual = json.load(f) actual_metrics = _load_metrics(actual) - assert pytest.approx( - expected_metrics["decoder_fraction_variance_explained"], rel=1e-2, - ) == actual_metrics["decoder_fraction_variance_explained"] - assert pytest.approx( - expected_metrics["final_reconstruction_mse"], rel=1e-2, - ) == actual_metrics["final_reconstruction_mse"] + assert ( + pytest.approx( + expected_metrics["decoder_fraction_variance_explained"], + rel=1e-2, + ) + == actual_metrics["decoder_fraction_variance_explained"] + ) + assert ( + pytest.approx( + expected_metrics["final_reconstruction_mse"], + rel=1e-2, + ) + == actual_metrics["final_reconstruction_mse"] + )