Skip to content

Commit 393bf2c

Browse files
committed
better weight loading
1 parent 0586189 commit 393bf2c

File tree

2 files changed

+197
-64
lines changed

2 files changed

+197
-64
lines changed

dplm2_fastplms/modeling_dplm2.py

Lines changed: 94 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -890,17 +890,38 @@ def forward(
890890
)
891891

892892

893-
class DPLM2Model(DPLM2PreTrainedModel, EmbeddingMixin):
894-
config_class = DPLM2Config
893+
class FAST_DPLM2_ENCODER(DPLM2PreTrainedModel, EmbeddingMixin):
894+
"""Inner encoder class that holds the actual ESM-style weights (embeddings, encoder)
895+
so that the weight keys are prefixed with 'esm.' in the outer DPLM2Model,
896+
matching pretrained DPLM2 checkpoints."""
895897

896-
def __init__(self, config, add_pooling_layer=True):
897-
DPLM2PreTrainedModel.__init__(self, config)
898+
def __init__(self, config, **kwargs):
899+
DPLM2PreTrainedModel.__init__(self, config, **kwargs)
898900
self.config = config
899901
self.embeddings = EsmEmbeddings(config)
900902
self.encoder = ModifiedEsmEncoder(config)
901-
self.pooler = EsmPooler(config) if add_pooling_layer else None
902903
self.post_init()
903904

