diff --git a/test/quantization/wrapq/wrappers/qwen_vl/test_quant_for_conditional_generation.py b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_for_conditional_generation.py new file mode 100644 index 00000000..b89b015c --- /dev/null +++ b/test/quantization/wrapq/wrappers/qwen_vl/test_quant_for_conditional_generation.py @@ -0,0 +1,198 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pathlib +import tempfile +import unittest +import warnings + +import tico + +import torch +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.dtypes import DType +from tico.quantization.wrapq.mode import Mode +from tico.quantization.wrapq.utils.version import has_transformers_for +from tico.quantization.wrapq.wrappers.nn.quant_linear import QuantLinear +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.qwen_vl.quant_for_conditional_generation import ( + QuantQwen3VLForConditionalGeneration, +) + + +skip_msg = "required transformers not installed — skipping Qwen3VLForConditionalGeneration tests" + + +@unittest.skipUnless(has_transformers_for("qwen3-vl"), skip_msg) +class TestQuantQwen3VLForConditionalGeneration(unittest.TestCase): + model_fp: torch.nn.Module + config: object # Will be Qwen3VLConfig but we don't want to import it at module level + + @classmethod + def setUpClass(cls): + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + + # Import the original model class + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + ) + + # Create a small config for testing + config = Qwen3VLConfig( + vision_config={ + "hidden_size": 32, + "intermediate_size": 64, + "depth": 2, + "num_heads": 4, + "patch_size": 14, + "temporal_patch_size": 1, + "in_channels": 3, + "num_position_embeddings": 144, # 12*12 + "spatial_merge_size": 2, + "deepstack_visual_indexes": [], + }, + text_config={ + "hidden_size": 32, + "intermediate_size": 64, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "max_position_embeddings": 128, + "vocab_size": 1000, + "pad_token_id": 0, + "rope_scaling": {}, + }, + image_token_id=1, + video_token_id=2, + vision_start_token_id=3, + ) + + cls.model_fp = Qwen3VLForConditionalGeneration(config) + cls.config = config + + def test_mode_transitions(self): + qmodel = QuantQwen3VLForConditionalGeneration(self.model_fp) + self.assertIs(qmodel._mode, Mode.NO_QUANT) + + qmodel.enable_calibration() + self.assertIs(qmodel._mode, Mode.CALIB) + + # Create dummy inputs + input_ids = torch.randint(0, self.config.text_config.vocab_size, (1, 10)) + # For simplicity, not providing pixel_values, so no vision processing + + _ = qmodel(input_ids=input_ids) + + # For simplicity, not providing pixel_values, so no vision processing + + qmodel.freeze_qparams() + self.assertIs(qmodel._mode, Mode.QUANT) + + def test_forward_diff(self): + qmodel = QuantQwen3VLForConditionalGeneration(self.model_fp) + qmodel.enable_calibration() + for _ in range(2): + inp = torch.randint(0, self.config.text_config.vocab_size, (1, 10)) + _ = qmodel(input_ids=inp) + qmodel.freeze_qparams() + + x = torch.randint(0, self.config.text_config.vocab_size, (1, 10)) + with torch.no_grad(): + q_out = qmodel(input_ids=x) + fp_out = self.model_fp(input_ids=x) + + # Check that outputs are close but not identical (due to quantization) + diff = (fp_out.logits - q_out.logits).abs().mean().item() + self.assertGreater(diff, 0.0) + # The threshold might need adjustment based on actual behavior + self.assertLess(diff, 1.0) + + def test_lm_head_override(self): + cfg = PTQConfig( + default_dtype=DType.uint(8), + overrides={ + "lm_head": { + "act_in": {"dtype": DType.uint(4)}, + "act_out": {"dtype": DType.uint(4)}, + } + }, + ) + qmodel = QuantQwen3VLForConditionalGeneration(self.model_fp, qcfg=cfg) + # We know qmodel.lm_head is a PTQWrapper wrapping a QuantLinear + assert isinstance(qmodel.lm_head, PTQWrapper) + q_lin = qmodel.lm_head.wrapped + + self.assertIsInstance(q_lin, QuantLinear) + # type: ignore below because obs_act_in and obs_act_out are not in the base class interface + self.assertEqual(q_lin.obs_act_in.dtype, DType.uint(4)) + self.assertEqual(q_lin.obs_act_out.dtype, DType.uint(4)) + + +class TestSubgraphExport(unittest.TestCase): + def setUp(self): + torch.manual_seed(0) + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + ) + + config = Qwen3VLConfig( + vision_config={ + "hidden_size": 16, + "intermediate_size": 32, + "depth": 1, + "num_heads": 2, + "patch_size": 4, + "temporal_patch_size": 1, + "in_channels": 3, + "num_position_embeddings": 16, # 4*4 + "spatial_merge_size": 2, + "deepstack_visual_indexes": [], + }, + text_config={ + "hidden_size": 16, + "intermediate_size": 32, + "num_hidden_layers": 1, + "num_attention_heads": 2, + "num_key_value_heads": 1, + "max_position_embeddings": 32, + "vocab_size": 100, + "pad_token_id": 0, + "rope_scaling": {}, + }, + image_token_id=1, + video_token_id=2, + vision_start_token_id=3, + ) + + model_fp = Qwen3VLForConditionalGeneration(config) + self.model_int8 = QuantQwen3VLForConditionalGeneration(model_fp).eval() + self.input_ids = torch.randint(0, config.text_config.vocab_size, (1, 8)) + + def test_calib_quant_export(self): + # calib + self.model_int8.enable_calibration() + _ = self.model_int8(self.input_ids) + self.model_int8.freeze_qparams() + + self.assertIs(self.model_int8._mode, Mode.QUANT) + + # export + with tempfile.TemporaryDirectory() as td: + path = pathlib.Path(td) / "model.circle" + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + exported = tico.convert(self.model_int8, (self.input_ids,)) + exported.save(path) + self.assertTrue(path.exists()) diff --git a/tico/quantization/wrapq/examples/qwen/quantize_for_conditional_generation.py b/tico/quantization/wrapq/examples/qwen/quantize_for_conditional_generation.py new file mode 100644 index 00000000..15b72952 --- /dev/null +++ b/tico/quantization/wrapq/examples/qwen/quantize_for_conditional_generation.py @@ -0,0 +1,193 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example of using QuantQwen3VLForConditionalGeneration wrapper. + +This script demonstrates how to: +1. Create a small Qwen3VLForConditionalGeneration model. +2. Wrap it with QuantQwen3VLForConditionalGeneration using `prepare`. +3. Perform calibration with synthetic data. +4. Freeze quantization parameters using `convert`. +5. Run forward pass and compare results. +6. Export the quantized model to .circle format. +""" + +import pathlib + +import torch + +# Import necessary modules from tico +from tico.quantization import convert, prepare +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.evaluation.metric import compute_peir +from tico.quantization.evaluation.utils import plot_two_outputs +from tico.quantization.wrapq.mode import Mode + +# Check if transformers library is available for Qwen3-VL +from tico.quantization.wrapq.utils.version import has_transformers_for +from tico.utils.utils import SuppressWarning + +# Set random seed for reproducibility +torch.manual_seed(123) + + +def main(): + if not has_transformers_for("qwen3-vl"): + print("Required transformers not installed — skipping example") + return + + from transformers.models.qwen3_vl.configuration_qwen3_vl import Qwen3VLConfig + + # Import the original model class + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + ) + + print("Creating a small Qwen3VLForConditionalGeneration model for testing...") + + # Create a small config for testing to make the example lightweight + config = Qwen3VLConfig( + vision_config={ + "hidden_size": 32, + "intermediate_size": 64, + "depth": 2, + "num_heads": 4, + "patch_size": 14, + "temporal_patch_size": 1, + "in_channels": 3, + "num_position_embeddings": 144, # 12*12 + "spatial_merge_size": 2, + "deepstack_visual_indexes": [], + }, + text_config={ + "hidden_size": 32, + "intermediate_size": 64, + "num_hidden_layers": 2, + "num_attention_heads": 4, + "num_key_value_heads": 2, + "max_position_embeddings": 128, + "vocab_size": 1000, + "pad_token_id": 0, + }, + image_token_id=1, + video_token_id=2, + vision_start_token_id=3, + ) + + # Create the original model + model_fp = Qwen3VLForConditionalGeneration(config) + print(f"Original model created: {type(model_fp).__name__}") + + # Wrap the model with QuantQwen3VLForConditionalGeneration using `prepare` + # This is the standard way to wrap models in tico + print("\nWrapping the model with QuantQwen3VLForConditionalGeneration...") + qmodel = prepare(model_fp, PTQConfig()) + qmodel.eval() # Set to evaluation mode + + print(f"Quantized model created: {type(qmodel).__name__}") + print(f"Wrapped module type: {type(qmodel.wrapped).__name__}") + print(f"Initial mode: {qmodel._mode.name}") + + # Check that the model is in NO_QUANT mode + assert qmodel._mode is Mode.NO_QUANT + + # Enable calibration mode (this is done internally by `prepare`, but we can check) + print("\nModel is ready for calibration.") + + # ------------------------------------------------------------------------- + # 2. Calibration with synthetic data + # ------------------------------------------------------------------------- + print("\nPerforming calibration with synthetic data...") + + # Create dummy inputs for calibration + # For simplicity, we will not provide pixel_values, so no vision processing + # This means the vision components will not be calibrated, but the text part will be + BATCH_SIZE = 2 + SEQ_LEN = 10 + VOCAB_SIZE = config.text_config.vocab_size + + CALIB_INPUTS = [] + for _ in range(5): # 5 calibration samples + input_ids = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN)) + CALIB_INPUTS.append({"input_ids": input_ids}) + + # Run calibration + with torch.no_grad(): + for inp in CALIB_INPUTS: + _ = qmodel(**inp) + + print("Calibration completed.") + + # ------------------------------------------------------------------------- + # 3. Freeze quantization parameters + # ------------------------------------------------------------------------- + print("\nFreezing quantization parameters...") + convert(qmodel) + + print(f"Mode after convert: {qmodel._mode.name}") + assert qmodel._mode is Mode.QUANT, "Quantization mode should be active now." + print("Quantization parameters frozen.") + + # ------------------------------------------------------------------------- + # 4. Quick diff check (INT-sim vs FP32) + # ------------------------------------------------------------------------- + print("\nComparing quantized and original model outputs...") + test_input = {"input_ids": torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN))} + + with torch.no_grad(): + q_out = qmodel(**test_input) + fp_out = model_fp(**test_input) + + # For Qwen3VLForConditionalGeneration, the output is typically a CausalLMOutputWithPast + # which has a `logits` attribute + logits_quant = q_out.logits + logits_fp = fp_out.logits + + diff_mean = (logits_quant - logits_fp).abs().mean().item() + peir = compute_peir(logits_fp, logits_quant) * 100 + + print("┌───────────── Quantization Error Summary ─────────────") + print(f"│ Mean |diff|: {diff_mean:.6f}") + print(f"│ PEIR : {peir:.6f} %") + print("└──────────────────────────────────────────────────────") + + # Optionally, plot the outputs (this might not be very informative for high-dim tensors) + # print(plot_two_outputs(logits_fp, logits_quant)) + + # ------------------------------------------------------------------------- + # 5. Export the quantized model + # ------------------------------------------------------------------------- + print("\nExporting the quantized model...") + save_path = pathlib.Path("qwen3vl_conditional_generation.q.circle") + + # Example input for export + example_input = (test_input["input_ids"],) + + with SuppressWarning(UserWarning, ".*"): + try: + import tico + + cm = tico.convert(qmodel, example_input) + cm.save(save_path) + print(f"Quantized Circle model saved to {save_path.resolve()}") + except Exception as e: + print(f"Export failed: {e}") + print( + "This might be expected if the model is not fully supported for export yet." + ) + + +if __name__ == "__main__": + main() diff --git a/tico/quantization/wrapq/wrappers/qwen_vl/quant_for_conditional_generation.py b/tico/quantization/wrapq/wrappers/qwen_vl/quant_for_conditional_generation.py new file mode 100644 index 00000000..95ead728 --- /dev/null +++ b/tico/quantization/wrapq/wrappers/qwen_vl/quant_for_conditional_generation.py @@ -0,0 +1,127 @@ +# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Optional, TYPE_CHECKING + +import torch +import torch.nn as nn + +from tico.quantization.config.ptq import PTQConfig +from tico.quantization.wrapq.wrappers.ptq_wrapper import PTQWrapper +from tico.quantization.wrapq.wrappers.quant_module_base import QuantModuleBase +from tico.quantization.wrapq.wrappers.registry import try_register +from transformers.cache_utils import Cache +from transformers.modeling_outputs import CausalLMOutputWithPast + +if TYPE_CHECKING: + from transformers.models.qwen3_vl.modeling_qwen3_vl import ( + Qwen3VLForConditionalGeneration, + ) + + +@try_register( + "transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLForConditionalGeneration", +) +class QuantQwen3VLForConditionalGeneration(QuantModuleBase): + def __init__( + self, + model_fp: nn.Module, # This will be an instance of Qwen3VLForConditionalGeneration + *, + qcfg: Optional[PTQConfig] = None, + fp_name: Optional[str] = None, + ): + super().__init__(qcfg, fp_name=fp_name) + + # Store reference to original model for accessing its attributes + self.wrapped = model_fp + + # Wrap self.model_fp.model (an instance of Qwen3VLModel) + model_cfg = qcfg.child("model") if qcfg else None + # Use type: ignore for the model attribute since we know it's a Module + self.model = PTQWrapper( + model_fp.model, qcfg=model_cfg, fp_name=f"{fp_name}.model" # type: ignore[arg-type] + ) + + # Wrap self.model_fp.lm_head (an instance of nn.Linear) + lm_head_cfg = qcfg.child("lm_head") if qcfg else None + # Use type: ignore for the lm_head attribute since we know it's a Module + self.lm_head = PTQWrapper( + model_fp.lm_head, qcfg=lm_head_cfg, fp_name=f"{fp_name}.lm_head" # type: ignore[arg-type] + ) + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + cache_position: torch.LongTensor | None = None, + **kwargs, + ) -> CausalLMOutputWithPast: + # Get config from wrapped model + # Type ignore is needed because Pylance infers self.wrapped incorrectly + config = self.wrapped.config # type: ignore[attr-defined] + + output_attentions = config.output_attentions # type: ignore[attr-defined] + output_hidden_states = config.output_hidden_states # type: ignore[attr-defined] + return_dict = config.use_return_dict # type: ignore[attr-defined] + + # Call the wrapped model to get hidden states + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Get loss function from wrapped model + loss_fct = self.wrapped.loss_function # type: ignore[attr-defined] + loss = loss_fct( + logits=logits, + labels=labels, + vocab_size=config.vocab_size, # type: ignore[attr-defined] + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def _all_observers(self) -> Iterable: + # Recursively return observers from subcomponents + yield from self.model._all_observers() + yield from self.lm_head._all_observers()