Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions openadmet/models/_registry_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""
Lazy registry loader — zero heavy imports at module level.

Call ``load_group(name)`` to load a specific registry group, or ``load_all()``
to populate every registry at once. Both are idempotent.
"""

import importlib

_MODELS = [
"openadmet.models.architecture.catboost",
"openadmet.models.architecture.chemprop",
"openadmet.models.architecture.dummy",
"openadmet.models.architecture.lgbm",
"openadmet.models.architecture.nepare",
"openadmet.models.architecture.rf",
"openadmet.models.architecture.svm",
"openadmet.models.architecture.tabpfn",
"openadmet.models.architecture.xgboost",
]

_EVALUATORS = [
"openadmet.models.eval.classification",
"openadmet.models.eval.cross_validation",
"openadmet.models.eval.regression",
"openadmet.models.eval.uncertainty",
]

_FEATURIZERS = [
"openadmet.models.features.chemprop",
"openadmet.models.features.combine",
"openadmet.models.features.molfeat_fingerprint",
"openadmet.models.features.molfeat_properties",
]

_SPLITTERS = [
"openadmet.models.split.scaffold",
"openadmet.models.split.sklearn",
"openadmet.models.split.cluster",
]

_TRAINERS = [
"openadmet.models.trainer.lightning",
"openadmet.models.trainer.sklearn",
]

_TRANSFORMS = [
"openadmet.models.transforms.impute",
"openadmet.models.transforms.transform_base",
]

_ACTIVE_LEARNING = [
"openadmet.models.active_learning.committee",
]

_GROUPS: dict[str, list[str]] = {
"models": _MODELS,
"evaluators": _EVALUATORS,
"featurizers": _FEATURIZERS,
"splitters": _SPLITTERS,
"trainers": _TRAINERS,
"transforms": _TRANSFORMS,
"active_learning": _ACTIVE_LEARNING,
}

_loaded: set[str] = set()


def load_group(name: str) -> None:
"""
Import all modules in the named registry group (idempotent).

Parameters
----------
name : str
Registry group key. Must be one of: ``"models"``, ``"evaluators"``,
``"featurizers"``, ``"splitters"``, ``"trainers"``, ``"transforms"``,
``"active_learning"``.

"""
if name in _loaded:
return
for mod in _GROUPS[name]:
importlib.import_module(mod)
_loaded.add(name)


def load_all() -> None:
"""Import all registry groups, making every registered class available (idempotent)."""
for name in _GROUPS:
load_group(name)
23 changes: 22 additions & 1 deletion openadmet/models/active_learning/ensemble_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,28 @@


def get_ensemble_class(ensemble_type):
"""Get the ensemble class."""
"""
Get the ensemble class from the registry by type.

Parameters
----------
ensemble_type : str
The registered key for the ensemble (e.g., ``"QueryByCommittee"``).

Returns
-------
type
The ensemble class corresponding to the given type.

Raises
------
ValueError
If ``ensemble_type`` is not found in the ensemble registry.

"""
from openadmet.models._registry_loader import load_group

load_group("active_learning")
try:
ensemble_class = ensemblers.get_class(ensemble_type)
except RegistryKeyError:
Expand Down
2 changes: 1 addition & 1 deletion openadmet/models/anvil/specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from openadmet.models.drivers import DriverType
from openadmet.models.eval.eval_base import get_eval_class
from openadmet.models.features.feature_base import get_featurizer_class
from openadmet.models.registries import * # noqa: F401, F403
from openadmet.models.registries import load_all # noqa: F401
from openadmet.models.split.split_base import get_splitter_class
from openadmet.models.trainer.trainer_base import get_trainer_class
from openadmet.models.transforms.transform_base import (
Expand Down
2 changes: 1 addition & 1 deletion openadmet/models/anvil/workflow_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from openadmet.models.architecture.model_base import ModelBase
from openadmet.models.eval.eval_base import EvalBase
from openadmet.models.features.feature_base import FeaturizerBase
from openadmet.models.registries import * # noqa: F401, F403
from openadmet.models.registries import load_all # noqa: F401
from openadmet.models.split.split_base import SplitterBase
from openadmet.models.trainer.trainer_base import TrainerBase
from openadmet.models.transforms.transform_base import (
Expand Down
26 changes: 18 additions & 8 deletions openadmet/models/architecture/catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import ClassVar

import numpy as np
from catboost import CatBoostClassifier, CatBoostRegressor
from loguru import logger
from pydantic import ConfigDict

Expand All @@ -18,9 +17,6 @@ class CatBoostModelBase(PickleableModelBase):
----------
type : ClassVar[str]
The type of the model.
mod_class : ClassVar[type]
To specify the CatBoost model class (e.g., CatBoostRegressor or CatBoost
Classifier)

"""

Expand All @@ -29,12 +25,16 @@ class CatBoostModelBase(PickleableModelBase):

# Meta parameters for this class
type: ClassVar[str]
mod_class: ClassVar[type]

@classmethod
def _get_estimator_class(cls) -> type:
"""Return the CatBoost estimator class (deferred import)."""
raise NotImplementedError

def build(self):
"""Prepare the model."""
if not self.estimator:
self.estimator = self.mod_class(**self.model_dump())
self.estimator = self._get_estimator_class()(**self.model_dump())
else:
logger.warning("Model already exists, skipping build")

Expand Down Expand Up @@ -95,7 +95,12 @@ class CatBoostRegressorModel(CatBoostModelBase):

