Skip to content

Commit f5d54ff

Browse files
author
d.savchenkov
committed
[quantization] Introduce wrapper for Qwen3VLVisionRotaryEmbedding
This change introduces QuantQwen3VLVisionRotaryEmbedding wrapper to support post-training quantization of Qwen3VLVisionRotaryEmbedding module. TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
1 parent 446dafb commit f5d54ff

4 files changed

Lines changed: 404 additions & 0 deletions

File tree

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import importlib.util
16+
import unittest
17+
18+
import torch
19+
from tico.quantization.config.ptq import PTQConfig
20+
from tico.quantization.wrapq.dtypes import DType
21+
from tico.quantization.wrapq.mode import Mode
22+
from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_rotary_embedding import (
23+
QuantQwen3VLVisionRotaryEmbedding,
24+
)
25+
26+
27+
trans_spec = importlib.util.find_spec("transformers")
28+
skip_msg = "transformers not installed — skipping Qwen3VLVisionRotaryEmbedding tests"
29+
30+
31+
@unittest.skipUnless(trans_spec, skip_msg)
32+
class TestQuantQwen3VLVisionRotaryEmbedding(unittest.TestCase):
33+
fp_rope: torch.nn.Module
34+
dim: int
35+
theta: float
36+
37+
@classmethod
38+
def setUpClass(cls):
39+
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
40+
Qwen3VLVisionRotaryEmbedding,
41+
)
42+
43+
# Use smaller dim for testing (typically 128 for head_dim=64)
44+
cls.fp_rope = Qwen3VLVisionRotaryEmbedding(dim=64)
45+
cls.dim = 64
46+
cls.theta = 10000.0
47+
48+
def test_mode_transitions(self):
49+
"""Test quantization mode transitions: NO_QUANT → CALIB → QUANT"""
50+
q_rope = QuantQwen3VLVisionRotaryEmbedding(self.fp_rope)
51+
self.assertIs(q_rope._mode, Mode.NO_QUANT)
52+
53+
q_rope.enable_calibration()
54+
self.assertIs(q_rope._mode, Mode.CALIB)
55+
56+
# Run forward pass during calibration
57+
seqlen = 128
58+
_ = q_rope(seqlen)
59+
60+
q_rope.freeze_qparams()
61+
self.assertIs(q_rope._mode, Mode.QUANT)
62+
63+
def test_quantised_output_close(self):
64+
"""
65+
Test that quantized output is acceptably close to FP32 reference.
66+
"""
67+
torch.manual_seed(42)
68+
q_rope = QuantQwen3VLVisionRotaryEmbedding(self.fp_rope)
69+
q_rope.enable_calibration()
70+
71+
# Calibrate with different sequence lengths
72+
for seqlen in [64, 128, 256]:
73+
_ = q_rope(seqlen)
74+
75+
q_rope.freeze_qparams()
76+
77+
seqlen = 128
78+
with torch.no_grad():
79+
q_out = q_rope(seqlen)
80+
fp_out = self.fp_rope(seqlen)
81+
82+
diff = (fp_out - q_out).abs().mean().item()
83+
self.assertGreater(diff, 0.0) # not identical
84+
self.assertLess(diff, 0.4) # acceptably close
85+
self.assertEqual(fp_out.shape, q_out.shape)
86+
87+
def test_output_shape(self):
88+
"""
89+
Test that output shape is correct: (seqlen, dim/2)
90+
"""
91+
q_rope = QuantQwen3VLVisionRotaryEmbedding(self.fp_rope)
92+
q_rope.enable_calibration()
93+
94+
for seqlen in [64, 128, 256]:
95+
q_rope.enable_calibration()
96+
_ = q_rope(seqlen)
97+
98+
q_rope.freeze_qparams()
99+
100+
seqlen = 128
101+
with torch.no_grad():
102+
q_out = q_rope(seqlen)
103+
104+
expected_shape = (seqlen, self.dim // 2)
105+
self.assertEqual(q_out.shape, expected_shape)
106+
107+
def test_different_sequence_lengths(self):
108+
"""
109+
Test that quantization works correctly with different sequence lengths.
110+
"""
111+
q_rope = QuantQwen3VLVisionRotaryEmbedding(self.fp_rope)
112+
q_rope.enable_calibration()
113+
114+
# Calibrate with one length
115+
for _ in range(3):
116+
_ = q_rope(256)
117+
118+
q_rope.freeze_qparams()
119+
120+
# Test with different lengths
121+
for seqlen in [2, 4, 8, 16, 32, 64, 128, 256]:
122+
with torch.no_grad():
123+
q_out = q_rope(seqlen)
124+
fp_out = self.fp_rope(seqlen)
125+
126+
diff = (fp_out - q_out).abs().mean().item()
127+
self.assertLess(diff, 0.4)
128+
self.assertEqual(q_out.shape[0], seqlen)
129+
self.assertEqual(q_out.shape[1], self.dim // 2)
130+
131+
def test_dtype_override(self):
132+
"""
133+
PTQConfig overrides should affect the output observer.
134+
"""
135+
cfg = PTQConfig(
136+
default_dtype=DType.uint(8),
137+
overrides={
138+
"output": {"dtype": DType.uint(4)},
139+
},
140+
)
141+
q_rope = QuantQwen3VLVisionRotaryEmbedding(self.fp_rope, qcfg=cfg)
142+
143+
self.assertEqual(q_rope.obs_output.dtype, DType.uint(4))
144+
145+
def test_activation_stats_collected(self):
146+
"""
147+
Test that activation statistics are properly collected during calibration.
148+
"""
149+
q_rope = QuantQwen3VLVisionRotaryEmbedding(self.fp_rope)
150+
q_rope.enable_calibration()
151+
152+
# Run forward pass to collect stats
153+
seqlen = 128
154+
_ = q_rope(seqlen)
155+
156+
# Check that observer has collected stats
157+
self.assertTrue(q_rope.obs_output.min_val.numel() > 0)
158+
159+
# Freeze and check qparams exist
160+
q_rope.freeze_qparams()
161+
self.assertTrue(q_rope.obs_output.has_qparams)
162+
163+
def test_observer_count(self):
164+
"""
165+
Test that the wrapper has the correct number of observers.
166+
Only 1 observer (output) since there are no learnable parameters.
167+
"""
168+
q_rope = QuantQwen3VLVisionRotaryEmbedding(self.fp_rope)
169+
170+
observers = list(q_rope._all_observers())
171+
self.assertEqual(len(observers), 1)
172+
173+
def test_registration_in_registry(self):
174+
"""
175+
Test that Qwen3VLVisionRotaryEmbedding is properly registered.
176+
"""
177+
from tico.quantization.wrapq.wrappers.qwen_vl.quant_vision_rotary_embedding import (
178+
QuantQwen3VLVisionRotaryEmbedding,
179+
)
180+
from tico.quantization.wrapq.wrappers.registry import lookup
181+
from transformers.models.qwen3_vl.modeling_qwen3_vl import (
182+
Qwen3VLVisionRotaryEmbedding,
183+
)
184+
185+
wrapper_cls = lookup(Qwen3VLVisionRotaryEmbedding)
186+
self.assertIs(wrapper_cls, QuantQwen3VLVisionRotaryEmbedding)
187+
188+
def test_no_learnable_parameters(self):
189+
"""
190+
Test that the wrapper has no learnable parameters (only buffers).
191+
"""
192+
q_rope = QuantQwen3VLVisionRotaryEmbedding(self.fp_rope)
193+
194+
# Check that there are no parameters
195+
params = list(q_rope.parameters())
196+
self.assertEqual(len(params), 0)
197+
198+
# Check that inv_freq is a buffer, not a parameter
199+
self.assertIsInstance(q_rope.inv_freq, torch.Tensor)
200+
self.assertIn("inv_freq", q_rope._buffers)
201+
202+
def test_frequency_values_correct(self):
203+
"""
204+
Test that the computed frequency values are mathematically correct.
205+
Formula: freqs[i, j] = i * theta^(-2j/dim)
206+
"""
207+
q_rope = QuantQwen3VLVisionRotaryEmbedding(self.fp_rope)
208+
q_rope.enable_calibration()
209+
q_rope.freeze_qparams()
210+
211+
seqlen = 4
212+
with torch.no_grad():
213+
freqs = q_rope(seqlen)
214+
215+
# Manually compute expected values
216+
expected = torch.outer(
217+
torch.arange(seqlen, dtype=torch.float32),
218+
self.fp_rope.inv_freq,
219+
)
220+
221+
# The quantized output should still have the same pattern
222+
# (quantization changes precision but not the mathematical relationship)
223+
torch.testing.assert_close(freqs.shape, expected.shape)
224+
self.assertEqual(freqs.shape, (seqlen, self.dim // 2))
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import importlib.util
17+
import sys
18+
19+
import torch
20+
import torch.nn as nn
21+
22+
import tico
23+
import tico.quantization
24+
import tico.quantization.config.ptq
25+
26+
# Check if transformers is available
27+
trans_spec = importlib.util.find_spec("transformers")
28+
if trans_spec is None:
29+
print(
30+
"Error: transformers package not installed. Cannot test Qwen3VLVisionRotaryEmbedding."
31+
)
32+
sys.exit(1)
33+
34+
from transformers.models.qwen3_vl.modeling_qwen3_vl import Qwen3VLVisionRotaryEmbedding
35+
36+
37+
def generate_calibration_data(batch_size: int, sequence_lengths: list) -> list:
38+
"""Generate calibration data for PTQ"""
39+
calibration_data = []
40+
for _ in range(batch_size):
41+
for seqlen in sequence_lengths:
42+
calibration_data.append(seqlen)
43+
return calibration_data
44+
45+
46+
def main():
47+
# Create the vision rotary embedding model
48+
# dim=128 is typical for head_dim=64 in Qwen3-VL
49+
dim = 128
50+
theta = 10000.0
51+
model = Qwen3VLVisionRotaryEmbedding(dim=dim, theta=theta)
52+
model.eval()
53+
54+
# Qwen3VLVisionRotaryEmbedding(
55+
# (inv_freq): Buffer [64] # dim/2 frequency bands
56+
# )
57+
assert model.dim == dim
58+
assert model.theta == theta
59+
assert model.inv_freq.shape == (dim // 2,)
60+
61+
# Generate calibration data
62+
# Calibrate with various sequence lengths to capture full dynamic range
63+
calibration_data = generate_calibration_data(
64+
batch_size=20, sequence_lengths=[64, 128, 256, 512]
65+
)
66+
67+
# Configure PTQ
68+
ptq_config = tico.quantization.config.ptq.PTQConfig()
69+
70+
# Prepare the model for quantization
71+
prepared_model = tico.quantization.prepare(
72+
model, ptq_config, inplace=True # Transform the model in place
73+
)
74+
75+
# Calibrate the model (collect statistics)
76+
with torch.no_grad():
77+
for i, seqlen in enumerate(calibration_data):
78+
_ = prepared_model(seqlen)
79+
80+
# Convert to quantized model
81+
quantized_model = tico.quantization.convert(prepared_model, inplace=True)
82+
83+
# Convert to Circle format
84+
# example_inputs: seqlen as an integer
85+
example_seqlen = 256
86+
circle_model = tico.convert(quantized_model, (example_seqlen,))
87+
88+
# Save the Circle model
89+
filename = "quantized_vision_rotary_embedding.circle"
90+
circle_model.save(filename)
91+
print(f"Circle model saved as '{filename}'")
92+
93+
94+
if __name__ == "__main__":
95+
main()

0 commit comments

Comments
 (0)