TemporalSAE: don't apply decoding bias if weights are tied and bias wasn't applied at encoding#703
Open
danra wants to merge 2 commits into
Open
TemporalSAE: don't apply decoding bias if weights are tied and bias wasn't applied at encoding#703danra wants to merge 2 commits into
danra wants to merge 2 commits into
Conversation
…ias wasn't applied at encoding
Contributor
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Expands TemporalSAE’s bias-handling logic to support a new apply_b_dec_to_input mode (especially relevant when weights are tied) and updates tests to cover the new configuration combinations.
Changes:
- Conditioned adding
b_decinTemporalSAE.decode()andTemporalSAE.forward()based ontied_weights+apply_b_dec_to_input. - Broadened TemporalSAE unit tests to parametrize over
apply_b_dec_to_input(andtied_weightsfor decode).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
| tests/saes/test_temporal_sae.py | Expands test parametrization to cover apply_b_dec_to_input (and tied_weights for decode). |
| sae_lens/saes/temporal_sae.py | Makes decoder bias application conditional based on config flags. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+328
to
+329
| if not self.cfg.tied_weights or self.cfg.apply_b_dec_to_input: | ||
| sae_out = sae_out + self.b_dec |
Comment on lines
+328
to
+329
| if not self.cfg.tied_weights or self.cfg.apply_b_dec_to_input: | ||
| sae_out = sae_out + self.b_dec |
Comment on lines
+352
to
+354
| x_recons = torch.matmul(z_novel + z_pred, self.W_dec) | ||
| if not self.cfg.tied_weights or self.cfg.apply_b_dec_to_input: | ||
| x_recons = x_recons + self.b_dec |
| # Decode novel codes | ||
| sae_out = torch.matmul(feature_acts, self.W_dec) | ||
| sae_out = sae_out + self.b_dec | ||
| if not self.cfg.tied_weights or self.cfg.apply_b_dec_to_input: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
TemporalSAE decoding previously added the decoder bias unconditionally at decoding time. It should only be done in case the encoder/decoder weights are untied, or, if tied, in case the bias was also subtracted at encoding time.
Type of change
Please delete options that are not relevant.
Checklist:
You have tested formatting, typing and tests
make check-cito check format and linting. (you can runmake formatto format code if needed.)