Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions benchmark_utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,16 @@
"""

import numpy as np
from sklearn.metrics import (
accuracy_score,
average_precision_score,
balanced_accuracy_score,
f1_score,
roc_auc_score,
)

from benchmark_utils.outputs import ForecastOutput

# NB: sklearn is imported lazily inside the classification / anomaly-detection
# metrics below, so this module (and the registries at the bottom) can be
# imported with only numpy installed. benchopt imports objective.py and the
# dataset files just to read metadata in envs without the objective's
# requirements; a top-level sklearn import would break that (and ``benchopt
# test`` env setup). The metric functions only run when sklearn is present.

# ---------------------------------------------------------------------------
# Forecasting — internal helpers
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -172,14 +172,20 @@ def mcis(y_true, forecast: ForecastOutput, alpha=0.05, **_):


def accuracy(y_true, y_pred):
from sklearn.metrics import accuracy_score

return float(accuracy_score(y_true, y_pred))


def balanced_accuracy(y_true, y_pred):
from sklearn.metrics import balanced_accuracy_score

return float(balanced_accuracy_score(y_true, y_pred))


def f1_weighted(y_true, y_pred):
from sklearn.metrics import f1_score

return float(f1_score(y_true, y_pred, average="weighted", zero_division=0))


Expand All @@ -196,6 +202,8 @@ def auc_roc(y_true, y_score):
y_true : list of (T_j,) int arrays, concatenated
y_score : list of (T_j,) float arrays, concatenated
"""
from sklearn.metrics import roc_auc_score

y_true = np.concatenate([np.asarray(y) for y in y_true])
y_score = np.concatenate([np.asarray(y) for y in y_score])
if y_true.sum() == 0:
Expand All @@ -205,6 +213,8 @@ def auc_roc(y_true, y_score):

def auc_pr(y_true, y_score):
"""Area under Precision-Recall curve."""
from sklearn.metrics import average_precision_score

y_true = np.concatenate([np.asarray(y) for y in y_true])
y_score = np.concatenate([np.asarray(y) for y in y_score])
if y_true.sum() == 0:
Expand All @@ -224,6 +234,8 @@ def f1_pa(y_true, y_score, threshold=None):
If None, the threshold is chosen to maximise F1 on the test set
(oracle threshold — for benchmarking purposes only).
"""
from sklearn.metrics import f1_score

y_true_cat = np.concatenate([np.asarray(y) for y in y_true])
y_score_cat = np.concatenate([np.asarray(y) for y in y_score])

Expand Down
9 changes: 8 additions & 1 deletion objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,14 @@ class Objective(BaseObjective):

sampling_strategy = "run_once"

# Minimal config for ``benchopt test``
# Test dataset for ``benchopt test``. ``test_dataset_name`` is the only
# selector benchopt can read statically (via AST) when it builds the test
# env under ``skip_import_ctx`` without importing this module — ``test_config``
# is not a base-class attribute, so it is invisible there and resolution
# would otherwise fall back to the non-existent ``simulated`` dataset.
test_dataset_name = "monash"
# Richer config used once the objective is actually imported: exercise more
# datasets, all in debug mode for speed.
test_config = {
"dataset": {
# Skipping MITDB for now due to timeout in download
Expand Down
Loading