diff --git a/sae_lens/saes/temporal_sae.py b/sae_lens/saes/temporal_sae.py index 9f624249..1bbcd4f7 100644 --- a/sae_lens/saes/temporal_sae.py +++ b/sae_lens/saes/temporal_sae.py @@ -145,7 +145,6 @@ class TemporalSAEConfig(SAEConfig): activation_normalization_factor: float = 1.0 def __post_init__(self): - # Call parent's __post_init__ first, but allow constant_scalar_rescale if self.normalize_activations not in [ "none", "expected_average_only_in", @@ -252,10 +251,12 @@ def initialize_weights(self) -> None: def encode_with_predictions( self, x: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: - """Encode input to novel codes only. + """Encode input to novel and predicted codes. - Returns only the sparse novel codes (not predicted codes). - This is the main feature representation for TemporalSAE. + Returns a tuple (z_novel, z_pred), where z_novel contains the sparse + novel codes (the main feature representation for TemporalSAE) and + z_pred contains the codes predicted from context by the attention + layers. """ # Process input through SAELens preprocessing x = self.process_sae_in(x) @@ -309,11 +310,16 @@ def encode_with_predictions( mask.scatter_(-1, topk_indices, 1) z_novel = z_novel * mask - # Return only novel codes (these are the interpretable features) + # Return novel codes (the interpretable features) and predicted codes return z_novel, z_pred @override def encode(self, x: torch.Tensor) -> torch.Tensor: + """Encode input to novel codes only. + + Returns only the sparse novel codes (not predicted codes). + This is the main feature representation for TemporalSAE. + """ return self.encode_with_predictions(x)[0] @override