From 16395d92832b2aecf58b8b7233f9a84a296e55c1 Mon Sep 17 00:00:00 2001 From: hmacdope Date: Wed, 27 Aug 2025 17:38:04 +1000 Subject: [PATCH 1/5] add in openfold inference pipeline copy --- openadmet/toolkit/cofolding/openfold.py | 125 ++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 openadmet/toolkit/cofolding/openfold.py diff --git a/openadmet/toolkit/cofolding/openfold.py b/openadmet/toolkit/cofolding/openfold.py new file mode 100644 index 0000000..6f302f7 --- /dev/null +++ b/openadmet/toolkit/cofolding/openfold.py @@ -0,0 +1,125 @@ +import gc +import os +import tempfile +import pandas as pd +from pathlib import Path +from shutil import copyfile +from typing import Optional, Union +import subprocess + +import numpy as np +import torch +from loguru import logger + +from pydantic import Field + +from openadmet.toolkit.cofolding.cofold_base import CoFoldingEngine + + +try: + from openfold3.core.config import config_utils + from openfold3.core.data.pipelines.preprocessing.template import TemplatePreprocessor + from openfold3.core.data.tools.colabfold_msa_server import preprocess_colabfold_msas + from openfold3.entry_points.experiment_runner import ( + InferenceExperimentRunner, + TrainingExperimentRunner, + ) + from openfold3.entry_points.validator import ( + InferenceExperimentConfig, + TrainingExperimentConfig, + ) + from openfold3.projects.of3_all_atom.config.dataset_config_components import ( + colabfold_msa_settings, + ) + from openfold3.projects.of3_all_atom.config.inference_query_format import ( + InferenceQuerySet, + ) +except ImportError: + HAS_OPENFOLD3 = False +else: + HAS_OPENFOLD3 = True + + +class OpenFold3CofoldingEngine(CoFoldingEngine): + + use_msa_server: bool = Field( + False, description="Use MSA server for multiple sequence alignment" + ) + num_diffusion_samples: int = Field(None, description="Number of diffusion samples to generate for each query") + + num_model_seeds: int = Field(None, description="Number of model seeds to use for each query") + + use_msa_server: bool = Field(True, description="Use MSA server for multiple sequence alignment") + + use_templates: bool = Field(False, description="Use templates for structure prediction") + + + def inference( + self, + query_json: Path, + inference_ckpt_path: Path, + runner_yaml: Path | None = None, + output_dir: Path | None = None, + ): + if not HAS_OPENFOLD3: + raise ImportError("OpenFold3 is not installed.") + + runner_args = config_utils.load_yaml(runner_yaml) if runner_yaml else dict() + + expt_config = InferenceExperimentConfig( + inference_ckpt_path=inference_ckpt_path, **runner_args + ) + expt_runner = InferenceExperimentRunner( + expt_config, + self.num_diffusion_samples, + self.num_model_seeds, + self.use_msa_server, + self.use_templates, + output_dir, + ) + + # Dump experiment runner + import json + + with open(output_dir / "experiment_config.json", "w") as f: + json.dump(expt_config.model_dump_json(indent=2), f) + + # Load inference query set + query_set = InferenceQuerySet.from_json(query_json) + + # Perform MSA computation if selected + # update query_set with MSA paths + if expt_runner.use_msa_server: + logger.info("Using ColabFold MSA server for alignments.") + query_set = preprocess_colabfold_msas( + inference_query_set=query_set, + compute_settings=expt_config.msa_computation_settings, + ) + + # Update the msa dataset config settings + updated_dataset_config_kwargs = expt_config.dataset_config_kwargs.model_copy( + update={"msa": colabfold_msa_settings} + ) + expt_config = expt_config.model_copy( + update={"dataset_config_kwargs": updated_dataset_config_kwargs} + ) + else: + expt_config.msa_computation_settings.cleanup_msa_dir = False + + # Preprocess template alignments and optionally template structures + if expt_runner.use_templates: + logger.info("Using templates for inference.") + template_preprocessor = TemplatePreprocessor( + input_set=query_set, + config=expt_config.dataset_config_kwargs.template_preprocessor, + ) + template_preprocessor() + else: + logger.info("Not using templates for inference.") + + # Run the forward pass + expt_runner.setup() + expt_runner.run(query_set) + expt_runner.cleanup() + + logger.info("Inference completed successfully.") From fb3366cd6e809399ca125dcaa5694805eb9cf360 Mon Sep 17 00:00:00 2001 From: hmacdope Date: Wed, 27 Aug 2025 17:47:11 +1000 Subject: [PATCH 2/5] add OF3 --- openadmet/toolkit/cofolding/openfold.py | 1 + 1 file changed, 1 insertion(+) diff --git a/openadmet/toolkit/cofolding/openfold.py b/openadmet/toolkit/cofolding/openfold.py index 6f302f7..1ba8008 100644 --- a/openadmet/toolkit/cofolding/openfold.py +++ b/openadmet/toolkit/cofolding/openfold.py @@ -53,6 +53,7 @@ class OpenFold3CofoldingEngine(CoFoldingEngine): use_templates: bool = Field(False, description="Use templates for structure prediction") + inference_ckpt_path: Path = Field(None, description="Path to the inference checkpoint") def inference( self, From e4ec5cb278f43cab33a8f5f76118c4ed470b31c3 Mon Sep 17 00:00:00 2001 From: hmacdope Date: Thu, 28 Aug 2025 17:00:18 +1000 Subject: [PATCH 3/5] minor bugfixes --- openadmet/toolkit/cofolding/openfold.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/openadmet/toolkit/cofolding/openfold.py b/openadmet/toolkit/cofolding/openfold.py index 1ba8008..fca1473 100644 --- a/openadmet/toolkit/cofolding/openfold.py +++ b/openadmet/toolkit/cofolding/openfold.py @@ -11,7 +11,7 @@ import torch from loguru import logger -from pydantic import Field +from pydantic import Field, field_validator from openadmet.toolkit.cofolding.cofold_base import CoFoldingEngine @@ -53,14 +53,20 @@ class OpenFold3CofoldingEngine(CoFoldingEngine): use_templates: bool = Field(False, description="Use templates for structure prediction") - inference_ckpt_path: Path = Field(None, description="Path to the inference checkpoint") + inference_ckpt_path: Path = Field(description="Path to the inference checkpoint") + + + @field_validator("inference_ckpt_path") + def check_inference_ckpt_path(cls, v): + # path must exist + if not v.exists(): + raise ValueError(f"inference_ckpt_path must exist, got {v}") + return v def inference( self, query_json: Path, - inference_ckpt_path: Path, runner_yaml: Path | None = None, - output_dir: Path | None = None, ): if not HAS_OPENFOLD3: raise ImportError("OpenFold3 is not installed.") @@ -68,7 +74,7 @@ def inference( runner_args = config_utils.load_yaml(runner_yaml) if runner_yaml else dict() expt_config = InferenceExperimentConfig( - inference_ckpt_path=inference_ckpt_path, **runner_args + inference_ckpt_path=self.inference_ckpt_path, **runner_args ) expt_runner = InferenceExperimentRunner( expt_config, @@ -76,13 +82,13 @@ def inference( self.num_model_seeds, self.use_msa_server, self.use_templates, - output_dir, + self.output_dir, ) # Dump experiment runner import json - with open(output_dir / "experiment_config.json", "w") as f: + with open(self.output_dir / "experiment_config.json", "w") as f: json.dump(expt_config.model_dump_json(indent=2), f) # Load inference query set From b4ba2fe23e15105a1c4bc0ddef28032a4cd1b1a0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Aug 2025 07:01:19 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- openadmet/toolkit/cofolding/openfold.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/openadmet/toolkit/cofolding/openfold.py b/openadmet/toolkit/cofolding/openfold.py index fca1473..ebca9ea 100644 --- a/openadmet/toolkit/cofolding/openfold.py +++ b/openadmet/toolkit/cofolding/openfold.py @@ -16,7 +16,7 @@ from openadmet.toolkit.cofolding.cofold_base import CoFoldingEngine -try: +try: from openfold3.core.config import config_utils from openfold3.core.data.pipelines.preprocessing.template import TemplatePreprocessor from openfold3.core.data.tools.colabfold_msa_server import preprocess_colabfold_msas @@ -64,13 +64,13 @@ def check_inference_ckpt_path(cls, v): return v def inference( - self, + self, query_json: Path, runner_yaml: Path | None = None, ): if not HAS_OPENFOLD3: raise ImportError("OpenFold3 is not installed.") - + runner_args = config_utils.load_yaml(runner_yaml) if runner_yaml else dict() expt_config = InferenceExperimentConfig( From 6b785d7f51022103e9a01643006f5921af767b5f Mon Sep 17 00:00:00 2001 From: hmacdope Date: Thu, 28 Aug 2025 17:03:52 +1000 Subject: [PATCH 5/5] fix dict use --- openadmet/toolkit/cofolding/openfold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/openadmet/toolkit/cofolding/openfold.py b/openadmet/toolkit/cofolding/openfold.py index ebca9ea..b3cb207 100644 --- a/openadmet/toolkit/cofolding/openfold.py +++ b/openadmet/toolkit/cofolding/openfold.py @@ -71,7 +71,7 @@ def inference( if not HAS_OPENFOLD3: raise ImportError("OpenFold3 is not installed.") - runner_args = config_utils.load_yaml(runner_yaml) if runner_yaml else dict() + runner_args = config_utils.load_yaml(runner_yaml) if runner_yaml else {} expt_config = InferenceExperimentConfig( inference_ckpt_path=self.inference_ckpt_path, **runner_args