Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/recipes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Total: **75** (model, task) tuples that pass fp16 eval on all 10 (EP, device) bu
| ahotrod/electra_large_discriminator_squad2_512 | question-answering |
| apple/mobilevit-small | image-classification |
| cardiffnlp/twitter-roberta-base-sentiment-latest | text-classification |
| dandelin/vilt-b32-finetuned-vqa | visual-question-answering |
| dbmdz/bert-large-cased-finetuned-conll03-english | token-classification |
| deepset/bert-large-uncased-whole-word-masking-squad2 | question-answering |
| deepset/roberta-base-squad2 | question-answering |
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
{
"export": {
"opset_version": 17,
"batch_size": 1,
"export_params": true,
"do_constant_folding": true,
"verbose": false,
"dynamo": false,
"enable_hierarchy_tags": true,
"clean_onnx": false,
"hierarchy_tag_format": "full",
"input_tensors": [
{
"name": "input_ids",
"dtype": "int32",
"shape": [
1,
40
],
"value_range": [
0,
30522
]
},
{
"name": "attention_mask",
"dtype": "int32",
"shape": [
1,
40
],
"value_range": [

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recipe contradicts the PR's own reviewer note on mask value_range. The PR description says "Recipe value_range for mask-of-ones inputs must be [1, 2] not [0, 1] because randint high is exclusive." But here attention_mask (and token_type_ids below) use [0, 2]. It's functionally harmless for export — tracing doesn't depend on input values and the PT-vs-ORT check feeds identical inputs to both — but the note and the file disagree and will mislead anyone re-deriving this recipe. Please reconcile (fix the note or the values).

0,
2
]
},
{
"name": "token_type_ids",
"dtype": "int32",
"shape": [
1,
40
],
"value_range": [
0,
2
]
},
{
"name": "pixel_values",
"dtype": "float32",
"shape": [
1,
3,
384,
384
],
"value_range": [
0,
1
]
}
],
"output_tensors": [
{
"name": "logits"
}
]
},
"optim": {},
"quant": null,
"compile": null,
"loader": {
"task": "visual-question-answering",
"model_class": "ViltForQuestionAnswering",
"model_type": "vilt"
}
}
3 changes: 3 additions & 0 deletions src/winml/modelkit/models/hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
VisionDecoderIOConfig as _VisionDecoderIOConfig, # triggers registration
)
from .vision_encoder_decoder import VisionEncoderIOConfig as _VisionEncoderIOConfig
from .vilt import MODEL_CLASS_MAPPING as _VILT_CLASS_MAPPING

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will fail CI lint (ruff I001). The two new .vilt imports are placed after .vision_encoder_decoder, but vilt sorts before vision_encoder_decoder (vil < vis). Ruff's I rule is enabled and this file's per-file-ignores (D104, E402, F401, F403) don't exempt I, so ruff check errors with I001 Import block is un-sorted (confirmed against this branch). Per the repo CLAUDE.md ("Run uv run ruff check --fix after revising Python code"), running that reorders the block and resolves it.

from .vilt import ViltVqaOnnxConfig as _ViltVqaOnnxConfig # triggers registration
from .zoedepth import ZoeDepthIOConfig as _ZoeDepthIOConfig # triggers registration


Expand All @@ -97,6 +99,7 @@
**_SIGLIP_CLASS_MAPPING,
**_T5_CLASS_MAPPING,
**_VED_CLASS_MAPPING,
**_VILT_CLASS_MAPPING,
}

# Registry: model_type -> WinMLBuildConfig
Expand Down
242 changes: 242 additions & 0 deletions src/winml/modelkit/models/hf/vilt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""ViLT (Vision-and-Language Transformer) HuggingFace Model Configuration.

ViLT is a single-stream multi-modal transformer that processes text + image
in a unified attention stack. The ``ViltForQuestionAnswering`` head produces
classification logits over a fixed VQAv2 answer vocabulary (3129 labels for
``dandelin/vilt-b32-finetuned-vqa``).

