Skip to content

Commit e4601e4

Browse files
authored
Merge pull request #399 from bigict/model
feat: predict with conditional coords
2 parents ea08dc2 + 7d0c375 commit e4601e4

2 files changed

Lines changed: 68 additions & 13 deletions

File tree

profold2/command/trainer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ def create_cycling_data(
214214
accept_msa_attn=args.model_evoformer_accept_msa_attn,
215215
accept_frame_attn=args.model_evoformer_accept_frame_attn,
216216
accept_frame_update=args.model_evoformer_accept_frame_update,
217+
conditional_pos=args.model_conditional_pos,
218+
recycling_frames=args.model_recycling_frames,
219+
recycling_pos=args.model_recycling_pos,
217220
headers=headers
218221
)
219222
####
@@ -743,6 +746,11 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
743746
default='model_headers_main.json',
744747
help='json format headers of model.'
745748
)
749+
parser.add_argument(
750+
'--model_conditional_pos',
751+
action='store_true',
752+
help='enable backbone atom position conditional.'
753+
)
746754
parser.add_argument(
747755
'--model_recycles', type=int, default=2, help='number of recycles in model.'
748756
)

profold2/model/alphafold2.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,10 @@ def __init__(
144144
accept_msa_attn=True,
145145
accept_frame_attn=False,
146146
accept_frame_update=False,
147+
conditional_pos=False,
148+
conditional_pos_min_bin=3.25,
149+
conditional_pos_max_bin=20.75,
150+
conditional_pos_num_bin=15,
147151
recycling_single_repr=True,
148152
recycling_frames=False,
149153
recycling_pos=False,
@@ -186,6 +190,18 @@ def __init__(
186190
# msa to single activations
187191
self.to_single_repr = nn.Linear(dim_msa, dim_single)
188192

193+
# conditional params
194+
self.conditional_pos_linear = nn.Linear(
195+
recycling_pos_num_bin, dim_pairwise
196+
) if conditional_pos else None
197+
if conditional_pos:
198+
conditional_pos_breaks = torch.linspace(
199+
conditional_pos_min_bin,
200+
conditional_pos_max_bin,
201+
steps=conditional_pos_num_bin
202+
)
203+
self.register_buffer('conditional_pos_breaks', conditional_pos_breaks)
204+
189205
# recycling params
190206
self.recycling_to_msa_repr = nn.Linear(
191207
dim_single, dim_msa
@@ -227,7 +243,48 @@ def forward(
227243
else:
228244
msa, msa_mask, msa_embed = None, None, None # msa as features disabled
229245
del seq_embed, msa_embed
230-
recyclables, = map(batch.get, ('recyclables', ))
246+
# FIXME: fake recyclables
247+
if 'recyclables' not in batch:
248+
b, n, device = seq.shape[:-1], seq.shape[-1], seq.device
249+
_, dim_msa, dim_pairwise = self.dim # embedd_dim_get(self.dim)
250+
if exists(self.conditional_pos_linear):
251+
# assert all(key in batch for key in ('cond_mask', 'coord', 'coord_mask'))
252+
if 'cond_mask' in batch:
253+
cond_mask = batch['cond_mask']
254+
else:
255+
cond_mask = torch.zeros(b + (n, ), device=device)
256+
cond_mask = cond_mask[..., :, None] * cond_mask[..., None, :]
257+
258+
if 'coord' in batch:
259+
coord = batch['coord']
260+
else:
261+
coord = torch.zeros(
262+
b + (n, residue_constants.atom14_type_num, 3), device=device
263+
)
264+
if 'coord_mask' in batch:
265+
coord_mask = batch['coord_mask']
266+
else:
267+
coord_mask = torch.zeros(
268+
b + (n, residue_constants.atom14_type_num), device=device
269+
)
270+
pseudo_beta, pseudo_beta_mask = functional.pseudo_beta_fn(
271+
seq, coord, coord_mask
272+
)
273+
dgram = functional.distogram_from_positions(
274+
self.conditional_pos_breaks, pseudo_beta
275+
)
276+
pairwise_repr = self.conditional_pos_linear(dgram) * pseudo_beta_mask[..., None]
277+
278+
pairwise_repr = pairwise_repr * cond_mask[..., None]
279+
else:
280+
pairwise_repr = torch.zeros(b + (n, n, dim_pairwise), device=device)
281+
282+
batch['recyclables'] = Recyclables(
283+
msa_first_row_repr=torch.zeros(b + (n, dim_msa), device=device),
284+
pairwise_repr=pairwise_repr,
285+
coords=torch.zeros(b + (n, residue_constants.atom_type_num, 3), device=device)
286+
)
287+
recyclables = batch['recyclables']
231288

232289
representations = {'recycling': return_recyclables}
233290

@@ -350,6 +407,8 @@ def from_config(config):
350407
'template_depth',
351408
'num_tokens',
352409
'num_msa_tokens',
410+
'conditional_pos',
411+
'recycling_frames',
353412
'recycling_single_repr',
354413
'recycling_pos',
355414
):
@@ -367,18 +426,6 @@ def embeddings(self):
367426
def forward(self, batch, *, num_recycle=0, **kwargs):
368427
assert num_recycle >= 0
369428

370-
# variables
371-
seq = batch['seq']
372-
b, n, device = seq.shape[:-1], seq.shape[-1], seq.device
373-
# FIXME: fake recyclables
374-
if 'recyclables' not in batch:
375-
_, dim_msa, dim_pairwise = self.impl.dim # embedd_dim_get(self.impl.dim)
376-
batch['recyclables'] = Recyclables(
377-
msa_first_row_repr=torch.zeros(b + (n, dim_msa), device=device),
378-
pairwise_repr=torch.zeros(b + (n, n, dim_pairwise), device=device),
379-
coords=torch.zeros(b + (n, residue_constants.atom_type_num, 3), device=device)
380-
)
381-
382429
if self.training:
383430
num_recycle = random.randint(0, num_recycle)
384431
cycling_function = functools.partial(

0 commit comments

Comments
 (0)