@@ -1278,7 +1278,7 @@ def symmetric_ground_truth_find_optimal(
12781278 assert coord_exists .shape == coord_alt_mask .shape
12791279 assert coord_exists .shape == coord_is_symmetric .shape
12801280
1281- def to_distance (point ):
1281+ def to_distance (points ):
12821282 return torch .sqrt (
12831283 epsilon + torch .sum (
12841284 (points [..., :, None , :, None , :] - points [..., None , :, None , :, :])** 2 ,
@@ -1306,9 +1306,9 @@ def to_lddt(x, y):
13061306 # in cols.
13071307 # shape (N ,N, 14, 14)
13081308 mask = (
1309- rearrange (coord_mask * coord_is_symmetric , '... i c -> ... i () c ()' ) * # rows
1310- rearrange (coord_mask * (~ coord_is_symmetric ), '... j d -> ... () j () d' )
1311- ) # cols
1309+ rearrange (coord_mask * coord_is_symmetric , '... i c -> ... i () c ()' ) * # rows
1310+ rearrange (coord_mask * (~ coord_is_symmetric ), '... j d -> ... () j () d' ) # cols
1311+ )
13121312
13131313 # Aggregate distances for each residue to the non-amibuguous atoms.
13141314 # shape (N)
@@ -1320,7 +1320,7 @@ def to_lddt(x, y):
13201320 return per_res_lddt_alt < per_res_lddt # alt_naming_is_better
13211321
13221322
1323- def symmetric_ground_truth_rename (
1323+ def symmetric_ground_truth_renaming (
13241324 coord_pred ,
13251325 coord_exists ,
13261326 coord ,
@@ -1341,21 +1341,20 @@ def symmetric_ground_truth_rename(
13411341 coord_is_symmetric ,
13421342 epsilon = epsilon
13431343 )
1344- coord_renamed = (
1345- rearrange (~ alt_naming_is_better , '... i -> ... i () ()' ) * coord +
1346- rearrange (alt_naming_is_better , '... i -> ... i () ()' ) * coord_alt
1347- )
1348- coord_renamed_mask = (
1349- rearrange (~ alt_naming_is_better , '... i -> ... i () ()' ) * coord_mask +
1350- rearrange (alt_naming_is_better , '... i -> ... i () ()' ) * coord_alt_mask
1351- )
13521344
1353- return dict (
1354- alt_naming_is_better = alt_naming_is_better , # pylint: disable=use-dict-literal
1355- coord_renamed = coord_renamed ,
1356- coord_renamed_mask = coord_renamed_mask
1345+ def renaming (m , x , y ):
1346+ return (~ m ) * x + m * y
1347+ coord_renamed = renaming (alt_naming_is_better [..., None , None ], coord , coord_alt )
1348+ coord_renamed_mask = renaming (
1349+ alt_naming_is_better [..., None ], coord_mask , coord_alt_mask
13571350 )
13581351
1352+ return {
1353+ 'alt_naming_is_better' : alt_naming_is_better ,
1354+ 'coord_renamed' : coord_renamed ,
1355+ 'coord_renamed_mask' : coord_renamed_mask
1356+ }
1357+
13591358
13601359def contact_precision (
13611360 pred : torch .Tensor ,
@@ -1676,10 +1675,30 @@ def multi_chain_permutation_alignment(value, batch):
16761675
16771676 # Apply the optimal coordinates
16781677 batch ['coord' ][bdx ], batch ['coord_mask' ][bdx ] = coord_opt , coord_mask_opt
1679- if torch .any (batch ['seq_anchor' ] > 0 ) and 'coord_alt' in batch :
1680- batch .update (
1681- symmetric_ground_truth_create_alt (
1682- batch ['seq' ], batch ['coord' ], batch ['coord_mask' ]
1683- )
1684- )
1678+ if torch .any (batch ['seq_anchor' ] > 0 ):
1679+ if 'coord_alt' in batch :
1680+ batch .update (
1681+ symmetric_ground_truth_create_alt (
1682+ batch ['seq' ], batch ['coord' ], batch ['coord_mask' ]
1683+ )
1684+ )
1685+ if 'backbone_affine' in batch :
1686+ n_idx = residue_constants .atom_order ['N' ]
1687+ ca_idx = residue_constants .atom_order ['CA' ]
1688+ c_idx = residue_constants .atom_order ['C' ]
1689+
1690+ batch ['backbone_affine' ] = rigids_from_3x3 (
1691+ batch ['coord' ], indices = (c_idx , ca_idx , n_idx )
1692+ )
1693+
1694+ coord_mask = batch ['coord_mask' ]
1695+ coord_mask = torch .stack (
1696+ (coord_mask [..., c_idx ], coord_mask [..., ca_idx ], coord_mask [..., n_idx ]),
1697+ dim = - 1
1698+ )
1699+ batch ['backbone_affine_mask' ] = torch .all (coord_mask != 0 , dim = - 1 )
1700+ if 'atom_affine' in batch :
1701+ batch .update (
1702+ rigids_from_positions (batch ['seq' ], batch ['coord' ], batch ['coord_mask' ])
1703+ )
16851704 return batch
0 commit comments