Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/mechir/modelling/architectures/base/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .linear import ClassificationHead, HiddenLinear
from .components import BertEmbed
from ._model import HookedEncoder, HookedEncoderForSequenceClassification

__all__ = [
"HookedEncoder",
"HookedEncoderForSequenceClassification",
"ClassificationHead",
"HiddenLinear",
"BertEmbed",
]
606 changes: 606 additions & 0 deletions src/mechir/modelling/architectures/base/_model.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions src/mechir/modelling/architectures/distilbert/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._model import HookedDistilBert, HookedDistilBertForSequenceClassification

__all__ = [
"HookedDistilBert",
"HookedDistilBertForSequenceClassification",
]
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,22 @@
from einops import repeat
from jaxtyping import Float, Int
from torch import nn
from transformers import AutoTokenizer
from transformers import AutoTokenizer, DistilBertModel, DistilBertForSequenceClassification
from typing_extensions import Literal

from transformer_lens.components import BertBlock, BertMLMHead, Unembed
from transformer_lens.hook_points import HookPoint
from mechir.modelling.hooked.components import BertEmbed
from mechir.modelling.hooked.linear import MLPClassificationHead
from mechir.modelling.architectures.base import HookedEncoder
from mechir.modelling.architectures.base.components import BertEmbed
from mechir.modelling.architectures.base.linear import MLPClassificationHead
from mechir.modelling.architectures.base._model import HookedEncoder
from mechir.modelling.hooked.config import HookedTransformerConfig


HookedDistilBert = HookedEncoder
class HookedDistilBert(HookedEncoder):
_hf_class = DistilBertModel

class HookedDistilBertForSequenceClassification(HookedDistilBert):
_hf_class = DistilBertForSequenceClassification
"""
This class implements a BERT-style encoder for sequence classification using the components in ./components.py, with HookPoints on every interesting activation. It inherits from HookedDistilBert.

Expand Down
6 changes: 6 additions & 0 deletions src/mechir/modelling/architectures/electra/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._model import HookedElectra, HookedElectraForSequenceClassification

__all__ = [
"HookedElectra",
"HookedElectraForSequenceClassification",
]
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import logging
from typing import Dict, Optional, Union

from transformers import ElectraModel, ElectraForSequenceClassification
import torch
from jaxtyping import Float, Int
from torch import nn
from transformer_lens.hook_points import HookPoint
from mechir.modelling.hooked.linear import ClassificationHead, HiddenLinear
from mechir.modelling.architectures.base import HookedEncoder
from mechir.modelling.architectures.base.linear import ClassificationHead, HiddenLinear
from mechir.modelling.architectures.base._model import HookedEncoder
from mechir.modelling.hooked.config import HookedTransformerConfig


Expand All @@ -38,7 +39,10 @@ def forward(self, resid: Float[torch.Tensor, "batch d_model"]) -> torch.Tensor:
post_act = self.hook_post(self.activation(pre_act))
return self.out_proj(post_act)

HookedElectra = HookedEncoder

class HookedElectra(HookedEncoder):
_hf_class = ElectraModel


class HookedElectraForSequenceClassification(HookedEncoder):
"""
Expand All @@ -49,6 +53,7 @@ class HookedElectraForSequenceClassification(HookedEncoder):
- There is no preprocessing (e.g. LayerNorm folding) when loading a pretrained model
- The model only accepts tokens as inputs, and not strings, or lists of strings
"""
_hf_class = ElectraForSequenceClassification

def __init__(self, cfg, tokenizer=None, move_to_device=True, **kwargs):
super().__init__(cfg, tokenizer, move_to_device, **kwargs)
Expand Down
Empty file.
Loading