Optimum has NO vendor ``ViltOnnxConfig`` (verified 2026-06-24: ``vilt`` is
absent from ``TasksManager._SUPPORTED_MODEL_TYPE`` for the transformers
library). This module writes the export config from scratch.

The forward takes 4 required tensors (pixel_mask is omitted — see Notes):
- ``pixel_values`` [B, 3, 384, 384] RGB image
- ``input_ids`` [B, 40] tokenized question
- ``attention_mask`` [B, 40] text padding mask
- ``token_type_ids`` [B, 40] BERT segment IDs (modality)

Output: ``logits`` [B, num_labels] over the answer vocabulary.

Notes
-----
ViLT's stock ``visual_embed`` is fundamentally NOT ONNX-traceable: it iterates
Python-level over tensor values (``for h, w in zip(x_h, x_w)``), uses
``torch.multinomial`` (random + non-exportable), and does per-row Python loops
over ``nonzero()`` results. We replace it during export with a statically-
shaped equivalent (see ``_patched_visual_embed`` + ``_ViltVisualEmbedPatcher``)
that assumes an all-ones ``pixel_mask`` — which is exactly what ``ViltProcessor``
emits in production (the processor pre-pads images to 384×384). Because the
patched path ignores ``pixel_mask``, we drop it from the exported ONNX graph.
Verified numerically equivalent: ``cos=1.000000``, same argmax,
max_abs_diff≈1.2e-5.

This is an Effort-L1 contribution per the `adding-model-support` skill:
new OnnxConfig from scratch + custom model patcher.
"""

from __future__ import annotations

import types

from optimum.exporters.onnx import OnnxConfig
from optimum.exporters.onnx.model_patcher import ModelPatcher
from optimum.utils import NormalizedTextConfig
from optimum.utils.input_generators import DummyVisionInputGenerator
from transformers import ViltForQuestionAnswering

from ...export import MaxLengthTextInputGenerator, register_onnx_overwrite


# =============================================================================
# Export-time patch for ``ViltEmbeddings.visual_embed``
# =============================================================================
# Upstream ``visual_embed`` is **not ONNX-traceable** as written:
# * ``for h, w in zip(x_h, x_w)`` iterates Python-level over tensor values
# * ``nonzero()`` + ``unique()`` + per-row Python list-comprehension subset
# selection over a dynamic ``valid_idx``
# * ``torch.multinomial`` random sampling (non-deterministic, not exportable)
# The eager path silently "works" only when ``pixel_mask`` is all-ones and the
# batch is 1, because the Python loop runs once with a concrete value. Under
# legacy ``torch.onnx.export`` tracing the shape resolves to 0 and PyTorch's
# ``F.interpolate`` aborts with ``input (H: 12, W: 12) output (H: 0, W: 0)``.
#
# For the production ``visual-question-answering`` inference path the
# ``ViltProcessor`` ALWAYS pads to 384×384 and emits an all-ones ``pixel_mask``,
# so the per-sample subset selection is a no-op. We replace ``visual_embed``
# during export with a simplified, statically-shaped implementation that:
# * uses ``x.shape[2], x.shape[3]`` (static) for position-embed interpolation
# * skips ``multinomial`` / nonzero / Python-level batch loops
# * returns an all-ones token mask of length ``H*W + 1`` (patches + CLS)
#
# Verified numerically equivalent on ``dandelin/vilt-b32-finetuned-vqa`` with
# fixed seed: ``cos=1.000000``, same ``argmax`` class, ``max_abs_diff≈1.2e-5``
# (within fp tolerance from interpolation operation ordering).


def _patched_visual_embed(self, pixel_values, pixel_mask, max_image_length=200):
"""Static-shape, ONNX-traceable replacement for ``ViltEmbeddings.visual_embed``."""
import torch
from torch import nn

x = self.patch_embeddings(pixel_values)
batch_size, num_channels, height, width = x.shape

patch_dim = self.config.image_size // self.config.patch_size
spatial_pos = self.position_embeddings[:, 1:, :].transpose(1, 2).view(
1, num_channels, patch_dim, patch_dim
)
pos_embed = nn.functional.interpolate(
spatial_pos, size=(height, width), mode="bilinear", align_corners=True
)
pos_embed = pos_embed.flatten(2).transpose(1, 2).expand(batch_size, -1, -1)

x = x.flatten(2).transpose(1, 2)

cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
pos_cls = self.position_embeddings[:, 0:1, :].expand(batch_size, -1, -1)
pos_embed = torch.cat((pos_cls, pos_embed), dim=1)
x = x + pos_embed
x = self.dropout(x)

num_tokens = height * width + 1 # patches + CLS
x_mask = torch.ones(batch_size, num_tokens, dtype=torch.long, device=x.device)
return x, x_mask, None


class _ViltVisualEmbedPatcher(ModelPatcher):
"""Swaps ``ViltEmbeddings.visual_embed`` for the duration of ONNX export."""

def __enter__(self):
super().__enter__()
emb = self._model.vilt.embeddings if hasattr(self._model, "vilt") else self._model.embeddings
self._emb_ref = emb
self._orig_visual_embed = emb.visual_embed
emb.visual_embed = types.MethodType(_patched_visual_embed, emb)
return self

def __exit__(self, exc_type, exc_value, traceback):
self._emb_ref.visual_embed = self._orig_visual_embed
super().__exit__(exc_type, exc_value, traceback)


# =============================================================================
# Optimum ONNX Export Config Registration
# =============================================================================
@register_onnx_overwrite("vilt", "visual-question-answering", library_name="transformers")
class ViltVqaOnnxConfig(OnnxConfig):
"""ONNX export config for ``ViltForQuestionAnswering``.

