diff --git a/configs/models/tasks/is2re.yaml b/configs/models/tasks/is2re.yaml index 4b1fe4304f..209b4b2567 100644 --- a/configs/models/tasks/is2re.yaml +++ b/configs/models/tasks/is2re.yaml @@ -1,6 +1,7 @@ default: trainer: single logger: wandb + prevent_load: {} task: dataset: single_point_lmdb diff --git a/configs/models/tasks/qm7x.yaml b/configs/models/tasks/qm7x.yaml index defb898e91..7c5f590e55 100644 --- a/configs/models/tasks/qm7x.yaml +++ b/configs/models/tasks/qm7x.yaml @@ -1,6 +1,7 @@ default: trainer: single logger: wandb + prevent_load: {} eval_on_test: True model: diff --git a/configs/models/tasks/qm9.yaml b/configs/models/tasks/qm9.yaml index b13954b393..76371e50d6 100644 --- a/configs/models/tasks/qm9.yaml +++ b/configs/models/tasks/qm9.yaml @@ -2,6 +2,7 @@ default: trainer: single logger: wandb eval_on_test: True + prevent_load: {} model: otf_graph: False diff --git a/configs/models/tasks/s2ef.yaml b/configs/models/tasks/s2ef.yaml index ef62591945..9a54254adc 100644 --- a/configs/models/tasks/s2ef.yaml +++ b/configs/models/tasks/s2ef.yaml @@ -1,6 +1,8 @@ default: trainer: single logger: wandb + prevent_load: {} + task: dataset: trajectory_lmdb description: "Regressing to energies and forces for DFT trajectories from OCP" diff --git a/ocpmodels/common/gfn.py b/ocpmodels/common/gfn.py index 0a1f45521c..cc6b3657a2 100644 --- a/ocpmodels/common/gfn.py +++ b/ocpmodels/common/gfn.py @@ -1,16 +1,17 @@ +import os from copy import deepcopy from pathlib import Path -from typing import Callable, Union, List - -import os +from typing import Callable, List, Union +import torch import torch.nn as nn -from torch_geometric.data.data import Data from torch_geometric.data.batch import Batch +from torch_geometric.data.data import Data -from ocpmodels.common.utils import make_trainer_from_dir, resolve -from ocpmodels.models.faenet import FAENet +from ocpmodels.common.registry import registry +from ocpmodels.common.utils import resolve, setup_imports from ocpmodels.datasets.data_transforms import get_transforms +from ocpmodels.models.faenet import FAENet class FAENetWrapper(nn.Module): @@ -190,6 +191,37 @@ def parse_loc() -> str: return loc +def reset_data_paths(config): + """ + Reset config data paths to defaults, instead of SLURM temporary paths (inplace). + + Args: + config (dict): The trainer config dictionary to modify. + + Returns: + dict: The modified config dictionary. + """ + ds_configs = deepcopy(config["dataset"]) + task_name = config["task"]["name"] + if task_name != "is2re": + raise NotImplementedError( + "Only the is2re task is currently supported for resetting data paths." + + " To implement this for other tasks, modify how `base_path` is constructed" + " in `reset_data_paths()`" + ) + base_path = Path("/network/projects/ocp/oc20/is2re") + for name, ds_config in ds_configs.items(): + if not isinstance(ds_config, dict): + continue + if "slurm" in ds_config["src"].lower(): + ds_config["src"] = str( + base_path / ds_config["split"] / Path(ds_config["src"]).name + ) + config["dataset"][name] = ds_config + + return config + + def find_ckpt(ckpt_paths: dict, release: str) -> Path: """ Finds a checkpoint in a dictionary of paths, based on the current cluster name and @@ -223,7 +255,7 @@ def find_ckpt(ckpt_paths: dict, release: str) -> Path: if path.is_file(): return path path = path / release - ckpts = list(path.glob("**/*.ckpt")) + ckpts = list(path.glob("**/*.pt")) if len(ckpts) == 0: raise ValueError(f"No FAENet proxy checkpoint found at {str(path)}.") if len(ckpts) > 1: @@ -256,18 +288,22 @@ def prepare_for_gfn(ckpt_paths: dict, release: str) -> tuple: Returns: tuple: (model, loaders) where loaders is a dict of loaders for the model. """ + setup_imports() ckpt_path = find_ckpt(ckpt_paths, release) assert ckpt_path.exists(), f"Path {ckpt_path} does not exist." - trainer = make_trainer_from_dir( - ckpt_path, - mode="continue", - overrides={ - "is_debug": True, - "silent": True, - "cp_data_to_tmpdir": False, - }, - silent=True, - ) + config = torch.load(ckpt_path, map_location="cpu")["config"] + config["is_debug"] = True + config["silent"] = True + config["cp_data_to_tmpdir"] = False + config["prevent_load"] = { + "logger": True, + "loss": True, + "datasets": True, + "optimizer": True, + "extras": True, + } + config = reset_data_paths(config) + trainer = registry.get_trainer_class(config["trainer"])(**config) wrapper = FAENetWrapper( faenet=trainer.model, diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index b7a39d391b..bf27099853 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -755,7 +755,16 @@ def add_edge_distance_to_graph( # Copied from https://github.com/facebookresearch/mmf/blob/master/mmf/utils/env.py#L89. -def setup_imports(skip_imports=[]): +def setup_imports(skip_modules=[]): + """Automatically load all of the modules, so that they register within the registry. + + Parameters + ---------- + skip_modules : list, optional + List of modules (as ``str``) to skip while importing, by default []. Use module + names not paths, for instance, to skip ``ocpmodels.models.gemnet_oc.gemnet_oc``, + use ``skip_modules=["gemnet_oc"]``. + """ from ocpmodels.common.registry import registry try: @@ -803,7 +812,7 @@ def setup_imports(skip_imports=[]): splits = f.split(os.sep) file_name = splits[-1] module_name = file_name[: file_name.find(".py")] - if module_name not in skip_imports: + if module_name not in skip_modules: importlib.import_module("ocpmodels.%s.%s" % (key[1:], module_name)) # manual model imports @@ -1191,7 +1200,7 @@ def build_config(args, args_override=[], dict_overrides={}, silent=None): # load config from `model-task-split` pattern config = load_config(args.config) - # overwride with command-line args, including default values + # override with command-line args, including default values config = merge_dicts(config, args_dict_with_defaults) # override with build_config()'s overrides config = merge_dicts(config, overrides) @@ -1801,7 +1810,7 @@ def make_script_trainer(str_args=[], overrides={}, silent=False, mode="train"): return trainer -def make_config_from_dir(path, mode, overrides={}, silent=None, skip_imports=[]): +def make_config_from_dir(path, mode, overrides={}, silent=None, skip_modules=[]): """ Make a config from a directory. This is useful when restarting or continuing from a previous run. @@ -1838,11 +1847,11 @@ def make_config_from_dir(path, mode, overrides={}, silent=None, skip_imports=[]) config = build_config(default_args, silent=silent) config = merge_dicts(config, overrides) - setup_imports(skip_imports=skip_imports) + setup_imports(skip_modules=skip_modules) return config -def make_trainer_from_dir(path, mode, overrides={}, silent=None, skip_imports=[]): +def make_trainer_from_dir(path, mode, overrides={}, silent=None, skip_modules=[]): """ Make a trainer from a directory. @@ -1858,7 +1867,7 @@ def make_trainer_from_dir(path, mode, overrides={}, silent=None, skip_imports=[] Returns: Trainer: The loaded trainer. """ - config = make_config_from_dir(path, mode, overrides, silent, skip_imports) + config = make_config_from_dir(path, mode, overrides, silent, skip_modules) return registry.get_trainer_class(config["trainer"])(**config) diff --git a/ocpmodels/datasets/qm7x.py b/ocpmodels/datasets/qm7x.py index 7ecb7e0bf3..1345333f5e 100644 --- a/ocpmodels/datasets/qm7x.py +++ b/ocpmodels/datasets/qm7x.py @@ -20,10 +20,20 @@ from torch_geometric.data import Data from tqdm import tqdm -from cosmosis.dataset import CDataset from ocpmodels.common.registry import registry from ocpmodels.common.utils import ROOT +CDataset = object +try: + from cosmosis.dataset import CDataset +except ImportError: + print( + "Warning: `cosmosis` is not installed. `QM7X` will not be available.", + "See https://github.com/icanswim/cosmosis", + ) + print(f"(message from {Path(__file__).resolve()})") + + try: import orjson as json # noqa: F401 except: # noqa: E722 @@ -33,6 +43,7 @@ "`orjson` is not installed. ", "Consider `pip install orjson` to speed up json loading.", ) + print(f"(message from {Path(__file__).resolve()})") class Molecule: diff --git a/ocpmodels/models/comenet.py b/ocpmodels/models/comenet.py index ec8ada6e54..3aa8562a6e 100644 --- a/ocpmodels/models/comenet.py +++ b/ocpmodels/models/comenet.py @@ -1,9 +1,19 @@ -from dig.threedgraph.method import ComENet as DIGComENet -from ocpmodels.models.base_model import BaseModel +from copy import deepcopy + import torch + from ocpmodels.common.registry import registry from ocpmodels.common.utils import conditional_grad -from copy import deepcopy +from ocpmodels.models.base_model import BaseModel + +DIGComENet = None +try: + from dig.threedgraph.method import ComENet as DIGComENet +except ImportError: + from pathlib import Path + + print("Warning: `dig` is not installed. `SphereNet` will not be available.") + print(f"(message from {Path(__file__).resolve()})\n") @registry.register_model("comenet") diff --git a/ocpmodels/models/spherenet.py b/ocpmodels/models/spherenet.py index df0024fe8e..d0627a4661 100644 --- a/ocpmodels/models/spherenet.py +++ b/ocpmodels/models/spherenet.py @@ -1,9 +1,19 @@ -from dig.threedgraph.method import SphereNet as DIGSphereNet -from ocpmodels.models.base_model import BaseModel +from copy import deepcopy + import torch + from ocpmodels.common.registry import registry from ocpmodels.common.utils import conditional_grad -from copy import deepcopy +from ocpmodels.models.base_model import BaseModel + +DIGSphereNet = None +try: + from dig.threedgraph.method import SphereNet as DIGSphereNet +except ImportError: + from pathlib import Path + + print("Warning: `dig` is not installed. `SphereNet` will not be available.") + print(f"(message from {Path(__file__).resolve()})\n") @registry.register_model("spherenet") diff --git a/ocpmodels/modules/scheduler.py b/ocpmodels/modules/scheduler.py index af4107ebb9..4e9d0fd634 100644 --- a/ocpmodels/modules/scheduler.py +++ b/ocpmodels/modules/scheduler.py @@ -1,10 +1,11 @@ """scheduler.py """ + import inspect + import torch.optim.lr_scheduler as lr_scheduler from ocpmodels.common.utils import warmup_lr_lambda -import pytorch_warmup as warmup class LRScheduler: @@ -54,6 +55,8 @@ def scheduler_lambda_fn(x): if not self.silent: print(f"Using fidelity_max_steps for scheduler -> {T_max}") if self.optim_config["warmup_steps"] > 0: + import pytorch_warmup as warmup + self.warmup_scheduler = warmup.ExponentialWarmup( self.optimizer, warmup_period=self.optim_config["warmup_steps"] ) diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index e871027efe..40b446b2d8 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -4,6 +4,7 @@ This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. """ + import datetime import errno import logging @@ -56,7 +57,7 @@ @registry.register_trainer("base") class BaseTrainer(ABC): - def __init__(self, load=True, **kwargs): + def __init__(self, **kwargs): run_dir = kwargs["run_dir"] model_name = kwargs["model"].pop( @@ -76,9 +77,14 @@ def __init__(self, load=True, **kwargs): } self.sigterm = False - self.objective = None self.epoch = 0 self.step = 0 + self.objective = None + self.logger = None + self.parallel_collater = None + self.ema_decay = None + self.clip_grad_norm = None + self.scheduler = None self.cpu = self.config["cpu"] self.task_name = self.config["task"].get("name", self.config.get("name")) assert self.task_name, "Specify task name (got {})".format(self.task_name) @@ -90,6 +96,7 @@ def __init__(self, load=True, **kwargs): self.datasets = {} self.samplers = {} self.loaders = {} + self.normalizers = {} self.early_stopper = EarlyStopper( patience=self.config["optim"].get("es_patience") or 15, min_abs_change=self.config["optim"].get("es_min_abs_change") or 1e-5, @@ -189,21 +196,49 @@ def __init__(self, load=True, **kwargs): ) self.config["is_disconnected"] = True - self.load() + self.load(self.config.get("prevent_load")) self.evaluator = Evaluator( task=self.task_name, model_regresses_forces=self.config["model"].get("regress_forces", ""), ) - def load(self): - self.load_seed_from_config() - self.load_logger() - self.load_datasets() - self.load_task() - self.load_model() - self.load_loss() - self.load_optimizer() - self.load_extras() + def load(self, prevent_load={}): + """Load all components of the trainer. + + Arbitrary components can be prevented from loading by specifying them in the + ``prevent_load`` dictionary. Allowed keys are: + + - "seed" + - "logger" + - "datasets" + - "task" + - "model" + - "loss" + - "optimizer" + - "extras" + + Parameters + ---------- + prevent_load : dict, optional + Dictionary describing loading events that should be prevented, by default ``{}`` + """ + prevent_load = prevent_load or {} + if not prevent_load.get("seed"): + self.load_seed_from_config() + if not prevent_load.get("logger"): + self.load_logger() + if not prevent_load.get("datasets"): + self.load_datasets() + if not prevent_load.get("task"): + self.load_task() + if not prevent_load.get("model"): + self.load_model() + if not prevent_load.get("loss"): + self.load_loss() + if not prevent_load.get("optimizer"): + self.load_optimizer() + if not prevent_load.get("extras"): + self.load_extras() def load_seed_from_config(self): # https://pytorch.org/docs/stable/notes/randomness.html @@ -220,7 +255,6 @@ def load_seed_from_config(self): torch.backends.cudnn.benchmark = False def load_logger(self): - self.logger = None if not self.is_debug and dist_utils.is_master() and not self.is_hpo: assert self.config["logger"] is not None, "Specify logger in config" @@ -380,7 +414,6 @@ def load_datasets(self): # Normalizer for the dataset. # Compute mean, std of training set labels. - self.normalizers = {} if self.normalizer.get("normalize_labels", False): if "target_mean" in self.normalizer: self.normalizers["target"] = Normalizer( @@ -619,9 +652,11 @@ def save( "step": self.step, "state_dict": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), - "scheduler": self.scheduler.scheduler.state_dict() - if self.scheduler.scheduler_type != "Null" - else None, + "scheduler": ( + self.scheduler.scheduler.state_dict() + if self.scheduler.scheduler_type != "Null" + else None + ), "normalizers": { key: value.state_dict() for key, value in self.normalizers.items() @@ -632,9 +667,9 @@ def save( "amp": self.scaler.state_dict() if self.scaler else None, } if self.scheduler.warmup_scheduler is not None: - ckpt_dict[ - "warmup_scheduler" - ] = self.scheduler.warmup_scheduler.state_dict() + ckpt_dict["warmup_scheduler"] = ( + self.scheduler.warmup_scheduler.state_dict() + ) save_checkpoint( ckpt_dict, diff --git a/pyproject.toml b/pyproject.toml index fed6c97593..738fcd9d40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,57 @@ +[build-system] +# A list of packages that are needed to build your package: +requires = ["setuptools"] # REQUIRED if [build-system] table is used +# The name of the Python object that frontends will use to perform the build: +build-backend = "setuptools.build_meta" # If not defined, then legacy behavior can happen. + +[project] + +name = "ocpmodels" # REQUIRED, is the only field that cannot be marked as dynamic. +version = "0.1.0" # REQUIRED, although can be dynamic +description = "RolnickLab's OCP fork" +readme = "README.md" +requires-python = ">=3.8" +license = { file = "LICENSE.md" } + + +dependencies = [ + "ase>=3.19.3", + "black>=23.1.0", + "CatKit @ git+https://github.com/vict0rsch/CatKit.git@df7f1aa7a47eb7b8022452fa77e4ce60cd006a7d", + "dive-into-graphs @ git+https://github.com/divelab/DIG.git@55c40a7b0938d3804d9f265193cb45a2fe80da8c", + "e3nn==0.5.1", + "flake8>=6.0.0", + "h5py>=3.8.0", + "lmdb>=1.4.0", + "matplotlib>=3.7.0", + "mendeleev>=0.12", + "minydra==0.1.6", + "orjson>=3.8", + "pytorch-warmup>=0.1", + "pymatgen>=2023.2", + "PyYAML>=6.0", + "rdkit>=2022.9.5", + "rich", + "ruamel.yaml", + "scikit-learn", + "scikit-optimize", + "tensorboard", + "torch>=1.12", + "torch_geometric==2.3.0", + "tqdm>=4.66", + "wandb", +] + +[project.optional-dependencies] +geom = [ + "pyg_lib", + "torch_scatter", + "torch_sparse", + "torch_cluster", + "torch_spline_conv", +] +dev = ["ipdb", "ipython", "pytest"] + [tool.black] line-length = 88 include = '\.pyi?$'