[quantization] Introduce wrapper for Qwen3VLModel#555
[quantization] Introduce wrapper for Qwen3VLModel#555dvsav wants to merge 1 commit intoSamsung:mainfrom
Conversation
0b92491 to
cec58ba
Compare
0e7e67d to
6bafa6a
Compare
602883c to
ab81016
Compare
There was a problem hiding this comment.
Please remove this change.
| # position_ids=text_position_ids, | ||
| # ) | ||
| attention_mask = self._fq(attention_mask, self.obs_attention_mask) | ||
| if torch.is_floating_point(attention_mask): |
There was a problem hiding this comment.
When do we need to check float?
There was a problem hiding this comment.
Example when attention_mask is long
transformers/models/qwen3_vl/modeling_qwen3_vl.py
def get_rope_index(
...
input_ids: torch.LongTensor | None = None,
...
)
...
total_input_ids = input_ids
attention_mask = torch.ones_like(total_input_ids) # attention_mask.dtype is long
...
Example when attention_mask is float
tico/quantization/wrapq/wrappers/qwen_vl/quant_text_attn.py
mask = torch.full((1, 1, max_seq, max_seq), float("-120")) # type: ignore[arg-type]
mask.triu_(1)
self.register_buffer("causal_mask_template", mask, persistent=False)
...
attention_mask = self.causal_mask_template[..., :q_len, :k_len].to(
hidden.device
)
Why checking float?
Fake quantization is not supported for integer tensors and raises an exception in tico/quantization/wrapq/observers/affine_base.py:144:
RuntimeError: "fake_quantize_tensor_cachemask_kernel_type_handling" not implemented for 'Long'
| @staticmethod | ||
| def _fuse_text_n_image(inputs_embeds, visual_embeds): | ||
| num_visual_tokens = visual_embeds.shape[0] | ||
| flat_inputs = inputs_embeds.view(-1, inputs_embeds.shape[-1]) | ||
| flat_inputs[:num_visual_tokens] = visual_embeds | ||
| inputs_embeds = flat_inputs.view_as(inputs_embeds) | ||
| return inputs_embeds | ||
|
|
||
| @staticmethod | ||
| def _masked_scatter(inputs_embeds, visual_mask, visual_embeds): | ||
| # Use indexing assignment instead of masked_scatter for better Circle support | ||
| # (TICO can't convert torch.masked_scatter operator) | ||
| flat_inputs = inputs_embeds.view(-1, inputs_embeds.shape[-1]) | ||
| mask_2d = visual_mask[..., 0] # Get mask for the first dimension only | ||
| _, indices = torch.nonzero(mask_2d, as_tuple=True) | ||
| flat_inputs[indices] = visual_embeds | ||
| inputs_embeds = flat_inputs.view_as(inputs_embeds) | ||
| return inputs_embeds |
There was a problem hiding this comment.
Note for Reviewers
torch.Tensor.masked_scatter is inherently data-dependent (the number of and position of True values in the mask varies at runtime depending on the positions of visual data in the prompt). This leads to errors during conversion to Circle like the following:
ERROR:tico.utils.convert:NOT SUPPORTED OPERATOR
(op) masked_scatter.default
(trace) File "tico/quantization/wrapq/wrappers/ptq_wrapper.py", line 68, in forward
return self.wrapped(*args, **kwargs)
File "tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py", line 145, in forward
inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
Traceback (most recent call last):
File "tico/quantization/wrapq/examples/qwen/quantize_model.py", line 357, in <module>
main()
File "tico/quantization/wrapq/examples/qwen/quantize_model.py", line 348, in main
circle_model = tico.convert(quantized_model.eval(), example_input)
File "tico/utils/convert.py", line 351, in convert
circle_binary = convert_exported_module_to_circle(
File "tico/utils/convert.py", line 324, in convert_exported_module_to_circle
check_unsupported_target(exported_program)
File "tico/utils/convert.py", line 177, in check_unsupported_target
raise NotYetSupportedError("NOT SUPPORTED OPERATOR IN GRAPH MODULE")
tico.utils.errors.NotYetSupportedError: NOT SUPPORTED OPERATOR IN GRAPH MODULE
_masked_scatter method was created in an attempt to replace torch.Tensor.masked_scatter with flattening and direct indexing. Nevertheless, torch.nonzero(mask_2d, as_tuple=True) used there produced another Circle conversion error (see the evidence below):
Traceback (most recent call last):
File "tico/quantization/wrapq/examples/qwen/quantize_model.py", line 357, in <module>
main()
File "tico/quantization/wrapq/examples/qwen/quantize_model.py", line 348, in main
circle_model = tico.convert(quantized_model.eval(), example_input)
File "tico/utils/convert.py", line 351, in convert
circle_binary = convert_exported_module_to_circle(
File "tico/utils/convert.py", line 284, in convert_exported_module_to_circle
circle_legalize.run(exported_program)
File "tico/utils/passes.py", line 65, in run
result = _pass.call(exported_program)
File "tico/utils/trace_decorators.py", line 63, in wrapped
ret = fn(*args)
File "tico/passes/remove_redundant_reshape.py", line 359, in call
assert isinstance(s, int), type(s)
AssertionError: <class 'torch.SymInt'>
Debugging revealed that the culprit was a reshape node in the exported program's graph:
(Pdb) list
354 reshape1_args = ReshapeArgs(*reshape1.args, **reshape1.kwargs) # type: ignore[arg-type]
355 reshape1_input, size = reshape1_args.input, reshape1_args.shape
356 assert isinstance(reshape1_input, torch.fx.Node), type(reshape1_input)
357 assert isinstance(size, list), type(size)
358 for s in size:
359 -> assert isinstance(s, int), type(s)
360
361 if not len(reshape1.users) == 1:
362 continue
363
364 # reshape_2
(Pdb) pp reshape1.stack_trace
(' File '
'"tico/quantization/wrapq/wrappers/ptq_wrapper.py", '
'line 68, in forward\n'
' return self.wrapped(*args, **kwargs)\n'
' File '
'"tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py", '
'line 145, in forward\n'
' self._masked_scatter(inputs_embeds, image_mask, image_embeds)\n'
' File '
'"tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py", '
'line 246, in _masked_scatter\n'
' _, indices = torch.nonzero(mask_2d, as_tuple=True)\n')
All these led me to the thought that we have to fix the position of visual data in the prompt just like we've fixed grid_thw in PTQConfiq previously (see #560). A rough way of doing this is implemented as _fuse_text_n_image method that is used as export-time alternative to _masked_scatter. _fuse_text_n_image simply assumes that all visual tokens are located at the beginning of the prompt.
This change introduces QuantQwen3VLModel wrapper to support post-training quantization of Qwen3VLModel operation. TICO-DCO-1.0-Signed-off-by: d.savchenkov <d.savchenkov@partner.samsung.com>
Note for me
Debugging position_embeddings = self.rotary_emb(hidden_states, position_ids)The thing is that I carelessly picked that idea from PR #535 assuming that This issue should probably be addressed (in a separate PR) by fixing |
This change introduces
QuantQwen3VLModelwrapper to support post-training quantization ofQwen3VLModelmodule.Why?
Qwen3VLModelis an essential part of Qwen model.Trying to quantize
Qwen3VLModelvia PTQ generates exceptionPTQQuantizer: no quantization wrapper for Qwen3VLModel.What
This change introduces:
QuantQwen3VLModel(tico/quantization/wrapq/wrappers/qwen_vl/quant_model.py).class TestQuantQwen3VLModel(test/quantization/wrapq/wrappers/qwen_vl/test_quant_model.py) - skipped iftransformerspackage is not installed._CORE_MODULES(tico/quantization/wrapq/wrappers/registry.py).Qwen3VLModelquantization and conversion to Circle (tico/quantization/wrapq/examples/qwen/quantize_qwen_model.py).Unit Tests
Unit tests results with coverage information:
Coverage info (irrelevant files skipped):
Script for testing quantization and conversion to Circle