diff --git a/sae_lens/sae.py b/sae_lens/sae.py index 1610090c3..111148ea7 100644 --- a/sae_lens/sae.py +++ b/sae_lens/sae.py @@ -67,7 +67,6 @@ class SAEConfig: @classmethod def from_dict(cls, config_dict: dict[str, Any]) -> "SAEConfig": - # rename dict: rename_dict = { # old : new "hook_point": "hook_name", @@ -155,17 +154,8 @@ def __init__( self.device = torch.device(cfg.device) self.use_error_term = use_error_term - if self.cfg.architecture == "standard": - self.initialize_weights_basic() - self.encode = self.encode_standard - elif self.cfg.architecture == "gated": - self.initialize_weights_gated() - self.encode = self.encode_gated - elif self.cfg.architecture == "jumprelu": - self.initialize_weights_jumprelu() - self.encode = self.encode_jumprelu - else: - raise (ValueError) + if self.cfg.architecture not in ["standard", "gated", "jumprelu"]: + raise ValueError(f"Architecture {self.cfg.architecture} not supported") # handle presence / absence of scaling factor. if self.cfg.finetuning_scaling_factor: @@ -196,7 +186,6 @@ def __init__( # handle run time activation normalization if needed: if self.cfg.normalize_activations == "constant_norm_rescale": - # we need to scale the norm of the input and store the scaling factor def run_time_activation_norm_fn_in(x: torch.Tensor) -> torch.Tensor: self.x_norm_coeff = (self.cfg.d_in**0.5) / x.norm(dim=-1, keepdim=True) @@ -212,7 +201,6 @@ def run_time_activation_norm_fn_out(x: torch.Tensor) -> torch.Tensor: # self.run_time_activation_norm_fn_out = run_time_activation_norm_fn_out elif self.cfg.normalize_activations == "layer_norm": - # we need to scale the norm of the input and store the scaling factor def run_time_activation_ln_in( x: torch.Tensor, eps: float = 1e-5 @@ -236,8 +224,17 @@ def run_time_activation_ln_out(x: torch.Tensor, eps: float = 1e-5): self.setup() # Required for `HookedRootModule`s - def initialize_weights_basic(self): + def initialize_weights(self): + if self.cfg.architecture == "standard": + self.initialize_weights_basic() + elif self.cfg.architecture == "gated": + self.initialize_weights_gated() + elif self.cfg.architecture == "jumprelu": + self.initialize_weights_jumprelu() + else: + raise (ValueError) + def initialize_weights_basic(self): # no config changes encoder bias init for now. self.b_enc = nn.Parameter( torch.zeros(self.cfg.d_sae, dtype=self.dtype, device=self.device) @@ -491,9 +488,39 @@ def forward( return self.hook_sae_output(sae_out) + def encode( + self, x: torch.Tensor, latents: torch.Tensor | None = None + ) -> torch.Tensor: + """ + Calculate SAE latents from inputs. Includes optional `latents` argument to only calculate a subset. Note that + this won't make sense for topk SAEs, because we need to compute all hidden values to apply the topk masking. + """ + if self.cfg.activation_fn_str == "topk": + assert ( + latents is None + ), "Computing a slice of SAE hidden values doesn't make sense in topk SAEs." + + return { + "standard": self.encode_standard, + "gated": self.encode_gated, + "jumprelu": self.encode_jumprelu, + }[self.cfg.architecture](x, latents) + def encode_gated( - self, x: Float[torch.Tensor, "... d_in"] + self, + x: Float[torch.Tensor, "... d_in"], + latents: torch.Tensor | None = None, ) -> Float[torch.Tensor, "... d_sae"]: + """ + Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are + computed as the product of the masking term & the post-activation function magnitude term: + + 1[(x - b_dec) @ W_gate + b_gate > 0] * activation_fn((x - b_dec) @ W_enc + b_enc) + + The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not + provided, all latent values will be computed. + """ + latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents x = x.to(self.dtype) x = self.reshape_fn_in(x) @@ -502,12 +529,15 @@ def encode_gated( sae_in = x - self.b_dec * self.cfg.apply_b_dec_to_input # Gating path - gating_pre_activation = sae_in @ self.W_enc + self.b_gate + gating_pre_activation = ( + sae_in @ self.W_enc[:, latents_tensor] + self.b_gate[latents_tensor] + ) active_features = (gating_pre_activation > 0).to(self.dtype) # Magnitude path with weight sharing magnitude_pre_activation = self.hook_sae_acts_pre( - sae_in @ (self.W_enc * self.r_mag.exp()) + self.b_mag + sae_in @ (self.W_enc[:, latents_tensor] * self.r_mag[latents_tensor].exp()) + + self.b_mag[latents_tensor] ) feature_magnitudes = self.activation_fn(magnitude_pre_activation) @@ -516,11 +546,20 @@ def encode_gated( return feature_acts def encode_jumprelu( - self, x: Float[torch.Tensor, "... d_in"] + self, + x: Float[torch.Tensor, "... d_in"], + latents: torch.Tensor | None = None, ) -> Float[torch.Tensor, "... d_sae"]: """ - Calculate SAE features from inputs + Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are + computed as: + + activation_fn((x - b_dec) @ W_enc + b_enc) * 1[(x - b_dec) @ W_enc + b_enc > threshold] + + The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not + provided, all latent values will be computed. """ + latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents # move x to correct dtype x = x.to(self.dtype) @@ -535,20 +574,32 @@ def encode_jumprelu( sae_in = self.hook_sae_input(x - (self.b_dec * self.cfg.apply_b_dec_to_input)) # "... d_in, d_in d_sae -> ... d_sae", - hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) + hidden_pre = self.hook_sae_acts_pre( + sae_in @ self.W_enc[:, latents_tensor] + self.b_enc[latents_tensor] + ) feature_acts = self.hook_sae_acts_post( - self.activation_fn(hidden_pre) * (hidden_pre > self.threshold) + self.activation_fn(hidden_pre) + * (hidden_pre > self.threshold[latents_tensor]) ) return feature_acts def encode_standard( - self, x: Float[torch.Tensor, "... d_in"] + self, + x: Float[torch.Tensor, "... d_in"], + latents: torch.Tensor | None = None, ) -> Float[torch.Tensor, "... d_sae"]: """ - Calculate SAE features from inputs + Computes the latent values of the Sparse Autoencoder (SAE) using a gated architecture. The activation values are + computed as: + + activation_fn((x - b_dec) @ W_enc + b_enc) + + The `latents` argument allows for the computation of a specific subset of the hidden values. If `latents` is not + provided, all latent values will be computed. """ + latents_tensor = torch.arange(self.cfg.d_sae) if latents is None else latents x = x.to(self.dtype) x = self.reshape_fn_in(x) @@ -559,7 +610,9 @@ def encode_standard( sae_in = x - (self.b_dec * self.cfg.apply_b_dec_to_input) # "... d_in, d_in d_sae -> ... d_sae", - hidden_pre = self.hook_sae_acts_pre(sae_in @ self.W_enc + self.b_enc) + hidden_pre = self.hook_sae_acts_pre( + sae_in @ self.W_enc[:, latents_tensor] + self.b_enc[latents_tensor] + ) feature_acts = self.hook_sae_acts_post(self.activation_fn(hidden_pre)) return feature_acts @@ -606,7 +659,6 @@ def fold_activation_norm_scaling_factor( self.cfg.normalize_activations = "none" def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None): - if not os.path.exists(path): os.mkdir(path) @@ -627,7 +679,6 @@ def save_model(self, path: str, sparsity: Optional[torch.Tensor] = None): def load_from_pretrained( cls, path: str, device: str = "cpu", dtype: str | None = None ) -> "SAE": - # get the config config_path = os.path.join(path, SAE_CFG_PATH) with open(config_path, "r") as f: @@ -752,7 +803,6 @@ def from_dict(cls, config_dict: dict[str, Any]) -> "SAE": return cls(SAEConfig.from_dict(config_dict)) def turn_on_forward_pass_hook_z_reshaping(self): - assert self.cfg.hook_name.endswith( "_z" ), "This method should only be called for hook_z SAEs." diff --git a/sae_lens/training/training_sae.py b/sae_lens/training/training_sae.py index b7925d4ef..94ea550c1 100644 --- a/sae_lens/training/training_sae.py +++ b/sae_lens/training/training_sae.py @@ -38,7 +38,6 @@ class TrainStepOutput: @dataclass(kw_only=True) class TrainingSAEConfig(SAEConfig): - # Sparsity Loss Calculations l1_coefficient: float lp_norm: float @@ -55,7 +54,6 @@ class TrainingSAEConfig(SAEConfig): def from_sae_runner_config( cls, cfg: LanguageModelSAERunnerConfig ) -> "TrainingSAEConfig": - return cls( # base config architecture=cfg.architecture, @@ -168,7 +166,6 @@ class TrainingSAE(SAE): device: torch.device def __init__(self, cfg: TrainingSAEConfig, use_error_term: bool = False): - base_sae_cfg = SAEConfig.from_dict(cfg.get_base_sae_cfg_dict()) super().__init__(base_sae_cfg) self.cfg = cfg # type: ignore @@ -203,18 +200,21 @@ def check_cfg_compatibility(self): assert self.use_error_term is False, "Gated SAEs do not support error terms" def encode_standard( - self, x: Float[torch.Tensor, "... d_in"] + self, x: Float[torch.Tensor, "... d_in"], latents: torch.Tensor | None = None ) -> Float[torch.Tensor, "... d_sae"]: """ - Calcuate SAE features from inputs + Calcuate SAE features from inputs. The `latents` argument is ignored (this is just so the type signature matches + the parent class, which uses this argument to compute only a subset of the SAE hidden values) """ + assert ( + latents is None + ), "Function `encode_standard` in training should always return activations for all latents" feature_acts, _ = self.encode_with_hidden_pre_fn(x) return feature_acts def encode_with_hidden_pre( self, x: Float[torch.Tensor, "... d_in"] ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: - x = x.to(self.dtype) x = self.reshape_fn_in(x) # type: ignore x = self.hook_sae_input(x) @@ -235,7 +235,6 @@ def encode_with_hidden_pre( def encode_with_hidden_pre_gated( self, x: Float[torch.Tensor, "... d_in"] ) -> tuple[Float[torch.Tensor, "... d_sae"], Float[torch.Tensor, "... d_sae"]]: - x = x.to(self.dtype) x = self.reshape_fn_in(x) # type: ignore x = self.hook_sae_input(x) @@ -267,7 +266,6 @@ def forward( self, x: Float[torch.Tensor, "... d_in"], ) -> Float[torch.Tensor, "... d_in"]: - feature_acts, _ = self.encode_with_hidden_pre_fn(x) sae_out = self.decode(feature_acts) @@ -279,7 +277,6 @@ def training_forward_pass( current_l1_coefficient: float, dead_neuron_mask: Optional[torch.Tensor] = None, ) -> TrainStepOutput: - # do a forward pass to get SAE out, but we also need the # hidden pre. feature_acts, _ = self.encode_with_hidden_pre_fn(sae_in) @@ -291,7 +288,6 @@ def training_forward_pass( # GHOST GRADS if self.cfg.use_ghost_grads and self.training and dead_neuron_mask is not None: - # first half of second forward pass _, hidden_pre = self.encode_with_hidden_pre_fn(sae_in) ghost_grad_loss = self.calculate_ghost_grad_loss( @@ -362,7 +358,6 @@ def calculate_ghost_grad_loss( hidden_pre: torch.Tensor, dead_neuron_mask: torch.Tensor, ) -> torch.Tensor: - # 1. residual = x - sae_out l2_norm_residual = torch.norm(residual, dim=-1) @@ -394,7 +389,6 @@ def calculate_ghost_grad_loss( @torch.no_grad() def _get_mse_loss_fn(self) -> Any: - def standard_mse_loss_fn( preds: torch.Tensor, target: torch.Tensor ) -> torch.Tensor: @@ -421,7 +415,6 @@ def load_from_pretrained( device: str = "cpu", dtype: str | None = None, ) -> "TrainingSAE": - # get the config config_path = os.path.join(path, SAE_CFG_PATH) with open(config_path, "r") as f: diff --git a/tests/unit/training/test_sae_basic.py b/tests/unit/training/test_sae_basic.py index 530fca2cd..b514d3fd3 100644 --- a/tests/unit/training/test_sae_basic.py +++ b/tests/unit/training/test_sae_basic.py @@ -70,6 +70,21 @@ def test_sae_init(cfg: LanguageModelSAERunnerConfig): assert sae.b_dec.shape == (cfg.d_in,) +@pytest.mark.parametrize("architecture", ["standard", "gated", "jumprelu"]) +def test_sae_encode_with_different_architectures(architecture: str) -> None: + cfg = build_sae_cfg(architecture=architecture) + sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) + assert isinstance(cfg.d_sae, int) + + activations = torch.randn(10, 4, cfg.d_in, device=cfg.device) + latents = torch.randint(low=0, high=cfg.d_sae, size=(10,)) + feature_activations = sae.encode(activations) + feature_activations_slice = sae.encode(activations, latents=latents) + torch.testing.assert_close( + feature_activations[..., latents], feature_activations_slice + ) + + def test_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): sae = SAE.from_dict(cfg.get_base_sae_cfg_dict()) sae.turn_off_forward_pass_hook_z_reshaping() # hook z reshaping not needed here. @@ -106,7 +121,6 @@ def test_sae_fold_w_dec_norm(cfg: LanguageModelSAERunnerConfig): def test_sae_fold_norm_scaling_factor(cfg: LanguageModelSAERunnerConfig): - norm_scaling_factor = 3.0 sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())