Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,7 @@ def _run_quantize_stage(
) -> Path:
"""Run the quantize stage (if quant is configured).

Delegates single-pass quantization to ``quantize_onnx(config=...)``.
Delegates quantization to ``quantize_onnx(config=...)``.
The cmd layer only handles UI display and the QDQ skip check.

Args:
Expand Down
149 changes: 136 additions & 13 deletions src/winml/modelkit/commands/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@
@cli_utils.output_option("Output path (default: {input}_qdq.onnx)")
@cli_utils.overwrite_option()
@cli_utils.precision_option(
default=None,
help_text="Quantization precision: auto, fp16, int4, int8, int16, or w{x}a{y} where "
"x in {4,8,16}, y in {8,16} (e.g., w4a16, w8a8, w8a16). "
"int4/w4a16 uses RTN weight-only quantization; "
"fp16 converts all FP32 tensors to FP16 (no QDQ)",
default=(),
multiple=True,
help_text="Quantization precision: fp16, int4, int8, int16, or w{x}a{y} where "
"x in {4,8,16}, y in {8,16} (e.g., w8a8, w8a16). "
"int4 uses RTN weight-only quantization; "
"fp16 converts all FP32 tensors to FP16 (no QDQ). "
"Repeat to chain passes in order (e.g. -p int4 -p fp16)",
optional_message="Overridden by explicit --weight-type/--activation-type",
)
@click.option(
Expand Down Expand Up @@ -111,7 +113,7 @@ def quantize(
model: Path,
output: Path | None,
overwrite: bool,
precision: str | None,
precision: tuple[str, ...],
samples: int,
method: str,
weight_type: str | None,
Expand All @@ -127,9 +129,12 @@ def quantize(
r"""Quantize ONNX model by inserting QDQ nodes, RTN weight-only, or convert to FP16.

This command applies quantization to an ONNX model. The algorithm is
auto-selected from the precision: int4/w4a16 → RTN weight-only,
auto-selected from the precision: int4 → RTN weight-only,
int8/int16/w8a8 → static QDQ, fp16 → FP16 conversion.

Repeat --precision to chain passes in order:
``-p int4 -p fp16`` runs RTN int4 quantization then FP16 conversion.

