@@ -144,6 +144,10 @@ def __init__(
144144 accept_msa_attn = True ,
145145 accept_frame_attn = False ,
146146 accept_frame_update = False ,
147+ conditional_pos = False ,
148+ conditional_pos_min_bin = 3.25 ,
149+ conditional_pos_max_bin = 20.75 ,
150+ conditional_pos_num_bin = 15 ,
147151 recycling_single_repr = True ,
148152 recycling_frames = False ,
149153 recycling_pos = False ,
@@ -186,6 +190,18 @@ def __init__(
186190 # msa to single activations
187191 self .to_single_repr = nn .Linear (dim_msa , dim_single )
188192
193+ # conditional params
194+ self .conditional_pos_linear = nn .Linear (
195+ recycling_pos_num_bin , dim_pairwise
196+ ) if conditional_pos else None
197+ if conditional_pos :
198+ conditional_pos_breaks = torch .linspace (
199+ conditional_pos_min_bin ,
200+ conditional_pos_max_bin ,
201+ steps = conditional_pos_num_bin
202+ )
203+ self .register_buffer ('conditional_pos_breaks' , conditional_pos_breaks )
204+
189205 # recycling params
190206 self .recycling_to_msa_repr = nn .Linear (
191207 dim_single , dim_msa
@@ -227,7 +243,48 @@ def forward(
227243 else :
228244 msa , msa_mask , msa_embed = None , None , None # msa as features disabled
229245 del seq_embed , msa_embed
230- recyclables , = map (batch .get , ('recyclables' , ))
246+ # FIXME: fake recyclables
247+ if 'recyclables' not in batch :
248+ b , n , device = seq .shape [:- 1 ], seq .shape [- 1 ], seq .device
249+ _ , dim_msa , dim_pairwise = self .dim # embedd_dim_get(self.dim)
250+ if exists (self .conditional_pos_linear ):
251+ # assert all(key in batch for key in ('cond_mask', 'coord', 'coord_mask'))
252+ if 'cond_mask' in batch :
253+ cond_mask = batch ['cond_mask' ]
254+ else :
255+ cond_mask = torch .zeros (b + (n , ), device = device )
256+ cond_mask = cond_mask [..., :, None ] * cond_mask [..., None , :]
257+
258+ if 'coord' in batch :
259+ coord = batch ['coord' ]
260+ else :
261+ coord = torch .zeros (
262+ b + (n , residue_constants .atom14_type_num , 3 ), device = device
263+ )
264+ if 'coord_mask' in batch :
265+ coord_mask = batch ['coord_mask' ]
266+ else :
267+ coord_mask = torch .zeros (
268+ b + (n , residue_constants .atom14_type_num ), device = device
269+ )
270+ pseudo_beta , pseudo_beta_mask = functional .pseudo_beta_fn (
271+ seq , coord , coord_mask
272+ )
273+ dgram = functional .distogram_from_positions (
274+ self .conditional_pos_breaks , pseudo_beta
275+ )
276+ pairwise_repr = self .conditional_pos_linear (dgram ) * pseudo_beta_mask [..., None ]
277+
278+ pairwise_repr = pairwise_repr * cond_mask [..., None ]
279+ else :
280+ pairwise_repr = torch .zeros (b + (n , n , dim_pairwise ), device = device )
281+
282+ batch ['recyclables' ] = Recyclables (
283+ msa_first_row_repr = torch .zeros (b + (n , dim_msa ), device = device ),
284+ pairwise_repr = pairwise_repr ,
285+ coords = torch .zeros (b + (n , residue_constants .atom_type_num , 3 ), device = device )
286+ )
287+ recyclables = batch ['recyclables' ]
231288
232289 representations = {'recycling' : return_recyclables }
233290
@@ -350,6 +407,8 @@ def from_config(config):
350407 'template_depth' ,
351408 'num_tokens' ,
352409 'num_msa_tokens' ,
410+ 'conditional_pos' ,
411+ 'recycling_frames' ,
353412 'recycling_single_repr' ,
354413 'recycling_pos' ,
355414 ):
@@ -367,18 +426,6 @@ def embeddings(self):
367426 def forward (self , batch , * , num_recycle = 0 , ** kwargs ):
368427 assert num_recycle >= 0
369428
370- # variables
371- seq = batch ['seq' ]
372- b , n , device = seq .shape [:- 1 ], seq .shape [- 1 ], seq .device
373- # FIXME: fake recyclables
374- if 'recyclables' not in batch :
375- _ , dim_msa , dim_pairwise = self .impl .dim # embedd_dim_get(self.impl.dim)
376- batch ['recyclables' ] = Recyclables (
377- msa_first_row_repr = torch .zeros (b + (n , dim_msa ), device = device ),
378- pairwise_repr = torch .zeros (b + (n , n , dim_pairwise ), device = device ),
379- coords = torch .zeros (b + (n , residue_constants .atom_type_num , 3 ), device = device )
380- )
381-
382429 if self .training :
383430 num_recycle = random .randint (0 , num_recycle )
384431 cycling_function = functools .partial (
0 commit comments