Declares 4 multi-modal inputs (text triple + pixel_values) and the single
classification output. ``pixel_mask`` is deliberately omitted — see
``inputs`` property docstring and the module-level ``Notes`` section for
the full rationale.

Inputs:
- ``input_ids``: [B, 40] int64
- ``attention_mask``: [B, 40] int64
- ``token_type_ids``: [B, 40] int64
- ``pixel_values``: [B, 3, 384, 384] float32

Outputs:
- ``logits``: [B, num_labels=3129] float32

Notes:
- ``num_labels`` (3129 for VQAv2) is a config-time fact, not declared
dynamic in the symbolic axes — it's a static dim of ``logits``.
- ``sequence_length`` resolves to ``max_position_embeddings`` (40 for
ViLT-B/32) via ``NORMALIZED_CONFIG_CLASS``; the
``MaxLengthTextInputGenerator`` reads this for dummy tokens.
- Chained ``DummyVisionInputGenerator`` + ``MaxLengthTextInputGenerator``
produce ``pixel_values`` + ``input_ids``/``attention_mask``/
``token_type_ids``. The patched ``visual_embed`` (see module-level
``_ViltVisualEmbedPatcher``) synthesizes an all-ones token mask
internally, so no ``pixel_mask`` input is required.
"""

NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args(
sequence_length="max_position_embeddings",
num_channels="num_channels",
image_size="image_size",
patch_size="patch_size",
allow_new=True,
)

DUMMY_INPUT_GENERATOR_CLASSES = (
DummyVisionInputGenerator,
MaxLengthTextInputGenerator,
)

DEFAULT_ONNX_OPSET = 17

@property
def inputs(self) -> dict[str, dict[int, str]]:
"""Declare 4 model inputs (insertion order matches forward).

``pixel_values`` H,W is kept STATIC — ViLT interpolates position
embeddings from the actual H,W, and exposing those as dynamic symbols
trips the ONNX ``Resize`` shape-inference (``input (H:12 W:12) output
(H:0 W:0)``). Pinning H,W matches all known production usage (always
384×384 input via ``ViltProcessor``).

