Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions openadmet/models/eval/eval_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from loguru import logger
import numpy as np
from class_registry import ClassRegistry, RegistryKeyError
from pydantic import BaseModel
from pydantic import BaseModel, Field
from scipy.stats import bootstrap

evaluators = ClassRegistry(unique=True)
Expand Down Expand Up @@ -145,10 +145,26 @@ def get_t_true_and_t_pred(task_id, y_true, y_pred, y_val=None, y_pred_fold=None)


class EvalBase(BaseModel):
"""Abstract base class for evaluation modules."""
"""
Abstract base class for evaluation modules.

Attributes
----------
n_resamples : int
Number of bootstrap resamples used to estimate confidence intervals.
Defaults to 9999 (scipy default). Lower values (e.g. 100) are appropriate
for unit tests where CI precision is not required.

"""

is_cross_val: ClassVar[bool] = False

n_resamples: int = Field(
default=9999,
ge=1,
description="Number of bootstrap resamples for confidence interval estimation",
)

class Config:
"""Pydantic configuration for the EvalBase class."""

Expand Down Expand Up @@ -244,6 +260,7 @@ def stat_and_bootstrap(
statistic=lambda y_true, y_pred: statistic(y_true, y_pred).statistic,
method="basic",
confidence_level=confidence_level,
n_resamples=self.n_resamples,
paired=True,
).confidence_interval

Expand All @@ -254,6 +271,7 @@ def stat_and_bootstrap(
statistic=statistic,
method="basic",
confidence_level=confidence_level,
n_resamples=self.n_resamples,
paired=True,
).confidence_interval

Expand Down
4 changes: 2 additions & 2 deletions openadmet/models/tests/unit/eval/test_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_regression_metrics():
y_true = np.array([3, -0.5, 2, 7]).reshape(-1, 1)
y_pred = np.array([2.5, 0.0, 2, 8]).reshape(-1, 1)

rm = RegressionMetrics()
rm = RegressionMetrics(n_resamples=100)
metrics = rm.evaluate(y_true, y_pred)

assert metrics["task_0"]["mse"]["value"] == 0.375
Expand Down Expand Up @@ -110,7 +110,7 @@ def test_classification_metrics():
# Classes would be [0, 1, 1, 1]
y_pred = [[1, 0], [0, 1], [0, 1], [0, 1]]

cm = ClassificationMetrics()
cm = ClassificationMetrics(n_resamples=100)
metrics = cm.evaluate(y_true, y_pred)

assert metrics["accuracy"]["value"] == 0.75
Expand Down
Loading