\b
Examples:
# Basic quantization with defaults (10 samples, uint8)
Expand All @@ -141,6 +146,9 @@ def quantize(
# RTN 4-bit weight-only quantization (no calibration data needed)
winml quantize -m model.onnx --precision int4

# RTN int4 followed by FP16 conversion (two-pass pipeline)
winml quantize -m model.onnx --precision int4 --precision fp16

# Int16 quantization
winml quantize -m model.onnx --precision int16

Expand Down Expand Up @@ -182,8 +190,29 @@ def quantize(
# Import quantizer (late import to speed up CLI)
from ..quant import WinMLQuantizationConfig, quantize_onnx

# ── Build config based on precision ──────────────────────────
precision_lower = precision.lower() if precision else None
# ── Multi-pass pipeline ───────────────────────────────────────
if len(precision) > 1:
_run_multi_precision(
ctx=ctx,
model=model,
output=output,
overwrite=overwrite,
precision=precision,
samples=samples,
method=method,
weight_type=weight_type,
activation_type=activation_type,
per_channel=per_channel,
symmetric=symmetric,
task=task,
model_id=model_id,
console=console,
)
return

# ── Single-precision (or default) path ───────────────────────
single = precision[0] if precision else None
precision_lower = single.lower() if single else None

if precision_lower == "fp16":
# FP16 conversion
Expand Down Expand Up @@ -211,7 +240,7 @@ def quantize(
else:
# QDQ calibrated quantization
resolved_weight, resolved_activation = _resolve_quant_types(
precision, weight_type, activation_type
single, weight_type, activation_type
)
if output is None:
output = model.parent / f"{model.stem}_qdq.onnx"
Expand Down Expand Up @@ -243,13 +272,11 @@ def quantize(
console.print(f"[bold blue]Dataset:[/bold blue] {_dataset_display}")

# ── Shared execution: print header, run, report ──────────────
# Refuse to clobber an existing output unless the user opted in. Runs after
# the per-precision default path is resolved, before any mkdir/work.
cli_utils.guard_output(output, overwrite)
output.parent.mkdir(parents=True, exist_ok=True)
console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Output:[/bold blue] {output}")
console.print(f"[bold blue]Precision:[/bold blue] {precision or 'auto'}")
console.print(f"[bold blue]Precision:[/bold blue] {single or 'auto'}")

try:
console.print(f"\n[bold]Running {label.lower()}...[/bold]")
Expand All @@ -275,6 +302,102 @@ def quantize(
raise click.ClickException(f"{label} failed: {e}") from e


def _cli_precision_to_mode(precision: str) -> str:
"""Map a CLI precision string to a quantizer pass mode."""
p = precision.lower()
if p == "fp16":
return "fp16"
if is_weight_only_precision(p):
return "rtn"
return "static"


def _run_multi_precision(
*,
ctx: click.Context,
model: Path,
output: Path | None,
overwrite: bool,
precision: tuple[str, ...],
samples: int,
method: str,
weight_type: str | None,
activation_type: str | None,
per_channel: bool,
symmetric: bool,
task: str | None,
model_id: str | None,
console: Console,
) -> None:
"""Execute a multi-pass quantization pipeline from ordered precision strings."""
from ..config.precision import extract_weight_bits
from ..quant import Quantizer, WinMLQuantizationConfig, expand_precision

modes = [_cli_precision_to_mode(p) for p in precision]
has_calibration_pass = any(m == "static" for m in modes)

if not has_calibration_pass:
cli_utils.warn_ignored_calibration_options(
ctx, "No selected pass uses calibration data.", console=console
)

# Extract rtn_bits from the first weight-only precision in the list.
rtn_bits = next(
(extract_weight_bits(p.lower()) for p in precision if is_weight_only_precision(p.lower())),
4,
)

config = WinMLQuantizationConfig(
rtn_bits=rtn_bits,
samples=samples,
calibration_method=cast('Literal["minmax", "entropy", "percentile"]', method),
weight_type=cast('Literal["uint8", "int8", "uint16", "int16"]', weight_type or "uint8"),
activation_type=cast(
'Literal["uint8", "int8", "uint16", "int16"]', activation_type or "uint8"
),
per_channel=per_channel,
symmetric=symmetric,
task=task,
model_id=model_id,
)

passes = []
for mode in modes:
passes.extend(expand_precision(mode, config))

label = " → ".join(p.lower() for p in precision)
if output is None:
suffix = "_".join(p.lower() for p in precision)
output = model.parent / f"{model.stem}_{suffix}.onnx"

cli_utils.guard_output(output, overwrite)
output.parent.mkdir(parents=True, exist_ok=True)
console.print(f"[bold blue]Input:[/bold blue] {model}")
console.print(f"[bold blue]Output:[/bold blue] {output}")
console.print(f"[bold blue]Pipeline:[/bold blue] {label}")

try:
console.print(f"\n[bold]Running pipeline: {label}...[/bold]")
result = Quantizer(passes).run(model, output)

if result.success:
console.print("\n[bold green]Success![/bold green] Pipeline complete")
console.print(f"[dim]Output: {result.output_path}[/dim]")
console.print(f"[dim]Total time: {result.total_time_seconds:.2f}s[/dim]")
else:
console.print("\n[bold red]Pipeline failed:[/bold red]")
for error in result.errors:
console.print(f" {error}")
raise click.ClickException("Pipeline failed")

except click.ClickException:
raise
except Exception as e:
console.print(f"\n[bold red]Pipeline failed:[/bold red] {e}")
logger.exception("Pipeline failed")
raise click.ClickException(f"Pipeline failed: {e}") from e


def _resolve_quant_types(
precision: str | None,
weight_type: str | None,
Expand Down
28 changes: 25 additions & 3 deletions src/winml/modelkit/quant/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,56 @@
Provides QDQ (Quantize-Dequantize) quantization for ONNX models.

Usage:
from winml.modelkit.quant import quantize_onnx, WinMLQuantizationConfig
from winml.modelkit.quant import (
quantize_onnx,
Quantizer,
expand_precision,
WinMLQuantizationConfig,
)

# Quick quantize with defaults (10 samples, uint8)
result = quantize_onnx("model.onnx")

# Custom config
result = quantize_onnx("model.onnx", WinMLQuantizationConfig(samples=100))

# Pipeline: RTN int4 weight-only
config = WinMLQuantizationConfig(mode="rtn", rtn_bits=4)
result = Quantizer(expand_precision("rtn", config)).run("model.onnx", "out.onnx")
"""

from typing import Any
from typing import TYPE_CHECKING, Any

from .config import QuantizeResult, WinMLQuantizationConfig
from .passes import BaseQuantPass, FP16Pass, RTNPass, StaticPass


if TYPE_CHECKING:
from .quantizer import Quantizer, expand_precision, quantize_onnx


__all__ = [
"BaseQuantPass",
"FP16Pass",
"QuantizeResult",
"Quantizer",
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
"RTNPass",
"StaticPass",
"WinMLQuantizationConfig",
"expand_precision",
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
"quantize_onnx",
]


_LAZY_IMPORTS: dict[str, tuple[str, str]] = {
"quantize_onnx": (".quantizer", "quantize_onnx"),
"Quantizer": (".quantizer", "Quantizer"),
"expand_precision": (".quantizer", "expand_precision"),
}


def __getattr__(name: str) -> Any:
"""Lazy-load quantizer (imports onnxruntime.quantization)."""
"""Lazy-load quantizer module (avoids importing onnxruntime at package import time)."""
if name in _LAZY_IMPORTS:
module_path, attr_name = _LAZY_IMPORTS[name]
import importlib
Expand Down
18 changes: 18 additions & 0 deletions src/winml/modelkit/quant/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Quantization passes sub-package."""

from .base import BaseQuantPass
from .fp16 import FP16Pass
from .rtn import RTNPass
from .static import StaticPass


__all__ = [
"BaseQuantPass",
"FP16Pass",
"RTNPass",
"StaticPass",
]
65 changes: 65 additions & 0 deletions src/winml/modelkit/quant/passes/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""Base class for quantization passes."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING


if TYPE_CHECKING:
from pathlib import Path

from ..config import QuantizeResult, WinMLQuantizationConfig


class BaseQuantPass(ABC):
"""Abstract base class for a single quantization pass.

Each pass is constructed with a ``WinMLQuantizationConfig`` that provides
all settings. Passes read only the fields relevant to them and ignore the
rest, so a single shared config object can be threaded through every pass
in a :class:`~winml.modelkit.quant.quantizer.Quantizer` pipeline.

Example::

pass_ = FP16Pass(config)
result = pass_.run(model_path, output_path)
"""

def __init__(self, config: WinMLQuantizationConfig) -> None:
self._config = config

@property
def config(self) -> WinMLQuantizationConfig:
"""Return the shared quantization configuration."""
return self._config

@abstractmethod
def run(
self,
model_path: Path,
Comment thread
DingmaomaoBJTU marked this conversation as resolved.
output_path: Path,
*,
use_external_data: bool = True,
) -> QuantizeResult:
"""Run this quantization pass.

Args:
model_path: Path to the input ONNX model.
output_path: Path where the output ONNX model should be written.
use_external_data: Whether to write large tensors as external data.

Returns:
:class:`~winml.modelkit.quant.config.QuantizeResult` describing
the outcome of this pass.

Note:
Passes use file-based I/O because ORT's calibration and RTN APIs
operate on paths, and external-data models cannot be held fully in
memory. A future enhancement could add an optional in-memory
fast-path for small single-pass models.
"""
Loading
Loading