From 6b0dba604e85161cbf0209e3efb6e10abb7bb553 Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Thu, 11 Jun 2026 15:37:34 -0700 Subject: [PATCH 1/2] doc: Fix docstrings/comment in TemporalSAE encoding --- sae_lens/saes/temporal_sae.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/sae_lens/saes/temporal_sae.py b/sae_lens/saes/temporal_sae.py index 9f624249..1b00e09b 100644 --- a/sae_lens/saes/temporal_sae.py +++ b/sae_lens/saes/temporal_sae.py @@ -252,10 +252,12 @@ def initialize_weights(self) -> None: def encode_with_predictions( self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - """Encode input to novel codes only. + """Encode input to novel and predicted codes. - Returns only the sparse novel codes (not predicted codes). - This is the main feature representation for TemporalSAE. + Returns a tuple (z_novel, z_pred), where z_novel contains the sparse + novel codes (the main feature representation for TemporalSAE) and + z_pred contains the codes predicted from context by the attention + layers. """ # Process input through SAELens preprocessing x = self.process_sae_in(x) @@ -309,11 +311,16 @@ def encode_with_predictions( mask.scatter_(-1, topk_indices, 1) z_novel = z_novel * mask - # Return only novel codes (these are the interpretable features) + # Return novel codes (the interpretable features) and predicted codes return z_novel, z_pred @override def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode input to novel codes only. + + Returns only the sparse novel codes (not predicted codes). + This is the main feature representation for TemporalSAE. + """ return self.encode_with_predictions(x)[0] @override From e4db4acda86abe3c655d94844f5a9e9637072f0b Mon Sep 17 00:00:00 2001 From: Dan Raviv Date: Thu, 11 Jun 2026 17:39:06 -0700 Subject: [PATCH 2/2] chore: Drop incorrect comment --- sae_lens/saes/temporal_sae.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sae_lens/saes/temporal_sae.py b/sae_lens/saes/temporal_sae.py index 1b00e09b..1bbcd4f7 100644 --- a/sae_lens/saes/temporal_sae.py +++ b/sae_lens/saes/temporal_sae.py @@ -145,7 +145,6 @@ class TemporalSAEConfig(SAEConfig): activation_normalization_factor: float = 1.0 def __post_init__(self): - # Call parent's __post_init__ first, but allow constant_scalar_rescale if self.normalize_activations not in [ "none", "expected_average_only_in",