From c832fadbb084b06d529977a5b804dea483ffeea8 Mon Sep 17 00:00:00 2001 From: d-savchenkov Date: Wed, 4 Mar 2026 18:35:27 +0300 Subject: [PATCH] [quantization] Introduce wrapper for Qwen3VLVisionModel This change introduces QuantQwen3VLVisionModel wrapper to support post-training quantization of Qwen3VLVisionModel operation. TICO-DCO-1.0-Signed-off-by: d.savchenkov --- .../qwen_vl/test_quant_vision_model.py | 372 +++++++++++++++ .../quantization/utils/transformers_compat.py | 55 +++ .../examples/qwen/quantize_vision_model.py | 154 +++++++ .../wrappers/qwen_vl/quant_vision_model.py | 431 ++++++++++++++++++ tico/quantization/wrapq/wrappers/registry.py | 1 + 5 files changed, 1013 insertions(+) create mode 100644 test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py create mode 100644 tico/quantization/utils/transformers_compat.py create mode 100644 tico/quantization/wrapq/examples/qwen/quantize_vision_model.py create mode 100644 tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_model.py diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py new file mode 100644 index 00000000..15d9cb68 --- /dev/null +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_vision_model.py @@ -0,0 +1,372 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import unittest +from typing import Tuple + +import torch + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.utils.version import has_transformers_for +from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_model import ( + QuantQwen3VLVisionModel, +) + + +skip_msg = "transformers not installed — skipping Qwen3VLVisionModel tests" + + +@unittest.skipUnless(has_transformers_for("qwen3-vl"), skip_msg) +class TestQuantQwen3VLVisionModel(unittest.TestCase): + fp_model: torch.nn.Module + hidden_size: int + num_heads: int + head_dim: int + theta: float + + @classmethod + def setUpClass(cls): + from transformers.models.qwen3_vl.configuration_qwen3_vl import ( + Qwen3VLVisionConfig, + ) + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel + + # Use smaller sizes for testing + cfg = Qwen3VLVisionConfig( + hidden_size=64, + num_heads=4, + depth=2, # Smaller depth for faster testing + temporal_patch_size=2, + patch_size=16, + ) + + # Ensure eager attention implementation so outputs are deterministic + # and do not require GPU flash attention kernels. + # Some versions use `_attn_implementation`, others expose `attn_implementation`. + if not hasattr(cfg, "_attn_implementation"): + setattr(cfg, "_attn_implementation", "eager") + else: + cfg._attn_implementation = "eager" + + cls.fp_model = Qwen3VLVisionModel(cfg) + cls.hidden_size = cfg.hidden_size + cls.num_heads = cfg.num_heads + cls.head_dim = cls.hidden_size // cls.num_heads + cls.theta = ( + cls.fp_model.rotary_pos_emb.theta + if hasattr(cls.fp_model.rotary_pos_emb, "theta") + else 10000.0 + ) + + def _create_test_inputs( + self, grid_thw: Tuple[int, int, int] = (1, 8, 8) + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Helper to create test inputs for VisionModel.""" + t, h, w = grid_thw + num_patches = t * h * w + # Input shape: (seq_len, in_channels * temporal_patch_size * patch_size * patch_size) + hidden_states = torch.randn( + num_patches, 3 * 2 * 16 * 16 + ) # 3 channels, 2 temporal, 16x16 patches + grid_tensor = torch.tensor([grid_thw]) + return hidden_states, grid_tensor + + def test_get_vision_grid_thw_from_config(self): + """Test _get_vision_grid_thw static method with valid config.""" + # Test with valid config + ptq_config = PTQConfig() + setattr(ptq_config, "vision_grid_thw", [1, 8, 8]) + + grid_thw = QuantQwen3VLVisionModel._get_vision_grid_thw(ptq_config) + expected = torch.tensor([[1, 8, 8]]) + self.assertTrue(torch.equal(grid_thw, expected)) + self.assertEqual(grid_thw.shape, (1, 3)) + + def test_get_vision_grid_thw_missing_config(self): + """Test _get_vision_grid_thw raises error when config is missing.""" + # Test with None config + with self.assertRaises(ValueError) as context: + QuantQwen3VLVisionModel._get_vision_grid_thw(None) + self.assertIn("vision_grid_thw must be specified", str(context.exception)) + + # Test with config without vision_grid_thw + ptq_config = PTQConfig() + with self.assertRaises(ValueError) as context: + QuantQwen3VLVisionModel._get_vision_grid_thw(ptq_config) + self.assertIn("vision_grid_thw must be specified", str(context.exception)) + + def test_precompute_rope_inv_freq(self): + """Test _precompute_rope_inv_freq static method.""" + dim = 32 + theta = 10000.0 + inv_freq = QuantQwen3VLVisionModel._precompute_rope_inv_freq(dim, theta) + + self.assertEqual(inv_freq.shape, (dim // 2,)) + self.assertTrue(torch.all(inv_freq > 0)) + # Check that frequencies are decreasing + self.assertTrue(torch.all(inv_freq[:-1] >= inv_freq[1:])) + + def test_precompute_cu_seqlens(self): + """Test _precompute_cu_seqlens static method.""" + grid_thw = torch.tensor( + [[1, 8, 8], [2, 4, 4]] + ) # 1*8*8 + 2*4*4 = 96 total patches + cu_seqlens = QuantQwen3VLVisionModel._precompute_cu_seqlens(grid_thw) + + self.assertEqual(cu_seqlens.shape, (4,)) # 3 images + 1 padding + self.assertEqual(cu_seqlens[0].item(), 0) + self.assertEqual(cu_seqlens[1].item(), 64) # 1st image: 1*8*8 = 64 patches + self.assertEqual(cu_seqlens[2].item(), 80) # 2nd image: 1*4*4 = 16 patches + self.assertEqual( + cu_seqlens[3].item(), 96 + ) # 3rd image: 1*4*4 = 16 patches, total 96 + + def test_precompute_rope_position_embeddings(self): + """Test _precompute_rope_position_embeddings static method.""" + grid_thw = torch.tensor([[1, 8, 8]]) + inv_freq = QuantQwen3VLVisionModel._precompute_rope_inv_freq( + dim=self.head_dim // 2, + theta=self.theta, + ) + + cos_t, sin_t = QuantQwen3VLVisionModel._precompute_rope_position_embeddings( + merge_size=2, + rope_inv_freq=inv_freq, + grid_thw=grid_thw, + ) + + expected_patches = math.prod(grid_thw[0].tolist()) # t * h * w = 1 * 8 * 8 = 64 + self.assertEqual(cos_t.shape, (expected_patches, self.head_dim)) + self.assertEqual(sin_t.shape, (expected_patches, self.head_dim)) + + def test_rot_pos_emb(self): + """Test _rot_pos_emb static method.""" + grid_thw = torch.tensor([[1, 8, 8]]) + inv_freq = QuantQwen3VLVisionModel._precompute_rope_inv_freq( + dim=self.head_dim // 2, + theta=self.theta, + ) + + rotary_pos_emb = QuantQwen3VLVisionModel._rot_pos_emb(2, inv_freq, grid_thw) + + expected_patches = math.prod(grid_thw[0].tolist()) # t * h * w = 1 * 8 * 8 = 64 + self.assertEqual(rotary_pos_emb.shape, (expected_patches, self.head_dim // 2)) + + def test_create_freq_table(self): + """Test _create_freq_table static method.""" + seqlen = 64 + inv_freq = torch.randn(16) # dim//2 = 32//2 = 16 + freq_table = QuantQwen3VLVisionModel._create_freq_table(seqlen, inv_freq) + + self.assertEqual(freq_table.shape, (seqlen, inv_freq.shape[0])) + + def test_fast_pos_embed_interpolate(self): + """Test _fast_pos_embed_interpolate static method.""" + grid_thw = torch.tensor([[1, 8, 8]]) + pos_embeds = QuantQwen3VLVisionModel._fast_pos_embed_interpolate( + merge_size=2, + num_grid_per_side=48, # From model config + pos_embedder=self.fp_model.pos_embed, + grid_thw=grid_thw, + ) + + expected_patches = math.prod(grid_thw[0].tolist()) # t * h * w = 1 * 8 * 8 = 64 + self.assertEqual(pos_embeds.shape, (expected_patches, self.hidden_size)) + + def test_init_with_valid_config(self): + """Test successful initialization with valid config.""" + ptq_config = PTQConfig() + setattr(ptq_config, "vision_grid_thw", [1, 8, 8]) + + q_model = QuantQwen3VLVisionModel( + self.fp_model, qcfg=ptq_config, fp_name="test_model" + ) + + # Check that buffers are registered + self.assertTrue(hasattr(q_model, "cu_seqlens_template")) + self.assertTrue(hasattr(q_model, "pos_embed_template")) + self.assertTrue(hasattr(q_model, "rope_inv_freq")) + self.assertTrue(hasattr(q_model, "rope_cos_template")) + self.assertTrue(hasattr(q_model, "rope_sin_template")) + + # Check submodule wrapping + self.assertIsNotNone(q_model.patch_embed) + self.assertEqual(len(q_model.blocks), len(self.fp_model.blocks)) + self.assertIsNotNone(q_model.merger) + self.assertEqual( + len(q_model.deepstack_merger_list), len(self.fp_model.deepstack_merger_list) + ) + + def test_init_missing_vision_grid_thw(self): + """Test initialization fails without vision_grid_thw.""" + ptq_config = PTQConfig() + + with self.assertRaises(ValueError) as context: + QuantQwen3VLVisionModel( + self.fp_model, qcfg=ptq_config, fp_name="test_model" + ) + self.assertIn("vision_grid_thw must be specified", str(context.exception)) + + def test_mode_transitions(self): + """Test quantization mode transitions: NO_QUANT → CALIB → QUANT""" + ptq_config = PTQConfig() + setattr(ptq_config, "vision_grid_thw", [1, 8, 8]) + q_model = QuantQwen3VLVisionModel( + self.fp_model, qcfg=ptq_config, fp_name="test_model" + ) + self.assertIs(q_model._mode, Mode.NO_QUANT) + + q_model.enable_calibration() + self.assertIs(q_model._mode, Mode.CALIB) + + # Run forward pass during calibration + hidden_states, grid_thw = self._create_test_inputs((1, 8, 8)) + _ = q_model(hidden_states, grid_thw) + + q_model.freeze_qparams() + self.assertIs(q_model._mode, Mode.QUANT) + + def test_forward_grid_mismatch_during_calibration(self): + """Test forward pass fails with mismatched grid_thw during calibration.""" + ptq_config = PTQConfig() + setattr(ptq_config, "vision_grid_thw", [1, 8, 8]) + q_model = QuantQwen3VLVisionModel( + self.fp_model, qcfg=ptq_config, fp_name="test_model" + ) + q_model.enable_calibration() + + # Try with different grid + hidden_states, grid_thw = self._create_test_inputs((1, 4, 4)) + + with self.assertRaises(AssertionError) as context: + _ = q_model(hidden_states, grid_thw) + self.assertIn("grid_thw", str(context.exception)) + + def test_observer_count(self): + """Test that the wrapper has the correct number of observers.""" + ptq_config = PTQConfig() + setattr(ptq_config, "vision_grid_thw", [1, 8, 8]) + q_model = QuantQwen3VLVisionModel( + self.fp_model, qcfg=ptq_config, fp_name="test_model" + ) + + observers = list(q_model._all_observers()) + # Should have 4 local observers: pos_embeds, pos_add, rope_cos, rope_sin + self.assertEqual(len(observers), 4) + + def test_precomputed_embeddings_shape(self): + """Test that precomputed embeddings have correct shapes.""" + ptq_config = PTQConfig() + setattr(ptq_config, "vision_grid_thw", [1, 8, 8]) + q_model = QuantQwen3VLVisionModel( + self.fp_model, qcfg=ptq_config, fp_name="test_model" + ) + + expected_patches = math.prod( + getattr(ptq_config, "vision_grid_thw") + ) # t * h * w = 1 * 8 * 8 = 64 + + # Check position embeddings + self.assertEqual( + q_model.pos_embed_template.shape, (expected_patches, self.hidden_size) + ) + + # Check RoPE embeddings + self.assertEqual( + q_model.rope_cos_template.shape, + (expected_patches, self.head_dim), + ) + self.assertEqual( + q_model.rope_sin_template.shape, + (expected_patches, self.head_dim), + ) + + # Check cumulative sequence lengths + self.assertEqual(q_model.cu_seqlens_template.shape, (2,)) # 1 image + 1 padding + + def test_registration_in_registry(self): + """Test that Qwen3VLVisionModel is properly registered.""" + from tico.quantization.wrapq.wrappers.registry import lookup + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel + + wrapper_cls = lookup(Qwen3VLVisionModel) + self.assertIs(wrapper_cls, QuantQwen3VLVisionModel) + + def test_output_structure(self): + """Test that output has correct structure.""" + ptq_config = PTQConfig() + setattr(ptq_config, "vision_grid_thw", [1, 8, 8]) + q_model = QuantQwen3VLVisionModel( + self.fp_model, qcfg=ptq_config, fp_name="test_model" + ) + q_model.enable_calibration() + + hidden_states, grid_thw = self._create_test_inputs((1, 8, 8)) + _ = q_model(hidden_states, grid_thw) + + q_model.freeze_qparams() + + with torch.no_grad(): + q_out = q_model(hidden_states, grid_thw) + + # Check shapes + expected_patches = math.prod( + getattr(ptq_config, "vision_grid_thw") + ) # t * h * w = 1 * 8 * 8 + + # The structure of q_out depends on transformers version + merged_hidden_states = ( + q_out.pooler_output if q_model.has_deepstack_model_output else q_out[0] + ) + + self.assertEqual(merged_hidden_states.shape[0], expected_patches // 4) + + def test_different_grid_sizes(self): + """Test with different grid sizes.""" + test_cases = [ + ((1, 4, 4), "small_image"), + ((1, 6, 6), "medium_image"), + ((1, 8, 8), "large_image"), + ] + + grid_thw_list: tuple[int, int, int] + description: str + for grid_thw_list, description in test_cases: + with self.subTest(description=description): + ptq_config = PTQConfig() + setattr(ptq_config, "vision_grid_thw", grid_thw_list) + q_model = QuantQwen3VLVisionModel( + self.fp_model, qcfg=ptq_config, fp_name=f"test_model_{description}" + ) + + hidden_states, grid_thw = self._create_test_inputs(grid_thw_list) + + q_model.enable_calibration() + _ = q_model(hidden_states, grid_thw) + q_model.freeze_qparams() + + with torch.no_grad(): + q_out = q_model(hidden_states, grid_thw) + + # The structure of q_out depends on transformers version + merged_hidden_states = ( + q_out.pooler_output + if q_model.has_deepstack_model_output + else q_out[0] + ) + + expected_patches = math.prod(grid_thw_list) # t * h * w + self.assertEqual(merged_hidden_states.shape[0], expected_patches // 4) diff --git a/tico/quantization/utils/transformers_compat.py b/tico/quantization/utils/transformers_compat.py new file mode 100644 index 00000000..4a237e1e --- /dev/null +++ b/tico/quantization/utils/transformers_compat.py @@ -0,0 +1,55 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Runtime capability-detection helpers for Hugging Face `transformers`. + +Instead of branching on specific package versions such as +`transformers >= 5.x`, use these helpers to detect whether the exact +symbol or behavior required by the code is available at runtime. + +Each probe is cached once per process with `functools.lru_cache`, +so repeated checks have negligible overhead. +""" + +import functools +import importlib + + +@functools.lru_cache(maxsize=None) +def qwen3_vl_has_deepstack_model_output() -> bool: + """ + Return whether Qwen3-VL exposes + `BaseModelOutputWithDeepstackFeatures` in its modeling module. + + This wrapper only needs to know whether the structured return type is + available. Using feature detection keeps the code resilient to + backports, forward ports, and non-linear package versioning. + + Returns + ------- + bool + ``True`` if + `transformers.models.qwen3_vl.modeling_qwen3_vl` + defines `BaseModelOutputWithDeepstackFeatures`, + otherwise ``False``. + """ + try: + module = importlib.import_module( + "transformers.models.qwen3_vl.modeling_qwen3_vl" + ) + except ImportError: + return False + + return hasattr(module, "BaseModelOutputWithDeepstackFeatures") diff --git a/tico/quantization/wrapq/examples/qwen/quantize_vision_model.py b/tico/quantization/wrapq/examples/qwen/quantize_vision_model.py new file mode 100644 index 00000000..d0bf3fcf --- /dev/null +++ b/tico/quantization/wrapq/examples/qwen/quantize_vision_model.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import sys +from collections import namedtuple + +import torch +import torch.nn as nn + +import tico +import tico.quantization +import tico.quantization.config.ptq +from tico.quantization.evaluation.metric import compute_peir +from tico.quantization.evaluation.utils import plot_two_outputs +from tico.quantization.wrapq.utils.version import has_transformers_for + +torch.manual_seed(123) + + +# Check if transformers is available + +if not has_transformers_for("qwen3-vl"): + print("Error: transformers package not installed. Cannot test Qwen3VLVisionModel.") + sys.exit(1) + +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLVisionConfig +from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel + +from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_model import ( + QuantQwen3VLVisionModel, +) + + +def generate_calibration_data(batch_size: int, sample_shape: tuple) -> list: + """Generate calibration data for PTQ""" + calibration_data = [] + for i in range(batch_size): + x = torch.randn(sample_shape) + calibration_data.append(x) + return calibration_data + + +def main(): + # Create the vision model configuration + # Based on Qwen3VLVisionModel structure: + # (patch_embed): Qwen3VLVisionPatchEmbed( + # (proj): Conv3d(3, 1024, kernel_size=(2, 16, 16), stride=(2, 16, 16)) + # ) + # (pos_embed): Embedding(2304, 1024) + cfg = Qwen3VLVisionConfig( + hidden_size=1024, + num_position_embeddings=2304, # 48x48 spatial grid + temporal_patch_size=2, + patch_size=16, + depth=2, # Number of transformer blocks (reduced for example) + ) + model = Qwen3VLVisionModel(cfg) + orig_model = copy.deepcopy(model) + model.eval() + + # Define grid_thw for fixed input size + # grid_thw: (num_images, 3) with (temporal, height, width) + # Example: [1, 24, 24] means 1 video with 1 temporal patch, 24 vertical, 24 horizontal + # Total patches: 1 * 24 * 24 = 576 + THW = namedtuple( + "THW", ["num_temporal_patches", "num_height_patches", "num_width_patches"] + ) + vision_grid_thw = THW(1, 24, 24) + grid_thw = torch.tensor([vision_grid_thw]) + + # Input to patch_embed: (batch_size, in_channels, depth, height, width) + # Example: (1, 3, 16, 384, 384) + # - batch_size: 1 + # - in_channels: 3 (RGB) + # - depth: frames = num_temporal_patches * temporal_patch_size = 1 * 2 = 2 frames + # - height: num_height_patches * patch_size = 24 * 16 = 384 + # - width: num_width_patches * patch_size = 24 * 16 = 384 + num_frames = vision_grid_thw.num_temporal_patches * cfg.temporal_patch_size + frame_height = vision_grid_thw.num_height_patches * cfg.patch_size + frame_width = vision_grid_thw.num_width_patches * cfg.patch_size + input_shape = (1, cfg.in_channels, num_frames, frame_height, frame_width) + + print(f"Input shape: {input_shape}") + print(f"grid_thw: {grid_thw.tolist()}") + + # Generate calibration data + calibration_data = generate_calibration_data( + batch_size=20, sample_shape=input_shape + ) + + # Configure PTQ with vision_grid_thw override + # This is required for QuantQwen3VLVisionModel to precompute RoPE embeddings + ptq_config = tico.quantization.config.ptq.PTQConfig() + setattr(ptq_config, "vision_grid_thw", vision_grid_thw) + + # Prepare the model for quantization + prepared_model = tico.quantization.prepare( + model, ptq_config, inplace=True # Transform the model in place + ) + + # Calibrate the model (collect statistics) + with torch.no_grad(): + for i, batch in enumerate(calibration_data): + prepared_model(batch, grid_thw) + + # Convert to quantized model + quantized_model = tico.quantization.convert(prepared_model, inplace=True) + + # Compute PEIR (Peak Error-to-Input Ratio) between quantized model and original model + with torch.no_grad(): + test_input = calibration_data[0] + quant_out = quantized_model(test_input, grid_thw) + fp_out = orig_model(test_input, grid_thw) + + # The structure of quant_out depends on transformers version + if QuantQwen3VLVisionModel.has_deepstack_model_output: + quant_out = quant_out.pooler_output + fp_out = fp_out.pooler_output + else: + quant_out = quant_out[0] + fp_out = fp_out[0] + + print(f"┌───────────── Quantization Error Summary ─────────────") + print(f"│ Mean |diff|: {(quant_out - fp_out).abs().mean().item():.6f}") + print(f"│ PEIR : {compute_peir(fp_out, quant_out) * 100:.6f} %") + print(f"└──────────────────────────────────────────────────────") + print(plot_two_outputs(fp_out, quant_out)) + + # Convert to Circle format + # example_inputs: (hidden_states, grid_thw) + example_input = (calibration_data[0], grid_thw) + circle_model = tico.convert(quantized_model.eval(), example_input) + + # Save the Circle model + filename = "qwen3vl_vision_model.q.circle" + circle_model.save(filename) + print(f"Circle model saved as '{filename}'") + + +if __name__ == "__main__": + main() diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_model.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_model.py new file mode 100644 index 00000000..33b0c0a7 --- /dev/null +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_model.py @@ -0,0 +1,431 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Iterable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.utils.transformers_compat import ( + qwen3_vl_has_deepstack_model_output, +) +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +@try_register("transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLVisionModel") +class QuantQwen3VLVisionModel(QuantModuleBase): + """ + Quantization wrapper for Qwen3VLVisionModel module. + + This is the main vision model that processes image/video patches through: + - Patch embedding + - Position embedding (spatial) + - Rotary position embedding (RoPE) + - Transformer blocks + - Patch merger + """ + + has_deepstack_model_output: bool = qwen3_vl_has_deepstack_model_output() + + def __init__( + self, + fp_model: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + ): + super().__init__(qcfg, fp_name=fp_name) + self.module = fp_model + + cfg = fp_model.config + self.spatial_merge_size = cfg.spatial_merge_size + self.patch_size = cfg.patch_size + self.hidden_size = cfg.hidden_size + self.num_position_embeddings = cfg.num_position_embeddings + self.num_grid_per_side = int(cfg.num_position_embeddings**0.5) + self.deepstack_visual_indexes = cfg.deepstack_visual_indexes + + # Extract vision_grid_thw from config for precomputing RoPE embeddings + self.vision_grid_thw = QuantQwen3VLVisionModel._get_vision_grid_thw(qcfg) + + # Precompute cumulative sequence lengths + cu_seqlens = QuantQwen3VLVisionModel._precompute_cu_seqlens( + self.vision_grid_thw + ) + self.register_buffer("cu_seqlens_template", cu_seqlens, persistent=False) + + # Precompute fast position embeddings + pos_embeds = QuantQwen3VLVisionModel._fast_pos_embed_interpolate( + merge_size=self.spatial_merge_size, + num_grid_per_side=self.num_grid_per_side, + pos_embedder=fp_model.pos_embed, + grid_thw=self.vision_grid_thw, + ) + self.register_buffer("pos_embed_template", pos_embeds, persistent=False) + + # Precompute rotary frequency table for RoPE + dim = ( + fp_model.rotary_pos_emb.dim + if hasattr(fp_model.rotary_pos_emb, "dim") + else (cfg.hidden_size // cfg.num_heads) // 2 + ) + theta = ( + fp_model.rotary_pos_emb.theta + if hasattr(fp_model.rotary_pos_emb, "theta") + else 10000.0 + ) + inv_freq = QuantQwen3VLVisionModel._precompute_rope_inv_freq( + dim=dim, theta=theta + ) + self.register_buffer("rope_inv_freq", inv_freq, persistent=False) + + # Precompute RoPE position embeddings for fixed vision_grid_thw + cos_t, sin_t = QuantQwen3VLVisionModel._precompute_rope_position_embeddings( + merge_size=self.spatial_merge_size, + rope_inv_freq=self.rope_inv_freq, + grid_thw=self.vision_grid_thw, + ) + self.register_buffer("rope_cos_template", cos_t, persistent=False) + self.register_buffer("rope_sin_template", sin_t, persistent=False) + + # Wrap patch embedder + self.patch_embed = PTQWrapper( + fp_model.patch_embed, + qcfg=qcfg.child("patch_embed") if qcfg else None, + fp_name=f"{fp_name}.patch_embed", + ) + + # Wrap transformer blocks + self.blocks = nn.ModuleList() + blocks_cfg = qcfg.child("blocks") if qcfg else None + for i, blk in enumerate(fp_model.blocks): + self.blocks.append( + PTQWrapper( + blk, + qcfg=blocks_cfg.child(str(i)) if blocks_cfg else None, + fp_name=f"{fp_name}.blocks.{i}", + ) + ) + + # Wrap merger + self.merger = PTQWrapper( + fp_model.merger, + qcfg=qcfg.child("merger") if qcfg else None, + fp_name=f"{fp_name}.merger", + ) + + # Wrap deepstack merger list + self.deepstack_merger_list = nn.ModuleList() + deepstack_merger_cfg = qcfg.child("deepstack_merger_list") if qcfg else None + for i, merger in enumerate(fp_model.deepstack_merger_list): + self.deepstack_merger_list.append( + PTQWrapper( + merger, + qcfg=deepstack_merger_cfg.child(str(i)) + if deepstack_merger_cfg + else None, + fp_name=f"{fp_name}.deepstack_merger_list.{i}", + ) + ) + + # --- Observers for intermediate tensors -------------------------------- + mk = self._make_obs + + # Position embedding observers + self.obs_pos_embeds = mk("pos_embed") + self.obs_pos_add = mk("pos_add") + + # RoPE observers + self.obs_rope_cos = mk("rope_cos") + self.obs_rope_sin = mk("rope_sin") + + @staticmethod + def _get_vision_grid_thw(qcfg: Optional[PTQConfig]) -> torch.Tensor: + """Extract vision_grid_thw from config for precomputing RoPE embeddings""" + if qcfg and hasattr(qcfg, "vision_grid_thw"): + grid_thw = torch.tensor([getattr(qcfg, "vision_grid_thw")]) + else: + raise ValueError( + "vision_grid_thw must be specified in PTQConfig overrides for " + "QuantQwen3VLVisionModel. Example: ptq_cfg = PTQConfig(); ptq_cfg.vision_grid_thw = [8, 24, 24]" + ) + assert grid_thw.shape == (1, 3) # 1 row, 3 columns (T, H, W) + return grid_thw + + @staticmethod + def _precompute_rope_inv_freq(dim: int, theta: float) -> torch.Tensor: + """Precompute rotary frequency table for RoPE.""" + # Compute inverse frequencies + inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim)) + return inv_freq + + @staticmethod + def _precompute_cu_seqlens(grid_thw: torch.Tensor) -> torch.Tensor: + """Precompute cumulative sequence lengths for fixed vision_grid_thw.""" + # Compute cumulative sequence lengths + from torch.nn import functional as F + + cu_seqlens = torch.repeat_interleave( + grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0], + ).cumsum(dim=0) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0) + return cu_seqlens + + @staticmethod + def _precompute_rope_position_embeddings( + merge_size: int, rope_inv_freq: torch.Tensor, grid_thw: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Precompute RoPE position embeddings (cos, sin) for fixed vision_grid_thw.""" + seq_len = int(torch.prod(grid_thw, dim=1).sum().item()) + rotary_pos_emb = QuantQwen3VLVisionModel._rot_pos_emb( + merge_size, rope_inv_freq, grid_thw + ) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) + return emb.cos(), emb.sin() + + @staticmethod + def _rot_pos_emb( + merge_size: int, rope_inv_freq: torch.Tensor, grid_thw: torch.Tensor + ) -> torch.Tensor: + """Compute rotary position embeddings from grid dimensions.""" + max_hw = int(grid_thw[:, 1:].max().item()) + + # Create frequency table up to max_hw + freq_table = QuantQwen3VLVisionModel._create_freq_table( + seqlen=max_hw, rope_inv_freq=rope_inv_freq + ) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) + block_cols = torch.arange(merged_w, device=device) + intra_row = torch.arange(merge_size, device=device) + intra_col = torch.arange(merge_size, device=device) + + # Compute full-resolution positions + row_idx = ( + block_rows[:, None, None, None] * merge_size + + intra_row[None, None, :, None] + ) + col_idx = ( + block_cols[None, :, None, None] * merge_size + + intra_col[None, None, None, :] + ) + + row_idx = row_idx.expand( + merged_h, merged_w, merge_size, merge_size + ).reshape(-1) + col_idx = col_idx.expand( + merged_h, merged_w, merge_size, merge_size + ).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset : offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] + embeddings = embeddings.flatten(1) + return embeddings + + @staticmethod + def _create_freq_table(seqlen: int, rope_inv_freq: torch.Tensor) -> torch.Tensor: + """Create rotary frequency table.""" + seq = torch.arange( + seqlen, device=rope_inv_freq.device, dtype=rope_inv_freq.dtype + ) + freqs = torch.outer(seq, rope_inv_freq) + return freqs + + @staticmethod + def _fast_pos_embed_interpolate( + merge_size: int, + num_grid_per_side: int, + pos_embedder: nn.Module, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + """Compute interpolated position embeddings.""" + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + device = pos_embedder.weight.device + + idx_list: List[Any] = [[] for _ in range(4)] + weight_list: List[Any] = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * num_grid_per_side + base_h_ceil = h_idxs_ceil * num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=device) + weight_tensor = torch.tensor( + weight_list, dtype=pos_embedder.weight.dtype, device=device + ) + pos_embeds = pos_embedder(idx_tensor).to(device) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split( + [h * w for h, w in zip(grid_hs, grid_ws)] + ) + + patch_pos_embeds_permute = [] + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view( + t, h // merge_size, merge_size, w // merge_size, merge_size, -1 + ) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward( + self, + hidden_states: torch.Tensor, + grid_thw: torch.Tensor, + **kwargs, + ) -> Union[torch.Tensor, tuple]: + """ + Forward pass with fake quantization. + + Args: + hidden_states: Input tensor of shape (seq_len, in_channels * T * H * W) + grid_thw: Grid dimensions (num_images, 3) with (temporal, height, width) + + Returns: + BaseModelOutputWithDeepstackFeatures or similar + """ + # Assert that grid_thw matches the precomputed vision_grid_thw + if self._mode is Mode.CALIB: + assert torch.equal(grid_thw, self.vision_grid_thw.to(grid_thw.device)), ( + f"grid_thw {grid_thw.tolist()} does not match the precomputed " + f"vision_grid_thw {self.vision_grid_thw.tolist()}" + ) + + # Patch embedding (already quantized by wrapper) + hidden_states = self.patch_embed(hidden_states) + + # Position embedding + pos_embeds = self.pos_embed_template.to( + dtype=hidden_states.dtype, device=hidden_states.device + ) + pos_embeds = self._fq(pos_embeds, self.obs_pos_embeds) + hidden_states = hidden_states + pos_embeds + hidden_states = self._fq(hidden_states, self.obs_pos_add) + + # Reshape hidden states + seq_len, _ = hidden_states.size() + hidden_states = hidden_states.reshape(seq_len, -1) + + # Use precomputed RoPE position embeddings (cos, sin) and quantize + cos = self.rope_cos_template.to( + dtype=hidden_states.dtype, device=hidden_states.device + ) + sin = self.rope_sin_template.to( + dtype=hidden_states.dtype, device=hidden_states.device + ) + position_embeddings = ( + self._fq(cos, self.obs_rope_cos), + self._fq(sin, self.obs_rope_sin), + ) + + cu_seqlens = self.cu_seqlens_template + + # Process through transformer blocks + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states = blk( + hidden_states, + cu_seqlens=cu_seqlens, + position_embeddings=position_embeddings, + **kwargs, + ) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[ + self.deepstack_visual_indexes.index(layer_num) + ](hidden_states) + deepstack_feature_lists.append(deepstack_feature) + + # Merge patches (already quantized by wrapper) + merged_hidden_states = self.merger(hidden_states) + + # Return in the same format as the original + if self.has_deepstack_model_output: + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + BaseModelOutputWithDeepstackFeatures, + ) + + return BaseModelOutputWithDeepstackFeatures( + last_hidden_state=hidden_states, + pooler_output=merged_hidden_states, + deepstack_features=deepstack_feature_lists, + ) + else: + return merged_hidden_states, deepstack_feature_lists + + def _all_observers(self) -> Iterable: + """Yield all observers from this module.""" + # Local observers + yield from ( + self.obs_pos_embeds, + self.obs_pos_add, + self.obs_rope_cos, + self.obs_rope_sin, + ) diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index 725475dc..8cf00d9c 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -74,6 +74,7 @@ "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_embed", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_merger", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_block", + "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_model", # add future core wrappers here )