diff --git a/profold2/model/alphafold2.py b/profold2/model/alphafold2.py index d5e586e7..5bd9c7d7 100644 --- a/profold2/model/alphafold2.py +++ b/profold2/model/alphafold2.py @@ -243,9 +243,9 @@ def forward( else: msa, msa_mask, msa_embed = None, None, None # msa as features disabled del seq_embed, msa_embed + b, n, device = seq.shape[:-1], seq.shape[-1], seq.device # FIXME: fake recyclables if 'recyclables' not in batch: - b, n, device = seq.shape[:-1], seq.shape[-1], seq.device _, dim_msa, dim_pairwise = self.dim # embedd_dim_get(self.dim) batch['recyclables'] = Recyclables( msa_first_row_repr=torch.zeros(b + (n, dim_msa), device=device),