Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 107 additions & 27 deletions openadmet/models/tests/unit/cli/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,94 @@
from openadmet.models.cli.cli import cli
from openadmet.models.tests.test_utils import click_success
from openadmet.models.tests.unit.datafiles import (
anvil_lgbm_trained_model_dir,
pred_test_data_csv,
basic_anvil_yaml_cv,
)
from pathlib import Path

import pytest
from click.testing import CliRunner

from openadmet.models.cli import anvil as anvil_cli_module
from openadmet.models.cli import predict as predict_cli_module
from openadmet.models.cli.cli import cli
from openadmet.models.tests.test_utils import click_success
from openadmet.models.tests.unit.datafiles import basic_anvil_yaml_cv


@pytest.fixture
def runner():
"""Provide a Click CliRunner for testing CLI commands in isolation."""
return CliRunner()


def test_toplevel_runnable():
"""Test the top-level CLI command"""
runner = CliRunner()
def test_toplevel_runnable(runner):
"""Ensure the top-level 'openadmet' command runs and displays help without error."""
result = runner.invoke(cli, ["--help"])
assert click_success(result)


@pytest.mark.parametrize(
"args",
[
["anvil", "--help"],
["compare", "--help"],
["predict", "--help"],
],
"args", [["anvil", "--help"], ["compare", "--help"], ["predict", "--help"]]
)
def test_subcommand_runnable(args):
"""Test the subcommands"""
runner = CliRunner()
def test_subcommand_runnable(runner, args):
"""Verify that all major subcommands (anvil, compare, predict) are registered and runnable."""
result = runner.invoke(cli, args)
assert click_success(result)


def test_predict_cli(tmp_path):
"""Test the predict CLI command"""
runner = CliRunner()
def test_predict_cli_invokes_inference(tmp_path, runner, mocker):
"""
Validate that the 'predict' subcommand correctly parses arguments and calls the underlying inference function.

We mock `inference_func` to avoid loading real models (which is heavy and requires trained artifacts).
This ensures that the CLI layer correctly passes paths, column names, and flags to the logic layer.
"""
input_csv = tmp_path / "input.csv"
input_csv.write_text("MY_SMILES\nCCO\n")
model_dir = tmp_path / "model_dir"
model_dir.mkdir()

mock_inference = mocker.patch.object(
predict_cli_module, "inference_func", autospec=True
)

result = runner.invoke(
cli,
[
"predict",
"--input-path",
pred_test_data_csv,
input_csv,
"--input-col",
"MY_SMILES",
"--model-dir",
anvil_lgbm_trained_model_dir,
model_dir,
"--output-csv",
tmp_path / "predictions.csv",
"--accelerator",
"cpu",
],
)
assert click_success(result)
mock_inference.assert_called_once()
called = mock_inference.call_args.kwargs
assert called["input_col"] == "MY_SMILES"
assert called["accelerator"] == "cpu"
assert called["write_csv"] is True
assert list(called["model_dir"]) == [model_dir]


def test_anvil_cli(tmp_path):
"""Test the anvil CLI command"""
runner = CliRunner()
def test_anvil_cli_invokes_workflow(tmp_path, runner, mocker):
"""
Validate that the 'anvil' subcommand correctly initializes and runs a workflow from a recipe.

We mock the `AnvilSpecification` and workflow execution to verify that the CLI correctly handles
recipe paths and output directories without actually running a full ML training job.
"""
mock_spec = mocker.create_autospec(
anvil_cli_module.AnvilSpecification, instance=True
)
mock_from_recipe = mocker.patch.object(
anvil_cli_module.AnvilSpecification,
"from_recipe",
autospec=True,
return_value=mock_spec,
)

