diff --git a/src/mechir/modelling/architectures/base/__init__.py b/src/mechir/modelling/architectures/base/__init__.py new file mode 100644 index 0000000..85dd3c7 --- /dev/null +++ b/src/mechir/modelling/architectures/base/__init__.py @@ -0,0 +1,11 @@ +from .linear import ClassificationHead, HiddenLinear +from .components import BertEmbed +from ._model import HookedEncoder, HookedEncoderForSequenceClassification + +__all__ = [ + "HookedEncoder", + "HookedEncoderForSequenceClassification", + "ClassificationHead", + "HiddenLinear", + "BertEmbed", +] diff --git a/src/mechir/modelling/architectures/base/_model.py b/src/mechir/modelling/architectures/base/_model.py new file mode 100644 index 0000000..2856021 --- /dev/null +++ b/src/mechir/modelling/architectures/base/_model.py @@ -0,0 +1,606 @@ +"""Hooked Encoder. + +Contains a BERT style model. This is separate from :class:`transformer_lens.HookedTransformer` +because it has a significantly different architecture to e.g. GPT style transformers. +""" + +from __future__ import annotations + +import logging +import os +from typing import Dict, List, Optional, Tuple, Union, cast, overload + +import torch +from einops import repeat +from jaxtyping import Float, Int +from torch import nn +from transformers import AutoTokenizer, AutoModel +from typing_extensions import Literal + +from mechir.modelling.hooked import loading_from_pretrained as loading +from transformer_lens.ActivationCache import ActivationCache +from transformer_lens.components import ( + BertBlock, + BertMLMHead, + Unembed, + BertNSPHead, + BertPooler, +) +from transformer_lens.FactoredMatrix import FactoredMatrix +from transformer_lens.hook_points import HookedRootModule, HookPoint +from transformer_lens.utilities import devices + +from mechir.modelling.hooked.config import HookedTransformerConfig +from mechir.modelling.architectures.base.components import BertEmbed +from mechir.modelling.architectures.base.linear import ClassificationHead, MLPClassificationHead + + +class HookedEncoder(HookedRootModule): + """ + This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. + + Limitations: + - The model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. + + Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: + - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model + """ + _hf_class = AutoModel + + def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): + super().__init__() + if isinstance(cfg, Dict): + cfg = HookedTransformerConfig(**cfg) + elif isinstance(cfg, str): + raise ValueError( + "Please pass in a config dictionary or HookedTransformerConfig object. If you want to load a pretrained model, use HookedEncoder.from_pretrained() instead." + ) + self.cfg = cfg + + assert ( + self.cfg.n_devices == 1 + ), "Multiple devices not supported for HookedEncoder" + if tokenizer is not None: + self.tokenizer = tokenizer + elif self.cfg.tokenizer_name is not None: + huggingface_token = os.environ.get("HF_TOKEN", "") + self.tokenizer = AutoTokenizer.from_pretrained( + self.cfg.tokenizer_name, + token=huggingface_token if len(huggingface_token) > 0 else None, + ) + else: + self.tokenizer = None + + if self.cfg.d_vocab == -1: + # If we have a tokenizer, vocab size can be inferred from it. + assert ( + self.tokenizer is not None + ), "Must provide a tokenizer if d_vocab is not provided" + self.cfg.d_vocab = max(self.tokenizer.vocab.values()) + 1 + if self.cfg.d_vocab_out == -1: + self.cfg.d_vocab_out = self.cfg.d_vocab + + self.embed = BertEmbed(self.cfg) + self.blocks = nn.ModuleList( + [BertBlock(self.cfg) for _ in range(self.cfg.n_layers)] + ) + self.mlm_head = BertMLMHead(self.cfg) + self.unembed = Unembed(self.cfg) + self.nsp_head = BertNSPHead(self.cfg) + self.pooler = BertPooler(self.cfg) + + self.hook_full_embed = HookPoint() + + self.use_token_type_ids = self.cfg.use_token_type_ids + + if move_to_device: + self.to(self.cfg.device) + + self.setup() + + def to_tokens( + self, + input: Union[str, List[str]], + move_to_device: bool = True, + truncate: bool = True, + ) -> Tuple[ + Int[torch.Tensor, "batch pos"], # noqa: F722 + Int[torch.Tensor, "batch pos"], + Int[torch.Tensor, "batch pos"], + ]: + """Converts a string to a tensor of tokens. + Taken mostly from the HookedTransformer implementation, but does not support default padding + sides or prepend_bos. + Args: + input (Union[str, List[str]]): The input to tokenize. + move_to_device (bool): Whether to move the output tensor of tokens to the device the model lives on. Defaults to True + truncate (bool): If the output tokens are too long, whether to truncate the output + tokens to the model's max context window. Does nothing for shorter inputs. Defaults to + True. + """ + + assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer" + + encodings = self.tokenizer( + input, + return_tensors="pt", + padding=True, + truncation=truncate, + max_length=self.cfg.n_ctx if truncate else None, + ) + + tokens = encodings.input_ids + token_type_ids = encodings.token_type_ids if self.use_token_type_ids else None + attention_mask = encodings.attention_mask + + if move_to_device: + tokens = tokens.to(self.cfg.device) + token_type_ids = ( + token_type_ids.to(self.cfg.device) + if self.use_token_type_ids + else None + ) + attention_mask = attention_mask.to(self.cfg.device) + + return tokens, token_type_ids, attention_mask + + def encoder_output( + self, + tokens: Int[torch.Tensor, "batch pos"], + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + start_at_layer: Optional[int] = None, + stop_at_layer: Optional[int] = None, + one_zero_attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, + ) -> Float[torch.Tensor, "batch pos d_vocab"]: + """Processes input through the encoder layers and returns the resulting residual stream. + + Args: + input: Input tokens as integers with shape (batch, position) + token_type_ids: Optional binary ids indicating segment membership. + Shape (batch_size, sequence_length). For example, with input + "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be + [0, 0, ..., 0, 1, ..., 1, 1] where 0 marks tokens from sentence A + and 1 marks tokens from sentence B. + one_zero_attention_mask: Optional binary mask of shape (batch_size, sequence_length) + where 1 indicates tokens to attend to and 0 indicates tokens to ignore. + Used primarily for handling padding in batched inputs. + + Returns: + resid: Final residual stream tensor of shape (batch, position, d_model) + + Raises: + AssertionError: If using string input without a tokenizer + """ + + if tokens.device.type != self.cfg.device: + tokens = tokens.to(self.cfg.device) + if one_zero_attention_mask is not None: + one_zero_attention_mask = one_zero_attention_mask.to(self.cfg.device) + + resid = self.hook_full_embed(self.embed(tokens, token_type_ids)) + + large_negative_number = -torch.inf + mask = ( + repeat(1 - one_zero_attention_mask, "batch pos -> batch 1 1 pos") + if one_zero_attention_mask is not None + else None + ) + additive_attention_mask = ( + torch.where(mask == 1, large_negative_number, 0) + if mask is not None + else None + ) + + if start_at_layer is None: + start_at_layer = 0 + + idx_and_block = list(zip(range(self.cfg.n_layers), self.blocks)) + + for _, block in idx_and_block[start_at_layer:stop_at_layer]: + resid = block(resid, additive_attention_mask) + + return resid + + def forward( + self, + input: Int[torch.Tensor, "batch pos"], + return_type: Optional[str] = "embeddings", + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, + start_at_layer: Optional[int] = None, + stop_at_layer: Optional[int] = None, + ) -> Union[Float[torch.Tensor, "batch pos d_vocab"], None]: + """Input must be a batch of tokens. Strings and lists of strings are not yet supported. + + return_type Optional[str]: The type of output to return. Can be one of: None (return nothing, don't calculate logits), or 'logits' (return logits). + + token_type_ids Optional[torch.Tensor]: Binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length). + + attention_mask: Optional[torch.Tensor]: A binary mask which indicates which tokens should be attended to (1) and which should be ignored (0). Primarily used for padding variable-length sentences in a batch. For instance, in a batch with sentences of differing lengths, shorter sentences are padded with 0s on the right. If not provided, the model assumes all tokens should be attended to. + """ + + if start_at_layer is None: + if isinstance(input, str) or isinstance(input, list): + assert ( + self.tokenizer is not None + ), "Must provide a tokenizer if input is a string" + input, token_type_ids_from_tokenizer, attention_mask = self.to_tokens( + input + ) + + # If token_type_ids or attention mask are not provided, use the ones from the tokenizer + token_type_ids = ( + token_type_ids_from_tokenizer + if token_type_ids is None + else token_type_ids + ) + else: + assert type(input) is torch.Tensor + residual = input + + if residual.device.type != self.cfg.device: + residual = residual.to(self.cfg.device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.cfg.device) + if start_at_layer is None: + start_at_layer = 0 + + resid = self.encoder_output( + residual, + token_type_ids=token_type_ids, + start_at_layer=start_at_layer, + stop_at_layer=stop_at_layer, + one_zero_attention_mask=attention_mask, + ) + + if stop_at_layer is not None or return_type == "embeddings": + return resid + + resid = self.mlm_head(resid) + logits = self.unembed(resid) + + if return_type == "predictions": + # Get predictions for masked tokens + logprobs = logits[logits == self.tokenizer.mask_token_id].log_softmax( + dim=-1 + ) + predictions = self.tokenizer.decode(logprobs.argmax(dim=-1)) + + # If input was a list of strings, split predictions into a list + if " " in predictions: + # Split along space + predictions = predictions.split(" ") + predictions = [ + f"Prediction {i}: {p}" for i, p in enumerate(predictions) + ] + return predictions + + elif return_type is None: + return None + + return logits + + @overload + def run_with_cache( + self, *model_args, return_cache_object: Literal[True] = True, **kwargs + ) -> Tuple[ + Float[torch.Tensor, "batch pos d_vocab"], + ActivationCache, + ]: ... + + @overload + def run_with_cache( + self, *model_args, return_cache_object: Literal[False], **kwargs + ) -> Tuple[ + Float[torch.Tensor, "batch pos d_vocab"], + Dict[str, torch.Tensor], + ]: ... + + def run_with_cache( + self, + *model_args, + return_cache_object: bool = True, + cache_as_dict: bool = False, + remove_batch_dim: bool = False, + **kwargs, + ) -> Tuple[ + Float[torch.Tensor, "batch pos d_vocab"], + Union[ActivationCache, Dict[str, torch.Tensor]], + ]: + """ + Wrapper around run_with_cache in HookedRootModule. If return_cache_object is True, this will return an ActivationCache object, with a bunch of useful HookedTransformer specific methods, otherwise it will return a dictionary of activations as in HookedRootModule. This function was copied directly from HookedTransformer. + """ + out, cache_dict = super().run_with_cache( + *model_args, remove_batch_dim=remove_batch_dim, **kwargs + ) + if return_cache_object: + if not cache_as_dict: + cache = ActivationCache( + cache_dict, self, has_batch_dim=not remove_batch_dim + ) + return out, cache + else: + return out, None + + def to( # type: ignore + self, + device_or_dtype: Union[torch.device, str, torch.dtype], + print_details: bool = True, + ): + return devices.move_to_and_update_config(self, device_or_dtype, print_details) + + def cuda(self): + # Wrapper around cuda that also changes self.cfg.device + return self.to("cuda") + + def cpu(self): + # Wrapper around cuda that also changes self.cfg.device + return self.to("cpu") + + def mps(self): + # Wrapper around cuda that also changes self.cfg.device + return self.to("mps") + + @classmethod + def from_pretrained( + cls, + model_name: str, + checkpoint_index: Optional[int] = None, + checkpoint_value: Optional[int] = None, + hf_model=None, + device: Optional[str] = None, + tokenizer=None, + move_to_device=True, + dtype=torch.float32, + **from_pretrained_kwargs, + ) -> HookedEncoder: + """Loads in the pretrained weights from huggingface. Currently supports loading weight from HuggingFace BertForMaskedLM. Unlike HookedTransformer, this does not yet do any preprocessing on the model.""" + logging.warning( + "Support for BERT in TransformerLens is currently experimental, until such a time when it has feature " + "parity with HookedTransformer and has been tested on real research tasks. Until then, backward " + "compatibility is not guaranteed. Please see the docs for information on the limitations of the current " + "implementation." + "\n" + "If using BERT for interpretability research, keep in mind that BERT has some significant architectural " + "differences to GPT. For example, LayerNorms are applied *after* the attention and MLP components, meaning " + "that the last LayerNorm in a block cannot be folded." + ) + + assert not ( + from_pretrained_kwargs.get("load_in_8bit", False) + or from_pretrained_kwargs.get("load_in_4bit", False) + ), "Quantization not supported" + + if "torch_dtype" in from_pretrained_kwargs: + dtype = from_pretrained_kwargs["torch_dtype"] + + official_model_name = loading.get_official_model_name(model_name) + + cfg = loading.get_pretrained_model_config( + official_model_name, + checkpoint_index=checkpoint_index, + checkpoint_value=checkpoint_value, + fold_ln=False, + device=device, + n_devices=1, + dtype=dtype, + **from_pretrained_kwargs, + ) + if hf_model is None: + hf_model = cls._hf_class.from_pretrained( + official_model_name, + torch_dtype=dtype, + **from_pretrained_kwargs, + ) + state_dict = loading.get_pretrained_state_dict( + official_model_name, cfg, hf_model, dtype=dtype, **from_pretrained_kwargs + ) + + model = cls(cfg, tokenizer, move_to_device=False) + + model.load_state_dict(state_dict, strict=False) + + if move_to_device: + model.to(cfg.device) + + print(f"Loaded pretrained model {model_name} into HookedEncoder") + + return model + + @property + def W_U(self) -> Float[torch.Tensor, "d_model d_vocab"]: + """ + Convenience to get the unembedding matrix (ie the linear map from the final residual stream to the output logits) + """ + return self.unembed.W_U + + @property + def b_U(self) -> Float[torch.Tensor, "d_vocab"]: + """ + Convenience to get the unembedding bias + """ + return self.unembed.b_U + + @property + def W_E(self) -> Float[torch.Tensor, "d_vocab d_model"]: + """ + Convenience to get the embedding matrix + """ + return self.embed.embed.W_E + + @property + def W_pos(self) -> Float[torch.Tensor, "n_ctx d_model"]: + """ + Convenience function to get the positional embedding. Only works on models with absolute positional embeddings! + """ + return self.embed.pos_embed.W_pos + + @property + def W_E_pos(self) -> Float[torch.Tensor, "d_vocab+n_ctx d_model"]: + """ + Concatenated W_E and W_pos. Used as a full (overcomplete) basis of the input space, useful for full QK and full OV circuits. + """ + return torch.cat([self.W_E, self.W_pos], dim=0) + + @property + def W_K(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the key weights across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.W_K for block in self.blocks], dim=0 + ) + + @property + def W_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the query weights across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.W_Q for block in self.blocks], dim=0 + ) + + @property + def W_V(self) -> Float[torch.Tensor, "n_layers n_heads d_model d_head"]: + """Stacks the value weights across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.W_V for block in self.blocks], dim=0 + ) + + @property + def W_O(self) -> Float[torch.Tensor, "n_layers n_heads d_head d_model"]: + """Stacks the attn output weights across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.W_O for block in self.blocks], dim=0 + ) + + @property + def W_in(self) -> Float[torch.Tensor, "n_layers d_model d_mlp"]: # noqa: F722 + """Stacks the MLP input weights across all layers""" + return torch.stack( + [cast(BertBlock, block).mlp.W_in for block in self.blocks], dim=0 + ) + + @property + def W_out(self) -> Float[torch.Tensor, "n_layers d_mlp d_model"]: + """Stacks the MLP output weights across all layers""" + return torch.stack( + [cast(BertBlock, block).mlp.W_out for block in self.blocks], dim=0 + ) + + @property + def b_K(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the key biases across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.b_K for block in self.blocks], dim=0 + ) + + @property + def b_Q(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the query biases across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.b_Q for block in self.blocks], dim=0 + ) + + @property + def b_V(self) -> Float[torch.Tensor, "n_layers n_heads d_head"]: + """Stacks the value biases across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.b_V for block in self.blocks], dim=0 + ) + + @property + def b_O(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the attn output biases across all layers""" + return torch.stack( + [cast(BertBlock, block).attn.b_O for block in self.blocks], dim=0 + ) + + @property + def b_in(self) -> Float[torch.Tensor, "n_layers d_mlp"]: + """Stacks the MLP input biases across all layers""" + return torch.stack( + [cast(BertBlock, block).mlp.b_in for block in self.blocks], dim=0 + ) + + @property + def b_out(self) -> Float[torch.Tensor, "n_layers d_model"]: + """Stacks the MLP output biases across all layers""" + return torch.stack( + [cast(BertBlock, block).mlp.b_out for block in self.blocks], dim=0 + ) + + @property + def QK(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + """Returns a FactoredMatrix object with the product of the Q and K matrices for each layer and head. + Useful for visualizing attention patterns.""" + return FactoredMatrix(self.W_Q, self.W_K.transpose(-2, -1)) + + @property + def OV(self) -> FactoredMatrix: # [n_layers, n_heads, d_model, d_model] + """Returns a FactoredMatrix object with the product of the O and V matrices for each layer and head.""" + return FactoredMatrix(self.W_V, self.W_O) + + def all_head_labels(self) -> List[str]: + """Returns a list of strings with the format "L{l}H{h}", where l is the layer index and h is the head index.""" + return [ + f"L{l}H{h}" + for l in range(self.cfg.n_layers) + for h in range(self.cfg.n_heads) + ] + + +class HookedEncoderForSequenceClassification(HookedEncoder): + """ + This class implements a BERT-style encoder using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedRootModule. + + Limitations: + - The current MVP implementation supports only the masked language modelling (MLM) task. Next sentence prediction (NSP), causal language modelling, and other tasks are not yet supported. + - Also note that model does not include dropouts, which may lead to inconsistent results from training or fine-tuning. + + Like HookedTransformer, it can have a pretrained Transformer's weights loaded via `.from_pretrained`. There are a few features you might know from HookedTransformer which are not yet supported: + - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model + - The model only accepts tokens as inputs, and not strings, or lists of strings + """ + + def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): + super().__init__(cfg, tokenizer, move_to_device, **kwargs) + self.classifier = ( + ClassificationHead(cfg) + if not self.cfg.use_mlp_head + else MLPClassificationHead(cfg) + ) + self.setup() + + def forward( + self, + input: Int[torch.Tensor, "batch pos"], + return_type: Optional[str] = "embeddings", + token_type_ids: Optional[Int[torch.Tensor, "batch pos"]] = None, + attention_mask: Optional[Int[torch.Tensor, "batch pos"]] = None, + start_at_layer: Optional[int] = None, + stop_at_layer: Optional[int] = None, + ) -> Optional[Float[torch.Tensor, "batch pos d_vocab"]]: + """Input must be a batch of tokens. Strings and lists of strings are not yet supported. + + return_type Optional[str]: The type of output to return. Can be one of: None (return nothing, don't calculate logits), or 'logits' (return logits). + + token_type_ids Optional[torch.Tensor]: Binary ids indicating whether a token belongs to sequence A or B. For example, for two sentences: "[CLS] Sentence A [SEP] Sentence B [SEP]", token_type_ids would be [0, 0, ..., 0, 1, ..., 1, 1]. `0` represents tokens from Sentence A, `1` from Sentence B. If not provided, BERT assumes a single sequence input. Typically, shape is (batch_size, sequence_length). + + attention_mask: Optional[torch.Tensor]: A binary mask which indicates which tokens should be attended to (1) and which should be ignored (0). Primarily used for padding variable-length sentences in a batch. For instance, in a batch with sentences of differing lengths, shorter sentences are padded with 0s on the right. If not provided, the model assumes all tokens should be attended to. + """ + + hidden = super().forward( + input, + token_type_ids=token_type_ids, + start_at_layer=start_at_layer, + stop_at_layer=stop_at_layer, + return_type="embeddings", + attention_mask=attention_mask, + ) + if return_type == "embeddings" or stop_at_layer is not None: + return hidden + logits = self.classifier(hidden[:, 0, :]) + + if return_type is None: + return None + return logits + + +__all__ = ["HookedEncoder", "HookedEncoderForSequenceClassification"] diff --git a/src/mechir/modelling/hooked/components.py b/src/mechir/modelling/architectures/base/components.py similarity index 100% rename from src/mechir/modelling/hooked/components.py rename to src/mechir/modelling/architectures/base/components.py diff --git a/src/mechir/modelling/hooked/linear.py b/src/mechir/modelling/architectures/base/linear.py similarity index 100% rename from src/mechir/modelling/hooked/linear.py rename to src/mechir/modelling/architectures/base/linear.py diff --git a/src/mechir/modelling/architectures/distilbert/__init__.py b/src/mechir/modelling/architectures/distilbert/__init__.py new file mode 100644 index 0000000..6b8bc0c --- /dev/null +++ b/src/mechir/modelling/architectures/distilbert/__init__.py @@ -0,0 +1,6 @@ +from ._model import HookedDistilBert, HookedDistilBertForSequenceClassification + +__all__ = [ + "HookedDistilBert", + "HookedDistilBertForSequenceClassification", +] diff --git a/src/mechir/modelling/architectures/distilbert.py b/src/mechir/modelling/architectures/distilbert/_model.py similarity index 91% rename from src/mechir/modelling/architectures/distilbert.py rename to src/mechir/modelling/architectures/distilbert/_model.py index e245629..f39cc7e 100644 --- a/src/mechir/modelling/architectures/distilbert.py +++ b/src/mechir/modelling/architectures/distilbert/_model.py @@ -12,20 +12,22 @@ from einops import repeat from jaxtyping import Float, Int from torch import nn -from transformers import AutoTokenizer +from transformers import AutoTokenizer, DistilBertModel, DistilBertForSequenceClassification from typing_extensions import Literal from transformer_lens.components import BertBlock, BertMLMHead, Unembed from transformer_lens.hook_points import HookPoint -from mechir.modelling.hooked.components import BertEmbed -from mechir.modelling.hooked.linear import MLPClassificationHead -from mechir.modelling.architectures.base import HookedEncoder +from mechir.modelling.architectures.base.components import BertEmbed +from mechir.modelling.architectures.base.linear import MLPClassificationHead +from mechir.modelling.architectures.base._model import HookedEncoder from mechir.modelling.hooked.config import HookedTransformerConfig -HookedDistilBert = HookedEncoder +class HookedDistilBert(HookedEncoder): + _hf_class = DistilBertModel class HookedDistilBertForSequenceClassification(HookedDistilBert): + _hf_class = DistilBertForSequenceClassification """ This class implements a BERT-style encoder for sequence classification using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedDistilBert. diff --git a/src/mechir/modelling/architectures/electra/__init__.py b/src/mechir/modelling/architectures/electra/__init__.py new file mode 100644 index 0000000..318a1ba --- /dev/null +++ b/src/mechir/modelling/architectures/electra/__init__.py @@ -0,0 +1,6 @@ +from ._model import HookedElectra, HookedElectraForSequenceClassification + +__all__ = [ + "HookedElectra", + "HookedElectraForSequenceClassification", +] diff --git a/src/mechir/modelling/architectures/electra.py b/src/mechir/modelling/architectures/electra/_model.py similarity index 92% rename from src/mechir/modelling/architectures/electra.py rename to src/mechir/modelling/architectures/electra/_model.py index 51ed8d7..23c4255 100644 --- a/src/mechir/modelling/architectures/electra.py +++ b/src/mechir/modelling/architectures/electra/_model.py @@ -9,12 +9,13 @@ import logging from typing import Dict, Optional, Union +from transformers import ElectraModel, ElectraForSequenceClassification import torch from jaxtyping import Float, Int from torch import nn from transformer_lens.hook_points import HookPoint -from mechir.modelling.hooked.linear import ClassificationHead, HiddenLinear -from mechir.modelling.architectures.base import HookedEncoder +from mechir.modelling.architectures.base.linear import ClassificationHead, HiddenLinear +from mechir.modelling.architectures.base._model import HookedEncoder from mechir.modelling.hooked.config import HookedTransformerConfig @@ -38,7 +39,10 @@ def forward(self, resid: Float[torch.Tensor, "batch d_model"]) -> torch.Tensor: post_act = self.hook_post(self.activation(pre_act)) return self.out_proj(post_act) -HookedElectra = HookedEncoder + +class HookedElectra(HookedEncoder): + _hf_class = ElectraModel + class HookedElectraForSequenceClassification(HookedEncoder): """ @@ -49,6 +53,7 @@ class HookedElectraForSequenceClassification(HookedEncoder): - There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model - The model only accepts tokens as inputs, and not strings, or lists of strings """ + _hf_class = ElectraForSequenceClassification def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs): super().__init__(cfg, tokenizer, move_to_device, **kwargs) diff --git a/src/mechir/modelling/architectures/modernbert/__init__.py b/src/mechir/modelling/architectures/modernbert/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mechir/modelling/architectures/modernbert/components.py b/src/mechir/modelling/architectures/modernbert/components.py new file mode 100644 index 0000000..789732b --- /dev/null +++ b/src/mechir/modelling/architectures/modernbert/components.py @@ -0,0 +1,778 @@ +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +import torch +import torch.nn as nn +from typing import Optional +from jaxtyping import Float, Int +from transformer_lens.components.abstract_attention import AbstractAttention + +from mechir.modelling.hooked.config import HookedTransformerConfig + + + +class ModernBertRotaryEmbedding(nn.Module): + def __init__(self, cfg: Union[Dict, HookedTransformerConfig]): + super().__init__() + config = HookedTransformerConfig.unwrap(cfg) + + + # BC: "rope_type" was originally "type" + if hasattr(config, "rope_scaling") and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) + else: + self.rope_type = "default" + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + inv_freq, self.attention_scaling = self.rope_init_fn(None, device, dim=dim, base=base) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) + position_ids_expanded = position_ids[:, None, :].float() + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + + +def sdpa_attention_forward( + module: "ModernBertAttention", + qkv: torch.Tensor, + attention_mask: torch.Tensor, + sliding_window_mask: torch.Tensor, + position_ids: Optional[torch.LongTensor], + local_attention: Tuple[int, int], + bs: int, + dim: int, + **_kwargs, +) -> Tuple[torch.Tensor]: + # qkv: [batch_size, seqlen, 3, nheads, headdim] + cos, sin = module.rotary_emb(qkv, position_ids=position_ids) + query, key, value = qkv.transpose(3, 1).unbind(dim=2) + # query, key, value: [batch_size, heads, seq_len, head_dim] + query, key = apply_rotary_pos_emb(query, key, cos, sin) + + if local_attention != (-1, -1): + attention_mask = sliding_window_mask + + attn_output = ( + F.scaled_dot_product_attention( + query, + key, + value, + dropout_p=module.attention_dropout if module.training else 0.0, + attn_mask=attention_mask, + ) + .transpose(1, 2) + .contiguous() + ) + attn_output = attn_output.view(bs, -1, dim) + return (attn_output,) + + +class ModernBertAttention(nn.Module): + """Performs multi-headed self attention on a batch of unpadded sequences.""" + + def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None): + super().__init__() + self.config = config + self.layer_id = layer_id + + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention heads ({config.num_attention_heads})" + ) + + self.attention_dropout = config.attention_dropout + self.num_heads = config.num_attention_heads + self.head_dim = config.hidden_size // config.num_attention_heads + self.all_head_size = self.head_dim * self.num_heads + self.Wqkv = nn.Linear(config.hidden_size, 3 * self.all_head_size, bias=config.attention_bias) + + if layer_id % config.global_attn_every_n_layers != 0: + self.local_attention = (config.local_attention // 2, config.local_attention // 2) + else: + self.local_attention = (-1, -1) + + rope_theta = config.global_rope_theta + max_position_embeddings = config.max_position_embeddings + if self.local_attention != (-1, -1): + if config.local_rope_theta is not None: + rope_theta = config.local_rope_theta + max_position_embeddings = config.local_attention + + + self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta) + + self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias) + self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity() + self.pruned_heads = set() + + def forward( + self, + hidden_states: torch.Tensor, + output_attentions: Optional[bool] = False, + **kwargs, + ) -> torch.Tensor: + qkv = self.Wqkv(hidden_states) + + bs = hidden_states.shape[0] + if self.config._attn_implementation == "flash_attention_2": + qkv = qkv.view(-1, 3, self.num_heads, self.head_dim) + else: + qkv = qkv.view(bs, -1, 3, self.num_heads, self.head_dim) + + attn_outputs = sdpa_attention_forward( + self, + qkv=qkv, + rotary_emb=self.rotary_emb, + local_attention=self.local_attention, + bs=bs, + dim=self.all_head_size, + output_attentions=output_attentions, + **kwargs, + ) + hidden_states = attn_outputs[0] + hidden_states = self.out_drop(self.Wo(hidden_states)) + + return (hidden_states,) + attn_outputs[1:] # add attentions if outputted + +class ModernBertAttention(nn.Module): + def __init__( + self, + cfg: Union[Dict, HookedTransformerConfig], + attn_type: str = "global", + layer_id: Optional[int] = None, + ): + """Abstract Base Class of Attention Blocks, featuring common functionality of both Attention and GroupedQueryAttention blocks. + + Query and Output projections are defined in this class as they are the same for regular and grouped query attention. + Attributes related to Key and Value projections are abstract as their implementations may differ. For example, in GroupedQueryAttention there are less query and key heads than value heads. + To enforce implementation of W_K, W_V, b_K, and b_V by child classes, the better_abc.abstract_attribute class is used. See here for details: https://stackoverflow.com/questions/23831510/abstract-attribute-not-property. + + Args: + cfg (Union[Dict, HookedTransformerConfig]): Config + attn_type (str, optional): "global" or "local", used by GPT-Neo. Local attention means the model can only attend back cfg.window_size tokens (here, 256). Not used by any other model at the moment. Defaults to "global". + layer_id (int, optional): The index of the current layer. Used by the Mistral models (labelled here as stanford-gpt2) to scale down attention scores pre softmax for numerical stability reasons by 1/(layer_id+1). Defaults to None. + """ + super().__init__() + self.cfg = HookedTransformerConfig.unwrap(cfg) + + if self.cfg.load_in_4bit: + nq = int((self.cfg.d_model * self.cfg.d_head * self.cfg.n_heads) / 2) + self.W_Q = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) + self.W_O = Params4bit(torch.empty(nq, 1, dtype=torch.uint8), requires_grad=False) + else: + self.W_Q = nn.Parameter( + torch.empty( + self.cfg.n_heads, + self.cfg.d_model, + self.cfg.d_head, + dtype=self.cfg.dtype, + ) + ) + self.W_O = nn.Parameter( + torch.empty( + self.cfg.n_heads, + self.cfg.d_head, + self.cfg.d_model, + dtype=self.cfg.dtype, + ) + ) + self.W_K = abstract_attribute() + self.W_V = abstract_attribute() + + self.b_Q = nn.Parameter( + torch.zeros(self.cfg.n_heads, self.cfg.d_head, dtype=self.cfg.dtype) + ) + self.b_K: nn.Parameter = abstract_attribute() + self.b_V: nn.Parameter = abstract_attribute() + self.b_O = nn.Parameter(torch.zeros(self.cfg.d_model, dtype=self.cfg.dtype)) + + self.attn_type = attn_type + # Create a max_ctx x max_ctx mask, with True iff that query position + # can attend to that key position (query is first axis, key is second axis) + causal_mask = torch.tril(torch.ones((self.cfg.n_ctx, self.cfg.n_ctx)).bool()) + if self.attn_type == "global": + # For global attention, this is a lower triangular matrix - key <= query + self.register_buffer("mask", causal_mask) + elif self.attn_type == "local": + # For local, this is banded, query - window_size < key <= query + if not isinstance(self.cfg.window_size, int): + raise ValueError("Window size must be an integer for local attention") + self.register_buffer("mask", torch.triu(causal_mask, 1 - self.cfg.window_size)) + else: + raise ValueError(f"Invalid attention type: {self.attn_type}") + + self.register_buffer("IGNORE", torch.tensor(-torch.inf)) + + self.layer_id = layer_id + + # attn_scale is a constant that we divide the attention scores by pre-softmax. I'm not entirely sure why it matters, but it's probably a mix of softmax not being scale invariant and numerical stability? + if self.cfg.use_attn_scale: + self.attn_scale = self.cfg.attn_scale # Defaults to sqrt(d_head) + else: + self.attn_scale = 1.0 + if self.cfg.scale_attn_by_inverse_layer_idx: + if self.layer_id is None: # keep mypy happy + raise ValueError("Layer ID must be provided to scale attention scores") + self.attn_scale *= self.layer_id + 1 + + self.hook_k = HookPoint() # [batch, pos, head_index, d_head] + self.hook_q = HookPoint() # [batch, pos, head_index, d_head] + self.hook_v = HookPoint() # [batch, pos, head_index, d_head] + self.hook_z = HookPoint() # [batch, pos, head_index, d_head] + self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos] + self.hook_pattern = HookPoint() # [batch, head_index, query_pos, key_pos] + self.hook_result = HookPoint() # [batch, pos, head_index, d_model] + + # Applies a rotation to each two-element chunk of keys and queries pre dot producting to bake in relative position. See HookedTransformerConfig for details + self.hook_rot_k = HookPoint() + self.hook_rot_q = HookPoint() + if self.cfg.rotary_dim is None: # keep mypy happy + raise ValueError("Rotary dim must be provided for rotary positional embeddings") + sin, cos = self.calculate_sin_cos_rotary( + self.cfg.rotary_dim, + self.cfg.n_ctx, + base=self.cfg.rotary_base, + dtype=self.cfg.dtype, + ) + self.register_buffer("rotary_sin", sin) + self.register_buffer("rotary_cos", cos) + + @property + def OV(self) -> FactoredMatrix: + """ + OV-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity between the value vector and the output of the layer, the output is purely determined by the matrix W_OV = W_V @ W_O, and not W_V or W_O individually. (Mathematically, for a single head, output == pattern @ residual @ W_V @ W_O, see the glossary for more) + + Done in the order W_V, W_O because the paper uses left-multiplying weight matrices, and TransformerLens uses right-multiplying, sorry! + + Returns a FactoredMatrix, with left matrix W_V [head_index, d_model, d_head] and right matrix W_O [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model]. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the OV circuit of a head k, attn.OV[k] works. + """ + return FactoredMatrix(self.W_V, self.W_O) + + @property + def QK(self) -> FactoredMatrix: + """ + QK-Circuit, as defined in A Mathematical Framework. Because there's no non-linearity in the key-query dot product, the output is purely determined by the matrix W_QK = W_Q.T @ W_K, and not W_Q or W_K individually. (Mathematically, for a single head, pattern = destination_residual.T @ W_Q.T @ W_K @ source-residual, see the glossary for more). + + Done in the order Q on the left, K on the right, because the pattern has dimensions [destination_pos, source_pos] + + Returns a FactoredMatrix, with left matrix W_Q [head_index, d_model, d_head] and right matrix W_K.T [head_index, d_head, d_model] - this is a low rank factorisation of the underlying [head_index, d_model, d_model] matrix. FactoredMatrix has helper functions to deal with these large matrices efficiently. To get the QK circuit of a head k, attn.QK[k] works. + """ + W_K_transpose = einops.rearrange( + self.W_K, "head_index d_model d_head -> head_index d_head d_model" + ) + return FactoredMatrix(self.W_Q, W_K_transpose) + + def forward( + self, + query_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + key_input: Union[ + Float[torch.Tensor, "batch kv_pos d_model"], + Float[torch.Tensor, "batch kv_pos head_index d_model"], + Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], + ], + value_input: Union[ + Float[torch.Tensor, "batch kv_pos d_model"], + Float[torch.Tensor, "batch kv_pos head_index d_model"], + Float[torch.Tensor, "batch kv_pos kv_head_index d_model"], + ], + past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, + additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None, + attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, + position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None, + ) -> Float[torch.Tensor, "batch pos d_model"]: + """ + shortformer_pos_embed is only used if self.cfg.positional_embedding_type == "shortformer", else defaults to None and is irrelevant. See HookedTransformerConfig for more details + past_kv_cache_entry is an optional entry of past keys and values for this layer, only relevant if generating text. Defaults to None + additive_attention_mask is an optional mask to add to the attention weights. Defaults to None. + attention_mask is the attention mask for padded tokens. Defaults to None. + """ + + q, k, v = self.calculate_qkv_matrices(query_input, key_input, value_input) + + if past_kv_cache_entry is not None: + # Appends the new keys and values to the cached values, and automatically updates the cache + kv_cache_pos_offset = past_kv_cache_entry.past_keys.size(1) + k, v = past_kv_cache_entry.append(k, v) + else: + # Not using a cache + kv_cache_pos_offset = 0 + + q = self.hook_rot_q(self.apply_rotary(q, kv_cache_pos_offset, attention_mask)) + k = self.hook_rot_k( + self.apply_rotary(k, 0, attention_mask) + ) # keys are cached so no offset + + if self.cfg.dtype not in [torch.float32, torch.float64]: + # If using 16 bits, increase the precision to avoid numerical instabilities + q = q.to(torch.float32) + k = k.to(torch.float32) + + attn_scores = self.calculate_attention_scores( + q, k + ) # [batch, head_index, query_pos, key_pos] + + if position_bias is None: + if self.has_relative_attention_bias: + raise ValueError("Positional bias is required for relative_positional_bias") + else: + position_bias = torch.zeros( + 1, + self.cfg.n_heads, + attn_scores.shape[2], + attn_scores.shape[3], + device=attn_scores.device, + ) + + attn_scores += position_bias + if self.cfg.attention_dir == "causal": + # If causal attention, we mask it to only attend backwards. If bidirectional, we don't mask. + attn_scores = self.apply_causal_mask( + attn_scores, kv_cache_pos_offset, attention_mask + ) # [batch, head_index, query_pos, key_pos] + if additive_attention_mask is not None: + attn_scores += additive_attention_mask + + attn_scores = self.hook_attn_scores(attn_scores) + pattern = F.softmax(attn_scores, dim=-1) + pattern = torch.where(torch.isnan(pattern), torch.zeros_like(pattern), pattern) + pattern = self.hook_pattern(pattern) # [batch, head_index, query_pos, key_pos] + pattern = pattern.to(self.cfg.dtype) + pattern = pattern.to(v.device) + z = self.calculate_z_scores(v, pattern) # [batch, pos, head_index, d_head] + if not self.cfg.use_attn_result: + if self.cfg.load_in_4bit: + # call bitsandbytes method to dequantize and multiply + out = ( + bnb.matmul_4bit( + z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), + self.W_O.t(), + # bias=self.W_O.t(), + bias=None, + quant_state=self.W_O.quant_state, + ) + + self.b_O + ) + else: + w = einops.rearrange( + self.W_O, "head_index d_head d_model -> d_model (head_index d_head)" + ) + + if self.b_O.device != w.device: + w = w.to(self.b_O.device) + if self.b_O.device != z.device: + z = z.to(self.b_O.device) + + out = F.linear( + z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), + w, + self.b_O, + ) + else: + # Explicitly calculate the attention result so it can be accessed by a hook + # This is off by default because it can easily eat through your GPU memory. + if self.cfg.load_in_4bit: + result = self.hook_result( + bnb.matmul_4bit( + z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads), + self.W_O.t(), + bias=None, + quant_state=self.W_O.quant_state, + ) + ) + else: + # Add singleton dimensions to make shapes compatible for broadcasting: + w = einops.rearrange( + self.W_O, + "head_index d_head d_model -> 1 1 head_index d_head d_model", + ) + z = einops.rearrange( + z, "batch pos head_index d_head -> batch pos head_index d_head 1" + ) + + # Multiply the z tensor by the W_O tensor, summing over the d_head dimension + unhooked_result = (z * w).sum(-2) + + result = self.hook_result(unhooked_result) # [batch, pos, head_index, d_model] + out = ( + einops.reduce(result, "batch position index model->batch position model", "sum") + + self.b_O + ) # [batch, pos, d_model] + return out + + def calculate_qkv_matrices( + self, + query_input: Union[ + Float[torch.Tensor, "batch pos d_model"], + Float[torch.Tensor, "batch pos head_index d_model"], + ], + key_input: Union[ + Float[torch.Tensor, "batch kv_pos d_model"], + Float[torch.Tensor, "batch kv_pos head_index d_model"], + ], + value_input: Union[ + Float[torch.Tensor, "batch kv_pos d_model"], + Float[torch.Tensor, "batch kv_pos head_index d_model"], + ], + ) -> Tuple[ + Float[torch.Tensor, "batch pos head_index d_head"], + Float[torch.Tensor, "batch kv_pos head_index d_head"], + Float[torch.Tensor, "batch kv_pos head_index d_head"], + ]: + attn_fn = ( + complex_attn_linear + if self.cfg.use_split_qkv_input or self.cfg.use_attn_in + else simple_attn_linear + ) + if self.cfg.load_in_4bit: + q = self.hook_q( + # call bitsandbytes method to dequantize and multiply + bnb.matmul_4bit( + query_input, + self.W_Q.t(), + bias=None, + quant_state=self.W_Q.quant_state, + ).reshape( + query_input.shape[0], + query_input.shape[1], + self.cfg.n_heads, + self.cfg.d_head, + ) + + self.b_Q + ) + else: + q = self.hook_q(attn_fn(query_input, self.W_Q, self.b_Q)) + if self.cfg.load_in_4bit: + if not isinstance(self.W_K, Params4bit): + raise ValueError("W_K must be a Params4bit object if load_in_4bit is True") + k = self.hook_k( + # call bitsandbytes method to dequantize and multiply + bnb.matmul_4bit( + key_input, self.W_K.t(), bias=None, quant_state=self.W_K.quant_state + ).reshape( + key_input.shape[0], + key_input.shape[1], + self.cfg.n_heads, + self.cfg.d_head, + ) + + self.b_K + ) + else: + k = self.hook_k(attn_fn(key_input, self.W_K, self.b_K)) + + if self.cfg.load_in_4bit: + if not isinstance(self.W_V, Params4bit): + raise ValueError("W_V must be a Params4bit object if load_in_4bit is True") + v = self.hook_v( + # call bitsandbytes method to dequantize and multiply + bnb.matmul_4bit( + value_input, + self.W_V.t(), + bias=None, + quant_state=self.W_V.quant_state, + ).reshape( + value_input.shape[0], + value_input.shape[1], + self.cfg.n_heads, + self.cfg.d_head, + ) + + self.b_V + ) + else: + v = self.hook_v(attn_fn(value_input, self.W_V, self.b_V)) + + return q, k, v + + def calculate_attention_scores( + self, + q: Float[torch.Tensor, "batch query_pos head_index d_head"], + k: Float[torch.Tensor, "batch key_pos head_index d_head"], + ) -> Float[torch.Tensor, "batch head_index query_pos key_pos"]: + q_ = einops.rearrange( + q, "batch query_pos head_index d_head -> batch head_index query_pos d_head" + ) + k_ = einops.rearrange( + k, "batch key_pos head_index d_head -> batch head_index d_head key_pos" + ) + attn_scores = q_ @ k_ / self.attn_scale + if self.cfg.attn_scores_soft_cap > 0: + attn_scores = self.cfg.attn_scores_soft_cap * F.tanh( + attn_scores / self.cfg.attn_scores_soft_cap + ) + return attn_scores + + def calculate_z_scores( + self, + v: Float[torch.Tensor, "batch key_pos head_index d_head"], + pattern: Float[torch.Tensor, "batch head_index query_pos key_pos"], + ) -> Float[torch.Tensor, "batch query_pos head_index d_head"]: + v_ = einops.rearrange( + v, "batch key_pos head_index d_head -> batch head_index key_pos d_head" + ) + pattern_ = einops.rearrange( + pattern, + "batch head_index query_pos key_pos -> batch head_index query_pos key_pos", + ) + z = self.hook_z( + einops.rearrange( + pattern_ @ v_, + "batch head_index query_pos d_head -> batch query_pos head_index d_head", + ) + ) + return z + + def apply_causal_mask( + self, + attn_scores: Float[torch.Tensor, "batch head_index pos pos_plus_past_kv_pos_offset"], + past_kv_pos_offset: int = 0, + attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, + ): + # The query context length is the number of positions we take queries from - if not using a past_kv_cache this is just the context length (for the current prompt), but if we're caching it can be different. + query_ctx_length = attn_scores.size(-2) + # The key context length is the number of positions in the past - this includes all positions in the cache + # If not caching, query_ctx_length == key_ctx_length + key_ctx_length = attn_scores.size(-1) + + if query_ctx_length + past_kv_pos_offset != key_ctx_length: + raise ValueError( + f"query_ctx_length {query_ctx_length} + past_kv_pos_offset {past_kv_pos_offset} != key_ctx_length {key_ctx_length} - you likely have a bug." + ) + + # Index back to front to ensure local attention works + final_mask = self.mask[None, None, -query_ctx_length:, -key_ctx_length:] # [1, 1, pos, pos] + if attention_mask is not None: + # Apply a causal mask to the attention scores considering the padding + + # Add singleton dimensions to the attention mask to match the shape of the final mask + attention_mask = einops.rearrange( + attention_mask, "batch offset_pos -> batch 1 1 offset_pos" + ) + + final_mask = final_mask.to(attention_mask.device) + + # Element-wise multiplication of the final mask and the attention mask and cast to boolean + final_mask = (final_mask * attention_mask).bool() # [batch, head, pos, offset_pos] + + attn_scores = attn_scores.to(final_mask.device) + return torch.where(final_mask, attn_scores, self.IGNORE) + + def calculate_sin_cos_rotary( + self, + rotary_dim: int, + n_ctx: int, + base: int = 10000, + dtype: torch.dtype = torch.float32, + ) -> Tuple[Float[torch.Tensor, "n_ctx rotary_dim"], Float[torch.Tensor, "n_ctx rotary_dim"]]: + """ + Calculate the sine and cosine waves to use in a rotary embedding. See https://blog.eleuther.ai/rotary-embeddings/ for details + + Note: For some inexplicable reason, in GPT-J each ADJACENT pair of elements in k and q are rotated, in GPT-NeoX the pair of elements at k and k+n//2 are rotated (ie folding the full length in half, and then looking at pairs accordingly). I have absolutely no clue why, it should be completely equivalent. + To resolve this, I've coded it to default to the GPT-J mode, but to explicitly check whether it's GPT-NeoX and then do the GPT-NeoX thing if it is. + """ + high_precision = torch.float32 if dtype != torch.float64 else torch.float64 + pos = torch.arange(n_ctx, dtype=high_precision) + dim = torch.arange(rotary_dim // 2, dtype=high_precision) + + # Llama-3.1 uses NTK-by-Parts Rotary Embedding introduced in Section 3.2 in https://arxiv.org/pdf/2309.00071 + # Implementation copied from https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/modeling_rope_utils.py#L310 + if self.cfg.use_NTK_by_parts_rope: + inv_freq = 1.0 / ( + base ** (torch.arange(0, rotary_dim, 2, dtype=torch.int64).float() / rotary_dim) + ) + factor = self.cfg.NTK_by_parts_factor + low_freq_factor = self.cfg.NTK_by_parts_low_freq_factor + high_freq_factor = self.cfg.NTK_by_parts_high_freq_factor + old_context_len = self.cfg.NTK_original_ctx_len + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + wavelen = 2 * math.pi / inv_freq + inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq) + smooth_factor = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_inv_freq = ( + 1 - smooth_factor + ) * inv_freq_llama / factor + smooth_factor * inv_freq_llama + is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama) + freq = 1 / inv_freq_llama + else: + freq = base ** (dim / (rotary_dim / 2)) + if self.cfg.rotary_adjacent_pairs: + freq = einops.repeat(freq, "d -> (d 2)") + else: + freq = einops.repeat(freq, "d -> (2 d)") + # Create a n_ctx x rotary_dim tensor, where each column is an arithmetic sequence of angles in that frequency + angles = pos[:, None] / freq[None, :] + return torch.sin(angles).to(dtype), torch.cos(angles).to(dtype) + + def rotate_every_two( + self, x: Float[torch.Tensor, "... rotary_dim"] + ) -> Float[torch.Tensor, "... rotary_dim"]: + """ + Rotary helper function, splits x into blocks of size 2 along the final axis and maps [x0, x1] to [-x1, x0] + + The final axis of x must have even length. + + GPT-NeoX and GPT-J do rotary subtly differently, see calculate_sin_cos_rotary for details. + """ + rot_x = x.clone() + if self.cfg.rotary_adjacent_pairs: + rot_x[..., ::2] = -x[..., 1::2] + rot_x[..., 1::2] = x[..., ::2] + else: + n = x.size(-1) // 2 + rot_x[..., :n] = -x[..., n:] + rot_x[..., n:] = x[..., :n] + + return rot_x + + def apply_rotary( + self, + x: Float[torch.Tensor, "batch pos head_index d_head"], + past_kv_pos_offset=0, + attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, + ) -> Float[torch.Tensor, "batch pos head_index d_head"]: + # Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions) + + if x.device != self.rotary_sin.device: + x = x.to(self.rotary_sin.device) + + x_pos = x.size(1) + x_rot = x[..., : self.cfg.rotary_dim] + x_pass = x[..., self.cfg.rotary_dim :] + x_flip = self.rotate_every_two(x_rot) + + if attention_mask is None: + rotary_cos = self.rotary_cos[ + None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : + ] + rotary_sin = self.rotary_sin[ + None, past_kv_pos_offset : past_kv_pos_offset + x_pos, None, : + ] + x_rotated = x_rot * rotary_cos + x_flip * rotary_sin + else: + offset_position_ids = get_offset_position_ids(past_kv_pos_offset, attention_mask) + offset_position_ids = offset_position_ids.to(self.rotary_cos.device) + mask_rotary_cos = self.rotary_cos[offset_position_ids, None, :] + mask_rotary_sin = self.rotary_sin[offset_position_ids, None, :] + x_rotated = x_rot * mask_rotary_cos + x_flip * mask_rotary_sin + + return torch.cat([x_rotated, x_pass], dim=-1) + +class ModenBertBlock(nn.Module): + """ + BERT Block. Similar to the TransformerBlock, except that the LayerNorms are applied after the attention and MLP, rather than before. + """ + + def __init__(self, cfg: HookedTransformerConfig, layer_id: Optional[int] = None): + super().__init__() + self.cfg = cfg + + self.attn = ModernBertAttention(cfg, layer_id=layer_id) + self.ln1 = LayerNorm(cfg) if layer_id != 0 else nn.Identity() # No LayerNorm on the first layer + self.mlp = MLPFactory.create_mlp(self.cfg) + self.ln2 = LayerNorm(cfg) + + self.hook_q_input = HookPoint() # [batch, pos, n_heads, d_model] + self.hook_k_input = HookPoint() # [batch, pos, n_heads, d_model] + self.hook_v_input = HookPoint() # [batch, pos, n_heads, d_model] + + self.hook_attn_out = HookPoint() # [batch, pos, d_model] + self.hook_mlp_in = HookPoint() # [batch, pos, d_model] + self.hook_mlp_out = HookPoint() # [batch, pos, d_model] + self.hook_resid_pre = HookPoint() # [batch, pos, d_model] + self.hook_resid_mid = HookPoint() # [batch, pos, d_model] + self.hook_resid_post = HookPoint() # [batch, pos, d_model] + self.hook_normalized_resid_post = HookPoint() # [batch, pos, d_model] + + def forward( + self, + self, + resid_pre: Float[torch.Tensor, "batch pos d_model"], + past_kv_cache_entry: Optional[HookedTransformerKeyValueCacheEntry] = None, + additive_attention_mask: Optional[Float[torch.Tensor, "batch 1 1 kv_pos"]] = None, + attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None, + position_bias: Optional[Float[torch.Tensor, "1 head_index pos kv_pos"]] = None, + ) -> Float[torch.Tensor, "batch pos d_model"]: + resid_pre = self.hook_resid_pre(resid_pre) + + query_input = resid_pre + key_input = resid_pre + value_input = resid_pre + + if self.cfg.use_split_qkv_input: + n_heads = self.cfg.n_heads + query_input = self.hook_q_input(repeat_along_head_dimension(query_input, n_heads)) + key_input = self.hook_k_input(repeat_along_head_dimension(key_input, n_heads)) + value_input = self.hook_v_input(repeat_along_head_dimension(value_input, n_heads)) + + attn_out = self.hook_attn_out( + self.attn( + query_input, + key_input, + value_input, + additive_attention_mask=additive_attention_mask, + ) + ) + resid_mid = self.hook_resid_mid(resid_pre + attn_out) + + mlp_in = resid_mid if not self.cfg.use_hook_mlp_in else self.hook_mlp_in(resid_mid.clone()) + normalized_resid_mid = self.ln1(mlp_in) + mlp_out = self.hook_mlp_out(self.mlp(normalized_resid_mid)) + resid_post = self.hook_resid_post(normalized_resid_mid + mlp_out) + normalized_resid_post = self.hook_normalized_resid_post(self.ln2(resid_post)) + + return normalized_resid_post diff --git a/src/mechir/modelling/architectures/base.py b/src/mechir/modelling/architectures/modernbert/modernbert.py similarity index 99% rename from src/mechir/modelling/architectures/base.py rename to src/mechir/modelling/architectures/modernbert/modernbert.py index 63d66cb..eb6127d 100644 --- a/src/mechir/modelling/architectures/base.py +++ b/src/mechir/modelling/architectures/modernbert/modernbert.py @@ -31,8 +31,8 @@ from transformer_lens.utilities import devices from mechir.modelling.hooked.config import HookedTransformerConfig -from mechir.modelling.hooked.components import BertEmbed -from mechir.modelling.hooked.linear import ClassificationHead, MLPClassificationHead +from mechir.modelling.architectures.base.components import BertEmbed +from mechir.modelling.architectures.base.linear import ClassificationHead, MLPClassificationHead class HookedEncoder(HookedRootModule): diff --git a/src/mechir/modelling/hooked/__init__.py b/src/mechir/modelling/hooked/__init__.py index fefe61c..e69de29 100644 --- a/src/mechir/modelling/hooked/__init__.py +++ b/src/mechir/modelling/hooked/__init__.py @@ -1,4 +0,0 @@ -from .HookedDistilBert import HookedDistilBert -from .HookedDistilBert import HookedDistilBertForSequenceClassification -from .HookedElectra import HookedElectraForSequenceClassification -from .linear import ClassificationHead, HiddenLinear diff --git a/src/mechir/modelling/hooked/config.py b/src/mechir/modelling/hooked/config.py index 7ede4d6..56d8244 100644 --- a/src/mechir/modelling/hooked/config.py +++ b/src/mechir/modelling/hooked/config.py @@ -10,17 +10,16 @@ import pprint import random from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, Union import numpy as np import torch -from transformer_lens import utils +from transformer_lens import utils, HookedTransformerConfig as _HookedTransformerConfig from transformer_lens.utilities.activation_functions import SUPPORTED_ACTIVATIONS - @dataclass -class HookedTransformerConfig: +class HookedTransformerConfig(_HookedTransformerConfig): """ Configuration class to store the configuration of a HookedTransformer model. @@ -182,68 +181,6 @@ class HookedTransformerConfig: num_labels (int): The number of labels for the classification task. Defaults to 1. use_token_type_ids (bool): Whether to use token type ids. Defaults to True. """ - - n_layers: int - d_model: int - n_ctx: int - d_head: int - model_name: str = "custom" - n_heads: int = -1 - d_mlp: Optional[int] = None - act_fn: Optional[str] = None - d_vocab: int = -1 - eps: float = 1e-5 - use_attn_result: bool = False - use_attn_scale: bool = True - attn_scale: float = -1.0 - use_split_qkv_input: bool = False - use_hook_mlp_in: bool = False - use_attn_in: bool = False - use_local_attn: bool = False - original_architecture: Optional[str] = None - from_checkpoint: bool = False - checkpoint_index: Optional[int] = None - checkpoint_label_type: Optional[str] = None - checkpoint_value: Optional[int] = None - tokenizer_name: Optional[str] = None - window_size: Optional[int] = None - attn_types: Optional[List] = None - init_mode: str = "gpt2" - normalization_type: Optional[str] = "LN" - device: Optional[str] = None - n_devices: int = 1 - attention_dir: str = "causal" - attn_only: bool = False - seed: Optional[int] = None - initializer_range: float = -1.0 - init_weights: bool = True - scale_attn_by_inverse_layer_idx: bool = False - positional_embedding_type: str = "standard" - final_rms: bool = False - d_vocab_out: int = -1 - parallel_attn_mlp: bool = False - rotary_dim: Optional[int] = None - n_params: Optional[int] = None - use_hook_tokens: bool = False - gated_mlp: bool = False - default_prepend_bos: bool = True - dtype: torch.dtype = torch.float32 - tokenizer_prepends_bos: Optional[bool] = None - n_key_value_heads: Optional[int] = None - post_embedding_ln: bool = False - rotary_base: int = 10000 - trust_remote_code: bool = False - rotary_adjacent_pairs: bool = False - load_in_4bit: bool = False - num_experts: Optional[int] = None - experts_per_token: Optional[int] = None - relative_attention_max_distance: Optional[int] = None - relative_attention_num_buckets: Optional[int] = None - decoder_start_token_id: Optional[int] = None - tie_word_embeddings: bool = False - use_normalization_before_and_after: bool = False - attn_scores_soft_cap: float = -1.0 - output_logits_soft_cap: float = -1.0 num_labels: int = 1 use_token_type_ids: bool = True use_mlp_head: bool = False diff --git a/src/mechir/modelling/hooked/conversion.py b/src/mechir/modelling/hooked/conversion.py index 1602324..f49c57f 100644 --- a/src/mechir/modelling/hooked/conversion.py +++ b/src/mechir/modelling/hooked/conversion.py @@ -92,10 +92,11 @@ def convert_bert_based_weights( model_name: str = "bert", ): if not hasattr(bert, "embeddings"): - bert = getattr(bert, model_name) - embeddings = bert.embeddings + encoder = getattr(bert, model_name) + embeddings = encoder.embeddings else: embeddings = bert.embeddings + encoder = bert state_dict = { "embed.embed.W_E": embeddings.word_embeddings.weight, "embed.pos_embed.W_pos": embeddings.position_embeddings.weight, diff --git a/test/acceptance/test_hookedencoder.py b/test/acceptance/test_hookedencoder.py index 4b1b90b..b377c0b 100644 --- a/test/acceptance/test_hookedencoder.py +++ b/test/acceptance/test_hookedencoder.py @@ -5,7 +5,7 @@ import torch.nn.functional as F from jaxtyping import Float from torch.testing import assert_close -from transformers import AutoTokenizer, AutoModel, BertForPreTraining +from transformers import AutoTokenizer, BertForPreTraining from mechir.modelling.architectures import HookedEncoder