diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 6c194a4b..96c2d0fd 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -59,3 +59,13 @@ The CLI entry point is `openadmet` (`openadmet/models/cli/cli.py`), with subcomm - Ruff + Black formatting; isort with Black-compatible profile - Sentence case in comments and print statements; acronyms (MPNN, MVE, ADMET, FFN) stay capitalized - Do not number steps in comments; do not end comments with a period + +## Unit Testing & Refactoring Rules + +When writing or refactoring tests, you must strictly adhere to the following guidelines to ensure tests are mathematically sound, robust, and non-tautological: + +* **Avoid Tautological Mocks:** Do not mock the system under test. Mock heavy I/O, external dependencies, or heavy data loading, but ensure the core logic of the target function is actually executed. Use lightweight synthetic datasets (e.g., small tensors or pandas DataFrames) instead of bypassing the execution entirely. +* **Standard Mocking:** Never write custom nested dummy classes or custom mock fixtures. Always use the standard `pytest-mock` library (the `mocker` fixture) to patch objects and verify calls. +* **No Lazy Assertions:** Never use `assert True`. Assert actual state changes, specific dictionary keys, object types (e.g., `isinstance(obj, matplotlib.figure.Figure)`), or verify file creation via the `tmp_path` fixture. +* **Robust ML Data Testing:** When testing data splitters or clustering algorithms, you must explicitly assert that the resulting train/validation/test sets are mutually exclusive (e.g., checking that set intersections of indices or arrays are empty). Ensure synthetic testing data has enough variance (e.g., diverse SMILES scaffolds) to meaningfully test the algorithm. +* **Safe Floating-Point Math:** Never use strict equality (`==`) to compare floating-point numbers. Always use `pytest.approx()` or `numpy.testing.assert_almost_equal()` to prevent cross-platform precision failures. Assert the actual math (e.g., UQ or metric calculations), not just the existence of the output. diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1074358f..e80f57d4 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,7 +17,7 @@ repos: - id: black files: ^openadmet_models - repo: https://github.com/PyCQA/isort - rev: 8.0.0 + rev: 8.0.1 hooks: - id: isort files: ^openadmet_models @@ -37,7 +37,7 @@ repos: - --py39-plus - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.15.2 + rev: v0.15.4 hooks: # Run the linter. - id: ruff-check diff --git a/devtools/conda-envs/openadmet-models-gpu.yaml b/devtools/conda-envs/openadmet-models-gpu.yaml index f67fff20..287d3539 100644 --- a/devtools/conda-envs/openadmet-models-gpu.yaml +++ b/devtools/conda-envs/openadmet-models-gpu.yaml @@ -33,6 +33,7 @@ dependencies: - pytorch_scatter - pytorch_sparse - pytest + - pytest-mock - pytest-cov - pytest-xdist - rdkit diff --git a/devtools/conda-envs/openadmet-models.yaml b/devtools/conda-envs/openadmet-models.yaml index 776d7987..bebe702c 100644 --- a/devtools/conda-envs/openadmet-models.yaml +++ b/devtools/conda-envs/openadmet-models.yaml @@ -33,6 +33,7 @@ dependencies: - pytorch_scatter - pytorch_sparse - pytest + - pytest-mock - pytest-cov - pytest-xdist - rdkit diff --git a/openadmet/models/active_learning/committee.py b/openadmet/models/active_learning/committee.py index 4f8bf7fd..5402a70e 100644 --- a/openadmet/models/active_learning/committee.py +++ b/openadmet/models/active_learning/committee.py @@ -382,8 +382,8 @@ def _predict(self, X, return_std=False, **kwargs): if return_std is False: return mean - # Compute standard deviation - std = np.std(preds, axis=-1) + # Compute standard deviation, guard against zero std + std = np.maximum(np.std(preds, axis=-1), 1e-8) # Calibrate std if calibration model is available if self.calibrated: diff --git a/openadmet/models/anvil/specification.py b/openadmet/models/anvil/specification.py index 830a1a36..9c1f6fe7 100644 --- a/openadmet/models/anvil/specification.py +++ b/openadmet/models/anvil/specification.py @@ -15,8 +15,8 @@ from openadmet.models.active_learning.ensemble_base import ( get_ensemble_class, ) -from openadmet.models.drivers import DriverType from openadmet.models.architecture.model_base import get_mod_class +from openadmet.models.drivers import DriverType from openadmet.models.eval.eval_base import get_eval_class from openadmet.models.features.feature_base import get_featurizer_class from openadmet.models.registries import * # noqa: F401, F403 @@ -449,9 +449,23 @@ class SplitSpec(AnvilSection): class FeatureSpec(AnvilSection): - """Featurization specification.""" + """ + Featurization specification. + + Attributes + ---------- + section_name : ClassVar[str] + The name of the section. + type : Optional[str] + The type of featurizer to use. + params : dict + The parameters for the featurizer. + + """ section_name: ClassVar[str] = "feat" + type: str | None = None + params: dict = Field(default_factory=dict) class ModelSpec(AnvilSection): @@ -546,8 +560,8 @@ class EnsembleSpec(AnvilSection): section_name: ClassVar[str] = "ensemble" n_models: int - calibration_method: str | None = "isotonic-regression" - use_bagging: bool = True + calibration_method: str | None = None + use_bagging: bool = False param_paths: list[str] | None = None serial_paths: list[str] | None = None @@ -729,6 +743,26 @@ def to_workflow(self): # Pull driver from associated trainer to choose the correct workflow trainer_class = self.procedure.train.to_class() driver = _DRIVER_TO_CLASS[trainer_class._driver_type] + model_kwargs = { + "param_path": self.procedure.model.param_path, + "serial_path": self.procedure.model.serial_path, + "freeze_weights": self.procedure.model.freeze_weights, + } + ensemble_kwargs = ( + { + "n_models": self.procedure.ensemble.n_models, + "calibration_method": self.procedure.ensemble.calibration_method, + "param_paths": self.procedure.ensemble.param_paths, + "serial_paths": self.procedure.ensemble.serial_paths, + "use_bagging": self.procedure.ensemble.use_bagging, + } + if self.procedure.ensemble + else {} + ) + feat_kwargs = { + "type": self.procedure.feat.type, + "params": self.procedure.feat.params, + } return driver( metadata=self.metadata, @@ -744,5 +778,34 @@ def to_workflow(self): feat=self.procedure.feat.to_class(), trainer=self.procedure.train.to_class(), evals=[eval.to_class() for eval in self.report.eval], - parent_spec=self, + model_kwargs=model_kwargs, + ensemble_kwargs=ensemble_kwargs, + feat_kwargs=feat_kwargs, + ) + + def run( + self, + output_dir: PathLike = "anvil_training", + debug: bool = False, + tag: str = None, + ): + """Run the Anvil workflow from this specification.""" + workflow = self.to_workflow() + result = workflow.run(output_dir=output_dir, debug=debug, tag=tag) + + resolved_output_dir = workflow.resolved_output_dir or Path(output_dir) + resolved_output_dir.mkdir(parents=True, exist_ok=True) + provenance_spec = self.model_copy(deep=True) + if tag is not None: + provenance_spec.metadata.tag = tag + + provenance_spec.to_recipe(resolved_output_dir / "anvil_recipe.yaml") + recipe_components = resolved_output_dir / "recipe_components" + recipe_components.mkdir(parents=True, exist_ok=True) + provenance_spec.to_multi_yaml( + metadata_yaml=recipe_components / "metadata.yaml", + procedure_yaml=recipe_components / "procedure.yaml", + data_yaml=recipe_components / "data.yaml", + report_yaml=recipe_components / "eval.yaml", ) + return result diff --git a/openadmet/models/anvil/workflow.py b/openadmet/models/anvil/workflow.py index b8c50d78..5b4c5cbf 100644 --- a/openadmet/models/anvil/workflow.py +++ b/openadmet/models/anvil/workflow.py @@ -5,7 +5,7 @@ from datetime import datetime from os import PathLike from pathlib import Path -from typing import Any, ClassVar, Literal, Optional +from typing import Any import numpy as np import pandas as pd @@ -71,8 +71,8 @@ def check_no_finetuning(self): # Ensemble specified if self.ensemble: # Fine-tuning paths specified - if (self.parent_spec.procedure.ensemble.param_paths is not None) or ( - self.parent_spec.procedure.ensemble.serial_paths is not None + if (self.ensemble_kwargs.get("param_paths") is not None) or ( + self.ensemble_kwargs.get("serial_paths") is not None ): raise ValueError( "Finetuning from serialized ensemble models is not supported in this workflow." @@ -81,8 +81,8 @@ def check_no_finetuning(self): # No ensemble else: # Fine-tuning paths supplied - if (self.parent_spec.procedure.model.param_path is not None) or ( - self.parent_spec.procedure.model.serial_path is not None + if (self.model_kwargs.get("param_path") is not None) or ( + self.model_kwargs.get("serial_path") is not None ): raise ValueError( "Finetuning from serialized model is not supported in this workflow." @@ -114,19 +114,28 @@ def _train_ensemble(self, X_train_feat, y_train, output_dir, **kwargs): X_train_feat = _safe_to_numpy(X_train_feat) y_train = _safe_to_numpy(y_train) - # Bootstrap iterations - models = [] # Get bagging setting - use_bagging = self.parent_spec.procedure.ensemble.use_bagging + use_bagging = self.ensemble_kwargs.get("use_bagging") + # Get global seed + # Currently grabbing from `split`, should this be set separately? global_seed = self.split.random_state - for i in range(self.parent_spec.procedure.ensemble.n_models): + # Bootstrap iterations + models = [] + for i in range(self.ensemble_kwargs["n_models"]): # Manage bootstrap directory bootstrap_dir = output_dir / f"bootstrap_{i}" bootstrap_dir.mkdir(parents=True, exist_ok=True) + # Bootstrap data if using bagging, if not specified default False if use_bagging: + # Set seed for bootstrapping + logger.info( + f"Using incremented seed={global_seed + i} for bootstrapping" + ) + np.random.seed(global_seed + i) + # Bootstrap train data logger.info("Bootstrapping train data") bootstrap_indices = np.random.choice( @@ -140,7 +149,9 @@ def _train_ensemble(self, X_train_feat, y_train, output_dir, **kwargs): y_train_bootstrap = y_train # Build model from scratch - logger.info(f"Building model {i}") + logger.info( + f"Building model {i} using incremented seed={global_seed + i} to vary model initialization" + ) bootstrap_model = self.model.make_new() # Set seed for model @@ -148,7 +159,7 @@ def _train_ensemble(self, X_train_feat, y_train, output_dir, **kwargs): bootstrap_model.random_state = global_seed + i else: logger.warning( - f"Model {bootstrap_model} does not support random_state seeding." + f"Model {bootstrap_model} does not support random_state seeding" ) bootstrap_model.build() @@ -227,24 +238,12 @@ def run( # Create the output directory output_dir.mkdir(parents=True, exist_ok=True) + self.resolved_output_dir = output_dir # Create data subdirectory data_dir = output_dir / "data" data_dir.mkdir(parents=True, exist_ok=True) - # Write recipe to output directory - self.parent_spec.to_recipe(output_dir / "anvil_recipe.yaml") - - # Split recipe into components and save - recipe_components = Path(output_dir / "recipe_components") - recipe_components.mkdir(parents=True, exist_ok=True) - self.parent_spec.to_multi_yaml( - metadata_yaml=recipe_components / "metadata.yaml", - procedure_yaml=recipe_components / "procedure.yaml", - data_yaml=recipe_components / "data.yaml", - report_yaml=recipe_components / "eval.yaml", - ) - # Log output directory information logger.info(f"Running workflow from directory {output_dir}") @@ -339,7 +338,7 @@ def run( self.model.calibrate_uncertainty( X_val_feat, y_val, - method=self.parent_spec.procedure.ensemble.calibration_method, + method=self.ensemble_kwargs.get("calibration_method"), ) # Save @@ -462,18 +461,67 @@ def check_if_val_needed(self): return self + @model_validator(mode="after") + def check_finetuning_paths(self): + """ + Check that finetuning path pairs are consistent and exist on disk. + + Both ``param_path`` and ``serial_path`` must be provided together (or + neither). When both are provided, both paths must exist before training + begins. The same requirement applies to ``param_paths`` / ``serial_paths`` + for ensemble workflows, which must additionally be equal-length lists. + + Raises + ------ + ValueError + If exactly one of the path pair is provided, if provided paths do + not exist on disk, or if ensemble path lists have unequal length. + + """ + if not self.ensemble: + param_path = self.model_kwargs.get("param_path") + serial_path = self.model_kwargs.get("serial_path") + if (param_path is None) != (serial_path is None): + raise ValueError( + "Both param_path and serial_path must be provided together for finetuning." + ) + if param_path is not None: + if not Path(param_path).exists(): + raise ValueError(f"param_path '{param_path}' does not exist.") + if not Path(serial_path).exists(): + raise ValueError(f"serial_path '{serial_path}' does not exist.") + else: + param_paths = self.ensemble_kwargs.get("param_paths") + serial_paths = self.ensemble_kwargs.get("serial_paths") + if (param_paths is None) != (serial_paths is None): + raise ValueError( + "Both param_paths and serial_paths must be provided together for ensemble finetuning." + ) + if param_paths is not None: + if len(param_paths) != len(serial_paths): + raise ValueError( + "param_paths and serial_paths must have equal length." + ) + for p in param_paths: + if not Path(p).exists(): + raise ValueError(f"param_path '{p}' does not exist.") + for s in serial_paths: + if not Path(s).exists(): + raise ValueError(f"serial_path '{s}' does not exist.") + return self + def _train( self, train_dataloader, val_dataloader, train_scaler, output_dir, **kwargs ): # Load model from disk if ( - self.parent_spec.procedure.model.param_path is not None - and self.parent_spec.procedure.model.serial_path is not None + self.model_kwargs.get("param_path") is not None + and self.model_kwargs.get("serial_path") is not None ): logger.info("Loading model from disk, overrides any specified parameters.") self.model = self.model.deserialize( - self.parent_spec.procedure.model.param_path, - self.parent_spec.procedure.model.serial_path, + self.model_kwargs.get("param_path"), + self.model_kwargs.get("serial_path"), scaler=train_scaler, **kwargs, ) @@ -481,16 +529,14 @@ def _train( logger.info("Model loaded") # Optionally freeze weights - if self.parent_spec.procedure.model.freeze_weights is not None: + if self.model_kwargs.get("freeze_weights") is not None: logger.info(f"Freezing model weights") - self.model.freeze_weights( - **self.parent_spec.procedure.model.freeze_weights - ) + self.model.freeze_weights(**self.model_kwargs.get("freeze_weights")) logger.info(f"Model weights frozen") # Build model from scratch else: - logger.info("Building model") + logger.info(f"Building model") self.model.build(scaler=train_scaler, **kwargs) logger.info("Model built") @@ -522,16 +568,16 @@ def _train_ensemble(self, X_train, y_train, val_dataloader, output_dir, **kwargs if not self.trainer.output_dir: self.trainer.output_dir = output_dir - # Bootstrap iterations - models = [] - # Get bagging setting - use_bagging = self.parent_spec.procedure.ensemble.use_bagging + use_bagging = self.ensemble_kwargs.get("use_bagging") # Get global seed + # Currently grabbing from `split`, should this be set separately? global_seed = self.split.random_state - for i in range(self.parent_spec.procedure.ensemble.n_models): + # Bootstrap iterations + models = [] + for i in range(self.ensemble_kwargs["n_models"]): # Manage bootstrap directory bootstrap_dir = output_dir / f"bootstrap_{i}" bootstrap_dir.mkdir(parents=True, exist_ok=True) @@ -540,12 +586,14 @@ def _train_ensemble(self, X_train, y_train, val_dataloader, output_dir, **kwargs self.feat = self.feat.make_new() self.trainer = self.trainer.make_new() - # Seed everything for reproducibility - pl.seed_everything(global_seed + i) - - # Bootstrap data if using bagging + # Bootstrap data if using bagging, if not specified default False if use_bagging: - logger.info("Bootstrapping train data") + # Set seed for bootstrapping + logger.info( + f"Bootstrapping train data with incremented seed={global_seed + i}" + ) + np.random.seed(global_seed + i) + bootstrap_indices = np.random.choice( np.arange(len(X_train)), size=len(X_train), replace=True ) @@ -575,31 +623,34 @@ def _train_ensemble(self, X_train, y_train, val_dataloader, output_dir, **kwargs logger.info("Data featurized") # Load model from disk - if (self.parent_spec.procedure.ensemble.param_paths is not None) and ( - self.parent_spec.procedure.ensemble.serial_paths is not None + if (self.ensemble_kwargs.get("param_paths") is not None) and ( + self.ensemble_kwargs.get("serial_paths") is not None ): logger.info( f"Loading model {i} from disk, overrides any specified parameters." ) self.model = self.model.deserialize( - self.parent_spec.procedure.ensemble.param_paths[i], - self.parent_spec.procedure.ensemble.serial_paths[i], + self.ensemble_kwargs.get("param_paths")[i], + self.ensemble_kwargs.get("serial_paths")[i], scaler=bootstrap_scaler, **kwargs, ) logger.info(f"Model {i} loaded") # Optionally freeze weights - if self.parent_spec.procedure.model.freeze_weights is not None: + if self.model_kwargs.get("freeze_weights") is not None: logger.info(f"Freezing weights for model {i}") - self.model.freeze_weights( - **self.parent_spec.procedure.model.freeze_weights - ) + self.model.freeze_weights(**self.model_kwargs.get("freeze_weights")) logger.info(f"Model {i} frozen") # Build model from scratch else: - logger.info(f"Building model {i}") + # Set seed for bootstrap model + logger.info( + f"Building model {i} with incremented seed={global_seed + i} to vary model initialization" + ) + pl.seed_everything(global_seed + i) + self.model = self.model.make_new() self.model.build(scaler=bootstrap_scaler, **kwargs) logger.info(f"Model {i} built") @@ -684,24 +735,12 @@ def run( # Create the output directory output_dir.mkdir(parents=True, exist_ok=True) + self.resolved_output_dir = output_dir # Create data subdirectory data_dir = output_dir / "data" data_dir.mkdir(parents=True, exist_ok=True) - # Write recipe to output directory - self.parent_spec.to_recipe(output_dir / "anvil_recipe.yaml") - - # Split recipe into components and save - recipe_components = Path(output_dir / "recipe_components") - recipe_components.mkdir(parents=True, exist_ok=True) - self.parent_spec.to_multi_yaml( - metadata_yaml=recipe_components / "metadata.yaml", - procedure_yaml=recipe_components / "procedure.yaml", - data_yaml=recipe_components / "data.yaml", - report_yaml=recipe_components / "eval.yaml", - ) - # Log output directory information logger.info(f"Running workflow from directory {output_dir}") @@ -768,7 +807,7 @@ def run( logger.info("Data featurized") kwargs = {} - if self.parent_spec.procedure.feat.type == "PairwiseFeaturizer": + if self.feat_kwargs.get("type") == "PairwiseFeaturizer": kwargs["input_dim"] = train_dataset[0][0].shape[ -1 ] # this is the dimension of # of features, e.g. 1024 for ECFP4, variable for descriptors @@ -791,7 +830,7 @@ def run( self.model.calibrate_uncertainty( val_dataloader, y_val, - method=self.parent_spec.procedure.ensemble.calibration_method, + method=self.ensemble_kwargs.get("calibration_method"), accelerator=self.trainer.accelerator, devices=self.trainer.devices, ) diff --git a/openadmet/models/anvil/workflow_base.py b/openadmet/models/anvil/workflow_base.py index 77ee0922..7d116b7f 100644 --- a/openadmet/models/anvil/workflow_base.py +++ b/openadmet/models/anvil/workflow_base.py @@ -2,14 +2,15 @@ from abc import abstractmethod from os import PathLike +from pathlib import Path from typing import Any, Optional -from pydantic import BaseModel, model_validator +from pydantic import BaseModel, Field, model_validator from openadmet.models.active_learning.ensemble_base import ( EnsembleBase, ) -from openadmet.models.anvil.specification import AnvilSpecification, DataSpec, Metadata +from openadmet.models.anvil.specification import DataSpec, Metadata from openadmet.models.architecture.model_base import ModelBase from openadmet.models.eval.eval_base import EvalBase from openadmet.models.features.feature_base import FeaturizerBase @@ -45,8 +46,12 @@ class AnvilWorkflowBase(BaseModel): The trainer for the model. evals : list[EvalBase] List of evaluation metrics. - parent_spec : AnvilSpecification - The parent specification for the workflow. + model_kwargs : dict + Runtime model settings from the specification domain. + ensemble_kwargs : dict + Runtime ensemble settings from the specification domain. + feat_kwargs : dict + Runtime feature settings from the specification domain. debug : bool Whether to run in debug mode. @@ -61,8 +66,11 @@ class AnvilWorkflowBase(BaseModel): ensemble: EnsembleBase | None = None trainer: TrainerBase evals: list[EvalBase] - parent_spec: AnvilSpecification + model_kwargs: dict = Field(default_factory=dict) + ensemble_kwargs: dict = Field(default_factory=dict) + feat_kwargs: dict = Field(default_factory=dict) debug: bool = False + resolved_output_dir: Path | None = None @abstractmethod def run(self, output_dir: PathLike = "anvil_training", debug: bool = False) -> Any: @@ -97,7 +105,7 @@ def check_multitask_compatibility(self) -> None: """ if self.model._n_tasks != len(self.data_spec.target_cols): raise ValueError( - f"The model has {self.model.n_tasks} tasks but the data specification has {len(self.data_spec.target_cols)} target columns." + f"The model has {self.model._n_tasks} tasks but the data specification has {len(self.data_spec.target_cols)} target columns." ) return self diff --git a/openadmet/models/architecture/dummy.py b/openadmet/models/architecture/dummy.py index 4156b4bd..aa6e6aab 100644 --- a/openadmet/models/architecture/dummy.py +++ b/openadmet/models/architecture/dummy.py @@ -4,7 +4,6 @@ import numpy as np from loguru import logger -from pydantic import ConfigDict from sklearn.dummy import DummyClassifier, DummyRegressor from openadmet.models.architecture.model_base import PickleableModelBase, models @@ -37,7 +36,10 @@ def train(self, X: np.ndarray, y: np.ndarray): """ self.build() - self.estimator = self.estimator.fit(X, y, verbose=True) + y_arr = np.asarray(y) + if y_arr.ndim == 2 and y_arr.shape[1] == 1: + y_arr = y_arr.ravel() + self.estimator = self.estimator.fit(X, y_arr) def predict(self, X: np.ndarray, **kwargs) -> np.ndarray: """ @@ -58,7 +60,10 @@ def predict(self, X: np.ndarray, **kwargs) -> np.ndarray: """ if not self.estimator: raise ValueError("Model not trained") - return np.expand_dims(self.estimator.predict(X), axis=1) + pred = self.estimator.predict(X) + if pred.ndim == 1: + pred = np.expand_dims(pred, axis=1) + return pred @models.register("DummyRegressorModel") @@ -76,8 +81,8 @@ class DummyRegressorModel(DummyModelBase): # DummyRegressor parameters strategy: str = "mean" # Default strategy for dummy models - constant: float = None # Default constant value for dummy models - quantile: float = None # Default quantile value for dummy models + constant: float | None = None # Default constant value for dummy models + quantile: float | None = None # Default quantile value for dummy models @models.register("DummyClassifierModel") @@ -95,5 +100,5 @@ class DummyClassifierModel(DummyModelBase): # DummyClassifier parameters strategy: str = "most_frequent" # Default strategy for dummy models - random_state: int = None # Default random state for dummy models + random_state: int | None = None # Default random state for dummy models constant: int | str = None # Default constant value for dummy models diff --git a/openadmet/models/cli/anvil.py b/openadmet/models/cli/anvil.py index 70de5259..5ffc4728 100644 --- a/openadmet/models/cli/anvil.py +++ b/openadmet/models/cli/anvil.py @@ -40,7 +40,6 @@ def anvil(recipe_path, tag, debug, output_dir): """ spec = AnvilSpecification.from_recipe(recipe_path) - wf = spec.to_workflow() click.echo(f"Workflow initialized successfully with recipe: {recipe_path}") - wf.run(tag=tag, debug=debug, output_dir=output_dir) + spec.run(tag=tag, debug=debug, output_dir=output_dir) click.echo("Workflow completed successfully") diff --git a/openadmet/models/eval/regression.py b/openadmet/models/eval/regression.py index b24b0655..fa73f795 100644 --- a/openadmet/models/eval/regression.py +++ b/openadmet/models/eval/regression.py @@ -17,7 +17,7 @@ evaluators, get_t_true_and_t_pred, ) -from openadmet.models.eval.utils import _make_stat_caption, _make_stat_dict +from openadmet.models.eval.utils import _make_stat_caption, _make_stat_dict, ensure_2d # create partial functions for the scipy stats nan_omit_ktau = partial(kendalltau, nan_policy="omit") @@ -96,10 +96,8 @@ def evaluate( y_true = y_true.to_numpy() # Ensure y_pred and y_true are 2D arrays for consistency - if y_pred.ndim == 1: - y_pred = y_pred.reshape(-1, 1) - if y_true.ndim == 1: - y_true = y_true.reshape(-1, 1) + y_pred = ensure_2d(y_pred) + y_true = ensure_2d(y_true) n_tasks = y_true.shape[1] if not (n_tasks == y_pred.shape[1]): @@ -380,10 +378,8 @@ def evaluate( y_true = y_true.to_numpy() # Ensure y_pred and y_true are 2D arrays for consistency - if y_pred.ndim == 1: - y_pred = y_pred.reshape(-1, 1) - if y_true.ndim == 1: - y_true = y_true.reshape(-1, 1) + y_pred = ensure_2d(y_pred) + y_true = ensure_2d(y_true) n_tasks = y_true.shape[1] if not (n_tasks == y_pred.shape[1]): diff --git a/openadmet/models/eval/uncertainty.py b/openadmet/models/eval/uncertainty.py index ade21252..97fc0d68 100644 --- a/openadmet/models/eval/uncertainty.py +++ b/openadmet/models/eval/uncertainty.py @@ -9,6 +9,7 @@ from pydantic import Field from openadmet.models.eval.eval_base import EvalBase, evaluators, mask_nans_std +from openadmet.models.eval.utils import ensure_2d @evaluators.register("UncertaintyMetrics") @@ -122,12 +123,9 @@ def evaluate( y_true = y_true.to_numpy() # Ensure 2D arrays for consistency - if y_pred.ndim == 1: - y_pred = y_pred.reshape(-1, 1) - if y_true.ndim == 1: - y_true = y_true.reshape(-1, 1) - if y_std.ndim == 1: - y_std = y_std.reshape(-1, 1) + y_pred = ensure_2d(y_pred) + y_true = ensure_2d(y_true) + y_std = ensure_2d(y_std) # Verify number of tasks n_tasks = y_true.shape[1] @@ -158,7 +156,6 @@ def evaluate( t_pred, t_true, False ) - # Calibration calibration_metrics = uct.metrics.get_all_average_calibration( t_pred, t_std, t_true, bins, False ) @@ -322,12 +319,9 @@ def evaluate(self, y_true, y_pred, y_std, target_labels=None, **kwargs): y_true = y_true.to_numpy() # Ensure 2D arrays for consistency - if y_pred.ndim == 1: - y_pred = y_pred.reshape(-1, 1) - if y_true.ndim == 1: - y_true = y_true.reshape(-1, 1) - if y_std.ndim == 1: - y_std = y_std.reshape(-1, 1) + y_pred = ensure_2d(y_pred) + y_true = ensure_2d(y_true) + y_std = ensure_2d(y_std) # Verify number of tasks n_tasks = y_true.shape[1] diff --git a/openadmet/models/eval/utils.py b/openadmet/models/eval/utils.py index 6b50ebee..ec1ce02c 100644 --- a/openadmet/models/eval/utils.py +++ b/openadmet/models/eval/utils.py @@ -1,5 +1,27 @@ """Utility functions for evaluation modules.""" +import numpy as np + + +def ensure_2d(arr: np.ndarray) -> np.ndarray: + """ + Coerce a 1D array to column-vector shape (N, 1); leave 2D arrays unchanged. + + Parameters + ---------- + arr : np.ndarray + Input array of shape (N,) or (N, M). + + Returns + ------- + np.ndarray + Array of shape (N, 1) if input was 1D, otherwise the original array. + + """ + if arr.ndim == 1: + return arr.reshape(-1, 1) + return arr + def _make_stat_caption( data: dict, diff --git a/openadmet/models/features/combine.py b/openadmet/models/features/combine.py index a320ee41..8bd93d51 100644 --- a/openadmet/models/features/combine.py +++ b/openadmet/models/features/combine.py @@ -114,9 +114,16 @@ def concatenate(feats: list[ArrayLike], indices: list[np.ndarray]) -> np.ndarray # use indices to mask out the features that are not present in all datasets common_indices = reduce(np.intersect1d, indices) + # filter features to only include common indices + filtered_feats = [] + for feat, idx in zip(feats, indices): + # find where common_indices are in idx + mask = np.isin(idx, common_indices) + filtered_feats.append(feat[mask]) + # handle 1d features from single input by making them 2d # concatenate the features column wise - concat_feats = np.concatenate(feats, axis=1) + concat_feats = np.concatenate(filtered_feats, axis=1) return ( concat_feats, common_indices, diff --git a/openadmet/models/tests/integration/test_data/chemeleon_MT_ensemble.yaml b/openadmet/models/tests/integration/test_data/chemeleon_MT_ensemble.yaml index d22fa76b..687efb8c 100644 --- a/openadmet/models/tests/integration/test_data/chemeleon_MT_ensemble.yaml +++ b/openadmet/models/tests/integration/test_data/chemeleon_MT_ensemble.yaml @@ -41,7 +41,7 @@ procedure: from_chemeleon: True ensemble: - n_models: 3 + n_models: 2 type: CommitteeRegressor split: diff --git a/openadmet/models/tests/integration/test_data/lgbm_fp_ensemble.yaml b/openadmet/models/tests/integration/test_data/lgbm_fp_ensemble.yaml index a1739321..f714513f 100644 --- a/openadmet/models/tests/integration/test_data/lgbm_fp_ensemble.yaml +++ b/openadmet/models/tests/integration/test_data/lgbm_fp_ensemble.yaml @@ -50,7 +50,7 @@ procedure: learning_rate: 0.05 ensemble: - n_models: 3 + n_models: 2 type: CommitteeRegressor # Specify data splits diff --git a/openadmet/models/tests/unit/active_learning/test_acquisition.py b/openadmet/models/tests/unit/active_learning/test_acquisition.py new file mode 100644 index 00000000..a3289db3 --- /dev/null +++ b/openadmet/models/tests/unit/active_learning/test_acquisition.py @@ -0,0 +1,81 @@ +import numpy as np +from numpy.testing import assert_allclose +from scipy.stats import norm + +from openadmet.models.active_learning.acquisition import ( + _ACQUISITION_FUNCTIONS, + expected_improvement, + exploitation, + max_uncertainty_reduction, + probability_improvement, + upper_confidence_bound, +) + + +def test_basic_acquisition_functions_passthrough(): + """ + Validate that basic acquisition functions return expected values based on mean and standard deviation. + + This verifies that: + - Max uncertainty reduction returns standard deviation (uncertainty). + - Exploitation returns the mean prediction. + - UCB correctly combines mean and uncertainty with the beta parameter. + """ + mean = np.array([[1.0], [2.0]]) + std = np.array([[0.1], [0.2]]) + assert_allclose(max_uncertainty_reduction(mean, std), std) + assert_allclose(exploitation(mean, std), mean) + assert_allclose(upper_confidence_bound(mean, std, beta=3.0), mean + 3.0 * std) + + +def test_probability_improvement_matches_formula(): + """ + Verify Probability of Improvement (PI) calculation against the explicit mathematical formula. + + This ensures that the implementation correctly computes the cumulative distribution function (CDF) + of the improvement over the best observed value, accounting for the exploration parameter xi. + """ + mean = np.array([[1.0], [2.0]]) + std = np.array([[0.5], [1e-12]]) + best_y = 1.2 + xi = 0.1 + expected = norm.cdf((mean - best_y - xi) / std.clip(min=1e-9)) + assert_allclose(probability_improvement(mean, std, best_y=best_y, xi=xi), expected) + + +def test_expected_improvement_matches_formula(): + """ + Verify Expected Improvement (EI) calculation against the explicit mathematical formula. + + This ensures that EI correctly balances exploration and exploitation using both the CDF and PDF + of the normal distribution, which is critical for efficient active learning query strategies. + """ + mean = np.array([[1.0], [1.5]]) + std = np.array([[0.2], [1e-12]]) + best_y = 0.8 + xi = 0.01 + std_clip = std.clip(min=1e-9) + improvement = mean - best_y - xi + z_score = improvement / std_clip + expected = improvement * norm.cdf(z_score) + std_clip * norm.pdf(z_score) + assert_allclose(expected_improvement(mean, std, best_y=best_y, xi=xi), expected) + + +def test_acquisition_aliases_map_to_same_function(): + """Ensure that shorthand aliases for acquisition functions map to the correct implementation functions.""" + assert ( + _ACQUISITION_FUNCTIONS["ur"] + is _ACQUISITION_FUNCTIONS["max-uncertainty-reduction"] + ) + assert _ACQUISITION_FUNCTIONS["exp"] is _ACQUISITION_FUNCTIONS["exploitation"] + assert ( + _ACQUISITION_FUNCTIONS["ucb"] + is _ACQUISITION_FUNCTIONS["upper-confidence-bound"] + ) + assert ( + _ACQUISITION_FUNCTIONS["ei"] is _ACQUISITION_FUNCTIONS["expected-improvement"] + ) + assert ( + _ACQUISITION_FUNCTIONS["pi"] + is _ACQUISITION_FUNCTIONS["probability-improvement"] + ) diff --git a/openadmet/models/tests/unit/active_learning/test_active_learning.py b/openadmet/models/tests/unit/active_learning/test_active_learning.py index 362d20a6..f560c967 100644 --- a/openadmet/models/tests/unit/active_learning/test_active_learning.py +++ b/openadmet/models/tests/unit/active_learning/test_active_learning.py @@ -1,476 +1,200 @@ -from itertools import product -from pathlib import Path - import numpy as np -import pandas as pd import pytest from numpy.testing import assert_allclose from openadmet.models.active_learning.acquisition import _ACQUISITION_FUNCTIONS -from openadmet.models.active_learning.committee import ( - CommitteeRegressor, -) -from openadmet.models.architecture.lgbm import LGBMRegressorModel -from openadmet.models.architecture.model_base import ModelBase -from openadmet.models.inference.inference import load_anvil_model_and_metadata -from openadmet.models.split.sklearn import ShuffleSplitter -from openadmet.models.tests.unit.datafiles import ( - ACEH_chembl_pchembl, # chemprop - CYP3A4_chembl_pchembl, # lgbm - anvil_chemprop_trained_model_dir, - anvil_lgbm_trained_model_dir, -) - -# Remove redundant for testing -_ACQUISITION_FUNCTIONS_SHORTLIST = [ - x for x in _ACQUISITION_FUNCTIONS.keys() if "-" in x -] - - -@pytest.fixture -def chemprop_models(): - # Load the model and metadata - model_list = [] - for i in range(5): - model, feat, _, _ = load_anvil_model_and_metadata( - Path(anvil_chemprop_trained_model_dir) - ) - model_list.append(model) - - # Load data - data = pd.read_csv(ACEH_chembl_pchembl).iloc[:100, :] - X = data["OPENADMET_SMILES"].values - y = data["pchembl_value_mean"].values - - # Featurize - X_feat = feat.featurize(X)[0] - - return model_list, X_feat, y.reshape(-1, 1) - - -@pytest.fixture -def lgbm_models(): - model_list = [] - for i in range(5): - model, feat, _, _ = load_anvil_model_and_metadata( - Path(anvil_lgbm_trained_model_dir) - ) - model_list.append(model) - - # Load data - data = pd.read_csv(CYP3A4_chembl_pchembl).iloc[:100, :] - X = data["CANONICAL_SMILES"].values - y = data["pChEMBL mean"].values - - # Featurize - X_feat = feat.featurize(X)[0] - - return model_list, X_feat, y.reshape(-1, 1) +from openadmet.models.active_learning.committee import CommitteeRegressor +from openadmet.models.architecture.dummy import DummyRegressorModel @pytest.fixture def toy_data(): - # Set random seed for reproducibility - np.random.seed(42) - - # Number of samples - n_samples = 2000 - - # Features - X = np.column_stack( - [ - np.linspace(0, 10, n_samples), - np.random.uniform(0, 5, n_samples), - np.random.normal(5, 2, n_samples), - ] - ) - - # Targets - y = np.column_stack( - [ - 3 * np.sin(X[:, 0]) - + 0.5 * X[:, 1] ** 2 - - 0.8 * X[:, 2] - + np.random.normal(0, 0.1, n_samples), - 2 * np.cos(X[:, 0]) - + 0.3 * X[:, 1] ** 2 - + 0.5 * X[:, 2] - + np.random.normal(0, 0.1, n_samples), - ] - ) - - # Split the data - splitter = ShuffleSplitter(train_size=0.7, val_size=0.1, test_size=0.2) - X_train, X_val, X_test, y_train, y_val, y_test, _ = splitter.split( - X, y[:, 0].reshape(-1, 1) - ) - - return X_train, X_val, X_test, y_train, y_val, y_test - - -class MockCommitteeModel(ModelBase): - """ - Mock model for testing CommitteeRegressor. - Tracks random_state and training data shape. """ + Generate synthetic regression data for testing committee models. - random_state: int | None = None - _trained_data_shape: tuple | None = None - _trained_unique_samples: int | None = None - - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def build(self): - pass - - def train(self, X, y): - self._trained_data_shape = X.shape - # Count unique samples to detect bagging (duplicates) - # Using simple row hashing for uniqueness check - if X.ndim > 1: - X_view = np.ascontiguousarray(X).view( - np.dtype((np.void, X.dtype.itemsize * X.shape[1])) - ) - self._trained_unique_samples = len(np.unique(X_view)) - else: - self._trained_unique_samples = len(np.unique(X)) - - def predict(self, input: np.ndarray, **kwargs): - # Return dummy predictions - n_samples = input.shape[0] - # Return (n_samples, 1) to match regression output shape - return np.zeros((n_samples, 1)) - - def save(self, path): - pass - - def load(self, path): - pass - - def serialize(self, param_path, serial_path): - pass - - def deserialize(self, param_path, serial_path): - pass - - -def test_committee_bagging_logic(toy_data): - """Test that use_bagging flag correctly controls bootstrap aggregation.""" - X_train, _, _, y_train, _, _ = toy_data - n_samples = X_train.shape[0] - - # Test use_bagging=True (Default) - committee = CommitteeRegressor.train( - X_train, - y_train, - mod_class=MockCommitteeModel, - mod_params={"random_state": 42}, - n_models=3, - use_bagging=True, + This fixture creates a simple linear relationship with noise to verify that the + ensemble can learn and predict reasonable values. + """ + rng = np.random.default_rng(42) + X = rng.normal(size=(120, 3)) + y = ( + 1.2 * X[:, [0]] + - 0.8 * X[:, [1]] + + 0.3 * X[:, [2]] + + 0.1 * rng.normal(size=(120, 1)) ) + return X[:80], X[80:100], X[100:], y[:80], y[80:100], y[100:] - for i, model in enumerate(committee.models): - # Verify random_state increment - assert model.random_state == 42 + i - - # Verify bagging occurred: - # With replacement sampling, unique samples should be approx ~63.2% of total - # Definitely should be less than total n_samples for large N - assert model._trained_unique_samples < n_samples - assert model._trained_data_shape[0] == n_samples - - # Test use_bagging=False - committee_nb = CommitteeRegressor.train( - X_train, - y_train, - mod_class=MockCommitteeModel, - mod_params={"random_state": 10}, - n_models=2, - use_bagging=False, - ) - for i, model in enumerate(committee_nb.models): - # Verify random_state increment - assert model.random_state == 10 + i +@pytest.fixture +def dummy_models(): + """Create a list of initialized DummyRegressorModel instances for building a committee.""" + models = [] + for _ in range(5): + model = DummyRegressorModel(strategy="mean") + models.append(model) + return models - # Verify NO bagging occurred: - # Unique samples should equal total samples (training on full set) - assert model._trained_unique_samples == n_samples - assert model._trained_data_shape[0] == n_samples - # Test random_state=None (should handle gracefully) - committee_none = CommitteeRegressor.train( - X_train, - y_train, - mod_class=MockCommitteeModel, - mod_params={"random_state": None}, - n_models=2, - ) +@pytest.fixture +def trained_committee(dummy_models, toy_data): + """ + Create a trained CommitteeRegressor using bootstrapped data. - for model in committee_none.models: - assert model.random_state is None + This fixture trains multiple dummy models on bootstrapped subsets of the training data + to simulate a real ensemble training process, allowing for testing of prediction and uncertainty estimation. + """ + X_train, X_val, _, y_train, y_val, _ = toy_data + rng = np.random.default_rng(123) + for model in dummy_models: + bootstrap_idx = rng.choice( + X_train.shape[0], size=X_train.shape[0], replace=True + ) + model.train(X_train[bootstrap_idx], y_train[bootstrap_idx]) + return CommitteeRegressor.from_models(models=dummy_models), X_val, y_val -@pytest.mark.parametrize( - "model_list, calibration_method, query_strategy", - product( - ["lgbm_models", "chemprop_models"], - ["isotonic-regression", "scaling-factor", None], - _ACQUISITION_FUNCTIONS_SHORTLIST, - ), -) -def test_committee(request, model_list, calibration_method, query_strategy): - # Skip calibration on real data until we have true ensemble - if calibration_method is not None: - # Skip the test - pytest.skip("Skipping calibration test") +@pytest.mark.parametrize("query_strategy", sorted(_ACQUISITION_FUNCTIONS.keys())) +def test_committee_query_predict(trained_committee, query_strategy): + """ + Validate that the committee can query samples using all registered acquisition functions. - # Unpack models, features - _model_list, X_feat, y = request.getfixturevalue(model_list) + This ensures that the query interface works consistent across strategies and that + predictions and uncertainty estimates are returned in the correct shape. + """ + committee, X_val, _ = trained_committee + y_query = committee.query(X_val, query_strategy=query_strategy) + y_pred, y_pred_std = committee.predict(X_val, return_std=True) + assert y_query.shape == (X_val.shape[0], 1) + assert y_pred.shape == (X_val.shape[0], 1) + assert y_pred_std.shape == (X_val.shape[0], 1) + assert np.isfinite(y_query).all() - # Create committee - committee = CommitteeRegressor.from_models(models=_model_list) - # Calibrate uncertainty - if calibration_method is not None: - committee.calibrate_uncertainty( - X_feat, y, method=calibration_method, accelerator="cpu" - ) +def test_invalid_query_strategy_raises(trained_committee): + """Ensure ValueError is raised when an invalid query strategy is requested.""" + committee, X_val, _ = trained_committee + with pytest.raises(ValueError): + committee.query(X_val, query_strategy="not-a-strategy") - # Check model is calibrated - assert committee.calibrated - # Query - y_query = committee.query(X_feat, query_strategy=query_strategy, accelerator="cpu") - - # Predict - y_pred, y_pred_std = committee.predict(X_feat, return_std=True, accelerator="cpu") +def test_invalid_calibration_method_raises(trained_committee): + """Ensure ValueError is raised when an invalid uncertainty calibration method is requested.""" + committee, X_val, y_val = trained_committee + with pytest.raises(ValueError): + committee.calibrate_uncertainty(X_val, y_val, method="not-a-method") @pytest.mark.parametrize( - "model_list, calibration_method", - product( - ["lgbm_models", "chemprop_models"], - ["isotonic-regression", "scaling-factor", None], - ), + "calibration_method", ["isotonic-regression", "scaling-factor"] ) -def test_save_load(request, tmp_path, model_list, calibration_method): - # Skip calibration on real data until we have true ensemble - if calibration_method is not None: - calibration_model_path = tmp_path / "calibration_model.pkl" - # Skip the test - pytest.skip("Skipping calibration test") - else: - calibration_model_path = None - - # Unpack models, features - model_list, X_feat, y = request.getfixturevalue(model_list) - - # Create committee - committee = CommitteeRegressor.from_models(models=model_list) - - # Calibrate uncertainty - if calibration_method is not None: - committee.calibrate_uncertainty( - X_feat, y, method=calibration_method, accelerator="cpu" - ) - - # Check model is calibrated - assert committee.calibrated - - # Predict before saving - y_pred_mean, y_pred_std = committee.predict( - X_feat, return_std=True, accelerator="cpu" - ) +def test_calibration_paths(trained_committee, calibration_method): + """ + Verify that uncertainty calibration methods can be applied successfully. - # Save - save_paths = [tmp_path / "committee_model_{i}.pkl" for i in range(len(model_list))] - committee.save(save_paths, calibration_path=calibration_model_path) + This checks that the calibration state is updated and that predictions remain valid + after calibration (shapes are preserved). + """ + committee, X_val, y_val = trained_committee + committee.calibrate_uncertainty(X_val, y_val, method=calibration_method) + assert committee.calibrated + y_pred, y_std = committee.predict(X_val, return_std=True) + assert y_pred.shape == y_std.shape == y_val.shape - # Instantiate empty models to "fill" - models_new = [model.make_new() for model in model_list] - [model.build() for model in models_new] - # Load - committee.load( - save_paths, - models=models_new, - calibration_path=calibration_model_path, - ) +def test_train_and_train_validation(toy_data): + """ + Validate the high-level train method for creating a CommitteeRegressor. - # Predict after loading - y_pred_mean2, y_pred_std2 = committee.predict( - X_feat, return_std=True, accelerator="cpu" + This tests the end-to-end training process, ensuring that the correct number of models + are created and that the resulting ensemble can make predictions. + """ + X_train, _, X_test, y_train, _, _ = toy_data + committee = CommitteeRegressor.train( + X_train, y_train, mod_class=DummyRegressorModel, n_models=4 ) - - # Check that predictions are the same - assert_allclose(y_pred_mean, y_pred_mean2) - assert_allclose(y_pred_std, y_pred_std2) - - # Check that we successfully loaded calibration models - if calibration_method is not None: - assert committee.calibrated + mean, std = committee.predict(X_test, return_std=True) + assert committee.n_models == 4 + assert mean.shape == std.shape == (X_test.shape[0], 1) + with pytest.raises(ValueError): + CommitteeRegressor.train(X_train, y_train, mod_class=None, n_models=2) @pytest.mark.parametrize( - "model_list, calibration_method", - product( - ["lgbm_models", "chemprop_models"], - ["isotonic-regression", "scaling-factor", None], - ), + "calibration_method", ["isotonic-regression", "scaling-factor", None] ) -def test_serialization(request, tmp_path, model_list, calibration_method): - # Skip calibration on real data until we have true ensemble - if calibration_method is not None: - calibration_model_path = tmp_path / "calibration_model.pkl" - # Skip the test - pytest.skip("Skipping calibration test") - else: - calibration_model_path = None - - # Unpack models, features - model_list, X_feat, y = request.getfixturevalue(model_list) - - # Create committee - committee = CommitteeRegressor.from_models(models=model_list) - - # Calibrate uncertainty - if calibration_method is not None: - committee.calibrate_uncertainty( - X_feat, y, method=calibration_method, accelerator="cpu" - ) - - # Check model is calibrated - assert committee.calibrated +def test_save_load_roundtrip(tmp_path, trained_committee, calibration_method): + """ + Verify that a CommitteeRegressor can be saved and loaded correctly. - # Predict before saving - y_pred_mean, y_pred_std = committee.predict( - X_feat, return_std=True, accelerator="cpu" + This ensures persistence of the ensemble, including individual models and calibration state. + It verifies that predictions before save and after load are identical. + """ + committee, X_val, y_val = trained_committee + calibration_model_path = ( + tmp_path / "calibration_model.pkl" if calibration_method is not None else None ) - - # Serialize/deserialize - param_paths = [ - tmp_path / "committee_model_{i}.json" for i in range(len(model_list)) - ] - serial_paths = [ - tmp_path / "committee_model_{i}.pkl" for i in range(len(model_list)) + if calibration_method is not None: + committee.calibrate_uncertainty(X_val, y_val, method=calibration_method) + y_pred_mean, y_pred_std = committee.predict(X_val, return_std=True) + save_paths = [ + tmp_path / f"committee_model_{i}.pkl" for i in range(committee.n_models) ] - committee.serialize( - param_paths, serial_paths, calibration_path=calibration_model_path - ) - committee.deserialize( - param_paths, - serial_paths, - mod_class=model_list[0].__class__, - calibration_path=calibration_model_path, - ) - - # Predict after loading - y_pred_mean2, y_pred_std2 = committee.predict( - X_feat, return_std=True, accelerator="cpu" + committee.save(save_paths, calibration_path=calibration_model_path) + models_new = [model.make_new() for model in committee.models] + [model.build() for model in models_new] + committee = committee.load( + save_paths, models=models_new, calibration_path=calibration_model_path ) - - # Check that predictions are the same + y_pred_mean2, y_pred_std2 = committee.predict(X_val, return_std=True) assert_allclose(y_pred_mean, y_pred_mean2) assert_allclose(y_pred_std, y_pred_std2) - - # Check that we successfully loaded calibration models - if calibration_method is not None: - assert committee.calibrated + assert committee.calibrated is (calibration_method is not None) -# This test is somewhat redundant and a catch-all, but useful until we have ability to test real-world ensembles -# more explicitly @pytest.mark.parametrize( "calibration_method", ["isotonic-regression", "scaling-factor", None] ) -def test_calibration(tmp_path, toy_data, calibration_method): - # Unpack data - X_train, X_val, X_test, y_train, y_val, y_test = toy_data - - # Parameters - mod_params = {"alpha": 0.005, "learning_rate": 0.05, "force_col_wise": True} +def test_serialize_deserialize_roundtrip( + tmp_path, trained_committee, calibration_method +): + """ + Verify that a CommitteeRegressor can be serialized and deserialized via JSON/pickle. - # Train committee - committee = CommitteeRegressor.train( - X_train, - y_train, - mod_class=LGBMRegressorModel, - mod_params=mod_params, - n_models=5, + This tests the serialization pathway used for distributed training or registry storage, + ensuring full state recovery including calibration. + """ + committee, X_val, y_val = trained_committee + calibration_model_path = ( + tmp_path / "calibration_model.pkl" if calibration_method is not None else None ) - - # Calibrate uncertainty if calibration_method is not None: committee.calibrate_uncertainty(X_val, y_val, method=calibration_method) - calibration_model_path = tmp_path / "calibration_model.pkl" - - # Check model is calibrated - assert committee.calibrated - else: - calibration_model_path = None - - # Evaluate on test set - y_pred_mean, y_pred_std = committee.predict(X_test, return_std=True) - - # Generate plot - committee.plot_uncertainty_calibration(X_test, y_test) - - # Serialize/deserialize + y_pred_mean, y_pred_std = committee.predict(X_val, return_std=True) param_paths = [ - tmp_path / "committee_model_{i}.json" for i in range(len(committee.models)) + tmp_path / f"committee_model_{i}.json" for i in range(committee.n_models) ] serial_paths = [ - tmp_path / "committee_model_{i}.pkl" for i in range(len(committee.models)) + tmp_path / f"committee_model_{i}.pkl" for i in range(committee.n_models) ] committee.serialize( param_paths, serial_paths, calibration_path=calibration_model_path ) - committee.deserialize( + committee = committee.deserialize( param_paths, serial_paths, - mod_class=LGBMRegressorModel, + mod_class=DummyRegressorModel, calibration_path=calibration_model_path, ) - - # Evaluate on test set again - y_pred_mean2, y_pred_std2 = committee.predict(X_test, return_std=True) - - # Check results match original + y_pred_mean2, y_pred_std2 = committee.predict(X_val, return_std=True) assert_allclose(y_pred_mean, y_pred_mean2) assert_allclose(y_pred_std, y_pred_std2) + assert committee.calibrated is (calibration_method is not None) - # Check that we successfully loaded calibration models - if calibration_method is not None: - assert committee.calibrated - - # Save/load - save_paths = [ - tmp_path / "committee_model_{i}.pkl" for i in range(len(committee.models)) - ] - committee.save(save_paths, calibration_path=calibration_model_path) - - # Instantiate empty models to "fill" - models_new = [ - LGBMRegressorModel(**mod_params) for _ in range(len(committee.models)) - ] - [model.build() for model in models_new] - # Load - committee.load( - save_paths, - models=models_new, - calibration_path=calibration_model_path, - ) - - # Evaluate on test set again - y_pred_mean2, y_pred_std2 = committee.predict(X_test, return_std=True) - - # Check results match original - assert_allclose(y_pred_mean, y_pred_mean2) - assert_allclose(y_pred_std, y_pred_std2) - - # Check that we successfully loaded calibration models - if calibration_method is not None: - assert committee.calibrated +def test_plot_uncertainty_calibration(trained_committee): + """Check that the uncertainty calibration plotting function runs without error.""" + committee, X_val, y_val = trained_committee + committee.calibrate_uncertainty(X_val, y_val, method="scaling-factor") + plot = committee.plot_uncertainty_calibration(X_val, y_val) + assert plot is not None diff --git a/openadmet/models/tests/unit/active_learning/test_ensemble_base.py b/openadmet/models/tests/unit/active_learning/test_ensemble_base.py new file mode 100644 index 00000000..29b5759a --- /dev/null +++ b/openadmet/models/tests/unit/active_learning/test_ensemble_base.py @@ -0,0 +1,15 @@ +import pytest + +from openadmet.models.active_learning.committee import CommitteeRegressor +from openadmet.models.active_learning.ensemble_base import get_ensemble_class + + +def test_get_ensemble_class_success(): + """Verify that get_ensemble_class returns the correct class for a valid ensemble type.""" + assert get_ensemble_class("CommitteeRegressor") is CommitteeRegressor + + +def test_get_ensemble_class_raises_for_invalid_type(): + """Ensure get_ensemble_class raises ValueError when requested for a non-existent ensemble type.""" + with pytest.raises(ValueError, match="Ensemble type not-real not found"): + get_ensemble_class("not-real") diff --git a/openadmet/models/tests/unit/anvil/test_anvil.py b/openadmet/models/tests/unit/anvil/test_anvil.py deleted file mode 100644 index 3c87cce5..00000000 --- a/openadmet/models/tests/unit/anvil/test_anvil.py +++ /dev/null @@ -1,113 +0,0 @@ -from pathlib import Path - -import pytest - -from openadmet.models.anvil.specification import ( - AnvilSpecification, -) -from openadmet.models.tests.unit.datafiles import ( - acetylcholinesterase_anvil_chemprop_yaml, - anvil_yaml_featconcat, - anvil_yaml_gridsearch, - anvil_yaml_xgboost_cv, - basic_anvil_yaml, - basic_anvil_yaml_classification, - basic_anvil_yaml_cv, - tabpfn_anvil_classification_yaml, -) - - -def all_anvil_full_recipes(): - return [ - basic_anvil_yaml, - # anvil_yaml_featconcat, # skipping as slow, redundant with integration tests - anvil_yaml_gridsearch, - # anvil_yaml_xgboost_cv, # skipping as slow, redundant with integration tests - ] - - -def test_anvil_spec_create(): - anvil_spec = AnvilSpecification.from_recipe(basic_anvil_yaml) - assert anvil_spec - - -def test_anvil_spec_create_from_recipe_roundtrip(tmp_path): - anvil_spec = AnvilSpecification.from_recipe(basic_anvil_yaml) - assert anvil_spec - anvil_spec.to_recipe(tmp_path / "tst.yaml") - anvil_spec2 = AnvilSpecification.from_recipe(tmp_path / "tst.yaml") - # these were created from different directories, so the anvil_dir will be different - anvil_spec.data.anvil_dir = None - anvil_spec2.data.anvil_dir = None - - assert anvil_spec == anvil_spec2 - - -def test_anvil_spec_create_to_workflow(): - anvil_spec = AnvilSpecification.from_recipe(basic_anvil_yaml) - anvil_workflow = anvil_spec.to_workflow() - assert anvil_workflow - - -@pytest.mark.parametrize("anvil_full_recipie", all_anvil_full_recipes()) -def test_anvil_workflow_run(tmp_path, anvil_full_recipie): - anvil_workflow = AnvilSpecification.from_recipe(anvil_full_recipie).to_workflow() - anvil_workflow.run(output_dir=tmp_path / "tst") - assert Path(tmp_path / "tst" / "model.json").exists() - assert Path(tmp_path / "tst" / "regression_metrics.json").exists() - assert any((tmp_path / "tst").glob("*regplot.png")) - - -def test_anvil_multiyaml(tmp_path): - anvil_spec = AnvilSpecification.from_recipe(basic_anvil_yaml) - anvil_spec.to_multi_yaml( - metadata_yaml=tmp_path / "metadata.yaml", - procedure_yaml=tmp_path / "procedure.yaml", - data_yaml=tmp_path / "data.yaml", - report_yaml=tmp_path / "eval.yaml", - ) - anvil_spec2 = AnvilSpecification.from_multi_yaml( - metadata_yaml=tmp_path / "metadata.yaml", - procedure_yaml=tmp_path / "procedure.yaml", - data_yaml=tmp_path / "data.yaml", - report_yaml=tmp_path / "eval.yaml", - ) - assert anvil_spec.data.anvil_dir == anvil_spec2.data.anvil_dir - assert anvil_spec.dict() == anvil_spec2.dict() - - -def test_anvil_cross_val_run(tmp_path): - anvil_spec = AnvilSpecification.from_recipe(basic_anvil_yaml_cv) - anvil_workflow = anvil_spec.to_workflow() - anvil_workflow.run(output_dir=tmp_path / "tst") - - -def test_anvil_classification_run(tmp_path): - anvil_spec = AnvilSpecification.from_recipe(basic_anvil_yaml_classification) - anvil_workflow = anvil_spec.to_workflow() - anvil_workflow.run(output_dir=tmp_path / "tst") - - assert Path(tmp_path / "tst" / "anvil_recipe.yaml").exists() - assert Path(tmp_path / "tst" / "model.json").exists() - assert Path(tmp_path / "tst" / "classification_metrics.json").exists() - assert Path(tmp_path / "tst" / "pr_curve.png").exists() - assert Path(tmp_path / "tst" / "roc_curve.png").exists() - - -# skip on MacOS runner? -def test_anvil_chemprop_cpu_regression(tmp_path): - anvil_spec = AnvilSpecification.from_recipe( - acetylcholinesterase_anvil_chemprop_yaml - ) - anvil_workflow = anvil_spec.to_workflow() - anvil_workflow.run(output_dir=tmp_path / "tst") - assert Path(tmp_path / "tst" / "model.json").exists() - assert Path(tmp_path / "tst" / "regression_metrics.json").exists() - assert any((tmp_path / "tst").glob("*regplot.png")) - - -@pytest.mark.skip(reason="TabPFN requires GPU and is not supported on MacOS runners") -def test_anvil_tabpfn_classification(tmp_path): - anvil_spec = AnvilSpecification.from_recipe(tabpfn_anvil_classification_yaml) - anvil_workflow = anvil_spec.to_workflow() - anvil_workflow.run(output_dir=tmp_path / "tst") diff --git a/openadmet/models/tests/unit/anvil/test_specification.py b/openadmet/models/tests/unit/anvil/test_specification.py new file mode 100644 index 00000000..59f3fbcc --- /dev/null +++ b/openadmet/models/tests/unit/anvil/test_specification.py @@ -0,0 +1,645 @@ +from pathlib import Path + +import numpy as np +import pandas as pd +import pytest +import yaml + +from openadmet.models.anvil.specification import ( + AnvilSpecification, + DataSpec, + EnsembleSpec, + EvalSpec, + FeatureSpec, + Metadata, + ModelSpec, + ProcedureSpec, + ReportSpec, + SplitSpec, + TrainerSpec, + TransformSpec, +) +from openadmet.models.anvil.workflow import AnvilDeepLearningWorkflow, AnvilWorkflow +from openadmet.models.architecture.model_base import LightningModelBase + +# --- DataSpec Tests --- + + +def test_dataspec_resource_and_train_test_mutually_exclusive(): + """Test that specifying both resource and train_resource raises ValueError.""" + with pytest.raises(ValueError, match="Specify either `resource` or"): + DataSpec( + type="csv", + input_col="smiles", + target_cols="target", + resource="data.csv", + train_resource="train.csv", + ) + + +def test_dataspec_requires_train_and_test_together(): + """Test that specifying train_resource without test_resource raises ValueError.""" + with pytest.raises(ValueError, match="must both be specified"): + DataSpec( + type="csv", + input_col="smiles", + target_cols="target", + train_resource="train.csv", + ) + + +def test_dataspec_target_cols_string_normalized_to_list(): + """Test that a string target_cols is converted to a list.""" + spec = DataSpec( + type="csv", + input_col="smiles", + target_cols="activity", + resource="data.csv", + ) + assert spec.target_cols == ["activity"] + + +def test_dataspec_template_anvil_dir_replaces_placeholder(tmp_path): + """Test that {{ ANVIL_DIR }} is replaced in resource path.""" + spec = DataSpec( + type="csv", + input_col="smiles", + target_cols="target", + resource="{{ ANVIL_DIR }}/data.csv", + anvil_dir="/tmp/mydir", + ) + # The validator runs automatically if anvil_dir is set at init + assert spec.resource == "/tmp/mydir/data.csv" + + # Test explicit method call + spec2 = DataSpec( + type="csv", + input_col="smiles", + target_cols="target", + resource="{{ ANVIL_DIR }}/data.csv", + ) + spec2.template_anvil_dir(Path("/other/dir")) + assert spec2.resource == "/other/dir/data.csv" + + +def test_dataspec_read_single_resource_csv(tmp_path): + """Test reading a single CSV resource.""" + csv_path = tmp_path / "data.csv" + df = pd.DataFrame( + { + "smiles": ["CCO", "CC(C)O", "c1ccccc1"], + "target": [1.0, 2.0, 3.0], + "extra": ["a", "b", "c"], + } + ) + df.to_csv(csv_path, index=False) + + spec = DataSpec( + type="csv", + input_col="smiles", + target_cols="target", + resource=str(csv_path), + ) + X, y = spec.read() + + assert isinstance(X, pd.Series) + assert isinstance(y, pd.DataFrame) + assert len(X) == 3 + assert len(y) == 3 + assert list(y.columns) == ["target"] + assert X.iloc[0] == "CCO" + assert y.iloc[0, 0] == 1.0 + + +def test_dataspec_read_single_resource_dropna(tmp_path): + """Test that rows with NaNs in target columns are dropped.""" + csv_path = tmp_path / "data_nan.csv" + df = pd.DataFrame( + { + "smiles": ["CCO", "CC(C)O", "c1ccccc1", "C"], + "target": [1.0, np.nan, 3.0, 4.0], + } + ) + df.to_csv(csv_path, index=False) + + spec = DataSpec( + type="csv", + input_col="smiles", + target_cols="target", + resource=str(csv_path), + dropna=True, + ) + X, y = spec.read() + + assert len(X) == 3 + assert len(y) == 3 + assert "CC(C)O" not in X.values + + +def test_dataspec_read_train_test_val_returns_correct_splits(tmp_path): + """Test reading separate train, test, and val resources.""" + train_path = tmp_path / "train.csv" + test_path = tmp_path / "test.csv" + val_path = tmp_path / "val.csv" + + pd.DataFrame({"smiles": ["A", "B", "C"], "target": [1, 2, 3]}).to_csv( + train_path, index=False + ) + + pd.DataFrame({"smiles": ["D", "E"], "target": [4, 5]}).to_csv( + test_path, index=False + ) + + pd.DataFrame({"smiles": ["F"], "target": [6]}).to_csv(val_path, index=False) + + spec = DataSpec( + type="csv", + input_col="smiles", + target_cols="target", + train_resource=str(train_path), + test_resource=str(test_path), + val_resource=str(val_path), + ) + + # Returns: X_train, X_val, X_test, y_train, y_val, y_test, X, y + X_train, X_val, X_test, y_train, y_val, y_test, X, y = spec.read() + + assert len(X_train) == 3 + assert len(X_test) == 2 + assert len(X_val) == 1 + assert len(X) == 6 + assert len(y) == 6 + + # Verify content + assert X_train.tolist() == ["A", "B", "C"] + assert X_test.tolist() == ["D", "E"] + assert X_val.tolist() == ["F"] + + +def test_dataspec_read_train_test_raises_on_split_column_in_file(tmp_path): + """Test that a ValueError is raised if input files contain a '_split' column.""" + train_path = tmp_path / "train_bad.csv" + test_path = tmp_path / "test_bad.csv" + + pd.DataFrame({"smiles": ["A"], "target": [1], "_split": ["train"]}).to_csv( + train_path, index=False + ) + + pd.DataFrame({"smiles": ["B"], "target": [2]}).to_csv(test_path, index=False) + + spec = DataSpec( + type="csv", + input_col="smiles", + target_cols="target", + train_resource=str(train_path), + test_resource=str(test_path), + ) + + with pytest.raises(ValueError, match="should not contain a '_split' column"): + spec.read() + + +def test_dataspec_to_yaml_from_yaml_roundtrip(tmp_path): + """Test roundtrip YAML serialization for DataSpec.""" + spec = DataSpec( + type="csv", + input_col="smiles", + target_cols=["target1", "target2"], + resource="data.csv", + dropna=True, + ) + yaml_path = tmp_path / "spec.yaml" + spec.to_yaml(yaml_path) + + loaded_spec = DataSpec.from_yaml(yaml_path) + assert loaded_spec.input_col == spec.input_col + assert loaded_spec.target_cols == spec.target_cols + assert loaded_spec.resource == spec.resource + assert loaded_spec.dropna == spec.dropna + + +# --- Metadata Tests --- + + +def test_metadata_to_yaml_from_yaml_roundtrip(tmp_path): + """Test roundtrip YAML serialization for Metadata.""" + meta = Metadata( + version="v1", + name="test-workflow", + build_number=1, + description="A test workflow", + tag="v1.0.0", + authors="Test Author", + email="test@example.com", + biotargets=["TargetA"], + tags=["tag1", "tag2"], + ) + yaml_path = tmp_path / "metadata.yaml" + meta.to_yaml(yaml_path) + + loaded_meta = Metadata.from_yaml(yaml_path) + assert loaded_meta.name == meta.name + assert loaded_meta.biotargets == meta.biotargets + assert loaded_meta.version == "v1" + + +# --- AnvilSection Tests --- + + +def test_anvilsection_to_class_dispatches_correctly(): + """Test that to_class returns the correct class instance.""" + # Using SplitSpec as a concrete example + spec = SplitSpec( + type="ShuffleSplitter", params={"train_size": 0.8, "test_size": 0.2} + ) + splitter = spec.to_class() + # Check if it has the attributes we expect from a splitter + assert hasattr(splitter, "split") + assert splitter.train_size == 0.8 + + +# --- ModelSpec Tests --- + + +def test_modelspec_path_pairs_validation(): + """Test validation of param_path and serial_path pairs.""" + # Success cases + ModelSpec(type="MyModel", param_path="p.pt", serial_path="s.pt") + ModelSpec(type="MyModel") + + # Failure cases + with pytest.raises(ValueError, match="must be provided together"): + ModelSpec(type="MyModel", param_path="p.pt") + + with pytest.raises(ValueError, match="must be provided together"): + ModelSpec(type="MyModel", serial_path="s.pt") + + +# --- EnsembleSpec Tests --- + + +def test_ensemblespec_n_models_minimum(): + """Test validation of n_models.""" + with pytest.raises(ValueError, match="Ensemble must have more than one model"): + EnsembleSpec(type="Ensemble", n_models=1) + + EnsembleSpec(type="Ensemble", n_models=2) + + +def test_ensemblespec_path_count_validation(): + """Test validation of param_paths and serial_paths lengths.""" + # Length mismatch between paths + with pytest.raises(ValueError, match="same length"): + EnsembleSpec( + type="Ensemble", n_models=2, param_paths=["p1", "p2"], serial_paths=["s1"] + ) + + # Length mismatch with n_models + with pytest.raises(ValueError, match="match the number of models"): + EnsembleSpec( + type="Ensemble", + n_models=3, + param_paths=["p1", "p2"], + serial_paths=["s1", "s2"], + ) + + # Success + EnsembleSpec( + type="Ensemble", n_models=2, param_paths=["p1", "p2"], serial_paths=["s1", "s2"] + ) + + +# --- AnvilSpecification Tests --- + + +def test_anvilspecification_from_recipe_resolves_anvil_dir(tmp_path): + """Test that loading from a recipe resolves {{ ANVIL_DIR }}.""" + workflow_dir = tmp_path / "myworkflow" + workflow_dir.mkdir() + recipe_path = workflow_dir / "recipe.yaml" + + # Create minimal valid YAML + recipe_content = { + "metadata": { + "version": "v1", + "name": "test", + "build_number": 0, + "description": "d", + "tag": "t", + "authors": "a", + "email": "a@b.com", + "biotargets": [], + "tags": [], + }, + "data": { + "type": "csv", + "resource": "{{ ANVIL_DIR }}/data.csv", + "input_col": "s", + "target_cols": "t", + }, + "procedure": { + "split": {"type": "RandomSplitter"}, + "feat": {"type": "FingerprintFeaturizer"}, + "model": {"type": "LGBMRegressorModel"}, + "train": {"type": "SKLearnBasicTrainer"}, + }, + "report": {"eval": []}, + } + + with open(recipe_path, "w") as f: + yaml.dump(recipe_content, f) + + spec = AnvilSpecification.from_recipe(recipe_path) + # The resolved path should contain the temp dir path (fsspec adds file:// scheme) + expected_path = (workflow_dir / "data.csv").as_uri() + assert spec.data.resource == expected_path + + +def test_anvilspecification_to_multi_yaml_from_multi_yaml_roundtrip(tmp_path): + """Test splitting spec into multiple YAMLs and reloading.""" + meta = Metadata( + version="v1", + name="test", + build_number=0, + description="d", + tag="t", + authors="a", + email="a@b.com", + biotargets=[], + tags=[], + ) + data = DataSpec(type="csv", resource="data.csv", input_col="s", target_cols="t") + proc = ProcedureSpec( + split=SplitSpec(type="RandomSplitter"), + feat=FeatureSpec(type="FingerprintFeaturizer"), + model=ModelSpec(type="LGBMRegressorModel"), + train=TrainerSpec(type="SKLearnBasicTrainer"), + ) + report = ReportSpec(eval=[]) + + spec = AnvilSpecification(metadata=meta, data=data, procedure=proc, report=report) + + spec.to_multi_yaml( + metadata_yaml=tmp_path / "meta.yaml", + procedure_yaml=tmp_path / "proc.yaml", + data_yaml=tmp_path / "data.yaml", + report_yaml=tmp_path / "eval.yaml", + ) + + assert (tmp_path / "meta.yaml").exists() + assert (tmp_path / "proc.yaml").exists() + + reloaded = AnvilSpecification.from_multi_yaml( + metadata_yaml=tmp_path / "meta.yaml", + procedure_yaml=tmp_path / "proc.yaml", + data_yaml=tmp_path / "data.yaml", + report_yaml=tmp_path / "eval.yaml", + ) + + assert reloaded.metadata.name == spec.metadata.name + assert reloaded.data.resource == spec.data.resource + + +def test_anvilspecification_to_workflow_returns_correct_driver_type(mocker): + """Test that to_workflow returns correct workflow class based on trainer driver.""" + + def make_spec(trainer_type, feat_params=None): + return AnvilSpecification( + metadata=Metadata( + version="v1", + name="t", + build_number=0, + description="d", + tag="t", + authors="a", + email="a@b.com", + biotargets=[], + tags=[], + ), + data=DataSpec(type="csv", resource="d.csv", input_col="s", target_cols="t"), + procedure=ProcedureSpec( + split=SplitSpec(type="ShuffleSplitter"), + feat=FeatureSpec( + type="FingerprintFeaturizer", + params=feat_params or {"fp_type": "ecfp:4"}, + ), + model=ModelSpec(type="LGBMRegressorModel"), + train=TrainerSpec(type=trainer_type), + ), + report=ReportSpec(eval=[]), + ) + + # Case 1: SKLEARN driver — use real registered types; no mocking needed + spec_sklearn = make_spec("SKLearnBasicTrainer") + workflow_sklearn = spec_sklearn.to_workflow() + assert isinstance(workflow_sklearn, AnvilWorkflow) + + # Case 2: LIGHTNING driver — mock section.to_class() at class level since no DL model is registered + from openadmet.models.drivers import DriverType as _DriverType + from openadmet.models.trainer.lightning import LightningTrainer as _LightningTrainer + + dl_model = mocker.create_autospec(LightningModelBase, instance=True) + dl_model._n_tasks = 1 + dl_model._driver_type = _DriverType.LIGHTNING + dl_trainer = mocker.create_autospec(_LightningTrainer, instance=True) + dl_trainer._driver_type = _DriverType.LIGHTNING + + spec_dl = make_spec("LightningTrainer") + + # Patch only model and trainer to_class; split/feat use real registered types + mocker.patch.object(ModelSpec, "to_class", autospec=True, return_value=dl_model) + mocker.patch.object(TrainerSpec, "to_class", autospec=True, return_value=dl_trainer) + + workflow_dl = spec_dl.to_workflow() + assert isinstance(workflow_dl, AnvilDeepLearningWorkflow) + assert workflow_dl.model_kwargs == { + "param_path": None, + "serial_path": None, + "freeze_weights": None, + } + assert workflow_dl.feat_kwargs == { + "type": "FingerprintFeaturizer", + "params": {"fp_type": "ecfp:4"}, + } + + +def test_anvilspecification_run_writes_provenance_to_resolved_output_dir( + tmp_path, mocker +): + """Test that run() writes the recipe to the output directory.""" + spec = AnvilSpecification( + metadata=Metadata( + version="v1", + name="t", + build_number=0, + description="d", + tag="tag_original", + authors="a", + email="a@b.com", + biotargets=[], + tags=[], + ), + data=DataSpec(type="csv", resource="d.csv", input_col="s", target_cols="t"), + procedure=ProcedureSpec( + split=SplitSpec(type="S"), + feat=FeatureSpec(type="F"), + model=ModelSpec(type="M"), + train=TrainerSpec(type="SKLearnBasicTrainer"), + ), + report=ReportSpec(eval=[]), + ) + + # Mock workflow run to avoid real execution + mock_workflow = mocker.Mock() + mock_workflow.resolved_output_dir = tmp_path / "resolved" + mock_workflow.run.return_value = None + + mocker.patch.object( + AnvilSpecification, "to_workflow", autospec=True, return_value=mock_workflow + ) + + spec.run(output_dir=tmp_path / "out") + + # Check that provenance files were written + assert (tmp_path / "resolved" / "anvil_recipe.yaml").exists() + assert (tmp_path / "resolved" / "recipe_components" / "metadata.yaml").exists() + + +def test_anvilspecification_run_tag_override(tmp_path, mocker): + """Test that providing a tag to run() overrides the metadata tag in provenance.""" + spec = AnvilSpecification( + metadata=Metadata( + version="v1", + name="t", + build_number=0, + description="d", + tag="tag_original", + authors="a", + email="a@b.com", + biotargets=[], + tags=[], + ), + data=DataSpec(type="csv", resource="d.csv", input_col="s", target_cols="t"), + procedure=ProcedureSpec( + split=SplitSpec(type="S"), + feat=FeatureSpec(type="F"), + model=ModelSpec(type="M"), + train=TrainerSpec(type="SKLearnBasicTrainer"), + ), + report=ReportSpec(eval=[]), + ) + + mock_workflow = mocker.Mock() + mock_workflow.resolved_output_dir = tmp_path / "resolved" + mocker.patch.object( + AnvilSpecification, "to_workflow", autospec=True, return_value=mock_workflow + ) + + spec.run(output_dir=tmp_path / "out", tag="new_tag") + + # Check the saved yaml has the new tag + saved_yaml = tmp_path / "resolved" / "anvil_recipe.yaml" + with open(saved_yaml) as f: + saved_data = yaml.safe_load(f) + assert saved_data["metadata"]["tag"] == "new_tag" + + # Ensure original object is not mutated + assert spec.metadata.tag == "tag_original" + + +# --- DataSpec format/catalog tests (Refinement 5) --- + + +def test_dataspec_read_single_resource_yaml_raises_without_cat_entry(tmp_path): + """Test that reading a YAML resource without cat_entry raises ValueError.""" + yaml_path = tmp_path / "catalog.yaml" + yaml_path.write_text("sources: {}\n") + + spec = DataSpec( + type="yaml", + input_col="smiles", + target_cols="target", + resource=str(yaml_path), + ) + with pytest.raises(ValueError, match="cat_entry must be specified"): + spec.read() + + +def test_dataspec_read_single_resource_parquet(tmp_path): + """Test reading a single Parquet resource returns correct data.""" + pq_path = tmp_path / "data.parquet" + df = pd.DataFrame( + { + "smiles": ["CCO", "CC(C)O", "c1ccccc1"], + "activity": [0.1, 0.5, 0.9], + } + ) + df.to_parquet(pq_path, index=False) + + spec = DataSpec( + type="parquet", + input_col="smiles", + target_cols="activity", + resource=str(pq_path), + ) + X, y = spec.read() + + assert len(X) == 3 + assert len(y) == 3 + assert list(y.columns) == ["activity"] + assert X.iloc[0] == "CCO" + assert y.iloc[0, 0] == pytest.approx(0.1) + + +def test_dataspec_read_single_resource_unsupported_extension(): + """Test that reading a resource with unsupported extension raises ValueError.""" + spec = DataSpec( + type="json", + input_col="smiles", + target_cols="target", + resource="/some/file.json", + ) + with pytest.raises(ValueError, match="Unsupported resource type"): + spec.read() + + +def test_dataspec_read_train_test_yaml_raises(): + """Test that YAML resources raise ValueError for train/test split reads.""" + spec = DataSpec( + type="yaml", + input_col="smiles", + target_cols="target", + train_resource="data.yaml", + test_resource="data2.yaml", + ) + with pytest.raises(ValueError, match="YAML catalogs not supported"): + spec.read() + + +# --- ModelSpec freeze_weights tests (Refinement 6) --- + + +def test_modelspec_freeze_weights_succeeds_when_supported(mocker): + """Test ModelSpec instantiates without error when freeze_weights is supported.""" + mock_model = mocker.create_autospec(LightningModelBase, instance=True) + mock_model.build.return_value = None + mock_model.freeze_weights.return_value = None + + mocker.patch.object(ModelSpec, "to_class", autospec=True, return_value=mock_model) + + spec = ModelSpec(type="SomeModel", freeze_weights={"layer": "encoder"}) + assert spec is not None + mock_model.build.assert_called_once() + mock_model.freeze_weights.assert_called_once() + + +def test_modelspec_freeze_weights_raises_when_not_implemented(mocker): + """Test ModelSpec raises ValueError when freeze_weights is not implemented.""" + mock_model = mocker.create_autospec(LightningModelBase, instance=True) + mock_model.build.return_value = None + mock_model.freeze_weights.side_effect = NotImplementedError("not implemented") + + mocker.patch.object(ModelSpec, "to_class", autospec=True, return_value=mock_model) + + with pytest.raises(ValueError, match="Weight freezing not implemented"): + ModelSpec(type="SomeModel", freeze_weights={"layer": "encoder"}) diff --git a/openadmet/models/tests/unit/anvil/test_workflow.py b/openadmet/models/tests/unit/anvil/test_workflow.py new file mode 100644 index 00000000..92c87428 --- /dev/null +++ b/openadmet/models/tests/unit/anvil/test_workflow.py @@ -0,0 +1,577 @@ +"""Unit tests for anvil/workflow.py — utility functions, class instantiation, Pydantic validators, and driver routing. + +Scope is intentionally limited to construction-time behavior. No `.run()`, `_train()`, or execution +paths are exercised here; those belong in integration tests. +""" + +import numpy as np +import pandas as pd +import pytest + +from openadmet.models.active_learning.committee import CommitteeRegressor +from openadmet.models.anvil.specification import DataSpec, Metadata +from openadmet.models.anvil.workflow import ( + _DRIVER_TO_CLASS, + AnvilDeepLearningWorkflow, + AnvilWorkflow, + _safe_to_numpy, +) +from openadmet.models.architecture.chemprop import ChemPropModel +from openadmet.models.architecture.dummy import DummyRegressorModel +from openadmet.models.drivers import DriverType +from openadmet.models.features.molfeat_fingerprint import FingerprintFeaturizer +from openadmet.models.split.sklearn import ShuffleSplitter +from openadmet.models.trainer.lightning import LightningTrainer +from openadmet.models.trainer.sklearn import SKlearnBasicTrainer +from openadmet.models.transforms.impute import ImputeTransform + +# --------------------------------------------------------------------------- +# Module-scoped fixtures — constructed once per test session for performance +# --------------------------------------------------------------------------- + + +@pytest.fixture(scope="module") +def metadata(): + """Return a minimal real Metadata instance.""" + return Metadata( + version="v1", + driver="sklearn", + name="test-workflow", + build_number=0, + description="Unit test workflow", + tag="test-tag", + authors="Test Author", + email="test@example.com", + biotargets=["target1"], + tags=["unit-test"], + ) + + +@pytest.fixture(scope="module") +def data_spec(): + """Return a minimal real DataSpec instance with one target column.""" + return DataSpec( + type="csv", + input_col="smiles", + target_cols=["target"], + resource="data.csv", + ) + + +@pytest.fixture(scope="module") +def sklearn_feat(): + """Return a real FingerprintFeaturizer using ECFP4 (RDKit-only, no downloads).""" + return FingerprintFeaturizer(fp_type="ecfp:4") + + +# --------------------------------------------------------------------------- +# Factory helpers — build workflows from fully real Pydantic components +# --------------------------------------------------------------------------- + + +def _make_anvil_workflow( + metadata, + data_spec, + feat, + *, + split=None, + ensemble=None, + model_kwargs=None, + ensemble_kwargs=None, + feat_kwargs=None, + transform=None, +): + """Construct an AnvilWorkflow from real lightweight production components.""" + if split is None: + split = ShuffleSplitter(train_size=0.8, val_size=0.0, test_size=0.2) + return AnvilWorkflow( + metadata=metadata, + data_spec=data_spec, + split=split, + feat=feat, + model=DummyRegressorModel(), + trainer=SKlearnBasicTrainer(), + evals=[], + ensemble=ensemble, + transform=transform, + model_kwargs=model_kwargs or {}, + ensemble_kwargs=ensemble_kwargs or {}, + feat_kwargs=feat_kwargs or {}, + ) + + +def _make_dl_workflow( + metadata, + data_spec, + feat, + *, + split=None, + ensemble=None, + transform=None, + model_kwargs=None, + ensemble_kwargs=None, +): + """Construct an AnvilDeepLearningWorkflow from real lightweight production components.""" + if split is None: + split = ShuffleSplitter(train_size=0.8, val_size=0.0, test_size=0.2) + return AnvilDeepLearningWorkflow( + metadata=metadata, + data_spec=data_spec, + split=split, + feat=feat, + model=ChemPropModel(), + trainer=LightningTrainer(), + evals=[], + ensemble=ensemble, + transform=transform, + model_kwargs=model_kwargs or {}, + ensemble_kwargs=ensemble_kwargs or {}, + ) + + +# --------------------------------------------------------------------------- +# Section 1: _safe_to_numpy utility +# --------------------------------------------------------------------------- + + +def test_safe_to_numpy_series(): + """Test that _safe_to_numpy converts a pd.Series to a np.ndarray with correct values.""" + s = pd.Series([1.0, 2.0, 3.0]) + result = _safe_to_numpy(s) + assert isinstance(result, np.ndarray) + assert result.shape == (3,) + np.testing.assert_array_equal(result, np.array([1.0, 2.0, 3.0])) + + +def test_safe_to_numpy_dataframe(): + """Test that _safe_to_numpy converts a pd.DataFrame to a np.ndarray with correct shape and values.""" + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + result = _safe_to_numpy(df) + assert isinstance(result, np.ndarray) + assert result.shape == (2, 2) + np.testing.assert_array_equal(result, df.to_numpy()) + + +def test_safe_to_numpy_ndarray_passthrough(): + """Test that _safe_to_numpy returns a np.ndarray unchanged via identity check.""" + arr = np.array([1.0, 2.0, 3.0]) + result = _safe_to_numpy(arr) + assert result is arr + + +# --------------------------------------------------------------------------- +# Section 2: _DRIVER_TO_CLASS routing dictionary +# --------------------------------------------------------------------------- + + +def test_driver_to_class_sklearn_routes_to_anvil_workflow(): + """Test that DriverType.SKLEARN maps to AnvilWorkflow.""" + assert _DRIVER_TO_CLASS[DriverType.SKLEARN] is AnvilWorkflow + + +def test_driver_to_class_lightning_routes_to_dl_workflow(): + """Test that DriverType.LIGHTNING maps to AnvilDeepLearningWorkflow.""" + assert _DRIVER_TO_CLASS[DriverType.LIGHTNING] is AnvilDeepLearningWorkflow + + +def test_driver_to_class_has_exactly_two_entries(): + """Test that _DRIVER_TO_CLASS contains exactly the two expected driver keys.""" + assert set(_DRIVER_TO_CLASS.keys()) == {DriverType.SKLEARN, DriverType.LIGHTNING} + + +# --------------------------------------------------------------------------- +# Section 3: AnvilWorkflow happy-path construction +# --------------------------------------------------------------------------- + + +def test_anvil_workflow_constructs_with_real_components( + metadata, data_spec, sklearn_feat +): + """Test that AnvilWorkflow can be constructed from real lightweight registered components.""" + wf = _make_anvil_workflow(metadata, data_spec, sklearn_feat) + assert isinstance(wf, AnvilWorkflow) + + +def test_anvil_workflow_driver_type_is_sklearn(metadata, data_spec, sklearn_feat): + """Test that AnvilWorkflow correctly exposes the SKLEARN driver type.""" + wf = _make_anvil_workflow(metadata, data_spec, sklearn_feat) + assert wf._driver_type == DriverType.SKLEARN + + +# --------------------------------------------------------------------------- +# Section 4: AnvilWorkflow.check_if_val_needed validator +# --------------------------------------------------------------------------- + + +def test_anvil_workflow_ensemble_without_val_raises(metadata, data_spec, sklearn_feat): + """Test that constructing an ensemble AnvilWorkflow without a validation split raises ValueError.""" + split = ShuffleSplitter(train_size=0.8, val_size=0.0, test_size=0.2) + with pytest.raises(ValueError, match="Ensemble models require a validation set"): + _make_anvil_workflow( + metadata, + data_spec, + sklearn_feat, + split=split, + ensemble=CommitteeRegressor(), + ensemble_kwargs={"n_models": 2}, + ) + + +def test_anvil_workflow_val_without_ensemble_raises(metadata, data_spec, sklearn_feat): + """Test that requesting a validation split without an ensemble raises ValueError.""" + split = ShuffleSplitter(train_size=0.7, val_size=0.1, test_size=0.2) + with pytest.raises(ValueError, match="Validation set requested, but not used"): + _make_anvil_workflow( + metadata, data_spec, sklearn_feat, split=split, ensemble=None + ) + + +def test_anvil_workflow_ensemble_with_val_succeeds(metadata, data_spec, sklearn_feat): + """Test that an ensemble AnvilWorkflow with a validation split constructs without error.""" + split = ShuffleSplitter(train_size=0.7, val_size=0.1, test_size=0.2) + wf = _make_anvil_workflow( + metadata, + data_spec, + sklearn_feat, + split=split, + ensemble=CommitteeRegressor(), + ensemble_kwargs={"n_models": 2}, + ) + assert isinstance(wf, AnvilWorkflow) + assert wf.ensemble is not None + + +# --------------------------------------------------------------------------- +# Section 5: AnvilWorkflow.check_no_finetuning validator +# --------------------------------------------------------------------------- + + +# Single-model branch: all triggering combinations of path kwargs +@pytest.mark.parametrize( + "model_kwargs", + [ + {"param_path": "model.json"}, + {"serial_path": "model.pkl"}, + {"param_path": "model.json", "serial_path": "model.pkl"}, + ], + ids=["param-path-only", "serial-path-only", "both-paths"], +) +def test_anvil_workflow_single_model_finetuning_raises( + metadata, data_spec, sklearn_feat, model_kwargs +): + """Test that any finetuning path kwarg for a single model raises ValueError.""" + with pytest.raises( + ValueError, match="Finetuning from serialized model is not supported" + ): + _make_anvil_workflow( + metadata, data_spec, sklearn_feat, model_kwargs=model_kwargs + ) + + +# Single-model branch: safe kwargs that must never trigger the validator +@pytest.mark.parametrize( + "model_kwargs", + [ + {}, + {"n_estimators": 100}, + ], + ids=["empty-kwargs", "unrelated-key"], +) +def test_anvil_workflow_single_model_no_finetuning_succeeds( + metadata, data_spec, sklearn_feat, model_kwargs +): + """Test that empty or unrelated model_kwargs do not trigger the finetuning validator.""" + wf = _make_anvil_workflow( + metadata, data_spec, sklearn_feat, model_kwargs=model_kwargs + ) + assert isinstance(wf, AnvilWorkflow) + assert wf.model_kwargs == model_kwargs + + +# Ensemble branch: all triggering combinations of path kwargs +@pytest.mark.parametrize( + "path_kwargs", + [ + {"param_paths": ["p1.json", "p2.json"]}, + {"serial_paths": ["s1.pkl", "s2.pkl"]}, + {"param_paths": ["p1.json", "p2.json"], "serial_paths": ["s1.pkl", "s2.pkl"]}, + ], + ids=["param-paths-only", "serial-paths-only", "both-path-types"], +) +def test_anvil_workflow_ensemble_finetuning_raises( + metadata, data_spec, sklearn_feat, path_kwargs +): + """Test that any finetuning path kwarg for an ensemble raises ValueError.""" + split = ShuffleSplitter(train_size=0.7, val_size=0.1, test_size=0.2) + ensemble_kwargs = {"n_models": 2, **path_kwargs} + with pytest.raises( + ValueError, match="Finetuning from serialized ensemble models is not supported" + ): + _make_anvil_workflow( + metadata, + data_spec, + sklearn_feat, + split=split, + ensemble=CommitteeRegressor(), + ensemble_kwargs=ensemble_kwargs, + ) + + +# Ensemble branch: non-path ensemble_kwargs that must never trigger the validator +@pytest.mark.parametrize( + "ensemble_kwargs", + [ + {"n_models": 2}, + {"n_models": 2, "calibration_method": "isotonic-regression"}, + ], + ids=["n-models-only", "with-calibration-method"], +) +def test_anvil_workflow_ensemble_no_finetuning_succeeds( + metadata, data_spec, sklearn_feat, ensemble_kwargs +): + """Test that ensemble_kwargs containing only non-path keys do not trigger the finetuning validator.""" + split = ShuffleSplitter(train_size=0.7, val_size=0.1, test_size=0.2) + wf = _make_anvil_workflow( + metadata, + data_spec, + sklearn_feat, + split=split, + ensemble=CommitteeRegressor(), + ensemble_kwargs=ensemble_kwargs, + ) + assert isinstance(wf, AnvilWorkflow) + assert wf.ensemble_kwargs == ensemble_kwargs + + +# feat_kwargs default_factory coverage — any content must pass construction +@pytest.mark.parametrize( + "feat_kwargs", + [ + {}, + {"type": "FingerprintFeaturizer", "params": {"fp_type": "ecfp:4"}}, + ], + ids=["empty-feat-kwargs", "with-type-and-params"], +) +def test_anvil_workflow_feat_kwargs_passthrough( + metadata, data_spec, sklearn_feat, feat_kwargs +): + """Test that arbitrary feat_kwargs content does not affect workflow construction.""" + wf = _make_anvil_workflow( + metadata, data_spec, sklearn_feat, feat_kwargs=feat_kwargs + ) + assert isinstance(wf, AnvilWorkflow) + assert wf.feat_kwargs == feat_kwargs + + +# --------------------------------------------------------------------------- +# Section 6: AnvilDeepLearningWorkflow happy-path construction +# --------------------------------------------------------------------------- + + +def test_dl_workflow_constructs_with_real_components(metadata, data_spec, sklearn_feat): + """Test that AnvilDeepLearningWorkflow can be constructed from real lightweight registered components.""" + wf = _make_dl_workflow(metadata, data_spec, sklearn_feat) + assert isinstance(wf, AnvilDeepLearningWorkflow) + + +def test_dl_workflow_driver_type_is_lightning(metadata, data_spec, sklearn_feat): + """Test that AnvilDeepLearningWorkflow correctly exposes the LIGHTNING driver type.""" + wf = _make_dl_workflow(metadata, data_spec, sklearn_feat) + assert wf._driver_type == DriverType.LIGHTNING + + +# --------------------------------------------------------------------------- +# Section 7: AnvilDeepLearningWorkflow.check_no_transform validator +# --------------------------------------------------------------------------- + + +def test_dl_workflow_rejects_transform(metadata, data_spec, sklearn_feat): + """Test that specifying a transform step in a DL workflow raises ValueError.""" + with pytest.raises(ValueError, match="Transform step is not supported"): + _make_dl_workflow( + metadata, data_spec, sklearn_feat, transform=ImputeTransform() + ) + + +def test_dl_workflow_accepts_no_transform(metadata, data_spec, sklearn_feat): + """Test that a DL workflow without a transform step constructs successfully.""" + wf = _make_dl_workflow(metadata, data_spec, sklearn_feat, transform=None) + assert wf.transform is None + + +# --------------------------------------------------------------------------- +# Section 8: AnvilDeepLearningWorkflow.check_if_val_needed validator +# --------------------------------------------------------------------------- + + +def test_dl_workflow_ensemble_requires_val_raises(metadata, data_spec, sklearn_feat): + """Test that a DL ensemble workflow without a validation split raises ValueError.""" + split = ShuffleSplitter(train_size=0.8, val_size=0.0, test_size=0.2) + with pytest.raises(ValueError, match="Ensemble models require a validation set"): + _make_dl_workflow( + metadata, + data_spec, + sklearn_feat, + split=split, + ensemble=CommitteeRegressor(), + ) + + +def test_dl_workflow_ensemble_with_val_succeeds(metadata, data_spec, sklearn_feat): + """Test that a DL ensemble workflow with a validation split constructs successfully.""" + split = ShuffleSplitter(train_size=0.7, val_size=0.1, test_size=0.2) + wf = _make_dl_workflow( + metadata, + data_spec, + sklearn_feat, + split=split, + ensemble=CommitteeRegressor(), + ) + assert isinstance(wf, AnvilDeepLearningWorkflow) + assert wf.ensemble is not None + + +# --------------------------------------------------------------------------- +# Section 9: AnvilDeepLearningWorkflow.check_finetuning_paths validator +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "model_kwargs,match", + [ + ({"param_path": "/nonexistent/model.json"}, "must be provided together"), + ({"serial_path": "/nonexistent/model.pth"}, "must be provided together"), + ( + { + "param_path": "/nonexistent/model.json", + "serial_path": "/nonexistent/model.pth", + }, + "does not exist", + ), + ], + ids=["param-path-only", "serial-path-only", "both-nonexistent"], +) +def test_dl_workflow_single_model_finetuning_path_raises( + metadata, data_spec, sklearn_feat, model_kwargs, match +): + """Test that mismatched or nonexistent single-model finetuning paths raise ValueError.""" + with pytest.raises(ValueError, match=match): + _make_dl_workflow(metadata, data_spec, sklearn_feat, model_kwargs=model_kwargs) + + +def test_dl_workflow_single_model_finetuning_path_succeeds_no_paths( + metadata, data_spec, sklearn_feat +): + """Test that empty model_kwargs passes finetuning path validation.""" + wf = _make_dl_workflow(metadata, data_spec, sklearn_feat, model_kwargs={}) + assert isinstance(wf, AnvilDeepLearningWorkflow) + assert wf.model_kwargs == {} + + +def test_dl_workflow_single_model_finetuning_path_succeeds_both_exist( + metadata, data_spec, sklearn_feat, tmp_path +): + """Test that both finetuning paths pointing to real files passes validation.""" + param_file = tmp_path / "model.json" + serial_file = tmp_path / "model.pth" + param_file.touch() + serial_file.touch() + + model_kwargs = {"param_path": str(param_file), "serial_path": str(serial_file)} + wf = _make_dl_workflow(metadata, data_spec, sklearn_feat, model_kwargs=model_kwargs) + assert isinstance(wf, AnvilDeepLearningWorkflow) + assert wf.model_kwargs == model_kwargs + + +@pytest.mark.parametrize( + "path_kwargs,match", + [ + ( + {"param_paths": ["/nonexistent/p1.json", "/nonexistent/p2.json"]}, + "must be provided together", + ), + ( + {"serial_paths": ["/nonexistent/s1.pth", "/nonexistent/s2.pth"]}, + "must be provided together", + ), + ( + { + "param_paths": ["/nonexistent/p1.json", "/nonexistent/p2.json"], + "serial_paths": ["/nonexistent/s1.pth"], + }, + "equal length", + ), + ( + { + "param_paths": ["/nonexistent/p1.json", "/nonexistent/p2.json"], + "serial_paths": ["/nonexistent/s1.pth", "/nonexistent/s2.pth"], + }, + "does not exist", + ), + ], + ids=[ + "param-paths-only", + "serial-paths-only", + "unequal-lengths", + "both-nonexistent", + ], +) +def test_dl_workflow_ensemble_finetuning_path_raises( + metadata, data_spec, sklearn_feat, path_kwargs, match +): + """Test that mismatched, unequal-length, or nonexistent ensemble finetuning paths raise ValueError.""" + split = ShuffleSplitter(train_size=0.7, val_size=0.1, test_size=0.2) + ensemble_kwargs = {"n_models": 2, **path_kwargs} + with pytest.raises(ValueError, match=match): + _make_dl_workflow( + metadata, + data_spec, + sklearn_feat, + split=split, + ensemble=CommitteeRegressor(), + ensemble_kwargs=ensemble_kwargs, + ) + + +def test_dl_workflow_ensemble_finetuning_path_succeeds_no_paths( + metadata, data_spec, sklearn_feat +): + """Test that ensemble_kwargs with no path keys passes finetuning path validation.""" + split = ShuffleSplitter(train_size=0.7, val_size=0.1, test_size=0.2) + wf = _make_dl_workflow( + metadata, + data_spec, + sklearn_feat, + split=split, + ensemble=CommitteeRegressor(), + ensemble_kwargs={"n_models": 2}, + ) + assert isinstance(wf, AnvilDeepLearningWorkflow) + assert wf.ensemble is not None + + +def test_dl_workflow_ensemble_finetuning_path_succeeds_both_exist( + metadata, data_spec, sklearn_feat, tmp_path +): + """Test that ensemble finetuning paths pointing to real files passes validation.""" + p1, p2 = tmp_path / "m0.json", tmp_path / "m1.json" + s1, s2 = tmp_path / "m0.pth", tmp_path / "m1.pth" + for f in [p1, p2, s1, s2]: + f.touch() + + split = ShuffleSplitter(train_size=0.7, val_size=0.1, test_size=0.2) + ensemble_kwargs = { + "n_models": 2, + "param_paths": [str(p1), str(p2)], + "serial_paths": [str(s1), str(s2)], + } + wf = _make_dl_workflow( + metadata, + data_spec, + sklearn_feat, + split=split, + ensemble=CommitteeRegressor(), + ensemble_kwargs=ensemble_kwargs, + ) + assert isinstance(wf, AnvilDeepLearningWorkflow) + assert wf.ensemble_kwargs == ensemble_kwargs diff --git a/openadmet/models/tests/unit/anvil/test_workflow_base.py b/openadmet/models/tests/unit/anvil/test_workflow_base.py new file mode 100644 index 00000000..4c8c361b --- /dev/null +++ b/openadmet/models/tests/unit/anvil/test_workflow_base.py @@ -0,0 +1,145 @@ +import pytest +from pydantic import ConfigDict + +from openadmet.models.active_learning.committee import CommitteeRegressor +from openadmet.models.anvil.specification import DataSpec, Metadata +from openadmet.models.anvil.workflow_base import AnvilWorkflowBase +from openadmet.models.architecture.dummy import DummyRegressorModel +from openadmet.models.drivers import DriverType +from openadmet.models.eval.cross_validation import ( + PytorchLightningRepeatedKFoldCrossValidation, + SKLearnRepeatedKFoldCrossValidation, +) +from openadmet.models.eval.regression import RegressionMetrics +from openadmet.models.features.molfeat_fingerprint import FingerprintFeaturizer +from openadmet.models.split.sklearn import ShuffleSplitter +from openadmet.models.trainer.lightning import LightningTrainer +from openadmet.models.trainer.sklearn import SKlearnBasicTrainer + + +# Concrete workflow used to test the abstract base validation logic +class ConcreteWorkflow(AnvilWorkflowBase): + model_config = ConfigDict(arbitrary_types_allowed=True) + + def run(self, output_dir="anvil_training", debug=False): + return "ran" + + +# Minimal metadata for testing +def get_minimal_metadata(): + return Metadata( + version="v1", + driver="sklearn", + name="test", + build_number=0, + description="desc", + tag="tag", + authors="auth", + email="a@b.com", + biotargets=[], + tags=[], + ) + + +# Helper to build a workflow with real lightweight components as defaults +def build_workflow( + *, + model=None, + trainer=None, + evals=None, + ensemble=None, + target_cols=["target"], +): + if model is None: + model = DummyRegressorModel() + if trainer is None: + trainer = SKlearnBasicTrainer() + if evals is None: + evals = [RegressionMetrics()] + return ConcreteWorkflow( + metadata=get_minimal_metadata(), + data_spec=DataSpec( + type="csv", + input_col="smiles", + target_cols=target_cols, + resource="data.csv", + ), + split=ShuffleSplitter(), + feat=FingerprintFeaturizer(fp_type="ecfp"), + model=model, + trainer=trainer, + evals=evals, + ensemble=ensemble, + ) + + +# --- Tests --- + + +def test_multitask_check_passes_when_counts_match(): + """Test that validation passes when model n_tasks matches data target_cols.""" + model = DummyRegressorModel() + model._n_tasks = 2 + workflow = build_workflow(model=model, target_cols=["t1", "t2"]) + assert workflow + + +def test_multitask_check_raises_when_counts_mismatch(): + """Test that validation raises ValueError when n_tasks does not match target_cols.""" + model = DummyRegressorModel() + model._n_tasks = 2 + with pytest.raises(ValueError, match="tasks but the data specification has"): + build_workflow(model=model, target_cols=["t1", "t2", "t3"]) + + +def test_no_ensemble_cross_val_raises_when_both_present(): + """Test that using ensemble with cross-validation raises ValueError.""" + with pytest.raises( + ValueError, match="Ensemble models cannot be used with cross-validation" + ): + build_workflow( + ensemble=CommitteeRegressor(), + evals=[SKLearnRepeatedKFoldCrossValidation()], + ) + + +def test_no_ensemble_cross_val_allows_cv_without_ensemble(): + """Test that cross-validation is allowed if no ensemble is present.""" + workflow = build_workflow( + evals=[SKLearnRepeatedKFoldCrossValidation()], ensemble=None + ) + assert workflow + + +def test_model_trainer_driver_mismatch_raises(): + """Test that mismatched model and trainer drivers raise ValueError.""" + with pytest.raises(ValueError, match="Model driver type .* does not match trainer"): + build_workflow(trainer=LightningTrainer()) + + +def test_model_trainer_driver_match_succeeds(): + """Test that matching model and trainer drivers succeed.""" + workflow = build_workflow( + model=DummyRegressorModel(), trainer=SKlearnBasicTrainer() + ) + assert workflow + + +def test_cv_trainer_compatibility_raises_on_driver_mismatch(): + """Test that a CV evaluator with a mismatched trainer driver raises ValueError.""" + with pytest.raises( + ValueError, match="Trainer driver type .* does not match evaluation" + ): + build_workflow( + trainer=SKlearnBasicTrainer(), + evals=[PytorchLightningRepeatedKFoldCrossValidation()], + ) + + +def test_cv_trainer_compatibility_ignores_non_cv_evals(): + """Test that non-CV evaluators do not trigger driver mismatch checks.""" + workflow = build_workflow( + trainer=SKlearnBasicTrainer(), + evals=[RegressionMetrics()], + ) + assert workflow diff --git a/openadmet/models/tests/unit/cli/test_cli.py b/openadmet/models/tests/unit/cli/test_cli.py index 73da2cec..8e430025 100644 --- a/openadmet/models/tests/unit/cli/test_cli.py +++ b/openadmet/models/tests/unit/cli/test_cli.py @@ -1,59 +1,87 @@ -from openadmet.models.cli.cli import cli -from openadmet.models.tests.test_utils import click_success -from openadmet.models.tests.unit.datafiles import ( - anvil_lgbm_trained_model_dir, - pred_test_data_csv, - basic_anvil_yaml_cv, -) +from pathlib import Path + import pytest from click.testing import CliRunner +from openadmet.models.cli import anvil as anvil_cli_module +from openadmet.models.cli import predict as predict_cli_module +from openadmet.models.cli.cli import cli +from openadmet.models.tests.test_utils import click_success +from openadmet.models.tests.unit.datafiles import basic_anvil_yaml_cv + + +@pytest.fixture +def runner(): + """Provide a Click CliRunner for testing CLI commands in isolation.""" + return CliRunner() -def test_toplevel_runnable(): - """Test the top-level CLI command""" - runner = CliRunner() + +def test_toplevel_runnable(runner): + """Ensure the top-level 'openadmet' command runs and displays help without error.""" result = runner.invoke(cli, ["--help"]) assert click_success(result) @pytest.mark.parametrize( - "args", - [ - ["anvil", "--help"], - ["compare", "--help"], - ["predict", "--help"], - ], + "args", [["anvil", "--help"], ["compare", "--help"], ["predict", "--help"]] ) -def test_subcommand_runnable(args): - """Test the subcommands""" - runner = CliRunner() +def test_subcommand_runnable(runner, args): + """Verify that all major subcommands (anvil, compare, predict) are registered and runnable.""" result = runner.invoke(cli, args) assert click_success(result) -def test_predict_cli(tmp_path): - """Test the predict CLI command""" - runner = CliRunner() +def test_predict_cli_invokes_inference(tmp_path, runner, mocker): + """ + Validate that the 'predict' subcommand correctly parses arguments and calls the underlying inference function. + + We mock `inference_func` to avoid loading real models (which is heavy and requires trained artifacts). + This ensures that the CLI layer correctly passes paths, column names, and flags to the logic layer. + """ + input_csv = tmp_path / "input.csv" + input_csv.write_text("MY_SMILES\nCCO\n") + model_dir = tmp_path / "model_dir" + model_dir.mkdir() + + mock_inference = mocker.patch.object(predict_cli_module, "inference_func") + result = runner.invoke( cli, [ "predict", "--input-path", - pred_test_data_csv, + input_csv, "--input-col", "MY_SMILES", "--model-dir", - anvil_lgbm_trained_model_dir, + model_dir, "--output-csv", tmp_path / "predictions.csv", + "--accelerator", + "cpu", ], ) assert click_success(result) + mock_inference.assert_called_once() + called = mock_inference.call_args.kwargs + assert called["input_col"] == "MY_SMILES" + assert called["accelerator"] == "cpu" + assert called["write_csv"] is True + assert list(called["model_dir"]) == [model_dir] -def test_anvil_cli(tmp_path): - """Test the anvil CLI command""" - runner = CliRunner() +def test_anvil_cli_invokes_workflow(tmp_path, runner, mocker): + """ + Validate that the 'anvil' subcommand correctly initializes and runs a workflow from a recipe. + + We mock the `AnvilSpecification` and workflow execution to verify that the CLI correctly handles + recipe paths and output directories without actually running a full ML training job. + """ + mock_spec = mocker.Mock() + mock_from_recipe = mocker.patch.object( + anvil_cli_module.AnvilSpecification, "from_recipe", return_value=mock_spec + ) + result = runner.invoke( cli, [ @@ -66,3 +94,48 @@ def test_anvil_cli(tmp_path): ) assert click_success(result) + mock_from_recipe.assert_called_once_with(basic_anvil_yaml_cv) + mock_spec.run.assert_called_once() + called = mock_spec.run.call_args.kwargs + assert Path(called["output_dir"]) == tmp_path / "anvil_output" + assert called["debug"] is False + + +@pytest.mark.parametrize( + "aq_fxns,beta,best_y,xi,expected", + [ + (("ucb",), (2.0,), (), (), {"ucb": {"beta": 2.0}}), + ( + ("ei", "pi"), + (), + (1.0, 2.0), + (0.1, 0.2), + {"ei": {"xi": 0.1, "best_y": 1.0}, "pi": {"xi": 0.2, "best_y": 2.0}}, + ), + ], +) +def test_validate_aq_fxns_success(aq_fxns, beta, best_y, xi, expected): + """ + Verify that valid combinations of acquisition function arguments are correctly parsed into a configuration dict. + + This tests the CLI argument validation logic for active learning parameters. + """ + assert predict_cli_module._validate_aq_fxns(aq_fxns, beta, best_y, xi) == expected + + +@pytest.mark.parametrize( + "aq_fxns,beta,best_y,xi,error_message", + [ + (("ucb", "ucb"), (1.0, 2.0), (), (), "UCB can only be specified once"), + (("ei",), (), (), (), "must be specified once per EI and/or PI acquisition"), + (("ucb",), (), (), (), "Field `beta` must be specified for UCB acquisition"), + ], +) +def test_validate_aq_fxns_errors(aq_fxns, beta, best_y, xi, error_message): + """ + Ensure that invalid acquisition function arguments trigger appropriate validation errors. + + This prevents users from running predictions with ambiguous or incomplete active learning settings. + """ + with pytest.raises(ValueError, match=error_message): + predict_cli_module._validate_aq_fxns(aq_fxns, beta, best_y, xi) diff --git a/openadmet/models/tests/unit/comparison/test_comparison.py b/openadmet/models/tests/unit/comparison/test_comparison.py index 27d93900..cd9949cc 100644 --- a/openadmet/models/tests/unit/comparison/test_comparison.py +++ b/openadmet/models/tests/unit/comparison/test_comparison.py @@ -1,26 +1,32 @@ +import numpy as np import pytest from numpy.testing import assert_almost_equal -import numpy as np + from openadmet.models.comparison.compare_base import get_comparison_class from openadmet.models.comparison.posthoc import PostHocComparison from openadmet.models.tests.unit.datafiles import ( - cyp2c9_json, + anvil_lgbm_trained_model_dir, cyp1a2_json, + cyp2c9_json, cyp3a4_json, multi_task_json, - anvil_lgbm_trained_model_dir, ) def test_get_comparison_class(): - """Test getting comparison class.""" + """ + Test dynamic retrieval of comparison classes from the registry. + + Verifies that valid class names return the class and invalid names raise ValueError. + """ get_comparison_class("PostHoc") with pytest.raises(ValueError): get_comparison_class("NotARealClass") def test_posthoc_fails_on_incorrect_inputs(): - """Test that posthoc comparison fails when given incorrect inputs. + """ + Test that posthoc comparison fails when given incorrect inputs. Inputs include: - No inputs @@ -28,6 +34,8 @@ def test_posthoc_fails_on_incorrect_inputs(): - Mismatched lengths of model_stats_fns, labels, and task_names - Repeated labels - Incorrect labels and task_names for model_stats_fns + + This validation is critical to ensure that comparison tables and plots match models to their correct metadata. """ comp_obj = PostHocComparison() with pytest.raises(ValueError): @@ -69,7 +77,12 @@ def test_posthoc_repeat_label_error(): def test_posthoc_comparison(): - """Test that posthoc comparison works when given correct inputs.""" + """ + Test that posthoc comparison works correctly when given valid inputs. + + This verifies the calculation of statistical tests (Levene's test for equality of variances, + Tukey's HSD for pairwise mean differences) based on loaded model metrics. + """ model_stats = [cyp2c9_json, cyp3a4_json, cyp1a2_json] model_tags = [ "openadmet-CYP2C9-pchembl-regression-testing-cv", @@ -97,7 +110,12 @@ def test_posthoc_comparison(): def test_posthoc_comparison_anvil_reader_and_feature_label( label_types, expected_labels ): - """Test that posthoc comparison can read from anvil-trained model directories and features.""" + """ + Test that posthoc comparison can automatically extract labels from anvil-trained model directories. + + This ensures that metadata stored in `metadata.yaml` within model directories can be correctly + parsed to generate readable labels for comparison plots. + """ comp_obj = PostHocComparison() model_stats_fns, labels, task_names = comp_obj.label_and_task_name_from_anvil( model_dirs=[anvil_lgbm_trained_model_dir], label_types=label_types @@ -121,7 +139,12 @@ def test_posthoc_comparison_json_reader_fails(label_types): def test_posthoc_comparison_json_reader(): - """Test that posthoc comparison can read multi vs single task from anvil file.""" + """ + Test that posthoc comparison handles both multi-task and single-task JSON result files. + + This verifies that the system can normalize results from different task types into a common + format for statistical comparison. + """ model_stats = [multi_task_json, cyp3a4_json] model_tags = ["multitask", "single_task"] task_tags = ["cyp3a4_pchembl_value_mean", "pchembl_value_mean"] @@ -130,14 +153,18 @@ def test_posthoc_comparison_json_reader(): levene, tukeys_df = comp_obj.compare( model_stats_fns=model_stats, labels=model_tags, task_names=task_tags ) - assert levene["mse"][0] == 2.483488460351842 - assert levene["ktau"][0] == 1.0392615736603197 - assert tukeys_df["metric_val"][0] == -0.01037444780666702 - assert tukeys_df["pvalue"][0] == 0.2488307785417857 + assert levene["mse"][0] == pytest.approx(2.483, abs=0.001) + assert levene["ktau"][0] == pytest.approx(1.039, abs=0.001) + assert tukeys_df["metric_val"][0] == pytest.approx(-0.010, abs=0.001) + assert tukeys_df["pvalue"][0] == pytest.approx(0.248, abs=0.001) def test_posthoc_comparison_printing(capsys): - """Test that posthoc comparison prints results to console.""" + """ + Test that posthoc comparison prints results to console in a readable format. + + We capture stdout to verify that Levene's test and Tukey's HSD results are actually displayed to the user. + """ model_stats = [cyp2c9_json, cyp3a4_json, cyp1a2_json] model_tags = [ "openadmet-CYP2C9-pchembl-regression-testing-cv", diff --git a/openadmet/models/tests/unit/data/test_data.py b/openadmet/models/tests/unit/data/test_data.py index d714bc7e..fa310a66 100644 --- a/openadmet/models/tests/unit/data/test_data.py +++ b/openadmet/models/tests/unit/data/test_data.py @@ -5,6 +5,12 @@ def test_data_spec_from_csv(): + """ + Validate loading data from a CSV file via DataSpec. + + Ensures that the data loader correctly reads the specified CSV, extracts the target and SMILES columns, + and returns them as expected. + """ data_spec = DataSpec( type="intake", resource=test_csv, @@ -18,6 +24,12 @@ def test_data_spec_from_csv(): def test_data_spec_from_intake(): + """ + Validate loading data from an Intake catalog. + + Intake allows for declarative data loading. This test checks that DataSpec can correctly interface + with an Intake catalog to retrieve data. + """ data_spec = DataSpec( type="intake", resource=intake_cat, @@ -32,6 +44,12 @@ def test_data_spec_from_intake(): @pytest.mark.parametrize("dropna, expected_length", [(True, 3333), (False, 7196)]) def test_data_spec_dropna(dropna, expected_length): + """ + Test the `dropna` functionality in DataSpec. + + Verifies that rows with missing values in target columns are dropped when dropna=True, + and preserved when dropna=False. This is critical for handling real-world datasets which often contain gaps. + """ data_spec = DataSpec( type="intake", resource=nan_data, diff --git a/openadmet/models/tests/unit/eval/test_eval.py b/openadmet/models/tests/unit/eval/test_eval.py index 1e89053e..9de3ef17 100644 --- a/openadmet/models/tests/unit/eval/test_eval.py +++ b/openadmet/models/tests/unit/eval/test_eval.py @@ -1,6 +1,8 @@ +import matplotlib.figure +import numpy as np import pytest +import seaborn as sns -import numpy as np from openadmet.models.eval.binary import PosthocBinaryMetrics from openadmet.models.eval.classification import ( ClassificationMetrics, @@ -11,34 +13,57 @@ def test_get_eval_class(): + """Verify that evaluation classes can be retrieved by name from the registry.""" get_eval_class("RegressionMetrics") get_eval_class("PosthocBinaryMetrics") get_eval_class("ClassificationMetrics") def test_regression_metrics(): + """ + Validate calculation of standard regression metrics (MSE, MAE, R2). + + This test uses simple synthetic data to ensure that the mathematical implementations + of these metrics are correct and return the expected values. + """ y_true = np.array([3, -0.5, 2, 7]).reshape(-1, 1) y_pred = np.array([2.5, 0.0, 2, 8]).reshape(-1, 1) rm = RegressionMetrics() metrics = rm.evaluate(y_true, y_pred) - assert metrics["task_0"]["mse"]["value"] == 0.375 - assert metrics["task_0"]["mae"]["value"] == 0.5 - assert metrics["task_0"]["r2"]["value"] == 0.9486081370449679 + assert metrics["task_0"]["mse"]["value"] == pytest.approx(0.375, abs=0.001) + assert metrics["task_0"]["mae"]["value"] == pytest.approx(0.5, abs=0.001) + assert metrics["task_0"]["r2"]["value"] == pytest.approx(0.94860, abs=0.001) def test_regression_plots(): + """ + Verify that regression plotting functions return valid figure objects. + + This ensures that regression plots (JointGrid for parity, Figure for CI) are generated + without error, which is important for model reporting. + """ y_true = np.array([3, -0.5, 2, 7]).reshape(-1, 1) y_pred = np.array([2.5, 0.0, 2, 8]).reshape(-1, 1) rm = RegressionPlots() - rm.evaluate(y_true, y_pred) + plot_data = rm.evaluate(y_true, y_pred) - assert True + assert isinstance(plot_data, dict) + assert "task_0_regplot" in plot_data + assert "task_0_ciplot" in plot_data + assert isinstance(plot_data["task_0_regplot"], sns.axisgrid.JointGrid) + assert isinstance(plot_data["task_0_ciplot"], matplotlib.figure.Figure) def test_classification_metrics(): + """ + Validate calculation of classification metrics (Accuracy, Precision, Recall, F1, AUC). + + This ensures that for binary classification tasks, the metrics are computed correctly based on + predicted probabilities and ground truth labels. + """ y_true = [0, 1, 1, 0] # We pass probabilities of the class, not the class itself @@ -48,25 +73,38 @@ def test_classification_metrics(): cm = ClassificationMetrics() metrics = cm.evaluate(y_true, y_pred) - assert metrics["accuracy"]["value"] == 0.75 + assert metrics["accuracy"]["value"] == pytest.approx(0.75) assert metrics["precision"]["value"] == pytest.approx(0.667, abs=0.001) - assert metrics["recall"]["value"] == 1.0 - assert metrics["f1"]["value"] == 0.8 - assert metrics["roc_auc"]["value"] == 0.75 + assert metrics["recall"]["value"] == pytest.approx(1.0) + assert metrics["f1"]["value"] == pytest.approx(0.8) + assert metrics["roc_auc"]["value"] == pytest.approx(0.75) assert metrics["pr_auc"]["value"] == pytest.approx(0.833, abs=0.001) def test_classification_plots(): + """ + Verify that classification plotting functions (ROC, PR curves) return valid figure objects. + """ y_true = [0, 1, 1, 0] y_pred = [[1, 0], [0, 1], [0, 1], [0, 1]] cp = ClassificationPlots() cp.evaluate(y_true, y_pred) - assert True + assert isinstance(cp.plot_data, dict) + assert "roc_curve" in cp.plot_data + assert "pr_curve" in cp.plot_data + assert isinstance(cp.plot_data["roc_curve"], matplotlib.figure.Figure) + assert isinstance(cp.plot_data["pr_curve"], matplotlib.figure.Figure) def test_posthoc_eval_metrics(): + """ + Test post-hoc binary metrics utility functions. + + Verifies that we can calculate precision and recall at a specific cutoff threshold from + regression-like outputs (or probabilities). + """ y_true = [3, -0.5, 2, 7] y_pred = [2.5, 0.0, 2, 8] cutoff = 4.0 diff --git a/openadmet/models/tests/unit/features/test_features.py b/openadmet/models/tests/unit/features/test_features.py index b3c6c04e..ea13a1be 100644 --- a/openadmet/models/tests/unit/features/test_features.py +++ b/openadmet/models/tests/unit/features/test_features.py @@ -10,17 +10,25 @@ @pytest.fixture() def smiles(): + """Provide a list of valid SMILES strings for testing featurization.""" return ["CCO", "CCN", "CCO"] @pytest.fixture() def one_invalid_smi(): + """Provide a list of SMILES strings containing one invalid entry to test error handling.""" return ["CCO", "CCN", "invalid", "CCO"] @pytest.mark.parametrize("dtype", (np.float32, np.float64)) @pytest.mark.parametrize("descr_type", ["mordred", "desc2d"]) def test_descriptor_featurizer(descr_type, dtype): + """ + Validate DescriptorFeaturizer for different descriptor types and floating point precisions. + + This ensures that physical-chemical descriptors (like Mordred or RDKit 2D) are correctly generated + and returned with the requested data type, which is important for downstream model compatibility. + """ featurizer = DescriptorFeaturizer(descr_type=descr_type, dtype=dtype) X, idx = featurizer.featurize(["CCO", "CCN", "CCO"]) assert X.dtype == dtype @@ -28,6 +36,12 @@ def test_descriptor_featurizer(descr_type, dtype): def test_descriptor_one_invalid(one_invalid_smi): + """ + Ensure DescriptorFeaturizer robustly handles invalid SMILES strings. + + The featurizer should skip invalid molecules and return indices corresponding only to the valid ones. + This prevents the entire pipeline from crashing due to a single bad input. + """ featurizer = DescriptorFeaturizer(descr_type="mordred") X, idx = featurizer.featurize(one_invalid_smi) assert X.shape == (3, 1613) @@ -38,6 +52,12 @@ def test_descriptor_one_invalid(one_invalid_smi): @pytest.mark.parametrize("dtype", (np.float32, np.float64)) @pytest.mark.parametrize("fp_type", ("ecfp", "fcfp")) def test_fingerprint_featurizer(smiles, fp_type, dtype): + """ + Validate FingerprintFeaturizer for different fingerprint types (ECFP, FCFP) and precisions. + + This verifies that structural fingerprints are correctly generated with the expected vector size (2000) + and data type. + """ featurizer = FingerprintFeaturizer(fp_type=fp_type, dtype=dtype) X, idx = featurizer.featurize(smiles) assert X.shape == (3, 2000) @@ -46,6 +66,11 @@ def test_fingerprint_featurizer(smiles, fp_type, dtype): def test_fingerprint_one_invalid(one_invalid_smi): + """ + Ensure FingerprintFeaturizer robustly handles invalid SMILES strings. + + Similar to descriptors, it should filter out invalid entries and return correct indices for valid ones. + """ featurizer = FingerprintFeaturizer(fp_type="ecfp") X, idx = featurizer.featurize(one_invalid_smi) assert X.shape == (3, 2000) @@ -54,6 +79,12 @@ def test_fingerprint_one_invalid(one_invalid_smi): def test_feature_concatenator(smiles): + """ + Validate that FeatureConcatenator correctly combines multiple feature sets (descriptors + fingerprints). + + This ensures that different feature representations can be stacked horizontally for the same molecules, + providing a richer feature set for training. + """ desc_featurizer = DescriptorFeaturizer(descr_type="mordred") fp_featurizer = FingerprintFeaturizer(fp_type="ecfp") concat = FeatureConcatenator(featurizers=[desc_featurizer, fp_featurizer]) @@ -62,17 +93,54 @@ def test_feature_concatenator(smiles): assert_array_equal(idx, np.arange(3)) -def test_feature_concatenator_failed_diff_positions(one_invalid_smi): +def test_feature_concatenator_drops_intersection(mocker): + """ + Verify that FeatureConcatenator only keeps molecules valid across ALL featurizers. + + If one featurizer fails for molecule A and another fails for molecule B, the concatenator + must drop both A and B to maintain feature alignment. + + We mock the underlying featurizers to control which indices fail, avoiding the need for + complex real-world molecules that fail specific featurizers. This isolates the intersection logic. + """ + # Arrange desc_featurizer = DescriptorFeaturizer(descr_type="mordred") fp_featurizer = FingerprintFeaturizer(fp_type="ecfp") concat = FeatureConcatenator(featurizers=[desc_featurizer, fp_featurizer]) - X, idx = concat.featurize(one_invalid_smi) - assert X.shape == (3, 3613) - # index 2 is invalid, so the shape should be 3 - assert_array_equal(idx, np.asarray([0, 1, 3])) + + # Mock descriptor featurizer to return 3 valid outputs (fails on index 1) + # Indices: 0, 2, 3 (skips 1) + desc_features = np.zeros((3, 1613)) + desc_indices = np.array([0, 2, 3]) + mocker.patch.object( + DescriptorFeaturizer, "featurize", return_value=(desc_features, desc_indices) + ) + + # Mock fingerprint featurizer to return 3 valid outputs (fails on index 2) + # Indices: 0, 1, 3 (skips 2) + # Note: ECFP size is 2000 in this codebase + fp_features = np.zeros((3, 2000)) + fp_indices = np.array([0, 1, 3]) + mocker.patch.object( + FingerprintFeaturizer, "featurize", return_value=(fp_features, fp_indices) + ) + + smiles = ["SMI0", "SMI1", "SMI2", "SMI3"] + + # Act + X, idx = concat.featurize(smiles) + + # Assert + # Intersection of [0, 2, 3] and [0, 1, 3] is [0, 3] + # Expected shape: (2, 1613 + 2000) = (2, 3613) + assert X.shape == (2, 3613) + assert_array_equal(idx, np.array([0, 3])) def test_feature_concatenator_order_independence(smiles): + """ + Ensure that changing the order of featurizers in the list results in the same outcome due to sorting. + """ desc_featurizer = DescriptorFeaturizer(descr_type="mordred") fp_featurizer = FingerprintFeaturizer(fp_type="ecfp") @@ -87,6 +155,12 @@ def test_feature_concatenator_order_independence(smiles): def test_pairwise_featurizer(smiles): + """ + Validate PairwiseFeaturizer in 'full' mode (all-pairs). + + This tests that features are generated for every pair of molecules and that target values + (differences) are correctly computed. + """ featurizer = PairwiseFeaturizer( featurizer={"FingerprintFeaturizer": {"fp_type": "ecfp", "dtype": np.float32}}, how_to_pair="full", diff --git a/openadmet/models/tests/unit/features/test_mtenn.py b/openadmet/models/tests/unit/features/test_mtenn.py index a8d89a3e..d108f06e 100644 --- a/openadmet/models/tests/unit/features/test_mtenn.py +++ b/openadmet/models/tests/unit/features/test_mtenn.py @@ -1,58 +1,85 @@ import numpy as np -import pytest import pandas as pd -from openadmet.models.features.mtenn import MTENNFeaturizer, MTENNDataset -from openadmet.models.tests.unit.datafiles import ligand_pose +import pytest +import torch +from openadmet.models.features.mtenn import MTENNDataset, MTENNFeaturizer -@pytest.fixture() -def cyp3a4_pose(): - """Fixture for ligand pose""" - return ligand_pose +@pytest.fixture +def mock_complex_features(mocker): + """ + Patch MTENN complex loading with lightweight synthetic tensors. -def test_mtenn_dataset(cyp3a4_pose): - """Test MTENNDataset class for basic functionality""" - # Create a mock dataset, with two identical complexes and a single target value - complexes = [cyp3a4_pose, cyp3a4_pose] - y = np.asarray([42, 43]) - dataset = MTENNDataset(complexes, y, ligand_resname="X5Y", ignore_h=True) + We mock `_load_complexes` to avoid needing actual PDB/SDF files and heavy RDKit/OpenBabel parsing. + This isolates the MTENNDataset and MTENNFeaturizer logic, allowing us to verify data structuring + and tensor shapes without file I/O overhead. + """ + pos = torch.randn(5, 3) + z = torch.tensor([6, 6, 8, 1, 1], dtype=torch.int32) + b = torch.ones(5, dtype=torch.float32) + lig_mask = torch.tensor([False, False, True, True, True], dtype=torch.bool) - # Check the length of the dataset - assert len(dataset) == 2 - # Check the shape of the features + def _mock_load_complexes(complexes, ligand_resname, ignore_h=True): + n = len(complexes) + return ( + [pos.clone() for _ in range(n)], + [z.clone() for _ in range(n)], + [b.clone() for _ in range(n)], + [lig_mask.clone() for _ in range(n)], + ) - feats = next(iter(dataset)) + mocker.patch.object( + MTENNDataset, "_load_complexes", side_effect=_mock_load_complexes + ) - assert feats["Y"] == 42 - assert feats["lig_mask"].numpy().shape == (3695,) - assert feats["pos"].numpy().shape == (3695, 3) - assert feats["Z"].numpy().shape == (3695,) - assert feats["B"].numpy().shape == (3695,) + return pos, z, b, lig_mask - # check the ligand mask, 38 atoms in the ligand - assert feats["lig_mask"].numpy().sum() == 38 +def test_mtenn_dataset(mock_complex_features): + """ + Validate that MTENNDataset correctly constructs data items from complex features. -def test_mtenn_featurizer(cyp3a4_pose): - ft = MTENNFeaturizer( - ligand_resname="X5Y", + This ensures that the dataset class properly organizes positions, atomic numbers, + and masks into the dictionary format expected by MTENN models. + """ + pos, z, b, lig_mask = mock_complex_features + dataset = MTENNDataset( + ["complex_a", "complex_b"], + np.asarray([42, 43]), + ligand_resname="LIG", ignore_h=True, ) - dataloader, _, _, _ = ft.featurize([cyp3a4_pose], pd.Series([42])) + assert len(dataset) == 2 + feats = dataset[0] - # Check the length of the dataloader - assert len(dataloader) == 1 - # Check the shape of the features - feats, y = next(iter(dataloader)) + assert feats["Y"] == 42 + assert feats["pos"].shape == pos.shape + assert torch.equal(feats["Z"], z) + assert torch.equal(feats["B"], b) + assert torch.equal(feats["lig_mask"], lig_mask) + + +def test_mtenn_featurizer(mock_complex_features): + """ + Validate the MTENNFeaturizer high-level interface. + + This checks that the featurizer correctly instantiates the dataset and data loader, + returning formatted batches ready for training. + """ + ft = MTENNFeaturizer(ligand_resname="LIG", ignore_h=True, batch_size=2, n_jobs=0) + dataloader, idx, scaler, dataset = ft.featurize( + ["complex_a", "complex_b"], pd.Series([42.0, 43.0]) + ) - assert y.item() == 42 - assert feats[0]["lig"].numpy().shape == (3695,) - assert feats[0]["pos"].numpy().shape == (3695, 3) - assert feats[0]["z"].numpy().shape == (3695,) + assert len(dataset) == 2 + assert len(dataloader) == 1 + assert np.array_equal(idx, np.array([0, 1])) + assert scaler is None - ##The following are not returned from featurizer - # assert feats["B"].numpy().shape == (1, 3695) - # check the ligand mask, 38 atoms in the ligand - # assert feats["lig_mask"].numpy().sum() == 38 + feats, y = next(iter(dataloader)) + assert y.shape == (2, 1) + assert feats[0]["pos"].shape[1] == 3 + assert feats[0]["z"].ndim == 1 + assert feats[0]["lig"].dtype == torch.bool diff --git a/openadmet/models/tests/unit/features/test_nepare.py b/openadmet/models/tests/unit/features/test_nepare.py index 61cfa787..2e926b0b 100644 --- a/openadmet/models/tests/unit/features/test_nepare.py +++ b/openadmet/models/tests/unit/features/test_nepare.py @@ -9,6 +9,12 @@ def test_pairwise_make_new(): + """ + Verify that PairwiseFeaturizer can create a new independent instance via make_new(). + + This is important for factory-like creation patterns in the registry or during cross-validation + where fresh featurizers are needed. + """ featurizer = PairwiseFeaturizer( how_to_pair="full", featurizer=FingerprintFeaturizer(fp_type="ecfp:4") ) diff --git a/openadmet/models/tests/unit/inference/test_inference.py b/openadmet/models/tests/unit/inference/test_inference.py index df341b21..7a4c369c 100644 --- a/openadmet/models/tests/unit/inference/test_inference.py +++ b/openadmet/models/tests/unit/inference/test_inference.py @@ -1,50 +1,204 @@ +"""Tests for the inference orchestration pipeline using real, lightweight components.""" + from pathlib import Path + +import numpy as np import pandas as pd -import os import pytest -from openadmet.models.inference.inference import predict -from openadmet.models.tests.unit.datafiles import ( - pred_test_data_csv, - anvil_lgbm_trained_model_dir, - anvil_chemprop_trained_model_dir, -) +from openadmet.models.active_learning.committee import CommitteeRegressor +from openadmet.models.anvil.specification import DataSpec, Metadata +from openadmet.models.architecture.dummy import DummyRegressorModel +from openadmet.models.features.molfeat_fingerprint import FingerprintFeaturizer +from openadmet.models.inference import inference as inference_module @pytest.fixture -def anvil_lgbm(): - return anvil_lgbm_trained_model_dir +def input_df(): + """Provide a simple DataFrame with SMILES for testing inference inputs.""" + return pd.DataFrame({"MY_SMILES": ["CCO", "CCN"]}) -@pytest.fixture -def anvil_chemprop(): - return anvil_chemprop_trained_model_dir - - -@pytest.mark.skipif( - os.getenv("RUNNER_OS") == "macOS", reason="MacOS runner not enough memory" -) -@pytest.mark.parametrize("model_dir", ["anvil_lgbm", "anvil_chemprop"]) -def test_predict(model_dir, request): - # Use the fixture to get the model directory - model_dir = request.getfixturevalue(model_dir) - # Test the predict function with a sample input - input_path = pred_test_data_csv - input_col = "MY_SMILES" - model_dir = [model_dir] - write_csv = False - output_path = None - debug = False - - result = predict( - input_path, - input_col, - model_dir, - write_csv, - output_path, - debug=False, +@pytest.fixture(scope="module") +def real_featurizer(): + """Return a real FingerprintFeaturizer using ECFP4 fingerprints.""" + return FingerprintFeaturizer(fp_type="ecfp:4") + + +@pytest.fixture(scope="module") +def real_data_spec(): + """Return a real DataSpec with a single regression target.""" + return DataSpec(type="csv", target_cols=["task_0"], input_col="MY_SMILES") + + +@pytest.fixture(scope="module") +def real_metadata_single(): + """Return real Metadata with tag UNIT for single-model tests.""" + return Metadata( + version="v1", + driver="sklearn", + name="unit-test", + build_number=0, + description="Unit test model", + tag="UNIT", + authors="Test Author", + email="test@example.com", + biotargets=["test"], + tags=["test"], + ) + + +@pytest.fixture(scope="module") +def real_metadata_ensemble(): + """Return real Metadata with tag ENS for ensemble tests.""" + return Metadata( + version="v1", + driver="sklearn", + name="ens-test", + build_number=0, + description="Ensemble test model", + tag="ENS", + authors="Test Author", + email="test@example.com", + biotargets=["test"], + tags=["test"], + ) + + +@pytest.fixture(scope="module") +def trained_single_model(): + """Return a DummyRegressorModel trained to always predict 1.0 regardless of input features.""" + X_train = np.zeros((3, 2)) + y_train = np.array([[1.0], [1.0], [1.0]]) + model = DummyRegressorModel() + model.train(X_train, y_train) + return model + + +@pytest.fixture(scope="module") +def trained_ensemble(): + """Return a CommitteeRegressor whose two members predict 1.0 and 3.0 respectively. + + The ensemble mean is 2.0 and the standard deviation is 1.0 for any input, + making the UCB score with beta=2.0 equal to 4.0. + """ + X_train = np.zeros((3, 2)) + + model1 = DummyRegressorModel() + model1.train(X_train, np.array([[1.0], [1.0], [1.0]])) + + model2 = DummyRegressorModel() + model2.train(X_train, np.array([[3.0], [3.0], [3.0]])) + + return CommitteeRegressor.from_models([model1, model2]) + + +def test_predict_with_real_single_model( + mocker, + input_df, + real_featurizer, + real_metadata_single, + real_data_spec, + trained_single_model, +): + """Test the inference pipeline with a real DummyRegressorModel. + + SMILES strings flow through a real FingerprintFeaturizer and a real DummyRegressorModel + to verify internal data plumbing. Because DummyRegressorModel always predicts the + training mean, PRED values must equal 1.0 for both inputs. The STD column must be NaN + because non-ensemble models produce no uncertainty estimate. + """ + mock_loader = mocker.patch.object( + inference_module, + "load_anvil_model_and_metadata", + return_value=( + trained_single_model, + real_featurizer, + real_metadata_single, + real_data_spec, + ), + ) + + result = inference_module.predict( + input_path=input_df, + input_col="MY_SMILES", + model_dir=["unused-model-dir"], accelerator="cpu", + log=False, ) - # Check if the result is a DataFrame assert isinstance(result, pd.DataFrame) + assert "OADMET_PRED_UNIT_task_0" in result.columns + assert "OADMET_STD_UNIT_task_0" in result.columns + assert result["OADMET_PRED_UNIT_task_0"].tolist() == pytest.approx([1.0, 1.0]) + assert result["OADMET_STD_UNIT_task_0"].isna().all() + mock_loader.assert_called_once_with(Path("unused-model-dir")) + + +def test_predict_with_real_ensemble_and_acquisition( + mocker, + input_df, + real_featurizer, + real_metadata_ensemble, + real_data_spec, + trained_ensemble, +): + """Test the inference pipeline with a real CommitteeRegressor and UCB acquisition. + + Two DummyRegressorModel members predict 1.0 and 3.0 respectively, yielding a committee + mean of 2.0 and standard deviation of 1.0 for any input. With beta=2.0, + UCB = mean + beta * std = 2.0 + 2.0 * 1.0 = 4.0. + """ + mock_loader = mocker.patch.object( + inference_module, + "load_anvil_model_and_metadata", + return_value=( + trained_ensemble, + real_featurizer, + real_metadata_ensemble, + real_data_spec, + ), + ) + + result = inference_module.predict( + input_path=input_df, + input_col="MY_SMILES", + model_dir=["unused-model-dir"], + accelerator="cpu", + log=False, + aq_fxn_args={"ucb": {"beta": 2.0}}, + ) + + assert result["OADMET_PRED_ENS_task_0"].tolist() == pytest.approx([2.0, 2.0]) + assert result["OADMET_STD_ENS_task_0"].tolist() == pytest.approx([1.0, 1.0]) + assert result["OADMET_UCB_ENS_task_0"].tolist() == pytest.approx([4.0, 4.0]) + mock_loader.assert_called_once_with(Path("unused-model-dir")) + + +def test_predict_raises_when_input_column_missing(input_df): + """Ensure that the inference function validates the existence of the specified SMILES column.""" + with pytest.raises(ValueError, match="Column OTHER not found"): + inference_module.predict( + input_path=input_df, + input_col="OTHER", + model_dir=["unused-model-dir"], + log=False, + ) + + +def test_load_anvil_model_and_metadata_missing_recipe_components(tmp_path): + """Ensure correct error is raised when the model directory structure is invalid.""" + with pytest.raises(FileNotFoundError, match="does not contain recipe components"): + inference_module.load_anvil_model_and_metadata(tmp_path) + + +def test_load_anvil_model_and_metadata_missing_procedure_yaml(tmp_path): + """Ensure correct error is raised when critical YAML metadata files are missing.""" + model_dir = tmp_path / "model" + recipe_components = model_dir / "recipe_components" + recipe_components.mkdir(parents=True) + (recipe_components / "metadata.yaml").write_text("metadata") + (recipe_components / "data.yaml").write_text("data") + + with pytest.raises(FileNotFoundError, match="does not contain procedure.yaml"): + inference_module.load_anvil_model_and_metadata(model_dir) diff --git a/openadmet/models/tests/unit/models/test_base.py b/openadmet/models/tests/unit/models/test_base.py index 13aff76a..d3601e3d 100644 --- a/openadmet/models/tests/unit/models/test_base.py +++ b/openadmet/models/tests/unit/models/test_base.py @@ -9,6 +9,13 @@ @pytest.mark.parametrize("mclass", models.classes()) def test_save_load_pickleable(mclass, tmp_path): + """ + Verify save/load mechanics for all registered pickleable models (e.g., sklearn-based). + + This iterates through the model registry and tests that any model inheriting from PickleableModelBase + can be instantiated, built, saved, and loaded without error. This is a crucial contract test + ensuring all registered models comply with the persistence interface. + """ if not issubclass(mclass, PickleableModelBase): pytest.skip(f"Skipping non-pickleable model {mclass.__name__}") model = mclass() @@ -21,6 +28,12 @@ def test_save_load_pickleable(mclass, tmp_path): @pytest.mark.parametrize("mclass", models.classes()) def test_save_load_torch_model(mclass, tmp_path): + """ + Verify save/load mechanics for all registered PyTorch Lightning models. + + Similar to the pickleable test, this ensures that deep learning models (inheriting from LightningModelBase) + implement the correct save/load logic for their weights and configurations. + """ if not issubclass(mclass, LightningModelBase): pytest.skip(f"Skipping non-torch model {mclass.__name__}") model = mclass() diff --git a/openadmet/models/tests/unit/models/test_lgbm.py b/openadmet/models/tests/unit/models/test_lgbm.py index d6a76d71..7634f755 100644 --- a/openadmet/models/tests/unit/models/test_lgbm.py +++ b/openadmet/models/tests/unit/models/test_lgbm.py @@ -6,17 +6,24 @@ @pytest.fixture def X_y(): + """Provide simple synthetic data for basic model training tests.""" X = [[1, 2, 3], [4, 5, 6]] y = [1, 2] return X, y def test_lgbm(): + """Verify that LGBMRegressorModel initializes with the correct type identifier.""" lgbm_model = LGBMRegressorModel() assert lgbm_model.type == "LGBMRegressorModel" def test_lgbm_from_params(): + """ + Validate that hyperparameters passed to the constructor are correctly applied to the underlying estimator. + + This ensures that user configurations (like n_estimators) are respected by the model. + """ lgbm_model = LGBMRegressorModel(n_estimators=100, boosting_type="rf") lgbm_model.build() assert lgbm_model.type == "LGBMRegressorModel" @@ -25,6 +32,11 @@ def test_lgbm_from_params(): def test_lgbm_train_predict(X_y): + """ + Verify the train and predict lifecycle of LGBMRegressorModel. + + This checks that the model can fit to data and generate predictions with the expected shape and values. + """ lgbm_model = LGBMRegressorModel(n_estimators=100) lgbm_model.build() X, y = X_y @@ -41,6 +53,11 @@ def test_lgbm_train_predict(X_y): def test_lgbm_save_load(tmp_path, X_y): + """ + Validate persistence of the LGBM model to disk. + + Ensures that saving and reloading the model preserves its learned state and prediction behavior. + """ lgbm_model = LGBMRegressorModel(n_estimators=100) lgbm_model.build() X, y = X_y @@ -54,6 +71,12 @@ def test_lgbm_save_load(tmp_path, X_y): def test_serialization(tmp_path, X_y): + """ + Validate JSON/pickle serialization workflow for LGBM models. + + This tests the separate storage of hyperparameters (JSON) and model weights (pickle), + which is used for model registry and versioning. + """ lgbm_model = LGBMRegressorModel(n_estimators=100) lgbm_model.build() X, y = X_y diff --git a/openadmet/models/tests/unit/models/test_nepare.py b/openadmet/models/tests/unit/models/test_nepare.py index 17a00f5f..113653e7 100644 --- a/openadmet/models/tests/unit/models/test_nepare.py +++ b/openadmet/models/tests/unit/models/test_nepare.py @@ -6,11 +6,13 @@ @pytest.fixture def X_y(): + """Provide synthetic data pairs for testing pairwise regression.""" X = [[1, 2, 3], [4, 5, 6]] y = [1, 2] return X, y def test_nepare(): + """Verify initialization of the NeuralPairwiseRegressorModel.""" nepare_model = NeuralPairwiseRegressorModel() assert nepare_model.type == "NeuralPairwiseRegressorModel" diff --git a/openadmet/models/tests/unit/split/test_splitters.py b/openadmet/models/tests/unit/split/test_splitters.py index 4c375786..f0602492 100644 --- a/openadmet/models/tests/unit/split/test_splitters.py +++ b/openadmet/models/tests/unit/split/test_splitters.py @@ -5,10 +5,10 @@ from openadmet.models.split.sklearn import ShuffleSplitter from openadmet.models.split.cluster import ClusterSplitter from openadmet.models.split.split_base import splitters -from openadmet.models.tests.unit.datafiles import CYP3A4_chembl_pchembl def test_in_splitters(): + """Verify that concrete splitter implementations are correctly registered in the splitters registry.""" assert "ShuffleSplitter" in splitters assert "ClusterSplitter" in splitters @@ -34,10 +34,15 @@ def test_in_splitters(): def test_simple_split( train_size, val_size, test_size, expected_train, expected_val, expected_test, error ): - # Error expected + """ + Validate that ShuffleSplitter correctly partitions data according to specified ratios. + + This test verifies both successful splits and error handling for invalid configurations. + Correct splitting ensures that training, validation, and test sets are of the expected size + and are mutually exclusive, which is critical for valid model evaluation. + """ if error is True: with pytest.raises(ValueError): - # Initialize splitter splitter = ShuffleSplitter( train_size=train_size, val_size=val_size, @@ -46,142 +51,239 @@ def test_simple_split( ) return - # Initialize splitter splitter = ShuffleSplitter( train_size=train_size, val_size=val_size, test_size=test_size, random_state=42 ) - # Generate random data + # Generate synthetic random data for testing split logic X = np.random.rand(100, 10) y = np.random.rand(100) - # Error is expected - if error is True: - with pytest.raises(ValueError): - splitter.split(X, y) - return - - # Perform the split X_train, X_val, X_test, y_train, y_val, y_test, groups = splitter.split(X, y) - # Check train assert X_train.shape[0] == expected_train assert y_train.shape[0] == expected_train - # Validation set requested if val_size > 0: assert X_val.shape[0] == expected_val assert y_val.shape[0] == expected_val + # Assert X_train and X_val are mutually exclusive + train_set = set(map(tuple, X_train)) + val_set = set(map(tuple, X_val)) + assert len(train_set.intersection(val_set)) == 0 - # Validation set not requested else: assert X_val is None assert y_val is None - # Test set requested if test_size > 0: assert X_test.shape[0] == expected_test assert y_test.shape[0] == expected_test + # Assert X_train and X_test are mutually exclusive + train_set = set(map(tuple, X_train)) + test_set = set(map(tuple, X_test)) + assert len(train_set.intersection(test_set)) == 0 + + if val_size > 0: + # Assert X_val and X_test are mutually exclusive + val_set = set(map(tuple, X_val)) + assert len(val_set.intersection(test_set)) == 0 - # Test set not requested else: assert X_test is None assert y_test is None +@pytest.fixture +def synthetic_cluster_data(): + """ + Provide a synthetic dataset with structural diversity for testing cluster splitting. + + This fixture returns a set of SMILES strings representing different chemical scaffolds + (benzenes, pyridines, cyclohexanes, furans, thiophenes) and corresponding target values. + Using diverse scaffolds ensures that clustering algorithms (like Butina or Bemis-Murcko) + can meaningfully group molecules, allowing verification that splits respect cluster boundaries. + """ + base_smiles = [ + "Cc1ccccc1", + "CCc1ccccc1", + "Oc1ccccc1", + "Nc1ccccc1", + "Clc1ccccc1", + "Fc1ccccc1", + "C(=O)Oc1ccccc1", + "C(=O)Cc1ccccc1", + "c1ccccc1C#N", + "COc1ccccc1", + "Cc1ccncc1", + "CCc1ccncc1", + "Oc1ccncc1", + "Nc1ccncc1", + "Clc1ccncc1", + "Fc1ccncc1", + "C(=O)Oc1ccncc1", + "C(=O)Cc1ccncc1", + "c1ccncc1C#N", + "COc1ccncc1", + "CC1CCCCC1", + "CCC1CCCCC1", + "OC1CCCCC1", + "NC1CCCCC1", + "ClC1CCCCC1", + "FC1CCCCC1", + "C(=O)OC1CCCCC1", + "C(=O)CC1CCCCC1", + "C1CCCCC1C#N", + "COC1CCCCC1", + "Cc1ccoc1", + "CCc1ccoc1", + "Oc1ccoc1", + "Nc1ccoc1", + "Clc1ccoc1", + "Fc1ccoc1", + "C(=O)Oc1ccoc1", + "C(=O)Cc1ccoc1", + "c1ccoc1C#N", + "COc1ccoc1", + "Cc1ccsc1", + "CCc1ccsc1", + "Oc1ccsc1", + "Nc1ccsc1", + "Clc1ccsc1", + "Fc1ccsc1", + "C(=O)Oc1ccsc1", + "C(=O)Cc1ccsc1", + "c1ccsc1C#N", + "COc1ccsc1", + "Cc1ccc2ccccc2c1", + "CCc1ccc2ccccc2c1", + "Oc1ccc2ccccc2c1", + "Nc1ccc2ccccc2c1", + "Clc1ccc2ccccc2c1", + "Fc1ccc2ccccc2c1", + "C(=O)Oc1ccc2ccccc2c1", + "C(=O)Cc1ccc2ccccc2c1", + "c1ccc2ccccc2c1C#N", + "COc1ccc2ccccc2c1", + "Cc1ccc2[nH]ccc2c1", + "CCc1ccc2[nH]ccc2c1", + "Oc1ccc2[nH]ccc2c1", + "Nc1ccc2[nH]ccc2c1", + "Clc1ccc2[nH]ccc2c1", + "Fc1ccc2[nH]ccc2c1", + "C(=O)Oc1ccc2[nH]ccc2c1", + "C(=O)Cc1ccc2[nH]ccc2c1", + "c1ccc2[nH]ccc2c1C#N", + "COc1ccc2[nH]ccc2c1", + "Cc1ccc2ncccc2c1", + "CCc1ccc2ncccc2c1", + "Oc1ccc2ncccc2c1", + "Nc1ccc2ncccc2c1", + "Clc1ccc2ncccc2c1", + "Fc1ccc2ncccc2c1", + "C(=O)Oc1ccc2ncccc2c1", + "C(=O)Cc1ccc2ncccc2c1", + "c1ccc2ncccc2c1C#N", + "COc1ccc2ncccc2c1", + "CC1CCCC1", + "CCC1CCCC1", + "OC1CCCC1", + "NC1CCCC1", + "ClC1CCCC1", + "FC1CCCC1", + "C(=O)OC1CCCC1", + "C(=O)CC1CCCC1", + "C1CCCC1C#N", + "COC1CCCC1", + "CC1CCNCC1", + "CCC1CCNCC1", + "OC1CCNCC1", + "NC1CCNCC1", + "ClC1CCNCC1", + "FC1CCNCC1", + "C(=O)OC1CCNCC1", + "C(=O)CC1CCNCC1", + "C1CCNCC1C#N", + "COC1CCNCC1", + ] + smiles = pd.Series(base_smiles) + y = pd.Series(np.linspace(0.0, 1.0, len(smiles))) + return smiles, y + + @pytest.mark.parametrize( - "train_size, val_size, test_size, expected_train, expected_val, expected_test, error, method", + "method", [ - # Test cases for kmeans - (0.8, 0.0, 0.2, 1600, 0, 400, False, "kmeans"), - (0.7, 0.3, 0.0, 1400, 600, 0, False, "kmeans"), - (0.7, 0.1, 0.2, 1400, 200, 400, False, "kmeans"), - (0.6, 0.2, 0.2, 1200, 400, 400, False, "kmeans"), - # Test cases for butina - (0.8, 0.0, 0.2, 1600, 0, 400, False, "butina"), - (0.7, 0.3, 0.0, 1400, 600, 0, False, "butina"), - (0.7, 0.1, 0.2, 1400, 200, 400, False, "butina"), - (0.6, 0.2, 0.2, 1200, 400, 400, False, "butina"), - # Test cases for bemis-murcko - (0.8, 0.0, 0.2, 1600, 0, 400, False, "bemis-murcko"), - (0.7, 0.3, 0.0, 1400, 600, 0, False, "bemis-murcko"), - (0.7, 0.1, 0.2, 1400, 200, 400, False, "bemis-murcko"), - (0.6, 0.2, 0.2, 1200, 400, 400, False, "bemis-murcko"), - # Error cases - (1.0, 0.0, 0.0, 200, 0, 0, True, "kmeans"), - (0.5, 0.5, 0.5, -1, -1, -1, True, "kmeans"), + "kmeans", + "butina", + "bemis-murcko", ], ) -def test_cluster_split( - train_size, - val_size, - test_size, - expected_train, - expected_val, - expected_test, - error, - method, -): - df = pd.read_csv(CYP3A4_chembl_pchembl) - X = df["CANONICAL_SMILES"][:2000] - y = df["pChEMBL mean"][:2000] - - # Error expected - if error is True: - with pytest.raises(ValueError): - # Initialize splitter - splitter = ClusterSplitter( - train_size=train_size, - val_size=val_size, - test_size=test_size, - random_state=42, - method=method, - k_clusters=100, - ) - return +def test_cluster_split_synthetic_data(method, synthetic_cluster_data): + """ + Validate ClusterSplitter functionality with different clustering methods. - # Initialize splitter + This test ensures that molecular data is split such that training, validation, and test sets + contain mutually exclusive molecules (no data leakage). It verifies split sizes are approximately + correct and that structural separation is maintained. + """ + X, y = synthetic_cluster_data splitter = ClusterSplitter( - train_size=train_size, - val_size=val_size, - test_size=test_size, + train_size=0.7, + val_size=0.1, + test_size=0.2, random_state=42, method=method, - k_clusters=100, + k_clusters=10, + ) + X_train, X_val, X_test, y_train, y_val, y_test, groups = splitter.split( + X, y, num_iters=50 ) - # Perform the split - X_train, X_val, X_test, y_train, y_val, y_test, groups = splitter.split(X, y) - - # Check type preservation for obj in [X_train, X_val, X_test]: if obj is not None: - assert isinstance(obj, pd.Series), "X split must preserve pandas Series" + assert isinstance(obj, pd.Series) for obj in [y_train, y_val, y_test]: if obj is not None: - assert isinstance(obj, pd.Series), "y split must preserve pandas Series" + assert isinstance(obj, pd.Series) - # Check train - assert abs(X_train.shape[0] - expected_train) <= 10 - assert abs(y_train.shape[0] - expected_train) <= 10 + total = len(X) + assert abs(len(X_train) - int(0.7 * total)) <= 5 + assert abs(len(X_val) - int(0.1 * total)) <= 5 + assert abs(len(X_test) - int(0.2 * total)) <= 5 + assert len(groups) == total - # Validation set requested - if val_size > 0: - assert abs(X_val.shape[0] - expected_val) <= 10 - assert abs(y_val.shape[0] - expected_val) <= 10 + # Check for data leakage + # Assert X_train, X_val, and X_test are mutually exclusive by index + train_idx = set(X_train.index) + val_idx = set(X_val.index) + test_idx = set(X_test.index) - # Validation set not requested - else: - assert X_val is None - assert y_val is None + assert len(train_idx.intersection(val_idx)) == 0 + assert len(train_idx.intersection(test_idx)) == 0 + assert len(val_idx.intersection(test_idx)) == 0 - # Test set requested - if test_size > 0: - assert abs(X_test.shape[0] - expected_test) <= 10 - assert abs(y_test.shape[0] - expected_test) <= 10 - # Test set not requested - else: - assert X_test is None - assert y_test is None +def test_cluster_split_invalid_size_configuration(): + """Ensure ClusterSplitter raises ValueError for invalid split size configurations (e.g., sum != 1.0).""" + with pytest.raises(ValueError): + ClusterSplitter( + train_size=1.0, + val_size=0.0, + test_size=0.0, + random_state=42, + method="kmeans", + ) + + +def test_cluster_split_invalid_method(): + """Ensure ClusterSplitter raises ValueError when initialized with an unknown clustering method.""" + with pytest.raises(ValueError): + ClusterSplitter( + train_size=0.7, + val_size=0.1, + test_size=0.2, + random_state=42, + method="not-a-method", + ) diff --git a/openadmet/models/tests/unit/test_utils.py b/openadmet/models/tests/unit/test_utils.py index fbd17e31..9e65ee49 100644 --- a/openadmet/models/tests/unit/test_utils.py +++ b/openadmet/models/tests/unit/test_utils.py @@ -2,6 +2,12 @@ def click_success(result): + """ + Helper function to verify that a Click command executed successfully (exit code 0). + + If the command failed, this function prints the output and traceback to aid in debugging + before returning False. + """ if result.exit_code != 0: # -no-cov- (only occurs on test error) print(result.output) traceback.print_tb(result.exc_info[2]) diff --git a/pyproject.toml b/pyproject.toml index 48996cef..0ed1c4c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ requires-python = ">=3.10" [project.optional-dependencies] test = [ "pytest>=6.1.2", + "pytest-mock", ] [project.scripts]