Skip to content

[quantization] Introduce wrapper for Qwen3VLModel#555

Open
dvsav wants to merge 1 commit intoSamsung:mainfrom
dvsav:quant_model
Open

[quantization] Introduce wrapper for Qwen3VLModel#555
dvsav wants to merge 1 commit intoSamsung:mainfrom
dvsav:quant_model

Conversation

@dvsav
Copy link
Copy Markdown
Contributor

@dvsav dvsav commented Mar 16, 2026

This change introduces QuantQwen3VLModel wrapper to support post-training quantization of Qwen3VLModel module.

Why?

Qwen3VLModel is an essential part of Qwen model.
Trying to quantize Qwen3VLModel via PTQ generates exception PTQQuantizer: no quantization wrapper for Qwen3VLModel.

What

This change introduces:

  • Class QuantQwen3VLModel (tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py).
  • Unit tests: class TestQuantQwen3VLModel (test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py) - skipped if transformers package is not installed.
  • New entry in _CORE_MODULES (tico/quantization/wrapq/wrappers/registry.py).
  • Example of Qwen3VLModel quantization and conversion to Circle (tico/quantization/wrapq/examples/qwen/quantize_qwen_model.py).

Unit Tests

Unit tests results with coverage information:

$ coverage run -m pytest test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py -v
================================================================== test session starts ==================================================================
platform linux -- Python 3.10.12, pytest-8.4.0, pluggy-1.6.0 -- /home/d.savchenkov/myenv/bin/python3
cachedir: .pytest_cache
rootdir: /home/d.savchenkov/TICO
configfile: pyproject.toml
plugins: anyio-4.12.0, mock-3.15.1, xdist-3.7.0, cov-6.2.1
collected 22 items                                                                                                                                      

