Skip to content

Commit c81d6b6

Browse files
authored
Merge pull request #366 from bigict/head
fix: symmetric_ground_truth_renaming
2 parents b60e5ab + 37be960 commit c81d6b6

1 file changed

Lines changed: 42 additions & 23 deletions

File tree

profold2/model/functional.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,7 +1278,7 @@ def symmetric_ground_truth_find_optimal(
12781278
assert coord_exists.shape == coord_alt_mask.shape
12791279
assert coord_exists.shape == coord_is_symmetric.shape
12801280

1281-
def to_distance(point):
1281+
def to_distance(points):
12821282
return torch.sqrt(
12831283
epsilon + torch.sum(
12841284
(points[..., :, None, :, None, :] - points[..., None, :, None, :, :])**2,
@@ -1306,9 +1306,9 @@ def to_lddt(x, y):
13061306
# in cols.
13071307
# shape (N ,N, 14, 14)
13081308
mask = (
1309-
rearrange(coord_mask * coord_is_symmetric, '... i c -> ... i () c ()') * # rows
1310-
rearrange(coord_mask * (~coord_is_symmetric), '... j d -> ... () j () d')
1311-
) # cols
1309+
rearrange(coord_mask * coord_is_symmetric, '... i c -> ... i () c ()') * # rows
1310+
rearrange(coord_mask * (~coord_is_symmetric), '... j d -> ... () j () d') # cols
1311+
)
13121312

13131313
# Aggregate distances for each residue to the non-amibuguous atoms.
13141314
# shape (N)
@@ -1320,7 +1320,7 @@ def to_lddt(x, y):
13201320
return per_res_lddt_alt < per_res_lddt # alt_naming_is_better
13211321

13221322

1323-
def symmetric_ground_truth_rename(
1323+
def symmetric_ground_truth_renaming(
13241324
coord_pred,
13251325
coord_exists,
13261326
coord,
@@ -1341,21 +1341,20 @@ def symmetric_ground_truth_rename(
13411341
coord_is_symmetric,
13421342
epsilon=epsilon
13431343
)
1344-
coord_renamed = (
1345-
rearrange(~alt_naming_is_better, '... i -> ... i () ()') * coord +
1346-
rearrange(alt_naming_is_better, '... i -> ... i () ()') * coord_alt
1347-
)
1348-
coord_renamed_mask = (
1349-
rearrange(~alt_naming_is_better, '... i -> ... i () ()') * coord_mask +
1350-
rearrange(alt_naming_is_better, '... i -> ... i () ()') * coord_alt_mask
1351-
)
13521344

1353-
return dict(
1354-
alt_naming_is_better=alt_naming_is_better, # pylint: disable=use-dict-literal
1355-
coord_renamed=coord_renamed,
1356-
coord_renamed_mask=coord_renamed_mask
1345+
def renaming(m, x, y):
1346+
return (~m) * x + m * y
1347+
coord_renamed = renaming(alt_naming_is_better[..., None, None], coord, coord_alt)
1348+
coord_renamed_mask = renaming(
1349+
alt_naming_is_better[..., None], coord_mask, coord_alt_mask
13571350
)
13581351

1352+
return {
1353+
'alt_naming_is_better': alt_naming_is_better,
1354+
'coord_renamed': coord_renamed,
1355+
'coord_renamed_mask': coord_renamed_mask
1356+
}
1357+
13591358

13601359
def contact_precision(
13611360
pred: torch.Tensor,
@@ -1676,10 +1675,30 @@ def multi_chain_permutation_alignment(value, batch):
16761675

16771676
# Apply the optimal coordinates
16781677
batch['coord'][bdx], batch['coord_mask'][bdx] = coord_opt, coord_mask_opt
1679-
if torch.any(batch['seq_anchor'] > 0) and 'coord_alt' in batch:
1680-
batch.update(
1681-
symmetric_ground_truth_create_alt(
1682-
batch['seq'], batch['coord'], batch['coord_mask']
1683-
)
1684-
)
1678+
if torch.any(batch['seq_anchor'] > 0):
1679+
if 'coord_alt' in batch:
1680+
batch.update(
1681+
symmetric_ground_truth_create_alt(
1682+
batch['seq'], batch['coord'], batch['coord_mask']
1683+
)
1684+
)
1685+
if 'backbone_affine' in batch:
1686+
n_idx = residue_constants.atom_order['N']
1687+
ca_idx = residue_constants.atom_order['CA']
1688+
c_idx = residue_constants.atom_order['C']
1689+
1690+
batch['backbone_affine'] = rigids_from_3x3(
1691+
batch['coord'], indices=(c_idx, ca_idx, n_idx)
1692+
)
1693+
1694+
coord_mask = batch['coord_mask']
1695+
coord_mask = torch.stack(
1696+
(coord_mask[..., c_idx], coord_mask[..., ca_idx], coord_mask[..., n_idx]),
1697+
dim=-1
1698+
)
1699+
batch['backbone_affine_mask'] = torch.all(coord_mask != 0, dim=-1)
1700+
if 'atom_affine' in batch:
1701+
batch.update(
1702+
rigids_from_positions(batch['seq'], batch['coord'], batch['coord_mask'])
1703+
)
16851704
return batch

0 commit comments

Comments
 (0)