Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/winml/modelkit/analyze/core/runtime_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def op_support(
if not self._has_any_rule_data:
logger.warning(
"No runtime check data found. Follow "
"https://github.com/microsoft/WinML-ModelKit/blob/main/CONTRIBUTING.md "
"https://github.com/microsoft/winml-cli/blob/main/CONTRIBUTING.md "
"to set up runtime check files."
)
else:
Expand Down
22 changes: 8 additions & 14 deletions src/winml/modelkit/commands/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@

from ..utils import cli as cli_utils
from ..utils.constants import (
ALL_EP_NAMES,
DEVICE_TYPE_TO_DEVICE,
EP_SUPPORTED_DEVICES,
SUPPORTED_DEVICES,
Expand Down Expand Up @@ -716,27 +715,22 @@ def _build_runtime_debug_output_path(model_path: Path, ep_name: str, device_name

@click.command(name="analyze")
@cli_utils.model_path_option(required=True)
@click.option(
"--ep",
"--execution-provider",
@cli_utils.ep_option(
required=False,
default="auto",
show_default=True,
type=click.Choice([*ALL_EP_NAMES, "all", "auto"], case_sensitive=False),
help=(
"Target execution provider. Supports canonical names, aliases, and all/auto. "
include_auto=True,
include_all=True,
optional_message=(
"all = evaluate all rule-data-backed EPs; "
"auto = infer a single best target from local availability"
),
)
@click.option(
"--device",
@cli_utils.device_option(
required=False,
default="auto",
show_default=True,
type=click.Choice([*SUPPORTED_DEVICES, "all", "auto"], case_sensitive=False),
help=(
"Target device type. Supports CPU/GPU/NPU and all/auto. "
include_auto=True,
include_all=True,
optional_message=(
"all = all rule-data-backed devices; "
"auto = infer a single best target from local availability"
),
Expand Down
9 changes: 3 additions & 6 deletions src/winml/modelkit/commands/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,10 @@
default=None,
help="Output directory (default: same as input model)",
)
@click.option(
"--device",
"-d",
type=click.Choice(["auto", "npu", "gpu", "cpu"], case_sensitive=False),
@cli_utils.device_option(
required=False,
default="auto",
show_default=True,
help="Target device",
include_auto=True,
)
@cli_utils.ep_option(
required=False,
Expand Down
8 changes: 3 additions & 5 deletions src/winml/modelkit/commands/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,10 @@
default=None,
help="Task (e.g. 'image-classification'). Auto-detected from --model-id.",
)
@click.option(
"--device",
type=click.Choice(["auto", "cpu", "gpu", "npu"], case_sensitive=False),
@cli_utils.device_option(
required=False,
default="auto",
show_default=True,
help="Device to run on. 'auto' detects the best available device.",
include_auto=True,
)
@cli_utils.ep_option(required=False)
@cli_utils.precision_option(
Expand Down
2 changes: 1 addition & 1 deletion src/winml/modelkit/commands/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def optimize(
else:
console.print(f" source: {group.sources[0]}")
console.print(f" target: {group.target_class}")
console.print(f" rule: modelkit/pattern/rules/{rule_file}")
console.print(f" rule: winml/modelkit/pattern/rules/{rule_file}")
if is_multi:
for src in group.sources:
src_flag = f"--enable-{source_flag_name(src, group.target_class)}"
Expand Down
24 changes: 21 additions & 3 deletions src/winml/modelkit/utils/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,25 @@ def format_option(
)


def ep_option(required: bool = True, optional_message: str | None = None) -> Callable[[F], F]:
def ep_option(
required: bool = True,
optional_message: str | None = None,
default: str | None = None,
include_auto: bool = False,
include_all: bool = False,
) -> Callable[[F], F]:
"""Add --ep (execution provider) option to a Click command.

Args:
required: Whether the EP option is required (default: True)
optional_message: Message to append to help text when
optional (e.g., "If not specified, analyzes all
supported EPs.")
default: Default value when optional (default: None)
include_auto: Whether to include "auto" as a valid choice
(default: False).
include_all: Whether to include "all" as a valid choice
(default: False).

Returns:
Decorator function
Expand All @@ -176,13 +187,16 @@ def ep_option(required: bool = True, optional_message: str | None = None) -> Cal
help_text = f"{help_text}. {optional_message}"

ep_choices = [name for name in ALL_EP_NAMES if name not in ("cuda", "CUDAExecutionProvider")]
choices = ["auto", *ep_choices] if include_auto else ep_choices
choices = ["all", *choices] if include_all else choices

return click.option(
"--ep",
"--execution-provider",
required=required,
default=None,
type=click.Choice(ep_choices, case_sensitive=False),
default=default if not required else None,
show_default=True,
type=click.Choice(choices, case_sensitive=False),
help=help_text,
)

Expand Down Expand Up @@ -262,6 +276,7 @@ def device_option(
optional_message: str | None = None,
default: str | None = "NPU",
include_auto: bool = False,
include_all: bool = False,
) -> Callable[[F], F]:
"""Add --device option to a Click command.

Expand All @@ -273,12 +288,15 @@ def device_option(
default: Default value when optional (default: "NPU")
include_auto: Whether to include "auto" as a valid choice
(default: False).
include_all: Whether to include "all" as a valid choice
(default: False).

Returns:
Decorator function
"""
device_choices = [device.lower() for device in SUPPORTED_DEVICES]
choices = ["auto", *device_choices] if include_auto else device_choices
choices = ["all", *choices] if include_all else choices
help_text = f"Target device type ({', '.join(choices)})"
if optional_message:
help_text = f"{help_text}. {optional_message}"
Expand Down
Loading