Note: ViLT's ``forward`` also takes a ``pixel_mask`` parameter, but
this contribution exports without it. The ``ViltProcessor`` always
emits an all-ones mask (the image is padded to 384×384 before the
model sees it), and our export-time ``ModelPatcher`` replaces the
original ``visual_embed`` with a statically-shaped version that
synthesizes an all-ones token mask internally. Including ``pixel_mask``
as an ONNX input would dead-code-eliminate (since the patched path
doesn't reference it) and confuse runtime callers.
"""
return {
"input_ids": {0: "batch_size", 1: "sequence_length"},
"attention_mask": {0: "batch_size", 1: "sequence_length"},
"token_type_ids": {0: "batch_size", 1: "sequence_length"},
"pixel_values": {0: "batch_size"},

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Static pixel_values H/W silently restricts the export to square images, and the justification here is inaccurate.

I checked ViltImageProcessor against the runtime. The default is size={"shortest_edge": 384}, size_divisor=32 — it pins the shorter edge to 384 and preserves aspect ratio (longer edge ≈ up to int(1333/800*384), floored to a multiple of 32), then pads to the per-batch max. It does not pad to 384×384. Empirically (batch=1, defaults):

input H×W pixel_values pixel_mask all-ones
384×384 (1,3,384,384)
480×640 (1,3,384,512)
640×480 (1,3,512,384)
800×600 (1,3,512,384)

So the all-ones pixel_mask assumption is correct (good), but "always 384×384 via ViltProcessor" (docstring L183) is not — only square inputs land on 384×384. Because inputs marks only batch_size dynamic, the exported ONNX accepts only 384×384; a standard non-square ViltProcessor output (e.g. 384×512) fails at session.run, or forces callers to square-resize (distorts aspect ratio → VQA-accuracy risk). L2 numerics likely passed because validation used a 384×384 input.

Also: the cited reason for pinning — "dynamic symbols trip Resize shape-inference (H:12 W:12 → H:0 W:0)" (L181–183) — describes the original visual_embed (pixel_mask.sum()→0 under tracing), not the patched path. The patch replaced that with a static-grid bilinear interpolate on real dims, so the 0×0 premise no longer applies and dynamic H/W may export fine now.

Suggest either (a) make H/W dynamic — the patch already interpolates to actual x.shape[2], x.shape[3], so re-test the Resize export — or (b) keep it static but document the real constraint honestly (square-384-only + the preprocessing/accuracy caveat) instead of asserting the processor always emits 384×384.

}

@property
def outputs(self) -> dict[str, dict[int, str]]:
"""Single classification output over fixed answer vocabulary."""
return {
"logits": {0: "batch_size"},
}

def generate_dummy_inputs(self, framework: str = "pt", **kwargs): # type: ignore[override]
"""Generate the 4 declared inputs via the chained vendor generators.

``pixel_mask`` is intentionally NOT generated — see ``inputs`` docstring.
Our model patcher's replacement ``visual_embed`` synthesizes an
all-ones token mask internally, so the model can be called with the
4 declared inputs.
"""
dummy = super().generate_dummy_inputs(framework=framework, **kwargs)
# Drop any pixel_mask the generators may have produced — the patched
# visual_embed ignores it (and including it would error at sess.run
# since it isn't in the exported ONNX graph).
dummy.pop("pixel_mask", None)
return dummy

def patch_model_for_export(self, model, model_kwargs=None): # type: ignore[override]
"""Install the ``visual_embed`` patcher for the export context."""
return _ViltVisualEmbedPatcher(self, model, model_kwargs=model_kwargs)


# =============================================================================
# HuggingFace Model Class Mapping
# =============================================================================
# ``visual-question-answering`` has no default AutoModel routing for ViLT;
# bind the (model_type, task) tuple directly to the head-bearing HF class.
MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = {
("vilt", "visual-question-answering"): ViltForQuestionAnswering,
}


__all__ = [
"ViltVqaOnnxConfig",
"MODEL_CLASS_MAPPING",
]
Loading