-
Notifications
You must be signed in to change notification settings - Fork 21
ENH model capability flags + per-capability deactivation (covariate lift) #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,75 @@ | ||
| """Model capability flags and covariate masking. | ||
|
|
||
| Vocabulary | ||
| ---------- | ||
| A forecasting solver declares a ``capabilities`` set drawn from: | ||
|
|
||
| - :data:`MULTIVARIATE` — the model treats target channels jointly. | ||
| *Declarative only*: targets are always passed whole (no channel | ||
| splitting), so there is no behavioural toggle for this yet — it exists to | ||
| describe the model until a multivariate-*target* dataset and the matching | ||
| masking land. | ||
| - :data:`HIST_COVARIATES` — the model consumes history-only (past) covariates. | ||
| - :data:`FUTURE_COVARIATES` — the model consumes known-ahead (future) covariates. | ||
|
|
||
| ``univariate`` is deliberately **not** a flag — it is the floor every model | ||
| gets. A model that declares (or has enabled) none of the covariate | ||
| capabilities runs univariate. | ||
|
|
||
| Deactivation / lift | ||
| ------------------- | ||
| The covariate capabilities are independently switchable per run (exposed as | ||
| benchopt parameters by the consuming solver), so the lift each one provides | ||
| can be benchmarked. Enforcement is central: the objective masks the | ||
| :class:`~benchmark_utils.covariates.Covariates` payload down to the adapter's | ||
| *effective* active set (``BaseTSFMAdapter.covariate_capabilities``) via | ||
| :func:`mask_covariates` before calling ``predict``. A model therefore only | ||
| ever sees covariates it both declares and has enabled. Targets are never | ||
| masked. | ||
| """ | ||
|
|
||
| from benchmark_utils.covariates import Covariates | ||
|
|
||
| MULTIVARIATE = "multivariate" | ||
| HIST_COVARIATES = "hist_covariates" | ||
| FUTURE_COVARIATES = "future_covariates" | ||
|
Comment on lines
+33
to
+35
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Probably better modeled as an enum |
||
|
|
||
| #: Capabilities whose covariate payload :func:`mask_covariates` acts on. | ||
| COVARIATE_CAPABILITIES = frozenset({HIST_COVARIATES, FUTURE_COVARIATES}) | ||
|
|
||
| #: Every capability in the vocabulary. | ||
| ALL_CAPABILITIES = frozenset({MULTIVARIATE, HIST_COVARIATES, FUTURE_COVARIATES}) | ||
|
|
||
|
|
||
| def mask_covariates(covariates: Covariates, active) -> Covariates: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing type for |
||
| """Return a copy of ``covariates`` with disabled covariate fields emptied. | ||
|
|
||
| ``hist_covars`` is cleared unless :data:`HIST_COVARIATES` is in ``active``, | ||
| and ``future_covars`` unless :data:`FUTURE_COVARIATES` is in ``active``. | ||
| ``static_covars`` is passed through unchanged — it is not yet part of the | ||
| capability vocabulary. Targets live in ``ForecastInput.x`` and are never | ||
| touched here. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| covariates : Covariates | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Types in doc duplicated with method signature. |
||
| The dataset's full covariate payload. | ||
| active : Iterable[str] | ||
| The effective active capability names (typically an adapter's | ||
| ``covariate_capabilities``). | ||
|
|
||
| Returns | ||
| ------- | ||
| Covariates | ||
| A new (frozen) instance; the input is not mutated. | ||
| """ | ||
| active = frozenset(active) | ||
| return Covariates( | ||
| static_covars=covariates.static_covars, | ||
| hist_covars=( | ||
| covariates.hist_covars if HIST_COVARIATES in active else [] | ||
| ), | ||
| future_covars=( | ||
| covariates.future_covars if FUTURE_COVARIATES in active else [] | ||
| ), | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -117,13 +117,19 @@ def evaluate_result(self, model): | |
| # --- forecasting --------------------------------------------------- | ||
|
|
||
| def _eval_forecasting(self, model): | ||
| from benchmark_utils.capabilities import mask_covariates | ||
| from benchmark_utils.inputs import ForecastInput | ||
|
|
||
| # Mask the covariate payload down to what this model declares it can | ||
| # use and has enabled. A model that consumes no covariates (the | ||
| # default) thus runs univariate; toggling a capability off here is | ||
| # what makes its lift measurable. Targets are never masked. | ||
| active = getattr(model, "covariate_capabilities", frozenset()) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd suggest not just silently using |
||
| forecast = model.predict( | ||
| ForecastInput( | ||
| x=self.X_test, | ||
| cutoff_indexes=self.cutoff_indexes, | ||
| covariates=self.covariates, | ||
| covariates=mask_covariates(self.covariates, active), | ||
| ) | ||
| ).flatten() # canonical (M, Q, H, C) shape for metrics | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be more specific, e.g.,
frozenset[str]or elements from an enum?