From a0addabda04d41334fbb6807837597a58aa411fb Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Wed, 13 May 2026 19:49:07 +0800 Subject: [PATCH 1/6] [Feat] Support dataclass in magi_register_custom_op --- magi_compiler/_magi_register_custom_op.py | 971 ++++++++++-- magi_compiler/api.py | 98 +- tests/api_tests/test_register_custom_op.py | 1593 ++++++++++++++++++-- 3 files changed, 2417 insertions(+), 245 deletions(-) diff --git a/magi_compiler/_magi_register_custom_op.py b/magi_compiler/_magi_register_custom_op.py index 3f770ec..f50e668 100644 --- a/magi_compiler/_magi_register_custom_op.py +++ b/magi_compiler/_magi_register_custom_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 SandAI. All Rights Reserved. +# Copyright (c) 2026 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,34 +12,626 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Magi custom-op registration: dataclass-aware wrapper around ``torch.library``. + +This module implements ``magi_register_custom_op`` -- a decorator that takes a +plain Python function (possibly with frozen-dataclass parameters, +``Literal[str, ...]`` / string-Enum annotations, or other Python-rich +signatures that ``torch.library.infer_schema`` cannot consume) and registers +it as a real custom op while letting the user keep calling it with their +original, ergonomic signature. + + +Part A. Registration-time pipeline -- the four slots +==================================================== + +When ``@magi_register_custom_op(...)`` is applied to a user function, up to +four named slots are produced. Each slot is a concrete callable object. + + slot 0 -- fn + The user's original function. Always present. + + slot 1 -- lowered_fn + A thin wrapper around ``fn`` whose ``__signature__`` / + ``__annotations__`` have been *lowered* (Literal/Enum -> str, + unsupported defaults scrubbed, dataclasses flattened into primitive + leaves) so that ``torch.library.infer_schema`` accepts it. + Skipped when ``fn``'s signature is already schema-compatible. + + slot 2 -- torch_registered_op + The ``OpOverload`` returned by ``torch.library.custom_op`` / + ``register_fake`` after registering whichever of ``fn`` / + ``lowered_fn`` reached this point. Always present. + + slot 3 -- magi_exposed_op + A magi-level Python wrapper around ``torch_registered_op`` that + preserves the user's ORIGINAL (dataclass-bearing) calling + convention. At call time it flattens incoming args via the static + ``param_mapping_tree`` and dispatches into slot 2. Only created + on the dataclass-flatten path. + +The naming is a deliberate dual: ``torch_registered_op`` is *registered +into* torch.library's dispatcher; ``magi_exposed_op`` is *exposed out of* +magi to user code. + + +Part B. Runtime paths -- the three pipelines +============================================ + +Three pipelines are possible; the decorator returns whichever object sits +at the end of the path: + + 1. simple fn -> torch_registered_op + Returned: ``torch._ops.OpOverload`` (slot 2). + Runtime: zero magi-level overhead -- straight into torch.library's + dispatcher. + + 2. sig-only-rewrite fn -> lowered_fn -> torch_registered_op + Returned: ``torch._ops.OpOverload`` (slot 2). + Runtime: same as simple -- ``lowered_fn`` is a transparent + forwarding shim (the rewrite is registration-time only). + + 3. dataclass-flatten fn -> lowered_fn -> torch_registered_op + -> magi_exposed_op + Returned: a Python callable carrying the + ``_magi_torch_registered_op`` attribute (slot 3). + Runtime forward (per call): + user code calls magi_exposed_op(x, cfg=...) + -> _flatten_call_args (original kwargs -> flat tuple) + -> _flatten_value_into (DFS over param_mapping_tree) + -> torch_registered_op(*flat) (slot 2 -- enters dispatcher) + -> lowered_fn(*flat) (slot 1 -- still in lowered shape) + -> _reassemble_kwargs (flat tuple -> original kwargs) + -> _build_value_from_node (rebuilds dataclass instances) + -> fn(**original_kwargs) (slot 0 -- user code finally sees + its original dataclass-bearing + signature) + Runtime backward (when backward_fn is supplied): + autograd calls _bridged_backward(ctx, *grads) + -> user_backward(ctx, *grads) (returns one grad per ORIGINAL + input, possibly a dataclass-shaped + grad object) + -> _flatten_grads (original grads -> flat grads) + -> _flatten_grad_into (DFS over param_mapping_tree) + +You can tell at runtime which pipeline an op went through by inspecting +the decorator's return value: an ``OpOverload`` means simple/sig-rewrite; +a Python callable carrying ``_magi_torch_registered_op`` means +dataclass-flatten. + + +File layout +=========== + + -- registration-time helpers (executed once) -- + 1. Validate the user's fn signature + 2. Resolve types & sanitise defaults for infer_schema + 3. Build & query the param mapping tree (used by sec 4 and sec 7) + 4. Lower fn's signature (produces slot 1) + 5. Synthesise the meta/fake function (input to slot 2) + 6. Register the op (produces slot 2) + + -- runtime helpers (executed on every call) -- + 7. Runtime bridge: flatten / unflatten on every call + + -- main pipeline -- + 8. The decorator: orchestrates sec 1-6 and builds the runtime + closures from sec 7 (produces slot 3 on the flatten path) +""" + +import dataclasses +import functools import inspect -from typing import Callable, get_args, get_origin +from typing import Any, Callable, get_args, get_origin import torch +import torch.utils._pytree as pytree from .config import get_compile_config +from .utils.logger import magi_logger + +# ============================================================================== +# 1. Validate the user's fn signature +# ------------------------------------------------------------------------------ +# Predicate + assert helpers that reject `fn` signatures torch.library cannot +# consume, each raising a clear `TypeError` instead of the opaque error that +# would otherwise surface deep inside `infer_schema`. Called from +# `_lower_op_signature` (sec 4) and `_build_dataclass_sub_mapping_tree` (sec 3). +# ============================================================================== + + +def _is_frozen_dataclass(tp) -> bool: + """Return True if ``tp`` is a frozen dataclass type.""" + return ( + isinstance(tp, type) + and dataclasses.is_dataclass(tp) + and getattr(tp, "__dataclass_params__", None) is not None + and tp.__dataclass_params__.frozen + ) + + +def _assert_not_unsupported_container(tp, *, where: str) -> None: + """Reject ``tuple[...]`` / ``dict[...]`` annotations (schema only models ``list``).""" + origin = get_origin(tp) + if origin is tuple: + raise TypeError( + f"@magi_register_custom_op: {where} has tuple annotation {tp!r}; " + f"use ``list[...]`` or split into separate fields." + ) + if origin is dict: + raise TypeError( + f"@magi_register_custom_op: {where} has dict-typed annotation {tp!r}; " f"promote the values to explicit fields." + ) + + +def _assert_not_dataclass_return(tp, *, fn_name: str) -> None: + """Reject dataclass return annotations (schema only returns Tensor / tuple / list / None).""" + if isinstance(tp, type) and dataclasses.is_dataclass(tp): + raise TypeError( + f"@magi_register_custom_op: {fn_name!r} returns dataclass " + f"{tp.__name__!r}; only Tensor / tuple[Tensor, ...] / list[Tensor] " + f"are supported -- destructure into a tuple at the op boundary." + ) + + +def _assert_not_mutable_dataclass(tp, *, where: str) -> None: + """Reject non-frozen dataclasses (schema needs hashable, stable inputs).""" + if ( + isinstance(tp, type) + and dataclasses.is_dataclass(tp) + and getattr(tp, "__dataclass_params__", None) is not None + and not tp.__dataclass_params__.frozen + ): + raise TypeError( + f"@magi_register_custom_op: {where} has mutable dataclass " + f"{tp.__name__!r}; add ``frozen=True`` to {tp.__name__}." + ) + + +def _assert_has_annotation(annotation, *, where: str) -> None: + """Require an annotation on every parameter / field / return value (needed + to recognise dataclasses and to feed ``infer_schema``).""" + if annotation is inspect.Parameter.empty or annotation is inspect.Signature.empty: + raise TypeError( + f"@magi_register_custom_op: {where} has no type annotation " f"(e.g. ``x: torch.Tensor`` or ``cfg: MyFrozenCfg``)." + ) + + +def _assert_no_var_args(param: inspect.Parameter, *, fn_name: str) -> None: + """Reject ``*args`` / ``**kwargs`` (op schemas are positional-or-keyword only).""" + if param.kind is inspect.Parameter.VAR_POSITIONAL: + raise TypeError( + f"@magi_register_custom_op: {fn_name!r} declares ``*{param.name}``; " + f"variadics aren't supported -- replace with explicit annotated parameters." + ) + if param.kind is inspect.Parameter.VAR_KEYWORD: + raise TypeError( + f"@magi_register_custom_op: {fn_name!r} declares ``**{param.name}``; " + f"variadics aren't supported -- replace with explicit annotated parameters." + ) + + +def _assert_resolved_field_type(f_type, *, where: str) -> None: + """Reject unresolved string annotations -- typically a local class combined + with stringified annotations that ``get_type_hints`` could not eval.""" + if isinstance(f_type, str): + raise TypeError( + f"@magi_register_custom_op: {where} has unresolved string " + f"annotation {f_type!r}; move the type to module scope so " + f"``get_type_hints`` can resolve it." + ) + + +# ============================================================================== +# 2. Resolve types & sanitise defaults for infer_schema +# ------------------------------------------------------------------------------ +# Resolve stringified annotations to real types, downgrade Literal/string-Enum +# to `str`, and scrub defaults that `infer_schema` cannot render. Called by +# `_lower_op_signature` (sec 4) and `_build_dataclass_sub_mapping_tree` (sec 3). +# ============================================================================== + + +def _resolve_annotations(fn: Callable) -> dict[str, Any]: + """Return ``fn``'s annotations as real types, resolving stringified ones. + + Falls back to per-annotation eval against ``globals + closure nonlocals`` + when ``get_type_hints`` can't resolve atomically (typical for functions + defined inside another function whose annotations reference enclosing + names). + """ + import typing + try: + return typing.get_type_hints(fn) + except Exception: + pass + + # Build an eval namespace from module globals + closure nonlocals. + # ``__globals__`` covers the common case; closure vars from + # ``getclosurevars`` cover annotations that name enclosing locals. + fn_globals = getattr(fn, "__globals__", {}) or {} + namespace: dict[str, Any] = dict(fn_globals) + try: + cv = inspect.getclosurevars(fn) + namespace.update(cv.builtins) + namespace.update(cv.globals) + namespace.update(cv.nonlocals) + except Exception as e: + magi_logger.debug( + "inspect.getclosurevars(%s) failed: %s; falling back to module globals only", + getattr(fn, "__qualname__", fn), + e, + rank="all", + ) + + anns: dict[str, Any] = {} + raw = getattr(fn, "__annotations__", {}) or {} + for k, v in raw.items(): + if isinstance(v, str): + try: + anns[k] = eval(v, namespace, None) + except Exception: + anns[k] = v + else: + anns[k] = v + return anns -def _get_num_outputs_from_return_annotation(fn: Callable) -> int: + +def _resolve_dataclass_field_types(cls: type) -> dict[str, Any]: + """Return ``cls``'s field name -> resolved type (best-effort).""" + import typing as _typing + + try: + return _typing.get_type_hints(cls) + except Exception: + return {f.name: f.type for f in dataclasses.fields(cls)} + + +def _maybe_downgrade_literal_or_enum(annotation, *, where: str): + """Collapse ``Literal[str, ...]`` and string-Enum annotations to plain ``str``. + + Lossless because the op body still receives the original string value. + Mixed/numeric Literals and non-string Enums raise (no safe downgrade). + """ + import enum + import typing + + _LITERAL_STRING_DOWNGRADE_HINT = ( + "Use ``str`` and validate the value inside the op body, e.g. " "``assert mode in ('a', 'b')``." + ) + origin = get_origin(annotation) + if origin is typing.Literal: + choices = get_args(annotation) + if choices and all(isinstance(c, str) for c in choices): + return str + raise TypeError( + f"@magi_register_custom_op: {where} has Literal {annotation!r}; " + f"only ``Literal[str, ...]`` is auto-downgraded. " + f"{_LITERAL_STRING_DOWNGRADE_HINT}" + ) + if isinstance(annotation, type) and issubclass(annotation, enum.Enum): + members = list(annotation) + if members and all(isinstance(m.value, str) for m in members): + return str + raise TypeError( + f"@magi_register_custom_op: {where} has non-string Enum " + f"{annotation.__name__!r}. {_LITERAL_STRING_DOWNGRADE_HINT}" + ) + return annotation + + +_SCHEMA_DEFAULT_TYPES: tuple[type, ...] = (int, float, bool, str, torch.device, torch.dtype) + + +def _schema_compatible_param_default(default: Any) -> Any: + """Scrub a top-level parameter default that ``infer_schema`` cannot render. + + Same rules as :func:`_schema_compatible_default`, but for raw values + rather than ``dataclasses.Field``. """ - Get the number of outputs from the function's return type annotation. + if default is inspect.Parameter.empty: + return inspect.Parameter.empty + if default is None or isinstance(default, _SCHEMA_DEFAULT_TYPES): + return default + return inspect.Parameter.empty + + +def _schema_compatible_default(f: "dataclasses.Field") -> Any: + """Lowered default for dataclass field ``f``: keep ``None`` / int / float / + bool / str / torch.device / torch.dtype; drop everything else (the user- + constructed dataclass instance still carries the real default at runtime).""" + if f.default is not dataclasses.MISSING: + d = f.default + if d is None or isinstance(d, _SCHEMA_DEFAULT_TYPES): + return d + return inspect.Parameter.empty + if f.default_factory is not dataclasses.MISSING: # type: ignore[misc] + try: + d = f.default_factory() + except Exception: + return inspect.Parameter.empty + if d is None or isinstance(d, _SCHEMA_DEFAULT_TYPES): + return d + return inspect.Parameter.empty + return inspect.Parameter.empty + + +# ============================================================================== +# 3. Build & query the param mapping tree +# ------------------------------------------------------------------------------ +# The `param_mapping_tree` is the single source of truth bridging the user's +# (possibly nested-dataclass) signature and the lowered primitive signature. +# Built once at registration time and consumed twice afterwards: by +# `_expand_mutates_args` (statically, sec 6) and by the runtime bridge (sec 7). +# ============================================================================== + +_DATACLASS_PYTREE_REGISTERED: set[type] = set() + + +def _register_dataclass_pytree(cls: type) -> None: + """Register ``cls`` as a pytree node (idempotent) so Dynamo / AOTAutograd + can flatten and unflatten dataclass instances during tracing.""" + if cls in _DATACLASS_PYTREE_REGISTERED: + return + + field_names = tuple(f.name for f in dataclasses.fields(cls)) + + def _flatten(obj): + return [getattr(obj, n) for n in field_names], field_names + + def _unflatten(values, ctx): + return cls(**dict(zip(ctx, values))) + + try: + pytree.register_pytree_node(cls, _flatten, _unflatten) + except ValueError: + # Already registered elsewhere (e.g. user code). + pass + _DATACLASS_PYTREE_REGISTERED.add(cls) + + +def _build_dataclass_sub_mapping_tree(cls: type, attr_name: str, flat_prefix: str) -> tuple[tuple, list[inspect.Parameter]]: + """Recursively expand a frozen-dataclass type into one ``param_mapping_tree`` + subtree plus its flat list of leaf ``inspect.Parameter`` objects (DFS order). + + ``attr_name`` is the field name on the parent dataclass (or the parameter + name on ``fn`` for a top-level dataclass arg). ``flat_prefix`` builds the + leaf parameter name; e.g. ``cfg: OuterCfg(inner: InnerCfg(val: float))`` + becomes a lowered leaf parameter ``cfg__inner__val``. + """ + _register_dataclass_pytree(cls) + + field_types = _resolve_dataclass_field_types(cls) + children: list[tuple] = [] + lowered_params: list[inspect.Parameter] = [] + + for f in dataclasses.fields(cls): + f_type = field_types.get(f.name, f.type) + child_flat_name = f"{flat_prefix}__{f.name}" + _assert_has_annotation(f_type, where=f"field {cls.__name__}.{f.name}") + _assert_resolved_field_type(f_type, where=f"field {cls.__name__}.{f.name}") + _assert_not_mutable_dataclass(f_type, where=f"field {cls.__name__}.{f.name}") + if _is_frozen_dataclass(f_type): + sub_node, sub_params = _build_dataclass_sub_mapping_tree(f_type, attr_name=f.name, flat_prefix=child_flat_name) + children.append(sub_node) + lowered_params.extend(sub_params) + else: + _assert_not_unsupported_container(f_type, where=f"field {cls.__name__}.{f.name}") + f_type = _maybe_downgrade_literal_or_enum(f_type, where=f"field {cls.__name__}.{f.name}") + children.append(("primitive", f.name, child_flat_name, None)) + lowered_params.append( + inspect.Parameter( + child_flat_name, + # POSITIONAL_OR_KEYWORD: torch.library.custom_op does not yet + # support kwarg-only Tensor arguments. + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=_schema_compatible_default(f), + annotation=f_type, + ) + ) + + return ("dataclass", attr_name, cls, children), lowered_params + + +def _count_leaves(node: tuple) -> int: + """Number of lowered parameter slots a ``param_mapping_tree`` ``node`` occupies.""" + if node[0] == "primitive": + return 1 + return sum(_count_leaves(c) for c in node[3]) + + +def _collect_tensor_leaf_lowered_names(node: tuple) -> list[str]: + """Lowered names of every leaf under ``node``. Used to expand a dataclass + parameter referenced in ``mutates_args`` (torch.library does its own + Tensor-type validation, so over-specifying non-Tensor leaves is fine).""" + if node[0] == "primitive": + _, _attr, lowered_name, _ = node + return [lowered_name] + out: list[str] = [] + for child in node[3]: + out.extend(_collect_tensor_leaf_lowered_names(child)) + return out + + +def _expand_mutates_args(mutates_args: tuple[str, ...] | list[str], param_mapping_tree: list[tuple]) -> tuple[str, ...]: + """Translate ``mutates_args`` from the original parameter space to the + lowered space: top-level dataclass names expand to all their leaves; + primitive top-level names and already-lowered names pass through; unknown + names raise ``ValueError`` listing valid choices.""" + if not mutates_args: + return tuple(mutates_args) + by_attr: dict[str, tuple] = {node[1]: node for node in param_mapping_tree} + valid_lowered: set[str] = set() + for node in param_mapping_tree: + valid_lowered.update(_collect_tensor_leaf_lowered_names(node)) + out: list[str] = [] + for name in mutates_args: + if name in by_attr: + node = by_attr[name] + if node[0] == "primitive": + out.append(node[2]) + else: + out.extend(_collect_tensor_leaf_lowered_names(node)) + elif name in valid_lowered: + out.append(name) + else: + raise ValueError( + f"@magi_register_custom_op: mutates_args entry {name!r} does " + f"not match any parameter. Valid: {sorted(by_attr.keys())} " + f"(or lowered: {sorted(valid_lowered)})." + ) + seen: set[str] = set() + deduped: list[str] = [] + for n in out: + if n not in seen: + seen.add(n) + deduped.append(n) + return tuple(deduped) + + +# ============================================================================== +# 4. Lower fn's signature (produces slot 1) +# ------------------------------------------------------------------------------ +# Produces slot 1 (`lowered_fn`) in two stages: +# data: `_lower_op_signature` walks `fn`'s parameters once and emits +# `(original_sig, lowered_sig, param_mapping_tree)`, calling into +# sec 1 (validate), sec 2 (resolve/sanitise), sec 3 (tree build). +# object: `_make_lowered_signature_wrapper` stamps the lowered signature +# onto a forwarding wrapper of `fn`. +# `_signatures_differ` lets the decorator (sec 6) skip the wrapper entirely +# when lowering was a no-op (zero-overhead path). +# ============================================================================== + + +def _signatures_differ(original: inspect.Signature, lowered: inspect.Signature) -> bool: + """True iff ``lowered`` differs from ``original`` on parameter names, + annotations, defaults, kinds, or return annotation. The decorator uses + this to skip slot 1 entirely when lowering was a no-op (zero-overhead path).""" + return original != lowered + + +def _apply_lowered_signature_metadata(wrapper: Callable, lowered_sig: inspect.Signature) -> None: + """In-place: stamp ``wrapper`` with ``lowered_sig`` as its ``__signature__`` + / ``__annotations__``, and strip ``__wrapped__`` so ``inspect.signature`` + cannot fall back to the original (un-lowered) signature on ``fn``.""" + wrapper.__signature__ = lowered_sig + lowered_annotations = { + p.name: p.annotation for p in lowered_sig.parameters.values() if p.annotation is not inspect.Parameter.empty + } + if lowered_sig.return_annotation is not inspect.Signature.empty: + lowered_annotations["return"] = lowered_sig.return_annotation + wrapper.__annotations__ = lowered_annotations + # ``functools.wraps`` sets ``__wrapped__`` -> ``fn``; strip it so + # introspection cannot bypass our ``__signature__``. + try: + del wrapper.__wrapped__ + except AttributeError: + pass + + +def _make_lowered_signature_wrapper(fn: Callable, lowered_sig: inspect.Signature) -> Callable: + """Forwarding wrapper around ``fn`` carrying ``lowered_sig`` as metadata. + Used on the no-flattening path so ``infer_schema`` sees the cleaned-up + signature instead of ``fn``'s original annotations.""" + + @functools.wraps(fn) + def _wrapped(*args, **kwargs): + return fn(*args, **kwargs) + + _apply_lowered_signature_metadata(_wrapped, lowered_sig) + return _wrapped + + +def _lower_op_signature(fn: Callable): + """Lower ``fn``'s signature into a form ``torch.library.infer_schema`` accepts. + + "Lower" is used in the compiler sense (high-level -> low-level): we walk + ``fn``'s parameters once and do six things at the same time -- they all + need the same resolved annotations and the same iteration: + + 1. VALIDATE -- reject variadics, missing annotations, mutable dataclasses, + unsupported containers, dataclass returns (sec 1). + 2. RESOLVE -- turn stringified annotations into real types via + ``_resolve_annotations``, so dataclass detection works. + 3. NORMALIZE -- collapse parameter kinds to POSITIONAL_OR_KEYWORD, + downgrade Literal/Enum to ``str``, scrub unsupported defaults. + 4. FLATTEN -- expand each frozen-dataclass parameter (recursively) into + its primitive leaves via ``_build_dataclass_sub_mapping_tree``. + 5. PYTREE -- side effect of step 4: register every dataclass as a pytree + node so Dynamo / AOTAutograd can trace through it. + 6. EMIT -- assemble ``(original_sig, lowered_sig, param_mapping_tree)``. + + A single pass is intentional: splitting concerns would force re-resolving + annotations and threading accumulator state. When the input is already + schema-compatible the lowered signature is bit-identical to the original, + and the caller's ``_signatures_differ`` check restores the zero-overhead path. Returns: - - 1 if the return type is a single Tensor - - N if the return type is tuple[Tensor, Tensor, ...] with N elements - - 1 if no annotation or unrecognized annotation (default to single output) + original_sig (inspect.Signature): the user's un-flattened signature. + lowered_sig (inspect.Signature): what ``infer_schema`` will see. + param_mapping_tree (list[tuple]): the bridge between the two; a list + of root nodes (one per original parameter), each of which is: + * ``("primitive", attr_name, lowered_name, None)``, or + * ``("dataclass", attr_name, cls, [child_nodes...])``. + ``attr_name`` is the parameter name at top level / field name + deeper down. The same tree drives both runtime translation + directions (sec 7). """ + original_sig = inspect.signature(fn) + resolved = _resolve_annotations(fn) + lowered_params: list[inspect.Parameter] = [] + param_mapping_tree: list[tuple] = [] + + for name, param in original_sig.parameters.items(): + _assert_no_var_args(param, fn_name=fn.__name__) + annotation = resolved.get(name, param.annotation) + _assert_has_annotation(annotation, where=f"parameter {name!r} of {fn.__name__!r}") + _assert_not_mutable_dataclass(annotation, where=f"parameter {name!r}") + if _is_frozen_dataclass(annotation): + node, sub_params = _build_dataclass_sub_mapping_tree(annotation, attr_name=name, flat_prefix=name) + param_mapping_tree.append(node) + lowered_params.extend(sub_params) + else: + _assert_not_unsupported_container(annotation, where=f"parameter {name!r}") + annotation = _maybe_downgrade_literal_or_enum(annotation, where=f"parameter {name!r}") + new_param = param.replace( + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=annotation, + default=_schema_compatible_param_default(param.default), + ) + lowered_params.append(new_param) + param_mapping_tree.append(("primitive", name, name, None)) + + return_annotation = resolved.get("return", original_sig.return_annotation) + _assert_has_annotation(return_annotation, where=f"return value of {fn.__name__!r}") + _assert_not_dataclass_return(return_annotation, fn_name=fn.__name__) + lowered_sig = inspect.Signature(lowered_params, return_annotation=return_annotation) + return original_sig, lowered_sig, param_mapping_tree + + +# ============================================================================== +# 5. Synthesise the meta/fake function (input to slot 2) +# ------------------------------------------------------------------------------ +# Constructors for the meta ("fake") function torch.library uses for shape +# propagation during tracing. Fallbacks: identity meta when the user passes +# no `infer_output_meta_fn`; param-name-echoing meta when they pass a +# `list[str]`. The result is handed to `register_fake` by sec 6. +# ============================================================================== + + +def _get_num_outputs_from_return_annotation(fn: Callable) -> int: + """Output count from ``fn``'s return annotation: ``N`` for + ``tuple[T1, ..., TN]``, else ``1`` (default / unrecognized).""" sig = inspect.signature(fn) return_annotation = sig.return_annotation if return_annotation is inspect.Parameter.empty: return 1 - # Check if it's a tuple type (e.g., tuple[Tensor, Tensor]) origin = get_origin(return_annotation) if origin is tuple: args = get_args(return_annotation) - # Filter out ellipsis (for variable-length tuples like tuple[Tensor, ...]) + # tuple[T, ...] (variable-length) collapses to a single output. if args and args[-1] is not ...: return len(args) return 1 @@ -47,56 +639,17 @@ def _get_num_outputs_from_return_annotation(fn: Callable) -> int: return 1 -def _generate_op_name(fn: Callable) -> str: - """ - Generate a unique operator name from function's name and source file. - - Format: {filename_stem}::{function_name} - Example: my_module.py with function `my_op` -> "my_module::my_op" - - Falls back to "magi_custom::{function_name}" if source file cannot be determined. - """ - import re - from pathlib import Path - - func_name = fn.__name__ - - # Get the source file path - try: - source_file = inspect.getfile(fn) - # Extract the file stem (without extension) as namespace - namespace = Path(source_file).stem - # Clean up namespace: replace invalid characters with underscores - namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace) - except (TypeError, OSError): - # If we can't get the source file, use a default namespace - namespace = "magi_custom" - - return f"{namespace}::{func_name}" - - def _create_identity_meta_fn(fn: Callable) -> Callable: - """ - Create a default identity meta function for the given function. - - The generated meta function: - - Determines number of outputs from return type annotation - - Uses first N tensor inputs to infer output metadata - - Returns torch.empty_like() tensors with matching shape/dtype/device - - Raises ValueError if not enough tensor inputs are provided. - """ + """Default meta/fake: copy shape/dtype/device of the first N tensor inputs + to N outputs (N from the return annotation).""" num_outputs = _get_num_outputs_from_return_annotation(fn) sig = inspect.signature(fn) - # Get parameter names, excluding 'self' if present param_names = [name for name in sig.parameters.keys() if name != "self"] def identity_meta_fn(*args, **kwargs): - # Bind arguments to get a mapping of param_name -> value bound = sig.bind(*args, **kwargs) bound.apply_defaults() - # Collect the first `num_outputs` tensor arguments tensor_args = [] for name in param_names: arg = bound.arguments.get(name) @@ -107,12 +660,11 @@ def identity_meta_fn(*args, **kwargs): if len(tensor_args) < num_outputs: raise ValueError( - f"identity_meta_fn requires at least {num_outputs} tensor inputs to match " - f"{num_outputs} outputs, but only found {len(tensor_args)} tensor inputs. " - f"Please provide a custom infer_output_meta_fn." + f"@magi_register_custom_op: identity_meta_fn needs {num_outputs} " + f"tensor input(s) but found {len(tensor_args)}; provide a custom " + f"infer_output_meta_fn." ) - # Return outputs with same metadata as the first N inputs if num_outputs == 1: return torch.empty_like(tensor_args[0]) return tuple(torch.empty_like(t) for t in tensor_args[:num_outputs]) @@ -121,43 +673,32 @@ def identity_meta_fn(*args, **kwargs): def _create_meta_fn_from_param_names(fn: Callable, param_names: list[str]) -> Callable: - """ - Create a meta function that returns torch.empty_like() for each specified parameter. - - Args: - fn: Target function to inspect - param_names: List of parameter names to use as output templates - - Returns: - Meta function that maps specified input params to output tensors - - Raises: - ValueError: If parameter name doesn't exist or isn't a Tensor - """ + """Meta/fake that echoes the listed tensor parameters as outputs + (``torch.empty_like`` each). Raises ``ValueError`` for unknown or + non-Tensor names.""" sig = inspect.signature(fn) def meta_fn(*args, **kwargs): - # Bind arguments to get a mapping of param_name -> value bound = sig.bind(*args, **kwargs) bound.apply_defaults() - # Collect tensors for each specified parameter name tensor_outputs = [] for name in param_names: if name not in bound.arguments: raise ValueError( - f"Parameter '{name}' not found in function signature. " - f"Available parameters: {list(bound.arguments.keys())}" + f"@magi_register_custom_op: infer_output_meta_fn references " + f"unknown parameter {name!r}; available: " + f"{list(bound.arguments.keys())}." ) arg = bound.arguments[name] if not isinstance(arg, torch.Tensor): raise ValueError( - f"Parameter '{name}' is not a Tensor (got {type(arg).__name__}). " - f"infer_output_meta_fn list should only contain tensor parameter names." + f"@magi_register_custom_op: infer_output_meta_fn entry " + f"{name!r} is not a Tensor (got {type(arg).__name__}); " + f"list must contain only tensor parameter names." ) tensor_outputs.append(torch.empty_like(arg)) - # Return single tensor or tuple based on number of outputs if len(tensor_outputs) == 1: return tensor_outputs[0] return tuple(tensor_outputs) @@ -165,6 +706,154 @@ def meta_fn(*args, **kwargs): return meta_fn +# ============================================================================== +# 6. Register the op (produces slot 2) +# ------------------------------------------------------------------------------ +# `_register_torch_op` calls `custom_op` + `register_fake`, yielding slot 2 +# (`torch_registered_op`). `_generate_op_name` derives a default op name +# from the user's `fn` when one isn't supplied. The orchestrator that +# stitches these together with sec 1-5 (and the runtime closures from +# sec 7) lives in sec 8. +# ============================================================================== + + +def _generate_op_name(fn: Callable) -> str: + """Op name ``{filename_stem}::{fn.__name__}``, falling back to + ``magi_custom::`` if the source file isn't available.""" + import re + from pathlib import Path + + func_name = fn.__name__ + try: + source_file = inspect.getfile(fn) + namespace = Path(source_file).stem + namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace) + except (TypeError, OSError): + namespace = "magi_custom" + + return f"{namespace}::{func_name}" + + +def _register_torch_op(op_name: str, fn: Callable, mutates_args: tuple[str, ...], meta_fn: Callable): + """``custom_op`` + ``register_fake``; returns slot 2 (``torch_registered_op``).""" + torch_registered_op = torch.library.custom_op(op_name, mutates_args=mutates_args)(fn) + torch.library.register_fake(op_name)(meta_fn) + return torch_registered_op + + +# ============================================================================== +# 7. Runtime bridge: flatten / unflatten on every call +# ------------------------------------------------------------------------------ +# Executed on every call to slot 1 (`lowered_fn`) or slot 3 (`magi_exposed_op`), +# consuming the static `param_mapping_tree` from sec 3 to translate between +# the original (dataclass) and lowered (primitive) call shapes. See Part B +# of the module docstring for the full call-stack picture. +# original -> lowered: `_flatten_value_into`, `_flatten_call_args` +# lowered -> original: `_build_value_from_node`, `_reassemble_kwargs` +# grad bridge: `_flatten_grad_into`, `_flatten_grads` +# ============================================================================== + + +def _build_value_from_node(node: tuple, lowered_kwargs: dict): + """``lowered_kwargs`` -> one original-shaped value (recursive).""" + kind = node[0] + if kind == "primitive": + _, _attr, lowered_name, _ = node + return lowered_kwargs[lowered_name] + _, _attr, cls, children = node + init_kwargs: dict[str, Any] = {} + for child in children: + field_name = child[1] + init_kwargs[field_name] = _build_value_from_node(child, lowered_kwargs) + return cls(**init_kwargs) + + +def _reassemble_kwargs(param_mapping_tree: list[tuple], lowered_kwargs: dict) -> dict: + """``lowered_kwargs`` -> original kwargs (the ``lowered -> original`` walk).""" + out: dict[str, Any] = {} + for node in param_mapping_tree: + out[node[1]] = _build_value_from_node(node, lowered_kwargs) + return out + + +def _flatten_value_into(node: tuple, value: Any, out: list) -> None: + """Append leaves of ``value`` to ``out`` in DFS order (no isinstance check + on ``cls`` -- duck-typed via ``getattr`` so mocks / SimpleNamespace work).""" + kind = node[0] + if kind == "primitive": + out.append(value) + return + _, _attr, cls, children = node + for child in children: + field_name = child[1] + _flatten_value_into(child, getattr(value, field_name), out) + + +def _flatten_call_args(param_mapping_tree: list[tuple], original_sig: inspect.Signature, args: tuple, kwargs: dict) -> list: + """User-side call -> flat positional list matching the lowered signature + (the ``original -> lowered`` walk).""" + bound = original_sig.bind(*args, **kwargs) + bound.apply_defaults() + flat: list = [] + for node in param_mapping_tree: + _flatten_value_into(node, bound.arguments[node[1]], flat) + return flat + + +def _flatten_grad_into(node: tuple, grad: Any, out: list) -> None: + """Spread a user-returned grad across the lowered slots of one original input. + + ``primitive`` -> append ``grad`` as-is. ``dataclass`` -> if ``grad`` is + ``None`` fill every leaf with ``None`` (the common whole-dataclass-not- + differentiable case); otherwise descend with ``dict``-aware lookup so + users can return dict / SimpleNamespace / dataclass-shaped objects. + """ + if node[0] == "primitive": + out.append(grad) + return + _, _attr, _cls, children = node + if grad is None: + for child in children: + for _ in range(_count_leaves(child)): + out.append(None) + return + is_mapping = isinstance(grad, dict) + for child in children: + field_name = child[1] + if is_mapping: + sub = grad.get(field_name) + else: + sub = getattr(grad, field_name, None) + _flatten_grad_into(child, sub, out) + + +def _flatten_grads(param_mapping_tree: list[tuple], original_grads: tuple | list) -> list: + """Grads keyed by original-parameter order -> grads keyed by lowered order.""" + if len(original_grads) != len(param_mapping_tree): + raise ValueError( + f"@magi_register_custom_op: backward_fn returned {len(original_grads)} " + f"grad(s) but the function has {len(param_mapping_tree)} input(s); " + f"return one grad per ORIGINAL parameter (``None`` for non-differentiable " + f"or whole-dataclass args)." + ) + flat: list = [] + for node, g in zip(param_mapping_tree, original_grads): + _flatten_grad_into(node, g, flat) + return flat + + +# ============================================================================== +# 8. The decorator: main pipeline (produces slot 3) +# ------------------------------------------------------------------------------ +# The single public entry point. Its inner `decorator` closure orchestrates the +# full 4-slot pipeline (see module docstring): it calls sec 4 to lower the user's +# signature (slot 1), sec 5 to synthesise the meta function, sec 6 to register +# the op with torch.library (slot 2), and -- on the dataclass-flatten path -- +# additionally builds the user-facing wrapper (slot 3) plus the runtime closures +# that drive sec 7 on each call. +# ============================================================================== + + def _magi_register_custom_op_impl( name: str | None = None, mutates_args: tuple[str, ...] = (), @@ -175,30 +864,132 @@ def _magi_register_custom_op_impl( is_subgraph_boundary: bool = False, ): def decorator(fn: Callable) -> Callable: - # Auto-generate name if not provided + # See the module docstring for the 4-slot pipeline / 3-runtime-path + # picture; the body below just walks slot 0 -> 1 -> 2 (-> 3 if needed). + op_name = name if name is not None else _generate_op_name(fn) if is_compute_sensitive: get_compile_config().recompute_config.custom_compute_sensitive_ops.append(op_name) if is_subgraph_boundary: get_compile_config().splitting_ops.append(op_name) - # Step 1: Register the custom op with torch.library.custom_op - registered_op = torch.library.custom_op(op_name, mutates_args=mutates_args)(fn) + # Dataclass parameters are the only thing forcing slot 3; other lowering + # (Literal/Enum/default scrub) is handled by slot 1 alone at zero + # per-call cost. + original_sig, lowered_sig, param_mapping_tree = _lower_op_signature(fn) + needs_flattening = any(kind == "dataclass" for kind, *_ in param_mapping_tree) + + if not needs_flattening: + # ----- No-flattening path: fn -> [lowered_fn?] -> torch_registered_op ----- + # Step 1 (slot 1): only when the lowering actually rewrote the + # signature -- otherwise register ``fn`` directly (zero-overhead). + if _signatures_differ(original_sig, lowered_sig): + lowered_fn = _make_lowered_signature_wrapper(fn, lowered_sig) + fn_to_register = lowered_fn + else: + fn_to_register = fn + + # Step 2: meta/fake function. + if infer_output_meta_fn is None: + meta_fn = _create_identity_meta_fn(fn_to_register) + elif isinstance(infer_output_meta_fn, list): + meta_fn = _create_meta_fn_from_param_names(fn_to_register, infer_output_meta_fn) + else: + meta_fn = infer_output_meta_fn + + # Step 3 (slot 2): custom_op + register_fake. + torch_registered_op = _register_torch_op( + op_name=op_name, fn=fn_to_register, mutates_args=mutates_args, meta_fn=meta_fn + ) - # Step 2: Register the output meta inference function - # Determine meta_fn based on the type of infer_output_meta_fn - if infer_output_meta_fn is None: - meta_fn = _create_identity_meta_fn(fn) - elif isinstance(infer_output_meta_fn, list): - meta_fn = _create_meta_fn_from_param_names(fn, infer_output_meta_fn) - else: - meta_fn = infer_output_meta_fn - torch.library.register_fake(op_name)(meta_fn) + # Step 4: autograd. + if backward_fn is not None: + torch_registered_op.register_autograd(backward_fn, setup_context=setup_context_fn) - # Step 3: Register autograd if backward_fn is provided - if backward_fn is not None: - registered_op.register_autograd(backward_fn, setup_context=setup_context_fn) + # No slot 3 needed: the user's calling convention already matches + # the lowered one, so ``torch_registered_op`` is itself returned. + return torch_registered_op + + else: + # ----- Flattening path: fn -> lowered_fn -> torch_registered_op -> magi_exposed_op ----- + # Step 1 (slot 1): build the lowered-signature bridge. ``lowered_fn`` + # speaks the flat primitive signature; it rebinds args, reassembles + # dataclasses, then dispatches to the user's ``fn``. + def _bind_to_original_kwargs(args, kwargs): + bound = lowered_sig.bind(*args, **kwargs) + bound.apply_defaults() + return _reassemble_kwargs(param_mapping_tree, bound.arguments) + + @functools.wraps(fn) + def lowered_fn(*args, **kwargs): + return fn(**_bind_to_original_kwargs(args, kwargs)) + + _apply_lowered_signature_metadata(lowered_fn, lowered_sig) + + # Step 2: meta/fake function. User-supplied meta_fn is bridged so + # it sees the original (dataclass-bearing) signature it was + # written against. + if infer_output_meta_fn is None: + meta_fn = _create_identity_meta_fn(lowered_fn) + elif isinstance(infer_output_meta_fn, list): + meta_fn = _create_meta_fn_from_param_names(lowered_fn, infer_output_meta_fn) + else: + user_meta = infer_output_meta_fn + + def _bridged_meta_fn(*args, **kwargs): + return user_meta(**_bind_to_original_kwargs(args, kwargs)) + + _bridged_meta_fn.__signature__ = lowered_sig + meta_fn = _bridged_meta_fn + + # Step 3 (slot 2): custom_op + register_fake. ``mutates_args`` is + # expanded from original-space to lowered-space so torch.library + # sees the leaf parameter names it actually owns. + flat_mutates_args = _expand_mutates_args(mutates_args, param_mapping_tree) + torch_registered_op = _register_torch_op( + op_name=op_name, fn=lowered_fn, mutates_args=flat_mutates_args, meta_fn=meta_fn + ) - return registered_op + # Step 4: autograd. The user's hooks speak the ORIGINAL signature, + # but torch.library passes/expects LOWERED inputs and grads, so we + # wrap both ends. + if backward_fn is not None: + user_setup = setup_context_fn + user_backward = backward_fn + + def _bridged_setup_context(ctx, inputs, output): + if user_setup is None: + return None + # Reassemble the lowered positional tuple into the user's + # original (possibly nested-dataclass) shape, preserving + # original positional order so ``x, cfg = inputs`` works. + lowered_kwargs = {p.name: v for p, v in zip(lowered_sig.parameters.values(), inputs)} + original_kwargs = _reassemble_kwargs(param_mapping_tree, lowered_kwargs) + original_inputs = tuple(original_kwargs[p] for p in original_sig.parameters) + return user_setup(ctx, original_inputs, output) + + def _bridged_backward(ctx, *grads): + original_grads = user_backward(ctx, *grads) + if not isinstance(original_grads, tuple): + # Single-input convenience: PyTorch allows a bare grad + # when the op has one input. + original_grads = (original_grads,) + return tuple(_flatten_grads(param_mapping_tree, original_grads)) + + torch_registered_op.register_autograd(_bridged_backward, setup_context=_bridged_setup_context) + + # Step 5 (slot 3, flattening-only): the user-facing op that + # preserves the original signature, flattens at entry, and + # dispatches to ``torch_registered_op``. + @functools.wraps(fn) + def magi_exposed_op(*args, **kwargs): + flat = _flatten_call_args(param_mapping_tree, original_sig, args, kwargs) + return torch_registered_op(*flat) + + # Internal handles so downstream tooling can drop one slot lower + # (e.g. dispatch the OpOverload directly with pre-flattened args). + magi_exposed_op._magi_torch_registered_op = torch_registered_op + magi_exposed_op._magi_param_mapping_tree = param_mapping_tree + return magi_exposed_op return decorator diff --git a/magi_compiler/api.py b/magi_compiler/api.py index 996657a..a712a5b 100644 --- a/magi_compiler/api.py +++ b/magi_compiler/api.py @@ -203,41 +203,57 @@ def magi_register_custom_op( is_compute_sensitive: bool = False, is_subgraph_boundary: bool = False, ): - """ - A unified decorator to register a custom operator with PyTorch's library. - - This decorator combines the functionality of: - - @torch.library.custom_op - - @torch.library.register_fake - - fn.register_autograd + """Register a Python function as a custom op for ``torch.library`` / ``torch.compile``. + + Combines ``@torch.library.custom_op`` + ``@torch.library.register_fake`` + + ``fn.register_autograd`` into one decorator, plus the following ergonomic + affordances on top of bare ``torch.library``: + + - **Frozen-dataclass parameters** (recursively nested) are flattened into + primitive leaves before being handed to ``infer_schema``, and reassembled + inside ``fn`` so the op body still sees the original dataclass. + - **Literal / string-Enum annotations** are auto-downgraded to ``str``; + the op body still receives the original value. + - **Unsupported defaults** (mutable, dataclass instances, ...) are scrubbed + from the lowered signature only; user-facing calls keep the original default. + - **Auto-generated op name** when ``name`` is omitted: derived from the + function's source file and ``__name__``. + - **Auto-generated meta function** when ``infer_output_meta_fn`` is omitted: + output ``i`` copies shape/dtype/device of the ``i``-th tensor input. Arguments: - name: The fully qualified name of the operator (e.g., "namespace::op_name"). - If None, auto-generated from the function name and source file. - mutates_args: Tuple of argument names that are mutated by the operator. - infer_output_meta_fn: Specifies output tensor metadata (shape, dtype, device) for tracing. - - None (default): Assumes each output has the same metadata as the corresponding - input tensor (1st output matches 1st tensor input, 2nd matches 2nd, etc.). - - list[str]: Parameter names whose metadata to use for outputs. - E.g., ["weight", "bias"] means output[0] has same shape as `weight`, - output[1] has same shape as `bias`. - - Callable: Custom function with same signature as the op, returns torch.empty_like() - tensors matching the expected output shapes. - setup_context_fn: Function to save tensors/values for backward. - Signature: setup_context_fn(ctx, inputs, output) - backward_fn: Function to compute gradients. - Signature: backward_fn(ctx, *grad_outputs) -> tuple of gradients - is_compute_sensitive: If True, marks this operator as compute-intensive (e.g., MatMul, - Attention). During activation recomputation (rematerialization), outputs of - compute-sensitive ops are prioritized for saving rather than recomputing, - since recomputing them would be expensive. - is_subgraph_boundary: If True, the FX graph will be split at this operator during - compilation. Each sub-graph between boundary operators is compiled independently - by Inductor, enabling piecewise compilation and more flexible scheduling - (e.g., for CPU offloading or overlapping computation with data transfer). + name: Fully qualified op name (e.g. ``"namespace::op_name"``). If ``None``, + auto-generated from the function's source file and name. + mutates_args: Argument names that the op mutates. For a frozen-dataclass + argument, listing the dataclass parameter expands to every Tensor leaf + under it; lowered leaf names (e.g. ``"cfg__weight"``) are also accepted. + infer_output_meta_fn: How to propagate output metadata at trace time. + - ``None`` (default): output ``i`` copies the ``i``-th tensor input. + - ``list[str]``: parameter names whose metadata to copy. E.g. + ``["weight", "bias"]`` makes ``output[0]`` shape-match ``weight`` + and ``output[1]`` shape-match ``bias``. + - ``Callable``: a function with the same signature as the op that + returns ``torch.empty_like(...)`` tensors of the expected shapes. + setup_context_fn: Forward-context setup; signature + ``setup_context_fn(ctx, inputs, output)``. ``inputs`` is the + user-side (original-shape) tuple, including dataclass instances. + backward_fn: Gradient computation; signature + ``backward_fn(ctx, *grad_outputs) -> tuple of grads``. Return **one + grad per original parameter** (not per lowered leaf); use ``None`` + for non-differentiable parameters, including whole dataclass args. + is_compute_sensitive: Mark as compute-intensive. During activation + recomputation, outputs of compute-sensitive ops are prioritised for + saving rather than recomputing. + is_subgraph_boundary: Split the FX graph at this op during compilation. + Each sub-graph between boundary ops is compiled independently. Returns: - The registered custom operator function. + A callable with the user's original signature. + - If ``fn`` has no dataclass parameter, returns a ``torch._ops.OpOverload`` + directly (zero per-call overhead). + - If ``fn`` has a frozen-dataclass parameter, returns a Python wrapper + that flattens/unflattens on each call and dispatches to the underlying + ``OpOverload`` (accessible via ``op._magi_torch_registered_op``). Examples: 1. Basic usage (forward only, auto-generated name and meta function): @@ -248,9 +264,7 @@ def magi_register_custom_op( 2. Multiple outputs with explicit output metadata via parameter names: - >>> @magi_register_custom_op( - ... infer_output_meta_fn=["weight", "bias"], # output shapes match weight and bias - ... ) + >>> @magi_register_custom_op(infer_output_meta_fn=["weight", "bias"]) ... def compute_gradients( ... grad_output: torch.Tensor, ... weight: torch.Tensor, @@ -260,7 +274,21 @@ def magi_register_custom_op( ... grad_bias = grad_output.sum(dim=0).view_as(bias) ... return grad_weight, grad_bias - 3. Full custom op with autograd support: + 3. Frozen-dataclass parameter (grouped config): + + >>> @dataclasses.dataclass(frozen=True) + ... class AttnCfg: + ... scale: float + ... causal: bool = False + ... + >>> @magi_register_custom_op() + ... def attn(q: torch.Tensor, k: torch.Tensor, cfg: AttnCfg) -> torch.Tensor: + ... scores = (q @ k.transpose(-1, -2)) * cfg.scale + ... if cfg.causal: + ... scores = scores.tril() + ... return scores + + 4. Full custom op with autograd support: >>> def _square_meta(x: torch.Tensor) -> torch.Tensor: ... return torch.empty_like(x) diff --git a/tests/api_tests/test_register_custom_op.py b/tests/api_tests/test_register_custom_op.py index 9872e68..81e7656 100644 --- a/tests/api_tests/test_register_custom_op.py +++ b/tests/api_tests/test_register_custom_op.py @@ -24,6 +24,7 @@ - Integration with magi_compile decorator """ +import dataclasses import tempfile from unittest.mock import patch @@ -113,126 +114,6 @@ def _split_op(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert_close(out2, x[..., 4:]) -class TestAutograd: - """Tests for custom op with autograd support.""" - - def test_with_autograd(self): - """Test custom op with setup_context and backward functions.""" - - def _square_infer_output_meta(x: torch.Tensor) -> torch.Tensor: - return torch.empty_like(x) - - def _square_setup_context(ctx, inputs, output): - (x,) = inputs - ctx.save_for_backward(x) - - def _square_backward(ctx, grad_output): - (x,) = ctx.saved_tensors - return grad_output * 2 * x - - @magi_register_custom_op( - name="test::square_op", - mutates_args=(), - infer_output_meta_fn=_square_infer_output_meta, - setup_context_fn=_square_setup_context, - backward_fn=_square_backward, - ) - def _square_op(x: torch.Tensor) -> torch.Tensor: - return x * x - - x = torch.randn(4, 8, requires_grad=True) - output = _square_op(x) - loss = output.sum() - loss.backward() - - # Gradient of x^2 is 2x - expected_grad = 2 * x - assert_close(x.grad, expected_grad) - - def test_autograd_multiple_inputs(self): - """Test autograd with multiple input tensors.""" - - def _weighted_sum_infer_output_meta(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor: - return torch.empty_like(a) - - def _weighted_sum_setup_context(ctx, inputs, output): - a, b, weight = inputs - ctx.save_for_backward(a, b) - ctx.weight = weight - - def _weighted_sum_backward(ctx, grad_output): - a, b = ctx.saved_tensors - weight = ctx.weight - grad_a = grad_output * weight - grad_b = grad_output * (1 - weight) - return grad_a, grad_b, None # None for non-tensor input - - @magi_register_custom_op( - name="test::weighted_sum_op", - mutates_args=(), - infer_output_meta_fn=_weighted_sum_infer_output_meta, - setup_context_fn=_weighted_sum_setup_context, - backward_fn=_weighted_sum_backward, - ) - def _weighted_sum_op(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor: - return a * weight + b * (1 - weight) - - a = torch.randn(4, 8, requires_grad=True) - b = torch.randn(4, 8, requires_grad=True) - weight = 0.7 - - output = _weighted_sum_op(a, b, weight) - loss = output.sum() - loss.backward() - - expected_grad_a = torch.ones_like(a) * weight - expected_grad_b = torch.ones_like(b) * (1 - weight) - - assert_close(a.grad, expected_grad_a) - assert_close(b.grad, expected_grad_b) - - def test_autograd_multiple_outputs(self): - """Test autograd with multiple output tensors.""" - - def _split_scale_infer_output_meta(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]: - half = x.shape[-1] // 2 - return (x.new_empty((*x.shape[:-1], half)), x.new_empty((*x.shape[:-1], half))) - - def _split_scale_setup_context(ctx, inputs, output): - x, scale = inputs - ctx.save_for_backward(x) - ctx.scale = scale - ctx.half = x.shape[-1] // 2 - - def _split_scale_backward(ctx, grad_out1, grad_out2): - (x,) = ctx.saved_tensors - scale = ctx.scale - # Reconstruct gradient for x - grad_x = torch.cat([grad_out1 * scale, grad_out2 * scale], dim=-1) - return grad_x, None - - @magi_register_custom_op( - name="test::split_scale_op", - mutates_args=(), - infer_output_meta_fn=_split_scale_infer_output_meta, - setup_context_fn=_split_scale_setup_context, - backward_fn=_split_scale_backward, - ) - def _split_scale_op(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]: - half = x.shape[-1] // 2 - return x[..., :half] * scale, x[..., half:] * scale - - x = torch.randn(4, 8, requires_grad=True) - scale = 2.0 - - out1, out2 = _split_scale_op(x, scale) - loss = out1.sum() + out2.sum() - loss.backward() - - expected_grad = torch.ones_like(x) * scale - assert_close(x.grad, expected_grad) - - class TestAutoGeneratedName: """Tests for auto-generated operator name when name is not provided.""" @@ -655,5 +536,1477 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert_close(out2, x[..., 4:] * 2) +@dataclasses.dataclass(frozen=True) +class _CSCfg: + s: float + + +class TestDataclassWithComputeSensitiveSmoke: + """Dataclass input + ``is_compute_sensitive=True`` should register cleanly + and the op name should land in ``custom_compute_sensitive_ops``. + """ + + def test_dataclass_compute_sensitive_smoke(self): + from magi_compiler.config import get_compile_config + + @magi_register_custom_op(name="test::dc_compute_sensitive", is_compute_sensitive=True) + def _op(x: torch.Tensor, cfg: _CSCfg) -> torch.Tensor: + return x * cfg.s + + out = _op(torch.ones(2), _CSCfg(s=2.0)) + assert_close(out, torch.full((2,), 2.0)) + assert "test::dc_compute_sensitive" in get_compile_config().recompute_config.custom_compute_sensitive_ops + + +class TestFrozenDataclassInput: + """Tests for frozen-dataclass input support via the auto pytree path.""" + + def test_forward_with_dataclass_only(self): + """Custom op whose only argument is a frozen dataclass.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _ScaleCfg: + scale: float + bias: float + + @magi_register_custom_op(name="test::dc_only_op", mutates_args=()) + def _dc_only_op(cfg: _ScaleCfg) -> torch.Tensor: + return torch.full((2, 3), cfg.scale + cfg.bias) + + cfg = _ScaleCfg(scale=2.0, bias=1.0) + out = _dc_only_op(cfg) + assert_close(out, torch.full((2, 3), 3.0)) + + def test_forward_with_mixed_args(self): + """Custom op with mixed dataclass and tensor inputs.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _AttnCfg: + scale: float + causal: bool + + @magi_register_custom_op(name="test::dc_mixed_op", mutates_args=()) + def _dc_mixed_op(q: torch.Tensor, k: torch.Tensor, cfg: _AttnCfg) -> torch.Tensor: + out = (q @ k.transpose(-1, -2)) * cfg.scale + if cfg.causal: + mask = torch.tril(torch.ones_like(out, dtype=torch.bool)) + out = out.masked_fill(~mask, 0.0) + return out + + q = torch.randn(2, 4, 8) + k = torch.randn(2, 4, 8) + cfg = _AttnCfg(scale=0.25, causal=False) + out = _dc_mixed_op(q, k, cfg) + expected = (q @ k.transpose(-1, -2)) * 0.25 + assert_close(out, expected) + + # Try positional + keyword call form to ensure outer wrapper handles both. + out_kw = _dc_mixed_op(q, k=k, cfg=cfg) + assert_close(out_kw, expected) + + def test_forward_with_custom_meta_fn(self): + """Custom op with a user-provided meta function expressed in dataclass terms.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _ProjCfg: + out_dim: int + + def _proj_meta(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: + return x.new_empty((*x.shape[:-1], cfg.out_dim)) + + @magi_register_custom_op(name="test::dc_meta_op", mutates_args=(), infer_output_meta_fn=_proj_meta) + def _dc_meta_op(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: + return x[..., : cfg.out_dim].clone() + + x = torch.randn(2, 8) + cfg = _ProjCfg(out_dim=3) + out = _dc_meta_op(x, cfg) + assert out.shape == (2, 3) + assert_close(out, x[..., :3]) + + +class TestNestedDataclassInput: + """Tests for *recursively* flattened nested-dataclass inputs.""" + + def test_nested_dataclass_only(self): + """A single dataclass argument that itself contains a dataclass field.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Inner: + scale: float + bias: float + + @dataclass(frozen=True) + class _Outer: + inner: _Inner + offset: float + + @magi_register_custom_op(name="test::nested_dc_only_op", mutates_args=()) + def _nested_dc_only_op(cfg: _Outer) -> torch.Tensor: + return torch.full((2, 3), cfg.inner.scale * cfg.inner.bias + cfg.offset) + + cfg = _Outer(inner=_Inner(scale=2.0, bias=3.0), offset=1.5) + out = _nested_dc_only_op(cfg) + assert_close(out, torch.full((2, 3), 7.5)) + + # The param mapping tree should fully expand the nested dataclass into leaves. + param_mapping_tree = _nested_dc_only_op._magi_param_mapping_tree + assert param_mapping_tree[0][0] == "dataclass" + assert param_mapping_tree[0][1] == "cfg" + children = param_mapping_tree[0][3] + kinds = [c[0] for c in children] + assert "dataclass" in kinds, "inner dataclass field must remain a dataclass node" + # Find the leaf lowered names. + lowered_names: list[str] = [] + + def _collect(node): + if node[0] == "primitive": + lowered_names.append(node[2]) + else: + for child in node[3]: + _collect(child) + + _collect(param_mapping_tree[0]) + assert "cfg__inner__scale" in lowered_names + assert "cfg__inner__bias" in lowered_names + assert "cfg__offset" in lowered_names + + def test_nested_dataclass_mixed_with_tensor(self): + """Tensor arg alongside a nested dataclass arg.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _NormCfg: + eps: float + affine: bool + + @dataclass(frozen=True) + class _LayerCfg: + norm: _NormCfg + scale: float + + @magi_register_custom_op(name="test::nested_dc_mixed_op", mutates_args=()) + def _nested_dc_mixed_op(x: torch.Tensor, cfg: _LayerCfg) -> torch.Tensor: + y = x / (x.std() + cfg.norm.eps) + if cfg.norm.affine: + y = y * cfg.scale + return y + + x = torch.randn(4, 8) + cfg = _LayerCfg(norm=_NormCfg(eps=1e-3, affine=True), scale=2.0) + out = _nested_dc_mixed_op(x, cfg) + expected = x / (x.std() + 1e-3) * 2.0 + assert_close(out, expected) + + # Keyword form too. + out_kw = _nested_dc_mixed_op(x=x, cfg=cfg) + assert_close(out_kw, expected) + + def test_deeply_nested_dataclass(self): + """Three levels of nesting plus a sibling primitive at the top level.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Leaf: + val: float + + @dataclass(frozen=True) + class _Mid: + leaf: _Leaf + extra: int + + @dataclass(frozen=True) + class _Root: + mid: _Mid + tag: float + + @magi_register_custom_op(name="test::deep_nested_dc_op", mutates_args=()) + def _deep_nested_dc_op(x: torch.Tensor, cfg: _Root, alpha: float) -> torch.Tensor: + return x * cfg.mid.leaf.val + cfg.mid.extra + cfg.tag + alpha + + x = torch.randn(2, 3) + cfg = _Root(mid=_Mid(leaf=_Leaf(val=2.0), extra=3), tag=0.5) + out = _deep_nested_dc_op(x, cfg, alpha=1.0) + assert_close(out, x * 2.0 + 3 + 0.5 + 1.0) + + def test_nested_dataclass_with_meta_fn(self): + """User-supplied meta function expressed in nested-dataclass terms.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _ShapeCfg: + out_dim: int + + @dataclass(frozen=True) + class _ProjCfg: + shape: _ShapeCfg + scale: float + + def _proj_meta(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: + return x.new_empty((*x.shape[:-1], cfg.shape.out_dim)) + + @magi_register_custom_op(name="test::nested_dc_meta_op", mutates_args=(), infer_output_meta_fn=_proj_meta) + def _nested_dc_meta_op(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: + return x[..., : cfg.shape.out_dim].clone() * cfg.scale + + x = torch.randn(2, 8) + cfg = _ProjCfg(shape=_ShapeCfg(out_dim=3), scale=2.0) + out = _nested_dc_meta_op(x, cfg) + assert out.shape == (2, 3) + assert_close(out, x[..., :3] * 2.0) + + +class TestDataclassOptionalFields: + """Frozen dataclass fields annotated with ``Optional[...]`` should flow + through the recursive flattener and be accepted by + ``torch.library.infer_schema`` as long as the underlying type is one of + the schema's known optional types. + + Coverage matrix (verified against torch 2.9.1): + + - ``Optional[Tensor]`` ✓ + - ``Optional[int]`` ✓ + - ``Optional[float]`` ✓ + - ``Optional[bool]`` ✓ + - ``Optional[str]`` ✓ + - ``Optional[list[int]]`` ✓ + - ``Tensor | None`` (PEP 604) ✓ + - ``Optional[list[Tensor]]`` ✗ (PyTorch limitation, see negative + test below; users should use ``list[Optional[Tensor]]`` instead) + """ + + def test_optional_scalar_and_tensor_fields(self): + from dataclasses import dataclass + from typing import Optional + + @dataclass(frozen=True) + class _Cfg: + bias: Optional[torch.Tensor] + scale: Optional[float] + mode: Optional[str] + block_sizes: Optional[list[int]] + flag: Optional[bool] + count: Optional[int] + + @magi_register_custom_op(name="test::dc_optional_mix") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + out = x.clone() + if cfg.bias is not None: + out = out + cfg.bias + if cfg.scale is not None: + out = out * cfg.scale + return out + + x = torch.ones(3) + # All-None: just the identity transform. + out_all_none = _op(x, _Cfg(None, None, None, None, None, None)) + assert_close(out_all_none, x) + # Some-None: bias + scale active, others ignored by the body. + cfg = _Cfg(bias=torch.tensor([1.0, 2.0, 3.0]), scale=2.0, mode="a", block_sizes=[4, 8], flag=True, count=7) + out_some = _op(x, cfg) + assert_close(out_some, torch.tensor([4.0, 6.0, 8.0])) + + def test_pep604_optional_tensor_field(self): + """The ``X | None`` PEP 604 syntax should be equivalent to + ``Optional[X]`` for dataclass fields.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + bias: torch.Tensor | None + + @magi_register_custom_op(name="test::dc_pep604_optional") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x.clone() if cfg.bias is None else x + cfg.bias + + x = torch.ones(2) + assert_close(_op(x, _Cfg(bias=None)), x) + assert_close(_op(x, _Cfg(bias=torch.tensor([10.0, 20.0]))), torch.tensor([11.0, 21.0])) + + def test_optional_list_of_tensors_unsupported_by_torch_library(self): + """Sanity-pin a known-broken case so a future torch upgrade that + relaxes ``infer_schema`` flips this from xfail to pass. + + ``Optional[list[Tensor]]`` is NOT in the torch.library schema's + accepted-type set (although ``list[Optional[Tensor]]`` is). The + magi flattener doesn't try to rewrite the user's annotation, so we + let the underlying ValueError propagate.""" + from dataclasses import dataclass + from typing import Optional + + @dataclass(frozen=True) + class _Cfg: + biases: Optional[list[torch.Tensor]] + + with pytest.raises(ValueError, match="unsupported type"): + + @magi_register_custom_op(name="test::dc_optional_list_tensor") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x + + +class TestDataclassListFields: + """Frozen dataclass fields annotated with ``list[...]`` flow through the + flattener and are accepted by ``torch.library.infer_schema`` for the + schema's supported scalar/tensor element types. + + Coverage matrix (verified against torch 2.9.1): + + - ``list[Tensor]`` ✓ + - ``list[int]`` ✓ + - ``list[float]`` ✓ + - ``list[bool]`` ✓ + - ``list[Optional[Tensor]]`` ✓ + """ + + def test_list_of_tensors_and_scalars_in_dataclass(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + biases: list[torch.Tensor] + scales: list[float] + block_sizes: list[int] + flags: list[bool] + + @magi_register_custom_op(name="test::dc_list_mix") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + out = x.clone() + for b in cfg.biases: + out = out + b + out = out * float(sum(cfg.scales)) + # ``flags`` and ``block_sizes`` are accepted on the schema side + # even though the body here doesn't depend on them. + assert all(isinstance(s, int) for s in cfg.block_sizes) + assert all(isinstance(f, bool) for f in cfg.flags) + return out + + x = torch.ones(3) + cfg = _Cfg( + biases=[torch.tensor([1.0, 2.0, 3.0]), torch.tensor([10.0, 20.0, 30.0])], + scales=[2.0, 0.5], + block_sizes=[4, 8], + flags=[True, False], + ) + out = _op(x, cfg) + # (1 + 1 + 10) * 2.5 = 30, (1 + 2 + 20) * 2.5 = 57.5, ... + assert_close(out, torch.tensor([30.0, 57.5, 85.0])) + + def test_list_of_optional_tensors_in_dataclass(self): + """``list[Optional[Tensor]]`` is on the schema's supported list (in + contrast to the rejected ``Optional[list[Tensor]]``).""" + from dataclasses import dataclass + from typing import Optional + + @dataclass(frozen=True) + class _Cfg: + maybe_biases: list[Optional[torch.Tensor]] + + @magi_register_custom_op(name="test::dc_list_optional_tensor") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + out = x.clone() + for b in cfg.maybe_biases: + if b is not None: + out = out + b + return out + + x = torch.ones(2) + out = _op(x, _Cfg(maybe_biases=[None, torch.tensor([10.0, 20.0]), None])) + assert_close(out, torch.tensor([11.0, 21.0])) + + def test_empty_list_field(self): + """An empty ``list[Tensor]`` value should round-trip through the + flattener and run the body normally.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + biases: list[torch.Tensor] + + @magi_register_custom_op(name="test::dc_empty_list") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + out = x.clone() + for b in cfg.biases: + out = out + b + return out + + x = torch.ones(3) + assert_close(_op(x, _Cfg(biases=[])), x) + + +class TestDataclassFieldDefaults: + """Dataclass field ``default`` / ``default_factory`` values should + propagate to the flat ``inspect.Signature`` so that calling the op + without those fields works (and ``torch.library.infer_schema`` records + them as optional).""" + + def test_dataclass_field_default_carried_into_signature(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + scale: float = 2.0 + + @magi_register_custom_op(name="test::dc_field_default_scale") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.scale + + x = torch.ones(3) + # User constructs the dataclass with the default value implicitly. + assert_close(_op(x, _Cfg()), torch.full((3,), 2.0)) + # And the override path still works. + assert_close(_op(x, _Cfg(scale=10.0)), torch.full((3,), 10.0)) + + def test_dataclass_field_default_factory_for_list(self): + from dataclasses import dataclass, field + + @dataclass(frozen=True) + class _Cfg: + block_sizes: list[int] = field(default_factory=lambda: [4, 8]) + + @magi_register_custom_op(name="test::dc_field_default_factory") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * float(sum(cfg.block_sizes)) + + x = torch.ones(2) + assert_close(_op(x, _Cfg()), torch.full((2,), 12.0)) + assert_close(_op(x, _Cfg(block_sizes=[1, 2, 3])), torch.full((2,), 6.0)) + + def test_nested_dataclass_with_default(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Inner: + val: float = 3.0 + + @dataclass(frozen=True) + class _Outer: + inner: _Inner = _Inner() + scale: float = 1.0 + + @magi_register_custom_op(name="test::dc_nested_default") + def _op(x: torch.Tensor, cfg: _Outer) -> torch.Tensor: + return x * cfg.inner.val * cfg.scale + + x = torch.ones(2) + assert_close(_op(x, _Outer()), torch.full((2,), 3.0)) + assert_close(_op(x, _Outer(inner=_Inner(val=5.0), scale=2.0)), torch.full((2,), 10.0)) + + +class TestDataclassUnsupportedContainerFields: + """``tuple[...]`` and ``dict[...]`` dataclass field annotations are not + accepted by ``torch.library.infer_schema`` (only ``list[...]`` is). We + reject them at registration time with an actionable hint pointing to the + field name and suggested fix.""" + + def test_tuple_field_rejected_with_hint(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + qkv: tuple[torch.Tensor, torch.Tensor, torch.Tensor] + + with pytest.raises(TypeError, match=r"list\[\.\.\.\]"): + + @magi_register_custom_op(name="test::dc_tuple_field") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x + + def test_dict_field_rejected_with_hint(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + weight_by_name: dict[str, torch.Tensor] + + with pytest.raises(TypeError, match="dict-typed"): + + @magi_register_custom_op(name="test::dc_dict_field") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x + + +class TestMutableDataclassRejected: + """``magi_register_custom_op`` only supports ``@dataclass(frozen=True)``. + + A non-frozen dataclass passed as an op input would otherwise produce a + confusing ``ValueError: Unsupported type annotation X. It is not a type`` + deep inside ``torch.library.infer_schema``. We surface a clear error + pointing at the fix. + """ + + def test_top_level_mutable_dataclass_rejected(self): + from dataclasses import dataclass + + @dataclass # NOTE: missing frozen=True + class _MutableCfg: + scale: float + + with pytest.raises(TypeError, match="frozen=True"): + + @magi_register_custom_op(name="test::mutable_dc_top") + def _op(x: torch.Tensor, cfg: _MutableCfg) -> torch.Tensor: + return x * cfg.scale + + def test_nested_mutable_dataclass_rejected(self): + """Inner non-frozen dataclass nested inside an outer frozen one is + also detected; the error names the offending field.""" + from dataclasses import dataclass + + @dataclass # NOT frozen + class _MutableInner: + val: float + + @dataclass(frozen=True) + class _OuterCfg: + inner: _MutableInner + + with pytest.raises(TypeError, match="_MutableInner"): + + @magi_register_custom_op(name="test::mutable_dc_nested") + def _op(x: torch.Tensor, cfg: _OuterCfg) -> torch.Tensor: + return x * cfg.inner.val + + def test_frozen_dataclass_still_works(self): + """Sanity: the rejection path doesn't break the supported case.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _OkCfg: + scale: float + + @magi_register_custom_op(name="test::frozen_dc_ok") + def _op(x: torch.Tensor, cfg: _OkCfg) -> torch.Tensor: + return x * cfg.scale + + out = _op(torch.ones(4), _OkCfg(scale=3.0)) + assert_close(out, torch.full((4,), 3.0)) + + +class TestDataclassReturnRejected: + """``torch.library`` cannot model dataclass returns; we should refuse with + an actionable error instead of letting ``infer_schema`` raise an opaque + ``ValueError`` from the depths. + """ + + def test_dataclass_return_rejected(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Out: + a: torch.Tensor + + with pytest.raises(TypeError, match="dataclass"): + + @magi_register_custom_op(name="test::dc_return") + def _op(x: torch.Tensor) -> _Out: + return _Out(a=x.clone()) + + def test_tensor_return_still_works(self): + @magi_register_custom_op(name="test::tensor_return_ok") + def _op(x: torch.Tensor) -> torch.Tensor: + return x + 1 + + out = _op(torch.zeros(3)) + assert_close(out, torch.ones(3)) + + +class TestDataclassUnresolvableLocalField: + """A dataclass field annotated with a class defined inside another + function (a "local class") under ``from __future__ import annotations`` + cannot be resolved by ``typing.get_type_hints``. We surface a clear + error pointing at the fix instead of a confusing schema error. + """ + + def test_unresolvable_local_field_rejected(self): + # We synthesize the failure by hand-constructing a frozen dataclass + # whose field type is a *string* that doesn't resolve in any visible + # namespace. This mirrors what ``from __future__ import annotations`` + # produces for a local class. + import dataclasses + + @dataclasses.dataclass(frozen=True) + class _OuterCfg: + inner: "_DefinitelyNotInScope" # noqa: F821 + + with pytest.raises(TypeError, match="unresolved string annotation"): + + @magi_register_custom_op(name="test::dc_unresolvable_local") + def _op(x: torch.Tensor, cfg: _OuterCfg) -> torch.Tensor: + return x + + +class TestPep563LocalDataclassResolved: + """Happy-path counterpart of ``TestDataclassUnresolvableLocalField``: when + a function is defined under ``from __future__ import annotations`` and + references a *local* dataclass that's actually visible in the enclosing + closure, ``_resolve_annotations``'s fallback path (the one that combines + ``fn.__globals__`` with ``inspect.getclosurevars(fn).nonlocals``) must + succeed and the op must register & run normally. + + Why this needs ``exec`` to set up: the test file as a whole does NOT + enable ``from __future__ import annotations`` (that would silently + string-ify every other test's annotations, a high-blast-radius change). + We therefore compile a small standalone snippet WITH that future flag + and inject a local frozen dataclass into its exec namespace. This + reproduces the real-world bug class -- "user wrote PEP 563 + factory + function returning an op" -- without contaminating the rest of the file. + + Without this test the resolver fallback (``_resolve_annotations``, + namespace-merging branch) has no regression net: dropping + ``cv.nonlocals`` from the merge would still leave every other test + green. + """ + + def test_local_dataclass_resolved_via_closure_fallback(self): + import dataclasses + import textwrap + + # Step 1: build a frozen dataclass that will live ONLY in the exec + # namespace -- it is intentionally NOT importable from any module, + # so ``typing.get_type_hints`` (which can only see module globals) + # MUST fail and we MUST fall through to the closurevars-based path. + @dataclasses.dataclass(frozen=True) + class _LocalCfg: + scale: float + bias: float + + # Step 2: source for a factory function whose annotations get + # stringified by ``from __future__ import annotations``. ``_LocalCfg`` + # is referenced both in the annotation (as a string) AND in the + # function body (as a real call) -- the body reference is what makes + # ``inspect.getclosurevars`` report ``_LocalCfg`` in ``.nonlocals``, + # which is exactly the namespace lookup we want to exercise. + src = textwrap.dedent( + """ + from __future__ import annotations + + def make_op(register, name, _LocalCfg): + @register(name=name) + def _op(x: torch.Tensor, cfg: _LocalCfg) -> torch.Tensor: + # body reference forces _LocalCfg into cv.nonlocals + _ = _LocalCfg + return x * cfg.scale + cfg.bias + return _op + """ + ) + + ns: dict = {"torch": torch} + exec(compile(src, "", "exec"), ns) + + op = ns["make_op"](magi_register_custom_op, "test::dc_pep563_local_resolved", _LocalCfg) + + x = torch.tensor([1.0, 2.0, 3.0]) + out = op(x, _LocalCfg(scale=2.0, bias=1.0)) + assert_close(out, torch.tensor([3.0, 5.0, 7.0])) + + +@dataclasses.dataclass(frozen=True) +class _OptScalarCfg: + softcap: float | None = None + causal: bool | None = None + + +class TestDataclassOptionalScalarFields: + """Common attention/cfg pattern: dataclass field is ``Optional[scalar]`` + with a ``None`` default. Verify forward across explicit-None and + explicit-value calls. + """ + + def test_optional_scalar_defaults_and_overrides(self): + @magi_register_custom_op(name="test::dc_opt_scalar") + def _op(x: torch.Tensor, cfg: _OptScalarCfg) -> torch.Tensor: + y = x.clone() + if cfg.softcap is not None: + y = torch.tanh(y / cfg.softcap) * cfg.softcap + if cfg.causal is True: + y = y * 0.5 + return y + + x = torch.randn(4) + assert_close(_op(x, _OptScalarCfg()), x) + out = _op(x, _OptScalarCfg(softcap=2.0, causal=True)) + expected = torch.tanh(x / 2.0) * 2.0 * 0.5 + assert_close(out, expected) + + +@dataclasses.dataclass(frozen=True) +class _DtypeCfg: + out_dtype: torch.dtype = torch.float16 + out_device: torch.device | None = None + + +class TestDataclassDtypeDeviceFields: + """``torch.dtype`` / ``torch.device`` are the canonical schema-supported + "structured scalar" types. Make sure they survive the dataclass-flatten + round-trip with sensible defaults. + """ + + def test_dtype_device_round_trip(self): + @magi_register_custom_op(name="test::dc_dtype_device") + def _op(x: torch.Tensor, cfg: _DtypeCfg) -> torch.Tensor: + return x.to(dtype=cfg.out_dtype) + + x = torch.randn(4) + out = _op(x, _DtypeCfg(out_dtype=torch.bfloat16)) + assert out.dtype == torch.bfloat16 + + +class TestTopLevelTupleDictRejected: + """Top-level parameters of type ``tuple[T, ...]`` / ``dict[K, V]`` are + rejected with the same actionable message as dataclass fields. + """ + + def test_top_level_tuple_rejected(self): + with pytest.raises(TypeError, match="tuple"): + + @magi_register_custom_op(name="test::top_tuple") + def _op(xs: tuple[torch.Tensor, ...]) -> torch.Tensor: + return xs[0] + + def test_top_level_dict_rejected(self): + with pytest.raises(TypeError, match="dict"): + + @magi_register_custom_op(name="test::top_dict") + def _op(xs: dict[str, torch.Tensor]) -> torch.Tensor: + return next(iter(xs.values())) + + +class TestMissingAnnotationRejected: + """Every parameter on the registered op (and its return value) must have + a type annotation; ``*args`` / ``**kwargs`` are not supported either. + + Without an explicit guard these mistakes still fail at registration time + but with a confusing low-level error from ``torch.library.infer_schema``. + ``magi_register_custom_op`` re-raises them as a clear ``TypeError`` with + an actionable hint and a ``where=`` locator. + """ + + def test_op_parameter_without_annotation_rejected(self): + with pytest.raises(TypeError, match="no type annotation"): + + @magi_register_custom_op(name="test::missing_param_ann") + def _op(x: torch.Tensor, scale) -> torch.Tensor: + return x * scale + + def test_op_first_parameter_without_annotation_rejected(self): + with pytest.raises(TypeError, match=r"parameter 'x'"): + + @magi_register_custom_op(name="test::missing_first_param_ann") + def _op(x, y: torch.Tensor) -> torch.Tensor: + return x + y + + def test_op_return_annotation_missing_rejected(self): + with pytest.raises(TypeError, match="return value"): + + @magi_register_custom_op(name="test::missing_return_ann") + def _op(x: torch.Tensor): + return x + + def test_var_positional_rejected(self): + with pytest.raises(TypeError, match=r"\*xs"): + + @magi_register_custom_op(name="test::varargs_rejected") + def _op(*xs: torch.Tensor) -> torch.Tensor: + return xs[0] + + def test_var_keyword_rejected(self): + with pytest.raises(TypeError, match=r"\*\*kwargs"): + + @magi_register_custom_op(name="test::varkw_rejected") + def _op(x: torch.Tensor, **kwargs: float) -> torch.Tensor: + return x + + def test_fully_annotated_op_still_works(self): + """Sanity check: the new guards do not regress the happy path.""" + + @magi_register_custom_op(name="test::fully_annotated_ok") + def _op(x: torch.Tensor, scale: float) -> torch.Tensor: + return x * scale + + x = torch.tensor([1.0, 2.0]) + assert_close(_op(x, scale=2.0), torch.tensor([2.0, 4.0])) + + +class TestLiteralAndEnumDowngrade: + """``Literal[str, ...]`` and string-valued ``Enum`` annotations are + auto-downgraded to ``str`` at the schema boundary; the op body still + receives the raw string. Numeric/heterogeneous Literals are rejected + with a clear error. + """ + + def test_literal_str_downgraded(self): + from typing import Literal + + @magi_register_custom_op(name="test::literal_str") + def _op(x: torch.Tensor, mode: Literal["a", "b"]) -> torch.Tensor: + return x * (2.0 if mode == "a" else 3.0) + + x = torch.ones(3) + assert_close(_op(x, "a"), torch.full((3,), 2.0)) + assert_close(_op(x, "b"), torch.full((3,), 3.0)) + + def test_string_enum_downgraded(self): + import enum + + class _Mode(enum.Enum): + A = "a" + B = "b" + + @magi_register_custom_op(name="test::enum_str") + def _op(x: torch.Tensor, mode: _Mode) -> torch.Tensor: + return x * (2.0 if mode == "a" else 3.0) + + # Note: caller passes the raw string value (the schema sees ``str``). + x = torch.ones(3) + assert_close(_op(x, "a"), torch.full((3,), 2.0)) + + def test_int_literal_rejected(self): + from typing import Literal + + with pytest.raises(TypeError, match="Literal"): + + @magi_register_custom_op(name="test::literal_int") + def _op(x: torch.Tensor, k: Literal[1, 2]) -> torch.Tensor: + return x + + def test_int_valued_enum_rejected(self): + import enum + + class _IntMode(enum.Enum): + A = 1 + B = 2 + + with pytest.raises(TypeError, match="Enum"): + + @magi_register_custom_op(name="test::enum_int") + def _op(x: torch.Tensor, mode: _IntMode) -> torch.Tensor: + return x + + +class TestTopLevelOptionalTensor: + """``Optional[Tensor]`` / ``Tensor | None`` is the most common attention / + bias / mask pattern. Verify both annotation flavours, with and without + a ``None`` default, and across forward + autograd paths. + """ + + def test_optional_tensor_default_none(self): + @magi_register_custom_op(name="test::opt_tensor_default") + def _op(x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: + return x.clone() if bias is None else x + bias + + x = torch.randn(3, 4) + b = torch.randn(3, 4) + assert_close(_op(x), x) + assert_close(_op(x, b), x + b) + + def test_optional_tensor_no_default_explicit_none(self): + @magi_register_custom_op(name="test::opt_tensor_explicit") + def _op(x: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor: + return x.clone() if bias is None else x + bias + + x = torch.randn(3, 4) + assert_close(_op(x, None), x) + + +class TestReturnListTensor: + """``list[Tensor]`` is the canonical return shape for split / chunk / + expert-output ops. Validate forward and that a user-supplied + ``infer_output_meta_fn`` is required for tracing. + """ + + def test_list_tensor_return_runtime(self): + def meta(x, n): + return [torch.empty_like(x.chunk(n, dim=0)[0]) for _ in range(n)] + + @magi_register_custom_op(name="test::list_tensor_return", infer_output_meta_fn=meta) + def _op(x: torch.Tensor, n: int) -> list[torch.Tensor]: + # ``.chunk`` returns views that alias ``x``; clone each chunk so + # the op doesn't violate torch.library's no-aliasing invariant + # for outputs. + return [c.clone() for c in x.chunk(n, dim=0)] + + x = torch.randn(8, 4) + out = _op(x, 4) + assert isinstance(out, list) + assert len(out) == 4 + assert all(o.shape == (2, 4) for o in out) + + +class TestTopLevelListIntDefault: + """A top-level parameter typed ``list[int]`` with a non-None default + should not crash registration (the previous behavior was to forward the + list literal to ``infer_schema``, which rejects it). + """ + + def test_list_int_default_does_not_crash(self): + @magi_register_custom_op(name="test::list_int_default") + def _op(x: torch.Tensor, dims: list[int] = [0]) -> torch.Tensor: + return x.sum(dim=dims, keepdim=True) + + x = torch.randn(3, 4) + # Caller still has to supply the list explicitly because the schema + # default was scrubbed; that's the documented price of using a list + # default with torch.library. + out = _op(x, [0]) + assert out.shape == (1, 4) + + +class TestAutograd: + """Tests for custom op with autograd support.""" + + def test_with_autograd(self): + """Test custom op with setup_context and backward functions.""" + + def _square_infer_output_meta(x: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x) + + def _square_setup_context(ctx, inputs, output): + (x,) = inputs + ctx.save_for_backward(x) + + def _square_backward(ctx, grad_output): + (x,) = ctx.saved_tensors + return grad_output * 2 * x + + @magi_register_custom_op( + name="test::square_op", + mutates_args=(), + infer_output_meta_fn=_square_infer_output_meta, + setup_context_fn=_square_setup_context, + backward_fn=_square_backward, + ) + def _square_op(x: torch.Tensor) -> torch.Tensor: + return x * x + + x = torch.randn(4, 8, requires_grad=True) + output = _square_op(x) + loss = output.sum() + loss.backward() + + # Gradient of x^2 is 2x + expected_grad = 2 * x + assert_close(x.grad, expected_grad) + + def test_autograd_multiple_inputs(self): + """Test autograd with multiple input tensors.""" + + def _weighted_sum_infer_output_meta(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor: + return torch.empty_like(a) + + def _weighted_sum_setup_context(ctx, inputs, output): + a, b, weight = inputs + ctx.save_for_backward(a, b) + ctx.weight = weight + + def _weighted_sum_backward(ctx, grad_output): + a, b = ctx.saved_tensors + weight = ctx.weight + grad_a = grad_output * weight + grad_b = grad_output * (1 - weight) + return grad_a, grad_b, None # None for non-tensor input + + @magi_register_custom_op( + name="test::weighted_sum_op", + mutates_args=(), + infer_output_meta_fn=_weighted_sum_infer_output_meta, + setup_context_fn=_weighted_sum_setup_context, + backward_fn=_weighted_sum_backward, + ) + def _weighted_sum_op(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor: + return a * weight + b * (1 - weight) + + a = torch.randn(4, 8, requires_grad=True) + b = torch.randn(4, 8, requires_grad=True) + weight = 0.7 + + output = _weighted_sum_op(a, b, weight) + loss = output.sum() + loss.backward() + + expected_grad_a = torch.ones_like(a) * weight + expected_grad_b = torch.ones_like(b) * (1 - weight) + + assert_close(a.grad, expected_grad_a) + assert_close(b.grad, expected_grad_b) + + def test_autograd_multiple_outputs(self): + """Test autograd with multiple output tensors.""" + + def _split_scale_infer_output_meta(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]: + half = x.shape[-1] // 2 + return (x.new_empty((*x.shape[:-1], half)), x.new_empty((*x.shape[:-1], half))) + + def _split_scale_setup_context(ctx, inputs, output): + x, scale = inputs + ctx.save_for_backward(x) + ctx.scale = scale + ctx.half = x.shape[-1] // 2 + + def _split_scale_backward(ctx, grad_out1, grad_out2): + (x,) = ctx.saved_tensors + scale = ctx.scale + # Reconstruct gradient for x + grad_x = torch.cat([grad_out1 * scale, grad_out2 * scale], dim=-1) + return grad_x, None + + @magi_register_custom_op( + name="test::split_scale_op", + mutates_args=(), + infer_output_meta_fn=_split_scale_infer_output_meta, + setup_context_fn=_split_scale_setup_context, + backward_fn=_split_scale_backward, + ) + def _split_scale_op(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]: + half = x.shape[-1] // 2 + return x[..., :half] * scale, x[..., half:] * scale + + x = torch.randn(4, 8, requires_grad=True) + scale = 2.0 + + out1, out2 = _split_scale_op(x, scale) + loss = out1.sum() + out2.sum() + loss.backward() + + expected_grad = torch.ones_like(x) * scale + assert_close(x.grad, expected_grad) + + +class TestDataclassInputWithBackward: + """``backward_fn`` is bridged so users can write it in the ORIGINAL + (dataclass-bearing) signature even though the registered op underneath + is flat. These tests pin every shape of bridging we promise. + """ + + def test_backward_with_tensor_and_dataclass_input(self): + """``backward`` returns ``(grad_x, None)`` against the original sig.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _ScaleCfg: + scale: float + + def _setup(ctx, inputs, output): + x, cfg = inputs # NOTE: original-signature view, not flat + assert isinstance(cfg, _ScaleCfg) + ctx.save_for_backward(x) + ctx.scale = cfg.scale + + def _bwd(ctx, grad_out): + (_x,) = ctx.saved_tensors + return grad_out * ctx.scale, None + + @magi_register_custom_op(name="test::dc_bwd_basic", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) + def _op(x: torch.Tensor, cfg: _ScaleCfg) -> torch.Tensor: + return x * cfg.scale + + x = torch.randn(4, 8, requires_grad=True) + cfg = _ScaleCfg(scale=3.0) + y = _op(x, cfg) + assert_close(y, x * 3.0) + (gx,) = torch.autograd.grad(y.sum(), x) + assert_close(gx, torch.full_like(x, 3.0)) + + def test_backward_returning_bare_grad_is_allowed(self): + """Single-input convenience: returning a bare tensor instead of a + 1-tuple should still work, mirroring stock PyTorch behaviour.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + alpha: float + + def _setup(ctx, inputs, output): + (cfg,) = inputs + ctx.alpha = cfg.alpha + + def _bwd(ctx, grad_out): + # Op has zero tensor inputs; nothing to backprop into. Returning + # ``None`` (or even a bare value) for the lone non-tensor input + # must be accepted by the bridge. + return None + + @magi_register_custom_op(name="test::dc_bwd_bare_grad", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) + def _op(cfg: _Cfg) -> torch.Tensor: + return torch.full((2, 3), cfg.alpha) + + out = _op(_Cfg(alpha=1.5)) + assert_close(out, torch.full((2, 3), 1.5)) + + def test_backward_with_per_field_grads(self): + """If the user wants to express grads at field granularity, returning + a same-shape dataclass (or a dict) for the dataclass slot must be + spread to the underlying flat slots. The bridge accepts ``None`` + leaves to stand in for non-differentiable fields.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + scale: float + offset: float + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + ctx.scale = cfg.scale + + def _bwd(ctx, grad_out): + # Express the dataclass grad as a same-shape dataclass with + # ``None`` leaves; the bridge must spread these to flat slots. + return grad_out * ctx.scale, _Cfg(scale=None, offset=None) + + @magi_register_custom_op(name="test::dc_bwd_per_field", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.scale + cfg.offset + + x = torch.randn(3, requires_grad=True) + y = _op(x, _Cfg(scale=2.0, offset=0.5)) + (gx,) = torch.autograd.grad(y.sum(), x) + assert_close(gx, torch.full_like(x, 2.0)) + + def test_backward_with_dict_shaped_dc_grad(self): + """The bridge should also accept a plain ``dict`` for the dataclass + slot (handy when the user does not want to construct another instance).""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + scale: float + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + ctx.scale = cfg.scale + + def _bwd(ctx, grad_out): + return grad_out * ctx.scale, {"scale": None} + + @magi_register_custom_op(name="test::dc_bwd_dict_grad", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.scale + + x = torch.randn(5, requires_grad=True) + y = _op(x, _Cfg(scale=4.0)) + (gx,) = torch.autograd.grad(y.sum(), x) + assert_close(gx, torch.full_like(x, 4.0)) + + def test_backward_with_nested_dataclass(self): + """Backward bridging must descend through nested dataclasses (the + flat slot count for a nested dc with ``None`` grad equals the total + number of leaf fields).""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Inner: + scale: float + bias: float + + @dataclass(frozen=True) + class _Outer: + inner: _Inner + tag: float + + def _setup(ctx, inputs, output): + x, cfg = inputs + assert isinstance(cfg, _Outer) and isinstance(cfg.inner, _Inner) + ctx.save_for_backward(x) + ctx.scale = cfg.inner.scale + + def _bwd(ctx, grad_out): + return grad_out * ctx.scale, None # whole nested dc => None + + @magi_register_custom_op(name="test::dc_bwd_nested", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) + def _op(x: torch.Tensor, cfg: _Outer) -> torch.Tensor: + return x * cfg.inner.scale + cfg.inner.bias + cfg.tag + + x = torch.randn(2, 4, requires_grad=True) + cfg = _Outer(inner=_Inner(scale=1.5, bias=0.25), tag=0.0) + y = _op(x, cfg) + (gx,) = torch.autograd.grad(y.sum(), x) + assert_close(gx, torch.full_like(x, 1.5)) + + def test_backward_two_tensor_inputs_around_dataclass(self): + """Two tensor inputs sandwiching a dataclass: both gradients must + land in the right flat slots.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + alpha: float + beta: float + + def _setup(ctx, inputs, output): + a, cfg, b = inputs + ctx.save_for_backward(a, b) + ctx.alpha = cfg.alpha + ctx.beta = cfg.beta + + def _bwd(ctx, grad_out): + a, b = ctx.saved_tensors + return grad_out * ctx.alpha, None, grad_out * ctx.beta + + @magi_register_custom_op(name="test::dc_bwd_sandwich", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) + def _op(a: torch.Tensor, cfg: _Cfg, b: torch.Tensor) -> torch.Tensor: + return a * cfg.alpha + b * cfg.beta + + a = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + cfg = _Cfg(alpha=2.0, beta=-3.0) + y = _op(a, cfg, b) + ga, gb = torch.autograd.grad(y.sum(), [a, b]) + assert_close(ga, torch.full_like(a, 2.0)) + assert_close(gb, torch.full_like(b, -3.0)) + + def test_backward_wrong_grad_count_raises(self): + """If ``backward_fn`` returns the wrong number of grads (relative to + the ORIGINAL signature), the user gets a clear error rather than an + opaque autograd shape mismatch.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + scale: float + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + + def _bad_bwd(ctx, grad_out): + return (grad_out,) # missing the dataclass slot + + @magi_register_custom_op( + name="test::dc_bwd_wrong_count", mutates_args=(), setup_context_fn=_setup, backward_fn=_bad_bwd + ) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.scale + + x = torch.randn(3, requires_grad=True) + y = _op(x, _Cfg(scale=2.0)) + with pytest.raises(ValueError, match=r"backward_fn returned 1 grad"): + torch.autograd.grad(y.sum(), x) + + +@dataclasses.dataclass(frozen=True) +class _BwTupleCfg: + s: float + + +class TestTupleOutputBackward: + """``backward_fn`` with ``tuple`` outputs and a dataclass input.""" + + def test_two_output_backward_with_dataclass(self): + def setup(ctx, inputs, output): + x, _cfg = inputs + ctx.save_for_backward(x) + ctx.s = _cfg.s + + def bwd(ctx, gy0, gy1): + (x,) = ctx.saved_tensors + # d(out0)/dx = s, d(out1)/dx = 1 + return gy0 * ctx.s + gy1, None + + @magi_register_custom_op(name="test::tuple_bwd_dc", setup_context_fn=setup, backward_fn=bwd) + def _op(x: torch.Tensor, cfg: _BwTupleCfg) -> tuple[torch.Tensor, torch.Tensor]: + return x * cfg.s, x.clone() + + x = torch.randn(4, requires_grad=True) + y0, y1 = _op(x, _BwTupleCfg(s=3.0)) + (y0.sum() + 2.0 * y1.sum()).backward() + assert x.grad is not None + # gy0 = 1, gy1 = 2 -> grad = 1*3 + 2 = 5 + assert torch.allclose(x.grad, torch.full_like(x, 5.0)) + + +class TestPerFieldNoneGrad: + """The autograd bridge supports returning a partially-None grad for a + dataclass input (i.e. some Tensor fields differentiable, others not). + """ + + def test_partial_none_grad_for_dataclass(self): + @dataclasses.dataclass(frozen=True) + class _Cfg: + w: torch.Tensor + b: torch.Tensor # we'll mark this non-differentiable + + def setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x, cfg.w) + + def bwd(ctx, gy): + x, w = ctx.saved_tensors + # x: gx = gy * w; cfg.w: gw = gy * x; cfg.b: None. + return gy * w, {"w": gy * x, "b": None} + + @magi_register_custom_op(name="test::partial_none_grad", setup_context_fn=setup, backward_fn=bwd) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.w + cfg.b + + x = torch.randn(4, requires_grad=True) + w = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=False) + cfg = _Cfg(w=w, b=b) + out = _op(x, cfg) + out.sum().backward() + assert torch.allclose(x.grad, w) + assert torch.allclose(w.grad, x) + # b did not require grad; nothing to assert beyond "no exception". + + +class TestBackwardCallsAnotherOp: + """``backward_fn`` is allowed to call other registered ops to compute the + gradient. This is the FlashAttention-style "forward op + backward op" + decomposition. Verify the gradient is correct end-to-end. + """ + + def test_backward_dispatches_to_another_op(self): + @magi_register_custom_op(name="test::matmul_grad_helper") + def matmul_grad_x(gy: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + return gy @ w.t() + + def setup(ctx, inputs, output): + x, w = inputs + ctx.save_for_backward(x, w) + + def bwd(ctx, gy): + x, w = ctx.saved_tensors + return matmul_grad_x(gy, w), x.t() @ gy + + @magi_register_custom_op(name="test::matmul_with_op_bwd", setup_context_fn=setup, backward_fn=bwd) + def matmul(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + return x @ w + + x = torch.randn(3, 4, requires_grad=True) + w = torch.randn(4, 5, requires_grad=True) + out = matmul(x, w) + out.sum().backward() + gy = torch.ones_like(out) + assert torch.allclose(x.grad, gy @ w.t()) + assert torch.allclose(w.grad, x.t() @ gy) + + +@dataclasses.dataclass(frozen=True) +class _KwCfg: + s: float + + +class TestKwargOnlyDataclass: + """``cfg`` passed as a kwarg-only argument should still be flattened and + routed correctly through the dataclass-aware bridge. + """ + + def test_kwarg_only_dataclass_call(self): + @magi_register_custom_op(name="test::kwonly_dc") + def _op(x: torch.Tensor, *, cfg: _KwCfg) -> torch.Tensor: + return x * cfg.s + + x = torch.randn(2, 3) + out = _op(x, cfg=_KwCfg(s=4.0)) + assert_close(out, x * 4.0) + + def test_kwarg_only_dataclass_with_backward(self): + def setup(ctx, inputs, output): + x, cfg = inputs + ctx.s = cfg.s + + def bwd(ctx, gy): + return gy * ctx.s, None + + @magi_register_custom_op(name="test::kwonly_dc_bwd", setup_context_fn=setup, backward_fn=bwd) + def _op(x: torch.Tensor, *, cfg: _KwCfg) -> torch.Tensor: + return x * cfg.s + + x = torch.randn(3, requires_grad=True) + out = _op(x, cfg=_KwCfg(s=2.5)) + out.sum().backward() + assert torch.allclose(x.grad, torch.full_like(x, 2.5)) + + +@dataclasses.dataclass(frozen=True) +class _MutCfg: + a: torch.Tensor + b: torch.Tensor + + +class TestMutatesArgsOnDataclass: + """``mutates_args=('cfg',)`` on a dataclass parameter should expand to the + flat leaf names automatically; an unknown name should raise. + """ + + def test_mutates_args_dataclass_expands(self): + # We don't actually mutate (frozen dataclass) -- we just check the + # registration succeeds, which it would not if the name failed to + # expand to flat tensor leaves. + @magi_register_custom_op(name="test::mutates_dc_expand", mutates_args=("cfg",)) + def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: + cfg.a.add_(x) # in-place on a tensor field + cfg.b.add_(x) + return cfg.a + cfg.b + + x = torch.ones(3) + cfg = _MutCfg(a=torch.zeros(3), b=torch.zeros(3)) + out = _op(x, cfg) + assert_close(out, torch.full((3,), 2.0)) + + def test_mutates_args_unknown_name_rejected(self): + with pytest.raises(ValueError, match="does not match"): + + @magi_register_custom_op(name="test::mutates_dc_unknown", mutates_args=("does_not_exist",)) + def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: + return x + + def test_mutates_args_flat_name_passthrough(self): + """Users may also use the flat name (``cfg__a``) directly.""" + + @magi_register_custom_op(name="test::mutates_dc_flat", mutates_args=("cfg__a",)) + def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: + cfg.a.add_(x) + return cfg.a + cfg.b + + x = torch.ones(3) + cfg = _MutCfg(a=torch.zeros(3), b=torch.zeros(3)) + out = _op(x, cfg) + assert_close(out, torch.ones(3)) + + +class TestNestedMagiOpCall: + """Calling one ``@magi_register_custom_op``-registered op from inside + another is the bread-and-butter compositional pattern (e.g. a fused + attention op that internally calls a registered softmax). Verify both + plain-input and dataclass-input compositions. + """ + + def test_plain_op_inside_plain_op(self): + @magi_register_custom_op(name="test::nested_inner_plain") + def inner(x: torch.Tensor) -> torch.Tensor: + return x * 2 + + @magi_register_custom_op(name="test::nested_outer_plain") + def outer(x: torch.Tensor) -> torch.Tensor: + return inner(x) + 1 + + x = torch.randn(4) + assert_close(outer(x), x * 2 + 1) + + def test_dataclass_op_inside_plain_op(self): + @dataclasses.dataclass(frozen=True) + class _Cfg: + s: float + + @magi_register_custom_op(name="test::nested_inner_dc") + def inner(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.s + + @magi_register_custom_op(name="test::nested_outer_calls_dc") + def outer(x: torch.Tensor) -> torch.Tensor: + return inner(x, _Cfg(s=3.0)) + 1 + + x = torch.randn(4) + assert_close(outer(x), x * 3.0 + 1) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 866ea917062dfec4a5b595ccf823bf0e7b5a94bb Mon Sep 17 00:00:00 2001 From: HongyuJia Date: Sun, 17 May 2026 14:38:36 +0800 Subject: [PATCH 2/6] polish docs --- magi_compiler/_magi_register_custom_op.py | 14 ++-- magi_compiler/api.py | 80 +++++------------------ 2 files changed, 24 insertions(+), 70 deletions(-) diff --git a/magi_compiler/_magi_register_custom_op.py b/magi_compiler/_magi_register_custom_op.py index f50e668..adcec49 100644 --- a/magi_compiler/_magi_register_custom_op.py +++ b/magi_compiler/_magi_register_custom_op.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Magi custom-op registration: dataclass-aware wrapper around ``torch.library``. - -This module implements ``magi_register_custom_op`` -- a decorator that takes a -plain Python function (possibly with frozen-dataclass parameters, -``Literal[str, ...]`` / string-Enum annotations, or other Python-rich -signatures that ``torch.library.infer_schema`` cannot consume) and registers -it as a real custom op while letting the user keep calling it with their -original, ergonomic signature. +""" +Magi custom-op registration: dataclass-aware wrapper around ``torch.library``. + +This module implements ``magi_register_custom_op`` -- a decorator that takes +a plain Python function and registers it as a real custom op while letting +the user keep calling it with their original signature. Part A. Registration-time pipeline -- the four slots diff --git a/magi_compiler/api.py b/magi_compiler/api.py index a712a5b..5821550 100644 --- a/magi_compiler/api.py +++ b/magi_compiler/api.py @@ -203,30 +203,19 @@ def magi_register_custom_op( is_compute_sensitive: bool = False, is_subgraph_boundary: bool = False, ): - """Register a Python function as a custom op for ``torch.library`` / ``torch.compile``. - - Combines ``@torch.library.custom_op`` + ``@torch.library.register_fake`` + - ``fn.register_autograd`` into one decorator, plus the following ergonomic - affordances on top of bare ``torch.library``: - - - **Frozen-dataclass parameters** (recursively nested) are flattened into - primitive leaves before being handed to ``infer_schema``, and reassembled - inside ``fn`` so the op body still sees the original dataclass. - - **Literal / string-Enum annotations** are auto-downgraded to ``str``; - the op body still receives the original value. - - **Unsupported defaults** (mutable, dataclass instances, ...) are scrubbed - from the lowered signature only; user-facing calls keep the original default. - - **Auto-generated op name** when ``name`` is omitted: derived from the - function's source file and ``__name__``. - - **Auto-generated meta function** when ``infer_output_meta_fn`` is omitted: - output ``i`` copies shape/dtype/device of the ``i``-th tensor input. + """ + A unified decorator to register a custom operator with PyTorch's library. + + It supports advanced features like frozen-dataclass param and combines the + functionality of: + - @torch.library.custom_op + - @torch.library.register_fake + - fn.register_autograd Arguments: - name: Fully qualified op name (e.g. ``"namespace::op_name"``). If ``None``, - auto-generated from the function's source file and name. - mutates_args: Argument names that the op mutates. For a frozen-dataclass - argument, listing the dataclass parameter expands to every Tensor leaf - under it; lowered leaf names (e.g. ``"cfg__weight"``) are also accepted. + name: Fully qualified op name (e.g. ``"namespace::op_name"``). If + ``None``, auto-generated from the function's source file and name. + mutates_args: Argument names that the op mutates. infer_output_meta_fn: How to propagate output metadata at trace time. - ``None`` (default): output ``i`` copies the ``i``-th tensor input. - ``list[str]``: parameter names whose metadata to copy. E.g. @@ -234,26 +223,18 @@ def magi_register_custom_op( and ``output[1]`` shape-match ``bias``. - ``Callable``: a function with the same signature as the op that returns ``torch.empty_like(...)`` tensors of the expected shapes. - setup_context_fn: Forward-context setup; signature - ``setup_context_fn(ctx, inputs, output)``. ``inputs`` is the - user-side (original-shape) tuple, including dataclass instances. + setup_context_fn: Function to save tensors/values for backward. + Signature: setup_context_fn(ctx, inputs, output) backward_fn: Gradient computation; signature - ``backward_fn(ctx, *grad_outputs) -> tuple of grads``. Return **one - grad per original parameter** (not per lowered leaf); use ``None`` - for non-differentiable parameters, including whole dataclass args. - is_compute_sensitive: Mark as compute-intensive. During activation - recomputation, outputs of compute-sensitive ops are prioritised for - saving rather than recomputing. + ``backward_fn(ctx, *grad_outputs) -> tuple of grads``. + is_compute_sensitive: marks this operator as compute-intensive (e.g., + MatMul, Attention). During training, outputs of compute-sensitive + ops are prioritised for saving rather than recomputing. is_subgraph_boundary: Split the FX graph at this op during compilation. Each sub-graph between boundary ops is compiled independently. Returns: A callable with the user's original signature. - - If ``fn`` has no dataclass parameter, returns a ``torch._ops.OpOverload`` - directly (zero per-call overhead). - - If ``fn`` has a frozen-dataclass parameter, returns a Python wrapper - that flattens/unflattens on each call and dispatches to the underlying - ``OpOverload`` (accessible via ``op._magi_torch_registered_op``). Examples: 1. Basic usage (forward only, auto-generated name and meta function): @@ -283,32 +264,7 @@ def magi_register_custom_op( ... >>> @magi_register_custom_op() ... def attn(q: torch.Tensor, k: torch.Tensor, cfg: AttnCfg) -> torch.Tensor: - ... scores = (q @ k.transpose(-1, -2)) * cfg.scale - ... if cfg.causal: - ... scores = scores.tril() - ... return scores - - 4. Full custom op with autograd support: - - >>> def _square_meta(x: torch.Tensor) -> torch.Tensor: - ... return torch.empty_like(x) - ... - >>> def _square_setup_context(ctx, inputs, output): - ... (x,) = inputs - ... ctx.save_for_backward(x) - ... - >>> def _square_backward(ctx, grad_output): - ... (x,) = ctx.saved_tensors - ... return grad_output * 2 * x - ... - >>> @magi_register_custom_op( - ... name="my_ops::square", - ... infer_output_meta_fn=_square_meta, - ... setup_context_fn=_square_setup_context, - ... backward_fn=_square_backward, - ... ) - ... def square(x: torch.Tensor) -> torch.Tensor: - ... return x * x + ... pass """ return _magi_register_custom_op_impl( name=name, From 6f9a17e897a0e10c9c2fe7b9d1ce5625bdf9bb58 Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Tue, 19 May 2026 12:14:50 +0800 Subject: [PATCH 3/6] Reconstruct _magi_register_custom_op.py --- magi_compiler/_magi_register_custom_op.py | 946 +++++++++++----------- 1 file changed, 470 insertions(+), 476 deletions(-) diff --git a/magi_compiler/_magi_register_custom_op.py b/magi_compiler/_magi_register_custom_op.py index adcec49..2fa1d7d 100644 --- a/magi_compiler/_magi_register_custom_op.py +++ b/magi_compiler/_magi_register_custom_op.py @@ -20,103 +20,36 @@ the user keep calling it with their original signature. -Part A. Registration-time pipeline -- the four slots -==================================================== - -When ``@magi_register_custom_op(...)`` is applied to a user function, up to -four named slots are produced. Each slot is a concrete callable object. - - slot 0 -- fn - The user's original function. Always present. - - slot 1 -- lowered_fn - A thin wrapper around ``fn`` whose ``__signature__`` / - ``__annotations__`` have been *lowered* (Literal/Enum -> str, - unsupported defaults scrubbed, dataclasses flattened into primitive - leaves) so that ``torch.library.infer_schema`` accepts it. - Skipped when ``fn``'s signature is already schema-compatible. - - slot 2 -- torch_registered_op - The ``OpOverload`` returned by ``torch.library.custom_op`` / - ``register_fake`` after registering whichever of ``fn`` / - ``lowered_fn`` reached this point. Always present. - - slot 3 -- magi_exposed_op - A magi-level Python wrapper around ``torch_registered_op`` that - preserves the user's ORIGINAL (dataclass-bearing) calling - convention. At call time it flattens incoming args via the static - ``param_mapping_tree`` and dispatches into slot 2. Only created - on the dataclass-flatten path. - -The naming is a deliberate dual: ``torch_registered_op`` is *registered -into* torch.library's dispatcher; ``magi_exposed_op`` is *exposed out of* -magi to user code. - - -Part B. Runtime paths -- the three pipelines -============================================ - -Three pipelines are possible; the decorator returns whichever object sits -at the end of the path: - - 1. simple fn -> torch_registered_op - Returned: ``torch._ops.OpOverload`` (slot 2). - Runtime: zero magi-level overhead -- straight into torch.library's - dispatcher. - - 2. sig-only-rewrite fn -> lowered_fn -> torch_registered_op - Returned: ``torch._ops.OpOverload`` (slot 2). - Runtime: same as simple -- ``lowered_fn`` is a transparent - forwarding shim (the rewrite is registration-time only). - - 3. dataclass-flatten fn -> lowered_fn -> torch_registered_op - -> magi_exposed_op - Returned: a Python callable carrying the - ``_magi_torch_registered_op`` attribute (slot 3). - Runtime forward (per call): - user code calls magi_exposed_op(x, cfg=...) - -> _flatten_call_args (original kwargs -> flat tuple) - -> _flatten_value_into (DFS over param_mapping_tree) - -> torch_registered_op(*flat) (slot 2 -- enters dispatcher) - -> lowered_fn(*flat) (slot 1 -- still in lowered shape) - -> _reassemble_kwargs (flat tuple -> original kwargs) - -> _build_value_from_node (rebuilds dataclass instances) - -> fn(**original_kwargs) (slot 0 -- user code finally sees - its original dataclass-bearing - signature) - Runtime backward (when backward_fn is supplied): - autograd calls _bridged_backward(ctx, *grads) - -> user_backward(ctx, *grads) (returns one grad per ORIGINAL - input, possibly a dataclass-shaped - grad object) - -> _flatten_grads (original grads -> flat grads) - -> _flatten_grad_into (DFS over param_mapping_tree) - -You can tell at runtime which pipeline an op went through by inspecting -the decorator's return value: an ``OpOverload`` means simple/sig-rewrite; -a Python callable carrying ``_magi_torch_registered_op`` means -dataclass-flatten. - - File layout =========== - -- registration-time helpers (executed once) -- - 1. Validate the user's fn signature - 2. Resolve types & sanitise defaults for infer_schema - 3. Build & query the param mapping tree (used by sec 4 and sec 7) - 4. Lower fn's signature (produces slot 1) - 5. Synthesise the meta/fake function (input to slot 2) - 6. Register the op (produces slot 2) +The file has five blocks. Each block groups its own helpers (private, +above) with the one core piece it exists to support (below). Block +boundaries follow the 5-stage pipeline. + + Block 0 -- VALIDATE op signature constraints (registration-time) + helpers: assertion predicates + validation primitives + core: _validate_op_signature_constraints - -- runtime helpers (executed on every call) -- - 7. Runtime bridge: flatten / unflatten on every call + Block 1 -- LOWER (registration-time) + helpers: type resolution, default scrubbing, param-mapping-tree construction + core: _lower_op_signature (produces slot 1) - -- main pipeline -- - 8. The decorator: orchestrates sec 1-6 and builds the runtime - closures from sec 7 (produces slot 3 on the flatten path) + Block 2 -- REGISTER (registration-time) + helpers: op-name generation, meta/fake-fn synthesis + core: _register_torch_op (produces slot 2) + + Block 3 -- RUNTIME ADAPTER (runtime) + helpers: flatten / unflatten primitives + signature-bound wrappers + core: _DataclassRuntimeAdapter (used by slot 3) + + Block 4 -- MAIN PIPELINE + core: _magi_register_custom_op_impl (the decorator; + orchestrates blocks 0-3, produces slot 3 on the flatten path) """ +from __future__ import annotations + import dataclasses import functools import inspect @@ -129,12 +62,16 @@ from .utils.logger import magi_logger # ============================================================================== -# 1. Validate the user's fn signature -# ------------------------------------------------------------------------------ -# Predicate + assert helpers that reject `fn` signatures torch.library cannot -# consume, each raising a clear `TypeError` instead of the opaque error that -# would otherwise surface deep inside `infer_schema`. Called from -# `_lower_op_signature` (sec 4) and `_build_dataclass_sub_mapping_tree` (sec 3). +# BLOCK 0 -- VALIDATE op signature constraints +# +# Helpers: +# - type predicates: +# _is_frozen_dataclass +# - assertion primitives: +# _assert_not_unsupported_container, _assert_not_dataclass_return, +# _assert_not_mutable_dataclass, _assert_has_annotation, +# _assert_no_var_args, _assert_resolved_field_type +# Core: _validate_op_signature_constraints # ============================================================================== @@ -220,23 +157,72 @@ def _assert_resolved_field_type(f_type, *, where: str) -> None: ) +def _validate_op_signature_constraints(fn: Callable) -> None: + """Validate fn parameters/return and recursively validate frozen dataclass subtrees.""" + original_sig = inspect.signature(fn) + resolved = _resolve_annotations(fn) + + def _validate_through( + params_or_fields, *, owner_name: str, is_fn_params: bool, field_types: dict[str, Any] | None = None + ) -> None: + for item in params_or_fields: + if is_fn_params: + name, param = item + _assert_no_var_args(param, fn_name=fn.__name__) + annotation = resolved.get(name, param.annotation) + where = f"parameter {name!r} of {owner_name}" + else: + field = item + name = field.name + annotation = (field_types or {}).get(name, field.type) + where = f"field {owner_name}.{name}" + _assert_resolved_field_type(annotation, where=where) + + _assert_has_annotation(annotation, where=where) + _assert_not_mutable_dataclass(annotation, where=where) + + if _is_frozen_dataclass(annotation): + _validate_through( + dataclasses.fields(annotation), + owner_name=annotation.__name__, + is_fn_params=False, + field_types=_resolve_dataclass_field_types(annotation), + ) + else: + _assert_not_unsupported_container(annotation, where=where) + + _validate_through(original_sig.parameters.items(), owner_name=f"{fn.__name__!r}", is_fn_params=True) + + return_annotation = resolved.get("return", original_sig.return_annotation) + _assert_has_annotation(return_annotation, where=f"return value of {fn.__name__!r}") + _assert_not_dataclass_return(return_annotation, fn_name=fn.__name__) + + # ============================================================================== -# 2. Resolve types & sanitise defaults for infer_schema -# ------------------------------------------------------------------------------ -# Resolve stringified annotations to real types, downgrade Literal/string-Enum -# to `str`, and scrub defaults that `infer_schema` cannot render. Called by -# `_lower_op_signature` (sec 4) and `_build_dataclass_sub_mapping_tree` (sec 3). +# BLOCK 1 -- LOWER fn signature +# +# Helpers: +# - type resolution & default scrubbing: +# _resolve_annotations, _resolve_dataclass_field_types, +# _maybe_downgrade_literal_or_enum, +# _schema_compatible_field_default, _schema_compatible_param_default +# - param-mapping-tree construction: +# _register_dataclass_pytree, +# _expand_mutates_args +# - lowered-signature utilities: +# _apply_lowered_signature, +# _make_lowered_signature_wrapper +# Core: _lower_op_signature # ============================================================================== -def _resolve_annotations(fn: Callable) -> dict[str, Any]: - """Return ``fn``'s annotations as real types, resolving stringified ones. +# ------------------------------------------------------------------------------ +# helpers: resolve types & sanitise defaults for infer_schema +# ------------------------------------------------------------------------------ - Falls back to per-annotation eval against ``globals + closure nonlocals`` - when ``get_type_hints`` can't resolve atomically (typical for functions - defined inside another function whose annotations reference enclosing - names). - """ + +def _resolve_annotations(fn: Callable) -> dict[str, Any]: + """Return ``fn`` annotations as real types, with globals+closure eval fallback.""" import typing try: @@ -277,20 +263,38 @@ def _resolve_annotations(fn: Callable) -> dict[str, Any]: def _resolve_dataclass_field_types(cls: type) -> dict[str, Any]: """Return ``cls``'s field name -> resolved type (best-effort).""" + import sys import typing as _typing try: return _typing.get_type_hints(cls) except Exception: - return {f.name: f.type for f in dataclasses.fields(cls)} + pass + # ``get_type_hints(cls)`` is all-or-nothing; fall back to per-field eval so + # one unresolved annotation does not poison the whole dataclass. + namespace: dict[str, Any] = {} + module = sys.modules.get(getattr(cls, "__module__", "")) + if module is not None: + namespace.update(vars(module)) + namespace.update(getattr(cls, "__dict__", {})) + namespace.setdefault(cls.__name__, cls) -def _maybe_downgrade_literal_or_enum(annotation, *, where: str): - """Collapse ``Literal[str, ...]`` and string-Enum annotations to plain ``str``. + resolved: dict[str, Any] = {} + for f in dataclasses.fields(cls): + tp = f.type + if isinstance(tp, str): + try: + resolved[f.name] = eval(tp, namespace, None) + except Exception: + resolved[f.name] = tp + else: + resolved[f.name] = tp + return resolved - Lossless because the op body still receives the original string value. - Mixed/numeric Literals and non-string Enums raise (no safe downgrade). - """ + +def _maybe_downgrade_literal_or_enum(annotation, *, where: str): + """Downgrade ``Literal[str,...]`` and string-valued Enums to ``str`` or raise.""" import enum import typing @@ -321,23 +325,8 @@ def _maybe_downgrade_literal_or_enum(annotation, *, where: str): _SCHEMA_DEFAULT_TYPES: tuple[type, ...] = (int, float, bool, str, torch.device, torch.dtype) -def _schema_compatible_param_default(default: Any) -> Any: - """Scrub a top-level parameter default that ``infer_schema`` cannot render. - - Same rules as :func:`_schema_compatible_default`, but for raw values - rather than ``dataclasses.Field``. - """ - if default is inspect.Parameter.empty: - return inspect.Parameter.empty - if default is None or isinstance(default, _SCHEMA_DEFAULT_TYPES): - return default - return inspect.Parameter.empty - - -def _schema_compatible_default(f: "dataclasses.Field") -> Any: - """Lowered default for dataclass field ``f``: keep ``None`` / int / float / - bool / str / torch.device / torch.dtype; drop everything else (the user- - constructed dataclass instance still carries the real default at runtime).""" +def _schema_compatible_field_default(f: "dataclasses.Field") -> Any: + """Return schema-safe field default (including resolved ``default_factory``) or empty.""" if f.default is not dataclasses.MISSING: d = f.default if d is None or isinstance(d, _SCHEMA_DEFAULT_TYPES): @@ -354,21 +343,24 @@ def _schema_compatible_default(f: "dataclasses.Field") -> Any: return inspect.Parameter.empty -# ============================================================================== -# 3. Build & query the param mapping tree +def _schema_compatible_param_default(default: Any) -> Any: + """Return schema-safe parameter default or ``inspect.Parameter.empty``.""" + if default is inspect.Parameter.empty: + return inspect.Parameter.empty + if default is None or isinstance(default, _SCHEMA_DEFAULT_TYPES): + return default + return inspect.Parameter.empty + + +# ------------------------------------------------------------------------------ +# helpers: build & query the param mapping tree # ------------------------------------------------------------------------------ -# The `param_mapping_tree` is the single source of truth bridging the user's -# (possibly nested-dataclass) signature and the lowered primitive signature. -# Built once at registration time and consumed twice afterwards: by -# `_expand_mutates_args` (statically, sec 6) and by the runtime bridge (sec 7). -# ============================================================================== _DATACLASS_PYTREE_REGISTERED: set[type] = set() def _register_dataclass_pytree(cls: type) -> None: - """Register ``cls`` as a pytree node (idempotent) so Dynamo / AOTAutograd - can flatten and unflatten dataclass instances during tracing.""" + """Idempotently register dataclass ``cls`` as a pytree node for tracing.""" if cls in _DATACLASS_PYTREE_REGISTERED: return @@ -388,80 +380,24 @@ def _unflatten(values, ctx): _DATACLASS_PYTREE_REGISTERED.add(cls) -def _build_dataclass_sub_mapping_tree(cls: type, attr_name: str, flat_prefix: str) -> tuple[tuple, list[inspect.Parameter]]: - """Recursively expand a frozen-dataclass type into one ``param_mapping_tree`` - subtree plus its flat list of leaf ``inspect.Parameter`` objects (DFS order). - - ``attr_name`` is the field name on the parent dataclass (or the parameter - name on ``fn`` for a top-level dataclass arg). ``flat_prefix`` builds the - leaf parameter name; e.g. ``cfg: OuterCfg(inner: InnerCfg(val: float))`` - becomes a lowered leaf parameter ``cfg__inner__val``. - """ - _register_dataclass_pytree(cls) +def _expand_mutates_args(param_mapping_tree: list[tuple], mutates_args: tuple[str, ...] | list[str]) -> tuple[str, ...]: + """Expand/validate ``mutates_args`` from original names into lowered leaf names.""" - field_types = _resolve_dataclass_field_types(cls) - children: list[tuple] = [] - lowered_params: list[inspect.Parameter] = [] - - for f in dataclasses.fields(cls): - f_type = field_types.get(f.name, f.type) - child_flat_name = f"{flat_prefix}__{f.name}" - _assert_has_annotation(f_type, where=f"field {cls.__name__}.{f.name}") - _assert_resolved_field_type(f_type, where=f"field {cls.__name__}.{f.name}") - _assert_not_mutable_dataclass(f_type, where=f"field {cls.__name__}.{f.name}") - if _is_frozen_dataclass(f_type): - sub_node, sub_params = _build_dataclass_sub_mapping_tree(f_type, attr_name=f.name, flat_prefix=child_flat_name) - children.append(sub_node) - lowered_params.extend(sub_params) - else: - _assert_not_unsupported_container(f_type, where=f"field {cls.__name__}.{f.name}") - f_type = _maybe_downgrade_literal_or_enum(f_type, where=f"field {cls.__name__}.{f.name}") - children.append(("primitive", f.name, child_flat_name, None)) - lowered_params.append( - inspect.Parameter( - child_flat_name, - # POSITIONAL_OR_KEYWORD: torch.library.custom_op does not yet - # support kwarg-only Tensor arguments. - inspect.Parameter.POSITIONAL_OR_KEYWORD, - default=_schema_compatible_default(f), - annotation=f_type, - ) - ) + def _collect_tensor_leaf_lowered_attr_names(node: tuple) -> list[str]: + if node[0] == "primitive": + _, _attr, lowered_attr_name, _ = node + return [lowered_attr_name] + out: list[str] = [] + for child in node[3]: + out.extend(_collect_tensor_leaf_lowered_attr_names(child)) + return out - return ("dataclass", attr_name, cls, children), lowered_params - - -def _count_leaves(node: tuple) -> int: - """Number of lowered parameter slots a ``param_mapping_tree`` ``node`` occupies.""" - if node[0] == "primitive": - return 1 - return sum(_count_leaves(c) for c in node[3]) - - -def _collect_tensor_leaf_lowered_names(node: tuple) -> list[str]: - """Lowered names of every leaf under ``node``. Used to expand a dataclass - parameter referenced in ``mutates_args`` (torch.library does its own - Tensor-type validation, so over-specifying non-Tensor leaves is fine).""" - if node[0] == "primitive": - _, _attr, lowered_name, _ = node - return [lowered_name] - out: list[str] = [] - for child in node[3]: - out.extend(_collect_tensor_leaf_lowered_names(child)) - return out - - -def _expand_mutates_args(mutates_args: tuple[str, ...] | list[str], param_mapping_tree: list[tuple]) -> tuple[str, ...]: - """Translate ``mutates_args`` from the original parameter space to the - lowered space: top-level dataclass names expand to all their leaves; - primitive top-level names and already-lowered names pass through; unknown - names raise ``ValueError`` listing valid choices.""" if not mutates_args: return tuple(mutates_args) by_attr: dict[str, tuple] = {node[1]: node for node in param_mapping_tree} valid_lowered: set[str] = set() for node in param_mapping_tree: - valid_lowered.update(_collect_tensor_leaf_lowered_names(node)) + valid_lowered.update(_collect_tensor_leaf_lowered_attr_names(node)) out: list[str] = [] for name in mutates_args: if name in by_attr: @@ -469,7 +405,7 @@ def _expand_mutates_args(mutates_args: tuple[str, ...] | list[str], param_mappin if node[0] == "primitive": out.append(node[2]) else: - out.extend(_collect_tensor_leaf_lowered_names(node)) + out.extend(_collect_tensor_leaf_lowered_attr_names(node)) elif name in valid_lowered: out.append(name) else: @@ -487,31 +423,13 @@ def _expand_mutates_args(mutates_args: tuple[str, ...] | list[str], param_mappin return tuple(deduped) -# ============================================================================== -# 4. Lower fn's signature (produces slot 1) # ------------------------------------------------------------------------------ -# Produces slot 1 (`lowered_fn`) in two stages: -# data: `_lower_op_signature` walks `fn`'s parameters once and emits -# `(original_sig, lowered_sig, param_mapping_tree)`, calling into -# sec 1 (validate), sec 2 (resolve/sanitise), sec 3 (tree build). -# object: `_make_lowered_signature_wrapper` stamps the lowered signature -# onto a forwarding wrapper of `fn`. -# `_signatures_differ` lets the decorator (sec 6) skip the wrapper entirely -# when lowering was a no-op (zero-overhead path). -# ============================================================================== - - -def _signatures_differ(original: inspect.Signature, lowered: inspect.Signature) -> bool: - """True iff ``lowered`` differs from ``original`` on parameter names, - annotations, defaults, kinds, or return annotation. The decorator uses - this to skip slot 1 entirely when lowering was a no-op (zero-overhead path).""" - return original != lowered +# core: _lower_op_signature (and its lowered-signature wrapper utilities) +# ------------------------------------------------------------------------------ -def _apply_lowered_signature_metadata(wrapper: Callable, lowered_sig: inspect.Signature) -> None: - """In-place: stamp ``wrapper`` with ``lowered_sig`` as its ``__signature__`` - / ``__annotations__``, and strip ``__wrapped__`` so ``inspect.signature`` - cannot fall back to the original (un-lowered) signature on ``fn``.""" +def _apply_lowered_signature(lowered_sig: inspect.Signature, wrapper: Callable) -> None: + """Stamp wrapper signature/annotations with ``lowered_sig`` and clear ``__wrapped__``.""" wrapper.__signature__ = lowered_sig lowered_annotations = { p.name: p.annotation for p in lowered_sig.parameters.values() if p.annotation is not inspect.Parameter.empty @@ -527,96 +445,129 @@ def _apply_lowered_signature_metadata(wrapper: Callable, lowered_sig: inspect.Si pass -def _make_lowered_signature_wrapper(fn: Callable, lowered_sig: inspect.Signature) -> Callable: - """Forwarding wrapper around ``fn`` carrying ``lowered_sig`` as metadata. - Used on the no-flattening path so ``infer_schema`` sees the cleaned-up - signature instead of ``fn``'s original annotations.""" - - @functools.wraps(fn) - def _wrapped(*args, **kwargs): - return fn(*args, **kwargs) - - _apply_lowered_signature_metadata(_wrapped, lowered_sig) - return _wrapped - - def _lower_op_signature(fn: Callable): """Lower ``fn``'s signature into a form ``torch.library.infer_schema`` accepts. - "Lower" is used in the compiler sense (high-level -> low-level): we walk - ``fn``'s parameters once and do six things at the same time -- they all - need the same resolved annotations and the same iteration: - - 1. VALIDATE -- reject variadics, missing annotations, mutable dataclasses, - unsupported containers, dataclass returns (sec 1). - 2. RESOLVE -- turn stringified annotations into real types via - ``_resolve_annotations``, so dataclass detection works. + Pipeline: + 1. RESOLVE -- turn stringified annotations/dataclass fields into real types (best-effort). + If failed, trying globals+closure evalining per annotation/dataclass field. + 2. FLATTEN -- recursively flatten frozen dataclass parameters into primitive leaves + and register the dataclass types as pytree nodes for Dynamo/AOTAutograd tracing. 3. NORMALIZE -- collapse parameter kinds to POSITIONAL_OR_KEYWORD, downgrade Literal/Enum to ``str``, scrub unsupported defaults. - 4. FLATTEN -- expand each frozen-dataclass parameter (recursively) into - its primitive leaves via ``_build_dataclass_sub_mapping_tree``. - 5. PYTREE -- side effect of step 4: register every dataclass as a pytree - node so Dynamo / AOTAutograd can trace through it. - 6. EMIT -- assemble ``(original_sig, lowered_sig, param_mapping_tree)``. - - A single pass is intentional: splitting concerns would force re-resolving - annotations and threading accumulator state. When the input is already - schema-compatible the lowered signature is bit-identical to the original, - and the caller's ``_signatures_differ`` check restores the zero-overhead path. + 4. EMIT -- assemble ``(original_sig, lowered_sig, param_mapping_tree)``. Returns: - original_sig (inspect.Signature): the user's un-flattened signature. - lowered_sig (inspect.Signature): what ``infer_schema`` will see. + original_sig (inspect.Signature): the user's original input signature. + lowered_sig (inspect.Signature): schema-compatible signature for ``infer_schema``. param_mapping_tree (list[tuple]): the bridge between the two; a list of root nodes (one per original parameter), each of which is: - * ``("primitive", attr_name, lowered_name, None)``, or - * ``("dataclass", attr_name, cls, [child_nodes...])``. - ``attr_name`` is the parameter name at top level / field name - deeper down. The same tree drives both runtime translation - directions (sec 7). + * ``("primitive", attr_name, lowered_attr_name, None)``, or + * ``("dataclass", attr_name, dataclass_cls_type, [child_nodes...])``. + + Example: + ``fn(x: Tensor, cfg: Outer(inner: Inner(scale: float, bias: Tensor), mode: str))`` + -> ``original_sig(x: Tensor, cfg: Outer)`` + -> ``lowered_sig(x: Tensor, cfg__inner__scale: float, cfg__inner__bias: Tensor, cfg__mode: str)`` + -> ``param_mapping_tree = [("primitive", "x", "x", None), + ("dataclass", "cfg", Outer, [ + ("dataclass", "inner", Inner, [ + ("primitive", "scale", "cfg__inner__scale", None), + ("primitive", "bias", "cfg__inner__bias", None), + ]), + ("primitive", "mode", "cfg__mode", None), + ])]`` """ + original_sig = inspect.signature(fn) resolved = _resolve_annotations(fn) - lowered_params: list[inspect.Parameter] = [] - param_mapping_tree: list[tuple] = [] - - for name, param in original_sig.parameters.items(): - _assert_no_var_args(param, fn_name=fn.__name__) - annotation = resolved.get(name, param.annotation) - _assert_has_annotation(annotation, where=f"parameter {name!r} of {fn.__name__!r}") - _assert_not_mutable_dataclass(annotation, where=f"parameter {name!r}") - if _is_frozen_dataclass(annotation): - node, sub_params = _build_dataclass_sub_mapping_tree(annotation, attr_name=name, flat_prefix=name) - param_mapping_tree.append(node) - lowered_params.extend(sub_params) + + def _lower_through( + params_or_fields, + *, + is_fn_params: bool, + owner_name: str, + flat_prefix: str | None = None, + field_types: dict[str, Any] | None = None, + ) -> tuple[list[tuple], list[inspect.Parameter]]: + nodes: list[tuple] = [] + lowered: list[inspect.Parameter] = [] + + if is_fn_params: + iterator = ((name, resolved.get(name, param.annotation), param) for name, param in params_or_fields) else: - _assert_not_unsupported_container(annotation, where=f"parameter {name!r}") - annotation = _maybe_downgrade_literal_or_enum(annotation, where=f"parameter {name!r}") - new_param = param.replace( - kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=annotation, - default=_schema_compatible_param_default(param.default), - ) - lowered_params.append(new_param) - param_mapping_tree.append(("primitive", name, name, None)) + resolved_fields = field_types or {} + iterator = ((field.name, resolved_fields.get(field.name, field.type), field) for field in params_or_fields) + + for name, annotation, source in iterator: + leaf_flat_name = name if is_fn_params else f"{flat_prefix}__{name}" + where = f"parameter {name!r}" if is_fn_params else f"field {owner_name}.{name}" + + if _is_frozen_dataclass(annotation): + _register_dataclass_pytree(annotation) + child_nodes, child_params = _lower_through( + dataclasses.fields(annotation), + is_fn_params=False, + owner_name=annotation.__name__, + flat_prefix=leaf_flat_name, + field_types=_resolve_dataclass_field_types(annotation), + ) + nodes.append(("dataclass", name, annotation, child_nodes)) + lowered.extend(child_params) + else: + annotation = _maybe_downgrade_literal_or_enum(annotation, where=where) + if is_fn_params: + param = source + lowered.append( + param.replace( + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=annotation, + default=_schema_compatible_param_default(param.default), + ) + ) + else: + field = source + lowered.append( + inspect.Parameter( + leaf_flat_name, + # POSITIONAL_OR_KEYWORD: torch.library.custom_op does not yet + # support kwarg-only Tensor arguments. + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=annotation, + default=_schema_compatible_field_default(field), + ) + ) + nodes.append(("primitive", name, leaf_flat_name, None)) + + return nodes, lowered + + param_mapping_tree, lowered_params = _lower_through( + original_sig.parameters.items(), is_fn_params=True, owner_name=f"{fn.__name__!r}" + ) return_annotation = resolved.get("return", original_sig.return_annotation) - _assert_has_annotation(return_annotation, where=f"return value of {fn.__name__!r}") - _assert_not_dataclass_return(return_annotation, fn_name=fn.__name__) lowered_sig = inspect.Signature(lowered_params, return_annotation=return_annotation) return original_sig, lowered_sig, param_mapping_tree # ============================================================================== -# 5. Synthesise the meta/fake function (input to slot 2) -# ------------------------------------------------------------------------------ -# Constructors for the meta ("fake") function torch.library uses for shape -# propagation during tracing. Fallbacks: identity meta when the user passes -# no `infer_output_meta_fn`; param-name-echoing meta when they pass a -# `list[str]`. The result is handed to `register_fake` by sec 6. +# BLOCK 2 -- REGISTER torch op +# +# Helpers: +# - meta/fake-fn synthesis: +# _get_num_outputs_from_return_annotation, +# _create_identity_meta_fn, _create_meta_fn_from_param_names +# - op-name generation: +# _generate_op_name +# Core: _register_torch_op # ============================================================================== +# ------------------------------------------------------------------------------ +# helpers: synthesise the meta/fake function +# ------------------------------------------------------------------------------ + + def _get_num_outputs_from_return_annotation(fn: Callable) -> int: """Output count from ``fn``'s return annotation: ``N`` for ``tuple[T1, ..., TN]``, else ``1`` (default / unrecognized).""" @@ -704,15 +655,9 @@ def meta_fn(*args, **kwargs): return meta_fn -# ============================================================================== -# 6. Register the op (produces slot 2) # ------------------------------------------------------------------------------ -# `_register_torch_op` calls `custom_op` + `register_fake`, yielding slot 2 -# (`torch_registered_op`). `_generate_op_name` derives a default op name -# from the user's `fn` when one isn't supplied. The orchestrator that -# stitches these together with sec 1-5 (and the runtime closures from -# sec 7) lives in sec 8. -# ============================================================================== +# helpers: generate op name +# ------------------------------------------------------------------------------ def _generate_op_name(fn: Callable) -> str: @@ -732,64 +677,102 @@ def _generate_op_name(fn: Callable) -> str: return f"{namespace}::{func_name}" -def _register_torch_op(op_name: str, fn: Callable, mutates_args: tuple[str, ...], meta_fn: Callable): - """``custom_op`` + ``register_fake``; returns slot 2 (``torch_registered_op``).""" - torch_registered_op = torch.library.custom_op(op_name, mutates_args=mutates_args)(fn) +# ------------------------------------------------------------------------------ +# core: _register_torch_op +# +# Forward reference: ``_DataclassRuntimeAdapter`` using ``from __future__ import annotations``. +# ------------------------------------------------------------------------------ + + +def _register_torch_op( + op_name: str, + fn: Callable, + mutates_args: tuple[str, ...], + infer_output_meta_fn: Callable | list[str] | None, + setup_context_fn: Callable | None, + backward_fn: Callable | None, + dataclass_runtime_adapter: _DataclassRuntimeAdapter | None = None, +): + """Register the op in torch.library.custom_op.""" + effective_mutates_args = ( + dataclass_runtime_adapter.expand_mutates_args(mutates_args) if dataclass_runtime_adapter is not None else mutates_args + ) + torch_registered_op = torch.library.custom_op(op_name, mutates_args=effective_mutates_args)(fn) + + # Build & register the meta/fake function. + if infer_output_meta_fn is None: + meta_fn = _create_identity_meta_fn(fn) + elif isinstance(infer_output_meta_fn, list): + meta_fn = _create_meta_fn_from_param_names(fn, infer_output_meta_fn) + elif dataclass_runtime_adapter is None: # No flattening scenario + meta_fn = infer_output_meta_fn + else: # Flattening scenario + user_meta = infer_output_meta_fn + + def _bridged_meta_fn(*args, **kwargs): + return user_meta(**dataclass_runtime_adapter.unflatten_call_args(args, kwargs)) + + _bridged_meta_fn.__signature__ = inspect.signature(fn) + meta_fn = _bridged_meta_fn torch.library.register_fake(op_name)(meta_fn) - return torch_registered_op + # Register autograd. + if backward_fn is not None: + if dataclass_runtime_adapter is None: # No flattening scenario + torch_registered_op.register_autograd(backward_fn, setup_context=setup_context_fn) + else: # Flattening scenario -# ============================================================================== -# 7. Runtime bridge: flatten / unflatten on every call -# ------------------------------------------------------------------------------ -# Executed on every call to slot 1 (`lowered_fn`) or slot 3 (`magi_exposed_op`), -# consuming the static `param_mapping_tree` from sec 3 to translate between -# the original (dataclass) and lowered (primitive) call shapes. See Part B -# of the module docstring for the full call-stack picture. -# original -> lowered: `_flatten_value_into`, `_flatten_call_args` -# lowered -> original: `_build_value_from_node`, `_reassemble_kwargs` -# grad bridge: `_flatten_grad_into`, `_flatten_grads` -# ============================================================================== + def _bridged_setup_context(ctx, inputs, output): + if setup_context_fn is None: + return None + original_inputs = dataclass_runtime_adapter.unflatten_setup_ctx_inputs(inputs) + return setup_context_fn(ctx, original_inputs, output) + def _bridged_backward(ctx, *grads): + original_grads = backward_fn(ctx, *grads) + if not isinstance(original_grads, tuple): + original_grads = (original_grads,) + return dataclass_runtime_adapter.flatten_input_grads(original_grads) -def _build_value_from_node(node: tuple, lowered_kwargs: dict): - """``lowered_kwargs`` -> one original-shaped value (recursive).""" - kind = node[0] - if kind == "primitive": - _, _attr, lowered_name, _ = node - return lowered_kwargs[lowered_name] - _, _attr, cls, children = node - init_kwargs: dict[str, Any] = {} - for child in children: - field_name = child[1] - init_kwargs[field_name] = _build_value_from_node(child, lowered_kwargs) - return cls(**init_kwargs) + torch_registered_op.register_autograd(_bridged_backward, setup_context=_bridged_setup_context) + return torch_registered_op -def _reassemble_kwargs(param_mapping_tree: list[tuple], lowered_kwargs: dict) -> dict: - """``lowered_kwargs`` -> original kwargs (the ``lowered -> original`` walk).""" - out: dict[str, Any] = {} - for node in param_mapping_tree: - out[node[1]] = _build_value_from_node(node, lowered_kwargs) - return out +# ============================================================================== +# BLOCK 3 -- RUNTIME ADAPTER +# +# Helpers (adapter field <- bound function): +# original -> lowered: flatten_call_args <- _flatten_call_args +# flatten_input_grads <- _flatten_input_grads +# lowered -> original: unflatten_call_args <- _unflatten_call_args +# unflatten_setup_ctx_inputs <- _unflatten_setup_ctx_inputs +# _reassemble_kwargs (internal primitive) +# mutates_args expand: expand_mutates_args <- _expand_mutates_args +# signature stamping: apply_lowered_signature <- _apply_lowered_signature +# Core: _DataclassRuntimeAdapter +# ============================================================================== -def _flatten_value_into(node: tuple, value: Any, out: list) -> None: - """Append leaves of ``value`` to ``out`` in DFS order (no isinstance check - on ``cls`` -- duck-typed via ``getattr`` so mocks / SimpleNamespace work).""" - kind = node[0] - if kind == "primitive": - out.append(value) - return - _, _attr, cls, children = node - for child in children: - field_name = child[1] - _flatten_value_into(child, getattr(value, field_name), out) + +# ---- flatten_call_args ---- def _flatten_call_args(param_mapping_tree: list[tuple], original_sig: inspect.Signature, args: tuple, kwargs: dict) -> list: """User-side call -> flat positional list matching the lowered signature (the ``original -> lowered`` walk).""" + + def _flatten_value_into(node: tuple, value: Any, out: list) -> None: + """Append leaves of ``value`` to ``out`` in DFS order (no isinstance check + on ``cls`` -- duck-typed via ``getattr`` so mocks / SimpleNamespace work).""" + kind = node[0] + if kind == "primitive": + out.append(value) + return + _, _attr, _cls, children = node + for child in children: + field_name = child[1] + _flatten_value_into(child, getattr(value, field_name), out) + bound = original_sig.bind(*args, **kwargs) bound.apply_defaults() flat: list = [] @@ -798,35 +781,37 @@ def _flatten_call_args(param_mapping_tree: list[tuple], original_sig: inspect.Si return flat -def _flatten_grad_into(node: tuple, grad: Any, out: list) -> None: - """Spread a user-returned grad across the lowered slots of one original input. +# ---- flatten_input_grads ---- - ``primitive`` -> append ``grad`` as-is. ``dataclass`` -> if ``grad`` is - ``None`` fill every leaf with ``None`` (the common whole-dataclass-not- - differentiable case); otherwise descend with ``dict``-aware lookup so - users can return dict / SimpleNamespace / dataclass-shaped objects. - """ - if node[0] == "primitive": - out.append(grad) - return - _, _attr, _cls, children = node - if grad is None: - for child in children: - for _ in range(_count_leaves(child)): - out.append(None) - return - is_mapping = isinstance(grad, dict) - for child in children: - field_name = child[1] - if is_mapping: - sub = grad.get(field_name) - else: - sub = getattr(grad, field_name, None) - _flatten_grad_into(child, sub, out) +def _flatten_input_grads(param_mapping_tree: list[tuple], original_grads: tuple) -> tuple: + """Original-space input grads -> lowered-space input grads.""" + + def _count_leaves(node: tuple) -> int: + if node[0] == "primitive": + return 1 + return sum(_count_leaves(c) for c in node[3]) + + def _flatten_grad_into(node: tuple, grad: Any, out: list) -> None: + """Spread a user-returned grad across lowered slots for one original input.""" + if node[0] == "primitive": + out.append(grad) + return + _, _attr, _cls, children = node + if grad is None: + for child in children: + for _ in range(_count_leaves(child)): + out.append(None) + return + is_mapping = isinstance(grad, dict) + for child in children: + field_name = child[1] + if is_mapping: + sub = grad.get(field_name) + else: + sub = getattr(grad, field_name, None) + _flatten_grad_into(child, sub, out) -def _flatten_grads(param_mapping_tree: list[tuple], original_grads: tuple | list) -> list: - """Grads keyed by original-parameter order -> grads keyed by lowered order.""" if len(original_grads) != len(param_mapping_tree): raise ValueError( f"@magi_register_custom_op: backward_fn returned {len(original_grads)} " @@ -837,18 +822,67 @@ def _flatten_grads(param_mapping_tree: list[tuple], original_grads: tuple | list flat: list = [] for node, g in zip(param_mapping_tree, original_grads): _flatten_grad_into(node, g, flat) - return flat + return tuple(flat) + + +# ---- unflatten_call_args / unflatten_setup_ctx_inputs ---- + + +def _reassemble_kwargs(param_mapping_tree: list[tuple], lowered_kwargs: dict) -> dict: + """``lowered_kwargs`` -> original kwargs (the ``lowered -> original`` walk).""" + + def _build_value_from_node(node: tuple): + kind = node[0] + if kind == "primitive": + _, _attr, lowered_attr_name, _ = node + return lowered_kwargs[lowered_attr_name] + _, _attr, cls, children = node + init_kwargs: dict[str, Any] = {} + for child in children: + field_name = child[1] + init_kwargs[field_name] = _build_value_from_node(child) + return cls(**init_kwargs) + + out: dict[str, Any] = {} + for node in param_mapping_tree: + out[node[1]] = _build_value_from_node(node) + return out + + +def _unflatten_call_args(lowered_sig: inspect.Signature, param_mapping_tree: list[tuple], args: tuple, kwargs: dict) -> dict: + """Lowered call args/kwargs -> original kwargs (dict for ``fn(**dict)``).""" + bound = lowered_sig.bind(*args, **kwargs) + bound.apply_defaults() + return _reassemble_kwargs(param_mapping_tree, bound.arguments) + + +def _unflatten_setup_ctx_inputs( + lowered_sig: inspect.Signature, original_sig: inspect.Signature, param_mapping_tree: list[tuple], inputs: tuple +) -> tuple: + """Lowered positional inputs tuple -> original positional inputs tuple + (for ``setup_context_fn(ctx, inputs, output)``).""" + lowered_kwargs = {p.name: v for p, v in zip(lowered_sig.parameters.values(), inputs)} + original_kwargs = _reassemble_kwargs(param_mapping_tree, lowered_kwargs) + return tuple(original_kwargs[p] for p in original_sig.parameters) + + +# ---- core ---- + + +@dataclasses.dataclass(frozen=True) +class _DataclassRuntimeAdapter: + """Runtime conversion adapter.""" + + flatten_call_args: Callable[[tuple, dict], list] + flatten_input_grads: Callable[[tuple], tuple] + unflatten_call_args: Callable[[tuple, dict], dict] + unflatten_setup_ctx_inputs: Callable[[tuple], tuple] + expand_mutates_args: Callable[[tuple[str, ...]], tuple[str, ...]] + apply_lowered_signature: Callable[[Callable], None] # ============================================================================== -# 8. The decorator: main pipeline (produces slot 3) -# ------------------------------------------------------------------------------ -# The single public entry point. Its inner `decorator` closure orchestrates the -# full 4-slot pipeline (see module docstring): it calls sec 4 to lower the user's -# signature (slot 1), sec 5 to synthesise the meta function, sec 6 to register -# the op with torch.library (slot 2), and -- on the dataclass-flatten path -- -# additionally builds the user-facing wrapper (slot 3) plus the runtime closures -# that drive sec 7 on each call. +# BLOCK 4 -- MAIN PIPELINE # ============================================================================== @@ -862,8 +896,7 @@ def _magi_register_custom_op_impl( is_subgraph_boundary: bool = False, ): def decorator(fn: Callable) -> Callable: - # See the module docstring for the 4-slot pipeline / 3-runtime-path - # picture; the body below just walks slot 0 -> 1 -> 2 (-> 3 if needed). + # A 4-slot pipeline. op_name = name if name is not None else _generate_op_name(fn) if is_compute_sensitive: @@ -871,123 +904,84 @@ def decorator(fn: Callable) -> Callable: if is_subgraph_boundary: get_compile_config().splitting_ops.append(op_name) - # Dataclass parameters are the only thing forcing slot 3; other lowering - # (Literal/Enum/default scrub) is handled by slot 1 alone at zero - # per-call cost. + _validate_op_signature_constraints(fn) original_sig, lowered_sig, param_mapping_tree = _lower_op_signature(fn) needs_flattening = any(kind == "dataclass" for kind, *_ in param_mapping_tree) if not needs_flattening: - # ----- No-flattening path: fn -> [lowered_fn?] -> torch_registered_op ----- - # Step 1 (slot 1): only when the lowering actually rewrote the - # signature -- otherwise register ``fn`` directly (zero-overhead). - if _signatures_differ(original_sig, lowered_sig): - lowered_fn = _make_lowered_signature_wrapper(fn, lowered_sig) - fn_to_register = lowered_fn - else: + # ----- No-flattening scenario ----- + # Path: fn -> [lowered_fn ->] torch_registered_op + + # Step 1: Build ``lowered_fn`` iff the signature was rewritten. + if original_sig == lowered_sig: fn_to_register = fn + else: # Signatures differ, need to wrap the function - # Step 2: meta/fake function. - if infer_output_meta_fn is None: - meta_fn = _create_identity_meta_fn(fn_to_register) - elif isinstance(infer_output_meta_fn, list): - meta_fn = _create_meta_fn_from_param_names(fn_to_register, infer_output_meta_fn) - else: - meta_fn = infer_output_meta_fn + @functools.wraps(fn) + def lowered_fn(*args, **kwargs): + return fn(*args, **kwargs) + + _apply_lowered_signature(lowered_sig, lowered_fn) + fn_to_register = lowered_fn - # Step 3 (slot 2): custom_op + register_fake. + # Step 2: Register the op in torch and get ``torch_registered_op``. torch_registered_op = _register_torch_op( - op_name=op_name, fn=fn_to_register, mutates_args=mutates_args, meta_fn=meta_fn + op_name=op_name, + fn=fn_to_register, + mutates_args=mutates_args, + infer_output_meta_fn=infer_output_meta_fn, + setup_context_fn=setup_context_fn, + backward_fn=backward_fn, + dataclass_runtime_adapter=None, ) - # Step 4: autograd. - if backward_fn is not None: - torch_registered_op.register_autograd(backward_fn, setup_context=setup_context_fn) - - # No slot 3 needed: the user's calling convention already matches - # the lowered one, so ``torch_registered_op`` is itself returned. + # Return bare torch-level op (slot 2). return torch_registered_op else: - # ----- Flattening path: fn -> lowered_fn -> torch_registered_op -> magi_exposed_op ----- - # Step 1 (slot 1): build the lowered-signature bridge. ``lowered_fn`` - # speaks the flat primitive signature; it rebinds args, reassembles - # dataclasses, then dispatches to the user's ``fn``. - def _bind_to_original_kwargs(args, kwargs): - bound = lowered_sig.bind(*args, **kwargs) - bound.apply_defaults() - return _reassemble_kwargs(param_mapping_tree, bound.arguments) + # ----- Flattening scenario ----- + # Path: fn -> lowered_fn -> torch_registered_op -> magi_exposed_op + + # Step 0 (only in the flattening scenario): Build the scenario-wide adapter. + dataclass_runtime_adapter = _DataclassRuntimeAdapter( + flatten_call_args=functools.partial(_flatten_call_args, param_mapping_tree, original_sig), + flatten_input_grads=functools.partial(_flatten_input_grads, param_mapping_tree), + unflatten_call_args=functools.partial(_unflatten_call_args, lowered_sig, param_mapping_tree), + unflatten_setup_ctx_inputs=functools.partial( + _unflatten_setup_ctx_inputs, lowered_sig, original_sig, param_mapping_tree + ), + expand_mutates_args=functools.partial(_expand_mutates_args, param_mapping_tree), + apply_lowered_signature=functools.partial(_apply_lowered_signature, lowered_sig), + ) + # Step 1: Build ``lowered_fn``. @functools.wraps(fn) def lowered_fn(*args, **kwargs): - return fn(**_bind_to_original_kwargs(args, kwargs)) + return fn(**dataclass_runtime_adapter.unflatten_call_args(args, kwargs)) - _apply_lowered_signature_metadata(lowered_fn, lowered_sig) + dataclass_runtime_adapter.apply_lowered_signature(lowered_fn) - # Step 2: meta/fake function. User-supplied meta_fn is bridged so - # it sees the original (dataclass-bearing) signature it was - # written against. - if infer_output_meta_fn is None: - meta_fn = _create_identity_meta_fn(lowered_fn) - elif isinstance(infer_output_meta_fn, list): - meta_fn = _create_meta_fn_from_param_names(lowered_fn, infer_output_meta_fn) - else: - user_meta = infer_output_meta_fn - - def _bridged_meta_fn(*args, **kwargs): - return user_meta(**_bind_to_original_kwargs(args, kwargs)) - - _bridged_meta_fn.__signature__ = lowered_sig - meta_fn = _bridged_meta_fn - - # Step 3 (slot 2): custom_op + register_fake. ``mutates_args`` is - # expanded from original-space to lowered-space so torch.library - # sees the leaf parameter names it actually owns. - flat_mutates_args = _expand_mutates_args(mutates_args, param_mapping_tree) + # Step 2: Register the op in torch and get ``torch_registered_op``. torch_registered_op = _register_torch_op( - op_name=op_name, fn=lowered_fn, mutates_args=flat_mutates_args, meta_fn=meta_fn + op_name=op_name, + fn=lowered_fn, + mutates_args=mutates_args, + infer_output_meta_fn=infer_output_meta_fn, + setup_context_fn=setup_context_fn, + backward_fn=backward_fn, + dataclass_runtime_adapter=dataclass_runtime_adapter, ) - # Step 4: autograd. The user's hooks speak the ORIGINAL signature, - # but torch.library passes/expects LOWERED inputs and grads, so we - # wrap both ends. - if backward_fn is not None: - user_setup = setup_context_fn - user_backward = backward_fn - - def _bridged_setup_context(ctx, inputs, output): - if user_setup is None: - return None - # Reassemble the lowered positional tuple into the user's - # original (possibly nested-dataclass) shape, preserving - # original positional order so ``x, cfg = inputs`` works. - lowered_kwargs = {p.name: v for p, v in zip(lowered_sig.parameters.values(), inputs)} - original_kwargs = _reassemble_kwargs(param_mapping_tree, lowered_kwargs) - original_inputs = tuple(original_kwargs[p] for p in original_sig.parameters) - return user_setup(ctx, original_inputs, output) - - def _bridged_backward(ctx, *grads): - original_grads = user_backward(ctx, *grads) - if not isinstance(original_grads, tuple): - # Single-input convenience: PyTorch allows a bare grad - # when the op has one input. - original_grads = (original_grads,) - return tuple(_flatten_grads(param_mapping_tree, original_grads)) - - torch_registered_op.register_autograd(_bridged_backward, setup_context=_bridged_setup_context) - - # Step 5 (slot 3, flattening-only): the user-facing op that - # preserves the original signature, flattens at entry, and - # dispatches to ``torch_registered_op``. + # Step 3 (only in the flattening scenario): Wrap the torch-level op and get ``magi_exposed_op``. @functools.wraps(fn) def magi_exposed_op(*args, **kwargs): - flat = _flatten_call_args(param_mapping_tree, original_sig, args, kwargs) + flat = dataclass_runtime_adapter.flatten_call_args(args, kwargs) return torch_registered_op(*flat) - # Internal handles so downstream tooling can drop one slot lower - # (e.g. dispatch the OpOverload directly with pre-flattened args). magi_exposed_op._magi_torch_registered_op = torch_registered_op magi_exposed_op._magi_param_mapping_tree = param_mapping_tree + + # Return magi-level op. return magi_exposed_op return decorator From 084d163542f01f7c9aa57d65326b6565dcd5ffda Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Tue, 19 May 2026 14:42:06 +0800 Subject: [PATCH 4/6] add unittest for input grads in dataclass format --- tests/api_tests/test_register_custom_op.py | 119 +++++++++++++++++++++ 1 file changed, 119 insertions(+) diff --git a/tests/api_tests/test_register_custom_op.py b/tests/api_tests/test_register_custom_op.py index 81e7656..071cee0 100644 --- a/tests/api_tests/test_register_custom_op.py +++ b/tests/api_tests/test_register_custom_op.py @@ -1856,6 +1856,125 @@ def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: # b did not require grad; nothing to assert beyond "no exception". +class TestStructuralDataclassGrad: + """The dataclass-grad slot is matched **structurally by field name** -- the + returned object does NOT have to be an instance of the input dataclass. + Any object that exposes the same field names (the input dataclass itself, + a different dataclass, a ``SimpleNamespace``, etc.) is accepted, and + Tensor leaves are routed to the corresponding flat slots. + """ + + def test_grad_returned_as_same_dataclass(self): + """Baseline: backward_fn returns an instance of the *same* dataclass + type as the input. Tensor fields carry real grads, demonstrating that + the dataclass-grad path works end-to-end before we exercise the + structural (foreign-type) variants below.""" + + @dataclasses.dataclass(frozen=True) + class _Cfg: + w: torch.Tensor + b: torch.Tensor + + def setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x, cfg.w) + + def bwd(ctx, gy): + x, w = ctx.saved_tensors + # Use the SAME dataclass type _Cfg to carry the field-level grads. + return gy * w, _Cfg(w=gy * x, b=torch.zeros_like(w)) + + @magi_register_custom_op(name="test::dc_grad_same_type", setup_context_fn=setup, backward_fn=bwd) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.w + cfg.b + + x = torch.randn(4, requires_grad=True) + w = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + cfg = _Cfg(w=w, b=b) + out = _op(x, cfg) + out.sum().backward() + assert torch.allclose(x.grad, w) + assert torch.allclose(w.grad, x) + # b's grad came from _Cfg.b = zeros_like(w), so it should be 0. + assert torch.allclose(b.grad, torch.zeros_like(b)) + + def test_grad_returned_as_different_dataclass(self): + """``backward_fn`` returns a *different* dataclass type that just + happens to share field names with the input dataclass. The bridge + must spread its Tensor fields to the correct flat grad slots based + purely on field-name matching.""" + + @dataclasses.dataclass(frozen=True) + class _Cfg: + w: torch.Tensor + b: torch.Tensor + + @dataclasses.dataclass(frozen=True) + class _CfgGrad: + # Same field names as _Cfg, different class. Distinct identity + # to prove name-only matching (no isinstance check). + w: torch.Tensor + b: torch.Tensor + + def setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x, cfg.w) + + def bwd(ctx, gy): + x, w = ctx.saved_tensors + # Express grads via a foreign dataclass type with matching field + # names. The bridge must NOT require _Cfg specifically. + return gy * w, _CfgGrad(w=gy * x, b=torch.zeros_like(w)) + + @magi_register_custom_op(name="test::dc_grad_foreign_type", setup_context_fn=setup, backward_fn=bwd) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.w + cfg.b + + x = torch.randn(4, requires_grad=True) + w = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + cfg = _Cfg(w=w, b=b) + out = _op(x, cfg) + out.sum().backward() + assert torch.allclose(x.grad, w) + assert torch.allclose(w.grad, x) + # b's grad came from _CfgGrad.b = zeros_like(w), so it should be 0. + assert torch.allclose(b.grad, torch.zeros_like(b)) + + def test_grad_returned_as_simple_namespace(self): + """A non-dataclass object (``types.SimpleNamespace``) that exposes + the dataclass field names should also be accepted.""" + import types as _types + + @dataclasses.dataclass(frozen=True) + class _Cfg: + w: torch.Tensor + b: torch.Tensor + + def setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x, cfg.w) + + def bwd(ctx, gy): + x, w = ctx.saved_tensors + return gy * w, _types.SimpleNamespace(w=gy * x, b=torch.zeros_like(w)) + + @magi_register_custom_op(name="test::dc_grad_namespace", setup_context_fn=setup, backward_fn=bwd) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.w + cfg.b + + x = torch.randn(4, requires_grad=True) + w = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + cfg = _Cfg(w=w, b=b) + out = _op(x, cfg) + out.sum().backward() + assert torch.allclose(x.grad, w) + assert torch.allclose(w.grad, x) + assert torch.allclose(b.grad, torch.zeros_like(b)) + + class TestBackwardCallsAnotherOp: """``backward_fn`` is allowed to call other registered ops to compute the gradient. This is the FlashAttention-style "forward op + backward op" From aac4e1bd0568e469107930a14996398b4cbf9fe9 Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Tue, 19 May 2026 15:13:46 +0800 Subject: [PATCH 5/6] add api example --- magi_compiler/api.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/magi_compiler/api.py b/magi_compiler/api.py index 5821550..8a92d0b 100644 --- a/magi_compiler/api.py +++ b/magi_compiler/api.py @@ -210,7 +210,7 @@ def magi_register_custom_op( functionality of: - @torch.library.custom_op - @torch.library.register_fake - - fn.register_autograd + - @torch.library.register_autograd Arguments: name: Fully qualified op name (e.g. ``"namespace::op_name"``). If @@ -265,6 +265,37 @@ def magi_register_custom_op( >>> @magi_register_custom_op() ... def attn(q: torch.Tensor, k: torch.Tensor, cfg: AttnCfg) -> torch.Tensor: ... pass + + 4. Backward with a dataclass input (per-field grads by field name): + + For a dataclass input slot, the easiest way to express per-field + grads is to return a ``dict`` keyed by field name. Tensor fields + carry their grads; non-differentiable fields (or absent keys) use + ``None``. The bridge matches by field **name** (not by type), so + any object exposing the same names works; ``dict`` is just the + most convenient. + + >>> @dataclasses.dataclass(frozen=True) + ... class WeightCfg: + ... w: torch.Tensor + ... b: torch.Tensor + ... + >>> def setup(ctx, inputs, output): + ... x, cfg = inputs + ... ctx.save_for_backward(x, cfg.w) + ... + >>> def bwd(ctx, gy): + ... x, w = ctx.saved_tensors + ... # Slot order matches the ORIGINAL signature of ``op`` below. + ... return ( + ... gy * w, # grad for x (Tensor) + ... {"w": gy * x, "b": None}, # grad for cfg (dict: per-field) + ... # equivalently: WeightCfg(w=gy * x, b=None) + ... ) + ... + >>> @magi_register_custom_op(setup_context_fn=setup, backward_fn=bwd) + ... def op(x: torch.Tensor, cfg: WeightCfg) -> torch.Tensor: + ... return x * cfg.w + cfg.b """ return _magi_register_custom_op_impl( name=name, From 9709d0c9821038d316d0a791eaa96039b56c65ba Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Tue, 19 May 2026 17:28:36 +0800 Subject: [PATCH 6/6] fix docs --- magi_compiler/api.py | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/magi_compiler/api.py b/magi_compiler/api.py index 8a92d0b..3f94989 100644 --- a/magi_compiler/api.py +++ b/magi_compiler/api.py @@ -211,6 +211,7 @@ def magi_register_custom_op( - @torch.library.custom_op - @torch.library.register_fake - @torch.library.register_autograd + and extends it to support nested dataclass parameters. Arguments: name: Fully qualified op name (e.g. ``"namespace::op_name"``). If @@ -257,7 +258,7 @@ def magi_register_custom_op( 3. Frozen-dataclass parameter (grouped config): - >>> @dataclasses.dataclass(frozen=True) + ... @dataclasses.dataclass(frozen=True) ... class AttnCfg: ... scale: float ... causal: bool = False @@ -266,36 +267,43 @@ def magi_register_custom_op( ... def attn(q: torch.Tensor, k: torch.Tensor, cfg: AttnCfg) -> torch.Tensor: ... pass - 4. Backward with a dataclass input (per-field grads by field name): + 4. Backward with a nested-dataclass input (per-field grads by field name): For a dataclass input slot, the easiest way to express per-field - grads is to return a ``dict`` keyed by field name. Tensor fields - carry their grads; non-differentiable fields (or absent keys) use - ``None``. The bridge matches by field **name** (not by type), so - any object exposing the same names works; ``dict`` is just the - most convenient. - - >>> @dataclasses.dataclass(frozen=True) - ... class WeightCfg: + grads is to return a ``dict`` keyed by field name (nested for + nested dataclasses). Tensor fields carry their grads; non- + differentiable fields (or absent keys) use ``None``. The bridge + matches by field **name** (not by type), so any object exposing + the same names works; ``dict`` is just the most convenient. + + ... @dataclasses.dataclass(frozen=True) + ... class Inner: ... w: torch.Tensor ... b: torch.Tensor ... - >>> def setup(ctx, inputs, output): + ... @dataclasses.dataclass(frozen=True) + ... class WeightCfg: + ... inner: Inner + ... scale: float + ... + ... def setup(ctx, inputs, output): ... x, cfg = inputs - ... ctx.save_for_backward(x, cfg.w) + ... ctx.save_for_backward(x, cfg.inner.w) + ... ctx.scale = cfg.scale ... - >>> def bwd(ctx, gy): + ... def bwd(ctx, gy): ... x, w = ctx.saved_tensors + ... s = ctx.scale ... # Slot order matches the ORIGINAL signature of ``op`` below. ... return ( - ... gy * w, # grad for x (Tensor) - ... {"w": gy * x, "b": None}, # grad for cfg (dict: per-field) - ... # equivalently: WeightCfg(w=gy * x, b=None) + ... gy * w * s, # grad for x (Tensor) + ... {"inner": {"w": gy * x * s, "b": None}, "scale": None}, # grad for cfg (nested dict) + ... # equivalently: WeightCfg(inner=Inner(w=gy * x * s, b=None), scale=None) ... ) ... >>> @magi_register_custom_op(setup_context_fn=setup, backward_fn=bwd) ... def op(x: torch.Tensor, cfg: WeightCfg) -> torch.Tensor: - ... return x * cfg.w + cfg.b + ... return (x * cfg.inner.w + cfg.inner.b) * cfg.scale """ return _magi_register_custom_op_impl( name=name,