Here is some pseudo-code:
class CleavedAutoModelForCausalLM(nn.Module):
def __init__(self, model: AutoModelForCausalLM, cleave_point: int):
super().__init__()
self.model = model
self.embed_tokens = self.model.model.embed_tokens
# Break into lower and upper halves
self.lower_half = nn.Sequential(*self.model.model.layers[: cleave_point])
self.upper_half = nn.Sequential(*self.model.model.layers[cleave_point:])
self.lm_head = self.model.lm_head
def forward(self, x):
hidden_states = self.embed_tokens(x)
hidden_states = self.lower_half(x)
return self.lm_head(hidden_states)
def full_forward(self, x):
return self.model(x)
def upper_forward(self, input_features):
hidden_states = self.upper_half(input_features)
return self.lm_head(hidden_states)
def verify_step(self, input_features):
return self.upper_forward(input_features)