From aa57ff2ddae8c4f4e42195940b1367263848c27a Mon Sep 17 00:00:00 2001 From: Geoffrey Negiar Date: Fri, 29 May 2026 17:13:13 +0200 Subject: [PATCH] ENH add model capability flags + per-capability deactivation Declare each forecasting model's capabilities (multivariate, hist/future covariates) and let users deactivate covariate capabilities per run to benchmark the lift each provides. - benchmark_utils/capabilities.py: flag vocabulary + mask_covariates helper - BaseTSFMAdapter.covariate_capabilities: effective active set (default empty => univariate, safe by default) - Objective._eval_forecasting masks the covariate payload to the adapter's capabilities before predict() -- single, guaranteed enforcement point - Solvers declare `capabilities`; TFC-API exposes use_hist_covars / use_future_covars toggles and now threads covariates to the SDK via historical_variables / future_variables (multivariate stays declarative) Co-Authored-By: Claude Opus 4.8 (1M context) --- benchmark_utils/adapters/base.py | 13 +++ benchmark_utils/capabilities.py | 75 +++++++++++++ objective.py | 8 +- solvers/chronos.py | 3 + solvers/chronos2.py | 6 + solvers/naive.py | 3 + solvers/seasonal_naive.py | 3 + solvers/tfc_api.py | 121 ++++++++++++++++++--- solvers/toto2.py | 6 + tests/benchmark_utils/test_capabilities.py | 53 +++++++++ tests/solvers/test_tfc_api_covariates.py | 88 +++++++++++++++ tests/test_objective_capability_masking.py | 86 +++++++++++++++ 12 files changed, 447 insertions(+), 18 deletions(-) create mode 100644 benchmark_utils/capabilities.py create mode 100644 tests/benchmark_utils/test_capabilities.py create mode 100644 tests/solvers/test_tfc_api_covariates.py create mode 100644 tests/test_objective_capability_masking.py diff --git a/benchmark_utils/adapters/base.py b/benchmark_utils/adapters/base.py index f7e2b69..3ec66d5 100644 --- a/benchmark_utils/adapters/base.py +++ b/benchmark_utils/adapters/base.py @@ -47,8 +47,21 @@ class BaseTSFMAdapter(ABC): Subclasses must implement ``predict``. ``fit`` is optional (used by supervised adaptations such as linear probe or fine-tuning). + + Attributes + ---------- + covariate_capabilities : frozenset[str] + The *effective* covariate capabilities this adapter consumes for the + current run (a subset of + :data:`benchmark_utils.capabilities.COVARIATE_CAPABILITIES`). The + forecasting objective reads this and masks the covariate payload down + to it before calling :meth:`predict`, so an adapter only ever sees + covariates it both declares and has enabled. Defaults to empty — + univariate, no covariates — so a new adapter is safe by default. """ + covariate_capabilities: frozenset = frozenset() + def fit(self, X_train, y_train, **kwargs): """Optional supervised fitting step (called inside Solver.run()).""" return self diff --git a/benchmark_utils/capabilities.py b/benchmark_utils/capabilities.py new file mode 100644 index 0000000..d71c481 --- /dev/null +++ b/benchmark_utils/capabilities.py @@ -0,0 +1,75 @@ +"""Model capability flags and covariate masking. + +Vocabulary +---------- +A forecasting solver declares a ``capabilities`` set drawn from: + +- :data:`MULTIVARIATE` — the model treats target channels jointly. + *Declarative only*: targets are always passed whole (no channel + splitting), so there is no behavioural toggle for this yet — it exists to + describe the model until a multivariate-*target* dataset and the matching + masking land. +- :data:`HIST_COVARIATES` — the model consumes history-only (past) covariates. +- :data:`FUTURE_COVARIATES` — the model consumes known-ahead (future) covariates. + +``univariate`` is deliberately **not** a flag — it is the floor every model +gets. A model that declares (or has enabled) none of the covariate +capabilities runs univariate. + +Deactivation / lift +------------------- +The covariate capabilities are independently switchable per run (exposed as +benchopt parameters by the consuming solver), so the lift each one provides +can be benchmarked. Enforcement is central: the objective masks the +:class:`~benchmark_utils.covariates.Covariates` payload down to the adapter's +*effective* active set (``BaseTSFMAdapter.covariate_capabilities``) via +:func:`mask_covariates` before calling ``predict``. A model therefore only +ever sees covariates it both declares and has enabled. Targets are never +masked. +""" + +from benchmark_utils.covariates import Covariates + +MULTIVARIATE = "multivariate" +HIST_COVARIATES = "hist_covariates" +FUTURE_COVARIATES = "future_covariates" + +#: Capabilities whose covariate payload :func:`mask_covariates` acts on. +COVARIATE_CAPABILITIES = frozenset({HIST_COVARIATES, FUTURE_COVARIATES}) + +#: Every capability in the vocabulary. +ALL_CAPABILITIES = frozenset({MULTIVARIATE, HIST_COVARIATES, FUTURE_COVARIATES}) + + +def mask_covariates(covariates: Covariates, active) -> Covariates: + """Return a copy of ``covariates`` with disabled covariate fields emptied. + + ``hist_covars`` is cleared unless :data:`HIST_COVARIATES` is in ``active``, + and ``future_covars`` unless :data:`FUTURE_COVARIATES` is in ``active``. + ``static_covars`` is passed through unchanged — it is not yet part of the + capability vocabulary. Targets live in ``ForecastInput.x`` and are never + touched here. + + Parameters + ---------- + covariates : Covariates + The dataset's full covariate payload. + active : Iterable[str] + The effective active capability names (typically an adapter's + ``covariate_capabilities``). + + Returns + ------- + Covariates + A new (frozen) instance; the input is not mutated. + """ + active = frozenset(active) + return Covariates( + static_covars=covariates.static_covars, + hist_covars=( + covariates.hist_covars if HIST_COVARIATES in active else [] + ), + future_covars=( + covariates.future_covars if FUTURE_COVARIATES in active else [] + ), + ) diff --git a/objective.py b/objective.py index 1f09243..52e9468 100644 --- a/objective.py +++ b/objective.py @@ -117,13 +117,19 @@ def evaluate_result(self, model): # --- forecasting --------------------------------------------------- def _eval_forecasting(self, model): + from benchmark_utils.capabilities import mask_covariates from benchmark_utils.inputs import ForecastInput + # Mask the covariate payload down to what this model declares it can + # use and has enabled. A model that consumes no covariates (the + # default) thus runs univariate; toggling a capability off here is + # what makes its lift measurable. Targets are never masked. + active = getattr(model, "covariate_capabilities", frozenset()) forecast = model.predict( ForecastInput( x=self.X_test, cutoff_indexes=self.cutoff_indexes, - covariates=self.covariates, + covariates=mask_covariates(self.covariates, active), ) ).flatten() # canonical (M, Q, H, C) shape for metrics diff --git a/solvers/chronos.py b/solvers/chronos.py index 3c5bdad..f79ce30 100644 --- a/solvers/chronos.py +++ b/solvers/chronos.py @@ -253,6 +253,9 @@ class Solver(BaseSolver): sampling_strategy = "run_once" + # Chronos (v1) is univariate and consumes no covariates. + capabilities = frozenset() + parameters = { "model_size": ["small"], "layer": [None], diff --git a/solvers/chronos2.py b/solvers/chronos2.py index d0e8234..c19496c 100644 --- a/solvers/chronos2.py +++ b/solvers/chronos2.py @@ -27,6 +27,7 @@ UnpooledEncoder, ) from benchmark_utils.adapters.forecast_residual import ForecastResidualAdapter +from benchmark_utils.capabilities import MULTIVARIATE from benchmark_utils.outputs import ForecastOutput from .chronos import ( @@ -174,6 +175,11 @@ class Solver(BaseSolver): sampling_strategy = "run_once" + # Chronos-2 models channels jointly. ``multivariate`` is declarative + # metadata only — targets are always passed whole, so there is no + # behavioural toggle yet. This solver does not consume covariates. + capabilities = frozenset({MULTIVARIATE}) + parameters = { "model_size": ["small"], "layer": [None], diff --git a/solvers/naive.py b/solvers/naive.py index 3b7edb1..7f04e65 100644 --- a/solvers/naive.py +++ b/solvers/naive.py @@ -94,6 +94,9 @@ class Solver(BaseSolver): sampling_strategy = "run_once" + # Per-channel univariate baseline; consumes no covariates. + capabilities = frozenset() + parameters = { "seasonality": [1], } diff --git a/solvers/seasonal_naive.py b/solvers/seasonal_naive.py index 176691d..8953dd8 100644 --- a/solvers/seasonal_naive.py +++ b/solvers/seasonal_naive.py @@ -64,6 +64,9 @@ class Solver(BaseSolver): sampling_strategy = "run_once" + # Per-channel univariate baseline; consumes no covariates. + capabilities = frozenset() + parameters = { "season_length": [1, 7, 12, 24], } diff --git a/solvers/tfc_api.py b/solvers/tfc_api.py index a507c9c..78f3013 100644 --- a/solvers/tfc_api.py +++ b/solvers/tfc_api.py @@ -32,6 +32,12 @@ from benchopt import BaseSolver from benchmark_utils.adapters.base import BaseTSFMAdapter +from benchmark_utils.capabilities import ( + FUTURE_COVARIATES, + HIST_COVARIATES, + MULTIVARIATE, +) +from benchmark_utils.covariates import Covariates from benchmark_utils.inputs import ForecastInput from benchmark_utils.outputs import ForecastOutput @@ -69,6 +75,47 @@ def _shared_offsets_from_end(x, cutoff_indexes): return reference +def _as_2d(arr) -> np.ndarray: + """Normalise a covariate cell to ``(T, n)``.""" + arr = np.asarray(arr, dtype=np.float32) + return arr[:, None] if arr.ndim == 1 else arr + + +def _covar_var_names(covariates: Covariates) -> tuple[list[str], list[str]]: + """Column names for the SDK's ``historical_variables`` / ``future_variables``. + + Derived from the per-series covariate width (assumed homogeneous across + series). Empty lists when a covariate kind is absent — which is exactly + what the objective produces after masking off a deactivated capability. + """ + hist_names, future_names = [], [] + if covariates.hist_covars: + n = _as_2d(covariates.hist_covars[0]).shape[1] + hist_names = [f"hist_{j}" for j in range(n)] + if covariates.future_covars: + n = _as_2d(covariates.future_covars[0]).shape[1] + future_names = [f"future_{j}" for j in range(n)] + return hist_names, future_names + + +def _attach_covars(frame, covariates: Covariates, series_idx: int): + """Add this series' covariate columns to a per-``unique_id`` frame. + + Covariates are series-level, so every channel frame of a series gets the + same columns. Arrays span the full series length ``T`` (history *and* + horizon), so future-covariate values for each cutoff's horizon are present. + """ + if covariates.hist_covars: + arr = _as_2d(covariates.hist_covars[series_idx]) + for j in range(arr.shape[1]): + frame[f"hist_{j}"] = arr[:, j] + if covariates.future_covars: + arr = _as_2d(covariates.future_covars[series_idx]) + for j in range(arr.shape[1]): + frame[f"future_{j}"] = arr[:, j] + return frame + + class _TFCAPIForecaster(BaseTSFMAdapter): """Adapter calling the TFC SDK. @@ -106,24 +153,30 @@ def __init__( self.batch_size = batch_size def predict(self, x: ForecastInput) -> ForecastOutput: - # TODO: thread ``x.covariates`` (static/hist/future) through to the SDK - # once the benchmark datasets populate them. Monash currently - # carries none, so the dataclass arrives with empty sequences. + # ``x.covariates`` is already masked by the objective down to this + # adapter's ``covariate_capabilities`` — a deactivated (or + # undeclared) covariate kind arrives as an empty sequence, so the + # column/variable wiring below simply produces nothing for it. series_list, cutoff_indexes = x.x, x.cutoff_indexes + covariates = x.covariates + hist_names, future_names = _covar_var_names(covariates) pd_freq = _to_pandas_freq(self.freq) offsets = _shared_offsets_from_end(series_list, cutoff_indexes) if getattr(self.model, "supports_batching", False) and offsets is not None: per_series, levels = self._predict_batched( - series_list, cutoff_indexes, pd_freq, offsets + series_list, cutoff_indexes, pd_freq, offsets, + covariates, hist_names, future_names, ) else: per_series, levels = self._predict_per_series( - series_list, cutoff_indexes, pd_freq + series_list, cutoff_indexes, pd_freq, + covariates, hist_names, future_names, ) return ForecastOutput(quantiles=per_series, quantile_levels=levels) - def _predict_per_series(self, x, cutoff_indexes, pd_freq): + def _predict_per_series(self, x, cutoff_indexes, pd_freq, + covariates, hist_names, future_names): per_series = [] levels = None for series_idx, (series, cutoffs) in enumerate(zip(x, cutoff_indexes)): @@ -134,11 +187,14 @@ def _predict_per_series(self, x, cutoff_indexes, pd_freq): index = pd.date_range("2000-01-01", periods=T, freq=pd_freq) frames = [ - pd.DataFrame({ - "unique_id": f"s{series_idx}_c{c}", - "ds": index, - "target": series[:, c], - }) + _attach_covars( + pd.DataFrame({ + "unique_id": f"s{series_idx}_c{c}", + "ds": index, + "target": series[:, c], + }), + covariates, series_idx, + ) for c in range(C) ] train_df = pd.concat(frames, ignore_index=True) @@ -155,6 +211,8 @@ def _predict_per_series(self, x, cutoff_indexes, pd_freq): add_holidays=self.add_holidays, add_events=self.add_events, country_isocode=self.country_isocode, + historical_variables=hist_names or None, + future_variables=future_names or None, batch_size=self.batch_size, ) @@ -165,7 +223,8 @@ def _predict_per_series(self, x, cutoff_indexes, pd_freq): levels = series_levels return per_series, (levels if levels is not None else (0.5,)) - def _predict_batched(self, x, cutoff_indexes, pd_freq, offsets): + def _predict_batched(self, x, cutoff_indexes, pd_freq, offsets, + covariates, hist_names, future_names): """One ``cross_validate`` call covering every series in ``x``. Series are aligned to share an end date so all cutoffs collapse to @@ -183,11 +242,14 @@ def _predict_batched(self, x, cutoff_indexes, pd_freq, offsets): index = pd.date_range(end=end, periods=T, freq=pd_freq) for c in range(C): frames.append( - pd.DataFrame({ - "unique_id": f"s{series_idx}_c{c}", - "ds": index, - "target": series[:, c], - }) + _attach_covars( + pd.DataFrame({ + "unique_id": f"s{series_idx}_c{c}", + "ds": index, + "target": series[:, c], + }), + covariates, series_idx, + ) ) per_series_meta.append((series_idx, C, index, cutoffs)) @@ -209,6 +271,8 @@ def _predict_batched(self, x, cutoff_indexes, pd_freq, offsets): add_holidays=self.add_holidays, add_events=self.add_events, country_isocode=self.country_isocode, + historical_variables=hist_names or None, + future_variables=future_names or None, batch_size=self.batch_size, ) @@ -271,6 +335,12 @@ class Solver(BaseSolver): ``country_isocode`` to be set. country_isocode : str or None ISO country code (e.g. ``"US"``) used by the holiday/event lookup. + use_hist_covars, use_future_covars : bool + Whether to feed the dataset's historical / future covariates to the + model. Default ``True``; sweep over ``[True, False]`` to benchmark the + lift each covariate kind provides. Deactivating both runs the model + univariate. (The objective enforces this by masking the covariate + payload — see :mod:`benchmark_utils.capabilities`.) batch_size : int Series-per-batch for batching-enabled models (chronos-2, moirai-2). """ @@ -281,12 +351,20 @@ class Solver(BaseSolver): sampling_strategy = "run_once" + # Declared capabilities (metadata). ``multivariate`` is declarative only — + # targets are always passed whole, so there is no behavioural toggle for + # it yet. The two covariate capabilities are wired end-to-end and + # switchable via ``use_hist_covars`` / ``use_future_covars``. + capabilities = frozenset({MULTIVARIATE, HIST_COVARIATES, FUTURE_COVARIATES}) + parameters = { "model": ["chronos-2"], "context": [None], "add_holidays": [False], "add_events": [False], "country_isocode": [None], + "use_hist_covars": [True], + "use_future_covars": [True], "batch_size": [256], } @@ -329,6 +407,15 @@ def run(self, _): country_isocode=self.country_isocode, batch_size=self.batch_size, ) + # Effective active covariate capabilities for this run = the toggled-on + # ones, intersected with what the model declares. The objective reads + # this to mask the covariate payload before calling predict(). + active = set() + if self.use_hist_covars: + active.add(HIST_COVARIATES) + if self.use_future_covars: + active.add(FUTURE_COVARIATES) + self._adapter.covariate_capabilities = frozenset(active & self.capabilities) def get_result(self): return {"model": self._adapter} diff --git a/solvers/toto2.py b/solvers/toto2.py index 207797f..6cd5d22 100644 --- a/solvers/toto2.py +++ b/solvers/toto2.py @@ -22,6 +22,7 @@ ) from benchmark_utils.adapters.base import BaseTSFMAdapter from benchmark_utils.adapters.forecast_residual import ForecastResidualAdapter +from benchmark_utils.capabilities import MULTIVARIATE from benchmark_utils.inputs import ForecastInput from benchmark_utils.outputs import ForecastOutput @@ -265,6 +266,11 @@ class Solver(BaseSolver): sampling_strategy = "run_once" + # Toto models channels jointly. ``multivariate`` is declarative metadata + # only — targets are always passed whole, so there is no behavioural + # toggle yet. This solver does not consume covariates. + capabilities = frozenset({MULTIVARIATE}) + parameters = { "checkpoint": ["Datadog/Toto-2.0-2.5B"], "context_length": [512], diff --git a/tests/benchmark_utils/test_capabilities.py b/tests/benchmark_utils/test_capabilities.py new file mode 100644 index 0000000..9f19187 --- /dev/null +++ b/tests/benchmark_utils/test_capabilities.py @@ -0,0 +1,53 @@ +"""Tests for the covariate-masking helper.""" + +import numpy as np +import pytest + +from benchmark_utils.capabilities import ( + FUTURE_COVARIATES, + HIST_COVARIATES, + mask_covariates, +) +from benchmark_utils.covariates import Covariates + + +@pytest.fixture +def covariates(): + return Covariates( + static_covars=[np.array([1.0, 2.0])], + hist_covars=[np.zeros((10, 1))], + future_covars=[np.ones((10, 2))], + ) + + +def test_empty_active_drops_both_covariates(covariates): + masked = mask_covariates(covariates, frozenset()) + assert masked.hist_covars == [] + assert masked.future_covars == [] + # static is not part of the vocabulary — passed through untouched. + assert masked.static_covars is covariates.static_covars + + +def test_hist_only_keeps_hist_drops_future(covariates): + masked = mask_covariates(covariates, {HIST_COVARIATES}) + assert masked.hist_covars is covariates.hist_covars + assert masked.future_covars == [] + + +def test_future_only_keeps_future_drops_hist(covariates): + masked = mask_covariates(covariates, {FUTURE_COVARIATES}) + assert masked.future_covars is covariates.future_covars + assert masked.hist_covars == [] + + +def test_both_active_preserves_all(covariates): + masked = mask_covariates(covariates, {HIST_COVARIATES, FUTURE_COVARIATES}) + assert masked.hist_covars is covariates.hist_covars + assert masked.future_covars is covariates.future_covars + + +def test_input_not_mutated(covariates): + mask_covariates(covariates, frozenset()) + # Original still intact. + assert len(covariates.hist_covars) == 1 + assert len(covariates.future_covars) == 1 diff --git a/tests/solvers/test_tfc_api_covariates.py b/tests/solvers/test_tfc_api_covariates.py new file mode 100644 index 0000000..d352d21 --- /dev/null +++ b/tests/solvers/test_tfc_api_covariates.py @@ -0,0 +1,88 @@ +"""TFC-API adapter threads covariates through to the SDK. + +Uses a fake ``cross_validate`` that records its kwargs and returns a minimal +forecast frame — no network, no API key. The adapter receives whatever +covariates the objective hands it (already masked), so passing an empty +``Covariates`` here exercises the deactivated path. +""" + +import importlib + +import numpy as np +import pandas as pd + +from benchmark_utils.covariates import Covariates +from benchmark_utils.inputs import ForecastInput + +tfc_api = importlib.import_module("solvers.tfc_api") +_TFCAPIForecaster = tfc_api._TFCAPIForecaster + +MODEL = "mymodel" # a str has no ``supports_batching`` → per-series path +H = 3 + + +class _FakeClient: + def __init__(self): + self.calls = [] + + def cross_validate(self, train_df, *, model, horizon, freq, fcds, quantiles, + context, add_holidays, add_events, country_isocode, + historical_variables, future_variables, batch_size): + self.calls.append({ + "columns": list(train_df.columns), + "historical_variables": historical_variables, + "future_variables": future_variables, + }) + rows = [] + for uid in train_df["unique_id"].unique(): + ds_vals = sorted(train_df.loc[train_df["unique_id"] == uid, "ds"].unique()) + for fcd in fcds: + future_ds = [d for d in ds_vals if d > fcd][:horizon] + for d in future_ds: + rows.append( + {"unique_id": uid, "ds": d, "fcd": fcd, str(model): 0.0} + ) + return pd.DataFrame(rows) + + +def _adapter(client): + return _TFCAPIForecaster( + client=client, model=MODEL, prediction_length=H, freq="D", + context=None, quantiles=None, add_holidays=False, add_events=False, + country_isocode=None, batch_size=256, + ) + + +def _forecast_input(covariates): + series = np.arange(10, dtype=np.float32)[:, None] # (10, 1) + return ForecastInput( + x=[series], cutoff_indexes=[[5]], covariates=covariates + ) + + +def test_covariates_are_passed_to_sdk(): + client = _FakeClient() + out = _adapter(client).predict(_forecast_input(Covariates( + static_covars=[], + hist_covars=[np.zeros((10, 1), dtype=np.float32)], + future_covars=[np.ones((10, 2), dtype=np.float32)], + ))) + + call = client.calls[0] + assert call["historical_variables"] == ["hist_0"] + assert call["future_variables"] == ["future_0", "future_1"] + assert {"hist_0", "future_0", "future_1"}.issubset(call["columns"]) + # Sanity: a well-formed forecast came back. + assert out.point[0].shape == (1, H, 1) + + +def test_no_covariates_passes_none(): + client = _FakeClient() + _adapter(client).predict(_forecast_input(Covariates())) + + call = client.calls[0] + assert call["historical_variables"] is None + assert call["future_variables"] is None + assert not any( + c.startswith(("hist_", "future_")) for c in call["columns"] + ) diff --git a/tests/test_objective_capability_masking.py b/tests/test_objective_capability_masking.py new file mode 100644 index 0000000..eff1e50 --- /dev/null +++ b/tests/test_objective_capability_masking.py @@ -0,0 +1,86 @@ +"""The objective masks covariates to the adapter's declared capabilities. + +Enforcement lives in ``Objective._eval_forecasting`` so it is guaranteed for +every forecasting model, not reimplemented per solver. +""" + +import numpy as np +import pytest + +from benchmark_utils.adapters.base import BaseTSFMAdapter +from benchmark_utils.capabilities import FUTURE_COVARIATES, HIST_COVARIATES +from benchmark_utils.covariates import Covariates +from benchmark_utils.outputs import ForecastOutput + +H, C = 3, 1 + + +class _RecordingAdapter(BaseTSFMAdapter): + """Captures the covariates it is handed, returns a valid zero forecast.""" + + def __init__(self, covariate_capabilities=frozenset()): + self.covariate_capabilities = frozenset(covariate_capabilities) + self.seen = None + + def predict(self, x): + self.seen = x.covariates + qs = [ + np.zeros((len(cutoffs), 1, H, C), dtype=np.float32) + for cutoffs in x.cutoff_indexes + ] + return ForecastOutput(quantiles=qs, quantile_levels=(0.5,)) + + +def _make_objective(): + from objective import Objective + + obj = Objective.get_instance() + series = np.arange(10, dtype=np.float32)[:, None] # (10, 1) + cutoffs = [5] + covariates = Covariates( + static_covars=[], + hist_covars=[np.zeros((10, 1), dtype=np.float32)], + future_covars=[np.ones((10, 2), dtype=np.float32)], + ) + obj.set_data( + X_train=[series[:5]], + y_train=[series[5:8]], + X_test=[series], + y_test=[np.zeros((1, H, C), dtype=np.float32)], + cutoff_indexes=[cutoffs], + covariates=covariates, + task="forecasting", + metrics=["mae"], + prediction_length=H, + ) + return obj + + +@pytest.mark.parametrize( + "active, expect_hist, expect_future", + [ + (frozenset(), False, False), + ({HIST_COVARIATES}, True, False), + ({FUTURE_COVARIATES}, False, True), + ({HIST_COVARIATES, FUTURE_COVARIATES}, True, True), + ], +) +def test_objective_masks_to_adapter_capabilities(active, expect_hist, expect_future): + obj = _make_objective() + adapter = _RecordingAdapter(active) + + obj.evaluate_result(adapter) + + assert (len(adapter.seen.hist_covars) > 0) is expect_hist + assert (len(adapter.seen.future_covars) > 0) is expect_future + + +def test_default_adapter_sees_no_covariates(): + """An adapter that declares nothing (base default) runs univariate.""" + obj = _make_objective() + adapter = _RecordingAdapter() # inherits empty covariate_capabilities + + obj.evaluate_result(adapter) + + assert adapter.seen.hist_covars == [] + assert adapter.seen.future_covars == []