diff --git a/sae_lens/evals.py b/sae_lens/evals.py index f1c20d82f..96aa20a45 100644 --- a/sae_lens/evals.py +++ b/sae_lens/evals.py @@ -381,6 +381,36 @@ def get_downstream_reconstruction_metrics( return metrics +def compute_explained_variance( + input_batches: list[torch.Tensor], + output_batches: list[torch.Tensor], +) -> float: + """Compute explained variance from lists of input/output tensor pairs. + + Each tensor has shape (n_tokens, d_model). Total variance is computed as + Var(X) = E[||X||^2] - ||E[X]||^2, summed across dimensions. + """ + mean_sum_of_squares: list[torch.Tensor] = [] + mean_act_per_dimension: list[torch.Tensor] = [] + mean_sum_of_resid_squared: list[torch.Tensor] = [] + + for sae_input, sae_output in zip(input_batches, output_batches): + mean_sum_of_squares.append(sae_input.pow(2).sum(dim=-1).mean(dim=0)) + mean_act_per_dimension.append(sae_input.mean(dim=0)) + resid_ss = (sae_input - sae_output).pow(2).sum(dim=-1) + mean_sum_of_resid_squared.append(resid_ss.mean(dim=0)) + + total_mean_ss = torch.stack(mean_sum_of_squares).mean(dim=0) + total_mean_act = torch.stack(mean_act_per_dimension).mean(dim=0) + total_variance = total_mean_ss - (total_mean_act**2).sum() + residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0) + + eps = 1e-12 + if torch.abs(total_variance) <= eps: + return 1.0 if torch.abs(residual_variance) <= eps else 0.0 + return (1 - residual_variance / total_variance).item() + + def get_sparsity_and_variance_metrics( sae: SAE[Any], model: HookedRootModule, @@ -411,9 +441,8 @@ def get_sparsity_and_variance_metrics( metric_dict["l0"] = [] metric_dict["l1"] = [] - mean_sum_of_squares = [] # for explained variance - mean_act_per_dimension = [] # for explained variance - mean_sum_of_resid_squared = [] # for explained variance + variance_inputs: list[torch.Tensor] = [] + variance_outputs: list[torch.Tensor] = [] if compute_variance_metrics: # explained_variance is left out of the dict here, since we don't want to naively # average over the batch dimension. This is handled later in the function. @@ -542,18 +571,8 @@ def get_sparsity_and_variance_metrics( ) explained_variance_legacy = 1 - resid_sum_of_squares / batched_variance_sum metric_dict["explained_variance_legacy"].append(explained_variance_legacy) - # Individual sums for the new (correct) formula. We're taking the mean over the batch - # dimension here to save memory, but we could also pass the full tensors and take the - # mean later (like we do for other metrics). - mean_sum_of_squares.append( - (flattened_sae_input).pow(2).sum(dim=-1).mean(dim=0) # scalar - ) - mean_act_per_dimension.append( - (flattened_sae_input).pow(2).mean(dim=0) # [d_model] - ) - mean_sum_of_resid_squared.append( - resid_sum_of_squares.mean(dim=0) # scalar - ) + variance_inputs.append(flattened_sae_input) + variance_outputs.append(flattened_sae_out) x_normed = flattened_sae_input / torch.norm( flattened_sae_input, dim=-1, keepdim=True @@ -581,11 +600,9 @@ def get_sparsity_and_variance_metrics( # calculate explained variance if compute_variance_metrics: - mean_sum_of_squares = torch.stack(mean_sum_of_squares).mean(dim=0) - mean_act_per_dimension = torch.cat(mean_act_per_dimension).mean(dim=0) - total_variance = mean_sum_of_squares - mean_act_per_dimension**2 - residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0) - metrics["explained_variance"] = (1 - residual_variance / total_variance).item() + metrics["explained_variance"] = compute_explained_variance( + variance_inputs, variance_outputs + ) # Aggregate feature-wise metrics feature_metrics: dict[str, list[float]] = {} diff --git a/tests/test_evals.py b/tests/test_evals.py index 8a285aa88..61b90795e 100644 --- a/tests/test_evals.py +++ b/tests/test_evals.py @@ -15,6 +15,7 @@ EvalConfig, _kl, all_loadable_saes, + compute_explained_variance, get_downstream_reconstruction_metrics, get_eval_everything_config, get_saes_from_regex, @@ -33,6 +34,7 @@ from sae_lens.saes.sae import SAE, TrainingSAE from sae_lens.saes.standard_sae import StandardSAE, StandardTrainingSAE from sae_lens.saes.topk_sae import TopKTrainingSAE +from sae_lens.synthetic.evals import ExplainedVarianceCalculator from sae_lens.training.activation_scaler import ActivationScaler from sae_lens.training.activations_store import ActivationsStore from tests.helpers import ( @@ -644,6 +646,141 @@ def test_get_sparsity_and_variance_metrics_identity_sae_perfect_reconstruction( assert metrics["mse"] == pytest.approx(0.0, abs=1e-5) +def test_explained_variance_invariant_to_activation_shift( + model: HookedTransformer, + example_dataset: Dataset, +): + d_in = 64 + hook_name = "blocks.1.hook_resid_pre" + cfg = build_runner_cfg(d_in=d_in, d_sae=2 * d_in, hook_name=hook_name) + + training_sae = StandardTrainingSAE.from_dict(cfg.get_training_sae_cfg_dict()) + sae = StandardSAE.from_dict(training_sae.cfg.get_inference_sae_cfg_dict()) + random_params(sae) + # The shift invariance requires b_dec to be subtracted from the input during + # encoding so that the shift cancels: encode(x+c) with b_dec+c == encode(x) with b_dec. + sae.cfg.apply_b_dec_to_input = True + + eval_kwargs: dict[str, Any] = dict( + sae=sae, + model=model, + activation_scaler=ActivationScaler(None), + n_batches=3, + compute_l2_norms=False, + compute_sparsity_metrics=False, + compute_variance_metrics=True, + compute_featurewise_density_statistics=True, + eval_batch_size_prompts=4, + model_kwargs={}, + ) + + activation_store = ActivationsStore.from_config( + model, cfg, override_dataset=example_dataset + ) + metrics_original, _ = get_sparsity_and_variance_metrics( + activation_store=activation_store, **eval_kwargs + ) + + # Shifting activations by a constant and adjusting b_dec by the same amount + # should not change explained_variance: the shift cancels in encoding + # (so features are identical) and both input and output shift equally. + shift = torch.full((d_in,), 5.0) + sae.b_dec.data += shift + model.add_hook(hook_name, lambda tensor, **_: tensor + shift, is_permanent=True) + try: + activation_store_shifted = ActivationsStore.from_config( + model, cfg, override_dataset=example_dataset + ) + metrics_shifted, _ = get_sparsity_and_variance_metrics( + activation_store=activation_store_shifted, **eval_kwargs + ) + finally: + model.reset_hooks(including_permanent=True) + + # Tolerance is limited by float32 catastrophic cancellation in E[||x||^2] - ||E[x]||^2 + assert metrics_shifted["explained_variance"] == pytest.approx( + metrics_original["explained_variance"], rel=1e-3 + ) + + +def test_explained_variance_single_batch_matches_formula(): + d_model = 8 + n_samples = 10000 + x = torch.randn(n_samples, d_model) + 5.0 + x_hat = x + torch.randn(n_samples, d_model) * 0.3 + + ev = compute_explained_variance([x], [x_hat]) + + total_var = x.var(dim=0, correction=0).sum() + residual_var = (x - x_hat).pow(2).sum(dim=-1).mean() + expected_ev = (1 - residual_var / total_var).item() + + assert ev == pytest.approx(expected_ev, rel=1e-5) + + +def test_explained_variance_batched_matches_unbatched(): + d_model = 8 + n_samples = 3000 + x = torch.randn(n_samples, d_model) + 5.0 + x_hat = x + torch.randn(n_samples, d_model) * 0.3 + + ev_single = compute_explained_variance([x], [x_hat]) + + x_batches = list(x.chunk(3)) + x_hat_batches = list(x_hat.chunk(3)) + ev_batched = compute_explained_variance(x_batches, x_hat_batches) + + assert ev_batched == pytest.approx(ev_single, rel=1e-5) + + +def test_explained_variance_invariant_to_input_bias(): + # Use float64 to avoid catastrophic cancellation with large biases + d_model = 8 + n_samples = 10000 + x = torch.randn(n_samples, d_model, dtype=torch.float64) + x_hat = x + torch.randn(n_samples, d_model, dtype=torch.float64) * 0.3 + + ev_original = compute_explained_variance([x], [x_hat]) + + bias = torch.full((d_model,), 100.0, dtype=torch.float64) + ev_biased = compute_explained_variance([x + bias], [x_hat + bias]) + + assert ev_biased == pytest.approx(ev_original, rel=1e-10) + + +def test_explained_variance_zero_total_variance(): + d_model = 4 + n_samples = 100 + x = torch.full((n_samples, d_model), 5.0) + + # Constant input, perfect reconstruction -> 1.0 + assert compute_explained_variance([x], [x]) == 1.0 + + # Constant input, nonzero residual -> 0.0 + x_hat = x + 0.5 + assert compute_explained_variance([x], [x_hat]) == 0.0 + + +def test_explained_variance_matches_synthetic_calculator(): + d_model = 8 + n_samples = 3000 + batch_size = 1000 + x = torch.randn(n_samples, d_model) + 5.0 + x_hat = x + torch.randn(n_samples, d_model) * 0.3 + + x_batches = list(x.split(batch_size)) + x_hat_batches = list(x_hat.split(batch_size)) + + ev_evals = compute_explained_variance(x_batches, x_hat_batches) + + calc = ExplainedVarianceCalculator(hidden_dim=d_model) + for inp, out in zip(x_batches, x_hat_batches): + calc.add_batch(sae_output=out, hidden_acts=inp) + ev_synthetic = calc.compute() + + assert ev_evals == pytest.approx(ev_synthetic, rel=1e-5) + + def test_process_args(): args = [ "gpt2-small-res_scefr-ajt",