diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py new file mode 100644 index 00000000..87b76366 --- /dev/null +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py @@ -0,0 +1,1116 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import unittest +from typing import Tuple + +import torch +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.utils.version import has_transformers_for +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.qwen_vl.quant_model import QuantQwen3VLModel + + +skip_msg = "transformers not installed — skipping Qwen3VLModel tests" + + +@unittest.skipUnless(has_transformers_for("qwen3-vl"), skip_msg) +class TestQuantQwen3VLModel(unittest.TestCase): + fp_model: torch.nn.Module + hidden_size: int + vocab_size: int + patch_size: int + temporal_patch_size: int + video_token_id: int + image_token_id: int + spatial_merge_size: int + ptq_config: PTQConfig + + @classmethod + def setUpClass(cls): + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel + + # Use smaller sizes for testing + cfg = Qwen3VLConfig( + vision_config={ + "hidden_size": 64, + "num_heads": 4, + "depth": 2, # Smaller depth for faster testing + "temporal_patch_size": 2, + "patch_size": 16, + "out_hidden_size": 64, + "deepstack_visual_indexes": [0, 1], + }, + text_config={ + "hidden_size": 64, + "intermediate_size": 256, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "head_dim": 32, + "num_hidden_layers": 2, + "attention_bias": False, + "attention_dropout": 0.0, + "max_position_embeddings": 1024, + "vocab_size": 1000, + "use_cache": False, + "rope_scaling": {"rope_type": "default", "mrope_section": [1, 1, 2]}, + }, + image_token_id=998, + video_token_id=999, + ) + + assert cfg.image_token_id < cfg.text_config.vocab_size + assert cfg.video_token_id < cfg.text_config.vocab_size + assert cfg.vision_config.out_hidden_size == cfg.text_config.hidden_size + + cls.fp_model = Qwen3VLModel(cfg) + cls.patch_size = cfg.vision_config.patch_size + cls.temporal_patch_size = cfg.vision_config.temporal_patch_size + cls.hidden_size = cfg.text_config.hidden_size + cls.vocab_size = cfg.text_config.vocab_size + cls.video_token_id = cfg.video_token_id + cls.image_token_id = cfg.image_token_id + cls.spatial_merge_size = cfg.vision_config.spatial_merge_size + + @staticmethod + def _make_ptq_config(grid_thw: Tuple[int, int, int]) -> PTQConfig: + return PTQConfig( + model_args={ + "vision": { + "grid_thw": grid_thw, + } + } + ) + + @staticmethod + def _compute_3d_position_ids( + input_ids: torch.Tensor, + thw: Tuple[int, int, int], + spatial_merge_size: int, + image_token_id: int, + ) -> torch.Tensor: + """ + Compute 3D position IDs for multimodal RoPE. + This function pre-computes position_ids to avoid tracing issues during model export. + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + position_ids = torch.ones( + 3, batch_size, seq_len, dtype=input_ids.dtype, device=device + ) + + for i in range(batch_size): + # Find positions of image tokens + image_mask = input_ids[i] == image_token_id + image_positions = torch.nonzero(image_mask, as_tuple=True)[0] + + llm_pos_ids_list: list[torch.tensor] = [] + st = 0 + + # Process visual tokens + if len(image_positions) > 0: + # Group consecutive placeholder tokens into a single visual object + # All consecutive image tokens represent ONE image/video + start_pos = image_positions[0].item() + + # Text position IDs (before first visual token) + text_len = start_pos - st + if text_len > 0: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len, device=device).view(1, -1).expand(3, -1) + + st_idx + ) + + # Vision position IDs (3D) + llm_grid_t = 1 # Always 1 for images + llm_grid_h = thw[1] // spatial_merge_size + llm_grid_w = thw[2] // spatial_merge_size + + t_index = ( + torch.arange(llm_grid_t, device=device) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h, device=device) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w, device=device) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + st_idx + ) + + # Update st to after all visual placeholder tokens + # The number of visual tokens is (thw[1] // spatial_merge_size) * (thw[2] // spatial_merge_size) + num_visual_tokens = (thw[1] // spatial_merge_size) * ( + thw[2] // spatial_merge_size + ) + st = start_pos + num_visual_tokens + + # Trailing text + if st < seq_len: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + text_len = seq_len - st + llm_pos_ids_list.append( + torch.arange(text_len, device=device).view(1, -1).expand(3, -1) + + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, :] = llm_positions + + return position_ids + + def _create_text_only_input(self, batch_size=1, seq_len=10): + """Helper to create text-only input without images/videos.""" + input_ids = torch.randint( + low=0, high=self.vocab_size, size=(batch_size, seq_len), dtype=torch.long + ) + attention_mask = torch.ones_like(input_ids) + return input_ids, attention_mask + + def _create_visual_input( + self, + visual_token_id: int, + batch_size: int, + seq_len: int, + thw: Tuple[int, int, int], + ): + """Helper to create input with videos or images.""" + assert visual_token_id in (self.video_token_id, self.image_token_id) + + # Calculate number of visual placeholder tokens needed + # Each video is represented by multiple tokens after spatial merge + # Spatial merge reduces the grid size by spatial_merge_size in each dimension + num_video_tokens = (thw[1] // self.spatial_merge_size) * ( + thw[2] // self.spatial_merge_size + ) + assert ( + num_video_tokens <= seq_len + ), f"{num_video_tokens} video tokens can't fit into input sequence of length {seq_len}" + + # Create input_ids with random text tokens + input_ids = torch.randint( + low=0, + high=self.vocab_size - 2, + size=(batch_size, seq_len), + dtype=torch.long, + ) + attention_mask = torch.ones_like(input_ids) + + # Replace first tokens with video placeholder tokens + # This marks where the video features should be inserted + for i in range(batch_size): + input_ids[i, :num_video_tokens] = visual_token_id + + num_temporal_patches, num_spatial_patches_h, num_spatial_patches_w = thw + + # Create pixel values for videos + pixel_values = torch.randn( + batch_size, + 3, + num_temporal_patches * self.temporal_patch_size, + num_spatial_patches_h * self.patch_size, + num_spatial_patches_w * self.patch_size, + ) + video_grid_thw = torch.tensor([thw]) + + position_ids = self._compute_3d_position_ids( + input_ids=input_ids, + thw=thw, + spatial_merge_size=self.spatial_merge_size, + image_token_id=visual_token_id, + ) + + return input_ids, attention_mask, pixel_values, video_grid_thw, position_ids + + def _create_video_input( + self, + batch_size: int = 1, + seq_len: int = 64, + thw: Tuple[int, int, int] = (1, 8, 8), + ): + return self._create_visual_input(self.video_token_id, batch_size, seq_len, thw) + + def _create_image_input( + self, + batch_size: int = 1, + seq_len: int = 64, + thw: Tuple[int, int, int] = (1, 8, 8), + ): + return self._create_visual_input(self.image_token_id, batch_size, seq_len, thw) + + # ------------------------------------------------------------------------- + # Initialization tests + # ------------------------------------------------------------------------- + + def test_wraps_submodules(self): + """Test that __init__ wraps all submodules with PTQWrapper.""" + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + + # Check that submodules are wrapped + self.assertTrue(hasattr(q_model, "visual")) + self.assertIsInstance(q_model.visual, PTQWrapper) + + self.assertTrue(hasattr(q_model, "language_model")) + self.assertIsInstance(q_model.language_model, PTQWrapper) + + def test_mode_transitions(self): + """Test quantization mode transitions: NO_QUANT → CALIB → QUANT""" + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + self.assertIs(q_model._mode, Mode.NO_QUANT) + + q_model.enable_calibration() + self.assertIs(q_model._mode, Mode.CALIB) + + # Run forward pass during calibration (text-only) + input_ids, attention_mask = self._create_text_only_input() + _ = q_model(input_ids=input_ids, attention_mask=attention_mask) + + q_model.freeze_qparams() + self.assertIs(q_model._mode, Mode.QUANT) + + # ------------------------------------------------------------------------- + # Forward pass tests + # ------------------------------------------------------------------------- + + def test_forward_text_only(self): + """Test forward pass with text-only input.""" + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + input_ids, attention_mask = self._create_text_only_input() + _ = q_model(input_ids=input_ids, attention_mask=attention_mask) + + q_model.freeze_qparams() + + with torch.no_grad(): + output = q_model(input_ids=input_ids, attention_mask=attention_mask) + + # Check output structure + self.assertTrue(hasattr(output, "last_hidden_state")) + self.assertTrue(hasattr(output, "past_key_values")) + self.assertTrue(hasattr(output, "rope_deltas")) + + # Check output shape + batch_size, seq_len = input_ids.shape + self.assertEqual( + output.last_hidden_state.shape, (batch_size, seq_len, self.hidden_size) + ) + + def test_forward_with_images(self): + """Test forward pass with image input.""" + thw = (1, 8, 8) + ptq_config = self._make_ptq_config(grid_thw=thw) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + ( + input_ids, + attention_mask, + pixel_values, + image_grid_thw, + position_ids, + ) = self._create_image_input(thw=thw) + + _ = q_model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + q_model.freeze_qparams() + + with torch.no_grad(): + output = q_model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + # Check output structure + self.assertTrue(hasattr(output, "last_hidden_state")) + self.assertTrue(hasattr(output, "past_key_values")) + self.assertTrue(hasattr(output, "rope_deltas")) + + # Check output shape + batch_size, seq_len = input_ids.shape + self.assertEqual( + output.last_hidden_state.shape, (batch_size, seq_len, self.hidden_size) + ) + + def test_forward_with_videos(self): + """Test forward pass with video input.""" + thw = (1, 8, 8) + ptq_config = self._make_ptq_config(grid_thw=thw) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + ( + input_ids, + attention_mask, + pixel_values_videos, + video_grid_thw, + position_ids, + ) = self._create_video_input(thw=thw) + + _ = q_model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + q_model.freeze_qparams() + + with torch.no_grad(): + output = q_model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + # Check output structure + self.assertTrue(hasattr(output, "last_hidden_state")) + self.assertTrue(hasattr(output, "past_key_values")) + self.assertTrue(hasattr(output, "rope_deltas")) + + # Check output shape + batch_size, seq_len = input_ids.shape + self.assertEqual( + output.last_hidden_state.shape, (batch_size, seq_len, self.hidden_size) + ) + + def test_forward_with_both_images_and_videos(self): + """Test forward pass with both image and video inputs (tests deepstack feature combination).""" + thw = (1, 8, 8) + ptq_config = self._make_ptq_config(grid_thw=thw) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + # Calculate visual token count + num_visual_tokens = (thw[1] // self.spatial_merge_size) * ( + thw[2] // self.spatial_merge_size + ) # 16 tokens + + # Create input with both images and videos + batch_size = 1 + seq_len = 64 + + # Start with image tokens + input_ids = torch.randint( + low=0, + high=self.vocab_size - 2, + size=(batch_size, seq_len), + dtype=torch.long, + ) + input_ids[0, 0:num_visual_tokens] = self.image_token_id + + # Add video tokens later in the sequence (with some text in between) + video_start = num_visual_tokens + 10 + input_ids[ + 0, video_start : video_start + num_visual_tokens + ] = self.video_token_id + + # Create pixel values for images + pixel_values = torch.randn( + batch_size, + 3, + thw[0] * self.temporal_patch_size, + thw[1] * self.patch_size, + thw[2] * self.patch_size, + ) + image_grid_thw = torch.tensor([thw]) + + # Create pixel values for videos + pixel_values_videos = torch.randn( + batch_size, + 3, + thw[0] * self.temporal_patch_size, + thw[1] * self.patch_size, + thw[2] * self.patch_size, + ) + video_grid_thw = torch.tensor([thw]) + + # Run forward pass with both images and videos + _ = q_model( + input_ids=input_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + q_model.freeze_qparams() + + with torch.no_grad(): + output = q_model( + input_ids=input_ids, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + pixel_values_videos=pixel_values_videos, + video_grid_thw=video_grid_thw, + ) + + # Check output structure + self.assertTrue(hasattr(output, "last_hidden_state")) + self.assertTrue(hasattr(output, "past_key_values")) + self.assertTrue(hasattr(output, "rope_deltas")) + + # Check output shape + batch_size, seq_len = input_ids.shape + self.assertEqual( + output.last_hidden_state.shape, (batch_size, seq_len, self.hidden_size) + ) + + def test_forward_with_inputs_embeds(self): + """Test forward pass with inputs_embeds (triggers embedding comparison logic).""" + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + batch_size = 1 + seq_len = 20 + + # Create inputs_embeds with some tokens matching image/video token embeddings + inputs_embeds = torch.randn(batch_size, seq_len, self.hidden_size) + + # Create attention mask + attention_mask = torch.ones(batch_size, seq_len) + + # Run forward pass with inputs_embeds (not input_ids) + # This should trigger the embedding comparison logic in _get_placeholder_mask + _ = q_model( + input_ids=None, inputs_embeds=inputs_embeds, attention_mask=attention_mask + ) + + q_model.freeze_qparams() + + with torch.no_grad(): + output = q_model( + input_ids=None, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + ) + + # Check output structure + self.assertTrue(hasattr(output, "last_hidden_state")) + self.assertTrue(hasattr(output, "past_key_values")) + self.assertTrue(hasattr(output, "rope_deltas")) + + # Check output shape + self.assertEqual( + output.last_hidden_state.shape, (batch_size, seq_len, self.hidden_size) + ) + + def test_forward_with_inputs_embeds_and_images(self): + """Test forward pass with inputs_embeds and images (triggers QuantQwen3VLModel._get_placeholder_mask).""" + thw = (1, 8, 8) + ptq_config = self._make_ptq_config(grid_thw=thw) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + batch_size = 1 + seq_len = 64 + + # Create inputs_embeds with some tokens matching image/video token embeddings + inputs_embeds = torch.randn(batch_size, seq_len, self.hidden_size) + + # Calculate number of visual placeholder tokens needed + # Each video is represented by multiple tokens after spatial merge + # Spatial merge reduces the grid size by spatial_merge_size in each dimension + num_video_tokens = (thw[1] // self.spatial_merge_size) * ( + thw[2] // self.spatial_merge_size + ) + + # Replace first tokens with video placeholder tokens + # This marks where the video features should be inserted + embedder = q_model.language_model.wrapped.embed_tokens + img_tkn = torch.tensor( + self.image_token_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + img_tkn_emb = embedder(img_tkn) + for i in range(batch_size): + inputs_embeds[i, :num_video_tokens] = img_tkn_emb + + pixel_values = torch.randn( + batch_size, + 3, + thw[0] * self.temporal_patch_size, + thw[1] * self.patch_size, + thw[2] * self.patch_size, + ) + + grid_thw = torch.tensor([thw]) + + # Create attention mask + attention_mask = torch.ones(batch_size, seq_len) + + # Run forward pass with inputs_embeds (not input_ids) + # This should trigger the embedding comparison logic in _get_placeholder_mask + _ = q_model( + input_ids=None, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=grid_thw, + ) + + q_model.freeze_qparams() + + # Recompute image token embedding as quantization may produce a different result for it + img_tkn_emb = embedder(img_tkn) + for i in range(batch_size): + inputs_embeds[i, :num_video_tokens] = img_tkn_emb + + with torch.no_grad(): + output = q_model( + input_ids=None, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=grid_thw, + ) + + # Check output structure + self.assertTrue(hasattr(output, "last_hidden_state")) + self.assertTrue(hasattr(output, "past_key_values")) + self.assertTrue(hasattr(output, "rope_deltas")) + + # Check output shape + self.assertEqual( + output.last_hidden_state.shape, (batch_size, seq_len, self.hidden_size) + ) + + def test_forward_input_validation(self): + """Test that forward validates input requirements.""" + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + + # Test: neither input_ids nor inputs_embeds + with self.assertRaises(ValueError) as context: + _ = q_model() + self.assertIn( + "exactly one of input_ids or inputs_embeds", str(context.exception) + ) + + # Test: both input_ids and inputs_embeds + input_ids, _ = self._create_text_only_input() + inputs_embeds = torch.randn(1, 10, self.hidden_size) + with self.assertRaises(ValueError) as context: + _ = q_model(input_ids=input_ids, inputs_embeds=inputs_embeds) + self.assertIn( + "exactly one of input_ids or inputs_embeds", str(context.exception) + ) + + # ------------------------------------------------------------------------- + # Output comparison tests + # ------------------------------------------------------------------------- + + def test_forward_diff_text_only(self): + """ + Test that quantized output is acceptably close to FP reference for text-only input. + """ + torch.manual_seed(42) + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + # Calibrate with multiple inputs + for _ in range(4): + input_ids, attention_mask = self._create_text_only_input() + _ = q_model(input_ids=input_ids, attention_mask=attention_mask) + + q_model.freeze_qparams() + + input_ids, attention_mask = self._create_text_only_input() + with torch.no_grad(): + q_out = q_model(input_ids=input_ids, attention_mask=attention_mask) + fp_out = self.fp_model(input_ids=input_ids, attention_mask=attention_mask) + + self.assertEqual(q_out.last_hidden_state.shape, fp_out.last_hidden_state.shape) + diff = (fp_out.last_hidden_state - q_out.last_hidden_state).abs().mean().item() + self.assertGreater(diff, 0.0) # not identical + self.assertLess(diff, 0.7) # acceptably close + + # ------------------------------------------------------------------------- + # Registration tests + # ------------------------------------------------------------------------- + + def test_registration_in_registry(self): + """Test that Qwen3VLModel is properly registered.""" + from tico.quantization.wrapq.wrappers.qwen_vl.quant_model import ( + QuantQwen3VLModel, + ) + from tico.quantization.wrapq.wrappers.registry import lookup + from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel + + wrapper_cls = lookup(Qwen3VLModel) + self.assertIs(wrapper_cls, QuantQwen3VLModel) + + # ------------------------------------------------------------------------- + # Observer tests + # ------------------------------------------------------------------------- + + def test_observer_count(self): + """Test that the wrapper has the correct number of observers.""" + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + + observers = list(q_model._all_observers()) + # Should have 1 local observer (obs_mm_fusion) + self.assertEqual(len(observers), 1) + + def test_activation_stats_collected_text_only(self): + """Test that activation statistics are collected during calibration (text-only).""" + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + # Run forward pass to collect stats + input_ids, attention_mask = self._create_text_only_input() + _ = q_model(input_ids=input_ids, attention_mask=attention_mask) + + # Freeze and check qparams exist for multimodal fusion observer + q_model.freeze_qparams() + self.assertTrue(q_model.obs_mm_fusion.has_qparams) + + def test_activation_stats_collected_with_images(self): + """Test that activation statistics are collected during calibration (with images).""" + thw = (1, 8, 8) + ptq_config = self._make_ptq_config(grid_thw=thw) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + # Run forward pass with images + ( + input_ids, + attention_mask, + pixel_values, + image_grid_thw, + position_ids, + ) = self._create_image_input(thw=thw) + _ = q_model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=image_grid_thw, + ) + + # Freeze and check qparams exist + q_model.freeze_qparams() + self.assertTrue(q_model.obs_mm_fusion.has_qparams) + + # ------------------------------------------------------------------------- + # Multiple calibration steps tests + # ------------------------------------------------------------------------- + + def test_multiple_calibration_steps_text_only(self): + """Test that running multiple calibration iterations works correctly.""" + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + # Run multiple calibration steps + for i in range(5): + input_ids, attention_mask = self._create_text_only_input() + _ = q_model(input_ids=input_ids, attention_mask=attention_mask) + + q_model.freeze_qparams() + + # Verify that observer has quantization parameters + self.assertTrue(q_model.obs_mm_fusion.has_qparams) + + # ------------------------------------------------------------------------- + # Config override tests + # ------------------------------------------------------------------------- + + def test_dtype_override(self): + """ + PTQConfig overrides should propagate to observers created by QuantQwen3VLModel. + """ + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + ptq_config.overrides = {"mm_fusion": {"dtype": DType.uint(4)}} + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + + # Check that overrides were applied + self.assertEqual(q_model.obs_mm_fusion.dtype, DType.uint(4)) + + # ------------------------------------------------------------------------- + # Batch size tests + # ------------------------------------------------------------------------- + + def test_different_batch_sizes_text_only(self): + """Test that quantization works correctly with different batch sizes.""" + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + # Calibrate with one batch size + input_ids, attention_mask = self._create_text_only_input(batch_size=2) + for _ in range(3): + _ = q_model(input_ids=input_ids, attention_mask=None) + q_model.freeze_qparams() + + # Test with different batch sizes + for batch_size in [1, 2, 4]: + input_ids, attention_mask = self._create_text_only_input(batch_size) + with torch.no_grad(): + output = q_model(input_ids=input_ids, attention_mask=None) + + expected_shape = (batch_size, input_ids.shape[1], self.hidden_size) + self.assertEqual(output.last_hidden_state.shape, expected_shape) + + # ------------------------------------------------------------------------- + # Rope deltas tests + # ------------------------------------------------------------------------- + + def test_rope_deltas_computed_after_forward(self): + """Test that rope_deltas are computed after forward pass.""" + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + # Initially None + self.assertIsNone(q_model.rope_deltas) + + # After forward pass (text-only), rope_deltas should be computed + input_ids, attention_mask = self._create_text_only_input() + output = q_model(input_ids=input_ids, attention_mask=attention_mask) + + # rope_deltas should now be set (even for text-only input) + self.assertIsNotNone(q_model.rope_deltas) + self.assertIsNotNone(output.rope_deltas) + + # ------------------------------------------------------------------------- + # _get_rope_index tests + # ------------------------------------------------------------------------- + + def test_get_rope_index_with_images_and_videos(self): + """Test _get_rope_index generates correct 3D position IDs for mixed image/video input.""" + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + + # Create input with vision_start_token_id followed by image/video tokens + batch_size = 1 + seq_len = 64 + input_ids = torch.zeros(batch_size, seq_len, dtype=torch.long) + + # Add vision_start_token_id, image tokens, then text + idx = 0 + input_ids[0, idx] = self.fp_model.config.vision_start_token_id + idx += 1 + input_ids[0, idx] = self.image_token_id + idx += 1 + input_ids[0, idx : idx + 16] = self.image_token_id # 16 image tokens + idx += 16 + input_ids[0, idx : idx + 10] = torch.randint( + 0, self.vocab_size - 2, (10,), dtype=torch.long + ) # Text tokens + idx += 10 + input_ids[0, idx] = self.fp_model.config.vision_start_token_id + idx += 1 + input_ids[0, idx] = self.video_token_id + idx += 1 + input_ids[0, idx : idx + 32] = self.video_token_id # 32 video tokens + idx += 32 + input_ids[0, idx:] = torch.randint( + 0, self.vocab_size - 2, (seq_len - idx,), dtype=torch.long + ) + + attention_mask = torch.ones_like(input_ids) + + # Grid dimensions for images/videos + # 1 image: (1, 8, 8) -> after spatial merge: (1, 4, 4) -> 16 tokens + # 2 videos: (1, 4, 8) -> after spatial merge: (1, 2, 4) -> 16 tokens each + image_grid_thw = torch.tensor([[1, 8, 8]]) + video_grid_thw = torch.tensor([[1, 4, 8], [1, 4, 8]]) + + # Call _get_rope_index + position_ids, mrope_position_deltas = q_model._get_rope_index( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + ) + + # Verify output structure + self.assertEqual(position_ids.shape, (3, batch_size, seq_len)) + self.assertEqual(mrope_position_deltas.shape, (batch_size, 1)) + + # Check that position_ids are 3D (t, h, w) for vision tokens + # First 16 tokens should be image (t=0, h in [0,3], w in [0,3]) + image_pos_range = position_ids[:, :, 0:16] + self.assertEqual(image_pos_range.shape, (3, batch_size, 16)) + + # Next 10 text tokens should have 3D IDs (t=0, h=w=0 since no visual) + text_pos_range = position_ids[:, :, 16:26] + self.assertEqual(text_pos_range.shape, (3, batch_size, 10)) + + # Last 32 video tokens should have 3D IDs + video_pos_range = position_ids[:, :, 26:58] + self.assertEqual(video_pos_range.shape, (3, batch_size, 32)) + + def test_compute_3d_position_ids_reuses_cached_rope_deltas(self): + """Test that _compute_3d_position_ids reuses cached rope_deltas for subsequent passes.""" + ptq_config = self._make_ptq_config(grid_thw=(1, 8, 8)) + q_model = QuantQwen3VLModel(self.fp_model, qcfg=ptq_config) + q_model.enable_calibration() + + # First forward pass to compute rope_deltas + input_ids, attention_mask = self._create_text_only_input( + batch_size=1, seq_len=10 + ) + _ = q_model(input_ids=input_ids, attention_mask=attention_mask) + + # rope_deltas should now be cached + self.assertIsNotNone(q_model.rope_deltas) + assert q_model.rope_deltas is not None # for mypy + cached_rope_deltas = q_model.rope_deltas.clone() + + # Simulate autoregressive generation with past_key_values + # This should trigger the else branch that reuses cached rope_deltas + seq_len_new = 15 # Generate 5 more tokens + input_ids_new = torch.randint( + 0, self.vocab_size, size=(1, seq_len_new), dtype=torch.long + ) + attention_mask_new = torch.ones_like(input_ids_new) + + # Create mock past_key_values that simulates previous cache + class MockPastKeyValues: + def __init__(self, seq_length): + self._seq_length = seq_length + + def get_seq_length(self): + return self._seq_length # Previous sequence length + + past_key_values = MockPastKeyValues(seq_length=10) + inputs_embeds_new = torch.randn(1, seq_len_new, self.hidden_size) + + # Call _compute_3d_position_ids with past_key_values (should reuse cached deltas) + position_ids = q_model._compute_3d_position_ids( + input_ids=None, + inputs_embeds=inputs_embeds_new, + image_grid_thw=None, + video_grid_thw=None, + attention_mask=None, + cache_position=None, + past_key_values=past_key_values, + ) + + # Verify rope_deltas were reused (not recomputed) + self.assertTrue(torch.equal(q_model.rope_deltas, cached_rope_deltas)) + + # Verify position_ids shape + assert position_ids is not None # for mypy + self.assertEqual(position_ids.shape, (3, 1, seq_len_new)) + + # Verify position_ids are monotonic (increasing) and properly offset + pos_ids_first = position_ids[0, 0, 0].item() + pos_ids_last = position_ids[0, 0, -1].item() + self.assertGreater(pos_ids_last, pos_ids_first) + self.assertGreaterEqual(pos_ids_first, past_key_values.get_seq_length()) + + # ------------------------------------------------------------------------- + # Circle conversion tests + # ------------------------------------------------------------------------- + + def test_graph_tracing_behavior_with_images(self): + """Test that QuantQwen3VLModel behavior in graph tracing mode.""" + import tico + + thw = (1, 8, 8) + ptq_config = self._make_ptq_config(grid_thw=thw) + + # Prepare the model for quantization + prepared_model = tico.quantization.prepare( + copy.deepcopy(self.fp_model), ptq_config, inplace=False + ) + prepared_model.eval() + + # Create example input + ( + input_ids, + attention_mask, + pixel_values, + grid_thw, + position_ids, + ) = self._create_image_input(batch_size=1, seq_len=64, thw=thw) + + # Calibrate with text-only input + with torch.no_grad(): + prepared_model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values=pixel_values, + image_grid_thw=grid_thw, + position_ids=position_ids, + ) + + # Convert to quantized model + quantized_model = tico.quantization.convert(prepared_model, inplace=False) + + # Create example input as namedtuple + from collections import namedtuple + + ModelInput = namedtuple( + "ModelInput", + [ + "input_ids", + "attention_mask", + "position_ids", + "past_key_values", + "inputs_embeds", + "pixel_values", + "pixel_values_videos", + "image_grid_thw", + "video_grid_thw", + "cache_position", + ], + ) + + example_input = ModelInput( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=pixel_values, + pixel_values_videos=None, + image_grid_thw=grid_thw, + video_grid_thw=None, + cache_position=None, + ) + + # Compute quantization error + quantized_model.wrapped.force_export = True + with torch.no_grad(): + test_input = example_input._asdict() + quant_out = quantized_model(**test_input).last_hidden_state + fp_out = self.fp_model(**test_input).last_hidden_state + + err = (quant_out - fp_out).abs().mean().item() + self.assertLess(err, 1.0) + + def test_graph_tracing_behavior_with_videos(self): + """Test that QuantQwen3VLModel behavior in graph tracing mode.""" + import tico + + thw = (1, 8, 8) + ptq_config = self._make_ptq_config(grid_thw=thw) + + # Prepare the model for quantization + prepared_model = tico.quantization.prepare( + copy.deepcopy(self.fp_model), ptq_config, inplace=False + ) + prepared_model.eval() + + # Create example input + ( + input_ids, + attention_mask, + pixel_values_videos, + grid_thw, + position_ids, + ) = self._create_video_input(batch_size=1, seq_len=64, thw=thw) + + # Calibrate with text-only input + with torch.no_grad(): + prepared_model( + input_ids=input_ids, + attention_mask=attention_mask, + pixel_values_videos=pixel_values_videos, + video_grid_thw=grid_thw, + position_ids=position_ids, + ) + + # Convert to quantized model + quantized_model = tico.quantization.convert(prepared_model, inplace=False) + + # Create example input as namedtuple + from collections import namedtuple + + ModelInput = namedtuple( + "ModelInput", + [ + "input_ids", + "attention_mask", + "position_ids", + "past_key_values", + "inputs_embeds", + "pixel_values", + "pixel_values_videos", + "image_grid_thw", + "video_grid_thw", + "cache_position", + ], + ) + + example_input = ModelInput( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=None, + pixel_values_videos=pixel_values_videos, + image_grid_thw=None, + video_grid_thw=grid_thw, + cache_position=None, + ) + + # Compute quantization error + quantized_model.wrapped.force_export = True + with torch.no_grad(): + test_input = example_input._asdict() + quant_out = quantized_model(**test_input).last_hidden_state + fp_out = self.fp_model(**test_input).last_hidden_state + + err = (quant_out - fp_out).abs().mean().item() + self.assertLess(err, 1.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/tico/quantization/wrapq/examples/qwen/quantize_model.py b/tico/quantization/wrapq/examples/qwen/quantize_model.py new file mode 100644 index 00000000..6e7d2ea4 --- /dev/null +++ b/tico/quantization/wrapq/examples/qwen/quantize_model.py @@ -0,0 +1,357 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example script for quantizing and converting Qwen3VLModel to Circle format. + +This script demonstrates: +1. Loading a Qwen3VL vision-language model +2. Preparing calibration data with text, images, and videos +3. Configuring PTQ (Post-Training Quantization) +4. Calibrating the model to collect statistics +5. Converting to quantized model +6. Evaluating quantization accuracy +7. Converting to Circle format for deployment + +Usage: + python quantize_model.py +""" + +import copy +import sys +from collections import namedtuple +from typing import Tuple + +import torch +import torch.nn as nn + +import tico +import tico.quantization +import tico.quantization.config.ptq +from tico.quantization.evaluation.metric import compute_peir +from tico.quantization.evaluation.utils import plot_two_outputs +from tico.quantization.wrapq.utils.version import has_transformers_for + +torch.manual_seed(123) + + +# Check if transformers is available +if not has_transformers_for("qwen3-vl"): + print( + "Error: Required transformers package not installed. Cannot test Qwen3VLModel." + ) + sys.exit(1) + +from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig +from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLModel + +Modelnput = namedtuple( + "Modelnput", + [ + "input_ids", + "attention_mask", + "position_ids", + "past_key_values", + "inputs_embeds", + "pixel_values", + "pixel_values_videos", + "image_grid_thw", + "video_grid_thw", + "cache_position", + ], +) + + +def create_visual_input( + seq_len: int, + thw: Tuple[int, int, int], + spatial_merge_size: int, + temporal_patch_size: int, + spatial_patch_size: int, + vocab_size: int, + image_token_id: int, +): + """Helper to create input with videos or images.""" + assert ( + image_token_id >= vocab_size - 2 + ), f"Visual token Id {image_token_id} must be outside text vocabulary range 0...{vocab_size-2}." + + batch_size = 1 + + # Calculate number of visual placeholder tokens needed + # Each video is represented by multiple tokens after spatial merge + # Spatial merge reduces the grid size by spatial_merge_size in each dimension + num_video_tokens = (thw[1] // spatial_merge_size) * (thw[2] // spatial_merge_size) + assert ( + num_video_tokens <= seq_len + ), f"{num_video_tokens} video tokens can't fit into input sequence of length {seq_len}" + + # Create input_ids with random text tokens + input_ids = torch.randint( + low=0, + high=vocab_size - 2, + size=(batch_size, seq_len), + dtype=torch.long, + ) + + # Replace first tokens with video placeholder tokens + # This marks where the video features should be inserted + for i in range(batch_size): + input_ids[i, :num_video_tokens] = image_token_id + + num_temporal_patches, num_spatial_patches_h, num_spatial_patches_w = thw + + # Create pixel values for videos + pixel_values = torch.randn( + batch_size, + 3, + num_temporal_patches * temporal_patch_size, + num_spatial_patches_h * spatial_patch_size, + num_spatial_patches_w * spatial_patch_size, + ) + grid_thw = torch.tensor([thw]) + + # Compute position_ids for 3D RoPE + # This replicates the logic from _get_rope_index but pre-computes it + position_ids = compute_3d_position_ids( + input_ids=input_ids, + thw=thw, + spatial_merge_size=spatial_merge_size, + image_token_id=image_token_id, + ) + + return Modelnput( + input_ids=input_ids, + attention_mask=None, + position_ids=position_ids, + past_key_values=None, + inputs_embeds=None, + pixel_values=pixel_values, + pixel_values_videos=None, + image_grid_thw=grid_thw, + video_grid_thw=None, + cache_position=None, + ) + + +def compute_3d_position_ids( + input_ids: torch.Tensor, + thw: Tuple[int, int, int], + spatial_merge_size: int, + image_token_id: int, +) -> torch.Tensor: + """ + Compute 3D position IDs for multimodal RoPE. + This function pre-computes position_ids to avoid tracing issues during model export. + """ + batch_size, seq_len = input_ids.shape + device = input_ids.device + + position_ids = torch.ones( + 3, batch_size, seq_len, dtype=input_ids.dtype, device=device + ) + + for i in range(batch_size): + # Find positions of image tokens + image_mask = input_ids[i] == image_token_id + image_positions = torch.nonzero(image_mask, as_tuple=True)[0] + + llm_pos_ids_list: list[torch.tensor] = [] + st = 0 + + # Process visual tokens + if len(image_positions) > 0: + # Group consecutive placeholder tokens into a single visual object + # All consecutive image tokens represent ONE image/video + start_pos = image_positions[0].item() + + # Text position IDs (before first visual token) + text_len = start_pos - st + if text_len > 0: + st_idx = ( + llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len, device=device).view(1, -1).expand(3, -1) + + st_idx + ) + + # Vision position IDs (3D) + llm_grid_t = 1 # Always 1 for images + llm_grid_h = thw[1] // spatial_merge_size + llm_grid_w = thw[2] // spatial_merge_size + + t_index = ( + torch.arange(llm_grid_t, device=device) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h, device=device) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w, device=device) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) + + # Update st to after all visual placeholder tokens + # The number of visual tokens is (thw[1] // spatial_merge_size) * (thw[2] // spatial_merge_size) + num_visual_tokens = (thw[1] // spatial_merge_size) * ( + thw[2] // spatial_merge_size + ) + st = start_pos + num_visual_tokens + + # Trailing text + if st < seq_len: + st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 + text_len = seq_len - st + llm_pos_ids_list.append( + torch.arange(text_len, device=device).view(1, -1).expand(3, -1) + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, :] = llm_positions + + return position_ids + + +def generate_calibration_data( + batch_size: int, + seq_len: int, + thw: Tuple[int, int, int], + spatial_merge_size: int, + temporal_patch_size: int, + spatial_patch_size: int, + vocab_size: int, + image_token_id: int, +): + calibration_data = [] + for i in range(batch_size): + x = create_visual_input( + seq_len, + thw, + spatial_merge_size, + temporal_patch_size, + spatial_patch_size, + vocab_size, + image_token_id, + ) + calibration_data.append(x) + return calibration_data + + +def main(): + # Create Qwen3VL configuration + cfg = Qwen3VLConfig( + vision_config={ + "hidden_size": 64, + "num_heads": 4, + "depth": 2, # Smaller depth for faster testing + "temporal_patch_size": 2, + "patch_size": 16, + "out_hidden_size": 64, + }, + text_config={ + "hidden_size": 64, + "intermediate_size": 256, + "num_attention_heads": 2, + "num_key_value_heads": 2, + "head_dim": 32, + "num_hidden_layers": 2, + "attention_bias": False, + "attention_dropout": 0.0, + "max_position_embeddings": 1024, + "vocab_size": 1000, + "use_cache": False, + "rope_scaling": {"rope_type": "default", "mrope_section": [1, 1, 2]}, + }, + image_token_id=998, + video_token_id=999, + ) + thw = (1, 8, 8) + + # Configure PTQ + ptq_config = tico.quantization.config.ptq.PTQConfig( + model_args={ + "vision": { + "grid_thw": thw, + } + } + ) + + # Load the model + model = Qwen3VLModel(cfg) + orig_model = copy.deepcopy(model) + model.eval() + + # Prepare the model for quantization + prepared_model = tico.quantization.prepare( + model, ptq_config, inplace=True # Transform the model in place + ) + + # Generate calibration data + calibration_data = generate_calibration_data( + batch_size=10, + seq_len=50, + thw=thw, + spatial_merge_size=cfg.vision_config.spatial_merge_size, + temporal_patch_size=cfg.vision_config.temporal_patch_size, + spatial_patch_size=cfg.vision_config.patch_size, + vocab_size=cfg.text_config.vocab_size, + image_token_id=cfg.image_token_id, + ) + + # Calibrate the model (collect statistics) + with torch.no_grad(): + for calibration_input in calibration_data: + prepared_model(**calibration_input._asdict()) + + # Convert to quantized model + quantized_model = tico.quantization.convert(prepared_model, inplace=True) + + # Compute quantization error metrics + with torch.no_grad(): + test_input = calibration_data[0]._asdict() + test_input["position_ids"] = None + quant_out = quantized_model(**test_input).last_hidden_state + fp_out = orig_model(**test_input).last_hidden_state + + print(f"┌───────────── Quantization Error Summary ─────────────") + print(f"│ Mean |diff|: {(quant_out - fp_out).abs().mean().item():.6f}") + print(f"│ PEIR : {compute_peir(fp_out, quant_out) * 100:.6f} %") + print(f"└──────────────────────────────────────────────────────") + print(plot_two_outputs(fp_out, quant_out)) + + # Convert to Circle format + + example_input = calibration_data[0] + circle_model = tico.convert(quantized_model.eval(), example_input) + + # Save the Circle model + filename = "qwen3vl_model.q.circle" + circle_model.save(filename) + print(f"Circle model saved as '{filename}'") + + +if __name__ == "__main__": + main() diff --git a/tico/quantization/wrapq/wrappers/nn/quant_layernorm.py b/tico/quantization/wrapq/wrappers/nn/quant_layernorm.py index cba5010d..c436959a 100644 --- a/tico/quantization/wrapq/wrappers/nn/quant_layernorm.py +++ b/tico/quantization/wrapq/wrappers/nn/quant_layernorm.py @@ -50,7 +50,15 @@ def __init__( ): super().__init__(qcfg, fp_name=fp_name) self.module = fp - self.eps = torch.tensor(self.module.eps) + + # self.eps = torch.tensor(self.module.eps) + # Without registering eps as a buffer (above line) a bunch of warnings is emitted during torch.export.export, e.g. + # UserWarning: Node wrapped_visual_wrapped_blocks_0_wrapped_norm1_wrapped_eps target + # wrapped.visual.wrapped.blocks.0.wrapped.norm1.wrapped.eps eps of + # wrapped.visual.wrapped.blocks.0.wrapped.norm1.wrapped does not reference + # an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target + self.register_buffer("eps", torch.tensor(self.module.eps)) + # Number of trailing dims participating in normalization # (PyTorch stores normalized_shape as a tuple even if an int was passed) self._norm_ndim: int = len(fp.normalized_shape) # safe for int→tuple diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py new file mode 100644 index 00000000..e2ff167d --- /dev/null +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py @@ -0,0 +1,636 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from typing import Iterable, Optional, Union + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register + + +@try_register( + "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLModel", +) +class QuantQwen3VLModel(QuantModuleBase): + """ + Quantization wrapper for Qwen3VLModel module. + + This is the main multimodal model that combines vision and language processing: + - Vision model (Qwen3VLVisionModel): Processes images/videos + - Language model (Qwen3VLTextModel): Processes text and generates outputs + - Multimodal fusion: Combines text and visual embeddings + """ + + # This boolean flag enforces model behavior that is only activated during model graph tracing (torch.export.export). + # This flag is used in unit tests only in order to check the behavior without actually exporting the model. + force_export: bool = False + + rope_deltas: Optional[torch.Tensor] # Type annotation for registered buffer + + def __init__( + self, + fp_model: nn.Module, + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + ): + super().__init__(qcfg, fp_name=fp_name) + self.module = fp_model + + self.image_token_id = fp_model.config.image_token_id + self.video_token_id = fp_model.config.video_token_id + + # Wrap vision model + self.visual = PTQWrapper( + fp_model.visual, + qcfg=qcfg.child("visual") if qcfg else None, + fp_name=f"{fp_name}.visual", + ) + + # Wrap language model + self.language_model = PTQWrapper( + fp_model.language_model, + qcfg=qcfg.child("language_model") if qcfg else None, + fp_name=f"{fp_name}.language_model", + ) + + # Cache for rope_deltas - register as buffer for proper export handling + # persistent=False means it won't be saved in state_dict + self.register_buffer("rope_deltas", None, persistent=False) + + # Multimodal fusion observers (masked scatter results) + self.obs_mm_fusion = self._make_obs("mm_fusion") + + def forward( + self, + input_ids: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.Tensor | None = None, + past_key_values=None, + inputs_embeds: torch.Tensor | None = None, + pixel_values: torch.Tensor | None = None, + pixel_values_videos: torch.Tensor | None = None, + image_grid_thw: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, + cache_position: torch.Tensor | None = None, + **kwargs, + ) -> Union[torch.Tensor, tuple]: + """ + Forward pass with fake quantization. + + Args: + input_ids: Input token IDs of shape (batch_size, sequence_length) + attention_mask: Attention mask of shape (batch_size, sequence_length) + position_ids: Position IDs for RoPE + past_key_values: Past key-value caches for autoregressive generation + inputs_embeds: Input embeddings of shape (batch_size, sequence_length, hidden_size) + pixel_values: Image pixel values of shape (batch_size, C, H, W) + pixel_values_videos: Video pixel values + image_grid_thw: Grid dimensions for images of shape (num_images, 3) + video_grid_thw: Grid dimensions for videos + cache_position: Cache positions for generation + **kwargs: Additional keyword arguments + + Returns: + Model output containing last hidden state, past key values, etc. + """ + if torch.compiler.is_compiling() or self.force_export: + assert ( + position_ids is not None + ), "position_ids must be provided as an argument since it's computation cannot be converted to Circle" + + # Validate input + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You must specify exactly one of input_ids or inputs_embeds" + ) + + # Generate input embeddings from token IDs + if inputs_embeds is None: + assert hasattr(self.language_model.wrapped, "embed_tokens") and isinstance( + self.language_model.wrapped.embed_tokens, PTQWrapper + ) + inputs_embeds = self.language_model.wrapped.embed_tokens(input_ids) + + deepstack_image_embeds: list = None # type: ignore[assignment] + deepstack_video_embeds: list = None # type: ignore[assignment] + + image_mask = None + video_mask = None + + # Process images + if pixel_values is not None: + # Get image features from vision model + image_outputs = self._get_image_features( + pixel_values, image_grid_thw, return_dict=True + ) + image_embeds = image_outputs.pooler_output + deepstack_image_embeds = image_outputs.deepstack_features + + # Concatenate all image features + image_embeds = torch.cat(image_embeds, dim=0).to( + inputs_embeds.device, inputs_embeds.dtype + ) + + # Create mask for image placeholder tokens + image_mask, _ = self._get_placeholder_mask( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + image_features=image_embeds, + ) + + # Replace image placeholders with actual visual features + if torch.compiler.is_compiling() or self.force_export: + # If we are exporting the model as a static graph. + # This assumes visual tokens are placed strictly in the beginning of the prompt + # Violating this requirement will lead to a corrupted prompt + inputs_embeds = self._fuse_text_n_image(inputs_embeds, image_embeds) + else: + # This operation cannot be converted to Circle because it produces data-dependent dynamic shapes + inputs_embeds = self._masked_scatter( + inputs_embeds, image_mask, image_embeds + ) + + # Quantize multimodal fusion result + inputs_embeds = self._fq(inputs_embeds, self.obs_mm_fusion) + + # Process videos + if pixel_values_videos is not None: + # Get video features from vision model + video_outputs = self._get_video_features( + pixel_values_videos, video_grid_thw, return_dict=True + ) + video_embeds = video_outputs.pooler_output + deepstack_video_embeds = video_outputs.deepstack_features + + # Concatenate all video features + video_embeds = torch.cat(video_embeds, dim=0).to( + inputs_embeds.device, inputs_embeds.dtype + ) + + # Create mask for video placeholder tokens + _, video_mask = self._get_placeholder_mask( + input_ids=input_ids, + inputs_embeds=inputs_embeds, + video_features=video_embeds, + ) + + # Replace video placeholders with actual visual features + if torch.compiler.is_compiling() or self.force_export: + # If we are exporting the model as a static graph + # This assumes visual tokens are placed strictly in the beginning of the prompt + # Violating this requirement will lead to a corrupted prompt + inputs_embeds = self._fuse_text_n_image(inputs_embeds, video_embeds) + else: + # This operation cannot be converted to Circle because it produces data-dependent dynamic shapes + inputs_embeds = self._masked_scatter( + inputs_embeds, video_mask, video_embeds + ) + + # Quantize multimodal fusion result + inputs_embeds = self._fq(inputs_embeds, self.obs_mm_fusion) + + # Combine deepstack features from images and videos + visual_pos_masks = None + deepstack_visual_embeds = None + if image_mask is not None and video_mask is not None: + # Aggregate visual masks and deepstack features + image_mask = image_mask[..., 0] + video_mask = video_mask[..., 0] + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + image_mask_joint = image_mask[visual_pos_masks] + video_mask_joint = video_mask[visual_pos_masks] + for img_embed, vid_embed in zip( + deepstack_image_embeds, deepstack_video_embeds + ): + embed_joint = img_embed.new_zeros( + visual_pos_masks.sum(), img_embed.shape[-1] + ).to(img_embed.device) + embed_joint[image_mask_joint, :] = img_embed + embed_joint[video_mask_joint, :] = vid_embed + deepstack_visual_embeds.append(embed_joint) + elif image_mask is not None: + image_mask = image_mask[..., 0] + visual_pos_masks = image_mask + deepstack_visual_embeds = deepstack_image_embeds + elif video_mask is not None: + video_mask = video_mask[..., 0] + visual_pos_masks = video_mask + deepstack_visual_embeds = deepstack_video_embeds + + # Compute 3D position IDs if not provided + # Note: This involves only integer operations, no quantization needed + if position_ids is None: + position_ids = self._compute_3d_position_ids( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + cache_position=cache_position, + past_key_values=past_key_values, + ) + + # Pass through language model (wrapped with PTQWrapper) + outputs = self.language_model( + input_ids=None, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + cache_position=cache_position, + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + **kwargs, + ) + + # Return output with rope_deltas + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLModelOutputWithPast, + ) + + return Qwen3VLModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + rope_deltas=self.rope_deltas, + ) + + @staticmethod + def _fuse_text_n_image(inputs_embeds, visual_embeds): + num_visual_tokens = visual_embeds.shape[0] + flat_inputs = inputs_embeds.view(-1, inputs_embeds.shape[-1]) + flat_inputs[:num_visual_tokens] = visual_embeds + inputs_embeds = flat_inputs.view_as(inputs_embeds) + return inputs_embeds + + @staticmethod + def _masked_scatter(inputs_embeds, visual_mask, visual_embeds): + # Use indexing assignment instead of masked_scatter for better Circle support + # (TICO can't convert torch.masked_scatter operator) + flat_inputs = inputs_embeds.view(-1, inputs_embeds.shape[-1]) + mask_2d = visual_mask[..., 0] # Get mask for the first dimension only + _, indices = torch.nonzero(mask_2d, as_tuple=True) + flat_inputs[indices] = visual_embeds + inputs_embeds = flat_inputs.view_as(inputs_embeds) + return inputs_embeds + + def _get_image_features( + self, + pixel_values: torch.Tensor, + image_grid_thw: torch.Tensor, + **kwargs, + ): + """Get image features from vision model.""" + # Convert to vision model dtype + pixel_values = pixel_values.type(self.visual.wrapped.module.dtype) + + # Process through vision model + vision_output = self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs) + + # Get pooled output + image_embeds, deepstack_features = ( + (vision_output.pooler_output, vision_output.deepstack_features) + if self.visual.wrapped.has_deepstack_model_output + else vision_output + ) + + # Split by image based on grid_thw + spatial_merge_size = self.visual.wrapped.module.spatial_merge_size + split_sizes = (image_grid_thw.prod(-1) // spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + + OutputWithDeepstackFeatures = namedtuple( + "OutputWithDeepstackFeatures", ["pooler_output", "deepstack_features"] + ) + return OutputWithDeepstackFeatures( + pooler_output=image_embeds, deepstack_features=deepstack_features + ) + + def _get_video_features( + self, + pixel_values_videos: torch.Tensor, + video_grid_thw: torch.Tensor | None = None, + **kwargs, + ): + """Get video features from vision model (same as image processing).""" + return self._get_image_features(pixel_values_videos, video_grid_thw, **kwargs) + + def _get_placeholder_mask( + self, + input_ids: torch.Tensor | None, + inputs_embeds: torch.Tensor, + image_features: torch.Tensor | None = None, + video_features: torch.Tensor | None = None, + ): + """ + Obtain multimodal placeholder mask from input_ids or inputs_embeds. + Validates that placeholder token count matches feature length. + """ + if input_ids is None: + # Compare embeddings directly + embedder = self.language_model.wrapped.embed_tokens + img_tkn = torch.tensor( + self.image_token_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + img_tkn_emb = embedder(img_tkn) + special_image_mask = (inputs_embeds == img_tkn_emb).all(-1) + + vid_tkn = torch.tensor( + self.video_token_id, + dtype=torch.long, + device=inputs_embeds.device, + ) + vid_tkn_emb = embedder(vid_tkn) + special_video_mask = (inputs_embeds == vid_tkn_emb).all(-1) + else: + # Compare token IDs + special_image_mask = input_ids == self.image_token_id + special_video_mask = input_ids == self.video_token_id + + # Count image tokens + n_image_tokens = special_image_mask.sum() + special_image_mask = ( + special_image_mask.unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + if image_features is not None: + # Validate image tokens count matches features + if not torch.compiler.is_compiling() and not self.force_export: + assert ( + inputs_embeds[special_image_mask].numel() == image_features.numel() + ), f"Image features ({image_features.shape[0]}) and image tokens ({n_image_tokens}) do not match" + + # Count video tokens + n_video_tokens = special_video_mask.sum() + special_video_mask = ( + special_video_mask.unsqueeze(-1) + .expand_as(inputs_embeds) + .to(inputs_embeds.device) + ) + if video_features is not None: + # Validate video tokens count matches features + assert ( + inputs_embeds[special_video_mask].numel() == video_features.numel() + ), f"Video features ({video_features.shape[0]}) and video tokens ({n_video_tokens}) do not match" + + return special_image_mask, special_video_mask + + def _compute_3d_position_ids( + self, + input_ids: torch.Tensor | None, + inputs_embeds: torch.Tensor | None, + image_grid_thw: torch.Tensor | None = None, + video_grid_thw: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + cache_position: torch.Tensor | None = None, + past_key_values=None, + ) -> torch.Tensor | None: + """ + Compute 3D position IDs for multimodal RoPE. + Note: This involves only integer operations, no quantization needed. + """ + past_key_values_length = ( + 0 if past_key_values is None else past_key_values.get_seq_length() + ) + + if self.rope_deltas is None or past_key_values_length == 0: + position_ids, rope_deltas = self._get_rope_index( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + # then use the prev pre-calculated rope-deltas to get the correct position ids + else: + # Type narrowing for mypy: assign to local non-None variable + assert inputs_embeds is not None + _rope_deltas: torch.Tensor = self.rope_deltas # type: ignore[assignment] + batch_size, seq_length, _ = inputs_embeds.shape + delta = (past_key_values_length + _rope_deltas).to(inputs_embeds.device) + position_ids = torch.arange(seq_length, device=inputs_embeds.device) + position_ids = position_ids.view(1, -1).expand(batch_size, -1) + if cache_position is not None: # otherwise `deltas` is an int `0` + delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) + position_ids = position_ids.add(delta) + position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) + + return position_ids + + def _get_rope_index( + self, + input_ids: torch.Tensor, + image_grid_thw: torch.Tensor, + video_grid_thw: torch.Tensor, + attention_mask: torch.Tensor, + ): + """Calculate 3D rope index based on image and video sizes.""" + # Since we use timestamps to separate videos, video_grid_thw should be split + if video_grid_thw is not None: + video_grid_thw = torch.repeat_interleave( + video_grid_thw, video_grid_thw[:, 0], dim=0 + ) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.visual.wrapped.module.spatial_merge_size + image_token_id = self.module.config.image_token_id + video_token_id = self.module.config.video_token_id + vision_start_token_id = self.module.config.vision_start_token_id + + mrope_position_deltas = [] + + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = torch.ones_like(total_input_ids) + + position_ids = torch.ones( + 3, + input_ids.shape[0], + input_ids.shape[1], + dtype=input_ids.dtype, + device=input_ids.device, + ) + + image_index, video_index = 0, 0 + attention_mask = attention_mask.to(total_input_ids.device) + + for i, current_input_ids in enumerate(total_input_ids): + input_ids = current_input_ids[attention_mask[i] == 1] + + # Count images and videos + image_nums, video_nums = 0, 0 + vision_start_indices = torch.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + + for _ in range(image_nums + video_nums): + # Find next image or video token + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + + if ed_image < ed_video: + # This is an image + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + # This is a video + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + + text_len = ed - st + + # Text position IDs + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + llm_pos_ids_list.append( + torch.arange(text_len, device=input_ids.device) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + + # Vision position IDs (3D) + # t_index is always 0 because llm_grid_t is always 1 + # (we use timestamps to encode temporal information for videos) + t_index = ( + torch.arange(llm_grid_t, device=input_ids.device) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + torch.arange(llm_grid_h, device=input_ids.device) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + torch.arange(llm_grid_w, device=input_ids.device) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx + ) + + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + # Trailing text after all images/videos + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len, device=input_ids.device) + .view(1, -1) + .expand(3, -1) + + st_idx + ) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + position_ids[..., i, attention_mask[i] == 1] = llm_positions.to( + position_ids.device + ) + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + + mrope_position_deltas = torch.tensor( + mrope_position_deltas, device=input_ids.device + ).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + # Fallback for text-only input + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = ( + position_ids.unsqueeze(0) + .expand(3, -1, -1) + .to(attention_mask.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] + mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1] + else: + position_ids = ( + torch.arange(input_ids.shape[1], device=input_ids.device) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = torch.zeros( + [input_ids.shape[0], 1], + device=input_ids.device, + dtype=input_ids.dtype, + ) + + return position_ids, mrope_position_deltas + + def _all_observers(self) -> Iterable: + """Yield all observers from this module and wrapped submodules.""" + # Local observers + yield self.obs_mm_fusion diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_attn.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_attn.py index d9e09cf2..c3d81fda 100644 --- a/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_attn.py +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_attn.py @@ -203,7 +203,8 @@ def forward( attention_mask = self.causal_mask_template[..., :q_len, :k_len].to( hidden.device ) - attention_mask = self._fq(attention_mask, self.obs_causal_mask) + if torch.is_floating_point(attention_mask): + attention_mask = self._fq(attention_mask, self.obs_causal_mask) attn_weights_parts = [] attn_out_parts = [] diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py index a3b72a3b..e60207e8 100644 --- a/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py @@ -301,7 +301,8 @@ def forward( # past_key_values=past_key_values, # position_ids=text_position_ids, # ) - attention_mask = self._fq(attention_mask, self.obs_attention_mask) + if torch.is_floating_point(attention_mask): + attention_mask = self._fq(attention_mask, self.obs_attention_mask) hidden_states = inputs_embeds diff --git a/tico/quantization/wrapq/wrappers/registry.py b/tico/quantization/wrapq/wrappers/registry.py index e8c71da7..aa6d1cff 100644 --- a/tico/quantization/wrapq/wrappers/registry.py +++ b/tico/quantization/wrapq/wrappers/registry.py @@ -77,6 +77,7 @@ "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_patch_merger", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_block", "tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_model", + "tico.quantization.wrapq.wrappers.qwen_vl.quant_model", # add future core wrappers here )