905+
def get_input_embeddings(self) -> nn.Module:
906+
return self.embeddings.word_embeddings
907+
908+
def set_input_embeddings(self, value):
909+
self.embeddings.word_embeddings = value
910+
911+
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
912+
if attention_mask is None:
913+
attention_mask = input_ids.ne(self.config.pad_token_id)
914+
type_ids = _infer_modality_type(input_ids, attention_mask)
915+
outputs = self(
916+
input_ids=input_ids,
917+
attention_mask=attention_mask,
918+
type_ids=type_ids,
919+
output_hidden_states=False,
920+
output_attentions=False,
921+
return_dict=True,
922+
)
923+
return outputs.last_hidden_state
924+
904925
def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor:
905926
if head_mask.dim() == 1:
906927
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
@@ -924,26 +945,6 @@ def get_head_mask(
924945
head_mask = head_mask.unsqueeze(-1)
925946
return head_mask
926947

927-
def get_input_embeddings(self) -> nn.Module:
928-
return self.embeddings.word_embeddings
929-
930-
def set_input_embeddings(self, value):
931-
self.embeddings.word_embeddings = value
932-
933-
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
934-
if attention_mask is None:
935-
attention_mask = input_ids.ne(self.config.pad_token_id)
936-
type_ids = _infer_modality_type(input_ids, attention_mask)
937-
outputs = self(
938-
input_ids=input_ids,
939-
attention_mask=attention_mask,
940-
type_ids=type_ids,
941-
output_hidden_states=False,
942-
output_attentions=False,
943-
return_dict=True,
944-
)
945-
return outputs.last_hidden_state
946-
947948
def forward(
948949
self,
949950
input_ids: Optional[torch.Tensor] = None,
@@ -1039,21 +1040,86 @@ def forward(
10391040
flex_block_mask=flex_block_mask,
10401041
)
10411042
sequence_output = encoder_outputs[0]
1042-
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
10431043

10441044
if return_dict is False:
1045-
return (sequence_output, pooled_output) + encoder_outputs[1:]
1045+
return (sequence_output,) + encoder_outputs[1:]
10461046

10471047
return BaseModelOutputWithPoolingAndCrossAttentions(
10481048
last_hidden_state=sequence_output,
1049-
pooler_output=pooled_output,
10501049
past_key_values=None,
10511050
hidden_states=encoder_outputs.hidden_states,
10521051
attentions=encoder_outputs.attentions,
10531052
cross_attentions=encoder_outputs.cross_attentions,
10541053
)
10551054

10561055

1056+
class DPLM2Model(DPLM2PreTrainedModel, EmbeddingMixin):
1057+
config_class = DPLM2Config
1058+
1059+
def __init__(self, config, add_pooling_layer=True):
1060+
DPLM2PreTrainedModel.__init__(self, config)
1061+
self.config = config
1062+
self.esm = FAST_DPLM2_ENCODER(config)
1063+
self.pooler = EsmPooler(config) if add_pooling_layer else None
1064+
self.post_init()
1065+
1066+
def get_input_embeddings(self) -> nn.Module:
1067+
return self.esm.embeddings.word_embeddings
1068+
1069+
def set_input_embeddings(self, value):
1070+
self.esm.embeddings.word_embeddings = value
1071+
1072+
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1073+
return self.esm._embed(input_ids, attention_mask)
1074+
1075+
def forward(
1076+
self,
1077+
input_ids: Optional[torch.Tensor] = None,
1078+
attention_mask: Optional[torch.Tensor] = None,
1079+
position_ids: Optional[torch.Tensor] = None,
1080+
head_mask: Optional[torch.Tensor] = None,
1081+
inputs_embeds: Optional[torch.Tensor] = None,
1082+
encoder_hidden_states: Optional[torch.Tensor] = None,
1083+
encoder_attention_mask: Optional[torch.Tensor] = None,
1084+
past_key_values: Optional[List[torch.FloatTensor]] = None,
1085+
use_cache: Optional[bool] = None,
1086+
output_attentions: Optional[bool] = None,
1087+
output_hidden_states: Optional[bool] = None,
1088+
return_dict: Optional[bool] = None,
1089+
type_ids: Optional[torch.Tensor] = None,
1090+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1091+
outputs = self.esm(
1092+
input_ids=input_ids,
1093+
attention_mask=attention_mask,
1094+
position_ids=position_ids,
1095+
head_mask=head_mask,
1096+
inputs_embeds=inputs_embeds,
1097+
encoder_hidden_states=encoder_hidden_states,
1098+
encoder_attention_mask=encoder_attention_mask,
1099+
past_key_values=past_key_values,
1100+
use_cache=use_cache,
1101+
output_attentions=output_attentions,
1102+
output_hidden_states=output_hidden_states,
1103+
return_dict=return_dict,
1104+
type_ids=type_ids,
1105+
)
1106+
sequence_output = outputs[0]
1107+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1108+
1109+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1110+
if return_dict is False:
1111+
return (sequence_output, pooled_output) + outputs[1:]
1112+
1113+
return BaseModelOutputWithPoolingAndCrossAttentions(
1114+
last_hidden_state=sequence_output,
1115+
pooler_output=pooled_output,
1116+
past_key_values=None,
1117+
hidden_states=outputs.hidden_states,
1118+
attentions=outputs.attentions,
1119+
cross_attentions=outputs.cross_attentions,
1120+
)
1121+
1122+
10571123
class DPLM2ForMaskedLM(DPLM2PreTrainedModel, EmbeddingMixin):
10581124
config_class = DPLM2Config
10591125

dplm_fastplms/modeling_dplm.py

Lines changed: 103 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -797,46 +797,24 @@ def forward(
797797
)
798798

799799

800-
class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
801-
config_class = DPLMConfig
802-
803-
def get_input_embeddings(self) -> nn.Module:
804-
return self.embeddings.word_embeddings
800+
class FAST_DPLM_ENCODER(DPLMPreTrainedModel, EmbeddingMixin):
801+
"""Inner encoder class that holds the actual ESM-style weights (embeddings, encoder,
802+
contact_head) so that the weight keys are prefixed with 'esm.' in the outer DPLMModel,
803+
matching pretrained DPLM checkpoints."""
805804

806-
def __init__(self, config, add_pooling_layer=True):
807-
DPLMPreTrainedModel.__init__(self, config)
805+
def __init__(self, config, **kwargs):
806+
DPLMPreTrainedModel.__init__(self, config, **kwargs)
808807
self.config = config
809808
self.embeddings = EsmEmbeddings(config)
810809
self.encoder = ModifiedEsmEncoder(config)
811-
self.pooler = EsmPooler(config) if add_pooling_layer else None
812810
self.contact_head = EsmContactPredictionHead(
813811
in_features=config.num_hidden_layers * config.num_attention_heads,
814812
bias=True,
815813
)
816814
self.post_init()
817815

818-
def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor:
819-
if head_mask.dim() == 1:
820-
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
821-
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
822-
elif head_mask.dim() == 2:
823-
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
824-
assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}"
825-
head_mask = head_mask.to(dtype=self.dtype)
826-
return head_mask
827-
828-
def get_head_mask(
829-
self,
830-
head_mask: Optional[torch.Tensor],
831-
num_hidden_layers: int,
832-
is_attention_chunked: bool = False,
833-
) -> Union[torch.Tensor, List[None]]:
834-
if head_mask is None:
835-
return [None] * num_hidden_layers
836-
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
837-
if is_attention_chunked:
838-
head_mask = head_mask.unsqueeze(-1)
839-
return head_mask
816+
def get_input_embeddings(self) -> nn.Module:
817+
return self.embeddings.word_embeddings
840818

841819
def set_input_embeddings(self, value):
842820
self.embeddings.word_embeddings = value
@@ -860,6 +838,29 @@ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor
860838
attns *= attention_mask.unsqueeze(1).unsqueeze(2).unsqueeze(4)
861839
return self.contact_head(input_ids, attns)
862840

841+
def _convert_head_mask_to_5d(self, head_mask: torch.Tensor, num_hidden_layers: int) -> torch.Tensor:
842+
if head_mask.dim() == 1:
843+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
844+
head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
845+
elif head_mask.dim() == 2:
846+
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
847+
assert head_mask.dim() == 5, f"head_mask.dim != 5, got {head_mask.dim()}"
848+
head_mask = head_mask.to(dtype=self.dtype)
849+
return head_mask
850+
851+
def get_head_mask(
852+
self,
853+
head_mask: Optional[torch.Tensor],
854+
num_hidden_layers: int,
855+
is_attention_chunked: bool = False,
856+
) -> Union[torch.Tensor, List[None]]:
857+
if head_mask is None:
858+
return [None] * num_hidden_layers
859+
head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
860+
if is_attention_chunked:
861+
head_mask = head_mask.unsqueeze(-1)
862+
return head_mask
863+
863864
def forward(
864865
self,
865866
input_ids: Optional[torch.Tensor] = None,
@@ -953,21 +954,87 @@ def forward(
953954
flex_block_mask=flex_block_mask,
954955
)
955956
sequence_output = encoder_outputs[0]
956-
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
957957

958958
if return_dict is False:
959-
return (sequence_output, pooled_output) + encoder_outputs[1:]
959+
return (sequence_output,) + encoder_outputs[1:]
960960

961961
return BaseModelOutputWithPoolingAndCrossAttentions(
962962
last_hidden_state=sequence_output,
963-
pooler_output=pooled_output,
964963
past_key_values=None,
965964
hidden_states=encoder_outputs.hidden_states,
966965
attentions=encoder_outputs.attentions,
967966
cross_attentions=encoder_outputs.cross_attentions,
968967
)
969968

970969

970+
class DPLMModel(DPLMPreTrainedModel, EmbeddingMixin):
971+
config_class = DPLMConfig
972+
973+
def __init__(self, config, add_pooling_layer=True):
974+
DPLMPreTrainedModel.__init__(self, config)
975+
self.config = config
976+
self.esm = FAST_DPLM_ENCODER(config)
977+
self.pooler = EsmPooler(config) if add_pooling_layer else None
978+
self.post_init()
979+
980+
def get_input_embeddings(self) -> nn.Module:
981+
return self.esm.embeddings.word_embeddings
982+
983+
def set_input_embeddings(self, value):
984+
self.esm.embeddings.word_embeddings = value
985+
986+
def _embed(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
987+
return self.esm._embed(input_ids, attention_mask)
988+
989+
def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
990+
return self.esm.predict_contacts(input_ids, attention_mask)
991+
992+
def forward(
993+
self,
994+
input_ids: Optional[torch.Tensor] = None,
995+
attention_mask: Optional[torch.Tensor] = None,
996+
position_ids: Optional[torch.Tensor] = None,
997+
head_mask: Optional[torch.Tensor] = None,
998+
inputs_embeds: Optional[torch.Tensor] = None,
999+
encoder_hidden_states: Optional[torch.Tensor] = None,
1000+
encoder_attention_mask: Optional[torch.Tensor] = None,
1001+
past_key_values: Optional[List[torch.FloatTensor]] = None,
1002+
use_cache: Optional[bool] = None,
1003+
output_attentions: Optional[bool] = None,
1004+
output_hidden_states: Optional[bool] = None,
1005+
return_dict: Optional[bool] = None,
1006+
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
1007+
outputs = self.esm(
1008+
input_ids=input_ids,
1009+
attention_mask=attention_mask,
1010+
position_ids=position_ids,
1011+
head_mask=head_mask,
1012+
inputs_embeds=inputs_embeds,
1013+
encoder_hidden_states=encoder_hidden_states,
1014+
encoder_attention_mask=encoder_attention_mask,
1015+
past_key_values=past_key_values,
1016+
use_cache=use_cache,
1017+
output_attentions=output_attentions,
1018+
output_hidden_states=output_hidden_states,
1019+
return_dict=return_dict,
1020+
)
1021+
sequence_output = outputs[0]
1022+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1023+
1024+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1025+
if return_dict is False:
1026+
return (sequence_output, pooled_output) + outputs[1:]
1027+
1028+
return BaseModelOutputWithPoolingAndCrossAttentions(
1029+
last_hidden_state=sequence_output,
1030+
pooler_output=pooled_output,
1031+
past_key_values=None,
1032+
hidden_states=outputs.hidden_states,
1033+
attentions=outputs.attentions,
1034+
cross_attentions=outputs.cross_attentions,
1035+
)
1036+
1037+
9711038
class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin):
9721039
config_class = DPLMConfig
9731040

@@ -994,7 +1061,7 @@ def __init__(self, config, dropout: float = 0.1):
9941061
self.contact_head = None
9951062

9961063
def get_input_embeddings(self) -> nn.Module:
997-
return self.esm.embeddings.word_embeddings
1064+
return self.esm.get_input_embeddings()
9981065

9991066
def get_output_embeddings(self):
10001067
return self.lm_head.decoder
@@ -1064,7 +1131,7 @@ class DPLMForSequenceClassification(DPLMPreTrainedModel, EmbeddingMixin):
10641131
config_class = DPLMConfig
10651132

10661133
def get_input_embeddings(self) -> nn.Module:
1067-
return self.esm.embeddings.word_embeddings
1134+
return self.esm.get_input_embeddings()
10681135

10691136
def __init__(self, config):
10701137
DPLMPreTrainedModel.__init__(self, config)
@@ -1134,7 +1201,7 @@ class DPLMForTokenClassification(DPLMPreTrainedModel, EmbeddingMixin):
11341201
config_class = DPLMConfig
11351202

11361203
def get_input_embeddings(self) -> nn.Module:
1137-
return self.esm.embeddings.word_embeddings
1204+
return self.esm.get_input_embeddings()
11381205

11391206
def __init__(self, config):
11401207
DPLMPreTrainedModel.__init__(self, config)

0 commit comments

Comments
 (0)