Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions profold2/command/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def create_cycling_data(
accept_msa_attn=args.model_evoformer_accept_msa_attn,
accept_frame_attn=args.model_evoformer_accept_frame_attn,
accept_frame_update=args.model_evoformer_accept_frame_update,
conditional_pos=args.model_conditional_pos,
recycling_frames=args.model_recycling_frames,
recycling_pos=args.model_recycling_pos,
headers=headers
)
####
Expand Down Expand Up @@ -743,6 +746,11 @@ def add_arguments(parser): # pylint: disable=redefined-outer-name
default='model_headers_main.json',
help='json format headers of model.'
)
parser.add_argument(
'--model_conditional_pos',
action='store_true',
help='enable backbone atom position conditional.'
)
parser.add_argument(
'--model_recycles', type=int, default=2, help='number of recycles in model.'
)
Expand Down
73 changes: 60 additions & 13 deletions profold2/model/alphafold2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def __init__(
accept_msa_attn=True,
accept_frame_attn=False,
accept_frame_update=False,
conditional_pos=False,
conditional_pos_min_bin=3.25,
conditional_pos_max_bin=20.75,
conditional_pos_num_bin=15,
recycling_single_repr=True,
recycling_frames=False,
recycling_pos=False,
Expand Down Expand Up @@ -186,6 +190,18 @@ def __init__(
# msa to single activations
self.to_single_repr = nn.Linear(dim_msa, dim_single)

# conditional params
self.conditional_pos_linear = nn.Linear(
recycling_pos_num_bin, dim_pairwise
) if conditional_pos else None
if conditional_pos:
conditional_pos_breaks = torch.linspace(
conditional_pos_min_bin,
conditional_pos_max_bin,
steps=conditional_pos_num_bin
)
self.register_buffer('conditional_pos_breaks', conditional_pos_breaks)

# recycling params
self.recycling_to_msa_repr = nn.Linear(
dim_single, dim_msa
Expand Down Expand Up @@ -227,7 +243,48 @@ def forward(
else:
msa, msa_mask, msa_embed = None, None, None # msa as features disabled
del seq_embed, msa_embed
recyclables, = map(batch.get, ('recyclables', ))
# FIXME: fake recyclables
if 'recyclables' not in batch:
b, n, device = seq.shape[:-1], seq.shape[-1], seq.device
_, dim_msa, dim_pairwise = self.dim # embedd_dim_get(self.dim)
if exists(self.conditional_pos_linear):
# assert all(key in batch for key in ('cond_mask', 'coord', 'coord_mask'))
if 'cond_mask' in batch:
cond_mask = batch['cond_mask']
else:
cond_mask = torch.zeros(b + (n, ), device=device)
cond_mask = cond_mask[..., :, None] * cond_mask[..., None, :]

if 'coord' in batch:
coord = batch['coord']
else:
coord = torch.zeros(
b + (n, residue_constants.atom14_type_num, 3), device=device
)
if 'coord_mask' in batch:
coord_mask = batch['coord_mask']
else:
coord_mask = torch.zeros(
b + (n, residue_constants.atom14_type_num), device=device
)
pseudo_beta, pseudo_beta_mask = functional.pseudo_beta_fn(
seq, coord, coord_mask
)
dgram = functional.distogram_from_positions(
self.conditional_pos_breaks, pseudo_beta
)
pairwise_repr = self.conditional_pos_linear(dgram) * pseudo_beta_mask[..., None]

pairwise_repr = pairwise_repr * cond_mask[..., None]
else:
pairwise_repr = torch.zeros(b + (n, n, dim_pairwise), device=device)

batch['recyclables'] = Recyclables(
msa_first_row_repr=torch.zeros(b + (n, dim_msa), device=device),
pairwise_repr=pairwise_repr,
coords=torch.zeros(b + (n, residue_constants.atom_type_num, 3), device=device)
)
recyclables = batch['recyclables']

representations = {'recycling': return_recyclables}

Expand Down Expand Up @@ -350,6 +407,8 @@ def from_config(config):
'template_depth',
'num_tokens',
'num_msa_tokens',
'conditional_pos',
'recycling_frames',
'recycling_single_repr',
'recycling_pos',
):
Expand All @@ -367,18 +426,6 @@ def embeddings(self):
def forward(self, batch, *, num_recycle=0, **kwargs):
assert num_recycle >= 0

# variables
seq = batch['seq']
b, n, device = seq.shape[:-1], seq.shape[-1], seq.device
# FIXME: fake recyclables
if 'recyclables' not in batch:
_, dim_msa, dim_pairwise = self.impl.dim # embedd_dim_get(self.impl.dim)
batch['recyclables'] = Recyclables(
msa_first_row_repr=torch.zeros(b + (n, dim_msa), device=device),
pairwise_repr=torch.zeros(b + (n, n, dim_pairwise), device=device),
coords=torch.zeros(b + (n, residue_constants.atom_type_num, 3), device=device)
)

if self.training:
num_recycle = random.randint(0, num_recycle)
cycling_function = functools.partial(
Expand Down