diff --git a/magi_compiler/_magi_register_custom_op.py b/magi_compiler/_magi_register_custom_op.py index 3f770ec..2fa1d7d 100644 --- a/magi_compiler/_magi_register_custom_op.py +++ b/magi_compiler/_magi_register_custom_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 SandAI. All Rights Reserved. +# Copyright (c) 2026 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,34 +12,575 @@ # See the License for the specific language governing permissions and # limitations under the License. +""" +Magi custom-op registration: dataclass-aware wrapper around ``torch.library``. + +This module implements ``magi_register_custom_op`` -- a decorator that takes +a plain Python function and registers it as a real custom op while letting +the user keep calling it with their original signature. + + +File layout +=========== + +The file has five blocks. Each block groups its own helpers (private, +above) with the one core piece it exists to support (below). Block +boundaries follow the 5-stage pipeline. + + Block 0 -- VALIDATE op signature constraints (registration-time) + helpers: assertion predicates + validation primitives + core: _validate_op_signature_constraints + + Block 1 -- LOWER (registration-time) + helpers: type resolution, default scrubbing, param-mapping-tree construction + core: _lower_op_signature (produces slot 1) + + Block 2 -- REGISTER (registration-time) + helpers: op-name generation, meta/fake-fn synthesis + core: _register_torch_op (produces slot 2) + + Block 3 -- RUNTIME ADAPTER (runtime) + helpers: flatten / unflatten primitives + signature-bound wrappers + core: _DataclassRuntimeAdapter (used by slot 3) + + Block 4 -- MAIN PIPELINE + core: _magi_register_custom_op_impl (the decorator; + orchestrates blocks 0-3, produces slot 3 on the flatten path) +""" + +from __future__ import annotations + +import dataclasses +import functools import inspect -from typing import Callable, get_args, get_origin +from typing import Any, Callable, get_args, get_origin import torch +import torch.utils._pytree as pytree from .config import get_compile_config +from .utils.logger import magi_logger +# ============================================================================== +# BLOCK 0 -- VALIDATE op signature constraints +# +# Helpers: +# - type predicates: +# _is_frozen_dataclass +# - assertion primitives: +# _assert_not_unsupported_container, _assert_not_dataclass_return, +# _assert_not_mutable_dataclass, _assert_has_annotation, +# _assert_no_var_args, _assert_resolved_field_type +# Core: _validate_op_signature_constraints +# ============================================================================== + + +def _is_frozen_dataclass(tp) -> bool: + """Return True if ``tp`` is a frozen dataclass type.""" + return ( + isinstance(tp, type) + and dataclasses.is_dataclass(tp) + and getattr(tp, "__dataclass_params__", None) is not None + and tp.__dataclass_params__.frozen + ) + + +def _assert_not_unsupported_container(tp, *, where: str) -> None: + """Reject ``tuple[...]`` / ``dict[...]`` annotations (schema only models ``list``).""" + origin = get_origin(tp) + if origin is tuple: + raise TypeError( + f"@magi_register_custom_op: {where} has tuple annotation {tp!r}; " + f"use ``list[...]`` or split into separate fields." + ) + if origin is dict: + raise TypeError( + f"@magi_register_custom_op: {where} has dict-typed annotation {tp!r}; " f"promote the values to explicit fields." + ) + + +def _assert_not_dataclass_return(tp, *, fn_name: str) -> None: + """Reject dataclass return annotations (schema only returns Tensor / tuple / list / None).""" + if isinstance(tp, type) and dataclasses.is_dataclass(tp): + raise TypeError( + f"@magi_register_custom_op: {fn_name!r} returns dataclass " + f"{tp.__name__!r}; only Tensor / tuple[Tensor, ...] / list[Tensor] " + f"are supported -- destructure into a tuple at the op boundary." + ) + + +def _assert_not_mutable_dataclass(tp, *, where: str) -> None: + """Reject non-frozen dataclasses (schema needs hashable, stable inputs).""" + if ( + isinstance(tp, type) + and dataclasses.is_dataclass(tp) + and getattr(tp, "__dataclass_params__", None) is not None + and not tp.__dataclass_params__.frozen + ): + raise TypeError( + f"@magi_register_custom_op: {where} has mutable dataclass " + f"{tp.__name__!r}; add ``frozen=True`` to {tp.__name__}." + ) + + +def _assert_has_annotation(annotation, *, where: str) -> None: + """Require an annotation on every parameter / field / return value (needed + to recognise dataclasses and to feed ``infer_schema``).""" + if annotation is inspect.Parameter.empty or annotation is inspect.Signature.empty: + raise TypeError( + f"@magi_register_custom_op: {where} has no type annotation " f"(e.g. ``x: torch.Tensor`` or ``cfg: MyFrozenCfg``)." + ) + + +def _assert_no_var_args(param: inspect.Parameter, *, fn_name: str) -> None: + """Reject ``*args`` / ``**kwargs`` (op schemas are positional-or-keyword only).""" + if param.kind is inspect.Parameter.VAR_POSITIONAL: + raise TypeError( + f"@magi_register_custom_op: {fn_name!r} declares ``*{param.name}``; " + f"variadics aren't supported -- replace with explicit annotated parameters." + ) + if param.kind is inspect.Parameter.VAR_KEYWORD: + raise TypeError( + f"@magi_register_custom_op: {fn_name!r} declares ``**{param.name}``; " + f"variadics aren't supported -- replace with explicit annotated parameters." + ) + + +def _assert_resolved_field_type(f_type, *, where: str) -> None: + """Reject unresolved string annotations -- typically a local class combined + with stringified annotations that ``get_type_hints`` could not eval.""" + if isinstance(f_type, str): + raise TypeError( + f"@magi_register_custom_op: {where} has unresolved string " + f"annotation {f_type!r}; move the type to module scope so " + f"``get_type_hints`` can resolve it." + ) + + +def _validate_op_signature_constraints(fn: Callable) -> None: + """Validate fn parameters/return and recursively validate frozen dataclass subtrees.""" + original_sig = inspect.signature(fn) + resolved = _resolve_annotations(fn) + + def _validate_through( + params_or_fields, *, owner_name: str, is_fn_params: bool, field_types: dict[str, Any] | None = None + ) -> None: + for item in params_or_fields: + if is_fn_params: + name, param = item + _assert_no_var_args(param, fn_name=fn.__name__) + annotation = resolved.get(name, param.annotation) + where = f"parameter {name!r} of {owner_name}" + else: + field = item + name = field.name + annotation = (field_types or {}).get(name, field.type) + where = f"field {owner_name}.{name}" + _assert_resolved_field_type(annotation, where=where) + + _assert_has_annotation(annotation, where=where) + _assert_not_mutable_dataclass(annotation, where=where) + + if _is_frozen_dataclass(annotation): + _validate_through( + dataclasses.fields(annotation), + owner_name=annotation.__name__, + is_fn_params=False, + field_types=_resolve_dataclass_field_types(annotation), + ) + else: + _assert_not_unsupported_container(annotation, where=where) -def _get_num_outputs_from_return_annotation(fn: Callable) -> int: - """ - Get the number of outputs from the function's return type annotation. + _validate_through(original_sig.parameters.items(), owner_name=f"{fn.__name__!r}", is_fn_params=True) + + return_annotation = resolved.get("return", original_sig.return_annotation) + _assert_has_annotation(return_annotation, where=f"return value of {fn.__name__!r}") + _assert_not_dataclass_return(return_annotation, fn_name=fn.__name__) + + +# ============================================================================== +# BLOCK 1 -- LOWER fn signature +# +# Helpers: +# - type resolution & default scrubbing: +# _resolve_annotations, _resolve_dataclass_field_types, +# _maybe_downgrade_literal_or_enum, +# _schema_compatible_field_default, _schema_compatible_param_default +# - param-mapping-tree construction: +# _register_dataclass_pytree, +# _expand_mutates_args +# - lowered-signature utilities: +# _apply_lowered_signature, +# _make_lowered_signature_wrapper +# Core: _lower_op_signature +# ============================================================================== + + +# ------------------------------------------------------------------------------ +# helpers: resolve types & sanitise defaults for infer_schema +# ------------------------------------------------------------------------------ + + +def _resolve_annotations(fn: Callable) -> dict[str, Any]: + """Return ``fn`` annotations as real types, with globals+closure eval fallback.""" + import typing + + try: + return typing.get_type_hints(fn) + except Exception: + pass + + # Build an eval namespace from module globals + closure nonlocals. + # ``__globals__`` covers the common case; closure vars from + # ``getclosurevars`` cover annotations that name enclosing locals. + fn_globals = getattr(fn, "__globals__", {}) or {} + namespace: dict[str, Any] = dict(fn_globals) + try: + cv = inspect.getclosurevars(fn) + namespace.update(cv.builtins) + namespace.update(cv.globals) + namespace.update(cv.nonlocals) + except Exception as e: + magi_logger.debug( + "inspect.getclosurevars(%s) failed: %s; falling back to module globals only", + getattr(fn, "__qualname__", fn), + e, + rank="all", + ) + + anns: dict[str, Any] = {} + raw = getattr(fn, "__annotations__", {}) or {} + for k, v in raw.items(): + if isinstance(v, str): + try: + anns[k] = eval(v, namespace, None) + except Exception: + anns[k] = v + else: + anns[k] = v + return anns + + +def _resolve_dataclass_field_types(cls: type) -> dict[str, Any]: + """Return ``cls``'s field name -> resolved type (best-effort).""" + import sys + import typing as _typing + + try: + return _typing.get_type_hints(cls) + except Exception: + pass + + # ``get_type_hints(cls)`` is all-or-nothing; fall back to per-field eval so + # one unresolved annotation does not poison the whole dataclass. + namespace: dict[str, Any] = {} + module = sys.modules.get(getattr(cls, "__module__", "")) + if module is not None: + namespace.update(vars(module)) + namespace.update(getattr(cls, "__dict__", {})) + namespace.setdefault(cls.__name__, cls) + + resolved: dict[str, Any] = {} + for f in dataclasses.fields(cls): + tp = f.type + if isinstance(tp, str): + try: + resolved[f.name] = eval(tp, namespace, None) + except Exception: + resolved[f.name] = tp + else: + resolved[f.name] = tp + return resolved + + +def _maybe_downgrade_literal_or_enum(annotation, *, where: str): + """Downgrade ``Literal[str,...]`` and string-valued Enums to ``str`` or raise.""" + import enum + import typing + + _LITERAL_STRING_DOWNGRADE_HINT = ( + "Use ``str`` and validate the value inside the op body, e.g. " "``assert mode in ('a', 'b')``." + ) + origin = get_origin(annotation) + if origin is typing.Literal: + choices = get_args(annotation) + if choices and all(isinstance(c, str) for c in choices): + return str + raise TypeError( + f"@magi_register_custom_op: {where} has Literal {annotation!r}; " + f"only ``Literal[str, ...]`` is auto-downgraded. " + f"{_LITERAL_STRING_DOWNGRADE_HINT}" + ) + if isinstance(annotation, type) and issubclass(annotation, enum.Enum): + members = list(annotation) + if members and all(isinstance(m.value, str) for m in members): + return str + raise TypeError( + f"@magi_register_custom_op: {where} has non-string Enum " + f"{annotation.__name__!r}. {_LITERAL_STRING_DOWNGRADE_HINT}" + ) + return annotation + + +_SCHEMA_DEFAULT_TYPES: tuple[type, ...] = (int, float, bool, str, torch.device, torch.dtype) + + +def _schema_compatible_field_default(f: "dataclasses.Field") -> Any: + """Return schema-safe field default (including resolved ``default_factory``) or empty.""" + if f.default is not dataclasses.MISSING: + d = f.default + if d is None or isinstance(d, _SCHEMA_DEFAULT_TYPES): + return d + return inspect.Parameter.empty + if f.default_factory is not dataclasses.MISSING: # type: ignore[misc] + try: + d = f.default_factory() + except Exception: + return inspect.Parameter.empty + if d is None or isinstance(d, _SCHEMA_DEFAULT_TYPES): + return d + return inspect.Parameter.empty + return inspect.Parameter.empty + + +def _schema_compatible_param_default(default: Any) -> Any: + """Return schema-safe parameter default or ``inspect.Parameter.empty``.""" + if default is inspect.Parameter.empty: + return inspect.Parameter.empty + if default is None or isinstance(default, _SCHEMA_DEFAULT_TYPES): + return default + return inspect.Parameter.empty + + +# ------------------------------------------------------------------------------ +# helpers: build & query the param mapping tree +# ------------------------------------------------------------------------------ + +_DATACLASS_PYTREE_REGISTERED: set[type] = set() + + +def _register_dataclass_pytree(cls: type) -> None: + """Idempotently register dataclass ``cls`` as a pytree node for tracing.""" + if cls in _DATACLASS_PYTREE_REGISTERED: + return + + field_names = tuple(f.name for f in dataclasses.fields(cls)) + + def _flatten(obj): + return [getattr(obj, n) for n in field_names], field_names + + def _unflatten(values, ctx): + return cls(**dict(zip(ctx, values))) + + try: + pytree.register_pytree_node(cls, _flatten, _unflatten) + except ValueError: + # Already registered elsewhere (e.g. user code). + pass + _DATACLASS_PYTREE_REGISTERED.add(cls) + + +def _expand_mutates_args(param_mapping_tree: list[tuple], mutates_args: tuple[str, ...] | list[str]) -> tuple[str, ...]: + """Expand/validate ``mutates_args`` from original names into lowered leaf names.""" + + def _collect_tensor_leaf_lowered_attr_names(node: tuple) -> list[str]: + if node[0] == "primitive": + _, _attr, lowered_attr_name, _ = node + return [lowered_attr_name] + out: list[str] = [] + for child in node[3]: + out.extend(_collect_tensor_leaf_lowered_attr_names(child)) + return out + + if not mutates_args: + return tuple(mutates_args) + by_attr: dict[str, tuple] = {node[1]: node for node in param_mapping_tree} + valid_lowered: set[str] = set() + for node in param_mapping_tree: + valid_lowered.update(_collect_tensor_leaf_lowered_attr_names(node)) + out: list[str] = [] + for name in mutates_args: + if name in by_attr: + node = by_attr[name] + if node[0] == "primitive": + out.append(node[2]) + else: + out.extend(_collect_tensor_leaf_lowered_attr_names(node)) + elif name in valid_lowered: + out.append(name) + else: + raise ValueError( + f"@magi_register_custom_op: mutates_args entry {name!r} does " + f"not match any parameter. Valid: {sorted(by_attr.keys())} " + f"(or lowered: {sorted(valid_lowered)})." + ) + seen: set[str] = set() + deduped: list[str] = [] + for n in out: + if n not in seen: + seen.add(n) + deduped.append(n) + return tuple(deduped) + + +# ------------------------------------------------------------------------------ +# core: _lower_op_signature (and its lowered-signature wrapper utilities) +# ------------------------------------------------------------------------------ + + +def _apply_lowered_signature(lowered_sig: inspect.Signature, wrapper: Callable) -> None: + """Stamp wrapper signature/annotations with ``lowered_sig`` and clear ``__wrapped__``.""" + wrapper.__signature__ = lowered_sig + lowered_annotations = { + p.name: p.annotation for p in lowered_sig.parameters.values() if p.annotation is not inspect.Parameter.empty + } + if lowered_sig.return_annotation is not inspect.Signature.empty: + lowered_annotations["return"] = lowered_sig.return_annotation + wrapper.__annotations__ = lowered_annotations + # ``functools.wraps`` sets ``__wrapped__`` -> ``fn``; strip it so + # introspection cannot bypass our ``__signature__``. + try: + del wrapper.__wrapped__ + except AttributeError: + pass + + +def _lower_op_signature(fn: Callable): + """Lower ``fn``'s signature into a form ``torch.library.infer_schema`` accepts. + + Pipeline: + 1. RESOLVE -- turn stringified annotations/dataclass fields into real types (best-effort). + If failed, trying globals+closure evalining per annotation/dataclass field. + 2. FLATTEN -- recursively flatten frozen dataclass parameters into primitive leaves + and register the dataclass types as pytree nodes for Dynamo/AOTAutograd tracing. + 3. NORMALIZE -- collapse parameter kinds to POSITIONAL_OR_KEYWORD, + downgrade Literal/Enum to ``str``, scrub unsupported defaults. + 4. EMIT -- assemble ``(original_sig, lowered_sig, param_mapping_tree)``. Returns: - - 1 if the return type is a single Tensor - - N if the return type is tuple[Tensor, Tensor, ...] with N elements - - 1 if no annotation or unrecognized annotation (default to single output) + original_sig (inspect.Signature): the user's original input signature. + lowered_sig (inspect.Signature): schema-compatible signature for ``infer_schema``. + param_mapping_tree (list[tuple]): the bridge between the two; a list + of root nodes (one per original parameter), each of which is: + * ``("primitive", attr_name, lowered_attr_name, None)``, or + * ``("dataclass", attr_name, dataclass_cls_type, [child_nodes...])``. + + Example: + ``fn(x: Tensor, cfg: Outer(inner: Inner(scale: float, bias: Tensor), mode: str))`` + -> ``original_sig(x: Tensor, cfg: Outer)`` + -> ``lowered_sig(x: Tensor, cfg__inner__scale: float, cfg__inner__bias: Tensor, cfg__mode: str)`` + -> ``param_mapping_tree = [("primitive", "x", "x", None), + ("dataclass", "cfg", Outer, [ + ("dataclass", "inner", Inner, [ + ("primitive", "scale", "cfg__inner__scale", None), + ("primitive", "bias", "cfg__inner__bias", None), + ]), + ("primitive", "mode", "cfg__mode", None), + ])]`` """ + + original_sig = inspect.signature(fn) + resolved = _resolve_annotations(fn) + + def _lower_through( + params_or_fields, + *, + is_fn_params: bool, + owner_name: str, + flat_prefix: str | None = None, + field_types: dict[str, Any] | None = None, + ) -> tuple[list[tuple], list[inspect.Parameter]]: + nodes: list[tuple] = [] + lowered: list[inspect.Parameter] = [] + + if is_fn_params: + iterator = ((name, resolved.get(name, param.annotation), param) for name, param in params_or_fields) + else: + resolved_fields = field_types or {} + iterator = ((field.name, resolved_fields.get(field.name, field.type), field) for field in params_or_fields) + + for name, annotation, source in iterator: + leaf_flat_name = name if is_fn_params else f"{flat_prefix}__{name}" + where = f"parameter {name!r}" if is_fn_params else f"field {owner_name}.{name}" + + if _is_frozen_dataclass(annotation): + _register_dataclass_pytree(annotation) + child_nodes, child_params = _lower_through( + dataclasses.fields(annotation), + is_fn_params=False, + owner_name=annotation.__name__, + flat_prefix=leaf_flat_name, + field_types=_resolve_dataclass_field_types(annotation), + ) + nodes.append(("dataclass", name, annotation, child_nodes)) + lowered.extend(child_params) + else: + annotation = _maybe_downgrade_literal_or_enum(annotation, where=where) + if is_fn_params: + param = source + lowered.append( + param.replace( + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=annotation, + default=_schema_compatible_param_default(param.default), + ) + ) + else: + field = source + lowered.append( + inspect.Parameter( + leaf_flat_name, + # POSITIONAL_OR_KEYWORD: torch.library.custom_op does not yet + # support kwarg-only Tensor arguments. + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=annotation, + default=_schema_compatible_field_default(field), + ) + ) + nodes.append(("primitive", name, leaf_flat_name, None)) + + return nodes, lowered + + param_mapping_tree, lowered_params = _lower_through( + original_sig.parameters.items(), is_fn_params=True, owner_name=f"{fn.__name__!r}" + ) + + return_annotation = resolved.get("return", original_sig.return_annotation) + lowered_sig = inspect.Signature(lowered_params, return_annotation=return_annotation) + return original_sig, lowered_sig, param_mapping_tree + + +# ============================================================================== +# BLOCK 2 -- REGISTER torch op +# +# Helpers: +# - meta/fake-fn synthesis: +# _get_num_outputs_from_return_annotation, +# _create_identity_meta_fn, _create_meta_fn_from_param_names +# - op-name generation: +# _generate_op_name +# Core: _register_torch_op +# ============================================================================== + + +# ------------------------------------------------------------------------------ +# helpers: synthesise the meta/fake function +# ------------------------------------------------------------------------------ + + +def _get_num_outputs_from_return_annotation(fn: Callable) -> int: + """Output count from ``fn``'s return annotation: ``N`` for + ``tuple[T1, ..., TN]``, else ``1`` (default / unrecognized).""" sig = inspect.signature(fn) return_annotation = sig.return_annotation if return_annotation is inspect.Parameter.empty: return 1 - # Check if it's a tuple type (e.g., tuple[Tensor, Tensor]) origin = get_origin(return_annotation) if origin is tuple: args = get_args(return_annotation) - # Filter out ellipsis (for variable-length tuples like tuple[Tensor, ...]) + # tuple[T, ...] (variable-length) collapses to a single output. if args and args[-1] is not ...: return len(args) return 1 @@ -47,56 +588,17 @@ def _get_num_outputs_from_return_annotation(fn: Callable) -> int: return 1 -def _generate_op_name(fn: Callable) -> str: - """ - Generate a unique operator name from function's name and source file. - - Format: {filename_stem}::{function_name} - Example: my_module.py with function `my_op` -> "my_module::my_op" - - Falls back to "magi_custom::{function_name}" if source file cannot be determined. - """ - import re - from pathlib import Path - - func_name = fn.__name__ - - # Get the source file path - try: - source_file = inspect.getfile(fn) - # Extract the file stem (without extension) as namespace - namespace = Path(source_file).stem - # Clean up namespace: replace invalid characters with underscores - namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace) - except (TypeError, OSError): - # If we can't get the source file, use a default namespace - namespace = "magi_custom" - - return f"{namespace}::{func_name}" - - def _create_identity_meta_fn(fn: Callable) -> Callable: - """ - Create a default identity meta function for the given function. - - The generated meta function: - - Determines number of outputs from return type annotation - - Uses first N tensor inputs to infer output metadata - - Returns torch.empty_like() tensors with matching shape/dtype/device - - Raises ValueError if not enough tensor inputs are provided. - """ + """Default meta/fake: copy shape/dtype/device of the first N tensor inputs + to N outputs (N from the return annotation).""" num_outputs = _get_num_outputs_from_return_annotation(fn) sig = inspect.signature(fn) - # Get parameter names, excluding 'self' if present param_names = [name for name in sig.parameters.keys() if name != "self"] def identity_meta_fn(*args, **kwargs): - # Bind arguments to get a mapping of param_name -> value bound = sig.bind(*args, **kwargs) bound.apply_defaults() - # Collect the first `num_outputs` tensor arguments tensor_args = [] for name in param_names: arg = bound.arguments.get(name) @@ -107,12 +609,11 @@ def identity_meta_fn(*args, **kwargs): if len(tensor_args) < num_outputs: raise ValueError( - f"identity_meta_fn requires at least {num_outputs} tensor inputs to match " - f"{num_outputs} outputs, but only found {len(tensor_args)} tensor inputs. " - f"Please provide a custom infer_output_meta_fn." + f"@magi_register_custom_op: identity_meta_fn needs {num_outputs} " + f"tensor input(s) but found {len(tensor_args)}; provide a custom " + f"infer_output_meta_fn." ) - # Return outputs with same metadata as the first N inputs if num_outputs == 1: return torch.empty_like(tensor_args[0]) return tuple(torch.empty_like(t) for t in tensor_args[:num_outputs]) @@ -121,43 +622,32 @@ def identity_meta_fn(*args, **kwargs): def _create_meta_fn_from_param_names(fn: Callable, param_names: list[str]) -> Callable: - """ - Create a meta function that returns torch.empty_like() for each specified parameter. - - Args: - fn: Target function to inspect - param_names: List of parameter names to use as output templates - - Returns: - Meta function that maps specified input params to output tensors - - Raises: - ValueError: If parameter name doesn't exist or isn't a Tensor - """ + """Meta/fake that echoes the listed tensor parameters as outputs + (``torch.empty_like`` each). Raises ``ValueError`` for unknown or + non-Tensor names.""" sig = inspect.signature(fn) def meta_fn(*args, **kwargs): - # Bind arguments to get a mapping of param_name -> value bound = sig.bind(*args, **kwargs) bound.apply_defaults() - # Collect tensors for each specified parameter name tensor_outputs = [] for name in param_names: if name not in bound.arguments: raise ValueError( - f"Parameter '{name}' not found in function signature. " - f"Available parameters: {list(bound.arguments.keys())}" + f"@magi_register_custom_op: infer_output_meta_fn references " + f"unknown parameter {name!r}; available: " + f"{list(bound.arguments.keys())}." ) arg = bound.arguments[name] if not isinstance(arg, torch.Tensor): raise ValueError( - f"Parameter '{name}' is not a Tensor (got {type(arg).__name__}). " - f"infer_output_meta_fn list should only contain tensor parameter names." + f"@magi_register_custom_op: infer_output_meta_fn entry " + f"{name!r} is not a Tensor (got {type(arg).__name__}); " + f"list must contain only tensor parameter names." ) tensor_outputs.append(torch.empty_like(arg)) - # Return single tensor or tuple based on number of outputs if len(tensor_outputs) == 1: return tensor_outputs[0] return tuple(tensor_outputs) @@ -165,6 +655,237 @@ def meta_fn(*args, **kwargs): return meta_fn +# ------------------------------------------------------------------------------ +# helpers: generate op name +# ------------------------------------------------------------------------------ + + +def _generate_op_name(fn: Callable) -> str: + """Op name ``{filename_stem}::{fn.__name__}``, falling back to + ``magi_custom::`` if the source file isn't available.""" + import re + from pathlib import Path + + func_name = fn.__name__ + try: + source_file = inspect.getfile(fn) + namespace = Path(source_file).stem + namespace = re.sub(r"[^a-zA-Z0-9_]", "_", namespace) + except (TypeError, OSError): + namespace = "magi_custom" + + return f"{namespace}::{func_name}" + + +# ------------------------------------------------------------------------------ +# core: _register_torch_op +# +# Forward reference: ``_DataclassRuntimeAdapter`` using ``from __future__ import annotations``. +# ------------------------------------------------------------------------------ + + +def _register_torch_op( + op_name: str, + fn: Callable, + mutates_args: tuple[str, ...], + infer_output_meta_fn: Callable | list[str] | None, + setup_context_fn: Callable | None, + backward_fn: Callable | None, + dataclass_runtime_adapter: _DataclassRuntimeAdapter | None = None, +): + """Register the op in torch.library.custom_op.""" + effective_mutates_args = ( + dataclass_runtime_adapter.expand_mutates_args(mutates_args) if dataclass_runtime_adapter is not None else mutates_args + ) + torch_registered_op = torch.library.custom_op(op_name, mutates_args=effective_mutates_args)(fn) + + # Build & register the meta/fake function. + if infer_output_meta_fn is None: + meta_fn = _create_identity_meta_fn(fn) + elif isinstance(infer_output_meta_fn, list): + meta_fn = _create_meta_fn_from_param_names(fn, infer_output_meta_fn) + elif dataclass_runtime_adapter is None: # No flattening scenario + meta_fn = infer_output_meta_fn + else: # Flattening scenario + user_meta = infer_output_meta_fn + + def _bridged_meta_fn(*args, **kwargs): + return user_meta(**dataclass_runtime_adapter.unflatten_call_args(args, kwargs)) + + _bridged_meta_fn.__signature__ = inspect.signature(fn) + meta_fn = _bridged_meta_fn + torch.library.register_fake(op_name)(meta_fn) + + # Register autograd. + if backward_fn is not None: + if dataclass_runtime_adapter is None: # No flattening scenario + torch_registered_op.register_autograd(backward_fn, setup_context=setup_context_fn) + else: # Flattening scenario + + def _bridged_setup_context(ctx, inputs, output): + if setup_context_fn is None: + return None + original_inputs = dataclass_runtime_adapter.unflatten_setup_ctx_inputs(inputs) + return setup_context_fn(ctx, original_inputs, output) + + def _bridged_backward(ctx, *grads): + original_grads = backward_fn(ctx, *grads) + if not isinstance(original_grads, tuple): + original_grads = (original_grads,) + return dataclass_runtime_adapter.flatten_input_grads(original_grads) + + torch_registered_op.register_autograd(_bridged_backward, setup_context=_bridged_setup_context) + + return torch_registered_op + + +# ============================================================================== +# BLOCK 3 -- RUNTIME ADAPTER +# +# Helpers (adapter field <- bound function): +# original -> lowered: flatten_call_args <- _flatten_call_args +# flatten_input_grads <- _flatten_input_grads +# lowered -> original: unflatten_call_args <- _unflatten_call_args +# unflatten_setup_ctx_inputs <- _unflatten_setup_ctx_inputs +# _reassemble_kwargs (internal primitive) +# mutates_args expand: expand_mutates_args <- _expand_mutates_args +# signature stamping: apply_lowered_signature <- _apply_lowered_signature +# Core: _DataclassRuntimeAdapter +# ============================================================================== + + +# ---- flatten_call_args ---- + + +def _flatten_call_args(param_mapping_tree: list[tuple], original_sig: inspect.Signature, args: tuple, kwargs: dict) -> list: + """User-side call -> flat positional list matching the lowered signature + (the ``original -> lowered`` walk).""" + + def _flatten_value_into(node: tuple, value: Any, out: list) -> None: + """Append leaves of ``value`` to ``out`` in DFS order (no isinstance check + on ``cls`` -- duck-typed via ``getattr`` so mocks / SimpleNamespace work).""" + kind = node[0] + if kind == "primitive": + out.append(value) + return + _, _attr, _cls, children = node + for child in children: + field_name = child[1] + _flatten_value_into(child, getattr(value, field_name), out) + + bound = original_sig.bind(*args, **kwargs) + bound.apply_defaults() + flat: list = [] + for node in param_mapping_tree: + _flatten_value_into(node, bound.arguments[node[1]], flat) + return flat + + +# ---- flatten_input_grads ---- + + +def _flatten_input_grads(param_mapping_tree: list[tuple], original_grads: tuple) -> tuple: + """Original-space input grads -> lowered-space input grads.""" + + def _count_leaves(node: tuple) -> int: + if node[0] == "primitive": + return 1 + return sum(_count_leaves(c) for c in node[3]) + + def _flatten_grad_into(node: tuple, grad: Any, out: list) -> None: + """Spread a user-returned grad across lowered slots for one original input.""" + if node[0] == "primitive": + out.append(grad) + return + _, _attr, _cls, children = node + if grad is None: + for child in children: + for _ in range(_count_leaves(child)): + out.append(None) + return + is_mapping = isinstance(grad, dict) + for child in children: + field_name = child[1] + if is_mapping: + sub = grad.get(field_name) + else: + sub = getattr(grad, field_name, None) + _flatten_grad_into(child, sub, out) + + if len(original_grads) != len(param_mapping_tree): + raise ValueError( + f"@magi_register_custom_op: backward_fn returned {len(original_grads)} " + f"grad(s) but the function has {len(param_mapping_tree)} input(s); " + f"return one grad per ORIGINAL parameter (``None`` for non-differentiable " + f"or whole-dataclass args)." + ) + flat: list = [] + for node, g in zip(param_mapping_tree, original_grads): + _flatten_grad_into(node, g, flat) + return tuple(flat) + + +# ---- unflatten_call_args / unflatten_setup_ctx_inputs ---- + + +def _reassemble_kwargs(param_mapping_tree: list[tuple], lowered_kwargs: dict) -> dict: + """``lowered_kwargs`` -> original kwargs (the ``lowered -> original`` walk).""" + + def _build_value_from_node(node: tuple): + kind = node[0] + if kind == "primitive": + _, _attr, lowered_attr_name, _ = node + return lowered_kwargs[lowered_attr_name] + _, _attr, cls, children = node + init_kwargs: dict[str, Any] = {} + for child in children: + field_name = child[1] + init_kwargs[field_name] = _build_value_from_node(child) + return cls(**init_kwargs) + + out: dict[str, Any] = {} + for node in param_mapping_tree: + out[node[1]] = _build_value_from_node(node) + return out + + +def _unflatten_call_args(lowered_sig: inspect.Signature, param_mapping_tree: list[tuple], args: tuple, kwargs: dict) -> dict: + """Lowered call args/kwargs -> original kwargs (dict for ``fn(**dict)``).""" + bound = lowered_sig.bind(*args, **kwargs) + bound.apply_defaults() + return _reassemble_kwargs(param_mapping_tree, bound.arguments) + + +def _unflatten_setup_ctx_inputs( + lowered_sig: inspect.Signature, original_sig: inspect.Signature, param_mapping_tree: list[tuple], inputs: tuple +) -> tuple: + """Lowered positional inputs tuple -> original positional inputs tuple + (for ``setup_context_fn(ctx, inputs, output)``).""" + lowered_kwargs = {p.name: v for p, v in zip(lowered_sig.parameters.values(), inputs)} + original_kwargs = _reassemble_kwargs(param_mapping_tree, lowered_kwargs) + return tuple(original_kwargs[p] for p in original_sig.parameters) + + +# ---- core ---- + + +@dataclasses.dataclass(frozen=True) +class _DataclassRuntimeAdapter: + """Runtime conversion adapter.""" + + flatten_call_args: Callable[[tuple, dict], list] + flatten_input_grads: Callable[[tuple], tuple] + unflatten_call_args: Callable[[tuple, dict], dict] + unflatten_setup_ctx_inputs: Callable[[tuple], tuple] + expand_mutates_args: Callable[[tuple[str, ...]], tuple[str, ...]] + apply_lowered_signature: Callable[[Callable], None] + + +# ============================================================================== +# BLOCK 4 -- MAIN PIPELINE +# ============================================================================== + + def _magi_register_custom_op_impl( name: str | None = None, mutates_args: tuple[str, ...] = (), @@ -175,30 +896,92 @@ def _magi_register_custom_op_impl( is_subgraph_boundary: bool = False, ): def decorator(fn: Callable) -> Callable: - # Auto-generate name if not provided + # A 4-slot pipeline. + op_name = name if name is not None else _generate_op_name(fn) if is_compute_sensitive: get_compile_config().recompute_config.custom_compute_sensitive_ops.append(op_name) if is_subgraph_boundary: get_compile_config().splitting_ops.append(op_name) - # Step 1: Register the custom op with torch.library.custom_op - registered_op = torch.library.custom_op(op_name, mutates_args=mutates_args)(fn) + _validate_op_signature_constraints(fn) + original_sig, lowered_sig, param_mapping_tree = _lower_op_signature(fn) + needs_flattening = any(kind == "dataclass" for kind, *_ in param_mapping_tree) + + if not needs_flattening: + # ----- No-flattening scenario ----- + # Path: fn -> [lowered_fn ->] torch_registered_op + + # Step 1: Build ``lowered_fn`` iff the signature was rewritten. + if original_sig == lowered_sig: + fn_to_register = fn + else: # Signatures differ, need to wrap the function + + @functools.wraps(fn) + def lowered_fn(*args, **kwargs): + return fn(*args, **kwargs) + + _apply_lowered_signature(lowered_sig, lowered_fn) + fn_to_register = lowered_fn + + # Step 2: Register the op in torch and get ``torch_registered_op``. + torch_registered_op = _register_torch_op( + op_name=op_name, + fn=fn_to_register, + mutates_args=mutates_args, + infer_output_meta_fn=infer_output_meta_fn, + setup_context_fn=setup_context_fn, + backward_fn=backward_fn, + dataclass_runtime_adapter=None, + ) + + # Return bare torch-level op (slot 2). + return torch_registered_op - # Step 2: Register the output meta inference function - # Determine meta_fn based on the type of infer_output_meta_fn - if infer_output_meta_fn is None: - meta_fn = _create_identity_meta_fn(fn) - elif isinstance(infer_output_meta_fn, list): - meta_fn = _create_meta_fn_from_param_names(fn, infer_output_meta_fn) else: - meta_fn = infer_output_meta_fn - torch.library.register_fake(op_name)(meta_fn) + # ----- Flattening scenario ----- + # Path: fn -> lowered_fn -> torch_registered_op -> magi_exposed_op + + # Step 0 (only in the flattening scenario): Build the scenario-wide adapter. + dataclass_runtime_adapter = _DataclassRuntimeAdapter( + flatten_call_args=functools.partial(_flatten_call_args, param_mapping_tree, original_sig), + flatten_input_grads=functools.partial(_flatten_input_grads, param_mapping_tree), + unflatten_call_args=functools.partial(_unflatten_call_args, lowered_sig, param_mapping_tree), + unflatten_setup_ctx_inputs=functools.partial( + _unflatten_setup_ctx_inputs, lowered_sig, original_sig, param_mapping_tree + ), + expand_mutates_args=functools.partial(_expand_mutates_args, param_mapping_tree), + apply_lowered_signature=functools.partial(_apply_lowered_signature, lowered_sig), + ) + + # Step 1: Build ``lowered_fn``. + @functools.wraps(fn) + def lowered_fn(*args, **kwargs): + return fn(**dataclass_runtime_adapter.unflatten_call_args(args, kwargs)) + + dataclass_runtime_adapter.apply_lowered_signature(lowered_fn) + + # Step 2: Register the op in torch and get ``torch_registered_op``. + torch_registered_op = _register_torch_op( + op_name=op_name, + fn=lowered_fn, + mutates_args=mutates_args, + infer_output_meta_fn=infer_output_meta_fn, + setup_context_fn=setup_context_fn, + backward_fn=backward_fn, + dataclass_runtime_adapter=dataclass_runtime_adapter, + ) + + # Step 3 (only in the flattening scenario): Wrap the torch-level op and get ``magi_exposed_op``. + @functools.wraps(fn) + def magi_exposed_op(*args, **kwargs): + flat = dataclass_runtime_adapter.flatten_call_args(args, kwargs) + return torch_registered_op(*flat) - # Step 3: Register autograd if backward_fn is provided - if backward_fn is not None: - registered_op.register_autograd(backward_fn, setup_context=setup_context_fn) + magi_exposed_op._magi_torch_registered_op = torch_registered_op + magi_exposed_op._magi_param_mapping_tree = param_mapping_tree - return registered_op + # Return magi-level op. + return magi_exposed_op return decorator diff --git a/magi_compiler/api.py b/magi_compiler/api.py index 996657a..3f94989 100644 --- a/magi_compiler/api.py +++ b/magi_compiler/api.py @@ -206,38 +206,36 @@ def magi_register_custom_op( """ A unified decorator to register a custom operator with PyTorch's library. - This decorator combines the functionality of: + It supports advanced features like frozen-dataclass param and combines the + functionality of: - @torch.library.custom_op - @torch.library.register_fake - - fn.register_autograd + - @torch.library.register_autograd + and extends it to support nested dataclass parameters. Arguments: - name: The fully qualified name of the operator (e.g., "namespace::op_name"). - If None, auto-generated from the function name and source file. - mutates_args: Tuple of argument names that are mutated by the operator. - infer_output_meta_fn: Specifies output tensor metadata (shape, dtype, device) for tracing. - - None (default): Assumes each output has the same metadata as the corresponding - input tensor (1st output matches 1st tensor input, 2nd matches 2nd, etc.). - - list[str]: Parameter names whose metadata to use for outputs. - E.g., ["weight", "bias"] means output[0] has same shape as `weight`, - output[1] has same shape as `bias`. - - Callable: Custom function with same signature as the op, returns torch.empty_like() - tensors matching the expected output shapes. + name: Fully qualified op name (e.g. ``"namespace::op_name"``). If + ``None``, auto-generated from the function's source file and name. + mutates_args: Argument names that the op mutates. + infer_output_meta_fn: How to propagate output metadata at trace time. + - ``None`` (default): output ``i`` copies the ``i``-th tensor input. + - ``list[str]``: parameter names whose metadata to copy. E.g. + ``["weight", "bias"]`` makes ``output[0]`` shape-match ``weight`` + and ``output[1]`` shape-match ``bias``. + - ``Callable``: a function with the same signature as the op that + returns ``torch.empty_like(...)`` tensors of the expected shapes. setup_context_fn: Function to save tensors/values for backward. Signature: setup_context_fn(ctx, inputs, output) - backward_fn: Function to compute gradients. - Signature: backward_fn(ctx, *grad_outputs) -> tuple of gradients - is_compute_sensitive: If True, marks this operator as compute-intensive (e.g., MatMul, - Attention). During activation recomputation (rematerialization), outputs of - compute-sensitive ops are prioritized for saving rather than recomputing, - since recomputing them would be expensive. - is_subgraph_boundary: If True, the FX graph will be split at this operator during - compilation. Each sub-graph between boundary operators is compiled independently - by Inductor, enabling piecewise compilation and more flexible scheduling - (e.g., for CPU offloading or overlapping computation with data transfer). + backward_fn: Gradient computation; signature + ``backward_fn(ctx, *grad_outputs) -> tuple of grads``. + is_compute_sensitive: marks this operator as compute-intensive (e.g., + MatMul, Attention). During training, outputs of compute-sensitive + ops are prioritised for saving rather than recomputing. + is_subgraph_boundary: Split the FX graph at this op during compilation. + Each sub-graph between boundary ops is compiled independently. Returns: - The registered custom operator function. + A callable with the user's original signature. Examples: 1. Basic usage (forward only, auto-generated name and meta function): @@ -248,9 +246,7 @@ def magi_register_custom_op( 2. Multiple outputs with explicit output metadata via parameter names: - >>> @magi_register_custom_op( - ... infer_output_meta_fn=["weight", "bias"], # output shapes match weight and bias - ... ) + >>> @magi_register_custom_op(infer_output_meta_fn=["weight", "bias"]) ... def compute_gradients( ... grad_output: torch.Tensor, ... weight: torch.Tensor, @@ -260,27 +256,54 @@ def magi_register_custom_op( ... grad_bias = grad_output.sum(dim=0).view_as(bias) ... return grad_weight, grad_bias - 3. Full custom op with autograd support: + 3. Frozen-dataclass parameter (grouped config): - >>> def _square_meta(x: torch.Tensor) -> torch.Tensor: - ... return torch.empty_like(x) + ... @dataclasses.dataclass(frozen=True) + ... class AttnCfg: + ... scale: float + ... causal: bool = False ... - >>> def _square_setup_context(ctx, inputs, output): - ... (x,) = inputs - ... ctx.save_for_backward(x) + >>> @magi_register_custom_op() + ... def attn(q: torch.Tensor, k: torch.Tensor, cfg: AttnCfg) -> torch.Tensor: + ... pass + + 4. Backward with a nested-dataclass input (per-field grads by field name): + + For a dataclass input slot, the easiest way to express per-field + grads is to return a ``dict`` keyed by field name (nested for + nested dataclasses). Tensor fields carry their grads; non- + differentiable fields (or absent keys) use ``None``. The bridge + matches by field **name** (not by type), so any object exposing + the same names works; ``dict`` is just the most convenient. + + ... @dataclasses.dataclass(frozen=True) + ... class Inner: + ... w: torch.Tensor + ... b: torch.Tensor + ... + ... @dataclasses.dataclass(frozen=True) + ... class WeightCfg: + ... inner: Inner + ... scale: float + ... + ... def setup(ctx, inputs, output): + ... x, cfg = inputs + ... ctx.save_for_backward(x, cfg.inner.w) + ... ctx.scale = cfg.scale ... - >>> def _square_backward(ctx, grad_output): - ... (x,) = ctx.saved_tensors - ... return grad_output * 2 * x + ... def bwd(ctx, gy): + ... x, w = ctx.saved_tensors + ... s = ctx.scale + ... # Slot order matches the ORIGINAL signature of ``op`` below. + ... return ( + ... gy * w * s, # grad for x (Tensor) + ... {"inner": {"w": gy * x * s, "b": None}, "scale": None}, # grad for cfg (nested dict) + ... # equivalently: WeightCfg(inner=Inner(w=gy * x * s, b=None), scale=None) + ... ) ... - >>> @magi_register_custom_op( - ... name="my_ops::square", - ... infer_output_meta_fn=_square_meta, - ... setup_context_fn=_square_setup_context, - ... backward_fn=_square_backward, - ... ) - ... def square(x: torch.Tensor) -> torch.Tensor: - ... return x * x + >>> @magi_register_custom_op(setup_context_fn=setup, backward_fn=bwd) + ... def op(x: torch.Tensor, cfg: WeightCfg) -> torch.Tensor: + ... return (x * cfg.inner.w + cfg.inner.b) * cfg.scale """ return _magi_register_custom_op_impl( name=name, diff --git a/tests/api_tests/test_register_custom_op.py b/tests/api_tests/test_register_custom_op.py index 9872e68..071cee0 100644 --- a/tests/api_tests/test_register_custom_op.py +++ b/tests/api_tests/test_register_custom_op.py @@ -24,6 +24,7 @@ - Integration with magi_compile decorator """ +import dataclasses import tempfile from unittest.mock import patch @@ -113,126 +114,6 @@ def _split_op(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert_close(out2, x[..., 4:]) -class TestAutograd: - """Tests for custom op with autograd support.""" - - def test_with_autograd(self): - """Test custom op with setup_context and backward functions.""" - - def _square_infer_output_meta(x: torch.Tensor) -> torch.Tensor: - return torch.empty_like(x) - - def _square_setup_context(ctx, inputs, output): - (x,) = inputs - ctx.save_for_backward(x) - - def _square_backward(ctx, grad_output): - (x,) = ctx.saved_tensors - return grad_output * 2 * x - - @magi_register_custom_op( - name="test::square_op", - mutates_args=(), - infer_output_meta_fn=_square_infer_output_meta, - setup_context_fn=_square_setup_context, - backward_fn=_square_backward, - ) - def _square_op(x: torch.Tensor) -> torch.Tensor: - return x * x - - x = torch.randn(4, 8, requires_grad=True) - output = _square_op(x) - loss = output.sum() - loss.backward() - - # Gradient of x^2 is 2x - expected_grad = 2 * x - assert_close(x.grad, expected_grad) - - def test_autograd_multiple_inputs(self): - """Test autograd with multiple input tensors.""" - - def _weighted_sum_infer_output_meta(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor: - return torch.empty_like(a) - - def _weighted_sum_setup_context(ctx, inputs, output): - a, b, weight = inputs - ctx.save_for_backward(a, b) - ctx.weight = weight - - def _weighted_sum_backward(ctx, grad_output): - a, b = ctx.saved_tensors - weight = ctx.weight - grad_a = grad_output * weight - grad_b = grad_output * (1 - weight) - return grad_a, grad_b, None # None for non-tensor input - - @magi_register_custom_op( - name="test::weighted_sum_op", - mutates_args=(), - infer_output_meta_fn=_weighted_sum_infer_output_meta, - setup_context_fn=_weighted_sum_setup_context, - backward_fn=_weighted_sum_backward, - ) - def _weighted_sum_op(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor: - return a * weight + b * (1 - weight) - - a = torch.randn(4, 8, requires_grad=True) - b = torch.randn(4, 8, requires_grad=True) - weight = 0.7 - - output = _weighted_sum_op(a, b, weight) - loss = output.sum() - loss.backward() - - expected_grad_a = torch.ones_like(a) * weight - expected_grad_b = torch.ones_like(b) * (1 - weight) - - assert_close(a.grad, expected_grad_a) - assert_close(b.grad, expected_grad_b) - - def test_autograd_multiple_outputs(self): - """Test autograd with multiple output tensors.""" - - def _split_scale_infer_output_meta(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]: - half = x.shape[-1] // 2 - return (x.new_empty((*x.shape[:-1], half)), x.new_empty((*x.shape[:-1], half))) - - def _split_scale_setup_context(ctx, inputs, output): - x, scale = inputs - ctx.save_for_backward(x) - ctx.scale = scale - ctx.half = x.shape[-1] // 2 - - def _split_scale_backward(ctx, grad_out1, grad_out2): - (x,) = ctx.saved_tensors - scale = ctx.scale - # Reconstruct gradient for x - grad_x = torch.cat([grad_out1 * scale, grad_out2 * scale], dim=-1) - return grad_x, None - - @magi_register_custom_op( - name="test::split_scale_op", - mutates_args=(), - infer_output_meta_fn=_split_scale_infer_output_meta, - setup_context_fn=_split_scale_setup_context, - backward_fn=_split_scale_backward, - ) - def _split_scale_op(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]: - half = x.shape[-1] // 2 - return x[..., :half] * scale, x[..., half:] * scale - - x = torch.randn(4, 8, requires_grad=True) - scale = 2.0 - - out1, out2 = _split_scale_op(x, scale) - loss = out1.sum() + out2.sum() - loss.backward() - - expected_grad = torch.ones_like(x) * scale - assert_close(x.grad, expected_grad) - - class TestAutoGeneratedName: """Tests for auto-generated operator name when name is not provided.""" @@ -655,5 +536,1596 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert_close(out2, x[..., 4:] * 2) +@dataclasses.dataclass(frozen=True) +class _CSCfg: + s: float + + +class TestDataclassWithComputeSensitiveSmoke: + """Dataclass input + ``is_compute_sensitive=True`` should register cleanly + and the op name should land in ``custom_compute_sensitive_ops``. + """ + + def test_dataclass_compute_sensitive_smoke(self): + from magi_compiler.config import get_compile_config + + @magi_register_custom_op(name="test::dc_compute_sensitive", is_compute_sensitive=True) + def _op(x: torch.Tensor, cfg: _CSCfg) -> torch.Tensor: + return x * cfg.s + + out = _op(torch.ones(2), _CSCfg(s=2.0)) + assert_close(out, torch.full((2,), 2.0)) + assert "test::dc_compute_sensitive" in get_compile_config().recompute_config.custom_compute_sensitive_ops + + +class TestFrozenDataclassInput: + """Tests for frozen-dataclass input support via the auto pytree path.""" + + def test_forward_with_dataclass_only(self): + """Custom op whose only argument is a frozen dataclass.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _ScaleCfg: + scale: float + bias: float + + @magi_register_custom_op(name="test::dc_only_op", mutates_args=()) + def _dc_only_op(cfg: _ScaleCfg) -> torch.Tensor: + return torch.full((2, 3), cfg.scale + cfg.bias) + + cfg = _ScaleCfg(scale=2.0, bias=1.0) + out = _dc_only_op(cfg) + assert_close(out, torch.full((2, 3), 3.0)) + + def test_forward_with_mixed_args(self): + """Custom op with mixed dataclass and tensor inputs.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _AttnCfg: + scale: float + causal: bool + + @magi_register_custom_op(name="test::dc_mixed_op", mutates_args=()) + def _dc_mixed_op(q: torch.Tensor, k: torch.Tensor, cfg: _AttnCfg) -> torch.Tensor: + out = (q @ k.transpose(-1, -2)) * cfg.scale + if cfg.causal: + mask = torch.tril(torch.ones_like(out, dtype=torch.bool)) + out = out.masked_fill(~mask, 0.0) + return out + + q = torch.randn(2, 4, 8) + k = torch.randn(2, 4, 8) + cfg = _AttnCfg(scale=0.25, causal=False) + out = _dc_mixed_op(q, k, cfg) + expected = (q @ k.transpose(-1, -2)) * 0.25 + assert_close(out, expected) + + # Try positional + keyword call form to ensure outer wrapper handles both. + out_kw = _dc_mixed_op(q, k=k, cfg=cfg) + assert_close(out_kw, expected) + + def test_forward_with_custom_meta_fn(self): + """Custom op with a user-provided meta function expressed in dataclass terms.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _ProjCfg: + out_dim: int + + def _proj_meta(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: + return x.new_empty((*x.shape[:-1], cfg.out_dim)) + + @magi_register_custom_op(name="test::dc_meta_op", mutates_args=(), infer_output_meta_fn=_proj_meta) + def _dc_meta_op(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: + return x[..., : cfg.out_dim].clone() + + x = torch.randn(2, 8) + cfg = _ProjCfg(out_dim=3) + out = _dc_meta_op(x, cfg) + assert out.shape == (2, 3) + assert_close(out, x[..., :3]) + + +class TestNestedDataclassInput: + """Tests for *recursively* flattened nested-dataclass inputs.""" + + def test_nested_dataclass_only(self): + """A single dataclass argument that itself contains a dataclass field.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Inner: + scale: float + bias: float + + @dataclass(frozen=True) + class _Outer: + inner: _Inner + offset: float + + @magi_register_custom_op(name="test::nested_dc_only_op", mutates_args=()) + def _nested_dc_only_op(cfg: _Outer) -> torch.Tensor: + return torch.full((2, 3), cfg.inner.scale * cfg.inner.bias + cfg.offset) + + cfg = _Outer(inner=_Inner(scale=2.0, bias=3.0), offset=1.5) + out = _nested_dc_only_op(cfg) + assert_close(out, torch.full((2, 3), 7.5)) + + # The param mapping tree should fully expand the nested dataclass into leaves. + param_mapping_tree = _nested_dc_only_op._magi_param_mapping_tree + assert param_mapping_tree[0][0] == "dataclass" + assert param_mapping_tree[0][1] == "cfg" + children = param_mapping_tree[0][3] + kinds = [c[0] for c in children] + assert "dataclass" in kinds, "inner dataclass field must remain a dataclass node" + # Find the leaf lowered names. + lowered_names: list[str] = [] + + def _collect(node): + if node[0] == "primitive": + lowered_names.append(node[2]) + else: + for child in node[3]: + _collect(child) + + _collect(param_mapping_tree[0]) + assert "cfg__inner__scale" in lowered_names + assert "cfg__inner__bias" in lowered_names + assert "cfg__offset" in lowered_names + + def test_nested_dataclass_mixed_with_tensor(self): + """Tensor arg alongside a nested dataclass arg.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _NormCfg: + eps: float + affine: bool + + @dataclass(frozen=True) + class _LayerCfg: + norm: _NormCfg + scale: float + + @magi_register_custom_op(name="test::nested_dc_mixed_op", mutates_args=()) + def _nested_dc_mixed_op(x: torch.Tensor, cfg: _LayerCfg) -> torch.Tensor: + y = x / (x.std() + cfg.norm.eps) + if cfg.norm.affine: + y = y * cfg.scale + return y + + x = torch.randn(4, 8) + cfg = _LayerCfg(norm=_NormCfg(eps=1e-3, affine=True), scale=2.0) + out = _nested_dc_mixed_op(x, cfg) + expected = x / (x.std() + 1e-3) * 2.0 + assert_close(out, expected) + + # Keyword form too. + out_kw = _nested_dc_mixed_op(x=x, cfg=cfg) + assert_close(out_kw, expected) + + def test_deeply_nested_dataclass(self): + """Three levels of nesting plus a sibling primitive at the top level.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Leaf: + val: float + + @dataclass(frozen=True) + class _Mid: + leaf: _Leaf + extra: int + + @dataclass(frozen=True) + class _Root: + mid: _Mid + tag: float + + @magi_register_custom_op(name="test::deep_nested_dc_op", mutates_args=()) + def _deep_nested_dc_op(x: torch.Tensor, cfg: _Root, alpha: float) -> torch.Tensor: + return x * cfg.mid.leaf.val + cfg.mid.extra + cfg.tag + alpha + + x = torch.randn(2, 3) + cfg = _Root(mid=_Mid(leaf=_Leaf(val=2.0), extra=3), tag=0.5) + out = _deep_nested_dc_op(x, cfg, alpha=1.0) + assert_close(out, x * 2.0 + 3 + 0.5 + 1.0) + + def test_nested_dataclass_with_meta_fn(self): + """User-supplied meta function expressed in nested-dataclass terms.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _ShapeCfg: + out_dim: int + + @dataclass(frozen=True) + class _ProjCfg: + shape: _ShapeCfg + scale: float + + def _proj_meta(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: + return x.new_empty((*x.shape[:-1], cfg.shape.out_dim)) + + @magi_register_custom_op(name="test::nested_dc_meta_op", mutates_args=(), infer_output_meta_fn=_proj_meta) + def _nested_dc_meta_op(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: + return x[..., : cfg.shape.out_dim].clone() * cfg.scale + + x = torch.randn(2, 8) + cfg = _ProjCfg(shape=_ShapeCfg(out_dim=3), scale=2.0) + out = _nested_dc_meta_op(x, cfg) + assert out.shape == (2, 3) + assert_close(out, x[..., :3] * 2.0) + + +class TestDataclassOptionalFields: + """Frozen dataclass fields annotated with ``Optional[...]`` should flow + through the recursive flattener and be accepted by + ``torch.library.infer_schema`` as long as the underlying type is one of + the schema's known optional types. + + Coverage matrix (verified against torch 2.9.1): + + - ``Optional[Tensor]`` ✓ + - ``Optional[int]`` ✓ + - ``Optional[float]`` ✓ + - ``Optional[bool]`` ✓ + - ``Optional[str]`` ✓ + - ``Optional[list[int]]`` ✓ + - ``Tensor | None`` (PEP 604) ✓ + - ``Optional[list[Tensor]]`` ✗ (PyTorch limitation, see negative + test below; users should use ``list[Optional[Tensor]]`` instead) + """ + + def test_optional_scalar_and_tensor_fields(self): + from dataclasses import dataclass + from typing import Optional + + @dataclass(frozen=True) + class _Cfg: + bias: Optional[torch.Tensor] + scale: Optional[float] + mode: Optional[str] + block_sizes: Optional[list[int]] + flag: Optional[bool] + count: Optional[int] + + @magi_register_custom_op(name="test::dc_optional_mix") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + out = x.clone() + if cfg.bias is not None: + out = out + cfg.bias + if cfg.scale is not None: + out = out * cfg.scale + return out + + x = torch.ones(3) + # All-None: just the identity transform. + out_all_none = _op(x, _Cfg(None, None, None, None, None, None)) + assert_close(out_all_none, x) + # Some-None: bias + scale active, others ignored by the body. + cfg = _Cfg(bias=torch.tensor([1.0, 2.0, 3.0]), scale=2.0, mode="a", block_sizes=[4, 8], flag=True, count=7) + out_some = _op(x, cfg) + assert_close(out_some, torch.tensor([4.0, 6.0, 8.0])) + + def test_pep604_optional_tensor_field(self): + """The ``X | None`` PEP 604 syntax should be equivalent to + ``Optional[X]`` for dataclass fields.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + bias: torch.Tensor | None + + @magi_register_custom_op(name="test::dc_pep604_optional") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x.clone() if cfg.bias is None else x + cfg.bias + + x = torch.ones(2) + assert_close(_op(x, _Cfg(bias=None)), x) + assert_close(_op(x, _Cfg(bias=torch.tensor([10.0, 20.0]))), torch.tensor([11.0, 21.0])) + + def test_optional_list_of_tensors_unsupported_by_torch_library(self): + """Sanity-pin a known-broken case so a future torch upgrade that + relaxes ``infer_schema`` flips this from xfail to pass. + + ``Optional[list[Tensor]]`` is NOT in the torch.library schema's + accepted-type set (although ``list[Optional[Tensor]]`` is). The + magi flattener doesn't try to rewrite the user's annotation, so we + let the underlying ValueError propagate.""" + from dataclasses import dataclass + from typing import Optional + + @dataclass(frozen=True) + class _Cfg: + biases: Optional[list[torch.Tensor]] + + with pytest.raises(ValueError, match="unsupported type"): + + @magi_register_custom_op(name="test::dc_optional_list_tensor") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x + + +class TestDataclassListFields: + """Frozen dataclass fields annotated with ``list[...]`` flow through the + flattener and are accepted by ``torch.library.infer_schema`` for the + schema's supported scalar/tensor element types. + + Coverage matrix (verified against torch 2.9.1): + + - ``list[Tensor]`` ✓ + - ``list[int]`` ✓ + - ``list[float]`` ✓ + - ``list[bool]`` ✓ + - ``list[Optional[Tensor]]`` ✓ + """ + + def test_list_of_tensors_and_scalars_in_dataclass(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + biases: list[torch.Tensor] + scales: list[float] + block_sizes: list[int] + flags: list[bool] + + @magi_register_custom_op(name="test::dc_list_mix") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + out = x.clone() + for b in cfg.biases: + out = out + b + out = out * float(sum(cfg.scales)) + # ``flags`` and ``block_sizes`` are accepted on the schema side + # even though the body here doesn't depend on them. + assert all(isinstance(s, int) for s in cfg.block_sizes) + assert all(isinstance(f, bool) for f in cfg.flags) + return out + + x = torch.ones(3) + cfg = _Cfg( + biases=[torch.tensor([1.0, 2.0, 3.0]), torch.tensor([10.0, 20.0, 30.0])], + scales=[2.0, 0.5], + block_sizes=[4, 8], + flags=[True, False], + ) + out = _op(x, cfg) + # (1 + 1 + 10) * 2.5 = 30, (1 + 2 + 20) * 2.5 = 57.5, ... + assert_close(out, torch.tensor([30.0, 57.5, 85.0])) + + def test_list_of_optional_tensors_in_dataclass(self): + """``list[Optional[Tensor]]`` is on the schema's supported list (in + contrast to the rejected ``Optional[list[Tensor]]``).""" + from dataclasses import dataclass + from typing import Optional + + @dataclass(frozen=True) + class _Cfg: + maybe_biases: list[Optional[torch.Tensor]] + + @magi_register_custom_op(name="test::dc_list_optional_tensor") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + out = x.clone() + for b in cfg.maybe_biases: + if b is not None: + out = out + b + return out + + x = torch.ones(2) + out = _op(x, _Cfg(maybe_biases=[None, torch.tensor([10.0, 20.0]), None])) + assert_close(out, torch.tensor([11.0, 21.0])) + + def test_empty_list_field(self): + """An empty ``list[Tensor]`` value should round-trip through the + flattener and run the body normally.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + biases: list[torch.Tensor] + + @magi_register_custom_op(name="test::dc_empty_list") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + out = x.clone() + for b in cfg.biases: + out = out + b + return out + + x = torch.ones(3) + assert_close(_op(x, _Cfg(biases=[])), x) + + +class TestDataclassFieldDefaults: + """Dataclass field ``default`` / ``default_factory`` values should + propagate to the flat ``inspect.Signature`` so that calling the op + without those fields works (and ``torch.library.infer_schema`` records + them as optional).""" + + def test_dataclass_field_default_carried_into_signature(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + scale: float = 2.0 + + @magi_register_custom_op(name="test::dc_field_default_scale") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.scale + + x = torch.ones(3) + # User constructs the dataclass with the default value implicitly. + assert_close(_op(x, _Cfg()), torch.full((3,), 2.0)) + # And the override path still works. + assert_close(_op(x, _Cfg(scale=10.0)), torch.full((3,), 10.0)) + + def test_dataclass_field_default_factory_for_list(self): + from dataclasses import dataclass, field + + @dataclass(frozen=True) + class _Cfg: + block_sizes: list[int] = field(default_factory=lambda: [4, 8]) + + @magi_register_custom_op(name="test::dc_field_default_factory") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * float(sum(cfg.block_sizes)) + + x = torch.ones(2) + assert_close(_op(x, _Cfg()), torch.full((2,), 12.0)) + assert_close(_op(x, _Cfg(block_sizes=[1, 2, 3])), torch.full((2,), 6.0)) + + def test_nested_dataclass_with_default(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Inner: + val: float = 3.0 + + @dataclass(frozen=True) + class _Outer: + inner: _Inner = _Inner() + scale: float = 1.0 + + @magi_register_custom_op(name="test::dc_nested_default") + def _op(x: torch.Tensor, cfg: _Outer) -> torch.Tensor: + return x * cfg.inner.val * cfg.scale + + x = torch.ones(2) + assert_close(_op(x, _Outer()), torch.full((2,), 3.0)) + assert_close(_op(x, _Outer(inner=_Inner(val=5.0), scale=2.0)), torch.full((2,), 10.0)) + + +class TestDataclassUnsupportedContainerFields: + """``tuple[...]`` and ``dict[...]`` dataclass field annotations are not + accepted by ``torch.library.infer_schema`` (only ``list[...]`` is). We + reject them at registration time with an actionable hint pointing to the + field name and suggested fix.""" + + def test_tuple_field_rejected_with_hint(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + qkv: tuple[torch.Tensor, torch.Tensor, torch.Tensor] + + with pytest.raises(TypeError, match=r"list\[\.\.\.\]"): + + @magi_register_custom_op(name="test::dc_tuple_field") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x + + def test_dict_field_rejected_with_hint(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + weight_by_name: dict[str, torch.Tensor] + + with pytest.raises(TypeError, match="dict-typed"): + + @magi_register_custom_op(name="test::dc_dict_field") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x + + +class TestMutableDataclassRejected: + """``magi_register_custom_op`` only supports ``@dataclass(frozen=True)``. + + A non-frozen dataclass passed as an op input would otherwise produce a + confusing ``ValueError: Unsupported type annotation X. It is not a type`` + deep inside ``torch.library.infer_schema``. We surface a clear error + pointing at the fix. + """ + + def test_top_level_mutable_dataclass_rejected(self): + from dataclasses import dataclass + + @dataclass # NOTE: missing frozen=True + class _MutableCfg: + scale: float + + with pytest.raises(TypeError, match="frozen=True"): + + @magi_register_custom_op(name="test::mutable_dc_top") + def _op(x: torch.Tensor, cfg: _MutableCfg) -> torch.Tensor: + return x * cfg.scale + + def test_nested_mutable_dataclass_rejected(self): + """Inner non-frozen dataclass nested inside an outer frozen one is + also detected; the error names the offending field.""" + from dataclasses import dataclass + + @dataclass # NOT frozen + class _MutableInner: + val: float + + @dataclass(frozen=True) + class _OuterCfg: + inner: _MutableInner + + with pytest.raises(TypeError, match="_MutableInner"): + + @magi_register_custom_op(name="test::mutable_dc_nested") + def _op(x: torch.Tensor, cfg: _OuterCfg) -> torch.Tensor: + return x * cfg.inner.val + + def test_frozen_dataclass_still_works(self): + """Sanity: the rejection path doesn't break the supported case.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _OkCfg: + scale: float + + @magi_register_custom_op(name="test::frozen_dc_ok") + def _op(x: torch.Tensor, cfg: _OkCfg) -> torch.Tensor: + return x * cfg.scale + + out = _op(torch.ones(4), _OkCfg(scale=3.0)) + assert_close(out, torch.full((4,), 3.0)) + + +class TestDataclassReturnRejected: + """``torch.library`` cannot model dataclass returns; we should refuse with + an actionable error instead of letting ``infer_schema`` raise an opaque + ``ValueError`` from the depths. + """ + + def test_dataclass_return_rejected(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Out: + a: torch.Tensor + + with pytest.raises(TypeError, match="dataclass"): + + @magi_register_custom_op(name="test::dc_return") + def _op(x: torch.Tensor) -> _Out: + return _Out(a=x.clone()) + + def test_tensor_return_still_works(self): + @magi_register_custom_op(name="test::tensor_return_ok") + def _op(x: torch.Tensor) -> torch.Tensor: + return x + 1 + + out = _op(torch.zeros(3)) + assert_close(out, torch.ones(3)) + + +class TestDataclassUnresolvableLocalField: + """A dataclass field annotated with a class defined inside another + function (a "local class") under ``from __future__ import annotations`` + cannot be resolved by ``typing.get_type_hints``. We surface a clear + error pointing at the fix instead of a confusing schema error. + """ + + def test_unresolvable_local_field_rejected(self): + # We synthesize the failure by hand-constructing a frozen dataclass + # whose field type is a *string* that doesn't resolve in any visible + # namespace. This mirrors what ``from __future__ import annotations`` + # produces for a local class. + import dataclasses + + @dataclasses.dataclass(frozen=True) + class _OuterCfg: + inner: "_DefinitelyNotInScope" # noqa: F821 + + with pytest.raises(TypeError, match="unresolved string annotation"): + + @magi_register_custom_op(name="test::dc_unresolvable_local") + def _op(x: torch.Tensor, cfg: _OuterCfg) -> torch.Tensor: + return x + + +class TestPep563LocalDataclassResolved: + """Happy-path counterpart of ``TestDataclassUnresolvableLocalField``: when + a function is defined under ``from __future__ import annotations`` and + references a *local* dataclass that's actually visible in the enclosing + closure, ``_resolve_annotations``'s fallback path (the one that combines + ``fn.__globals__`` with ``inspect.getclosurevars(fn).nonlocals``) must + succeed and the op must register & run normally. + + Why this needs ``exec`` to set up: the test file as a whole does NOT + enable ``from __future__ import annotations`` (that would silently + string-ify every other test's annotations, a high-blast-radius change). + We therefore compile a small standalone snippet WITH that future flag + and inject a local frozen dataclass into its exec namespace. This + reproduces the real-world bug class -- "user wrote PEP 563 + factory + function returning an op" -- without contaminating the rest of the file. + + Without this test the resolver fallback (``_resolve_annotations``, + namespace-merging branch) has no regression net: dropping + ``cv.nonlocals`` from the merge would still leave every other test + green. + """ + + def test_local_dataclass_resolved_via_closure_fallback(self): + import dataclasses + import textwrap + + # Step 1: build a frozen dataclass that will live ONLY in the exec + # namespace -- it is intentionally NOT importable from any module, + # so ``typing.get_type_hints`` (which can only see module globals) + # MUST fail and we MUST fall through to the closurevars-based path. + @dataclasses.dataclass(frozen=True) + class _LocalCfg: + scale: float + bias: float + + # Step 2: source for a factory function whose annotations get + # stringified by ``from __future__ import annotations``. ``_LocalCfg`` + # is referenced both in the annotation (as a string) AND in the + # function body (as a real call) -- the body reference is what makes + # ``inspect.getclosurevars`` report ``_LocalCfg`` in ``.nonlocals``, + # which is exactly the namespace lookup we want to exercise. + src = textwrap.dedent( + """ + from __future__ import annotations + + def make_op(register, name, _LocalCfg): + @register(name=name) + def _op(x: torch.Tensor, cfg: _LocalCfg) -> torch.Tensor: + # body reference forces _LocalCfg into cv.nonlocals + _ = _LocalCfg + return x * cfg.scale + cfg.bias + return _op + """ + ) + + ns: dict = {"torch": torch} + exec(compile(src, "", "exec"), ns) + + op = ns["make_op"](magi_register_custom_op, "test::dc_pep563_local_resolved", _LocalCfg) + + x = torch.tensor([1.0, 2.0, 3.0]) + out = op(x, _LocalCfg(scale=2.0, bias=1.0)) + assert_close(out, torch.tensor([3.0, 5.0, 7.0])) + + +@dataclasses.dataclass(frozen=True) +class _OptScalarCfg: + softcap: float | None = None + causal: bool | None = None + + +class TestDataclassOptionalScalarFields: + """Common attention/cfg pattern: dataclass field is ``Optional[scalar]`` + with a ``None`` default. Verify forward across explicit-None and + explicit-value calls. + """ + + def test_optional_scalar_defaults_and_overrides(self): + @magi_register_custom_op(name="test::dc_opt_scalar") + def _op(x: torch.Tensor, cfg: _OptScalarCfg) -> torch.Tensor: + y = x.clone() + if cfg.softcap is not None: + y = torch.tanh(y / cfg.softcap) * cfg.softcap + if cfg.causal is True: + y = y * 0.5 + return y + + x = torch.randn(4) + assert_close(_op(x, _OptScalarCfg()), x) + out = _op(x, _OptScalarCfg(softcap=2.0, causal=True)) + expected = torch.tanh(x / 2.0) * 2.0 * 0.5 + assert_close(out, expected) + + +@dataclasses.dataclass(frozen=True) +class _DtypeCfg: + out_dtype: torch.dtype = torch.float16 + out_device: torch.device | None = None + + +class TestDataclassDtypeDeviceFields: + """``torch.dtype`` / ``torch.device`` are the canonical schema-supported + "structured scalar" types. Make sure they survive the dataclass-flatten + round-trip with sensible defaults. + """ + + def test_dtype_device_round_trip(self): + @magi_register_custom_op(name="test::dc_dtype_device") + def _op(x: torch.Tensor, cfg: _DtypeCfg) -> torch.Tensor: + return x.to(dtype=cfg.out_dtype) + + x = torch.randn(4) + out = _op(x, _DtypeCfg(out_dtype=torch.bfloat16)) + assert out.dtype == torch.bfloat16 + + +class TestTopLevelTupleDictRejected: + """Top-level parameters of type ``tuple[T, ...]`` / ``dict[K, V]`` are + rejected with the same actionable message as dataclass fields. + """ + + def test_top_level_tuple_rejected(self): + with pytest.raises(TypeError, match="tuple"): + + @magi_register_custom_op(name="test::top_tuple") + def _op(xs: tuple[torch.Tensor, ...]) -> torch.Tensor: + return xs[0] + + def test_top_level_dict_rejected(self): + with pytest.raises(TypeError, match="dict"): + + @magi_register_custom_op(name="test::top_dict") + def _op(xs: dict[str, torch.Tensor]) -> torch.Tensor: + return next(iter(xs.values())) + + +class TestMissingAnnotationRejected: + """Every parameter on the registered op (and its return value) must have + a type annotation; ``*args`` / ``**kwargs`` are not supported either. + + Without an explicit guard these mistakes still fail at registration time + but with a confusing low-level error from ``torch.library.infer_schema``. + ``magi_register_custom_op`` re-raises them as a clear ``TypeError`` with + an actionable hint and a ``where=`` locator. + """ + + def test_op_parameter_without_annotation_rejected(self): + with pytest.raises(TypeError, match="no type annotation"): + + @magi_register_custom_op(name="test::missing_param_ann") + def _op(x: torch.Tensor, scale) -> torch.Tensor: + return x * scale + + def test_op_first_parameter_without_annotation_rejected(self): + with pytest.raises(TypeError, match=r"parameter 'x'"): + + @magi_register_custom_op(name="test::missing_first_param_ann") + def _op(x, y: torch.Tensor) -> torch.Tensor: + return x + y + + def test_op_return_annotation_missing_rejected(self): + with pytest.raises(TypeError, match="return value"): + + @magi_register_custom_op(name="test::missing_return_ann") + def _op(x: torch.Tensor): + return x + + def test_var_positional_rejected(self): + with pytest.raises(TypeError, match=r"\*xs"): + + @magi_register_custom_op(name="test::varargs_rejected") + def _op(*xs: torch.Tensor) -> torch.Tensor: + return xs[0] + + def test_var_keyword_rejected(self): + with pytest.raises(TypeError, match=r"\*\*kwargs"): + + @magi_register_custom_op(name="test::varkw_rejected") + def _op(x: torch.Tensor, **kwargs: float) -> torch.Tensor: + return x + + def test_fully_annotated_op_still_works(self): + """Sanity check: the new guards do not regress the happy path.""" + + @magi_register_custom_op(name="test::fully_annotated_ok") + def _op(x: torch.Tensor, scale: float) -> torch.Tensor: + return x * scale + + x = torch.tensor([1.0, 2.0]) + assert_close(_op(x, scale=2.0), torch.tensor([2.0, 4.0])) + + +class TestLiteralAndEnumDowngrade: + """``Literal[str, ...]`` and string-valued ``Enum`` annotations are + auto-downgraded to ``str`` at the schema boundary; the op body still + receives the raw string. Numeric/heterogeneous Literals are rejected + with a clear error. + """ + + def test_literal_str_downgraded(self): + from typing import Literal + + @magi_register_custom_op(name="test::literal_str") + def _op(x: torch.Tensor, mode: Literal["a", "b"]) -> torch.Tensor: + return x * (2.0 if mode == "a" else 3.0) + + x = torch.ones(3) + assert_close(_op(x, "a"), torch.full((3,), 2.0)) + assert_close(_op(x, "b"), torch.full((3,), 3.0)) + + def test_string_enum_downgraded(self): + import enum + + class _Mode(enum.Enum): + A = "a" + B = "b" + + @magi_register_custom_op(name="test::enum_str") + def _op(x: torch.Tensor, mode: _Mode) -> torch.Tensor: + return x * (2.0 if mode == "a" else 3.0) + + # Note: caller passes the raw string value (the schema sees ``str``). + x = torch.ones(3) + assert_close(_op(x, "a"), torch.full((3,), 2.0)) + + def test_int_literal_rejected(self): + from typing import Literal + + with pytest.raises(TypeError, match="Literal"): + + @magi_register_custom_op(name="test::literal_int") + def _op(x: torch.Tensor, k: Literal[1, 2]) -> torch.Tensor: + return x + + def test_int_valued_enum_rejected(self): + import enum + + class _IntMode(enum.Enum): + A = 1 + B = 2 + + with pytest.raises(TypeError, match="Enum"): + + @magi_register_custom_op(name="test::enum_int") + def _op(x: torch.Tensor, mode: _IntMode) -> torch.Tensor: + return x + + +class TestTopLevelOptionalTensor: + """``Optional[Tensor]`` / ``Tensor | None`` is the most common attention / + bias / mask pattern. Verify both annotation flavours, with and without + a ``None`` default, and across forward + autograd paths. + """ + + def test_optional_tensor_default_none(self): + @magi_register_custom_op(name="test::opt_tensor_default") + def _op(x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: + return x.clone() if bias is None else x + bias + + x = torch.randn(3, 4) + b = torch.randn(3, 4) + assert_close(_op(x), x) + assert_close(_op(x, b), x + b) + + def test_optional_tensor_no_default_explicit_none(self): + @magi_register_custom_op(name="test::opt_tensor_explicit") + def _op(x: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor: + return x.clone() if bias is None else x + bias + + x = torch.randn(3, 4) + assert_close(_op(x, None), x) + + +class TestReturnListTensor: + """``list[Tensor]`` is the canonical return shape for split / chunk / + expert-output ops. Validate forward and that a user-supplied + ``infer_output_meta_fn`` is required for tracing. + """ + + def test_list_tensor_return_runtime(self): + def meta(x, n): + return [torch.empty_like(x.chunk(n, dim=0)[0]) for _ in range(n)] + + @magi_register_custom_op(name="test::list_tensor_return", infer_output_meta_fn=meta) + def _op(x: torch.Tensor, n: int) -> list[torch.Tensor]: + # ``.chunk`` returns views that alias ``x``; clone each chunk so + # the op doesn't violate torch.library's no-aliasing invariant + # for outputs. + return [c.clone() for c in x.chunk(n, dim=0)] + + x = torch.randn(8, 4) + out = _op(x, 4) + assert isinstance(out, list) + assert len(out) == 4 + assert all(o.shape == (2, 4) for o in out) + + +class TestTopLevelListIntDefault: + """A top-level parameter typed ``list[int]`` with a non-None default + should not crash registration (the previous behavior was to forward the + list literal to ``infer_schema``, which rejects it). + """ + + def test_list_int_default_does_not_crash(self): + @magi_register_custom_op(name="test::list_int_default") + def _op(x: torch.Tensor, dims: list[int] = [0]) -> torch.Tensor: + return x.sum(dim=dims, keepdim=True) + + x = torch.randn(3, 4) + # Caller still has to supply the list explicitly because the schema + # default was scrubbed; that's the documented price of using a list + # default with torch.library. + out = _op(x, [0]) + assert out.shape == (1, 4) + + +class TestAutograd: + """Tests for custom op with autograd support.""" + + def test_with_autograd(self): + """Test custom op with setup_context and backward functions.""" + + def _square_infer_output_meta(x: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x) + + def _square_setup_context(ctx, inputs, output): + (x,) = inputs + ctx.save_for_backward(x) + + def _square_backward(ctx, grad_output): + (x,) = ctx.saved_tensors + return grad_output * 2 * x + + @magi_register_custom_op( + name="test::square_op", + mutates_args=(), + infer_output_meta_fn=_square_infer_output_meta, + setup_context_fn=_square_setup_context, + backward_fn=_square_backward, + ) + def _square_op(x: torch.Tensor) -> torch.Tensor: + return x * x + + x = torch.randn(4, 8, requires_grad=True) + output = _square_op(x) + loss = output.sum() + loss.backward() + + # Gradient of x^2 is 2x + expected_grad = 2 * x + assert_close(x.grad, expected_grad) + + def test_autograd_multiple_inputs(self): + """Test autograd with multiple input tensors.""" + + def _weighted_sum_infer_output_meta(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor: + return torch.empty_like(a) + + def _weighted_sum_setup_context(ctx, inputs, output): + a, b, weight = inputs + ctx.save_for_backward(a, b) + ctx.weight = weight + + def _weighted_sum_backward(ctx, grad_output): + a, b = ctx.saved_tensors + weight = ctx.weight + grad_a = grad_output * weight + grad_b = grad_output * (1 - weight) + return grad_a, grad_b, None # None for non-tensor input + + @magi_register_custom_op( + name="test::weighted_sum_op", + mutates_args=(), + infer_output_meta_fn=_weighted_sum_infer_output_meta, + setup_context_fn=_weighted_sum_setup_context, + backward_fn=_weighted_sum_backward, + ) + def _weighted_sum_op(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor: + return a * weight + b * (1 - weight) + + a = torch.randn(4, 8, requires_grad=True) + b = torch.randn(4, 8, requires_grad=True) + weight = 0.7 + + output = _weighted_sum_op(a, b, weight) + loss = output.sum() + loss.backward() + + expected_grad_a = torch.ones_like(a) * weight + expected_grad_b = torch.ones_like(b) * (1 - weight) + + assert_close(a.grad, expected_grad_a) + assert_close(b.grad, expected_grad_b) + + def test_autograd_multiple_outputs(self): + """Test autograd with multiple output tensors.""" + + def _split_scale_infer_output_meta(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]: + half = x.shape[-1] // 2 + return (x.new_empty((*x.shape[:-1], half)), x.new_empty((*x.shape[:-1], half))) + + def _split_scale_setup_context(ctx, inputs, output): + x, scale = inputs + ctx.save_for_backward(x) + ctx.scale = scale + ctx.half = x.shape[-1] // 2 + + def _split_scale_backward(ctx, grad_out1, grad_out2): + (x,) = ctx.saved_tensors + scale = ctx.scale + # Reconstruct gradient for x + grad_x = torch.cat([grad_out1 * scale, grad_out2 * scale], dim=-1) + return grad_x, None + + @magi_register_custom_op( + name="test::split_scale_op", + mutates_args=(), + infer_output_meta_fn=_split_scale_infer_output_meta, + setup_context_fn=_split_scale_setup_context, + backward_fn=_split_scale_backward, + ) + def _split_scale_op(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]: + half = x.shape[-1] // 2 + return x[..., :half] * scale, x[..., half:] * scale + + x = torch.randn(4, 8, requires_grad=True) + scale = 2.0 + + out1, out2 = _split_scale_op(x, scale) + loss = out1.sum() + out2.sum() + loss.backward() + + expected_grad = torch.ones_like(x) * scale + assert_close(x.grad, expected_grad) + + +class TestDataclassInputWithBackward: + """``backward_fn`` is bridged so users can write it in the ORIGINAL + (dataclass-bearing) signature even though the registered op underneath + is flat. These tests pin every shape of bridging we promise. + """ + + def test_backward_with_tensor_and_dataclass_input(self): + """``backward`` returns ``(grad_x, None)`` against the original sig.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _ScaleCfg: + scale: float + + def _setup(ctx, inputs, output): + x, cfg = inputs # NOTE: original-signature view, not flat + assert isinstance(cfg, _ScaleCfg) + ctx.save_for_backward(x) + ctx.scale = cfg.scale + + def _bwd(ctx, grad_out): + (_x,) = ctx.saved_tensors + return grad_out * ctx.scale, None + + @magi_register_custom_op(name="test::dc_bwd_basic", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) + def _op(x: torch.Tensor, cfg: _ScaleCfg) -> torch.Tensor: + return x * cfg.scale + + x = torch.randn(4, 8, requires_grad=True) + cfg = _ScaleCfg(scale=3.0) + y = _op(x, cfg) + assert_close(y, x * 3.0) + (gx,) = torch.autograd.grad(y.sum(), x) + assert_close(gx, torch.full_like(x, 3.0)) + + def test_backward_returning_bare_grad_is_allowed(self): + """Single-input convenience: returning a bare tensor instead of a + 1-tuple should still work, mirroring stock PyTorch behaviour.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + alpha: float + + def _setup(ctx, inputs, output): + (cfg,) = inputs + ctx.alpha = cfg.alpha + + def _bwd(ctx, grad_out): + # Op has zero tensor inputs; nothing to backprop into. Returning + # ``None`` (or even a bare value) for the lone non-tensor input + # must be accepted by the bridge. + return None + + @magi_register_custom_op(name="test::dc_bwd_bare_grad", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) + def _op(cfg: _Cfg) -> torch.Tensor: + return torch.full((2, 3), cfg.alpha) + + out = _op(_Cfg(alpha=1.5)) + assert_close(out, torch.full((2, 3), 1.5)) + + def test_backward_with_per_field_grads(self): + """If the user wants to express grads at field granularity, returning + a same-shape dataclass (or a dict) for the dataclass slot must be + spread to the underlying flat slots. The bridge accepts ``None`` + leaves to stand in for non-differentiable fields.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + scale: float + offset: float + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + ctx.scale = cfg.scale + + def _bwd(ctx, grad_out): + # Express the dataclass grad as a same-shape dataclass with + # ``None`` leaves; the bridge must spread these to flat slots. + return grad_out * ctx.scale, _Cfg(scale=None, offset=None) + + @magi_register_custom_op(name="test::dc_bwd_per_field", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.scale + cfg.offset + + x = torch.randn(3, requires_grad=True) + y = _op(x, _Cfg(scale=2.0, offset=0.5)) + (gx,) = torch.autograd.grad(y.sum(), x) + assert_close(gx, torch.full_like(x, 2.0)) + + def test_backward_with_dict_shaped_dc_grad(self): + """The bridge should also accept a plain ``dict`` for the dataclass + slot (handy when the user does not want to construct another instance).""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + scale: float + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + ctx.scale = cfg.scale + + def _bwd(ctx, grad_out): + return grad_out * ctx.scale, {"scale": None} + + @magi_register_custom_op(name="test::dc_bwd_dict_grad", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.scale + + x = torch.randn(5, requires_grad=True) + y = _op(x, _Cfg(scale=4.0)) + (gx,) = torch.autograd.grad(y.sum(), x) + assert_close(gx, torch.full_like(x, 4.0)) + + def test_backward_with_nested_dataclass(self): + """Backward bridging must descend through nested dataclasses (the + flat slot count for a nested dc with ``None`` grad equals the total + number of leaf fields).""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Inner: + scale: float + bias: float + + @dataclass(frozen=True) + class _Outer: + inner: _Inner + tag: float + + def _setup(ctx, inputs, output): + x, cfg = inputs + assert isinstance(cfg, _Outer) and isinstance(cfg.inner, _Inner) + ctx.save_for_backward(x) + ctx.scale = cfg.inner.scale + + def _bwd(ctx, grad_out): + return grad_out * ctx.scale, None # whole nested dc => None + + @magi_register_custom_op(name="test::dc_bwd_nested", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) + def _op(x: torch.Tensor, cfg: _Outer) -> torch.Tensor: + return x * cfg.inner.scale + cfg.inner.bias + cfg.tag + + x = torch.randn(2, 4, requires_grad=True) + cfg = _Outer(inner=_Inner(scale=1.5, bias=0.25), tag=0.0) + y = _op(x, cfg) + (gx,) = torch.autograd.grad(y.sum(), x) + assert_close(gx, torch.full_like(x, 1.5)) + + def test_backward_two_tensor_inputs_around_dataclass(self): + """Two tensor inputs sandwiching a dataclass: both gradients must + land in the right flat slots.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + alpha: float + beta: float + + def _setup(ctx, inputs, output): + a, cfg, b = inputs + ctx.save_for_backward(a, b) + ctx.alpha = cfg.alpha + ctx.beta = cfg.beta + + def _bwd(ctx, grad_out): + a, b = ctx.saved_tensors + return grad_out * ctx.alpha, None, grad_out * ctx.beta + + @magi_register_custom_op(name="test::dc_bwd_sandwich", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) + def _op(a: torch.Tensor, cfg: _Cfg, b: torch.Tensor) -> torch.Tensor: + return a * cfg.alpha + b * cfg.beta + + a = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + cfg = _Cfg(alpha=2.0, beta=-3.0) + y = _op(a, cfg, b) + ga, gb = torch.autograd.grad(y.sum(), [a, b]) + assert_close(ga, torch.full_like(a, 2.0)) + assert_close(gb, torch.full_like(b, -3.0)) + + def test_backward_wrong_grad_count_raises(self): + """If ``backward_fn`` returns the wrong number of grads (relative to + the ORIGINAL signature), the user gets a clear error rather than an + opaque autograd shape mismatch.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + scale: float + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + + def _bad_bwd(ctx, grad_out): + return (grad_out,) # missing the dataclass slot + + @magi_register_custom_op( + name="test::dc_bwd_wrong_count", mutates_args=(), setup_context_fn=_setup, backward_fn=_bad_bwd + ) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.scale + + x = torch.randn(3, requires_grad=True) + y = _op(x, _Cfg(scale=2.0)) + with pytest.raises(ValueError, match=r"backward_fn returned 1 grad"): + torch.autograd.grad(y.sum(), x) + + +@dataclasses.dataclass(frozen=True) +class _BwTupleCfg: + s: float + + +class TestTupleOutputBackward: + """``backward_fn`` with ``tuple`` outputs and a dataclass input.""" + + def test_two_output_backward_with_dataclass(self): + def setup(ctx, inputs, output): + x, _cfg = inputs + ctx.save_for_backward(x) + ctx.s = _cfg.s + + def bwd(ctx, gy0, gy1): + (x,) = ctx.saved_tensors + # d(out0)/dx = s, d(out1)/dx = 1 + return gy0 * ctx.s + gy1, None + + @magi_register_custom_op(name="test::tuple_bwd_dc", setup_context_fn=setup, backward_fn=bwd) + def _op(x: torch.Tensor, cfg: _BwTupleCfg) -> tuple[torch.Tensor, torch.Tensor]: + return x * cfg.s, x.clone() + + x = torch.randn(4, requires_grad=True) + y0, y1 = _op(x, _BwTupleCfg(s=3.0)) + (y0.sum() + 2.0 * y1.sum()).backward() + assert x.grad is not None + # gy0 = 1, gy1 = 2 -> grad = 1*3 + 2 = 5 + assert torch.allclose(x.grad, torch.full_like(x, 5.0)) + + +class TestPerFieldNoneGrad: + """The autograd bridge supports returning a partially-None grad for a + dataclass input (i.e. some Tensor fields differentiable, others not). + """ + + def test_partial_none_grad_for_dataclass(self): + @dataclasses.dataclass(frozen=True) + class _Cfg: + w: torch.Tensor + b: torch.Tensor # we'll mark this non-differentiable + + def setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x, cfg.w) + + def bwd(ctx, gy): + x, w = ctx.saved_tensors + # x: gx = gy * w; cfg.w: gw = gy * x; cfg.b: None. + return gy * w, {"w": gy * x, "b": None} + + @magi_register_custom_op(name="test::partial_none_grad", setup_context_fn=setup, backward_fn=bwd) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.w + cfg.b + + x = torch.randn(4, requires_grad=True) + w = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=False) + cfg = _Cfg(w=w, b=b) + out = _op(x, cfg) + out.sum().backward() + assert torch.allclose(x.grad, w) + assert torch.allclose(w.grad, x) + # b did not require grad; nothing to assert beyond "no exception". + + +class TestStructuralDataclassGrad: + """The dataclass-grad slot is matched **structurally by field name** -- the + returned object does NOT have to be an instance of the input dataclass. + Any object that exposes the same field names (the input dataclass itself, + a different dataclass, a ``SimpleNamespace``, etc.) is accepted, and + Tensor leaves are routed to the corresponding flat slots. + """ + + def test_grad_returned_as_same_dataclass(self): + """Baseline: backward_fn returns an instance of the *same* dataclass + type as the input. Tensor fields carry real grads, demonstrating that + the dataclass-grad path works end-to-end before we exercise the + structural (foreign-type) variants below.""" + + @dataclasses.dataclass(frozen=True) + class _Cfg: + w: torch.Tensor + b: torch.Tensor + + def setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x, cfg.w) + + def bwd(ctx, gy): + x, w = ctx.saved_tensors + # Use the SAME dataclass type _Cfg to carry the field-level grads. + return gy * w, _Cfg(w=gy * x, b=torch.zeros_like(w)) + + @magi_register_custom_op(name="test::dc_grad_same_type", setup_context_fn=setup, backward_fn=bwd) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.w + cfg.b + + x = torch.randn(4, requires_grad=True) + w = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + cfg = _Cfg(w=w, b=b) + out = _op(x, cfg) + out.sum().backward() + assert torch.allclose(x.grad, w) + assert torch.allclose(w.grad, x) + # b's grad came from _Cfg.b = zeros_like(w), so it should be 0. + assert torch.allclose(b.grad, torch.zeros_like(b)) + + def test_grad_returned_as_different_dataclass(self): + """``backward_fn`` returns a *different* dataclass type that just + happens to share field names with the input dataclass. The bridge + must spread its Tensor fields to the correct flat grad slots based + purely on field-name matching.""" + + @dataclasses.dataclass(frozen=True) + class _Cfg: + w: torch.Tensor + b: torch.Tensor + + @dataclasses.dataclass(frozen=True) + class _CfgGrad: + # Same field names as _Cfg, different class. Distinct identity + # to prove name-only matching (no isinstance check). + w: torch.Tensor + b: torch.Tensor + + def setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x, cfg.w) + + def bwd(ctx, gy): + x, w = ctx.saved_tensors + # Express grads via a foreign dataclass type with matching field + # names. The bridge must NOT require _Cfg specifically. + return gy * w, _CfgGrad(w=gy * x, b=torch.zeros_like(w)) + + @magi_register_custom_op(name="test::dc_grad_foreign_type", setup_context_fn=setup, backward_fn=bwd) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.w + cfg.b + + x = torch.randn(4, requires_grad=True) + w = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + cfg = _Cfg(w=w, b=b) + out = _op(x, cfg) + out.sum().backward() + assert torch.allclose(x.grad, w) + assert torch.allclose(w.grad, x) + # b's grad came from _CfgGrad.b = zeros_like(w), so it should be 0. + assert torch.allclose(b.grad, torch.zeros_like(b)) + + def test_grad_returned_as_simple_namespace(self): + """A non-dataclass object (``types.SimpleNamespace``) that exposes + the dataclass field names should also be accepted.""" + import types as _types + + @dataclasses.dataclass(frozen=True) + class _Cfg: + w: torch.Tensor + b: torch.Tensor + + def setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x, cfg.w) + + def bwd(ctx, gy): + x, w = ctx.saved_tensors + return gy * w, _types.SimpleNamespace(w=gy * x, b=torch.zeros_like(w)) + + @magi_register_custom_op(name="test::dc_grad_namespace", setup_context_fn=setup, backward_fn=bwd) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.w + cfg.b + + x = torch.randn(4, requires_grad=True) + w = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + cfg = _Cfg(w=w, b=b) + out = _op(x, cfg) + out.sum().backward() + assert torch.allclose(x.grad, w) + assert torch.allclose(w.grad, x) + assert torch.allclose(b.grad, torch.zeros_like(b)) + + +class TestBackwardCallsAnotherOp: + """``backward_fn`` is allowed to call other registered ops to compute the + gradient. This is the FlashAttention-style "forward op + backward op" + decomposition. Verify the gradient is correct end-to-end. + """ + + def test_backward_dispatches_to_another_op(self): + @magi_register_custom_op(name="test::matmul_grad_helper") + def matmul_grad_x(gy: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + return gy @ w.t() + + def setup(ctx, inputs, output): + x, w = inputs + ctx.save_for_backward(x, w) + + def bwd(ctx, gy): + x, w = ctx.saved_tensors + return matmul_grad_x(gy, w), x.t() @ gy + + @magi_register_custom_op(name="test::matmul_with_op_bwd", setup_context_fn=setup, backward_fn=bwd) + def matmul(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + return x @ w + + x = torch.randn(3, 4, requires_grad=True) + w = torch.randn(4, 5, requires_grad=True) + out = matmul(x, w) + out.sum().backward() + gy = torch.ones_like(out) + assert torch.allclose(x.grad, gy @ w.t()) + assert torch.allclose(w.grad, x.t() @ gy) + + +@dataclasses.dataclass(frozen=True) +class _KwCfg: + s: float + + +class TestKwargOnlyDataclass: + """``cfg`` passed as a kwarg-only argument should still be flattened and + routed correctly through the dataclass-aware bridge. + """ + + def test_kwarg_only_dataclass_call(self): + @magi_register_custom_op(name="test::kwonly_dc") + def _op(x: torch.Tensor, *, cfg: _KwCfg) -> torch.Tensor: + return x * cfg.s + + x = torch.randn(2, 3) + out = _op(x, cfg=_KwCfg(s=4.0)) + assert_close(out, x * 4.0) + + def test_kwarg_only_dataclass_with_backward(self): + def setup(ctx, inputs, output): + x, cfg = inputs + ctx.s = cfg.s + + def bwd(ctx, gy): + return gy * ctx.s, None + + @magi_register_custom_op(name="test::kwonly_dc_bwd", setup_context_fn=setup, backward_fn=bwd) + def _op(x: torch.Tensor, *, cfg: _KwCfg) -> torch.Tensor: + return x * cfg.s + + x = torch.randn(3, requires_grad=True) + out = _op(x, cfg=_KwCfg(s=2.5)) + out.sum().backward() + assert torch.allclose(x.grad, torch.full_like(x, 2.5)) + + +@dataclasses.dataclass(frozen=True) +class _MutCfg: + a: torch.Tensor + b: torch.Tensor + + +class TestMutatesArgsOnDataclass: + """``mutates_args=('cfg',)`` on a dataclass parameter should expand to the + flat leaf names automatically; an unknown name should raise. + """ + + def test_mutates_args_dataclass_expands(self): + # We don't actually mutate (frozen dataclass) -- we just check the + # registration succeeds, which it would not if the name failed to + # expand to flat tensor leaves. + @magi_register_custom_op(name="test::mutates_dc_expand", mutates_args=("cfg",)) + def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: + cfg.a.add_(x) # in-place on a tensor field + cfg.b.add_(x) + return cfg.a + cfg.b + + x = torch.ones(3) + cfg = _MutCfg(a=torch.zeros(3), b=torch.zeros(3)) + out = _op(x, cfg) + assert_close(out, torch.full((3,), 2.0)) + + def test_mutates_args_unknown_name_rejected(self): + with pytest.raises(ValueError, match="does not match"): + + @magi_register_custom_op(name="test::mutates_dc_unknown", mutates_args=("does_not_exist",)) + def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: + return x + + def test_mutates_args_flat_name_passthrough(self): + """Users may also use the flat name (``cfg__a``) directly.""" + + @magi_register_custom_op(name="test::mutates_dc_flat", mutates_args=("cfg__a",)) + def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: + cfg.a.add_(x) + return cfg.a + cfg.b + + x = torch.ones(3) + cfg = _MutCfg(a=torch.zeros(3), b=torch.zeros(3)) + out = _op(x, cfg) + assert_close(out, torch.ones(3)) + + +class TestNestedMagiOpCall: + """Calling one ``@magi_register_custom_op``-registered op from inside + another is the bread-and-butter compositional pattern (e.g. a fused + attention op that internally calls a registered softmax). Verify both + plain-input and dataclass-input compositions. + """ + + def test_plain_op_inside_plain_op(self): + @magi_register_custom_op(name="test::nested_inner_plain") + def inner(x: torch.Tensor) -> torch.Tensor: + return x * 2 + + @magi_register_custom_op(name="test::nested_outer_plain") + def outer(x: torch.Tensor) -> torch.Tensor: + return inner(x) + 1 + + x = torch.randn(4) + assert_close(outer(x), x * 2 + 1) + + def test_dataclass_op_inside_plain_op(self): + @dataclasses.dataclass(frozen=True) + class _Cfg: + s: float + + @magi_register_custom_op(name="test::nested_inner_dc") + def inner(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.s + + @magi_register_custom_op(name="test::nested_outer_calls_dc") + def outer(x: torch.Tensor) -> torch.Tensor: + return inner(x, _Cfg(s=3.0)) + 1 + + x = torch.randn(4) + assert_close(outer(x), x * 3.0 + 1) + + if __name__ == "__main__": pytest.main([__file__, "-v"])