# Meta parameters for this class
type: ClassVar[str] = "CatBoostRegressorModel"
mod_class: ClassVar[type] = CatBoostRegressor

@classmethod
def _get_estimator_class(cls) -> type:
from catboost import CatBoostRegressor

return CatBoostRegressor


@models.register("CatBoostClassifierModel")
Expand All @@ -109,7 +114,12 @@ class CatBoostClassifierModel(CatBoostModelBase):

# Meta parameters for this class
type: ClassVar[str] = "CatBoostClassifierModel"
mod_class: ClassVar[type] = CatBoostClassifier

@classmethod
def _get_estimator_class(cls) -> type:
from catboost import CatBoostClassifier

return CatBoostClassifier

def predict_proba(self, X: np.ndarray) -> np.ndarray:
"""
Expand Down
13 changes: 6 additions & 7 deletions openadmet/models/architecture/chemprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,9 @@
from loguru import logger
from pydantic import PrivateAttr, field_validator, model_validator

from openadmet.models.architecture.model_base import LightningModelBase
from openadmet.models.architecture.lightning_model_base import LightningModelBase
from openadmet.models.architecture.model_base import models as model_registry

_METRIC_TO_LOSS = {
"mse": nn.metrics.MSE(),
"mae": nn.metrics.MAE(),
"rmse": nn.metrics.RMSE(),
}


def configure_optimizers(self):
"""
Expand Down Expand Up @@ -445,6 +439,11 @@ def build(self, scaler=None):

"""
if not self.estimator:
_METRIC_TO_LOSS = {
"mse": nn.metrics.MSE(),
"mae": nn.metrics.MAE(),
"rmse": nn.metrics.RMSE(),
}
metric_list = [_METRIC_TO_LOSS[metric] for metric in self.metric_list]
if self.from_foundation:
if self.from_foundation == "chemeleon":
Expand Down
23 changes: 18 additions & 5 deletions openadmet/models/architecture/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
from loguru import logger
from sklearn.dummy import DummyClassifier, DummyRegressor

from openadmet.models.architecture.model_base import PickleableModelBase, models

Expand All @@ -14,12 +13,16 @@ class DummyModelBase(PickleableModelBase):

# Meta parameters for this class
type: ClassVar[str]
mod_class: ClassVar[type]

@classmethod
def _get_estimator_class(cls) -> type:
"""Return the sklearn dummy estimator class (deferred import)."""
raise NotImplementedError

def build(self):
"""Prepare the model."""
if not self.estimator:
self.estimator = self.mod_class(**self.model_dump())
self.estimator = self._get_estimator_class()(**self.model_dump())
else:
logger.warning("Model already exists, skipping build")

Expand Down Expand Up @@ -77,7 +80,12 @@ class DummyRegressorModel(DummyModelBase):

# Meta parameters for this class
type: ClassVar[str] = "DummyRegressorModel"
mod_class: ClassVar[type] = DummyRegressor

@classmethod
def _get_estimator_class(cls) -> type:
from sklearn.dummy import DummyRegressor

return DummyRegressor

# DummyRegressor parameters
strategy: str = "mean" # Default strategy for dummy models
Expand All @@ -96,7 +104,12 @@ class DummyClassifierModel(DummyModelBase):

# Meta parameters for this class
type: ClassVar[str] = "DummyClassifierModel"
mod_class: ClassVar[type] = DummyClassifier

@classmethod
def _get_estimator_class(cls) -> type:
from sklearn.dummy import DummyClassifier

return DummyClassifier

# DummyClassifier parameters
strategy: str = "most_frequent" # Default strategy for dummy models
Expand Down
23 changes: 18 additions & 5 deletions openadmet/models/architecture/lgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from typing import ClassVar

import lightgbm as lgb
import numpy as np
from loguru import logger

Expand All @@ -14,7 +13,11 @@ class LGBMModelBase(PickleableModelBase):

# Meta parameters for this class
type: ClassVar[str]
mod_class: ClassVar[type]

@classmethod
def _get_estimator_class(cls) -> type:
"""Return the LightGBM estimator class (deferred import)."""
raise NotImplementedError

# LGBM parameters
boosting_type: str = "gbdt"
Expand All @@ -41,7 +44,7 @@ class LGBMModelBase(PickleableModelBase):
def build(self):
"""Prepare the model."""
if not self.estimator:
self.estimator = self.mod_class(**self.model_dump())
self.estimator = self._get_estimator_class()(**self.model_dump())
else:
logger.warning("Model already exists, skipping build")

Expand Down Expand Up @@ -88,7 +91,12 @@ class LGBMRegressorModel(LGBMModelBase):

# Meta parameters for this class
type: ClassVar[str] = "LGBMRegressorModel"
mod_class: ClassVar[type] = lgb.LGBMRegressor

@classmethod
def _get_estimator_class(cls) -> type:
from lightgbm import LGBMRegressor

return LGBMRegressor


@models.register("LGBMClassifierModel")
Expand All @@ -97,7 +105,12 @@ class LGBMClassifierModel(LGBMModelBase):

# Meta parameters for this class
type: ClassVar[str] = "LGBMClassifierModel"
mod_class: ClassVar[type] = lgb.LGBMClassifier

@classmethod
def _get_estimator_class(cls) -> type:
from lightgbm import LGBMClassifier

return LGBMClassifier

def predict_proba(self, X: np.ndarray) -> np.ndarray:
"""Predict using the model."""
Expand Down
Loading
Loading