diff --git a/src/winml/modelkit/analyze/console_writer.py b/src/winml/modelkit/analyze/console_writer.py deleted file mode 100644 index c8be13a3f..000000000 --- a/src/winml/modelkit/analyze/console_writer.py +++ /dev/null @@ -1,555 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Console writer for static analyzer results. - -This module provides real-time console output using Rich library, -displaying analysis results in a user-friendly format. -""" - -from __future__ import annotations - -import json -import textwrap -from typing import TYPE_CHECKING - -from rich.console import Console -from rich.markup import escape -from rich.table import Table - -from .models.support_level import SupportLevel - - -if TYPE_CHECKING: - from .models.information import Information - from .models.output import AnalysisOutput, EPSupport - - -class StaticAnalyzerConsoleWriter: - """Console writer for displaying static analyzer results.""" - - # Configuration constants - SEPARATOR_LENGTH = 80 - CONSOLE_WIDTH = 80 # Match typical terminal width - - def __init__(self, console: Console | None = None, verbose: bool = True) -> None: - """Initialize console writer. - - Args: - console: Rich console instance (optional) - verbose: Whether to output verbose information - """ - self.verbose = verbose - self.console = console or Console( - width=self.CONSOLE_WIDTH, - highlight=False, - ) - # Ensure we have a reliable width value (Rich may return None in some cases) - self.width = self.CONSOLE_WIDTH if not self.console.width else self.console.width - - def _bright_cyan(self, text: str | int | float) -> str: - """Format text in bright cyan.""" - return f"[bold cyan]{text}[/bold cyan]" - - def _bright_green(self, text: str | int | float) -> str: - """Format text in bright green.""" - return f"[bold green]{text}[/bold green]" - - def _bright_red(self, text: str) -> str: - """Format text in bright red.""" - return f"[bold red]{text}[/bold red]" - - def _bright_yellow(self, text: str) -> str: - """Format text in bright yellow.""" - return f"[bold yellow]{text}[/bold yellow]" - - def _bold(self, text: str) -> str: - """Format text in bold.""" - return f"[bold]{text}[/bold]" - - def _dim(self, text: str) -> str: - """Format text in dim style.""" - return f"[dim]{text}[/dim]" - - def write_analysis_results(self, analysis: AnalysisOutput) -> None: - """Write complete analysis results to console. - - Args: - analysis: AnalysisOutput containing all analysis results - """ - self._write_header(analysis) - self._write_model_info(analysis) - self._write_operator_summary(analysis) - - if analysis.metadata.detected_pattern_count: - self._write_pattern_summary(analysis) - - self._write_ihv_results(analysis) - self._write_footer(analysis) - - def _write_header(self, analysis: AnalysisOutput) -> None: - """Write analysis header.""" - self.console.print() - self.console.print("=" * self.SEPARATOR_LENGTH) - self.console.print(f"{self._bold('ONNX MODEL STATIC ANALYSIS REPORT')}") - self.console.print("=" * self.SEPARATOR_LENGTH) - - timestamp = analysis.analysis_timestamp.strftime("%Y-%m-%d %H:%M:%S") - self.console.print(f"Analysis Time: {self._dim(timestamp)}") - self.console.print(f"Model: {self._bright_cyan(analysis.metadata.model_path)}") - self.console.print() - - def _write_model_info(self, analysis: AnalysisOutput) -> None: - """Write model information section.""" - self.console.print(f"{self._bold('MODEL INFORMATION')}") - self.console.print("-" * self.SEPARATOR_LENGTH) - - metadata = analysis.metadata - self.console.print(f" • ONNX Opset: {self._bright_green(metadata.opset_version)}") - - if metadata.producer_name: - producer = f"{metadata.producer_name}" - if metadata.producer_version: - producer += f" v{metadata.producer_version}" - self.console.print(f" • Producer: {self._bright_green(producer)}") - - self.console.print( - f" • Total Operators: {self._bright_cyan(f'{metadata.total_operators:,}')}" - ) - self.console.print( - f" • Unique Types: {self._bright_cyan(metadata.unique_operator_types)}" - ) - - if metadata.detected_pattern_count: - total_patterns = sum(metadata.detected_pattern_count.values()) - self.console.print( - f" • Detected Patterns: {self._bright_cyan(total_patterns)} instances " - f"({self._bright_cyan(len(metadata.detected_pattern_count))} types)" - ) - self.console.print() - - def _write_operator_summary(self, analysis: AnalysisOutput) -> None: - """Write operator summary with top operators.""" - self.console.print(f"{self._bold('OPERATOR ANALYSIS')}") - self.console.print("-" * self.SEPARATOR_LENGTH) - - metadata = analysis.metadata - sorted_ops = sorted(metadata.operator_counts.items(), key=lambda x: x[1], reverse=True) - - # Create table - table = Table(show_header=True, header_style="bold cyan", box=None) - table.add_column("Rank", style="dim", width=6) - table.add_column("Operator Type", style="green") - table.add_column("Count", justify="right", style="cyan") - table.add_column("Percentage", justify="right", style="yellow") - - for rank, (op_type, count) in enumerate(sorted_ops, 1): - percentage = ( - (count / metadata.total_operators * 100) if metadata.total_operators > 0 else 0 - ) - table.add_row(f"{rank}.", op_type, f"{count:,}", f"{percentage:.1f}%") - - self.console.print(table) - self.console.print() - - def _write_pattern_summary(self, analysis: AnalysisOutput) -> None: - """Write pattern detection summary.""" - self.console.print(f"{self._bold('PATTERN DETECTION')}") - self.console.print("-" * self.SEPARATOR_LENGTH) - - detected = analysis.metadata.detected_pattern_count - total_instances = sum(detected.values()) - - self.console.print( - f" Detected {self._bright_cyan(total_instances)} pattern instances " - f"across {self._bright_cyan(len(detected))} pattern types" - ) - self.console.print() - - # Show pattern details - sorted_patterns = sorted(detected.items(), key=lambda x: x[1], reverse=True) - for pattern_id, count in sorted_patterns: - pattern_type = "Subgraph" if pattern_id.startswith("SUBGRAPH/") else "Operator" - self.console.print( - f" {pattern_type}: {self._bright_green(pattern_id)} " - f"({self._bright_cyan(count)} instances)" - ) - self.console.print() - - def _write_ihv_results(self, analysis: AnalysisOutput) -> None: - """Write IHV support analysis results.""" - self.console.print(f"{self._bold('IHV PLATFORM SUPPORT ANALYSIS')}") - self.console.print("=" * self.SEPARATOR_LENGTH) - - for ihv_result in analysis.results: - self._write_single_ihv_result( - ihv_result, - total_operators=analysis.metadata.total_operators, - unique_operator_types=analysis.metadata.unique_operator_types, - ) - self.console.print() - - def _write_single_ihv_result( - self, ihv_result: EPSupport, total_operators: int, unique_operator_types: int - ) -> None: - """Write support analysis for a single IHV platform. - - Args: - ihv_result: IHV support result - total_operators: Total number of operators in model - unique_operator_types: Total number of unique operator types - """ - # Header with support status - status_icon = "+" if ihv_result.runtime_support else "x" - status_text = ( - self._bright_green("SUPPORTED") - if ihv_result.runtime_support - else self._bright_red("NOT SUPPORTED") - ) - - self.console.print( - f"\n{status_icon} {self._bold(ihv_result.ihv_type.value)} - {status_text}" - ) - self.console.print("-" * 60) - - # Version info - if ihv_result.ep_version: - self.console.print(f" EP Version: {self._dim(ihv_result.ep_version)}") - if ihv_result.driver_version: - self.console.print(f" Driver Version: {self._dim(ihv_result.driver_version)}") - - # Show EP Type and Device - self.console.print(f" EP: {self._bright_cyan(ihv_result.ep_type)}") - if ihv_result.device_type: - self.console.print(f" Device: {self._bright_cyan(ihv_result.device_type)}") - - # Support classification - self.console.print(f"\n {self._bold('Support Classification:')}") - - classification_info = [ - (SupportLevel.SUPPORTED, "+", "Fully Supported", "green"), - (SupportLevel.PARTIAL, "!", "Partial Support", "yellow"), - (SupportLevel.UNKNOWN, "?", "Unknown Support", "blue"), - (SupportLevel.UNSUPPORTED, "x", "Not Supported", "red"), - ] - - for level, icon, label, color in classification_info: - if level in ihv_result.classification: - operators = ihv_result.classification[level] - count = len(operators) - # Calculate percentage based on unique operator types, not total instances - percentage = ( - (count / unique_operator_types * 100) if unique_operator_types > 0 else 0 - ) - - count_str = f"[{color}]{count:3d}[/{color}]" - pct_str = f"[{color}]{percentage:5.1f}%[/{color}]" - - self.console.print(f" {icon} {label:20s}: {count_str} operator types ({pct_str})") - - # Show operators - expand SUPPORTED level, show samples for others - if count > 0: - if level == SupportLevel.SUPPORTED: - # Expand all fully supported operators - for op in operators: - self.console.print(f" • {self._bright_green(op)}") - else: - # Show all operators for non-supported levels - for op in operators: - self.console.print(f" • {self._dim(op)}") - - # Information summary - info_count = len(ihv_result.information) - if info_count > 0: - self.console.print( - f"\n {self._bright_yellow('Actionable Information')}: " - f"{self._bright_cyan(info_count)} items" - ) - - self._write_information_items(ihv_result.information, ihv_result.ihv_type.value) - - def _format_wrapped_text( - self, - text: str, - indent: str = " ", - first_line_indent: str = "", - width: int | None = None, - ) -> list[str]: - """Wrap long text with proper indentation. - - Args: - text: Text to wrap - indent: Indentation string for continuation lines - first_line_indent: Indentation for first line (default: empty) - width: Maximum width of each line (default: console width - padding) - - Returns: - List of formatted lines with proper indentation - """ - # Use console width if not specified, with padding for safety - if width is None: - width = self.width - 10 - - # Wrap the text - wrapper = textwrap.TextWrapper( - width=width, - initial_indent=first_line_indent, - subsequent_indent=indent, - break_long_words=False, - break_on_hyphens=False, - ) - - return wrapper.wrap(text) - - def _write_information_items(self, information_list: list[Information], ihv_name: str) -> None: - """Write detailed information items. - - Args: - information_list: List of Information objects - ihv_name: IHV platform name - """ - for idx, info in enumerate(information_list, 1): - self.console.print() - - # Format issue header with explanation - issue_text = f"Issue #{idx}:" - full_text = f"{issue_text} {info.explanation}" - - wrapped_lines = self._format_wrapped_text( - full_text, - indent=" ", - first_line_indent=" ", - ) - - # Print wrapped lines with formatting - if wrapped_lines: - # Format first line with Issue header as bold - first_line = wrapped_lines[0].strip() - issue_end = first_line.find(":") + 1 - if issue_end > 0: - issue_part = first_line[:issue_end] - rest_part = first_line[issue_end:].lstrip() - self.console.print(f" {self._bold(issue_part)} {rest_part}") - else: - self.console.print(f" {first_line}") - - # Print remaining lines - for line in wrapped_lines[1:]: - self.console.print(line) - - self.console.print() - - # Pattern info - if info.pattern_id: - self.console.print( - f" {self._bold('Pattern:')} {self._bright_cyan(info.pattern_id)}" - ) - - # Show affected nodes with details - if info.pattern_node_list: - instance_count = len(info.pattern_node_list) - total_nodes = sum(len(nodes) for nodes in info.pattern_node_list) - affected_label = self._bold("Affected:") - inst_str = self._bright_cyan(instance_count) - nodes_str = self._bright_cyan(total_nodes) - self.console.print( - f" {affected_label} {inst_str} pattern instances, {nodes_str} total nodes" - ) - - # Show first 3 pattern instances with node lists - max_patterns_to_show = 3 - for pattern_idx, node_list in enumerate( - info.pattern_node_list[:max_patterns_to_show], 1 - ): - if node_list: - self.console.print(f" Instance {pattern_idx}:") - # Show all nodes for readability - for i, node in enumerate(node_list): - node_escaped = escape(node) - dim_idx = self._dim(f"{i + 1}.") - green_node = self._bright_green(node_escaped) - self.console.print(f" {dim_idx} {green_node}") - - if instance_count > max_patterns_to_show: - remaining = instance_count - max_patterns_to_show - more_msg = self._dim(f"... and {remaining} more pattern instances") - self.console.print(f" {more_msg}") - self.console.print() - - # Show actions with transformation details - if info.actions: - self.console.print(f" {self._bold('Recommended Actions:')}") - for action_idx, action in enumerate(info.actions, 1): - # Show transformation: from_pattern -> to_pattern - transformation = ( - f"{self._bright_yellow(action.pattern_from_id)} -> " - f"{self._bright_green(action.pattern_to_id)}" - ) - priority_str = ( - f" {self._bright_red(f'[{action.level.value.upper()}]')}" - if action.level - else "" - ) - - transform_label = self._bold("Transform:") - self.console.print( - f" {action_idx}. {transform_label} {transformation}{priority_str}" - ) - - # Show expected status after transformation - if action.status: - status_icon = { - "supported": "+", - "partial": "!", - "unknown": "?", - "unsupported": "x", - }.get(action.status.value, "*") - result_label = self._bold("Expected Result:") - status_val = self._bright_green(action.status.value) - self.console.print(f" {result_label} {status_icon} {status_val}") - - # Show action details - if action.details: - self.console.print(f" {self._bold('Details:')}") - - # Try to parse as JSON for better formatting - try: - details_obj = json.loads(action.details) - - # If it's a list or dict, display as formatted JSON - if isinstance(details_obj, (list, dict)): - json_str = json.dumps(details_obj, indent=2, ensure_ascii=False) - # Display formatted JSON with proper indentation - indent_prefix = " " - for line in json_str.split("\n"): - self.console.print(f"{indent_prefix}{line}") - else: - # If it's a string value, just display it - wrapped_lines = self._format_wrapped_text( - str(details_obj), - indent=" ", - first_line_indent=" ", - ) - for line in wrapped_lines: - self.console.print(self._dim(line)) - except (json.JSONDecodeError, ValueError): - # If not valid JSON, treat as plain text - wrapped_lines = self._format_wrapped_text( - action.details, - indent=" ", - first_line_indent=" ", - ) - for line in wrapped_lines: - self.console.print(self._dim(line)) - - # Show action items (transformations/optimizations) - if action.action_items: - self.console.print(f" {self._bold('Steps:')}") - for item in action.action_items: - opt_str = "" - if item.optimization_options: - # Show all options - opts = ", ".join( - f"{k}={v}" for k, v in item.optimization_options.items() - ) - opt_str = f" {self._dim(f'({opts})')}" - self.console.print( - f" • {self._bright_cyan(item.type)}{opt_str}" - ) - self.console.print() - - def _write_footer(self, analysis: AnalysisOutput) -> None: - """Write analysis footer with summary.""" - self.console.print("=" * self.SEPARATOR_LENGTH) - self.console.print(f"{self._bold('ANALYSIS SUMMARY')}") - self.console.print("-" * self.SEPARATOR_LENGTH) - - # Overall support status - supported_platforms = sum(1 for r in analysis.results if r.runtime_support) - total_platforms = len(analysis.results) - - # Check if unsupported platforms only have unknown nodes (no unsupported/partial) - unsupported_results = [r for r in analysis.results if not r.runtime_support] - has_only_unknown = False - if unsupported_results: - # Check if there are NO unsupported or partial issues (only unknown/supported) - # Use .get() to handle missing keys and check for non-empty lists - has_only_unknown = all( - not r.classification.get(SupportLevel.UNSUPPORTED, []) - and not r.classification.get(SupportLevel.PARTIAL, []) - for r in unsupported_results - ) - - if supported_platforms == total_platforms: - status_msg = self._bright_green( - f"Model is supported on all {total_platforms} platform(s)" - ) - elif supported_platforms > 0 and not has_only_unknown: - status_msg = self._bright_yellow( - f"Model is supported on {supported_platforms}/{total_platforms} platform(s)" - ) - elif supported_platforms > 0 and has_only_unknown: - status_msg = self._bright_yellow( - f"Model is supported on {supported_platforms}/{total_platforms} platform(s), " - f"unknown nodes found on some of platforms" - ) - elif has_only_unknown: - status_msg = self._bright_yellow("Model has unknown nodes") - else: - status_msg = self._bright_red("Model is not supported on any platform") - - self.console.print(f" {status_msg}") - - # Show platform-specific summaries - for ep_result in analysis.results: - platform_name = ep_result.ep_type - if ep_result.runtime_support: - self.console.print(f" • {self._bright_green(platform_name)}: Ready to deploy") - else: - # Count issues - issue_counts = { - level: len(ops) - for level, ops in ep_result.classification.items() - if level != SupportLevel.SUPPORTED - } - - if any(issue_counts.values()): - # Check if only unknown nodes (no unsupported or partial) - has_only_unknown = all(level == SupportLevel.UNKNOWN for level in issue_counts) - - issue_summary = ", ".join( - f"{count} {level.value}" - for level, count in issue_counts.items() - if count > 0 - ) - - if has_only_unknown: - name = self._bright_yellow(platform_name) - self.console.print(f" • {name}: Unknown nodes found ({issue_summary})") - else: - name = self._bright_red(platform_name) - self.console.print(f" • {name}: Issues found ({issue_summary})") - - self.console.print() - self.console.print( - f"Use {self._bright_cyan('--output results.json')} to save detailed results" - ) - self.console.print("=" * self.SEPARATOR_LENGTH) - self.console.print() - - -def display_analysis_results( - analysis: AnalysisOutput, - console: Console | None = None, - verbose: bool = True, -) -> None: - """Display analysis results in console. - - Args: - analysis: AnalysisOutput containing analysis results - console: Optional Rich console instance - verbose: Whether to show verbose output - """ - writer = StaticAnalyzerConsoleWriter(console=console, verbose=verbose) - writer.write_analysis_results(analysis) diff --git a/src/winml/modelkit/analyze/utils/table_utils.py b/src/winml/modelkit/analyze/utils/table_utils.py deleted file mode 100644 index 45c679432..000000000 --- a/src/winml/modelkit/analyze/utils/table_utils.py +++ /dev/null @@ -1,36 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- - -"""Table utilities for runtime rule tables.""" - -from __future__ import annotations - -from typing import Any - -import pandas as pd - - -def build_table_df(raw_table: Any) -> pd.DataFrame: - """Build DataFrame from runtime table JSON while preserving int/None types. - - Runtime table JSON is stored as columnar dicts. Building with ``from_dict`` can - upcast int+None columns to float (for example, ``0/1/None`` to - ``0.0/1.0/nan``). Convert to row records first and force object dtype to keep - exact Python values. - """ - if not isinstance(raw_table, dict) or not raw_table: - return pd.DataFrame() - - columns = list(raw_table.keys()) - first_col = raw_table[columns[0]] - - if isinstance(first_col, dict): - row_keys = sorted(first_col.keys(), key=int) - rows = [{c: raw_table[c].get(k) for c in columns} for k in row_keys] - else: - row_count = len(first_col) - rows = [{c: raw_table[c][i] for c in columns} for i in range(row_count)] - - return pd.DataFrame(rows, dtype=object) diff --git a/src/winml/modelkit/core/onnx_node_bucketizer.py b/src/winml/modelkit/core/onnx_node_bucketizer.py deleted file mode 100644 index f1986f990..000000000 --- a/src/winml/modelkit/core/onnx_node_bucketizer.py +++ /dev/null @@ -1,128 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""ONNX Node Bucketization by Scope Name. - -This demonstrates how to group ONNX nodes by their scope names and -handle nodes with no scope (assign to root module). -""" - -from collections import defaultdict -from typing import Any - -import onnx - - -def bucketize_onnx_nodes_by_scope( - onnx_model: onnx.ModelProto, -) -> dict[str, list[onnx.NodeProto]]: - """Bucketize ONNX nodes by their scope names. - - Args: - onnx_model: ONNX model to analyze - - Returns: - Dictionary mapping scope names to lists of nodes - """ - scope_buckets = defaultdict(list) - - for node in onnx_model.graph.node: - scope_name = extract_scope_from_node(node) - scope_buckets[scope_name].append(node) - - return dict(scope_buckets) - - -def extract_scope_from_node(node: onnx.NodeProto) -> str: - """Extract scope name from ONNX node. - - Examples: - "/embeddings/word_embeddings/Gather" → "embeddings.word_embeddings" - "/encoder/layer.0/attention/self/query/MatMul" → "encoder.layer.0.attention.self.query" - "/Softmax_123" → "__root__" (no scope) - "MatMul" → "__root__" (no scope) - - Returns: - Scope name as dotted path, or "__root__" for nodes without scope - """ - node_name = node.name or "" - - # Handle empty node names - if not node_name: - return "__root__" - - # Handle root-level operations (no leading slash or single component) - if not node_name.startswith("/"): - return "__root__" - - # Parse structured node name: "/scope/path/OperationType" - parts = node_name.strip("/").split("/") - - # Single component means no scope (e.g., "/Gather_3") - if len(parts) <= 1: - return "__root__" - - # Extract scope path (everything except the last operation part) - scope_parts = parts[:-1] # Remove operation name - scope_name = ".".join(scope_parts) # Convert to dotted notation - - return scope_name if scope_name else "__root__" - - -def demonstrate_bucketization() -> defaultdict[str, list[Any]]: - """Demonstrate the bucketization process with examples.""" - # Example ONNX node names and their expected scopes - example_nodes = [ - ("/embeddings/word_embeddings/Gather", "embeddings.word_embeddings"), - ("/embeddings/LayerNorm/Add", "embeddings"), # LayerNorm under embeddings - ( - "/encoder/layer.0/attention/self/query/MatMul", - "encoder.layer.0.attention.self.query", - ), - ("/encoder/layer.0/attention/self/MatMul", "encoder.layer.0.attention.self"), - ( - "/encoder/layer.0/attention/output/dense/Gemm", - "encoder.layer.0.attention.output.dense", - ), - ("/pooler/dense/Tanh", "pooler.dense"), - ("/Softmax_123", "__root__"), # No scope - ("MatMul_456", "__root__"), # No scope - ("/Constant_789", "__root__"), # Root-level constant - ] - - print("🗂️ ONNX Node Scope Bucketization Examples:") - print("=" * 60) - - scope_buckets = defaultdict(list) - - for node_name, expected_scope in example_nodes: - # Create mock node - mock_node = type( - "MockNode", - (), - {"name": node_name, "op_type": node_name.split("/")[-1].split("_")[0]}, - )() - - # Extract scope - actual_scope = extract_scope_from_node(mock_node) - - # Add to bucket - scope_buckets[actual_scope].append(mock_node) - - # Verify expectation - status = "✅" if actual_scope == expected_scope else "❌" - print(f"{status} {node_name:50} -> {actual_scope}") - - print("\n📊 Scope Buckets:") - print("-" * 30) - for scope_name, nodes in scope_buckets.items(): - print(f"{scope_name:30} ({len(nodes)} nodes)") - for node in nodes: - print(f" └─ {node.name}") - - return scope_buckets - - -if __name__ == "__main__": - demonstrate_bucketization() diff --git a/src/winml/modelkit/core/operation_config.py b/src/winml/modelkit/core/operation_config.py deleted file mode 100644 index f6aa460d3..000000000 --- a/src/winml/modelkit/core/operation_config.py +++ /dev/null @@ -1,346 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Centralized Operation Configuration for PyTorch to ONNX Mapping. - -This module provides the OperationConfig class which serves as a single source -of truth for PyTorch operation definitions and their corresponding ONNX operation -types. This eliminates duplication between patching operations and ONNX mapping -across different export strategies. - -Key Features: -- Universal operation registry with PyTorch to ONNX mappings -- Priority-based operation organization -- Support for both torch and torch.nn.functional operations -- Extensible operation registry -- Centralized configuration for all export strategies -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import ClassVar - - -@dataclass -class OperationConfig: - """Centralized operation configuration for both patching and ONNX mapping. - - This class provides a single source of truth for PyTorch operation definitions - and their corresponding ONNX operation types, eliminating duplication between - _patch_torch_operations() and _project_execution_trace_to_onnx(). - """ - - # Single source of truth for operation mappings - OPERATION_REGISTRY: ClassVar[dict] = { - # Core mathematical operations - "matmul": { - "patch_targets": [("torch", "matmul")], - "onnx_types": ["MatMul", "Gemm"], - "priority": 1, - }, - "add": { - "patch_targets": [("torch", "add")], - "onnx_types": ["Add"], - "priority": 1, - }, - "sub": { - "patch_targets": [("torch", "sub")], - "onnx_types": ["Sub"], - "priority": 1, - }, - "mul": { - "patch_targets": [("torch", "mul")], - "onnx_types": ["Mul"], - "priority": 1, - }, - "div": { - "patch_targets": [("torch", "div")], - "onnx_types": ["Div"], - "priority": 1, - }, - "pow": { - "patch_targets": [("torch", "pow")], - "onnx_types": ["Pow"], - "priority": 1, - }, - "sqrt": { - "patch_targets": [("torch", "sqrt")], - "onnx_types": ["Sqrt"], - "priority": 1, - }, - "erf": { - "patch_targets": [("torch", "erf")], - "onnx_types": ["Erf"], - "priority": 1, - }, - "tanh": { - "patch_targets": [("torch", "tanh"), ("F", "tanh")], - "onnx_types": ["Tanh"], - "priority": 1, - }, - "relu": { - "patch_targets": [("torch", "relu"), ("F", "relu")], - "onnx_types": ["Relu"], - "priority": 1, - }, - "bmm": { - "patch_targets": [("torch", "bmm")], - "onnx_types": ["MatMul"], - "priority": 1, - }, - "abs": { - "patch_targets": [("torch", "abs")], - "onnx_types": ["Abs"], - "priority": 1, - }, - "neg": { - "patch_targets": [("torch", "neg")], - "onnx_types": ["Neg"], - "priority": 1, - }, - "reciprocal": { - "patch_targets": [("torch", "reciprocal")], - "onnx_types": ["Reciprocal"], - "priority": 1, - }, - "sigmoid": { - "patch_targets": [("torch", "sigmoid"), ("F", "sigmoid")], - "onnx_types": ["Sigmoid"], - "priority": 1, - }, - "log": { - "patch_targets": [("torch", "log")], - "onnx_types": ["Log"], - "priority": 1, - }, - "exp": { - "patch_targets": [("torch", "exp")], - "onnx_types": ["Exp"], - "priority": 1, - }, - "floor": { - "patch_targets": [("torch", "floor")], - "onnx_types": ["Floor"], - "priority": 1, - }, - "ceil": { - "patch_targets": [("torch", "ceil")], - "onnx_types": ["Ceil"], - "priority": 1, - }, - # Indexing and gathering operations - "index_select": { - "patch_targets": [("torch", "index_select")], - "onnx_types": ["Gather"], - "priority": 2, - }, - "gather": { - "patch_targets": [("torch", "gather")], - "onnx_types": ["Gather"], - "priority": 2, - }, - "embedding": { - "patch_targets": [("torch", "embedding"), ("F", "embedding")], - "onnx_types": ["Gather"], - "priority": 2, - }, - "where": { - "patch_targets": [("torch", "where")], - "onnx_types": ["Where"], - "priority": 2, - }, - "eq": { - "patch_targets": [("torch", "eq")], - "onnx_types": ["Equal"], - "priority": 2, - }, - "equal": { - "patch_targets": [("torch", "equal")], - "onnx_types": ["Equal"], - "priority": 2, - }, - # Shape operations - "reshape": { - "patch_targets": [("torch", "reshape")], - "onnx_types": ["Reshape"], - "priority": 3, - }, - "transpose": { - "patch_targets": [("torch", "transpose")], - "onnx_types": ["Transpose"], - "priority": 3, - }, - "unsqueeze": { - "patch_targets": [("torch", "unsqueeze")], - "onnx_types": ["Unsqueeze"], - "priority": 3, - }, - "squeeze": { - "patch_targets": [("torch", "squeeze")], - "onnx_types": ["Squeeze"], - "priority": 3, - }, - "cat": { - "patch_targets": [("torch", "cat")], - "onnx_types": ["Concat"], - "priority": 3, - }, - # Note: expand is a tensor method, not a torch function - # slice: PyTorch slicing (x[1:5]) converts to ONNX Slice nodes - # but there's no torch.slice function to patch - handled by ONNX conversion - "slice": { - "patch_targets": [], # No patchable function - tensor[1:5] syntax handled by ONNX - "onnx_types": ["Slice"], - "priority": 3, - }, - "narrow": { - "patch_targets": [("torch", "narrow")], - "onnx_types": ["Slice"], - "priority": 3, - }, - "select": { - "patch_targets": [("torch", "select")], - "onnx_types": ["Gather", "Slice"], - "priority": 3, - }, - "take": { - "patch_targets": [("torch", "take")], - "onnx_types": ["Gather"], - "priority": 3, - }, - # Reduction operations - "mean": { - "patch_targets": [("torch", "mean")], - "onnx_types": ["ReduceMean"], - "priority": 4, - }, - "sum": { - "patch_targets": [("torch", "sum")], - "onnx_types": ["ReduceSum"], - "priority": 4, - }, - "cumsum": { - "patch_targets": [("torch", "cumsum")], - "onnx_types": ["CumSum"], - "priority": 4, - }, - "cumprod": { - "patch_targets": [("torch", "cumprod")], - "onnx_types": ["CumProd"], - "priority": 4, - }, - # Note: cast is typically done via .to() method, not a torch function - # High-level functional operations - "linear": { - "patch_targets": [("F", "linear")], - "onnx_types": ["Gemm", "MatMul"], - "priority": 6, - }, - "softmax": { - "patch_targets": [("F", "softmax")], - "onnx_types": ["Softmax"], - "priority": 6, - }, - "layer_norm": { - "patch_targets": [("F", "layer_norm")], - "onnx_types": [ - "LayerNormalization", - "Add", - "Mul", - "Div", - "ReduceMean", - "Sub", - "Sqrt", - "Pow", - ], - "priority": 6, - }, - "pad": {"patch_targets": [("F", "pad")], "onnx_types": ["Pad"], "priority": 6}, - "dropout": { - "patch_targets": [("F", "dropout")], - "onnx_types": ["Dropout"], - "priority": 6, - }, - "gelu": { - "patch_targets": [("F", "gelu")], - "onnx_types": ["Erf", "Add", "Mul", "Div"], - "priority": 6, - }, - # Native operations (highest priority) - "scaled_dot_product_attention": { - "patch_targets": [("F", "scaled_dot_product_attention")], - "onnx_types": [ - "MatMul", - "Div", - "Softmax", - "MatMul", - ], # Typical decomposition pattern - "priority": 10, - }, - # Additional ONNX-only mappings (no patch targets) - "size": {"patch_targets": [], "onnx_types": ["Shape"], "priority": 3}, - "shape": {"patch_targets": [], "onnx_types": ["Shape"], "priority": 3}, - "zeros": { - "patch_targets": [], - "onnx_types": ["ConstantOfShape"], - "priority": 5, - }, - "ones": {"patch_targets": [], "onnx_types": ["ConstantOfShape"], "priority": 5}, - "full": {"patch_targets": [], "onnx_types": ["ConstantOfShape"], "priority": 5}, - "tensor": {"patch_targets": [], "onnx_types": ["Constant"], "priority": 5}, - } - - @classmethod - def get_operations_to_patch(cls) -> list[tuple]: - """Get list of (module_name, operation_name) tuples for patching. - - Returns: - List of tuples suitable for patching PyTorch operations - """ - import torch - import torch.nn.functional as F - - module_map = {"torch": torch, "F": F} - - operations = [] - for op_data in cls.OPERATION_REGISTRY.values(): - for module_name, op_name in op_data["patch_targets"]: - if module_name in module_map: - operations.append((module_map[module_name], op_name)) - - return operations - - @classmethod - def get_torch_to_onnx_mapping(cls) -> dict[str, list[str]]: - """Get mapping from PyTorch operation names to ONNX operation types. - - Returns: - Dictionary mapping operation names to lists of ONNX types - """ - return { - op_name: op_data["onnx_types"] for op_name, op_data in cls.OPERATION_REGISTRY.items() - } - - @classmethod - def add_operation( - cls, - op_name: str, - patch_targets: list[tuple[str, str]], - onnx_types: list[str], - priority: int = 5, - ) -> None: - """Add a new operation to the registry. - - Args: - op_name: Name of the operation - patch_targets: List of (module_name, operation_name) for patching - onnx_types: List of corresponding ONNX operation types - priority: Priority level (1=highest, 10=lowest) - """ - cls.OPERATION_REGISTRY[op_name] = { - "patch_targets": patch_targets, - "onnx_types": onnx_types, - "priority": priority, - } diff --git a/src/winml/modelkit/core/strategy_selector.py b/src/winml/modelkit/core/strategy_selector.py deleted file mode 100644 index 15f9708cf..000000000 --- a/src/winml/modelkit/core/strategy_selector.py +++ /dev/null @@ -1,331 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Intelligent Strategy Selection for ModelExport. - -This module provides automatic strategy selection based on model characteristics -and user requirements. It analyzes models and recommends the optimal export strategy. -""" - -import logging -from dataclasses import dataclass -from enum import Enum -from typing import Any, ClassVar - -import torch - - -logger = logging.getLogger(__name__) - - -class ExportStrategy(Enum): - """Available export strategies.""" - - USAGE_BASED = "usage_based" - HTP = "htp" - FX = "fx_graph" - AUTO = "auto" - - -@dataclass -class ModelCharacteristics: - """Characteristics of a model that affect strategy selection.""" - - model_type: str # e.g., "transformer", "cnn", "unknown" - has_control_flow: bool - is_huggingface: bool - module_count: int - has_dynamic_shapes: bool - estimated_complexity: str # "low", "medium", "high" - framework_hints: list[str] # e.g., ["attention", "convolution", "embedding"] - - -@dataclass -class StrategyRecommendation: - """Strategy recommendation with reasoning.""" - - primary_strategy: ExportStrategy - fallback_strategy: ExportStrategy | None - confidence: float # 0.0 to 1.0 - reasoning: list[str] - warnings: list[str] - expected_performance: dict[str, Any] - - -class ModelAnalyzer: - """Analyze PyTorch models to determine their characteristics.""" - - @staticmethod - def analyze_model(model: torch.nn.Module) -> ModelCharacteristics: - """Analyze a PyTorch model to determine its characteristics. - - Args: - model: PyTorch model to analyze - - Returns: - ModelCharacteristics with detected features - """ - characteristics = ModelCharacteristics( - model_type="unknown", - has_control_flow=False, - is_huggingface=False, - module_count=0, - has_dynamic_shapes=False, - estimated_complexity="medium", - framework_hints=[], - ) - - # Count modules - module_count = sum(1 for _ in model.named_modules()) - characteristics.module_count = module_count - - # Detect model type and framework - model_class_name = model.__class__.__name__.lower() - [name for name, _ in model.named_modules()] - - # Check if HuggingFace model - if hasattr(model, "config") or any( - "transformers" in str(type(m)) for _, m in model.named_modules() - ): - characteristics.is_huggingface = True - - # Detect transformer architecture - transformer_indicators = [ - "attention", - "transformer", - "bert", - "gpt", - "vit", - "clip", - ] - if any(indicator in model_class_name for indicator in transformer_indicators): - characteristics.model_type = "transformer" - characteristics.framework_hints.append("attention") - - # Detect CNN architecture - cnn_indicators = ["resnet", "vgg", "mobilenet", "efficientnet", "densenet"] - conv_count = sum(1 for _, m in model.named_modules() if isinstance(m, torch.nn.Conv2d)) - if any(indicator in model_class_name for indicator in cnn_indicators) or conv_count > 5: - characteristics.model_type = "cnn" - characteristics.framework_hints.append("convolution") - - # Detect embedding layers - if any(isinstance(m, torch.nn.Embedding) for _, m in model.named_modules()): - characteristics.framework_hints.append("embedding") - - # Check for control flow indicators - # This is a heuristic - actual control flow detection requires tracing - if characteristics.is_huggingface: - # HuggingFace models often have control flow - characteristics.has_control_flow = True - - # Estimate complexity - if module_count < 50: - characteristics.estimated_complexity = "low" - elif module_count < 200: - characteristics.estimated_complexity = "medium" - else: - characteristics.estimated_complexity = "high" - - # Check for dynamic shapes (heuristic) - if characteristics.model_type == "transformer": - characteristics.has_dynamic_shapes = True - - return characteristics - - -class StrategySelector: - """Intelligent strategy selector for ONNX export. - - Based on extensive benchmarking from Iterations 16-18. - """ - - # Performance benchmarks from testing (in seconds) - PERFORMANCE_DATA: ClassVar[dict[str, dict[ExportStrategy, float | None]]] = { - "resnet-50": { - ExportStrategy.USAGE_BASED: 2.488, - ExportStrategy.HTP: 5.920, - ExportStrategy.FX: None, # Incompatible - }, - "transformer": { - ExportStrategy.USAGE_BASED: 3.5, # Estimated - ExportStrategy.HTP: 6.0, # Estimated - ExportStrategy.FX: None, # Incompatible - }, - "simple_cnn": { - ExportStrategy.USAGE_BASED: 1.0, # Estimated - ExportStrategy.HTP: 1.5, # Estimated - ExportStrategy.FX: 0.8, # Estimated - }, - } - - @classmethod - def recommend_strategy( - cls, - model: torch.nn.Module, - prioritize_speed: bool = True, - prioritize_coverage: bool = False, - force_strategy: ExportStrategy | None = None, - ) -> StrategyRecommendation: - """Recommend the best export strategy for a given model. - - Args: - model: PyTorch model to export - prioritize_speed: Prioritize export speed (default: True) - prioritize_coverage: Prioritize hierarchy coverage over speed - force_strategy: Force a specific strategy (for testing) - - Returns: - StrategyRecommendation with primary and fallback strategies - """ - # Analyze model characteristics - characteristics = ModelAnalyzer.analyze_model(model) - - # Initialize recommendation - recommendation = StrategyRecommendation( - primary_strategy=ExportStrategy.USAGE_BASED, # Default - fallback_strategy=ExportStrategy.HTP, - confidence=0.9, - reasoning=[], - warnings=[], - expected_performance={}, - ) - - # Handle forced strategy - if force_strategy and force_strategy != ExportStrategy.AUTO: - recommendation.primary_strategy = force_strategy - recommendation.reasoning.append(f"Strategy forced to {force_strategy.value}") - recommendation.confidence = 1.0 - return recommendation - - # Strategy selection logic based on benchmarks - if characteristics.is_huggingface: - # HuggingFace models - FX incompatible - recommendation.reasoning.append("HuggingFace model detected") - - if prioritize_speed: - recommendation.primary_strategy = ExportStrategy.USAGE_BASED - recommendation.reasoning.append( - "Usage-Based fastest for HuggingFace models (2.488s vs 5.920s)" - ) - recommendation.expected_performance = {"export_time": 2.5} - elif prioritize_coverage: - recommendation.primary_strategy = ExportStrategy.HTP - recommendation.reasoning.append("HTP provides more comprehensive tracing") - recommendation.expected_performance = {"export_time": 6.0} - - recommendation.warnings.append("FX strategy incompatible with HuggingFace models") - - elif characteristics.has_control_flow: - # Models with control flow - FX likely incompatible - recommendation.reasoning.append("Control flow detected") - recommendation.primary_strategy = ExportStrategy.USAGE_BASED - recommendation.fallback_strategy = ExportStrategy.HTP - recommendation.warnings.append("FX strategy may fail due to control flow") - - elif characteristics.model_type == "cnn" and characteristics.estimated_complexity == "low": - # Simple CNNs - FX might work - recommendation.reasoning.append("Simple CNN architecture detected") - - if prioritize_speed and not characteristics.is_huggingface: - recommendation.primary_strategy = ExportStrategy.FX - recommendation.fallback_strategy = ExportStrategy.USAGE_BASED - recommendation.reasoning.append("FX can be fastest for simple CNNs") - recommendation.confidence = 0.7 # Lower confidence due to compatibility risks - else: - recommendation.primary_strategy = ExportStrategy.USAGE_BASED - recommendation.reasoning.append("Usage-Based most reliable for CNNs") - - else: - # Default case - Usage-Based is safest and fastest - recommendation.reasoning.append("Default recommendation based on benchmarks") - recommendation.primary_strategy = ExportStrategy.USAGE_BASED - recommendation.reasoning.append("Usage-Based proven fastest and most reliable") - - # Add performance expectations - if characteristics.model_type == "transformer": - recommendation.expected_performance = { - "export_time": 3.5 - if recommendation.primary_strategy == ExportStrategy.USAGE_BASED - else 6.0, - "coverage": "high", - "reliability": "excellent", - } - elif characteristics.model_type == "cnn": - recommendation.expected_performance = { - "export_time": 2.5 - if recommendation.primary_strategy == ExportStrategy.USAGE_BASED - else 4.0, - "coverage": "high", - "reliability": "excellent", - } - - # Add module count to performance expectations - recommendation.expected_performance["module_count"] = characteristics.module_count - - return recommendation - - @classmethod - def get_strategy_description(cls, strategy: ExportStrategy) -> dict[str, str]: - """Get description and characteristics of a strategy.""" - descriptions = { - ExportStrategy.USAGE_BASED: { - "name": "Usage-Based", - "description": "Simple and fast hierarchy tracking using forward hooks", - "pros": "Fastest (2.488s), reliable, works with all models", - "cons": "Basic hierarchy tracking, may miss some unused modules", - "best_for": "Production use, HuggingFace models, speed-critical applications", - }, - ExportStrategy.HTP: { - "name": "Hierarchical Trace-and-Project (HTP)", - "description": "Comprehensive tracing with built-in PyTorch module tracking", - "pros": "Detailed hierarchy, handles complex models, good coverage", - "cons": "Slower (5.920s), more complex implementation", - "best_for": "Development, debugging, comprehensive analysis", - }, - ExportStrategy.FX: { - "name": "FX Graph", - "description": "PyTorch FX symbolic tracing for graph analysis", - "pros": "Can be fast for simple models, graph-level analysis", - "cons": "Incompatible with control flow, fails on HuggingFace models", - "best_for": "Simple PyTorch models without dynamic control flow", - }, - } - - return descriptions.get( - strategy, - { - "name": strategy.value, - "description": "Unknown strategy", - "pros": "N/A", - "cons": "N/A", - "best_for": "N/A", - }, - ) - - -def select_best_strategy( - model: torch.nn.Module, example_inputs: torch.Tensor | tuple | None = None, **kwargs: Any -) -> tuple[ExportStrategy, StrategyRecommendation]: - """Convenience function to select the best strategy for a model. - - Args: - model: PyTorch model to export - example_inputs: Example inputs (used for shape analysis if provided) - **kwargs: Additional arguments passed to recommend_strategy - - Returns: - Tuple of (selected_strategy, recommendation_details) - """ - selector = StrategySelector() - recommendation = selector.recommend_strategy(model, **kwargs) - - logger.info(f"Selected strategy: {recommendation.primary_strategy.value}") - logger.info(f"Reasoning: {'; '.join(recommendation.reasoning)}") - - if recommendation.warnings: - for warning in recommendation.warnings: - logger.warning(warning) - - return recommendation.primary_strategy, recommendation diff --git a/src/winml/modelkit/core/tag_utils.py b/src/winml/modelkit/core/tag_utils.py deleted file mode 100644 index 9fe524e65..000000000 --- a/src/winml/modelkit/core/tag_utils.py +++ /dev/null @@ -1,340 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Utility functions for reading and manipulating hierarchy tags in ONNX models. - -This module provides functions to: -1. Read tags from ONNX node attributes -2. Read tags from sidecar JSON files -3. Validate tag consistency between sources -4. Query and filter operations by tags -""" - -import json -from pathlib import Path -from typing import Any, cast - -import onnx - -from ..onnx import load_onnx - - -def load_tags_from_onnx(onnx_path: str) -> dict[str, dict[str, Any]]: - """Load hierarchy tags from ONNX node attributes and doc_string fields. - - Args: - onnx_path: Path to ONNX model file - - Returns: - Dictionary mapping node names to their tag information - """ - model = load_onnx(onnx_path, validate=False) - node_tags: dict[str, dict[str, Any]] = {} - - for node in model.graph.node: - node_name = node.name or f"{node.op_type}_{hash(str(node))}" - node_info: dict[str, Any] = {"op_type": node.op_type} - - # First, check for hierarchy_tag attribute (HTP format) - hierarchy_tag = None - for attr in node.attribute: - if attr.name == "hierarchy_tag" and attr.type == onnx.AttributeProto.STRING: - hierarchy_tag = attr.s.decode('utf-8') - break - - if hierarchy_tag: - node_info["tags"] = [hierarchy_tag] - node_info["primary_path"] = hierarchy_tag - node_info["tag_count"] = 1 - node_info["method"] = "htp" - else: - # Fallback: Extract hierarchy information from doc_string (legacy format) - if node.doc_string: - try: - hierarchy_info = json.loads(node.doc_string) - if isinstance(hierarchy_info, dict) and "hierarchy_tags" in hierarchy_info: - node_info["tags"] = hierarchy_info.get("hierarchy_tags", []) - node_info["primary_path"] = hierarchy_info.get("hierarchy_path", "") - node_info["tag_count"] = hierarchy_info.get("hierarchy_count", 0) - node_info["method"] = hierarchy_info.get("hierarchy_method", "unknown") - except (json.JSONDecodeError, TypeError): - # Skip nodes with invalid JSON in doc_string - pass - - # Only include nodes that have hierarchy tags - if "tags" in node_info: - node_tags[node_name] = node_info - - return node_tags - - -def load_tags_from_sidecar(onnx_path: str) -> dict[str, Any]: - """Load hierarchy tags from sidecar JSON file. - - Args: - onnx_path: Path to ONNX model file. - Sidecar assumed to be *_hierarchy.json or *_htp_metadata.json. - - Returns: - Complete sidecar data including metadata and tag mappings - """ - # Try HTP metadata format first (new format) - htp_sidecar_path = Path(onnx_path.replace('.onnx', '_htp_metadata.json')) - if htp_sidecar_path.exists(): - with htp_sidecar_path.open() as f: - return cast("dict[str, Any]", json.load(f)) - - # Try legacy integrated format - htp_integrated_path = Path(onnx_path.replace('.onnx', '_htp_integrated_metadata.json')) - if htp_integrated_path.exists(): - with htp_integrated_path.open() as f: - return cast("dict[str, Any]", json.load(f)) - - # Try legacy hierarchy format - legacy_sidecar_path = Path(onnx_path.replace('.onnx', '_hierarchy.json')) - if legacy_sidecar_path.exists(): - with legacy_sidecar_path.open() as f: - return cast("dict[str, Any]", json.load(f)) - - msg = ( - f"Sidecar file not found. Tried: {htp_sidecar_path}, " - f"{htp_integrated_path}, {legacy_sidecar_path}" - ) - raise FileNotFoundError(msg) - - -def validate_tag_consistency(onnx_path: str) -> dict[str, Any]: - """Validate that tags in ONNX attributes match those in sidecar file. - - Args: - onnx_path: Path to ONNX model file - - Returns: - Validation report with consistency statistics - """ - try: - onnx_tags = load_tags_from_onnx(onnx_path) - sidecar_data = load_tags_from_sidecar(onnx_path) - - # Handle different sidecar formats - if "tagged_nodes" in sidecar_data: - # HTP format: convert to legacy format for comparison - sidecar_tags = {} - for node_name, tag in sidecar_data["tagged_nodes"].items(): - if tag: # Only include non-empty tags - sidecar_tags[node_name] = {"tags": [tag]} - else: - # Legacy format - sidecar_tags = sidecar_data.get("node_tags", {}) - - # Compare tag consistency - mismatches = [] - onnx_only = set(onnx_tags.keys()) - set(sidecar_tags.keys()) - sidecar_only = set(sidecar_tags.keys()) - set(onnx_tags.keys()) - - for node_name in set(onnx_tags.keys()) & set(sidecar_tags.keys()): - onnx_node_tags: set[str] = set(onnx_tags[node_name].get("tags", [])) - sidecar_node_tags: set[str] = set(sidecar_tags[node_name].get("tags", [])) - - if onnx_node_tags != sidecar_node_tags: - mismatches.append({ - "node": node_name, - "onnx_tags": list(onnx_node_tags), - "sidecar_tags": list(sidecar_node_tags) - }) - - return { - "consistent": len(mismatches) == 0 and len(onnx_only) == 0 and len(sidecar_only) == 0, - "total_onnx_nodes": len(onnx_tags), - "total_sidecar_nodes": len(sidecar_tags), - "tag_mismatches": mismatches, - "onnx_only_nodes": list(onnx_only), - "sidecar_only_nodes": list(sidecar_only) - } - - except Exception as e: - return { - "consistent": False, - "error": str(e) - } - - -def query_operations_by_tag( - onnx_path: str, tag_pattern: str, use_sidecar: bool = True -) -> list[dict[str, Any]]: - """Query operations that match a specific tag pattern. - - Args: - onnx_path: Path to ONNX model file - tag_pattern: Tag pattern to match (supports partial matching) - use_sidecar: Whether to use sidecar file (True) or ONNX attributes (False) - - Returns: - List of operations matching the tag pattern - """ - if use_sidecar: - sidecar_data = load_tags_from_sidecar(onnx_path) - - # Handle different sidecar formats - if "tagged_nodes" in sidecar_data: - # HTP format: convert to legacy format for processing - node_tags = {} - for node_name, tag in sidecar_data["tagged_nodes"].items(): - if tag: # Only include non-empty tags - node_tags[node_name] = {"tags": [tag]} - else: - # Legacy format - node_tags = sidecar_data.get("node_tags", {}) - else: - node_tags = load_tags_from_onnx(onnx_path) - - matching_operations = [] - - for node_name, node_info in node_tags.items(): - tags = node_info.get("tags", []) - - # Check if any tag matches the pattern - for tag in tags: - if tag_pattern in tag: - matching_operations.append({ - "node_name": node_name, - "op_type": node_info.get("op_type", "unknown"), - "matching_tag": tag, - "all_tags": tags - }) - break - - return matching_operations - - -def get_tag_statistics(onnx_path: str, use_sidecar: bool = True) -> dict[str, Any]: - """Get statistics about tag distribution in the model. - - Args: - onnx_path: Path to ONNX model file - use_sidecar: Whether to use sidecar file (True) or ONNX attributes (False) - - Returns: - Tag distribution statistics - """ - if use_sidecar: - try: - sidecar_data = load_tags_from_sidecar(onnx_path) - - # Check if it has pre-computed tag_statistics (legacy format) - if "tag_statistics" in sidecar_data: - return cast("dict[str, Any]", sidecar_data["tag_statistics"]) - - # If it's HTP format, compute statistics from tagged_nodes - if "tagged_nodes" in sidecar_data: - tagged_nodes = sidecar_data["tagged_nodes"] - tag_counts: dict[str, int] = {} - for tag in tagged_nodes.values(): - if tag: # Skip empty tags - tag_counts[tag] = tag_counts.get(tag, 0) + 1 - return tag_counts - - # Fall back to node_tags format (legacy) - if "node_tags" in sidecar_data: - node_tags = sidecar_data["node_tags"] - tag_counts = {} - for node_info in node_tags.values(): - for tag in node_info.get("tags", []): - tag_counts[tag] = tag_counts.get(tag, 0) + 1 - return tag_counts - - except FileNotFoundError: - # Fall back to ONNX attributes if sidecar not found - pass - - # Compute statistics from ONNX attributes - node_tags = load_tags_from_onnx(onnx_path) - tag_counts = {} - - for node_info in node_tags.values(): - for tag in node_info.get("tags", []): - tag_counts[tag] = tag_counts.get(tag, 0) + 1 - - return tag_counts - - -def export_tags_to_csv(onnx_path: str, output_csv: str, use_sidecar: bool = True) -> None: - """Export tag information to CSV for analysis. - - Args: - onnx_path: Path to ONNX model file - output_csv: Path to output CSV file - use_sidecar: Whether to use sidecar file (True) or ONNX attributes (False) - """ - import csv - - if use_sidecar: - sidecar_data = load_tags_from_sidecar(onnx_path) - - # Handle different sidecar formats - if "tagged_nodes" in sidecar_data: - # HTP format: convert to legacy format for processing - node_tags = {} - for node_name, tag in sidecar_data["tagged_nodes"].items(): - if tag: # Only include non-empty tags - node_tags[node_name] = {"tags": [tag]} - else: - # Legacy format - node_tags = sidecar_data.get("node_tags", {}) - else: - node_tags = load_tags_from_onnx(onnx_path) - - with Path(output_csv).open('w', newline='') as csvfile: - writer = csv.writer(csvfile) - writer.writerow(["Node Name", "Op Type", "Tag Count", "Primary Tag", "All Tags"]) - - for node_name, node_info in node_tags.items(): - tags = node_info.get("tags", []) - op_type = node_info.get("op_type", "unknown") - primary_tag = tags[0] if tags else "" - all_tags = "|".join(tags) - - writer.writerow([node_name, op_type, len(tags), primary_tag, all_tags]) - - -def compare_tag_distributions(onnx_path1: str, onnx_path2: str) -> dict[str, Any]: - """Compare tag distributions between two ONNX models. - - Args: - onnx_path1: Path to first ONNX model - onnx_path2: Path to second ONNX model - - Returns: - Comparison report - """ - stats1 = get_tag_statistics(onnx_path1) - stats2 = get_tag_statistics(onnx_path2) - - all_tags = set(stats1.keys()) | set(stats2.keys()) - - comparison: dict[str, Any] = { - "model1_path": onnx_path1, - "model2_path": onnx_path2, - "tag_differences": [], - "model1_only_tags": [], - "model2_only_tags": [] - } - - for tag in all_tags: - count1 = stats1.get(tag, 0) - count2 = stats2.get(tag, 0) - - if count1 > 0 and count2 == 0: - comparison["model1_only_tags"].append(tag) - elif count1 == 0 and count2 > 0: - comparison["model2_only_tags"].append(tag) - elif count1 != count2: - comparison["tag_differences"].append({ - "tag": tag, - "model1_count": count1, - "model2_count": count2, - "difference": count2 - count1 - }) - - return comparison diff --git a/src/winml/modelkit/core/unified_optimizer.py b/src/winml/modelkit/core/unified_optimizer.py deleted file mode 100644 index 9930abbda..000000000 --- a/src/winml/modelkit/core/unified_optimizer.py +++ /dev/null @@ -1,426 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -"""Unified Optimization Framework for ModelExport. - -This module provides a unified optimization framework that applies optimizations -across all export strategies based on learnings from iterations 17-18. -""" - -import logging -import time -from collections import defaultdict -from collections.abc import Callable -from dataclasses import dataclass -from functools import wraps -from typing import Any, ClassVar - -import torch - - -logger = logging.getLogger(__name__) - - -@dataclass -class OptimizationProfile: - """Profile of optimizations applied to an exporter.""" - - strategy_name: str - optimizations_applied: list[str] - performance_metrics: dict[str, float] - warnings: list[str] - - def get_summary(self) -> str: - """Get a summary of applied optimizations.""" - return f"{self.strategy_name}: {len(self.optimizations_applied)} optimizations applied" - - -class PerformanceMonitor: - """Monitor and track performance metrics across operations.""" - - def __init__(self) -> None: - self.timings: defaultdict[str, list[float]] = defaultdict(list) - self.counters: defaultdict[str, int] = defaultdict(int) - - def time_operation(self, operation_name: str) -> Callable[..., Any]: - """Decorator to time an operation.""" - - def decorator(func: Callable[..., Any]) -> Callable[..., Any]: - @wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - start_time = time.time() - try: - result = func(*args, **kwargs) - elapsed = time.time() - start_time - self.timings[operation_name].append(elapsed) - return result - except Exception: - elapsed = time.time() - start_time - self.timings[f"{operation_name}_error"].append(elapsed) - raise - - return wrapper - - return decorator - - def increment_counter(self, counter_name: str, value: int = 1) -> None: - """Increment a performance counter.""" - self.counters[counter_name] += value - - def get_metrics(self) -> dict[str, Any]: - """Get all collected metrics.""" - metrics: dict[str, Any] = {} - - # Timing metrics - for operation, times in self.timings.items(): - if times: - metrics[f"{operation}_count"] = len(times) - metrics[f"{operation}_total"] = sum(times) - metrics[f"{operation}_avg"] = sum(times) / len(times) - metrics[f"{operation}_min"] = min(times) - metrics[f"{operation}_max"] = max(times) - - # Counter metrics - metrics.update(self.counters) - - return metrics - - -class UnifiedOptimizer: - """Unified optimization framework for all export strategies. - - Applies common optimizations learned from iterations 17-18. - """ - - # Common optimizations that apply to all strategies - COMMON_OPTIMIZATIONS: ClassVar[dict[str, dict[str, Any]]] = { - "single_pass_algorithms": { - "description": "Use single-pass algorithms to reduce redundant computation", - "applicable_to": ["all"], - "impact": "medium", - }, - "batch_processing": { - "description": "Batch similar operations together", - "applicable_to": ["all"], - "impact": "medium", - }, - "caching": { - "description": "Cache computed values to avoid recomputation", - "applicable_to": ["all"], - "impact": "high", - }, - "lightweight_operations": { - "description": "Use lightweight data structures and operations", - "applicable_to": ["all"], - "impact": "medium", - }, - "optimized_onnx_params": { - "description": "Optimize ONNX export parameters", - "applicable_to": ["all"], - "impact": "medium", - }, - } - - # Strategy-specific optimizations - STRATEGY_OPTIMIZATIONS: ClassVar[dict[str, dict[str, dict[str, str]]]] = { - "htp": { - "tag_injection_optimization": { - "description": "Optimize tag injection using single-pass and Counter", - "impact": "high", - }, - "builtin_tracking": { - "description": "Use PyTorch's built-in module tracking", - "impact": "medium", - }, - }, - "usage_based": { - "lightweight_hooks": { - "description": "Use minimal overhead forward hooks", - "impact": "low", - }, - "pre_allocated_structures": { - "description": "Pre-allocate data structures", - "impact": "low", - }, - }, - "fx": { - "graph_caching": { - "description": "Cache FX graph transformations", - "impact": "medium", - }, - "node_batching": { - "description": "Batch node operations in FX graph", - "impact": "medium", - }, - }, - } - - def __init__(self) -> None: - self.monitor = PerformanceMonitor() - self.applied_optimizations: list[str] = [] - - def optimize_exporter(self, exporter: Any, strategy_name: str) -> OptimizationProfile: - """Apply optimizations to an exporter based on its strategy. - - Args: - exporter: The exporter instance to optimize - strategy_name: Name of the export strategy - - Returns: - OptimizationProfile with details of applied optimizations - """ - profile = OptimizationProfile( - strategy_name=strategy_name, - optimizations_applied=[], - performance_metrics={}, - warnings=[], - ) - - # Apply common optimizations - self._apply_common_optimizations(exporter, profile) - - # Apply strategy-specific optimizations - if strategy_name in self.STRATEGY_OPTIMIZATIONS: - self._apply_strategy_optimizations(exporter, strategy_name, profile) - - # Add performance monitor - if not hasattr(exporter, "_performance_monitor"): - exporter._performance_monitor = self.monitor - profile.optimizations_applied.append("performance_monitoring") - - # Log optimization summary - opt_count = len(profile.optimizations_applied) - logger.info(f"Applied {opt_count} optimizations to {strategy_name} exporter") - - return profile - - def _apply_common_optimizations( - self, exporter: Any, profile: OptimizationProfile - ) -> None: - """Apply optimizations common to all strategies.""" - # 1. Optimize ONNX export parameters - if hasattr(exporter, "export"): - original_export = exporter.export - - @wraps(original_export) - def optimized_export( - model: Any, example_inputs: Any, output_path: str, **kwargs: Any - ) -> Any: - # Apply ONNX parameter optimizations - import torch - - kwargs.setdefault("training", torch.onnx.TrainingMode.EVAL) - kwargs.setdefault("opset_version", 14) - kwargs.setdefault("verbose", False) - kwargs.setdefault("operator_export_type", torch.onnx.OperatorExportTypes.ONNX) - kwargs.setdefault("keep_initializers_as_inputs", True) - - return original_export(model, example_inputs, output_path, **kwargs) - - exporter.export = optimized_export - profile.optimizations_applied.append("optimized_onnx_params") - - # 2. Add caching capability - if not hasattr(exporter, "_cache"): - exporter._cache = {} - exporter._cache_hits = 0 - exporter._cache_misses = 0 - - def get_cached(key: str, compute_func: Callable[[], Any]) -> Any: - """Get value from cache or compute it.""" - if key in exporter._cache: - exporter._cache_hits += 1 - return exporter._cache[key] - exporter._cache_misses += 1 - value = compute_func() - exporter._cache[key] = value - return value - - exporter.get_cached = get_cached - profile.optimizations_applied.append("caching") - - # 3. Add batch processing utilities - if not hasattr(exporter, "batch_process"): - - def batch_process( - items: list[Any], process_func: Callable[[Any], Any], batch_size: int = 100 - ) -> list[Any]: - """Process items in batches for efficiency.""" - results = [] - for i in range(0, len(items), batch_size): - batch = items[i : i + batch_size] - batch_results = [process_func(item) for item in batch] - results.extend(batch_results) - return results - - exporter.batch_process = batch_process - profile.optimizations_applied.append("batch_processing") - - def _apply_strategy_optimizations( - self, exporter: Any, strategy_name: str, profile: OptimizationProfile - ) -> None: - """Apply strategy-specific optimizations.""" - if strategy_name == "htp": - # Apply HTP-specific optimizations from iteration 17 - try: - from ..strategies.htp.optimizations import ( # type: ignore[import-not-found] - apply_htp_optimizations, - ) - - apply_htp_optimizations(exporter) - profile.optimizations_applied.extend( - ["tag_injection_optimization", "builtin_tracking"] - ) - except ImportError: - profile.warnings.append("HTP optimizations module not found") - - elif strategy_name == "usage_based": - # Apply Usage-Based optimizations from iteration 18 - try: - from ..strategies.usage_based.optimizations import ( # type: ignore[import-not-found] - apply_usage_based_optimizations, - ) - - apply_usage_based_optimizations(exporter) - profile.optimizations_applied.extend( - ["lightweight_hooks", "pre_allocated_structures"] - ) - except ImportError: - profile.warnings.append("Usage-Based optimizations module not found") - - elif strategy_name == "fx": - # Apply FX-specific optimizations - self._apply_fx_optimizations(exporter, profile) - - def _apply_fx_optimizations( - self, exporter: Any, profile: OptimizationProfile - ) -> None: - """Apply FX-specific optimizations.""" - # Add graph caching for FX - if hasattr(exporter, "_trace_transformers_model"): - original_trace = exporter._trace_transformers_model - graph_cache: dict[int, Any] = {} - - @wraps(original_trace) - def cached_trace( - model: Any, example_inputs: Any = None - ) -> Any: - model_id = id(model) - if model_id in graph_cache: - logger.debug("Using cached FX graph") - return graph_cache[model_id] - - result = original_trace(model, example_inputs) - graph_cache[model_id] = result - return result - - exporter._trace_transformers_model = cached_trace - profile.optimizations_applied.append("graph_caching") - - -def create_optimized_exporter(strategy: str, **kwargs: Any) -> Any: - """Create an optimized exporter for the given strategy. - - Args: - strategy: Export strategy name ("usage_based", "htp", "fx_graph") - **kwargs: Additional arguments for the exporter - - Returns: - Optimized exporter instance - """ - # Import strategy modules - if strategy == "usage_based": - from ..strategies.usage_based import UsageBasedExporter # type: ignore[import-not-found] - - exporter = UsageBasedExporter(**kwargs) - elif strategy == "htp": - from ..strategies.htp import HTPExporter # type: ignore[import-not-found] - - exporter = HTPExporter(**kwargs) - elif strategy == "fx_graph" or strategy == "fx": - from ..strategies.fx import FXHierarchyExporter # type: ignore[import-not-found] - - exporter = FXHierarchyExporter(**kwargs) - else: - raise ValueError(f"Unknown strategy: {strategy}") - - # Apply unified optimizations - optimizer = UnifiedOptimizer() - optimization_profile = optimizer.optimize_exporter(exporter, strategy) - - # Store optimization profile - exporter._optimization_profile = optimization_profile - - opt_count = len(optimization_profile.optimizations_applied) - logger.info(f"Created optimized {strategy} exporter with {opt_count} optimizations") - - return exporter - - -class OptimizationBenchmark: - """Benchmark the impact of optimizations.""" - - @staticmethod - def compare_optimized_vs_original( - model: torch.nn.Module, example_inputs: Any, strategy: str, num_runs: int = 3 - ) -> dict[str, Any]: - """Compare optimized vs original exporter performance. - - Returns: - Dictionary with comparison metrics - """ - import tempfile - from pathlib import Path - - results: dict[str, Any] = { - "strategy": strategy, - "original_times": [], - "optimized_times": [], - "optimization_impact": {}, - } - - # Test original - for _i in range(num_runs): - # Create unoptimized exporter - if strategy == "usage_based": - from ..strategies.usage_based import UsageBasedExporter - - exporter = UsageBasedExporter() - elif strategy == "htp": - from ..strategies.htp import HTPExporter - - exporter = HTPExporter() - else: - continue - - with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: - start_time = time.time() - exporter.export(model, example_inputs, tmp.name) - elapsed = time.time() - start_time - results["original_times"].append(elapsed) - Path(tmp.name).unlink() - - # Test optimized - for _i in range(num_runs): - exporter = create_optimized_exporter(strategy) - - with tempfile.NamedTemporaryFile(suffix=".onnx", delete=False) as tmp: - start_time = time.time() - exporter.export(model, example_inputs, tmp.name) - elapsed = time.time() - start_time - results["optimized_times"].append(elapsed) - Path(tmp.name).unlink() - - # Calculate impact - avg_original = sum(results["original_times"]) / len(results["original_times"]) - avg_optimized = sum(results["optimized_times"]) / len(results["optimized_times"]) - - results["optimization_impact"] = { - "avg_original": avg_original, - "avg_optimized": avg_optimized, - "improvement": (avg_original - avg_optimized) / avg_original * 100, - "speedup": avg_original / avg_optimized, - } - - return results diff --git a/src/winml/modelkit/core/universal_hierarchy_exporter.py b/src/winml/modelkit/core/universal_hierarchy_exporter.py deleted file mode 100644 index efa50cd9f..000000000 --- a/src/winml/modelkit/core/universal_hierarchy_exporter.py +++ /dev/null @@ -1,881 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -#!/usr/bin/env python3 -"""Universal Hierarchy-Preserving ONNX Exporter. - -This is a clean, new implementation that follows all CARDINAL RULES and requirements: - -CARDINAL RULES: -- MUST-001: NO HARDCODED LOGIC - Universal PyTorch principles only -- MUST-002: TORCH.NN FILTERING - Filter most torch.nn except supportedlist -- MUST-003: UNIVERSAL DESIGN - Must work with ANY PyTorch model - -REQUIREMENTS: -- R7: Topology Preservation - 100% identical to baseline -- R10: Operation Attribution - Map every ONNX op to source module -- R12: Instance-Specific Paths - Preserve instance numbers (BertLayer.0 vs BertLayer.1) - -Based on insights from pytorch_internals_investigation.ipynb and ground truth analysis. -""" - -import json -import logging -import time -from collections.abc import Callable -from pathlib import Path -from typing import Any - -import torch -import torch.nn as nn - -from ..onnx import load_onnx - - -logger = logging.getLogger(__name__) - - -class UniversalHierarchyExporter: - """Universal hierarchy-preserving ONNX exporter using PyTorch's built-in mechanisms. - - This implementation leverages PyTorch's internal _trace_module_map which already - contains enhanced scope names in the format: ClassName::__module.path.to.module - - Follows all CARDINAL RULES: - - NO HARDCODED LOGIC: Works with any PyTorch model - - TORCH.NN FILTERING: Filters torch.nn modules except supportedlist - - UNIVERSAL DESIGN: Architecture-agnostic approach - """ - - def __init__(self, torch_nn_exceptions: list[str] | None = None, verbose: bool = False) -> None: - """Initialize the universal hierarchy exporter. - - Args: - torch_nn_exceptions: List of torch.nn module types to preserve. - For example, ['LayerNorm', 'Embedding']. - verbose: Enable verbose logging - """ - self.torch_nn_exceptions = set(torch_nn_exceptions or ["LayerNorm", "Embedding"]) - self.verbose = verbose - - # Internal state - self._trace_module_map: dict[nn.Module, str] = {} - self._module_hierarchy: dict[str, dict[str, Any]] = {} - self._operation_tags: dict[str, list[str]] = {} - self._export_stats = { - "total_modules": 0, - "tagged_operations": 0, - "filtered_modules": 0, - "export_time": 0.0, - } - - # Dynamic tagging state (hybrid approach) - self._tag_stack: list[str] = [] - self._operation_context: dict[str, dict[str, Any]] = {} - self._pre_hooks: list = [] - self._post_hooks: list = [] - self._onnx_operation_tags: dict[str, str] = {} - - def export( - self, - model: nn.Module, - args: tuple[torch.Tensor, ...], - output_path: str, - input_names: list[str] | None = None, - output_names: list[str] | None = None, - dynamic_axes: dict[str, dict[int, str]] | None = None, - opset_version: int = 17, - do_constant_folding: bool = True, - **export_kwargs: Any, - ) -> dict[str, Any]: - """Export model to ONNX with hierarchy preservation. - - Args: - model: PyTorch model to export - args: Input tensors for the model - output_path: Path to save the ONNX file - input_names: Names for input tensors - output_names: Names for output tensors - dynamic_axes: Dynamic axes configuration - opset_version: ONNX opset version - do_constant_folding: Enable constant folding optimization - **export_kwargs: Additional arguments for torch.onnx.export - - Returns: - Dictionary with export statistics and metadata - """ - start_time = time.time() - - if self.verbose: - logger.info(f"Starting universal hierarchy export for {type(model).__name__}") - - # Step 1: Analyze model hierarchy - self._analyze_model_hierarchy(model) - - # Step 2: Set model to eval mode - model.eval() - - # Step 3: Register dynamic hooks with selective approach - self._register_dynamic_hooks(model) - - # Step 4: Set up trace module map capture - captured_trace_map = self._setup_trace_capture() - - # Step 5: Perform ONNX export with trace capture - try: - self._perform_onnx_export( - model, - args, - output_path, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=opset_version, - do_constant_folding=do_constant_folding, - **export_kwargs, - ) - finally: - self._restore_trace_capture() - self._remove_dynamic_hooks() - - # Step 5: Process captured trace map - if captured_trace_map: - self._process_trace_module_map(captured_trace_map) - - # Step 6: Apply dynamic operation tags to ONNX - self._apply_dynamic_tags_to_onnx(output_path) - - # Step 7: Load ONNX and inject hierarchy metadata - self._inject_hierarchy_metadata(output_path) - - # Step 8: Create sidecar metadata file - self._create_sidecar_metadata(output_path) - - # Calculate final statistics - self._export_stats["export_time"] = time.time() - start_time - - if self.verbose: - logger.info(f"Export completed in {self._export_stats['export_time']:.2f}s") - logger.info(f"Tagged {self._export_stats['tagged_operations']} operations") - - return self._export_stats.copy() - - def _analyze_model_hierarchy(self, model: nn.Module) -> None: - """Analyze model hierarchy using universal PyTorch principles. - - CARDINAL RULE: NO HARDCODED LOGIC - works with any model - """ - self._module_hierarchy = {} - - # First pass: Extract basic module metadata without tags - # Analyze root module - self._module_hierarchy["__module"] = self._extract_module_metadata(model, "", "__module") - - # Analyze all submodules using named_modules() - universal approach - for name, module in model.named_modules(): - if name: # Skip root (empty name) - full_path = f"__module.{name}" - self._module_hierarchy[full_path] = self._extract_module_metadata( - module, name, full_path - ) - - # Second pass: Generate hierarchy tags now that all modules are catalogued - for full_path, module_data in self._module_hierarchy.items(): - module_data["expected_tag"] = self._generate_hierarchy_tag( - full_path, module_data["class_name"] - ) - - self._export_stats["total_modules"] = len(self._module_hierarchy) - - if self.verbose: - logger.info(f"Analyzed {len(self._module_hierarchy)} modules in hierarchy") - - def _extract_module_metadata( - self, module: nn.Module, name: str, full_path: str - ) -> dict[str, Any]: - """Extract metadata for a module using universal PyTorch principles. - - CARDINAL RULE: NO HARDCODED LOGIC - works with any module type - """ - module_class = type(module).__name__ - module_path = type(module).__module__ - - # Classify module type universally - if module_path.startswith("torch.nn"): - module_type = "torch.nn" - elif "transformers" in module_path: - module_type = "huggingface" - elif module_path.startswith("torch"): - module_type = "torch_other" - else: - module_type = "custom" - - # Apply MUST-002: torch.nn filtering - should_filter = module_type == "torch.nn" and module_class not in self.torch_nn_exceptions - - return { - "name": name, - "full_path": full_path, - "class_name": module_class, - "module_type": module_type, - "module_class_path": module_path, - "should_filter": should_filter, - "expected_tag": "", # Will be filled in second pass - "hierarchy_level": full_path.count(".") - 1, # Subtract 1 for __module - "children": [ - (child_name, type(child_module).__name__) - for child_name, child_module in module.named_children() - ], - "is_leaf": len(list(module.children())) == 0, - "parameter_count": sum(p.numel() for p in module.parameters()), - } - - def _generate_hierarchy_tag(self, full_path: str, module_class: str) -> str: - """Generate hierarchy tag following R12: Instance-Specific Hierarchy Paths. - - CARDINAL RULE: NO HARDCODED LOGIC - universal conversion of paths to tags - CARDINAL RULE #2: Stop at parent level for torch.nn modules (except LayerNorm/Embedding) - - Builds proper hierarchy by walking from root to leaf using actual module class names. - """ - # Check if this module should be filtered - module_data = self._module_hierarchy.get(full_path) - if not module_data: - return "" - - # For filtered torch.nn modules, return the parent's tag instead of empty - if module_data["should_filter"]: - # Find the parent module's tag by walking up the hierarchy - parent_path = ".".join(full_path.split(".")[:-1]) - if parent_path and parent_path in self._module_hierarchy: - # Recursively get the parent's tag - return self._generate_hierarchy_tag( - parent_path, self._module_hierarchy[parent_path]["class_name"] - ) - return "" # Only empty if no valid parent - - # Build hierarchy by walking path segments from root to current - path_segments = full_path.split(".") - hierarchy_parts = [] - - # Walk each segment and build cumulative path - i = 0 - while i < len(path_segments): - segment = path_segments[i] - - # Check if this segment is a digit (instance number) - if segment.isdigit(): - # When we have a digit, we need to check what module it represents - current_path = ".".join(path_segments[: i + 1]) - current_module_data = self._module_hierarchy.get(current_path) - - if current_module_data and not current_module_data["should_filter"]: - # This digit represents an actual module (e.g., layer.0 -> BertLayer) - # Add the module with its instance number - class_name = current_module_data["class_name"] - hierarchy_parts.append(f"{class_name}.{segment}") - # If the module at this digit path is filtered, the digit is already handled - else: - # Build cumulative path to this point for non-digit segments - current_path = ".".join(path_segments[: i + 1]) - current_module_data = self._module_hierarchy.get(current_path) - - if current_module_data and not current_module_data["should_filter"]: - class_name = current_module_data["class_name"] - hierarchy_parts.append(class_name) - - i += 1 - - # Return full hierarchy path from root to leaf - if hierarchy_parts: - return "/" + "/".join(hierarchy_parts) - return "" - - def _to_pascal_case(self, text: str) -> str: - """Convert text to PascalCase universally.""" - if not text: - return text - - # Handle snake_case and already PascalCase - if "_" in text: - parts = text.split("_") - return "".join(word.capitalize() for word in parts) - if text.islower(): - return text.capitalize() - return text # Already in proper case - - def _setup_trace_capture(self) -> dict[nn.Module, str]: - """Set up capture of PyTorch's internal _trace_module_map. - - This leverages PyTorch's existing infrastructure - the key insight from - pytorch_internals_investigation.ipynb - """ - self._original_setup_trace = getattr(torch.onnx.utils, "_setup_trace_module_map", None) - self._captured_trace_map: dict[Any, Any] = {} - - def enhanced_setup_trace(*args: Any, **kwargs: Any) -> Any: - """Hook to capture trace module map after PyTorch creates it.""" - # Call original setup - result = None - if self._original_setup_trace: - result = self._original_setup_trace(*args, **kwargs) - - # Capture the enhanced trace map PyTorch creates - trace_map = getattr(torch.jit._trace, "_trace_module_map", None) - if trace_map: - self._captured_trace_map = dict(trace_map) - if self.verbose: - logger.info(f"Captured trace module map with {len(trace_map)} entries") - - return result - - # Apply hook if available - if self._original_setup_trace: - torch.onnx.utils._setup_trace_module_map = enhanced_setup_trace - - return self._captured_trace_map - - def _restore_trace_capture(self) -> None: - """Restore original trace setup function.""" - if hasattr(self, "_original_setup_trace") and self._original_setup_trace: - torch.onnx.utils._setup_trace_module_map = self._original_setup_trace - - def _perform_onnx_export( - self, - model: nn.Module, - args: tuple[torch.Tensor, ...], - output_path: str, - **export_kwargs: Any, - ) -> None: - """Perform standard ONNX export to ensure R7: Topology Preservation. - - CARDINAL RULE: Use standard torch.onnx.export to guarantee identical topology - """ - torch.onnx.export(model, args, output_path, verbose=self.verbose, **export_kwargs) - - if self.verbose: - logger.info(f"ONNX export completed: {output_path}") - - def _process_trace_module_map(self, trace_map: dict[nn.Module, str]) -> None: - """Process captured trace module map to build operation tags. - - This is where we leverage PyTorch's enhanced scope names discovered in the notebook. - """ - if not trace_map: - if self.verbose: - logger.warning("No trace module map captured") - return - - # Convert trace map to operation tags - for module, scope_name in trace_map.items(): - module_id = id(module) - - # Find corresponding hierarchy metadata - hierarchy_data = None - for data in self._module_hierarchy.values(): - if id(module) == module_id: - hierarchy_data = data - break - - if hierarchy_data and not hierarchy_data["should_filter"]: - # Use the expected tag from our hierarchy analysis - tag = hierarchy_data["expected_tag"] - if tag: - # This would map to ONNX operations in a full implementation - self._operation_tags[scope_name] = [tag] - - if self.verbose: - logger.info(f"Processed {len(self._operation_tags)} operation tags") - - def _inject_hierarchy_metadata(self, onnx_path: str) -> None: - """Inject hierarchy metadata into ONNX model. - - For now, this creates a foundation for metadata injection. - In a full implementation, this would add attributes to ONNX nodes. - """ - try: - onnx_model = load_onnx(onnx_path, validate=False) - - # Count operations for statistics - total_nodes = len(onnx_model.graph.node) - - # For now, we prepare the metadata but don't modify the ONNX - # In a full implementation, this would iterate through nodes and add attributes - - # Count how many operations would be tagged - tagged_count = sum(len(tags) for tags in self._operation_tags.values()) - - self._export_stats["tagged_operations"] = tagged_count - - if self.verbose: - logger.info(f"Analyzed {total_nodes} ONNX nodes") - logger.info(f"Would tag {tagged_count} operations") - - except Exception as e: - if self.verbose: - logger.error(f"Error analyzing ONNX model: {e}") - - def _create_sidecar_metadata(self, onnx_path: str) -> None: - """Create comprehensive sidecar metadata file. - - This preserves all hierarchy information for reconstruction and analysis. - """ - sidecar_path = str(onnx_path).replace(".onnx", "_hierarchy_metadata.json") - - metadata = { - "export_info": { - "onnx_file": Path(onnx_path).name, - "exporter_version": "UniversalHierarchyExporter v1.0", - "export_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"), - "cardinal_rules_followed": { - "MUST_001_no_hardcoded_logic": True, - "MUST_002_torch_nn_filtering": True, - "MUST_003_universal_design": True, - }, - "requirements_met": { - "R7_topology_preservation": True, - "R10_operation_attribution": True, - "R12_instance_specific_paths": True, - }, - }, - "statistics": self._export_stats, - "module_hierarchy": self._module_hierarchy, - "operation_tags": self._operation_tags, - "torch_nn_exceptions": list(self.torch_nn_exceptions), - "reconstruction_guide": { - "overview": "This metadata enables complete hierarchy reconstruction", - "tag_format": ( - "Hierarchy tags follow R12: /ClassName/ParentClass/ChildClass.instanceNumber" - ), - "filtering": "torch.nn modules filtered except those in torch_nn_exceptions", - "verification": "Compare against ground truth in docs/BERT_TINY_GROUND_TRUTH.md", - }, - } - - with Path(sidecar_path).open("w") as f: - json.dump(metadata, f, indent=2) - - if self.verbose: - logger.info(f"Created sidecar metadata: {Path(sidecar_path).name}") - - def get_hierarchy_metadata(self) -> dict[str, Any]: - """Get the complete hierarchy metadata.""" - return { - "module_hierarchy": self._module_hierarchy, - "operation_tags": self._operation_tags, - "export_stats": self._export_stats, - } - - def validate_against_ground_truth(self, expected_tags: dict[str, str]) -> dict[str, Any]: - """Validate export results against ground truth. - - Args: - expected_tags: Dictionary mapping module paths to expected hierarchy tags - - Returns: - Validation results - """ - validation: dict[str, Any] = { - "passed": True, - "total_modules": len(expected_tags), - "correct_tags": 0, - "missing_tags": [], - "incorrect_tags": [], - "extra_tags": [], - } - - # Check each expected tag - for module_path, expected_tag in expected_tags.items(): - hierarchy_data = self._module_hierarchy.get(module_path) - - if not hierarchy_data: - validation["missing_tags"].append( - {"module_path": module_path, "expected_tag": expected_tag} - ) - continue - - actual_tag = hierarchy_data.get("expected_tag", "") - - if actual_tag == expected_tag: - validation["correct_tags"] += 1 - else: - validation["incorrect_tags"].append( - { - "module_path": module_path, - "expected_tag": expected_tag, - "actual_tag": actual_tag, - } - ) - - # Check for extra tags - for module_path, hierarchy_data in self._module_hierarchy.items(): - if module_path not in expected_tags: - actual_tag = hierarchy_data.get("expected_tag", "") - if actual_tag: # Non-empty tag - validation["extra_tags"].append( - {"module_path": module_path, "actual_tag": actual_tag} - ) - - # Determine if validation passed - validation["passed"] = ( - len(validation["missing_tags"]) == 0 and len(validation["incorrect_tags"]) == 0 - ) - - return validation - - def _register_dynamic_hooks(self, model: nn.Module) -> None: - """Register forward hooks for dynamic operation tagging during ONNX export. - - Uses the static hierarchy analysis to create dynamic hooks that will - tag operations in real-time during ONNX export. - """ - # Initialize tag stack with root module - root_tag = f"/{model.__class__.__name__}" - self._tag_stack = [root_tag] - - # Clear previous state - self._operation_context.clear() - self._onnx_operation_tags.clear() - - # Count modules to hook - hf_modules = [] - torch_nn_modules = [] - - for full_path, module_data in self._module_hierarchy.items(): - if full_path == "__module": # Skip root - continue - - if not module_data["should_filter"]: - hf_modules.append((full_path, module_data)) - else: - torch_nn_modules.append((full_path, module_data)) - - if self.verbose: - logger.info("Registering selective dynamic hooks:") - logger.info(f" - HuggingFace modules: {len(hf_modules)}") - logger.info(f" - torch.nn modules: {len(torch_nn_modules)} (limited hooks)") - - # Register hooks on HuggingFace modules only - for full_path, module_data in hf_modules: - # Get the actual module - module = self._get_module_by_path(model, full_path.replace("__module.", "")) - if module is None: - continue - - module_name = module_data["name"] - expected_tag = module_data["expected_tag"] - - # Create hierarchy-building hooks (push/pop tag stack) - pre_hook = module.register_forward_pre_hook( - self._create_pre_hook(module_name, expected_tag) - ) - self._pre_hooks.append(pre_hook) - - post_hook = module.register_forward_hook( - self._create_post_hook(module_name, expected_tag) - ) - self._post_hooks.append(post_hook) - - # For torch.nn modules, only register lightweight tagging hooks on a few - # This avoids potential conflicts with ONNX export - torch_nn_hook_count = 0 - max_torch_nn_hooks = 5 # Limit to avoid issues - - for full_path, module_data in torch_nn_modules: - if torch_nn_hook_count >= max_torch_nn_hooks: - break - - module = self._get_module_by_path(model, full_path.replace("__module.", "")) - if module is None: - continue - - # Only hook important torch.nn modules (LayerNorm, Embedding) - if module.__class__.__name__ in self.torch_nn_exceptions: - module_name = module_data["name"] - expected_tag = module_data["expected_tag"] - - tag_hook = module.register_forward_hook( - self._create_tagging_hook(module_name, expected_tag) - ) - self._post_hooks.append(tag_hook) - torch_nn_hook_count += 1 - - if self.verbose: - pre_count = len(self._pre_hooks) - post_count = len(self._post_hooks) - logger.info(f"Registered {pre_count} pre-hooks and {post_count} post-hooks") - - def _get_module_by_path(self, model: nn.Module, path: str) -> nn.Module | None: - """Get module by its dotted path.""" - if not path: - return model - - parts = path.split(".") - current = model - - for part in parts: - if hasattr(current, part): - current = getattr(current, part) - else: - return None - - return current - - def _create_pre_hook(self, module_name: str, expected_tag: str) -> Callable[..., None]: - """Create pre-forward hook to push tag onto stack.""" - - def pre_hook(module: nn.Module, inputs: Any) -> None: - self._tag_stack.append(expected_tag) - - # Record context for operation mapping - self._operation_context[module_name] = { - "tag": expected_tag, - "creates_hierarchy": True, - "stack_depth": len(self._tag_stack), - "module_class": module.__class__.__name__, - } - - return pre_hook - - def _create_post_hook(self, module_name: str, expected_tag: str) -> Callable[..., None]: - """Create post-forward hook to pop tag from stack.""" - - def post_hook(module: nn.Module, inputs: Any, outputs: Any) -> None: - if self._tag_stack and self._tag_stack[-1] == expected_tag: - self._tag_stack.pop() - - return post_hook - - def _create_tagging_hook(self, module_name: str, expected_tag: str) -> Callable[..., None]: - """Create tagging hook for filtered modules.""" - - def tagging_hook(module: nn.Module, inputs: Any, outputs: Any) -> None: - # Record context using parent tag - current_tag = self._tag_stack[-1] if self._tag_stack else "" - - self._operation_context[module_name] = { - "tag": expected_tag, # Use the computed parent tag - "creates_hierarchy": False, - "parent_tag": current_tag, - "module_class": module.__class__.__name__, - } - - return tagging_hook - - def _remove_dynamic_hooks(self) -> None: - """Remove all registered hooks.""" - for hook in self._pre_hooks: - hook.remove() - for hook in self._post_hooks: - hook.remove() - - self._pre_hooks.clear() - self._post_hooks.clear() - - if self.verbose: - logger.info("Removed all dynamic hooks") - - def _apply_dynamic_tags_to_onnx(self, output_path: str) -> None: - """Process the ONNX model and create operation tag mappings. - - Note: We don't modify the ONNX file itself (which would break validation). - Instead, we build a mapping of operation names to hierarchy tags that - can be used for filtering and analysis. - """ - if not self._operation_context: - if self.verbose: - logger.info("No dynamic operation context captured") - return - - # Load the ONNX model for analysis - try: - onnx_model = load_onnx(output_path, validate=False) - except Exception as e: - logger.warning(f"Could not load ONNX model for analysis: {e}") - return - - # Create operation tag mapping based on captured context - for operation_count, node in enumerate(onnx_model.graph.node): - node_name = node.name or f"{node.op_type}_{operation_count}" - - # Find best matching context based on operation type - best_tag = self._find_best_tag_for_operation(node) - if best_tag: - self._onnx_operation_tags[node_name] = best_tag - - # Also store in operation_tags for metadata - if node_name not in self._operation_tags: - self._operation_tags[node_name] = [] - self._operation_tags[node_name].append(best_tag) - - # Update statistics - self._export_stats["tagged_operations"] = len(self._onnx_operation_tags) - - if self.verbose: - logger.info(f"Mapped {len(self._onnx_operation_tags)} operations to hierarchy tags") - - def _find_best_tag_for_operation(self, node: Any) -> str: - """Find the best hierarchy tag for an ONNX operation. - - Uses the operation name structure to match with module hierarchy. - ONNX operations often have names like: - - /embeddings/word_embeddings/Gather - - /encoder/layer.0/attention/self/query/MatMul - - /pooler/dense/Gemm - """ - node_name = node.name or f"{node.op_type}_{id(node)}" - - # Strategy 1: Match operation path with module paths - # Remove leading slash and operation type suffix - op_path = node_name.lstrip("/") - - # Remove operation type from the end (e.g., /Gather, /MatMul) - path_parts = op_path.split("/") - if path_parts and path_parts[-1] in [ - "Gather", - "MatMul", - "Add", - "LayerNormalization", - "Gemm", - "Tanh", - "Softmax", - "Div", - "Mul", - "Sub", - "Transpose", - "Reshape", - "Constant", - "Shape", - "Unsqueeze", - "Concat", - "Slice", - "Where", - "Cast", - "Expand", - "Equal", - "ConstantOfShape", - "Sqrt", - "Erf", - ]: - op_path = "/".join(path_parts[:-1]) - - # Try to find the best matching module based on path similarity - best_match: str | None = None - best_score = 0 - - for module_name, context in self._operation_context.items(): - if not context.get("tag"): - continue - - # Calculate match score based on common path components - module_path = module_name.lower().replace(".", "/") - op_path_lower = op_path.lower() - - # Check if operation path contains module path components - if module_path in op_path_lower: - score = len(module_path) - if score > best_score: - best_score = score - best_match = context["tag"] - - # Also check individual components - module_parts = module_path.split("/") - op_parts = op_path_lower.split("/") - common_parts = sum(1 for mp in module_parts if mp in op_parts) - if common_parts > best_score: - best_score = common_parts - best_match = context["tag"] - - # If we found a good match, use it - if best_match: - return best_match - - # Strategy 2: Use operation context if no path match - # This was the original approach - use as fallback - for context in reversed(list(self._operation_context.values())): - if context.get("tag"): - return str(context["tag"]) - - # Final fallback to root tag - return f"/{self._get_root_class_name()}" if self._module_hierarchy else "" - - def _get_root_class_name(self) -> str: - """Get the root module class name.""" - root_data = self._module_hierarchy.get("__module") - return root_data.get("class_name", "Model") if root_data else "Model" - - -def create_bert_tiny_exporter() -> UniversalHierarchyExporter: - """Create exporter configured for BERT-tiny following ground truth specifications. - - This follows the exact configuration from docs/BERT_TINY_GROUND_TRUTH.md - """ - return UniversalHierarchyExporter( - torch_nn_exceptions=["LayerNorm", "Embedding"], # From ground truth - verbose=True, - ) - - -def export_bert_tiny_with_validation() -> dict[str, Any]: - """Export BERT-tiny and validate against ground truth. - - This demonstrates the complete workflow following all requirements. - """ - from transformers import AutoModel, AutoTokenizer - - # Load model (CARDINAL RULE: NO HARDCODED LOGIC - this works with any HF model) - from ..constants import DEFAULT_TEST_MODEL # type: ignore[import-not-found] - - model_name = DEFAULT_TEST_MODEL - model = AutoModel.from_pretrained(model_name) - tokenizer = AutoTokenizer.from_pretrained(model_name) - - # Prepare inputs - text = "Hello world" - inputs = tokenizer( - text, return_tensors="pt", max_length=128, padding="max_length", truncation=True - ) - input_ids = inputs["input_ids"] - attention_mask = inputs["attention_mask"] - - # Create exporter - exporter = create_bert_tiny_exporter() - - # Export with hierarchy preservation - output_path = "temp/bert_tiny_universal_export.onnx" - Path("temp").mkdir(exist_ok=True) - - export_result = exporter.export( - model=model, - args=(input_ids, attention_mask), - output_path=output_path, - input_names=["input_ids", "attention_mask"], - output_names=["last_hidden_state"], - dynamic_axes={ - "input_ids": {0: "batch_size", 1: "sequence"}, - "attention_mask": {0: "batch_size", 1: "sequence"}, - "last_hidden_state": {0: "batch_size", 1: "sequence"}, - }, - opset_version=17, - do_constant_folding=True, - ) - - return { - "export_result": export_result, - "output_path": output_path, - "hierarchy_metadata": exporter.get_hierarchy_metadata(), - } - - -if __name__ == "__main__": - # Demonstrate the universal hierarchy exporter - print("🎯 Universal Hierarchy Exporter - BERT-tiny Demo") - print("=" * 60) - - result = export_bert_tiny_with_validation() - - print("✅ Export completed successfully!") - print(f"📁 Output: {result['output_path']}") - print(f"📊 Statistics: {result['export_result']}") - print(f"🏷️ Hierarchy metadata: {len(result['hierarchy_metadata']['module_hierarchy'])} modules")