Skip to content

Commit c832fad

Browse files
author
d-savchenkov
committed
[quantization] Introduce wrapper for Qwen3VLVisionModel
This change introduces QuantQwen3VLVisionModel wrapper to support post-training quantization of Qwen3VLVisionModel operation. TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
1 parent fc67e35 commit c832fad

5 files changed

Lines changed: 1013 additions & 0 deletions

File tree

Lines changed: 372 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,372 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
import unittest
17+
from typing import Tuple
18+
19+
import torch
20+
21+
from tico.quantization.config.ptq import PTQConfig
22+
from tico.quantization.wrapq.mode import Mode
23+
from tico.quantization.wrapq.utils.version import has_transformers_for
24+
from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_model import (
25+
QuantQwen3VLVisionModel,
26+
)
27+
28+
29+
skip_msg = "transformers not installed — skipping Qwen3VLVisionModel tests"
30+
31+
32+
@unittest.skipUnless(has_transformers_for("qwen3-vl"), skip_msg)
33+
class TestQuantQwen3VLVisionModel(unittest.TestCase):
34+
fp_model: torch.nn.Module
35+
hidden_size: int
36+
num_heads: int
37+
head_dim: int
38+
theta: float
39+
40+
@classmethod
41+
def setUpClass(cls):
42+
from transformers.models.qwen3_vl.configuration_qwen3_vl import (
43+
Qwen3VLVisionConfig,
44+
)
45+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel
46+
47+
# Use smaller sizes for testing
48+
cfg = Qwen3VLVisionConfig(
49+
hidden_size=64,
50+
num_heads=4,
51+
depth=2, # Smaller depth for faster testing
52+
temporal_patch_size=2,
53+
patch_size=16,
54+
)
55+
56+
# Ensure eager attention implementation so outputs are deterministic
57+
# and do not require GPU flash attention kernels.
58+
# Some versions use `_attn_implementation`, others expose `attn_implementation`.
59+
if not hasattr(cfg, "_attn_implementation"):
60+
setattr(cfg, "_attn_implementation", "eager")
61+
else:
62+
cfg._attn_implementation = "eager"
63+
64+
cls.fp_model = Qwen3VLVisionModel(cfg)
65+
cls.hidden_size = cfg.hidden_size
66+
cls.num_heads = cfg.num_heads
67+
cls.head_dim = cls.hidden_size // cls.num_heads
68+
cls.theta = (
69+
cls.fp_model.rotary_pos_emb.theta
70+
if hasattr(cls.fp_model.rotary_pos_emb, "theta")
71+
else 10000.0
72+
)
73+
74+
def _create_test_inputs(
75+
self, grid_thw: Tuple[int, int, int] = (1, 8, 8)
76+
) -> Tuple[torch.Tensor, torch.Tensor]:
77+
"""Helper to create test inputs for VisionModel."""
78+
t, h, w = grid_thw
79+
num_patches = t * h * w
80+
# Input shape: (seq_len, in_channels * temporal_patch_size * patch_size * patch_size)
81+
hidden_states = torch.randn(
82+
num_patches, 3 * 2 * 16 * 16
83+
) # 3 channels, 2 temporal, 16x16 patches
84+
grid_tensor = torch.tensor([grid_thw])
85+
return hidden_states, grid_tensor
86+
87+
def test_get_vision_grid_thw_from_config(self):
88+
"""Test _get_vision_grid_thw static method with valid config."""
89+
# Test with valid config
90+
ptq_config = PTQConfig()
91+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
92+
93+
grid_thw = QuantQwen3VLVisionModel._get_vision_grid_thw(ptq_config)
94+
expected = torch.tensor([[1, 8, 8]])
95+
self.assertTrue(torch.equal(grid_thw, expected))
96+
self.assertEqual(grid_thw.shape, (1, 3))
97+
98+
def test_get_vision_grid_thw_missing_config(self):
99+
"""Test _get_vision_grid_thw raises error when config is missing."""
100+
# Test with None config
101+
with self.assertRaises(ValueError) as context:
102+
QuantQwen3VLVisionModel._get_vision_grid_thw(None)
103+
self.assertIn("vision_grid_thw must be specified", str(context.exception))
104+
105+
# Test with config without vision_grid_thw
106+
ptq_config = PTQConfig()
107+
with self.assertRaises(ValueError) as context:
108+
QuantQwen3VLVisionModel._get_vision_grid_thw(ptq_config)
109+
self.assertIn("vision_grid_thw must be specified", str(context.exception))
110+
111+
def test_precompute_rope_inv_freq(self):
112+
"""Test _precompute_rope_inv_freq static method."""
113+
dim = 32
114+
theta = 10000.0
115+
inv_freq = QuantQwen3VLVisionModel._precompute_rope_inv_freq(dim, theta)
116+
117+
self.assertEqual(inv_freq.shape, (dim // 2,))
118+
self.assertTrue(torch.all(inv_freq > 0))
119+
# Check that frequencies are decreasing
120+
self.assertTrue(torch.all(inv_freq[:-1] >= inv_freq[1:]))
121+
122+
def test_precompute_cu_seqlens(self):
123+
"""Test _precompute_cu_seqlens static method."""
124+
grid_thw = torch.tensor(
125+
[[1, 8, 8], [2, 4, 4]]
126+
) # 1*8*8 + 2*4*4 = 96 total patches
127+
cu_seqlens = QuantQwen3VLVisionModel._precompute_cu_seqlens(grid_thw)
128+
129+
self.assertEqual(cu_seqlens.shape, (4,)) # 3 images + 1 padding
130+
self.assertEqual(cu_seqlens[0].item(), 0)
131+
self.assertEqual(cu_seqlens[1].item(), 64) # 1st image: 1*8*8 = 64 patches
132+
self.assertEqual(cu_seqlens[2].item(), 80) # 2nd image: 1*4*4 = 16 patches
133+
self.assertEqual(
134+
cu_seqlens[3].item(), 96
135+
) # 3rd image: 1*4*4 = 16 patches, total 96
136+
137+
def test_precompute_rope_position_embeddings(self):
138+
"""Test _precompute_rope_position_embeddings static method."""
139+
grid_thw = torch.tensor([[1, 8, 8]])
140+
inv_freq = QuantQwen3VLVisionModel._precompute_rope_inv_freq(
141+
dim=self.head_dim // 2,
142+
theta=self.theta,
143+
)
144+
145+
cos_t, sin_t = QuantQwen3VLVisionModel._precompute_rope_position_embeddings(
146+
merge_size=2,
147+
rope_inv_freq=inv_freq,
148+
grid_thw=grid_thw,
149+
)
150+
151+
expected_patches = math.prod(grid_thw[0].tolist()) # t * h * w = 1 * 8 * 8 = 64
152+
self.assertEqual(cos_t.shape, (expected_patches, self.head_dim))
153+
self.assertEqual(sin_t.shape, (expected_patches, self.head_dim))
154+
155+
def test_rot_pos_emb(self):
156+
"""Test _rot_pos_emb static method."""
157+
grid_thw = torch.tensor([[1, 8, 8]])
158+
inv_freq = QuantQwen3VLVisionModel._precompute_rope_inv_freq(
159+
dim=self.head_dim // 2,
160+
theta=self.theta,
161+
)
162+
163+
rotary_pos_emb = QuantQwen3VLVisionModel._rot_pos_emb(2, inv_freq, grid_thw)
164+
165+
expected_patches = math.prod(grid_thw[0].tolist()) # t * h * w = 1 * 8 * 8 = 64
166+
self.assertEqual(rotary_pos_emb.shape, (expected_patches, self.head_dim // 2))
167+
168+
def test_create_freq_table(self):
169+
"""Test _create_freq_table static method."""
170+
seqlen = 64
171+
inv_freq = torch.randn(16) # dim//2 = 32//2 = 16
172+
freq_table = QuantQwen3VLVisionModel._create_freq_table(seqlen, inv_freq)
173+
174+
self.assertEqual(freq_table.shape, (seqlen, inv_freq.shape[0]))
175+
176+
def test_fast_pos_embed_interpolate(self):
177+
"""Test _fast_pos_embed_interpolate static method."""
178+
grid_thw = torch.tensor([[1, 8, 8]])
179+
pos_embeds = QuantQwen3VLVisionModel._fast_pos_embed_interpolate(
180+
merge_size=2,
181+
num_grid_per_side=48, # From model config
182+
pos_embedder=self.fp_model.pos_embed,
183+
grid_thw=grid_thw,
184+
)
185+
186+
expected_patches = math.prod(grid_thw[0].tolist()) # t * h * w = 1 * 8 * 8 = 64
187+
self.assertEqual(pos_embeds.shape, (expected_patches, self.hidden_size))
188+
189+
def test_init_with_valid_config(self):
190+
"""Test successful initialization with valid config."""
191+
ptq_config = PTQConfig()
192+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
193+
194+
q_model = QuantQwen3VLVisionModel(
195+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
196+
)
197+
198+
# Check that buffers are registered
199+
self.assertTrue(hasattr(q_model, "cu_seqlens_template"))
200+
self.assertTrue(hasattr(q_model, "pos_embed_template"))
201+
self.assertTrue(hasattr(q_model, "rope_inv_freq"))
202+
self.assertTrue(hasattr(q_model, "rope_cos_template"))
203+
self.assertTrue(hasattr(q_model, "rope_sin_template"))
204+
205+
# Check submodule wrapping
206+
self.assertIsNotNone(q_model.patch_embed)
207+
self.assertEqual(len(q_model.blocks), len(self.fp_model.blocks))
208+
self.assertIsNotNone(q_model.merger)
209+
self.assertEqual(
210+
len(q_model.deepstack_merger_list), len(self.fp_model.deepstack_merger_list)
211+
)
212+
213+
def test_init_missing_vision_grid_thw(self):
214+
"""Test initialization fails without vision_grid_thw."""
215+
ptq_config = PTQConfig()
216+
217+
with self.assertRaises(ValueError) as context:
218+
QuantQwen3VLVisionModel(
219+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
220+
)
221+
self.assertIn("vision_grid_thw must be specified", str(context.exception))
222+
223+
def test_mode_transitions(self):
224+
"""Test quantization mode transitions: NO_QUANT → CALIB → QUANT"""
225+
ptq_config = PTQConfig()
226+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
227+
q_model = QuantQwen3VLVisionModel(
228+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
229+
)
230+
self.assertIs(q_model._mode, Mode.NO_QUANT)
231+
232+
q_model.enable_calibration()
233+
self.assertIs(q_model._mode, Mode.CALIB)
234+
235+
# Run forward pass during calibration
236+
hidden_states, grid_thw = self._create_test_inputs((1, 8, 8))
237+
_ = q_model(hidden_states, grid_thw)
238+
239+
q_model.freeze_qparams()
240+
self.assertIs(q_model._mode, Mode.QUANT)
241+
242+
def test_forward_grid_mismatch_during_calibration(self):
243+
"""Test forward pass fails with mismatched grid_thw during calibration."""
244+
ptq_config = PTQConfig()
245+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
246+
q_model = QuantQwen3VLVisionModel(
247+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
248+
)
249+
q_model.enable_calibration()
250+
251+
# Try with different grid
252+
hidden_states, grid_thw = self._create_test_inputs((1, 4, 4))
253+
254+
with self.assertRaises(AssertionError) as context:
255+
_ = q_model(hidden_states, grid_thw)
256+
self.assertIn("grid_thw", str(context.exception))
257+
258+
def test_observer_count(self):
259+
"""Test that the wrapper has the correct number of observers."""
260+
ptq_config = PTQConfig()
261+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
262+
q_model = QuantQwen3VLVisionModel(
263+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
264+
)
265+
266+
observers = list(q_model._all_observers())
267+
# Should have 4 local observers: pos_embeds, pos_add, rope_cos, rope_sin
268+
self.assertEqual(len(observers), 4)
269+
270+
def test_precomputed_embeddings_shape(self):
271+
"""Test that precomputed embeddings have correct shapes."""
272+
ptq_config = PTQConfig()
273+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
274+
q_model = QuantQwen3VLVisionModel(
275+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
276+
)
277+
278+
expected_patches = math.prod(
279+
getattr(ptq_config, "vision_grid_thw")
280+
) # t * h * w = 1 * 8 * 8 = 64
281+
282+
# Check position embeddings
283+
self.assertEqual(
284+
q_model.pos_embed_template.shape, (expected_patches, self.hidden_size)
285+
)
286+
287+
# Check RoPE embeddings
288+
self.assertEqual(
289+
q_model.rope_cos_template.shape,
290+
(expected_patches, self.head_dim),
291+
)
292+
self.assertEqual(
293+
q_model.rope_sin_template.shape,
294+
(expected_patches, self.head_dim),
295+
)
296+
297+
# Check cumulative sequence lengths
298+
self.assertEqual(q_model.cu_seqlens_template.shape, (2,)) # 1 image + 1 padding
299+
300+
def test_registration_in_registry(self):
301+
"""Test that Qwen3VLVisionModel is properly registered."""
302+
from tico.quantization.wrapq.wrappers.registry import lookup
303+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionModel
304+
305+
wrapper_cls = lookup(Qwen3VLVisionModel)
306+
self.assertIs(wrapper_cls, QuantQwen3VLVisionModel)
307+
308+
def test_output_structure(self):
309+
"""Test that output has correct structure."""
310+
ptq_config = PTQConfig()
311+
setattr(ptq_config, "vision_grid_thw", [1, 8, 8])
312+
q_model = QuantQwen3VLVisionModel(
313+
self.fp_model, qcfg=ptq_config, fp_name="test_model"
314+
)
315+
q_model.enable_calibration()
316+
317+
hidden_states, grid_thw = self._create_test_inputs((1, 8, 8))
318+
_ = q_model(hidden_states, grid_thw)
319+
320+
q_model.freeze_qparams()
321+
322+
with torch.no_grad():
323+
q_out = q_model(hidden_states, grid_thw)
324+
325+
# Check shapes
326+
expected_patches = math.prod(
327+
getattr(ptq_config, "vision_grid_thw")
328+
) # t * h * w = 1 * 8 * 8
329+
330+
# The structure of q_out depends on transformers version
331+
merged_hidden_states = (
332+
q_out.pooler_output if q_model.has_deepstack_model_output else q_out[0]
333+
)
334+
335+
self.assertEqual(merged_hidden_states.shape[0], expected_patches // 4)
336+
337+
def test_different_grid_sizes(self):
338+
"""Test with different grid sizes."""
339+
test_cases = [
340+
((1, 4, 4), "small_image"),
341+
((1, 6, 6), "medium_image"),
342+
((1, 8, 8), "large_image"),
343+
]
344+
345+
grid_thw_list: tuple[int, int, int]
346+
description: str
347+
for grid_thw_list, description in test_cases:
348+
with self.subTest(description=description):
349+
ptq_config = PTQConfig()
350+
setattr(ptq_config, "vision_grid_thw", grid_thw_list)
351+
q_model = QuantQwen3VLVisionModel(
352+
self.fp_model, qcfg=ptq_config, fp_name=f"test_model_{description}"
353+
)
354+
355+
hidden_states, grid_thw = self._create_test_inputs(grid_thw_list)
356+
357+
q_model.enable_calibration()
358+
_ = q_model(hidden_states, grid_thw)
359+
q_model.freeze_qparams()
360+
361+
with torch.no_grad():
362+
q_out = q_model(hidden_states, grid_thw)
363+
364+
# The structure of q_out depends on transformers version
365+
merged_hidden_states = (
366+
q_out.pooler_output
367+
if q_model.has_deepstack_model_output
368+
else q_out[0]
369+
)
370+
371+
expected_patches = math.prod(grid_thw_list) # t * h * w
372+
self.assertEqual(merged_hidden_states.shape[0], expected_patches // 4)

0 commit comments

Comments
 (0)