-
Notifications
You must be signed in to change notification settings - Fork 4
Add ViLT (dandelin/vilt-b32-finetuned-vqa) visual-question-answering support #951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,78 @@ | ||
| { | ||
| "export": { | ||
| "opset_version": 17, | ||
| "batch_size": 1, | ||
| "export_params": true, | ||
| "do_constant_folding": true, | ||
| "verbose": false, | ||
| "dynamo": false, | ||
| "enable_hierarchy_tags": true, | ||
| "clean_onnx": false, | ||
| "hierarchy_tag_format": "full", | ||
| "input_tensors": [ | ||
| { | ||
| "name": "input_ids", | ||
| "dtype": "int32", | ||
| "shape": [ | ||
| 1, | ||
| 40 | ||
| ], | ||
| "value_range": [ | ||
| 0, | ||
| 30522 | ||
| ] | ||
| }, | ||
| { | ||
| "name": "attention_mask", | ||
| "dtype": "int32", | ||
| "shape": [ | ||
| 1, | ||
| 40 | ||
| ], | ||
| "value_range": [ | ||
| 0, | ||
| 2 | ||
| ] | ||
| }, | ||
| { | ||
| "name": "token_type_ids", | ||
| "dtype": "int32", | ||
| "shape": [ | ||
| 1, | ||
| 40 | ||
| ], | ||
| "value_range": [ | ||
| 0, | ||
| 2 | ||
| ] | ||
| }, | ||
| { | ||
| "name": "pixel_values", | ||
| "dtype": "float32", | ||
| "shape": [ | ||
| 1, | ||
| 3, | ||
| 384, | ||
| 384 | ||
| ], | ||
| "value_range": [ | ||
| 0, | ||
| 1 | ||
| ] | ||
| } | ||
| ], | ||
| "output_tensors": [ | ||
| { | ||
| "name": "logits" | ||
| } | ||
| ] | ||
| }, | ||
| "optim": {}, | ||
| "quant": null, | ||
| "compile": null, | ||
| "loader": { | ||
| "task": "visual-question-answering", | ||
| "model_class": "ViltForQuestionAnswering", | ||
| "model_type": "vilt" | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -75,6 +75,8 @@ | |
| VisionDecoderIOConfig as _VisionDecoderIOConfig, # triggers registration | ||
| ) | ||
| from .vision_encoder_decoder import VisionEncoderIOConfig as _VisionEncoderIOConfig | ||
| from .vilt import MODEL_CLASS_MAPPING as _VILT_CLASS_MAPPING | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will fail CI lint ( |
||
| from .vilt import ViltVqaOnnxConfig as _ViltVqaOnnxConfig # triggers registration | ||
| from .zoedepth import ZoeDepthIOConfig as _ZoeDepthIOConfig # triggers registration | ||
|
|
||
|
|
||
|
|
@@ -97,6 +99,7 @@ | |
| **_SIGLIP_CLASS_MAPPING, | ||
| **_T5_CLASS_MAPPING, | ||
| **_VED_CLASS_MAPPING, | ||
| **_VILT_CLASS_MAPPING, | ||
| } | ||
|
|
||
| # Registry: model_type -> WinMLBuildConfig | ||
|
|
||
| Original file line number | Diff line number | Diff line change | |||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,242 @@ | |||||||||||||||||
| # ------------------------------------------------------------------------- | |||||||||||||||||
| # Copyright (c) Microsoft Corporation. All rights reserved. | |||||||||||||||||
| # Licensed under the MIT License. | |||||||||||||||||
| # -------------------------------------------------------------------------- | |||||||||||||||||
| """ViLT (Vision-and-Language Transformer) HuggingFace Model Configuration. | |||||||||||||||||
|
|
|||||||||||||||||
| ViLT is a single-stream multi-modal transformer that processes text + image | |||||||||||||||||
| in a unified attention stack. The ``ViltForQuestionAnswering`` head produces | |||||||||||||||||
| classification logits over a fixed VQAv2 answer vocabulary (3129 labels for | |||||||||||||||||
| ``dandelin/vilt-b32-finetuned-vqa``). | |||||||||||||||||
|
|
|||||||||||||||||
| Optimum has NO vendor ``ViltOnnxConfig`` (verified 2026-06-24: ``vilt`` is | |||||||||||||||||
| absent from ``TasksManager._SUPPORTED_MODEL_TYPE`` for the transformers | |||||||||||||||||
| library). This module writes the export config from scratch. | |||||||||||||||||
|
|
|||||||||||||||||
| The forward takes 4 required tensors (pixel_mask is omitted — see Notes): | |||||||||||||||||
| - ``pixel_values`` [B, 3, 384, 384] RGB image | |||||||||||||||||
| - ``input_ids`` [B, 40] tokenized question | |||||||||||||||||
| - ``attention_mask`` [B, 40] text padding mask | |||||||||||||||||
| - ``token_type_ids`` [B, 40] BERT segment IDs (modality) | |||||||||||||||||
|
|
|||||||||||||||||
| Output: ``logits`` [B, num_labels] over the answer vocabulary. | |||||||||||||||||
|
|
|||||||||||||||||
| Notes | |||||||||||||||||
| ----- | |||||||||||||||||
| ViLT's stock ``visual_embed`` is fundamentally NOT ONNX-traceable: it iterates | |||||||||||||||||
| Python-level over tensor values (``for h, w in zip(x_h, x_w)``), uses | |||||||||||||||||
| ``torch.multinomial`` (random + non-exportable), and does per-row Python loops | |||||||||||||||||
| over ``nonzero()`` results. We replace it during export with a statically- | |||||||||||||||||
| shaped equivalent (see ``_patched_visual_embed`` + ``_ViltVisualEmbedPatcher``) | |||||||||||||||||
| that assumes an all-ones ``pixel_mask`` — which is exactly what ``ViltProcessor`` | |||||||||||||||||
| emits in production (the processor pre-pads images to 384×384). Because the | |||||||||||||||||
| patched path ignores ``pixel_mask``, we drop it from the exported ONNX graph. | |||||||||||||||||
| Verified numerically equivalent: ``cos=1.000000``, same argmax, | |||||||||||||||||
| max_abs_diff≈1.2e-5. | |||||||||||||||||
|
|
|||||||||||||||||
| This is an Effort-L1 contribution per the `adding-model-support` skill: | |||||||||||||||||
| new OnnxConfig from scratch + custom model patcher. | |||||||||||||||||
| """ | |||||||||||||||||
|
|
|||||||||||||||||
| from __future__ import annotations | |||||||||||||||||
|
|
|||||||||||||||||
| import types | |||||||||||||||||
|
|
|||||||||||||||||
| from optimum.exporters.onnx import OnnxConfig | |||||||||||||||||
| from optimum.exporters.onnx.model_patcher import ModelPatcher | |||||||||||||||||
| from optimum.utils import NormalizedTextConfig | |||||||||||||||||
| from optimum.utils.input_generators import DummyVisionInputGenerator | |||||||||||||||||
| from transformers import ViltForQuestionAnswering | |||||||||||||||||
|
|
|||||||||||||||||
| from ...export import MaxLengthTextInputGenerator, register_onnx_overwrite | |||||||||||||||||
|
|
|||||||||||||||||
|
|
|||||||||||||||||
| # ============================================================================= | |||||||||||||||||
| # Export-time patch for ``ViltEmbeddings.visual_embed`` | |||||||||||||||||
| # ============================================================================= | |||||||||||||||||
| # Upstream ``visual_embed`` is **not ONNX-traceable** as written: | |||||||||||||||||
| # * ``for h, w in zip(x_h, x_w)`` iterates Python-level over tensor values | |||||||||||||||||
| # * ``nonzero()`` + ``unique()`` + per-row Python list-comprehension subset | |||||||||||||||||
| # selection over a dynamic ``valid_idx`` | |||||||||||||||||
| # * ``torch.multinomial`` random sampling (non-deterministic, not exportable) | |||||||||||||||||
| # The eager path silently "works" only when ``pixel_mask`` is all-ones and the | |||||||||||||||||
| # batch is 1, because the Python loop runs once with a concrete value. Under | |||||||||||||||||
| # legacy ``torch.onnx.export`` tracing the shape resolves to 0 and PyTorch's | |||||||||||||||||
| # ``F.interpolate`` aborts with ``input (H: 12, W: 12) output (H: 0, W: 0)``. | |||||||||||||||||
| # | |||||||||||||||||
| # For the production ``visual-question-answering`` inference path the | |||||||||||||||||
| # ``ViltProcessor`` ALWAYS pads to 384×384 and emits an all-ones ``pixel_mask``, | |||||||||||||||||
| # so the per-sample subset selection is a no-op. We replace ``visual_embed`` | |||||||||||||||||
| # during export with a simplified, statically-shaped implementation that: | |||||||||||||||||
| # * uses ``x.shape[2], x.shape[3]`` (static) for position-embed interpolation | |||||||||||||||||
| # * skips ``multinomial`` / nonzero / Python-level batch loops | |||||||||||||||||
| # * returns an all-ones token mask of length ``H*W + 1`` (patches + CLS) | |||||||||||||||||
| # | |||||||||||||||||
| # Verified numerically equivalent on ``dandelin/vilt-b32-finetuned-vqa`` with | |||||||||||||||||
| # fixed seed: ``cos=1.000000``, same ``argmax`` class, ``max_abs_diff≈1.2e-5`` | |||||||||||||||||
| # (within fp tolerance from interpolation operation ordering). | |||||||||||||||||
|
|
|||||||||||||||||
|
|
|||||||||||||||||
| def _patched_visual_embed(self, pixel_values, pixel_mask, max_image_length=200): | |||||||||||||||||
| """Static-shape, ONNX-traceable replacement for ``ViltEmbeddings.visual_embed``.""" | |||||||||||||||||
| import torch | |||||||||||||||||
| from torch import nn | |||||||||||||||||
|
|
|||||||||||||||||
| x = self.patch_embeddings(pixel_values) | |||||||||||||||||
| batch_size, num_channels, height, width = x.shape | |||||||||||||||||
|
|
|||||||||||||||||
| patch_dim = self.config.image_size // self.config.patch_size | |||||||||||||||||
| spatial_pos = self.position_embeddings[:, 1:, :].transpose(1, 2).view( | |||||||||||||||||
| 1, num_channels, patch_dim, patch_dim | |||||||||||||||||
| ) | |||||||||||||||||
| pos_embed = nn.functional.interpolate( | |||||||||||||||||
| spatial_pos, size=(height, width), mode="bilinear", align_corners=True | |||||||||||||||||
| ) | |||||||||||||||||
| pos_embed = pos_embed.flatten(2).transpose(1, 2).expand(batch_size, -1, -1) | |||||||||||||||||
|
|
|||||||||||||||||
| x = x.flatten(2).transpose(1, 2) | |||||||||||||||||
|
|
|||||||||||||||||
| cls_tokens = self.cls_token.expand(batch_size, -1, -1) | |||||||||||||||||
| x = torch.cat((cls_tokens, x), dim=1) | |||||||||||||||||
| pos_cls = self.position_embeddings[:, 0:1, :].expand(batch_size, -1, -1) | |||||||||||||||||
| pos_embed = torch.cat((pos_cls, pos_embed), dim=1) | |||||||||||||||||
| x = x + pos_embed | |||||||||||||||||
| x = self.dropout(x) | |||||||||||||||||
|
|
|||||||||||||||||
| num_tokens = height * width + 1 # patches + CLS | |||||||||||||||||
| x_mask = torch.ones(batch_size, num_tokens, dtype=torch.long, device=x.device) | |||||||||||||||||
| return x, x_mask, None | |||||||||||||||||
|
|
|||||||||||||||||
|
|
|||||||||||||||||
| class _ViltVisualEmbedPatcher(ModelPatcher): | |||||||||||||||||
| """Swaps ``ViltEmbeddings.visual_embed`` for the duration of ONNX export.""" | |||||||||||||||||
|
|
|||||||||||||||||
| def __enter__(self): | |||||||||||||||||
| super().__enter__() | |||||||||||||||||
| emb = self._model.vilt.embeddings if hasattr(self._model, "vilt") else self._model.embeddings | |||||||||||||||||
| self._emb_ref = emb | |||||||||||||||||
| self._orig_visual_embed = emb.visual_embed | |||||||||||||||||
| emb.visual_embed = types.MethodType(_patched_visual_embed, emb) | |||||||||||||||||
| return self | |||||||||||||||||
|
|
|||||||||||||||||
| def __exit__(self, exc_type, exc_value, traceback): | |||||||||||||||||
| self._emb_ref.visual_embed = self._orig_visual_embed | |||||||||||||||||
| super().__exit__(exc_type, exc_value, traceback) | |||||||||||||||||
|
|
|||||||||||||||||
|
|
|||||||||||||||||
| # ============================================================================= | |||||||||||||||||
| # Optimum ONNX Export Config Registration | |||||||||||||||||
| # ============================================================================= | |||||||||||||||||
| @register_onnx_overwrite("vilt", "visual-question-answering", library_name="transformers") | |||||||||||||||||
| class ViltVqaOnnxConfig(OnnxConfig): | |||||||||||||||||
| """ONNX export config for ``ViltForQuestionAnswering``. | |||||||||||||||||
|
|
|||||||||||||||||
| Declares 4 multi-modal inputs (text triple + pixel_values) and the single | |||||||||||||||||
| classification output. ``pixel_mask`` is deliberately omitted — see | |||||||||||||||||
| ``inputs`` property docstring and the module-level ``Notes`` section for | |||||||||||||||||
| the full rationale. | |||||||||||||||||
|
|
|||||||||||||||||
| Inputs: | |||||||||||||||||
| - ``input_ids``: [B, 40] int64 | |||||||||||||||||
| - ``attention_mask``: [B, 40] int64 | |||||||||||||||||
| - ``token_type_ids``: [B, 40] int64 | |||||||||||||||||
| - ``pixel_values``: [B, 3, 384, 384] float32 | |||||||||||||||||
|
|
|||||||||||||||||
| Outputs: | |||||||||||||||||
| - ``logits``: [B, num_labels=3129] float32 | |||||||||||||||||
|
|
|||||||||||||||||
| Notes: | |||||||||||||||||
| - ``num_labels`` (3129 for VQAv2) is a config-time fact, not declared | |||||||||||||||||
| dynamic in the symbolic axes — it's a static dim of ``logits``. | |||||||||||||||||
| - ``sequence_length`` resolves to ``max_position_embeddings`` (40 for | |||||||||||||||||
| ViLT-B/32) via ``NORMALIZED_CONFIG_CLASS``; the | |||||||||||||||||
| ``MaxLengthTextInputGenerator`` reads this for dummy tokens. | |||||||||||||||||
| - Chained ``DummyVisionInputGenerator`` + ``MaxLengthTextInputGenerator`` | |||||||||||||||||
| produce ``pixel_values`` + ``input_ids``/``attention_mask``/ | |||||||||||||||||
| ``token_type_ids``. The patched ``visual_embed`` (see module-level | |||||||||||||||||
| ``_ViltVisualEmbedPatcher``) synthesizes an all-ones token mask | |||||||||||||||||
| internally, so no ``pixel_mask`` input is required. | |||||||||||||||||
| """ | |||||||||||||||||
|
|
|||||||||||||||||
| NORMALIZED_CONFIG_CLASS = NormalizedTextConfig.with_args( | |||||||||||||||||
| sequence_length="max_position_embeddings", | |||||||||||||||||
| num_channels="num_channels", | |||||||||||||||||
| image_size="image_size", | |||||||||||||||||
| patch_size="patch_size", | |||||||||||||||||
| allow_new=True, | |||||||||||||||||
| ) | |||||||||||||||||
|
|
|||||||||||||||||
| DUMMY_INPUT_GENERATOR_CLASSES = ( | |||||||||||||||||
| DummyVisionInputGenerator, | |||||||||||||||||
| MaxLengthTextInputGenerator, | |||||||||||||||||
| ) | |||||||||||||||||
|
|
|||||||||||||||||
| DEFAULT_ONNX_OPSET = 17 | |||||||||||||||||
|
|
|||||||||||||||||
| @property | |||||||||||||||||
| def inputs(self) -> dict[str, dict[int, str]]: | |||||||||||||||||
| """Declare 4 model inputs (insertion order matches forward). | |||||||||||||||||
|
|
|||||||||||||||||
| ``pixel_values`` H,W is kept STATIC — ViLT interpolates position | |||||||||||||||||
| embeddings from the actual H,W, and exposing those as dynamic symbols | |||||||||||||||||
| trips the ONNX ``Resize`` shape-inference (``input (H:12 W:12) output | |||||||||||||||||
| (H:0 W:0)``). Pinning H,W matches all known production usage (always | |||||||||||||||||
| 384×384 input via ``ViltProcessor``). | |||||||||||||||||
|
|
|||||||||||||||||
| Note: ViLT's ``forward`` also takes a ``pixel_mask`` parameter, but | |||||||||||||||||
| this contribution exports without it. The ``ViltProcessor`` always | |||||||||||||||||
| emits an all-ones mask (the image is padded to 384×384 before the | |||||||||||||||||
| model sees it), and our export-time ``ModelPatcher`` replaces the | |||||||||||||||||
| original ``visual_embed`` with a statically-shaped version that | |||||||||||||||||
| synthesizes an all-ones token mask internally. Including ``pixel_mask`` | |||||||||||||||||
| as an ONNX input would dead-code-eliminate (since the patched path | |||||||||||||||||
| doesn't reference it) and confuse runtime callers. | |||||||||||||||||
| """ | |||||||||||||||||
| return { | |||||||||||||||||
| "input_ids": {0: "batch_size", 1: "sequence_length"}, | |||||||||||||||||
| "attention_mask": {0: "batch_size", 1: "sequence_length"}, | |||||||||||||||||
| "token_type_ids": {0: "batch_size", 1: "sequence_length"}, | |||||||||||||||||
| "pixel_values": {0: "batch_size"}, | |||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Static I checked
So the all-ones Also: the cited reason for pinning — "dynamic symbols trip Resize shape-inference (H:12 W:12 → H:0 W:0)" (L181–183) — describes the original Suggest either (a) make H/W dynamic — the patch already interpolates to actual |
|||||||||||||||||
| } | |||||||||||||||||
|
|
|||||||||||||||||
| @property | |||||||||||||||||
| def outputs(self) -> dict[str, dict[int, str]]: | |||||||||||||||||
| """Single classification output over fixed answer vocabulary.""" | |||||||||||||||||
| return { | |||||||||||||||||
| "logits": {0: "batch_size"}, | |||||||||||||||||
| } | |||||||||||||||||
|
|
|||||||||||||||||
| def generate_dummy_inputs(self, framework: str = "pt", **kwargs): # type: ignore[override] | |||||||||||||||||
| """Generate the 4 declared inputs via the chained vendor generators. | |||||||||||||||||
|
|
|||||||||||||||||
| ``pixel_mask`` is intentionally NOT generated — see ``inputs`` docstring. | |||||||||||||||||
| Our model patcher's replacement ``visual_embed`` synthesizes an | |||||||||||||||||
| all-ones token mask internally, so the model can be called with the | |||||||||||||||||
| 4 declared inputs. | |||||||||||||||||
| """ | |||||||||||||||||
| dummy = super().generate_dummy_inputs(framework=framework, **kwargs) | |||||||||||||||||
| # Drop any pixel_mask the generators may have produced — the patched | |||||||||||||||||
| # visual_embed ignores it (and including it would error at sess.run | |||||||||||||||||
| # since it isn't in the exported ONNX graph). | |||||||||||||||||
| dummy.pop("pixel_mask", None) | |||||||||||||||||
| return dummy | |||||||||||||||||
|
|
|||||||||||||||||
| def patch_model_for_export(self, model, model_kwargs=None): # type: ignore[override] | |||||||||||||||||
| """Install the ``visual_embed`` patcher for the export context.""" | |||||||||||||||||
| return _ViltVisualEmbedPatcher(self, model, model_kwargs=model_kwargs) | |||||||||||||||||
|
|
|||||||||||||||||
|
|
|||||||||||||||||
| # ============================================================================= | |||||||||||||||||
| # HuggingFace Model Class Mapping | |||||||||||||||||
| # ============================================================================= | |||||||||||||||||
| # ``visual-question-answering`` has no default AutoModel routing for ViLT; | |||||||||||||||||
| # bind the (model_type, task) tuple directly to the head-bearing HF class. | |||||||||||||||||
| MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { | |||||||||||||||||
| ("vilt", "visual-question-answering"): ViltForQuestionAnswering, | |||||||||||||||||
| } | |||||||||||||||||
|
|
|||||||||||||||||
|
|
|||||||||||||||||
| __all__ = [ | |||||||||||||||||
| "ViltVqaOnnxConfig", | |||||||||||||||||
| "MODEL_CLASS_MAPPING", | |||||||||||||||||
| ] | |||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recipe contradicts the PR's own reviewer note on mask
value_range. The PR description says "Recipevalue_rangefor mask-of-ones inputs must be[1, 2]not[0, 1]becauserandinthighis exclusive." But hereattention_mask(andtoken_type_idsbelow) use[0, 2]. It's functionally harmless for export — tracing doesn't depend on input values and the PT-vs-ORT check feeds identical inputs to both — but the note and the file disagree and will mislead anyone re-deriving this recipe. Please reconcile (fix the note or the values).