Skip to content

Add TinyViT and MobileViT#3

Merged
runwangdl merged 5 commits intopulp-platform:develfrom
runwangdl:VITs
Feb 14, 2026
Merged

Add TinyViT and MobileViT#3
runwangdl merged 5 commits intopulp-platform:develfrom
runwangdl:VITs

Conversation

@runwangdl
Copy link
Collaborator

@runwangdl runwangdl commented Feb 11, 2026

Added

  • MobileViT model support (XXS, XS, S variants) - hybrid CNN-Transformer for mobile/edge
  • TinyViT model support - efficient hierarchical vision transformer
  • Prompt for adding a model

Changed

  • Optimized MobileViT ONNX export: replaced dynamic view(-1,...) with static reshape(batch_size,...)
  • Fixed dimension propagation in transformer blocks to eliminate Shape/Gather nodes

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds TinyViT and MobileViT model support to the Onnx4Deeploy exporters, focusing on cleaner ONNX graphs by removing dynamic-shape operations and propagating fixed dimensions through transformer components.

Changes:

  • Added TinyViT exporter + CLI registration and updated project exports (__init__.py).
  • Refactored TinyViT and MobileViT PyTorch implementations to use fixed-dimension reshapes and ONNX-friendly attention blocks.
  • Updated changelog and Mamba export documentation references.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 10 comments.

Show a summary per file
File Description
onnx4deeploy/models/tinyvit_exporter.py Adjusts TinyViT defaults and passes fixed batch_size into model creation.
onnx4deeploy/models/pytorch_models/tinyvit/tinyvit.py Introduces fixed-dimension TinyViT implementation details (CLS token handling, fixed attention reshapes).
onnx4deeploy/models/pytorch_models/mobilevit/mobilevit.py Reworks MobileViT blocks/attention to avoid dynamic shapes and uses fixed patch dimensions for reshapes.
onnx4deeploy/models/mobilevit_exporter.py Passes fixed batch_size/img_size to MobileViT variant constructors.
onnx4deeploy/models/__init__.py Exposes TinyViTExporter from the models package.
docs/MAMBA_CLEAN_EXPORT.md Updates Mamba custom-op documentation to point at the current implementation file.
Onnx4Deeploy.py Registers TinyViT variants in the CLI model list and usage examples.
CHANGELOG.md Documents newly added models and ONNX export optimizations.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +268 to +269
self.cls_selector = nn.Parameter(torch.zeros(1, num_patches + 1), requires_grad=False)
self.cls_selector.data[0, 0] = 1.0 # Select only the first token (CLS)
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cls_selector is stored as an nn.Parameter and initialized via .data[...] = 1.0. Since it’s a fixed constant used only for selection, it should be a registered buffer (so it’s not treated like a weight), and avoid .data assignment (use a with torch.no_grad(): ... assignment or construct the tensor with the 1.0 already set). This reduces state_dict noise and avoids unsafe .data usage.

Suggested change
self.cls_selector = nn.Parameter(torch.zeros(1, num_patches + 1), requires_grad=False)
self.cls_selector.data[0, 0] = 1.0 # Select only the first token (CLS)
cls_selector = torch.zeros(1, num_patches + 1)
cls_selector[0, 0] = 1.0 # Select only the first token (CLS)
self.register_buffer("cls_selector", cls_selector)

Copilot uses AI. Check for mistakes.
Comment on lines 365 to 385
@@ -265,19 +378,21 @@ def tiny_vit_11m(num_classes: int = 1000, img_size: int = 224) -> TinyViT:
img_size=img_size,
patch_size=16,
embed_dim=256,
depth=12,
depth=1, # Reduced from 12 for faster testing
num_heads=4,
num_classes=num_classes,
batch_size=batch_size,
)
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny_vit_11m() hardcodes depth=1 but is labeled as the ~11M-parameter variant. This no longer corresponds to the intended model configuration. Restore the correct depth for the 11M variant (or make depth configurable and adjust naming/docs accordingly).

Copilot uses AI. Check for mistakes.
Comment on lines +482 to +492
# Compute spatial dimensions at each stage (for MobileViT blocks)
# After stem (stride=2): H/2, W/2 = 128x128
# After mv2_2 (stride=2): H/4, W/4 = 64x64
# After mv2_4 (stride=2): H/8, W/8 = 32x32 <- MobileViT block 1
# After mv2_6 (stride=2): H/16, W/16 = 16x16 <- MobileViT block 2
# After mv2_8 (stride=2): H/32, W/32 = 8x8 <- MobileViT block 3
self.mvit_patch_dims = [
(self.image_h // 8, self.image_w // 8), # MobileViT block 1: 32x32
(self.image_h // 16, self.image_w // 16), # MobileViT block 2: 16x16
(self.image_h // 32, self.image_w // 32), # MobileViT block 3: 8x8
]
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mvit_patch_dims uses integer division (H//8, H//16, H//32) to derive the fixed spatial sizes used later for reshapes. If image_size is not divisible by 32 (or differs from the assumed downsampling pattern), this will silently compute incorrect patch sizes and cause reshape errors or incorrect exports. Consider validating image_size at init time (e.g., assert H % 32 == 0 and W % 32 == 0, and optionally that H,W are positive) since these are static deployment constraints anyway.

Copilot uses AI. Check for mistakes.
Comment on lines 576 to +584
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass."""
"""
Forward pass with detailed dimension tracking.

Args:
x: Input tensor [B, 3, 256, 256]

Returns:
Output logits [B, num_classes]
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The forward docstring hardcodes the input as [B, 3, 256, 256], but the constructor accepts image_size and the exporter passes a configurable img_size. To avoid documentation drift, update the docstring to reflect a generic fixed input size (e.g., [B, 3, H, W] with H/W from image_size).

Copilot uses AI. Check for mistakes.
Comment on lines 457 to 464
def __init__(
self,
batch_size: int = 1,
image_size: Tuple[int, int] = (256, 256),
num_classes: int = 1000,
dims: list = [64, 80, 96],
dims: list = [96, 120, 144],
channels: list = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
):
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dims and channels use mutable list defaults in the MobileViT constructor. Even if they aren’t mutated today, this is a common Python footgun because the same list instance is shared across calls. Prefer using tuples, or defaulting to None and creating the lists inside __init__.

Copilot uses AI. Check for mistakes.
Comment on lines +105 to +125
"TinyViT-5M": {
"class": TinyViTExporter,
"description": "TinyViT-5M (Compact Vision Transformer, ~5M params)",
"input_shape": "(B, 3, 64, 64)",
"classes": 10,
"config": {"variant": "tiny_vit_5m", "img_size": 64, "num_classes": 10},
},
"TinyViT-11M": {
"class": TinyViTExporter,
"description": "TinyViT-11M (Compact Vision Transformer, ~11M params)",
"input_shape": "(B, 3, 64, 64)",
"classes": 10,
"config": {"variant": "tiny_vit_11m", "img_size": 64, "num_classes": 10},
},
"TinyViT-21M": {
"class": TinyViTExporter,
"description": "TinyViT-21M (Compact Vision Transformer, ~21M params)",
"input_shape": "(B, 3, 64, 64)",
"classes": 10,
"config": {"variant": "tiny_vit_21m", "img_size": 64, "num_classes": 10},
},
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TinyViT-* entries rely on the per-model "config" overrides (variant/img_size/num_classes), but the current generate_model() flow updates exporter.config and then calls exporter.export_inference(), which reloads config via load_config() and discards those overrides. As a result, selecting "TinyViT-11M" / "TinyViT-21M" (and similarly MobileViT-XXS/-S) will still export the exporter defaults. Consider changing the exporter API to honor a pre-set exporter.config (e.g., skip load_config() when self.config is already set) or add a supported way to pass overrides into export_* so these model presets actually take effect.

Copilot uses AI. Check for mistakes.
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Attention.__init__ now defaults qkv_bias to False, but Block does not pass this flag, so this silently changes the model’s behavior compared to the previous default (and typical TinyViT/ViT implementations). If the bias removal is only for export cleanliness, consider keeping the default True and making the ONNX-friendly choice explicit via a constructor arg passed from the exporter/config.

Suggested change
qkv_bias: bool = False,
qkv_bias: bool = True,

Copilot uses AI. Check for mistakes.
Comment on lines 342 to 362
@@ -244,19 +355,21 @@ def tiny_vit_5m(num_classes: int = 1000, img_size: int = 224) -> TinyViT:
img_size=img_size,
patch_size=16,
embed_dim=192,
depth=12,
depth=1, # Reduced from 12 for faster testing
num_heads=3,
num_classes=num_classes,
batch_size=batch_size,
)
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny_vit_5m() hardcodes depth=1 while the function name/docstring claims this is the ~5M-parameter TinyViT variant. With depth reduced from 12 to 1, the architecture/parameter count no longer matches the advertised model and will likely break expectations for accuracy and benchmarking. Suggest restoring the canonical depth for the variant (or exposing depth as a configurable argument and renaming this helper to avoid implying it matches the published 5M model).

Copilot uses AI. Check for mistakes.
Comment on lines 388 to 408
@@ -286,7 +401,8 @@ def tiny_vit_21m(num_classes: int = 1000, img_size: int = 224) -> TinyViT:
img_size=img_size,
patch_size=16,
embed_dim=384,
depth=12,
depth=1, # Reduced from 12 for faster testing
num_heads=6,
num_classes=num_classes,
batch_size=batch_size,
)
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny_vit_21m() hardcodes depth=1 but is labeled as the ~21M-parameter variant. This helper no longer produces the advertised architecture/parameter count. Restore the intended depth for this variant (or make the simplification explicit via naming/config).

Copilot uses AI. Check for mistakes.
Comment on lines 14 to +30
@@ -26,4 +27,5 @@
"MobileViTExporter",
"MambaExporter",
"SleepConViTExporter",
"TinyViTExporter",
Copy link

Copilot AI Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is existing model-export test coverage under tests/models/ (e.g., test_cct.py) but no analogous tests added for the new/updated MobileViTExporter and the new TinyViTExporter. Adding basic inference export + ONNX Runtime execution tests would help catch shape/variant regressions (especially since these models rely on fixed reshape dimensions).

Copilot uses AI. Check for mistakes.
@runwangdl runwangdl merged commit 0a3f65e into pulp-platform:devel Feb 14, 2026
10 checks passed
@runwangdl runwangdl deleted the VITs branch March 4, 2026 10:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants