From 3796b7eabe620c9712fb65941ac5d7f441b6e79f Mon Sep 17 00:00:00 2001 From: spalne Date: Mon, 8 Jun 2026 11:17:08 -0700 Subject: [PATCH 01/17] Add qauntization for transformers for qwen0.6B --- qwen3_quantize.py | 256 ++++++++++++++++++++++++ src/winml/modelkit/onnx/__init__.py | 2 + src/winml/modelkit/onnx/qwen_surgery.py | 186 +++++++++++++++++ test_qwen 2.py | 70 +++++++ 4 files changed, 514 insertions(+) create mode 100644 qwen3_quantize.py create mode 100644 src/winml/modelkit/onnx/qwen_surgery.py create mode 100644 test_qwen 2.py diff --git a/qwen3_quantize.py b/qwen3_quantize.py new file mode 100644 index 000000000..655c65e6a --- /dev/null +++ b/qwen3_quantize.py @@ -0,0 +1,256 @@ +"""Qwen3 transformer-only quantization. + +Must be called after the composite Qwen3 model has been built (e.g. by +``test_qwen 2.py``) so that ``decoder_prefill`` / ``decoder_gen`` ONNX files +exist in the winml cache. + +Pipeline: + + 1. Apply ``make_transformer_only`` surgery to each sub-model, producing + ``*_transformer.onnx`` with ``inputs_embeds`` input and + ``output_hidden_states`` output — embeddings and lm_head are stripped + out (ignored, not quantized). + 2. Quantize those transformer-only files via winml-cli's ``quantize_onnx`` + using a calibration reader that runs ``embed_tokens`` in PyTorch on + real text samples. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Any, Iterator + +import numpy as np +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from winml.modelkit.models.winml.composite_model import WinMLCompositeModel +from winml.modelkit.onnx import make_transformer_only +from winml.modelkit.quant import WinMLQuantizationConfig, quantize_onnx + + +logger = logging.getLogger(__name__) + +DEFAULT_MODEL_ID = "Qwen/Qwen3-0.6B" +DEFAULT_MAX_CACHE = 256 +DEFAULT_PREFILL_SEQ = 64 +DEFAULT_GEN_SEQ = 1 +DEFAULT_NUM_SAMPLES = 16 +DEFAULT_PROMPTS = [ + "Solve: 8 * 7 = ?", + "Translate to French: The weather is nice today.", + "Write a short poem about the ocean.", + "Explain gradient descent in one paragraph.", + "What is the capital of Japan?", + "List three uses of magnesium.", + "Summarize the plot of Hamlet in two sentences.", + "Give a Python one-liner to reverse a string.", +] + + +# --------------------------------------------------------------------------- +# Calibration data reader +# --------------------------------------------------------------------------- + + +class Qwen3TransformerCalibReader: + """Yields calibration feeds for the transformer-only Qwen3 ONNX. + + Runs HF ``embed_tokens`` in PyTorch to produce ``inputs_embeds`` since the + embedding layer was stripped from the ONNX graph. All other inputs + (attention_mask, position_ids, past_{i}_key/value) follow the conventions + used by winml-cli's ``WinMLQwen3Model`` runtime. + """ + + def __init__( + self, + embed_tokens: torch.nn.Module, + config: Any, + token_ids_list: list[torch.Tensor], + *, + seq_len: int, + max_cache_len: int, + ) -> None: + self.embed = embed_tokens + self.cfg = config + self.seq_len = seq_len + self.max_cache_len = max_cache_len + self.num_layers = config.num_hidden_layers + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self._samples = list(self._build_samples(token_ids_list)) + self._iter: Iterator[dict[str, np.ndarray]] | None = None + self.rewind() + + def _build_samples( + self, token_ids_list: list[torch.Tensor] + ) -> Iterator[dict[str, np.ndarray]]: + for ids in token_ids_list: + # Right-truncate / pad to seq_len so we feed the static graph shape. + ids = ids[:, : self.seq_len] + real_len = ids.shape[1] + if real_len < self.seq_len: + pad = torch.zeros( + (1, self.seq_len - real_len), dtype=ids.dtype, device=ids.device + ) + ids = torch.cat([ids, pad], dim=1) + + with torch.no_grad(): + embeds = self.embed(ids).to(torch.float32).cpu().numpy() + + # attention_mask: ones for real prompt positions placed at the + # END of the max_cache buffer (sliding-window cache convention), + # zeros elsewhere. + attn_mask = np.zeros((1, self.max_cache_len), dtype=np.int64) + attn_mask[0, -real_len:] = 1 + + # position_ids: 0..seq_len-1 (clamped for padding). + position_ids = np.arange(self.seq_len, dtype=np.int64)[None, :] + + feed: dict[str, np.ndarray] = { + "inputs_embeds": embeds.astype(np.float32), + "attention_mask": attn_mask, + "position_ids": position_ids, + } + kv_shape = (1, self.num_kv_heads, self.max_cache_len, self.head_dim) + zeros = np.zeros(kv_shape, dtype=np.float32) + for i in range(self.num_layers): + feed[f"past_{i}_key"] = zeros + feed[f"past_{i}_value"] = zeros + yield feed + + def get_next(self) -> dict[str, np.ndarray] | None: + try: + return next(self._iter) if self._iter is not None else None + except StopIteration: + return None + + def rewind(self) -> None: + self._iter = iter(self._samples) + + +# --------------------------------------------------------------------------- +# Pipeline +# --------------------------------------------------------------------------- + + +def _tokenize_prompts( + tokenizer: Any, prompts: list[str], num_samples: int +) -> list[torch.Tensor]: + # Cycle through prompts up to num_samples; apply chat template like the + # runtime so calibration distribution matches inference inputs. + out: list[torch.Tensor] = [] + for i in range(num_samples): + prompt = prompts[i % len(prompts)] + text = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ids = tokenizer([text], return_tensors="pt").input_ids + out.append(ids) + return out + + +def quantize_built_model( + model: WinMLCompositeModel, + *, + model_id: str = DEFAULT_MODEL_ID, + max_cache_len: int = DEFAULT_MAX_CACHE, + prefill_seq: int = DEFAULT_PREFILL_SEQ, + num_samples: int = DEFAULT_NUM_SAMPLES, + weight_type: str = "uint8", + activation_type: str = "uint16", +) -> dict[str, Path]: + """Run surgery + transformer-only quantization on an already-built composite. + + Reuses the ONNX files produced by ``WinMLCompositeModel.from_pretrained`` + so this can be called after a build step without re-exporting. + + Returns: mapping of sub-model name → quantized ONNX path. + """ + sub_paths: dict[str, Path] = {} + for name, sub in model.sub_models.items(): + final_path = Path(sub._onnx_path) + # ``_model.onnx`` is the *compiled* QNN EPContext blob — surgery needs + # the uncompiled fp16 graph. ``build.hf`` emits ``{cache_key}_optimized.onnx`` + # alongside it in the same artifacts directory. + if final_path.name.endswith("_model.onnx"): + stem = final_path.name[: -len("_model.onnx")] + optimized = final_path.with_name(f"{stem}_optimized.onnx") + if optimized.exists(): + sub_paths[name] = optimized + continue + print( + f"WARNING: {optimized.name} not found next to {final_path.name}; " + "falling back to the compiled model (surgery will likely fail)." + ) + sub_paths[name] = final_path + + for name, p in sub_paths.items(): + print(f" {name}: {p}") + + print("\n=== Loading HF embed_tokens for calibration ===") + hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) + hf_model.eval() + embed_tokens = hf_model.get_input_embeddings() + tokenizer = AutoTokenizer.from_pretrained(model_id) + token_ids_list = _tokenize_prompts(tokenizer, DEFAULT_PROMPTS, num_samples) + + seq_by_sub = { + "decoder_prefill": prefill_seq, + "decoder_gen": DEFAULT_GEN_SEQ, + } + + quant_paths: dict[str, Path] = {} + for sub_name, fused_path in sub_paths.items(): + if sub_name not in seq_by_sub: + print(f"\n--- Skipping unknown sub-model {sub_name!r} ---") + continue + + seq_len = seq_by_sub[sub_name] + transformer_path = fused_path.with_name(fused_path.stem + "_transformer.onnx") + quant_path = transformer_path.with_name( + transformer_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" + ) + + print(f"\n=== Surgery: {sub_name} (seq_len={seq_len}) ===") + print(f" in : {fused_path}") + print(f" out: {transformer_path}") + make_transformer_only(fused_path, transformer_path) + + print(f"\n=== Quantize (transformer only): {sub_name} ===") + print(f" out: {quant_path}") + reader = Qwen3TransformerCalibReader( + embed_tokens, + hf_model.config, + token_ids_list, + seq_len=seq_len, + max_cache_len=max_cache_len, + ) + cfg = WinMLQuantizationConfig( + samples=num_samples, + weight_type=weight_type, # type: ignore[arg-type] + activation_type=activation_type, # type: ignore[arg-type] + calibration_method="minmax", + calibration_data=reader, + ) + result = quantize_onnx(transformer_path, output_path=quant_path, config=cfg) + if not result.success: + print(" FAILED:") + for err in result.errors: + print(f" {err}") + raise SystemExit(1) + print( + f" ok — {result.nodes_quantized} QDQ nodes inserted in " + f"{result.total_time_seconds:.1f}s" + ) + quant_paths[sub_name] = quant_path + + print("\n=== Done ===") + return quant_paths + diff --git a/src/winml/modelkit/onnx/__init__.py b/src/winml/modelkit/onnx/__init__.py index a3bc49d51..0287a2ff7 100644 --- a/src/winml/modelkit/onnx/__init__.py +++ b/src/winml/modelkit/onnx/__init__.py @@ -19,6 +19,7 @@ from .io import InputTensorSpec, OutputTensorSpec, generate_inputs_from_onnx, get_io_config from .metadata import capture_metadata, restore_metadata from .persistence import cleanup_onnx, load_onnx, save_onnx +from .qwen_surgery import make_transformer_only from .shape import infer_onnx_shapes, infer_shapes from .utils import EXTERNAL_DATA_THRESHOLD, check_onnx_model, get_model_size @@ -41,6 +42,7 @@ "is_compiled_onnx", "is_quantized_onnx", "load_onnx", + "make_transformer_only", "remove_optional_from_type_annotation", "restore_metadata", "save_onnx", diff --git a/src/winml/modelkit/onnx/qwen_surgery.py b/src/winml/modelkit/onnx/qwen_surgery.py new file mode 100644 index 000000000..cd49ee5ec --- /dev/null +++ b/src/winml/modelkit/onnx/qwen_surgery.py @@ -0,0 +1,186 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Ad-hoc ONNX surgery to turn a Qwen3 decoder ONNX into a transformer-only graph. + +Applied as a post-export surgery on the fused decoder ONNX produced by +``WinMLQwen3Model`` (``decoder_prefill.onnx`` / ``decoder_gen.onnx``). + +The resulting transformer-only ONNX has: + - ``input_ids`` graph input replaced by ``inputs_embeds`` (FLOAT, + ``[batch, seq, hidden_size]``) — the upstream embedding Gather is + removed. + - ``logits`` graph output replaced by ``output_hidden_states`` + (FLOAT, ``[batch, seq, hidden_size]``) — the final ``lm_head`` MatMul + is removed. +""" + +from __future__ import annotations + +import logging +from pathlib import Path + +import onnx +from onnx import TensorProto, helper + +from .persistence import load_onnx, save_onnx + + +logger = logging.getLogger(__name__) + + +def _dim(d: onnx.TensorShapeProto.Dimension) -> int | str: + if d.HasField("dim_value"): + return d.dim_value + return d.dim_param or "?" + + +def make_transformer_only( + model_path: str | Path, + output_path: str | Path, + *, + input_ids_name: str = "input_ids", + logits_name: str = "logits", + inputs_embeds_name: str = "inputs_embeds", + output_hidden_states_name: str = "output_hidden_states", +) -> Path: + """Strip the embedding Gather and the lm_head MatMul from a Qwen3 ONNX. + + Args: + model_path: Path to the fused decoder ONNX (logits output, input_ids input). + output_path: Destination for the transformer-only ONNX. + input_ids_name: Name of the input_ids graph input to drop. + logits_name: Name of the logits graph output to drop. + inputs_embeds_name: Display name for the new embeddings input + (used only for logging; the actual tensor keeps its existing + internal name so downstream nodes need no rewiring). + output_hidden_states_name: Display name for the new hidden-state output. + + Returns: + The output path. + """ + model_path = Path(model_path) + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + + model = load_onnx(model_path, load_weights=True, validate=False) + graph = model.graph + init_by_name = {init.name: init for init in graph.initializer} + + # -------------------- Embedding removal -------------------- + embed_idx = next( + (i for i, n in enumerate(graph.node) if input_ids_name in n.input), + None, + ) + if embed_idx is None: + msg = f"No node consumes graph input {input_ids_name!r}" + raise RuntimeError(msg) + + embed_node = graph.node[embed_idx] + embed_out_name = embed_node.output[0] + + embed_weight = None + for ipt in embed_node.input: + init = init_by_name.get(ipt) + if init is not None and len(init.dims) == 2: + embed_weight = init + break + if embed_weight is None: + msg = f"Could not find 2-D embedding weight initializer on node {embed_node.name!r}" + raise RuntimeError(msg) + hidden_size = int(embed_weight.dims[1]) + + ids_input = next(i for i in graph.input if i.name == input_ids_name) + batch_dim = _dim(ids_input.type.tensor_type.shape.dim[0]) + seq_dim = _dim(ids_input.type.tensor_type.shape.dim[1]) + + logger.info( + "Removing embedding node %r (%s) — exposing %r as new input %r [%s, %s, %d]", + embed_node.name, + embed_node.op_type, + embed_out_name, + inputs_embeds_name, + batch_dim, + seq_dim, + hidden_size, + ) + + new_embed_input = helper.make_tensor_value_info( + inputs_embeds_name, + TensorProto.FLOAT, + [batch_dim, seq_dim, hidden_size], + ) + + del graph.node[embed_idx] + graph.input.remove(ids_input) + graph.input.append(new_embed_input) + graph.initializer.remove(embed_weight) + + # Rewire any consumer of the removed embedding output to the new input. + for n in graph.node: + for i, name in enumerate(n.input): + if name == embed_out_name: + n.input[i] = inputs_embeds_name + + # -------------------- lm_head removal -------------------- + lmh_idx = next( + (i for i, n in enumerate(graph.node) if logits_name in n.output), + None, + ) + if lmh_idx is None: + msg = f"No node produces graph output {logits_name!r}" + raise RuntimeError(msg) + + lmh_node = graph.node[lmh_idx] + init_names = {init.name for init in graph.initializer} + hidden_in: str | None = None + weight_in: str | None = None + for ipt in lmh_node.input: + if ipt in init_names: + weight_in = ipt + else: + hidden_in = ipt + if hidden_in is None: + msg = f"lm_head node {lmh_node.name!r} has no non-initializer input ({list(lmh_node.input)})" + raise RuntimeError(msg) + + logger.info( + "Removing lm_head node %r (%s) — exposing %r as new output %r", + lmh_node.name, + lmh_node.op_type, + hidden_in, + output_hidden_states_name, + ) + + logits_output = next(o for o in graph.output if o.name == logits_name) + new_hidden_output = helper.make_tensor_value_info( + output_hidden_states_name, + TensorProto.FLOAT, + [batch_dim, seq_dim, hidden_size], + ) + + del graph.node[lmh_idx] + graph.output.remove(logits_output) + # Put hidden states first so it mirrors the original logits position. + graph.output.insert(0, new_hidden_output) + + # Rename the producer of ``hidden_in`` to emit the new graph output name. + for n in graph.node: + for i, name in enumerate(n.output): + if name == hidden_in: + n.output[i] = output_hidden_states_name + for i, name in enumerate(n.input): + if name == hidden_in: + n.input[i] = output_hidden_states_name + + if weight_in is not None and not any(weight_in in n.input for n in graph.node): + wi = next(init for init in graph.initializer if init.name == weight_in) + graph.initializer.remove(wi) + + save_onnx(model, output_path) + logger.info("Wrote transformer-only ONNX → %s", output_path) + return output_path + + +__all__ = ["make_transformer_only"] diff --git a/test_qwen 2.py b/test_qwen 2.py new file mode 100644 index 000000000..6a52dee72 --- /dev/null +++ b/test_qwen 2.py @@ -0,0 +1,70 @@ +"""E2E test for Qwen3 decoder-only pipeline. + +Uses sub_model_kwargs to set per-component shape_config: + - decoder_prefill: max_cache_len=256, seq_len=64 + - decoder_gen: max_cache_len=256, seq_len=1 + +Set env var ``QUANTIZE=1`` to also run the MOPS-style Step 3: +transformer-only surgery + winml quantize on both sub-models +(embeddings and lm_head are stripped and not quantized). +""" + +import os + +from transformers import AutoTokenizer + +from winml.modelkit.config import WinMLBuildConfig +from winml.modelkit.models.winml.composite_model import WinMLCompositeModel + +model_id = "Qwen/Qwen3-0.6B" + +model = WinMLCompositeModel.from_pretrained( + model_id, + task="text-generation", + # config=WinMLBuildConfig(quant=None, compile=None), + config=WinMLBuildConfig(quant=None), + precision="fp16", + device="npu", + ep="qnn", + force_rebuild=False, + sub_model_kwargs={ + "decoder_prefill": {"shape_config": {"max_cache_len": 256, "seq_len": 64}}, + "decoder_gen": {"shape_config": {"max_cache_len": 256, "seq_len": 1}}, + }, +) + +# Verify ONNX I/O shapes +for name, sub in model.sub_models.items(): + io = sub.io_config + shapes = dict(zip(io["input_names"], io["input_shapes"])) + print(f"\n=== {name} ===") + for k, v in shapes.items(): + print(f" {k}: {v}") + +tokenizer = AutoTokenizer.from_pretrained(model_id) + +prompt = "8 * 7 = ?" +messages = [{"role": "user", "content": prompt}] +text = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, +) +model_inputs = tokenizer([text], return_tensors="pt").to(model.device) + +generated_ids = model.generate(**model_inputs) + +output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() +content = tokenizer.decode(output_ids, skip_special_tokens=True) +print("\nAnswer:", content) + +if os.environ.get("QUANTIZE") == "1": + # Reuse the already-built decoder_prefill/decoder_gen ONNX files: + # surgery (strip embed + lm_head) + transformer-only quantize. + print("\n=== QUANTIZE=1 — running transformer-only quantization ===") + from qwen3_quantize import quantize_built_model + + quantize_built_model( + model, + model_id=model_id, + max_cache_len=256, + prefill_seq=64, + ) From 1ee316c8350d9a904e5e06a51dacb1a7186658d0 Mon Sep 17 00:00:00 2001 From: spalne Date: Tue, 16 Jun 2026 15:07:46 -0700 Subject: [PATCH 02/17] Quantize transformer-only with fused GQA + GSM8k calibration --- ...e.py => qwen3_transformer_only_quantize.py | 152 ++++---- .../modelkit/models/hf/qwen3_export_ops.py | 211 +++++++++++ .../modelkit/models/hf/qwen3_modeling.py | 237 ++++++++++++ .../models/hf/qwen_transformer_only.py | 354 ++++++++++++++++++ src/winml/modelkit/onnx/__init__.py | 2 - src/winml/modelkit/onnx/qwen_surgery.py | 186 --------- test_qwen 2.py | 70 ---- test_qwen.py | 235 ++++++++++++ 8 files changed, 1100 insertions(+), 347 deletions(-) rename qwen3_quantize.py => qwen3_transformer_only_quantize.py (54%) create mode 100644 src/winml/modelkit/models/hf/qwen3_export_ops.py create mode 100644 src/winml/modelkit/models/hf/qwen3_modeling.py create mode 100644 src/winml/modelkit/models/hf/qwen_transformer_only.py delete mode 100644 src/winml/modelkit/onnx/qwen_surgery.py delete mode 100644 test_qwen 2.py create mode 100644 test_qwen.py diff --git a/qwen3_quantize.py b/qwen3_transformer_only_quantize.py similarity index 54% rename from qwen3_quantize.py rename to qwen3_transformer_only_quantize.py index 655c65e6a..8b4efa9b7 100644 --- a/qwen3_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -1,18 +1,15 @@ -"""Qwen3 transformer-only quantization. +"""Transformer-only w8a16 quantization for Qwen3. -Must be called after the composite Qwen3 model has been built (e.g. by -``test_qwen 2.py``) so that ``decoder_prefill`` / ``decoder_gen`` ONNX files -exist in the winml cache. +Targets the transformer-only ONNX produced by +``qwen_transformer_only.install() + test_qwen.py``: -Pipeline: + - **No embedding/lm_head surgery.** The export already excludes both, + so we feed ``WinMLQuantization`` the file directly. + - **Transformer-shaped calibration feeds.** ``input_hidden_states`` (FP32), + ``past_seq_len`` / ``total_seq_len`` (INT32), ``past_keys_{i}`` / + ``past_values_{i}`` (FP16) — names + dtypes match the exported graph. - 1. Apply ``make_transformer_only`` surgery to each sub-model, producing - ``*_transformer.onnx`` with ``inputs_embeds`` input and - ``output_hidden_states`` output — embeddings and lm_head are stripped - out (ignored, not quantized). - 2. Quantize those transformer-only files via winml-cli's ``quantize_onnx`` - using a calibration reader that runs ``embed_tokens`` in PyTorch on - real text samples. +Run via ``test_qwen.py``. """ from __future__ import annotations @@ -26,7 +23,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from winml.modelkit.models.winml.composite_model import WinMLCompositeModel -from winml.modelkit.onnx import make_transformer_only from winml.modelkit.quant import WinMLQuantizationConfig, quantize_onnx @@ -36,31 +32,28 @@ DEFAULT_MAX_CACHE = 256 DEFAULT_PREFILL_SEQ = 64 DEFAULT_GEN_SEQ = 1 -DEFAULT_NUM_SAMPLES = 16 -DEFAULT_PROMPTS = [ - "Solve: 8 * 7 = ?", - "Translate to French: The weather is nice today.", - "Write a short poem about the ocean.", - "Explain gradient descent in one paragraph.", - "What is the capital of Japan?", - "List three uses of magnesium.", - "Summarize the plot of Hamlet in two sentences.", - "Give a Python one-liner to reverse a string.", -] - - -# --------------------------------------------------------------------------- -# Calibration data reader -# --------------------------------------------------------------------------- - - -class Qwen3TransformerCalibReader: - """Yields calibration feeds for the transformer-only Qwen3 ONNX. - - Runs HF ``embed_tokens`` in PyTorch to produce ``inputs_embeds`` since the - embedding layer was stripped from the ONNX graph. All other inputs - (attention_mask, position_ids, past_{i}_key/value) follow the conventions - used by winml-cli's ``WinMLQwen3Model`` runtime. +DEFAULT_NUM_SAMPLES = 30 +DEFAULT_CALIB_DATASET = "openai/gsm8k" +DEFAULT_CALIB_DATASET_CONFIG = "main" +DEFAULT_CALIB_SPLIT = "train" +DEFAULT_CALIB_SEED = 42 + + +def _load_gsm8k_prompts(num_samples: int) -> list[str]: + """GSM8K train split, shuffled seed=42 for reproducible calibration.""" + from datasets import load_dataset + + ds = load_dataset(DEFAULT_CALIB_DATASET, DEFAULT_CALIB_DATASET_CONFIG) + split = ds[DEFAULT_CALIB_SPLIT].shuffle(seed=DEFAULT_CALIB_SEED) + return [row["question"] for row in split.select(range(num_samples))] + + +class Qwen3TransformerOnlyCalibReader: + """Yields calibration feeds for the transformer-only ONNX. + + Feeds match the exported graph exactly: ``input_hidden_states`` (FP32), + ``past_seq_len`` (INT32 ``[1,1]``), ``total_seq_len`` (INT32 ``[1]``), + and ``past_keys_{i}`` / ``past_values_{i}`` (FP16, full cache buffer). """ def __init__( @@ -73,7 +66,6 @@ def __init__( max_cache_len: int, ) -> None: self.embed = embed_tokens - self.cfg = config self.seq_len = seq_len self.max_cache_len = max_cache_len self.num_layers = config.num_hidden_layers @@ -85,11 +77,8 @@ def __init__( self._iter: Iterator[dict[str, np.ndarray]] | None = None self.rewind() - def _build_samples( - self, token_ids_list: list[torch.Tensor] - ) -> Iterator[dict[str, np.ndarray]]: + def _build_samples(self, token_ids_list: list[torch.Tensor]) -> Iterator[dict[str, np.ndarray]]: for ids in token_ids_list: - # Right-truncate / pad to seq_len so we feed the static graph shape. ids = ids[:, : self.seq_len] real_len = ids.shape[1] if real_len < self.seq_len: @@ -101,25 +90,22 @@ def _build_samples( with torch.no_grad(): embeds = self.embed(ids).to(torch.float32).cpu().numpy() - # attention_mask: ones for real prompt positions placed at the - # END of the max_cache buffer (sliding-window cache convention), - # zeros elsewhere. - attn_mask = np.zeros((1, self.max_cache_len), dtype=np.int64) - attn_mask[0, -real_len:] = 1 - - # position_ids: 0..seq_len-1 (clamped for padding). - position_ids = np.arange(self.seq_len, dtype=np.int64)[None, :] - feed: dict[str, np.ndarray] = { - "inputs_embeds": embeds.astype(np.float32), - "attention_mask": attn_mask, - "position_ids": position_ids, + "input_hidden_states": embeds.astype(np.float32), + # seqlens_k for GQA = (valid context length - 1), i.e. + # ``embeddings.shape[1] - 1``. We pad to seq_len, so the query + # has seq_len valid positions → past_seq_len = seq_len - 1. + # (Using 0 here declares only 1 valid token while feeding a + # seq_len-token query, which makes the GQA prefill kernel read + # out of bounds → native access violation.) + "past_seq_len": np.array([[self.seq_len - 1]], dtype=np.int32), + "total_seq_len": np.array([self.max_cache_len], dtype=np.int32), } kv_shape = (1, self.num_kv_heads, self.max_cache_len, self.head_dim) - zeros = np.zeros(kv_shape, dtype=np.float32) + zeros = np.zeros(kv_shape, dtype=np.float16) for i in range(self.num_layers): - feed[f"past_{i}_key"] = zeros - feed[f"past_{i}_value"] = zeros + feed[f"past_keys_{i}"] = zeros + feed[f"past_values_{i}"] = zeros yield feed def get_next(self) -> dict[str, np.ndarray] | None: @@ -132,16 +118,7 @@ def rewind(self) -> None: self._iter = iter(self._samples) -# --------------------------------------------------------------------------- -# Pipeline -# --------------------------------------------------------------------------- - - -def _tokenize_prompts( - tokenizer: Any, prompts: list[str], num_samples: int -) -> list[torch.Tensor]: - # Cycle through prompts up to num_samples; apply chat template like the - # runtime so calibration distribution matches inference inputs. +def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> list[torch.Tensor]: out: list[torch.Tensor] = [] for i in range(num_samples): prompt = prompts[i % len(prompts)] @@ -166,19 +143,15 @@ def quantize_built_model( weight_type: str = "uint8", activation_type: str = "uint16", ) -> dict[str, Path]: - """Run surgery + transformer-only quantization on an already-built composite. - - Reuses the ONNX files produced by ``WinMLCompositeModel.from_pretrained`` - so this can be called after a build step without re-exporting. + """Quantize the transformer-only ONNX files in-place. - Returns: mapping of sub-model name → quantized ONNX path. + Returns ``{sub_model_name: quantized_path}``. """ + # Locate the un-compiled ONNX for each sub-model (no surgery — file is + # already transformer-only). sub_paths: dict[str, Path] = {} for name, sub in model.sub_models.items(): final_path = Path(sub._onnx_path) - # ``_model.onnx`` is the *compiled* QNN EPContext blob — surgery needs - # the uncompiled fp16 graph. ``build.hf`` emits ``{cache_key}_optimized.onnx`` - # alongside it in the same artifacts directory. if final_path.name.endswith("_model.onnx"): stem = final_path.name[: -len("_model.onnx")] optimized = final_path.with_name(f"{stem}_optimized.onnx") @@ -187,7 +160,7 @@ def quantize_built_model( continue print( f"WARNING: {optimized.name} not found next to {final_path.name}; " - "falling back to the compiled model (surgery will likely fail)." + "falling back to the compiled model." ) sub_paths[name] = final_path @@ -199,7 +172,14 @@ def quantize_built_model( hf_model.eval() embed_tokens = hf_model.get_input_embeddings() tokenizer = AutoTokenizer.from_pretrained(model_id) - token_ids_list = _tokenize_prompts(tokenizer, DEFAULT_PROMPTS, num_samples) + + print( + f"=== Loading {num_samples} GSM8K calibration prompts " + f"({DEFAULT_CALIB_DATASET}/{DEFAULT_CALIB_DATASET_CONFIG}, " + f"split={DEFAULT_CALIB_SPLIT}, seed={DEFAULT_CALIB_SEED}) ===" + ) + prompts = _load_gsm8k_prompts(num_samples) + token_ids_list = _tokenize_prompts(tokenizer, prompts, num_samples) seq_by_sub = { "decoder_prefill": prefill_seq, @@ -213,19 +193,14 @@ def quantize_built_model( continue seq_len = seq_by_sub[sub_name] - transformer_path = fused_path.with_name(fused_path.stem + "_transformer.onnx") - quant_path = transformer_path.with_name( - transformer_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" + quant_path = fused_path.with_name( + fused_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" ) - print(f"\n=== Surgery: {sub_name} (seq_len={seq_len}) ===") + print(f"\n=== Quantize (transformer-only): {sub_name} (seq_len={seq_len}) ===") print(f" in : {fused_path}") - print(f" out: {transformer_path}") - make_transformer_only(fused_path, transformer_path) - - print(f"\n=== Quantize (transformer only): {sub_name} ===") print(f" out: {quant_path}") - reader = Qwen3TransformerCalibReader( + reader = Qwen3TransformerOnlyCalibReader( embed_tokens, hf_model.config, token_ids_list, @@ -239,7 +214,7 @@ def quantize_built_model( calibration_method="minmax", calibration_data=reader, ) - result = quantize_onnx(transformer_path, output_path=quant_path, config=cfg) + result = quantize_onnx(fused_path, output_path=quant_path, config=cfg) if not result.success: print(" FAILED:") for err in result.errors: @@ -253,4 +228,3 @@ def quantize_built_model( print("\n=== Done ===") return quant_paths - diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3_export_ops.py new file mode 100644 index 000000000..61d45f0ef --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3_export_ops.py @@ -0,0 +1,211 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Custom ONNX export ops + the entry point that reshapes HF's Qwen3 modules +for the transformer-only export. + +These reshape the standard HF Qwen3 modules so winml-cli can produce a +QNN-friendly, transformer-only graph: + +- ``LpNormalization`` replaces the eager RMSNorm Mul/Pow/ReduceMean chain. +- ``com.microsoft::GroupQueryAttention`` replaces the eager QKV MatMul + + Softmax + KV-update path (with built-in rotary). +- 1x1 ``Conv`` (NHWC<->NCHW) replaces ``nn.Linear`` for QNN-friendly + projections. + +Everything here operates only on the standard ``transformers.models.qwen3`` +module attributes. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch.onnx import symbolic_helper + + +# ============================================================================= +# Custom ONNX symbolic functions +# ============================================================================= + + +class LpNormOnnxExport(torch.autograd.Function): + """RMSNorm body → ONNX ``LpNormalization`` (p=2 along last dim).""" + + @staticmethod + def symbolic(g, input, axis, p): # noqa: D401 + output_type = input.type().with_sizes(symbolic_helper._get_tensor_sizes(input)) + output = g.op( + "onnx::LpNormalization", + input, + axis_i=int(axis), + p_i=int(p), + ) + return output.setType(output_type) + + @staticmethod + def forward(ctx, input, axis, p): # noqa: ARG004 + return input # placeholder — real compute happens in symbolic + + +class GroupQueryAttentionOnnxExport(torch.autograd.Function): + """Fused Q/K/V + KV-cache + rotary → ``com.microsoft::GroupQueryAttention``.""" + + @staticmethod + def symbolic( + g, + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos_cache, + sin_cache, + do_rotary, + kv_num_heads, + num_heads, + ): + args = [query, key, value, past_key, past_value, seqlens_k, total_sequence_length, cos_cache, sin_cache] + attention_output, present_keys, present_values = g.op( + "com.microsoft::GroupQueryAttention", + *args, + do_rotary_i=int(do_rotary), + kv_num_heads_i=int(kv_num_heads), + num_heads_i=int(num_heads), + outputs=3, + ) + + query_sizes = symbolic_helper._get_tensor_sizes(query) + attention_output.setType(query.type().with_sizes(query_sizes)) + present_keys.setType(past_key.type().with_sizes(symbolic_helper._get_tensor_sizes(past_key))) + present_values.setType(past_value.type().with_sizes(symbolic_helper._get_tensor_sizes(past_value))) + return attention_output, present_keys, present_values + + @staticmethod + def forward( + ctx, + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos_cache, + sin_cache, + do_rotary, + kv_num_heads, + num_heads, + ): # noqa: ARG004 + return query, past_key, past_value # placeholder shapes + + +# ============================================================================= +# 1x1 Conv replacement for nn.Linear +# ============================================================================= + + +class TransposeConv2d1x1Transpose(nn.Module): + """``nn.Linear`` → 1x1 ``Conv2d`` with NHWC<->NCHW permutes.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + weight: torch.nn.Parameter, + bias: torch.nn.Parameter | None = None, + ) -> None: + super().__init__() + # Linear weight is (out, in); Conv2d weight is (out, in, 1, 1). + self.weight = nn.Parameter(weight.data.view(out_channels, in_channels, 1, 1)) + self.bias = bias + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.permute(0, 3, 1, 2) # NHWC -> NCHW + x = torch.nn.functional.conv2d(x, self.weight) + x = x.permute(0, 2, 3, 1) # NCHW -> NHWC + if self.bias is not None: + x = x + self.bias + return x + + @classmethod + def from_linear_module(cls, linear: nn.Linear) -> TransposeConv2d1x1Transpose: + return cls(linear.in_features, linear.out_features, linear.weight, linear.bias) + + +# ============================================================================= +# Apply export prep: bind winml Qwen3 export methods onto a loaded model +# ============================================================================= + + +def apply_transformer_only_export_prep(causal_lm: nn.Module, *, matmul_to_conv: bool = True) -> None: + """Mutate ``Qwen3ForCausalLM`` in-place into the export topology. + + Binds the winml-owned export behaviour from :mod:`.qwen3_modeling` onto each + Qwen3 submodule (runs ``prepare_for_onnx_export`` and rebinds ``forward``). + After this call, ``causal_lm.model(inputs_embeds, past_key_values, + past_seq_len, total_seq_len)`` runs the transformer-only forward. + + Args: + causal_lm: A ``transformers.Qwen3ForCausalLM`` instance. + matmul_to_conv: Swap ``nn.Linear`` projections to 1x1 ``Conv2d`` so + QNN sees them as Conv. + """ + from .qwen3_modeling import ( + WinMLQwen3Attention, + WinMLQwen3DecoderLayer, + WinMLQwen3MLP, + WinMLQwen3Model, + WinMLQwen3RMSNorm, + ) + + def _bind(module: nn.Module, owner: type) -> None: + module.forward = owner.forward.__get__(module, type(module)) + + # Identify Qwen3 submodules by their (stock HF) class name so we don't + # depend on importing ``transformers.models.qwen3`` here. + def _is(module: nn.Module, name: str) -> bool: + return type(module).__name__ == name + + # Patch every RMSNorm first (Qwen3RMSNorm appears at top, in q_norm/k_norm, + # in input/post_attention layernorms). + for mod in causal_lm.modules(): + if _is(mod, "Qwen3RMSNorm"): + WinMLQwen3RMSNorm.prepare_for_onnx_export(mod) + _bind(mod, WinMLQwen3RMSNorm) + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Attention"): + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Attention) + elif _is(mod, "Qwen3MLP"): + # MLP forward is unchanged; only the projections are swapped to Conv. + WinMLQwen3MLP.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + + # HF moved ``rotary_emb`` from ``Qwen3Attention`` up to ``Qwen3Model``; + # the export forward invokes ``self.rotary_emb`` on the attention module, + # so re-attach a reference from the parent model. + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model") and hasattr(mod, "rotary_emb"): + for layer in mod.layers: + layer.self_attn.rotary_emb = mod.rotary_emb + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3DecoderLayer"): + _bind(mod, WinMLQwen3DecoderLayer) + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model"): + WinMLQwen3Model.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Model) + + +__all__ = [ + "GroupQueryAttentionOnnxExport", + "LpNormOnnxExport", + "TransposeConv2d1x1Transpose", + "apply_transformer_only_export_prep", +] diff --git a/src/winml/modelkit/models/hf/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3_modeling.py new file mode 100644 index 000000000..05a70adfe --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3_modeling.py @@ -0,0 +1,237 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""winml-owned Qwen3 model definitions for the transformer-only ONNX export. + +Each class is a plain ``nn.Module`` that carries the export-time behaviour +directly (``prepare_for_onnx_export`` + ``forward``). The export entry point +binds these ``forward`` methods onto the corresponding live Qwen3 submodules, +so the stock eager model is left untouched. + +What each class emits: + +- ``WinMLQwen3RMSNorm`` -> ``onnx::LpNormalization`` body. +- ``WinMLQwen3Attention`` -> ``com.microsoft::GroupQueryAttention`` (built-in + rotary) with optional 1x1 ``Conv`` projections. +- ``WinMLQwen3MLP`` -> 1x1 ``Conv`` projections (NHWC). +- ``WinMLQwen3DecoderLayer`` / ``WinMLQwen3Model`` -> transformer-only forward + that threads the KV cache + seq-len tensors and omits embeddings / lm_head. + +``apply_transformer_only_export_prep`` (in ``qwen3_export_ops``) walks a loaded +``Qwen3ForCausalLM``, calls ``prepare_for_onnx_export`` on each submodule, and +binds the matching ``forward`` from these classes onto it. +""" + +from __future__ import annotations + +from typing import Any + +import torch +import torch.nn as nn + +from .qwen3_export_ops import ( + GroupQueryAttentionOnnxExport, + LpNormOnnxExport, + TransposeConv2d1x1Transpose, +) + + +class WinMLQwen3RMSNorm(nn.Module): + """RMSNorm export variant — ``onnx::LpNormalization`` body.""" + + def prepare_for_onnx_export(self) -> None: + # Pre-multiply the gain into the weight (LpNorm has unit gain). + n = self.weight.numel() + scale = torch.sqrt( + torch.tensor([n], device=self.weight.device, dtype=self.weight.dtype) + ) + if torch.any(self.weight.data != torch.ones_like(self.weight)).item(): + new_w = scale * self.weight + else: + new_w = scale + self.weight = nn.Parameter(new_w) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + out = LpNormOnnxExport.apply(hidden_states, -1, 2) + return self.weight * out + + +class WinMLQwen3MLP(nn.Module): + """MLP export variant — 1x1 Conv projections (forward unchanged).""" + + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + if not matmul_to_conv: + return + self.gate_proj = TransposeConv2d1x1Transpose.from_linear_module(self.gate_proj) + self.up_proj = TransposeConv2d1x1Transpose.from_linear_module(self.up_proj) + self.down_proj = TransposeConv2d1x1Transpose.from_linear_module(self.down_proj) + + +class WinMLQwen3Attention(nn.Module): + """Attention export variant — fused ``GroupQueryAttention`` op.""" + + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + if matmul_to_conv: + self.q_proj = TransposeConv2d1x1Transpose.from_linear_module(self.q_proj) + self.k_proj = TransposeConv2d1x1Transpose.from_linear_module(self.k_proj) + self.v_proj = TransposeConv2d1x1Transpose.from_linear_module(self.v_proj) + self.o_proj = TransposeConv2d1x1Transpose.from_linear_module(self.o_proj) + self._matmul_to_conv = matmul_to_conv # noqa: SLF001 + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None, + past_seq_len: torch.Tensor | None = None, + total_seq_len: torch.Tensor | None = None, + **kwargs: Any, # noqa: ARG002 + ) -> tuple[torch.Tensor, None, tuple[torch.Tensor, torch.Tensor]]: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + input_shape = hidden_states.shape[1:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_norm(query_states.view(hidden_shape)) + key_states = self.k_norm(key_states.view(hidden_shape)) + + num_heads = self.config.num_attention_heads + num_kv_heads = self.config.num_key_value_heads + query_dim = num_heads * self.head_dim + key_dim = num_kv_heads * self.head_dim + query_states = query_states.reshape(1, -1, query_dim) + key_states = key_states.reshape(1, -1, key_dim) + + if self._matmul_to_conv: + value_states = value_states.squeeze(0) + + past_keys, past_values = past_key_value + + # GroupQueryAttention requires Q/K/V/past_K/past_V to share dtype. + # The KV cache is FP16, so cast Q/K/V to the same dtype; otherwise ORT + # type inference rejects the node. + kv_dtype = past_keys.dtype + if query_states.dtype != kv_dtype: + query_states = query_states.to(kv_dtype) + key_states = key_states.to(kv_dtype) + value_states = value_states.to(kv_dtype) + + cos, sin = self.rotary_emb( + value_states, + torch.arange(self.config.max_position_embeddings).unsqueeze(0), + ) + cos = cos.squeeze(0)[:, : cos.shape[-1] // 2] + sin = sin.squeeze(0)[:, : sin.shape[-1] // 2] + if cos.dtype != kv_dtype: + cos = cos.to(kv_dtype) + sin = sin.to(kv_dtype) + + if isinstance(past_seq_len, int): + past_seq_len = torch.tensor(past_seq_len) + past_seq_len = torch.atleast_2d(past_seq_len) + + attention_output, present_keys, present_values = GroupQueryAttentionOnnxExport.apply( + query_states, + key_states, + value_states, + past_keys, + past_values, + past_seq_len, + total_seq_len, + cos, + sin, + 1, # do_rotary + num_kv_heads, + num_heads, + ) + + # Cast back to the residual-stream dtype so the downstream Conv + # (o_proj) sees its expected weight dtype. + if attention_output.dtype != hidden_states.dtype: + attention_output = attention_output.to(hidden_states.dtype) + + if self._matmul_to_conv: + attention_output = attention_output.unsqueeze(0) + + attention_output = self.o_proj(attention_output) + return attention_output, None, (present_keys, present_values) + + +class WinMLQwen3DecoderLayer(nn.Module): + """Decoder-layer export variant — threads KV cache + seq-len kwargs.""" + + def forward( + self, + hidden_states: torch.Tensor, + past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None, + past_seq_len: torch.Tensor | None = None, + total_seq_len: torch.Tensor | None = None, + use_cache: bool = True, + **kwargs: Any, # noqa: ARG002 + ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + attn_out, _, present_kv = self.self_attn( + hidden_states=hidden_states, + past_key_value=past_key_value, + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + ) + hidden_states = residual + attn_out + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if use_cache: + outputs += (present_kv,) + return outputs + + +class WinMLQwen3Model(nn.Module): + """Model export variant — transformer-only body (no embeddings / lm_head).""" + + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + self._matmul_to_conv = matmul_to_conv # noqa: SLF001 + + def forward( + self, + inputs_embeds: torch.Tensor, + past_key_values: list[tuple[torch.Tensor, torch.Tensor]], + past_seq_len: torch.Tensor, + total_seq_len: torch.Tensor, + use_cache: bool = True, + ) -> tuple[torch.Tensor, tuple[tuple[torch.Tensor, torch.Tensor], ...]]: + hidden_states = inputs_embeds + if self._matmul_to_conv: + hidden_states = hidden_states.unsqueeze(0) # NHWC for Conv path + + present_kvs: tuple[tuple[torch.Tensor, torch.Tensor], ...] = () + for idx, layer in enumerate(self.layers): + out = layer( + hidden_states, + past_key_value=past_key_values[idx], + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + use_cache=use_cache, + ) + hidden_states = out[0] + if use_cache: + present_kvs += (out[1],) + + hidden_states = self.norm(hidden_states) + if self._matmul_to_conv: + hidden_states = hidden_states.squeeze(0) + return hidden_states, present_kvs + + +__all__ = [ + "WinMLQwen3Attention", + "WinMLQwen3DecoderLayer", + "WinMLQwen3MLP", + "WinMLQwen3Model", + "WinMLQwen3RMSNorm", +] diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py new file mode 100644 index 000000000..8e30b1fb6 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -0,0 +1,354 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Parallel ``qwen3`` build path that produces a transformer-only ONNX. + +Opt-in via ``install()`` — calling it hot-patches the WinML registries so +that the next ``WinMLAutoModel.from_pretrained("Qwen/Qwen3-*", task="text-generation")`` +exports two transformer-only ONNX files (a prefill/context graph and an +iteration/decode graph) with this I/O: + + Inputs : past_keys_{i}, past_values_{i} (FP16, ``[1, kv_heads, max_cache, head_dim]``), + input_hidden_states (FP32, ``[1, seq_len, hidden]``), + past_seq_len (INT32, ``[1, 1]``), total_seq_len (INT32, ``[1]``) + Outputs: output_hidden_states (FP32), present_keys_{i}, present_values_{i} (FP16) + Ops : ``com.microsoft::GroupQueryAttention`` (do_rotary=1), + ``onnx::LpNormalization`` (RMSNorm), 1x1 ``Conv`` projections. + +The original eager-export path in ``qwen.py`` is left intact — only the +qwen3 entries in the registries are replaced. ``install()`` is idempotent. +""" + +from __future__ import annotations + +import logging +from typing import Any, ClassVar + +import torch +import torch.nn as nn +from optimum.exporters.onnx import OnnxConfig +from optimum.utils import NormalizedConfig +from optimum.utils.input_generators import DummyInputGenerator +from transformers import AutoModelForCausalLM + +from ...config import WinMLBuildConfig +from ...export import register_onnx_overwrite +from ...export.config import WinMLExportConfig +from ..winml import register_specialization +from ..winml.decoder_only import WinMLDecoderOnlyModel +from ..winml.kv_cache import WinMLSlidingWindowCache +from .qwen3_export_ops import apply_transformer_only_export_prep + + +logger = logging.getLogger(__name__) + + +# ============================================================================= +# Wrapper module +# ============================================================================= + + +class QwenTransformerOnlyDecoderWrapper(nn.Module): + """Wraps ``Qwen3ForCausalLM`` for transformer-only export. + + The wrapper applies the export prep (LpNorm RMSNorm, GQA op, 1x1 + Conv projections) in ``__init__`` and exposes a positional ``forward`` + whose argument order matches :class:`QwenTransformerOnlyPrefillIOConfig.inputs`. + Only ``self.model.model`` (the inner ``Qwen3Model``) is invoked at + export time — embedding lookup and ``lm_head`` stay out of the graph. + """ + + def __init__(self, model: nn.Module, num_layers: int) -> None: + super().__init__() + self.model = model + self.num_layers = num_layers + self.config = model.config + apply_transformer_only_export_prep(model, matmul_to_conv=True) + + @classmethod + def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenTransformerOnlyDecoderWrapper: + kwargs.setdefault("torch_dtype", torch.float32) + model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **kwargs) + model.config._attn_implementation = "eager" + wrapper = cls(model, model.config.num_hidden_layers) + wrapper.eval() + return wrapper + + def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, ...]: + return tuple(inputs.values()) + + def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: + """Positional inputs (order matches OnnxConfig.inputs): + + past_keys_0, past_values_0, ..., past_keys_{L-1}, past_values_{L-1}, + input_hidden_states, past_seq_len, total_seq_len + + Returns ``(output_hidden_states, present_keys_0, present_values_0, ...)``. + """ + kv_args = args[: 2 * self.num_layers] + input_hidden_states = args[2 * self.num_layers] + past_seq_len = args[2 * self.num_layers + 1] + total_seq_len = args[2 * self.num_layers + 2] + + past_key_values = [ + (kv_args[2 * i], kv_args[2 * i + 1]) for i in range(self.num_layers) + ] + + hidden_states, present_kvs = self.model.model( + inputs_embeds=input_hidden_states, + past_key_values=past_key_values, + past_seq_len=past_seq_len, + total_seq_len=total_seq_len, + use_cache=True, + ) + + out: list[torch.Tensor] = [hidden_states] + for k, v in present_kvs: + out.extend([k, v]) + return tuple(out) + + +# ============================================================================= +# Dummy input generators (transformer-only I/O) +# ============================================================================= + + +class _TransformerOnlyHiddenStateGenerator(DummyInputGenerator): + """Generates ``input_hidden_states`` (FP32, ``[1, seq_len, hidden]``).""" + + SUPPORTED_INPUT_NAMES = ("input_hidden_states",) + + _default_seq_len: ClassVar[int] = 1 + + def __init__( + self, + task: str, + normalized_config: Any, + batch_size: int = 1, + seq_len: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.hidden_size = normalized_config.hidden_size + self.seq_len = seq_len or getattr(normalized_config, "seq_len", self._default_seq_len) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + if input_name == "input_hidden_states": + return torch.randn(self.batch_size, self.seq_len, self.hidden_size, dtype=torch.float32) + raise ValueError(f"Unknown input: {input_name}") + + +class _TransformerOnlyHiddenStatePrefillGenerator(_TransformerOnlyHiddenStateGenerator): + _default_seq_len = 64 + + +class _TransformerOnlySeqLenGenerator(DummyInputGenerator): + """Generates ``past_seq_len`` (INT32 ``[1,1]``) and ``total_seq_len`` (INT32 ``[1]``).""" + + SUPPORTED_INPUT_NAMES = ("past_seq_len", "total_seq_len") + + def __init__(self, task: str, normalized_config: Any, **kwargs: Any) -> None: # noqa: ARG002 + self.max_cache_len = normalized_config.max_cache_len + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + if input_name == "past_seq_len": + return torch.zeros((1, 1), dtype=torch.int32) + if input_name == "total_seq_len": + return torch.tensor([self.max_cache_len], dtype=torch.int32) + raise ValueError(f"Unknown input: {input_name}") + + +class _TransformerOnlyKvCacheGenerator(DummyInputGenerator): + """Generates ``past_keys_{i}`` / ``past_values_{i}`` (FP16).""" + + SUPPORTED_INPUT_NAMES = () # built dynamically in __init__ + + def __init__( + self, + task: str, + normalized_config: NormalizedConfig, + batch_size: int = 1, + max_cache_len: int | None = None, + **kwargs: Any, + ) -> None: + self.batch_size = batch_size + self.num_layers: int = normalized_config.num_layers + self.num_heads: int = normalized_config.num_attention_heads # KV heads (NormalizedConfig maps it) + self.head_dim: int = normalized_config.head_dim + self.max_cache_len: int = max_cache_len or normalized_config.max_cache_len + self.SUPPORTED_INPUT_NAMES = tuple( + name for i in range(self.num_layers) for name in (f"past_keys_{i}", f"past_values_{i}") + ) + + def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + shape = (self.batch_size, self.num_heads, self.max_cache_len, self.head_dim) + return torch.zeros(shape, dtype=torch.float16) + + +# ============================================================================= +# OnnxConfigs — transformer-only I/O layout +# ============================================================================= + + +_QWEN_TRANSFORMER_ONLY_NORMALIZED = NormalizedConfig.with_args( + hidden_size="hidden_size", + num_layers="num_hidden_layers", + num_attention_heads="num_key_value_heads", # KV heads (GQA) + head_dim="head_dim", + max_cache_len="max_position_embeddings", + vocab_size="vocab_size", + allow_new=True, +) + + +def _transformer_only_inputs(num_layers: int, kv_seq_axis: str = "max_seq_len") -> dict[str, dict[int, str]]: + """Input ordering: past KV pairs, then hidden states, then seq lens.""" + result: dict[str, dict[int, str]] = {} + for i in range(num_layers): + result[f"past_keys_{i}"] = {2: kv_seq_axis} + result[f"past_values_{i}"] = {2: kv_seq_axis} + result["input_hidden_states"] = {1: "seq_len"} + result["past_seq_len"] = {} + result["total_seq_len"] = {} + return result + + +def _transformer_only_outputs(num_layers: int, kv_seq_axis: str = "max_seq_len") -> dict[str, dict[int, str]]: + result: dict[str, dict[int, str]] = {"output_hidden_states": {1: "seq_len"}} + for i in range(num_layers): + result[f"present_keys_{i}"] = {2: kv_seq_axis} + result[f"present_values_{i}"] = {2: kv_seq_axis} + return result + + +class QwenTransformerOnlyPrefillIOConfig(OnnxConfig): + """Prefill (seq=64) — transformer-only I/O.""" + + NORMALIZED_CONFIG_CLASS = _QWEN_TRANSFORMER_ONLY_NORMALIZED + DUMMY_INPUT_GENERATOR_CLASSES = ( + _TransformerOnlyKvCacheGenerator, + _TransformerOnlyHiddenStatePrefillGenerator, + _TransformerOnlySeqLenGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + return _transformer_only_inputs(self._normalized_config.num_layers) + + @property + def outputs(self) -> dict[str, dict[int, str]]: + return _transformer_only_outputs(self._normalized_config.num_layers) + + +class QwenTransformerOnlyGenIOConfig(OnnxConfig): + """Generation (seq=1) — transformer-only I/O.""" + + NORMALIZED_CONFIG_CLASS = _QWEN_TRANSFORMER_ONLY_NORMALIZED + DUMMY_INPUT_GENERATOR_CLASSES = ( + _TransformerOnlyKvCacheGenerator, + _TransformerOnlyHiddenStateGenerator, + _TransformerOnlySeqLenGenerator, + ) + + @property + def inputs(self) -> dict[str, dict[int, str]]: + return _transformer_only_inputs(self._normalized_config.num_layers) + + @property + def outputs(self) -> dict[str, dict[int, str]]: + return _transformer_only_outputs(self._normalized_config.num_layers) + + +# ============================================================================= +# Build config — TorchScript exporter required for the custom autograd ops +# ============================================================================= + + +QWEN_TRANSFORMER_ONLY_CONFIG = WinMLBuildConfig( + export=WinMLExportConfig(dynamo=False, opset_version=18), + # Pure graph (no post-export RMSNorm fusion / matmul-add fusion). + optim=None, +) + + +# ============================================================================= +# Composite inference wrapper (placeholder so the build pipeline finds a +# composite class — generation isn't yet wired for the transformer-only +# I/O signature). +# ============================================================================= + + +class WinMLQwen3TransformerOnlyModel(WinMLDecoderOnlyModel): + """Composite handle for the transformer-only Qwen3 build (export only). + + ``generate()`` is **not** functional with this build path — the inference + feeds and KV update logic still target the eager I/O signature. Use the + eager :class:`WinMLQwen3Model` for generation; use this class to produce + the transformer-only ONNX for downstream quantization. + """ + + _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { + "decoder_prefill": "feature-extraction", + "decoder_gen": "text2text-generation", + } + + @classmethod + def get_cache_class(cls) -> type: + return WinMLSlidingWindowCache + + +# ============================================================================= +# install() — hot-patch the registries +# ============================================================================= + + +_INSTALLED = False + + +def install() -> None: + """Replace the qwen3 entries in WinML registries with the transformer-only variants. + + Idempotent. After this call, building any qwen3 model via + :class:`~winml.modelkit.models.winml.composite_model.WinMLCompositeModel` + or :class:`~winml.modelkit.models.auto.WinMLAutoModel` produces + transformer-only ONNX files. + """ + global _INSTALLED + if _INSTALLED: + return + + # 1) Per-model build config + wrapper-class lookup live on the parent + # ``models.hf`` package as module-level dicts; mutating them is the + # documented hook for adding/overriding a model_type. + from .. import hf as _hf_pkg # noqa: PLC0415 + + _hf_pkg.MODEL_BUILD_CONFIGS["qwen3"] = QWEN_TRANSFORMER_ONLY_CONFIG + _hf_pkg.MODEL_CLASS_MAPPING[("qwen3", "feature-extraction")] = QwenTransformerOnlyDecoderWrapper + _hf_pkg.MODEL_CLASS_MAPPING[("qwen3", "text2text-generation")] = QwenTransformerOnlyDecoderWrapper + + # 2) Optimum OnnxConfig (overwrites existing registration). + register_onnx_overwrite("qwen3", "feature-extraction", library_name="transformers")(QwenTransformerOnlyPrefillIOConfig) + register_onnx_overwrite("qwen3", "text2text-generation", library_name="transformers")(QwenTransformerOnlyGenIOConfig) + + # 3) Inference specialization (still GenericTask — wrapper returns raw KV). + register_specialization("qwen3", "feature-extraction", "WinMLModelForGenericTask") + register_specialization("qwen3", "text2text-generation", "WinMLModelForGenericTask") + + # 4) Composite registry — swap to the transformer-only handle. + from ..winml.composite_model import COMPOSITE_MODEL_REGISTRY + + COMPOSITE_MODEL_REGISTRY[("qwen3", "text-generation")] = WinMLQwen3TransformerOnlyModel + + _INSTALLED = True + logger.info("qwen_transformer_only: transformer-only export path installed for qwen3.") + + +__all__ = [ + "QWEN_TRANSFORMER_ONLY_CONFIG", + "QwenTransformerOnlyDecoderWrapper", + "QwenTransformerOnlyGenIOConfig", + "QwenTransformerOnlyPrefillIOConfig", + "WinMLQwen3TransformerOnlyModel", + "install", +] diff --git a/src/winml/modelkit/onnx/__init__.py b/src/winml/modelkit/onnx/__init__.py index 0287a2ff7..a3bc49d51 100644 --- a/src/winml/modelkit/onnx/__init__.py +++ b/src/winml/modelkit/onnx/__init__.py @@ -19,7 +19,6 @@ from .io import InputTensorSpec, OutputTensorSpec, generate_inputs_from_onnx, get_io_config from .metadata import capture_metadata, restore_metadata from .persistence import cleanup_onnx, load_onnx, save_onnx -from .qwen_surgery import make_transformer_only from .shape import infer_onnx_shapes, infer_shapes from .utils import EXTERNAL_DATA_THRESHOLD, check_onnx_model, get_model_size @@ -42,7 +41,6 @@ "is_compiled_onnx", "is_quantized_onnx", "load_onnx", - "make_transformer_only", "remove_optional_from_type_annotation", "restore_metadata", "save_onnx", diff --git a/src/winml/modelkit/onnx/qwen_surgery.py b/src/winml/modelkit/onnx/qwen_surgery.py deleted file mode 100644 index cd49ee5ec..000000000 --- a/src/winml/modelkit/onnx/qwen_surgery.py +++ /dev/null @@ -1,186 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Ad-hoc ONNX surgery to turn a Qwen3 decoder ONNX into a transformer-only graph. - -Applied as a post-export surgery on the fused decoder ONNX produced by -``WinMLQwen3Model`` (``decoder_prefill.onnx`` / ``decoder_gen.onnx``). - -The resulting transformer-only ONNX has: - - ``input_ids`` graph input replaced by ``inputs_embeds`` (FLOAT, - ``[batch, seq, hidden_size]``) — the upstream embedding Gather is - removed. - - ``logits`` graph output replaced by ``output_hidden_states`` - (FLOAT, ``[batch, seq, hidden_size]``) — the final ``lm_head`` MatMul - is removed. -""" - -from __future__ import annotations - -import logging -from pathlib import Path - -import onnx -from onnx import TensorProto, helper - -from .persistence import load_onnx, save_onnx - - -logger = logging.getLogger(__name__) - - -def _dim(d: onnx.TensorShapeProto.Dimension) -> int | str: - if d.HasField("dim_value"): - return d.dim_value - return d.dim_param or "?" - - -def make_transformer_only( - model_path: str | Path, - output_path: str | Path, - *, - input_ids_name: str = "input_ids", - logits_name: str = "logits", - inputs_embeds_name: str = "inputs_embeds", - output_hidden_states_name: str = "output_hidden_states", -) -> Path: - """Strip the embedding Gather and the lm_head MatMul from a Qwen3 ONNX. - - Args: - model_path: Path to the fused decoder ONNX (logits output, input_ids input). - output_path: Destination for the transformer-only ONNX. - input_ids_name: Name of the input_ids graph input to drop. - logits_name: Name of the logits graph output to drop. - inputs_embeds_name: Display name for the new embeddings input - (used only for logging; the actual tensor keeps its existing - internal name so downstream nodes need no rewiring). - output_hidden_states_name: Display name for the new hidden-state output. - - Returns: - The output path. - """ - model_path = Path(model_path) - output_path = Path(output_path) - output_path.parent.mkdir(parents=True, exist_ok=True) - - model = load_onnx(model_path, load_weights=True, validate=False) - graph = model.graph - init_by_name = {init.name: init for init in graph.initializer} - - # -------------------- Embedding removal -------------------- - embed_idx = next( - (i for i, n in enumerate(graph.node) if input_ids_name in n.input), - None, - ) - if embed_idx is None: - msg = f"No node consumes graph input {input_ids_name!r}" - raise RuntimeError(msg) - - embed_node = graph.node[embed_idx] - embed_out_name = embed_node.output[0] - - embed_weight = None - for ipt in embed_node.input: - init = init_by_name.get(ipt) - if init is not None and len(init.dims) == 2: - embed_weight = init - break - if embed_weight is None: - msg = f"Could not find 2-D embedding weight initializer on node {embed_node.name!r}" - raise RuntimeError(msg) - hidden_size = int(embed_weight.dims[1]) - - ids_input = next(i for i in graph.input if i.name == input_ids_name) - batch_dim = _dim(ids_input.type.tensor_type.shape.dim[0]) - seq_dim = _dim(ids_input.type.tensor_type.shape.dim[1]) - - logger.info( - "Removing embedding node %r (%s) — exposing %r as new input %r [%s, %s, %d]", - embed_node.name, - embed_node.op_type, - embed_out_name, - inputs_embeds_name, - batch_dim, - seq_dim, - hidden_size, - ) - - new_embed_input = helper.make_tensor_value_info( - inputs_embeds_name, - TensorProto.FLOAT, - [batch_dim, seq_dim, hidden_size], - ) - - del graph.node[embed_idx] - graph.input.remove(ids_input) - graph.input.append(new_embed_input) - graph.initializer.remove(embed_weight) - - # Rewire any consumer of the removed embedding output to the new input. - for n in graph.node: - for i, name in enumerate(n.input): - if name == embed_out_name: - n.input[i] = inputs_embeds_name - - # -------------------- lm_head removal -------------------- - lmh_idx = next( - (i for i, n in enumerate(graph.node) if logits_name in n.output), - None, - ) - if lmh_idx is None: - msg = f"No node produces graph output {logits_name!r}" - raise RuntimeError(msg) - - lmh_node = graph.node[lmh_idx] - init_names = {init.name for init in graph.initializer} - hidden_in: str | None = None - weight_in: str | None = None - for ipt in lmh_node.input: - if ipt in init_names: - weight_in = ipt - else: - hidden_in = ipt - if hidden_in is None: - msg = f"lm_head node {lmh_node.name!r} has no non-initializer input ({list(lmh_node.input)})" - raise RuntimeError(msg) - - logger.info( - "Removing lm_head node %r (%s) — exposing %r as new output %r", - lmh_node.name, - lmh_node.op_type, - hidden_in, - output_hidden_states_name, - ) - - logits_output = next(o for o in graph.output if o.name == logits_name) - new_hidden_output = helper.make_tensor_value_info( - output_hidden_states_name, - TensorProto.FLOAT, - [batch_dim, seq_dim, hidden_size], - ) - - del graph.node[lmh_idx] - graph.output.remove(logits_output) - # Put hidden states first so it mirrors the original logits position. - graph.output.insert(0, new_hidden_output) - - # Rename the producer of ``hidden_in`` to emit the new graph output name. - for n in graph.node: - for i, name in enumerate(n.output): - if name == hidden_in: - n.output[i] = output_hidden_states_name - for i, name in enumerate(n.input): - if name == hidden_in: - n.input[i] = output_hidden_states_name - - if weight_in is not None and not any(weight_in in n.input for n in graph.node): - wi = next(init for init in graph.initializer if init.name == weight_in) - graph.initializer.remove(wi) - - save_onnx(model, output_path) - logger.info("Wrote transformer-only ONNX → %s", output_path) - return output_path - - -__all__ = ["make_transformer_only"] diff --git a/test_qwen 2.py b/test_qwen 2.py deleted file mode 100644 index 6a52dee72..000000000 --- a/test_qwen 2.py +++ /dev/null @@ -1,70 +0,0 @@ -"""E2E test for Qwen3 decoder-only pipeline. - -Uses sub_model_kwargs to set per-component shape_config: - - decoder_prefill: max_cache_len=256, seq_len=64 - - decoder_gen: max_cache_len=256, seq_len=1 - -Set env var ``QUANTIZE=1`` to also run the MOPS-style Step 3: -transformer-only surgery + winml quantize on both sub-models -(embeddings and lm_head are stripped and not quantized). -""" - -import os - -from transformers import AutoTokenizer - -from winml.modelkit.config import WinMLBuildConfig -from winml.modelkit.models.winml.composite_model import WinMLCompositeModel - -model_id = "Qwen/Qwen3-0.6B" - -model = WinMLCompositeModel.from_pretrained( - model_id, - task="text-generation", - # config=WinMLBuildConfig(quant=None, compile=None), - config=WinMLBuildConfig(quant=None), - precision="fp16", - device="npu", - ep="qnn", - force_rebuild=False, - sub_model_kwargs={ - "decoder_prefill": {"shape_config": {"max_cache_len": 256, "seq_len": 64}}, - "decoder_gen": {"shape_config": {"max_cache_len": 256, "seq_len": 1}}, - }, -) - -# Verify ONNX I/O shapes -for name, sub in model.sub_models.items(): - io = sub.io_config - shapes = dict(zip(io["input_names"], io["input_shapes"])) - print(f"\n=== {name} ===") - for k, v in shapes.items(): - print(f" {k}: {v}") - -tokenizer = AutoTokenizer.from_pretrained(model_id) - -prompt = "8 * 7 = ?" -messages = [{"role": "user", "content": prompt}] -text = tokenizer.apply_chat_template( - messages, tokenize=False, add_generation_prompt=True, enable_thinking=False, -) -model_inputs = tokenizer([text], return_tensors="pt").to(model.device) - -generated_ids = model.generate(**model_inputs) - -output_ids = generated_ids[0][len(model_inputs.input_ids[0]) :].tolist() -content = tokenizer.decode(output_ids, skip_special_tokens=True) -print("\nAnswer:", content) - -if os.environ.get("QUANTIZE") == "1": - # Reuse the already-built decoder_prefill/decoder_gen ONNX files: - # surgery (strip embed + lm_head) + transformer-only quantize. - print("\n=== QUANTIZE=1 — running transformer-only quantization ===") - from qwen3_quantize import quantize_built_model - - quantize_built_model( - model, - model_id=model_id, - max_cache_len=256, - prefill_seq=64, - ) diff --git a/test_qwen.py b/test_qwen.py new file mode 100644 index 000000000..f958c2932 --- /dev/null +++ b/test_qwen.py @@ -0,0 +1,235 @@ +"""E2E test for the transformer-only Qwen3 export path. + +Produces two transformer-only ONNX files whose I/O matches +``qwen3_gqa_fp16_ctx.onnx`` / ``qwen3_gqa_fp16_iter.onnx``: + + decoder_prefill: input_hidden_states [1, 64, 1024] → output_hidden_states + KV + decoder_gen : input_hidden_states [1, 1, 1024] → output_hidden_states + KV + +with FP16 past/present KV named ``past_keys_{i}`` / ``past_values_{i}``, +``com.microsoft::GroupQueryAttention``, ``LpNormalization``, and 1x1 Conv +projections. + +Important: ``install()`` MUST be called before importing the composite model +machinery so the registry hot-patches take effect. + +Generation (``model.generate(...)``) is NOT supported by this build path — +the inference feeds in ``WinMLDecoderOnlyModel`` still target the eager +I/O signature. Use the eager ``WinMLQwen3Model`` build path for end-to-end +generation. + +Run:: + + python test_qwen_transformer_only.py + +This builds each transformer sub-model and then runs the w8a16 +quantization on the exported transformer ONNX files (no surgery needed — +files are already transformer-only). +""" + +import os +import sys +import pathlib +import subprocess + +# Put the in-repo `src/` ahead of site-packages so `import winml` always +# resolves to the editable source tree — no manual copy-to-venv needed. +_repo_root = pathlib.Path(__file__).resolve().parent +sys.path.insert(0, str(_repo_root / "src")) +sys.path.insert(0, str(_repo_root)) + +model_id = "Qwen/Qwen3-0.6B" +MAX_CACHE = 256 + +# component name -> (HF task, seq_len, artifact prefix). Order matters +# (prefill first). The prefix is how the built npu_ctx file is named so the +# parent can verify success by artifact appearance (the build segfaults on +# native QNN/ORT teardown AFTER writing the file, so exit codes are unreliable). +SUB_MODELS = { + "decoder_prefill": ("feature-extraction", 64, "feat_"), + "decoder_gen": ("text2text-generation", 1, "txt2txt_"), +} + +ARTIFACTS_DIR = ( + pathlib.Path.home() / ".cache" / "winml" / "artifacts" / model_id.replace("/", "_") +) + + +def _latest_ctx_mtime(prefix: str) -> float: + """Newest mtime of a ``{prefix}*_optimized_npu_ctx.onnx`` artifact, or 0.""" + files = list(ARTIFACTS_DIR.glob(f"{prefix}*_optimized_npu_ctx.onnx")) + return max((f.stat().st_mtime for f in files), default=0.0) + + +def _build_one(task: str, seq_len: int) -> None: + """Build a SINGLE transformer sub-model in this (fresh) process. + + Invoked as a subprocess by ``main()`` so each sub-model exports in a + clean interpreter — building both in one process leaves PyTorch/ORT + state from the first build that corrupts/kills the second. + """ + from winml.modelkit.models.hf.qwen_transformer_only import install as install_qwen_transformer_only + + install_qwen_transformer_only() + + from winml.modelkit.config import WinMLBuildConfig + from winml.modelkit.models.auto import WinMLAutoModel + + WinMLAutoModel.from_pretrained( + model_id, + task=task, + config=WinMLBuildConfig(quant=None, compile=None), + precision="fp16", + device="npu", + ep="qnn", + force_rebuild=True, + shape_config={"max_cache_len": MAX_CACHE, "seq_len": seq_len}, + ) + # The QNN/ORT teardown segfaults (0xC0000005) on interpreter shutdown + # AFTER the artifact is fully written. Skip the buggy cleanup with a hard + # exit so the parent sees a clean exit code 0. + print(f"BUILD COMPLETE: task={task} seq_len={seq_len}", flush=True) + sys.stdout.flush() + sys.stderr.flush() + os._exit(0) + + +def _find_optimized(prefix: str) -> pathlib.Path: + """Locate the cached transformer-only ``{prefix}*_optimized.onnx`` file.""" + cands = [ + p for p in ARTIFACTS_DIR.glob(f"{prefix}*_optimized.onnx") + if not p.name.endswith("_optimized_npu_ctx.onnx") + ] + if not cands: + raise FileNotFoundError( + f"No {prefix}*_optimized.onnx in {ARTIFACTS_DIR} — build the sub-model first." + ) + return max(cands, key=lambda p: p.stat().st_mtime) + + +class _SubShim: + """Minimal stand-in exposing the ``_onnx_path`` quant needs.""" + + def __init__(self, onnx_path: pathlib.Path): + self._onnx_path = str(onnx_path) + + +class _ModelShim: + """Minimal stand-in exposing ``sub_models`` for ``quantize_built_model``.""" + + def __init__(self, sub_models: dict): + self.sub_models = sub_models + + +def _run_quant() -> None: + """Quantize the cached transformer ONNX files (no composite/QNN load). + + Runs as its own subprocess so any ORT teardown crash can't poison the + parent. Builds a shim ``model`` whose ``sub_models[name]._onnx_path`` + point straight at the cached ``*_optimized.onnx`` files. + """ + # Dump a native C-stack if the calibration InferenceSession segfaults + # (otherwise the crash is silent — no Python traceback). + import faulthandler + faulthandler.enable() + + from qwen3_transformer_only_quantize import quantize_built_model + + sub_models = { + name: _SubShim(_find_optimized(prefix)) + for name, (_task, _seq, prefix) in SUB_MODELS.items() + } + model = _ModelShim(sub_models) + print("=== Running transformer w8a16 quantization ===", flush=True) + for name, sub in sub_models.items(): + print(f" {name}: {sub._onnx_path}", flush=True) + + try: + quantize_built_model( + model, + model_id=model_id, + max_cache_len=MAX_CACHE, + prefill_seq=64, + ) + except BaseException: + import traceback + print("QUANT FAILED with exception:", flush=True) + traceback.print_exc() + sys.stdout.flush() + sys.stderr.flush() + raise + print("QUANT COMPLETE", flush=True) + sys.stdout.flush() + sys.stderr.flush() + os._exit(0) + + +def main() -> None: + # 1) Build each sub-model in its OWN subprocess (fresh state each time). + # Judge success by whether a FRESH npu_ctx artifact appeared, NOT by the + # subprocess exit code: the native QNN/ORT layer segfaults (0xC0000005) + # on teardown AFTER the artifact is fully written to disk. + import time as _time + + for name, (task, seq_len, prefix) in SUB_MODELS.items(): + print(f"\n########## BUILD {name} (task={task}, seq_len={seq_len}) ##########", flush=True) + before = _latest_ctx_mtime(prefix) + start = _time.time() + rc = subprocess.run( + [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), + "--build-sub", task, str(seq_len)], + cwd=str(_repo_root), + ).returncode + + after = _latest_ctx_mtime(prefix) + if after > before and after >= start - 1: + status = "OK" if rc == 0 else f"OK (ignored teardown exit {rc})" + print(f"########## {name} {status}: fresh {prefix}*_optimized_npu_ctx.onnx ##########", flush=True) + else: + raise SystemExit( + f"Sub-model build failed for {name} (exit {rc}) — " + f"no fresh {prefix}*_optimized_npu_ctx.onnx in {ARTIFACTS_DIR}" + ) + + # 2) Report the built transformer-only ONNX files (no composite/QNN load — + # that creates QNN EP sessions that segfault the parent on teardown). + for name, (_task, _seq, prefix) in SUB_MODELS.items(): + print(f"\n=== {name} ===") + print(f" optimized : {_find_optimized(prefix).name}") + ctx = sorted(ARTIFACTS_DIR.glob(f"{prefix}*_optimized_npu_ctx.onnx")) + if ctx: + print(f" npu_ctx : {ctx[-1].name}") + + # 3) Quantization — run in its OWN subprocess for the same teardown-crash + # isolation. Judge by whether quant files appeared. + print("\n########## QUANTIZE ##########", flush=True) + before = max( + (p.stat().st_mtime for p in ARTIFACTS_DIR.glob("*quant.onnx")), + default=0.0, + ) + qstart = _time.time() + rc = subprocess.run( + [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), "--quant"], + cwd=str(_repo_root), + ).returncode + after_files = list(ARTIFACTS_DIR.glob("*quant.onnx")) + after = max((p.stat().st_mtime for p in after_files), default=0.0) + if after > before and after >= qstart - 1: + status = "OK" if rc == 0 else f"OK (ignored teardown exit {rc})" + print(f"########## QUANTIZE {status} ##########", flush=True) + for p in sorted(after_files, key=lambda x: x.stat().st_mtime)[-len(SUB_MODELS):]: + print(f" {p.name}", flush=True) + else: + raise SystemExit( + f"Quantization failed (exit {rc}) — no fresh *quant.onnx in {ARTIFACTS_DIR}" + ) + + +if __name__ == "__main__": + if len(sys.argv) >= 4 and sys.argv[1] == "--build-sub": + _build_one(sys.argv[2], int(sys.argv[3])) + elif len(sys.argv) >= 2 and sys.argv[1] == "--quant": + _run_quant() + else: + main() + From 78815fd97d7458edc745185d65c8aefdb7b82d67 Mon Sep 17 00:00:00 2001 From: spalne Date: Mon, 22 Jun 2026 10:44:41 -0700 Subject: [PATCH 03/17] Fix Qwen3 w8a16 quant: symmetric int8 weights + exclude GQA from QDQ --- qwen3_transformer_only_quantize.py | 33 ++++++++++++++++++++++++++- src/winml/modelkit/quant/config.py | 9 ++++++++ src/winml/modelkit/quant/quantizer.py | 19 ++++++++++++--- 3 files changed, 57 insertions(+), 4 deletions(-) diff --git a/qwen3_transformer_only_quantize.py b/qwen3_transformer_only_quantize.py index 8b4efa9b7..3ae895ae2 100644 --- a/qwen3_transformer_only_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -133,6 +133,23 @@ def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> l return out +def _gqa_node_names(onnx_path: Path) -> list[str]: + """Return the names of every GroupQueryAttention node in ``onnx_path``. + + These nodes are excluded from quantization so ORT leaves both their + inputs and output in float (``... -> Cast -> GQA -> Cast``), matching + the reference graph which keeps attention entirely out of QDQ. + """ + import onnx + + model = onnx.load(str(onnx_path), load_external_data=False) + return [ + n.name + for n in model.graph.node + if n.op_type == "GroupQueryAttention" and n.name + ] + + def quantize_built_model( model: WinMLCompositeModel, *, @@ -140,7 +157,7 @@ def quantize_built_model( max_cache_len: int = DEFAULT_MAX_CACHE, prefill_seq: int = DEFAULT_PREFILL_SEQ, num_samples: int = DEFAULT_NUM_SAMPLES, - weight_type: str = "uint8", + weight_type: str = "int8", activation_type: str = "uint16", ) -> dict[str, Path]: """Quantize the transformer-only ONNX files in-place. @@ -200,6 +217,11 @@ def quantize_built_model( print(f"\n=== Quantize (transformer-only): {sub_name} (seq_len={seq_len}) ===") print(f" in : {fused_path}") print(f" out: {quant_path}") + gqa_nodes = _gqa_node_names(fused_path) + print( + f" excluding {len(gqa_nodes)} GroupQueryAttention nodes from " + "quantization (inputs + output stay float, Cast -> GQA -> Cast)" + ) reader = Qwen3TransformerOnlyCalibReader( embed_tokens, hf_model.config, @@ -213,6 +235,15 @@ def quantize_built_model( activation_type=activation_type, # type: ignore[arg-type] calibration_method="minmax", calibration_data=reader, + # w8a16: symmetric int8 weights (zp=0) + asymmetric uint16 + # activations, matching the reference quantization. + weight_symmetric=True, + activation_symmetric=False, + # ORT treats GroupQueryAttention as quantizable and wraps both its + # inputs and output in QDQ. The reference keeps attention entirely + # in float (Cast -> GQA -> Cast), so exclude the GQA nodes from + # quantization so no QDQ is inserted around them. + nodes_to_exclude=gqa_nodes, ) result = quantize_onnx(fused_path, output_path=quant_path, config=cfg) if not result.success: diff --git a/src/winml/modelkit/quant/config.py b/src/winml/modelkit/quant/config.py index b9709cc0e..6132be599 100644 --- a/src/winml/modelkit/quant/config.py +++ b/src/winml/modelkit/quant/config.py @@ -68,6 +68,11 @@ class WinMLQuantizationConfig: # Quantization options per_channel: bool = False symmetric: bool = False + # Optional per-target symmetry overrides. When None, fall back to + # ``symmetric``. Lets w8a16 use symmetric weights (int8, zp=0) together + # with asymmetric activations (uint16). + weight_symmetric: bool | None = None + activation_symmetric: bool | None = None # Output settings save_calibration: bool = False @@ -98,6 +103,8 @@ def to_dict(self) -> dict: "activation_type": self.activation_type, "per_channel": self.per_channel, "symmetric": self.symmetric, + "weight_symmetric": self.weight_symmetric, + "activation_symmetric": self.activation_symmetric, "save_calibration": self.save_calibration, "distribution": self.distribution, "seed": self.seed, @@ -139,6 +146,8 @@ def from_dict(cls, data: dict) -> WinMLQuantizationConfig: activation_type=data.get("activation_type", "uint8"), per_channel=data.get("per_channel", False), symmetric=data.get("symmetric", False), + weight_symmetric=data.get("weight_symmetric"), + activation_symmetric=data.get("activation_symmetric"), save_calibration=data.get("save_calibration", False), distribution=data.get("distribution", "uniform"), seed=data.get("seed"), diff --git a/src/winml/modelkit/quant/quantizer.py b/src/winml/modelkit/quant/quantizer.py index c562599de..e5fd30df3 100644 --- a/src/winml/modelkit/quant/quantizer.py +++ b/src/winml/modelkit/quant/quantizer.py @@ -132,10 +132,23 @@ def quantize_onnx( activation_type = activation_type_map[config.activation_type] calibrate_method = calibration_method_map[config.calibration_method] - # Build extra options + # Build extra options. Weight/activation symmetry can be controlled + # independently (e.g. w8a16 = symmetric int8 weights + asymmetric + # uint16 activations); fall back to the single ``symmetric`` flag when + # the per-target override is unset. + weight_symmetric = ( + config.weight_symmetric + if config.weight_symmetric is not None + else config.symmetric + ) + activation_symmetric = ( + config.activation_symmetric + if config.activation_symmetric is not None + else config.symmetric + ) extra_options = { - "ActivationSymmetric": config.symmetric, - "WeightSymmetric": config.symmetric, + "ActivationSymmetric": activation_symmetric, + "WeightSymmetric": weight_symmetric, } # Step 1: Generate QDQ config From 95d45d9ad9a9baab2576e2b88d7c3999a60ca3f4 Mon Sep 17 00:00:00 2001 From: spalne Date: Mon, 22 Jun 2026 14:51:23 -0700 Subject: [PATCH 04/17] refactor(qwen): register transformer-only path as a declarative model_type variant --- qwen3_transformer_only_quantize.py | 7 +- src/winml/modelkit/build/hf.py | 4 + src/winml/modelkit/loader/config.py | 13 +++ src/winml/modelkit/loader/hf.py | 13 +++ src/winml/modelkit/models/auto.py | 16 +++- src/winml/modelkit/models/hf/__init__.py | 10 ++ .../models/hf/qwen_transformer_only.py | 93 +++++++++---------- test_qwen.py | 8 +- 8 files changed, 105 insertions(+), 59 deletions(-) diff --git a/qwen3_transformer_only_quantize.py b/qwen3_transformer_only_quantize.py index 3ae895ae2..0b90c8bd0 100644 --- a/qwen3_transformer_only_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -1,7 +1,7 @@ """Transformer-only w8a16 quantization for Qwen3. -Targets the transformer-only ONNX produced by -``qwen_transformer_only.install() + test_qwen.py``: +Targets the transformer-only ONNX produced by the +``qwen3_transformer_only`` build variant (see ``test_qwen.py``): - **No embedding/lm_head surgery.** The export already excludes both, so we feed ``WinMLQuantization`` the file directly. @@ -24,6 +24,7 @@ from winml.modelkit.models.winml.composite_model import WinMLCompositeModel from winml.modelkit.quant import WinMLQuantizationConfig, quantize_onnx +from winml.modelkit.quant.config import CalibrationDataReader logger = logging.getLogger(__name__) @@ -48,7 +49,7 @@ def _load_gsm8k_prompts(num_samples: int) -> list[str]: return [row["question"] for row in split.select(range(num_samples))] -class Qwen3TransformerOnlyCalibReader: +class Qwen3TransformerOnlyCalibReader(CalibrationDataReader): """Yields calibration feeds for the transformer-only ONNX. Feeds match the exported graph exactly: ``input_hidden_states`` (FP32), diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index 26356a6eb..dc2661afa 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -91,6 +91,7 @@ def build_hf_model( cache_key: str | None = None, ep: EPNameOrAlias | None = None, device: str | None = None, + model_type: str | None = None, **kwargs: Any, ) -> BuildResult: """Build an ONNX model from a HuggingFace model architecture. @@ -208,6 +209,7 @@ def _name(base: str) -> str: model_id, trust_remote_code, random_init=random_init, + model_type=model_type, ) # ========================================================================= @@ -436,6 +438,7 @@ def _load_model( trust_remote_code: bool, random_init: bool = False, hf_config: Any | None = None, + model_type: str | None = None, ) -> Any: """Load PyTorch model — pretrained or random weights. @@ -511,6 +514,7 @@ def _load_model( task=task, trust_remote_code=effective_trust, hf_config=hf_config, + model_type=model_type, ) return pytorch_model diff --git a/src/winml/modelkit/loader/config.py b/src/winml/modelkit/loader/config.py index cb6cb9af1..b533c1636 100644 --- a/src/winml/modelkit/loader/config.py +++ b/src/winml/modelkit/loader/config.py @@ -218,6 +218,19 @@ def resolve_loader_config( f"attribute. Cannot proceed with config generation." ) + # Explicit model_type override alongside a model_id: honor the requested + # type so downstream class / build-config / export resolution selects the + # variant (e.g. "qwen3_transformer_only") rather than the architecture's + # native type. The model_type-only path above (AutoConfig.for_model) is + # unaffected because it only runs when model_id is None. + if model_id is not None and model_type is not None and hf_config.model_type != model_type: + logger.info( + "Overriding resolved model_type '%s' -> '%s' (explicit request)", + hf_config.model_type, + model_type, + ) + hf_config.model_type = model_type + # 2. Infer task (depends on: model_type param or hf_config.architectures) if task is None and model_type is not None: supported = get_supported_tasks(model_type, library_name=library_name) diff --git a/src/winml/modelkit/loader/hf.py b/src/winml/modelkit/loader/hf.py index 5a90b5828..7c40c5fee 100644 --- a/src/winml/modelkit/loader/hf.py +++ b/src/winml/modelkit/loader/hf.py @@ -150,6 +150,7 @@ def load_hf_model( user_script: str | None = None, trust_remote_code: bool = False, hf_config: PretrainedConfig | None = None, + model_type: str | None = None, ) -> tuple[nn.Module, PretrainedConfig, str]: """Load, detect task, and prepare HuggingFace model. @@ -224,6 +225,18 @@ def load_hf_model( trust_remote_code=trust_remote_code, ) + # Explicit model_type override: select a registered build variant (e.g. + # "qwen3_transformer_only") rather than the architecture's native type. + # Mutates the freshly-loaded config only; gated on an explicit request so + # normal loading is unaffected. + if model_type is not None and getattr(hf_config, "model_type", None) != model_type: + logger.info( + "Overriding model_type '%s' -> '%s' (explicit request)", + getattr(hf_config, "model_type", None), + model_type, + ) + hf_config.model_type = model_type + # [2] Task & Model Class Resolution if user_script is not None: resolved_class = _load_class_from_script(user_script, model_class) diff --git a/src/winml/modelkit/models/auto.py b/src/winml/modelkit/models/auto.py index 78f944b36..4767b97db 100644 --- a/src/winml/modelkit/models/auto.py +++ b/src/winml/modelkit/models/auto.py @@ -247,6 +247,7 @@ def from_pretrained( trust_remote_code: bool = False, shape_config: dict | None = None, no_compile: bool = False, + model_type: str | None = None, **kwargs: Any, ) -> WinMLPreTrainedModel: """Load appropriate WinML model based on task detection. @@ -278,6 +279,10 @@ def from_pretrained( shape_config: Shape overrides passed to generate_build_config(). Valid keys -- text: sequence_length; vision: height, width; audio: feature_size, nb_max_frames, audio_sequence_length. + model_type: Explicit model_type override. When provided alongside a + HF model_id, selects a registered build variant (e.g. + ``"qwen3_transformer_only"``) instead of the architecture's + native model_type. Leave ``None`` for normal auto-detection. **kwargs: Additional arguments Returns: @@ -334,6 +339,11 @@ def from_pretrained( else: _model_type = None + # Explicit override wins so a variant composite (e.g. + # "qwen3_transformer_only") can be selected over the native type. + if model_type is not None: + _model_type = model_type + if _model_type is not None and (_model_type, task) in COMPOSITE_MODEL_REGISTRY: from .winml.composite_model import WinMLCompositeModel @@ -368,6 +378,7 @@ def from_pretrained( trust_remote_code=trust_remote_code, ep=kwargs.get("ep"), no_compile=no_compile, + model_type=model_type, ) resolved_task = build_config.loader.task @@ -402,7 +413,9 @@ def from_pretrained( from transformers import AutoConfig hf_config = AutoConfig.from_pretrained(model_id, trust_remote_code=effective_trust) - model_type = getattr(hf_config, "model_type", "unknown") + # Honor an explicit model_type override; otherwise probe from the config. + if model_type is None: + model_type = getattr(hf_config, "model_type", "unknown") logger.debug("Model type: %s, task: %s", model_type, resolved_task) # ===================================================================== @@ -431,6 +444,7 @@ def from_pretrained( cache_key=cache_key, ep=resolved_ep, device=device, + model_type=model_type, ) onnx_path = result.final_onnx_path diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index c6f4c9520..0d2e538a3 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -56,6 +56,14 @@ from .qwen import QWEN_CONFIG from .qwen import QwenGenIOConfig as _QwenGenIOConfig from .qwen import QwenPrefillIOConfig as _QwenPrefillIOConfig +from .qwen_transformer_only import MODEL_CLASS_MAPPING as _QWEN_TO_CLASS_MAPPING +from .qwen_transformer_only import QWEN_TRANSFORMER_ONLY_CONFIG +from .qwen_transformer_only import ( + QwenTransformerOnlyGenIOConfig as _QwenTransformerOnlyGenIOConfig, # triggers registration +) +from .qwen_transformer_only import ( + QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig, # triggers registration +) from .roberta import ROBERTA_FAMILY_CONFIG from .roberta import RobertaIOConfig as _RobertaIOConfig # triggers registration from .sam import MODEL_CLASS_MAPPING as _SAM2_CLASS_MAPPING @@ -92,6 +100,7 @@ **_MARIAN_CLASS_MAPPING, **_MU2_CLASS_MAPPING, **_QWEN_CLASS_MAPPING, + **_QWEN_TO_CLASS_MAPPING, **_SAM2_CLASS_MAPPING, **_SEGFORMER_CLASS_MAPPING, **_SIGLIP_CLASS_MAPPING, @@ -115,6 +124,7 @@ "roberta": ROBERTA_FAMILY_CONFIG, "mu2": MU2_CONFIG, "qwen3": QWEN_CONFIG, + "qwen3-transformer-only": QWEN_TRANSFORMER_ONLY_CONFIG, "siglip": SIGLIP_CONFIG, "siglip-text-model": SIGLIP_CONFIG, "siglip-vision-model": SIGLIP_CONFIG, diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py index 8e30b1fb6..614267df4 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -2,12 +2,17 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Parallel ``qwen3`` build path that produces a transformer-only ONNX. +"""Transformer-only ``qwen3`` build variant, registered as a distinct model_type. -Opt-in via ``install()`` — calling it hot-patches the WinML registries so -that the next ``WinMLAutoModel.from_pretrained("Qwen/Qwen3-*", task="text-generation")`` -exports two transformer-only ONNX files (a prefill/context graph and an -iteration/decode graph) with this I/O: +This module registers a self-contained build path under the model_type +``"qwen3_transformer_only"`` (distinct from the stock ``"qwen3"`` path in +``qwen.py``). Selecting it is explicit — pass ``model_type="qwen3_transformer_only"`` +to ``WinMLAutoModel.from_pretrained(...)`` (or the underlying +``generate_hf_build_config(...)``). Both paths coexist; neither overrides the +other, and there is no import-ordering requirement. + +The variant exports two transformer-only ONNX files (a prefill/context graph +and an iteration/decode graph) with this I/O: Inputs : past_keys_{i}, past_values_{i} (FP16, ``[1, kv_heads, max_cache, head_dim]``), input_hidden_states (FP32, ``[1, seq_len, hidden]``), @@ -16,8 +21,9 @@ Ops : ``com.microsoft::GroupQueryAttention`` (do_rotary=1), ``onnx::LpNormalization`` (RMSNorm), 1x1 ``Conv`` projections. -The original eager-export path in ``qwen.py`` is left intact — only the -qwen3 entries in the registries are replaced. ``install()`` is idempotent. +Registration happens at import time via decorators and module-level mappings, +mirroring ``qwen.py``. The aggregating ``models.hf`` package imports this +module so the entries land in ``MODEL_CLASS_MAPPING`` / ``MODEL_BUILD_CONFIGS``. """ from __future__ import annotations @@ -36,6 +42,7 @@ from ...export import register_onnx_overwrite from ...export.config import WinMLExportConfig from ..winml import register_specialization +from ..winml.composite_model import register_composite_model from ..winml.decoder_only import WinMLDecoderOnlyModel from ..winml.kv_cache import WinMLSlidingWindowCache from .qwen3_export_ops import apply_transformer_only_export_prep @@ -43,6 +50,13 @@ logger = logging.getLogger(__name__) +# Distinct model_type for this variant. The underscore form is what the +# exporter sees on ``model.config.model_type`` and what Optimum's TasksManager +# and ``register_specialization`` are keyed on; the hyphenated form is used for +# the ``MODEL_CLASS_MAPPING`` / ``MODEL_BUILD_CONFIGS`` lookups (those callers +# normalize ``_`` -> ``-``). +TRANSFORMER_ONLY_MODEL_TYPE = "qwen3_transformer_only" + # ============================================================================= # Wrapper module @@ -65,6 +79,10 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: self.num_layers = num_layers self.config = model.config apply_transformer_only_export_prep(model, matmul_to_conv=True) + # Tag the config so the exporter resolves this variant's OnnxConfig + # (registered under ``TRANSFORMER_ONLY_MODEL_TYPE``) rather than the + # stock qwen3 one. Mirrors the CLIP/zoedepth sub-model precedent. + self.config.model_type = TRANSFORMER_ONLY_MODEL_TYPE @classmethod def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenTransformerOnlyDecoderWrapper: @@ -222,6 +240,9 @@ def _transformer_only_outputs(num_layers: int, kv_seq_axis: str = "max_seq_len") return result +@register_onnx_overwrite( + TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", library_name="transformers" +) class QwenTransformerOnlyPrefillIOConfig(OnnxConfig): """Prefill (seq=64) — transformer-only I/O.""" @@ -241,6 +262,9 @@ def outputs(self) -> dict[str, dict[int, str]]: return _transformer_only_outputs(self._normalized_config.num_layers) +@register_onnx_overwrite( + TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", library_name="transformers" +) class QwenTransformerOnlyGenIOConfig(OnnxConfig): """Generation (seq=1) — transformer-only I/O.""" @@ -279,6 +303,7 @@ def outputs(self) -> dict[str, dict[int, str]]: # ============================================================================= +@register_composite_model(TRANSFORMER_ONLY_MODEL_TYPE, "text-generation") class WinMLQwen3TransformerOnlyModel(WinMLDecoderOnlyModel): """Composite handle for the transformer-only Qwen3 build (export only). @@ -299,56 +324,28 @@ def get_cache_class(cls) -> type: # ============================================================================= -# install() — hot-patch the registries +# Declarative registration (import-time) # ============================================================================= +# Wrapper-class lookup keyed by (model_type, task). Keys use the hyphenated +# model_type form because ``_get_custom_model_class`` normalizes ``_`` -> ``-`` +# before lookup. Merged into the aggregate mapping by ``models.hf.__init__``. +MODEL_CLASS_MAPPING: dict[tuple[str, str], type] = { + ("qwen3-transformer-only", "feature-extraction"): QwenTransformerOnlyDecoderWrapper, + ("qwen3-transformer-only", "text2text-generation"): QwenTransformerOnlyDecoderWrapper, +} -_INSTALLED = False - - -def install() -> None: - """Replace the qwen3 entries in WinML registries with the transformer-only variants. - - Idempotent. After this call, building any qwen3 model via - :class:`~winml.modelkit.models.winml.composite_model.WinMLCompositeModel` - or :class:`~winml.modelkit.models.auto.WinMLAutoModel` produces - transformer-only ONNX files. - """ - global _INSTALLED - if _INSTALLED: - return - - # 1) Per-model build config + wrapper-class lookup live on the parent - # ``models.hf`` package as module-level dicts; mutating them is the - # documented hook for adding/overriding a model_type. - from .. import hf as _hf_pkg # noqa: PLC0415 - - _hf_pkg.MODEL_BUILD_CONFIGS["qwen3"] = QWEN_TRANSFORMER_ONLY_CONFIG - _hf_pkg.MODEL_CLASS_MAPPING[("qwen3", "feature-extraction")] = QwenTransformerOnlyDecoderWrapper - _hf_pkg.MODEL_CLASS_MAPPING[("qwen3", "text2text-generation")] = QwenTransformerOnlyDecoderWrapper - - # 2) Optimum OnnxConfig (overwrites existing registration). - register_onnx_overwrite("qwen3", "feature-extraction", library_name="transformers")(QwenTransformerOnlyPrefillIOConfig) - register_onnx_overwrite("qwen3", "text2text-generation", library_name="transformers")(QwenTransformerOnlyGenIOConfig) - - # 3) Inference specialization (still GenericTask — wrapper returns raw KV). - register_specialization("qwen3", "feature-extraction", "WinMLModelForGenericTask") - register_specialization("qwen3", "text2text-generation", "WinMLModelForGenericTask") - - # 4) Composite registry — swap to the transformer-only handle. - from ..winml.composite_model import COMPOSITE_MODEL_REGISTRY - - COMPOSITE_MODEL_REGISTRY[("qwen3", "text-generation")] = WinMLQwen3TransformerOnlyModel - - _INSTALLED = True - logger.info("qwen_transformer_only: transformer-only export path installed for qwen3.") +# Inference specialization (GenericTask — the wrapper returns raw hidden states / KV). +register_specialization(TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", "WinMLModelForGenericTask") +register_specialization(TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", "WinMLModelForGenericTask") __all__ = [ + "MODEL_CLASS_MAPPING", "QWEN_TRANSFORMER_ONLY_CONFIG", + "TRANSFORMER_ONLY_MODEL_TYPE", "QwenTransformerOnlyDecoderWrapper", "QwenTransformerOnlyGenIOConfig", "QwenTransformerOnlyPrefillIOConfig", "WinMLQwen3TransformerOnlyModel", - "install", ] diff --git a/test_qwen.py b/test_qwen.py index f958c2932..da23f4481 100644 --- a/test_qwen.py +++ b/test_qwen.py @@ -10,9 +10,6 @@ ``com.microsoft::GroupQueryAttention``, ``LpNormalization``, and 1x1 Conv projections. -Important: ``install()`` MUST be called before importing the composite model -machinery so the registry hot-patches take effect. - Generation (``model.generate(...)``) is NOT supported by this build path — the inference feeds in ``WinMLDecoderOnlyModel`` still target the eager I/O signature. Use the eager ``WinMLQwen3Model`` build path for end-to-end @@ -68,16 +65,13 @@ def _build_one(task: str, seq_len: int) -> None: clean interpreter — building both in one process leaves PyTorch/ORT state from the first build that corrupts/kills the second. """ - from winml.modelkit.models.hf.qwen_transformer_only import install as install_qwen_transformer_only - - install_qwen_transformer_only() - from winml.modelkit.config import WinMLBuildConfig from winml.modelkit.models.auto import WinMLAutoModel WinMLAutoModel.from_pretrained( model_id, task=task, + model_type="qwen3_transformer_only", config=WinMLBuildConfig(quant=None, compile=None), precision="fp16", device="npu", From 9cecb03913d9d19a10e1d2c27934414bcdd5ee3f Mon Sep 17 00:00:00 2001 From: spalne Date: Tue, 23 Jun 2026 11:52:52 -0700 Subject: [PATCH 05/17] fix(qwen): calibrate transformer-only decode model on real trajectory --- qwen3_transformer_only_quantize.py | 170 +++++++++++++++++++++++++++-- 1 file changed, 163 insertions(+), 7 deletions(-) diff --git a/qwen3_transformer_only_quantize.py b/qwen3_transformer_only_quantize.py index 0b90c8bd0..81bcb780f 100644 --- a/qwen3_transformer_only_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -34,6 +34,7 @@ DEFAULT_PREFILL_SEQ = 64 DEFAULT_GEN_SEQ = 1 DEFAULT_NUM_SAMPLES = 30 +DEFAULT_DECODE_STEPS = 16 DEFAULT_CALIB_DATASET = "openai/gsm8k" DEFAULT_CALIB_DATASET_CONFIG = "main" DEFAULT_CALIB_SPLIT = "train" @@ -119,6 +120,140 @@ def rewind(self) -> None: self._iter = iter(self._samples) +def _layer_kv(past: Any, i: int) -> tuple[torch.Tensor, torch.Tensor]: + """Extract layer ``i``'s (key, value) from an HF cache, version-agnostic. + + Handles the legacy tuple-of-tuples cache, the older ``DynamicCache`` + (``.key_cache`` / ``.value_cache``), and the newer per-layer + ``DynamicCache`` (``.layers[i].keys`` / ``.values``). + """ + if hasattr(past, "key_cache") and hasattr(past, "value_cache"): + return past.key_cache[i], past.value_cache[i] + if hasattr(past, "layers"): + layer = past.layers[i] + return layer.keys, layer.values + return past[i][0], past[i][1] + + +class Qwen3DecodeTrajectoryCalibReader(CalibrationDataReader): + """Calibrate the iter (seq_len=1) model on REAL decode-step states. + + The naive reader feeds one (repeated) token with a zeroed KV cache and + ``past_seq_len=0`` — a state the model never sees during generation. With + MinMax calibration this collapses the observed activation ranges far below + the real decode distribution, so the resulting w8a16 model degenerates + (e.g. ``Paris -> Parisammedammed...``). + + Instead, drive the HF FP reference model through a real prefill + decode + trajectory and capture, at each decode step, the exact feed the iter ONNX + would receive: the embedding of the *actually generated* token, the real + accumulated KV cache (copied into the fixed ``[1, kv_heads, max_cache, + head_dim]`` FP16 buffer), and the growing ``past_seq_len``. Token + selection uses the HF model's true logits, so the trajectory matches + greedy generation. The QDQ scheme is unchanged — only the calibration + statistics become representative. + """ + + def __init__( + self, + hf_model: torch.nn.Module, + embed_tokens: torch.nn.Module, + config: Any, + token_ids_list: list[torch.Tensor], + *, + prefill_seq: int, + max_cache_len: int, + decode_steps: int = 16, + ) -> None: + self.num_layers = config.num_hidden_layers + self.num_kv_heads = config.num_key_value_heads + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.max_cache_len = max_cache_len + self._samples = list( + self._build_samples( + hf_model, + embed_tokens, + token_ids_list, + prefill_seq=prefill_seq, + decode_steps=decode_steps, + ) + ) + self._iter: Iterator[dict[str, np.ndarray]] | None = None + self.rewind() + + def _kv_buffers(self, past: Any, cur_len: int) -> dict[str, np.ndarray]: + """Copy the ``cur_len`` valid KV positions into fixed FP16 buffers.""" + feed: dict[str, np.ndarray] = {} + for i in range(self.num_layers): + k, v = _layer_kv(past, i) + kbuf = np.zeros( + (1, self.num_kv_heads, self.max_cache_len, self.head_dim), np.float16 + ) + vbuf = np.zeros_like(kbuf) + kbuf[:, :, :cur_len, :] = k[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + vbuf[:, :, :cur_len, :] = v[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + feed[f"past_keys_{i}"] = kbuf + feed[f"past_values_{i}"] = vbuf + return feed + + def _build_samples( + self, + hf_model: torch.nn.Module, + embed_tokens: torch.nn.Module, + token_ids_list: list[torch.Tensor], + *, + prefill_seq: int, + decode_steps: int, + ) -> Iterator[dict[str, np.ndarray]]: + for ids in token_ids_list: + ids = ids[:, :prefill_seq] # real prompt prefix (no pad-token KV) + cur_len = ids.shape[1] + + # FP prefill once to seed a realistic KV cache + first token. + with torch.no_grad(): + out = hf_model(input_ids=ids, use_cache=True) + past = out.past_key_values + tok = int(out.logits[:, -1, :].argmax(-1)) + + for _ in range(decode_steps): + if cur_len >= self.max_cache_len: + break + # The feed the iter model sees for THIS token: embedding of the + # token to process, the KV of the `cur_len` preceding tokens, + # and seqlens_k = (cur_len + 1) - 1 = cur_len. + with torch.no_grad(): + emb = embed_tokens(torch.tensor([[tok]])).to(torch.float32).cpu().numpy() + feed: dict[str, np.ndarray] = { + "input_hidden_states": emb.astype(np.float32), + "past_seq_len": np.array([[cur_len]], dtype=np.int32), + "total_seq_len": np.array([self.max_cache_len], dtype=np.int32), + } + feed.update(self._kv_buffers(past, cur_len)) + yield feed + + # Advance the reference model one real decode step. + with torch.no_grad(): + out = hf_model( + input_ids=torch.tensor([[tok]]), + past_key_values=past, + use_cache=True, + ) + past = out.past_key_values + tok = int(out.logits[:, -1, :].argmax(-1)) + cur_len += 1 + + def get_next(self) -> dict[str, np.ndarray] | None: + try: + return next(self._iter) if self._iter is not None else None + except StopIteration: + return None + + def rewind(self) -> None: + self._iter = iter(self._samples) + + def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> list[torch.Tensor]: out: list[torch.Tensor] = [] for i in range(num_samples): @@ -160,6 +295,7 @@ def quantize_built_model( num_samples: int = DEFAULT_NUM_SAMPLES, weight_type: str = "int8", activation_type: str = "uint16", + decode_steps: int = DEFAULT_DECODE_STEPS, ) -> dict[str, Path]: """Quantize the transformer-only ONNX files in-place. @@ -223,13 +359,33 @@ def quantize_built_model( f" excluding {len(gqa_nodes)} GroupQueryAttention nodes from " "quantization (inputs + output stay float, Cast -> GQA -> Cast)" ) - reader = Qwen3TransformerOnlyCalibReader( - embed_tokens, - hf_model.config, - token_ids_list, - seq_len=seq_len, - max_cache_len=max_cache_len, - ) + if sub_name == "decoder_gen": + # The iter model only sees mid-generation states. Calibrate it on a + # real prefill+decode trajectory (true tokens, accumulated KV, + # growing past_seq_len) instead of one token + zeroed KV, which + # would under-range the MinMax activation scales and collapse + # generation. + print( + f" calibrating on decode trajectory ({decode_steps} steps/prompt, " + f"prefill_seq={prefill_seq})" + ) + reader: CalibrationDataReader = Qwen3DecodeTrajectoryCalibReader( + hf_model, + embed_tokens, + hf_model.config, + token_ids_list, + prefill_seq=prefill_seq, + max_cache_len=max_cache_len, + decode_steps=decode_steps, + ) + else: + reader = Qwen3TransformerOnlyCalibReader( + embed_tokens, + hf_model.config, + token_ids_list, + seq_len=seq_len, + max_cache_len=max_cache_len, + ) cfg = WinMLQuantizationConfig( samples=num_samples, weight_type=weight_type, # type: ignore[arg-type] From 08f05d7c399dafe2f60dfccf3cbd3348355ab721 Mon Sep 17 00:00:00 2001 From: spalne Date: Tue, 23 Jun 2026 13:18:25 -0700 Subject: [PATCH 06/17] Fixed small bugs --- qwen3_transformer_only_quantize.py | 18 +++- .../modelkit/models/hf/qwen3_export_ops.py | 81 +++----------- .../modelkit/models/hf/qwen3_modeling.py | 101 ++++++++++++++++-- .../models/hf/qwen_transformer_only.py | 2 +- test_qwen.py | 8 +- 5 files changed, 132 insertions(+), 78 deletions(-) diff --git a/qwen3_transformer_only_quantize.py b/qwen3_transformer_only_quantize.py index 81bcb780f..559620973 100644 --- a/qwen3_transformer_only_quantize.py +++ b/qwen3_transformer_only_quantize.py @@ -15,6 +15,7 @@ from __future__ import annotations import logging +import gc from pathlib import Path from typing import Any, Iterator @@ -40,6 +41,16 @@ DEFAULT_CALIB_SPLIT = "train" DEFAULT_CALIB_SEED = 42 +# Map an ONNX quantization dtype to the bit-width suffix used in artifact +# filenames (e.g. int8 -> "8", uint16 -> "16"), instead of brittle string +# slicing of the dtype name. +_DTYPE_BITS = { + "int8": "8", + "uint8": "8", + "int16": "16", + "uint16": "16", +} + def _load_gsm8k_prompts(num_samples: int) -> list[str]: """GSM8K train split, shuffled seed=42 for reproducible calibration.""" @@ -348,7 +359,8 @@ def quantize_built_model( seq_len = seq_by_sub[sub_name] quant_path = fused_path.with_name( - fused_path.stem + f"_w{weight_type[-1]}a{activation_type[-2:]}.quant.onnx" + fused_path.stem + + f"_w{_DTYPE_BITS[weight_type]}a{_DTYPE_BITS[activation_type]}.quant.onnx" ) print(f"\n=== Quantize (transformer-only): {sub_name} (seq_len={seq_len}) ===") @@ -414,5 +426,9 @@ def quantize_built_model( ) quant_paths[sub_name] = quant_path + # Free the FP reference model now that calibration is done. + del hf_model, embed_tokens + gc.collect() + print("\n=== Done ===") return quant_paths diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3_export_ops.py index 61d45f0ef..5fd3edb68 100644 --- a/src/winml/modelkit/models/hf/qwen3_export_ops.py +++ b/src/winml/modelkit/models/hf/qwen3_export_ops.py @@ -46,7 +46,12 @@ def symbolic(g, input, axis, p): # noqa: D401 @staticmethod def forward(ctx, input, axis, p): # noqa: ARG004 - return input # placeholder — real compute happens in symbolic + # Shape-only tracing placeholder. The real op is emitted by + # ``symbolic`` during ONNX export; ``forward`` exists solely so the + # TorchScript exporter (and Optimum's pre-export dry run) can trace + # output shapes. It returns ``input`` unchanged on purpose and is NOT a + # correct eager RMSNorm — do not call this module for real inference. + return input class GroupQueryAttentionOnnxExport(torch.autograd.Function): @@ -100,6 +105,12 @@ def forward( kv_num_heads, num_heads, ): # noqa: ARG004 + # Shape-only tracing placeholder. The real op is emitted by + # ``symbolic`` during ONNX export; ``forward`` exists solely so the + # TorchScript exporter (and Optimum's pre-export dry run) can trace + # output shapes. It returns the inputs as stand-in present-KV on + # purpose and is NOT correct attention — do not call this module for + # real inference. return query, past_key, past_value # placeholder shapes @@ -136,76 +147,8 @@ def from_linear_module(cls, linear: nn.Linear) -> TransposeConv2d1x1Transpose: return cls(linear.in_features, linear.out_features, linear.weight, linear.bias) -# ============================================================================= -# Apply export prep: bind winml Qwen3 export methods onto a loaded model -# ============================================================================= - - -def apply_transformer_only_export_prep(causal_lm: nn.Module, *, matmul_to_conv: bool = True) -> None: - """Mutate ``Qwen3ForCausalLM`` in-place into the export topology. - - Binds the winml-owned export behaviour from :mod:`.qwen3_modeling` onto each - Qwen3 submodule (runs ``prepare_for_onnx_export`` and rebinds ``forward``). - After this call, ``causal_lm.model(inputs_embeds, past_key_values, - past_seq_len, total_seq_len)`` runs the transformer-only forward. - - Args: - causal_lm: A ``transformers.Qwen3ForCausalLM`` instance. - matmul_to_conv: Swap ``nn.Linear`` projections to 1x1 ``Conv2d`` so - QNN sees them as Conv. - """ - from .qwen3_modeling import ( - WinMLQwen3Attention, - WinMLQwen3DecoderLayer, - WinMLQwen3MLP, - WinMLQwen3Model, - WinMLQwen3RMSNorm, - ) - - def _bind(module: nn.Module, owner: type) -> None: - module.forward = owner.forward.__get__(module, type(module)) - - # Identify Qwen3 submodules by their (stock HF) class name so we don't - # depend on importing ``transformers.models.qwen3`` here. - def _is(module: nn.Module, name: str) -> bool: - return type(module).__name__ == name - - # Patch every RMSNorm first (Qwen3RMSNorm appears at top, in q_norm/k_norm, - # in input/post_attention layernorms). - for mod in causal_lm.modules(): - if _is(mod, "Qwen3RMSNorm"): - WinMLQwen3RMSNorm.prepare_for_onnx_export(mod) - _bind(mod, WinMLQwen3RMSNorm) - - for mod in causal_lm.modules(): - if _is(mod, "Qwen3Attention"): - WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) - _bind(mod, WinMLQwen3Attention) - elif _is(mod, "Qwen3MLP"): - # MLP forward is unchanged; only the projections are swapped to Conv. - WinMLQwen3MLP.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) - - # HF moved ``rotary_emb`` from ``Qwen3Attention`` up to ``Qwen3Model``; - # the export forward invokes ``self.rotary_emb`` on the attention module, - # so re-attach a reference from the parent model. - for mod in causal_lm.modules(): - if _is(mod, "Qwen3Model") and hasattr(mod, "rotary_emb"): - for layer in mod.layers: - layer.self_attn.rotary_emb = mod.rotary_emb - - for mod in causal_lm.modules(): - if _is(mod, "Qwen3DecoderLayer"): - _bind(mod, WinMLQwen3DecoderLayer) - - for mod in causal_lm.modules(): - if _is(mod, "Qwen3Model"): - WinMLQwen3Model.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) - _bind(mod, WinMLQwen3Model) - - __all__ = [ "GroupQueryAttentionOnnxExport", "LpNormOnnxExport", "TransposeConv2d1x1Transpose", - "apply_transformer_only_export_prep", ] diff --git a/src/winml/modelkit/models/hf/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3_modeling.py index 05a70adfe..d3c538df5 100644 --- a/src/winml/modelkit/models/hf/qwen3_modeling.py +++ b/src/winml/modelkit/models/hf/qwen3_modeling.py @@ -18,7 +18,7 @@ - ``WinMLQwen3DecoderLayer`` / ``WinMLQwen3Model`` -> transformer-only forward that threads the KV cache + seq-len tensors and omits embeddings / lm_head. -``apply_transformer_only_export_prep`` (in ``qwen3_export_ops``) walks a loaded +``apply_transformer_only_export_prep`` (defined below) walks a loaded ``Qwen3ForCausalLM``, calls ``prepare_for_onnx_export`` on each submodule, and binds the matching ``forward`` from these classes onto it. """ @@ -42,15 +42,14 @@ class WinMLQwen3RMSNorm(nn.Module): def prepare_for_onnx_export(self) -> None: # Pre-multiply the gain into the weight (LpNorm has unit gain). + # ``scale`` is shape ``[1]`` and broadcasts over ``self.weight`` + # (shape ``[hidden_size]``), so the result keeps the per-channel + # shape even when the original weights are all ones. n = self.weight.numel() scale = torch.sqrt( torch.tensor([n], device=self.weight.device, dtype=self.weight.dtype) ) - if torch.any(self.weight.data != torch.ones_like(self.weight)).item(): - new_w = scale * self.weight - else: - new_w = scale - self.weight = nn.Parameter(new_w) + self.weight = nn.Parameter(scale * self.weight) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: out = LpNormOnnxExport.apply(hidden_states, -1, 2) @@ -228,10 +227,100 @@ def forward( return hidden_states, present_kvs +# ============================================================================= +# Apply export prep: bind winml Qwen3 export methods onto a loaded model +# ============================================================================= + + +def apply_transformer_only_export_prep( + causal_lm: nn.Module, *, matmul_to_conv: bool = True +) -> None: + """Mutate ``Qwen3ForCausalLM`` in-place into the export topology. + + Binds the winml-owned export behaviour (the ``WinMLQwen3*`` classes in this + module) onto each Qwen3 submodule (runs ``prepare_for_onnx_export`` and + rebinds ``forward``). After this call, ``causal_lm.model(inputs_embeds, + past_key_values, past_seq_len, total_seq_len)`` runs the transformer-only + forward. + + Args: + causal_lm: A ``transformers.Qwen3ForCausalLM`` instance. + matmul_to_conv: Swap ``nn.Linear`` projections to 1x1 ``Conv2d`` so + QNN sees them as Conv. + + Raises: + RuntimeError: If any expected Qwen3 submodule class is not found, + meaning the loaded model does not match the expected topology + (e.g. the stock HF class names changed). + """ + + def _bind(module: nn.Module, owner: type) -> None: + module.forward = owner.forward.__get__(module, type(module)) + + # Identify Qwen3 submodules by their (stock HF) class name so we don't + # depend on importing ``transformers.models.qwen3`` here. + def _is(module: nn.Module, name: str) -> bool: + return type(module).__name__ == name + + patched = { + "Qwen3RMSNorm": 0, + "Qwen3Attention": 0, + "Qwen3MLP": 0, + "Qwen3DecoderLayer": 0, + "Qwen3Model": 0, + } + + # Patch every RMSNorm first (Qwen3RMSNorm appears at top, in q_norm/k_norm, + # in input/post_attention layernorms). + for mod in causal_lm.modules(): + if _is(mod, "Qwen3RMSNorm"): + WinMLQwen3RMSNorm.prepare_for_onnx_export(mod) + _bind(mod, WinMLQwen3RMSNorm) + patched["Qwen3RMSNorm"] += 1 + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Attention"): + WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Attention) + patched["Qwen3Attention"] += 1 + elif _is(mod, "Qwen3MLP"): + # MLP forward is unchanged; only the projections are swapped to Conv. + WinMLQwen3MLP.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + patched["Qwen3MLP"] += 1 + + # HF moved ``rotary_emb`` from ``Qwen3Attention`` up to ``Qwen3Model``; + # the export forward invokes ``self.rotary_emb`` on the attention module, + # so re-attach a reference from the parent model. + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model") and hasattr(mod, "rotary_emb"): + for layer in mod.layers: + layer.self_attn.rotary_emb = mod.rotary_emb + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3DecoderLayer"): + _bind(mod, WinMLQwen3DecoderLayer) + patched["Qwen3DecoderLayer"] += 1 + + for mod in causal_lm.modules(): + if _is(mod, "Qwen3Model"): + WinMLQwen3Model.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + _bind(mod, WinMLQwen3Model) + patched["Qwen3Model"] += 1 + + missing = [name for name, count in patched.items() if count == 0] + if missing: + raise RuntimeError( + "transformer-only export prep found no " + f"{missing} submodule(s) to patch; the loaded model does not match " + "the expected Qwen3 topology (stock HF class names may have changed)." + ) + + __all__ = [ "WinMLQwen3Attention", "WinMLQwen3DecoderLayer", "WinMLQwen3MLP", "WinMLQwen3Model", "WinMLQwen3RMSNorm", + "apply_transformer_only_export_prep", ] diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py index 614267df4..6ac9d0852 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -45,7 +45,7 @@ from ..winml.composite_model import register_composite_model from ..winml.decoder_only import WinMLDecoderOnlyModel from ..winml.kv_cache import WinMLSlidingWindowCache -from .qwen3_export_ops import apply_transformer_only_export_prep +from .qwen3_modeling import apply_transformer_only_export_prep logger = logging.getLogger(__name__) diff --git a/test_qwen.py b/test_qwen.py index da23f4481..14cf4656d 100644 --- a/test_qwen.py +++ b/test_qwen.py @@ -17,7 +17,7 @@ Run:: - python test_qwen_transformer_only.py + python test_qwen.py This builds each transformer sub-model and then runs the w8a16 quantization on the exported transformer ONNX files (no surgery needed — @@ -85,6 +85,8 @@ def _build_one(task: str, seq_len: int) -> None: print(f"BUILD COMPLETE: task={task} seq_len={seq_len}", flush=True) sys.stdout.flush() sys.stderr.flush() + # TODO(winml-cli#836): replace the hard exit once the native QNN/ORT + # teardown segfault (0xC0000005) on interpreter shutdown is fixed upstream. os._exit(0) @@ -155,6 +157,8 @@ def _run_quant() -> None: print("QUANT COMPLETE", flush=True) sys.stdout.flush() sys.stderr.flush() + # TODO(winml-cli#836): replace the hard exit once the native QNN/ORT + # teardown segfault (0xC0000005) on interpreter shutdown is fixed upstream. os._exit(0) @@ -173,6 +177,7 @@ def main() -> None: [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), "--build-sub", task, str(seq_len)], cwd=str(_repo_root), + timeout=1800, ).returncode after = _latest_ctx_mtime(prefix) @@ -205,6 +210,7 @@ def main() -> None: rc = subprocess.run( [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), "--quant"], cwd=str(_repo_root), + timeout=1800, ).returncode after_files = list(ARTIFACTS_DIR.glob("*quant.onnx")) after = max((p.stat().st_mtime for p in after_files), default=0.0) From 818cfe47fbf30381c67cc7a7fefb99dd04edc509 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 11:53:25 +0800 Subject: [PATCH 07/17] refactor(qwen): config-driven transformer-only quant + pytest Replace the standalone root-level quant driver and __main__/subprocess test runner with the regular build pipeline and pytest. - Move calibration logic into src/.../hf/qwen_transformer_only_quant.py; the decode wrapper exposes winml_finalize_quant_config, invoked generically from build/hf.py just before quantize_onnx. The build now quantizes via precision=w8a16 + config.quant instead of a separate script. - The hook reads seq_len / max_cache / GQA node names from the exported ONNX and selects the prefill vs decode-trajectory calibration reader, keeping the verified-good scheme (int8-symmetric weights, uint16 activations, minmax, GQA excluded from QDQ). - Delete root qwen3_transformer_only_quantize.py and test_qwen.py. - Add tests/unit/models/qwen_transformer_only (fast, offline) and tests/e2e/models/test_qwen3_transformer_only_quant.py (build+quant+decode-parity, QNN-gated NPU). --- src/winml/modelkit/build/hf.py | 8 + .../models/hf/qwen_transformer_only.py | 92 ++++- .../models/hf/qwen_transformer_only_quant.py | 387 +++++++++--------- test_qwen.py | 235 ----------- .../test_qwen3_transformer_only_quant.py | 248 +++++++++++ .../models/qwen_transformer_only/__init__.py | 4 + .../test_quant_calibration.py | 234 +++++++++++ 7 files changed, 753 insertions(+), 455 deletions(-) rename qwen3_transformer_only_quantize.py => src/winml/modelkit/models/hf/qwen_transformer_only_quant.py (60%) delete mode 100644 test_qwen.py create mode 100644 tests/e2e/models/test_qwen3_transformer_only_quant.py create mode 100644 tests/unit/models/qwen_transformer_only/__init__.py create mode 100644 tests/unit/models/qwen_transformer_only/test_quant_calibration.py diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index dc2661afa..4dcf09b5b 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -310,6 +310,14 @@ def _name(base: str) -> str: else: logger.info("Quantizing model...") t0 = time.monotonic() + # Some model wrappers can only finalize their quant config once the + # exported ONNX exists (e.g. calibration feeds / nodes-to-exclude + # derived from the graph). Give the wrapper a chance to populate + # those runtime-only fields here. + if pytorch_model is not None and hasattr(pytorch_model, "winml_finalize_quant_config"): + config.quant = pytorch_model.winml_finalize_quant_config( + config.quant, onnx_path=current_path, model_id=model_id + ) quant_result = quantize_onnx( model_path=current_path, output_path=quantized_path, diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py index 6ac9d0852..fda69495f 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -85,7 +85,10 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: self.config.model_type = TRANSFORMER_ONLY_MODEL_TYPE @classmethod - def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenTransformerOnlyDecoderWrapper: + def from_pretrained( + cls, model_name_or_path: str, **kwargs: Any + ) -> QwenTransformerOnlyDecoderWrapper: + """Load the HF model and wrap it for transformer-only export.""" kwargs.setdefault("torch_dtype", torch.float32) model = AutoModelForCausalLM.from_pretrained(model_name_or_path, **kwargs) model.config._attn_implementation = "eager" @@ -94,24 +97,23 @@ def from_pretrained(cls, model_name_or_path: str, **kwargs: Any) -> QwenTransfor return wrapper def get_export_args(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, ...]: + """Flatten the dummy-input dict into positional export args.""" return tuple(inputs.values()) def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: - """Positional inputs (order matches OnnxConfig.inputs): + """Run the decoder stack on positional inputs (order matches OnnxConfig.inputs). - past_keys_0, past_values_0, ..., past_keys_{L-1}, past_values_{L-1}, - input_hidden_states, past_seq_len, total_seq_len - - Returns ``(output_hidden_states, present_keys_0, present_values_0, ...)``. + Positional inputs are ``past_keys_0, past_values_0, ..., + past_keys_{L-1}, past_values_{L-1}, input_hidden_states, past_seq_len, + total_seq_len``. Returns ``(output_hidden_states, present_keys_0, + present_values_0, ...)``. """ kv_args = args[: 2 * self.num_layers] input_hidden_states = args[2 * self.num_layers] past_seq_len = args[2 * self.num_layers + 1] total_seq_len = args[2 * self.num_layers + 2] - past_key_values = [ - (kv_args[2 * i], kv_args[2 * i + 1]) for i in range(self.num_layers) - ] + past_key_values = [(kv_args[2 * i], kv_args[2 * i + 1]) for i in range(self.num_layers)] hidden_states, present_kvs = self.model.model( inputs_embeds=input_hidden_states, @@ -126,6 +128,27 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: out.extend([k, v]) return tuple(out) + def winml_finalize_quant_config( + self, quant: Any, *, onnx_path: Any, model_id: str | None = None + ) -> Any: + """Build-pipeline hook: attach the calibration reader + GQA exclusions. + + Called by ``build_hf_model`` just before ``quantize_onnx`` (see + ``build/hf.py``). The exported transformer-only graph determines the + calibration feeds (shapes, KV buffers) and which GroupQueryAttention + nodes stay in float, so the live :class:`WinMLQuantizationConfig` can + only be finalized here — not at config-construction time. + """ + from .qwen_transformer_only_quant import ( + DEFAULT_MODEL_ID, + finalize_transformer_only_quant_config, + ) + + resolved_id = model_id or getattr(self.config, "_name_or_path", None) or DEFAULT_MODEL_ID + return finalize_transformer_only_quant_config( + quant, onnx_path=onnx_path, model_id=resolved_id + ) + # ============================================================================= # Dummy input generators (transformer-only I/O) @@ -151,7 +174,13 @@ def __init__( self.hidden_size = normalized_config.hidden_size self.seq_len = seq_len or getattr(normalized_config, "seq_len", self._default_seq_len) - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: if input_name == "input_hidden_states": return torch.randn(self.batch_size, self.seq_len, self.hidden_size, dtype=torch.float32) raise ValueError(f"Unknown input: {input_name}") @@ -166,10 +195,16 @@ class _TransformerOnlySeqLenGenerator(DummyInputGenerator): SUPPORTED_INPUT_NAMES = ("past_seq_len", "total_seq_len") - def __init__(self, task: str, normalized_config: Any, **kwargs: Any) -> None: # noqa: ARG002 + def __init__(self, task: str, normalized_config: Any, **kwargs: Any) -> None: self.max_cache_len = normalized_config.max_cache_len - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: if input_name == "past_seq_len": return torch.zeros((1, 1), dtype=torch.int32) if input_name == "total_seq_len": @@ -192,14 +227,22 @@ def __init__( ) -> None: self.batch_size = batch_size self.num_layers: int = normalized_config.num_layers - self.num_heads: int = normalized_config.num_attention_heads # KV heads (NormalizedConfig maps it) + self.num_heads: int = ( + normalized_config.num_attention_heads + ) # KV heads (NormalizedConfig maps it) self.head_dim: int = normalized_config.head_dim self.max_cache_len: int = max_cache_len or normalized_config.max_cache_len self.SUPPORTED_INPUT_NAMES = tuple( name for i in range(self.num_layers) for name in (f"past_keys_{i}", f"past_values_{i}") ) - def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32") -> torch.Tensor: # noqa: ARG002 + def generate( + self, + input_name: str, + framework: str = "pt", + int_dtype: str = "int64", + float_dtype: str = "fp32", + ) -> torch.Tensor: shape = (self.batch_size, self.num_heads, self.max_cache_len, self.head_dim) return torch.zeros(shape, dtype=torch.float16) @@ -220,7 +263,9 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int ) -def _transformer_only_inputs(num_layers: int, kv_seq_axis: str = "max_seq_len") -> dict[str, dict[int, str]]: +def _transformer_only_inputs( + num_layers: int, kv_seq_axis: str = "max_seq_len" +) -> dict[str, dict[int, str]]: """Input ordering: past KV pairs, then hidden states, then seq lens.""" result: dict[str, dict[int, str]] = {} for i in range(num_layers): @@ -232,7 +277,9 @@ def _transformer_only_inputs(num_layers: int, kv_seq_axis: str = "max_seq_len") return result -def _transformer_only_outputs(num_layers: int, kv_seq_axis: str = "max_seq_len") -> dict[str, dict[int, str]]: +def _transformer_only_outputs( + num_layers: int, kv_seq_axis: str = "max_seq_len" +) -> dict[str, dict[int, str]]: result: dict[str, dict[int, str]] = {"output_hidden_states": {1: "seq_len"}} for i in range(num_layers): result[f"present_keys_{i}"] = {2: kv_seq_axis} @@ -255,10 +302,12 @@ class QwenTransformerOnlyPrefillIOConfig(OnnxConfig): @property def inputs(self) -> dict[str, dict[int, str]]: + """ONNX input axes (past KV pairs, hidden states, seq lengths).""" return _transformer_only_inputs(self._normalized_config.num_layers) @property def outputs(self) -> dict[str, dict[int, str]]: + """ONNX output axes (hidden states then present KV pairs).""" return _transformer_only_outputs(self._normalized_config.num_layers) @@ -277,10 +326,12 @@ class QwenTransformerOnlyGenIOConfig(OnnxConfig): @property def inputs(self) -> dict[str, dict[int, str]]: + """ONNX input axes (past KV pairs, hidden states, seq lengths).""" return _transformer_only_inputs(self._normalized_config.num_layers) @property def outputs(self) -> dict[str, dict[int, str]]: + """ONNX output axes (hidden states then present KV pairs).""" return _transformer_only_outputs(self._normalized_config.num_layers) @@ -320,6 +371,7 @@ class WinMLQwen3TransformerOnlyModel(WinMLDecoderOnlyModel): @classmethod def get_cache_class(cls) -> type: + """Return the KV-cache class used during generation.""" return WinMLSlidingWindowCache @@ -336,8 +388,12 @@ def get_cache_class(cls) -> type: } # Inference specialization (GenericTask — the wrapper returns raw hidden states / KV). -register_specialization(TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", "WinMLModelForGenericTask") -register_specialization(TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", "WinMLModelForGenericTask") +register_specialization( + TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", "WinMLModelForGenericTask" +) +register_specialization( + TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", "WinMLModelForGenericTask" +) __all__ = [ diff --git a/qwen3_transformer_only_quantize.py b/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py similarity index 60% rename from qwen3_transformer_only_quantize.py rename to src/winml/modelkit/models/hf/qwen_transformer_only_quant.py index 559620973..f01de2f71 100644 --- a/qwen3_transformer_only_quantize.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py @@ -1,37 +1,58 @@ -"""Transformer-only w8a16 quantization for Qwen3. - -Targets the transformer-only ONNX produced by the -``qwen3_transformer_only`` build variant (see ``test_qwen.py``): - - - **No embedding/lm_head surgery.** The export already excludes both, - so we feed ``WinMLQuantization`` the file directly. - - **Transformer-shaped calibration feeds.** ``input_hidden_states`` (FP32), - ``past_seq_len`` / ``total_seq_len`` (INT32), ``past_keys_{i}`` / - ``past_values_{i}`` (FP16) — names + dtypes match the exported graph. - -Run via ``test_qwen.py``. +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Config-driven w8a16 calibration for the transformer-only Qwen3 build. + +The transformer-only export (:mod:`qwen_transformer_only`) emits a graph whose +only quantization-relevant runtime inputs (the calibration feeds and the +``GroupQueryAttention`` node names to keep in float) can't be known until the +ONNX exists. Rather than a standalone post-build script that reaches into +``composite.sub_models[...]._onnx_path``, this module plugs into the normal +build pipeline: :meth:`QwenTransformerOnlyDecoderWrapper.winml_finalize_quant_config` +calls :func:`finalize_transformer_only_quant_config` just before +``quantize_onnx`` runs (see ``build/hf.py``), populating the live +:class:`WinMLQuantizationConfig` with the right +:class:`~winml.modelkit.quant.config.CalibrationDataReader` and +``nodes_to_exclude``. + +The two readers match the exported graph exactly: + + - ``input_hidden_states`` (FP32), ``past_seq_len`` / ``total_seq_len`` + (INT32), ``past_keys_{i}`` / ``past_values_{i}`` (FP16, full cache buffer). + - The prefill reader (``seq_len > 1``) embeds real prompt prefixes. + - The decode reader (``seq_len == 1``) drives a fresh FP reference model + through a real prefill + decode trajectory so MinMax sees representative + mid-generation activation ranges (a single repeated token + zeroed KV + collapses the ranges and degenerates generation). + +The export wrapper surgically replaces its own ``self.model`` (RMSNorm -> +LpNorm-identity, attention -> GQA placeholder, Linear -> 1x1 Conv), so it can't +run real inference; calibration loads a *fresh* ``AutoModelForCausalLM``. """ from __future__ import annotations -import logging import gc +import logging from pathlib import Path -from typing import Any, Iterator +from typing import TYPE_CHECKING, Any import numpy as np import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from winml.modelkit.models.winml.composite_model import WinMLCompositeModel -from winml.modelkit.quant import WinMLQuantizationConfig, quantize_onnx -from winml.modelkit.quant.config import CalibrationDataReader +from ...quant.config import CalibrationDataReader, WinMLQuantizationConfig + + +if TYPE_CHECKING: + from collections.abc import Iterator logger = logging.getLogger(__name__) DEFAULT_MODEL_ID = "Qwen/Qwen3-0.6B" -DEFAULT_MAX_CACHE = 256 DEFAULT_PREFILL_SEQ = 64 DEFAULT_GEN_SEQ = 1 DEFAULT_NUM_SAMPLES = 30 @@ -41,16 +62,6 @@ DEFAULT_CALIB_SPLIT = "train" DEFAULT_CALIB_SEED = 42 -# Map an ONNX quantization dtype to the bit-width suffix used in artifact -# filenames (e.g. int8 -> "8", uint16 -> "16"), instead of brittle string -# slicing of the dtype name. -_DTYPE_BITS = { - "int8": "8", - "uint8": "8", - "int16": "16", - "uint16": "16", -} - def _load_gsm8k_prompts(num_samples: int) -> list[str]: """GSM8K train split, shuffled seed=42 for reproducible calibration.""" @@ -61,8 +72,79 @@ def _load_gsm8k_prompts(num_samples: int) -> list[str]: return [row["question"] for row in split.select(range(num_samples))] +def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> list[torch.Tensor]: + out: list[torch.Tensor] = [] + for i in range(num_samples): + prompt = prompts[i % len(prompts)] + text = tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ids = tokenizer([text], return_tensors="pt").input_ids + out.append(ids) + return out + + +def _gqa_node_names(onnx_path: Path) -> list[str]: + """Return the names of every GroupQueryAttention node in ``onnx_path``. + + These nodes are excluded from quantization so ORT leaves both their + inputs and output in float (``... -> Cast -> GQA -> Cast``), matching + the reference graph which keeps attention entirely out of QDQ. + """ + import onnx + + model = onnx.load(str(onnx_path), load_external_data=False) + return [n.name for n in model.graph.node if n.op_type == "GroupQueryAttention" and n.name] + + +def _graph_shapes(onnx_path: Path) -> tuple[int, int]: + """Read ``(seq_len, max_cache_len)`` from the exported graph's static inputs. + + ``seq_len`` is the query length (``input_hidden_states`` dim 1) and + ``max_cache_len`` is the KV buffer length (``past_keys_0`` dim 2). The + transformer-only export keeps both axes static, so these fully determine + whether the sub-model is prefill (``seq_len > 1``) or decode (``seq_len == 1``) + and the size of the fixed KV buffers the calibration feeds must match. + """ + import onnx + + model = onnx.load(str(onnx_path), load_external_data=False) + seq_len: int | None = None + max_cache_len: int | None = None + for inp in model.graph.input: + dims = inp.type.tensor_type.shape.dim + if inp.name == "input_hidden_states" and len(dims) >= 2: + seq_len = dims[1].dim_value + elif inp.name == "past_keys_0" and len(dims) >= 3: + max_cache_len = dims[2].dim_value + if seq_len is None or max_cache_len is None: + raise ValueError( + f"Could not read seq_len/max_cache_len from {onnx_path.name}; " + f"found seq_len={seq_len}, max_cache_len={max_cache_len}" + ) + return seq_len, max_cache_len + + +def _layer_kv(past: Any, i: int) -> tuple[torch.Tensor, torch.Tensor]: + """Extract layer ``i``'s (key, value) from an HF cache, version-agnostic. + + Handles the legacy tuple-of-tuples cache, the older ``DynamicCache`` + (``.key_cache`` / ``.value_cache``), and the newer per-layer + ``DynamicCache`` (``.layers[i].keys`` / ``.values``). + """ + if hasattr(past, "key_cache") and hasattr(past, "value_cache"): + return past.key_cache[i], past.value_cache[i] + if hasattr(past, "layers"): + layer = past.layers[i] + return layer.keys, layer.values + return past[i][0], past[i][1] + + class Qwen3TransformerOnlyCalibReader(CalibrationDataReader): - """Yields calibration feeds for the transformer-only ONNX. + """Prefill calibration feeds for the transformer-only ONNX. Feeds match the exported graph exactly: ``input_hidden_states`` (FP32), ``past_seq_len`` (INT32 ``[1,1]``), ``total_seq_len`` (INT32 ``[1]``), @@ -95,9 +177,7 @@ def _build_samples(self, token_ids_list: list[torch.Tensor]) -> Iterator[dict[st ids = ids[:, : self.seq_len] real_len = ids.shape[1] if real_len < self.seq_len: - pad = torch.zeros( - (1, self.seq_len - real_len), dtype=ids.dtype, device=ids.device - ) + pad = torch.zeros((1, self.seq_len - real_len), dtype=ids.dtype, device=ids.device) ids = torch.cat([ids, pad], dim=1) with torch.no_grad(): @@ -107,10 +187,10 @@ def _build_samples(self, token_ids_list: list[torch.Tensor]) -> Iterator[dict[st "input_hidden_states": embeds.astype(np.float32), # seqlens_k for GQA = (valid context length - 1), i.e. # ``embeddings.shape[1] - 1``. We pad to seq_len, so the query - # has seq_len valid positions → past_seq_len = seq_len - 1. + # has seq_len valid positions -> past_seq_len = seq_len - 1. # (Using 0 here declares only 1 valid token while feeding a # seq_len-token query, which makes the GQA prefill kernel read - # out of bounds → native access violation.) + # out of bounds -> native access violation.) "past_seq_len": np.array([[self.seq_len - 1]], dtype=np.int32), "total_seq_len": np.array([self.max_cache_len], dtype=np.int32), } @@ -122,30 +202,17 @@ def _build_samples(self, token_ids_list: list[torch.Tensor]) -> Iterator[dict[st yield feed def get_next(self) -> dict[str, np.ndarray] | None: + """Return the next calibration feed, or None when exhausted.""" try: return next(self._iter) if self._iter is not None else None except StopIteration: return None def rewind(self) -> None: + """Reset the iterator so calibration can run another pass.""" self._iter = iter(self._samples) -def _layer_kv(past: Any, i: int) -> tuple[torch.Tensor, torch.Tensor]: - """Extract layer ``i``'s (key, value) from an HF cache, version-agnostic. - - Handles the legacy tuple-of-tuples cache, the older ``DynamicCache`` - (``.key_cache`` / ``.value_cache``), and the newer per-layer - ``DynamicCache`` (``.layers[i].keys`` / ``.values``). - """ - if hasattr(past, "key_cache") and hasattr(past, "value_cache"): - return past.key_cache[i], past.value_cache[i] - if hasattr(past, "layers"): - layer = past.layers[i] - return layer.keys, layer.values - return past[i][0], past[i][1] - - class Qwen3DecodeTrajectoryCalibReader(CalibrationDataReader): """Calibrate the iter (seq_len=1) model on REAL decode-step states. @@ -174,7 +241,7 @@ def __init__( *, prefill_seq: int, max_cache_len: int, - decode_steps: int = 16, + decode_steps: int = DEFAULT_DECODE_STEPS, ) -> None: self.num_layers = config.num_hidden_layers self.num_kv_heads = config.num_key_value_heads @@ -199,9 +266,7 @@ def _kv_buffers(self, past: Any, cur_len: int) -> dict[str, np.ndarray]: feed: dict[str, np.ndarray] = {} for i in range(self.num_layers): k, v = _layer_kv(past, i) - kbuf = np.zeros( - (1, self.num_kv_heads, self.max_cache_len, self.head_dim), np.float16 - ) + kbuf = np.zeros((1, self.num_kv_heads, self.max_cache_len, self.head_dim), np.float16) vbuf = np.zeros_like(kbuf) kbuf[:, :, :cur_len, :] = k[:, :, :cur_len, :].to(torch.float16).cpu().numpy() vbuf[:, :, :cur_len, :] = v[:, :, :cur_len, :].to(torch.float16).cpu().numpy() @@ -256,179 +321,97 @@ def _build_samples( cur_len += 1 def get_next(self) -> dict[str, np.ndarray] | None: + """Return the next calibration feed, or None when exhausted.""" try: return next(self._iter) if self._iter is not None else None except StopIteration: return None def rewind(self) -> None: + """Reset the iterator so calibration can run another pass.""" self._iter = iter(self._samples) -def _tokenize_prompts(tokenizer: Any, prompts: list[str], num_samples: int) -> list[torch.Tensor]: - out: list[torch.Tensor] = [] - for i in range(num_samples): - prompt = prompts[i % len(prompts)] - text = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - tokenize=False, - add_generation_prompt=True, - enable_thinking=False, - ) - ids = tokenizer([text], return_tensors="pt").input_ids - out.append(ids) - return out - - -def _gqa_node_names(onnx_path: Path) -> list[str]: - """Return the names of every GroupQueryAttention node in ``onnx_path``. - - These nodes are excluded from quantization so ORT leaves both their - inputs and output in float (``... -> Cast -> GQA -> Cast``), matching - the reference graph which keeps attention entirely out of QDQ. - """ - import onnx - - model = onnx.load(str(onnx_path), load_external_data=False) - return [ - n.name - for n in model.graph.node - if n.op_type == "GroupQueryAttention" and n.name - ] - - -def quantize_built_model( - model: WinMLCompositeModel, +def finalize_transformer_only_quant_config( + quant: WinMLQuantizationConfig, *, + onnx_path: Path, model_id: str = DEFAULT_MODEL_ID, - max_cache_len: int = DEFAULT_MAX_CACHE, prefill_seq: int = DEFAULT_PREFILL_SEQ, - num_samples: int = DEFAULT_NUM_SAMPLES, - weight_type: str = "int8", - activation_type: str = "uint16", decode_steps: int = DEFAULT_DECODE_STEPS, -) -> dict[str, Path]: - """Quantize the transformer-only ONNX files in-place. - - Returns ``{sub_model_name: quantized_path}``. +) -> WinMLQuantizationConfig: + """Populate ``quant`` with the transformer-only w8a16 scheme + runtime fields. + + The build pipeline's device/precision policy only enables quantization and + picks generic dtypes; the transformer-only scheme is fixed and reference- + matched, so this hook is authoritative: + + - **int8-symmetric weights** (zp=0) + **uint16 asymmetric activations**, + - **MinMax** calibration, + - GroupQueryAttention nodes excluded from QDQ (read from the graph), + - the matching :class:`CalibrationDataReader` (prefill vs. decode-trajectory, + chosen by the graph's ``seq_len``). + + Reads static shapes + GQA nodes from ``onnx_path`` and loads a fresh FP + reference model for calibration (the export wrapper's own weights are + surgically replaced and can't run real inference). """ - # Locate the un-compiled ONNX for each sub-model (no surgery — file is - # already transformer-only). - sub_paths: dict[str, Path] = {} - for name, sub in model.sub_models.items(): - final_path = Path(sub._onnx_path) - if final_path.name.endswith("_model.onnx"): - stem = final_path.name[: -len("_model.onnx")] - optimized = final_path.with_name(f"{stem}_optimized.onnx") - if optimized.exists(): - sub_paths[name] = optimized - continue - print( - f"WARNING: {optimized.name} not found next to {final_path.name}; " - "falling back to the compiled model." - ) - sub_paths[name] = final_path - - for name, p in sub_paths.items(): - print(f" {name}: {p}") + onnx_path = Path(onnx_path) + seq_len, max_cache_len = _graph_shapes(onnx_path) + gqa_nodes = _gqa_node_names(onnx_path) + + # Fixed, reference-matched w8a16 scheme (authoritative over policy dtypes). + quant.weight_type = "int8" + quant.activation_type = "uint16" + quant.weight_symmetric = True + quant.activation_symmetric = False + quant.calibration_method = "minmax" + num_samples = quant.samples or DEFAULT_NUM_SAMPLES + + logger.info( + "Finalizing transformer-only quant config for %s " + "(seq_len=%d, max_cache_len=%d, %d GQA nodes excluded, %d samples)", + onnx_path.name, + seq_len, + max_cache_len, + len(gqa_nodes), + num_samples, + ) - print("\n=== Loading HF embed_tokens for calibration ===") - hf_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float32) + hf_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32) hf_model.eval() embed_tokens = hf_model.get_input_embeddings() tokenizer = AutoTokenizer.from_pretrained(model_id) - - print( - f"=== Loading {num_samples} GSM8K calibration prompts " - f"({DEFAULT_CALIB_DATASET}/{DEFAULT_CALIB_DATASET_CONFIG}, " - f"split={DEFAULT_CALIB_SPLIT}, seed={DEFAULT_CALIB_SEED}) ===" - ) prompts = _load_gsm8k_prompts(num_samples) token_ids_list = _tokenize_prompts(tokenizer, prompts, num_samples) - seq_by_sub = { - "decoder_prefill": prefill_seq, - "decoder_gen": DEFAULT_GEN_SEQ, - } - - quant_paths: dict[str, Path] = {} - for sub_name, fused_path in sub_paths.items(): - if sub_name not in seq_by_sub: - print(f"\n--- Skipping unknown sub-model {sub_name!r} ---") - continue - - seq_len = seq_by_sub[sub_name] - quant_path = fused_path.with_name( - fused_path.stem - + f"_w{_DTYPE_BITS[weight_type]}a{_DTYPE_BITS[activation_type]}.quant.onnx" + reader: CalibrationDataReader + if seq_len == DEFAULT_GEN_SEQ: + # Decode sub-model: calibrate on a real prefill+decode trajectory. + reader = Qwen3DecodeTrajectoryCalibReader( + hf_model, + embed_tokens, + hf_model.config, + token_ids_list, + prefill_seq=prefill_seq, + max_cache_len=max_cache_len, + decode_steps=decode_steps, ) - - print(f"\n=== Quantize (transformer-only): {sub_name} (seq_len={seq_len}) ===") - print(f" in : {fused_path}") - print(f" out: {quant_path}") - gqa_nodes = _gqa_node_names(fused_path) - print( - f" excluding {len(gqa_nodes)} GroupQueryAttention nodes from " - "quantization (inputs + output stay float, Cast -> GQA -> Cast)" + else: + reader = Qwen3TransformerOnlyCalibReader( + embed_tokens, + hf_model.config, + token_ids_list, + seq_len=seq_len, + max_cache_len=max_cache_len, ) - if sub_name == "decoder_gen": - # The iter model only sees mid-generation states. Calibrate it on a - # real prefill+decode trajectory (true tokens, accumulated KV, - # growing past_seq_len) instead of one token + zeroed KV, which - # would under-range the MinMax activation scales and collapse - # generation. - print( - f" calibrating on decode trajectory ({decode_steps} steps/prompt, " - f"prefill_seq={prefill_seq})" - ) - reader: CalibrationDataReader = Qwen3DecodeTrajectoryCalibReader( - hf_model, - embed_tokens, - hf_model.config, - token_ids_list, - prefill_seq=prefill_seq, - max_cache_len=max_cache_len, - decode_steps=decode_steps, - ) - else: - reader = Qwen3TransformerOnlyCalibReader( - embed_tokens, - hf_model.config, - token_ids_list, - seq_len=seq_len, - max_cache_len=max_cache_len, - ) - cfg = WinMLQuantizationConfig( - samples=num_samples, - weight_type=weight_type, # type: ignore[arg-type] - activation_type=activation_type, # type: ignore[arg-type] - calibration_method="minmax", - calibration_data=reader, - # w8a16: symmetric int8 weights (zp=0) + asymmetric uint16 - # activations, matching the reference quantization. - weight_symmetric=True, - activation_symmetric=False, - # ORT treats GroupQueryAttention as quantizable and wraps both its - # inputs and output in QDQ. The reference keeps attention entirely - # in float (Cast -> GQA -> Cast), so exclude the GQA nodes from - # quantization so no QDQ is inserted around them. - nodes_to_exclude=gqa_nodes, - ) - result = quantize_onnx(fused_path, output_path=quant_path, config=cfg) - if not result.success: - print(" FAILED:") - for err in result.errors: - print(f" {err}") - raise SystemExit(1) - print( - f" ok — {result.nodes_quantized} QDQ nodes inserted in " - f"{result.total_time_seconds:.1f}s" - ) - quant_paths[sub_name] = quant_path - # Free the FP reference model now that calibration is done. + quant.calibration_data = reader + quant.nodes_to_exclude = gqa_nodes + + # Readers materialize all samples eagerly, so the FP reference is no longer + # needed once they're built. del hf_model, embed_tokens gc.collect() - print("\n=== Done ===") - return quant_paths + return quant diff --git a/test_qwen.py b/test_qwen.py deleted file mode 100644 index 14cf4656d..000000000 --- a/test_qwen.py +++ /dev/null @@ -1,235 +0,0 @@ -"""E2E test for the transformer-only Qwen3 export path. - -Produces two transformer-only ONNX files whose I/O matches -``qwen3_gqa_fp16_ctx.onnx`` / ``qwen3_gqa_fp16_iter.onnx``: - - decoder_prefill: input_hidden_states [1, 64, 1024] → output_hidden_states + KV - decoder_gen : input_hidden_states [1, 1, 1024] → output_hidden_states + KV - -with FP16 past/present KV named ``past_keys_{i}`` / ``past_values_{i}``, -``com.microsoft::GroupQueryAttention``, ``LpNormalization``, and 1x1 Conv -projections. - -Generation (``model.generate(...)``) is NOT supported by this build path — -the inference feeds in ``WinMLDecoderOnlyModel`` still target the eager -I/O signature. Use the eager ``WinMLQwen3Model`` build path for end-to-end -generation. - -Run:: - - python test_qwen.py - -This builds each transformer sub-model and then runs the w8a16 -quantization on the exported transformer ONNX files (no surgery needed — -files are already transformer-only). -""" - -import os -import sys -import pathlib -import subprocess - -# Put the in-repo `src/` ahead of site-packages so `import winml` always -# resolves to the editable source tree — no manual copy-to-venv needed. -_repo_root = pathlib.Path(__file__).resolve().parent -sys.path.insert(0, str(_repo_root / "src")) -sys.path.insert(0, str(_repo_root)) - -model_id = "Qwen/Qwen3-0.6B" -MAX_CACHE = 256 - -# component name -> (HF task, seq_len, artifact prefix). Order matters -# (prefill first). The prefix is how the built npu_ctx file is named so the -# parent can verify success by artifact appearance (the build segfaults on -# native QNN/ORT teardown AFTER writing the file, so exit codes are unreliable). -SUB_MODELS = { - "decoder_prefill": ("feature-extraction", 64, "feat_"), - "decoder_gen": ("text2text-generation", 1, "txt2txt_"), -} - -ARTIFACTS_DIR = ( - pathlib.Path.home() / ".cache" / "winml" / "artifacts" / model_id.replace("/", "_") -) - - -def _latest_ctx_mtime(prefix: str) -> float: - """Newest mtime of a ``{prefix}*_optimized_npu_ctx.onnx`` artifact, or 0.""" - files = list(ARTIFACTS_DIR.glob(f"{prefix}*_optimized_npu_ctx.onnx")) - return max((f.stat().st_mtime for f in files), default=0.0) - - -def _build_one(task: str, seq_len: int) -> None: - """Build a SINGLE transformer sub-model in this (fresh) process. - - Invoked as a subprocess by ``main()`` so each sub-model exports in a - clean interpreter — building both in one process leaves PyTorch/ORT - state from the first build that corrupts/kills the second. - """ - from winml.modelkit.config import WinMLBuildConfig - from winml.modelkit.models.auto import WinMLAutoModel - - WinMLAutoModel.from_pretrained( - model_id, - task=task, - model_type="qwen3_transformer_only", - config=WinMLBuildConfig(quant=None, compile=None), - precision="fp16", - device="npu", - ep="qnn", - force_rebuild=True, - shape_config={"max_cache_len": MAX_CACHE, "seq_len": seq_len}, - ) - # The QNN/ORT teardown segfaults (0xC0000005) on interpreter shutdown - # AFTER the artifact is fully written. Skip the buggy cleanup with a hard - # exit so the parent sees a clean exit code 0. - print(f"BUILD COMPLETE: task={task} seq_len={seq_len}", flush=True) - sys.stdout.flush() - sys.stderr.flush() - # TODO(winml-cli#836): replace the hard exit once the native QNN/ORT - # teardown segfault (0xC0000005) on interpreter shutdown is fixed upstream. - os._exit(0) - - -def _find_optimized(prefix: str) -> pathlib.Path: - """Locate the cached transformer-only ``{prefix}*_optimized.onnx`` file.""" - cands = [ - p for p in ARTIFACTS_DIR.glob(f"{prefix}*_optimized.onnx") - if not p.name.endswith("_optimized_npu_ctx.onnx") - ] - if not cands: - raise FileNotFoundError( - f"No {prefix}*_optimized.onnx in {ARTIFACTS_DIR} — build the sub-model first." - ) - return max(cands, key=lambda p: p.stat().st_mtime) - - -class _SubShim: - """Minimal stand-in exposing the ``_onnx_path`` quant needs.""" - - def __init__(self, onnx_path: pathlib.Path): - self._onnx_path = str(onnx_path) - - -class _ModelShim: - """Minimal stand-in exposing ``sub_models`` for ``quantize_built_model``.""" - - def __init__(self, sub_models: dict): - self.sub_models = sub_models - - -def _run_quant() -> None: - """Quantize the cached transformer ONNX files (no composite/QNN load). - - Runs as its own subprocess so any ORT teardown crash can't poison the - parent. Builds a shim ``model`` whose ``sub_models[name]._onnx_path`` - point straight at the cached ``*_optimized.onnx`` files. - """ - # Dump a native C-stack if the calibration InferenceSession segfaults - # (otherwise the crash is silent — no Python traceback). - import faulthandler - faulthandler.enable() - - from qwen3_transformer_only_quantize import quantize_built_model - - sub_models = { - name: _SubShim(_find_optimized(prefix)) - for name, (_task, _seq, prefix) in SUB_MODELS.items() - } - model = _ModelShim(sub_models) - print("=== Running transformer w8a16 quantization ===", flush=True) - for name, sub in sub_models.items(): - print(f" {name}: {sub._onnx_path}", flush=True) - - try: - quantize_built_model( - model, - model_id=model_id, - max_cache_len=MAX_CACHE, - prefill_seq=64, - ) - except BaseException: - import traceback - print("QUANT FAILED with exception:", flush=True) - traceback.print_exc() - sys.stdout.flush() - sys.stderr.flush() - raise - print("QUANT COMPLETE", flush=True) - sys.stdout.flush() - sys.stderr.flush() - # TODO(winml-cli#836): replace the hard exit once the native QNN/ORT - # teardown segfault (0xC0000005) on interpreter shutdown is fixed upstream. - os._exit(0) - - -def main() -> None: - # 1) Build each sub-model in its OWN subprocess (fresh state each time). - # Judge success by whether a FRESH npu_ctx artifact appeared, NOT by the - # subprocess exit code: the native QNN/ORT layer segfaults (0xC0000005) - # on teardown AFTER the artifact is fully written to disk. - import time as _time - - for name, (task, seq_len, prefix) in SUB_MODELS.items(): - print(f"\n########## BUILD {name} (task={task}, seq_len={seq_len}) ##########", flush=True) - before = _latest_ctx_mtime(prefix) - start = _time.time() - rc = subprocess.run( - [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), - "--build-sub", task, str(seq_len)], - cwd=str(_repo_root), - timeout=1800, - ).returncode - - after = _latest_ctx_mtime(prefix) - if after > before and after >= start - 1: - status = "OK" if rc == 0 else f"OK (ignored teardown exit {rc})" - print(f"########## {name} {status}: fresh {prefix}*_optimized_npu_ctx.onnx ##########", flush=True) - else: - raise SystemExit( - f"Sub-model build failed for {name} (exit {rc}) — " - f"no fresh {prefix}*_optimized_npu_ctx.onnx in {ARTIFACTS_DIR}" - ) - - # 2) Report the built transformer-only ONNX files (no composite/QNN load — - # that creates QNN EP sessions that segfault the parent on teardown). - for name, (_task, _seq, prefix) in SUB_MODELS.items(): - print(f"\n=== {name} ===") - print(f" optimized : {_find_optimized(prefix).name}") - ctx = sorted(ARTIFACTS_DIR.glob(f"{prefix}*_optimized_npu_ctx.onnx")) - if ctx: - print(f" npu_ctx : {ctx[-1].name}") - - # 3) Quantization — run in its OWN subprocess for the same teardown-crash - # isolation. Judge by whether quant files appeared. - print("\n########## QUANTIZE ##########", flush=True) - before = max( - (p.stat().st_mtime for p in ARTIFACTS_DIR.glob("*quant.onnx")), - default=0.0, - ) - qstart = _time.time() - rc = subprocess.run( - [sys.executable, "-u", str(pathlib.Path(__file__).resolve()), "--quant"], - cwd=str(_repo_root), - timeout=1800, - ).returncode - after_files = list(ARTIFACTS_DIR.glob("*quant.onnx")) - after = max((p.stat().st_mtime for p in after_files), default=0.0) - if after > before and after >= qstart - 1: - status = "OK" if rc == 0 else f"OK (ignored teardown exit {rc})" - print(f"########## QUANTIZE {status} ##########", flush=True) - for p in sorted(after_files, key=lambda x: x.stat().st_mtime)[-len(SUB_MODELS):]: - print(f" {p.name}", flush=True) - else: - raise SystemExit( - f"Quantization failed (exit {rc}) — no fresh *quant.onnx in {ARTIFACTS_DIR}" - ) - - -if __name__ == "__main__": - if len(sys.argv) >= 4 and sys.argv[1] == "--build-sub": - _build_one(sys.argv[2], int(sys.argv[3])) - elif len(sys.argv) >= 2 and sys.argv[1] == "--quant": - _run_quant() - else: - main() - diff --git a/tests/e2e/models/test_qwen3_transformer_only_quant.py b/tests/e2e/models/test_qwen3_transformer_only_quant.py new file mode 100644 index 000000000..7c6499e51 --- /dev/null +++ b/tests/e2e/models/test_qwen3_transformer_only_quant.py @@ -0,0 +1,248 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""End-to-end coverage for the transformer-only Qwen3 w8a16 build. + +Replaces the former root-level ``test_qwen.py`` / ``qwen3_transformer_only_quantize.py`` +scripts. Quantization is now driven entirely through the standard build +pipeline (``WinMLAutoModel.from_pretrained(..., precision="w8a16")``): the +device/precision policy enables the quantize stage, and +``QwenTransformerOnlyDecoderWrapper.winml_finalize_quant_config`` finalizes the +reference-matched scheme (int8-symmetric weights, uint16 activations, +GroupQueryAttention excluded from QDQ) plus the decode-trajectory calibration +reader. + +These tests download Qwen3-0.6B from HuggingFace and run a full CPU export + +quantize, so they are gated behind ``slow`` + ``network`` and excluded from the +default lane. The QNN/NPU build is additionally gated on a real NPU. + +All expectations are generated in-code (FP reference greedy decode), never +hardcoded from a prior model run. +""" + +from __future__ import annotations + +import numpy as np +import onnx +import onnxruntime as ort +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from winml.modelkit.config import WinMLBuildConfig +from winml.modelkit.models.auto import WinMLAutoModel +from winml.modelkit.quant import WinMLQuantizationConfig + + +pytestmark = [pytest.mark.e2e, pytest.mark.slow, pytest.mark.network] + +MODEL_ID = "Qwen/Qwen3-0.6B" +MAX_CACHE = 256 +PARITY_TOKENS = 8 +DECODE_STEPS = 12 +# Keep CPU calibration cheap: the decode reader emits ``samples * 16`` feeds. +CALIB_SAMPLES = 4 + + +def _qnn_available() -> bool: + """True when ONNX Runtime exposes the QNN execution provider (real NPU).""" + return "QNNExecutionProvider" in ort.get_available_providers() + + +def _decoder_onnx_path(model) -> str: + """Locate the quantized decode ONNX behind the composite handle.""" + sub = model.sub_models["decoder_gen"] + return str(sub._onnx_path) + + +def _qdq_counts(onnx_path: str) -> dict[str, int]: + graph = onnx.load(onnx_path, load_external_data=False).graph + counts: dict[str, int] = {} + for node in graph.node: + counts[node.op_type] = counts.get(node.op_type, 0) + 1 + return counts + + +def _gqa_tensor_set(graph) -> set[str]: + tensors: set[str] = set() + for node in graph.node: + if node.op_type == "GroupQueryAttention": + tensors.update(node.input) + tensors.update(node.output) + return tensors + + +@pytest.fixture(scope="module") +def decode_quant_model(tmp_path_factory): + """Build + quantize the decode (seq_len=1) sub-model once on CPU.""" + cache_dir = tmp_path_factory.mktemp("qwen3_w8a16") + return WinMLAutoModel.from_pretrained( + MODEL_ID, + task="text2text-generation", + model_type="qwen3_transformer_only", + config=WinMLBuildConfig(quant=WinMLQuantizationConfig(samples=CALIB_SAMPLES)), + precision="w8a16", + device="cpu", + ep="cpu", + force_rebuild=True, + shape_config={"max_cache_len": MAX_CACHE, "seq_len": 1}, + cache_dir=str(cache_dir), + ) + + +@pytest.mark.timeout(2400) +def test_decode_model_is_quantized_with_gqa_excluded(decode_quant_model): + onnx_path = _decoder_onnx_path(decode_quant_model) + counts = _qdq_counts(onnx_path) + + # QDQ nodes were inserted via the config-driven pipeline. + assert counts.get("QuantizeLinear", 0) > 0 + assert counts.get("DequantizeLinear", 0) > 0 + # GroupQueryAttention survives in float (not quantized away). + assert counts.get("GroupQueryAttention", 0) > 0 + + # GQA exclusion contract: no QuantizeLinear/DequantizeLinear touches a GQA + # input or output tensor (attention stays Cast -> GQA -> Cast). + graph = onnx.load(onnx_path, load_external_data=False).graph + gqa_tensors = _gqa_tensor_set(graph) + touching = [ + node.name + for node in graph.node + if node.op_type in ("QuantizeLinear", "DequantizeLinear") + and (set(node.input) & gqa_tensors or set(node.output) & gqa_tensors) + ] + assert touching == [] + + +def _carry_kv(kv: dict[str, np.ndarray], out: dict[str, np.ndarray], num_layers: int) -> None: + for i in range(num_layers): + kv[f"past_keys_{i}"] = out[f"present_keys_{i}"] + kv[f"past_values_{i}"] = out[f"present_values_{i}"] + + +def _seed_kv_from_fp(past, num_layers, num_kv_heads, head_dim, cur_len): + """Copy an HF FP prefill cache into the decode model's fixed FP16 buffers.""" + kv: dict[str, np.ndarray] = {} + for i in range(num_layers): + layer = past[i] if not hasattr(past, "layers") else None + if layer is not None: + k, v = past[i][0], past[i][1] + else: # newer per-layer DynamicCache + k, v = past.layers[i].keys, past.layers[i].values + kbuf = np.zeros((1, num_kv_heads, MAX_CACHE, head_dim), np.float16) + vbuf = np.zeros_like(kbuf) + kbuf[:, :, :cur_len, :] = k[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + vbuf[:, :, :cur_len, :] = v[:, :, :cur_len, :].to(torch.float16).cpu().numpy() + kv[f"past_keys_{i}"] = kbuf + kv[f"past_values_{i}"] = vbuf + return kv + + +@pytest.mark.timeout(2400) +def test_decode_parity_against_fp_reference(decode_quant_model): + """The w8a16 decode model must track the FP reference token-for-token. + + This is the regression guard against the historical "decode collapse": + a degenerate calibration (single repeated token + zeroed KV) made the + quantized decode model diverge into garbage after ~1 token. With the + decode-trajectory reader the quantized greedy trajectory must match the + FP reference for the first ``PARITY_TOKENS`` tokens. + """ + onnx_path = _decoder_onnx_path(decode_quant_model) + session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) + want = {i.name for i in session.get_inputs()} + + hf = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype=torch.float32) + hf.eval() + cfg = hf.config + embed = hf.get_input_embeddings() + lm_head = hf.lm_head + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) + num_layers = cfg.num_hidden_layers + num_kv_heads = cfg.num_key_value_heads + head_dim = getattr(cfg, "head_dim", cfg.hidden_size // cfg.num_attention_heads) + + text = tokenizer.apply_chat_template( + [{"role": "user", "content": "What is the capital of France?"}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False, + ) + ids = tokenizer([text], return_tensors="pt").input_ids + cur_len = ids.shape[1] + assert cur_len < MAX_CACHE + + # --- FP reference greedy decode (generates the expected tokens) --- + with torch.no_grad(): + out = hf(input_ids=ids, use_cache=True) + fp_past = out.past_key_values + first_tok = int(out.logits[:, -1, :].argmax(-1)) + fp_tokens: list[int] = [] + tok, past = first_tok, fp_past + for _ in range(DECODE_STEPS): + with torch.no_grad(): + out = hf(input_ids=torch.tensor([[tok]]), past_key_values=past, use_cache=True) + past = out.past_key_values + tok = int(out.logits[:, -1, :].argmax(-1)) + fp_tokens.append(tok) + + # --- Quantized decode model greedy decode (own KV, FP embed + lm_head) --- + with torch.no_grad(): + seed = hf(input_ids=ids, use_cache=True) + kv = _seed_kv_from_fp(seed.past_key_values, num_layers, num_kv_heads, head_dim, cur_len) + quant_tokens: list[int] = [] + tok, past_len = first_tok, cur_len + for _ in range(DECODE_STEPS): + with torch.no_grad(): + emb = embed(torch.tensor([[tok]])).to(torch.float32).cpu().numpy() + feeds = { + "input_hidden_states": emb.astype(np.float32), + "past_seq_len": np.array([[past_len]], np.int32), + "total_seq_len": np.array([MAX_CACHE], np.int32), + **kv, + } + feeds = {k: v for k, v in feeds.items() if k in want} + names = [o.name for o in session.get_outputs()] + outs = dict(zip(names, session.run(None, feeds), strict=False)) + _carry_kv(kv, outs, num_layers) + hidden = torch.tensor(outs["output_hidden_states"][:, 0, :]) + with torch.no_grad(): + tok = int(lm_head(hidden).numpy()[0].argmax()) + quant_tokens.append(tok) + past_len += 1 + + assert quant_tokens[:PARITY_TOKENS] == fp_tokens[:PARITY_TOKENS], ( + f"w8a16 decode diverged from FP reference:\n" + f" fp : {fp_tokens[:PARITY_TOKENS]}\n" + f" quant: {quant_tokens[:PARITY_TOKENS]}" + ) + + +@pytest.mark.npu +@pytest.mark.qnn +@pytest.mark.timeout(2400) +@pytest.mark.skipif(not _qnn_available(), reason="requires QNN execution provider (NPU)") +@pytest.mark.parametrize( + ("task", "seq_len"), + [("feature-extraction", 64), ("text2text-generation", 1)], +) +def test_npu_build_quantizes(task, seq_len, tmp_path): + """On real NPU hardware, the w8a16 pipeline produces a quantized graph.""" + model = WinMLAutoModel.from_pretrained( + MODEL_ID, + task=task, + model_type="qwen3_transformer_only", + precision="w8a16", + device="npu", + ep="qnn", + no_compile=True, + force_rebuild=True, + shape_config={"max_cache_len": MAX_CACHE, "seq_len": seq_len}, + cache_dir=str(tmp_path), + ) + sub_name = "decoder_prefill" if seq_len == 64 else "decoder_gen" + onnx_path = str(model.sub_models[sub_name]._onnx_path) + counts = _qdq_counts(onnx_path) + assert counts.get("QuantizeLinear", 0) > 0 + assert counts.get("GroupQueryAttention", 0) > 0 diff --git a/tests/unit/models/qwen_transformer_only/__init__.py b/tests/unit/models/qwen_transformer_only/__init__.py new file mode 100644 index 000000000..862c45ce3 --- /dev/null +++ b/tests/unit/models/qwen_transformer_only/__init__.py @@ -0,0 +1,4 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- diff --git a/tests/unit/models/qwen_transformer_only/test_quant_calibration.py b/tests/unit/models/qwen_transformer_only/test_quant_calibration.py new file mode 100644 index 000000000..f1b160433 --- /dev/null +++ b/tests/unit/models/qwen_transformer_only/test_quant_calibration.py @@ -0,0 +1,234 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the transformer-only Qwen3 quant calibration readers. + +These are fast, offline tests (no model download, no ONNX Runtime): they +exercise the graph-shape introspection, GroupQueryAttention node discovery, +and the exact feed contract (names / dtypes / shapes) the two calibration +readers must satisfy. All expectations are derived in-code from the inputs, +never hardcoded from a model run. +""" + +from __future__ import annotations + +from types import SimpleNamespace + +import numpy as np +import onnx +import torch +from onnx import TensorProto, helper + +from winml.modelkit.models.hf.qwen_transformer_only_quant import ( + Qwen3DecodeTrajectoryCalibReader, + Qwen3TransformerOnlyCalibReader, + _gqa_node_names, + _graph_shapes, +) + + +NUM_LAYERS = 2 +NUM_KV_HEADS = 2 +HEAD_DIM = 4 +HIDDEN = NUM_KV_HEADS * HEAD_DIM +VOCAB = 16 + + +def _fake_config() -> SimpleNamespace: + return SimpleNamespace( + num_hidden_layers=NUM_LAYERS, + num_key_value_heads=NUM_KV_HEADS, + head_dim=HEAD_DIM, + hidden_size=HIDDEN, + num_attention_heads=NUM_KV_HEADS, + ) + + +def _build_tiny_onnx(path, *, seq_len: int, max_cache_len: int) -> None: + """Write a minimal graph carrying the inputs the readers introspect.""" + inputs = [ + helper.make_tensor_value_info( + "input_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] + ), + helper.make_tensor_value_info( + "past_keys_0", TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] + ), + ] + out = helper.make_tensor_value_info( + "output_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] + ) + gqa = helper.make_node( + "GroupQueryAttention", + ["input_hidden_states"], + ["attn_out"], + name="gqa_layer_0", + domain="com.microsoft", + ) + identity = helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) + graph = helper.make_graph([gqa, identity], "tiny", inputs, [out]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + onnx.save(model, str(path)) + + +def test_graph_shapes_and_gqa_nodes(tmp_path): + p = tmp_path / "tiny.onnx" + _build_tiny_onnx(p, seq_len=1, max_cache_len=16) + + assert _graph_shapes(p) == (1, 16) + assert _gqa_node_names(p) == ["gqa_layer_0"] + + +def test_graph_shapes_prefill(tmp_path): + p = tmp_path / "tiny_prefill.onnx" + _build_tiny_onnx(p, seq_len=64, max_cache_len=256) + + assert _graph_shapes(p) == (64, 256) + + +def _drain(reader) -> list[dict[str, np.ndarray]]: + feeds = [] + while (feed := reader.get_next()) is not None: + feeds.append(feed) + return feeds + + +def test_prefill_reader_feed_contract(): + seq_len, max_cache_len = 4, 16 + embed = torch.nn.Embedding(VOCAB, HIDDEN) + token_ids = [torch.tensor([[1, 2, 3, 4, 5]])] + + reader = Qwen3TransformerOnlyCalibReader( + embed, + _fake_config(), + token_ids, + seq_len=seq_len, + max_cache_len=max_cache_len, + ) + feeds = _drain(reader) + + assert len(feeds) == len(token_ids) + feed = feeds[0] + + # input_hidden_states: FP32, truncated to seq_len. + assert feed["input_hidden_states"].dtype == np.float32 + assert feed["input_hidden_states"].shape == (1, seq_len, HIDDEN) + + # seqlens_k contract: past_seq_len = seq_len - 1 (INT32 [1,1]). + assert feed["past_seq_len"].dtype == np.int32 + np.testing.assert_array_equal(feed["past_seq_len"], [[seq_len - 1]]) + + # total_seq_len: full cache (INT32 [1]). + assert feed["total_seq_len"].dtype == np.int32 + np.testing.assert_array_equal(feed["total_seq_len"], [max_cache_len]) + + # KV buffers: FP16, full cache shape, present for every layer. + for i in range(NUM_LAYERS): + for prefix in ("past_keys_", "past_values_"): + kv = feed[f"{prefix}{i}"] + assert kv.dtype == np.float16 + assert kv.shape == (1, NUM_KV_HEADS, max_cache_len, HEAD_DIM) + + # rewind() replays the same samples. + reader.rewind() + assert len(_drain(reader)) == len(token_ids) + + +def test_prefill_reader_pads_short_prompts(): + seq_len = 6 # longer than the 3-token prompt -> must pad + embed = torch.nn.Embedding(VOCAB, HIDDEN) + token_ids = [torch.tensor([[1, 2, 3]])] + + reader = Qwen3TransformerOnlyCalibReader( + embed, _fake_config(), token_ids, seq_len=seq_len, max_cache_len=16 + ) + feed = _drain(reader)[0] + assert feed["input_hidden_states"].shape == (1, seq_len, HIDDEN) + + +class _StubCausalLM: + """Minimal HF-like model: grows a tuple-of-tuples KV cache by 1 each call. + + Always predicts ``next_token`` so the trajectory is deterministic. + """ + + def __init__(self, next_token: int) -> None: + self.next_token = next_token + + def _cache(self, length: int): + return tuple( + ( + torch.randn(1, NUM_KV_HEADS, length, HEAD_DIM), + torch.randn(1, NUM_KV_HEADS, length, HEAD_DIM), + ) + for _ in range(NUM_LAYERS) + ) + + def __call__(self, input_ids=None, past_key_values=None, use_cache=True): + if past_key_values is None: + length = input_ids.shape[1] + query_len = length + else: + length = past_key_values[0][0].shape[2] + input_ids.shape[1] + query_len = input_ids.shape[1] + logits = torch.full((1, query_len, VOCAB), -10.0) + logits[..., self.next_token] = 10.0 + return SimpleNamespace(past_key_values=self._cache(length), logits=logits) + + +def test_decode_trajectory_reader_grows_past_seq_len(): + prefill_seq, decode_steps, max_cache_len = 2, 3, 16 + embed = torch.nn.Embedding(VOCAB, HIDDEN) + hf_model = _StubCausalLM(next_token=5) + token_ids = [torch.tensor([[1, 2, 3, 4]])] # truncated to prefill_seq=2 + + reader = Qwen3DecodeTrajectoryCalibReader( + hf_model, + embed, + _fake_config(), + token_ids, + prefill_seq=prefill_seq, + max_cache_len=max_cache_len, + decode_steps=decode_steps, + ) + feeds = _drain(reader) + + assert len(feeds) == len(token_ids) * decode_steps + + # past_seq_len must grow monotonically from prefill_seq (real decode), not + # stay pinned at 0 like the degenerate single-token reader. + seq_lens = [int(f["past_seq_len"][0, 0]) for f in feeds] + assert seq_lens == [prefill_seq, prefill_seq + 1, prefill_seq + 2] + + for f in feeds: + # One token per decode step. + assert f["input_hidden_states"].shape == (1, 1, HIDDEN) + assert f["input_hidden_states"].dtype == np.float32 + cur_len = int(f["past_seq_len"][0, 0]) + for i in range(NUM_LAYERS): + kv = f[f"past_keys_{i}"] + assert kv.dtype == np.float16 + assert kv.shape == (1, NUM_KV_HEADS, max_cache_len, HEAD_DIM) + # Positions beyond the valid context stay zero-padded. + assert np.all(kv[:, :, cur_len:, :] == 0) + + +def test_decode_trajectory_reader_respects_max_cache(): + prefill_seq, decode_steps, max_cache_len = 4, 10, 6 + embed = torch.nn.Embedding(VOCAB, HIDDEN) + hf_model = _StubCausalLM(next_token=2) + token_ids = [torch.tensor([[1, 2, 3, 4, 5, 6]])] + + reader = Qwen3DecodeTrajectoryCalibReader( + hf_model, + embed, + _fake_config(), + token_ids, + prefill_seq=prefill_seq, + max_cache_len=max_cache_len, + decode_steps=decode_steps, + ) + feeds = _drain(reader) + # Trajectory must stop once the cache is full (cur_len reaches max_cache_len). + assert len(feeds) == max_cache_len - prefill_seq + assert max(int(f["past_seq_len"][0, 0]) for f in feeds) == max_cache_len - 1 From a7f518e6ea61138a128eed8473a3fc89783caa8c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 13:42:32 +0800 Subject: [PATCH 08/17] fix(qwen): clean lint + persist finalized quant config + guard dynamic shapes - Add missing docstrings / return-type annotations and drop dead noqa directives across qwen3_export_ops.py, qwen3_modeling.py and the transformer-only registration so 'ruff check src/ tests/' (CI lint) passes. - build/hf.py: re-persist config.json after winml_finalize_quant_config runs, so the saved metadata reflects the actually-applied w8a16 scheme (int8/uint16/symmetry + GQA nodes_to_exclude) rather than the pre-finalize policy dtypes. - qwen_transformer_only_quant._graph_shapes: treat a non-positive dim_value (symbolic/dynamic axis) as a hard error instead of silently returning a zero-length shape. --- src/winml/modelkit/build/hf.py | 4 ++ src/winml/modelkit/models/hf/__init__.py | 3 +- .../modelkit/models/hf/qwen3_export_ops.py | 62 +++++++++++++------ .../modelkit/models/hf/qwen3_modeling.py | 20 +++--- .../models/hf/qwen_transformer_only_quant.py | 7 ++- 5 files changed, 66 insertions(+), 30 deletions(-) diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index 2faa3eec2..08698d125 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -325,6 +325,10 @@ def _name(base: str) -> str: config.quant = pytorch_model.winml_finalize_quant_config( config.quant, onnx_path=current_path, model_id=model_id ) + # The hook may overwrite the quant scheme (dtypes, symmetry, + # nodes-to-exclude) authoritatively, so re-persist the config + # to keep config.json consistent with what was actually applied. + config_path.write_text(json.dumps(config.to_dict(), indent=2)) quant_result = quantize_onnx( model_path=current_path, output_path=quantized_path, diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index 0d2e538a3..458bc8e34 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -62,7 +62,8 @@ QwenTransformerOnlyGenIOConfig as _QwenTransformerOnlyGenIOConfig, # triggers registration ) from .qwen_transformer_only import ( - QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig, # triggers registration + # triggers registration + QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig, ) from .roberta import ROBERTA_FAMILY_CONFIG from .roberta import RobertaIOConfig as _RobertaIOConfig # triggers registration diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3_export_ops.py index 5fd3edb68..e1eba87c3 100644 --- a/src/winml/modelkit/models/hf/qwen3_export_ops.py +++ b/src/winml/modelkit/models/hf/qwen3_export_ops.py @@ -2,8 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -"""Custom ONNX export ops + the entry point that reshapes HF's Qwen3 modules -for the transformer-only export. +"""Custom ONNX export ops that reshape HF's Qwen3 modules for export. These reshape the standard HF Qwen3 modules so winml-cli can produce a QNN-friendly, transformer-only graph: @@ -20,6 +19,8 @@ from __future__ import annotations +from typing import Any + import torch import torch.nn as nn from torch.onnx import symbolic_helper @@ -34,7 +35,8 @@ class LpNormOnnxExport(torch.autograd.Function): """RMSNorm body → ONNX ``LpNormalization`` (p=2 along last dim).""" @staticmethod - def symbolic(g, input, axis, p): # noqa: D401 + def symbolic(g, input, axis, p) -> Any: + """Emit the ONNX ``LpNormalization`` node during export.""" output_type = input.type().with_sizes(symbolic_helper._get_tensor_sizes(input)) output = g.op( "onnx::LpNormalization", @@ -45,12 +47,14 @@ def symbolic(g, input, axis, p): # noqa: D401 return output.setType(output_type) @staticmethod - def forward(ctx, input, axis, p): # noqa: ARG004 - # Shape-only tracing placeholder. The real op is emitted by - # ``symbolic`` during ONNX export; ``forward`` exists solely so the - # TorchScript exporter (and Optimum's pre-export dry run) can trace - # output shapes. It returns ``input`` unchanged on purpose and is NOT a - # correct eager RMSNorm — do not call this module for real inference. + def forward(ctx, input, axis, p) -> Any: + """Shape-only tracing placeholder; returns ``input`` unchanged. + + The real op is emitted by ``symbolic`` during ONNX export; ``forward`` + exists solely so the TorchScript exporter (and Optimum's pre-export dry + run) can trace output shapes. It is NOT a correct eager RMSNorm — do + not call this module for real inference. + """ return input @@ -72,8 +76,19 @@ def symbolic( do_rotary, kv_num_heads, num_heads, - ): - args = [query, key, value, past_key, past_value, seqlens_k, total_sequence_length, cos_cache, sin_cache] + ) -> Any: + """Emit the fused ``com.microsoft::GroupQueryAttention`` node.""" + args = [ + query, + key, + value, + past_key, + past_value, + seqlens_k, + total_sequence_length, + cos_cache, + sin_cache, + ] attention_output, present_keys, present_values = g.op( "com.microsoft::GroupQueryAttention", *args, @@ -85,8 +100,12 @@ def symbolic( query_sizes = symbolic_helper._get_tensor_sizes(query) attention_output.setType(query.type().with_sizes(query_sizes)) - present_keys.setType(past_key.type().with_sizes(symbolic_helper._get_tensor_sizes(past_key))) - present_values.setType(past_value.type().with_sizes(symbolic_helper._get_tensor_sizes(past_value))) + present_keys.setType( + past_key.type().with_sizes(symbolic_helper._get_tensor_sizes(past_key)) + ) + present_values.setType( + past_value.type().with_sizes(symbolic_helper._get_tensor_sizes(past_value)) + ) return attention_output, present_keys, present_values @staticmethod @@ -104,13 +123,14 @@ def forward( do_rotary, kv_num_heads, num_heads, - ): # noqa: ARG004 - # Shape-only tracing placeholder. The real op is emitted by - # ``symbolic`` during ONNX export; ``forward`` exists solely so the - # TorchScript exporter (and Optimum's pre-export dry run) can trace - # output shapes. It returns the inputs as stand-in present-KV on - # purpose and is NOT correct attention — do not call this module for - # real inference. + ) -> Any: + """Shape-only tracing placeholder; returns stand-in (output, KV). + + The real op is emitted by ``symbolic`` during ONNX export; ``forward`` + exists solely so the TorchScript exporter (and Optimum's pre-export dry + run) can trace output shapes. It is NOT correct attention — do not call + this module for real inference. + """ return query, past_key, past_value # placeholder shapes @@ -135,6 +155,7 @@ def __init__( self.bias = bias def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply the 1x1 conv with NHWC<->NCHW permutes (+ optional bias).""" x = x.permute(0, 3, 1, 2) # NHWC -> NCHW x = torch.nn.functional.conv2d(x, self.weight) x = x.permute(0, 2, 3, 1) # NCHW -> NHWC @@ -144,6 +165,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @classmethod def from_linear_module(cls, linear: nn.Linear) -> TransposeConv2d1x1Transpose: + """Build a 1x1-conv replacement from an existing ``nn.Linear``.""" return cls(linear.in_features, linear.out_features, linear.weight, linear.bias) diff --git a/src/winml/modelkit/models/hf/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3_modeling.py index d3c538df5..f5207d797 100644 --- a/src/winml/modelkit/models/hf/qwen3_modeling.py +++ b/src/winml/modelkit/models/hf/qwen3_modeling.py @@ -41,17 +41,17 @@ class WinMLQwen3RMSNorm(nn.Module): """RMSNorm export variant — ``onnx::LpNormalization`` body.""" def prepare_for_onnx_export(self) -> None: + """Fold the RMSNorm gain into the weight (LpNorm has unit gain).""" # Pre-multiply the gain into the weight (LpNorm has unit gain). # ``scale`` is shape ``[1]`` and broadcasts over ``self.weight`` # (shape ``[hidden_size]``), so the result keeps the per-channel # shape even when the original weights are all ones. n = self.weight.numel() - scale = torch.sqrt( - torch.tensor([n], device=self.weight.device, dtype=self.weight.dtype) - ) + scale = torch.sqrt(torch.tensor([n], device=self.weight.device, dtype=self.weight.dtype)) self.weight = nn.Parameter(scale * self.weight) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """Apply the LpNormalization-based RMSNorm body.""" out = LpNormOnnxExport.apply(hidden_states, -1, 2) return self.weight * out @@ -60,6 +60,7 @@ class WinMLQwen3MLP(nn.Module): """MLP export variant — 1x1 Conv projections (forward unchanged).""" def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + """Optionally swap the MLP's linear projections for 1x1 convs.""" if not matmul_to_conv: return self.gate_proj = TransposeConv2d1x1Transpose.from_linear_module(self.gate_proj) @@ -71,12 +72,13 @@ class WinMLQwen3Attention(nn.Module): """Attention export variant — fused ``GroupQueryAttention`` op.""" def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: + """Optionally swap the Q/K/V/O projections for 1x1 convs.""" if matmul_to_conv: self.q_proj = TransposeConv2d1x1Transpose.from_linear_module(self.q_proj) self.k_proj = TransposeConv2d1x1Transpose.from_linear_module(self.k_proj) self.v_proj = TransposeConv2d1x1Transpose.from_linear_module(self.v_proj) self.o_proj = TransposeConv2d1x1Transpose.from_linear_module(self.o_proj) - self._matmul_to_conv = matmul_to_conv # noqa: SLF001 + self._matmul_to_conv = matmul_to_conv def forward( self, @@ -84,8 +86,9 @@ def forward( past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None, past_seq_len: torch.Tensor | None = None, total_seq_len: torch.Tensor | None = None, - **kwargs: Any, # noqa: ARG002 + **kwargs: Any, ) -> tuple[torch.Tensor, None, tuple[torch.Tensor, torch.Tensor]]: + """Run fused GQA attention and return (output, None, present_kv).""" query_states = self.q_proj(hidden_states) key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) @@ -167,8 +170,9 @@ def forward( past_seq_len: torch.Tensor | None = None, total_seq_len: torch.Tensor | None = None, use_cache: bool = True, - **kwargs: Any, # noqa: ARG002 + **kwargs: Any, ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """Run the decoder layer (attention + MLP) with residual adds.""" residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_out, _, present_kv = self.self_attn( @@ -194,7 +198,8 @@ class WinMLQwen3Model(nn.Module): """Model export variant — transformer-only body (no embeddings / lm_head).""" def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: - self._matmul_to_conv = matmul_to_conv # noqa: SLF001 + """Record whether projections use the 1x1-conv (NHWC) path.""" + self._matmul_to_conv = matmul_to_conv def forward( self, @@ -204,6 +209,7 @@ def forward( total_seq_len: torch.Tensor, use_cache: bool = True, ) -> tuple[torch.Tensor, tuple[tuple[torch.Tensor, torch.Tensor], ...]]: + """Run the transformer-only body, returning hidden states + KV.""" hidden_states = inputs_embeds if self._matmul_to_conv: hidden_states = hidden_states.unsqueeze(0) # NHWC for Conv path diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py b/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py index f01de2f71..b52dfd85e 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py @@ -120,9 +120,12 @@ def _graph_shapes(onnx_path: Path) -> tuple[int, int]: seq_len = dims[1].dim_value elif inp.name == "past_keys_0" and len(dims) >= 3: max_cache_len = dims[2].dim_value - if seq_len is None or max_cache_len is None: + # A symbolic/dynamic axis yields dim_value == 0 (not None), so treat any + # non-positive value as "not a usable static shape" and fail loudly rather + # than silently building zero-length calibration feeds. + if not seq_len or not max_cache_len: raise ValueError( - f"Could not read seq_len/max_cache_len from {onnx_path.name}; " + f"Could not read static seq_len/max_cache_len from {onnx_path.name}; " f"found seq_len={seq_len}, max_cache_len={max_cache_len}" ) return seq_len, max_cache_len From caada38336eb671fa3809a4cd19d983ca8f50d3b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 15:18:36 +0800 Subject: [PATCH 09/17] fix(qwen): address review comments (LpNorm eager norm, CodeQL lint, e2e helper) - LpNormOnnxExport.forward now computes the real L2 normalization instead of a silent identity; export-invariant (node comes from symbolic) and correct in eager. - GroupQueryAttentionOnnxExport.forward keeps the non-raising placeholder, with a docstring explaining why raising is impossible (HTP hierarchy capture runs an eager forward outside trace/export). - Remove unused module-level logger in qwen_transformer_only.py (CodeQL). - Use a single onnx import form in test_quant_calibration.py (CodeQL). - Fix e2e _decoder_onnx_path helper to handle the single-model WinMLModelForGenericTask (.onnx_path) build, not just composite .sub_models. --- .../modelkit/models/hf/qwen3_export_ops.py | 35 ++++++++++++------- .../models/hf/qwen_transformer_only.py | 3 -- .../test_qwen3_transformer_only_quant.py | 14 ++++++-- .../test_quant_calibration.py | 21 ++++++----- 4 files changed, 43 insertions(+), 30 deletions(-) diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3_export_ops.py index e1eba87c3..aed592fa7 100644 --- a/src/winml/modelkit/models/hf/qwen3_export_ops.py +++ b/src/winml/modelkit/models/hf/qwen3_export_ops.py @@ -48,14 +48,15 @@ def symbolic(g, input, axis, p) -> Any: @staticmethod def forward(ctx, input, axis, p) -> Any: - """Shape-only tracing placeholder; returns ``input`` unchanged. + """Real ``LpNormalization`` (``input / ||input||_p`` along ``axis``). - The real op is emitted by ``symbolic`` during ONNX export; ``forward`` - exists solely so the TorchScript exporter (and Optimum's pre-export dry - run) can trace output shapes. It is NOT a correct eager RMSNorm — do - not call this module for real inference. + The exported node comes from ``symbolic``; this eager body computes the + same value so any eager execution (unit tests, calibration debug runs, + the exporter's own shape-tracing pass) gets correctly normalized output + instead of a silent identity. It matches the ONNX op faithfully (no + RMSNorm epsilon), since that is exactly what ``symbolic`` emits. """ - return input + return input / torch.linalg.vector_norm(input, ord=p, dim=axis, keepdim=True) class GroupQueryAttentionOnnxExport(torch.autograd.Function): @@ -124,14 +125,22 @@ def forward( kv_num_heads, num_heads, ) -> Any: - """Shape-only tracing placeholder; returns stand-in (output, KV). - - The real op is emitted by ``symbolic`` during ONNX export; ``forward`` - exists solely so the TorchScript exporter (and Optimum's pre-export dry - run) can trace output shapes. It is NOT correct attention — do not call - this module for real inference. + """Shape-only tracing placeholder; returns a stand-in ``(output, KV)``. + + The real op is emitted by ``symbolic`` during ONNX export; this body + only needs to return tensors of the right shape/dtype. It deliberately + does NOT raise on eager execution, even though that yields a stale + (never-advanced) KV cache: the HTP export pipeline runs a real eager + ``forward`` pass to capture the module hierarchy (see + ``export/htp/hierarchy.py::trace_model_execution``), and that pass is + indistinguishable from misuse — ``torch.jit.is_tracing()`` and + ``torch.onnx.is_in_onnx_export()`` are both False there — so raising + would break the actual build. There is also no cheap faithful eager + equivalent (correct attention would grow the sequence axis that the + static-shape export freezes). This module is export-only by design and + is never run for real inference; calibration loads a fresh real model. """ - return query, past_key, past_value # placeholder shapes + return query, past_key, past_value # placeholder shapes (export-only) # ============================================================================= diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py index fda69495f..28e394a4a 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -28,7 +28,6 @@ from __future__ import annotations -import logging from typing import Any, ClassVar import torch @@ -48,8 +47,6 @@ from .qwen3_modeling import apply_transformer_only_export_prep -logger = logging.getLogger(__name__) - # Distinct model_type for this variant. The underscore form is what the # exporter sees on ``model.config.model_type`` and what Optimum's TasksManager # and ``register_specialization`` are keyed on; the hyphenated form is used for diff --git a/tests/e2e/models/test_qwen3_transformer_only_quant.py b/tests/e2e/models/test_qwen3_transformer_only_quant.py index 7c6499e51..cf5b34132 100644 --- a/tests/e2e/models/test_qwen3_transformer_only_quant.py +++ b/tests/e2e/models/test_qwen3_transformer_only_quant.py @@ -51,9 +51,17 @@ def _qnn_available() -> bool: def _decoder_onnx_path(model) -> str: - """Locate the quantized decode ONNX behind the composite handle.""" - sub = model.sub_models["decoder_gen"] - return str(sub._onnx_path) + """Locate the quantized decode ONNX behind the model handle. + + The decode-only build (``seq_len=1``) returns a single + ``WinMLModelForGenericTask`` whose ``onnx_path`` is the quantized graph; a + full composite build instead exposes it under ``sub_models["decoder_gen"]``. + Handle both so the test does not depend on which wrapper the build picks. + """ + sub_models = getattr(model, "sub_models", None) + if sub_models and "decoder_gen" in sub_models: + return str(sub_models["decoder_gen"].onnx_path) + return str(model.onnx_path) def _qdq_counts(onnx_path: str) -> dict[str, int]: diff --git a/tests/unit/models/qwen_transformer_only/test_quant_calibration.py b/tests/unit/models/qwen_transformer_only/test_quant_calibration.py index f1b160433..75933962e 100644 --- a/tests/unit/models/qwen_transformer_only/test_quant_calibration.py +++ b/tests/unit/models/qwen_transformer_only/test_quant_calibration.py @@ -18,7 +18,6 @@ import numpy as np import onnx import torch -from onnx import TensorProto, helper from winml.modelkit.models.hf.qwen_transformer_only_quant import ( Qwen3DecodeTrajectoryCalibReader, @@ -48,26 +47,26 @@ def _fake_config() -> SimpleNamespace: def _build_tiny_onnx(path, *, seq_len: int, max_cache_len: int) -> None: """Write a minimal graph carrying the inputs the readers introspect.""" inputs = [ - helper.make_tensor_value_info( - "input_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] + onnx.helper.make_tensor_value_info( + "input_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] ), - helper.make_tensor_value_info( - "past_keys_0", TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] + onnx.helper.make_tensor_value_info( + "past_keys_0", onnx.TensorProto.FLOAT16, [1, NUM_KV_HEADS, max_cache_len, HEAD_DIM] ), ] - out = helper.make_tensor_value_info( - "output_hidden_states", TensorProto.FLOAT, [1, seq_len, HIDDEN] + out = onnx.helper.make_tensor_value_info( + "output_hidden_states", onnx.TensorProto.FLOAT, [1, seq_len, HIDDEN] ) - gqa = helper.make_node( + gqa = onnx.helper.make_node( "GroupQueryAttention", ["input_hidden_states"], ["attn_out"], name="gqa_layer_0", domain="com.microsoft", ) - identity = helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) - graph = helper.make_graph([gqa, identity], "tiny", inputs, [out]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 18)]) + identity = onnx.helper.make_node("Identity", ["attn_out"], ["output_hidden_states"]) + graph = onnx.helper.make_graph([gqa, identity], "tiny", inputs, [out]) + model = onnx.helper.make_model(graph, opset_imports=[onnx.helper.make_opsetid("", 18)]) onnx.save(model, str(path)) From c97373c3a663365e6f7f2316d2d14c9be0e4f9de Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 16:14:40 +0800 Subject: [PATCH 10/17] fix(build): resolve quant-finalize hook on model class + update model_type-override test - build_hf_model: look up winml_finalize_quant_config on type(pytorch_model) instead of the instance, and call it with explicit self. Fixes the mypy 'Tensor not callable' error (getattr yields Any) and stops the hook firing on raw HF models / MagicMock test doubles (whose attributes are instance-synthesized), which was serializing a MagicMock into config.json. - test_resolve_loader_config: replace the obsolete 'never mutated' test with one asserting the intended explicit-model_type override (needed for variants like qwen3_transformer_only). --- src/winml/modelkit/build/hf.py | 13 ++++++++++--- tests/unit/loader/test_resolve_loader_config.py | 17 +++++++++++------ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index 08698d125..60c27a02b 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -321,9 +321,16 @@ def _name(base: str) -> str: # exported ONNX exists (e.g. calibration feeds / nodes-to-exclude # derived from the graph). Give the wrapper a chance to populate # those runtime-only fields here. - if pytorch_model is not None and hasattr(pytorch_model, "winml_finalize_quant_config"): - config.quant = pytorch_model.winml_finalize_quant_config( - config.quant, onnx_path=current_path, model_id=model_id + # Resolve the optional hook on the model's *class* (not the + # instance): a genuine wrapper defines it at class scope, whereas a + # raw HF model — or a test double whose attributes are synthesized + # per-instance — does not, so this avoids firing spuriously. + finalize_quant_config = getattr( + type(pytorch_model), "winml_finalize_quant_config", None + ) + if callable(finalize_quant_config): + config.quant = finalize_quant_config( + pytorch_model, config.quant, onnx_path=current_path, model_id=model_id ) # The hook may overwrite the quant scheme (dtypes, symmetry, # nodes-to-exclude) authoritatively, so re-persist the config diff --git a/tests/unit/loader/test_resolve_loader_config.py b/tests/unit/loader/test_resolve_loader_config.py index ea26e6cff..491af63ce 100644 --- a/tests/unit/loader/test_resolve_loader_config.py +++ b/tests/unit/loader/test_resolve_loader_config.py @@ -142,8 +142,13 @@ def test_model_type_only_creates_default_config(self) -> None: mock_create.assert_called_once_with("bert") assert loader_config.task == "feature-extraction" - def test_hf_config_never_mutated(self) -> None: - """hf_config is never mutated — model_type param does not override it.""" + def test_explicit_model_type_overrides_hf_config(self) -> None: + """An explicit model_type (with a model_id) overrides the resolved type. + + Needed so a variant model_type such as ``qwen3_transformer_only`` selects + the variant rather than the architecture's native type. The override only + applies when a model_id is present and the requested type differs. + """ mock_config = MagicMock() mock_config.model_type = "original_type" mock_class = MagicMock(spec=[]) @@ -164,10 +169,10 @@ def test_hf_config_never_mutated(self) -> None: "some-model", model_type="gpt2", task="text-generation" ) - # hf_config retains its original model_type — never mutated - assert hf_config.model_type == "original_type" - # loader_config.model_type reflects the REAL hf_config, not the param - assert loader_config.model_type == "original_type" + # The explicit model_type wins over the architecture's native type. + assert hf_config.model_type == "gpt2" + # loader_config.model_type reflects the overridden type. + assert loader_config.model_type == "gpt2" def test_auto_detect_task_from_model_type(self) -> None: """model_type without task auto-detects first supported task.""" From 752e6c990b17ca8e8b6b3b8b72ae84443312b2d9 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 18:20:00 +0800 Subject: [PATCH 11/17] refactor(quant): move qwen3 calibration logic into quant registry Relocate the model-specific transformer-only calibration/quant logic out of models/hf (an export-only package) into a new quant/calibration/ subpackage, dispatched via a model_type-keyed registry that mirrors COMPOSITE_MODEL_REGISTRY. - Add quant/calibration/{base,registry}.py: QuantConfigFinalizer protocol + register_quant_finalizer / get_quant_finalizer (lazy, torch-free import). - git mv qwen_transformer_only_quant.py -> quant/calibration/qwen3_transformer_only.py and register Qwen3TransformerOnlyQuantFinalizer for 'qwen3_transformer_only'. - build/hf.py: replace the winml_finalize_quant_config wrapper hook with explicit registry dispatch keyed on config.model_type; unregistered types fall back to the default DatasetCalibrationReader. Preserve the model_id/_name_or_path fallback (now model-agnostic in the build layer). - Remove the hook from the export wrapper (back to export-only). - Relocate unit tests to tests/unit/quant/calibration/ and add test_registry.py. w8a16 scheme unchanged; CPU e2e (quantized-graph + GQA-exclusion + FP-parity) and 86 build/quant unit tests pass. --- src/winml/modelkit/build/hf.py | 34 +++++---- .../models/hf/qwen_transformer_only.py | 21 ------ src/winml/modelkit/quant/__init__.py | 4 + .../modelkit/quant/calibration/__init__.py | 23 ++++++ src/winml/modelkit/quant/calibration/base.py | 42 +++++++++++ .../calibration/qwen3_transformer_only.py} | 36 +++++++-- .../modelkit/quant/calibration/registry.py | 73 +++++++++++++++++++ .../test_qwen3_transformer_only_quant.py | 11 +-- .../calibration}/__init__.py | 0 .../calibration/test_qwen3_calibration.py} | 2 +- tests/unit/quant/calibration/test_registry.py | 38 ++++++++++ 11 files changed, 238 insertions(+), 46 deletions(-) create mode 100644 src/winml/modelkit/quant/calibration/__init__.py create mode 100644 src/winml/modelkit/quant/calibration/base.py rename src/winml/modelkit/{models/hf/qwen_transformer_only_quant.py => quant/calibration/qwen3_transformer_only.py} (92%) create mode 100644 src/winml/modelkit/quant/calibration/registry.py rename tests/unit/{models/qwen_transformer_only => quant/calibration}/__init__.py (100%) rename tests/unit/{models/qwen_transformer_only/test_quant_calibration.py => quant/calibration/test_qwen3_calibration.py} (99%) create mode 100644 tests/unit/quant/calibration/test_registry.py diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index 60c27a02b..ef8e794ee 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -317,22 +317,28 @@ def _name(base: str) -> str: else: logger.info("Quantizing model...") t0 = time.monotonic() - # Some model wrappers can only finalize their quant config once the - # exported ONNX exists (e.g. calibration feeds / nodes-to-exclude - # derived from the graph). Give the wrapper a chance to populate - # those runtime-only fields here. - # Resolve the optional hook on the model's *class* (not the - # instance): a genuine wrapper defines it at class scope, whereas a - # raw HF model — or a test double whose attributes are synthesized - # per-instance — does not, so this avoids firing spuriously. - finalize_quant_config = getattr( - type(pytorch_model), "winml_finalize_quant_config", None + # Some model types finalize their quant config only once the + # exported ONNX exists (calibration feeds / nodes-to-exclude derived + # from the graph). Resolve the model-type-specific quant policy from + # the quant registry, keyed on the live ``model_type``. Unregistered + # types return None → the quantizer uses its standard task-aware + # DatasetCalibrationReader. + from ..quant import get_quant_finalizer + + resolved_model_type = ( + getattr(getattr(pytorch_model, "config", None), "model_type", None) or model_type ) - if callable(finalize_quant_config): - config.quant = finalize_quant_config( - pytorch_model, config.quant, onnx_path=current_path, model_id=model_id + quant_finalizer = get_quant_finalizer(resolved_model_type) + if quant_finalizer is not None: + # Generic id fallback: the policy loads a fresh reference model + # for calibration, so feed it the best-known HF id/path. + resolved_model_id = model_id or getattr( + getattr(pytorch_model, "config", None), "_name_or_path", None ) - # The hook may overwrite the quant scheme (dtypes, symmetry, + config.quant = quant_finalizer.finalize( + config.quant, onnx_path=current_path, model_id=resolved_model_id + ) + # The policy may overwrite the quant scheme (dtypes, symmetry, # nodes-to-exclude) authoritatively, so re-persist the config # to keep config.json consistent with what was actually applied. config_path.write_text(json.dumps(config.to_dict(), indent=2)) diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen_transformer_only.py index 28e394a4a..bff3cc5c7 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen_transformer_only.py @@ -125,27 +125,6 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: out.extend([k, v]) return tuple(out) - def winml_finalize_quant_config( - self, quant: Any, *, onnx_path: Any, model_id: str | None = None - ) -> Any: - """Build-pipeline hook: attach the calibration reader + GQA exclusions. - - Called by ``build_hf_model`` just before ``quantize_onnx`` (see - ``build/hf.py``). The exported transformer-only graph determines the - calibration feeds (shapes, KV buffers) and which GroupQueryAttention - nodes stay in float, so the live :class:`WinMLQuantizationConfig` can - only be finalized here — not at config-construction time. - """ - from .qwen_transformer_only_quant import ( - DEFAULT_MODEL_ID, - finalize_transformer_only_quant_config, - ) - - resolved_id = model_id or getattr(self.config, "_name_or_path", None) or DEFAULT_MODEL_ID - return finalize_transformer_only_quant_config( - quant, onnx_path=onnx_path, model_id=resolved_id - ) - # ============================================================================= # Dummy input generators (transformer-only I/O) diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index bc8e6ee06..b7bc8bf38 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -24,12 +24,16 @@ __all__ = [ "QuantizeResult", "WinMLQuantizationConfig", + "get_quant_finalizer", "quantize_onnx", + "register_quant_finalizer", ] _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "quantize_onnx": (".quantizer", "quantize_onnx"), + "get_quant_finalizer": (".calibration", "get_quant_finalizer"), + "register_quant_finalizer": (".calibration", "register_quant_finalizer"), } diff --git a/src/winml/modelkit/quant/calibration/__init__.py b/src/winml/modelkit/quant/calibration/__init__.py new file mode 100644 index 000000000..88b1434c5 --- /dev/null +++ b/src/winml/modelkit/quant/calibration/__init__.py @@ -0,0 +1,23 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Model-type-specific quantization policies (calibration readers + schemes). + +This subpackage stays import-light on purpose: it exposes only the registry +API. The individual finalizer modules (which pull in torch/transformers) are +imported lazily by :func:`get_quant_finalizer` when their ``model_type`` is +quantized. +""" + +from __future__ import annotations + +from .base import QuantConfigFinalizer +from .registry import get_quant_finalizer, register_quant_finalizer + + +__all__ = [ + "QuantConfigFinalizer", + "get_quant_finalizer", + "register_quant_finalizer", +] diff --git a/src/winml/modelkit/quant/calibration/base.py b/src/winml/modelkit/quant/calibration/base.py new file mode 100644 index 000000000..895c48b63 --- /dev/null +++ b/src/winml/modelkit/quant/calibration/base.py @@ -0,0 +1,42 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Base protocol for model-type-specific quantization policies.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Protocol, runtime_checkable + + +if TYPE_CHECKING: + from pathlib import Path + + from ..config import WinMLQuantizationConfig + + +@runtime_checkable +class QuantConfigFinalizer(Protocol): + """Model-type-specific quant policy. + + Given the freshly exported ONNX, a finalizer populates the live + :class:`WinMLQuantizationConfig` with the fields that can only be known + once the graph exists — the calibration data reader, ``nodes_to_exclude``, + and (where the scheme is fixed and reference-matched) the dtype/symmetry + settings. + + Finalizers are registered per ``model_type`` (see + :func:`.registry.register_quant_finalizer`). Model types without a + registered policy fall back to the quantizer's default + ``DatasetCalibrationReader``. + """ + + def finalize( + self, + quant: WinMLQuantizationConfig, + *, + onnx_path: Path, + model_id: str | None = None, + ) -> WinMLQuantizationConfig: + """Return ``quant`` populated with the graph-derived quant settings.""" + ... diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py similarity index 92% rename from src/winml/modelkit/models/hf/qwen_transformer_only_quant.py rename to src/winml/modelkit/quant/calibration/qwen3_transformer_only.py index b52dfd85e..5abb7e4ce 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only_quant.py +++ b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py @@ -5,12 +5,13 @@ """Config-driven w8a16 calibration for the transformer-only Qwen3 build. -The transformer-only export (:mod:`qwen_transformer_only`) emits a graph whose -only quantization-relevant runtime inputs (the calibration feeds and the +The transformer-only export (``models.hf.qwen_transformer_only``) emits a graph +whose only quantization-relevant runtime inputs (the calibration feeds and the ``GroupQueryAttention`` node names to keep in float) can't be known until the ONNX exists. Rather than a standalone post-build script that reaches into -``composite.sub_models[...]._onnx_path``, this module plugs into the normal -build pipeline: :meth:`QwenTransformerOnlyDecoderWrapper.winml_finalize_quant_config` +``composite.sub_models[...]._onnx_path``, this module registers a quant policy +keyed on ``model_type`` (:class:`Qwen3TransformerOnlyQuantFinalizer`). The build +pipeline resolves it via :func:`~winml.modelkit.quant.get_quant_finalizer` and calls :func:`finalize_transformer_only_quant_config` just before ``quantize_onnx`` runs (see ``build/hf.py``), populating the live :class:`WinMLQuantizationConfig` with the right @@ -43,7 +44,8 @@ import torch from transformers import AutoModelForCausalLM, AutoTokenizer -from ...quant.config import CalibrationDataReader, WinMLQuantizationConfig +from ..config import CalibrationDataReader, WinMLQuantizationConfig +from .registry import register_quant_finalizer if TYPE_CHECKING: @@ -418,3 +420,27 @@ def finalize_transformer_only_quant_config( gc.collect() return quant + + +@register_quant_finalizer("qwen3_transformer_only") +class Qwen3TransformerOnlyQuantFinalizer: + """Registered quant policy for the ``qwen3_transformer_only`` model_type. + + Adapts :func:`finalize_transformer_only_quant_config` to the + :class:`~winml.modelkit.quant.calibration.base.QuantConfigFinalizer` + protocol so the build pipeline resolves the model-specific w8a16 scheme + + calibration reader through the quant registry (keyed on ``model_type``) + rather than a hardcoded hook on the export wrapper. + """ + + def finalize( + self, + quant: WinMLQuantizationConfig, + *, + onnx_path: Path, + model_id: str | None = None, + ) -> WinMLQuantizationConfig: + """Populate ``quant`` with the transformer-only w8a16 scheme + reader.""" + return finalize_transformer_only_quant_config( + quant, onnx_path=onnx_path, model_id=model_id or DEFAULT_MODEL_ID + ) diff --git a/src/winml/modelkit/quant/calibration/registry.py b/src/winml/modelkit/quant/calibration/registry.py new file mode 100644 index 000000000..47698da63 --- /dev/null +++ b/src/winml/modelkit/quant/calibration/registry.py @@ -0,0 +1,73 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Registry mapping ``model_type`` to its quantization policy. + +Mirrors the project's other ``model_type``-keyed registries (e.g. +``COMPOSITE_MODEL_REGISTRY``): a finalizer registers itself with +``@register_quant_finalizer(model_type)`` and the build pipeline resolves it +with :func:`get_quant_finalizer`. + +The registry is intentionally lazy. Importing :mod:`winml.modelkit.quant` +must stay free of heavy deps (torch/transformers); the per-model finalizer +modules — which do pull those in — are only imported the first time their +``model_type`` is actually quantized. +""" + +from __future__ import annotations + +import importlib +from typing import TYPE_CHECKING + + +if TYPE_CHECKING: + from .base import QuantConfigFinalizer + + +# Populated by the ``@register_quant_finalizer`` decorator at import time. +_QUANT_FINALIZER_REGISTRY: dict[str, type[QuantConfigFinalizer]] = {} + +# ``model_type`` -> submodule that defines (and self-registers) its finalizer. +# Looked up lazily so the heavy module loads only when needed. Keys must match +# the live ``model_type`` string verbatim (no ``_`` -> ``-`` normalization), +# since lookup is keyed on the exported model's ``config.model_type``. +_KNOWN_FINALIZER_MODULES: dict[str, str] = { + "qwen3_transformer_only": ".qwen3_transformer_only", +} + + +def register_quant_finalizer(model_type: str): + """Class decorator registering a :class:`QuantConfigFinalizer` for ``model_type``.""" + + def decorator(cls: type) -> type: + if not hasattr(cls, "finalize"): + raise TypeError( + f"{cls.__name__} cannot register as a quant finalizer for " + f"{model_type!r}: it must define a ``finalize`` method." + ) + if model_type in _QUANT_FINALIZER_REGISTRY: + raise ValueError( + f"Quant finalizer already registered for {model_type!r}: " + f"{_QUANT_FINALIZER_REGISTRY[model_type].__name__}. " + f"Cannot register {cls.__name__}." + ) + _QUANT_FINALIZER_REGISTRY[model_type] = cls + return cls + + return decorator + + +def get_quant_finalizer(model_type: str | None) -> QuantConfigFinalizer | None: + """Return a finalizer instance for ``model_type``, or ``None`` if unregistered. + + ``None`` means "no model-specific policy" — the quantizer then uses its + standard task-aware ``DatasetCalibrationReader``. + """ + if not model_type: + return None + if model_type not in _QUANT_FINALIZER_REGISTRY and model_type in _KNOWN_FINALIZER_MODULES: + # Triggers the module's ``@register_quant_finalizer`` side effect. + importlib.import_module(_KNOWN_FINALIZER_MODULES[model_type], __package__) + cls = _QUANT_FINALIZER_REGISTRY.get(model_type) + return cls() if cls is not None else None diff --git a/tests/e2e/models/test_qwen3_transformer_only_quant.py b/tests/e2e/models/test_qwen3_transformer_only_quant.py index cf5b34132..831a640e8 100644 --- a/tests/e2e/models/test_qwen3_transformer_only_quant.py +++ b/tests/e2e/models/test_qwen3_transformer_only_quant.py @@ -7,11 +7,12 @@ Replaces the former root-level ``test_qwen.py`` / ``qwen3_transformer_only_quantize.py`` scripts. Quantization is now driven entirely through the standard build pipeline (``WinMLAutoModel.from_pretrained(..., precision="w8a16")``): the -device/precision policy enables the quantize stage, and -``QwenTransformerOnlyDecoderWrapper.winml_finalize_quant_config`` finalizes the -reference-matched scheme (int8-symmetric weights, uint16 activations, -GroupQueryAttention excluded from QDQ) plus the decode-trajectory calibration -reader. +device/precision policy enables the quantize stage, and the +``qwen3_transformer_only`` quant policy registered in +``winml.modelkit.quant.calibration`` (resolved via ``get_quant_finalizer``) +finalizes the reference-matched scheme (int8-symmetric weights, uint16 +activations, GroupQueryAttention excluded from QDQ) plus the decode-trajectory +calibration reader. These tests download Qwen3-0.6B from HuggingFace and run a full CPU export + quantize, so they are gated behind ``slow`` + ``network`` and excluded from the diff --git a/tests/unit/models/qwen_transformer_only/__init__.py b/tests/unit/quant/calibration/__init__.py similarity index 100% rename from tests/unit/models/qwen_transformer_only/__init__.py rename to tests/unit/quant/calibration/__init__.py diff --git a/tests/unit/models/qwen_transformer_only/test_quant_calibration.py b/tests/unit/quant/calibration/test_qwen3_calibration.py similarity index 99% rename from tests/unit/models/qwen_transformer_only/test_quant_calibration.py rename to tests/unit/quant/calibration/test_qwen3_calibration.py index 75933962e..5c8bd9d69 100644 --- a/tests/unit/models/qwen_transformer_only/test_quant_calibration.py +++ b/tests/unit/quant/calibration/test_qwen3_calibration.py @@ -19,7 +19,7 @@ import onnx import torch -from winml.modelkit.models.hf.qwen_transformer_only_quant import ( +from winml.modelkit.quant.calibration.qwen3_transformer_only import ( Qwen3DecodeTrajectoryCalibReader, Qwen3TransformerOnlyCalibReader, _gqa_node_names, diff --git a/tests/unit/quant/calibration/test_registry.py b/tests/unit/quant/calibration/test_registry.py new file mode 100644 index 000000000..b60f74b9b --- /dev/null +++ b/tests/unit/quant/calibration/test_registry.py @@ -0,0 +1,38 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the quant finalizer registry. + +Fast, offline: no model download, no ONNX Runtime. Verifies that the +``model_type`` -> quant policy dispatch (lazy import + decorator registration) +resolves the registered Qwen3 finalizer and falls back to ``None`` (the +quantizer's default DatasetCalibrationReader path) for everything else. +""" + +from __future__ import annotations + +from winml.modelkit.quant import get_quant_finalizer +from winml.modelkit.quant.calibration import QuantConfigFinalizer + + +def test_registered_model_type_resolves_finalizer(): + """The qwen3_transformer_only policy is found via lazy registry import.""" + finalizer = get_quant_finalizer("qwen3_transformer_only") + assert finalizer is not None + assert isinstance(finalizer, QuantConfigFinalizer) + assert hasattr(finalizer, "finalize") + # Registry returns the concrete policy class, not the generic protocol. + assert type(finalizer).__name__ == "Qwen3TransformerOnlyQuantFinalizer" + + +def test_unregistered_model_type_returns_none(): + """Unknown / native model types have no policy -> default reader path.""" + assert get_quant_finalizer("resnet") is None + assert get_quant_finalizer("qwen3") is None + + +def test_none_model_type_returns_none(): + """A missing model_type must not raise and must not dispatch a policy.""" + assert get_quant_finalizer(None) is None + assert get_quant_finalizer("") is None From e9dbe2a2770df108c425078c83068a3dd2803998 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 19:09:17 +0800 Subject: [PATCH 12/17] fix(quant): satisfy mypy + CodeQL on calibration registry - annotate register_quant_finalizer return type (mypy no-untyped-def) - add TYPE_CHECKING re-imports so static analyzers see lazy __all__ exports (CodeQL py/undefined-export) - drop bare ... from finalizer Protocol; docstring is the body (CodeQL ineffectual-statement) --- src/winml/modelkit/quant/__init__.py | 11 ++++++++++- src/winml/modelkit/quant/calibration/base.py | 1 - src/winml/modelkit/quant/calibration/registry.py | 4 +++- 3 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index b7bc8bf38..e43a69068 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -16,7 +16,7 @@ result = quantize_onnx("model.onnx", WinMLQuantizationConfig(samples=100)) """ -from typing import Any +from typing import TYPE_CHECKING, Any from .config import QuantizeResult, WinMLQuantizationConfig @@ -30,6 +30,15 @@ ] +# Names below are loaded lazily via ``__getattr__`` to avoid pulling in +# onnxruntime.quantization/torch at import time. The TYPE_CHECKING re-imports +# give static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports +# without triggering the heavy imports at runtime. +if TYPE_CHECKING: + from .calibration import get_quant_finalizer, register_quant_finalizer + from .quantizer import quantize_onnx + + _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "quantize_onnx": (".quantizer", "quantize_onnx"), "get_quant_finalizer": (".calibration", "get_quant_finalizer"), diff --git a/src/winml/modelkit/quant/calibration/base.py b/src/winml/modelkit/quant/calibration/base.py index 895c48b63..d62ba4322 100644 --- a/src/winml/modelkit/quant/calibration/base.py +++ b/src/winml/modelkit/quant/calibration/base.py @@ -39,4 +39,3 @@ def finalize( model_id: str | None = None, ) -> WinMLQuantizationConfig: """Return ``quant`` populated with the graph-derived quant settings.""" - ... diff --git a/src/winml/modelkit/quant/calibration/registry.py b/src/winml/modelkit/quant/calibration/registry.py index 47698da63..78b321ae4 100644 --- a/src/winml/modelkit/quant/calibration/registry.py +++ b/src/winml/modelkit/quant/calibration/registry.py @@ -22,6 +22,8 @@ if TYPE_CHECKING: + from collections.abc import Callable + from .base import QuantConfigFinalizer @@ -37,7 +39,7 @@ } -def register_quant_finalizer(model_type: str): +def register_quant_finalizer(model_type: str) -> Callable[[type], type]: """Class decorator registering a :class:`QuantConfigFinalizer` for ``model_type``.""" def decorator(cls: type) -> type: From 52745638d9b4e887fe10d4d36a7433813592a3ef Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 24 Jun 2026 22:12:32 +0800 Subject: [PATCH 13/17] Thread model_type + quant finalizer through CLI HF build; move qwen3 transformer-only into subpackage The CLI-only _build_hf_pipeline did not pass loader.model_type to _load_model, so a config requesting qwen3_transformer_only was silently loaded as native qwen3 and crashed at export (embedding got HalfTensor). It also skipped the model-type quant finalizer, producing the default uint8/uint16 minmax scheme instead of the registered int8-sym / GQA-excluded policy. Both gaps existed only in the CLI path; the library build_hf_model already handled them. Mirror that logic so winml build produces the verified w8a16 graph (985 Q / 1294 DQ / 28 GQA / 0 QDQ-touching-GQA) end-to-end. Also move qwen3_export_ops, qwen3_modeling and qwen_transformer_only into a models/hf/qwen3/ subpackage and add regression tests for both fixes. --- src/winml/modelkit/commands/build.py | 32 ++++- src/winml/modelkit/models/hf/__init__.py | 8 +- .../modelkit/models/hf/qwen3/__init__.py | 6 + .../models/hf/{ => qwen3}/qwen3_export_ops.py | 0 .../models/hf/{ => qwen3}/qwen3_modeling.py | 0 .../hf/{ => qwen3}/qwen_transformer_only.py | 14 +- .../calibration/qwen3_transformer_only.py | 2 +- tests/unit/commands/test_build.py | 121 ++++++++++++++++++ 8 files changed, 170 insertions(+), 13 deletions(-) create mode 100644 src/winml/modelkit/models/hf/qwen3/__init__.py rename src/winml/modelkit/models/hf/{ => qwen3}/qwen3_export_ops.py (100%) rename src/winml/modelkit/models/hf/{ => qwen3}/qwen3_modeling.py (100%) rename src/winml/modelkit/models/hf/{ => qwen3}/qwen_transformer_only.py (97%) diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index c3ffc660d..0d10ebf67 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -1339,7 +1339,11 @@ def _name(base: str) -> str: # Load + export (blocking) pytorch_model = _load_model( - config, model_id, trust_remote_code=False, hf_config=preloaded_hf_config + config, + model_id, + trust_remote_code=False, + hf_config=preloaded_hf_config, + model_type=config.loader.model_type, ) t0 = time.monotonic() # config.export is None only for the ONNX build path; this is the HF path. @@ -1384,6 +1388,32 @@ def _name(base: str) -> str: config_path.write_text(json.dumps(config.to_dict(), indent=2)) # ── Quantize stage ─────────────────────────────────────────── + # Some model types finalize their quant config only once the exported ONNX + # exists (calibration feeds / nodes-to-exclude derived from the graph). + # Resolve the model-type-specific quant policy from the quant registry, + # keyed on the live ``model_type`` — mirrors build.hf.build_hf_model so the + # CLI and library pipelines apply the same scheme. Unregistered types return + # None → the quantizer uses its standard task-aware DatasetCalibrationReader. + if config.quant is not None: + from ..quant import get_quant_finalizer + + resolved_model_type = ( + getattr(getattr(pytorch_model, "config", None), "model_type", None) + or config.loader.model_type + ) + quant_finalizer = get_quant_finalizer(resolved_model_type) + if quant_finalizer is not None: + resolved_model_id = model_id or getattr( + getattr(pytorch_model, "config", None), "_name_or_path", None + ) + config.quant = quant_finalizer.finalize( + config.quant, onnx_path=current_path, model_id=resolved_model_id + ) + # The policy may overwrite the quant scheme (dtypes, symmetry, + # nodes-to-exclude) authoritatively, so re-persist the config to keep + # config.json consistent with what was actually applied. + config_path.write_text(json.dumps(config.to_dict(), indent=2)) + current_path = _run_quantize_stage( config=config, current_path=current_path, diff --git a/src/winml/modelkit/models/hf/__init__.py b/src/winml/modelkit/models/hf/__init__.py index 458bc8e34..5c854bb60 100644 --- a/src/winml/modelkit/models/hf/__init__.py +++ b/src/winml/modelkit/models/hf/__init__.py @@ -56,12 +56,12 @@ from .qwen import QWEN_CONFIG from .qwen import QwenGenIOConfig as _QwenGenIOConfig from .qwen import QwenPrefillIOConfig as _QwenPrefillIOConfig -from .qwen_transformer_only import MODEL_CLASS_MAPPING as _QWEN_TO_CLASS_MAPPING -from .qwen_transformer_only import QWEN_TRANSFORMER_ONLY_CONFIG -from .qwen_transformer_only import ( +from .qwen3.qwen_transformer_only import MODEL_CLASS_MAPPING as _QWEN_TO_CLASS_MAPPING +from .qwen3.qwen_transformer_only import QWEN_TRANSFORMER_ONLY_CONFIG +from .qwen3.qwen_transformer_only import ( QwenTransformerOnlyGenIOConfig as _QwenTransformerOnlyGenIOConfig, # triggers registration ) -from .qwen_transformer_only import ( +from .qwen3.qwen_transformer_only import ( # triggers registration QwenTransformerOnlyPrefillIOConfig as _QwenTransformerOnlyPrefillIOConfig, ) diff --git a/src/winml/modelkit/models/hf/qwen3/__init__.py b/src/winml/modelkit/models/hf/qwen3/__init__.py new file mode 100644 index 000000000..332fb9234 --- /dev/null +++ b/src/winml/modelkit/models/hf/qwen3/__init__.py @@ -0,0 +1,6 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +"""Qwen3 transformer-only export support (modeling, export ops, IO configs).""" diff --git a/src/winml/modelkit/models/hf/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3/qwen3_export_ops.py similarity index 100% rename from src/winml/modelkit/models/hf/qwen3_export_ops.py rename to src/winml/modelkit/models/hf/qwen3/qwen3_export_ops.py diff --git a/src/winml/modelkit/models/hf/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py similarity index 100% rename from src/winml/modelkit/models/hf/qwen3_modeling.py rename to src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py diff --git a/src/winml/modelkit/models/hf/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py similarity index 97% rename from src/winml/modelkit/models/hf/qwen_transformer_only.py rename to src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py index bff3cc5c7..cc4985de0 100644 --- a/src/winml/modelkit/models/hf/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py @@ -37,13 +37,13 @@ from optimum.utils.input_generators import DummyInputGenerator from transformers import AutoModelForCausalLM -from ...config import WinMLBuildConfig -from ...export import register_onnx_overwrite -from ...export.config import WinMLExportConfig -from ..winml import register_specialization -from ..winml.composite_model import register_composite_model -from ..winml.decoder_only import WinMLDecoderOnlyModel -from ..winml.kv_cache import WinMLSlidingWindowCache +from ....config import WinMLBuildConfig +from ....export import register_onnx_overwrite +from ....export.config import WinMLExportConfig +from ...winml import register_specialization +from ...winml.composite_model import register_composite_model +from ...winml.decoder_only import WinMLDecoderOnlyModel +from ...winml.kv_cache import WinMLSlidingWindowCache from .qwen3_modeling import apply_transformer_only_export_prep diff --git a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py index 5abb7e4ce..a4dc0c61c 100644 --- a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py +++ b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py @@ -5,7 +5,7 @@ """Config-driven w8a16 calibration for the transformer-only Qwen3 build. -The transformer-only export (``models.hf.qwen_transformer_only``) emits a graph +The transformer-only export (``models.hf.qwen3.qwen_transformer_only``) emits a graph whose only quantization-relevant runtime inputs (the calibration feeds and the ``GroupQueryAttention`` node names to keep in float) can't be known until the ONNX exists. Rather than a standalone post-build script that reaches into diff --git a/tests/unit/commands/test_build.py b/tests/unit/commands/test_build.py index 00f54fc23..400d3cccd 100644 --- a/tests/unit/commands/test_build.py +++ b/tests/unit/commands/test_build.py @@ -1711,3 +1711,124 @@ def test_returns_compiled_path_when_file_exists( # current_path should be updated to compiled_path assert result == compiled_path + + +class TestBuildHfPipelineModelType: + """Regression: the CLI HF pipeline must thread loader.model_type into _load_model. + + Without this, a config requesting a derived model_type (e.g. + ``qwen3_transformer_only``) is silently loaded as its native type, so the + wrong model class is exported. See _build_hf_pipeline. + """ + + @patch("winml.modelkit.utils.console.StageLive") + @patch("winml.modelkit.export.export_onnx") + @patch("winml.modelkit.build.hf._load_model") + def test_load_model_receives_config_model_type( + self, + mock_load_model: MagicMock, + mock_export_onnx: MagicMock, + mock_stage_live: MagicMock, + tmp_path: Path, + ) -> None: + from winml.modelkit.commands.build import _build_hf_pipeline + + mock_stage_live.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_stage_live.return_value.__exit__ = MagicMock(return_value=False) + + # Stop the pipeline right after export so we only exercise the load call. + sentinel = RuntimeError("stop-after-export") + mock_export_onnx.side_effect = sentinel + + config = MagicMock() + config.loader.model_type = "qwen3_transformer_only" + config.loader.task = "feature-extraction" + config.export = MagicMock() + + with pytest.raises(RuntimeError, match="stop-after-export"): + _build_hf_pipeline( + config=config, + model_id="Qwen/Qwen3-0.6B", + output_dir=tmp_path / "out", + rebuild=True, + cache_key=None, + ep=None, + device="cpu", + extra_kwargs={}, + preloaded_hf_config=None, + ) + + mock_load_model.assert_called_once() + assert mock_load_model.call_args.kwargs["model_type"] == "qwen3_transformer_only" + + @patch("winml.modelkit.commands.build._run_compile_stage") + @patch("winml.modelkit.commands.build._run_quantize_stage") + @patch("winml.modelkit.quant.get_quant_finalizer") + @patch("winml.modelkit.commands.build._run_optimize_stage") + @patch("winml.modelkit.commands.build._show_io") + @patch("winml.modelkit.utils.console.StageLive") + @patch("winml.modelkit.export.export_onnx") + @patch("winml.modelkit.build.hf._load_model") + def test_quant_finalizer_applied_for_registered_model_type( + self, + mock_load_model: MagicMock, + mock_export_onnx: MagicMock, + mock_stage_live: MagicMock, + mock_show_io: MagicMock, + mock_optimize: MagicMock, + mock_get_finalizer: MagicMock, + mock_quantize: MagicMock, + mock_compile: MagicMock, + tmp_path: Path, + ) -> None: + """The CLI HF pipeline must apply the registered quant finalizer. + + Mirrors build.hf.build_hf_model: without this the CLI quantizes with the + default task-aware scheme instead of the model-type-specific policy. + """ + from winml.modelkit.commands.build import _build_hf_pipeline + + mock_stage_live.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_stage_live.return_value.__exit__ = MagicMock(return_value=False) + + pytorch_model = MagicMock() + pytorch_model.config.model_type = "qwen3_transformer_only" + mock_load_model.return_value = pytorch_model + + optimized = tmp_path / "optimized.onnx" + mock_optimize.return_value = (optimized, None) + + finalized_quant = MagicMock(name="finalized_quant_config") + finalizer = MagicMock() + finalizer.finalize.return_value = finalized_quant + mock_get_finalizer.return_value = finalizer + + # Stop right after the quantize stage so we don't exercise compile. + mock_quantize.side_effect = RuntimeError("stop-after-quantize") + + config = MagicMock() + config.loader.model_type = "qwen3_transformer_only" + config.loader.task = "text2text-generation" + config.loader.model_class = None + config.export = MagicMock() + config.quant = MagicMock(name="initial_quant_config") + config.to_dict.return_value = {} + + with pytest.raises(RuntimeError, match="stop-after-quantize"): + _build_hf_pipeline( + config=config, + model_id="Qwen/Qwen3-0.6B", + output_dir=tmp_path / "out", + rebuild=True, + cache_key=None, + ep=None, + device="cpu", + extra_kwargs={}, + preloaded_hf_config=None, + ) + + mock_get_finalizer.assert_called_once_with("qwen3_transformer_only") + finalizer.finalize.assert_called_once() + assert finalizer.finalize.call_args.kwargs["model_id"] == "Qwen/Qwen3-0.6B" + # config.quant must be replaced with the finalized scheme before quantize. + assert config.quant is finalized_quant From bfc831b2af843be4c4cb192fbe68c3eb94fd3a3b Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 25 Jun 2026 09:06:04 +0800 Subject: [PATCH 14/17] Make transformer-only composite handle returnable + add one-shot export script Reparent WinMLQwen3TransformerOnlyModel from WinMLDecoderOnlyModel to the plain WinMLCompositeModel. The decoder-only base wires a generation runtime from the eager KV signature (past_0_key) in __init__, which the transformer-only graph (past_keys_0 + symbolic axes) lacks, so from_pretrained crashed while constructing the handle even though both sub-model ONNX built fine. The build is export-only, so the plain composite base (which just stores the built sub-models) is the correct parent. Add a from_pretrained override that injects model_type=qwen3_transformer_only for every sub-model, so omitting model_type no longer silently builds the native (full) qwen3 architecture. Add scripts/export_qwen3_transformer_only.py to export the prefill + decode transformer-only pair in one call, with optional --output-dir copy. Add tests/unit/models/qwen3/test_transformer_only_composite.py covering the reparent, registry entry, and model_type injection. --- scripts/export_qwen3_transformer_only.py | 171 ++++++++++++++++++ .../models/hf/qwen3/qwen_transformer_only.py | 55 ++++-- .../qwen3/test_transformer_only_composite.py | 78 ++++++++ 3 files changed, 293 insertions(+), 11 deletions(-) create mode 100644 scripts/export_qwen3_transformer_only.py create mode 100644 tests/unit/models/qwen3/test_transformer_only_composite.py diff --git a/scripts/export_qwen3_transformer_only.py b/scripts/export_qwen3_transformer_only.py new file mode 100644 index 000000000..6894af518 --- /dev/null +++ b/scripts/export_qwen3_transformer_only.py @@ -0,0 +1,171 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""One-shot export of the Qwen3 transformer-only prefill + decode pair. + +Leverages the registered ``WinMLQwen3TransformerOnlyModel`` composite to build +BOTH transformer-only sub-models in a single call: + + - ``decoder_prefill`` — context graph, ``seq_len`` = --prefill-seq-len (64) + - ``decoder_gen`` — iteration graph, ``seq_len`` = 1 + +Each sub-model is built through the standard ``build_hf_model`` pipeline, so the +model-type quant finalizer is applied (int8 weight / uint16 activation, GQA +excluded from QDQ). Embeddings and the LM head are NOT part of this graph — they +run separately (e.g. from the bundle). + +Usage:: + + # Build (or reuse cached) both ONNX, print their paths + node summary: + uv run python scripts/export_qwen3_transformer_only.py + + # Copy the two ONNX (with external data) into a folder: + uv run python scripts/export_qwen3_transformer_only.py --output-dir out/qwen3 + + # Different model / device / cache geometry, force a rebuild: + uv run python scripts/export_qwen3_transformer_only.py \ + --model-id Qwen/Qwen3-0.6B --device npu \ + --max-cache-len 256 --prefill-seq-len 64 --force-rebuild +""" + +from __future__ import annotations + +import argparse +import collections +import sys +import time +from pathlib import Path + +import onnx + +from winml.modelkit.models.hf.qwen3.qwen_transformer_only import ( + WinMLQwen3TransformerOnlyModel, +) +from winml.modelkit.onnx import copy_onnx_model + + +# Component name -> output file stem used when --output-dir is given. +_OUTPUT_STEMS = { + "decoder_prefill": "prefill", + "decoder_gen": "decode", +} + +# Default EP per device; CPU/NPU/GPU map to their canonical providers. +_DEVICE_TO_EP = { + "cpu": "CPUExecutionProvider", + "npu": "QNNExecutionProvider", + "gpu": "DmlExecutionProvider", +} + + +def node_summary(path: str | Path) -> str: + """Return a one-line QDQ/GQA structural summary of an ONNX graph.""" + model = onnx.load(str(path), load_external_data=False) + counts = collections.Counter(n.op_type for n in model.graph.node) + gqa_io: set[str] = set() + for node in model.graph.node: + if node.op_type == "GroupQueryAttention": + gqa_io.update(node.input) + gqa_io.update(node.output) + qdq_touching_gqa = sum( + 1 + for n in model.graph.node + if n.op_type in ("QuantizeLinear", "DequantizeLinear") + and (set(n.input) & gqa_io or set(n.output) & gqa_io) + ) + return ( + f"Q={counts['QuantizeLinear']} DQ={counts['DequantizeLinear']} " + f"GQA={counts['GroupQueryAttention']} QDQ-touching-GQA={qdq_touching_gqa}" + ) + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + """Parse command-line arguments.""" + p = argparse.ArgumentParser( + description=__doc__, + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + p.add_argument("--model-id", default="Qwen/Qwen3-0.6B", help="HF model id or local path.") + p.add_argument( + "--device", + default="cpu", + choices=sorted(_DEVICE_TO_EP), + help="Target device (selects the canonical EP). Default: cpu.", + ) + p.add_argument("--precision", default="w8a16", help="Build precision. Default: w8a16.") + p.add_argument("--max-cache-len", type=int, default=256, help="Static KV cache length.") + p.add_argument( + "--prefill-seq-len", + type=int, + default=64, + help="Prefill/context sequence length.", + ) + p.add_argument( + "--no-compile", + dest="no_compile", + action="store_true", + default=True, + help="Skip EPContext compilation (default; transformer-only is consumed pre-compile).", + ) + p.add_argument( + "--compile", + dest="no_compile", + action="store_false", + help="Enable EPContext compilation (requires the device's compiler/SDK).", + ) + p.add_argument("--force-rebuild", action="store_true", help="Rebuild even if cached.") + p.add_argument( + "--output-dir", + type=Path, + default=None, + help="If set, copy the two ONNX (with external data) here as prefill.onnx / decode.onnx.", + ) + return p.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + """Build (or reuse) both transformer-only ONNX and report/copy them.""" + args = parse_args(argv) + + t0 = time.monotonic() + model = WinMLQwen3TransformerOnlyModel.from_pretrained( + args.model_id, + device=args.device, + precision=args.precision, + ep=_DEVICE_TO_EP[args.device], + no_compile=args.no_compile, + use_cache=True, + force_rebuild=args.force_rebuild, + sub_model_kwargs={ + "decoder_prefill": { + "shape_config": { + "max_cache_len": args.max_cache_len, + "seq_len": args.prefill_seq_len, + } + }, + "decoder_gen": {"shape_config": {"max_cache_len": args.max_cache_len, "seq_len": 1}}, + }, + ) + elapsed = time.monotonic() - t0 + + print(f"\n=== transformer-only build done in {elapsed:.1f}s ===") + + output_dir: Path | None = args.output_dir + if output_dir is not None: + output_dir.mkdir(parents=True, exist_ok=True) + + for name, sub in model.sub_models.items(): + src = Path(sub.onnx_path) + print(f"\n[{name}] {src}") + print(f" {node_summary(src)}") + if output_dir is not None: + dst = output_dir / f"{_OUTPUT_STEMS.get(name, name)}.onnx" + copy_onnx_model(src, dst) + print(f" -> copied to {dst}") + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py index cc4985de0..fc26e6070 100644 --- a/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py @@ -41,8 +41,7 @@ from ....export import register_onnx_overwrite from ....export.config import WinMLExportConfig from ...winml import register_specialization -from ...winml.composite_model import register_composite_model -from ...winml.decoder_only import WinMLDecoderOnlyModel +from ...winml.composite_model import WinMLCompositeModel, register_composite_model from ...winml.kv_cache import WinMLSlidingWindowCache from .qwen3_modeling import apply_transformer_only_export_prep @@ -324,20 +323,32 @@ def outputs(self) -> dict[str, dict[int, str]]: # ============================================================================= -# Composite inference wrapper (placeholder so the build pipeline finds a -# composite class — generation isn't yet wired for the transformer-only -# I/O signature). +# Composite handle for the transformer-only build. +# +# This variant is EXPORT-ONLY: it produces the prefill + decode transformer +# ONNX for downstream quantization. It extends ``WinMLCompositeModel`` (not +# ``WinMLDecoderOnlyModel``) on purpose — the decoder-only base wires a full +# generation runtime from the eager KV I/O signature (``past_0_key`` etc.), +# which does not match the transformer-only graph (``past_keys_0`` + symbolic +# axes). Inheriting it would make ``from_pretrained`` crash while constructing +# the handle, even though both sub-model ONNX built fine. The plain composite +# base just stores the built sub-models, so ``from_pretrained`` returns a usable +# handle exposing ``.sub_models[name].onnx_path``. # ============================================================================= @register_composite_model(TRANSFORMER_ONLY_MODEL_TYPE, "text-generation") -class WinMLQwen3TransformerOnlyModel(WinMLDecoderOnlyModel): - """Composite handle for the transformer-only Qwen3 build (export only). +class WinMLQwen3TransformerOnlyModel(WinMLCompositeModel): + """Export-only composite handle for the transformer-only Qwen3 build. + + ``from_pretrained`` builds both sub-models (``decoder_prefill`` seq=64, + ``decoder_gen`` seq=1) and returns this handle; the built ONNX paths are + available via ``self.sub_models[name].onnx_path``. ``generate()`` is **not** functional with this build path — the inference - feeds and KV update logic still target the eager I/O signature. Use the - eager :class:`WinMLQwen3Model` for generation; use this class to produce - the transformer-only ONNX for downstream quantization. + feeds and KV update logic target the eager I/O signature. Use the eager + :class:`WinMLQwen3Model` for generation; use this class to produce the + transformer-only ONNX for downstream quantization. """ _SUB_MODEL_CONFIG: ClassVar[dict[str, str]] = { @@ -347,9 +358,31 @@ class WinMLQwen3TransformerOnlyModel(WinMLDecoderOnlyModel): @classmethod def get_cache_class(cls) -> type: - """Return the KV-cache class used during generation.""" + """KV-cache class the decode model targets at runtime (sliding window).""" return WinMLSlidingWindowCache + @classmethod + def from_pretrained( # type: ignore[override] + cls, + model_id: str, + task: str = "text-generation", + *, + model_type: str | None = None, + **kwargs: Any, + ) -> WinMLCompositeModel: + """Build both transformer-only sub-models and return the composite handle. + + Forces ``model_type="qwen3_transformer_only"`` for every sub-model so the + composite builds the transformer-only variant instead of silently falling + back to the native (full) ``qwen3`` architecture when the caller omits it. + """ + return super().from_pretrained( + model_id, + task, + model_type=model_type or TRANSFORMER_ONLY_MODEL_TYPE, + **kwargs, + ) + # ============================================================================= # Declarative registration (import-time) diff --git a/tests/unit/models/qwen3/test_transformer_only_composite.py b/tests/unit/models/qwen3/test_transformer_only_composite.py new file mode 100644 index 000000000..d4456c141 --- /dev/null +++ b/tests/unit/models/qwen3/test_transformer_only_composite.py @@ -0,0 +1,78 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Unit tests for the Qwen3 transformer-only composite handle. + +The transformer-only build is EXPORT-ONLY. Its composite handle must: + +1. Extend the plain ``WinMLCompositeModel`` (NOT ``WinMLDecoderOnlyModel``) so + ``from_pretrained`` can return after building the sub-models. The decoder-only + base wires a generation runtime from the eager KV name ``past_0_key`` in + ``__init__``, which the transformer-only graph (``past_keys_0``) lacks — and + would crash handle construction even though both ONNX built fine. +2. Inject ``model_type="qwen3_transformer_only"`` for every sub-model so the + composite builds the transformer-only variant rather than the native (full) + ``qwen3`` architecture when the caller omits ``model_type``. +""" + +from __future__ import annotations + +from unittest.mock import patch + +from winml.modelkit.models.hf.qwen3.qwen_transformer_only import ( + WinMLQwen3TransformerOnlyModel, +) +from winml.modelkit.models.winml import WinMLCompositeModel +from winml.modelkit.models.winml.composite_model import COMPOSITE_MODEL_REGISTRY +from winml.modelkit.models.winml.decoder_only import WinMLDecoderOnlyModel + + +class TestTransformerOnlyCompositeHandle: + def test_registered_for_text_generation(self) -> None: + assert ( + COMPOSITE_MODEL_REGISTRY.get(("qwen3_transformer_only", "text-generation")) + is WinMLQwen3TransformerOnlyModel + ) + + def test_is_plain_composite_not_decoder_runtime(self) -> None: + # Export-only: must not inherit the decoder-only generation runtime whose + # __init__ assumes the eager KV signature and crashes on this graph. + assert issubclass(WinMLQwen3TransformerOnlyModel, WinMLCompositeModel) + assert not issubclass(WinMLQwen3TransformerOnlyModel, WinMLDecoderOnlyModel) + + def test_sub_model_config(self) -> None: + assert WinMLQwen3TransformerOnlyModel._SUB_MODEL_CONFIG == { + "decoder_prefill": "feature-extraction", + "decoder_gen": "text2text-generation", + } + + def test_from_pretrained_injects_transformer_only_model_type(self) -> None: + recorded: dict[str, object] = {} + + def _fake(cls, model_id, task="text-generation", **kwargs): + recorded["model_id"] = model_id + recorded["task"] = task + recorded["model_type"] = kwargs.get("model_type") + return "SENTINEL" + + with patch.object(WinMLCompositeModel, "from_pretrained", classmethod(_fake)): + result = WinMLQwen3TransformerOnlyModel.from_pretrained("Qwen/Qwen3-0.6B") + + assert result == "SENTINEL" + assert recorded["model_id"] == "Qwen/Qwen3-0.6B" + assert recorded["model_type"] == "qwen3_transformer_only" + + def test_from_pretrained_preserves_explicit_model_type(self) -> None: + recorded: dict[str, object] = {} + + def _fake(cls, model_id, task="text-generation", **kwargs): + recorded["model_type"] = kwargs.get("model_type") + return "SENTINEL" + + with patch.object(WinMLCompositeModel, "from_pretrained", classmethod(_fake)): + WinMLQwen3TransformerOnlyModel.from_pretrained( + "Qwen/Qwen3-0.6B", model_type="custom-variant" + ) + + assert recorded["model_type"] == "custom-variant" From 4049631cfa4e578bd10be2f216dfbb8bf632d317 Mon Sep 17 00:00:00 2001 From: spalne Date: Thu, 25 Jun 2026 14:57:10 -0700 Subject: [PATCH 15/17] test(qwen3): fix NPU quant test EP detection and decoder path lookup --- .../test_qwen3_transformer_only_quant.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/tests/e2e/models/test_qwen3_transformer_only_quant.py b/tests/e2e/models/test_qwen3_transformer_only_quant.py index 831a640e8..e84b5d4c3 100644 --- a/tests/e2e/models/test_qwen3_transformer_only_quant.py +++ b/tests/e2e/models/test_qwen3_transformer_only_quant.py @@ -47,21 +47,36 @@ def _qnn_available() -> bool: - """True when ONNX Runtime exposes the QNN execution provider (real NPU).""" - return "QNNExecutionProvider" in ort.get_available_providers() - - -def _decoder_onnx_path(model) -> str: + """True when a QNN NPU device is reachable via the WinML autoEP path.""" + try: + from winml.modelkit.winml import get_registered_ep_devices + except Exception: + return "QNNExecutionProvider" in ort.get_available_providers() + + try: + devices = get_registered_ep_devices() + except Exception: + return False + + for device in devices: + ep_name = str(getattr(device, "ep_name", "")) + device_type = getattr(getattr(device, "device", None), "type", None) + if ep_name == "QNNExecutionProvider" and str(device_type).endswith("NPU"): + return True + return False + + +def _decoder_onnx_path(model, sub_name: str = "decoder_gen") -> str: """Locate the quantized decode ONNX behind the model handle. The decode-only build (``seq_len=1``) returns a single ``WinMLModelForGenericTask`` whose ``onnx_path`` is the quantized graph; a - full composite build instead exposes it under ``sub_models["decoder_gen"]``. + full composite build instead exposes it under ``sub_models[sub_name]``. Handle both so the test does not depend on which wrapper the build picks. """ sub_models = getattr(model, "sub_models", None) - if sub_models and "decoder_gen" in sub_models: - return str(sub_models["decoder_gen"].onnx_path) + if sub_models and sub_name in sub_models: + return str(sub_models[sub_name].onnx_path) return str(model.onnx_path) @@ -251,7 +266,7 @@ def test_npu_build_quantizes(task, seq_len, tmp_path): cache_dir=str(tmp_path), ) sub_name = "decoder_prefill" if seq_len == 64 else "decoder_gen" - onnx_path = str(model.sub_models[sub_name]._onnx_path) + onnx_path = _decoder_onnx_path(model, sub_name) counts = _qdq_counts(onnx_path) assert counts.get("QuantizeLinear", 0) > 0 assert counts.get("GroupQueryAttention", 0) > 0 From 44a68d40165f4959fa02c663223bec912b4c6d4e Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 26 Jun 2026 14:39:19 +0800 Subject: [PATCH 16/17] Address PR review: dedup quant finalizer dispatch, plain finalizer registry, non-mutating model_type override - registry: replace decorator/register API with a plain QUANT_FINALIZERS dict; get_quant_finalizer lazily imports + instantiates (review #1). - quantizer: resolve+apply the model-type-specific quant policy inside quantize_onnx from config.model_type, a single seam shared by all callers; drop the duplicated dispatch blocks in commands/build.py and build/hf.py (review #2). - loader: thread an explicit model_type as model_type_override through resolve_task instead of mutating hf_config.model_type, so exporters/patchers keep the architecture's native type while the loader config surfaces the build variant (review #4). --- src/winml/modelkit/build/hf.py | 33 ++----- src/winml/modelkit/commands/build.py | 32 ++----- src/winml/modelkit/config/build.py | 4 + src/winml/modelkit/loader/config.py | 35 +++++-- src/winml/modelkit/loader/hf.py | 30 ++++-- src/winml/modelkit/loader/resolution.py | 8 +- src/winml/modelkit/quant/__init__.py | 4 +- .../modelkit/quant/calibration/__init__.py | 4 +- src/winml/modelkit/quant/calibration/base.py | 7 +- .../calibration/qwen3_transformer_only.py | 19 ++-- .../modelkit/quant/calibration/registry.py | 71 +++++--------- src/winml/modelkit/quant/config.py | 11 +++ src/winml/modelkit/quant/quantizer.py | 17 ++++ tests/unit/commands/test_build.py | 28 +++--- tests/unit/config/test_build.py | 2 +- tests/unit/loader/test_load_hf_model.py | 6 +- .../unit/loader/test_resolve_loader_config.py | 16 ++-- tests/unit/test_quantizer.py | 93 +++++++++++++++++++ 18 files changed, 259 insertions(+), 161 deletions(-) diff --git a/src/winml/modelkit/build/hf.py b/src/winml/modelkit/build/hf.py index ef8e794ee..b2897b952 100644 --- a/src/winml/modelkit/build/hf.py +++ b/src/winml/modelkit/build/hf.py @@ -317,31 +317,14 @@ def _name(base: str) -> str: else: logger.info("Quantizing model...") t0 = time.monotonic() - # Some model types finalize their quant config only once the - # exported ONNX exists (calibration feeds / nodes-to-exclude derived - # from the graph). Resolve the model-type-specific quant policy from - # the quant registry, keyed on the live ``model_type``. Unregistered - # types return None → the quantizer uses its standard task-aware - # DatasetCalibrationReader. - from ..quant import get_quant_finalizer - - resolved_model_type = ( - getattr(getattr(pytorch_model, "config", None), "model_type", None) or model_type - ) - quant_finalizer = get_quant_finalizer(resolved_model_type) - if quant_finalizer is not None: - # Generic id fallback: the policy loads a fresh reference model - # for calibration, so feed it the best-known HF id/path. - resolved_model_id = model_id or getattr( - getattr(pytorch_model, "config", None), "_name_or_path", None - ) - config.quant = quant_finalizer.finalize( - config.quant, onnx_path=current_path, model_id=resolved_model_id - ) - # The policy may overwrite the quant scheme (dtypes, symmetry, - # nodes-to-exclude) authoritatively, so re-persist the config - # to keep config.json consistent with what was actually applied. - config_path.write_text(json.dumps(config.to_dict(), indent=2)) + # A model-type-specific quant policy (e.g. the qwen3_transformer_only + # w8a16 finalizer) is resolved and applied inside ``quantize_onnx`` + # from ``config.quant.model_type``. Ensure it carries the resolved + # variant so hand-built configs (that skipped assemble_build_config) + # still trigger the right policy; ``quantize_onnx`` no-ops for + # model types without a registered finalizer. + if config.quant.model_type is None: + config.quant.model_type = config.loader.model_type quant_result = quantize_onnx( model_path=current_path, output_path=quantized_path, diff --git a/src/winml/modelkit/commands/build.py b/src/winml/modelkit/commands/build.py index 88ab06967..74fe750e9 100644 --- a/src/winml/modelkit/commands/build.py +++ b/src/winml/modelkit/commands/build.py @@ -1443,31 +1443,13 @@ def _name(base: str) -> str: config_path.write_text(json.dumps(config.to_dict(), indent=2)) # ── Quantize stage ─────────────────────────────────────────── - # Some model types finalize their quant config only once the exported ONNX - # exists (calibration feeds / nodes-to-exclude derived from the graph). - # Resolve the model-type-specific quant policy from the quant registry, - # keyed on the live ``model_type`` — mirrors build.hf.build_hf_model so the - # CLI and library pipelines apply the same scheme. Unregistered types return - # None → the quantizer uses its standard task-aware DatasetCalibrationReader. - if config.quant is not None: - from ..quant import get_quant_finalizer - - resolved_model_type = ( - getattr(getattr(pytorch_model, "config", None), "model_type", None) - or config.loader.model_type - ) - quant_finalizer = get_quant_finalizer(resolved_model_type) - if quant_finalizer is not None: - resolved_model_id = model_id or getattr( - getattr(pytorch_model, "config", None), "_name_or_path", None - ) - config.quant = quant_finalizer.finalize( - config.quant, onnx_path=current_path, model_id=resolved_model_id - ) - # The policy may overwrite the quant scheme (dtypes, symmetry, - # nodes-to-exclude) authoritatively, so re-persist the config to keep - # config.json consistent with what was actually applied. - config_path.write_text(json.dumps(config.to_dict(), indent=2)) + # A model-type-specific quant policy (e.g. the qwen3_transformer_only w8a16 + # finalizer) is resolved and applied inside ``quantize_onnx`` from + # ``config.quant.model_type``; no per-call-site dispatch needed here. Carry + # the resolved variant onto the quant config so configs that were hand-built + # or loaded from JSON (skipping assemble_build_config) still trigger it. + if config.quant is not None and config.quant.model_type is None: + config.quant.model_type = config.loader.model_type current_path = _run_quantize_stage( config=config, diff --git a/src/winml/modelkit/config/build.py b/src/winml/modelkit/config/build.py index 6ca550d15..80647d2a3 100644 --- a/src/winml/modelkit/config/build.py +++ b/src/winml/modelkit/config/build.py @@ -1071,6 +1071,10 @@ def _assemble_config( model_type, ) quant_config.model_name = model_id or model_type + # Carry the resolved model_type so quantize_onnx can resolve a + # model-type-specific quant policy (e.g. the qwen3_transformer_only + # w8a16 finalizer) from the exported graph. + quant_config.model_type = loader_config.model_type return WinMLBuildConfig( loader=loader_config, diff --git a/src/winml/modelkit/loader/config.py b/src/winml/modelkit/loader/config.py index b97a13825..5e9ca0bff 100644 --- a/src/winml/modelkit/loader/config.py +++ b/src/winml/modelkit/loader/config.py @@ -220,22 +220,33 @@ def resolve_loader_config( f"attribute. Cannot proceed with config generation." ) - # Explicit model_type override alongside a model_id: honor the requested - # type so downstream class / build-config / export resolution selects the - # variant (e.g. "qwen3_transformer_only") rather than the architecture's - # native type. The model_type-only path above (AutoConfig.for_model) is + # Explicit model_type override alongside a model_id: thread the requested + # variant through downstream resolution (task / class / composite tag and + # the loader config's model_type) WITHOUT mutating the loaded HF config. The + # exported graph, the htp Optimum patcher and every other consumer must keep + # seeing the architecture's native type; only the resolved build-variant tag + # changes. The model_type-only path above (AutoConfig.for_model) is # unaffected because it only runs when model_id is None. - if model_id is not None and model_type is not None and hf_config.model_type != model_type: + model_type_override = ( + model_type + if (model_id is not None and model_type is not None and hf_config.model_type != model_type) + else None + ) + if model_type_override is not None: logger.info( - "Overriding resolved model_type '%s' -> '%s' (explicit request)", + "Applying model_type override '%s' -> '%s' (explicit request)", hf_config.model_type, - model_type, + model_type_override, ) - hf_config.model_type = model_type # 2-3. Unified resolution. Task detection — including the no-architectures # --model-type fallback (first supported task) — now lives in resolve_task. - resolution = resolve_task(hf_config, task=task, model_class=model_class) + resolution = resolve_task( + hf_config, + task=task, + model_class=model_class, + model_type_override=model_type_override, + ) resolved_task, resolved_class = resolution.task, resolution.model_class logger.info("Resolved: task=%s, model_class=%s", resolved_task, resolved_class.__name__) @@ -245,6 +256,12 @@ def resolve_loader_config( resolved_class, ) + # The explicit variant tag wins over the architecture's native type for the + # loader config (drives downstream build-variant dispatch), while + # resolved_hf_config keeps its native model_type. + if model_type_override is not None: + resolved_model_type = model_type_override + # 5. Build loader config loader_config = WinMLLoaderConfig( task=resolved_task, diff --git a/src/winml/modelkit/loader/hf.py b/src/winml/modelkit/loader/hf.py index 619109368..4a03bd763 100644 --- a/src/winml/modelkit/loader/hf.py +++ b/src/winml/modelkit/loader/hf.py @@ -219,17 +219,22 @@ def load_hf_model( trust_remote_code=trust_remote_code, ) - # Explicit model_type override: select a registered build variant (e.g. - # "qwen3_transformer_only") rather than the architecture's native type. - # Mutates the freshly-loaded config only; gated on an explicit request so - # normal loading is unaffected. - if model_type is not None and getattr(hf_config, "model_type", None) != model_type: + # Explicit model_type override: thread the requested build variant (e.g. + # "qwen3_transformer_only") into task resolution WITHOUT mutating the + # freshly-loaded HF config. The torch model is instantiated from its own + # native config below, so export/patcher consumers keep the native type; + # only class/task resolution sees the variant. + model_type_override = ( + model_type + if model_type is not None and getattr(hf_config, "model_type", None) != model_type + else None + ) + if model_type_override is not None: logger.info( - "Overriding model_type '%s' -> '%s' (explicit request)", + "Applying model_type override '%s' -> '%s' (explicit request)", getattr(hf_config, "model_type", None), - model_type, + model_type_override, ) - hf_config.model_type = model_type # [2] Task & Model Class Resolution from .resolution import resolve_task @@ -241,10 +246,15 @@ def load_hf_model( resolved_class = _load_class_from_script(user_script, model_class) logger.info("Using custom model class from script: %s", model_class) # Surfaced modality-aware task (consistent with the non-script branch). - task = resolve_task(hf_config, task=task).task + task = resolve_task(hf_config, task=task, model_type_override=model_type_override).task else: try: - resolution = resolve_task(hf_config, task=task, model_class=model_class) + resolution = resolve_task( + hf_config, + task=task, + model_class=model_class, + model_type_override=model_type_override, + ) task, resolved_class = resolution.task, resolution.model_class except ValueError as e: raise ValueError( diff --git a/src/winml/modelkit/loader/resolution.py b/src/winml/modelkit/loader/resolution.py index fce6448fb..f700e19e7 100644 --- a/src/winml/modelkit/loader/resolution.py +++ b/src/winml/modelkit/loader/resolution.py @@ -238,6 +238,7 @@ def _get_custom_model_class(model_type: str, task: str) -> type | None: return None + # Component-name -> sub-task, e.g. {"encoder": "feature-extraction", # "decoder": "text2text-generation"} (the composite ``_SUB_MODEL_CONFIG`` shape). CompositeComponents = dict[str, str] @@ -372,16 +373,21 @@ def resolve_task( *, task: str | None = None, model_class: str | None = None, + model_type_override: str | None = None, ) -> TaskResolution: """Resolve a single model's task + class from an HF config. Stages: 0 user override -> 1 detect (override / no-architectures / TasksManager / default) -> 2 model class -> 3 modality upgrade (detection path only) -> 4 composite tag. + + ``model_type_override`` lets a caller drive resolution with a build variant + (e.g. ``qwen3_transformer_only``) without mutating the loaded HF config; when + ``None`` the architecture's native ``config.model_type`` is used. """ from optimum.exporters.tasks import TasksManager - model_type = getattr(config, "model_type", None) + model_type = model_type_override or getattr(config, "model_type", None) model_type_norm = model_type.lower().replace("_", "-") if model_type else "" model_id = getattr(config, "_name_or_path", "") or None diff --git a/src/winml/modelkit/quant/__init__.py b/src/winml/modelkit/quant/__init__.py index e43a69068..9a2c0b34c 100644 --- a/src/winml/modelkit/quant/__init__.py +++ b/src/winml/modelkit/quant/__init__.py @@ -26,7 +26,6 @@ "WinMLQuantizationConfig", "get_quant_finalizer", "quantize_onnx", - "register_quant_finalizer", ] @@ -35,14 +34,13 @@ # give static analyzers (mypy, CodeQL) visibility into what ``__all__`` exports # without triggering the heavy imports at runtime. if TYPE_CHECKING: - from .calibration import get_quant_finalizer, register_quant_finalizer + from .calibration import get_quant_finalizer from .quantizer import quantize_onnx _LAZY_IMPORTS: dict[str, tuple[str, str]] = { "quantize_onnx": (".quantizer", "quantize_onnx"), "get_quant_finalizer": (".calibration", "get_quant_finalizer"), - "register_quant_finalizer": (".calibration", "register_quant_finalizer"), } diff --git a/src/winml/modelkit/quant/calibration/__init__.py b/src/winml/modelkit/quant/calibration/__init__.py index 88b1434c5..f8a769e41 100644 --- a/src/winml/modelkit/quant/calibration/__init__.py +++ b/src/winml/modelkit/quant/calibration/__init__.py @@ -13,11 +13,11 @@ from __future__ import annotations from .base import QuantConfigFinalizer -from .registry import get_quant_finalizer, register_quant_finalizer +from .registry import QUANT_FINALIZERS, get_quant_finalizer __all__ = [ + "QUANT_FINALIZERS", "QuantConfigFinalizer", "get_quant_finalizer", - "register_quant_finalizer", ] diff --git a/src/winml/modelkit/quant/calibration/base.py b/src/winml/modelkit/quant/calibration/base.py index d62ba4322..39b9543c5 100644 --- a/src/winml/modelkit/quant/calibration/base.py +++ b/src/winml/modelkit/quant/calibration/base.py @@ -25,10 +25,9 @@ class QuantConfigFinalizer(Protocol): and (where the scheme is fixed and reference-matched) the dtype/symmetry settings. - Finalizers are registered per ``model_type`` (see - :func:`.registry.register_quant_finalizer`). Model types without a - registered policy fall back to the quantizer's default - ``DatasetCalibrationReader``. + Finalizers are named per ``model_type`` in + :data:`.registry.QUANT_FINALIZERS`. Model types without a registered policy + fall back to the quantizer's default ``DatasetCalibrationReader``. """ def finalize( diff --git a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py index 68504504b..04901ca7c 100644 --- a/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py +++ b/src/winml/modelkit/quant/calibration/qwen3_transformer_only.py @@ -9,8 +9,9 @@ whose only quantization-relevant runtime inputs (the calibration feeds and the ``GroupQueryAttention`` node names to keep in float) can't be known until the ONNX exists. Rather than a standalone post-build script that reaches into -``composite.sub_models[...]._onnx_path``, this module registers a quant policy -keyed on ``model_type`` (:class:`Qwen3TransformerOnlyQuantFinalizer`). The build +``composite.sub_models[...]._onnx_path``, this module defines a quant policy +keyed on ``model_type`` (:class:`Qwen3TransformerOnlyQuantFinalizer`, named in +:data:`~winml.modelkit.quant.calibration.registry.QUANT_FINALIZERS`). The build pipeline resolves it via :func:`~winml.modelkit.quant.get_quant_finalizer` and calls :func:`finalize_transformer_only_quant_config` just before ``quantize_onnx`` runs (see ``build/hf.py``), populating the live @@ -45,7 +46,6 @@ from transformers import AutoModelForCausalLM, AutoTokenizer from ..config import CalibrationDataReader, WinMLQuantizationConfig -from .registry import register_quant_finalizer if TYPE_CHECKING: @@ -427,15 +427,16 @@ def finalize_transformer_only_quant_config( return quant -@register_quant_finalizer("qwen3_transformer_only") class Qwen3TransformerOnlyQuantFinalizer: - """Registered quant policy for the ``qwen3_transformer_only`` model_type. + """Quant policy for the ``qwen3_transformer_only`` model_type. - Adapts :func:`finalize_transformer_only_quant_config` to the + Named in :data:`~winml.modelkit.quant.calibration.registry.QUANT_FINALIZERS` + and resolved by :func:`~winml.modelkit.quant.get_quant_finalizer`. Adapts + :func:`finalize_transformer_only_quant_config` to the :class:`~winml.modelkit.quant.calibration.base.QuantConfigFinalizer` - protocol so the build pipeline resolves the model-specific w8a16 scheme + - calibration reader through the quant registry (keyed on ``model_type``) - rather than a hardcoded hook on the export wrapper. + protocol so the build pipeline applies the model-specific w8a16 scheme + + calibration reader (keyed on ``model_type``) rather than a hardcoded hook on + the export wrapper. """ def finalize( diff --git a/src/winml/modelkit/quant/calibration/registry.py b/src/winml/modelkit/quant/calibration/registry.py index 78b321ae4..3212c2002 100644 --- a/src/winml/modelkit/quant/calibration/registry.py +++ b/src/winml/modelkit/quant/calibration/registry.py @@ -4,15 +4,16 @@ # -------------------------------------------------------------------------- """Registry mapping ``model_type`` to its quantization policy. -Mirrors the project's other ``model_type``-keyed registries (e.g. -``COMPOSITE_MODEL_REGISTRY``): a finalizer registers itself with -``@register_quant_finalizer(model_type)`` and the build pipeline resolves it -with :func:`get_quant_finalizer`. - -The registry is intentionally lazy. Importing :mod:`winml.modelkit.quant` -must stay free of heavy deps (torch/transformers); the per-model finalizer -modules — which do pull those in — are only imported the first time their -``model_type`` is actually quantized. +A model type with a fixed, reference-matched quant scheme (calibration reader, +``nodes_to_exclude``, dtypes) names its :class:`QuantConfigFinalizer` in the +plain ``QUANT_FINALIZERS`` dict below; the build pipeline resolves it with +:func:`get_quant_finalizer`. This mirrors the other ``model_type``-keyed tables +in the repo — a simple dict, no decorator/self-registration machinery. + +The lookup is intentionally lazy. Importing :mod:`winml.modelkit.quant` must +stay free of heavy deps (torch/transformers); the per-model finalizer modules — +which do pull those in — are only imported the first time their ``model_type`` +is actually quantized. """ from __future__ import annotations @@ -22,44 +23,19 @@ if TYPE_CHECKING: - from collections.abc import Callable - from .base import QuantConfigFinalizer -# Populated by the ``@register_quant_finalizer`` decorator at import time. -_QUANT_FINALIZER_REGISTRY: dict[str, type[QuantConfigFinalizer]] = {} - -# ``model_type`` -> submodule that defines (and self-registers) its finalizer. -# Looked up lazily so the heavy module loads only when needed. Keys must match -# the live ``model_type`` string verbatim (no ``_`` -> ``-`` normalization), -# since lookup is keyed on the exported model's ``config.model_type``. -_KNOWN_FINALIZER_MODULES: dict[str, str] = { - "qwen3_transformer_only": ".qwen3_transformer_only", +# ``model_type`` -> ``(submodule, class name)`` of its QuantConfigFinalizer. +# Imported lazily by ``get_quant_finalizer`` so the heavy module loads only when +# needed. Keys must match the live ``model_type`` string verbatim (no ``_`` -> +# ``-`` normalization), since lookup is keyed on the exported model's +# ``config.model_type``. +QUANT_FINALIZERS: dict[str, tuple[str, str]] = { + "qwen3_transformer_only": (".qwen3_transformer_only", "Qwen3TransformerOnlyQuantFinalizer"), } -def register_quant_finalizer(model_type: str) -> Callable[[type], type]: - """Class decorator registering a :class:`QuantConfigFinalizer` for ``model_type``.""" - - def decorator(cls: type) -> type: - if not hasattr(cls, "finalize"): - raise TypeError( - f"{cls.__name__} cannot register as a quant finalizer for " - f"{model_type!r}: it must define a ``finalize`` method." - ) - if model_type in _QUANT_FINALIZER_REGISTRY: - raise ValueError( - f"Quant finalizer already registered for {model_type!r}: " - f"{_QUANT_FINALIZER_REGISTRY[model_type].__name__}. " - f"Cannot register {cls.__name__}." - ) - _QUANT_FINALIZER_REGISTRY[model_type] = cls - return cls - - return decorator - - def get_quant_finalizer(model_type: str | None) -> QuantConfigFinalizer | None: """Return a finalizer instance for ``model_type``, or ``None`` if unregistered. @@ -68,8 +44,11 @@ def get_quant_finalizer(model_type: str | None) -> QuantConfigFinalizer | None: """ if not model_type: return None - if model_type not in _QUANT_FINALIZER_REGISTRY and model_type in _KNOWN_FINALIZER_MODULES: - # Triggers the module's ``@register_quant_finalizer`` side effect. - importlib.import_module(_KNOWN_FINALIZER_MODULES[model_type], __package__) - cls = _QUANT_FINALIZER_REGISTRY.get(model_type) - return cls() if cls is not None else None + entry = QUANT_FINALIZERS.get(model_type) + if entry is None: + return None + module_name, class_name = entry + module = importlib.import_module(module_name, __package__) + finalizer_cls = getattr(module, class_name) + finalizer: QuantConfigFinalizer = finalizer_cls() + return finalizer diff --git a/src/winml/modelkit/quant/config.py b/src/winml/modelkit/quant/config.py index 3465e62d2..a4ef2250b 100644 --- a/src/winml/modelkit/quant/config.py +++ b/src/winml/modelkit/quant/config.py @@ -70,6 +70,14 @@ class WinMLQuantizationConfig: model_name: str | None = None # e.g., "microsoft/resnet-50" dataset_name: str | None = None # Optional: override default dataset + # Model-type-specific quant policy selector. When set to a model_type that + # has a registered finalizer (see ``quant.calibration.QUANT_FINALIZERS``), + # ``quantize_onnx`` resolves and applies that policy — populating the + # calibration reader / nodes-to-exclude / fixed dtypes from the exported + # graph — before running the quantization pass. None = no model-specific + # policy (use the default task-aware calibration). + model_type: str | None = None + # Quantization types (static/dynamic) weight_type: Literal["uint8", "int8", "uint16", "int16"] = "uint8" activation_type: Literal["uint8", "int8", "uint16", "int16"] = "uint8" @@ -142,6 +150,8 @@ def to_dict(self) -> dict: result["model_name"] = self.model_name if self.dataset_name is not None: result["dataset_name"] = self.dataset_name + if self.model_type is not None: + result["model_type"] = self.model_type if self.mode == "rtn": result["rtn_bits"] = self.rtn_bits result["rtn_block_size"] = self.rtn_block_size @@ -174,6 +184,7 @@ def from_dict(cls, data: dict) -> WinMLQuantizationConfig: task=data.get("task"), model_name=data.get("model_name"), dataset_name=data.get("dataset_name"), + model_type=data.get("model_type"), weight_type=data.get("weight_type", "uint8"), activation_type=data.get("activation_type", "uint8"), per_channel=data.get("per_channel", False), diff --git a/src/winml/modelkit/quant/quantizer.py b/src/winml/modelkit/quant/quantizer.py index eb340d962..3dab021a0 100644 --- a/src/winml/modelkit/quant/quantizer.py +++ b/src/winml/modelkit/quant/quantizer.py @@ -101,6 +101,23 @@ def _quantize_single_pass( warnings: list[str] = [] try: + # Model-type-specific quant policy. Some model types finalize their + # scheme (calibration reader / nodes-to-exclude / fixed dtypes + mode) + # only once the exported ONNX exists. Resolve it from the quant registry + # keyed on ``config.model_type`` and apply it here — a single seam shared + # by every caller (CLI build, library build, standalone quantize) instead + # of duplicated dispatch at each call site. Unregistered types (and a + # caller-supplied ``calibration_data``) leave the config untouched, so + # the quantizer falls back to its default task-aware reader. + if config.model_type and config.calibration_data is None: + from .calibration import get_quant_finalizer + + finalizer = get_quant_finalizer(config.model_type) + if finalizer is not None: + config = finalizer.finalize( + config, onnx_path=model_path, model_id=config.model_name + ) + # Dispatch to the appropriate single-mode handler _mode_handlers: dict[str, Callable[..., QuantizeResult]] = { "fp16": _quantize_fp16, diff --git a/tests/unit/commands/test_build.py b/tests/unit/commands/test_build.py index 87d547291..5fc1e520f 100644 --- a/tests/unit/commands/test_build.py +++ b/tests/unit/commands/test_build.py @@ -1768,28 +1768,27 @@ def test_load_model_receives_config_model_type( @patch("winml.modelkit.commands.build._run_compile_stage") @patch("winml.modelkit.commands.build._run_quantize_stage") - @patch("winml.modelkit.quant.get_quant_finalizer") @patch("winml.modelkit.commands.build._run_optimize_stage") @patch("winml.modelkit.commands.build._show_io") @patch("winml.modelkit.utils.console.StageLive") @patch("winml.modelkit.export.export_onnx") @patch("winml.modelkit.build.hf._load_model") - def test_quant_finalizer_applied_for_registered_model_type( + def test_quant_model_type_carried_into_quantize_stage( self, mock_load_model: MagicMock, mock_export_onnx: MagicMock, mock_stage_live: MagicMock, mock_show_io: MagicMock, mock_optimize: MagicMock, - mock_get_finalizer: MagicMock, mock_quantize: MagicMock, mock_compile: MagicMock, tmp_path: Path, ) -> None: - """The CLI HF pipeline must apply the registered quant finalizer. + """The CLI HF pipeline must hand the model_type to the quantize stage. - Mirrors build.hf.build_hf_model: without this the CLI quantizes with the - default task-aware scheme instead of the model-type-specific policy. + The model-type-specific quant policy is resolved inside ``quantize_onnx`` + from ``config.quant.model_type``; the pipeline is responsible for carrying + the resolved variant onto the quant config so the policy fires. """ from winml.modelkit.commands.build import _build_hf_pipeline @@ -1803,11 +1802,6 @@ def test_quant_finalizer_applied_for_registered_model_type( optimized = tmp_path / "optimized.onnx" mock_optimize.return_value = (optimized, None) - finalized_quant = MagicMock(name="finalized_quant_config") - finalizer = MagicMock() - finalizer.finalize.return_value = finalized_quant - mock_get_finalizer.return_value = finalizer - # Stop right after the quantize stage so we don't exercise compile. mock_quantize.side_effect = RuntimeError("stop-after-quantize") @@ -1816,7 +1810,8 @@ def test_quant_finalizer_applied_for_registered_model_type( config.loader.task = "text2text-generation" config.loader.model_class = None config.export = MagicMock() - config.quant = MagicMock(name="initial_quant_config") + config.quant = MagicMock(name="quant_config") + config.quant.model_type = None config.to_dict.return_value = {} with pytest.raises(RuntimeError, match="stop-after-quantize"): @@ -1832,11 +1827,10 @@ def test_quant_finalizer_applied_for_registered_model_type( preloaded_hf_config=None, ) - mock_get_finalizer.assert_called_once_with("qwen3_transformer_only") - finalizer.finalize.assert_called_once() - assert finalizer.finalize.call_args.kwargs["model_id"] == "Qwen/Qwen3-0.6B" - # config.quant must be replaced with the finalized scheme before quantize. - assert config.quant is finalized_quant + # The resolved variant must be carried onto the quant config so that + # quantize_onnx can resolve + apply the model-type-specific policy. + assert config.quant.model_type == "qwen3_transformer_only" + mock_quantize.assert_called_once() class TestBuildEpResolution: diff --git a/tests/unit/config/test_build.py b/tests/unit/config/test_build.py index ce7426ccd..c0d569515 100644 --- a/tests/unit/config/test_build.py +++ b/tests/unit/config/test_build.py @@ -1055,7 +1055,7 @@ def test_model_type_with_task( mock_model_class: MagicMock, mock_export_config: WinMLExportConfig, ) -> None: - """model_type + task: overrides hf_config.model_type, uses given task.""" + """model_type + task: threads variant model_type through, uses given task.""" gpt2_loader_config = WinMLLoaderConfig( task="text-generation", model_class="GPT2LMHeadModel", diff --git a/tests/unit/loader/test_load_hf_model.py b/tests/unit/loader/test_load_hf_model.py index b2a53bb30..0c819d79f 100644 --- a/tests/unit/loader/test_load_hf_model.py +++ b/tests/unit/loader/test_load_hf_model.py @@ -116,7 +116,7 @@ def test_model_class_without_user_script_uses_tasks_manager(self, monkeypatch): # Track calls to resolve_task resolve_calls = [] - def mock_resolve(config, *, task=None, model_class=None): + def mock_resolve(config, *, task=None, model_class=None, model_type_override=None): resolve_calls.append({"task": task, "model_class": model_class}) mock_class = MagicMock() mock_class.__name__ = "MockModel" @@ -158,7 +158,7 @@ def test_auto_detect_when_no_model_class(self, monkeypatch): # Track calls to resolve_task resolve_calls = [] - def mock_resolve(config, *, task=None, model_class=None): + def mock_resolve(config, *, task=None, model_class=None, model_type_override=None): resolve_calls.append({"task": task, "model_class": model_class}) mock_class = MagicMock() mock_class.__name__ = "AutoDetectedModel" @@ -195,7 +195,7 @@ def test_bert_tiny_uses_model_specific_default_task(self, monkeypatch): resolve_calls = [] - def mock_resolve(config, *, task=None, model_class=None): + def mock_resolve(config, *, task=None, model_class=None, model_type_override=None): resolved_task = task or "feature-extraction" resolve_calls.append({"task": resolved_task, "model_class": model_class}) mock_class = MagicMock() diff --git a/tests/unit/loader/test_resolve_loader_config.py b/tests/unit/loader/test_resolve_loader_config.py index 491af63ce..65850bd32 100644 --- a/tests/unit/loader/test_resolve_loader_config.py +++ b/tests/unit/loader/test_resolve_loader_config.py @@ -146,8 +146,10 @@ def test_explicit_model_type_overrides_hf_config(self) -> None: """An explicit model_type (with a model_id) overrides the resolved type. Needed so a variant model_type such as ``qwen3_transformer_only`` selects - the variant rather than the architecture's native type. The override only - applies when a model_id is present and the requested type differs. + the variant rather than the architecture's native type. The override is + threaded into resolution as ``model_type_override`` and surfaces on the + loader config WITHOUT mutating the loaded HF config — export/patcher + consumers must keep seeing the native type. """ mock_config = MagicMock() mock_config.model_type = "original_type" @@ -163,16 +165,18 @@ def test_explicit_model_type_overrides_hf_config(self) -> None: patch( "winml.modelkit.loader.resolution.resolve_task", return_value=_make_resolution("text-generation", mock_class), - ), + ) as mock_resolve, ): loader_config, hf_config, _, _resolution = resolve_loader_config( "some-model", model_type="gpt2", task="text-generation" ) - # The explicit model_type wins over the architecture's native type. - assert hf_config.model_type == "gpt2" - # loader_config.model_type reflects the overridden type. + # The loaded HF config is NOT mutated — it keeps its native type. + assert hf_config.model_type == "original_type" + # loader_config.model_type reflects the overridden (variant) type. assert loader_config.model_type == "gpt2" + # The override is threaded into resolve_task rather than mutated in place. + assert mock_resolve.call_args.kwargs.get("model_type_override") == "gpt2" def test_auto_detect_task_from_model_type(self) -> None: """model_type without task auto-detects first supported task.""" diff --git a/tests/unit/test_quantizer.py b/tests/unit/test_quantizer.py index 927956c87..995a4cc8a 100644 --- a/tests/unit/test_quantizer.py +++ b/tests/unit/test_quantizer.py @@ -143,3 +143,96 @@ def fake_quantize(*, model_input, model_output: str, quant_config) -> None: assert result.success is True assert extra_suffix_sidecar.exists() + + +def test_quantize_onnx_applies_model_type_finalizer( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A registered model_type finalizer is resolved + applied before dispatch. + + The model-type-specific quant policy used to be dispatched at each call site + (CLI build, library build). It now lives behind a single seam in + quantize_onnx, keyed on ``config.model_type``: the finalizer is resolved from + the calibration registry and its returned config is what the mode handler + receives. + """ + import winml.modelkit.quant.calibration as calibration_mod + import winml.modelkit.quant.quantizer as quantizer_mod + + model_path = tmp_path / "model.onnx" + model_path.write_text("input") + output_path = tmp_path / "quantized.onnx" + + finalized_config = WinMLQuantizationConfig( + model_type="dummy_type", + calibration_data=_FakeCalibrationReader(), + ) + + finalize_calls: list[dict[str, Any]] = [] + + class _StubFinalizer: + def finalize(self, config, *, onnx_path, model_id): # type: ignore[no-untyped-def] + finalize_calls.append({"config": config, "onnx_path": onnx_path, "model_id": model_id}) + return finalized_config + + monkeypatch.setattr(calibration_mod, "get_quant_finalizer", lambda model_type: _StubFinalizer()) + + handler_calls: list[WinMLQuantizationConfig] = [] + + def _fake_qdq(*, config, **_kwargs): # type: ignore[no-untyped-def] + handler_calls.append(config) + return SimpleNamespace(success=True, output_path=output_path, errors=[]) + + monkeypatch.setattr(quantizer_mod, "_quantize_qdq", _fake_qdq) + + result = quantize_onnx( + model_path, + output_path=output_path, + config=WinMLQuantizationConfig( + model_type="dummy_type", + model_name="some/model-id", + ), + ) + + assert result.success is True + # Finalizer was resolved + invoked with the exported graph + model id. + assert len(finalize_calls) == 1 + assert finalize_calls[0]["onnx_path"] == model_path + assert finalize_calls[0]["model_id"] == "some/model-id" + # The handler ran against the finalized config, not the original. + assert handler_calls == [finalized_config] + + +def test_quantize_onnx_skips_finalizer_when_calibration_data_provided( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """A caller-supplied calibration reader bypasses the model_type finalizer.""" + import winml.modelkit.quant.calibration as calibration_mod + import winml.modelkit.quant.quantizer as quantizer_mod + + model_path = tmp_path / "model.onnx" + model_path.write_text("input") + output_path = tmp_path / "quantized.onnx" + + def _boom(_model_type): # type: ignore[no-untyped-def] + raise AssertionError("finalizer must not be resolved when calibration_data is set") + + monkeypatch.setattr(calibration_mod, "get_quant_finalizer", _boom) + + def _fake_qdq(*, config, **_kwargs): # type: ignore[no-untyped-def] + return SimpleNamespace(success=True, output_path=output_path, errors=[]) + + monkeypatch.setattr(quantizer_mod, "_quantize_qdq", _fake_qdq) + + result = quantize_onnx( + model_path, + output_path=output_path, + config=WinMLQuantizationConfig( + model_type="dummy_type", + calibration_data=_FakeCalibrationReader(), + ), + ) + + assert result.success is True From da6006b8acda2c57e0efacd4af828be841763f04 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 26 Jun 2026 20:07:09 +0800 Subject: [PATCH 17/17] Make qwen3 transformer-only subpackage mypy-clean for whole-package type check --- .../models/hf/qwen3/qwen3_export_ops.py | 60 ++++++------ .../models/hf/qwen3/qwen3_modeling.py | 96 +++++++++++++------ .../models/hf/qwen3/qwen_transformer_only.py | 34 ++++--- 3 files changed, 118 insertions(+), 72 deletions(-) diff --git a/src/winml/modelkit/models/hf/qwen3/qwen3_export_ops.py b/src/winml/modelkit/models/hf/qwen3/qwen3_export_ops.py index aed592fa7..5042b4aaf 100644 --- a/src/winml/modelkit/models/hf/qwen3/qwen3_export_ops.py +++ b/src/winml/modelkit/models/hf/qwen3/qwen3_export_ops.py @@ -35,7 +35,7 @@ class LpNormOnnxExport(torch.autograd.Function): """RMSNorm body → ONNX ``LpNormalization`` (p=2 along last dim).""" @staticmethod - def symbolic(g, input, axis, p) -> Any: + def symbolic(g: Any, input: Any, axis: Any, p: Any) -> Any: """Emit the ONNX ``LpNormalization`` node during export.""" output_type = input.type().with_sizes(symbolic_helper._get_tensor_sizes(input)) output = g.op( @@ -47,7 +47,7 @@ def symbolic(g, input, axis, p) -> Any: return output.setType(output_type) @staticmethod - def forward(ctx, input, axis, p) -> Any: + def forward(ctx: Any, input: Any, axis: Any, p: Any) -> Any: """Real ``LpNormalization`` (``input / ||input||_p`` along ``axis``). The exported node comes from ``symbolic``; this eager body computes the @@ -64,19 +64,19 @@ class GroupQueryAttentionOnnxExport(torch.autograd.Function): @staticmethod def symbolic( - g, - query, - key, - value, - past_key, - past_value, - seqlens_k, - total_sequence_length, - cos_cache, - sin_cache, - do_rotary, - kv_num_heads, - num_heads, + g: Any, + query: Any, + key: Any, + value: Any, + past_key: Any, + past_value: Any, + seqlens_k: Any, + total_sequence_length: Any, + cos_cache: Any, + sin_cache: Any, + do_rotary: Any, + kv_num_heads: Any, + num_heads: Any, ) -> Any: """Emit the fused ``com.microsoft::GroupQueryAttention`` node.""" args = [ @@ -111,19 +111,19 @@ def symbolic( @staticmethod def forward( - ctx, - query, - key, - value, - past_key, - past_value, - seqlens_k, - total_sequence_length, - cos_cache, - sin_cache, - do_rotary, - kv_num_heads, - num_heads, + ctx: Any, + query: Any, + key: Any, + value: Any, + past_key: Any, + past_value: Any, + seqlens_k: Any, + total_sequence_length: Any, + cos_cache: Any, + sin_cache: Any, + do_rotary: Any, + kv_num_heads: Any, + num_heads: Any, ) -> Any: """Shape-only tracing placeholder; returns a stand-in ``(output, KV)``. @@ -155,8 +155,8 @@ def __init__( self, in_channels: int, out_channels: int, - weight: torch.nn.Parameter, - bias: torch.nn.Parameter | None = None, + weight: torch.Tensor, + bias: torch.Tensor | None = None, ) -> None: super().__init__() # Linear weight is (out, in); Conv2d weight is (out, in, 1, 1). diff --git a/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py b/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py index f5207d797..140c2dac8 100644 --- a/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py +++ b/src/winml/modelkit/models/hf/qwen3/qwen3_modeling.py @@ -25,7 +25,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, cast import torch import torch.nn as nn @@ -40,6 +40,10 @@ class WinMLQwen3RMSNorm(nn.Module): """RMSNorm export variant — ``onnx::LpNormalization`` body.""" + # Bound at runtime onto a live ``Qwen3RMSNorm`` module; declared so the + # type checker knows the attribute these methods rely on. + weight: torch.Tensor + def prepare_for_onnx_export(self) -> None: """Fold the RMSNorm gain into the weight (LpNorm has unit gain).""" # Pre-multiply the gain into the weight (LpNorm has unit gain). @@ -52,39 +56,68 @@ def prepare_for_onnx_export(self) -> None: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: """Apply the LpNormalization-based RMSNorm body.""" - out = LpNormOnnxExport.apply(hidden_states, -1, 2) + out = cast("torch.Tensor", LpNormOnnxExport.apply(hidden_states, -1, 2)) return self.weight * out class WinMLQwen3MLP(nn.Module): """MLP export variant — 1x1 Conv projections (forward unchanged).""" + # Bound at runtime onto a live ``Qwen3MLP`` module; declared so the type + # checker has a non-circular type for the projections these methods swap. + gate_proj: nn.Module + up_proj: nn.Module + down_proj: nn.Module + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: """Optionally swap the MLP's linear projections for 1x1 convs.""" if not matmul_to_conv: return - self.gate_proj = TransposeConv2d1x1Transpose.from_linear_module(self.gate_proj) - self.up_proj = TransposeConv2d1x1Transpose.from_linear_module(self.up_proj) - self.down_proj = TransposeConv2d1x1Transpose.from_linear_module(self.down_proj) + self.gate_proj = TransposeConv2d1x1Transpose.from_linear_module( + cast("nn.Linear", self.gate_proj) + ) + self.up_proj = TransposeConv2d1x1Transpose.from_linear_module( + cast("nn.Linear", self.up_proj) + ) + self.down_proj = TransposeConv2d1x1Transpose.from_linear_module( + cast("nn.Linear", self.down_proj) + ) class WinMLQwen3Attention(nn.Module): """Attention export variant — fused ``GroupQueryAttention`` op.""" + # Bound at runtime onto a live ``Qwen3Attention`` module; declared so the + # type checker knows the attributes these methods rely on. + config: Any + head_dim: int + q_proj: nn.Module + k_proj: nn.Module + v_proj: nn.Module + o_proj: nn.Module + def prepare_for_onnx_export(self, *, matmul_to_conv: bool) -> None: """Optionally swap the Q/K/V/O projections for 1x1 convs.""" if matmul_to_conv: - self.q_proj = TransposeConv2d1x1Transpose.from_linear_module(self.q_proj) - self.k_proj = TransposeConv2d1x1Transpose.from_linear_module(self.k_proj) - self.v_proj = TransposeConv2d1x1Transpose.from_linear_module(self.v_proj) - self.o_proj = TransposeConv2d1x1Transpose.from_linear_module(self.o_proj) + self.q_proj = TransposeConv2d1x1Transpose.from_linear_module( + cast("nn.Linear", self.q_proj) + ) + self.k_proj = TransposeConv2d1x1Transpose.from_linear_module( + cast("nn.Linear", self.k_proj) + ) + self.v_proj = TransposeConv2d1x1Transpose.from_linear_module( + cast("nn.Linear", self.v_proj) + ) + self.o_proj = TransposeConv2d1x1Transpose.from_linear_module( + cast("nn.Linear", self.o_proj) + ) self._matmul_to_conv = matmul_to_conv def forward( self, hidden_states: torch.Tensor, past_key_value: tuple[torch.Tensor, torch.Tensor] | None = None, - past_seq_len: torch.Tensor | None = None, + past_seq_len: torch.Tensor | int | None = None, total_seq_len: torch.Tensor | None = None, **kwargs: Any, ) -> tuple[torch.Tensor, None, tuple[torch.Tensor, torch.Tensor]]: @@ -95,8 +128,8 @@ def forward( input_shape = hidden_states.shape[1:-1] hidden_shape = (*input_shape, -1, self.head_dim) - query_states = self.q_norm(query_states.view(hidden_shape)) - key_states = self.k_norm(key_states.view(hidden_shape)) + query_states = cast("nn.Module", self.q_norm)(query_states.view(hidden_shape)) + key_states = cast("nn.Module", self.k_norm)(key_states.view(hidden_shape)) num_heads = self.config.num_attention_heads num_kv_heads = self.config.num_key_value_heads @@ -108,6 +141,7 @@ def forward( if self._matmul_to_conv: value_states = value_states.squeeze(0) + assert past_key_value is not None past_keys, past_values = past_key_value # GroupQueryAttention requires Q/K/V/past_K/past_V to share dtype. @@ -119,7 +153,7 @@ def forward( key_states = key_states.to(kv_dtype) value_states = value_states.to(kv_dtype) - cos, sin = self.rotary_emb( + cos, sin = cast("nn.Module", self.rotary_emb)( value_states, torch.arange(self.config.max_position_embeddings).unsqueeze(0), ) @@ -171,11 +205,11 @@ def forward( total_seq_len: torch.Tensor | None = None, use_cache: bool = True, **kwargs: Any, - ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + ) -> tuple[Any, ...]: """Run the decoder layer (attention + MLP) with residual adds.""" residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - attn_out, _, present_kv = self.self_attn( + hidden_states = cast("nn.Module", self.input_layernorm)(hidden_states) + attn_out, _, present_kv = cast("nn.Module", self.self_attn)( hidden_states=hidden_states, past_key_value=past_key_value, past_seq_len=past_seq_len, @@ -184,11 +218,11 @@ def forward( hidden_states = residual + attn_out residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) + hidden_states = cast("nn.Module", self.post_attention_layernorm)(hidden_states) + hidden_states = cast("nn.Module", self.mlp)(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) + outputs: tuple[Any, ...] = (hidden_states,) if use_cache: outputs += (present_kv,) return outputs @@ -215,7 +249,7 @@ def forward( hidden_states = hidden_states.unsqueeze(0) # NHWC for Conv path present_kvs: tuple[tuple[torch.Tensor, torch.Tensor], ...] = () - for idx, layer in enumerate(self.layers): + for idx, layer in enumerate(cast("nn.ModuleList", self.layers)): out = layer( hidden_states, past_key_value=past_key_values[idx], @@ -227,7 +261,7 @@ def forward( if use_cache: present_kvs += (out[1],) - hidden_states = self.norm(hidden_states) + hidden_states = cast("nn.Module", self.norm)(hidden_states) if self._matmul_to_conv: hidden_states = hidden_states.squeeze(0) return hidden_states, present_kvs @@ -260,7 +294,7 @@ def apply_transformer_only_export_prep( (e.g. the stock HF class names changed). """ - def _bind(module: nn.Module, owner: type) -> None: + def _bind(module: nn.Module, owner: type[nn.Module]) -> None: module.forward = owner.forward.__get__(module, type(module)) # Identify Qwen3 submodules by their (stock HF) class name so we don't @@ -280,18 +314,22 @@ def _is(module: nn.Module, name: str) -> bool: # in input/post_attention layernorms). for mod in causal_lm.modules(): if _is(mod, "Qwen3RMSNorm"): - WinMLQwen3RMSNorm.prepare_for_onnx_export(mod) + WinMLQwen3RMSNorm.prepare_for_onnx_export(cast("WinMLQwen3RMSNorm", mod)) _bind(mod, WinMLQwen3RMSNorm) patched["Qwen3RMSNorm"] += 1 for mod in causal_lm.modules(): if _is(mod, "Qwen3Attention"): - WinMLQwen3Attention.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + WinMLQwen3Attention.prepare_for_onnx_export( + cast("WinMLQwen3Attention", mod), matmul_to_conv=matmul_to_conv + ) _bind(mod, WinMLQwen3Attention) patched["Qwen3Attention"] += 1 elif _is(mod, "Qwen3MLP"): # MLP forward is unchanged; only the projections are swapped to Conv. - WinMLQwen3MLP.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + WinMLQwen3MLP.prepare_for_onnx_export( + cast("WinMLQwen3MLP", mod), matmul_to_conv=matmul_to_conv + ) patched["Qwen3MLP"] += 1 # HF moved ``rotary_emb`` from ``Qwen3Attention`` up to ``Qwen3Model``; @@ -299,8 +337,8 @@ def _is(module: nn.Module, name: str) -> bool: # so re-attach a reference from the parent model. for mod in causal_lm.modules(): if _is(mod, "Qwen3Model") and hasattr(mod, "rotary_emb"): - for layer in mod.layers: - layer.self_attn.rotary_emb = mod.rotary_emb + for layer in cast("nn.ModuleList", mod.layers): + cast("nn.Module", layer.self_attn).rotary_emb = mod.rotary_emb for mod in causal_lm.modules(): if _is(mod, "Qwen3DecoderLayer"): @@ -309,7 +347,9 @@ def _is(module: nn.Module, name: str) -> bool: for mod in causal_lm.modules(): if _is(mod, "Qwen3Model"): - WinMLQwen3Model.prepare_for_onnx_export(mod, matmul_to_conv=matmul_to_conv) + WinMLQwen3Model.prepare_for_onnx_export( + cast("WinMLQwen3Model", mod), matmul_to_conv=matmul_to_conv + ) _bind(mod, WinMLQwen3Model) patched["Qwen3Model"] += 1 diff --git a/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py b/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py index fc26e6070..46cd13a6f 100644 --- a/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py +++ b/src/winml/modelkit/models/hf/qwen3/qwen_transformer_only.py @@ -28,18 +28,19 @@ from __future__ import annotations -from typing import Any, ClassVar +from typing import Any, ClassVar, cast import torch import torch.nn as nn from optimum.exporters.onnx import OnnxConfig from optimum.utils import NormalizedConfig from optimum.utils.input_generators import DummyInputGenerator -from transformers import AutoModelForCausalLM +from transformers import AutoModelForCausalLM, PretrainedConfig from ....config import WinMLBuildConfig from ....export import register_onnx_overwrite from ....export.config import WinMLExportConfig +from ....optim import WinMLOptimizationConfig from ...winml import register_specialization from ...winml.composite_model import WinMLCompositeModel, register_composite_model from ...winml.kv_cache import WinMLSlidingWindowCache @@ -73,7 +74,7 @@ def __init__(self, model: nn.Module, num_layers: int) -> None: super().__init__() self.model = model self.num_layers = num_layers - self.config = model.config + self.config: PretrainedConfig = cast("PretrainedConfig", model.config) apply_transformer_only_export_prep(model, matmul_to_conv=True) # Tag the config so the exporter resolves this variant's OnnxConfig # (registered under ``TRANSFORMER_ONLY_MODEL_TYPE``) rather than the @@ -111,7 +112,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: past_key_values = [(kv_args[2 * i], kv_args[2 * i + 1]) for i in range(self.num_layers)] - hidden_states, present_kvs = self.model.model( + hidden_states, present_kvs = cast("nn.Module", self.model.model)( inputs_embeds=input_hidden_states, past_key_values=past_key_values, past_seq_len=past_seq_len, @@ -130,7 +131,7 @@ def forward(self, *args: torch.Tensor) -> tuple[torch.Tensor, ...]: # ============================================================================= -class _TransformerOnlyHiddenStateGenerator(DummyInputGenerator): +class _TransformerOnlyHiddenStateGenerator(DummyInputGenerator): # type: ignore[misc] # optimum base is untyped """Generates ``input_hidden_states`` (FP32, ``[1, seq_len, hidden]``).""" SUPPORTED_INPUT_NAMES = ("input_hidden_states",) @@ -147,7 +148,11 @@ def __init__( ) -> None: self.batch_size = batch_size self.hidden_size = normalized_config.hidden_size - self.seq_len = seq_len or getattr(normalized_config, "seq_len", self._default_seq_len) + self.seq_len = ( + int(seq_len) + if seq_len + else int(getattr(normalized_config, "seq_len", self._default_seq_len)) + ) def generate( self, @@ -165,7 +170,7 @@ class _TransformerOnlyHiddenStatePrefillGenerator(_TransformerOnlyHiddenStateGen _default_seq_len = 64 -class _TransformerOnlySeqLenGenerator(DummyInputGenerator): +class _TransformerOnlySeqLenGenerator(DummyInputGenerator): # type: ignore[misc] # optimum base is untyped """Generates ``past_seq_len`` (INT32 ``[1,1]``) and ``total_seq_len`` (INT32 ``[1]``).""" SUPPORTED_INPUT_NAMES = ("past_seq_len", "total_seq_len") @@ -187,10 +192,10 @@ def generate( raise ValueError(f"Unknown input: {input_name}") -class _TransformerOnlyKvCacheGenerator(DummyInputGenerator): +class _TransformerOnlyKvCacheGenerator(DummyInputGenerator): # type: ignore[misc] # optimum base is untyped """Generates ``past_keys_{i}`` / ``past_values_{i}`` (FP16).""" - SUPPORTED_INPUT_NAMES = () # built dynamically in __init__ + SUPPORTED_INPUT_NAMES: tuple[str, ...] = () # built dynamically in __init__ def __init__( self, @@ -265,7 +270,7 @@ def _transformer_only_outputs( @register_onnx_overwrite( TRANSFORMER_ONLY_MODEL_TYPE, "feature-extraction", library_name="transformers" ) -class QwenTransformerOnlyPrefillIOConfig(OnnxConfig): +class QwenTransformerOnlyPrefillIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """Prefill (seq=64) — transformer-only I/O.""" NORMALIZED_CONFIG_CLASS = _QWEN_TRANSFORMER_ONLY_NORMALIZED @@ -289,7 +294,7 @@ def outputs(self) -> dict[str, dict[int, str]]: @register_onnx_overwrite( TRANSFORMER_ONLY_MODEL_TYPE, "text2text-generation", library_name="transformers" ) -class QwenTransformerOnlyGenIOConfig(OnnxConfig): +class QwenTransformerOnlyGenIOConfig(OnnxConfig): # type: ignore[misc] # optimum base is untyped """Generation (seq=1) — transformer-only I/O.""" NORMALIZED_CONFIG_CLASS = _QWEN_TRANSFORMER_ONLY_NORMALIZED @@ -317,8 +322,9 @@ def outputs(self) -> dict[str, dict[int, str]]: QWEN_TRANSFORMER_ONLY_CONFIG = WinMLBuildConfig( export=WinMLExportConfig(dynamo=False, opset_version=18), - # Pure graph (no post-export RMSNorm fusion / matmul-add fusion). - optim=None, + # Pure graph (no post-export RMSNorm fusion / matmul-add fusion): the default + # WinMLOptimizationConfig() leaves every fusion flag off. + optim=WinMLOptimizationConfig(), ) @@ -362,7 +368,7 @@ def get_cache_class(cls) -> type: return WinMLSlidingWindowCache @classmethod - def from_pretrained( # type: ignore[override] + def from_pretrained( cls, model_id: str, task: str = "text-generation",