Skip to content
57 changes: 37 additions & 20 deletions sae_lens/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,36 @@ def get_downstream_reconstruction_metrics(
return metrics


def compute_explained_variance(
input_batches: list[torch.Tensor],
output_batches: list[torch.Tensor],
) -> float:
"""Compute explained variance from lists of input/output tensor pairs.

Each tensor has shape (n_tokens, d_model). Total variance is computed as
Var(X) = E[||X||^2] - ||E[X]||^2, summed across dimensions.
"""
mean_sum_of_squares: list[torch.Tensor] = []
mean_act_per_dimension: list[torch.Tensor] = []
mean_sum_of_resid_squared: list[torch.Tensor] = []

for sae_input, sae_output in zip(input_batches, output_batches):
mean_sum_of_squares.append(sae_input.pow(2).sum(dim=-1).mean(dim=0))
mean_act_per_dimension.append(sae_input.mean(dim=0))
resid_ss = (sae_input - sae_output).pow(2).sum(dim=-1)
mean_sum_of_resid_squared.append(resid_ss.mean(dim=0))

total_mean_ss = torch.stack(mean_sum_of_squares).mean(dim=0)
total_mean_act = torch.stack(mean_act_per_dimension).mean(dim=0)
total_variance = total_mean_ss - (total_mean_act**2).sum()
residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0)

eps = 1e-12
if torch.abs(total_variance) <= eps:
return 1.0 if torch.abs(residual_variance) <= eps else 0.0
return (1 - residual_variance / total_variance).item()


def get_sparsity_and_variance_metrics(
sae: SAE[Any],
model: HookedRootModule,
Expand Down Expand Up @@ -411,9 +441,8 @@ def get_sparsity_and_variance_metrics(
metric_dict["l0"] = []
metric_dict["l1"] = []

mean_sum_of_squares = [] # for explained variance
mean_act_per_dimension = [] # for explained variance
mean_sum_of_resid_squared = [] # for explained variance
variance_inputs: list[torch.Tensor] = []
variance_outputs: list[torch.Tensor] = []
if compute_variance_metrics:
# explained_variance is left out of the dict here, since we don't want to naively
# average over the batch dimension. This is handled later in the function.
Expand Down Expand Up @@ -542,18 +571,8 @@ def get_sparsity_and_variance_metrics(
)
explained_variance_legacy = 1 - resid_sum_of_squares / batched_variance_sum
metric_dict["explained_variance_legacy"].append(explained_variance_legacy)
# Individual sums for the new (correct) formula. We're taking the mean over the batch
# dimension here to save memory, but we could also pass the full tensors and take the
# mean later (like we do for other metrics).
mean_sum_of_squares.append(
(flattened_sae_input).pow(2).sum(dim=-1).mean(dim=0) # scalar
)
mean_act_per_dimension.append(
(flattened_sae_input).pow(2).mean(dim=0) # [d_model]
)
mean_sum_of_resid_squared.append(
resid_sum_of_squares.mean(dim=0) # scalar
)
variance_inputs.append(flattened_sae_input)
variance_outputs.append(flattened_sae_out)

x_normed = flattened_sae_input / torch.norm(
flattened_sae_input, dim=-1, keepdim=True
Expand Down Expand Up @@ -581,11 +600,9 @@ def get_sparsity_and_variance_metrics(

# calculate explained variance
if compute_variance_metrics:
mean_sum_of_squares = torch.stack(mean_sum_of_squares).mean(dim=0)
mean_act_per_dimension = torch.cat(mean_act_per_dimension).mean(dim=0)
total_variance = mean_sum_of_squares - mean_act_per_dimension**2
residual_variance = torch.stack(mean_sum_of_resid_squared).mean(dim=0)
metrics["explained_variance"] = (1 - residual_variance / total_variance).item()
metrics["explained_variance"] = compute_explained_variance(
variance_inputs, variance_outputs
)

# Aggregate feature-wise metrics
feature_metrics: dict[str, list[float]] = {}
Expand Down
137 changes: 137 additions & 0 deletions tests/test_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
EvalConfig,
_kl,
all_loadable_saes,
compute_explained_variance,
get_downstream_reconstruction_metrics,
get_eval_everything_config,
get_saes_from_regex,
Expand All @@ -33,6 +34,7 @@
from sae_lens.saes.sae import SAE, TrainingSAE
from sae_lens.saes.standard_sae import StandardSAE, StandardTrainingSAE
from sae_lens.saes.topk_sae import TopKTrainingSAE
from sae_lens.synthetic.evals import ExplainedVarianceCalculator
from sae_lens.training.activation_scaler import ActivationScaler
from sae_lens.training.activations_store import ActivationsStore
from tests.helpers import (
Expand Down Expand Up @@ -644,6 +646,141 @@ def test_get_sparsity_and_variance_metrics_identity_sae_perfect_reconstruction(
assert metrics["mse"] == pytest.approx(0.0, abs=1e-5)


def test_explained_variance_invariant_to_activation_shift(
model: HookedTransformer,
example_dataset: Dataset,
):
d_in = 64
hook_name = "blocks.1.hook_resid_pre"
cfg = build_runner_cfg(d_in=d_in, d_sae=2 * d_in, hook_name=hook_name)

training_sae = StandardTrainingSAE.from_dict(cfg.get_training_sae_cfg_dict())
sae = StandardSAE.from_dict(training_sae.cfg.get_inference_sae_cfg_dict())
random_params(sae)
# The shift invariance requires b_dec to be subtracted from the input during
# encoding so that the shift cancels: encode(x+c) with b_dec+c == encode(x) with b_dec.
sae.cfg.apply_b_dec_to_input = True

eval_kwargs: dict[str, Any] = dict(
sae=sae,
model=model,
activation_scaler=ActivationScaler(None),
n_batches=3,
compute_l2_norms=False,
compute_sparsity_metrics=False,
compute_variance_metrics=True,
compute_featurewise_density_statistics=True,
eval_batch_size_prompts=4,
model_kwargs={},
)

activation_store = ActivationsStore.from_config(
model, cfg, override_dataset=example_dataset
)
metrics_original, _ = get_sparsity_and_variance_metrics(
activation_store=activation_store, **eval_kwargs
)

# Shifting activations by a constant and adjusting b_dec by the same amount
# should not change explained_variance: the shift cancels in encoding
# (so features are identical) and both input and output shift equally.
shift = torch.full((d_in,), 5.0)
sae.b_dec.data += shift
model.add_hook(hook_name, lambda tensor, **_: tensor + shift, is_permanent=True)
try:
activation_store_shifted = ActivationsStore.from_config(
model, cfg, override_dataset=example_dataset
)
metrics_shifted, _ = get_sparsity_and_variance_metrics(
activation_store=activation_store_shifted, **eval_kwargs
)
finally:
model.reset_hooks(including_permanent=True)

# Tolerance is limited by float32 catastrophic cancellation in E[||x||^2] - ||E[x]||^2
assert metrics_shifted["explained_variance"] == pytest.approx(
metrics_original["explained_variance"], rel=1e-3
)


def test_explained_variance_single_batch_matches_formula():
d_model = 8
n_samples = 10000
x = torch.randn(n_samples, d_model) + 5.0
x_hat = x + torch.randn(n_samples, d_model) * 0.3

ev = compute_explained_variance([x], [x_hat])

total_var = x.var(dim=0, correction=0).sum()
residual_var = (x - x_hat).pow(2).sum(dim=-1).mean()
expected_ev = (1 - residual_var / total_var).item()

assert ev == pytest.approx(expected_ev, rel=1e-5)


def test_explained_variance_batched_matches_unbatched():
d_model = 8
n_samples = 3000
x = torch.randn(n_samples, d_model) + 5.0
x_hat = x + torch.randn(n_samples, d_model) * 0.3

ev_single = compute_explained_variance([x], [x_hat])

x_batches = list(x.chunk(3))
x_hat_batches = list(x_hat.chunk(3))
ev_batched = compute_explained_variance(x_batches, x_hat_batches)

assert ev_batched == pytest.approx(ev_single, rel=1e-5)


def test_explained_variance_invariant_to_input_bias():
# Use float64 to avoid catastrophic cancellation with large biases
d_model = 8
n_samples = 10000
x = torch.randn(n_samples, d_model, dtype=torch.float64)
x_hat = x + torch.randn(n_samples, d_model, dtype=torch.float64) * 0.3

ev_original = compute_explained_variance([x], [x_hat])

bias = torch.full((d_model,), 100.0, dtype=torch.float64)
ev_biased = compute_explained_variance([x + bias], [x_hat + bias])

assert ev_biased == pytest.approx(ev_original, rel=1e-10)


def test_explained_variance_zero_total_variance():
Comment thread
chanind marked this conversation as resolved.
d_model = 4
n_samples = 100
x = torch.full((n_samples, d_model), 5.0)

# Constant input, perfect reconstruction -> 1.0
assert compute_explained_variance([x], [x]) == 1.0

# Constant input, nonzero residual -> 0.0
x_hat = x + 0.5
assert compute_explained_variance([x], [x_hat]) == 0.0


def test_explained_variance_matches_synthetic_calculator():
d_model = 8
n_samples = 3000
batch_size = 1000
x = torch.randn(n_samples, d_model) + 5.0
x_hat = x + torch.randn(n_samples, d_model) * 0.3

x_batches = list(x.split(batch_size))
x_hat_batches = list(x_hat.split(batch_size))

ev_evals = compute_explained_variance(x_batches, x_hat_batches)

calc = ExplainedVarianceCalculator(hidden_dim=d_model)
for inp, out in zip(x_batches, x_hat_batches):
calc.add_batch(sae_output=out, hidden_acts=inp)
ev_synthetic = calc.compute()

assert ev_evals == pytest.approx(ev_synthetic, rel=1e-5)


def test_process_args():
args = [
"gpt2-small-res_scefr-ajt",
Expand Down
Loading