Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions sae_lens/saes/gated_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,12 @@ def decode(self, feature_acts: torch.Tensor) -> torch.Tensor:
@torch.no_grad()
def fold_W_dec_norm(self):
"""Override to handle gated-specific parameters."""
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
self.W_dec.data = self.W_dec.data / W_dec_norms
self.W_enc.data = self.W_enc.data * W_dec_norms.T
W_dec_norms = self.fold_and_get_W_dec_norm()

# Gated-specific parameters need special handling
# r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path
self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
self.b_gate.data = self.b_gate.data * W_dec_norms
self.b_mag.data = self.b_mag.data * W_dec_norms


@dataclass
Expand Down Expand Up @@ -226,14 +224,12 @@ def get_coefficients(self) -> dict[str, float | TrainCoefficientConfig]:
@torch.no_grad()
def fold_W_dec_norm(self):
"""Override to handle gated-specific parameters."""
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
self.W_dec.data = self.W_dec.data / W_dec_norms
self.W_enc.data = self.W_enc.data * W_dec_norms.T
W_dec_norms = self.fold_and_get_W_dec_norm()

# Gated-specific parameters need special handling
# r_mag doesn't need scaling since W_enc scaling is sufficient for magnitude path
self.b_gate.data = self.b_gate.data * W_dec_norms.squeeze()
self.b_mag.data = self.b_mag.data * W_dec_norms.squeeze()
self.b_gate.data = self.b_gate.data * W_dec_norms
self.b_mag.data = self.b_mag.data * W_dec_norms


def _init_weights_gated(
Expand Down
19 changes: 7 additions & 12 deletions sae_lens/saes/jumprelu_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,8 @@ def fold_W_dec_norm(self):
# Save the current threshold before calling parent method
current_thresh = self.threshold.clone()

# Get W_dec norms that will be used for scaling (clamped to avoid division by zero)
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8)

# Call parent implementation to handle W_enc, W_dec, and b_enc adjustment
super().fold_W_dec_norm()
# Handle W_enc, W_dec, and b_enc adjustment
W_dec_norms = self.fold_and_get_W_dec_norm()

# Scale the threshold by the same factor as we scaled b_enc
# This ensures the same features remain active/inactive after folding
Expand Down Expand Up @@ -333,14 +330,12 @@ def fold_W_dec_norm(self):
# Save the current threshold before we call the parent method
current_thresh = self.threshold.clone()

# Get W_dec norms (clamped to avoid division by zero)
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)

# Call parent implementation to handle W_enc and W_dec adjustment
super().fold_W_dec_norm()
# Handle W_enc, W_dec, and b_enc adjustment
W_dec_norms = self.fold_and_get_W_dec_norm()

# Fix: Use squeeze() instead of squeeze(-1) to match old behavior
self.log_threshold.data = torch.log(current_thresh * W_dec_norms.squeeze())
# Scale the threshold by the same factor as we scaled b_enc
# This ensures the same features remain active/inactive after folding
self.log_threshold.data = torch.log(current_thresh * W_dec_norms)

@override
def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
Expand Down
21 changes: 16 additions & 5 deletions sae_lens/saes/sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,16 +504,27 @@ def process_state_dict_for_saving(self, state_dict: dict[str, Any]) -> None:
def process_state_dict_for_loading(self, state_dict: dict[str, Any]) -> None:
pass

@torch.no_grad()
def fold_W_dec_norm(self):
"""Fold decoder norms into encoder."""
W_dec_norms = self.W_dec.norm(dim=-1).clamp(min=1e-8).unsqueeze(1)
self.W_dec.data = self.W_dec.data / W_dec_norms
self.W_enc.data = self.W_enc.data * W_dec_norms.T
self.fold_and_get_W_dec_norm()

@torch.no_grad()
def get_W_dec_norm(self) -> torch.Tensor:
"""Get decoder norms."""
return self.W_dec.norm(dim=-1).clamp(min=1e-8)

@torch.no_grad()
def fold_and_get_W_dec_norm(self) -> torch.Tensor:
"""Fold decoder norms into encoder and return them."""
W_dec_norms = self.get_W_dec_norm()
self.W_dec.data = self.W_dec.data / W_dec_norms.unsqueeze(1)
self.W_enc.data = self.W_enc.data * W_dec_norms.unsqueeze(1).T

# Only update b_enc if it exists (standard/jumprelu architectures)
if hasattr(self, "b_enc") and isinstance(self.b_enc, nn.Parameter):
self.b_enc.data = self.b_enc.data * W_dec_norms.squeeze()
self.b_enc.data = self.b_enc.data * W_dec_norms

return W_dec_norms

def get_name(self):
"""Generate a name for this SAE."""
Expand Down
7 changes: 4 additions & 3 deletions sae_lens/saes/topk_sae.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def fold_W_dec_norm(self) -> None:
raise NotImplementedError(
"Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations"
)
_fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc)
super().fold_W_dec_norm()


@dataclass
Expand Down Expand Up @@ -445,7 +445,7 @@ def fold_W_dec_norm(self) -> None:
raise NotImplementedError(
"Folding W_dec_norm is not safe for TopKSAEs when rescale_acts_by_decoder_norm is False, as this may change the topk activations"
)
_fold_norm_topk(W_dec=self.W_dec, b_enc=self.b_enc, W_enc=self.W_enc)
super().fold_W_dec_norm()

@override
def get_activation_fn(self) -> Callable[[torch.Tensor], torch.Tensor]:
Expand Down Expand Up @@ -515,6 +515,7 @@ def process_state_dict_for_saving_inference(
super().process_state_dict_for_saving_inference(state_dict)
if self.cfg.rescale_acts_by_decoder_norm:
_fold_norm_topk(
W_dec_norm=self.get_W_dec_norm(),
W_enc=state_dict["W_enc"],
b_enc=state_dict["b_enc"],
W_dec=state_dict["W_dec"],
Expand Down Expand Up @@ -558,11 +559,11 @@ def _init_weights_topk(


def _fold_norm_topk(
W_dec_norm: torch.Tensor,
W_enc: torch.Tensor,
b_enc: torch.Tensor,
W_dec: torch.Tensor,
) -> None:
W_dec_norm = W_dec.norm(dim=-1).clamp(min=1e-8)
b_enc.data = b_enc.data * W_dec_norm
W_dec_norms = W_dec_norm.unsqueeze(1)
W_dec.data = W_dec.data / W_dec_norms
Expand Down
8 changes: 2 additions & 6 deletions sae_lens/saes/transcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,10 @@ def fold_W_dec_norm(self) -> None:
This is important for JumpReLU as the threshold needs to be scaled
along with the decoder weights.
"""
# Get the decoder weight norms before normalizing
with torch.no_grad():
W_dec_norms = self.W_dec.norm(dim=1)

# Fold the decoder norms as in the parent class
super().fold_W_dec_norm()
W_dec_norms = self.fold_and_get_W_dec_norm()

# Scale the threshold by the decoder weight norms
# Scale the threshold by the same norms
with torch.no_grad():
self.threshold.data = self.threshold.data * W_dec_norms

Expand Down
Loading
Loading