Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
3796b7e
Add qauntization for transformers for qwen0.6B
spalne Jun 8, 2026
1ee316c
Quantize transformer-only with fused GQA + GSM8k calibration
spalne Jun 16, 2026
78815fd
Fix Qwen3 w8a16 quant: symmetric int8 weights + exclude GQA from QDQ
spalne Jun 22, 2026
95d45d9
refactor(qwen): register transformer-only path as a declarative model…
spalne Jun 22, 2026
9cecb03
fix(qwen): calibrate transformer-only decode model on real trajectory
spalne Jun 23, 2026
08f05d7
Fixed small bugs
spalne Jun 23, 2026
818cfe4
refactor(qwen): config-driven transformer-only quant + pytest
github-actions[bot] Jun 24, 2026
362c778
Merge remote-tracking branch 'origin/main' into feature/qwen3-quant
github-actions[bot] Jun 24, 2026
a7f518e
fix(qwen): clean lint + persist finalized quant config + guard dynami…
github-actions[bot] Jun 24, 2026
caada38
fix(qwen): address review comments (LpNorm eager norm, CodeQL lint, e…
github-actions[bot] Jun 24, 2026
c97373c
fix(build): resolve quant-finalize hook on model class + update model…
github-actions[bot] Jun 24, 2026
752e6c9
refactor(quant): move qwen3 calibration logic into quant registry
github-actions[bot] Jun 24, 2026
e9dbe2a
fix(quant): satisfy mypy + CodeQL on calibration registry
github-actions[bot] Jun 24, 2026
5274563
Thread model_type + quant finalizer through CLI HF build; move qwen3 …
github-actions[bot] Jun 24, 2026
bfc831b
Make transformer-only composite handle returnable + add one-shot expo…
github-actions[bot] Jun 25, 2026
0cb9057
Merge remote-tracking branch 'origin/main' into feature/qwen3-quant
github-actions[bot] Jun 25, 2026
60d2dfc
merge: resolve conflicts with main (#872 precision-driven quant)
github-actions[bot] Jun 25, 2026
cf1c336
Merge remote-tracking branch 'origin/main' into feature/qwen3-quant
github-actions[bot] Jun 25, 2026
4049631
test(qwen3): fix NPU quant test EP detection and decoder path lookup
spalne Jun 25, 2026
44a68d4
Address PR review: dedup quant finalizer dispatch, plain finalizer re…
github-actions[bot] Jun 26, 2026
565d792
Merge remote-tracking branch 'origin/main' into feature/qwen3-quant
github-actions[bot] Jun 26, 2026
da6006b
Make qwen3 transformer-only subpackage mypy-clean for whole-package t…
github-actions[bot] Jun 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions scripts/export_qwen3_transformer_only.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
"""One-shot export of the Qwen3 transformer-only prefill + decode pair.

Leverages the registered ``WinMLQwen3TransformerOnlyModel`` composite to build
BOTH transformer-only sub-models in a single call:

- ``decoder_prefill`` — context graph, ``seq_len`` = --prefill-seq-len (64)
- ``decoder_gen`` — iteration graph, ``seq_len`` = 1

Each sub-model is built through the standard ``build_hf_model`` pipeline, so the
model-type quant finalizer is applied (int8 weight / uint16 activation, GQA
excluded from QDQ). Embeddings and the LM head are NOT part of this graph — they
run separately (e.g. from the bundle).

Usage::

# Build (or reuse cached) both ONNX, print their paths + node summary:
uv run python scripts/export_qwen3_transformer_only.py

# Copy the two ONNX (with external data) into a folder:
uv run python scripts/export_qwen3_transformer_only.py --output-dir out/qwen3

# Different model / device / cache geometry, force a rebuild:
uv run python scripts/export_qwen3_transformer_only.py \
--model-id Qwen/Qwen3-0.6B --device npu \
--max-cache-len 256 --prefill-seq-len 64 --force-rebuild
"""

from __future__ import annotations

import argparse
import collections
import sys
import time
from pathlib import Path

import onnx

from winml.modelkit.models.hf.qwen3.qwen_transformer_only import (
WinMLQwen3TransformerOnlyModel,
)
from winml.modelkit.onnx import copy_onnx_model


# Component name -> output file stem used when --output-dir is given.
_OUTPUT_STEMS = {
"decoder_prefill": "prefill",
"decoder_gen": "decode",
}

# Default EP per device; CPU/NPU/GPU map to their canonical providers.
_DEVICE_TO_EP = {
"cpu": "CPUExecutionProvider",
"npu": "QNNExecutionProvider",
"gpu": "DmlExecutionProvider",
}


def node_summary(path: str | Path) -> str:
"""Return a one-line QDQ/GQA structural summary of an ONNX graph."""
model = onnx.load(str(path), load_external_data=False)
counts = collections.Counter(n.op_type for n in model.graph.node)
gqa_io: set[str] = set()
for node in model.graph.node:
if node.op_type == "GroupQueryAttention":
gqa_io.update(node.input)
gqa_io.update(node.output)
qdq_touching_gqa = sum(
1
for n in model.graph.node
if n.op_type in ("QuantizeLinear", "DequantizeLinear")
and (set(n.input) & gqa_io or set(n.output) & gqa_io)
)
return (
f"Q={counts['QuantizeLinear']} DQ={counts['DequantizeLinear']} "
f"GQA={counts['GroupQueryAttention']} QDQ-touching-GQA={qdq_touching_gqa}"
)


def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
"""Parse command-line arguments."""
p = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
p.add_argument("--model-id", default="Qwen/Qwen3-0.6B", help="HF model id or local path.")
p.add_argument(
"--device",
default="cpu",
choices=sorted(_DEVICE_TO_EP),
help="Target device (selects the canonical EP). Default: cpu.",
)
p.add_argument("--precision", default="w8a16", help="Build precision. Default: w8a16.")
p.add_argument("--max-cache-len", type=int, default=256, help="Static KV cache length.")
p.add_argument(
"--prefill-seq-len",
type=int,
default=64,
help="Prefill/context sequence length.",
)
p.add_argument(
"--no-compile",
dest="no_compile",
action="store_true",
default=True,
help="Skip EPContext compilation (default; transformer-only is consumed pre-compile).",
)
p.add_argument(
"--compile",
dest="no_compile",
action="store_false",
help="Enable EPContext compilation (requires the device's compiler/SDK).",
)
p.add_argument("--force-rebuild", action="store_true", help="Rebuild even if cached.")
p.add_argument(
"--output-dir",
type=Path,
default=None,
help="If set, copy the two ONNX (with external data) here as prefill.onnx / decode.onnx.",
)
return p.parse_args(argv)


def main(argv: list[str] | None = None) -> int:
"""Build (or reuse) both transformer-only ONNX and report/copy them."""
args = parse_args(argv)

t0 = time.monotonic()
model = WinMLQwen3TransformerOnlyModel.from_pretrained(
args.model_id,
device=args.device,
precision=args.precision,
ep=_DEVICE_TO_EP[args.device],
no_compile=args.no_compile,
use_cache=True,
force_rebuild=args.force_rebuild,
sub_model_kwargs={
"decoder_prefill": {
"shape_config": {
"max_cache_len": args.max_cache_len,
"seq_len": args.prefill_seq_len,
}
},
"decoder_gen": {"shape_config": {"max_cache_len": args.max_cache_len, "seq_len": 1}},
},
)
elapsed = time.monotonic() - t0

print(f"\n=== transformer-only build done in {elapsed:.1f}s ===")

output_dir: Path | None = args.output_dir
if output_dir is not None:
output_dir.mkdir(parents=True, exist_ok=True)

for name, sub in model.sub_models.items():
src = Path(sub.onnx_path)
print(f"\n[{name}] {src}")
print(f" {node_summary(src)}")
if output_dir is not None:
dst = output_dir / f"{_OUTPUT_STEMS.get(name, name)}.onnx"
copy_onnx_model(src, dst)
print(f" -> copied to {dst}")

return 0


if __name__ == "__main__":
sys.exit(main())
12 changes: 12 additions & 0 deletions src/winml/modelkit/build/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def build_hf_model(
cache_key: str | None = None,
ep: EPNameOrAlias | None = None,
device: str | None = None,
model_type: str | None = None,
**kwargs: Any,
) -> BuildResult:
"""Build an ONNX model from a HuggingFace model architecture.
Expand Down Expand Up @@ -211,6 +212,7 @@ def _name(base: str) -> str:
model_id,
trust_remote_code,
random_init=random_init,
model_type=model_type,
)

# =========================================================================
Expand Down Expand Up @@ -315,6 +317,14 @@ def _name(base: str) -> str:
else:
logger.info("Quantizing model...")
t0 = time.monotonic()
# A model-type-specific quant policy (e.g. the qwen3_transformer_only
# w8a16 finalizer) is resolved and applied inside ``quantize_onnx``
# from ``config.quant.model_type``. Ensure it carries the resolved
# variant so hand-built configs (that skipped assemble_build_config)
# still trigger the right policy; ``quantize_onnx`` no-ops for
# model types without a registered finalizer.
if config.quant.model_type is None:
config.quant.model_type = config.loader.model_type
quant_result = quantize_onnx(
model_path=current_path,
output_path=quantized_path,
Expand Down Expand Up @@ -443,6 +453,7 @@ def _load_model(
trust_remote_code: bool,
random_init: bool = False,
hf_config: Any | None = None,
model_type: str | None = None,
) -> Any:
"""Load PyTorch model — pretrained or random weights.

Expand Down Expand Up @@ -518,6 +529,7 @@ def _load_model(
task=task,
trust_remote_code=effective_trust,
hf_config=hf_config,
model_type=model_type,
)
return pytorch_model

Expand Down
16 changes: 14 additions & 2 deletions src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,11 @@ def _name(base: str) -> str:

# Load + export (blocking)
pytorch_model = _load_model(
config, model_id, trust_remote_code=False, hf_config=preloaded_hf_config
config,
model_id,
trust_remote_code=False,
hf_config=preloaded_hf_config,
model_type=config.loader.model_type,
)
t0 = time.monotonic()
# config.export is None only for the ONNX build path; this is the HF path.
Expand Down Expand Up @@ -1438,7 +1442,15 @@ def _name(base: str) -> str:
# Persist config after autoconf
config_path.write_text(json.dumps(config.to_dict(), indent=2))

# ── Quantize stage ──────
# ── Quantize stage ───────────────────────────────────────────
# A model-type-specific quant policy (e.g. the qwen3_transformer_only w8a16
# finalizer) is resolved and applied inside ``quantize_onnx`` from
# ``config.quant.model_type``; no per-call-site dispatch needed here. Carry
# the resolved variant onto the quant config so configs that were hand-built
# or loaded from JSON (skipping assemble_build_config) still trigger it.
if config.quant is not None and config.quant.model_type is None:
config.quant.model_type = config.loader.model_type

current_path = _run_quantize_stage(
config=config,
current_path=current_path,
Expand Down
4 changes: 4 additions & 0 deletions src/winml/modelkit/config/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,6 +1071,10 @@ def _assemble_config(
model_type,
)
quant_config.model_id = model_id or model_type
# Carry the resolved model_type so quantize_onnx can resolve a
# model-type-specific quant policy (e.g. the qwen3_transformer_only
# w8a16 finalizer) from the exported graph.
quant_config.model_type = loader_config.model_type

return WinMLBuildConfig(
loader=loader_config,
Expand Down
32 changes: 31 additions & 1 deletion src/winml/modelkit/loader/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,33 @@ def resolve_loader_config(
f"attribute. Cannot proceed with config generation."
)

# Explicit model_type override alongside a model_id: thread the requested
# variant through downstream resolution (task / class / composite tag and
# the loader config's model_type) WITHOUT mutating the loaded HF config. The
# exported graph, the htp Optimum patcher and every other consumer must keep
# seeing the architecture's native type; only the resolved build-variant tag
# changes. The model_type-only path above (AutoConfig.for_model) is
# unaffected because it only runs when model_id is None.
model_type_override = (
model_type
if (model_id is not None and model_type is not None and hf_config.model_type != model_type)
else None
)
if model_type_override is not None:
logger.info(
"Applying model_type override '%s' -> '%s' (explicit request)",
hf_config.model_type,
model_type_override,
)

# 2-3. Unified resolution. Task detection — including the no-architectures
# --model-type fallback (first supported task) — now lives in resolve_task.
resolution = resolve_task(hf_config, task=task, model_class=model_class)
resolution = resolve_task(
hf_config,
task=task,
model_class=model_class,
model_type_override=model_type_override,
)
resolved_task, resolved_class = resolution.task, resolution.model_class
logger.info("Resolved: task=%s, model_class=%s", resolved_task, resolved_class.__name__)

Expand All @@ -232,6 +256,12 @@ def resolve_loader_config(
resolved_class,
)

# The explicit variant tag wins over the architecture's native type for the
# loader config (drives downstream build-variant dispatch), while
# resolved_hf_config keeps its native model_type.
if model_type_override is not None:
resolved_model_type = model_type_override

# 5. Build loader config
loader_config = WinMLLoaderConfig(
task=resolved_task,
Expand Down
27 changes: 25 additions & 2 deletions src/winml/modelkit/loader/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def load_hf_model(
user_script: str | None = None,
trust_remote_code: bool = False,
hf_config: PretrainedConfig | None = None,
model_type: str | None = None,
) -> tuple[nn.Module, PretrainedConfig, str]:
"""Load, detect task, and prepare HuggingFace model.
Expand Down Expand Up @@ -218,6 +219,23 @@ def load_hf_model(
trust_remote_code=trust_remote_code,
)

# Explicit model_type override: thread the requested build variant (e.g.
# "qwen3_transformer_only") into task resolution WITHOUT mutating the
# freshly-loaded HF config. The torch model is instantiated from its own
# native config below, so export/patcher consumers keep the native type;
# only class/task resolution sees the variant.
model_type_override = (
model_type
if model_type is not None and getattr(hf_config, "model_type", None) != model_type
else None
)
if model_type_override is not None:
logger.info(
"Applying model_type override '%s' -> '%s' (explicit request)",
getattr(hf_config, "model_type", None),
model_type_override,
)

# [2] Task & Model Class Resolution
from .resolution import resolve_task

Expand All @@ -228,10 +246,15 @@ def load_hf_model(
resolved_class = _load_class_from_script(user_script, model_class)
logger.info("Using custom model class from script: %s", model_class)
# Surfaced modality-aware task (consistent with the non-script branch).
task = resolve_task(hf_config, task=task).task
task = resolve_task(hf_config, task=task, model_type_override=model_type_override).task
else:
try:
resolution = resolve_task(hf_config, task=task, model_class=model_class)
resolution = resolve_task(
hf_config,
task=task,
model_class=model_class,
model_type_override=model_type_override,
)
task, resolved_class = resolution.task, resolution.model_class
except ValueError as e:
raise ValueError(
Expand Down
8 changes: 7 additions & 1 deletion src/winml/modelkit/loader/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def _get_custom_model_class(model_type: str, task: str) -> type | None:

return None


# Component-name -> sub-task, e.g. {"encoder": "feature-extraction",
# "decoder": "text2text-generation"} (the composite ``_SUB_MODEL_CONFIG`` shape).
CompositeComponents = dict[str, str]
Expand Down Expand Up @@ -372,16 +373,21 @@ def resolve_task(
*,
task: str | None = None,
model_class: str | None = None,
model_type_override: str | None = None,
) -> TaskResolution:
"""Resolve a single model's task + class from an HF config.

Stages: 0 user override -> 1 detect (override / no-architectures /
TasksManager / default) -> 2 model class -> 3 modality upgrade
(detection path only) -> 4 composite tag.

``model_type_override`` lets a caller drive resolution with a build variant
(e.g. ``qwen3_transformer_only``) without mutating the loaded HF config; when
``None`` the architecture's native ``config.model_type`` is used.
"""
from optimum.exporters.tasks import TasksManager

model_type = getattr(config, "model_type", None)
model_type = model_type_override or getattr(config, "model_type", None)
model_type_norm = model_type.lower().replace("_", "-") if model_type else ""
model_id = getattr(config, "_name_or_path", "") or None

Expand Down
Loading
Loading