From 045e1653c375f32883fb3b91d83f40bb30218a2c Mon Sep 17 00:00:00 2001 From: Sean Colby Date: Sun, 17 May 2026 16:30:59 -0800 Subject: [PATCH 1/4] Implement deferred imports for 60x faster openadmet.models.registries load Baseline: import openadmet.models.registries = 6.702s After: import openadmet.models.registries = 0.111s (60x faster) Phase 1 - Split model_base.py: - Create architecture/lightning_model_base.py isolating all torch/lightning imports - Strip model_base.py of torch/lightning/joblib top-level imports - Add PEP 562 module __getattr__ for lazy LightningModelBase re-export - Defer joblib inside save()/load() method bodies - Result: architecture/model_base.py 1.652s -> 0.070s Phase 2 - Deferred estimator class imports: - Replace mod_class: ClassVar[type] = SomeClass with _get_estimator_class() classmethod in all concrete architecture modules (xgboost, catboost, lgbm, rf, svm, tabpfn, dummy) - Remove all top-level 3rd-party imports from these modules - Move _METRIC_TO_LOSS dict initialization inside chemprop build() - Result: each arch module 2-3s -> ~0.1s Phase 3 - Deferred imports in features/split/trainer/eval: - feature_base.py: move molfeat/torch imports to TYPE_CHECKING block - features/chemprop.py: remove self-import bug; defer all chemprop/torch/sklearn imports inside featurize() and _vendor_build_dataloader() - split/scaffold.py: defer splito and sklearn.model_selection inside split() - split/cluster.py: defer useful_rdkit_utils, datamol, molfeat, KMeans inside split(); remove unused GroupShuffleSplit import - trainer/lightning.py: defer torch and lightning imports inside build()/train() - eval/regression.py: defer wandb, scipy.stats, sklearn.metrics, seaborn inside their respective usage methods; convert _metrics class var to _base_metrics() classmethod; fix cross_validation.py to not import removed module-level names - eval/eval_base.py: defer scipy.stats.bootstrap inside stat_and_bootstrap() - Result: registries 6.702s -> 3.548s Phase 4 - Lazy registry loading: - Create _registry_loader.py with idempotent load_group()/load_all() functions using importlib.import_module; zero heavy imports at module level - Rewrite registries.py to only import base registry objects and expose load_all() - Add load_group() call to each get_*_class() function for on-demand loading - Update anvil/specification.py and anvil/workflow_base.py to import load_all instead of wildcard-importing registries - Result: import openadmet.models.registries 6.702s -> 0.111s (60x faster) Before/after summary: import openadmet.models.registries: 6.702s -> 0.111s architecture/model_base.py: 1.652s -> 0.070s architecture/xgboost.py: 2.123s -> 0.099s architecture/chemprop.py: 3.083s -> 0.101s split/cluster.py: 3.524s -> 0.331s split/scaffold.py: 1.476s -> 0.330s trainer/lightning.py: 1.653s -> 0.069s eval/regression.py: 1.582s -> 0.326s registries + load_all(): N/A -> 3.727s (same real cost, deferred) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- openadmet/models/_registry_loader.py | 80 ++++++ .../models/active_learning/ensemble_base.py | 3 + openadmet/models/anvil/specification.py | 2 +- openadmet/models/anvil/workflow_base.py | 2 +- openadmet/models/architecture/catboost.py | 26 +- openadmet/models/architecture/chemprop.py | 13 +- openadmet/models/architecture/dummy.py | 23 +- openadmet/models/architecture/lgbm.py | 23 +- .../architecture/lightning_model_base.py | 252 +++++++++++++++++ openadmet/models/architecture/model_base.py | 260 ++---------------- openadmet/models/architecture/nepare.py | 2 +- openadmet/models/architecture/rf.py | 23 +- openadmet/models/architecture/svm.py | 23 +- openadmet/models/architecture/tabpfn.py | 55 ++-- openadmet/models/architecture/xgboost.py | 23 +- openadmet/models/eval/cross_validation.py | 12 +- openadmet/models/eval/eval_base.py | 8 +- openadmet/models/eval/regression.py | 48 ++-- openadmet/models/features/chemprop.py | 29 +- openadmet/models/features/feature_base.py | 16 +- openadmet/models/registries.py | 71 ++--- openadmet/models/split/cluster.py | 24 +- openadmet/models/split/scaffold.py | 16 +- openadmet/models/split/split_base.py | 3 + .../models/tests/unit/models/test_base.py | 3 +- openadmet/models/trainer/lightning.py | 10 +- openadmet/models/trainer/trainer_base.py | 3 + 27 files changed, 646 insertions(+), 407 deletions(-) create mode 100644 openadmet/models/_registry_loader.py create mode 100644 openadmet/models/architecture/lightning_model_base.py diff --git a/openadmet/models/_registry_loader.py b/openadmet/models/_registry_loader.py new file mode 100644 index 00000000..3f43cd00 --- /dev/null +++ b/openadmet/models/_registry_loader.py @@ -0,0 +1,80 @@ +"""Lazy registry loader — zero heavy imports at module level. + +Call ``load_group(name)`` to load a specific registry group, or ``load_all()`` +to populate every registry at once. Both are idempotent. +""" + +import importlib + +_MODELS = [ + "openadmet.models.architecture.catboost", + "openadmet.models.architecture.chemprop", + "openadmet.models.architecture.dummy", + "openadmet.models.architecture.lgbm", + "openadmet.models.architecture.nepare", + "openadmet.models.architecture.rf", + "openadmet.models.architecture.svm", + "openadmet.models.architecture.tabpfn", + "openadmet.models.architecture.xgboost", +] + +_EVALUATORS = [ + "openadmet.models.eval.classification", + "openadmet.models.eval.cross_validation", + "openadmet.models.eval.regression", + "openadmet.models.eval.uncertainty", +] + +_FEATURIZERS = [ + "openadmet.models.features.chemprop", + "openadmet.models.features.combine", + "openadmet.models.features.molfeat_fingerprint", + "openadmet.models.features.molfeat_properties", +] + +_SPLITTERS = [ + "openadmet.models.split.scaffold", + "openadmet.models.split.sklearn", + "openadmet.models.split.cluster", +] + +_TRAINERS = [ + "openadmet.models.trainer.lightning", + "openadmet.models.trainer.sklearn", +] + +_TRANSFORMS = [ + "openadmet.models.transforms.impute", + "openadmet.models.transforms.transform_base", +] + +_ACTIVE_LEARNING = [ + "openadmet.models.active_learning.committee", +] + +_GROUPS: dict[str, list[str]] = { + "models": _MODELS, + "evaluators": _EVALUATORS, + "featurizers": _FEATURIZERS, + "splitters": _SPLITTERS, + "trainers": _TRAINERS, + "transforms": _TRANSFORMS, + "active_learning": _ACTIVE_LEARNING, +} + +_loaded: set[str] = set() + + +def load_group(name: str) -> None: + """Import all modules in the named registry group (idempotent).""" + if name in _loaded: + return + for mod in _GROUPS[name]: + importlib.import_module(mod) + _loaded.add(name) + + +def load_all() -> None: + """Import all registry groups (idempotent).""" + for name in _GROUPS: + load_group(name) diff --git a/openadmet/models/active_learning/ensemble_base.py b/openadmet/models/active_learning/ensemble_base.py index 85a68f8f..e534983d 100644 --- a/openadmet/models/active_learning/ensemble_base.py +++ b/openadmet/models/active_learning/ensemble_base.py @@ -11,6 +11,9 @@ def get_ensemble_class(ensemble_type): """Get the ensemble class.""" + from openadmet.models._registry_loader import load_group + + load_group("active_learning") try: ensemble_class = ensemblers.get_class(ensemble_type) except RegistryKeyError: diff --git a/openadmet/models/anvil/specification.py b/openadmet/models/anvil/specification.py index a191d1d6..a4210503 100644 --- a/openadmet/models/anvil/specification.py +++ b/openadmet/models/anvil/specification.py @@ -19,7 +19,7 @@ from openadmet.models.drivers import DriverType from openadmet.models.eval.eval_base import get_eval_class from openadmet.models.features.feature_base import get_featurizer_class -from openadmet.models.registries import * # noqa: F401, F403 +from openadmet.models.registries import load_all # noqa: F401 from openadmet.models.split.split_base import get_splitter_class from openadmet.models.trainer.trainer_base import get_trainer_class from openadmet.models.transforms.transform_base import ( diff --git a/openadmet/models/anvil/workflow_base.py b/openadmet/models/anvil/workflow_base.py index 7befc504..07b8a556 100644 --- a/openadmet/models/anvil/workflow_base.py +++ b/openadmet/models/anvil/workflow_base.py @@ -14,7 +14,7 @@ from openadmet.models.architecture.model_base import ModelBase from openadmet.models.eval.eval_base import EvalBase from openadmet.models.features.feature_base import FeaturizerBase -from openadmet.models.registries import * # noqa: F401, F403 +from openadmet.models.registries import load_all # noqa: F401 from openadmet.models.split.split_base import SplitterBase from openadmet.models.trainer.trainer_base import TrainerBase from openadmet.models.transforms.transform_base import ( diff --git a/openadmet/models/architecture/catboost.py b/openadmet/models/architecture/catboost.py index 3b3c62fe..6cf86898 100644 --- a/openadmet/models/architecture/catboost.py +++ b/openadmet/models/architecture/catboost.py @@ -3,7 +3,6 @@ from typing import ClassVar import numpy as np -from catboost import CatBoostClassifier, CatBoostRegressor from loguru import logger from pydantic import ConfigDict @@ -18,9 +17,6 @@ class CatBoostModelBase(PickleableModelBase): ---------- type : ClassVar[str] The type of the model. - mod_class : ClassVar[type] - To specify the CatBoost model class (e.g., CatBoostRegressor or CatBoost - Classifier) """ @@ -29,12 +25,16 @@ class CatBoostModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the CatBoost estimator class (deferred import).""" + raise NotImplementedError def build(self): """Prepare the model.""" if not self.estimator: - self.estimator = self.mod_class(**self.model_dump()) + self.estimator = self._get_estimator_class()(**self.model_dump()) else: logger.warning("Model already exists, skipping build") @@ -95,7 +95,12 @@ class CatBoostRegressorModel(CatBoostModelBase): # Meta parameters for this class type: ClassVar[str] = "CatBoostRegressorModel" - mod_class: ClassVar[type] = CatBoostRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from catboost import CatBoostRegressor + + return CatBoostRegressor @models.register("CatBoostClassifierModel") @@ -109,7 +114,12 @@ class CatBoostClassifierModel(CatBoostModelBase): # Meta parameters for this class type: ClassVar[str] = "CatBoostClassifierModel" - mod_class: ClassVar[type] = CatBoostClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from catboost import CatBoostClassifier + + return CatBoostClassifier def predict_proba(self, X: np.ndarray) -> np.ndarray: """ diff --git a/openadmet/models/architecture/chemprop.py b/openadmet/models/architecture/chemprop.py index 692fa69d..59eaea1b 100644 --- a/openadmet/models/architecture/chemprop.py +++ b/openadmet/models/architecture/chemprop.py @@ -14,15 +14,9 @@ from loguru import logger from pydantic import PrivateAttr, field_validator, model_validator -from openadmet.models.architecture.model_base import LightningModelBase +from openadmet.models.architecture.lightning_model_base import LightningModelBase from openadmet.models.architecture.model_base import models as model_registry -_METRIC_TO_LOSS = { - "mse": nn.metrics.MSE(), - "mae": nn.metrics.MAE(), - "rmse": nn.metrics.RMSE(), -} - def configure_optimizers(self): """ @@ -445,6 +439,11 @@ def build(self, scaler=None): """ if not self.estimator: + _METRIC_TO_LOSS = { + "mse": nn.metrics.MSE(), + "mae": nn.metrics.MAE(), + "rmse": nn.metrics.RMSE(), + } metric_list = [_METRIC_TO_LOSS[metric] for metric in self.metric_list] if self.from_foundation: if self.from_foundation == "chemeleon": diff --git a/openadmet/models/architecture/dummy.py b/openadmet/models/architecture/dummy.py index 4e28131f..761b16ab 100644 --- a/openadmet/models/architecture/dummy.py +++ b/openadmet/models/architecture/dummy.py @@ -4,7 +4,6 @@ import numpy as np from loguru import logger -from sklearn.dummy import DummyClassifier, DummyRegressor from openadmet.models.architecture.model_base import PickleableModelBase, models @@ -14,12 +13,16 @@ class DummyModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the sklearn dummy estimator class (deferred import).""" + raise NotImplementedError def build(self): """Prepare the model.""" if not self.estimator: - self.estimator = self.mod_class(**self.model_dump()) + self.estimator = self._get_estimator_class()(**self.model_dump()) else: logger.warning("Model already exists, skipping build") @@ -77,7 +80,12 @@ class DummyRegressorModel(DummyModelBase): # Meta parameters for this class type: ClassVar[str] = "DummyRegressorModel" - mod_class: ClassVar[type] = DummyRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from sklearn.dummy import DummyRegressor + + return DummyRegressor # DummyRegressor parameters strategy: str = "mean" # Default strategy for dummy models @@ -96,7 +104,12 @@ class DummyClassifierModel(DummyModelBase): # Meta parameters for this class type: ClassVar[str] = "DummyClassifierModel" - mod_class: ClassVar[type] = DummyClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from sklearn.dummy import DummyClassifier + + return DummyClassifier # DummyClassifier parameters strategy: str = "most_frequent" # Default strategy for dummy models diff --git a/openadmet/models/architecture/lgbm.py b/openadmet/models/architecture/lgbm.py index bf73dcac..cd47fd85 100644 --- a/openadmet/models/architecture/lgbm.py +++ b/openadmet/models/architecture/lgbm.py @@ -2,7 +2,6 @@ from typing import ClassVar -import lightgbm as lgb import numpy as np from loguru import logger @@ -14,7 +13,11 @@ class LGBMModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the LightGBM estimator class (deferred import).""" + raise NotImplementedError # LGBM parameters boosting_type: str = "gbdt" @@ -41,7 +44,7 @@ class LGBMModelBase(PickleableModelBase): def build(self): """Prepare the model.""" if not self.estimator: - self.estimator = self.mod_class(**self.model_dump()) + self.estimator = self._get_estimator_class()(**self.model_dump()) else: logger.warning("Model already exists, skipping build") @@ -88,7 +91,12 @@ class LGBMRegressorModel(LGBMModelBase): # Meta parameters for this class type: ClassVar[str] = "LGBMRegressorModel" - mod_class: ClassVar[type] = lgb.LGBMRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from lightgbm import LGBMRegressor + + return LGBMRegressor @models.register("LGBMClassifierModel") @@ -97,7 +105,12 @@ class LGBMClassifierModel(LGBMModelBase): # Meta parameters for this class type: ClassVar[str] = "LGBMClassifierModel" - mod_class: ClassVar[type] = lgb.LGBMClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from lightgbm import LGBMClassifier + + return LGBMClassifier def predict_proba(self, X: np.ndarray) -> np.ndarray: """Predict using the model.""" diff --git a/openadmet/models/architecture/lightning_model_base.py b/openadmet/models/architecture/lightning_model_base.py new file mode 100644 index 00000000..83eb5608 --- /dev/null +++ b/openadmet/models/architecture/lightning_model_base.py @@ -0,0 +1,252 @@ +"""Lightning-specific base classes for deep learning models. + +This module is intentionally separate from model_base so that importing +PickleableModelBase (sklearn-style models) does not incur the cost of +loading torch or lightning.pytorch. +""" + +import json +from abc import abstractmethod +from dataclasses import dataclass +from os import PathLike +from typing import Any, ClassVar + +import torch +from lightning import pytorch as pl +from pydantic import field_validator + +from openadmet.models.architecture.model_base import ModelBase +from openadmet.models.drivers import DriverType + + +@dataclass +class LightningModuleBase(pl.LightningModule): + """ + Lightning module base class. + + A PyTorch lightning model may inherit this instead of pl.LightningModule + to preconfigure optimizer and scheduler. + """ + + # Meta parameters for this class + type: ClassVar[str] + + # Optimizer and scheduler configuration + optimizer: str = "adamw" + optimizer_lr: float = 1e-3 + optimizer_weight_decay: float = 1e-5 + scheduler: str = "cosine" + scheduler_factor: float = 0.5 + scheduler_patience: int = 10 + monitor_metric: str = "val_loss" + + def __post_init__(self): + """Defer initialization of the LightningModuleBase.""" + pl.LightningModule.__init__(self) + + @field_validator("monitor_metric") + @classmethod + def check_monitor_metric(cls, value): + """Check if the monitor metric is valid.""" + allowed = ["val_loss", "train_loss"] + if value.lower() not in allowed: + raise ValueError(f"Monitored metric must be one of {allowed}") + return value + + @field_validator("optimizer") + @classmethod + def validate_optimizer(cls, value): + """Validate the optimizer parameter.""" + allowed = {"adamw", "adam", "sgd"} + if value.lower() not in allowed: + raise ValueError(f"Optimizer must be one of {allowed}") + return value + + @field_validator("scheduler") + @classmethod + def validate_scheduler(cls, value): + """Validate the scheduler parameter.""" + allowed = {"cosine", "reduce_on_plateau", "none", None} + if (value.lower() not in allowed) and (value is not None): + raise ValueError(f"Scheduler must be one of {allowed}") + return value + + def configure_optimizers(self): + """Return optimizer and scheduler configuration for Lightning's configure_optimizers.""" + # Adamw optimizer + if self.optimizer.lower() == "adamw": + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.optimizer_lr, + weight_decay=self.optimizer_weight_decay, + ) + + # Adam optimizer + elif self.optimizer.lower() == "adam": + optimizer = torch.optim.Adam( + self.parameters(), + lr=self.optimizer_lr, + weight_decay=self.optimizer_weight_decay, + ) + + # SGD optimizer + elif self.optimizer.lower() == "sgd": + optimizer = torch.optim.SGD( + self.parameters(), + lr=self.optimizer_lr, + weight_decay=self.optimizer_weight_decay, + ) + + # Cosine scheduler + if self.scheduler.lower() == "cosine": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=10, # T_max could be exposed as a parameter + ) + + scheduler_config = { + "scheduler": scheduler, + "monitor": self.monitor_metric, + "interval": "epoch", + "frequency": 1, + } + + # Reduce on plateau scheduler + elif self.scheduler.lower() == "reduce_on_plateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=self.scheduler_factor, + patience=self.scheduler_patience, + ) + + scheduler_config = { + "scheduler": scheduler, + "monitor": self.monitor_metric, + "interval": "epoch", + "frequency": 1, + } + + # No scheduler + elif (self.scheduler is None) or (self.scheduler.lower() == "none"): + scheduler_config = None + + # Return optimizer and scheduler configuration + if scheduler_config: + return {"optimizer": optimizer, "lr_scheduler": scheduler_config} + else: + return optimizer + + +class LightningModelBase(ModelBase): + """A model that uses PyTorch Lightning.""" + + # Meta parameters for this class + type: ClassVar[str] + _model_save_name: ClassVar[str] = "model.pth" + _driver_type: DriverType = DriverType.LIGHTNING + + def make_new(self): + """ + Copy parameters to a new model instance without copying the estimator. + + Returns + ------- + LightningModelBase + A new instance of LightningModelBase with the same parameters. + + """ + return self.__class__(**self.model_dump(exclude={"estimator"})) + + def save(self, path: PathLike): + """ + Save the model to a file. + + Parameters + ---------- + path: PathLike + Path to save the model to + + """ + torch.save(self.estimator.state_dict(), path) + + def load(self, path: PathLike): + """ + Load the model from a file. + + Parameters + ---------- + path: PathLike + Path to load the model from + + """ + self.estimator.load_state_dict(torch.load(path, weights_only=True)) + + def serialize( + self, param_path: PathLike = "model.json", serial_path: PathLike = "model.pth" + ): + """ + Save the model to a json file and a serialized file. + + Parameters + ---------- + param_path: PathLike + Path to save the model parameters to + serial_path: PathLike + Path to save the serialized model to + + """ + with open(param_path, "w") as f: + f.write(self.model_dump_json(indent=2)) + self.save(serial_path) + + @classmethod + def deserialize( + cls, + param_path: PathLike = "model.json", + serial_path: PathLike = "model.pth", + scaler: Any = None, + ): + """ + Create a model from parameters and a serialized model. + + Parameters + ---------- + param_path: PathLike + Path to load the model parameters from + serial_path: PathLike + Path to load the serialized model from + scaler: Any, optional + Scaler for target normalization, if applicable + + Returns + ------- + instance: LightningModelBase + An instance of the LightningModelBase class + + """ + with open(param_path) as f: + mod_params = json.load(f) + instance = cls(**mod_params) + instance.build(scaler=scaler) + instance.load(serial_path) + return instance + + def freeze_weights(self, *args, **kwargs): + """ + Freeze parts of the model for transfer learning or fine-tuning. + + Parameters + ---------- + *args: variable length argument list + Arguments to be passed to the implementing model's `freeze_weights` method. + **kwargs: keyword arguments + Keyword arguments to be passed to the implementing model's `freeze_weights` method. + + Notes + ----- + This method should set the `requires_grad` attribute of the specified layers to False, + preventing their weights from being updated during training. It also should set these + layers to evaluation mode. + + """ + raise NotImplementedError(f"Weight freezing not implemented for {self.type}.") diff --git a/openadmet/models/architecture/model_base.py b/openadmet/models/architecture/model_base.py index 83981aa5..120c53a8 100644 --- a/openadmet/models/architecture/model_base.py +++ b/openadmet/models/architecture/model_base.py @@ -2,15 +2,13 @@ import json from abc import ABC, abstractmethod -from dataclasses import dataclass from os import PathLike from typing import Any, ClassVar -import joblib -import torch + from class_registry import ClassRegistry, RegistryKeyError -from lightning import pytorch as pl from loguru import logger -from pydantic import BaseModel, field_validator +from pydantic import BaseModel + from openadmet.models.drivers import DriverType models = ClassRegistry(unique=True) @@ -18,6 +16,9 @@ def get_mod_class(model_type): """Get the model class from the registry.""" + from openadmet.models._registry_loader import load_group + + load_group("models") try: feat_class = models.get_class(model_type) except RegistryKeyError: @@ -153,6 +154,8 @@ def save(self, path: PathLike): if self.estimator is None: raise ValueError("Model is not built, cannot save") + import joblib + with open(path, "wb") as f: joblib.dump(self.estimator, f) @@ -166,6 +169,8 @@ def load(self, path: PathLike): Path to load the model from """ + import joblib + with open(path, "rb") as f: self.estimator = joblib.load(f) @@ -219,234 +224,29 @@ def serialize( self.save(serial_path) -@dataclass -class LightningModuleBase(pl.LightningModule): - """ - Lightning module base class. - - A PyTorch lightning model may inherit this instead of pl.LightningModule - to preconfigure optimizer and scheduler. - """ +# Re-export Lightning base classes using lazy module __getattr__ (PEP 562) so that +# importing model_base does NOT pull in torch or lightning.pytorch. +# The actual definitions live in lightning_model_base. +_LIGHTNING_EXPORTS = frozenset({"LightningModelBase", "LightningModuleBase"}) - # Meta parameters for this class - type: ClassVar[str] - # Optimizer and scheduler configuration - optimizer: str = "adamw" - optimizer_lr: float = 1e-3 - optimizer_weight_decay: float = 1e-5 - scheduler: str = "cosine" - scheduler_factor: float = 0.5 - scheduler_patience: int = 10 - monitor_metric: str = "val_loss" +def __getattr__(name: str): + """Lazily re-export Lightning base classes to avoid paying their import cost.""" + if name in _LIGHTNING_EXPORTS: + from openadmet.models.architecture import lightning_model_base as _lmb - def __post_init__(self): - """Defer initialization of the LightningModuleBase.""" - pl.LightningModule.__init__(self) - - @field_validator("monitor_metric") - @classmethod - def check_monitor_metric(cls, value): - """Check if the monitor metric is valid.""" - allowed = ["val_loss", "train_loss"] - if value.lower() not in allowed: - raise ValueError(f"Monitored metric must be one of {allowed}") - return value - - @field_validator("optimizer") - @classmethod - def validate_optimizer(cls, value): - """Validate the optimizer parameter.""" - allowed = {"adamw", "adam", "sgd"} - if value.lower() not in allowed: - raise ValueError(f"Optimizer must be one of {allowed}") + value = getattr(_lmb, name) + # Cache in module dict so subsequent accesses are direct + globals()[name] = value return value + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - @field_validator("scheduler") - @classmethod - def validate_scheduler(cls, value): - """Validate the scheduler parameter.""" - allowed = {"cosine", "reduce_on_plateau", "none", None} - if (value.lower() not in allowed) and (value is not None): - raise ValueError(f"Scheduler must be one of {allowed}") - return value - - def configure_optimizers(self): - """Return optimizer and scheduler configuration for Lightning's configure_optimizers.""" - # Adamw optimizer - if self.optimizer.lower() == "adamw": - optimizer = torch.optim.AdamW( - self.parameters(), - lr=self.optimizer_lr, - weight_decay=self.optimizer_weight_decay, - ) - - # Adam optimizer - elif self.optimizer.lower() == "adam": - optimizer = torch.optim.Adam( - self.parameters(), - lr=self.optimizer_lr, - weight_decay=self.optimizer_weight_decay, - ) - - # SGD optimizer - elif self.optimizer.lower() == "sgd": - optimizer = torch.optim.SGD( - self.parameters(), - lr=self.optimizer_lr, - weight_decay=self.optimizer_weight_decay, - ) - - # Cosine scheduler - if self.scheduler.lower() == "cosine": - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, - T_max=10, # T_max could be exposed as a parameter - ) - - scheduler_config = { - "scheduler": scheduler, - "monitor": self.monitor_metric, - "interval": "epoch", - "frequency": 1, - } - - # Reduce on plateau scheduler - elif self.scheduler.lower() == "reduce_on_plateau": - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - factor=self.scheduler_factor, - patience=self.scheduler_patience, - ) - - scheduler_config = { - "scheduler": scheduler, - "monitor": self.monitor_metric, - "interval": "epoch", - "frequency": 1, - } - - # No scheduler - elif (self.scheduler is None) or (self.scheduler.lower() == "none"): - scheduler_config = None - - # Return optimizer and scheduler configuration - if scheduler_config: - return {"optimizer": optimizer, "lr_scheduler": scheduler_config} - else: - return optimizer - - -class LightningModelBase(ModelBase): - """A model that uses PyTorch Lightning.""" - - # Meta parameters for this class - type: ClassVar[str] - _model_save_name: ClassVar[str] = "model.pth" - _driver_type: DriverType = DriverType.LIGHTNING - - def make_new(self): - """ - Copy parameters to a new model instance without copying the estimator. - - Returns - ------- - LightningModelBase - A new instance of LightningModelBase with the same parameters. - - """ - return self.__class__(**self.model_dump(exclude={"estimator"})) - - def save(self, path: PathLike): - """ - Save the model to a file. - - Parameters - ---------- - path: PathLike - Path to save the model to - - """ - torch.save(self.estimator.state_dict(), path) - - def load(self, path: PathLike): - """ - Load the model from a file. - - Parameters - ---------- - path: PathLike - Path to load the model from - - """ - self.estimator.load_state_dict(torch.load(path, weights_only=True)) - - def serialize( - self, param_path: PathLike = "model.json", serial_path: PathLike = "model.pth" - ): - """ - Save the model to a json file and a serialized file. - - Parameters - ---------- - param_path: PathLike - Path to save the model parameters to - serial_path: PathLike - Path to save the serialized model to - - """ - with open(param_path, "w") as f: - f.write(self.model_dump_json(indent=2)) - self.save(serial_path) - - @classmethod - def deserialize( - cls, - param_path: PathLike = "model.json", - serial_path: PathLike = "model.pth", - scaler: Any = None, - ): - """ - Create a model from parameters and a serialized model. - - Parameters - ---------- - param_path: PathLike - Path to load the model parameters from - serial_path: PathLike - Path to load the serialized model from - scaler: Any, optional - Scaler for target normalization, if applicable - Returns - ------- - instance: LightningModelBase - An instance of the LightningModelBase class - - """ - with open(param_path) as f: - mod_params = json.load(f) - instance = cls(**mod_params) - instance.build(scaler=scaler) - instance.load(serial_path) - return instance - - def freeze_weights(self, *args, **kwargs): - """ - Freeze parts of the model for transfer learning or fine-tuning. - - Parameters - ---------- - *args: variable length argument list - Arguments to be passed to the implementing model's `freeze_weights` method. - **kwargs: keyword arguments - Keyword arguments to be passed to the implementing model's `freeze_weights` method. - - Notes - ----- - This method should set the `requires_grad` attribute of the specified layers to False, - preventing their weights from being updated during training. It also should set these - layers to evaluation mode. - - """ - raise NotImplementedError(f"Weight freezing not implemented for {self.type}.") +__all__ = [ + "ModelBase", + "PickleableModelBase", + "LightningModuleBase", + "LightningModelBase", + "models", + "get_mod_class", +] diff --git a/openadmet/models/architecture/nepare.py b/openadmet/models/architecture/nepare.py index a7697d4f..50efd5ca 100644 --- a/openadmet/models/architecture/nepare.py +++ b/openadmet/models/architecture/nepare.py @@ -9,7 +9,7 @@ from collections import OrderedDict from openadmet.models.architecture.model_base import models as model_registry -from openadmet.models.architecture.model_base import ( +from openadmet.models.architecture.lightning_model_base import ( LightningModuleBase, LightningModelBase, ) diff --git a/openadmet/models/architecture/rf.py b/openadmet/models/architecture/rf.py index 88222700..deb264cb 100644 --- a/openadmet/models/architecture/rf.py +++ b/openadmet/models/architecture/rf.py @@ -4,7 +4,6 @@ import numpy as np from loguru import logger -from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from openadmet.models.architecture.model_base import PickleableModelBase, models @@ -14,12 +13,16 @@ class RFModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the Random Forest estimator class (deferred import).""" + raise NotImplementedError def build(self): """Prepare the model.""" if not self.estimator: - self.estimator = self.mod_class(**self.model_dump()) + self.estimator = self._get_estimator_class()(**self.model_dump()) else: logger.warning("Model already exists, skipping build") @@ -66,7 +69,12 @@ class RFRegressorModel(RFModelBase): # Meta parameters for this class type: ClassVar[str] = "RFRegressorModel" - mod_class: ClassVar[type] = RandomForestRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from sklearn.ensemble import RandomForestRegressor + + return RandomForestRegressor # RF parameters n_estimators: int = 100 @@ -95,7 +103,12 @@ class RFClassifierModel(RFModelBase): # Meta parameters for this class type: ClassVar[str] = "RFClassifierModel" - mod_class: ClassVar[type] = RandomForestClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from sklearn.ensemble import RandomForestClassifier + + return RandomForestClassifier # RF parameters n_estimators: int = 100 diff --git a/openadmet/models/architecture/svm.py b/openadmet/models/architecture/svm.py index ee61b912..dfa621a9 100644 --- a/openadmet/models/architecture/svm.py +++ b/openadmet/models/architecture/svm.py @@ -4,7 +4,6 @@ import numpy as np from loguru import logger -from sklearn.svm import SVC, SVR from openadmet.models.architecture.model_base import PickleableModelBase, models @@ -14,12 +13,16 @@ class SVMModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the SVM estimator class (deferred import).""" + raise NotImplementedError def build(self): """Prepare the model.""" if not self.estimator: - self.estimator = self.mod_class(**self.model_dump()) + self.estimator = self._get_estimator_class()(**self.model_dump()) else: logger.warning("Model already exists, skipping build") @@ -80,7 +83,12 @@ class SVMRegressorModel(SVMModelBase): # Meta parameters for this class type: ClassVar[str] = "SVMRegressorModel" - mod_class: ClassVar[type] = SVR + + @classmethod + def _get_estimator_class(cls) -> type: + from sklearn.svm import SVR + + return SVR # SVR parameters kernel: str = "rbf" @@ -115,7 +123,12 @@ class SVMClassifierModel(SVMModelBase): # Meta parameters for this class type: ClassVar[str] = "SVMClassifierModel" - mod_class: ClassVar[type] = SVC + + @classmethod + def _get_estimator_class(cls) -> type: + from sklearn.svm import SVC + + return SVC # SVC parameters C: float = 1.0 diff --git a/openadmet/models/architecture/tabpfn.py b/openadmet/models/architecture/tabpfn.py index d9a706d8..eab3d30c 100644 --- a/openadmet/models/architecture/tabpfn.py +++ b/openadmet/models/architecture/tabpfn.py @@ -1,16 +1,11 @@ """TabPFN model implementations.""" -from typing import ClassVar, Literal, Optional, Union +import warnings +from typing import ClassVar, Literal, Optional import numpy as np from loguru import logger from pydantic import Field, field_validator -from tabpfn import TabPFNClassifier, TabPFNRegressor -from tabpfn_extensions.post_hoc_ensembles.sklearn_interface import ( - AutoTabPFNClassifier, - AutoTabPFNRegressor, -) -import warnings from openadmet.models.architecture.model_base import PickleableModelBase, models @@ -26,8 +21,6 @@ class TabPFNExtensionModelBase(PickleableModelBase): ---------- type : ClassVar[str] Model type identifier. - mod_class : ClassVar[type] - The TabPFN model class (e.g., AutoTabPFNRegressor or AutoTabPFNClassifier). max_time : Optional[int] Maximum time to spend on fitting the post hoc ensemble. accelerator : Literal["cpu", "gpu", "auto"] @@ -43,7 +36,11 @@ class TabPFNExtensionModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the TabPFN extension estimator class (deferred import).""" + raise NotImplementedError # TabPFN parameters max_time: Optional[int] = Field( @@ -96,7 +93,7 @@ def build(self): "TabPFN 2.5 is distributed under the TabPFN 2.5 License: https://priorlabs.ai/tabpfn-license which prohibits commercial use. Review the license and ensure you are compliant before using this model. A commercial license can be obtained from the TabPFN team." ) if not self.estimator: - self.estimator = self.mod_class( + self.estimator = self._get_estimator_class()( max_time=self.max_time, device=accelerator, random_state=self.random_state, @@ -149,7 +146,12 @@ class TabPFNPostHocRegressorModel(TabPFNExtensionModelBase): # Meta parameters for this class type: ClassVar[str] = "TabPFNPostHocRegressorModel" - mod_class: ClassVar[type] = AutoTabPFNRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from tabpfn_extensions.post_hoc_ensembles.sklearn_interface import AutoTabPFNRegressor + + return AutoTabPFNRegressor @models.register("TabPFNPostHocClassifierModel") @@ -158,7 +160,12 @@ class TabPFNPostHocClassifierModel(TabPFNExtensionModelBase): # Meta parameters for this class type: ClassVar[str] = "TabPFNPostHocClassifierModel" - mod_class: ClassVar[type] = AutoTabPFNClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from tabpfn_extensions.post_hoc_ensembles.sklearn_interface import AutoTabPFNClassifier + + return AutoTabPFNClassifier def predict_proba(self, X: np.ndarray) -> np.ndarray: """ @@ -197,7 +204,11 @@ class TabPFNModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the basic TabPFN estimator class (deferred import).""" + raise NotImplementedError # TabPFN parameters accelerator: Literal["cpu", "cuda", "auto"] = Field(default="auto") @@ -208,7 +219,7 @@ def build(self): """Prepare and build the model instance.""" accelerator = self.accelerator if self.accelerator != "gpu" else "cuda" if not self.estimator: - self.estimator = self.mod_class( + self.estimator = self._get_estimator_class()( device=accelerator, random_state=self.random_state, ignore_pretraining_limits=self.ignore_pretraining_limits, @@ -264,7 +275,12 @@ class TabPFNRegressorModel(TabPFNModelBase): # Meta parameters for this class type: ClassVar[str] = "TabPFNRegressorModel" - mod_class: ClassVar[type] = TabPFNRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from tabpfn import TabPFNRegressor + + return TabPFNRegressor @models.register("TabPFNClassifierModel") @@ -273,4 +289,9 @@ class TabPFNClassifierModel(TabPFNModelBase): # Meta parameters for this class type: ClassVar[str] = "TabPFNClassifierModel" - mod_class: ClassVar[type] = TabPFNClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from tabpfn import TabPFNClassifier + + return TabPFNClassifier diff --git a/openadmet/models/architecture/xgboost.py b/openadmet/models/architecture/xgboost.py index d45617c3..5ebaf296 100644 --- a/openadmet/models/architecture/xgboost.py +++ b/openadmet/models/architecture/xgboost.py @@ -5,7 +5,6 @@ import numpy as np from loguru import logger from pydantic import ConfigDict -from xgboost import XGBClassifier, XGBRegressor from openadmet.models.architecture.model_base import PickleableModelBase, models @@ -18,12 +17,16 @@ class XGBoostModelBase(PickleableModelBase): # Meta-parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the XGBoost estimator class (deferred import).""" + raise NotImplementedError def build(self): """Prepare the model.""" if not self.estimator: - self.estimator = self.mod_class(**self.model_dump()) + self.estimator = self._get_estimator_class()(**self.model_dump()) else: logger.warning("Model already exists, skipping build") @@ -83,7 +86,12 @@ class XGBRegressorModel(XGBoostModelBase): """ type: ClassVar[str] = "XGBRegressorModel" - mod_class: ClassVar[type] = XGBRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from xgboost import XGBRegressor + + return XGBRegressor @models.register("XGBClassifierModel") @@ -104,7 +112,12 @@ class XGBClassifierModel(XGBoostModelBase): """ type: ClassVar[str] = "XGBoostClaXGBClassifierModelssifierModel" - mod_class: ClassVar[type] = XGBClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from xgboost import XGBClassifier + + return XGBClassifier def predict_proba(self, X: np.ndarray) -> np.ndarray: """ diff --git a/openadmet/models/eval/cross_validation.py b/openadmet/models/eval/cross_validation.py index b20068c6..60239328 100644 --- a/openadmet/models/eval/cross_validation.py +++ b/openadmet/models/eval/cross_validation.py @@ -20,8 +20,6 @@ from openadmet.models.eval.eval_base import EvalBase, evaluators, get_t_true_and_t_pred from openadmet.models.eval.regression import ( RegressionPlots, - nan_omit_ktau, - nan_omit_spearmanr, pct_within_1_log_unit, relative_absolute_error, ) @@ -32,11 +30,21 @@ def wrap_ktau(y_true, y_pred): """Wrap ktau nan omission.""" + from functools import partial + + from scipy.stats import kendalltau + + nan_omit_ktau = partial(kendalltau, nan_policy="omit") return nan_omit_ktau(y_true, y_pred).statistic def wrap_spearmanr(y_true, y_pred): """Wrap spearmanR nan omission.""" + from functools import partial + + from scipy.stats import spearmanr + + nan_omit_spearmanr = partial(spearmanr, nan_policy="omit") return nan_omit_spearmanr(y_true, y_pred).correlation diff --git a/openadmet/models/eval/eval_base.py b/openadmet/models/eval/eval_base.py index 42ad7304..82aa90bc 100644 --- a/openadmet/models/eval/eval_base.py +++ b/openadmet/models/eval/eval_base.py @@ -3,11 +3,10 @@ from abc import abstractmethod from typing import Callable, ClassVar -from loguru import logger import numpy as np from class_registry import ClassRegistry, RegistryKeyError +from loguru import logger from pydantic import BaseModel -from scipy.stats import bootstrap evaluators = ClassRegistry(unique=True) @@ -32,6 +31,9 @@ def get_eval_class(eval_type): If the evaluation type is not found in the registry. """ + from openadmet.models._registry_loader import load_group + + load_group("evaluators") try: eval_class = evaluators.get_class(eval_type) except RegistryKeyError: @@ -237,6 +239,8 @@ def stat_and_bootstrap( """ # calculate the metric and confidence intervals + from scipy.stats import bootstrap + if is_scipy_statistic: metric = statistic(y_true, y_pred).statistic conf_interval = bootstrap( diff --git a/openadmet/models/eval/regression.py b/openadmet/models/eval/regression.py index 58057834..78dc72ef 100644 --- a/openadmet/models/eval/regression.py +++ b/openadmet/models/eval/regression.py @@ -1,16 +1,10 @@ """Regression metrics and plots for model evaluation.""" import json -from functools import partial import numpy as np import pandas as pd -import seaborn as sns -import wandb -from matplotlib import pyplot as plt from pydantic import Field -from scipy.stats import kendalltau, spearmanr -from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score from openadmet.models.eval.eval_base import ( EvalBase, @@ -19,10 +13,6 @@ ) from openadmet.models.eval.utils import _make_stat_caption, _make_stat_dict -# create partial functions for the scipy stats -nan_omit_ktau = partial(kendalltau, nan_policy="omit") -nan_omit_spearmanr = partial(spearmanr, nan_policy="omit") - def relative_absolute_error(y_true, y_pred): """ @@ -104,19 +94,28 @@ class RegressionMetrics(EvalBase): ) _evaluated: bool = False - _metrics: dict = { - "mse": (mean_squared_error, False, "MSE"), - "mae": (mean_absolute_error, False, "MAE"), - "r2": (r2_score, False, "$R^2$"), - "ktau": (nan_omit_ktau, True, "Kendall's $\\tau$"), - "spearmanr": (nan_omit_spearmanr, True, "Spearman's $\\rho$"), - "rae": (relative_absolute_error, False, "RAE"), - } + @classmethod + def _base_metrics(cls) -> dict: + from functools import partial + + from scipy.stats import kendalltau, spearmanr + from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score + + nan_omit_ktau = partial(kendalltau, nan_policy="omit") + nan_omit_spearmanr = partial(spearmanr, nan_policy="omit") + return { + "mse": (mean_squared_error, False, "MSE"), + "mae": (mean_absolute_error, False, "MAE"), + "r2": (r2_score, False, "$R^2$"), + "ktau": (nan_omit_ktau, True, "Kendall's $\\tau$"), + "spearmanr": (nan_omit_spearmanr, True, "Spearman's $\\rho$"), + "rae": (relative_absolute_error, False, "RAE"), + } @property def active_metrics(self): """Return metrics applicable to the current target scale.""" - metrics = dict(self._metrics) + metrics = self._base_metrics() if self.pXC50: metrics["pct_within_1_log"] = ( pct_within_1_log_unit, @@ -205,6 +204,8 @@ def evaluate( } if self.use_wandb: + import wandb + for t_label in target_labels: # make a table for the metrics table = wandb.Table( @@ -294,6 +295,8 @@ def write_report(self, output_dir): # also log the json to wandb if self.use_wandb: + import wandb + artifact = wandb.Artifact(name="metrics_json", type="metric_json") # Add a file to the artifact artifact.add_file(json_path) @@ -351,7 +354,7 @@ def get_stat_dict(self, t_label): data=self.data, task_name=t_label, metric_names=self.metric_names, - metrics=self._metrics, + metrics=self._base_metrics(), confidence_level=self.bootstrap_confidence_level, cv=False, ) @@ -574,6 +577,7 @@ def regplot( else: max_ax = max_val # set the limits to be the same for both axes + import seaborn as sns g = sns.jointplot( x=np.ravel(y_true), @@ -718,6 +722,8 @@ def ciplot(stat_dict={}): } n_metrics = len(metrics) + from matplotlib import pyplot as plt + fig, axes = plt.subplots(1, n_metrics, figsize=(8, n_metrics), sharex=False) if n_metrics == 1: @@ -783,4 +789,6 @@ def write_report(self, output_dir): plot_path = output_dir / f"{plot_tag}.png" plot.savefig(plot_path, dpi=self.dpi) if self.use_wandb: + import wandb + wandb.log({plot_tag: wandb.Image(str(plot_path))}) diff --git a/openadmet/models/features/chemprop.py b/openadmet/models/features/chemprop.py index cb26d859..2935da84 100644 --- a/openadmet/models/features/chemprop.py +++ b/openadmet/models/features/chemprop.py @@ -1,32 +1,20 @@ """ChemProp featurizer implementation.""" +from __future__ import annotations + from collections.abc import Iterable from typing import Any, Union + import numpy as np import pandas as pd -from chemprop.data import ( - MoleculeDatapoint, - MoleculeDataset, - MulticomponentDataset, - ReactionDataset, -) -from chemprop.data.collate import collate_batch, collate_multicomponent -from chemprop.data.samplers import ClassBalanceSampler, SeededSampler -from sklearn.preprocessing import StandardScaler -from torch.utils.data import DataLoader - -from openadmet.models.features.chemprop import ( - MoleculeDataset, - MulticomponentDataset, - ReactionDataset, -) + from openadmet.models.features.feature_base import DeepLearningFeaturizer, featurizers # we vendor this from chemprop so that we can pass custom samplers # taken directly from https://github.com/chemprop/chemprop/blob/main/chemprop/data/dataloader.py def _vendor_build_dataloader( - dataset: MoleculeDataset | ReactionDataset | MulticomponentDataset, + dataset, batch_size: int = 64, num_workers: int = 0, class_balance: bool = False, @@ -68,6 +56,11 @@ def _vendor_build_dataloader( A PyTorch DataLoader for the given MoleculeDataset, ReactionDataset, or MulticomponentDataset. """ + from chemprop.data import MulticomponentDataset + from chemprop.data.collate import collate_batch, collate_multicomponent + from chemprop.data.samplers import ClassBalanceSampler, SeededSampler + from torch.utils.data import DataLoader + if sampler is not None: if class_balance: sampler = ClassBalanceSampler(dataset.Y, seed, shuffle) @@ -153,6 +146,8 @@ def featurize( - Union[MoleculeDataset, ReactionDataset, MulticomponentDataset]: PyTorch Dataset containing the features and targets. """ + from chemprop.data import MoleculeDatapoint, MoleculeDataset + if y is not None: # if a pandas dataframe or series if isinstance(y, pd.DataFrame) or isinstance(y, pd.Series): diff --git a/openadmet/models/features/feature_base.py b/openadmet/models/features/feature_base.py index 5d333452..accc7e15 100644 --- a/openadmet/models/features/feature_base.py +++ b/openadmet/models/features/feature_base.py @@ -1,20 +1,28 @@ """Base classes and utilities for molecular featurizers.""" +from __future__ import annotations + from abc import ABC, abstractmethod from collections.abc import Iterable +from typing import TYPE_CHECKING, Any import numpy as np from class_registry import ClassRegistry, RegistryKeyError -from molfeat.trans import MoleculeTransformer from pydantic import BaseModel -from sklearn.preprocessing import StandardScaler -from torch.utils.data import DataLoader, Dataset + +if TYPE_CHECKING: + from molfeat.trans import MoleculeTransformer + from sklearn.preprocessing import StandardScaler + from torch.utils.data import DataLoader, Dataset featurizers = ClassRegistry(unique=True) def get_featurizer_class(feat_type): """Retrieve a featurizer class from the registry by type.""" + from openadmet.models._registry_loader import load_group + + load_group("featurizers") try: feat_class = featurizers.get_class(feat_type) except RegistryKeyError: @@ -107,7 +115,7 @@ class MolfeatFeaturizer(FeaturizerBase): """ - _transformer: MoleculeTransformer = None + _transformer: Any = None def __init__(self, *args, **kwargs): """ diff --git a/openadmet/models/registries.py b/openadmet/models/registries.py index 66b11333..dcccaaab 100644 --- a/openadmet/models/registries.py +++ b/openadmet/models/registries.py @@ -1,53 +1,18 @@ -"""Registry imports for all model components.""" - -# IMPORTANT: order matters here, make sure to import all the registered classes first -# before importing the registry classes -# there must be a better way to do this, but for now this works - -# active learning -from openadmet.models.active_learning.committee import * # noqa: F401 F403 -from openadmet.models.active_learning.ensemble_base import ensemblers # noqa: F401 F403 - -# models -from openadmet.models.architecture.catboost import * # noqa: F401 F403 -from openadmet.models.architecture.chemprop import * # noqa: F401 F403 # noqa: F401 F403 -from openadmet.models.architecture.dummy import * # noqa: F401 F403 -from openadmet.models.architecture.lgbm import * # noqa: F401 F403 # noqa: F401 F403 -from openadmet.models.architecture.nepare import * # noqa: F401 F403 -from openadmet.models.architecture.rf import * # noqa: F401 F403 -from openadmet.models.architecture.svm import * # noqa: F401 F403 -from openadmet.models.architecture.tabpfn import * # noqa: F401 F403 -from openadmet.models.architecture.xgboost import * # noqa: F401 F403 -from openadmet.models.architecture.model_base import models # noqa: F401 F403 - -# evaluators -from openadmet.models.eval.classification import * # noqa: F401 F403 -from openadmet.models.eval.cross_validation import * # noqa: F401 F403 -from openadmet.models.eval.regression import * # noqa: F401 F403 -from openadmet.models.eval.uncertainty import * # noqa: F401 F403 -from openadmet.models.eval.eval_base import evaluators # noqa: F401 F403 - -# featurizers -from openadmet.models.features.chemprop import * # noqa: F401 F403 -from openadmet.models.features.combine import * # noqa: F401 F403 -from openadmet.models.features.molfeat_fingerprint import * # noqa: F401 F403 -from openadmet.models.features.molfeat_properties import * # noqa: F401 F403 -from openadmet.models.features.feature_base import featurizers # noqa: F401 F403 - -# util -from openadmet.models.log import logger # noqa: F401 F403 - -# splitters -from openadmet.models.split.scaffold import * # noqa: F401 F403 -from openadmet.models.split.sklearn import * # noqa: F401 F403 -from openadmet.models.split.split_base import splitters # noqa: F401 F403 -from openadmet.models.split.cluster import * # noqa: F401 F403 - -# trainers -from openadmet.models.trainer.lightning import * # noqa: F401 F403 -from openadmet.models.trainer.sklearn import * # noqa: F401 F403 -from openadmet.models.trainer.trainer_base import trainers # noqa: F401 F403 - -# transforms -from openadmet.models.transforms.impute import * # noqa: F401 F403 -from openadmet.models.transforms.transform_base import * # noqa: F401 F403 +"""Registry objects and lazy loader for all model components. + +Importing this module is intentionally cheap. Concrete classes are registered +only when the relevant group is first accessed. To eagerly load everything +(e.g. for CLI tools or the Anvil workflow), call ``load_all()``: + + from openadmet.models.registries import load_all + load_all() +""" + +from openadmet.models._registry_loader import load_all # noqa: F401 +from openadmet.models.active_learning.ensemble_base import ensemblers # noqa: F401 +from openadmet.models.architecture.model_base import models # noqa: F401 +from openadmet.models.eval.eval_base import evaluators # noqa: F401 +from openadmet.models.features.feature_base import featurizers # noqa: F401 +from openadmet.models.log import logger # noqa: F401 +from openadmet.models.split.split_base import splitters # noqa: F401 +from openadmet.models.trainer.trainer_base import trainers # noqa: F401 diff --git a/openadmet/models/split/cluster.py b/openadmet/models/split/cluster.py index 2a39dcd6..55daa3c9 100644 --- a/openadmet/models/split/cluster.py +++ b/openadmet/models/split/cluster.py @@ -1,22 +1,13 @@ """Cluster-based data splitting implementations.""" import logging -from pydantic import BaseModel, field_validator, model_validator from typing import Literal -from sklearn.model_selection import GroupShuffleSplit -from sklearn.cluster import KMeans -from molfeat.trans import MoleculeTransformer -from molfeat.trans.fp import FPVecTransformer -import datamol as dm + import numpy as np import pandas as pd +from pydantic import BaseModel, field_validator, model_validator + from openadmet.models.split.split_base import SplitterBase, splitters -from useful_rdkit_utils import ( - get_butina_clusters, - get_bemis_murcko_clusters, - get_scaffold, - smi2numpy_fp, -) @splitters.register("ClusterSplitter") @@ -78,13 +69,22 @@ def split(self, X, y, num_iters=1000): """ # Get clusters based on the selected method if self.method == "butina": + from useful_rdkit_utils import get_butina_clusters + clusters = get_butina_clusters(X, cutoff=self.butina_cutoff) elif self.method == "bemis-murcko": + from useful_rdkit_utils import get_bemis_murcko_clusters + clusters = get_bemis_murcko_clusters(X) elif self.method == "kmeans": logging.warning( "KMeans clustering is NOT DETERMINISTIC with random seed across platforms." ) + import datamol as dm + from molfeat.trans import MoleculeTransformer + from molfeat.trans.fp import FPVecTransformer + from sklearn.cluster import KMeans + km = KMeans( n_clusters=self.k_clusters, n_init=1, diff --git a/openadmet/models/split/scaffold.py b/openadmet/models/split/scaffold.py index d5dda0e1..07349986 100644 --- a/openadmet/models/split/scaffold.py +++ b/openadmet/models/split/scaffold.py @@ -1,10 +1,10 @@ """Scaffold-based data splitting implementations.""" import logging -from sklearn.model_selection import train_test_split -from splito import MaxDissimilaritySplit, PerimeterSplit, ScaffoldSplit + import numpy as np import pandas as pd + from openadmet.models.split.split_base import SplitterBase, splitters @@ -40,6 +40,8 @@ def split(self, X, y): """ logging.warning("ScaffoldSplitter is not available for cross-validation.") + from splito import ScaffoldSplit + # No test set requested if self.test_size == 0: # Split into train and val @@ -87,6 +89,8 @@ def split(self, X, y): ) # Split train+val into train and val sets + from sklearn.model_selection import train_test_split + X_train, X_val, y_train, y_val = train_test_split( safe_index(X, train_val_idx), safe_index(y, train_val_idx), @@ -135,6 +139,8 @@ def split(self, X, y): """ logging.warning("PerimeterSplitter is not available for cross-validation.") + from splito import PerimeterSplit + # No test set requested if self.test_size == 0: # Split into train and val @@ -181,6 +187,8 @@ def split(self, X, y): ) # Split train+val into train and val sets using sklearn + from sklearn.model_selection import train_test_split + X_train, X_val, y_train, y_val = train_test_split( safe_index(X, train_val_idx), safe_index(y, train_val_idx), @@ -231,6 +239,8 @@ def split(self, X, y): logging.warning( "MaxDissimilaritySplitter is not available for cross-validation." ) + from splito import MaxDissimilaritySplit + # No test set requested if self.test_size == 0: # Split into train and val @@ -277,6 +287,8 @@ def split(self, X, y): ) # Split train+val into train and val sets using sklearn + from sklearn.model_selection import train_test_split + X_train, X_val, y_train, y_val = train_test_split( safe_index(X, train_val_idx), safe_index(y, train_val_idx), diff --git a/openadmet/models/split/split_base.py b/openadmet/models/split/split_base.py index 7523cfa2..832e5d28 100644 --- a/openadmet/models/split/split_base.py +++ b/openadmet/models/split/split_base.py @@ -25,6 +25,9 @@ def get_splitter_class(feat_type): The splitter class corresponding to the given type. """ + from openadmet.models._registry_loader import load_group + + load_group("splitters") try: split_class = splitters.get_class(feat_type) except RegistryKeyError: diff --git a/openadmet/models/tests/unit/models/test_base.py b/openadmet/models/tests/unit/models/test_base.py index ed466fef..cc7be2ec 100644 --- a/openadmet/models/tests/unit/models/test_base.py +++ b/openadmet/models/tests/unit/models/test_base.py @@ -3,6 +3,7 @@ import pytest +import openadmet.models.architecture.lightning_model_base as lightning_model_base import openadmet.models.architecture.model_base as model_base from openadmet.models.architecture.model_base import ( LightningModelBase, @@ -38,7 +39,7 @@ def test_save_load_torch_model(mclass, tmp_path): def test_lightning_model_load_uses_weights_only(monkeypatch, tmp_path): state_dict = {"layer.weight": "dummy"} torch_load = Mock(return_value=state_dict) - monkeypatch.setattr(model_base.torch, "load", torch_load) + monkeypatch.setattr(lightning_model_base.torch, "load", torch_load) estimator = Mock() model = SimpleNamespace(estimator=estimator) diff --git a/openadmet/models/trainer/lightning.py b/openadmet/models/trainer/lightning.py index 2f89cac2..fef0197a 100644 --- a/openadmet/models/trainer/lightning.py +++ b/openadmet/models/trainer/lightning.py @@ -3,10 +3,6 @@ from pathlib import Path # it is used in the main therefore i do not remove it from typing import Any -import torch -from lightning import pytorch as pl -from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint -from lightning.pytorch.loggers import CSVLogger, WandbLogger from loguru import logger from openadmet.models.drivers import DriverType from openadmet.models.trainer.trainer_base import TrainerBase, trainers @@ -106,6 +102,10 @@ def build(self, no_val: bool = False): # Initialize logging container self._logger = [] + from lightning import pytorch as pl + from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint + from lightning.pytorch.loggers import CSVLogger, WandbLogger + # Initialize the callbacks dict self._callbacks = {} @@ -191,6 +191,8 @@ def train(self, train_dataloader, val_dataloader): # Indicate that the model is being trained logger.debug(f"Training model {self.model.estimator}") + import torch + # Fit model self._trainer.fit(self.model.estimator, train_dataloader, val_dataloader) diff --git a/openadmet/models/trainer/trainer_base.py b/openadmet/models/trainer/trainer_base.py index 8632ab12..896076ce 100644 --- a/openadmet/models/trainer/trainer_base.py +++ b/openadmet/models/trainer/trainer_base.py @@ -26,6 +26,9 @@ def get_trainer_class(model_type): The trainer class corresponding to the given type. """ + from openadmet.models._registry_loader import load_group + + load_group("trainers") try: feat_class = trainers.get_class(model_type) except RegistryKeyError: From 238add436fc0af782174caf4acd89aca8b13fce8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 May 2026 02:54:59 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- openadmet/models/_registry_loader.py | 3 ++- openadmet/models/architecture/lightning_model_base.py | 3 ++- openadmet/models/architecture/tabpfn.py | 8 ++++++-- openadmet/models/features/chemprop.py | 4 ++-- openadmet/models/registries.py | 3 ++- 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/openadmet/models/_registry_loader.py b/openadmet/models/_registry_loader.py index 3f43cd00..edc95c32 100644 --- a/openadmet/models/_registry_loader.py +++ b/openadmet/models/_registry_loader.py @@ -1,4 +1,5 @@ -"""Lazy registry loader — zero heavy imports at module level. +""" +Lazy registry loader — zero heavy imports at module level. Call ``load_group(name)`` to load a specific registry group, or ``load_all()`` to populate every registry at once. Both are idempotent. diff --git a/openadmet/models/architecture/lightning_model_base.py b/openadmet/models/architecture/lightning_model_base.py index 83eb5608..e8d20f51 100644 --- a/openadmet/models/architecture/lightning_model_base.py +++ b/openadmet/models/architecture/lightning_model_base.py @@ -1,4 +1,5 @@ -"""Lightning-specific base classes for deep learning models. +""" +Lightning-specific base classes for deep learning models. This module is intentionally separate from model_base so that importing PickleableModelBase (sklearn-style models) does not incur the cost of diff --git a/openadmet/models/architecture/tabpfn.py b/openadmet/models/architecture/tabpfn.py index eab3d30c..ac937e6a 100644 --- a/openadmet/models/architecture/tabpfn.py +++ b/openadmet/models/architecture/tabpfn.py @@ -149,7 +149,9 @@ class TabPFNPostHocRegressorModel(TabPFNExtensionModelBase): @classmethod def _get_estimator_class(cls) -> type: - from tabpfn_extensions.post_hoc_ensembles.sklearn_interface import AutoTabPFNRegressor + from tabpfn_extensions.post_hoc_ensembles.sklearn_interface import ( + AutoTabPFNRegressor, + ) return AutoTabPFNRegressor @@ -163,7 +165,9 @@ class TabPFNPostHocClassifierModel(TabPFNExtensionModelBase): @classmethod def _get_estimator_class(cls) -> type: - from tabpfn_extensions.post_hoc_ensembles.sklearn_interface import AutoTabPFNClassifier + from tabpfn_extensions.post_hoc_ensembles.sklearn_interface import ( + AutoTabPFNClassifier, + ) return AutoTabPFNClassifier diff --git a/openadmet/models/features/chemprop.py b/openadmet/models/features/chemprop.py index 2935da84..1866b0c6 100644 --- a/openadmet/models/features/chemprop.py +++ b/openadmet/models/features/chemprop.py @@ -124,7 +124,7 @@ def featurize( DataLoader, np.ndarray, StandardScaler, - Union[MoleculeDataset, ReactionDataset, MulticomponentDataset], + MoleculeDataset | ReactionDataset | MulticomponentDataset, ]: """ Featurize a list of SMILES strings. @@ -217,6 +217,6 @@ def dataset_to_dataloader( **kwargs, ) - def make_new(self) -> "ChemPropFeaturizer": + def make_new(self) -> ChemPropFeaturizer: """Copy parameters to a new ChemPropFeaturizer instance.""" return self.__class__(**self.dict()) diff --git a/openadmet/models/registries.py b/openadmet/models/registries.py index dcccaaab..e4983312 100644 --- a/openadmet/models/registries.py +++ b/openadmet/models/registries.py @@ -1,4 +1,5 @@ -"""Registry objects and lazy loader for all model components. +""" +Registry objects and lazy loader for all model components. Importing this module is intentionally cheap. Concrete classes are registered only when the relevant group is first accessed. To eagerly load everything From 71b92c3203c5708307a19d75cabf3fb8f213e3f3 Mon Sep 17 00:00:00 2001 From: Sean Colby Date: Sun, 17 May 2026 19:02:18 -0800 Subject: [PATCH 3/4] docs: expand numpy docstrings on all new public functions - load_group(): add Parameters section with valid group keys - load_all(): expand summary line - get_mod_class(): add full Parameters/Returns/Raises sections - get_featurizer_class(): add full Parameters/Returns/Raises sections - get_ensemble_class(): add full Parameters/Returns/Raises sections - RegressionEvaluator._base_metrics(): add Returns section All 5 D413 blank-line-after-section violations auto-fixed. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- openadmet/models/_registry_loader.py | 14 +++++++++++-- .../models/active_learning/ensemble_base.py | 20 ++++++++++++++++++- openadmet/models/architecture/model_base.py | 20 ++++++++++++++++++- openadmet/models/eval/regression.py | 10 ++++++++++ openadmet/models/features/feature_base.py | 20 ++++++++++++++++++- 5 files changed, 79 insertions(+), 5 deletions(-) diff --git a/openadmet/models/_registry_loader.py b/openadmet/models/_registry_loader.py index edc95c32..2ce8a4e8 100644 --- a/openadmet/models/_registry_loader.py +++ b/openadmet/models/_registry_loader.py @@ -67,7 +67,17 @@ def load_group(name: str) -> None: - """Import all modules in the named registry group (idempotent).""" + """ + Import all modules in the named registry group (idempotent). + + Parameters + ---------- + name : str + Registry group key. Must be one of: ``"models"``, ``"evaluators"``, + ``"featurizers"``, ``"splitters"``, ``"trainers"``, ``"transforms"``, + ``"active_learning"``. + + """ if name in _loaded: return for mod in _GROUPS[name]: @@ -76,6 +86,6 @@ def load_group(name: str) -> None: def load_all() -> None: - """Import all registry groups (idempotent).""" + """Import all registry groups, making every registered class available (idempotent).""" for name in _GROUPS: load_group(name) diff --git a/openadmet/models/active_learning/ensemble_base.py b/openadmet/models/active_learning/ensemble_base.py index e534983d..1e85f85b 100644 --- a/openadmet/models/active_learning/ensemble_base.py +++ b/openadmet/models/active_learning/ensemble_base.py @@ -10,7 +10,25 @@ def get_ensemble_class(ensemble_type): - """Get the ensemble class.""" + """ + Get the ensemble class from the registry by type. + + Parameters + ---------- + ensemble_type : str + The registered key for the ensemble (e.g., ``"QueryByCommittee"``). + + Returns + ------- + type + The ensemble class corresponding to the given type. + + Raises + ------ + ValueError + If ``ensemble_type`` is not found in the ensemble registry. + + """ from openadmet.models._registry_loader import load_group load_group("active_learning") diff --git a/openadmet/models/architecture/model_base.py b/openadmet/models/architecture/model_base.py index 120c53a8..88172541 100644 --- a/openadmet/models/architecture/model_base.py +++ b/openadmet/models/architecture/model_base.py @@ -15,7 +15,25 @@ def get_mod_class(model_type): - """Get the model class from the registry.""" + """ + Get the model class from the registry by type. + + Parameters + ---------- + model_type : str + The registered key for the model (e.g., ``"XGBRegressorModel"``). + + Returns + ------- + type + The model class corresponding to the given type. + + Raises + ------ + ValueError + If ``model_type`` is not found in the model registry. + + """ from openadmet.models._registry_loader import load_group load_group("models") diff --git a/openadmet/models/eval/regression.py b/openadmet/models/eval/regression.py index 78dc72ef..f9a83f31 100644 --- a/openadmet/models/eval/regression.py +++ b/openadmet/models/eval/regression.py @@ -96,6 +96,16 @@ class RegressionMetrics(EvalBase): @classmethod def _base_metrics(cls) -> dict: + """ + Build the base metrics dictionary with deferred 3rd-party imports. + + Returns + ------- + dict + Mapping of metric key to ``(callable, is_scipy_statistic, display_label)`` + tuples for MSE, MAE, R², Kendall's τ, Spearman's ρ, and RAE. + + """ from functools import partial from scipy.stats import kendalltau, spearmanr diff --git a/openadmet/models/features/feature_base.py b/openadmet/models/features/feature_base.py index accc7e15..e0464a75 100644 --- a/openadmet/models/features/feature_base.py +++ b/openadmet/models/features/feature_base.py @@ -19,7 +19,25 @@ def get_featurizer_class(feat_type): - """Retrieve a featurizer class from the registry by type.""" + """ + Retrieve a featurizer class from the registry by type. + + Parameters + ---------- + feat_type : str + The registered key for the featurizer (e.g., ``"MolfeatFeaturizer"``). + + Returns + ------- + type + The featurizer class corresponding to the given type. + + Raises + ------ + ValueError + If ``feat_type`` is not found in the featurizer registry. + + """ from openadmet.models._registry_loader import load_group load_group("featurizers") From 2d1023b934b0807ed0e25dd00782797cd36d72fc Mon Sep 17 00:00:00 2001 From: Sean Colby Date: Mon, 18 May 2026 08:10:13 -0800 Subject: [PATCH 4/4] fix: add load_group('transforms') to get_transform_class() get_transform_class() was the only get_*_class() function not calling load_group() before the registry lookup. This caused 'ImputeTransform not found in transform catalogue' in integration tests because the transforms group was never eagerly loaded under the new lazy registry. Also adds the missing Raises section to the docstring. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- openadmet/models/transforms/transform_base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/openadmet/models/transforms/transform_base.py b/openadmet/models/transforms/transform_base.py index 10628abb..062bbf24 100644 --- a/openadmet/models/transforms/transform_base.py +++ b/openadmet/models/transforms/transform_base.py @@ -27,7 +27,15 @@ def get_transform_class(trans_type): TransformBase The transform class corresponding to the given type. + Raises + ------ + ValueError + If ``trans_type`` is not found in the transform registry. + """ + from openadmet.models._registry_loader import load_group + + load_group("transforms") try: transf_class = transforms.get_class(trans_type) except RegistryKeyError: