@@ -797,46 +797,24 @@ def forward(
797797 )
798798
799799
800- class DPLMModel (DPLMPreTrainedModel , EmbeddingMixin ):
801- config_class = DPLMConfig
802-
803- def get_input_embeddings (self ) -> nn .Module :
804- return self .embeddings .word_embeddings
800+ class FAST_DPLM_ENCODER (DPLMPreTrainedModel , EmbeddingMixin ):
801+ """Inner encoder class that holds the actual ESM-style weights (embeddings, encoder,
802+ contact_head) so that the weight keys are prefixed with 'esm.' in the outer DPLMModel,
803+ matching pretrained DPLM checkpoints."""
805804
806- def __init__ (self , config , add_pooling_layer = True ):
807- DPLMPreTrainedModel .__init__ (self , config )
805+ def __init__ (self , config , ** kwargs ):
806+ DPLMPreTrainedModel .__init__ (self , config , ** kwargs )
808807 self .config = config
809808 self .embeddings = EsmEmbeddings (config )
810809 self .encoder = ModifiedEsmEncoder (config )
811- self .pooler = EsmPooler (config ) if add_pooling_layer else None
812810 self .contact_head = EsmContactPredictionHead (
813811 in_features = config .num_hidden_layers * config .num_attention_heads ,
814812 bias = True ,
815813 )
816814 self .post_init ()
817815
818- def _convert_head_mask_to_5d (self , head_mask : torch .Tensor , num_hidden_layers : int ) -> torch .Tensor :
819- if head_mask .dim () == 1 :
820- head_mask = head_mask .unsqueeze (0 ).unsqueeze (0 ).unsqueeze (- 1 ).unsqueeze (- 1 )
821- head_mask = head_mask .expand (num_hidden_layers , - 1 , - 1 , - 1 , - 1 )
822- elif head_mask .dim () == 2 :
823- head_mask = head_mask .unsqueeze (1 ).unsqueeze (- 1 ).unsqueeze (- 1 )
824- assert head_mask .dim () == 5 , f"head_mask.dim != 5, got { head_mask .dim ()} "
825- head_mask = head_mask .to (dtype = self .dtype )
826- return head_mask
827-
828- def get_head_mask (
829- self ,
830- head_mask : Optional [torch .Tensor ],
831- num_hidden_layers : int ,
832- is_attention_chunked : bool = False ,
833- ) -> Union [torch .Tensor , List [None ]]:
834- if head_mask is None :
835- return [None ] * num_hidden_layers
836- head_mask = self ._convert_head_mask_to_5d (head_mask , num_hidden_layers )
837- if is_attention_chunked :
838- head_mask = head_mask .unsqueeze (- 1 )
839- return head_mask
816+ def get_input_embeddings (self ) -> nn .Module :
817+ return self .embeddings .word_embeddings
840818
841819 def set_input_embeddings (self , value ):
842820 self .embeddings .word_embeddings = value
@@ -860,6 +838,29 @@ def predict_contacts(self, input_ids: torch.Tensor, attention_mask: torch.Tensor
860838 attns *= attention_mask .unsqueeze (1 ).unsqueeze (2 ).unsqueeze (4 )
861839 return self .contact_head (input_ids , attns )
862840
841+ def _convert_head_mask_to_5d (self , head_mask : torch .Tensor , num_hidden_layers : int ) -> torch .Tensor :
842+ if head_mask .dim () == 1 :
843+ head_mask = head_mask .unsqueeze (0 ).unsqueeze (0 ).unsqueeze (- 1 ).unsqueeze (- 1 )
844+ head_mask = head_mask .expand (num_hidden_layers , - 1 , - 1 , - 1 , - 1 )
845+ elif head_mask .dim () == 2 :
846+ head_mask = head_mask .unsqueeze (1 ).unsqueeze (- 1 ).unsqueeze (- 1 )
847+ assert head_mask .dim () == 5 , f"head_mask.dim != 5, got { head_mask .dim ()} "
848+ head_mask = head_mask .to (dtype = self .dtype )
849+ return head_mask
850+
851+ def get_head_mask (
852+ self ,
853+ head_mask : Optional [torch .Tensor ],
854+ num_hidden_layers : int ,
855+ is_attention_chunked : bool = False ,
856+ ) -> Union [torch .Tensor , List [None ]]:
857+ if head_mask is None :
858+ return [None ] * num_hidden_layers
859+ head_mask = self ._convert_head_mask_to_5d (head_mask , num_hidden_layers )
860+ if is_attention_chunked :
861+ head_mask = head_mask .unsqueeze (- 1 )
862+ return head_mask
863+
863864 def forward (
864865 self ,
865866 input_ids : Optional [torch .Tensor ] = None ,
@@ -953,21 +954,87 @@ def forward(
953954 flex_block_mask = flex_block_mask ,
954955 )
955956 sequence_output = encoder_outputs [0 ]
956- pooled_output = self .pooler (sequence_output ) if self .pooler is not None else None
957957
958958 if return_dict is False :
959- return (sequence_output , pooled_output ) + encoder_outputs [1 :]
959+ return (sequence_output ,) + encoder_outputs [1 :]
960960
961961 return BaseModelOutputWithPoolingAndCrossAttentions (
962962 last_hidden_state = sequence_output ,
963- pooler_output = pooled_output ,
964963 past_key_values = None ,
965964 hidden_states = encoder_outputs .hidden_states ,
966965 attentions = encoder_outputs .attentions ,
967966 cross_attentions = encoder_outputs .cross_attentions ,
968967 )
969968
970969
970+ class DPLMModel (DPLMPreTrainedModel , EmbeddingMixin ):
971+ config_class = DPLMConfig
972+
973+ def __init__ (self , config , add_pooling_layer = True ):
974+ DPLMPreTrainedModel .__init__ (self , config )
975+ self .config = config
976+ self .esm = FAST_DPLM_ENCODER (config )
977+ self .pooler = EsmPooler (config ) if add_pooling_layer else None
978+ self .post_init ()
979+
980+ def get_input_embeddings (self ) -> nn .Module :
981+ return self .esm .embeddings .word_embeddings
982+
983+ def set_input_embeddings (self , value ):
984+ self .esm .embeddings .word_embeddings = value
985+
986+ def _embed (self , input_ids : torch .Tensor , attention_mask : Optional [torch .Tensor ] = None ) -> torch .Tensor :
987+ return self .esm ._embed (input_ids , attention_mask )
988+
989+ def predict_contacts (self , input_ids : torch .Tensor , attention_mask : torch .Tensor ) -> torch .Tensor :
990+ return self .esm .predict_contacts (input_ids , attention_mask )
991+
992+ def forward (
993+ self ,
994+ input_ids : Optional [torch .Tensor ] = None ,
995+ attention_mask : Optional [torch .Tensor ] = None ,
996+ position_ids : Optional [torch .Tensor ] = None ,
997+ head_mask : Optional [torch .Tensor ] = None ,
998+ inputs_embeds : Optional [torch .Tensor ] = None ,
999+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
1000+ encoder_attention_mask : Optional [torch .Tensor ] = None ,
1001+ past_key_values : Optional [List [torch .FloatTensor ]] = None ,
1002+ use_cache : Optional [bool ] = None ,
1003+ output_attentions : Optional [bool ] = None ,
1004+ output_hidden_states : Optional [bool ] = None ,
1005+ return_dict : Optional [bool ] = None ,
1006+ ) -> Union [Tuple [torch .Tensor ], BaseModelOutputWithPoolingAndCrossAttentions ]:
1007+ outputs = self .esm (
1008+ input_ids = input_ids ,
1009+ attention_mask = attention_mask ,
1010+ position_ids = position_ids ,
1011+ head_mask = head_mask ,
1012+ inputs_embeds = inputs_embeds ,
1013+ encoder_hidden_states = encoder_hidden_states ,
1014+ encoder_attention_mask = encoder_attention_mask ,
1015+ past_key_values = past_key_values ,
1016+ use_cache = use_cache ,
1017+ output_attentions = output_attentions ,
1018+ output_hidden_states = output_hidden_states ,
1019+ return_dict = return_dict ,
1020+ )
1021+ sequence_output = outputs [0 ]
1022+ pooled_output = self .pooler (sequence_output ) if self .pooler is not None else None
1023+
1024+ return_dict = return_dict if return_dict is not None else self .config .use_return_dict
1025+ if return_dict is False :
1026+ return (sequence_output , pooled_output ) + outputs [1 :]
1027+
1028+ return BaseModelOutputWithPoolingAndCrossAttentions (
1029+ last_hidden_state = sequence_output ,
1030+ pooler_output = pooled_output ,
1031+ past_key_values = None ,
1032+ hidden_states = outputs .hidden_states ,
1033+ attentions = outputs .attentions ,
1034+ cross_attentions = outputs .cross_attentions ,
1035+ )
1036+
1037+
9711038class DPLMForMaskedLM (DPLMPreTrainedModel , EmbeddingMixin ):
9721039 config_class = DPLMConfig
9731040
@@ -994,7 +1061,7 @@ def __init__(self, config, dropout: float = 0.1):
9941061 self .contact_head = None
9951062
9961063 def get_input_embeddings (self ) -> nn .Module :
997- return self .esm .embeddings . word_embeddings
1064+ return self .esm .get_input_embeddings ()
9981065
9991066 def get_output_embeddings (self ):
10001067 return self .lm_head .decoder
@@ -1064,7 +1131,7 @@ class DPLMForSequenceClassification(DPLMPreTrainedModel, EmbeddingMixin):
10641131 config_class = DPLMConfig
10651132
10661133 def get_input_embeddings (self ) -> nn .Module :
1067- return self .esm .embeddings . word_embeddings
1134+ return self .esm .get_input_embeddings ()
10681135
10691136 def __init__ (self , config ):
10701137 DPLMPreTrainedModel .__init__ (self , config )
@@ -1134,7 +1201,7 @@ class DPLMForTokenClassification(DPLMPreTrainedModel, EmbeddingMixin):
11341201 config_class = DPLMConfig
11351202
11361203 def get_input_embeddings (self ) -> nn .Module :
1137- return self .esm .embeddings . word_embeddings
1204+ return self .esm .get_input_embeddings ()
11381205
11391206 def __init__ (self , config ):
11401207 DPLMPreTrainedModel .__init__ (self , config )
0 commit comments