-
Notifications
You must be signed in to change notification settings - Fork 3
Add OpenFold3 inference #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| import gc | ||
| import os | ||
| import tempfile | ||
| import pandas as pd | ||
| from pathlib import Path | ||
| from shutil import copyfile | ||
| from typing import Optional, Union | ||
| import subprocess | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from loguru import logger | ||
|
|
||
| from pydantic import Field, field_validator | ||
|
|
||
| from openadmet.toolkit.cofolding.cofold_base import CoFoldingEngine | ||
|
|
||
|
|
||
| try: | ||
| from openfold3.core.config import config_utils | ||
| from openfold3.core.data.pipelines.preprocessing.template import TemplatePreprocessor | ||
| from openfold3.core.data.tools.colabfold_msa_server import preprocess_colabfold_msas | ||
| from openfold3.entry_points.experiment_runner import ( | ||
| InferenceExperimentRunner, | ||
| TrainingExperimentRunner, | ||
| ) | ||
| from openfold3.entry_points.validator import ( | ||
| InferenceExperimentConfig, | ||
| TrainingExperimentConfig, | ||
| ) | ||
| from openfold3.projects.of3_all_atom.config.dataset_config_components import ( | ||
| colabfold_msa_settings, | ||
| ) | ||
| from openfold3.projects.of3_all_atom.config.inference_query_format import ( | ||
| InferenceQuerySet, | ||
| ) | ||
| except ImportError: | ||
| HAS_OPENFOLD3 = False | ||
| else: | ||
| HAS_OPENFOLD3 = True | ||
|
|
||
|
|
||
| class OpenFold3CofoldingEngine(CoFoldingEngine): | ||
|
|
||
| use_msa_server: bool = Field( | ||
| False, description="Use MSA server for multiple sequence alignment" | ||
| ) | ||
| num_diffusion_samples: int = Field(None, description="Number of diffusion samples to generate for each query") | ||
|
|
||
| num_model_seeds: int = Field(None, description="Number of model seeds to use for each query") | ||
|
|
||
| use_msa_server: bool = Field(True, description="Use MSA server for multiple sequence alignment") | ||
|
|
||
| use_templates: bool = Field(False, description="Use templates for structure prediction") | ||
|
|
||
| inference_ckpt_path: Path = Field(description="Path to the inference checkpoint") | ||
|
|
||
|
|
||
| @field_validator("inference_ckpt_path") | ||
| def check_inference_ckpt_path(cls, v): | ||
| # path must exist | ||
| if not v.exists(): | ||
| raise ValueError(f"inference_ckpt_path must exist, got {v}") | ||
| return v | ||
|
|
||
| def inference( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring for inference() please! |
||
| self, | ||
| query_json: Path, | ||
| runner_yaml: Path | None = None, | ||
| ): | ||
| if not HAS_OPENFOLD3: | ||
| raise ImportError("OpenFold3 is not installed.") | ||
|
|
||
| runner_args = config_utils.load_yaml(runner_yaml) if runner_yaml else {} | ||
|
|
||
| expt_config = InferenceExperimentConfig( | ||
| inference_ckpt_path=self.inference_ckpt_path, **runner_args | ||
| ) | ||
| expt_runner = InferenceExperimentRunner( | ||
| expt_config, | ||
| self.num_diffusion_samples, | ||
| self.num_model_seeds, | ||
| self.use_msa_server, | ||
| self.use_templates, | ||
| self.output_dir, | ||
| ) | ||
|
|
||
| # Dump experiment runner | ||
| import json | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this not up with other imports? |
||
|
|
||
| with open(self.output_dir / "experiment_config.json", "w") as f: | ||
| json.dump(expt_config.model_dump_json(indent=2), f) | ||
|
|
||
| # Load inference query set | ||
| query_set = InferenceQuerySet.from_json(query_json) | ||
|
|
||
| # Perform MSA computation if selected | ||
| # update query_set with MSA paths | ||
| if expt_runner.use_msa_server: | ||
| logger.info("Using ColabFold MSA server for alignments.") | ||
| query_set = preprocess_colabfold_msas( | ||
| inference_query_set=query_set, | ||
| compute_settings=expt_config.msa_computation_settings, | ||
| ) | ||
|
|
||
| # Update the msa dataset config settings | ||
| updated_dataset_config_kwargs = expt_config.dataset_config_kwargs.model_copy( | ||
| update={"msa": colabfold_msa_settings} | ||
| ) | ||
| expt_config = expt_config.model_copy( | ||
| update={"dataset_config_kwargs": updated_dataset_config_kwargs} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A little more documentation here, I don't totally understand what these arguments are that you're changing in the expt_config |
||
| ) | ||
| else: | ||
| expt_config.msa_computation_settings.cleanup_msa_dir = False | ||
|
|
||
| # Preprocess template alignments and optionally template structures | ||
| if expt_runner.use_templates: | ||
| logger.info("Using templates for inference.") | ||
| template_preprocessor = TemplatePreprocessor( | ||
| input_set=query_set, | ||
| config=expt_config.dataset_config_kwargs.template_preprocessor, | ||
| ) | ||
| template_preprocessor() | ||
| else: | ||
| logger.info("Not using templates for inference.") | ||
|
|
||
| # Run the forward pass | ||
| expt_runner.setup() | ||
| expt_runner.run(query_set) | ||
| expt_runner.cleanup() | ||
|
|
||
| logger.info("Inference completed successfully.") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also add corresponding updates in cli/cofolding.py?