result = runner.invoke(
cli,
[
Expand All @@ -66,3 +101,48 @@ def test_anvil_cli(tmp_path):
)

assert click_success(result)
mock_from_recipe.assert_called_once_with(basic_anvil_yaml_cv)
mock_spec.run.assert_called_once()
called = mock_spec.run.call_args.kwargs
assert Path(called["output_dir"]) == tmp_path / "anvil_output"
assert called["debug"] is False


@pytest.mark.parametrize(
"aq_fxns,beta,best_y,xi,expected",
[
(("ucb",), (2.0,), (), (), {"ucb": {"beta": 2.0}}),
(
("ei", "pi"),
(),
(1.0, 2.0),
(0.1, 0.2),
{"ei": {"xi": 0.1, "best_y": 1.0}, "pi": {"xi": 0.2, "best_y": 2.0}},
),
],
)
def test_validate_aq_fxns_success(aq_fxns, beta, best_y, xi, expected):
"""
Verify that valid combinations of acquisition function arguments are correctly parsed into a configuration dict.

This tests the CLI argument validation logic for active learning parameters.
"""
assert predict_cli_module._validate_aq_fxns(aq_fxns, beta, best_y, xi) == expected


@pytest.mark.parametrize(
"aq_fxns,beta,best_y,xi,error_message",
[
(("ucb", "ucb"), (1.0, 2.0), (), (), "UCB can only be specified once"),
(("ei",), (), (), (), "must be specified once per EI and/or PI acquisition"),
(("ucb",), (), (), (), "Field `beta` must be specified for UCB acquisition"),
],
)
def test_validate_aq_fxns_errors(aq_fxns, beta, best_y, xi, error_message):
"""
Ensure that invalid acquisition function arguments trigger appropriate validation errors.

This prevents users from running predictions with ambiguous or incomplete active learning settings.
"""
with pytest.raises(ValueError, match=error_message):
predict_cli_module._validate_aq_fxns(aq_fxns, beta, best_y, xi)
53 changes: 40 additions & 13 deletions openadmet/models/tests/unit/comparison/test_comparison.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,41 @@
import numpy as np
import pytest
from numpy.testing import assert_almost_equal
import numpy as np

from openadmet.models.comparison.compare_base import get_comparison_class
from openadmet.models.comparison.posthoc import PostHocComparison
from openadmet.models.tests.unit.datafiles import (
cyp2c9_json,
anvil_lgbm_trained_model_dir,
cyp1a2_json,
cyp2c9_json,
cyp3a4_json,
multi_task_json,
anvil_lgbm_trained_model_dir,
)


def test_get_comparison_class():
"""Test getting comparison class."""
"""
Test dynamic retrieval of comparison classes from the registry.

Verifies that valid class names return the class and invalid names raise ValueError.
"""
get_comparison_class("PostHoc")
with pytest.raises(ValueError):
get_comparison_class("NotARealClass")


def test_posthoc_fails_on_incorrect_inputs():
"""Test that posthoc comparison fails when given incorrect inputs.
"""
Test that posthoc comparison fails when given incorrect inputs.

Inputs include:
- No inputs
- Only one of model_stats_fns, labels, or task_names
- Mismatched lengths of model_stats_fns, labels, and task_names
- Repeated labels
- Incorrect labels and task_names for model_stats_fns

This validation is critical to ensure that comparison tables and plots match models to their correct metadata.
"""
comp_obj = PostHocComparison()
with pytest.raises(ValueError):
Expand Down Expand Up @@ -69,7 +77,12 @@ def test_posthoc_repeat_label_error():


def test_posthoc_comparison():
"""Test that posthoc comparison works when given correct inputs."""
"""
Test that posthoc comparison works correctly when given valid inputs.

This verifies the calculation of statistical tests (Levene's test for equality of variances,
Tukey's HSD for pairwise mean differences) based on loaded model metrics.
"""
model_stats = [cyp2c9_json, cyp3a4_json, cyp1a2_json]
model_tags = [
"openadmet-CYP2C9-pchembl-regression-testing-cv",
Expand Down Expand Up @@ -97,7 +110,12 @@ def test_posthoc_comparison():
def test_posthoc_comparison_anvil_reader_and_feature_label(
label_types, expected_labels
):
"""Test that posthoc comparison can read from anvil-trained model directories and features."""
"""
Test that posthoc comparison can automatically extract labels from anvil-trained model directories.

This ensures that metadata stored in `metadata.yaml` within model directories can be correctly
parsed to generate readable labels for comparison plots.
"""
comp_obj = PostHocComparison()
model_stats_fns, labels, task_names = comp_obj.label_and_task_name_from_anvil(
model_dirs=[anvil_lgbm_trained_model_dir], label_types=label_types
Expand All @@ -121,7 +139,12 @@ def test_posthoc_comparison_json_reader_fails(label_types):


def test_posthoc_comparison_json_reader():
"""Test that posthoc comparison can read multi vs single task from anvil file."""
"""
Test that posthoc comparison handles both multi-task and single-task JSON result files.

This verifies that the system can normalize results from different task types into a common
format for statistical comparison.
"""
model_stats = [multi_task_json, cyp3a4_json]
model_tags = ["multitask", "single_task"]
task_tags = ["cyp3a4_pchembl_value_mean", "pchembl_value_mean"]
Expand All @@ -130,14 +153,18 @@ def test_posthoc_comparison_json_reader():
levene, tukeys_df = comp_obj.compare(
model_stats_fns=model_stats, labels=model_tags, task_names=task_tags
)
assert levene["mse"]["stat"] == 2.483488460351842
assert levene["ktau"]["stat"] == 1.0392615736603197
assert tukeys_df["metric_val"][0] == -0.01037444780666702
assert tukeys_df["pvalue"][0] == 0.2488307785417857
assert levene["mse"]["stat"] == pytest.approx(2.483, abs=0.001)
assert levene["ktau"]["stat"] == pytest.approx(1.039, abs=0.001)
assert tukeys_df["metric_val"][0] == pytest.approx(-0.010, abs=0.001)
assert tukeys_df["pvalue"][0] == pytest.approx(0.248, abs=0.001)


def test_posthoc_comparison_printing(capsys):
"""Test that posthoc comparison prints results to console."""
"""
Test that posthoc comparison prints results to console in a readable format.

We capture stdout to verify that Levene's test and Tukey's HSD results are actually displayed to the user.
"""
model_stats = [cyp2c9_json, cyp3a4_json, cyp1a2_json]
model_tags = [
"openadmet-CYP2C9-pchembl-regression-testing-cv",
Expand Down
18 changes: 18 additions & 0 deletions openadmet/models/tests/unit/data/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@


def test_data_spec_from_csv():
"""
Validate loading data from a CSV file via DataSpec.

Ensures that the data loader correctly reads the specified CSV, extracts the target and SMILES columns,
and returns them as expected.
"""
data_spec = DataSpec(
type="intake",
resource=test_csv,
Expand All @@ -18,6 +24,12 @@ def test_data_spec_from_csv():


def test_data_spec_from_intake():
"""
Validate loading data from an Intake catalog.

Intake allows for declarative data loading. This test checks that DataSpec can correctly interface
with an Intake catalog to retrieve data.
"""
data_spec = DataSpec(
type="intake",
resource=intake_cat,
Expand All @@ -32,6 +44,12 @@ def test_data_spec_from_intake():

@pytest.mark.parametrize("dropna, expected_length", [(True, 3333), (False, 7196)])
def test_data_spec_dropna(dropna, expected_length):
"""
Test the `dropna` functionality in DataSpec.

Verifies that rows with missing values in target columns are dropped when dropna=True,
and preserved when dropna=False. This is critical for handling real-world datasets which often contain gaps.
"""
data_spec = DataSpec(
type="intake",
resource=nan_data,
Expand Down
Loading
Loading