Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions sae_lens/saes/temporal_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ class TemporalSAEConfig(SAEConfig):
activation_normalization_factor: float = 1.0

def __post_init__(self):
# Call parent's __post_init__ first, but allow constant_scalar_rescale
if self.normalize_activations not in [
"none",
"expected_average_only_in",
Expand Down Expand Up @@ -252,10 +251,12 @@ def initialize_weights(self) -> None:
def encode_with_predictions(
self, x: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Encode input to novel codes only.
"""Encode input to novel and predicted codes.

Returns only the sparse novel codes (not predicted codes).
This is the main feature representation for TemporalSAE.
Returns a tuple (z_novel, z_pred), where z_novel contains the sparse
novel codes (the main feature representation for TemporalSAE) and
z_pred contains the codes predicted from context by the attention
layers.
"""
# Process input through SAELens preprocessing
x = self.process_sae_in(x)
Expand Down Expand Up @@ -309,11 +310,16 @@ def encode_with_predictions(
mask.scatter_(-1, topk_indices, 1)
z_novel = z_novel * mask

# Return only novel codes (these are the interpretable features)
# Return novel codes (the interpretable features) and predicted codes
return z_novel, z_pred

@override
def encode(self, x: torch.Tensor) -> torch.Tensor:
"""Encode input to novel codes only.

Returns only the sparse novel codes (not predicted codes).
This is the main feature representation for TemporalSAE.
"""
return self.encode_with_predictions(x)[0]

@override
Expand Down
Loading