Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 61 additions & 64 deletions src/winml/modelkit/loader/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,26 @@

logger = logging.getLogger(__name__)

# Task abbreviations for cache keys (47 tasks from HuggingFace Transformers)
TASK_ABBREV: dict[str, str] = {
# Vision tasks
# =============================================================================
# Task registry — single source of truth
# =============================================================================
#
# ``_TASK_REGISTRY`` is the one authoritative table of canonical task names that
# ``winml`` recognizes, each paired with its cache-key abbreviation (``None`` ->
# ``get_task_abbrev`` falls back to an 8-char truncation). Both public views are
# *derived* from it, so they can no longer drift apart (the root cause of #724):
#
# * ``KNOWN_TASKS`` (validation / ``inspect --list-tasks``) = the names
# * ``TASK_ABBREV`` (cache-key abbreviations) = the (name, abbrev) pairs
#
# Kept hand-curated (not imported from optimum) so the ``--list-tasks`` fast path
# stays import-cheap — importing ``optimum.exporters`` would transitively pull in
# ``transformers`` and cost ~10s. ``tests/unit/loader/test_known_tasks.py`` guards
# this set against optimum's task list, ``HF_TASK_DEFAULTS``,
# ``HF_MODEL_CLASS_MAPPING`` and ``inference.tasks`` so any task added there fails
# CI until it is mirrored here.
_TASK_REGISTRY: dict[str, str | None] = {
# Vision
"image-classification": "imgcls",
"image-segmentation": "imgseg",
"image-feature-extraction": "imgfeat",
Expand All @@ -37,96 +54,76 @@
"image-text-to-text": "imgtxt2t",
"object-detection": "objdet",
"depth-estimation": "depth",
"instance-segmentation": "instseg",
"semantic-segmentation": "semseg",
"universal-segmentation": "uniseg",
"keypoint-detection": "kptdet",
"keypoint-matching": "kptmtch",
"mask-generation": "maskgen",
"masked-image-modeling": "mskim",
"masked-im": None,
"video-classification": "vidcls",
"zero-shot-image-classification": "zsimg",
"zero-shot-object-detection": "zsobj",
# NLP tasks
"inpainting": None,
"text-to-image": None,
# NLP
"text-classification": "txtcls",
"sequence-classification": "seqcls",
"token-classification": "tokcls",
"question-answering": "qa",
"text-generation": "txtgen",
"text2text-generation": "txt2txt",
"fill-mask": "mask",
"feature-extraction": "feat",
"text-encoding": "txtenc",
"summarization": "summ",
"translation": "transl",
"multiple-choice": "mltchs",
"next-sentence-prediction": "nsp",
"pretraining": "pretrain",
"table-question-answering": "tabqa",
"document-question-answering": "docqa",
"zero-shot-classification": "zscls",
# Audio tasks
"sentence-similarity": None,
# Audio
"audio-classification": "audiocls",
"audio-frame-classification": "audfrm",
"audio-tokenization": "audtok",
"audio-xvector": "audxvc",
"automatic-speech-recognition": "asr",
"text-to-audio": "txt2aud",
"zero-shot-audio-classification": "zsaud",
# Multimodal tasks
# Multimodal
"visual-question-answering": "vqa",
"any-to-any": "a2a",
"multimodal-lm": "mmlm",
# Other tasks
"backbone": "bkbone",
"time-series-prediction": "tseries",
# Other
"reinforcement-learning": None,
"time-series-forecasting": None,
}


# Canonical set of task names recognized by `winml inspect`.
# Hand-coded so that `winml inspect --list-tasks` does not need to import
# optimum.exporters (which transitively imports transformers and costs ~10s).
# Synced with optimum.exporters.tasks.TasksManager.get_all_tasks() plus our
# own HF_TASK_DEFAULTS entries; add new tasks here when optimum gains them.
KNOWN_TASKS: frozenset[str] = frozenset(
{
"audio-classification",
"audio-frame-classification",
"audio-xvector",
"automatic-speech-recognition",
"depth-estimation",
"document-question-answering",
"feature-extraction",
"fill-mask",
"image-classification",
"image-feature-extraction",
"image-segmentation",
"image-text-to-text",
"image-to-image",
"image-to-text",
"inpainting",
"keypoint-detection",
"mask-generation",
"masked-im",
"multiple-choice",
"next-sentence-prediction",
"object-detection",
"question-answering",
"reinforcement-learning",
"semantic-segmentation",
"sentence-similarity",
"text-classification",
"text-generation",
"text-to-audio",
"text-to-image",
"text2text-generation",
"time-series-forecasting",
"token-classification",
"visual-question-answering",
"zero-shot-image-classification",
"zero-shot-object-detection",
}
)
# Aliases that ``normalize_task()`` / ``to_optimum_task()`` collapse to canonical
# forms, so they are deliberately *excluded* from ``KNOWN_TASKS``. They keep a
# stable cache-key abbreviation because a few callers still use the alias name
# directly as the resolved task — composite models register ``summarization`` /
# ``translation`` (see ``models/hf/bart.py``, ``t5.py``, ``marian.py``) and
# ``inference.tasks`` defines a ``TaskInputSpec`` for ``zero-shot-classification``
# — so existing cache directories and the ``serve`` reverse-decode map
# (``app.py``: ``{v: k for k, v in TASK_ABBREV.items()}``) must round-trip.
_TASK_ALIAS_ABBREV: dict[str, str] = {
"pretraining": "pretrain",
"sequence-classification": "seqcls",
"summarization": "summ",
"translation": "transl",
"zero-shot-classification": "zscls",
}


# Canonical set of task names recognized by `winml inspect` (names only).
# Derived from `_TASK_REGISTRY` above — do not hand-edit; add tasks to the
# registry instead so `KNOWN_TASKS` and `TASK_ABBREV` stay in lockstep.
KNOWN_TASKS: frozenset[str] = frozenset(_TASK_REGISTRY)


# Task -> abbreviation for cache keys. Derived from `_TASK_REGISTRY`: canonical
# tasks whose abbreviation is `None` are omitted here (and truncated to 8 chars
# by `get_task_abbrev`); the collapsed aliases are appended for cache-key and
# `serve` reverse-decode stability.
TASK_ABBREV: dict[str, str] = {
**{task: abbrev for task, abbrev in _TASK_REGISTRY.items() if abbrev is not None},
**_TASK_ALIAS_ABBREV,
}


# =============================================================================
Expand Down
140 changes: 133 additions & 7 deletions tests/unit/loader/test_known_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,60 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Tests for the hand-coded KNOWN_TASKS constant.
"""Tests for the hand-coded task registry and its derived views.

KNOWN_TASKS is hand-coded so that ``winml inspect --list-tasks`` does not
need to import ``optimum.exporters`` (which transitively imports
``transformers`` and adds ~10 s of startup latency).
``_TASK_REGISTRY`` (and the ``KNOWN_TASKS`` / ``TASK_ABBREV`` views derived
from it) is hand-coded so that ``winml inspect --list-tasks`` does not need to
import ``optimum.exporters`` (which transitively imports ``transformers`` and
adds ~10 s of startup latency).

These tests guard against drift:
* KNOWN_TASKS must be a superset of optimum's TasksManager task set, so
no canonical task disappears from ``--list-tasks`` when optimum adds one.
* KNOWN_TASKS must include every task registered in our own
HF_TASK_DEFAULTS / HF_MODEL_CLASS_MAPPING.
* KNOWN_TASKS and TASK_ABBREV must stay derived from the single
``_TASK_REGISTRY`` so the two views cannot desync (the root cause of #724).
* Every task wired into ``inference.tasks`` must be a known task, so a newly
added canonical task can never be silently rejected by ``validate_task``.
"""

from __future__ import annotations

import pytest

from winml.modelkit.loader import KNOWN_TASKS
from winml.modelkit.loader.task import (
_TASK_ALIAS_ABBREV,
_TASK_REGISTRY,
TASK_ABBREV,
)


# Tasks #724 confirmed are wired into the codebase (inference TaskInputSpec,
# composite-model registration, or e2e testset model tag) and therefore must be
# accepted by validate_task and advertised by ``--list-tasks``.
WIRED_TASKS = (
"video-classification",
"keypoint-matching",
"table-question-answering",
"zero-shot-audio-classification",
"any-to-any",
)

# Entries audited as stale in #724 — they only ever existed in the old
# TASK_ABBREV table (no inference spec, model registration, test, or doc) and
# were dropped. Locked here so they are not silently re-added.
DROPPED_TASKS = (
"instance-segmentation",
"universal-segmentation",
"masked-image-modeling",
"text-encoding",
"audio-tokenization",
"multimodal-lm",
"backbone",
"time-series-prediction",
)


class TestKnownTasksShape:
Expand Down Expand Up @@ -47,7 +85,7 @@ def test_covers_hf_task_defaults(self) -> None:
missing = set(HF_TASK_DEFAULTS) - KNOWN_TASKS
assert not missing, (
f"HF_TASK_DEFAULTS contains tasks not in KNOWN_TASKS: {sorted(missing)}. "
"Add them to KNOWN_TASKS in src/winml/modelkit/loader/task.py."
"Add them to _TASK_REGISTRY in src/winml/modelkit/loader/task.py."
)

def test_covers_hf_model_class_mapping(self) -> None:
Expand All @@ -57,7 +95,7 @@ def test_covers_hf_model_class_mapping(self) -> None:
missing = registered - KNOWN_TASKS
assert not missing, (
f"HF_MODEL_CLASS_MAPPING uses tasks not in KNOWN_TASKS: {sorted(missing)}. "
"Add them to KNOWN_TASKS in src/winml/modelkit/loader/task.py."
"Add them to _TASK_REGISTRY in src/winml/modelkit/loader/task.py."
)

def test_covers_optimum_tasks(self) -> None:
Expand All @@ -74,5 +112,93 @@ def test_covers_optimum_tasks(self) -> None:
missing = optimum_tasks - KNOWN_TASKS
assert not missing, (
f"optimum exposes tasks not in KNOWN_TASKS: {sorted(missing)}. "
"Add them to KNOWN_TASKS in src/winml/modelkit/loader/task.py."
"Add them to _TASK_REGISTRY in src/winml/modelkit/loader/task.py."
)


class TestTaskRegistrySingleSourceOfTruth:
"""KNOWN_TASKS and TASK_ABBREV must derive from one registry (#724).

The root cause of #724 was two hand-maintained tables (``TASK_ABBREV`` and
``KNOWN_TASKS``) that drifted apart. They now both derive from
``_TASK_REGISTRY``; these tests lock that invariant.
"""

def test_known_tasks_equals_registry_keys(self) -> None:
assert frozenset(_TASK_REGISTRY) == KNOWN_TASKS

def test_every_canonical_abbrev_key_is_known(self) -> None:
"""Every non-alias key in TASK_ABBREV must be a known task."""
non_alias = {task for task in TASK_ABBREV if task not in _TASK_ALIAS_ABBREV}
missing = non_alias - KNOWN_TASKS
assert not missing, f"TASK_ABBREV canonical keys not in KNOWN_TASKS: {sorted(missing)}"

def test_aliases_excluded_from_known_tasks(self) -> None:
"""Collapsed aliases keep a cache-key abbrev but are not canonical tasks."""
overlap = set(_TASK_ALIAS_ABBREV) & KNOWN_TASKS
assert not overlap, f"Aliases must not appear in KNOWN_TASKS: {sorted(overlap)}"

def test_abbreviations_are_unique(self) -> None:
"""serve/app.py inverts TASK_ABBREV to {abbrev: task}; collisions lose entries."""
values = list(TASK_ABBREV.values())
assert len(values) == len(set(values)), "Duplicate abbreviations in TASK_ABBREV"


class TestWiredTasksAccepted:
"""Regression for #724: the wired tasks must pass validation, not be rejected."""

@pytest.mark.parametrize("task", WIRED_TASKS)
def test_wired_task_in_known_tasks(self, task: str) -> None:
assert task in KNOWN_TASKS

@pytest.mark.parametrize("task", WIRED_TASKS)
def test_resolver_validate_task_accepts(self, task: str) -> None:
from winml.modelkit.inspect.resolver import validate_task

validate_task(task) # must not raise

@pytest.mark.parametrize("task", WIRED_TASKS)
def test_click_callback_accepts(self, task: str) -> None:
from winml.modelkit.commands.inspect import _validate_task

assert _validate_task(None, None, task) == task # must not raise


class TestStaleTasksDropped:
"""Audit lock for #724: the 8 stale names stay out of every task view."""

@pytest.mark.parametrize("task", DROPPED_TASKS)
def test_dropped_from_known_tasks(self, task: str) -> None:
assert task not in KNOWN_TASKS

@pytest.mark.parametrize("task", DROPPED_TASKS)
def test_dropped_from_task_abbrev(self, task: str) -> None:
assert task not in TASK_ABBREV


class TestInferenceTasksAreKnown:
"""Every inference task must be reachable from the loader task registry.

This is the cross-source guard #724 was missing: a ``TaskInputSpec`` for a
genuinely new canonical task (as ``video-classification`` was) added to
``inference.tasks`` but not to ``_TASK_REGISTRY`` would be silently rejected
by ``validate_task``.
"""

# Pipeline-sugar aliases that live only in inference.tasks (they share another
# task's input schema; there is no separate canonical task). A NEW canonical
# task belongs in _TASK_REGISTRY (loader/task.py), NOT in this allowlist.
INFERENCE_ONLY_ALIASES = frozenset(
{"sentiment-analysis", "ner", "vqa", "text-to-speech"}
)

def test_inference_tasks_are_known_or_alias(self) -> None:
from winml.modelkit.inference.tasks import TASK_REGISTRY as INFERENCE_TASKS

recognized = KNOWN_TASKS | set(_TASK_ALIAS_ABBREV) | self.INFERENCE_ONLY_ALIASES
missing = set(INFERENCE_TASKS) - recognized
assert not missing, (
f"inference.tasks defines tasks unknown to the loader registry: {sorted(missing)}. "
"If these are real canonical tasks, add them to _TASK_REGISTRY in "
"src/winml/modelkit/loader/task.py (do not extend INFERENCE_ONLY_ALIASES)."
)
Loading