test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_activation_stats_collected_text_only PASSED             [  4%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_activation_stats_collected_with_images PASSED           [  9%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_compute_3d_position_ids_reuses_cached_rope_deltas PASSED [ 13%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_different_batch_sizes_text_only PASSED                  [ 18%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_dtype_override PASSED                                   [ 22%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_forward_diff_text_only PASSED                           [ 27%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_forward_input_validation PASSED                         [ 31%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_forward_text_only PASSED                                [ 36%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_forward_with_both_images_and_videos PASSED              [ 40%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_forward_with_images PASSED                              [ 45%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_forward_with_inputs_embeds PASSED                       [ 50%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_forward_with_inputs_embeds_and_images PASSED            [ 54%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_forward_with_videos PASSED                              [ 59%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_get_rope_index_with_images_and_videos PASSED            [ 63%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_graph_tracing_behavior_with_images PASSED               [ 68%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_graph_tracing_behavior_with_videos PASSED               [ 72%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_mode_transitions PASSED                                 [ 77%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_multiple_calibration_steps_text_only PASSED             [ 81%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_observer_count PASSED                                   [ 86%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_registration_in_registry PASSED                         [ 90%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_rope_deltas_computed_after_forward PASSED               [ 95%]
test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py::TestQuantQwen3VLModel::test_wraps_submodules PASSED                                 [100%]

============================================================ 22 passed, 2 warnings in 42.85s ============================================================

Coverage info (irrelevant files skipped):

$ coverage report -m
Name                                                                   Stmts   Miss  Cover   Missing
----------------------------------------------------------------------------------------------------
...
tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py                   215      2    99%   434, 507
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_attn.py               136      5    96%   196-197, 201-203
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_decoder_layer.py       42      0   100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_mlp.py                 43      0   100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py              130      8    94%   248, 254-256, 260, 278, 282, 285-286
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_attn.py             105      0   100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_block.py             42      0   100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_mlp.py               33      0   100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_model.py            173      6    97%   166, 173, 180, 195, 279, 452
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_embed.py       25      0   100%
tico/quantization/wrapq/wrappers/qwen_vl/quant_vision_patch_merger.py      36      0   100%
tico/quantization/wrapq/wrappers/registry.py                               36      1    97%   259
...
----------------------------------------------------------------------------------------------------
TOTAL                                                                  10170   6520    36%

Script for testing quantization and conversion to Circle

$ python3 tico/quantization/wrapq/examples/qwen/quantize_qwen_model.py

┌───────────── Quantization Error Summary ─────────────
│ Mean |diff|: 0.149582
│ PEIR       : 15.490412 %
└──────────────────────────────────────────────────────
    ┌────────────────────────────────────────────┐
 3.5┤                                            │
    │                                       •••  │
    │                                 •  •••••   │
 2.3┤                                 ••••••     │
    │                              ••••••        │
    │                           •••••••••        │
 1.1┤                         ••••••••           │
    │                    •••••••••••             │
    │                    ••••••••                │
-0.2┤                •••••••••••                 │
    │             •  ••••••••••                  │
    │             •••••••••                      │
-1.4┤           •••••••                          │
    │       • ••••••••                           │
    │      ••••••••                              │
-2.6┤      •••••                                 │
    │   •••• •                                   │
    │  ••                                        │
-3.8┤                                            │
    └┬──────────┬──────────┬─────────┬──────────┬┘
   -3.8       -2.0       -0.2       1.7       3.5 

[QuantCheck] WARNING: 34 nodes without qparam detected (see logs).
Circle model saved as 'qwen3vl_model.q.circle'

@dvsav dvsav force-pushed the quant_model branch 7 times, most recently from 0b92491 to cec58ba Compare March 17, 2026 11:31
@dvsav dvsav force-pushed the quant_model branch 22 times, most recently from 0e7e67d to 6bafa6a Compare April 1, 2026 08:36
@dvsav dvsav force-pushed the quant_model branch 2 times, most recently from 602883c to ab81016 Compare April 1, 2026 11:51
@dvsav dvsav marked this pull request as ready for review April 1, 2026 12:04
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this change.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 done

# position_ids=text_position_ids,
# )
attention_mask = self._fq(attention_mask, self.obs_attention_mask)
if torch.is_floating_point(attention_mask):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When do we need to check float?

Copy link
Copy Markdown
Contributor Author

@dvsav dvsav Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Example when attention_mask is long

transformers/models/qwen3_vl/modeling_qwen3_vl.py

def get_rope_index(
    ...
    input_ids: torch.LongTensor | None = None,
    ...
)
    ...
    total_input_ids = input_ids
    attention_mask = torch.ones_like(total_input_ids) # attention_mask.dtype is long
    ...

Example when attention_mask is float

tico/quantization/wrapq/wrappers/qwen_vl/quant_text_attn.py

        mask = torch.full((1, 1, max_seq, max_seq), float("-120"))  # type: ignore[arg-type]
        mask.triu_(1)
        self.register_buffer("causal_mask_template", mask, persistent=False)
        
...
            attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
                hidden.device
            )

Why checking float?

Fake quantization is not supported for integer tensors and raises an exception in tico/quantization/wrapq/observers/affine_base.py:144:

RuntimeError: "fake_quantize_tensor_cachemask_kernel_type_handling" not implemented for 'Long'

Comment on lines +275 to +292
@staticmethod
def _fuse_text_n_image(inputs_embeds, visual_embeds):
num_visual_tokens = visual_embeds.shape[0]
flat_inputs = inputs_embeds.view(-1, inputs_embeds.shape[-1])
flat_inputs[:num_visual_tokens] = visual_embeds
inputs_embeds = flat_inputs.view_as(inputs_embeds)
return inputs_embeds

@staticmethod
def _masked_scatter(inputs_embeds, visual_mask, visual_embeds):
# Use indexing assignment instead of masked_scatter for better Circle support
# (TICO can't convert torch.masked_scatter operator)
flat_inputs = inputs_embeds.view(-1, inputs_embeds.shape[-1])
mask_2d = visual_mask[..., 0] # Get mask for the first dimension only
_, indices = torch.nonzero(mask_2d, as_tuple=True)
flat_inputs[indices] = visual_embeds
inputs_embeds = flat_inputs.view_as(inputs_embeds)
return inputs_embeds
Copy link
Copy Markdown
Contributor Author

@dvsav dvsav Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for Reviewers

torch.Tensor.masked_scatter is inherently data-dependent (the number of and position of True values in the mask varies at runtime depending on the positions of visual data in the prompt). This leads to errors during conversion to Circle like the following:

ERROR:tico.utils.convert:NOT SUPPORTED OPERATOR
        (op) masked_scatter.default
        (trace)   File "tico/quantization/wrapq/wrappers/ptq_wrapper.py", line 68, in forward
    return self.wrapped(*args, **kwargs)
  File "tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py", line 145, in forward
    inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

Traceback (most recent call last):
  File "tico/quantization/wrapq/examples/qwen/quantize_model.py", line 357, in <module>
    main()
  File "tico/quantization/wrapq/examples/qwen/quantize_model.py", line 348, in main
    circle_model = tico.convert(quantized_model.eval(), example_input)
  File "tico/utils/convert.py", line 351, in convert
    circle_binary = convert_exported_module_to_circle(
  File "tico/utils/convert.py", line 324, in convert_exported_module_to_circle
    check_unsupported_target(exported_program)
  File "tico/utils/convert.py", line 177, in check_unsupported_target
    raise NotYetSupportedError("NOT SUPPORTED OPERATOR IN GRAPH MODULE")
tico.utils.errors.NotYetSupportedError: NOT SUPPORTED OPERATOR IN GRAPH MODULE

_masked_scatter method was created in an attempt to replace torch.Tensor.masked_scatter with flattening and direct indexing. Nevertheless, torch.nonzero(mask_2d, as_tuple=True) used there produced another Circle conversion error (see the evidence below):

Traceback (most recent call last):
  File "tico/quantization/wrapq/examples/qwen/quantize_model.py", line 357, in <module>
    main()
  File "tico/quantization/wrapq/examples/qwen/quantize_model.py", line 348, in main
    circle_model = tico.convert(quantized_model.eval(), example_input)
  File "tico/utils/convert.py", line 351, in convert
    circle_binary = convert_exported_module_to_circle(
  File "tico/utils/convert.py", line 284, in convert_exported_module_to_circle
    circle_legalize.run(exported_program)
  File "tico/utils/passes.py", line 65, in run
    result = _pass.call(exported_program)
  File "tico/utils/trace_decorators.py", line 63, in wrapped
    ret = fn(*args)
  File "tico/passes/remove_redundant_reshape.py", line 359, in call
    assert isinstance(s, int), type(s)
AssertionError: <class 'torch.SymInt'>

Debugging revealed that the culprit was a reshape node in the exported program's graph:

(Pdb) list
354                 reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs)  # type: ignore[arg-type]
355                 reshape1_input, size = reshape1_args.input, reshape1_args.shape
356                 assert isinstance(reshape1_input, torch.fx.Node), type(reshape1_input)
357                 assert isinstance(size, list), type(size)
358                 for s in size:
359  ->                 assert isinstance(s, int), type(s)
360  
361                 if not len(reshape1.users) == 1:
362                     continue
363  
364                 # reshape_2

(Pdb) pp reshape1.stack_trace
('  File '
 '"tico/quantization/wrapq/wrappers/ptq_wrapper.py", '
 'line 68, in forward\n'
 '    return self.wrapped(*args, **kwargs)\n'
 '  File '
 '"tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py", '
 'line 145, in forward\n'
 '    self._masked_scatter(inputs_embeds, image_mask, image_embeds)\n'
 '  File '
 '"tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py", '
 'line 246, in _masked_scatter\n'
 '    _, indices = torch.nonzero(mask_2d, as_tuple=True)\n')

All these led me to the thought that we have to fix the position of visual data in the prompt just like we've fixed grid_thw in PTQConfiq previously (see #560). A rough way of doing this is implemented as _fuse_text_n_image method that is used as export-time alternative to _masked_scatter. _fuse_text_n_image simply assumes that all visual tokens are located at the beginning of the prompt.

This change introduces QuantQwen3VLModel wrapper to support post-training quantization of Qwen3VLModel operation.

TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
@dvsav
Copy link
Copy Markdown
Contributor Author

dvsav commented Apr 2, 2026

Note for me

[QuantCheck] WARNING: 34 nodes without qparam detected (see logs).

Debugging tico/quantization/wrapq/utils/check_missing_qparam.py(108)check_missing_qparam() shows that the nodes without qparam stem from the following line in /tico/quantization/wrapq/wrappers/qwen_vl/quant_text_model.py:

position_embeddings = self.rotary_emb(hidden_states, position_ids)

The thing is that self.rotary_emb is the original not wrapped Qwen3VLTextRotaryEmbedding.

I carelessly picked that idea from PR #535 assuming that self.rotary_emb(hidden_states, position_ids) will be converted to static tensor values during the model export. But it looks like it's not the case, and the actual computations of Qwen3VLTextRotaryEmbedding.forward are exported as is (and not quantized).

This issue should probably be addressed (in a separate PR) by fixing position_ids for inference time (just like we did for grid_thw - see #560)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants