Skip to content

Add missing calls to reshape_fn_out() in TemporalSAE#695

Open
danra wants to merge 2 commits into
decoderesearch:mainfrom
danra:fix_temporal_sae_hook_z_reshape
Open

Add missing calls to reshape_fn_out() in TemporalSAE#695
danra wants to merge 2 commits into
decoderesearch:mainfrom
danra:fix_temporal_sae_hook_z_reshape

Conversation

@danra

@danra danra commented Jun 7, 2026

Copy link
Copy Markdown
Contributor

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)

Checklist:

  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing tests pass locally with my changes
  • I have not rewritten tests relating to key interfaces which would affect backward compatibility

You have tested formatting, typing and tests

  • I have run make check-ci to check format and linting. (you can run make format to format code if needed.)

danra and others added 2 commits June 6, 2026 23:29
…rward

Both methods are missing calls to reshape_fn_out()

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings June 7, 2026 06:37

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Adds output reshaping for hook_z-mode TemporalSAE so decoded/reconstructed activations return to (batch, seq, n_heads, d_head) layout, and introduces tests to validate the expected output shape.

Changes:

  • Reshape decoder/forward outputs via reshape_fn_out(..., d_head) to reverse hook_z input reshaping.
  • Add pytest coverage asserting decode/forward return to the original 4D hook_z shape.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 6 comments.

File Description
tests/saes/test_temporal_sae.py Adds shape-based tests for decode/forward when reshape_activations="hook_z".
sae_lens/saes/temporal_sae.py Applies reshape_fn_out in decode and forward to reverse hook_z reshaping.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 331 to 337
# Apply output activation normalization (reverses input normalization)
sae_out = self.run_time_activation_norm_fn_out(sae_out)

sae_out = self.reshape_fn_out(sae_out, self.d_head)

# Add bias (already removed in process_sae_in)
logger.warning(
assert reconstruction.shape == (batch_size, seq_len, cfg.d_in)


def test_TemporalSAE_decode_reverses_hook_z_reshaping():
Comment on lines +96 to +99
assert reconstruction.shape == x.shape


def test_TemporalSAE_forward_reverses_hook_z_reshaping():

reconstruction = sae(x)

assert reconstruction.shape == x.shape
Comment on lines +80 to +91
def test_TemporalSAE_decode_reverses_hook_z_reshaping():
n_heads = 4
d_head = 8
cfg = build_temporal_sae_cfg(d_in=n_heads * d_head, reshape_activations="hook_z")
sae = TemporalSAE.from_dict(cfg.to_dict())
assert sae.hook_z_reshaping_mode

batch_size = 2
seq_len = 6
x = torch.randn(
batch_size, seq_len, n_heads, d_head, dtype=DTYPE_MAP[sae.cfg.dtype]
)
Comment on lines +99 to +110
def test_TemporalSAE_forward_reverses_hook_z_reshaping():
n_heads = 4
d_head = 8
cfg = build_temporal_sae_cfg(d_in=n_heads * d_head, reshape_activations="hook_z")
sae = TemporalSAE.from_dict(cfg.to_dict())
assert sae.hook_z_reshaping_mode

batch_size = 2
seq_len = 6
x = torch.randn(
batch_size, seq_len, n_heads, d_head, dtype=DTYPE_MAP[sae.cfg.dtype]
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants