Skip to content

Add missing hooks#704

Merged
chanind merged 2 commits into
decoderesearch:mainfrom
danra:add_missing_hooks
Jun 16, 2026
Merged

Add missing hooks#704
chanind merged 2 commits into
decoderesearch:mainfrom
danra:add_missing_hooks

Conversation

@danra

@danra danra commented Jun 12, 2026

Copy link
Copy Markdown
Contributor

Add missing hook_sae_acts_pre/post in JumpReLUTrainingSAE and TemporalSAE

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

You have tested formatting, typing and tests

  • I have run make check-ci to check format and linting. (you can run make format to format code if needed.)

Copilot AI review requested due to automatic review settings June 12, 2026 21:37

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds/extends hook-point coverage across multiple SAE variants (including TemporalSAE and JumpReLU) and introduces regression tests to ensure hook caches are populated with expected tensors during forward passes.

Changes:

  • Add hook_sae_acts_pre / hook_sae_acts_post calls to TemporalSAE and JumpReLU encode paths.
  • Add tests across SAEs/Transcoders verifying hook cache keys and (where feasible) cached values.
  • Minor test import/type updates to support new parametrized coverage.

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
tests/saes/test_transcoder.py Adds parametrized test asserting transcoder hook caches are populated and consistent.
tests/saes/test_topk_sae.py Adds hook-cache assertions for TopK SAE training and inference paths.
tests/saes/test_temporal_sae.py Adds TemporalSAE hook-cache regression test and needed helper import.
tests/saes/test_standard_sae.py Adds hook-cache assertions for Standard SAE training and inference paths; imports TrainStepInput.
tests/saes/test_matryoshka_batchtopk_sae.py Adds hook-cache assertions for Matryoshka BatchTopK training path.
tests/saes/test_jumprelu_sae.py Extends existing training test to assert hook-cache contents.
tests/saes/test_gated_sae.py Adds hook-cache assertions for Gated training path.
tests/saes/test_batchtopk_sae.py Adds hook-cache assertions for BatchTopK training path.
sae_lens/saes/temporal_sae.py Wires hook_sae_acts_pre/hook_sae_acts_post into novel-code path in TemporalSAE.
sae_lens/saes/jumprelu_sae.py Wires hook_sae_acts_pre/hook_sae_acts_post into JumpReLU encode path.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +303 to +304
hidden_pre = self.hook_sae_acts_pre(torch.matmul(x_residual * self.lam, W_enc))
z_novel = F.relu(hidden_pre)
mask = torch.zeros_like(z_novel)
mask.scatter_(-1, topk_indices, 1)
z_novel = z_novel * mask
z_novel = self.hook_sae_acts_post(z_novel)
Comment thread tests/saes/test_standard_sae.py Outdated
Comment on lines +895 to +908
def test_StandardTrainingSAE_training_forward_pass_calls_hooks():
sae = StandardTrainingSAE(build_sae_training_cfg())
x = torch.randn(32, sae.cfg.d_in)
train_step_output = sae.training_forward_pass(
step_input=TrainStepInput(
sae_in=x,
coefficients={"l1": sae.cfg.l1_coefficient},
dead_neuron_mask=None,
n_training_steps=0,
is_logging_step=False,
),
)

_, cache = sae.run_with_cache(x)
Comment on lines +45 to +52
def test_TemporalSAE_forward_pass_calls_hooks():
sae = TemporalSAE(build_temporal_sae_cfg(dtype="float32"))
x = torch.randn(4, 16, sae.cfg.d_in)
out, cache = sae.run_with_cache(x)
assert_close(cache["hook_sae_input"], x)
assert "hook_sae_acts_pre" in cache
assert_close(cache["hook_sae_acts_post"], sae.encode(x))
assert_close(cache["hook_sae_output"], out)
Comment thread tests/saes/test_topk_sae.py Outdated
Comment on lines +482 to +504
def test_TopKTrainingSAE_training_forward_pass_calls_hooks():
sae = TopKTrainingSAE(build_topk_sae_training_cfg())
x = torch.randn(32, sae.cfg.d_in)
train_step_output = sae.training_forward_pass(
step_input=TrainStepInput(
sae_in=x,
coefficients={},
dead_neuron_mask=None,
n_training_steps=0,
is_logging_step=False,
),
)

_, cache = sae.run_with_cache(x)
assert_close(cache["hook_sae_input"], x)
# topk rescales hidden_pre by the decoder norm after hook_sae_acts_pre fires,
# so the hook captures the raw pre-activation, not train_step_output.hidden_pre
assert_close(
cache["hook_sae_acts_pre"], sae.process_sae_in(x) @ sae.W_enc + sae.b_enc
)
assert_close(cache["hook_sae_acts_post"], train_step_output.feature_acts)
assert_close(cache["hook_sae_recons"], train_step_output.sae_out)
assert_close(cache["hook_sae_output"], train_step_output.sae_out)
danra and others added 2 commits June 12, 2026 20:54
…ingSAE and TemporalSAE fail

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@danra danra force-pushed the add_missing_hooks branch from acf95b5 to 899c564 Compare June 13, 2026 04:03

@chanind chanind left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Thanks you for this!

@chanind chanind merged commit a2eb4b2 into decoderesearch:main Jun 16, 2026
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants