Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions src/winml/modelkit/commands/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,12 +420,9 @@ def _validate_loader_tasks_for_model(
default=None,
help="WinMLBuildConfig JSON file. If omitted, config is auto-generated from -m.",
)
@click.option(
"-m",
"--model",
"model_id",
default=None,
help="HuggingFace model ID or path to .onnx file. Omit for random-weight build.",
@cli_utils.model_option(
required=False,
help_text="HuggingFace model ID or path to .onnx file. Omit for random-weight build.",
)
@click.option(
"-o",
Expand Down Expand Up @@ -479,7 +476,7 @@ def _validate_loader_tasks_for_model(
def build(
ctx: click.Context,
config_file: str | None,
model_id: str | None,
model: str | None,
output_dir: str | None,
use_cache: bool,
rebuild: bool,
Expand Down Expand Up @@ -585,12 +582,12 @@ def build(
no_compile=no_compile,
)
else:
if not model_id:
if not model:
raise click.UsageError("-m/--model is required when -c is not provided.")
from ..config import generate_build_config

config_or_configs = generate_build_config(
model_id,
model,
trust_remote_code=trust_remote_code,
device=device,
precision=precision,
Expand Down Expand Up @@ -623,8 +620,8 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:
# so the resulting config passes HF-build validation.
if cfg.loader is not None and cfg.loader.task:
resolved_quant.task = cfg.loader.task
if model_id:
resolved_quant.model_id = model_id
if model:
resolved_quant.model_id = model
cfg.quant = resolved_quant
else:
# Only update precision fields; preserve task/model_id
Expand Down Expand Up @@ -667,7 +664,7 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:
raise click.UsageError(f"Config validation failed: {e}") from e

preloaded_hf_config = _validate_loader_tasks_for_model(
model_id=model_id,
model_id=model,
configs=_configs_to_validate,
trust_remote_code=trust_remote_code,
)
Expand Down Expand Up @@ -704,7 +701,7 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:

print_setup(
console,
model=model_id or "random-init",
model=model or "random-init",
config=Path(config_file).name if config_file else "(auto)",
output=str(resolved_dir),
source="HuggingFace",
Expand Down Expand Up @@ -748,7 +745,7 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:

write_module_summary(
output_path=resolved_dir / "module_summary.json",
model_id=model_id or "random-init",
model_id=model or "random-init",
module_class=configs[0].loader.model_class or "unknown",
instances=summary_instances,
)
Expand All @@ -768,7 +765,7 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:

task = config.loader.task if config.loader else None
resolved_dir = get_model_dir(
model_id or "random-init",
model or "random-init",
cache_dir=get_cache_dir(),
)
if not task:
Expand All @@ -789,7 +786,7 @@ def _patch_device(cfg: WinMLBuildConfig) -> None:
_run_single_build(
config=config,
config_file=config_file,
model_id=model_id,
model_id=model,
resolved_dir=resolved_dir,
rebuild=rebuild,
cache_key=cache_key,
Expand Down
7 changes: 2 additions & 5 deletions src/winml/modelkit/commands/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,10 @@


@click.command()
@click.option(
"--model",
"-m",
@cli_utils.model_path_option(
required=False,
multiple=True,
type=click.Path(exists=True, path_type=Path),
help="Input ONNX model file. Repeat -m to compile multiple models with a shared "
help_text="Input ONNX model file. Repeat -m to compile multiple models with a shared "
"EP context (weight sharing). Required unless --list.",
)
@cli_utils.output_option("Output file path (e.g., model_compiled.onnx)")
Expand Down
9 changes: 3 additions & 6 deletions src/winml/modelkit/commands/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,10 @@


@click.command("eval")
@click.option(
"-m",
"--model",
type=str,
@cli_utils.model_option(
required=False,
multiple=True,
default=(),
help=(
help_text=(
"Model to evaluate. Accepts a HuggingFace model ID, an ONNX file path "
"(requires --model-id), or split-encoder role=path pairs (see --schema)."
),
Expand Down
7 changes: 2 additions & 5 deletions src/winml/modelkit/commands/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,9 @@ def _delete_onnx_with_external_data(onnx_path: Path) -> None:


@click.command()
@click.option(
"--model",
"-m",
@cli_utils.model_option(
required=True,
type=str,
help="HuggingFace model name or local path (e.g., prajjwal1/bert-tiny)",
help_text="HuggingFace model name or local path (e.g., prajjwal1/bert-tiny)",
)
@cli_utils.output_option("Output ONNX file path (e.g., model.onnx)", required=True)
@cli_utils.overwrite_option()
Expand Down
24 changes: 10 additions & 14 deletions src/winml/modelkit/commands/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,9 @@ def _looks_like_local_path(model_id: str) -> bool:


@click.command("inspect")
@click.option(
"-m",
"--model",
"model_id",
@cli_utils.model_option(
required=False,
default=None,
help="HuggingFace model ID (e.g., microsoft/resnet-50)",
help_text="HuggingFace model ID (e.g., microsoft/resnet-50)",
)
@cli_utils.format_option(choices=["table", "json"], default="table")
@click.option(
Expand Down Expand Up @@ -127,7 +123,7 @@ def _looks_like_local_path(model_id: str) -> bool:
@click.pass_context
def inspect(
ctx: click.Context,
model_id: str | None,
model: str | None,
output_format: cli_utils.OutputFormat,
verbose: int,
quiet: bool,
Expand Down Expand Up @@ -173,7 +169,7 @@ def inspect(
return

# Validate: need at least one of model_id, model_type, model_class
if model_id is None and model_type is None and model_class is None:
if model is None and model_type is None and model_class is None:
raise click.UsageError(
"At least one of -m/--model, --model-type, or --model-class is required. "
"Use --list-tasks to see available tasks."
Expand All @@ -182,17 +178,17 @@ def inspect(
# Classify the input before hitting HF Hub: local paths must exist.
# _looks_like_local_path uses a conservative allowlist to avoid misclassifying
# HF IDs with version dots (Phi-3.5, Qwen2.5, …) as local paths.
if model_id and _looks_like_local_path(model_id):
if model and _looks_like_local_path(model):
from pathlib import Path

_p = Path(model_id).expanduser()
_p = Path(model).expanduser()
if _p.suffix == ".onnx" and _p.is_file():
raise click.ClickException(
"ONNX file inspection is not yet supported. "
"Use 'winml config -m model.onnx' for ONNX build config."
)
if not _p.exists():
raise click.ClickException(f"Local path '{model_id}' does not exist.")
raise click.ClickException(f"Local path '{model}' does not exist.")

# Merge top-level -v/-q with subcommand-level flags so either position
# works, once and up front. The banner decision below needs the merged
Expand All @@ -207,7 +203,7 @@ def inspect(
# and in JSON mode (Click 8.4 mixes stderr into CliRunner.result.output,
# and JSON consumers expect clean stdout regardless).
json_mode = output_format == "json"
target = model_id or model_type or model_class
target = model or model_type or model_class
if not quiet and not json_mode:
_stderr_console.print(f"[dim]Inspecting [bold]{target}[/bold] …[/dim]")

Expand All @@ -219,7 +215,7 @@ def inspect(
try:
if quiet or json_mode:
result = _inspect_model_v2(
model_id=model_id,
model_id=model,
task_override=task,
model_type_override=model_type,
model_class_override=model_class,
Expand All @@ -231,7 +227,7 @@ def inspect(
spinner="dots",
):
result = _inspect_model_v2(
model_id=model_id,
model_id=model,
task_override=task,
model_type_override=model_type,
model_class_override=model_class,
Expand Down
10 changes: 4 additions & 6 deletions src/winml/modelkit/commands/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,10 @@ def capability_options(func: F) -> F:
default=False,
help="List available pattern rewrite families and exit",
)
@click.option(
"--model",
"-m",
required=False, # Not required when --list-capabilities/--list-rewrites is used
type=click.Path(exists=True, path_type=Path),
help="Input ONNX model file",
@cli_utils.model_path_option(
# Not required when --list-capabilities/--list-rewrites is used
required=False,
help_text="Input ONNX model file",
)
@cli_utils.output_option("Output path (default: {input}_opt.onnx)")
@cli_utils.overwrite_option()
Expand Down
10 changes: 2 additions & 8 deletions src/winml/modelkit/commands/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from __future__ import annotations

import logging
from pathlib import Path
from typing import TYPE_CHECKING, cast

import click
Expand All @@ -32,6 +31,7 @@


if TYPE_CHECKING:
from pathlib import Path
from typing import Literal


Expand All @@ -40,13 +40,7 @@


@click.command()
@click.option(
"--model",
"-m",
required=True,
type=click.Path(exists=True, path_type=Path),
help="Input ONNX model file",
)
@cli_utils.model_path_option(required=True, help_text="Input ONNX model file")
@cli_utils.output_option("Output path (default: {input}_qdq.onnx)")
@cli_utils.overwrite_option()
@cli_utils.precision_option(
Expand Down
56 changes: 43 additions & 13 deletions src/winml/modelkit/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,47 +92,77 @@ def warn_ignored_calibration_options(
out.print(f"[yellow]Warning:[/yellow] {', '.join(ignored)} ignored — {reason}")


def model_path_option(required: bool = True) -> Callable[[F], F]:
"""Add --model option that accepts a local ONNX file path.
def model_path_option(
required: bool = True,
multiple: bool = False,
help_text: str | None = None,
) -> Callable[[F], F]:
"""Add ``-m/--model`` option that accepts a local ONNX file path.

The path is validated for existence on disk.
The path is validated for existence on disk and delivered as a
:class:`pathlib.Path`. Shared by the ONNX-only commands (``analyze``,
``compile``, ``optimize``, ``quantize``) so the flag spelling, ``Path``
type, and existence check stay identical. The decorated function receives
the value as the ``model`` parameter (a tuple when ``multiple=True``).

Args:
required: Whether the model option is required (default: True)
required: Whether the model option is required (default: True).
multiple: Accept the flag repeatably; the value becomes a tuple
(default: False).
help_text: Override for the help string (default: a generic
ONNX-file description).

Returns:
Decorator function
Decorator function.
"""
return click.option(
"--model",
"-m",
required=required,
multiple=multiple,
type=click.Path(exists=True, path_type=Path),
help="Path to ONNX model file to analyze",
help=help_text or "Path to ONNX model file to analyze",
)


def model_option(required: bool = True, optional_message: str | None = None) -> Callable[[F], F]:
"""Add --model option that accepts any model reference.
def model_option(
required: bool = True,
optional_message: str | None = None,
multiple: bool = False,
help_text: str | None = None,
) -> Callable[[F], F]:
"""Add ``-m/--model`` option that accepts any model reference.

Accepts a HuggingFace model ID, build output directory, or .onnx file path.
No path existence validation is performed.
No path existence validation is performed. Shared by the flexible-input
commands (``build``, ``config``, ``eval``, ``export``, ``inspect``,
``perf``, ``run``, ``serve``) so the flag spelling stays identical. The
decorated function receives the value as the ``model`` parameter (a tuple
when ``multiple=True``).

Args:
required: Whether the model option is required (default: True)
required: Whether the model option is required (default: True).
optional_message: Command-specific note appended after the help text.
multiple: Accept the flag repeatably; the value becomes a tuple
(default: False).
help_text: Override for the base help string. Commands whose accepted
inputs are narrower (e.g. ``inspect`` takes only an HF ID) supply
their own; ``optional_message`` is still appended to it.

Returns:
Decorator function
Decorator function.
"""
help = "Model: HF model ID, build output directory, or .onnx file path"
help = help_text or "Model: HF model ID, build output directory, or .onnx file path"
if optional_message:
help = f"{help}. {optional_message}"
# ``multiple`` options default to an empty tuple; single-valued ones to None.
kwargs: dict[str, Any] = {"multiple": True} if multiple else {"default": None}
return click.option(
"--model",
"-m",
required=required,
default=None,
help=help,
**kwargs,
)


Expand Down
Loading