From a3cf39fdd235cb6259b98aa4f21bf3a3f65bb25c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 26 Jun 2026 13:34:05 +0800 Subject: [PATCH 1/6] feat(quant): Quantizer class with BaseQuantPass pipeline (#964) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add passes/ sub-package with BaseQuantPass ABC - Implement FP16Pass, RTNPass, QDQPass — each accepts WinMLQuantizationConfig and reads only the fields relevant to that pass - Add Quantizer class: chains passes sequentially, uses tempfile for intermediates, merges QuantizeResult stats across passes - Add expand_precision(mode, config) to map precision strings to pass lists (supports 'fp16', 'rtn', 'static', 'dynamic', 'w4a16') - Keep quantize_onnx() as backward-compatible entry point - Add tests/unit/test_quant_passes.py (19 tests, all passing) --- src/winml/modelkit/quant/__init__.py | 20 +- src/winml/modelkit/quant/passes/__init__.py | 18 + src/winml/modelkit/quant/passes/base.py | 60 +++ src/winml/modelkit/quant/passes/fp16.py | 87 +++ src/winml/modelkit/quant/passes/qdq.py | 209 ++++++++ src/winml/modelkit/quant/passes/rtn.py | 106 ++++ src/winml/modelkit/quant/quantizer.py | 556 ++++++++------------ tests/unit/test_quant_passes.py | 398 ++++++++++++++ 8 files changed, 1104 insertions(+), 350 deletions(-) create mode 100644 src/winml/modelkit/quant/passes/__init__.py create mode 100644 src/winml/modelkit/quant/passes/base.py create mode 100644 src/winml/modelkit/quant/passes/fp16.py create mode 100644 src/winml/modelkit/quant/passes/qdq.py create mode 100644 src/winml/modelkit/quant/passes/rtn.py create mode 100644 tests/unit/test_quant_passes.py diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index bc8e6ee06..070ecf86f 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -7,29 +7,47 @@ Provides QDQ (Quantize-Dequantize) quantization for ONNX models. Usage: - from winml.modelkit.quant import quantize_onnx, WinMLQuantizationConfig + from winml.modelkit.quant import ( + quantize_onnx, + Quantizer, + expand_precision, + WinMLQuantizationConfig, + ) # Quick quantize with defaults (10 samples, uint8) result = quantize_onnx("model.onnx") # Custom config result = quantize_onnx("model.onnx", WinMLQuantizationConfig(samples=100)) + + # Pipeline: RTN int4 followed by FP16 (w4a16) + config = WinMLQuantizationConfig(mode="w4a16", rtn_bits=4) + result = Quantizer(expand_precision("w4a16", config)).run("model.onnx", "out.onnx") """ from typing import Any from .config import QuantizeResult, WinMLQuantizationConfig +from .passes import BaseQuantPass, FP16Pass, QDQPass, RTNPass __all__ = [ + "BaseQuantPass", + "FP16Pass", + "QDQPass", "QuantizeResult", + "Quantizer", + "RTNPass", "WinMLQuantizationConfig", + "expand_precision", "quantize_onnx", ] _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "quantize_onnx": (".quantizer", "quantize_onnx"), + "Quantizer": (".quantizer", "Quantizer"), + "expand_precision": (".quantizer", "expand_precision"), } diff --git a/src/winml/modelkit/quant/passes/__init__.py b/src/winml/modelkit/quant/passes/__init__.py new file mode 100644 index 000000000..9f2910a7d --- /dev/null +++ b/src/winml/modelkit/quant/passes/__init__.py @@ -0,0 +1,18 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Quantization passes sub-package.""" + +from .base import BaseQuantPass +from .fp16 import FP16Pass +from .qdq import QDQPass +from .rtn import RTNPass + + +__all__ = [ + "BaseQuantPass", + "FP16Pass", + "QDQPass", + "RTNPass", +] diff --git a/src/winml/modelkit/quant/passes/base.py b/src/winml/modelkit/quant/passes/base.py new file mode 100644 index 000000000..611c490c2 --- /dev/null +++ b/src/winml/modelkit/quant/passes/base.py @@ -0,0 +1,60 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Base class for quantization passes.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from pathlib import Path + + from ..config import QuantizeResult, WinMLQuantizationConfig + + +class BaseQuantPass(ABC): + """Abstract base class for a single quantization pass. + + Each pass is constructed with a ``WinMLQuantizationConfig`` that provides + all settings. Passes read only the fields relevant to them and ignore the + rest, so a single shared config object can be threaded through every pass + in a :class:`~winml.modelkit.quant.quantizer.Quantizer` pipeline. + + Example:: + + pass_ = FP16Pass(config) + result = pass_.run(model_path, output_path) + """ + + def __init__(self, config: WinMLQuantizationConfig) -> None: + self._config = config + + @property + def config(self) -> WinMLQuantizationConfig: + """Return the shared quantization configuration.""" + return self._config + + @abstractmethod + def run( + self, + model_path: Path, + output_path: Path, + *, + use_external_data: bool = True, + ) -> QuantizeResult: + """Run this quantization pass. + + Args: + model_path: Path to the input ONNX model. + output_path: Path where the output ONNX model should be written. + use_external_data: Whether to write large tensors as external data. + + Returns: + :class:`~winml.modelkit.quant.config.QuantizeResult` describing + the outcome of this pass. + """ + ... diff --git a/src/winml/modelkit/quant/passes/fp16.py b/src/winml/modelkit/quant/passes/fp16.py new file mode 100644 index 000000000..77cae509f --- /dev/null +++ b/src/winml/modelkit/quant/passes/fp16.py @@ -0,0 +1,87 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""FP16 conversion pass.""" + +from __future__ import annotations + +import logging +import time +from typing import TYPE_CHECKING + +from .base import BaseQuantPass + + +if TYPE_CHECKING: + from pathlib import Path + + from ..config import QuantizeResult, WinMLQuantizationConfig + + +logger = logging.getLogger(__name__) + + +class FP16Pass(BaseQuantPass): + """Convert an ONNX model to FP16. + + Reads from :class:`~winml.modelkit.quant.config.WinMLQuantizationConfig`: + + - ``fp16_keep_io_types`` — keep model inputs/outputs in their original dtype + - ``fp16_op_block_list`` — op types that must not be cast to FP16 + + Example:: + + pass_ = FP16Pass(config) + result = pass_.run("model.onnx", "model_fp16.onnx") + """ + + def __init__(self, config: WinMLQuantizationConfig) -> None: + super().__init__(config) + + def run( + self, + model_path: Path, + output_path: Path, + *, + use_external_data: bool = True, + ) -> QuantizeResult: + """Convert *model_path* to FP16 and write the result to *output_path*.""" + from ...onnx import load_onnx, save_onnx + from ..config import QuantizeResult + from ..fp16 import convert_to_fp16 + + if self._config.calibration_data is not None: + logger.warning( + "calibration_data is set but this is an FP16Pass" + " — calibration data will be ignored." + ) + + start_time = time.perf_counter() + errors: list[str] = [] + warnings: list[str] = [] + + logger.info("Running FP16-only conversion (no quantization)...") + model = load_onnx(model_path, validate=False) + model = convert_to_fp16( + model, + keep_io_types=self._config.fp16_keep_io_types, + op_block_list=self._config.fp16_op_block_list, + ) + output_path.parent.mkdir(parents=True, exist_ok=True) + save_onnx(model, output_path, use_external_data=use_external_data) + + total_time = time.perf_counter() - start_time + logger.info( + "FP16 conversion complete: %s -> %s (%.2fs)", + model_path.name, + output_path.name, + total_time, + ) + return QuantizeResult( + success=True, + output_path=output_path, + total_time_seconds=total_time, + errors=errors, + warnings=warnings, + ) diff --git a/src/winml/modelkit/quant/passes/qdq.py b/src/winml/modelkit/quant/passes/qdq.py new file mode 100644 index 000000000..ed2a68598 --- /dev/null +++ b/src/winml/modelkit/quant/passes/qdq.py @@ -0,0 +1,209 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""QDQ (Quantize-Dequantize) calibrated quantization pass.""" + +from __future__ import annotations + +import logging +import os +import time +from pathlib import Path +from typing import TYPE_CHECKING + +from .base import BaseQuantPass + + +if TYPE_CHECKING: + from ..config import QuantizeResult, WinMLQuantizationConfig + + +logger = logging.getLogger(__name__) + + +class QDQPass(BaseQuantPass): + """QDQ (static/dynamic) calibrated quantization pass. + + Reads all QDQ-relevant fields from + :class:`~winml.modelkit.quant.config.WinMLQuantizationConfig`: + ``samples``, ``calibration_method``, ``calibration_data``, ``task``, + ``model_name``, ``dataset_name``, ``weight_type``, ``activation_type``, + ``per_channel``, ``symmetric``, ``op_types_to_quantize``, + ``nodes_to_exclude``. + + Example:: + + pass_ = QDQPass(config) + result = pass_.run("model.onnx", "model_qdq.onnx") + """ + + def __init__(self, config: WinMLQuantizationConfig) -> None: + super().__init__(config) + + def run( + self, + model_path: Path, + output_path: Path, + *, + use_external_data: bool = True, + ) -> QuantizeResult: + """Apply QDQ calibrated quantization to *model_path*.""" + from onnxruntime.quantization import ( + CalibrationMethod, + QuantType, + get_qdq_config, + quantize, + ) + + from ..config import QuantizeResult + + weight_type_map = { + "uint8": QuantType.QUInt8, + "int8": QuantType.QInt8, + "uint16": QuantType.QUInt16, + "int16": QuantType.QInt16, + } + activation_type_map = { + "uint8": QuantType.QUInt8, + "int8": QuantType.QInt8, + "uint16": QuantType.QUInt16, + "int16": QuantType.QInt16, + } + calibration_method_map = { + "minmax": CalibrationMethod.MinMax, + "entropy": CalibrationMethod.Entropy, + "percentile": CalibrationMethod.Percentile, + } + + start_time = time.perf_counter() + errors: list[str] = [] + warnings: list[str] = [] + + cal_start = time.perf_counter() + + if self._config.calibration_data is not None: + data_reader = self._config.calibration_data + logger.info("Using custom calibration data") + else: + from ...datasets import DatasetCalibrationReader + + task = self._config.task or "random" + data_reader = DatasetCalibrationReader( + model_name=self._config.model_name or "random", + task=task, + max_samples=self._config.samples, + dataset_name=self._config.dataset_name, + model_path=model_path, + ) + logger.info( + "Using calibration: task=%s, samples=%d", + task, + self._config.samples, + ) + + cal_time = time.perf_counter() - cal_start + + qdq_start = time.perf_counter() + + weight_type = weight_type_map[self._config.weight_type] + activation_type = activation_type_map[self._config.activation_type] + calibrate_method = calibration_method_map[self._config.calibration_method] + + extra_options = { + "ActivationSymmetric": self._config.symmetric, + "WeightSymmetric": self._config.symmetric, + } + + logger.info("Generating QDQ config...") + qdq_config = get_qdq_config( + model_input=str(model_path), + calibration_data_reader=data_reader, + weight_type=weight_type, + activation_type=activation_type, + per_channel=self._config.per_channel, + calibrate_method=calibrate_method, + op_types_to_quantize=self._config.op_types_to_quantize, + nodes_to_exclude=self._config.nodes_to_exclude or [], + extra_options=extra_options, + ) + + from onnxruntime.quantization.quant_utils import add_pre_process_metadata + + from ...onnx import capture_metadata, load_onnx, restore_metadata, save_onnx + from ..qdq_fix import fix_qdq_dtype_info + + input_model = load_onnx(model_path, validate=False) + metadata_snapshot = capture_metadata(input_model) + add_pre_process_metadata(input_model) + + if use_external_data: + qdq_config.use_external_data_format = True + logger.info("Applying quantization...") + abs_model_output = str(Path(output_path).resolve()) + if output_path.exists(): + output_path.unlink() + stale_sidecar = output_path.parent / f"{output_path.name}.data" + if stale_sidecar.exists(): + stale_sidecar.unlink() + original_cwd = Path.cwd() + output_path.parent.mkdir(parents=True, exist_ok=True) + try: + os.chdir(output_path.parent) + quantize( + model_input=input_model, + model_output=abs_model_output, + quant_config=qdq_config, + ) + finally: + os.chdir(original_cwd) + + qdq_time = time.perf_counter() - qdq_start + + postproc_start = time.perf_counter() + + quantized_model = load_onnx(output_path, validate=False) + + logger.info("Fixing QDQ node dtype info...") + fix_result = fix_qdq_dtype_info(quantized_model) + warnings.extend(fix_result.warnings) + + from ...onnx import infer_shapes + + logger.info("Running shape inference on quantized model...") + quantized_model = infer_shapes(quantized_model) + + if metadata_snapshot.node_count > 0: + logger.info("Restoring metadata from pre-quantization model...") + restore_metadata(quantized_model, metadata_snapshot) + + postproc_time = time.perf_counter() - postproc_start + + from ...compiler import QDQ_OP_TYPES + + nodes_quantized = sum( + 1 for node in quantized_model.graph.node if node.op_type in QDQ_OP_TYPES + ) + + save_onnx(quantized_model, output_path) + + total_time = time.perf_counter() - start_time + + logger.info( + "Quantization complete: %s -> %s (%.2fs)", + model_path.name, + output_path.name, + total_time, + ) + + return QuantizeResult( + success=True, + output_path=output_path, + calibration_time_seconds=cal_time, + qdq_insertion_time_seconds=qdq_time, + postproc_time_seconds=postproc_time, + total_time_seconds=total_time, + nodes_quantized=nodes_quantized, + errors=errors, + warnings=warnings, + ) diff --git a/src/winml/modelkit/quant/passes/rtn.py b/src/winml/modelkit/quant/passes/rtn.py new file mode 100644 index 000000000..80b12c9b0 --- /dev/null +++ b/src/winml/modelkit/quant/passes/rtn.py @@ -0,0 +1,106 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""RTN (Round-To-Nearest) weight-only quantization pass.""" + +from __future__ import annotations + +import logging +import time +from typing import TYPE_CHECKING + +from .base import BaseQuantPass + + +if TYPE_CHECKING: + from pathlib import Path + + from ..config import QuantizeResult, WinMLQuantizationConfig + + +logger = logging.getLogger(__name__) + + +class RTNPass(BaseQuantPass): + """RTN weight-only quantization pass. + + Reads from :class:`~winml.modelkit.quant.config.WinMLQuantizationConfig`: + + - ``rtn_bits`` — quantization bit-width (default 4) + - ``rtn_block_size`` — block size for quantization (default 128) + - ``rtn_symmetric`` — symmetric quantization (default True) + - ``rtn_accuracy_level`` — ORT accuracy level 0-4 (0 = disabled) + - ``nodes_to_exclude`` — node names to skip + + Example:: + + pass_ = RTNPass(config) + result = pass_.run("model.onnx", "model_rtn.onnx") + """ + + def __init__(self, config: WinMLQuantizationConfig) -> None: + super().__init__(config) + + def run( + self, + model_path: Path, + output_path: Path, + *, + use_external_data: bool = True, + ) -> QuantizeResult: + """Apply RTN weight-only quantization to *model_path*.""" + from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer + + from ...onnx import save_onnx + from ..config import QuantizeResult + + if self._config.calibration_data is not None: + logger.warning( + "calibration_data is set but this is an RTNPass — calibration data will be ignored." + ) + + start_time = time.perf_counter() + errors: list[str] = [] + warnings: list[str] = [] + + logger.info( + "Running RTN %d-bit weight-only quantization (block_size=%d, symmetric=%s)...", + self._config.rtn_bits, + self._config.rtn_block_size, + self._config.rtn_symmetric, + ) + + accuracy_level = ( + self._config.rtn_accuracy_level if self._config.rtn_accuracy_level != 0 else None + ) + + quantizer = MatMulNBitsQuantizer( + model=str(model_path), + bits=self._config.rtn_bits, + block_size=self._config.rtn_block_size, + is_symmetric=self._config.rtn_symmetric, + accuracy_level=accuracy_level, + nodes_to_exclude=self._config.nodes_to_exclude, + ) + quantizer.process() + + output_path.parent.mkdir(parents=True, exist_ok=True) + quantized_model = quantizer.model.model + + save_onnx(quantized_model, output_path, use_external_data=use_external_data) + + total_time = time.perf_counter() - start_time + logger.info( + "RTN quantization complete: %s -> %s (%.2fs)", + model_path.name, + output_path.name, + total_time, + ) + return QuantizeResult( + success=True, + output_path=output_path, + total_time_seconds=total_time, + errors=errors, + warnings=warnings, + ) diff --git a/src/winml/modelkit/quant/quantizer.py b/src/winml/modelkit/quant/quantizer.py index 8b9ba033f..bd45cb0b9 100644 --- a/src/winml/modelkit/quant/quantizer.py +++ b/src/winml/modelkit/quant/quantizer.py @@ -7,398 +7,256 @@ from __future__ import annotations import logging -import os -import time +import tempfile +import traceback from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import Any from .config import QuantizeResult, WinMLQuantizationConfig - - -if TYPE_CHECKING: - from collections.abc import Callable +from .passes import BaseQuantPass, FP16Pass, QDQPass, RTNPass logger = logging.getLogger(__name__) +# Precision strings that expand to multiple sequential passes. +_COMPOSITE_PRECISIONS: dict[str, list[str]] = { + "w4a16": ["rtn", "fp16"], +} -def quantize_onnx( - model_path: str | Path, - output_path: str | Path | None = None, + +def expand_precision( + mode: str, config: WinMLQuantizationConfig | None = None, - **kwargs: Any, -) -> QuantizeResult: - """Quantize an ONNX model using a single quantization pass. +) -> list[BaseQuantPass]: + """Expand a precision string into an ordered list of quantization passes. - The quantization mode is driven by ``config.mode``: - - "fp16": FP16 conversion (no quantization) - - "rtn": RTN weight-only quantization - - "static"/"dynamic": QDQ calibrated quantization + All passes share the same ``config`` so every pass can read the fields + relevant to it. + + Supported values: - Note: Composite precisions like "w4a16" (requiring multiple sequential - passes) are not yet supported here — see #964 for the planned - Quantizer pipeline that will handle multi-pass orchestration. + ========= ======================= + mode passes + ========= ======================= + ``fp16`` ``[FP16Pass(config)]`` + ``rtn`` ``[RTNPass(config)]`` + ``static`` ``[QDQPass(config)]`` + ``dynamic`` ``[QDQPass(config)]`` + ``w4a16`` ``[RTNPass(config), FP16Pass(config)]`` + ========= ======================= Args: - model_path: Path to input ONNX model. - output_path: Path for output model (defaults to {model_stem}_qdq.onnx). - config: Quantization configuration (uses defaults if None). + mode: Precision string (e.g. ``"w4a16"``). + config: Shared quantization configuration. If *None*, a default + :class:`WinMLQuantizationConfig` is used. Returns: - QuantizeResult with path to final output model and metrics. - - Examples: - # Single-pass RTN int4 - result = quantize_onnx("model.onnx", config=WinMLQuantizationConfig(mode="rtn")) + Ordered list of :class:`~winml.modelkit.quant.passes.BaseQuantPass` + instances ready to be executed by :class:`Quantizer`. - # Single-pass FP16 only - result = quantize_onnx("model.onnx", config=WinMLQuantizationConfig(mode="fp16")) - - # QDQ with defaults - result = quantize_onnx("model.onnx", config=WinMLQuantizationConfig(samples=100)) + Raises: + ValueError: If *mode* is not recognised. """ - model_path = Path(model_path) config = config or WinMLQuantizationConfig() - if output_path is not None: - output_path = Path(output_path) - else: - output_path = model_path.parent / f"{model_path.stem}_qdq.onnx" - - return _quantize_single_pass( - model_path=model_path, - output_path=output_path, - config=config, - **kwargs, - ) - + _pass_factories: dict[str, BaseQuantPass] = { + "fp16": FP16Pass(config), + "rtn": RTNPass(config), + "static": QDQPass(config), + "dynamic": QDQPass(config), + } -def _quantize_single_pass( - *, - model_path: Path, - output_path: Path, - config: WinMLQuantizationConfig, - **kwargs: Any, -) -> QuantizeResult: - """Run a single quantization pass (FP16, RTN, or QDQ). + if mode in _pass_factories: + return [_pass_factories[mode]] - This is the internal workhorse — callers should use ``quantize_onnx()`` - which handles multi-pass expansion and path resolution. - """ - use_external_data: bool = kwargs.pop("use_external_data", True) + if mode in _COMPOSITE_PRECISIONS: + return [_pass_factories[step] for step in _COMPOSITE_PRECISIONS[mode]] - start_time = time.perf_counter() - - # Validate input - if not model_path.exists(): - return QuantizeResult( - success=False, - output_path=None, - errors=[f"Model not found: {model_path}"], - ) - - errors: list[str] = [] - warnings: list[str] = [] - - try: - # Dispatch to the appropriate single-mode handler - _mode_handlers: dict[str, Callable[..., QuantizeResult]] = { - "fp16": _quantize_fp16, - "rtn": _quantize_rtn, - } - handler = _mode_handlers.get(config.mode, _quantize_qdq) - return handler( - model_path=model_path, - output_path=output_path, - config=config, - start_time=start_time, - use_external_data=use_external_data, - errors=errors, - warnings=warnings, - ) - - except Exception: - total_time = time.perf_counter() - start_time - logger.exception("Quantization failed") - - import traceback - - return QuantizeResult( - success=False, - output_path=None, - total_time_seconds=total_time, - errors=[traceback.format_exc()], - warnings=warnings, - ) - - -def _quantize_fp16( - *, - model_path: Path, - output_path: Path, - config: WinMLQuantizationConfig, - start_time: float, - use_external_data: bool, - errors: list[str], - warnings: list[str], -) -> QuantizeResult: - """Run FP16 conversion (no quantization).""" - from ..onnx import load_onnx, save_onnx - from .fp16 import convert_to_fp16 - - if config.calibration_data is not None: - logger.warning( - "calibration_data is set but mode='fp16' — calibration data will be ignored." - ) - - logger.info("Running FP16-only conversion (no quantization)...") - model = load_onnx(model_path, validate=False) - model = convert_to_fp16( - model, - keep_io_types=config.fp16_keep_io_types, - op_block_list=config.fp16_op_block_list, - ) - output_path.parent.mkdir(parents=True, exist_ok=True) - save_onnx(model, output_path, use_external_data=use_external_data) - - total_time = time.perf_counter() - start_time - logger.info( - "FP16 conversion complete: %s -> %s (%.2fs)", - model_path.name, - output_path.name, - total_time, - ) - return QuantizeResult( - success=True, - output_path=output_path, - total_time_seconds=total_time, - errors=errors, - warnings=warnings, + raise ValueError( + f"Unknown precision mode {mode!r}. " + f"Valid values: {sorted(_pass_factories) + sorted(_COMPOSITE_PRECISIONS)}" ) -def _quantize_rtn( - *, - model_path: Path, - output_path: Path, - config: WinMLQuantizationConfig, - start_time: float, - use_external_data: bool, - errors: list[str], - warnings: list[str], -) -> QuantizeResult: - """Run RTN weight-only quantization.""" - from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer +class Quantizer: + """Orchestrate a sequential pipeline of quantization passes. - from ..onnx import save_onnx + Each pass receives the output of the previous pass as its input. For a + single-pass pipeline no temporary files are created. For multi-pass + pipelines, intermediate models are written to a ``tempfile.TemporaryDirectory`` + that is cleaned up automatically on success *or* failure. - if config.calibration_data is not None: - logger.warning("calibration_data is set but mode='rtn' — calibration data will be ignored.") + :class:`QuantizeResult` fields are merged across passes: - logger.info( - "Running RTN %d-bit weight-only quantization (block_size=%d, symmetric=%s)...", - config.rtn_bits, - config.rtn_block_size, - config.rtn_symmetric, - ) + - ``success`` — logical AND of all pass results + - ``output_path`` — path written by the final pass + - Timing fields — summed across passes + - ``nodes_quantized`` — summed across passes + - ``errors`` / ``warnings`` — concatenated across passes - accuracy_level = config.rtn_accuracy_level if config.rtn_accuracy_level != 0 else None + Example:: - quantizer = MatMulNBitsQuantizer( - model=str(model_path), - bits=config.rtn_bits, - block_size=config.rtn_block_size, - is_symmetric=config.rtn_symmetric, - accuracy_level=accuracy_level, - nodes_to_exclude=config.nodes_to_exclude, - ) - quantizer.process() + from winml.modelkit.quant import Quantizer, expand_precision, WinMLQuantizationConfig - output_path.parent.mkdir(parents=True, exist_ok=True) - quantized_model = quantizer.model.model + config = WinMLQuantizationConfig(mode="w4a16", rtn_bits=4) + quantizer = Quantizer(expand_precision("w4a16", config)) + result = quantizer.run("model.onnx", "model_w4a16.onnx") + """ - save_onnx(quantized_model, output_path, use_external_data=use_external_data) + def __init__(self, passes: list[BaseQuantPass]) -> None: + if not passes: + raise ValueError("Quantizer requires at least one pass.") + self._passes = passes + + @property + def passes(self) -> list[BaseQuantPass]: + """Return a copy of the pass list.""" + return list(self._passes) + + def run( + self, + model_path: str | Path, + output_path: str | Path, + *, + use_external_data: bool = True, + ) -> QuantizeResult: + """Run the quantization pipeline. + + Args: + model_path: Path to the input ONNX model. + output_path: Path for the final output model. + use_external_data: Whether to write large tensors as external data. + + Returns: + Merged :class:`~winml.modelkit.quant.config.QuantizeResult`. + """ + model_path = Path(model_path) + output_path = Path(output_path) - total_time = time.perf_counter() - start_time - logger.info( - "RTN quantization complete: %s -> %s (%.2fs)", - model_path.name, - output_path.name, - total_time, - ) + if not model_path.exists(): + return QuantizeResult( + success=False, + output_path=None, + errors=[f"Model not found: {model_path}"], + ) + + if len(self._passes) == 1: + return self._run_pass(self._passes[0], model_path, output_path, use_external_data) + + return self._run_multi_pass(model_path, output_path, use_external_data) + + def _run_pass( + self, + pass_: BaseQuantPass, + model_path: Path, + output_path: Path, + use_external_data: bool, + ) -> QuantizeResult: + try: + return pass_.run(model_path, output_path, use_external_data=use_external_data) + except Exception: + logger.exception("Pass %s failed", type(pass_).__name__) + return QuantizeResult( + success=False, + output_path=None, + errors=[traceback.format_exc()], + ) + + def _run_multi_pass( + self, + model_path: Path, + output_path: Path, + use_external_data: bool, + ) -> QuantizeResult: + accumulated = QuantizeResult(success=True, output_path=None) + + with tempfile.TemporaryDirectory(prefix="winml_quant_") as tmp_dir: + current_input = model_path + + for i, pass_ in enumerate(self._passes): + is_last = i == len(self._passes) - 1 + if is_last: + current_output = output_path + else: + current_output = Path(tmp_dir) / f"pass_{i}_{type(pass_).__name__}.onnx" + + logger.info( + "Pass %d/%d: %s %s -> %s", + i + 1, + len(self._passes), + type(pass_).__name__, + current_input.name, + current_output.name, + ) + + result = self._run_pass(pass_, current_input, current_output, use_external_data) + accumulated = _merge_results(accumulated, result) + + if not result.success: + logger.error("Pass %s failed — aborting pipeline.", type(pass_).__name__) + break + + current_input = current_output + + return accumulated + + +def _merge_results(base: QuantizeResult, new: QuantizeResult) -> QuantizeResult: + """Merge two QuantizeResult objects, accumulating stats.""" return QuantizeResult( - success=True, - output_path=output_path, - total_time_seconds=total_time, - errors=errors, - warnings=warnings, + success=base.success and new.success, + output_path=new.output_path if new.output_path is not None else base.output_path, + calibration_path=new.calibration_path or base.calibration_path, + calibration_time_seconds=base.calibration_time_seconds + new.calibration_time_seconds, + qdq_insertion_time_seconds=base.qdq_insertion_time_seconds + new.qdq_insertion_time_seconds, + postproc_time_seconds=base.postproc_time_seconds + new.postproc_time_seconds, + total_time_seconds=base.total_time_seconds + new.total_time_seconds, + nodes_quantized=base.nodes_quantized + new.nodes_quantized, + nodes_skipped=base.nodes_skipped + new.nodes_skipped, + errors=base.errors + new.errors, + warnings=base.warnings + new.warnings, ) -def _quantize_qdq( - *, - model_path: Path, - output_path: Path, - config: WinMLQuantizationConfig, - start_time: float, - use_external_data: bool, - errors: list[str], - warnings: list[str], +def quantize_onnx( + model_path: str | Path, + output_path: str | Path | None = None, + config: WinMLQuantizationConfig | None = None, + **kwargs: Any, ) -> QuantizeResult: - """Run QDQ (static/dynamic) calibrated quantization.""" - from onnxruntime.quantization import ( - CalibrationMethod, - QuantType, - get_qdq_config, - quantize, - ) - - weight_type_map = { - "uint8": QuantType.QUInt8, - "int8": QuantType.QInt8, - "uint16": QuantType.QUInt16, - "int16": QuantType.QInt16, - } - activation_type_map = { - "uint8": QuantType.QUInt8, - "int8": QuantType.QInt8, - "uint16": QuantType.QUInt16, - "int16": QuantType.QInt16, - } - calibration_method_map = { - "minmax": CalibrationMethod.MinMax, - "entropy": CalibrationMethod.Entropy, - "percentile": CalibrationMethod.Percentile, - } - - cal_start = time.perf_counter() - - if config.calibration_data is not None: - data_reader = config.calibration_data - logger.info("Using custom calibration data") - else: - from ..datasets import DatasetCalibrationReader - - task = config.task or "random" - data_reader = DatasetCalibrationReader( - model_name=config.model_id or "random", - task=task, - max_samples=config.samples, - dataset_name=config.dataset_name, - model_path=model_path, - ) - logger.info( - "Using calibration: task=%s, samples=%d", - task, - config.samples, - ) - - cal_time = time.perf_counter() - cal_start - - qdq_start = time.perf_counter() - - weight_type = weight_type_map[config.weight_type] - activation_type = activation_type_map[config.activation_type] - calibrate_method = calibration_method_map[config.calibration_method] - - extra_options = { - "ActivationSymmetric": config.symmetric, - "WeightSymmetric": config.symmetric, - } - - logger.info("Generating QDQ config...") - qdq_config = get_qdq_config( - model_input=str(model_path), - calibration_data_reader=data_reader, - weight_type=weight_type, - activation_type=activation_type, - per_channel=config.per_channel, - calibrate_method=calibrate_method, - op_types_to_quantize=config.op_types_to_quantize, - nodes_to_exclude=config.nodes_to_exclude or [], - extra_options=extra_options, - ) - - # Load the input model, capture its metadata snapshot (ORT rebuilds the - # graph during quantization, so we restore afterwards), and tag it as - # pre-processed so quantize_static() does not emit a warning. - from onnxruntime.quantization.quant_utils import add_pre_process_metadata - - from ..onnx import capture_metadata, load_onnx, restore_metadata, save_onnx - from .qdq_fix import fix_qdq_dtype_info - - input_model = load_onnx(model_path, validate=False) - metadata_snapshot = capture_metadata(input_model) - add_pre_process_metadata(input_model) - - if use_external_data: - qdq_config.use_external_data_format = True - logger.info("Applying quantization...") - # Temporarily change CWD to output directory so ORT's save_model_to_file() - # resolves its CWD-relative os.path.exists() check correctly. - abs_model_output = str(Path(output_path).resolve()) - # Remove stale output artifacts from a previous build - if output_path.exists(): - output_path.unlink() - stale_sidecar = output_path.parent / f"{output_path.name}.data" - if stale_sidecar.exists(): - stale_sidecar.unlink() - original_cwd = Path.cwd() - try: - os.chdir(output_path.parent) - quantize( - model_input=input_model, - model_output=abs_model_output, - quant_config=qdq_config, - ) - finally: - os.chdir(original_cwd) - - qdq_time = time.perf_counter() - qdq_start - - # Post-processing: fix QDQ dtype + shape inference + restore metadata - postproc_start = time.perf_counter() - - quantized_model = load_onnx(output_path, validate=False) + """Quantize an ONNX model. - logger.info("Fixing QDQ node dtype info...") - fix_result = fix_qdq_dtype_info(quantized_model) - warnings.extend(fix_result.warnings) + Backward-compatible entry point. Internally builds a :class:`Quantizer` + pipeline from ``config.mode`` via :func:`expand_precision`. - from ..onnx import infer_shapes - - logger.info("Running shape inference on quantized model...") - quantized_model = infer_shapes(quantized_model) - - if metadata_snapshot.node_count > 0: - logger.info("Restoring metadata from pre-quantization model...") - restore_metadata(quantized_model, metadata_snapshot) - - postproc_time = time.perf_counter() - postproc_start + The quantization mode is driven by ``config.mode``: - from ..compiler import QDQ_OP_TYPES + - ``"fp16"`` — FP16 conversion (no quantization) + - ``"rtn"`` — RTN weight-only quantization + - ``"static"`` / ``"dynamic"`` — QDQ calibrated quantization + - ``"w4a16"`` — RTN int4 followed by FP16 conversion - nodes_quantized = sum(1 for node in quantized_model.graph.node if node.op_type in QDQ_OP_TYPES) + Args: + model_path: Path to input ONNX model. + output_path: Path for output model (defaults to ``{model_stem}_qdq.onnx``). + config: Quantization configuration (uses defaults if *None*). - save_onnx(quantized_model, output_path) + Returns: + :class:`QuantizeResult` with path to final output model and metrics. - total_time = time.perf_counter() - start_time + Examples: + >>> result = quantize_onnx("model.onnx", config=WinMLQuantizationConfig(mode="rtn")) + >>> result = quantize_onnx("model.onnx", config=WinMLQuantizationConfig(mode="fp16")) + >>> result = quantize_onnx("model.onnx", config=WinMLQuantizationConfig(mode="w4a16")) + """ + model_path = Path(model_path) + config = config or WinMLQuantizationConfig() - logger.info( - "Quantization complete: %s -> %s (%.2fs)", - model_path.name, - output_path.name, - total_time, - ) + if output_path is not None: + output_path = Path(output_path) + else: + output_path = model_path.parent / f"{model_path.stem}_qdq.onnx" - return QuantizeResult( - success=True, - output_path=output_path, - calibration_time_seconds=cal_time, - qdq_insertion_time_seconds=qdq_time, - postproc_time_seconds=postproc_time, - total_time_seconds=total_time, - nodes_quantized=nodes_quantized, - errors=errors, - warnings=warnings, - ) + use_external_data: bool = kwargs.pop("use_external_data", True) + passes = expand_precision(config.mode, config) + return Quantizer(passes).run(model_path, output_path, use_external_data=use_external_data) diff --git a/tests/unit/test_quant_passes.py b/tests/unit/test_quant_passes.py new file mode 100644 index 000000000..774c2b510 --- /dev/null +++ b/tests/unit/test_quant_passes.py @@ -0,0 +1,398 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for quantization passes and the Quantizer pipeline.""" + +from __future__ import annotations + +import sys +from types import ModuleType, SimpleNamespace +from typing import TYPE_CHECKING, Any + +import pytest + +from winml.modelkit.quant import WinMLQuantizationConfig +from winml.modelkit.quant.config import QuantizeResult +from winml.modelkit.quant.passes import BaseQuantPass, FP16Pass, QDQPass, RTNPass + + +if TYPE_CHECKING: + from pathlib import Path + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _ok_result(output_path: Path, *, total_time: float = 0.1) -> QuantizeResult: + return QuantizeResult(success=True, output_path=output_path, total_time_seconds=total_time) + + +def _fail_result() -> QuantizeResult: + return QuantizeResult(success=False, output_path=None, errors=["boom"]) + + +class _StubPass(BaseQuantPass): + """Configurable stub pass for Quantizer pipeline tests.""" + + def __init__(self, config: WinMLQuantizationConfig, *, succeed: bool = True) -> None: + super().__init__(config) + self.called_with: list[tuple[Path, Path]] = [] + self._succeed = succeed + + def run( + self, + model_path: Path, + output_path: Path, + *, + use_external_data: bool = True, + ) -> QuantizeResult: + self.called_with.append((model_path, output_path)) + if self._succeed: + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text("model") + return _ok_result(output_path) + return _fail_result() + + +# --------------------------------------------------------------------------- +# expand_precision +# --------------------------------------------------------------------------- + + +class TestExpandPrecision: + def test_fp16_returns_fp16_pass(self) -> None: + from winml.modelkit.quant.quantizer import expand_precision + + config = WinMLQuantizationConfig(mode="fp16") + passes = expand_precision("fp16", config) + assert len(passes) == 1 + assert isinstance(passes[0], FP16Pass) + assert passes[0].config is config + + def test_rtn_returns_rtn_pass(self) -> None: + from winml.modelkit.quant.quantizer import expand_precision + + config = WinMLQuantizationConfig(mode="rtn") + passes = expand_precision("rtn", config) + assert len(passes) == 1 + assert isinstance(passes[0], RTNPass) + assert passes[0].config is config + + def test_static_returns_qdq_pass(self) -> None: + from winml.modelkit.quant.quantizer import expand_precision + + config = WinMLQuantizationConfig(mode="static") + passes = expand_precision("static", config) + assert len(passes) == 1 + assert isinstance(passes[0], QDQPass) + + def test_dynamic_returns_qdq_pass(self) -> None: + from winml.modelkit.quant.quantizer import expand_precision + + config = WinMLQuantizationConfig(mode="static") + passes = expand_precision("dynamic", config) + assert len(passes) == 1 + assert isinstance(passes[0], QDQPass) + + def test_w4a16_returns_rtn_then_fp16(self) -> None: + from winml.modelkit.quant.quantizer import expand_precision + + config = WinMLQuantizationConfig(mode="rtn", rtn_bits=4) + passes = expand_precision("w4a16", config) + assert len(passes) == 2 + assert isinstance(passes[0], RTNPass) + assert isinstance(passes[1], FP16Pass) + # All passes share the same config object + assert passes[0].config is config + assert passes[1].config is config + + def test_unknown_mode_raises(self) -> None: + from winml.modelkit.quant.quantizer import expand_precision + + with pytest.raises(ValueError, match="Unknown precision mode"): + expand_precision("int8_only") + + def test_none_config_uses_default(self) -> None: + from winml.modelkit.quant.quantizer import expand_precision + + passes = expand_precision("fp16") + assert len(passes) == 1 + assert isinstance(passes[0].config, WinMLQuantizationConfig) + + +# --------------------------------------------------------------------------- +# Quantizer — single pass +# --------------------------------------------------------------------------- + + +class TestQuantizerSinglePass: + def test_single_pass_calls_run_with_correct_paths(self, tmp_path: Path) -> None: + from winml.modelkit.quant.quantizer import Quantizer + + config = WinMLQuantizationConfig() + model = tmp_path / "model.onnx" + model.write_text("x") + out = tmp_path / "out.onnx" + + stub = _StubPass(config) + result = Quantizer([stub]).run(model, out) + + assert result.success + assert result.output_path == out + assert stub.called_with == [(model, out)] + + def test_single_pass_missing_model_returns_failure(self, tmp_path: Path) -> None: + from winml.modelkit.quant.quantizer import Quantizer + + config = WinMLQuantizationConfig() + stub = _StubPass(config) + result = Quantizer([stub]).run(tmp_path / "missing.onnx", tmp_path / "out.onnx") + + assert not result.success + assert "not found" in result.errors[0].lower() + assert stub.called_with == [] + + def test_single_pass_exception_returns_failure(self, tmp_path: Path) -> None: + from winml.modelkit.quant.quantizer import Quantizer + + config = WinMLQuantizationConfig() + model = tmp_path / "model.onnx" + model.write_text("x") + + class _ExplodingPass(BaseQuantPass): + def run(self, model_path, output_path, *, use_external_data=True): + raise RuntimeError("kaboom") + + result = Quantizer([_ExplodingPass(config)]).run(model, tmp_path / "out.onnx") + assert not result.success + assert any("kaboom" in e for e in result.errors) + + def test_empty_passes_raises(self) -> None: + from winml.modelkit.quant.quantizer import Quantizer + + with pytest.raises(ValueError, match="at least one pass"): + Quantizer([]) + + +# --------------------------------------------------------------------------- +# Quantizer — multi-pass chaining +# --------------------------------------------------------------------------- + + +class TestQuantizerMultiPass: + def test_multi_pass_chains_input_output(self, tmp_path: Path) -> None: + from winml.modelkit.quant.quantizer import Quantizer + + config = WinMLQuantizationConfig() + model = tmp_path / "model.onnx" + model.write_text("x") + final_out = tmp_path / "final.onnx" + + p1 = _StubPass(config) + p2 = _StubPass(config) + result = Quantizer([p1, p2]).run(model, final_out) + + assert result.success + assert result.output_path == final_out + + # p1 receives the original model; p2 receives p1's output (a temp file) + p1_input, p1_output = p1.called_with[0] + p2_input, p2_output = p2.called_with[0] + + assert p1_input == model + assert p1_output != final_out # intermediate temp file + assert p2_input == p1_output # chained correctly + assert p2_output == final_out + + def test_multi_pass_stats_are_merged(self, tmp_path: Path) -> None: + from winml.modelkit.quant.quantizer import Quantizer + + config = WinMLQuantizationConfig() + model = tmp_path / "model.onnx" + model.write_text("x") + + class _TimedPass(BaseQuantPass): + def __init__(self, cfg, *, nodes: int, time_s: float) -> None: + super().__init__(cfg) + self._nodes = nodes + self._time = time_s + + def run(self, model_path, output_path, *, use_external_data=True): + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text("x") + return QuantizeResult( + success=True, + output_path=output_path, + nodes_quantized=self._nodes, + total_time_seconds=self._time, + ) + + p1 = _TimedPass(config, nodes=10, time_s=1.0) + p2 = _TimedPass(config, nodes=5, time_s=2.0) + result = Quantizer([p1, p2]).run(model, tmp_path / "out.onnx") + + assert result.nodes_quantized == 15 + assert abs(result.total_time_seconds - 3.0) < 1e-9 + + def test_multi_pass_aborts_on_failure(self, tmp_path: Path) -> None: + from winml.modelkit.quant.quantizer import Quantizer + + config = WinMLQuantizationConfig() + model = tmp_path / "model.onnx" + model.write_text("x") + + p1 = _StubPass(config, succeed=False) + p2 = _StubPass(config, succeed=True) + + result = Quantizer([p1, p2]).run(model, tmp_path / "out.onnx") + + assert not result.success + assert p2.called_with == [] # p2 never called + + def test_multi_pass_warnings_concatenated(self, tmp_path: Path) -> None: + from winml.modelkit.quant.quantizer import Quantizer + + config = WinMLQuantizationConfig() + model = tmp_path / "model.onnx" + model.write_text("x") + + class _WarnPass(BaseQuantPass): + def __init__(self, cfg, msg: str) -> None: + super().__init__(cfg) + self._msg = msg + + def run(self, model_path, output_path, *, use_external_data=True): + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text("x") + return QuantizeResult(success=True, output_path=output_path, warnings=[self._msg]) + + result = Quantizer( + [ + _WarnPass(config, "w1"), + _WarnPass(config, "w2"), + ] + ).run(model, tmp_path / "out.onnx") + + assert result.warnings == ["w1", "w2"] + + +# --------------------------------------------------------------------------- +# FP16Pass — config field wiring +# --------------------------------------------------------------------------- + + +class TestFP16PassConfig: + def test_reads_fp16_fields_from_config( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """FP16Pass should pass fp16_keep_io_types and fp16_op_block_list to convert_to_fp16.""" + config = WinMLQuantizationConfig( + mode="fp16", + fp16_keep_io_types=False, + fp16_op_block_list=["Gather"], + ) + model_path = tmp_path / "model.onnx" + model_path.write_text("x") + output_path = tmp_path / "out.onnx" + + calls: list[dict] = [] + fake_model = SimpleNamespace() + + def fake_convert(model, *, keep_io_types, op_block_list): + calls.append({"keep_io_types": keep_io_types, "op_block_list": op_block_list}) + return model + + # Patch the source modules that are lazily imported inside run() + fake_onnx_mod = ModuleType("winml.modelkit.onnx") + fake_onnx_mod.load_onnx = lambda *a, **k: fake_model # type: ignore[attr-defined] + fake_onnx_mod.save_onnx = lambda *a, **k: None # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "winml.modelkit.onnx", fake_onnx_mod) + + fake_fp16_mod = ModuleType("winml.modelkit.quant.fp16") + fake_fp16_mod.convert_to_fp16 = fake_convert # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "winml.modelkit.quant.fp16", fake_fp16_mod) + + result = FP16Pass(config).run(model_path, output_path) + + assert result.success + assert calls == [{"keep_io_types": False, "op_block_list": ["Gather"]}] + + +# --------------------------------------------------------------------------- +# RTNPass — config field wiring +# --------------------------------------------------------------------------- + + +def _install_fake_ort_nbits( + monkeypatch: pytest.MonkeyPatch, + fake_quantized_model: Any, + init_kwargs: list[dict], +) -> None: + """Install a minimal fake MatMulNBitsQuantizer into sys.modules.""" + + class FakeMatMulNBitsQuantizer: + def __init__(self, **kwargs: Any) -> None: + init_kwargs.append(kwargs) + + def process(self) -> None: + pass + + model = SimpleNamespace(model=fake_quantized_model) + + fake_ort_quant = ModuleType("onnxruntime.quantization.matmul_nbits_quantizer") + fake_ort_quant.MatMulNBitsQuantizer = FakeMatMulNBitsQuantizer # type: ignore[attr-defined] + monkeypatch.setitem( + sys.modules, + "onnxruntime.quantization.matmul_nbits_quantizer", + fake_ort_quant, + ) + + fake_onnx_mod = ModuleType("winml.modelkit.onnx") + fake_onnx_mod.save_onnx = lambda *a, **k: None # type: ignore[attr-defined] + monkeypatch.setitem(sys.modules, "winml.modelkit.onnx", fake_onnx_mod) + + +class TestRTNPassConfig: + def test_reads_rtn_fields_from_config( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + """RTNPass should forward all rtn_* fields to MatMulNBitsQuantizer.""" + config = WinMLQuantizationConfig( + mode="rtn", + rtn_bits=8, + rtn_block_size=64, + rtn_symmetric=False, + rtn_accuracy_level=2, + ) + model_path = tmp_path / "model.onnx" + model_path.write_text("x") + output_path = tmp_path / "out.onnx" + + init_kwargs: list[dict] = [] + fake_quantized_model = SimpleNamespace(graph=SimpleNamespace(node=[])) + _install_fake_ort_nbits(monkeypatch, fake_quantized_model, init_kwargs) + + result = RTNPass(config).run(model_path, output_path) + + assert result.success + assert init_kwargs[0]["bits"] == 8 + assert init_kwargs[0]["block_size"] == 64 + assert init_kwargs[0]["is_symmetric"] is False + assert init_kwargs[0]["accuracy_level"] == 2 + + def test_accuracy_level_zero_maps_to_none( + self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch + ) -> None: + config = WinMLQuantizationConfig(mode="rtn", rtn_accuracy_level=0) + model_path = tmp_path / "model.onnx" + model_path.write_text("x") + + init_kwargs: list[dict] = [] + fake_quantized_model = SimpleNamespace() + _install_fake_ort_nbits(monkeypatch, fake_quantized_model, init_kwargs) + + RTNPass(config).run(model_path, tmp_path / "out.onnx") + assert init_kwargs[0]["accuracy_level"] is None From 4d117c0f0fa3f461baaf8391ff2eb760125c3652 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 26 Jun 2026 13:41:13 +0800 Subject: [PATCH 2/6] fix(quant): address review findings - QDQPass.run(): forward use_external_data to final save_onnx call - WinMLQuantizationConfig: add 'w4a16' to mode Literal; to_dict() now serialises rtn_* and fp16_* fields when mode is 'w4a16' - quantize_onnx(): raise TypeError on unrecognised kwargs instead of silently discarding them - Tests: add TestW4a16Config (3 cases) and TestQuantizeOnnxKwargsGuard (1 case) --- src/winml/modelkit/quant/config.py | 11 ++++-- src/winml/modelkit/quant/passes/qdq.py | 2 +- src/winml/modelkit/quant/quantizer.py | 2 + tests/unit/test_quant_passes.py | 52 +++++++++++++++++++++++++- 4 files changed, 62 insertions(+), 5 deletions(-) diff --git a/src/winml/modelkit/quant/config.py b/src/winml/modelkit/quant/config.py index 3be32ed0a..cbf92281e 100644 --- a/src/winml/modelkit/quant/config.py +++ b/src/winml/modelkit/quant/config.py @@ -51,14 +51,19 @@ class WinMLQuantizationConfig: # FP16 conversion (pure FP16, no quantization) config = WinMLQuantizationConfig(mode="fp16") result = quantize_onnx("model.onnx", config) + + # RTN int4 + FP16 (w4a16) + config = WinMLQuantizationConfig(mode="w4a16", rtn_bits=4) + result = quantize_onnx("model.onnx", config) """ # Quantization mode - mode: Literal["static", "dynamic", "rtn", "fp16"] = "static" + mode: Literal["static", "dynamic", "rtn", "fp16", "w4a16"] = "static" # "static" — Calibrated QDQ quantization (requires calibration data) # "dynamic" — Dynamic quantization (no calibration) [planned, not yet wired] # "rtn" — Round-To-Nearest weight-only (no calibration, block-wise) # "fp16" — Pure FP16 conversion only (no quantization) + # "w4a16" — RTN int4 weight quantization followed by FP16 conversion # Calibration settings (static/dynamic) samples: int = 10 @@ -135,12 +140,12 @@ def to_dict(self) -> dict: result["model_id"] = self.model_id if self.dataset_name is not None: result["dataset_name"] = self.dataset_name - if self.mode == "rtn": + if self.mode in ("rtn", "w4a16"): result["rtn_bits"] = self.rtn_bits result["rtn_block_size"] = self.rtn_block_size result["rtn_symmetric"] = self.rtn_symmetric result["rtn_accuracy_level"] = self.rtn_accuracy_level - if self.mode == "fp16": + if self.mode in ("fp16", "w4a16"): result["fp16_keep_io_types"] = self.fp16_keep_io_types result["fp16_op_block_list"] = self.fp16_op_block_list return result diff --git a/src/winml/modelkit/quant/passes/qdq.py b/src/winml/modelkit/quant/passes/qdq.py index ed2a68598..7c2be2de4 100644 --- a/src/winml/modelkit/quant/passes/qdq.py +++ b/src/winml/modelkit/quant/passes/qdq.py @@ -185,7 +185,7 @@ def run( 1 for node in quantized_model.graph.node if node.op_type in QDQ_OP_TYPES ) - save_onnx(quantized_model, output_path) + save_onnx(quantized_model, output_path, use_external_data=use_external_data) total_time = time.perf_counter() - start_time diff --git a/src/winml/modelkit/quant/quantizer.py b/src/winml/modelkit/quant/quantizer.py index bd45cb0b9..16914eb78 100644 --- a/src/winml/modelkit/quant/quantizer.py +++ b/src/winml/modelkit/quant/quantizer.py @@ -258,5 +258,7 @@ def quantize_onnx( output_path = model_path.parent / f"{model_path.stem}_qdq.onnx" use_external_data: bool = kwargs.pop("use_external_data", True) + if kwargs: + raise TypeError(f"quantize_onnx() got unexpected keyword arguments: {sorted(kwargs)}") passes = expand_precision(config.mode, config) return Quantizer(passes).run(model_path, output_path, use_external_data=use_external_data) diff --git a/tests/unit/test_quant_passes.py b/tests/unit/test_quant_passes.py index 774c2b510..bc9c54592 100644 --- a/tests/unit/test_quant_passes.py +++ b/tests/unit/test_quant_passes.py @@ -100,7 +100,7 @@ def test_dynamic_returns_qdq_pass(self) -> None: def test_w4a16_returns_rtn_then_fp16(self) -> None: from winml.modelkit.quant.quantizer import expand_precision - config = WinMLQuantizationConfig(mode="rtn", rtn_bits=4) + config = WinMLQuantizationConfig(mode="w4a16", rtn_bits=4) passes = expand_precision("w4a16", config) assert len(passes) == 2 assert isinstance(passes[0], RTNPass) @@ -123,6 +123,39 @@ def test_none_config_uses_default(self) -> None: assert isinstance(passes[0].config, WinMLQuantizationConfig) +# --------------------------------------------------------------------------- +# WinMLQuantizationConfig — w4a16 mode +# --------------------------------------------------------------------------- + + +class TestW4a16Config: + def test_w4a16_mode_is_valid(self) -> None: + """WinMLQuantizationConfig should accept mode='w4a16' without type error.""" + config = WinMLQuantizationConfig(mode="w4a16", rtn_bits=4, fp16_keep_io_types=False) + assert config.mode == "w4a16" + + def test_to_dict_includes_rtn_and_fp16_fields(self) -> None: + """to_dict() must serialise both rtn_* and fp16_* fields when mode='w4a16'.""" + config = WinMLQuantizationConfig( + mode="w4a16", + rtn_bits=8, + rtn_block_size=64, + fp16_keep_io_types=False, + fp16_op_block_list=["Gather"], + ) + d = config.to_dict() + assert d["rtn_bits"] == 8 + assert d["rtn_block_size"] == 64 + assert d["fp16_keep_io_types"] is False + assert d["fp16_op_block_list"] == ["Gather"] + + def test_two_w4a16_configs_produce_distinct_dicts(self) -> None: + """Different rtn_bits must produce different to_dict() output (cache key stability).""" + c1 = WinMLQuantizationConfig(mode="w4a16", rtn_bits=4) + c2 = WinMLQuantizationConfig(mode="w4a16", rtn_bits=8) + assert c1.to_dict()["rtn_bits"] != c2.to_dict()["rtn_bits"] + + # --------------------------------------------------------------------------- # Quantizer — single pass # --------------------------------------------------------------------------- @@ -177,6 +210,23 @@ def test_empty_passes_raises(self) -> None: Quantizer([]) +# --------------------------------------------------------------------------- +# quantize_onnx — kwargs guard +# --------------------------------------------------------------------------- + + +class TestQuantizeOnnxKwargsGuard: + def test_unexpected_kwarg_raises_type_error(self, tmp_path: Path) -> None: + """quantize_onnx must raise TypeError on unrecognised kwargs.""" + from winml.modelkit.quant import quantize_onnx + + model_path = tmp_path / "model.onnx" + model_path.write_text("x") + + with pytest.raises(TypeError, match="unexpected keyword arguments"): + quantize_onnx(model_path, use_external_data_format=False) + + # --------------------------------------------------------------------------- # Quantizer — multi-pass chaining # --------------------------------------------------------------------------- From 36cc1da500258f0fc47a9dc0540b5a2d44a54649 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 26 Jun 2026 15:52:01 +0800 Subject: [PATCH 3/6] fix: resolve mypy and CodeQL issues in quant __init__ - Add TYPE_CHECKING import block for Quantizer, expand_precision, and quantize_onnx so mypy resolves their types instead of falling back to Any? (fixes 'Any? not callable [misc]' in hf.py and onnx.py) - Same TYPE_CHECKING imports satisfy CodeQL's 'Explicit export is not defined' alerts for those names in __all__ - Remove trailing ... after docstring in BaseQuantPass.run() to fix CodeQL 'Statement has no effect' alert --- src/winml/modelkit/quant/__init__.py | 8 ++++++-- src/winml/modelkit/quant/passes/base.py | 1 - 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index 070ecf86f..c0ab0d7de 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -25,12 +25,16 @@ result = Quantizer(expand_precision("w4a16", config)).run("model.onnx", "out.onnx") """ -from typing import Any +from typing import TYPE_CHECKING, Any from .config import QuantizeResult, WinMLQuantizationConfig from .passes import BaseQuantPass, FP16Pass, QDQPass, RTNPass +if TYPE_CHECKING: + from .quantizer import Quantizer, expand_precision, quantize_onnx + + __all__ = [ "BaseQuantPass", "FP16Pass", @@ -52,7 +56,7 @@ def __getattr__(name: str) -> Any: - """Lazy-load quantizer (imports onnxruntime.quantization).""" + """Lazy-load quantizer module (avoids importing onnxruntime at package import time).""" if name in _LAZY_IMPORTS: module_path, attr_name = _LAZY_IMPORTS[name] import importlib diff --git a/src/winml/modelkit/quant/passes/base.py b/src/winml/modelkit/quant/passes/base.py index 611c490c2..3d3b4a23f 100644 --- a/src/winml/modelkit/quant/passes/base.py +++ b/src/winml/modelkit/quant/passes/base.py @@ -57,4 +57,3 @@ def run( :class:`~winml.modelkit.quant.config.QuantizeResult` describing the outcome of this pass. """ - ... From 519b5626b1857555437b8e77e6fe2005b5d67840 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 26 Jun 2026 15:56:20 +0800 Subject: [PATCH 4/6] fix(quant): remove w4a16 from WinMLQuantizationConfig.mode and rebase onto main w4a16 is a composite pipeline concept, not a single-pass quantization mode. Removing it from the mode Literal keeps config.py focused on atomic pass modes (static, dynamic, rtn, fp16). Multi-pass pipelines are expressed through Quantizer + expand_precision at a higher level. Changes: - config.py: revert mode Literal to [static, dynamic, rtn, fp16], revert to_dict() guards back to equality checks, remove w4a16 docstring example - quantizer.py: remove w4a16 from _COMPOSITE_PRECISIONS and docstrings - __init__.py: update module docstring example - commands/build.py: fix stale 'single-pass' comment - tests: remove TestW4a16Config and test_w4a16_returns_rtn_then_fp16 --- src/winml/modelkit/commands/build.py | 2 +- src/winml/modelkit/quant/__init__.py | 6 ++-- src/winml/modelkit/quant/config.py | 11 ++----- src/winml/modelkit/quant/quantizer.py | 15 +++------ tests/unit/test_quant_passes.py | 45 --------------------------- 5 files changed, 12 insertions(+), 67 deletions(-) diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index 288ffc3fe..cf0de6364 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -1143,7 +1143,7 @@ def _run_quantize_stage( ) -> Path: """Run the quantize stage (if quant is configured). - Delegates single-pass quantization to ``quantize_onnx(config=...)``. + Delegates quantization to ``quantize_onnx(config=...)``. The cmd layer only handles UI display and the QDQ skip check. Args: diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index c0ab0d7de..202d566bf 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -20,9 +20,9 @@ # Custom config result = quantize_onnx("model.onnx", WinMLQuantizationConfig(samples=100)) - # Pipeline: RTN int4 followed by FP16 (w4a16) - config = WinMLQuantizationConfig(mode="w4a16", rtn_bits=4) - result = Quantizer(expand_precision("w4a16", config)).run("model.onnx", "out.onnx") + # Pipeline: RTN int4 weight-only + config = WinMLQuantizationConfig(mode="rtn", rtn_bits=4) + result = Quantizer(expand_precision("rtn", config)).run("model.onnx", "out.onnx") """ from typing import TYPE_CHECKING, Any diff --git a/src/winml/modelkit/quant/config.py b/src/winml/modelkit/quant/config.py index cbf92281e..3be32ed0a 100644 --- a/src/winml/modelkit/quant/config.py +++ b/src/winml/modelkit/quant/config.py @@ -51,19 +51,14 @@ class WinMLQuantizationConfig: # FP16 conversion (pure FP16, no quantization) config = WinMLQuantizationConfig(mode="fp16") result = quantize_onnx("model.onnx", config) - - # RTN int4 + FP16 (w4a16) - config = WinMLQuantizationConfig(mode="w4a16", rtn_bits=4) - result = quantize_onnx("model.onnx", config) """ # Quantization mode - mode: Literal["static", "dynamic", "rtn", "fp16", "w4a16"] = "static" + mode: Literal["static", "dynamic", "rtn", "fp16"] = "static" # "static" — Calibrated QDQ quantization (requires calibration data) # "dynamic" — Dynamic quantization (no calibration) [planned, not yet wired] # "rtn" — Round-To-Nearest weight-only (no calibration, block-wise) # "fp16" — Pure FP16 conversion only (no quantization) - # "w4a16" — RTN int4 weight quantization followed by FP16 conversion # Calibration settings (static/dynamic) samples: int = 10 @@ -140,12 +135,12 @@ def to_dict(self) -> dict: result["model_id"] = self.model_id if self.dataset_name is not None: result["dataset_name"] = self.dataset_name - if self.mode in ("rtn", "w4a16"): + if self.mode == "rtn": result["rtn_bits"] = self.rtn_bits result["rtn_block_size"] = self.rtn_block_size result["rtn_symmetric"] = self.rtn_symmetric result["rtn_accuracy_level"] = self.rtn_accuracy_level - if self.mode in ("fp16", "w4a16"): + if self.mode == "fp16": result["fp16_keep_io_types"] = self.fp16_keep_io_types result["fp16_op_block_list"] = self.fp16_op_block_list return result diff --git a/src/winml/modelkit/quant/quantizer.py b/src/winml/modelkit/quant/quantizer.py index 16914eb78..83752e5ca 100644 --- a/src/winml/modelkit/quant/quantizer.py +++ b/src/winml/modelkit/quant/quantizer.py @@ -19,9 +19,7 @@ logger = logging.getLogger(__name__) # Precision strings that expand to multiple sequential passes. -_COMPOSITE_PRECISIONS: dict[str, list[str]] = { - "w4a16": ["rtn", "fp16"], -} +_COMPOSITE_PRECISIONS: dict[str, list[str]] = {} def expand_precision( @@ -42,11 +40,10 @@ def expand_precision( ``rtn`` ``[RTNPass(config)]`` ``static`` ``[QDQPass(config)]`` ``dynamic`` ``[QDQPass(config)]`` - ``w4a16`` ``[RTNPass(config), FP16Pass(config)]`` ========= ======================= Args: - mode: Precision string (e.g. ``"w4a16"``). + mode: Precision string (e.g. ``"fp16"``). config: Shared quantization configuration. If *None*, a default :class:`WinMLQuantizationConfig` is used. @@ -98,9 +95,9 @@ class Quantizer: from winml.modelkit.quant import Quantizer, expand_precision, WinMLQuantizationConfig - config = WinMLQuantizationConfig(mode="w4a16", rtn_bits=4) - quantizer = Quantizer(expand_precision("w4a16", config)) - result = quantizer.run("model.onnx", "model_w4a16.onnx") + config = WinMLQuantizationConfig(mode="rtn", rtn_bits=4) + quantizer = Quantizer(expand_precision("rtn", config)) + result = quantizer.run("model.onnx", "model_rtn.onnx") """ def __init__(self, passes: list[BaseQuantPass]) -> None: @@ -234,7 +231,6 @@ def quantize_onnx( - ``"fp16"`` — FP16 conversion (no quantization) - ``"rtn"`` — RTN weight-only quantization - ``"static"`` / ``"dynamic"`` — QDQ calibrated quantization - - ``"w4a16"`` — RTN int4 followed by FP16 conversion Args: model_path: Path to input ONNX model. @@ -247,7 +243,6 @@ def quantize_onnx( Examples: >>> result = quantize_onnx("model.onnx", config=WinMLQuantizationConfig(mode="rtn")) >>> result = quantize_onnx("model.onnx", config=WinMLQuantizationConfig(mode="fp16")) - >>> result = quantize_onnx("model.onnx", config=WinMLQuantizationConfig(mode="w4a16")) """ model_path = Path(model_path) config = config or WinMLQuantizationConfig() diff --git a/tests/unit/test_quant_passes.py b/tests/unit/test_quant_passes.py index bc9c54592..22e10e868 100644 --- a/tests/unit/test_quant_passes.py +++ b/tests/unit/test_quant_passes.py @@ -97,18 +97,6 @@ def test_dynamic_returns_qdq_pass(self) -> None: assert len(passes) == 1 assert isinstance(passes[0], QDQPass) - def test_w4a16_returns_rtn_then_fp16(self) -> None: - from winml.modelkit.quant.quantizer import expand_precision - - config = WinMLQuantizationConfig(mode="w4a16", rtn_bits=4) - passes = expand_precision("w4a16", config) - assert len(passes) == 2 - assert isinstance(passes[0], RTNPass) - assert isinstance(passes[1], FP16Pass) - # All passes share the same config object - assert passes[0].config is config - assert passes[1].config is config - def test_unknown_mode_raises(self) -> None: from winml.modelkit.quant.quantizer import expand_precision @@ -123,39 +111,6 @@ def test_none_config_uses_default(self) -> None: assert isinstance(passes[0].config, WinMLQuantizationConfig) -# --------------------------------------------------------------------------- -# WinMLQuantizationConfig — w4a16 mode -# --------------------------------------------------------------------------- - - -class TestW4a16Config: - def test_w4a16_mode_is_valid(self) -> None: - """WinMLQuantizationConfig should accept mode='w4a16' without type error.""" - config = WinMLQuantizationConfig(mode="w4a16", rtn_bits=4, fp16_keep_io_types=False) - assert config.mode == "w4a16" - - def test_to_dict_includes_rtn_and_fp16_fields(self) -> None: - """to_dict() must serialise both rtn_* and fp16_* fields when mode='w4a16'.""" - config = WinMLQuantizationConfig( - mode="w4a16", - rtn_bits=8, - rtn_block_size=64, - fp16_keep_io_types=False, - fp16_op_block_list=["Gather"], - ) - d = config.to_dict() - assert d["rtn_bits"] == 8 - assert d["rtn_block_size"] == 64 - assert d["fp16_keep_io_types"] is False - assert d["fp16_op_block_list"] == ["Gather"] - - def test_two_w4a16_configs_produce_distinct_dicts(self) -> None: - """Different rtn_bits must produce different to_dict() output (cache key stability).""" - c1 = WinMLQuantizationConfig(mode="w4a16", rtn_bits=4) - c2 = WinMLQuantizationConfig(mode="w4a16", rtn_bits=8) - assert c1.to_dict()["rtn_bits"] != c2.to_dict()["rtn_bits"] - - # --------------------------------------------------------------------------- # Quantizer — single pass # --------------------------------------------------------------------------- From 5e8bdad83f11849d2cb2c91258b8a023189217f8 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 26 Jun 2026 16:16:14 +0800 Subject: [PATCH 5/6] feat(quant): rename QDQPass to StaticPass and add multi-precision CLI pipeline MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename: - passes/qdq.py → passes/static.py; QDQPass → StaticPass throughout - Update all imports, __all__, quantizer.py pass_factories, and tests Multi-precision --precision: - precision_option() gains multiple=True support - quantize command accepts repeated --precision flags; len > 1 routes to _run_multi_precision() which chains expand_precision() calls into a single Quantizer pipeline - Default output path for multi-pass: {stem}_{p1}_{p2}.onnx - Calibration-unused warning emitted when no static pass is in pipeline E2E tests (TestMultiPrecision): - test_int4_then_fp16_pipeline: verifies MatMulNBits nodes (RTN) and FLOAT16 initializers (FP16 pass) are both present in output - test_pipeline_default_output_path: verifies auto-named output file --- src/winml/modelkit/commands/quantize.py | 151 ++++++++++++++++-- src/winml/modelkit/quant/__init__.py | 4 +- src/winml/modelkit/quant/passes/__init__.py | 4 +- .../quant/passes/{qdq.py => static.py} | 8 +- src/winml/modelkit/quant/quantizer.py | 10 +- src/winml/modelkit/utils/cli.py | 8 +- tests/e2e/test_quantize_e2e.py | 95 +++++++++++ tests/unit/test_quant_passes.py | 6 +- 8 files changed, 256 insertions(+), 30 deletions(-) rename src/winml/modelkit/quant/passes/{qdq.py => static.py} (97%) diff --git a/src/winml/modelkit/commands/quantize.py b/src/winml/modelkit/commands/quantize.py index 079b37ea8..c56ebcced 100644 --- a/src/winml/modelkit/commands/quantize.py +++ b/src/winml/modelkit/commands/quantize.py @@ -50,11 +50,13 @@ @cli_utils.output_option("Output path (default: {input}_qdq.onnx)") @cli_utils.overwrite_option() @cli_utils.precision_option( - default=None, - help_text="Quantization precision: auto, fp16, int4, int8, int16, or w{x}a{y} where " - "x in {4,8,16}, y in {8,16} (e.g., w4a16, w8a8, w8a16). " - "int4/w4a16 uses RTN weight-only quantization; " - "fp16 converts all FP32 tensors to FP16 (no QDQ)", + default=(), + multiple=True, + help_text="Quantization precision: fp16, int4, int8, int16, or w{x}a{y} where " + "x in {4,8,16}, y in {8,16} (e.g., w8a8, w8a16). " + "int4 uses RTN weight-only quantization; " + "fp16 converts all FP32 tensors to FP16 (no QDQ). " + "Repeat to chain passes in order (e.g. -p int4 -p fp16)", optional_message="Overridden by explicit --weight-type/--activation-type", ) @click.option( @@ -111,7 +113,7 @@ def quantize( model: Path, output: Path | None, overwrite: bool, - precision: str | None, + precision: tuple[str, ...], samples: int, method: str, weight_type: str | None, @@ -127,9 +129,12 @@ def quantize( r"""Quantize ONNX model by inserting QDQ nodes, RTN weight-only, or convert to FP16. This command applies quantization to an ONNX model. The algorithm is - auto-selected from the precision: int4/w4a16 → RTN weight-only, + auto-selected from the precision: int4 → RTN weight-only, int8/int16/w8a8 → static QDQ, fp16 → FP16 conversion. + Repeat --precision to chain passes in order: + ``-p int4 -p fp16`` runs RTN int4 quantization then FP16 conversion. + \b Examples: # Basic quantization with defaults (10 samples, uint8) @@ -141,6 +146,9 @@ def quantize( # RTN 4-bit weight-only quantization (no calibration data needed) winml quantize -m model.onnx --precision int4 + # RTN int4 followed by FP16 conversion (two-pass pipeline) + winml quantize -m model.onnx --precision int4 --precision fp16 + # Int16 quantization winml quantize -m model.onnx --precision int16 @@ -182,8 +190,29 @@ def quantize( # Import quantizer (late import to speed up CLI) from ..quant import WinMLQuantizationConfig, quantize_onnx - # ── Build config based on precision ────────────────────────── - precision_lower = precision.lower() if precision else None + # ── Multi-pass pipeline ─────────────────────────────────────── + if len(precision) > 1: + _run_multi_precision( + ctx=ctx, + model=model, + output=output, + overwrite=overwrite, + precision=precision, + samples=samples, + method=method, + weight_type=weight_type, + activation_type=activation_type, + per_channel=per_channel, + symmetric=symmetric, + task=task, + model_id=model_id, + console=console, + ) + return + + # ── Single-precision (or default) path ─────────────────────── + single = precision[0] if precision else None + precision_lower = single.lower() if single else None if precision_lower == "fp16": # FP16 conversion @@ -211,7 +240,7 @@ def quantize( else: # QDQ calibrated quantization resolved_weight, resolved_activation = _resolve_quant_types( - precision, weight_type, activation_type + single, weight_type, activation_type ) if output is None: output = model.parent / f"{model.stem}_qdq.onnx" @@ -243,13 +272,11 @@ def quantize( console.print(f"[bold blue]Dataset:[/bold blue] {_dataset_display}") # ── Shared execution: print header, run, report ────────────── - # Refuse to clobber an existing output unless the user opted in. Runs after - # the per-precision default path is resolved, before any mkdir/work. cli_utils.guard_output(output, overwrite) output.parent.mkdir(parents=True, exist_ok=True) console.print(f"[bold blue]Input:[/bold blue] {model}") console.print(f"[bold blue]Output:[/bold blue] {output}") - console.print(f"[bold blue]Precision:[/bold blue] {precision or 'auto'}") + console.print(f"[bold blue]Precision:[/bold blue] {single or 'auto'}") try: console.print(f"\n[bold]Running {label.lower()}...[/bold]") @@ -275,6 +302,104 @@ def quantize( raise click.ClickException(f"{label} failed: {e}") from e +def _cli_precision_to_mode(precision: str) -> str: + """Map a CLI precision string to a quantizer pass mode.""" + p = precision.lower() + if p == "fp16": + return "fp16" + if is_weight_only_precision(p): + return "rtn" + return "static" + + +def _run_multi_precision( + *, + ctx: click.Context, + model: Path, + output: Path | None, + overwrite: bool, + precision: tuple[str, ...], + samples: int, + method: str, + weight_type: str | None, + activation_type: str | None, + per_channel: bool, + symmetric: bool, + task: str | None, + model_id: str | None, + console: Console, +) -> None: + """Execute a multi-pass quantization pipeline from ordered precision strings.""" + from typing import cast + + from ..config.precision import extract_weight_bits + from ..quant import Quantizer, WinMLQuantizationConfig, expand_precision + + modes = [_cli_precision_to_mode(p) for p in precision] + has_calibration_pass = any(m == "static" for m in modes) + + if not has_calibration_pass: + cli_utils.warn_ignored_calibration_options( + ctx, "No selected pass uses calibration data.", console=console + ) + + # Extract rtn_bits from the first weight-only precision in the list. + rtn_bits = next( + (extract_weight_bits(p.lower()) for p in precision if is_weight_only_precision(p.lower())), + 4, + ) + + config = WinMLQuantizationConfig( + rtn_bits=rtn_bits, + samples=samples, + calibration_method=cast('Literal["minmax", "entropy", "percentile"]', method), + weight_type=cast('Literal["uint8", "int8", "uint16", "int16"]', weight_type or "uint8"), + activation_type=cast( + 'Literal["uint8", "int8", "uint16", "int16"]', activation_type or "uint8" + ), + per_channel=per_channel, + symmetric=symmetric, + task=task, + model_id=model_id, + ) + + passes = [] + for mode in modes: + passes.extend(expand_precision(mode, config)) + + label = " → ".join(p.lower() for p in precision) + if output is None: + suffix = "_".join(p.lower() for p in precision) + output = model.parent / f"{model.stem}_{suffix}.onnx" + + cli_utils.guard_output(output, overwrite) + output.parent.mkdir(parents=True, exist_ok=True) + console.print(f"[bold blue]Input:[/bold blue] {model}") + console.print(f"[bold blue]Output:[/bold blue] {output}") + console.print(f"[bold blue]Pipeline:[/bold blue] {label}") + + try: + console.print(f"\n[bold]Running pipeline: {label}...[/bold]") + result = Quantizer(passes).run(model, output) + + if result.success: + console.print("\n[bold green]Success![/bold green] Pipeline complete") + console.print(f"[dim]Output: {result.output_path}[/dim]") + console.print(f"[dim]Total time: {result.total_time_seconds:.2f}s[/dim]") + else: + console.print("\n[bold red]Pipeline failed:[/bold red]") + for error in result.errors: + console.print(f" {error}") + raise click.ClickException("Pipeline failed") + + except click.ClickException: + raise + except Exception as e: + console.print(f"\n[bold red]Pipeline failed:[/bold red] {e}") + logger.exception("Pipeline failed") + raise click.ClickException(f"Pipeline failed: {e}") from e + + def _resolve_quant_types( precision: str | None, weight_type: str | None, diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index 202d566bf..69a69754a 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Any from .config import QuantizeResult, WinMLQuantizationConfig -from .passes import BaseQuantPass, FP16Pass, QDQPass, RTNPass +from .passes import BaseQuantPass, FP16Pass, RTNPass, StaticPass if TYPE_CHECKING: @@ -38,10 +38,10 @@ __all__ = [ "BaseQuantPass", "FP16Pass", - "QDQPass", "QuantizeResult", "Quantizer", "RTNPass", + "StaticPass", "WinMLQuantizationConfig", "expand_precision", "quantize_onnx", diff --git a/src/winml/modelkit/quant/passes/__init__.py b/src/winml/modelkit/quant/passes/__init__.py index 9f2910a7d..b7f571660 100644 --- a/src/winml/modelkit/quant/passes/__init__.py +++ b/src/winml/modelkit/quant/passes/__init__.py @@ -6,13 +6,13 @@ from .base import BaseQuantPass from .fp16 import FP16Pass -from .qdq import QDQPass from .rtn import RTNPass +from .static import StaticPass __all__ = [ "BaseQuantPass", "FP16Pass", - "QDQPass", "RTNPass", + "StaticPass", ] diff --git a/src/winml/modelkit/quant/passes/qdq.py b/src/winml/modelkit/quant/passes/static.py similarity index 97% rename from src/winml/modelkit/quant/passes/qdq.py rename to src/winml/modelkit/quant/passes/static.py index 7c2be2de4..16cd8b24b 100644 --- a/src/winml/modelkit/quant/passes/qdq.py +++ b/src/winml/modelkit/quant/passes/static.py @@ -2,7 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""QDQ (Quantize-Dequantize) calibrated quantization pass.""" +"""Static (calibrated QDQ) quantization pass.""" from __future__ import annotations @@ -22,8 +22,8 @@ logger = logging.getLogger(__name__) -class QDQPass(BaseQuantPass): - """QDQ (static/dynamic) calibrated quantization pass. +class StaticPass(BaseQuantPass): + """Static calibrated QDQ quantization pass. Reads all QDQ-relevant fields from :class:`~winml.modelkit.quant.config.WinMLQuantizationConfig`: @@ -34,7 +34,7 @@ class QDQPass(BaseQuantPass): Example:: - pass_ = QDQPass(config) + pass_ = StaticPass(config) result = pass_.run("model.onnx", "model_qdq.onnx") """ diff --git a/src/winml/modelkit/quant/quantizer.py b/src/winml/modelkit/quant/quantizer.py index 83752e5ca..8aa9a467c 100644 --- a/src/winml/modelkit/quant/quantizer.py +++ b/src/winml/modelkit/quant/quantizer.py @@ -13,7 +13,7 @@ from typing import Any from .config import QuantizeResult, WinMLQuantizationConfig -from .passes import BaseQuantPass, FP16Pass, QDQPass, RTNPass +from .passes import BaseQuantPass, FP16Pass, RTNPass, StaticPass logger = logging.getLogger(__name__) @@ -38,8 +38,8 @@ def expand_precision( ========= ======================= ``fp16`` ``[FP16Pass(config)]`` ``rtn`` ``[RTNPass(config)]`` - ``static`` ``[QDQPass(config)]`` - ``dynamic`` ``[QDQPass(config)]`` + ``static`` ``[StaticPass(config)]`` + ``dynamic`` ``[StaticPass(config)]`` (placeholder until DynamicPass is implemented) ========= ======================= Args: @@ -59,8 +59,8 @@ def expand_precision( _pass_factories: dict[str, BaseQuantPass] = { "fp16": FP16Pass(config), "rtn": RTNPass(config), - "static": QDQPass(config), - "dynamic": QDQPass(config), + "static": StaticPass(config), + "dynamic": StaticPass(config), } if mode in _pass_factories: diff --git a/src/winml/modelkit/utils/cli.py b/src/winml/modelkit/utils/cli.py index c62ede36b..4c0f42a41 100644 --- a/src/winml/modelkit/utils/cli.py +++ b/src/winml/modelkit/utils/cli.py @@ -438,6 +438,7 @@ def precision_option( optional_message: str | None = None, include_short: bool = True, help_text: str | None = None, + multiple: bool = False, ) -> Callable[[F], F]: """Add --precision option to a Click command. @@ -459,6 +460,10 @@ def precision_option( values differ from the default float+int set (e.g. ``quantize``, which has no fp16/fp32) supply their own; ``optional_message`` is still appended to it. + multiple: Allow the flag to be specified multiple times to compose a + pass pipeline (e.g. ``-p int4 -p fp16``). When True the parameter + receives a ``tuple[str, ...]`` and ``default`` should be ``()`` + (default: False). Returns: Decorator function. @@ -478,7 +483,8 @@ def precision_option( *param_decls, type=str, default=default, - show_default=True, + multiple=multiple, + show_default=not multiple, help=base_help, ) diff --git a/tests/e2e/test_quantize_e2e.py b/tests/e2e/test_quantize_e2e.py index 94882dde1..8304e5730 100644 --- a/tests/e2e/test_quantize_e2e.py +++ b/tests/e2e/test_quantize_e2e.py @@ -933,3 +933,98 @@ def test_verbose_emits_more_output(self, runner: CliRunner, tiny_onnx: Path, tmp f"verbose did not increase output\n--- quiet ---\n{r_quiet.output}\n" f"--- verbose ---\n{r_verbose.output}" ) + + +# =========================================================================== +# Multi-precision pipeline +# =========================================================================== + + +def _build_rtn_onnx(path: Path) -> None: + """Build an ONNX with MatMul weights large enough for RTN int4 (K >= block_size=128).""" + rng = np.random.default_rng(77) + x = onnx.helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 256]) + y = onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1, 64]) + w1 = onnx.numpy_helper.from_array(rng.standard_normal((256, 128)).astype(np.float32), "W1") + b1 = onnx.numpy_helper.from_array(rng.standard_normal((128,)).astype(np.float32), "B1") + w2 = onnx.numpy_helper.from_array(rng.standard_normal((128, 64)).astype(np.float32), "W2") + b2 = onnx.numpy_helper.from_array(rng.standard_normal((64,)).astype(np.float32), "B2") + nodes = [ + onnx.helper.make_node("MatMul", ["input", "W1"], ["mm1"]), + onnx.helper.make_node("Add", ["mm1", "B1"], ["add1"]), + onnx.helper.make_node("MatMul", ["add1", "W2"], ["mm2"]), + onnx.helper.make_node("Add", ["mm2", "B2"], ["output"]), + ] + graph = onnx.helper.make_graph(nodes, "rtn_quantizable", [x], [y], [w1, b1, w2, b2]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 17)]) + model.ir_version = 8 + onnx.checker.check_model(model) + onnx.save(model, str(path)) + + +@pytest.fixture(scope="session") +def rtn_onnx(tmp_path_factory: pytest.TempPathFactory) -> Path: + """ONNX with large enough MatMul weights to be quantized by RTN int4.""" + d = tmp_path_factory.mktemp("rtn_quant") + p = d / "rtn.onnx" + _build_rtn_onnx(p) + return p + + +class TestMultiPrecision: + def test_int4_then_fp16_pipeline(self, runner: CliRunner, rtn_onnx: Path, tmp_path: Path): + """--precision int4 --precision fp16: RTN pass then FP16 conversion. + + Verifies that: + - The pipeline completes successfully + - RTN pass applied: model contains MatMulNBits nodes + - FP16 pass applied: bias initializers converted to FLOAT16 + """ + out = tmp_path / "multi_int4_fp16.onnx" + r = _invoke( + runner, + [ + "-m", + str(rtn_onnx), + "-o", + str(out), + "--precision", + "int4", + "--precision", + "fp16", + ], + ) + assert r.exit_code == 0, f"pipeline exited {r.exit_code}\n{r.output}" + assert out.exists() + + model = onnx.load(str(out)) + + # RTN pass: MatMul nodes replaced by MatMulNBits + op_types = {n.op_type for n in model.graph.node} + assert "MatMulNBits" in op_types, ( + f"expected MatMulNBits after RTN pass, got ops: {op_types}" + ) + + # FP16 pass: bias initializers converted from FLOAT to FLOAT16 + float16_inits = [ + i for i in model.graph.initializer if i.data_type == onnx.TensorProto.FLOAT16 + ] + assert float16_inits, ( + "expected at least one FLOAT16 initializer after FP16 pass; " + f"dtypes: {[i.data_type for i in model.graph.initializer]}" + ) + + # Pipeline label appears in stdout + assert "int4" in r.output.lower() + assert "fp16" in r.output.lower() + + def test_pipeline_default_output_path(self, runner: CliRunner, rtn_onnx: Path, tmp_path: Path): + """Multi-precision without -o should produce {stem}_int4_fp16.onnx next to input.""" + local = tmp_path / "model.onnx" + local.write_bytes(rtn_onnx.read_bytes()) + r = _invoke( + runner, + ["-m", str(local), "--precision", "int4", "--precision", "fp16"], + ) + assert r.exit_code == 0, r.output + assert (tmp_path / "model_int4_fp16.onnx").exists() diff --git a/tests/unit/test_quant_passes.py b/tests/unit/test_quant_passes.py index 22e10e868..b76457e8b 100644 --- a/tests/unit/test_quant_passes.py +++ b/tests/unit/test_quant_passes.py @@ -14,7 +14,7 @@ from winml.modelkit.quant import WinMLQuantizationConfig from winml.modelkit.quant.config import QuantizeResult -from winml.modelkit.quant.passes import BaseQuantPass, FP16Pass, QDQPass, RTNPass +from winml.modelkit.quant.passes import BaseQuantPass, FP16Pass, RTNPass, StaticPass if TYPE_CHECKING: @@ -87,7 +87,7 @@ def test_static_returns_qdq_pass(self) -> None: config = WinMLQuantizationConfig(mode="static") passes = expand_precision("static", config) assert len(passes) == 1 - assert isinstance(passes[0], QDQPass) + assert isinstance(passes[0], StaticPass) def test_dynamic_returns_qdq_pass(self) -> None: from winml.modelkit.quant.quantizer import expand_precision @@ -95,7 +95,7 @@ def test_dynamic_returns_qdq_pass(self) -> None: config = WinMLQuantizationConfig(mode="static") passes = expand_precision("dynamic", config) assert len(passes) == 1 - assert isinstance(passes[0], QDQPass) + assert isinstance(passes[0], StaticPass) def test_unknown_mode_raises(self) -> None: from winml.modelkit.quant.quantizer import expand_precision From 0b6098a3a4ddfaa665aa4a94d89a6187fd91aae3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 26 Jun 2026 17:26:46 +0800 Subject: [PATCH 6/6] fix(quant): address PR review comments and fix mypy lint errors MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - static.py: fix model_name -> model_id (WinMLQuantizationConfig has no model_name field; correct field is model_id) — fixes mypy [attr-defined] - cli.py: widen precision_option default type to str | tuple[str,...] | None so passing default=() for multiple=True passes mypy — fixes [arg-type] - quantizer.py: make expand_precision mode optional, falling back to config.mode when not provided; removes redundant arg from quantize_onnx caller (addresses reviewer: 'why have mode param when config has it') - quantize.py: remove redundant 'from typing import cast' inside _run_multi_precision (cast already imported at module level) - passes/base.py: add note in run() docstring explaining why file-based I/O is used (addresses reviewer suggestion about in-memory model proto) - tests: add test_no_mode_uses_config_mode to cover expand_precision(config=) path --- src/winml/modelkit/commands/quantize.py | 2 -- src/winml/modelkit/quant/passes/base.py | 6 ++++++ src/winml/modelkit/quant/passes/static.py | 2 +- src/winml/modelkit/quant/quantizer.py | 21 ++++++++++++--------- src/winml/modelkit/utils/cli.py | 2 +- tests/unit/test_quant_passes.py | 10 ++++++++++ 6 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/winml/modelkit/commands/quantize.py b/src/winml/modelkit/commands/quantize.py index c56ebcced..80988c4ce 100644 --- a/src/winml/modelkit/commands/quantize.py +++ b/src/winml/modelkit/commands/quantize.py @@ -330,8 +330,6 @@ def _run_multi_precision( console: Console, ) -> None: """Execute a multi-pass quantization pipeline from ordered precision strings.""" - from typing import cast - from ..config.precision import extract_weight_bits from ..quant import Quantizer, WinMLQuantizationConfig, expand_precision diff --git a/src/winml/modelkit/quant/passes/base.py b/src/winml/modelkit/quant/passes/base.py index 3d3b4a23f..2503a0dc4 100644 --- a/src/winml/modelkit/quant/passes/base.py +++ b/src/winml/modelkit/quant/passes/base.py @@ -56,4 +56,10 @@ def run( Returns: :class:`~winml.modelkit.quant.config.QuantizeResult` describing the outcome of this pass. + + Note: + Passes use file-based I/O because ORT's calibration and RTN APIs + operate on paths, and external-data models cannot be held fully in + memory. A future enhancement could add an optional in-memory + fast-path for small single-pass models. """ diff --git a/src/winml/modelkit/quant/passes/static.py b/src/winml/modelkit/quant/passes/static.py index 16cd8b24b..0ef080026 100644 --- a/src/winml/modelkit/quant/passes/static.py +++ b/src/winml/modelkit/quant/passes/static.py @@ -90,7 +90,7 @@ def run( task = self._config.task or "random" data_reader = DatasetCalibrationReader( - model_name=self._config.model_name or "random", + model_name=self._config.model_id or "random", task=task, max_samples=self._config.samples, dataset_name=self._config.dataset_name, diff --git a/src/winml/modelkit/quant/quantizer.py b/src/winml/modelkit/quant/quantizer.py index 8aa9a467c..a06a4b399 100644 --- a/src/winml/modelkit/quant/quantizer.py +++ b/src/winml/modelkit/quant/quantizer.py @@ -23,13 +23,14 @@ def expand_precision( - mode: str, + mode: str | None = None, config: WinMLQuantizationConfig | None = None, ) -> list[BaseQuantPass]: """Expand a precision string into an ordered list of quantization passes. All passes share the same ``config`` so every pass can read the fields - relevant to it. + relevant to it. When *mode* is omitted, ``config.mode`` is used so that + ``expand_precision(config=cfg)`` works as a single-precision convenience. Supported values: @@ -43,7 +44,8 @@ def expand_precision( ========= ======================= Args: - mode: Precision string (e.g. ``"fp16"``). + mode: Precision string (e.g. ``"fp16"``). When *None*, falls back to + ``config.mode`` (or ``"static"`` if *config* is also *None*). config: Shared quantization configuration. If *None*, a default :class:`WinMLQuantizationConfig` is used. @@ -55,6 +57,7 @@ def expand_precision( ValueError: If *mode* is not recognised. """ config = config or WinMLQuantizationConfig() + effective_mode = mode if mode is not None else config.mode _pass_factories: dict[str, BaseQuantPass] = { "fp16": FP16Pass(config), @@ -63,14 +66,14 @@ def expand_precision( "dynamic": StaticPass(config), } - if mode in _pass_factories: - return [_pass_factories[mode]] + if effective_mode in _pass_factories: + return [_pass_factories[effective_mode]] - if mode in _COMPOSITE_PRECISIONS: - return [_pass_factories[step] for step in _COMPOSITE_PRECISIONS[mode]] + if effective_mode in _COMPOSITE_PRECISIONS: + return [_pass_factories[step] for step in _COMPOSITE_PRECISIONS[effective_mode]] raise ValueError( - f"Unknown precision mode {mode!r}. " + f"Unknown precision mode {effective_mode!r}. " f"Valid values: {sorted(_pass_factories) + sorted(_COMPOSITE_PRECISIONS)}" ) @@ -255,5 +258,5 @@ def quantize_onnx( use_external_data: bool = kwargs.pop("use_external_data", True) if kwargs: raise TypeError(f"quantize_onnx() got unexpected keyword arguments: {sorted(kwargs)}") - passes = expand_precision(config.mode, config) + passes = expand_precision(config=config) return Quantizer(passes).run(model_path, output_path, use_external_data=use_external_data) diff --git a/src/winml/modelkit/utils/cli.py b/src/winml/modelkit/utils/cli.py index 4c0f42a41..33fbf7fd4 100644 --- a/src/winml/modelkit/utils/cli.py +++ b/src/winml/modelkit/utils/cli.py @@ -434,7 +434,7 @@ def device_option( def precision_option( - default: str | None = "auto", + default: str | tuple[str, ...] | None = "auto", optional_message: str | None = None, include_short: bool = True, help_text: str | None = None, diff --git a/tests/unit/test_quant_passes.py b/tests/unit/test_quant_passes.py index b76457e8b..bb87a8af3 100644 --- a/tests/unit/test_quant_passes.py +++ b/tests/unit/test_quant_passes.py @@ -110,6 +110,16 @@ def test_none_config_uses_default(self) -> None: assert len(passes) == 1 assert isinstance(passes[0].config, WinMLQuantizationConfig) + def test_no_mode_uses_config_mode(self) -> None: + """expand_precision(config=cfg) should use cfg.mode when mode is not given.""" + from winml.modelkit.quant.quantizer import expand_precision + + config = WinMLQuantizationConfig(mode="rtn", rtn_bits=4) + passes = expand_precision(config=config) + assert len(passes) == 1 + assert isinstance(passes[0], RTNPass) + assert passes[0].config is config + # --------------------------------------------------------------------------- # Quantizer — single pass