diff --git a/openadmet/models/_registry_loader.py b/openadmet/models/_registry_loader.py new file mode 100644 index 00000000..2ce8a4e8 --- /dev/null +++ b/openadmet/models/_registry_loader.py @@ -0,0 +1,91 @@ +""" +Lazy registry loader — zero heavy imports at module level. + +Call ``load_group(name)`` to load a specific registry group, or ``load_all()`` +to populate every registry at once. Both are idempotent. +""" + +import importlib + +_MODELS = [ + "openadmet.models.architecture.catboost", + "openadmet.models.architecture.chemprop", + "openadmet.models.architecture.dummy", + "openadmet.models.architecture.lgbm", + "openadmet.models.architecture.nepare", + "openadmet.models.architecture.rf", + "openadmet.models.architecture.svm", + "openadmet.models.architecture.tabpfn", + "openadmet.models.architecture.xgboost", +] + +_EVALUATORS = [ + "openadmet.models.eval.classification", + "openadmet.models.eval.cross_validation", + "openadmet.models.eval.regression", + "openadmet.models.eval.uncertainty", +] + +_FEATURIZERS = [ + "openadmet.models.features.chemprop", + "openadmet.models.features.combine", + "openadmet.models.features.molfeat_fingerprint", + "openadmet.models.features.molfeat_properties", +] + +_SPLITTERS = [ + "openadmet.models.split.scaffold", + "openadmet.models.split.sklearn", + "openadmet.models.split.cluster", +] + +_TRAINERS = [ + "openadmet.models.trainer.lightning", + "openadmet.models.trainer.sklearn", +] + +_TRANSFORMS = [ + "openadmet.models.transforms.impute", + "openadmet.models.transforms.transform_base", +] + +_ACTIVE_LEARNING = [ + "openadmet.models.active_learning.committee", +] + +_GROUPS: dict[str, list[str]] = { + "models": _MODELS, + "evaluators": _EVALUATORS, + "featurizers": _FEATURIZERS, + "splitters": _SPLITTERS, + "trainers": _TRAINERS, + "transforms": _TRANSFORMS, + "active_learning": _ACTIVE_LEARNING, +} + +_loaded: set[str] = set() + + +def load_group(name: str) -> None: + """ + Import all modules in the named registry group (idempotent). + + Parameters + ---------- + name : str + Registry group key. Must be one of: ``"models"``, ``"evaluators"``, + ``"featurizers"``, ``"splitters"``, ``"trainers"``, ``"transforms"``, + ``"active_learning"``. + + """ + if name in _loaded: + return + for mod in _GROUPS[name]: + importlib.import_module(mod) + _loaded.add(name) + + +def load_all() -> None: + """Import all registry groups, making every registered class available (idempotent).""" + for name in _GROUPS: + load_group(name) diff --git a/openadmet/models/active_learning/ensemble_base.py b/openadmet/models/active_learning/ensemble_base.py index 85a68f8f..1e85f85b 100644 --- a/openadmet/models/active_learning/ensemble_base.py +++ b/openadmet/models/active_learning/ensemble_base.py @@ -10,7 +10,28 @@ def get_ensemble_class(ensemble_type): - """Get the ensemble class.""" + """ + Get the ensemble class from the registry by type. + + Parameters + ---------- + ensemble_type : str + The registered key for the ensemble (e.g., ``"QueryByCommittee"``). + + Returns + ------- + type + The ensemble class corresponding to the given type. + + Raises + ------ + ValueError + If ``ensemble_type`` is not found in the ensemble registry. + + """ + from openadmet.models._registry_loader import load_group + + load_group("active_learning") try: ensemble_class = ensemblers.get_class(ensemble_type) except RegistryKeyError: diff --git a/openadmet/models/anvil/specification.py b/openadmet/models/anvil/specification.py index a191d1d6..a4210503 100644 --- a/openadmet/models/anvil/specification.py +++ b/openadmet/models/anvil/specification.py @@ -19,7 +19,7 @@ from openadmet.models.drivers import DriverType from openadmet.models.eval.eval_base import get_eval_class from openadmet.models.features.feature_base import get_featurizer_class -from openadmet.models.registries import * # noqa: F401, F403 +from openadmet.models.registries import load_all # noqa: F401 from openadmet.models.split.split_base import get_splitter_class from openadmet.models.trainer.trainer_base import get_trainer_class from openadmet.models.transforms.transform_base import ( diff --git a/openadmet/models/anvil/workflow_base.py b/openadmet/models/anvil/workflow_base.py index 7befc504..07b8a556 100644 --- a/openadmet/models/anvil/workflow_base.py +++ b/openadmet/models/anvil/workflow_base.py @@ -14,7 +14,7 @@ from openadmet.models.architecture.model_base import ModelBase from openadmet.models.eval.eval_base import EvalBase from openadmet.models.features.feature_base import FeaturizerBase -from openadmet.models.registries import * # noqa: F401, F403 +from openadmet.models.registries import load_all # noqa: F401 from openadmet.models.split.split_base import SplitterBase from openadmet.models.trainer.trainer_base import TrainerBase from openadmet.models.transforms.transform_base import ( diff --git a/openadmet/models/architecture/catboost.py b/openadmet/models/architecture/catboost.py index 3b3c62fe..6cf86898 100644 --- a/openadmet/models/architecture/catboost.py +++ b/openadmet/models/architecture/catboost.py @@ -3,7 +3,6 @@ from typing import ClassVar import numpy as np -from catboost import CatBoostClassifier, CatBoostRegressor from loguru import logger from pydantic import ConfigDict @@ -18,9 +17,6 @@ class CatBoostModelBase(PickleableModelBase): ---------- type : ClassVar[str] The type of the model. - mod_class : ClassVar[type] - To specify the CatBoost model class (e.g., CatBoostRegressor or CatBoost - Classifier) """ @@ -29,12 +25,16 @@ class CatBoostModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the CatBoost estimator class (deferred import).""" + raise NotImplementedError def build(self): """Prepare the model.""" if not self.estimator: - self.estimator = self.mod_class(**self.model_dump()) + self.estimator = self._get_estimator_class()(**self.model_dump()) else: logger.warning("Model already exists, skipping build") @@ -95,7 +95,12 @@ class CatBoostRegressorModel(CatBoostModelBase): # Meta parameters for this class type: ClassVar[str] = "CatBoostRegressorModel" - mod_class: ClassVar[type] = CatBoostRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from catboost import CatBoostRegressor + + return CatBoostRegressor @models.register("CatBoostClassifierModel") @@ -109,7 +114,12 @@ class CatBoostClassifierModel(CatBoostModelBase): # Meta parameters for this class type: ClassVar[str] = "CatBoostClassifierModel" - mod_class: ClassVar[type] = CatBoostClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from catboost import CatBoostClassifier + + return CatBoostClassifier def predict_proba(self, X: np.ndarray) -> np.ndarray: """ diff --git a/openadmet/models/architecture/chemprop.py b/openadmet/models/architecture/chemprop.py index 692fa69d..59eaea1b 100644 --- a/openadmet/models/architecture/chemprop.py +++ b/openadmet/models/architecture/chemprop.py @@ -14,15 +14,9 @@ from loguru import logger from pydantic import PrivateAttr, field_validator, model_validator -from openadmet.models.architecture.model_base import LightningModelBase +from openadmet.models.architecture.lightning_model_base import LightningModelBase from openadmet.models.architecture.model_base import models as model_registry -_METRIC_TO_LOSS = { - "mse": nn.metrics.MSE(), - "mae": nn.metrics.MAE(), - "rmse": nn.metrics.RMSE(), -} - def configure_optimizers(self): """ @@ -445,6 +439,11 @@ def build(self, scaler=None): """ if not self.estimator: + _METRIC_TO_LOSS = { + "mse": nn.metrics.MSE(), + "mae": nn.metrics.MAE(), + "rmse": nn.metrics.RMSE(), + } metric_list = [_METRIC_TO_LOSS[metric] for metric in self.metric_list] if self.from_foundation: if self.from_foundation == "chemeleon": diff --git a/openadmet/models/architecture/dummy.py b/openadmet/models/architecture/dummy.py index 4e28131f..761b16ab 100644 --- a/openadmet/models/architecture/dummy.py +++ b/openadmet/models/architecture/dummy.py @@ -4,7 +4,6 @@ import numpy as np from loguru import logger -from sklearn.dummy import DummyClassifier, DummyRegressor from openadmet.models.architecture.model_base import PickleableModelBase, models @@ -14,12 +13,16 @@ class DummyModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the sklearn dummy estimator class (deferred import).""" + raise NotImplementedError def build(self): """Prepare the model.""" if not self.estimator: - self.estimator = self.mod_class(**self.model_dump()) + self.estimator = self._get_estimator_class()(**self.model_dump()) else: logger.warning("Model already exists, skipping build") @@ -77,7 +80,12 @@ class DummyRegressorModel(DummyModelBase): # Meta parameters for this class type: ClassVar[str] = "DummyRegressorModel" - mod_class: ClassVar[type] = DummyRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from sklearn.dummy import DummyRegressor + + return DummyRegressor # DummyRegressor parameters strategy: str = "mean" # Default strategy for dummy models @@ -96,7 +104,12 @@ class DummyClassifierModel(DummyModelBase): # Meta parameters for this class type: ClassVar[str] = "DummyClassifierModel" - mod_class: ClassVar[type] = DummyClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from sklearn.dummy import DummyClassifier + + return DummyClassifier # DummyClassifier parameters strategy: str = "most_frequent" # Default strategy for dummy models diff --git a/openadmet/models/architecture/lgbm.py b/openadmet/models/architecture/lgbm.py index bf73dcac..cd47fd85 100644 --- a/openadmet/models/architecture/lgbm.py +++ b/openadmet/models/architecture/lgbm.py @@ -2,7 +2,6 @@ from typing import ClassVar -import lightgbm as lgb import numpy as np from loguru import logger @@ -14,7 +13,11 @@ class LGBMModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the LightGBM estimator class (deferred import).""" + raise NotImplementedError # LGBM parameters boosting_type: str = "gbdt" @@ -41,7 +44,7 @@ class LGBMModelBase(PickleableModelBase): def build(self): """Prepare the model.""" if not self.estimator: - self.estimator = self.mod_class(**self.model_dump()) + self.estimator = self._get_estimator_class()(**self.model_dump()) else: logger.warning("Model already exists, skipping build") @@ -88,7 +91,12 @@ class LGBMRegressorModel(LGBMModelBase): # Meta parameters for this class type: ClassVar[str] = "LGBMRegressorModel" - mod_class: ClassVar[type] = lgb.LGBMRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from lightgbm import LGBMRegressor + + return LGBMRegressor @models.register("LGBMClassifierModel") @@ -97,7 +105,12 @@ class LGBMClassifierModel(LGBMModelBase): # Meta parameters for this class type: ClassVar[str] = "LGBMClassifierModel" - mod_class: ClassVar[type] = lgb.LGBMClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from lightgbm import LGBMClassifier + + return LGBMClassifier def predict_proba(self, X: np.ndarray) -> np.ndarray: """Predict using the model.""" diff --git a/openadmet/models/architecture/lightning_model_base.py b/openadmet/models/architecture/lightning_model_base.py new file mode 100644 index 00000000..e8d20f51 --- /dev/null +++ b/openadmet/models/architecture/lightning_model_base.py @@ -0,0 +1,253 @@ +""" +Lightning-specific base classes for deep learning models. + +This module is intentionally separate from model_base so that importing +PickleableModelBase (sklearn-style models) does not incur the cost of +loading torch or lightning.pytorch. +""" + +import json +from abc import abstractmethod +from dataclasses import dataclass +from os import PathLike +from typing import Any, ClassVar + +import torch +from lightning import pytorch as pl +from pydantic import field_validator + +from openadmet.models.architecture.model_base import ModelBase +from openadmet.models.drivers import DriverType + + +@dataclass +class LightningModuleBase(pl.LightningModule): + """ + Lightning module base class. + + A PyTorch lightning model may inherit this instead of pl.LightningModule + to preconfigure optimizer and scheduler. + """ + + # Meta parameters for this class + type: ClassVar[str] + + # Optimizer and scheduler configuration + optimizer: str = "adamw" + optimizer_lr: float = 1e-3 + optimizer_weight_decay: float = 1e-5 + scheduler: str = "cosine" + scheduler_factor: float = 0.5 + scheduler_patience: int = 10 + monitor_metric: str = "val_loss" + + def __post_init__(self): + """Defer initialization of the LightningModuleBase.""" + pl.LightningModule.__init__(self) + + @field_validator("monitor_metric") + @classmethod + def check_monitor_metric(cls, value): + """Check if the monitor metric is valid.""" + allowed = ["val_loss", "train_loss"] + if value.lower() not in allowed: + raise ValueError(f"Monitored metric must be one of {allowed}") + return value + + @field_validator("optimizer") + @classmethod + def validate_optimizer(cls, value): + """Validate the optimizer parameter.""" + allowed = {"adamw", "adam", "sgd"} + if value.lower() not in allowed: + raise ValueError(f"Optimizer must be one of {allowed}") + return value + + @field_validator("scheduler") + @classmethod + def validate_scheduler(cls, value): + """Validate the scheduler parameter.""" + allowed = {"cosine", "reduce_on_plateau", "none", None} + if (value.lower() not in allowed) and (value is not None): + raise ValueError(f"Scheduler must be one of {allowed}") + return value + + def configure_optimizers(self): + """Return optimizer and scheduler configuration for Lightning's configure_optimizers.""" + # Adamw optimizer + if self.optimizer.lower() == "adamw": + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.optimizer_lr, + weight_decay=self.optimizer_weight_decay, + ) + + # Adam optimizer + elif self.optimizer.lower() == "adam": + optimizer = torch.optim.Adam( + self.parameters(), + lr=self.optimizer_lr, + weight_decay=self.optimizer_weight_decay, + ) + + # SGD optimizer + elif self.optimizer.lower() == "sgd": + optimizer = torch.optim.SGD( + self.parameters(), + lr=self.optimizer_lr, + weight_decay=self.optimizer_weight_decay, + ) + + # Cosine scheduler + if self.scheduler.lower() == "cosine": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=10, # T_max could be exposed as a parameter + ) + + scheduler_config = { + "scheduler": scheduler, + "monitor": self.monitor_metric, + "interval": "epoch", + "frequency": 1, + } + + # Reduce on plateau scheduler + elif self.scheduler.lower() == "reduce_on_plateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, + factor=self.scheduler_factor, + patience=self.scheduler_patience, + ) + + scheduler_config = { + "scheduler": scheduler, + "monitor": self.monitor_metric, + "interval": "epoch", + "frequency": 1, + } + + # No scheduler + elif (self.scheduler is None) or (self.scheduler.lower() == "none"): + scheduler_config = None + + # Return optimizer and scheduler configuration + if scheduler_config: + return {"optimizer": optimizer, "lr_scheduler": scheduler_config} + else: + return optimizer + + +class LightningModelBase(ModelBase): + """A model that uses PyTorch Lightning.""" + + # Meta parameters for this class + type: ClassVar[str] + _model_save_name: ClassVar[str] = "model.pth" + _driver_type: DriverType = DriverType.LIGHTNING + + def make_new(self): + """ + Copy parameters to a new model instance without copying the estimator. + + Returns + ------- + LightningModelBase + A new instance of LightningModelBase with the same parameters. + + """ + return self.__class__(**self.model_dump(exclude={"estimator"})) + + def save(self, path: PathLike): + """ + Save the model to a file. + + Parameters + ---------- + path: PathLike + Path to save the model to + + """ + torch.save(self.estimator.state_dict(), path) + + def load(self, path: PathLike): + """ + Load the model from a file. + + Parameters + ---------- + path: PathLike + Path to load the model from + + """ + self.estimator.load_state_dict(torch.load(path, weights_only=True)) + + def serialize( + self, param_path: PathLike = "model.json", serial_path: PathLike = "model.pth" + ): + """ + Save the model to a json file and a serialized file. + + Parameters + ---------- + param_path: PathLike + Path to save the model parameters to + serial_path: PathLike + Path to save the serialized model to + + """ + with open(param_path, "w") as f: + f.write(self.model_dump_json(indent=2)) + self.save(serial_path) + + @classmethod + def deserialize( + cls, + param_path: PathLike = "model.json", + serial_path: PathLike = "model.pth", + scaler: Any = None, + ): + """ + Create a model from parameters and a serialized model. + + Parameters + ---------- + param_path: PathLike + Path to load the model parameters from + serial_path: PathLike + Path to load the serialized model from + scaler: Any, optional + Scaler for target normalization, if applicable + + Returns + ------- + instance: LightningModelBase + An instance of the LightningModelBase class + + """ + with open(param_path) as f: + mod_params = json.load(f) + instance = cls(**mod_params) + instance.build(scaler=scaler) + instance.load(serial_path) + return instance + + def freeze_weights(self, *args, **kwargs): + """ + Freeze parts of the model for transfer learning or fine-tuning. + + Parameters + ---------- + *args: variable length argument list + Arguments to be passed to the implementing model's `freeze_weights` method. + **kwargs: keyword arguments + Keyword arguments to be passed to the implementing model's `freeze_weights` method. + + Notes + ----- + This method should set the `requires_grad` attribute of the specified layers to False, + preventing their weights from being updated during training. It also should set these + layers to evaluation mode. + + """ + raise NotImplementedError(f"Weight freezing not implemented for {self.type}.") diff --git a/openadmet/models/architecture/model_base.py b/openadmet/models/architecture/model_base.py index 83981aa5..88172541 100644 --- a/openadmet/models/architecture/model_base.py +++ b/openadmet/models/architecture/model_base.py @@ -2,22 +2,41 @@ import json from abc import ABC, abstractmethod -from dataclasses import dataclass from os import PathLike from typing import Any, ClassVar -import joblib -import torch + from class_registry import ClassRegistry, RegistryKeyError -from lightning import pytorch as pl from loguru import logger -from pydantic import BaseModel, field_validator +from pydantic import BaseModel + from openadmet.models.drivers import DriverType models = ClassRegistry(unique=True) def get_mod_class(model_type): - """Get the model class from the registry.""" + """ + Get the model class from the registry by type. + + Parameters + ---------- + model_type : str + The registered key for the model (e.g., ``"XGBRegressorModel"``). + + Returns + ------- + type + The model class corresponding to the given type. + + Raises + ------ + ValueError + If ``model_type`` is not found in the model registry. + + """ + from openadmet.models._registry_loader import load_group + + load_group("models") try: feat_class = models.get_class(model_type) except RegistryKeyError: @@ -153,6 +172,8 @@ def save(self, path: PathLike): if self.estimator is None: raise ValueError("Model is not built, cannot save") + import joblib + with open(path, "wb") as f: joblib.dump(self.estimator, f) @@ -166,6 +187,8 @@ def load(self, path: PathLike): Path to load the model from """ + import joblib + with open(path, "rb") as f: self.estimator = joblib.load(f) @@ -219,234 +242,29 @@ def serialize( self.save(serial_path) -@dataclass -class LightningModuleBase(pl.LightningModule): - """ - Lightning module base class. +# Re-export Lightning base classes using lazy module __getattr__ (PEP 562) so that +# importing model_base does NOT pull in torch or lightning.pytorch. +# The actual definitions live in lightning_model_base. +_LIGHTNING_EXPORTS = frozenset({"LightningModelBase", "LightningModuleBase"}) - A PyTorch lightning model may inherit this instead of pl.LightningModule - to preconfigure optimizer and scheduler. - """ - - # Meta parameters for this class - type: ClassVar[str] - - # Optimizer and scheduler configuration - optimizer: str = "adamw" - optimizer_lr: float = 1e-3 - optimizer_weight_decay: float = 1e-5 - scheduler: str = "cosine" - scheduler_factor: float = 0.5 - scheduler_patience: int = 10 - monitor_metric: str = "val_loss" - def __post_init__(self): - """Defer initialization of the LightningModuleBase.""" - pl.LightningModule.__init__(self) +def __getattr__(name: str): + """Lazily re-export Lightning base classes to avoid paying their import cost.""" + if name in _LIGHTNING_EXPORTS: + from openadmet.models.architecture import lightning_model_base as _lmb - @field_validator("monitor_metric") - @classmethod - def check_monitor_metric(cls, value): - """Check if the monitor metric is valid.""" - allowed = ["val_loss", "train_loss"] - if value.lower() not in allowed: - raise ValueError(f"Monitored metric must be one of {allowed}") + value = getattr(_lmb, name) + # Cache in module dict so subsequent accesses are direct + globals()[name] = value return value + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") - @field_validator("optimizer") - @classmethod - def validate_optimizer(cls, value): - """Validate the optimizer parameter.""" - allowed = {"adamw", "adam", "sgd"} - if value.lower() not in allowed: - raise ValueError(f"Optimizer must be one of {allowed}") - return value - - @field_validator("scheduler") - @classmethod - def validate_scheduler(cls, value): - """Validate the scheduler parameter.""" - allowed = {"cosine", "reduce_on_plateau", "none", None} - if (value.lower() not in allowed) and (value is not None): - raise ValueError(f"Scheduler must be one of {allowed}") - return value - - def configure_optimizers(self): - """Return optimizer and scheduler configuration for Lightning's configure_optimizers.""" - # Adamw optimizer - if self.optimizer.lower() == "adamw": - optimizer = torch.optim.AdamW( - self.parameters(), - lr=self.optimizer_lr, - weight_decay=self.optimizer_weight_decay, - ) - - # Adam optimizer - elif self.optimizer.lower() == "adam": - optimizer = torch.optim.Adam( - self.parameters(), - lr=self.optimizer_lr, - weight_decay=self.optimizer_weight_decay, - ) - - # SGD optimizer - elif self.optimizer.lower() == "sgd": - optimizer = torch.optim.SGD( - self.parameters(), - lr=self.optimizer_lr, - weight_decay=self.optimizer_weight_decay, - ) - - # Cosine scheduler - if self.scheduler.lower() == "cosine": - scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, - T_max=10, # T_max could be exposed as a parameter - ) - - scheduler_config = { - "scheduler": scheduler, - "monitor": self.monitor_metric, - "interval": "epoch", - "frequency": 1, - } - - # Reduce on plateau scheduler - elif self.scheduler.lower() == "reduce_on_plateau": - scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( - optimizer, - factor=self.scheduler_factor, - patience=self.scheduler_patience, - ) - - scheduler_config = { - "scheduler": scheduler, - "monitor": self.monitor_metric, - "interval": "epoch", - "frequency": 1, - } - - # No scheduler - elif (self.scheduler is None) or (self.scheduler.lower() == "none"): - scheduler_config = None - - # Return optimizer and scheduler configuration - if scheduler_config: - return {"optimizer": optimizer, "lr_scheduler": scheduler_config} - else: - return optimizer - - -class LightningModelBase(ModelBase): - """A model that uses PyTorch Lightning.""" - - # Meta parameters for this class - type: ClassVar[str] - _model_save_name: ClassVar[str] = "model.pth" - _driver_type: DriverType = DriverType.LIGHTNING - - def make_new(self): - """ - Copy parameters to a new model instance without copying the estimator. - - Returns - ------- - LightningModelBase - A new instance of LightningModelBase with the same parameters. - - """ - return self.__class__(**self.model_dump(exclude={"estimator"})) - - def save(self, path: PathLike): - """ - Save the model to a file. - - Parameters - ---------- - path: PathLike - Path to save the model to - - """ - torch.save(self.estimator.state_dict(), path) - - def load(self, path: PathLike): - """ - Load the model from a file. - - Parameters - ---------- - path: PathLike - Path to load the model from - - """ - self.estimator.load_state_dict(torch.load(path, weights_only=True)) - - def serialize( - self, param_path: PathLike = "model.json", serial_path: PathLike = "model.pth" - ): - """ - Save the model to a json file and a serialized file. - - Parameters - ---------- - param_path: PathLike - Path to save the model parameters to - serial_path: PathLike - Path to save the serialized model to - - """ - with open(param_path, "w") as f: - f.write(self.model_dump_json(indent=2)) - self.save(serial_path) - - @classmethod - def deserialize( - cls, - param_path: PathLike = "model.json", - serial_path: PathLike = "model.pth", - scaler: Any = None, - ): - """ - Create a model from parameters and a serialized model. - - Parameters - ---------- - param_path: PathLike - Path to load the model parameters from - serial_path: PathLike - Path to load the serialized model from - scaler: Any, optional - Scaler for target normalization, if applicable - - Returns - ------- - instance: LightningModelBase - An instance of the LightningModelBase class - """ - with open(param_path) as f: - mod_params = json.load(f) - instance = cls(**mod_params) - instance.build(scaler=scaler) - instance.load(serial_path) - return instance - - def freeze_weights(self, *args, **kwargs): - """ - Freeze parts of the model for transfer learning or fine-tuning. - - Parameters - ---------- - *args: variable length argument list - Arguments to be passed to the implementing model's `freeze_weights` method. - **kwargs: keyword arguments - Keyword arguments to be passed to the implementing model's `freeze_weights` method. - - Notes - ----- - This method should set the `requires_grad` attribute of the specified layers to False, - preventing their weights from being updated during training. It also should set these - layers to evaluation mode. - - """ - raise NotImplementedError(f"Weight freezing not implemented for {self.type}.") +__all__ = [ + "ModelBase", + "PickleableModelBase", + "LightningModuleBase", + "LightningModelBase", + "models", + "get_mod_class", +] diff --git a/openadmet/models/architecture/nepare.py b/openadmet/models/architecture/nepare.py index a7697d4f..50efd5ca 100644 --- a/openadmet/models/architecture/nepare.py +++ b/openadmet/models/architecture/nepare.py @@ -9,7 +9,7 @@ from collections import OrderedDict from openadmet.models.architecture.model_base import models as model_registry -from openadmet.models.architecture.model_base import ( +from openadmet.models.architecture.lightning_model_base import ( LightningModuleBase, LightningModelBase, ) diff --git a/openadmet/models/architecture/rf.py b/openadmet/models/architecture/rf.py index 88222700..deb264cb 100644 --- a/openadmet/models/architecture/rf.py +++ b/openadmet/models/architecture/rf.py @@ -4,7 +4,6 @@ import numpy as np from loguru import logger -from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from openadmet.models.architecture.model_base import PickleableModelBase, models @@ -14,12 +13,16 @@ class RFModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the Random Forest estimator class (deferred import).""" + raise NotImplementedError def build(self): """Prepare the model.""" if not self.estimator: - self.estimator = self.mod_class(**self.model_dump()) + self.estimator = self._get_estimator_class()(**self.model_dump()) else: logger.warning("Model already exists, skipping build") @@ -66,7 +69,12 @@ class RFRegressorModel(RFModelBase): # Meta parameters for this class type: ClassVar[str] = "RFRegressorModel" - mod_class: ClassVar[type] = RandomForestRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from sklearn.ensemble import RandomForestRegressor + + return RandomForestRegressor # RF parameters n_estimators: int = 100 @@ -95,7 +103,12 @@ class RFClassifierModel(RFModelBase): # Meta parameters for this class type: ClassVar[str] = "RFClassifierModel" - mod_class: ClassVar[type] = RandomForestClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from sklearn.ensemble import RandomForestClassifier + + return RandomForestClassifier # RF parameters n_estimators: int = 100 diff --git a/openadmet/models/architecture/svm.py b/openadmet/models/architecture/svm.py index ee61b912..dfa621a9 100644 --- a/openadmet/models/architecture/svm.py +++ b/openadmet/models/architecture/svm.py @@ -4,7 +4,6 @@ import numpy as np from loguru import logger -from sklearn.svm import SVC, SVR from openadmet.models.architecture.model_base import PickleableModelBase, models @@ -14,12 +13,16 @@ class SVMModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the SVM estimator class (deferred import).""" + raise NotImplementedError def build(self): """Prepare the model.""" if not self.estimator: - self.estimator = self.mod_class(**self.model_dump()) + self.estimator = self._get_estimator_class()(**self.model_dump()) else: logger.warning("Model already exists, skipping build") @@ -80,7 +83,12 @@ class SVMRegressorModel(SVMModelBase): # Meta parameters for this class type: ClassVar[str] = "SVMRegressorModel" - mod_class: ClassVar[type] = SVR + + @classmethod + def _get_estimator_class(cls) -> type: + from sklearn.svm import SVR + + return SVR # SVR parameters kernel: str = "rbf" @@ -115,7 +123,12 @@ class SVMClassifierModel(SVMModelBase): # Meta parameters for this class type: ClassVar[str] = "SVMClassifierModel" - mod_class: ClassVar[type] = SVC + + @classmethod + def _get_estimator_class(cls) -> type: + from sklearn.svm import SVC + + return SVC # SVC parameters C: float = 1.0 diff --git a/openadmet/models/architecture/tabpfn.py b/openadmet/models/architecture/tabpfn.py index d9a706d8..ac937e6a 100644 --- a/openadmet/models/architecture/tabpfn.py +++ b/openadmet/models/architecture/tabpfn.py @@ -1,16 +1,11 @@ """TabPFN model implementations.""" -from typing import ClassVar, Literal, Optional, Union +import warnings +from typing import ClassVar, Literal, Optional import numpy as np from loguru import logger from pydantic import Field, field_validator -from tabpfn import TabPFNClassifier, TabPFNRegressor -from tabpfn_extensions.post_hoc_ensembles.sklearn_interface import ( - AutoTabPFNClassifier, - AutoTabPFNRegressor, -) -import warnings from openadmet.models.architecture.model_base import PickleableModelBase, models @@ -26,8 +21,6 @@ class TabPFNExtensionModelBase(PickleableModelBase): ---------- type : ClassVar[str] Model type identifier. - mod_class : ClassVar[type] - The TabPFN model class (e.g., AutoTabPFNRegressor or AutoTabPFNClassifier). max_time : Optional[int] Maximum time to spend on fitting the post hoc ensemble. accelerator : Literal["cpu", "gpu", "auto"] @@ -43,7 +36,11 @@ class TabPFNExtensionModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the TabPFN extension estimator class (deferred import).""" + raise NotImplementedError # TabPFN parameters max_time: Optional[int] = Field( @@ -96,7 +93,7 @@ def build(self): "TabPFN 2.5 is distributed under the TabPFN 2.5 License: https://priorlabs.ai/tabpfn-license which prohibits commercial use. Review the license and ensure you are compliant before using this model. A commercial license can be obtained from the TabPFN team." ) if not self.estimator: - self.estimator = self.mod_class( + self.estimator = self._get_estimator_class()( max_time=self.max_time, device=accelerator, random_state=self.random_state, @@ -149,7 +146,14 @@ class TabPFNPostHocRegressorModel(TabPFNExtensionModelBase): # Meta parameters for this class type: ClassVar[str] = "TabPFNPostHocRegressorModel" - mod_class: ClassVar[type] = AutoTabPFNRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from tabpfn_extensions.post_hoc_ensembles.sklearn_interface import ( + AutoTabPFNRegressor, + ) + + return AutoTabPFNRegressor @models.register("TabPFNPostHocClassifierModel") @@ -158,7 +162,14 @@ class TabPFNPostHocClassifierModel(TabPFNExtensionModelBase): # Meta parameters for this class type: ClassVar[str] = "TabPFNPostHocClassifierModel" - mod_class: ClassVar[type] = AutoTabPFNClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from tabpfn_extensions.post_hoc_ensembles.sklearn_interface import ( + AutoTabPFNClassifier, + ) + + return AutoTabPFNClassifier def predict_proba(self, X: np.ndarray) -> np.ndarray: """ @@ -197,7 +208,11 @@ class TabPFNModelBase(PickleableModelBase): # Meta parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the basic TabPFN estimator class (deferred import).""" + raise NotImplementedError # TabPFN parameters accelerator: Literal["cpu", "cuda", "auto"] = Field(default="auto") @@ -208,7 +223,7 @@ def build(self): """Prepare and build the model instance.""" accelerator = self.accelerator if self.accelerator != "gpu" else "cuda" if not self.estimator: - self.estimator = self.mod_class( + self.estimator = self._get_estimator_class()( device=accelerator, random_state=self.random_state, ignore_pretraining_limits=self.ignore_pretraining_limits, @@ -264,7 +279,12 @@ class TabPFNRegressorModel(TabPFNModelBase): # Meta parameters for this class type: ClassVar[str] = "TabPFNRegressorModel" - mod_class: ClassVar[type] = TabPFNRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from tabpfn import TabPFNRegressor + + return TabPFNRegressor @models.register("TabPFNClassifierModel") @@ -273,4 +293,9 @@ class TabPFNClassifierModel(TabPFNModelBase): # Meta parameters for this class type: ClassVar[str] = "TabPFNClassifierModel" - mod_class: ClassVar[type] = TabPFNClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from tabpfn import TabPFNClassifier + + return TabPFNClassifier diff --git a/openadmet/models/architecture/xgboost.py b/openadmet/models/architecture/xgboost.py index d45617c3..5ebaf296 100644 --- a/openadmet/models/architecture/xgboost.py +++ b/openadmet/models/architecture/xgboost.py @@ -5,7 +5,6 @@ import numpy as np from loguru import logger from pydantic import ConfigDict -from xgboost import XGBClassifier, XGBRegressor from openadmet.models.architecture.model_base import PickleableModelBase, models @@ -18,12 +17,16 @@ class XGBoostModelBase(PickleableModelBase): # Meta-parameters for this class type: ClassVar[str] - mod_class: ClassVar[type] + + @classmethod + def _get_estimator_class(cls) -> type: + """Return the XGBoost estimator class (deferred import).""" + raise NotImplementedError def build(self): """Prepare the model.""" if not self.estimator: - self.estimator = self.mod_class(**self.model_dump()) + self.estimator = self._get_estimator_class()(**self.model_dump()) else: logger.warning("Model already exists, skipping build") @@ -83,7 +86,12 @@ class XGBRegressorModel(XGBoostModelBase): """ type: ClassVar[str] = "XGBRegressorModel" - mod_class: ClassVar[type] = XGBRegressor + + @classmethod + def _get_estimator_class(cls) -> type: + from xgboost import XGBRegressor + + return XGBRegressor @models.register("XGBClassifierModel") @@ -104,7 +112,12 @@ class XGBClassifierModel(XGBoostModelBase): """ type: ClassVar[str] = "XGBoostClaXGBClassifierModelssifierModel" - mod_class: ClassVar[type] = XGBClassifier + + @classmethod + def _get_estimator_class(cls) -> type: + from xgboost import XGBClassifier + + return XGBClassifier def predict_proba(self, X: np.ndarray) -> np.ndarray: """ diff --git a/openadmet/models/eval/cross_validation.py b/openadmet/models/eval/cross_validation.py index b20068c6..60239328 100644 --- a/openadmet/models/eval/cross_validation.py +++ b/openadmet/models/eval/cross_validation.py @@ -20,8 +20,6 @@ from openadmet.models.eval.eval_base import EvalBase, evaluators, get_t_true_and_t_pred from openadmet.models.eval.regression import ( RegressionPlots, - nan_omit_ktau, - nan_omit_spearmanr, pct_within_1_log_unit, relative_absolute_error, ) @@ -32,11 +30,21 @@ def wrap_ktau(y_true, y_pred): """Wrap ktau nan omission.""" + from functools import partial + + from scipy.stats import kendalltau + + nan_omit_ktau = partial(kendalltau, nan_policy="omit") return nan_omit_ktau(y_true, y_pred).statistic def wrap_spearmanr(y_true, y_pred): """Wrap spearmanR nan omission.""" + from functools import partial + + from scipy.stats import spearmanr + + nan_omit_spearmanr = partial(spearmanr, nan_policy="omit") return nan_omit_spearmanr(y_true, y_pred).correlation diff --git a/openadmet/models/eval/eval_base.py b/openadmet/models/eval/eval_base.py index 42ad7304..82aa90bc 100644 --- a/openadmet/models/eval/eval_base.py +++ b/openadmet/models/eval/eval_base.py @@ -3,11 +3,10 @@ from abc import abstractmethod from typing import Callable, ClassVar -from loguru import logger import numpy as np from class_registry import ClassRegistry, RegistryKeyError +from loguru import logger from pydantic import BaseModel -from scipy.stats import bootstrap evaluators = ClassRegistry(unique=True) @@ -32,6 +31,9 @@ def get_eval_class(eval_type): If the evaluation type is not found in the registry. """ + from openadmet.models._registry_loader import load_group + + load_group("evaluators") try: eval_class = evaluators.get_class(eval_type) except RegistryKeyError: @@ -237,6 +239,8 @@ def stat_and_bootstrap( """ # calculate the metric and confidence intervals + from scipy.stats import bootstrap + if is_scipy_statistic: metric = statistic(y_true, y_pred).statistic conf_interval = bootstrap( diff --git a/openadmet/models/eval/regression.py b/openadmet/models/eval/regression.py index 58057834..f9a83f31 100644 --- a/openadmet/models/eval/regression.py +++ b/openadmet/models/eval/regression.py @@ -1,16 +1,10 @@ """Regression metrics and plots for model evaluation.""" import json -from functools import partial import numpy as np import pandas as pd -import seaborn as sns -import wandb -from matplotlib import pyplot as plt from pydantic import Field -from scipy.stats import kendalltau, spearmanr -from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score from openadmet.models.eval.eval_base import ( EvalBase, @@ -19,10 +13,6 @@ ) from openadmet.models.eval.utils import _make_stat_caption, _make_stat_dict -# create partial functions for the scipy stats -nan_omit_ktau = partial(kendalltau, nan_policy="omit") -nan_omit_spearmanr = partial(spearmanr, nan_policy="omit") - def relative_absolute_error(y_true, y_pred): """ @@ -104,19 +94,38 @@ class RegressionMetrics(EvalBase): ) _evaluated: bool = False - _metrics: dict = { - "mse": (mean_squared_error, False, "MSE"), - "mae": (mean_absolute_error, False, "MAE"), - "r2": (r2_score, False, "$R^2$"), - "ktau": (nan_omit_ktau, True, "Kendall's $\\tau$"), - "spearmanr": (nan_omit_spearmanr, True, "Spearman's $\\rho$"), - "rae": (relative_absolute_error, False, "RAE"), - } + @classmethod + def _base_metrics(cls) -> dict: + """ + Build the base metrics dictionary with deferred 3rd-party imports. + + Returns + ------- + dict + Mapping of metric key to ``(callable, is_scipy_statistic, display_label)`` + tuples for MSE, MAE, R², Kendall's τ, Spearman's ρ, and RAE. + + """ + from functools import partial + + from scipy.stats import kendalltau, spearmanr + from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score + + nan_omit_ktau = partial(kendalltau, nan_policy="omit") + nan_omit_spearmanr = partial(spearmanr, nan_policy="omit") + return { + "mse": (mean_squared_error, False, "MSE"), + "mae": (mean_absolute_error, False, "MAE"), + "r2": (r2_score, False, "$R^2$"), + "ktau": (nan_omit_ktau, True, "Kendall's $\\tau$"), + "spearmanr": (nan_omit_spearmanr, True, "Spearman's $\\rho$"), + "rae": (relative_absolute_error, False, "RAE"), + } @property def active_metrics(self): """Return metrics applicable to the current target scale.""" - metrics = dict(self._metrics) + metrics = self._base_metrics() if self.pXC50: metrics["pct_within_1_log"] = ( pct_within_1_log_unit, @@ -205,6 +214,8 @@ def evaluate( } if self.use_wandb: + import wandb + for t_label in target_labels: # make a table for the metrics table = wandb.Table( @@ -294,6 +305,8 @@ def write_report(self, output_dir): # also log the json to wandb if self.use_wandb: + import wandb + artifact = wandb.Artifact(name="metrics_json", type="metric_json") # Add a file to the artifact artifact.add_file(json_path) @@ -351,7 +364,7 @@ def get_stat_dict(self, t_label): data=self.data, task_name=t_label, metric_names=self.metric_names, - metrics=self._metrics, + metrics=self._base_metrics(), confidence_level=self.bootstrap_confidence_level, cv=False, ) @@ -574,6 +587,7 @@ def regplot( else: max_ax = max_val # set the limits to be the same for both axes + import seaborn as sns g = sns.jointplot( x=np.ravel(y_true), @@ -718,6 +732,8 @@ def ciplot(stat_dict={}): } n_metrics = len(metrics) + from matplotlib import pyplot as plt + fig, axes = plt.subplots(1, n_metrics, figsize=(8, n_metrics), sharex=False) if n_metrics == 1: @@ -783,4 +799,6 @@ def write_report(self, output_dir): plot_path = output_dir / f"{plot_tag}.png" plot.savefig(plot_path, dpi=self.dpi) if self.use_wandb: + import wandb + wandb.log({plot_tag: wandb.Image(str(plot_path))}) diff --git a/openadmet/models/features/chemprop.py b/openadmet/models/features/chemprop.py index cb26d859..1866b0c6 100644 --- a/openadmet/models/features/chemprop.py +++ b/openadmet/models/features/chemprop.py @@ -1,32 +1,20 @@ """ChemProp featurizer implementation.""" +from __future__ import annotations + from collections.abc import Iterable from typing import Any, Union + import numpy as np import pandas as pd -from chemprop.data import ( - MoleculeDatapoint, - MoleculeDataset, - MulticomponentDataset, - ReactionDataset, -) -from chemprop.data.collate import collate_batch, collate_multicomponent -from chemprop.data.samplers import ClassBalanceSampler, SeededSampler -from sklearn.preprocessing import StandardScaler -from torch.utils.data import DataLoader - -from openadmet.models.features.chemprop import ( - MoleculeDataset, - MulticomponentDataset, - ReactionDataset, -) + from openadmet.models.features.feature_base import DeepLearningFeaturizer, featurizers # we vendor this from chemprop so that we can pass custom samplers # taken directly from https://github.com/chemprop/chemprop/blob/main/chemprop/data/dataloader.py def _vendor_build_dataloader( - dataset: MoleculeDataset | ReactionDataset | MulticomponentDataset, + dataset, batch_size: int = 64, num_workers: int = 0, class_balance: bool = False, @@ -68,6 +56,11 @@ def _vendor_build_dataloader( A PyTorch DataLoader for the given MoleculeDataset, ReactionDataset, or MulticomponentDataset. """ + from chemprop.data import MulticomponentDataset + from chemprop.data.collate import collate_batch, collate_multicomponent + from chemprop.data.samplers import ClassBalanceSampler, SeededSampler + from torch.utils.data import DataLoader + if sampler is not None: if class_balance: sampler = ClassBalanceSampler(dataset.Y, seed, shuffle) @@ -131,7 +124,7 @@ def featurize( DataLoader, np.ndarray, StandardScaler, - Union[MoleculeDataset, ReactionDataset, MulticomponentDataset], + MoleculeDataset | ReactionDataset | MulticomponentDataset, ]: """ Featurize a list of SMILES strings. @@ -153,6 +146,8 @@ def featurize( - Union[MoleculeDataset, ReactionDataset, MulticomponentDataset]: PyTorch Dataset containing the features and targets. """ + from chemprop.data import MoleculeDatapoint, MoleculeDataset + if y is not None: # if a pandas dataframe or series if isinstance(y, pd.DataFrame) or isinstance(y, pd.Series): @@ -222,6 +217,6 @@ def dataset_to_dataloader( **kwargs, ) - def make_new(self) -> "ChemPropFeaturizer": + def make_new(self) -> ChemPropFeaturizer: """Copy parameters to a new ChemPropFeaturizer instance.""" return self.__class__(**self.dict()) diff --git a/openadmet/models/features/feature_base.py b/openadmet/models/features/feature_base.py index 5d333452..e0464a75 100644 --- a/openadmet/models/features/feature_base.py +++ b/openadmet/models/features/feature_base.py @@ -1,20 +1,46 @@ """Base classes and utilities for molecular featurizers.""" +from __future__ import annotations + from abc import ABC, abstractmethod from collections.abc import Iterable +from typing import TYPE_CHECKING, Any import numpy as np from class_registry import ClassRegistry, RegistryKeyError -from molfeat.trans import MoleculeTransformer from pydantic import BaseModel -from sklearn.preprocessing import StandardScaler -from torch.utils.data import DataLoader, Dataset + +if TYPE_CHECKING: + from molfeat.trans import MoleculeTransformer + from sklearn.preprocessing import StandardScaler + from torch.utils.data import DataLoader, Dataset featurizers = ClassRegistry(unique=True) def get_featurizer_class(feat_type): - """Retrieve a featurizer class from the registry by type.""" + """ + Retrieve a featurizer class from the registry by type. + + Parameters + ---------- + feat_type : str + The registered key for the featurizer (e.g., ``"MolfeatFeaturizer"``). + + Returns + ------- + type + The featurizer class corresponding to the given type. + + Raises + ------ + ValueError + If ``feat_type`` is not found in the featurizer registry. + + """ + from openadmet.models._registry_loader import load_group + + load_group("featurizers") try: feat_class = featurizers.get_class(feat_type) except RegistryKeyError: @@ -107,7 +133,7 @@ class MolfeatFeaturizer(FeaturizerBase): """ - _transformer: MoleculeTransformer = None + _transformer: Any = None def __init__(self, *args, **kwargs): """ diff --git a/openadmet/models/registries.py b/openadmet/models/registries.py index 66b11333..e4983312 100644 --- a/openadmet/models/registries.py +++ b/openadmet/models/registries.py @@ -1,53 +1,19 @@ -"""Registry imports for all model components.""" - -# IMPORTANT: order matters here, make sure to import all the registered classes first -# before importing the registry classes -# there must be a better way to do this, but for now this works - -# active learning -from openadmet.models.active_learning.committee import * # noqa: F401 F403 -from openadmet.models.active_learning.ensemble_base import ensemblers # noqa: F401 F403 - -# models -from openadmet.models.architecture.catboost import * # noqa: F401 F403 -from openadmet.models.architecture.chemprop import * # noqa: F401 F403 # noqa: F401 F403 -from openadmet.models.architecture.dummy import * # noqa: F401 F403 -from openadmet.models.architecture.lgbm import * # noqa: F401 F403 # noqa: F401 F403 -from openadmet.models.architecture.nepare import * # noqa: F401 F403 -from openadmet.models.architecture.rf import * # noqa: F401 F403 -from openadmet.models.architecture.svm import * # noqa: F401 F403 -from openadmet.models.architecture.tabpfn import * # noqa: F401 F403 -from openadmet.models.architecture.xgboost import * # noqa: F401 F403 -from openadmet.models.architecture.model_base import models # noqa: F401 F403 - -# evaluators -from openadmet.models.eval.classification import * # noqa: F401 F403 -from openadmet.models.eval.cross_validation import * # noqa: F401 F403 -from openadmet.models.eval.regression import * # noqa: F401 F403 -from openadmet.models.eval.uncertainty import * # noqa: F401 F403 -from openadmet.models.eval.eval_base import evaluators # noqa: F401 F403 - -# featurizers -from openadmet.models.features.chemprop import * # noqa: F401 F403 -from openadmet.models.features.combine import * # noqa: F401 F403 -from openadmet.models.features.molfeat_fingerprint import * # noqa: F401 F403 -from openadmet.models.features.molfeat_properties import * # noqa: F401 F403 -from openadmet.models.features.feature_base import featurizers # noqa: F401 F403 - -# util -from openadmet.models.log import logger # noqa: F401 F403 - -# splitters -from openadmet.models.split.scaffold import * # noqa: F401 F403 -from openadmet.models.split.sklearn import * # noqa: F401 F403 -from openadmet.models.split.split_base import splitters # noqa: F401 F403 -from openadmet.models.split.cluster import * # noqa: F401 F403 - -# trainers -from openadmet.models.trainer.lightning import * # noqa: F401 F403 -from openadmet.models.trainer.sklearn import * # noqa: F401 F403 -from openadmet.models.trainer.trainer_base import trainers # noqa: F401 F403 - -# transforms -from openadmet.models.transforms.impute import * # noqa: F401 F403 -from openadmet.models.transforms.transform_base import * # noqa: F401 F403 +""" +Registry objects and lazy loader for all model components. + +Importing this module is intentionally cheap. Concrete classes are registered +only when the relevant group is first accessed. To eagerly load everything +(e.g. for CLI tools or the Anvil workflow), call ``load_all()``: + + from openadmet.models.registries import load_all + load_all() +""" + +from openadmet.models._registry_loader import load_all # noqa: F401 +from openadmet.models.active_learning.ensemble_base import ensemblers # noqa: F401 +from openadmet.models.architecture.model_base import models # noqa: F401 +from openadmet.models.eval.eval_base import evaluators # noqa: F401 +from openadmet.models.features.feature_base import featurizers # noqa: F401 +from openadmet.models.log import logger # noqa: F401 +from openadmet.models.split.split_base import splitters # noqa: F401 +from openadmet.models.trainer.trainer_base import trainers # noqa: F401 diff --git a/openadmet/models/split/cluster.py b/openadmet/models/split/cluster.py index 2a39dcd6..55daa3c9 100644 --- a/openadmet/models/split/cluster.py +++ b/openadmet/models/split/cluster.py @@ -1,22 +1,13 @@ """Cluster-based data splitting implementations.""" import logging -from pydantic import BaseModel, field_validator, model_validator from typing import Literal -from sklearn.model_selection import GroupShuffleSplit -from sklearn.cluster import KMeans -from molfeat.trans import MoleculeTransformer -from molfeat.trans.fp import FPVecTransformer -import datamol as dm + import numpy as np import pandas as pd +from pydantic import BaseModel, field_validator, model_validator + from openadmet.models.split.split_base import SplitterBase, splitters -from useful_rdkit_utils import ( - get_butina_clusters, - get_bemis_murcko_clusters, - get_scaffold, - smi2numpy_fp, -) @splitters.register("ClusterSplitter") @@ -78,13 +69,22 @@ def split(self, X, y, num_iters=1000): """ # Get clusters based on the selected method if self.method == "butina": + from useful_rdkit_utils import get_butina_clusters + clusters = get_butina_clusters(X, cutoff=self.butina_cutoff) elif self.method == "bemis-murcko": + from useful_rdkit_utils import get_bemis_murcko_clusters + clusters = get_bemis_murcko_clusters(X) elif self.method == "kmeans": logging.warning( "KMeans clustering is NOT DETERMINISTIC with random seed across platforms." ) + import datamol as dm + from molfeat.trans import MoleculeTransformer + from molfeat.trans.fp import FPVecTransformer + from sklearn.cluster import KMeans + km = KMeans( n_clusters=self.k_clusters, n_init=1, diff --git a/openadmet/models/split/scaffold.py b/openadmet/models/split/scaffold.py index d5dda0e1..07349986 100644 --- a/openadmet/models/split/scaffold.py +++ b/openadmet/models/split/scaffold.py @@ -1,10 +1,10 @@ """Scaffold-based data splitting implementations.""" import logging -from sklearn.model_selection import train_test_split -from splito import MaxDissimilaritySplit, PerimeterSplit, ScaffoldSplit + import numpy as np import pandas as pd + from openadmet.models.split.split_base import SplitterBase, splitters @@ -40,6 +40,8 @@ def split(self, X, y): """ logging.warning("ScaffoldSplitter is not available for cross-validation.") + from splito import ScaffoldSplit + # No test set requested if self.test_size == 0: # Split into train and val @@ -87,6 +89,8 @@ def split(self, X, y): ) # Split train+val into train and val sets + from sklearn.model_selection import train_test_split + X_train, X_val, y_train, y_val = train_test_split( safe_index(X, train_val_idx), safe_index(y, train_val_idx), @@ -135,6 +139,8 @@ def split(self, X, y): """ logging.warning("PerimeterSplitter is not available for cross-validation.") + from splito import PerimeterSplit + # No test set requested if self.test_size == 0: # Split into train and val @@ -181,6 +187,8 @@ def split(self, X, y): ) # Split train+val into train and val sets using sklearn + from sklearn.model_selection import train_test_split + X_train, X_val, y_train, y_val = train_test_split( safe_index(X, train_val_idx), safe_index(y, train_val_idx), @@ -231,6 +239,8 @@ def split(self, X, y): logging.warning( "MaxDissimilaritySplitter is not available for cross-validation." ) + from splito import MaxDissimilaritySplit + # No test set requested if self.test_size == 0: # Split into train and val @@ -277,6 +287,8 @@ def split(self, X, y): ) # Split train+val into train and val sets using sklearn + from sklearn.model_selection import train_test_split + X_train, X_val, y_train, y_val = train_test_split( safe_index(X, train_val_idx), safe_index(y, train_val_idx), diff --git a/openadmet/models/split/split_base.py b/openadmet/models/split/split_base.py index 7523cfa2..832e5d28 100644 --- a/openadmet/models/split/split_base.py +++ b/openadmet/models/split/split_base.py @@ -25,6 +25,9 @@ def get_splitter_class(feat_type): The splitter class corresponding to the given type. """ + from openadmet.models._registry_loader import load_group + + load_group("splitters") try: split_class = splitters.get_class(feat_type) except RegistryKeyError: diff --git a/openadmet/models/tests/unit/models/test_base.py b/openadmet/models/tests/unit/models/test_base.py index ed466fef..cc7be2ec 100644 --- a/openadmet/models/tests/unit/models/test_base.py +++ b/openadmet/models/tests/unit/models/test_base.py @@ -3,6 +3,7 @@ import pytest +import openadmet.models.architecture.lightning_model_base as lightning_model_base import openadmet.models.architecture.model_base as model_base from openadmet.models.architecture.model_base import ( LightningModelBase, @@ -38,7 +39,7 @@ def test_save_load_torch_model(mclass, tmp_path): def test_lightning_model_load_uses_weights_only(monkeypatch, tmp_path): state_dict = {"layer.weight": "dummy"} torch_load = Mock(return_value=state_dict) - monkeypatch.setattr(model_base.torch, "load", torch_load) + monkeypatch.setattr(lightning_model_base.torch, "load", torch_load) estimator = Mock() model = SimpleNamespace(estimator=estimator) diff --git a/openadmet/models/trainer/lightning.py b/openadmet/models/trainer/lightning.py index 2f89cac2..fef0197a 100644 --- a/openadmet/models/trainer/lightning.py +++ b/openadmet/models/trainer/lightning.py @@ -3,10 +3,6 @@ from pathlib import Path # it is used in the main therefore i do not remove it from typing import Any -import torch -from lightning import pytorch as pl -from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint -from lightning.pytorch.loggers import CSVLogger, WandbLogger from loguru import logger from openadmet.models.drivers import DriverType from openadmet.models.trainer.trainer_base import TrainerBase, trainers @@ -106,6 +102,10 @@ def build(self, no_val: bool = False): # Initialize logging container self._logger = [] + from lightning import pytorch as pl + from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint + from lightning.pytorch.loggers import CSVLogger, WandbLogger + # Initialize the callbacks dict self._callbacks = {} @@ -191,6 +191,8 @@ def train(self, train_dataloader, val_dataloader): # Indicate that the model is being trained logger.debug(f"Training model {self.model.estimator}") + import torch + # Fit model self._trainer.fit(self.model.estimator, train_dataloader, val_dataloader) diff --git a/openadmet/models/trainer/trainer_base.py b/openadmet/models/trainer/trainer_base.py index 8632ab12..896076ce 100644 --- a/openadmet/models/trainer/trainer_base.py +++ b/openadmet/models/trainer/trainer_base.py @@ -26,6 +26,9 @@ def get_trainer_class(model_type): The trainer class corresponding to the given type. """ + from openadmet.models._registry_loader import load_group + + load_group("trainers") try: feat_class = trainers.get_class(model_type) except RegistryKeyError: diff --git a/openadmet/models/transforms/transform_base.py b/openadmet/models/transforms/transform_base.py index 10628abb..062bbf24 100644 --- a/openadmet/models/transforms/transform_base.py +++ b/openadmet/models/transforms/transform_base.py @@ -27,7 +27,15 @@ def get_transform_class(trans_type): TransformBase The transform class corresponding to the given type. + Raises + ------ + ValueError + If ``trans_type`` is not found in the transform registry. + """ + from openadmet.models._registry_loader import load_group + + load_group("transforms") try: transf_class = transforms.get_class(trans_type) except RegistryKeyError: