diff --git a/profold2/model/head.py b/profold2/model/head.py index 7079dea7..bec8f091 100644 --- a/profold2/model/head.py +++ b/profold2/model/head.py @@ -479,7 +479,7 @@ def backbone_fape_loss(pred_frames_list, gt_frames, frames_mask): if 'seq_color' in batch: clamp_ratio = torch.where( batch['seq_color'][..., :, None] == batch['seq_color'][..., None, :], - clamp_ratio, 0. + clamp_ratio, self.params.get('fape_interchain_clamp_ratio', clamp_ratio) ) _, pred_points = pred_frames