From 9707cdf5279347f1aa006d94395d501088291e76 Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Mon, 27 Apr 2026 19:40:40 +0800 Subject: [PATCH 1/2] [feat] Smart magi-register-op --- magi_compiler/_magi_register_custom_op.py | 1163 +++++++++++- magi_compiler/_triton_introspect.py | 811 +++++++++ magi_compiler/api.py | 247 ++- tests/api_tests/_triton_external_helpers.py | 91 + tests/api_tests/test_register_custom_op.py | 1804 +++++++++++++++++-- tests/api_tests/test_register_triton_op.py | 1482 +++++++++++++++ 6 files changed, 5412 insertions(+), 186 deletions(-) create mode 100644 magi_compiler/_triton_introspect.py create mode 100644 tests/api_tests/_triton_external_helpers.py create mode 100644 tests/api_tests/test_register_triton_op.py diff --git a/magi_compiler/_magi_register_custom_op.py b/magi_compiler/_magi_register_custom_op.py index 3f770ec..fb3ea98 100644 --- a/magi_compiler/_magi_register_custom_op.py +++ b/magi_compiler/_magi_register_custom_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 SandAI. All Rights Reserved. +# Copyright (c) 2026 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,221 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses +import functools import inspect -from typing import Callable, get_args, get_origin +import logging +from typing import Any, Callable, get_args, get_origin import torch +import torch.utils._pytree as pytree +from ._triton_introspect import ( + get_bare_triton_kernels, + get_inner_triton_kernels, + get_referenced_heuristics_kernels, + get_user_wrapped_triton_kernels, + rewrite_fn_with_wrap_triton, +) from .config import get_compile_config +logger = logging.getLogger(__name__) + +_DATACLASS_PYTREE_REGISTERED: set[type] = set() + +# ============================================================================== +# SECTION 1: Type Validation & Schema Error Prevention +# ------------------------------------------------------------------------------ +# These helpers intercept various unsupported edge cases (like tuples in +# dataclasses, Literal/Enum annotations, or returning a Dataclass) and replace +# opaque torch.library internal errors setup with clear, actionable messages. +# ============================================================================== + + +def _is_frozen_dataclass(tp) -> bool: + """Return True if ``tp`` is a frozen dataclass type.""" + return ( + isinstance(tp, type) + and dataclasses.is_dataclass(tp) + and getattr(tp, "__dataclass_params__", None) is not None + and tp.__dataclass_params__.frozen + ) + + +def _assert_not_unsupported_container(tp, *, where: str) -> None: + """Reject container annotations that ``torch.library.infer_schema`` cannot + consume in dataclass fields, with an actionable hint. + + Specifically: + + - ``tuple[T, ...]`` / ``Tuple[T, T]``: schema only models ``list``; + suggest splitting into independent fields or switching to ``list``. + - ``dict[K, V]`` / ``Dict[...]``: not supported by the schema at all. + """ + origin = get_origin(tp) + if origin is tuple: + raise TypeError( + f"@magi_register_custom_op: {where} is annotated with " + f"{tp!r}. torch.library does not accept tuple-typed parameters " + f"in op schemas. Either change the annotation to ``list[...]`` " + f"or split the tuple into separate dataclass fields / op args." + ) + if origin is dict: + raise TypeError( + f"@magi_register_custom_op: {where} is annotated with " + f"{tp!r}. torch.library does not accept dict-typed parameters " + f"in op schemas. Promote the values to explicit fields or " + f"split them into separate op arguments." + ) + + +def _assert_op_name_namespaced(op_name: str) -> None: + """Reject ``name=`` values that don't follow the ``namespace::op_name`` + convention. + + ``torch.library.custom_op("my_op", ...)`` (no ``::``) raises a relatively + opaque error that doesn't suggest the fix to a first-time user. We catch + it up front with a clear message that mirrors PyTorch's docs. + """ + if "::" not in op_name: + raise ValueError( + f"@magi_register_custom_op: op name {op_name!r} is missing a " + "namespace. Use ``namespace::op_name`` (e.g. " + "``my_lib::my_op``). Pick a unique namespace for your project to " + "avoid clashing with other libraries." + ) + + +_LITERAL_STRING_DOWNGRADE_HINT = ( + "Use ``str`` and validate the value inside the op body, e.g. " + "``assert mode in ('a', 'b')``." +) + + +def _maybe_downgrade_literal_or_enum(annotation, *, where: str): + """Return a schema-compatible annotation for ``Literal[str, ...]`` / + ``Enum``-of-str inputs, by collapsing them to plain ``str``. + + ``torch.library.infer_schema`` does not understand ``Literal`` or ``Enum`` + annotations, but the underlying op only ever receives the *value* — and + for the common "tag string" case (e.g. ``Literal["fp16", "bf16", "fp8"]`` + or ``class Quant(Enum): FP8 = "fp8"``) collapsing the annotation to + ``str`` is both safe and lossless. The op body still gets the original + string value at runtime. + + Numeric Literals / heterogeneous Literals / Enums whose values aren't all + strings raise a clear ``TypeError`` instead, since those don't have an + obvious safe downgrade. + """ + import enum + import typing + + origin = get_origin(annotation) + # Handle ``Literal["a", "b"]``. + if origin is typing.Literal: + choices = get_args(annotation) + if choices and all(isinstance(c, str) for c in choices): + return str + raise TypeError( + f"@magi_register_custom_op: {where} is annotated with " + f"{annotation!r}. Only ``Literal[str, ...]`` is auto-downgraded " + f"to ``str``; mixed / numeric Literals are not supported by " + f"torch.library schemas. {_LITERAL_STRING_DOWNGRADE_HINT}" + ) + # Handle ``MyEnum`` whose members are all strings. + if isinstance(annotation, type) and issubclass(annotation, enum.Enum): + members = list(annotation) + if members and all(isinstance(m.value, str) for m in members): + return str + raise TypeError( + f"@magi_register_custom_op: {where} is annotated with Enum " + f"{annotation.__name__!r} whose values are not all strings. " + f"torch.library schemas don't support Enum directly. " + f"{_LITERAL_STRING_DOWNGRADE_HINT}" + ) + return annotation + + +def _assert_not_dataclass_return(tp, *, fn_name: str) -> None: + """Reject return-type annotations that ``torch.library.infer_schema`` cannot + consume (notably dataclasses), with an actionable hint. + + Returning a dataclass is a common mistake when users start grouping op + inputs into a config dataclass and want symmetric outputs, but the schema + layer only models ``Tensor`` / ``tuple[Tensor, ...]`` / ``list[Tensor]`` / + ``None``. Without this guard users get a cryptic + ``ValueError: Return has unsupported type`` deep inside ``infer_schema``. + """ + if isinstance(tp, type) and dataclasses.is_dataclass(tp): + raise TypeError( + f"@magi_register_custom_op: function {fn_name!r} is annotated to " + f"return dataclass {tp.__name__!r}. torch.library only supports " + "returning Tensor / tuple[Tensor, ...] / list[Tensor]. Either " + "destructure the dataclass into a tuple at the op boundary, or " + "wrap the dataclass-returning logic in a thin Python helper that " + "calls the registered op." + ) + + +def _assert_not_mutable_dataclass(tp, *, where: str) -> None: + """Raise a clear error if ``tp`` is a *non-frozen* dataclass type. + + ``magi_register_custom_op`` only supports ``frozen=True`` dataclasses as + op inputs (and as nested fields). Without this guard the user gets a + confusing ``ValueError: Unsupported type annotation X. It is not a type`` + from ``torch.library.infer_schema`` deep inside the registration call. + + Frozenness is required because ``torch.library`` / Inductor assume the + flattened scalar inputs are stable for the duration of a tracing call; + mutable dataclass instances would also break the pytree node hashing + used by AOTAutograd. + """ + if ( + isinstance(tp, type) + and dataclasses.is_dataclass(tp) + and getattr(tp, "__dataclass_params__", None) is not None + and not tp.__dataclass_params__.frozen + ): + raise TypeError( + f"@magi_register_custom_op: {where} is annotated with mutable " + f"dataclass {tp.__name__!r}. Only @dataclass(frozen=True) is " + f"supported (the schema needs a stable, hashable value). " + f"Add ``frozen=True`` to {tp.__name__}." + ) + + +def _register_dataclass_pytree(cls: type) -> None: + """ + Idempotently register ``cls`` as a pytree node so that TorchDynamo / + AOTAutograd can flatten/unflatten dataclass instances when tracing. + """ + if cls in _DATACLASS_PYTREE_REGISTERED: + return + + field_names = tuple(f.name for f in dataclasses.fields(cls)) + + def _flatten(obj): + return [getattr(obj, n) for n in field_names], field_names + + def _unflatten(values, ctx): + return cls(**dict(zip(ctx, values))) + + try: + pytree.register_pytree_node(cls, _flatten, _unflatten) + except ValueError: + # Already registered elsewhere (e.g. user code). Treat as success. + pass + _DATACLASS_PYTREE_REGISTERED.add(cls) + + +# ============================================================================== +# SECTION 2: Meta Function Auto-Generation +# ------------------------------------------------------------------------------ +# Helpers to generate the meta/fake implementation required by `torch.library`. +# Fallbacks to identity_meta_fn (copying input properties to outputs) when +# the user does not provide `infer_output_meta_fn`. +# ============================================================================== + def _get_num_outputs_from_return_annotation(fn: Callable) -> int: """ @@ -165,6 +373,750 @@ def meta_fn(*args, **kwargs): return meta_fn +# ============================================================================== +# SECTION 3: Dataclass Flattening & Schema Bridging +# ------------------------------------------------------------------------------ +# These functions implement the core "dataclass-aware" mapping: recursively +# unpacking `@dataclass(frozen=True)` inputs into primitive scalars/tensors that +# `torch.library` can trace, and reassembling them into Python objects before +# handing them to the user's `backward_fn` or inner forward implementation. +# ============================================================================== + + +def _resolve_annotations(fn: Callable) -> dict[str, Any]: + """Return ``fn``'s annotations, resolving any stringified ones (e.g. when + ``from __future__ import annotations`` is in effect) into real types. + + Falls back to per-annotation resolution if ``get_type_hints`` cannot + resolve every name atomically (which happens when the function is defined + in a local scope, e.g. inside a test method, so its annotations reference + names that live only in the enclosing closure). + """ + import typing + + try: + return typing.get_type_hints(fn) + except Exception: + pass + + # Build a best-effort namespace that combines globals + nonlocal closure + # variables, so we can eval ``cfg: '_LocalDataclass'`` annotations from + # functions defined inside other functions. + fn_globals = getattr(fn, "__globals__", {}) or {} + namespace: dict[str, Any] = dict(fn_globals) + try: + cv = inspect.getclosurevars(fn) + namespace.update(cv.builtins) + namespace.update(cv.nonlocals) + namespace.update(cv.globals) + except Exception: + pass + + anns: dict[str, Any] = {} + raw = getattr(fn, "__annotations__", {}) or {} + for k, v in raw.items(): + if isinstance(v, str): + try: + anns[k] = eval(v, namespace, None) + except Exception: + anns[k] = v + else: + anns[k] = v + return anns + + +def _resolve_dataclass_field_types(cls: type) -> dict[str, Any]: + """Return ``cls``'s field name -> resolved type, with PEP 563 strings + resolved best-effort (so nested dataclass types are real classes). + """ + import typing as _typing + + try: + return _typing.get_type_hints(cls) + except Exception: + # Fallback: take whatever ``dataclasses.fields`` exposes (which may be + # a string under ``from __future__ import annotations``). Best-effort. + return {f.name: f.type for f in dataclasses.fields(cls)} + + +_SCHEMA_DEFAULT_TYPES: tuple[type, ...] = ( + int, + float, + bool, + str, + torch.device, + torch.dtype, +) + + +def _schema_compatible_param_default(default: Any) -> Any: + """Return a default value safe to attach to an ``inspect.Parameter`` that + will be handed to ``torch.library.infer_schema``. + + Same rules as :func:`_schema_compatible_default` but for raw values (used + on the top-level parameter path, where defaults come from the user's + function signature directly rather than from a ``dataclasses.Field``). + """ + if default is inspect.Parameter.empty: + return inspect.Parameter.empty + if default is None or isinstance(default, _SCHEMA_DEFAULT_TYPES): + return default + return inspect.Parameter.empty + + +def _schema_compatible_default(f: "dataclasses.Field") -> Any: + """Return a value safe to attach as ``inspect.Parameter.default`` for the + flat parameter representing dataclass field ``f``. + + ``torch.library.infer_schema`` only renders defaults of ``None`` / + ``int`` / ``float`` / ``bool`` / ``str`` / ``torch.device`` / + ``torch.dtype``; anything else (e.g. a ``list`` from ``default_factory``) + triggers ``"unsupported default value type"``. We therefore drop unsupported + defaults — the outer dataclass instance still carries the real default for + the user, so behaviour is preserved. + """ + if f.default is not dataclasses.MISSING: + d = f.default + if d is None or isinstance(d, _SCHEMA_DEFAULT_TYPES): + return d + return inspect.Parameter.empty + if f.default_factory is not dataclasses.MISSING: # type: ignore[misc] + try: + d = f.default_factory() + except Exception: + return inspect.Parameter.empty + if d is None or isinstance(d, _SCHEMA_DEFAULT_TYPES): + return d + return inspect.Parameter.empty + return inspect.Parameter.empty + + +def _build_dataclass_subplan( + cls: type, attr_name: str, flat_prefix: str +) -> tuple[tuple, list[inspect.Parameter]]: + """Recursively build a (sub-)plan and the corresponding flat parameters + for one frozen-dataclass-typed value. + + ``attr_name`` is the attribute name on the parent dataclass (or the + parameter name on ``fn`` when called for a top-level argument). + + ``flat_prefix`` is the dot-replaced prefix used to build leaf parameter + names. For a top-level dataclass parameter ``cfg`` of type ``OuterCfg`` + with a nested ``inner: InnerCfg(val: float)``, the flat parameter name + for the leaf is ``cfg__inner__val``. + + Returns ``(node, flat_params)`` where ``node`` is the recursive plan node + and ``flat_params`` is the list of leaf ``inspect.Parameter`` objects in + DFS order. + """ + _register_dataclass_pytree(cls) + + field_types = _resolve_dataclass_field_types(cls) + children: list[tuple] = [] + flat_params: list[inspect.Parameter] = [] + + for f in dataclasses.fields(cls): + f_type = field_types.get(f.name, f.type) + child_flat_name = f"{flat_prefix}__{f.name}" + if isinstance(f_type, str): + raise TypeError( + f"@magi_register_custom_op: field {cls.__name__}.{f.name} has " + f"an unresolved string annotation {f_type!r}. This usually " + "happens when the field's type is defined inside a function " + "body (a 'local class') combined with " + "``from __future__ import annotations``. Move the type to " + "module scope, or import it at module scope, so " + "``typing.get_type_hints`` can resolve it." + ) + _assert_not_mutable_dataclass(f_type, where=f"field {cls.__name__}.{f.name}") + if _is_frozen_dataclass(f_type): + sub_node, sub_params = _build_dataclass_subplan( + f_type, attr_name=f.name, flat_prefix=child_flat_name + ) + children.append(sub_node) + flat_params.extend(sub_params) + else: + _assert_not_unsupported_container( + f_type, where=f"field {cls.__name__}.{f.name}" + ) + f_type = _maybe_downgrade_literal_or_enum( + f_type, where=f"field {cls.__name__}.{f.name}" + ) + children.append(("primitive", f.name, child_flat_name, None)) + # Carry the dataclass field's default (or default_factory product) + # over to the flat parameter so torch.library.infer_schema records + # it as optional. ``infer_schema`` only accepts defaults of types + # ``None``, ``int``, ``float``, ``bool``, ``str``, ``torch.device``, + # ``torch.dtype``; other defaults (notably ``list``/``dict`` from + # ``default_factory``) are left as "required" on the flat param. + # Either way the outer wrapper still gets the real default value + # via ``cls(**user_kwargs)`` since users construct the dataclass + # instance themselves. + flat_params.append( + inspect.Parameter( + child_flat_name, + # NOTE: POSITIONAL_OR_KEYWORD because torch.library.custom_op + # does not yet support kwarg-only Tensor arguments. + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=_schema_compatible_default(f), + annotation=f_type, + ) + ) + + return ("dataclass", attr_name, cls, children), flat_params + + +def _build_flat_signature(fn: Callable): + """ + Build a "flat" signature that only contains primitive-typed parameters by + recursively expanding any frozen-dataclass-typed parameter into its + individual leaf fields. Nested dataclasses (dataclass-of-dataclass) are + fully unwrapped: ``cfg: OuterCfg`` with ``inner: InnerCfg(val: float)`` + becomes a flat parameter ``cfg__inner__val: float``. + + Returns: + flat_sig (inspect.Signature): signature with all dataclass params + recursively expanded. + plan (list[tuple]): per-original-parameter plan describing how to + reassemble values from a flat kwargs dict. Each entry is one of: + * ``("primitive", attr_name, flat_name, None)`` for a leaf + whose runtime value is read from / written to the flat + ``flat_name`` slot; + * ``("dataclass", attr_name, cls, [sub_plan_nodes...])`` for a + dataclass node whose children follow the same recursive + structure. + ``attr_name`` is the parameter name on ``fn`` at the top level + and the field name on the parent dataclass deeper in the tree. + user_sig (inspect.Signature): the original (un-flattened) signature + of ``fn`` for binding user calls. + """ + user_sig = inspect.signature(fn) + # Resolve stringified annotations (PEP 563 / ``from __future__ import + # annotations``) so ``_is_frozen_dataclass`` can recognise dataclass-typed + # parameters and dataclass field types are real ``type`` objects. + resolved = _resolve_annotations(fn) + flat_params: list[inspect.Parameter] = [] + plan: list[tuple] = [] + + for name, param in user_sig.parameters.items(): + annotation = resolved.get(name, param.annotation) + _assert_not_mutable_dataclass(annotation, where=f"parameter {name!r}") + if _is_frozen_dataclass(annotation): + node, sub_params = _build_dataclass_subplan( + annotation, attr_name=name, flat_prefix=name + ) + plan.append(node) + flat_params.extend(sub_params) + else: + _assert_not_unsupported_container(annotation, where=f"parameter {name!r}") + annotation = _maybe_downgrade_literal_or_enum( + annotation, where=f"parameter {name!r}" + ) + new_param = param.replace( + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=annotation, + default=_schema_compatible_param_default(param.default), + ) + flat_params.append(new_param) + plan.append(("primitive", name, name, None)) + + return_annotation = resolved.get("return", user_sig.return_annotation) + _assert_not_dataclass_return(return_annotation, fn_name=fn.__name__) + flat_sig = inspect.Signature(flat_params, return_annotation=return_annotation) + return flat_sig, plan, user_sig + + +def _signatures_differ(a: inspect.Signature, b: inspect.Signature) -> bool: + """Return True iff two signatures disagree on at least one of: parameter + names, annotations, defaults, kinds, or return annotation. + + We use this to decide whether the flat signature meaningfully differs + from the user signature (and thus needs an ``inner_fn`` wrapper) on the + no-dataclass path. ``inspect.Signature.__eq__`` is conservative enough + for our purposes; we wrap it for clarity at the call site. + """ + return a != b + + +def _make_flat_signature_wrapper(fn: Callable, flat_sig: inspect.Signature) -> Callable: + """Return a thin wrapper around ``fn`` whose ``__signature__`` / + ``__annotations__`` reflect ``flat_sig`` (annotations downgraded / + defaults scrubbed). The body simply forwards every argument through. + + The wrapper is required because ``torch.library.infer_schema`` reads + ``inspect.signature(fn)`` and would otherwise see the user's original + (un-scrubbed) annotations. We also strip ``__wrapped__`` so + ``inspect.signature`` doesn't unwrap back to the original. + """ + + @functools.wraps(fn) + def _wrapped(*args, **kwargs): + return fn(*args, **kwargs) + + _wrapped.__signature__ = flat_sig + flat_annotations = { + p.name: p.annotation + for p in flat_sig.parameters.values() + if p.annotation is not inspect.Parameter.empty + } + if flat_sig.return_annotation is not inspect.Signature.empty: + flat_annotations["return"] = flat_sig.return_annotation + _wrapped.__annotations__ = flat_annotations + try: + del _wrapped.__wrapped__ + except AttributeError: + pass + return _wrapped + + +def _build_value_from_node(node: tuple, flat_kwargs: dict): + """Recursively materialise the value described by a plan ``node`` from + the flat kwargs dict produced by ``_build_flat_signature``. + + Used by both the meta-side reassembly (when torch.library hands us back + flat kwargs) and any other site that needs the original-shaped value. + """ + kind = node[0] + if kind == "primitive": + _, _attr, flat_name, _ = node + return flat_kwargs[flat_name] + # ``("dataclass", attr, cls, children)`` + _, _attr, cls, children = node + init_kwargs: dict[str, Any] = {} + for child in children: + # ``child[1]`` is the field name on ``cls`` regardless of node kind. + field_name = child[1] + init_kwargs[field_name] = _build_value_from_node(child, flat_kwargs) + return cls(**init_kwargs) + + +def _reassemble_user_kwargs(plan: list[tuple], flat_kwargs: dict) -> dict: + """Reconstruct the original (possibly nested-dataclass-bearing) kwargs + from flat kwargs. Mirrors :func:`_build_flat_signature`. + """ + out: dict[str, Any] = {} + for node in plan: + # Top-level node: ``node[1]`` is the original parameter name on ``fn``. + out[node[1]] = _build_value_from_node(node, flat_kwargs) + return out + + +def _flatten_value_into(node: tuple, value: Any, out: list) -> None: + """Recursively flatten ``value`` according to plan ``node``, appending + leaf primitives to ``out`` in DFS order. + """ + kind = node[0] + if kind == "primitive": + out.append(value) + return + _, _attr, cls, children = node + # We don't isinstance-check ``cls`` here on purpose: users may pass + # arbitrary objects that quack like the dataclass (e.g. mocks). We just + # rely on getattr for each declared field. + for child in children: + field_name = child[1] + _flatten_value_into(child, getattr(value, field_name), out) + + +def _flatten_call_args( + plan: list[tuple], user_sig: inspect.Signature, args: tuple, kwargs: dict +) -> list: + """ + Flatten a user-side call (which may pass nested dataclass instances) into + a positional list. The order matches the flat signature produced by + :func:`_build_flat_signature`. + """ + bound = user_sig.bind(*args, **kwargs) + bound.apply_defaults() + flat: list = [] + for node in plan: + # ``node[1]`` is the top-level parameter name; primitives are passed + # through unchanged, dataclasses are recursively unwrapped. + _flatten_value_into(node, bound.arguments[node[1]], flat) + return flat + + +def _count_leaves(node: tuple) -> int: + """Number of flat parameter slots a plan ``node`` occupies.""" + if node[0] == "primitive": + return 1 + return sum(_count_leaves(c) for c in node[3]) + + +def _flatten_grad_into(node: tuple, grad: Any, out: list) -> None: + """Spread a user-returned grad for one original-signature input across + the flat parameter slots described by ``node``. + + Rules: + * ``primitive`` node: append ``grad`` (whatever the user returned) as-is. + * ``dataclass`` node: + - If the user returned ``None``: every leaf slot under this node + gets ``None`` (whole-dataclass-not-differentiable, the common case). + - Otherwise: the user is returning a dataclass-shaped grad object. + We descend recursively, reading each child via + ``getattr(grad, field_name)``. Missing fields are treated as + ``None``. This mirrors :func:`_flatten_value_into` but is more + forgiving so users can return a plain ``SimpleNamespace`` / + ``dict``-like object too (we accept ``dict`` via ``__getitem__``). + """ + if node[0] == "primitive": + out.append(grad) + return + _, _attr, _cls, children = node + if grad is None: + for child in children: + for _ in range(_count_leaves(child)): + out.append(None) + return + is_mapping = isinstance(grad, dict) + for child in children: + field_name = child[1] + if is_mapping: + sub = grad.get(field_name) + else: + sub = getattr(grad, field_name, None) + _flatten_grad_into(child, sub, out) + + +def _collect_tensor_leaf_flat_names(node: tuple) -> list[str]: + """Return the flat parameter names of every Tensor leaf under ``node``. + + Used to expand a top-level dataclass parameter referenced in + ``mutates_args`` into the actual flat parameter names that + ``torch.library`` sees. + """ + if node[0] == "primitive": + _, _attr, flat_name, _ = node + anno = node[3] # currently always ``None`` for primitive nodes + # We don't actually have the annotation cached here -- callers expand + # all leaves under a dataclass and torch.library will validate Tensor + # types itself, so over-specifying is acceptable. + return [flat_name] + out: list[str] = [] + for child in node[3]: + out.extend(_collect_tensor_leaf_flat_names(child)) + return out + + +def _expand_mutates_args( + mutates_args: tuple[str, ...] | list[str], + plan: list[tuple], +) -> tuple[str, ...]: + """Translate ``mutates_args`` from the *original* parameter space to the + *flat* parameter space. + + * Names that already match a primitive top-level parameter pass through. + * A name that matches a top-level dataclass parameter expands into every + flat leaf under it (``cfg`` -> ``cfg__a``, ``cfg__b__x``, ...). + * Names of the form ``cfg__inner__x`` (already in flat space) pass through + unchanged so users can be precise if they want to. + * Unknown names raise ``ValueError`` with the list of valid choices. + """ + if not mutates_args: + return tuple(mutates_args) + by_attr: dict[str, tuple] = {node[1]: node for node in plan} + valid_flat: set[str] = set() + for node in plan: + valid_flat.update(_collect_tensor_leaf_flat_names(node)) + out: list[str] = [] + for name in mutates_args: + if name in by_attr: + node = by_attr[name] + if node[0] == "primitive": + out.append(node[2]) + else: + out.extend(_collect_tensor_leaf_flat_names(node)) + elif name in valid_flat: + out.append(name) + else: + raise ValueError( + f"@magi_register_custom_op: mutates_args entry {name!r} does " + f"not match any parameter. Valid original-space names: " + f"{sorted(by_attr.keys())}; valid flat-space names: " + f"{sorted(valid_flat)}." + ) + seen: set[str] = set() + deduped: list[str] = [] + for n in out: + if n not in seen: + seen.add(n) + deduped.append(n) + return tuple(deduped) + + +def _flatten_user_grads(plan: list[tuple], user_grads: tuple | list) -> list: + """Convert a tuple of grads keyed by *original* parameter order into the + flat tuple keyed by the flat-op parameter order. + + Length of ``user_grads`` MUST equal ``len(plan)`` (one grad per original + input). Raises ``ValueError`` otherwise so users get a clear message + instead of an opaque autograd shape error. + """ + if len(user_grads) != len(plan): + raise ValueError( + f"backward_fn returned {len(user_grads)} grad(s) but the original " + f"function had {len(plan)} input(s). When using a frozen-dataclass " + f"input, return one grad per ORIGINAL parameter (use ``None`` for " + f"non-differentiable ones, including whole dataclass arguments)." + ) + flat: list = [] + for node, g in zip(plan, user_grads): + _flatten_grad_into(node, g, flat) + return flat + + +# ============================================================================== +# SECTION 4: Triton Introspection & Wrapping Intercept +# ------------------------------------------------------------------------------ +# These helpers interface with the `torch.library` triton registry and our custom +# introspector. They identify which kernels the op uses, detect unsupported +# heuristics, and optionally rebuild the function to shadow module-global kernel +# references with `wrap_triton(...)` so Inductor can trace through the op. +# ============================================================================== + + +def _assert_wrap_triton_compatible(kernels: list[Any]) -> None: + """Reject kernels whose outermost decorator is ``@triton.heuristics``. + + ``torch.library.wrap_triton`` (the only entry point into ``triton_op`` / + Inductor's traceable HOP path) hard-codes:: + + isinstance(triton_kernel, (JITFunction, Autotuner)) + + A bare ``@triton.heuristics`` (or ``@triton.heuristics`` wrapping + ``@triton.autotune``) produces a ``Heuristics`` object that fails this + check at runtime with a confusing + ``"wrap_triton only works on functions annotated with triton.jit or + triton.autotune"``. We surface a clearer error here pointing at the fix + (``@triton.autotune → @triton.heuristics → @triton.jit``). + """ + if not kernels: + return + try: + from triton.runtime.autotuner import Autotuner, Heuristics + from triton.runtime.jit import JITFunction + except ImportError: + return + for k in kernels: + if isinstance(k, (JITFunction, Autotuner)): + continue + if isinstance(k, Heuristics): + name = getattr(getattr(k, "fn", None), "__name__", repr(k)) + raise RuntimeError( + f"@magi_register_custom_op: triton kernel {name!r} has " + "@triton.heuristics as its outermost decorator. " + "torch.library.wrap_triton (and therefore triton_op / Inductor) " + "only accepts triton.jit or triton.autotune at the top level. " + "Either remove @triton.heuristics, or place @triton.autotune " + "outside it: @triton.autotune -> @triton.heuristics -> @triton.jit." + ) + + +def _resolve_triton_kernels( + fn: Callable, + extra_triton_kernels: list[Any] | tuple[Any, ...] | None, +) -> tuple[list[Any], list[Any], set[int]]: + """Best-effort: collect triton kernels referenced inside ``fn``. + + Returns ``(all_kernels, bare_kernels, user_wrapped_ids)``: + + * ``all_kernels`` is the union of user-supplied ``extra_triton_kernels`` + and *every* kernel discovered by source introspection. This is the list + we use to decide whether to take the ``triton_op`` registration path. + * ``bare_kernels`` is the subset that is invoked via the bare + ``kernel[grid](...)`` pattern (the user did NOT wrap them themselves). + Only these need to be shadowed by ``rewrite_fn_with_wrap_triton``; + shadowing kernels the user already wrapped would yield + ``wrap_triton(wrap_triton(kernel))`` and crash at runtime. + * ``user_wrapped_ids`` is the set of kernel object ids the user has + already wrapped explicitly via ``wrap_triton(k)`` in the source. We + forward this to the rewriter as an exclusion list so its blanket + "wrap every JITFunction in fn.__globals__" pass does not double-wrap + a kernel that's also referenced from a manual ``wrap_triton(k)`` call + in the same body. + + User-supplied ``extra_triton_kernels`` are treated as bare (the user is + asking us to wrap them on their behalf). Both lists are deduplicated by + ``id(kernel)`` and preserve user-supplied order first. + """ + seen_all: set[int] = set() + all_kernels: list[Any] = [] + seen_bare: set[int] = set() + bare_kernels: list[Any] = [] + for k in extra_triton_kernels or (): + kid = id(k) + if kid not in seen_all: + seen_all.add(kid) + all_kernels.append(k) + if kid not in seen_bare: + seen_bare.add(kid) + bare_kernels.append(k) + try: + detected_all = get_inner_triton_kernels(fn) + except Exception: + logger.debug("get_inner_triton_kernels(%r) failed", fn, exc_info=True) + detected_all = [] + try: + detected_bare = get_bare_triton_kernels(fn) + except Exception: + logger.debug("get_bare_triton_kernels(%r) failed", fn, exc_info=True) + detected_bare = [] + for k in detected_all: + if id(k) not in seen_all: + seen_all.add(id(k)) + all_kernels.append(k) + for k in detected_bare: + if id(k) not in seen_bare: + seen_bare.add(id(k)) + bare_kernels.append(k) + + # Reject kernels whose outermost decorator is @triton.heuristics with a + # readable error before the user hits either (a) wrap_triton's opaque + # RuntimeError or (b) a silent fallback to plain custom_op (because the + # introspector unwraps Heuristics via ``obj.fn`` and may even drop it + # entirely when callable(Heuristics_instance) is False, hiding the + # outer layer from ``all_kernels``). + try: + referenced_heuristics = get_referenced_heuristics_kernels(fn) + except Exception: + logger.debug("get_referenced_heuristics_kernels(%r) failed", fn, exc_info=True) + referenced_heuristics = [] + _assert_wrap_triton_compatible( + list(extra_triton_kernels or ()) + list(referenced_heuristics) + ) + try: + user_wrapped = get_user_wrapped_triton_kernels(fn) + except Exception: + logger.debug("get_user_wrapped_triton_kernels(%r) failed", fn, exc_info=True) + user_wrapped = [] + user_wrapped_ids = {id(k) for k in user_wrapped} + return all_kernels, bare_kernels, user_wrapped_ids + + +# ============================================================================== +# SECTION 5: Core Registration & Main Decorator +# ------------------------------------------------------------------------------ +# These functions implement the actual dispatch layer: registering the op name +# avoiding duplicates, handling the fallback from `` to ``, +# and the outermost `magi_register_custom_op_impl` orchestration. +# ============================================================================== + +_REGISTERED_OP_NAMES: set[str] = set() + + +def _assert_op_name_unused(op_name: str) -> None: + """Raise a clear error if ``op_name`` was already registered through this + decorator (or already exists on ``torch.ops``). + + Without this guard, ``torch.library.custom_op`` raises a low-level + ``RuntimeError`` referring to schema fingerprints that is hard to map back + to "you have two ``@magi_register_custom_op`` calls with the same name". + """ + if op_name in _REGISTERED_OP_NAMES: + raise RuntimeError( + f"@magi_register_custom_op: op name {op_name!r} is already " + "registered. Each magi op must use a unique " + "``namespace::op_name``. If you really want to override, delete " + "the previous registration with " + "``torch.library._del_library_impl`` first, or pass an explicit " + "``name=`` to disambiguate." + ) + ns, _, opname = op_name.partition("::") + if ns and opname: + ns_obj = getattr(torch.ops, ns, None) + if ns_obj is not None and hasattr(ns_obj, opname): + raise RuntimeError( + f"@magi_register_custom_op: op name {op_name!r} is already " + f"defined on torch.ops.{ns}. Use a different name (or pass an " + "explicit ``name=`` to your decorator) to avoid clashing with " + "an existing operator." + ) + + +def _register_op( + op_name: str, + fn: Callable, + mutates_args: tuple[str, ...], + meta_fn: Callable, + user_supplied_meta: bool, + triton_kernels: list[Any], + bare_triton_kernels: list[Any] | None = None, + signature_override: inspect.Signature | None = None, + excluded_kernel_ids: set[int] | None = None, +): + """Register ``fn`` either as a triton_op (when triton kernels are present) + or as a plain custom_op, with sensible fallback if triton_op registration + fails. + + ``bare_triton_kernels`` is the subset of ``triton_kernels`` that the user + did NOT already wrap in ``wrap_triton(...)`` themselves. Only those are + fed to :func:`rewrite_fn_with_wrap_triton` so we never produce a + ``wrap_triton(wrap_triton(kernel))``. Defaults to ``triton_kernels`` for + backwards compatibility. + + Returns the resulting ``CustomOpDef`` instance. + """ + if bare_triton_kernels is None: + bare_triton_kernels = triton_kernels + if triton_kernels: + try: + from torch.library import triton_op + except ImportError: + triton_op = None # type: ignore[assignment] + logger.warning( + "torch.library.triton_op not available; falling back to " + "torch.library.custom_op for op %s", + op_name, + ) + + if triton_op is not None: + try: + fn_for_register = rewrite_fn_with_wrap_triton( + fn, bare_triton_kernels, excluded_kernel_ids=excluded_kernel_ids + ) + # ``rewrite_fn_with_wrap_triton`` builds a fresh + # ``types.FunctionType`` from ``fn.__code__``; if ``fn`` is a + # thin signature-rewriting wrapper (e.g. for Literal / + # default-list scrubbing), the freshly built function has the + # wrapper's ``(*args, **kwargs)`` code object, so we need to + # re-attach the cleaned signature for ``infer_schema``. + if signature_override is not None: + fn_for_register.__signature__ = signature_override + fn_for_register.__annotations__ = { + p.name: p.annotation + for p in signature_override.parameters.values() + if p.annotation is not inspect.Parameter.empty + } + if ( + signature_override.return_annotation + is not inspect.Signature.empty + ): + fn_for_register.__annotations__["return"] = ( + signature_override.return_annotation + ) + registered_op = triton_op(op_name, mutates_args=mutates_args)( + fn_for_register + ) + # ``triton_op`` already registers ``fn`` as the fake/meta + # implementation. Only override when the user explicitly + # supplied an ``infer_output_meta_fn``. + if user_supplied_meta: + registered_op.register_fake(meta_fn) + return registered_op + except Exception: + logger.warning( + "triton_op registration failed for %s; falling back to " + "custom_op + register_fake. Inductor will not be able to " + "see through the op.", + op_name, + exc_info=True, + ) + + registered_op = torch.library.custom_op(op_name, mutates_args=mutates_args)(fn) + torch.library.register_fake(op_name)(meta_fn) + return registered_op + + def _magi_register_custom_op_impl( name: str | None = None, mutates_args: tuple[str, ...] = (), @@ -173,32 +1125,215 @@ def _magi_register_custom_op_impl( backward_fn: Callable | None = None, is_compute_sensitive: bool = False, is_subgraph_boundary: bool = False, + extra_triton_kernels: list[Any] | tuple[Any, ...] | None = None, ): def decorator(fn: Callable) -> Callable: # Auto-generate name if not provided op_name = name if name is not None else _generate_op_name(fn) + _assert_op_name_namespaced(op_name) + _assert_op_name_unused(op_name) if is_compute_sensitive: - get_compile_config().recompute_config.custom_compute_sensitive_ops.append(op_name) + get_compile_config().recompute_config.custom_compute_sensitive_ops.append( + op_name + ) if is_subgraph_boundary: get_compile_config().splitting_ops.append(op_name) - # Step 1: Register the custom op with torch.library.custom_op - registered_op = torch.library.custom_op(op_name, mutates_args=mutates_args)(fn) + # Detect whether any input is a frozen dataclass; if not, fall through + # to the original (zero-overhead) registration path. + flat_sig, plan, user_sig = _build_flat_signature(fn) + has_dataclass = any(kind == "dataclass" for kind, *_ in plan) + + if not has_dataclass: + # The flat signature may differ from the user's signature even + # without any dataclass input: we may have downgraded a Literal / + # Enum annotation to ``str`` or scrubbed a list/dict default that + # ``infer_schema`` cannot consume. In those cases we route through + # a thin wrapper whose ``__signature__`` is ``flat_sig`` so the + # schema sees the cleaned-up version. Otherwise we register ``fn`` + # directly to preserve the original zero-overhead path. + sig_was_rewritten = _signatures_differ(flat_sig, user_sig) + fn_for_register = ( + _make_flat_signature_wrapper(fn, flat_sig) if sig_was_rewritten else fn + ) + + # Step 1: Build the meta/fake function (used either as a + # register_fake override on the triton path, or as the regular + # fake implementation on the plain custom_op path). + meta_target = fn_for_register + if infer_output_meta_fn is None: + meta_fn = _create_identity_meta_fn(meta_target) + user_supplied_meta = False + elif isinstance(infer_output_meta_fn, list): + meta_fn = _create_meta_fn_from_param_names( + meta_target, infer_output_meta_fn + ) + user_supplied_meta = True + else: + meta_fn = infer_output_meta_fn + user_supplied_meta = True + + # Step 2: Detect inner triton kernels and register the op via + # triton_op (if any kernels are present) or custom_op (otherwise). + triton_kernels, bare_triton_kernels, user_wrapped_ids = ( + _resolve_triton_kernels(fn, extra_triton_kernels) + ) + registered_op = _register_op( + op_name=op_name, + fn=fn_for_register, + mutates_args=mutates_args, + meta_fn=meta_fn, + user_supplied_meta=user_supplied_meta, + triton_kernels=triton_kernels, + bare_triton_kernels=bare_triton_kernels, + signature_override=flat_sig if sig_was_rewritten else None, + excluded_kernel_ids=user_wrapped_ids, + ) + + # Step 3: Register autograd if backward_fn is provided + if backward_fn is not None: + registered_op.register_autograd( + backward_fn, setup_context=setup_context_fn + ) + + _REGISTERED_OP_NAMES.add(op_name) + return registered_op - # Step 2: Register the output meta inference function - # Determine meta_fn based on the type of infer_output_meta_fn + # ----- Dataclass-aware path ----- + # Build inner_fn whose signature contains only primitive types so that + # torch.library.custom_op's schema validator is happy. The inner_fn + # accepts positional/keyword args following the flat signature; we + # bind them, reassemble dataclasses, then call the original ``fn``. + def _bind_to_user_kwargs(args, kwargs): + bound = flat_sig.bind(*args, **kwargs) + bound.apply_defaults() + return _reassemble_user_kwargs(plan, bound.arguments) + + # Detect triton kernels referenced from the original (dataclass-typed) + # fn. If any are present, route ``inner_fn`` through a wrap_triton-aware + # copy of ``fn`` so the eventual triton_op registration captures them. + triton_kernels, bare_triton_kernels, user_wrapped_ids = _resolve_triton_kernels( + fn, extra_triton_kernels + ) + fn_for_inner = ( + rewrite_fn_with_wrap_triton( + fn, bare_triton_kernels, excluded_kernel_ids=user_wrapped_ids + ) + if bare_triton_kernels + else fn + ) + + @functools.wraps(fn) + def inner_fn(*args, **kwargs): + return fn_for_inner(**_bind_to_user_kwargs(args, kwargs)) + + inner_fn.__signature__ = flat_sig + # ``functools.wraps`` set ``__wrapped__`` to ``fn``; that makes + # ``inspect.signature(inner_fn)`` follow back to ``fn`` (which still + # carries the dataclass-typed annotations) and bypass our flat + # ``__signature__`` override. ``triton_op`` / ``infer_schema`` rely on + # ``inspect.signature`` and would then choke on the dataclass type + # annotation. Strip the wrapper marker so the flat signature wins. + try: + del inner_fn.__wrapped__ + except AttributeError: + pass + # Replace the dataclass-typed annotations copied over by + # ``functools.wraps`` with the flat-signature annotations so that any + # tool reading ``__annotations__`` directly (e.g. ``get_type_hints``) + # also sees the primitive types torch.library expects. + flat_annotations = { + p.name: p.annotation + for p in flat_sig.parameters.values() + if p.annotation is not inspect.Parameter.empty + } + if flat_sig.return_annotation is not inspect.Signature.empty: + flat_annotations["return"] = flat_sig.return_annotation + inner_fn.__annotations__ = flat_annotations + + # Build the meta function based on the flat signature. if infer_output_meta_fn is None: - meta_fn = _create_identity_meta_fn(fn) + meta_fn = _create_identity_meta_fn(inner_fn) + user_supplied_meta = False elif isinstance(infer_output_meta_fn, list): - meta_fn = _create_meta_fn_from_param_names(fn, infer_output_meta_fn) + meta_fn = _create_meta_fn_from_param_names(inner_fn, infer_output_meta_fn) + user_supplied_meta = True else: - meta_fn = infer_output_meta_fn - torch.library.register_fake(op_name)(meta_fn) + user_meta = infer_output_meta_fn + + def meta_fn(*args, **kwargs): + return user_meta(**_bind_to_user_kwargs(args, kwargs)) + + meta_fn.__signature__ = flat_sig + user_supplied_meta = True - # Step 3: Register autograd if backward_fn is provided + flat_mutates_args = _expand_mutates_args(mutates_args, plan) + registered_op = _register_op( + op_name=op_name, + fn=inner_fn, + mutates_args=flat_mutates_args, + meta_fn=meta_fn, + signature_override=flat_sig, + user_supplied_meta=user_supplied_meta, + triton_kernels=triton_kernels, + # ``inner_fn`` already wraps a rewritten copy of ``fn``, so we do + # NOT want _register_op to rewrite a second time (that would + # introspect ``inner_fn`` and re-wrap kernels referenced via the + # original ``fn`` closure). Pass an empty list to short-circuit. + bare_triton_kernels=[], + ) + + # Bridge user-supplied autograd hooks (which speak the ORIGINAL + # dataclass signature) into the FLAT signature actually registered + # with torch.library. + # + # On the forward pass torch.library calls + # setup_context(ctx, inputs=, output=...) + # On the backward pass it expects the user's ``backward`` to return + # one grad per FLAT input. Users naturally want to write both in + # terms of the original (dataclass-bearing) signature, so we wrap + # both ends. if backward_fn is not None: - registered_op.register_autograd(backward_fn, setup_context=setup_context_fn) + user_setup = setup_context_fn + user_backward = backward_fn + + def _bridged_setup_context(ctx, inputs, output): + if user_setup is None: + return None + # ``inputs`` is the flat positional tuple in the order of + # ``flat_sig``. Reassemble it into the user's original + # (possibly nested-dataclass-bearing) shape. + flat_kwargs = { + p.name: v for p, v in zip(flat_sig.parameters.values(), inputs) + } + user_kwargs = _reassemble_user_kwargs(plan, flat_kwargs) + # Preserve original positional order so users can do + # ``x, cfg = inputs`` exactly like in the no-dataclass case. + user_inputs = tuple(user_kwargs[p] for p in user_sig.parameters) + return user_setup(ctx, user_inputs, output) + + def _bridged_backward(ctx, *grads): + user_grads = user_backward(ctx, *grads) + if not isinstance(user_grads, tuple): + # Single-input convenience: PyTorch allows returning a + # bare grad if the op has a single input. Mirror that. + user_grads = (user_grads,) + return tuple(_flatten_user_grads(plan, user_grads)) + + registered_op.register_autograd( + _bridged_backward, setup_context=_bridged_setup_context + ) + + # Outer wrapper preserves the original (dataclass-aware) signature for + # users while routing through the registered (flat) op underneath. + @functools.wraps(fn) + def outer_wrapper(*args, **kwargs): + flat = _flatten_call_args(plan, user_sig, args, kwargs) + return registered_op(*flat) - return registered_op + outer_wrapper._magi_inner_op = registered_op + outer_wrapper._magi_flat_plan = plan + _REGISTERED_OP_NAMES.add(op_name) + return outer_wrapper return decorator diff --git a/magi_compiler/_triton_introspect.py b/magi_compiler/_triton_introspect.py new file mode 100644 index 0000000..d5d10da --- /dev/null +++ b/magi_compiler/_triton_introspect.py @@ -0,0 +1,811 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Triton kernel introspection utilities used by ``magi_register_custom_op``. + +This module vendors the (more capable) v2.11.0 implementation of +``torch._library.triton.get_inner_triton_kernels`` and adds a runtime +"globals/closure shadow rewrite" pass (``rewrite_fn_with_wrap_triton``) that +replaces every reference to a detected triton kernel inside a function (and +any helper functions it calls) with ``torch.library.wrap_triton(kernel)``, +without touching the source code. + +Only ``get_inner_triton_kernels`` and ``rewrite_fn_with_wrap_triton`` are +intended to be public (used by ``_magi_register_custom_op``). +""" + +from __future__ import annotations + +import ast +import functools +import inspect +import logging +import textwrap +import types +from typing import Any, Callable, Optional + +logger = logging.getLogger(__name__) + + +__all__ = [ + "get_inner_triton_kernels", + "get_bare_triton_kernels", + "get_referenced_heuristics_kernels", + "rewrite_fn_with_wrap_triton", +] + + +# ============================================================================== +# SECTION 1: Triton Kernel AST Introspection +# ------------------------------------------------------------------------------ +# Vendored from torch._library.triton (v2.11.0 / pytorch main) and extended. +# Local copy so that ``magi_register_custom_op`` works even if the user is on +# an older PyTorch whose helper is much weaker (e.g. 2.9.x). +# +# These functions parse the AST of the decorated function (and its helpers) +# to discover `_cos_kernel[...]` usage, tracing globals and closure cells. +# ============================================================================== + + +def _find_triton_kernels_impl( + fn: Callable[..., Any], only_bare: bool = False +) -> list[object]: + """Shared driver for :func:`get_inner_triton_kernels` and + :func:`get_bare_triton_kernels`. + + When ``only_bare`` is True, only kernels invoked via the bare + ``kernel[grid](...)`` pattern (i.e. without an explicit ``wrap_triton`` / + ``capture_triton`` call) are returned. Those are the ones we MUST shadow + in :func:`rewrite_fn_with_wrap_triton`; kernels the user already wrapped + explicitly should be left alone to avoid double wrapping. + """ + + # prevent infinite recursion + MAX_RECURSION_DEPTH = 5 + + def find_triton_kernels( + fn: Callable[..., Any], + visited_fns: set[int] | None = None, + depth: int = 0, + ) -> list[object]: + try: + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + except ImportError: + logger.warning("Triton not available, find_triton_kernels = []") + return [] + + # unwrap decorated fn's (e.g., @lru_cache) to get the original + fn = inspect.unwrap(fn) + + # init visited set and check for cycles/depth limit + if visited_fns is None: + visited_fns = set() + + fn_id = id(fn) + if fn_id in visited_fns: + return [] + if depth > MAX_RECURSION_DEPTH: + logger.debug( + "reached max recursion depth (%s) in find_triton_kernels", + MAX_RECURSION_DEPTH, + ) + return [] + + visited_fns.add(fn_id) + + try: + source = inspect.getsource(fn) + except (OSError, TypeError): + return [] # Source code not available + + from torch._inductor.utils import IndentedBuffer + + buffer = IndentedBuffer() + buffer.splice(source, strip=True) + tree = ast.parse(buffer.getrawvalue()) + + # Visitor to collect function calls, assignments, and triton kernels + class Visitor(ast.NodeVisitor): + def __init__(self) -> None: + # Names referenced via wrap_triton(name) / capture_triton(name). + # The user has *already* wrapped these, so the rewrite pass + # must NOT shadow them (doing so would produce + # wrap_triton(wrap_triton(kernel)) at runtime). + self.wrapped_kernel_names: list[Any] = [] + # Names invoked via bare ``kernel[grid](...)`` syntax. + # These are the ones we need to wrap_triton-shadow at runtime + # so the resulting triton_op is traceable. + self.bare_kernel_names: list[Any] = [] + # track local variable assignments: var_name -> list of RHS expressions + self.assignments: dict[str, list[ast.expr]] = {} + # track function calls + self.called_functions: list[str] = [] + # track return statement expressions + self.return_exprs: list[ast.expr] = [] + + def visit_Assign(self, node: ast.Assign) -> None: + for target in node.targets: + if isinstance(target, ast.Name): + self.assignments.setdefault(target.id, []).append(node.value) + self.generic_visit(node) + + def visit_Return(self, node: ast.Return) -> None: + if node.value is not None: + self.return_exprs.append(node.value) + self.generic_visit(node) + + def visit_Call(self, node: ast.Call) -> None: + triton_func_names = ("capture_triton", "wrap_triton") + if isinstance(node.func, ast.Attribute): + attr = node.func + if isinstance(attr.value, ast.Attribute): + if ( + isinstance(attr.value.value, ast.Name) + and attr.value.value.id == "torch" + and attr.value.attr == "_library" + and attr.attr in triton_func_names + ): + if node.args and isinstance(node.args[0], ast.Name): + self.wrapped_kernel_names.append(node.args[0].id) + elif ( + isinstance(attr.value.value, ast.Attribute) + and isinstance(attr.value.value.value, ast.Name) + and attr.value.value.value.id == "torch" + and attr.value.value.attr == "ops" + ): + self.called_functions.append( + f"{attr.value.attr}::{attr.attr}" + ) + # Catch capture_triton, wrap_triton that's been + # imported directly + elif isinstance(node.func, ast.Name): + if node.func.id in triton_func_names: + if node.args and isinstance(node.args[0], ast.Name): + self.wrapped_kernel_names.append(node.args[0].id) + else: + # track regular function calls for recursive analysis + self.called_functions.append(node.func.id) + + # Also detect bare triton-style launches: ``kernel[grid](args)``. + # The decorated function is allowed to invoke the kernel without + # explicit wrap_triton(...); we pick up the kernel name here so + # downstream code can wrap it. We only look at Subscript whose + # value is a plain Name (the most common pattern); subscripted + # attributes (e.g. ``self.kernel[grid](...)``) need the + # ``extra_triton_kernels`` escape hatch. + if isinstance(node.func, ast.Subscript) and isinstance( + node.func.value, ast.Name + ): + self.bare_kernel_names.append(node.func.value.id) + + self.generic_visit(node) + + collector = Visitor() + collector.visit(tree) + + def extract_names_from_expr(expr: ast.expr) -> list[str]: + """Extract all Name references from an AST expression.""" + names: list[str] = [] + + class NameExtractor(ast.NodeVisitor): + def visit_Name(self, node: ast.Name) -> None: + names.append(node.id) + + def visit_Call(self, node: ast.Call) -> None: + # for function calls, visit the function and all args + self.generic_visit(node) + + NameExtractor().visit(expr) + return names + + def resolve_to_kernel(obj: object) -> object | None: + """Check if obj is a triton kernel or wrapper and return the kernel.""" + if isinstance(obj, (JITFunction, Autotuner)): + return obj + # handle wrappers that have a .fn attribute pointing to JITFunction + if callable(obj) and hasattr(obj, "fn"): + inner = obj.fn + if isinstance(inner, JITFunction): + return inner + return None + + def build_namespace(func_obj: object) -> dict[str, Any]: + """Build a combined namespace from a function's globals and closures.""" + # unwrap decorated fns (e.g., @lru_cache) + if callable(func_obj): + try: + func_obj = inspect.unwrap(func_obj) + except ValueError: + pass + if not callable(func_obj) or not hasattr(func_obj, "__code__"): + return {} + try: + func_closure_vars = inspect.getclosurevars(func_obj) + except Exception: + func_closure_vars = None + namespace: dict[str, Any] = {} + if func_closure_vars is not None: + namespace.update(func_closure_vars.builtins) + namespace.update(func_closure_vars.globals) + namespace.update(func_closure_vars.nonlocals) + if hasattr(func_obj, "__globals__"): + namespace.update(func_obj.__globals__) + return namespace + + all_names = build_namespace(fn) + + def resolve_names_to_kernels( + names: list[str], + namespace: dict[str, Any], + assignments: dict[str, list[ast.expr]] | None = None, + visited: set[str] | None = None, + ) -> list[object]: + """ + Resolve a list of names to triton kernels using the given namespace. + """ + if visited is None: + visited = set() + + results: list[object] = [] + for name in names: + if name in visited: + continue + visited.add(name) + + if name in namespace: + obj = namespace[name] + kernel = resolve_to_kernel(obj) + if kernel is not None: + results.append(kernel) + continue + # recurse into callable objects (factory fn's), + # unwrapping decorators if applicable + if callable(obj): + try: + unwrapped = inspect.unwrap(obj) + except ValueError: + unwrapped = obj + if hasattr(unwrapped, "__code__"): + nested = find_triton_kernels( + unwrapped, + visited_fns, + depth + 1, + ) + if nested: + results.extend(nested) + continue + logger.debug("failed to resolve %s to a triton kernel", name) + elif assignments is not None and name in assignments: + # trace through local assignments + for rhs_expr in assignments[name]: + referenced = extract_names_from_expr(rhs_expr) + traced = resolve_names_to_kernels( + referenced, namespace, assignments, visited + ) + results.extend(traced) + else: + logger.debug("%s not found in namespace or assignments", name) + + return results + + # resolve kernel names, tracing through local variables if needed + resolved: list[object] = [] + seen_ids: set[int] = set() + + if only_bare: + names_to_resolve: list[str] = list(collector.bare_kernel_names) + else: + names_to_resolve = list(collector.bare_kernel_names) + list( + collector.wrapped_kernel_names + ) + for expr in collector.return_exprs: + names_to_resolve.extend(extract_names_from_expr(expr)) + + for name in names_to_resolve: + traced_objects = resolve_names_to_kernels( + [name], all_names, collector.assignments + ) + for obj in traced_objects: + obj_id = id(obj) + if obj_id not in seen_ids: + seen_ids.add(obj_id) + resolved.append(obj) + + for func_name in collector.called_functions: + func_obj = all_names.get(func_name) + + if func_obj is None: + try: + from torch._library.custom_ops import OPDEFS + + if func_name in OPDEFS: + func_obj = OPDEFS[func_name]._abstract_fn + except Exception: + pass + + # skip if not a callable or if it's a triton kernel itself + if func_obj is None or not callable(func_obj): + continue + + # skip built-in functions and C extensions (they can't contain triton kernels) + if not hasattr(func_obj, "__code__"): + continue + + try: + nested_kernels = find_triton_kernels(func_obj, visited_fns, depth + 1) + for kernel in nested_kernels: + kernel_id = id(kernel) + if kernel_id not in seen_ids: + seen_ids.add(kernel_id) + resolved.append(kernel) + except Exception: + logger.debug( + "failed to analyze called function %s", func_name, exc_info=True + ) + + return resolved + + return find_triton_kernels(fn) + + +def get_user_wrapped_triton_kernels(fn: Callable[..., Any]) -> list[object]: + """Return triton kernels that ``fn``'s source explicitly wraps in a + ``wrap_triton(kernel)`` / ``capture_triton(kernel)`` call. + + These are the kernels the user has *already* taken responsibility for + wrapping; :func:`rewrite_fn_with_wrap_triton` must not rewrite their + module-globals references (doing so would produce + ``wrap_triton(wrap_triton(kernel))`` at runtime). Exposed so the caller + can build an ``excluded_kernel_ids`` set to pass into the rewriter. + """ + return _find_user_wrapped_kernels_impl(fn) + + +def _find_user_wrapped_kernels_impl(fn: Callable[..., Any]) -> list[object]: + triton_types_pair = _try_import_triton_types() + if triton_types_pair is None: + return [] + kernel_types: tuple[type, ...] = triton_types_pair + + try: + source = inspect.getsource(fn) + except (OSError, TypeError): + return [] + try: + tree = ast.parse(textwrap.dedent(source)) + except SyntaxError: + return [] + + wrapped_names: list[str] = [] + triton_func_names = ("capture_triton", "wrap_triton") + + class _Visitor(ast.NodeVisitor): + def visit_Call(self, node: ast.Call) -> None: + if isinstance(node.func, ast.Name) and node.func.id in triton_func_names: + if node.args and isinstance(node.args[0], ast.Name): + wrapped_names.append(node.args[0].id) + elif isinstance(node.func, ast.Attribute): + attr = node.func + if ( + isinstance(attr.value, ast.Attribute) + and isinstance(attr.value.value, ast.Name) + and attr.value.value.id == "torch" + and attr.value.attr == "_library" + and attr.attr in triton_func_names + ): + if node.args and isinstance(node.args[0], ast.Name): + wrapped_names.append(node.args[0].id) + self.generic_visit(node) + + _Visitor().visit(tree) + if not wrapped_names: + return [] + + namespace: dict[str, Any] = {} + namespace.update(getattr(fn, "__globals__", {}) or {}) + if fn.__closure__ is not None: + try: + for name, cell in zip(fn.__code__.co_freevars, fn.__closure__): + try: + namespace[name] = cell.cell_contents + except ValueError: + pass + except Exception: + pass + + out: list[object] = [] + seen: set[int] = set() + for n in wrapped_names: + obj = namespace.get(n) + if obj is None: + continue + resolved = ( + obj if isinstance(obj, kernel_types) else _resolve_kernel(obj, kernel_types) + ) + if resolved is None: + continue + if id(resolved) in seen: + continue + seen.add(id(resolved)) + out.append(resolved) + return out + + +def get_inner_triton_kernels(fn: Callable[..., Any]) -> list[object]: + """ + Inspect the source of an arbitrary callable, and grab all of the triton + kernels that are wrapped inside of it. + + Traces local variable assignments, follows ``return`` expressions, and + recursively descends into helper functions called from ``fn`` so that + kernels hidden behind launcher wrappers are still detected. + + Returns an empty list if triton is not installed or no kernels are found. + Best-effort: deeply recursive call graphs (>5 levels) are not followed. + """ + return _find_triton_kernels_impl(fn, only_bare=False) + + +def get_referenced_heuristics_kernels(fn: Callable[..., Any]) -> list[object]: + """Return ``triton.runtime.autotuner.Heuristics`` instances that ``fn`` + (or any helper function it transitively calls) references via a name in + its globals/closure. + + Designed specifically to surface ``@triton.heuristics`` placed at the + *top* of the decorator stack, which :func:`get_inner_triton_kernels` + deliberately peels through to expose the inner ``JITFunction`` (the only + type that ``wrap_triton`` accepts together with ``Autotuner``). Without + this helper the top-level ``Heuristics`` would be silently dropped + (registration falls back to plain custom_op, no Inductor speedup) or + explode later in ``wrap_triton`` with an opaque error. + + Returns ``[]`` if triton is not installed or no such kernels are found. + """ + try: + from triton.runtime.autotuner import Heuristics + except ImportError: + return [] + + MAX_DEPTH = 5 + found: list[object] = [] + seen_objs: set[int] = set() + visited_fns: set[int] = set() + + def _maybe_add(obj: Any) -> None: + if isinstance(obj, Heuristics) and id(obj) not in seen_objs: + seen_objs.add(id(obj)) + found.append(obj) + + def _walk(func: Callable[..., Any], depth: int) -> None: + try: + f = inspect.unwrap(func) + except ValueError: + f = func + if id(f) in visited_fns or depth > MAX_DEPTH: + return + visited_fns.add(id(f)) + + try: + source = inspect.getsource(f) + except (OSError, TypeError): + return + try: + tree = ast.parse(inspect.cleandoc(source)) + except SyntaxError: + return + + names: set[str] = set() + + class NameCollector(ast.NodeVisitor): + def visit_Subscript(self, node: ast.Subscript) -> None: + if isinstance(node.value, ast.Name): + names.add(node.value.id) + self.generic_visit(node) + + def visit_Call(self, node: ast.Call) -> None: + if isinstance(node.func, ast.Name): + names.add(node.func.id) + self.generic_visit(node) + + NameCollector().visit(tree) + + # Build the same combined namespace used by find_triton_kernels. + ns: dict[str, Any] = {} + try: + cv = inspect.getclosurevars(f) + ns.update(cv.builtins) + ns.update(cv.globals) + ns.update(cv.nonlocals) + except Exception: + pass + if hasattr(f, "__globals__"): + ns.update(f.__globals__) + + for n in names: + obj = ns.get(n) + if obj is None: + continue + _maybe_add(obj) + # Recurse into helper python functions so launchers nested one + # level deep also get inspected. + if isinstance(obj, types.FunctionType): + _walk(obj, depth + 1) + + _walk(fn, 0) + return found + + +def get_bare_triton_kernels(fn: Callable[..., Any]) -> list[object]: + """ + Like :func:`get_inner_triton_kernels`, but only returns kernels invoked + via the bare ``kernel[grid](...)`` syntax (NOT via an explicit + ``wrap_triton`` / ``capture_triton`` call). + + These are the kernels that :func:`rewrite_fn_with_wrap_triton` actually + needs to shadow at runtime. Skipping the already-wrapped ones avoids + producing a ``wrap_triton(wrap_triton(kernel))`` at runtime, which raises + ``RuntimeError`` from ``torch.library.wrap_triton``. + """ + return _find_triton_kernels_impl(fn, only_bare=True) + + +# ============================================================================== +# SECTION 2: Runtime "wrap_triton" Shadow Rewriter +# ------------------------------------------------------------------------------ +# In order for Inductor to trace into bare `kernel[grid]` calls, they must be +# `torch.library.wrap_triton(kernel)[grid]`. +# Instead of forcing the user to rewrite their code, this pass builds a clone of +# the user's function where `__globals__` maps the kernel name to a wrapped +# version. +# ============================================================================== + + +def _try_import_triton_types() -> Optional[tuple[type, type]]: + try: + from triton.runtime.autotuner import Autotuner + from triton.runtime.jit import JITFunction + + return (JITFunction, Autotuner) + except ImportError: + return None + + +def _resolve_kernel(obj: object, kernel_types: tuple[type, ...]) -> Optional[object]: + """Return the underlying triton kernel object if ``obj`` is one (or wraps one).""" + if isinstance(obj, kernel_types): + return obj + if callable(obj) and hasattr(obj, "fn"): + try: + inner = obj.fn + except Exception: + return None + if isinstance(inner, kernel_types): + return obj + return None + + +def _is_user_helper(obj: object) -> bool: + """True if ``obj`` is a plain Python function we can recursively rebuild. + + Excludes triton kernels (JITFunction / Autotuner) themselves, builtins, + and C-extension callables. + """ + if not isinstance(obj, types.FunctionType): + return False + code = getattr(obj, "__code__", None) + if code is None: + return False + # skip torch.library helpers and triton internals to avoid surprises + mod = getattr(obj, "__module__", "") or "" + if mod.startswith(("torch._library", "triton.")): + return False + return True + + +def rewrite_fn_with_wrap_triton( + fn: Callable[..., Any], + kernels: list[object], + excluded_kernel_ids: Optional[set[int]] = None, +) -> Callable[..., Any]: + """ + Return a copy of ``fn`` whose globals / closures are shadowed so that every + reference to a kernel in ``kernels`` resolves to + ``torch.library.wrap_triton(kernel)``. Helper functions called from ``fn`` + are also rebuilt the same way, so kernels referenced from launcher helpers + are wrapped too. + + The original ``fn`` (and the kernel objects) are not modified. This works + for kernels referenced via globals, closures, or factory functions, as + long as the reference can be reached through the function's + ``__globals__`` / ``__closure__``. + + If ``kernels`` is empty or triton is not installed, returns ``fn`` unchanged. + """ + if not kernels: + return fn + + triton_types_pair = _try_import_triton_types() + if triton_types_pair is None: + return fn + kernel_types: tuple[type, ...] = triton_types_pair + + try: + from torch.library import wrap_triton + except ImportError: + try: + from torch._library.triton import wrap_triton # type: ignore + except ImportError: + logger.debug("wrap_triton unavailable; skipping rewrite") + return fn + + # Map id(kernel_object) -> wrap_triton(kernel_object). Cached so the same + # kernel passed through multiple references shares one wrapper, and so we + # never wrap_triton(wrap_triton(k)). + wrapped_cache: dict[int, Any] = {} + + def _wrap_once(k: object) -> Any: + kid = id(k) + if kid not in wrapped_cache: + wrapped_cache[kid] = wrap_triton(k) + return wrapped_cache[kid] + + # Pre-populate cache with the explicitly detected kernels so identical + # kernel objects encountered later resolve to the same wrapper. + target_ids: set[int] = set() + for k in kernels: + if isinstance(k, kernel_types): + _wrap_once(k) + target_ids.add(id(k)) + else: + # Allow callers to pass wrappers (e.g. objects with .fn) too. + resolved = _resolve_kernel(k, kernel_types) + if resolved is not None: + _wrap_once(resolved) + target_ids.add(id(resolved)) + + excluded_kernel_ids = set(excluded_kernel_ids or set()) + + def _maybe_wrap(obj: object) -> Optional[Any]: + """If ``obj`` is one of our target kernels, return the wrap_triton wrapper. + + Returns ``None`` if ``obj`` should be left alone. + """ + # Don't double-wrap something that already came out of wrap_triton. + # ``wrap_triton`` returns a TraceableTritonKernelWrapper; importing + # that class is brittle across torch versions, so we identity-check + # against the cache values instead. + if id(obj) in {id(v) for v in wrapped_cache.values()}: + return None + + resolved = _resolve_kernel(obj, kernel_types) + if resolved is None: + return None + # Caller has explicitly told us this kernel is already user-wrapped + # (e.g. via ``wrap_triton(kernel)`` in the op body); don't shadow its + # module-globals reference, otherwise the explicit ``wrap_triton(k)`` + # in the source becomes ``wrap_triton(wrap_triton(k))`` at runtime. + if id(resolved) in excluded_kernel_ids: + return None + if id(resolved) in target_ids or isinstance(resolved, kernel_types): + # Always wrap any encountered triton kernel (not just the + # initially-detected ones) so dynamically-resolved kernels in + # helper globals are also captured. + return _wrap_once(resolved) + return None + + rebuilt_fns: dict[int, Callable[..., Any]] = {} + + # Per-module globals_dict cache: every function defined in the same module + # shares the same ``__globals__`` dict, so we only need to walk + rewrite + # that dict ONCE per module instead of once per helper. Without this cache + # the rebuilder is O(N_helpers * N_globals_per_module), which becomes + # catastrophic (multi-second per registration) for any module with many + # top-level functions / fixtures. + rebuilt_globals: dict[int, dict[str, Any]] = {} + + def _build_new_globals(old_globals: dict[str, Any]) -> dict[str, Any]: + gid = id(old_globals) + if gid in rebuilt_globals: + return rebuilt_globals[gid] + new_globals: dict[str, Any] = dict(old_globals) + # Register the partially-populated dict immediately so any reentrant + # _rebuild call (e.g. helper -> back-references the module) finds it + # and short-circuits without infinite recursion. + rebuilt_globals[gid] = new_globals + + for name, obj in list(old_globals.items()): + wrapped = _maybe_wrap(obj) + if wrapped is not None: + new_globals[name] = wrapped + continue + if _is_user_helper(obj): + try: + new_globals[name] = _rebuild(obj) + except Exception: + logger.debug("failed to rebuild helper %s", name, exc_info=True) + return new_globals + + def _rebuild(f: Callable[..., Any]) -> Callable[..., Any]: + if not isinstance(f, types.FunctionType): + return f + if id(f) in rebuilt_fns: + return rebuilt_fns[id(f)] + + # Pre-register a placeholder so cycles (helper that references back + # into ``f`` through globals or closures) don't recurse forever. + # We swap in the real new_fn at the bottom of this function. + rebuilt_fns[id(f)] = f + + new_globals = _build_new_globals(f.__globals__) + + # Rebuild closure cells. + new_closure: Optional[tuple] = None + if f.__closure__ is not None: + new_cells = [] + for cell in f.__closure__: + try: + contents = cell.cell_contents + except ValueError: + # empty cell + new_cells.append(cell) + continue + + wrapped = _maybe_wrap(contents) + if wrapped is not None: + new_cells.append(types.CellType(wrapped)) + continue + if _is_user_helper(contents) and id(contents) != id(f): + try: + new_cells.append(types.CellType(_rebuild(contents))) + continue + except Exception: + logger.debug( + "failed to rebuild closure helper %s", + getattr(contents, "__name__", "?"), + exc_info=True, + ) + new_cells.append(cell) + new_closure = tuple(new_cells) + + new_fn = types.FunctionType( + f.__code__, + new_globals, + f.__name__, + f.__defaults__, + new_closure, + ) + # Preserve introspectable metadata so that downstream tooling + # (infer_schema, register_fake, etc.) continues to work. + try: + functools.update_wrapper(new_fn, f, updated=()) + except Exception: + pass + new_fn.__kwdefaults__ = f.__kwdefaults__ + new_fn.__module__ = f.__module__ + new_fn.__qualname__ = f.__qualname__ + # update_wrapper sets __wrapped__ which makes inspect.unwrap follow + # back to the original; that's actually undesirable here because the + # original function's globals do NOT have wrap_triton applied. Strip + # it so inspect.signature / unwrap stop at the rewritten function. + try: + del new_fn.__wrapped__ + except AttributeError: + pass + + rebuilt_fns[id(f)] = new_fn + return new_fn + + return _rebuild(fn) diff --git a/magi_compiler/api.py b/magi_compiler/api.py index 996657a..1dcf1d6 100644 --- a/magi_compiler/api.py +++ b/magi_compiler/api.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 SandAI. All Rights Reserved. +# Copyright (c) 2026 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,7 +15,7 @@ import copy import functools import inspect -from typing import Callable, TypeVar +from typing import Any, Callable, TypeVar from ._api import ( _check_dynamic_arg_dims, @@ -202,6 +202,7 @@ def magi_register_custom_op( backward_fn: Callable | None = None, is_compute_sensitive: bool = False, is_subgraph_boundary: bool = False, + extra_triton_kernels: list[Any] | tuple[Any, ...] | None = None, ): """ A unified decorator to register a custom operator with PyTorch's library. @@ -210,6 +211,153 @@ def magi_register_custom_op( - @torch.library.custom_op - @torch.library.register_fake - fn.register_autograd + - @torch.library.triton_op (auto-detected: see "Triton kernels" below) + + plus two convenience layers on top: + - frozen-dataclass inputs (including arbitrarily nested dataclasses) + are transparently flattened into primitive parameters before being + handed to ``torch.library``, then reassembled at runtime so user code + sees the original signature; see "Frozen-dataclass inputs" below. + - autograd hooks expressed against the original signature keep + working when dataclass inputs are present: ``setup_context_fn`` and + ``backward_fn`` are bridged in/out of the flat parameter space + automatically; see "Autograd with dataclass inputs" below. + + Triton kernels + -------------- + If the decorated function (or any helper it calls) launches one or more + ``triton.jit`` kernels (with or without an explicit ``wrap_triton`` call), + the function is automatically registered via ``torch.library.triton_op`` + instead of ``torch.library.custom_op``. Detected kernel references are + transparently rewritten to go through ``torch.library.wrap_triton`` at + runtime, so the user does not need to add ``wrap_triton(...)`` manually. + Already-wrapped kernels are detected and not re-wrapped (the rewrite is + idempotent), and the same op may launch any number of triton kernels: + Inductor will see all of them and may inline / fuse them. + + This makes the kernels visible to ``torch.compile`` / Inductor (instead of + keeping the op opaque), enabling kernel inlining and fusion in the + generated graph. Falls back to plain ``custom_op`` if no kernels are + detected or if ``triton_op`` registration fails. + + Detection is best-effort source introspection and recurses through helper + functions called from the decorated function (including helpers in other + modules and helpers reached via ``torch.ops..(...)`` calls). For + pathological cases (kernels stored on instance attributes, kernels behind + user-defined conditional wrappers, kernels constructed only at runtime, + etc.), pass ``extra_triton_kernels`` to provide the list explicitly. + + Frozen-dataclass inputs + ----------------------- + Any parameter typed as a ``@dataclass(frozen=True)`` is recursively + flattened into its individual leaf fields before being registered with + ``torch.library``. Nested dataclasses (dataclass-of-dataclass, to any + depth) are fully unwrapped using ``__`` as the join separator, e.g. an + outer ``cfg: OuterCfg`` whose ``OuterCfg.inner: InnerCfg(val: float)`` + becomes a flat parameter named ``cfg__inner__val: float``. At call time + the user passes (and the body sees) the original dataclass instance; the + decorator handles the conversion in both directions. + + Requirements: + - Each dataclass MUST be ``frozen=True``; the leaf field types must be + types accepted by ``torch.library.custom_op``. Supported field types + include: + * ``torch.Tensor`` + * Scalars: ``int``, ``float``, ``bool``, ``str`` + * Structured scalars: ``torch.dtype``, ``torch.device`` + * Optional variants: ``Optional[Tensor]``, ``Optional[int]``, etc., + as well as PEP 604 syntax (e.g. ``Tensor | None``). + * Lists of scalars/tensors: ``list[int]``, ``list[float]``, + ``list[bool]``, ``list[Tensor]``, ``list[Optional[Tensor]]``. + Note that PyTorch does **not** support ``list[Optional[int]]`` + but it does support ``list[Optional[Tensor]]``. Similarly, + ``Optional[list[Tensor]]`` is **not** supported (use + ``list[Optional[Tensor]]`` instead). + * ``Literal[str, ...]`` and ``Enum`` containing only string values + are automatically supported by downgrading them to ``str`` at the + schema boundary (but your op body still receives the original string). + - Returning a dataclass from the op body is not supported; only + ``torch.Tensor`` / ``tuple[torch.Tensor, ...]`` returns are allowed. + + Autograd with dataclass inputs + ------------------------------ + ``setup_context_fn`` and ``backward_fn`` are written against the + *original* (dataclass-bearing) signature, not the flat one: + - ``setup_context_fn(ctx, inputs, output)`` receives ``inputs`` in the + same positional order as ``fn``'s signature, with each dataclass + argument reassembled back into its original instance. + - ``backward_fn(ctx, *grad_outputs)`` must return one grad per + *original* input (dataclass arguments count as one slot). For a + dataclass slot the user may return any of: + * ``None`` -> equivalent to "no grad for any field" + (the bridge fills ``None`` into every + flat slot under that dataclass). + * a same-shape dataclass / namedtuple instance -> per-field grad, + ``None`` leaves are allowed and are spread to the corresponding + flat slots. + * a ``dict`` keyed by field name -> same as above but without + having to construct a new dataclass instance. + Returning the wrong number of top-level grads raises ``ValueError``. + + Limitations and known caveats + ----------------------------- + - Return type: only ``torch.Tensor`` / ``tuple[torch.Tensor, ...]`` / + ``list[torch.Tensor]`` / ``None`` are accepted by the underlying + ``torch.library`` schema. Returning a dataclass raises a clear + ``TypeError`` at registration time -- destructure the dataclass into a + tuple at the op boundary instead. + - Top-level tuple/dict: parameters typed as ``tuple[...]`` or + ``dict[...]`` are not supported by the schema and will raise a + ``TypeError``. Wrap them in a ``@dataclass(frozen=True)`` instead. + - Local nested types: a dataclass field annotated with a class defined + inside another function body, combined with + ``from __future__ import annotations``, cannot be resolved by + ``typing.get_type_hints`` and produces a clear ``TypeError`` pointing at + the offending field. Move the type to module scope to fix. + - Double backward: not supported automatically. ``backward_fn`` runs + under autograd but does not get its own backward registered. If you need + higher-order derivatives, either compute them manually inside + ``backward_fn`` (using ``torch.autograd.grad(..., create_graph=True)`` + against differentiable building blocks), or split the op so the second + derivative comes from a separately registered op. + - vmap / functorch: there is no automatic ``vmap`` rule. Calling a + registered op under ``torch.vmap`` falls back to the default per-sample + loop. If you need a real batched implementation, register one with + ``torch.library.register_vmap`` against the *flat* inner op + (``op._magi_inner_op`` when dataclass inputs are present). + - Triton kernel imported inside the op body: ``import`` statements + executed at call time are not visible to source introspection. Either + hoist the import to module scope, or pass the kernel object explicitly + via ``extra_triton_kernels=``. + - Mixed wrapped/bare Triton kernels: If your op body uses both + ``wrap_triton(kernel)[grid]`` and bare ``kernel[grid]`` calls, avoid + using the same kernel function for both styles, or the automated wrapper + might double-wrap it. Standardize on one style (bare is recommended). + - dataclass field of type ``list[Dataclass]``: not supported. The flat + schema requires a static, finite leaf count; a runtime-sized list of + dataclass instances has no fixed shape. Restructure into parallel + ``list[Tensor]`` / ``list[int]`` fields, or split into per-element op + calls. + - Mixed-type tuple returns (e.g. ``tuple[Tensor, int]``): not + supported by the schema (only homogeneous ``tuple[Tensor, ...]`` / + ``list[Tensor]`` are accepted). Either return only the tensors, or + stash the scalar on ``ctx`` and recover it from the call site. + - Custom CUDA streams inside the op body (``with torch.cuda.stream(s):``): + not analysed. Inductor will treat the op as opaque w.r.t. the + alternate stream; do stream-overlap orchestration above the op + boundary, not inside it. + - 0-dim Tensor used as a scalar: works but goes through a Tensor + schema slot (not a ``Scalar`` slot), so the value enters the FX graph + as a tensor input and won't constant-fold. Pass an actual + ``int``/``float``/``bool`` if you want scalar semantics. + - CPU-only execution on the Triton path: a Triton-backed op only + registers a ``cuda`` kernel. Calling it on CPU tensors raises + ``"no kernel registered"`` from PyTorch; do CPU dispatch above the op + boundary. + - Decorating a function twice with magi_register_custom_op: the + second decoration receives the wrapper from the first, not the user's + original function, and produces a confusing schema error. Decorate at + most once per function object. Arguments: name: The fully qualified name of the operator (e.g., "namespace::op_name"). @@ -218,15 +366,26 @@ def magi_register_custom_op( infer_output_meta_fn: Specifies output tensor metadata (shape, dtype, device) for tracing. - None (default): Assumes each output has the same metadata as the corresponding input tensor (1st output matches 1st tensor input, 2nd matches 2nd, etc.). + On the triton path, when None is passed the decorated function itself is used as + the fake/meta implementation (must be make_fx-traceable, which it is once kernel + calls go through ``wrap_triton``). - list[str]: Parameter names whose metadata to use for outputs. E.g., ["weight", "bias"] means output[0] has same shape as `weight`, output[1] has same shape as `bias`. - - Callable: Custom function with same signature as the op, returns torch.empty_like() - tensors matching the expected output shapes. + - Callable: Custom function with same signature as the op (in the + *original* signature space, including dataclass arguments + -- the bridge handles flattening for you), returns + torch.empty_like() tensors matching the expected output shapes. setup_context_fn: Function to save tensors/values for backward. - Signature: setup_context_fn(ctx, inputs, output) + Signature: ``setup_context_fn(ctx, inputs, output)``. ``inputs`` + mirrors the *original* signature: dataclass arguments are + reassembled into their original instances rather than exposed as + flat fields. Safe to use both with and without dataclass inputs. backward_fn: Function to compute gradients. - Signature: backward_fn(ctx, *grad_outputs) -> tuple of gradients + Signature: ``backward_fn(ctx, *grad_outputs) -> tuple of grads``. + Return one grad per *original* parameter (use ``None`` for + non-differentiable / non-tensor parameters). For dataclass + parameters see "Autograd with dataclass inputs" above. is_compute_sensitive: If True, marks this operator as compute-intensive (e.g., MatMul, Attention). During activation recomputation (rematerialization), outputs of compute-sensitive ops are prioritized for saving rather than recomputing, @@ -235,6 +394,13 @@ def magi_register_custom_op( compilation. Each sub-graph between boundary operators is compiled independently by Inductor, enabling piecewise compilation and more flexible scheduling (e.g., for CPU offloading or overlapping computation with data transfer). + extra_triton_kernels: Optional explicit list of triton kernels (``triton.jit`` / + ``triton.autotune`` objects) referenced inside the decorated function. Use this + when automatic source-based detection fails to discover a kernel + (e.g., kernel stored on ``self``, kernel selected by a user-defined ``maybe_capture`` + wrapper, etc.). Kernels listed here are merged with the auto-detected + ones and deduplicated by object identity, so it is safe (and harmless) + to also list a kernel that is statically detectable. Returns: The registered custom operator function. @@ -281,6 +447,74 @@ def magi_register_custom_op( ... ) ... def square(x: torch.Tensor) -> torch.Tensor: ... return x * x + + 4. With a (nested) frozen dataclass argument (auto pytree-flattened): + + >>> from dataclasses import dataclass + >>> + >>> @dataclass(frozen=True) + ... class NormCfg: + ... eps: float + ... affine: bool + ... + >>> @dataclass(frozen=True) + ... class AttnCfg: + ... scale: float + ... norm: NormCfg # nested dataclass field + ... + >>> @magi_register_custom_op() + ... def my_attn(q: torch.Tensor, k: torch.Tensor, cfg: AttnCfg) -> torch.Tensor: + ... out = (q @ k.transpose(-1, -2)) * cfg.scale + ... return out / (out.std() + cfg.norm.eps) + + Internally the registered op has flat parameters + ``q, k, cfg__scale, cfg__norm__eps, cfg__norm__affine``; users still + call ``my_attn(q, k, AttnCfg(scale=..., norm=NormCfg(...)))``. + + 5. Dataclass input + custom backward (signature is the original one): + + >>> @dataclass(frozen=True) + ... class ScaleCfg: + ... scale: float + ... + >>> def _setup(ctx, inputs, output): + ... x, cfg = inputs # original signature view + ... ctx.save_for_backward(x) + ... ctx.scale = cfg.scale + ... + >>> def _bwd(ctx, grad_out): + ... # one grad per ORIGINAL input; dataclass slot -> ``None``. + ... return grad_out * ctx.scale, None + ... + >>> @magi_register_custom_op( + ... setup_context_fn=_setup, + ... backward_fn=_bwd, + ... ) + ... def scale_op(x: torch.Tensor, cfg: ScaleCfg) -> torch.Tensor: + ... return x * cfg.scale + + 6. Triton kernel inside the body, no manual ``wrap_triton`` needed: + + >>> import triton + >>> import triton.language as tl + >>> + >>> @triton.jit + ... def cos_kernel(in_ptr, out_ptr, n, BLOCK_SIZE: tl.constexpr): + ... pid = tl.program_id(axis=0) + ... offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + ... mask = offsets < n + ... x = tl.load(in_ptr + offsets, mask=mask) + ... tl.store(out_ptr + offsets, tl.cos(x), mask=mask) + ... + >>> @magi_register_custom_op() + ... def my_cos(x: torch.Tensor) -> torch.Tensor: + ... out = torch.empty_like(x) + ... n = x.numel() + ... # Plain ``kernel[grid](...)`` -- the decorator detects this and + ... # registers ``my_cos`` as a triton_op so torch.compile can + ... # inline ``cos_kernel``. + ... cos_kernel[((n + 127) // 128,)](x, out, n, BLOCK_SIZE=128) + ... return out """ return _magi_register_custom_op_impl( name=name, @@ -290,4 +524,5 @@ def magi_register_custom_op( backward_fn=backward_fn, is_compute_sensitive=is_compute_sensitive, is_subgraph_boundary=is_subgraph_boundary, + extra_triton_kernels=extra_triton_kernels, ) diff --git a/tests/api_tests/_triton_external_helpers.py b/tests/api_tests/_triton_external_helpers.py new file mode 100644 index 0000000..37cc253 --- /dev/null +++ b/tests/api_tests/_triton_external_helpers.py @@ -0,0 +1,91 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""External helper module used by ``test_register_triton_op.py``. + +The helpers below intentionally live in their own module so that, when the +test file imports them and calls them inside a ``magi_register_custom_op``- +decorated function, the helpers' ``__globals__`` are *this* module, not the +test module. That exercises the truly cross-module rebuild path in +``rewrite_fn_with_wrap_triton``. +""" + +from __future__ import annotations + +""" +External helper module for test_register_triton_op.py to verify +cross-module triton kernel introspection. +""" +import torch + +try: + import triton + import triton.language as tl + + HAS_TRITON = True +except ImportError: # pragma: no cover + HAS_TRITON = False + + +if HAS_TRITON: + + @triton.jit + def external_neg_kernel( + in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, -x, mask=mask) + + @triton.jit + def external_double_kernel( + in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, x * 2, mask=mask) + + def external_neg_launcher(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + external_neg_kernel[((n + 127) // 128,)](x, out, n, BLOCK_SIZE=128) + return out + + def external_double_launcher(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + external_double_kernel[((n + 127) // 128,)](x, out, n, BLOCK_SIZE=128) + return out + + def maybe_capture(kernel): + """Third-party-style thin wrapper around a triton kernel. + + Some libraries return objects with a ``.fn`` attribute pointing back to + the underlying ``JITFunction``; we mimic that pattern here so the test + can confirm ``rewrite_fn_with_wrap_triton`` still recognises the + underlying kernel when users write ``maybe_capture(kernel)[grid](...)``. + """ + + class _Captured: + def __init__(self, k): + self.fn = k # introspector recognises objects with .fn + + def __getitem__(self, grid): + return self.fn[grid] + + return _Captured(kernel) diff --git a/tests/api_tests/test_register_custom_op.py b/tests/api_tests/test_register_custom_op.py index 9872e68..3b694f7 100644 --- a/tests/api_tests/test_register_custom_op.py +++ b/tests/api_tests/test_register_custom_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025 SandAI. All Rights Reserved. +# Copyright (c) 2026 SandAI. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,17 +13,35 @@ # limitations under the License. """ -Tests for @magi_register_custom_op decorator functionality. - -This module tests: -- Basic custom op registration (forward only) -- Custom op with infer_output_meta_fn for torch.compile tracing -- Custom op with autograd support (setup_context + backward) -- Full custom op with all components -- Multiple outputs support -- Integration with magi_compile decorator +This test suite covers the plain ``@magi_register_custom_op`` path (without Triton kernels). + +Coverage Matrix & Sections: +--------------------------- +SECTION 1: Core Registration & Meta Inference + - Basic usage & auto-generated namings + - Explicit / Identity / Param-name-based ``infer_output_meta_fn`` + - Name duplicate rejection & missing-namespace rejection + - Compile & Subgraph boundary / Compute-sensitive markings + +SECTION 2: Type Support & Schema Bridging + - **Dataclass flattening**: Single, nested, Optional fields (``Optional[Tensor/scalar]``, PEP 604), List fields (``list[int/bool/Tensor/Optional[Tensor]]``). + - **Downgrades & Scrubs**: ``Literal[str]``, ``Enum``, ``list[int]=[0]`` defaults. + - **Unsupported Rejections**: Mutable dataclasses, Return dataclasses, Top-level ``tuple/dict``, Unresolvable local types. + - **Special Inputs/Outputs**: ``dtype/device`` fields, Top-level ``Optional[Tensor]``, Returning ``list[Tensor]``. + +SECTION 3: Autograd Bridge + - Standard setup/backward for plain and dataclass inputs + - Tuple multi-output backward + - Per-field ``None`` gradients + - Backward function calling another magi op + +SECTION 4: Compositions & Python Semantics + - Kwarg-only dataclass calls + - ``mutates_args`` mapping + - Nested magi op calls (an op calling another op) """ +import dataclasses import tempfile from unittest.mock import patch @@ -33,7 +51,11 @@ from torch.testing import assert_close from magi_compiler.api import magi_compile, magi_register_custom_op -from magi_compiler.config import CompileConfig +from magi_compiler.config import CompileConfig, CompileMode + +# ============================================================================ +# SECTION 1: Core Registration & Meta Inference +# ============================================================================ class TestBasicRegistration: @@ -56,7 +78,9 @@ def test_multiple_inputs(self): """Test custom op with multiple input tensors.""" @magi_register_custom_op(name="test::multi_input_op", mutates_args=()) - def _multi_input_op(a: torch.Tensor, b: torch.Tensor, scale: float) -> torch.Tensor: + def _multi_input_op( + a: torch.Tensor, b: torch.Tensor, scale: float + ) -> torch.Tensor: return (a + b) * scale a = torch.randn(4, 8) @@ -74,13 +98,19 @@ class TestInferOutputMeta: def test_with_infer_output_meta(self): """Test that infer_output_meta_fn is correctly registered for tracing.""" - def _scaled_add_infer_output_meta(x: torch.Tensor, y: torch.Tensor, scale: float) -> torch.Tensor: + def _scaled_add_infer_output_meta( + x: torch.Tensor, y: torch.Tensor, scale: float + ) -> torch.Tensor: return torch.empty_like(x) @magi_register_custom_op( - name="test::scaled_add_op", mutates_args=(), infer_output_meta_fn=_scaled_add_infer_output_meta + name="test::scaled_add_op", + mutates_args=(), + infer_output_meta_fn=_scaled_add_infer_output_meta, ) - def _scaled_add_op(x: torch.Tensor, y: torch.Tensor, scale: float) -> torch.Tensor: + def _scaled_add_op( + x: torch.Tensor, y: torch.Tensor, scale: float + ) -> torch.Tensor: return (x + y) * scale x = torch.randn(4, 8) @@ -94,11 +124,20 @@ def _scaled_add_op(x: torch.Tensor, y: torch.Tensor, scale: float) -> torch.Tens def test_multiple_outputs_infer_meta(self): """Test infer_output_meta_fn with multiple outputs.""" - def _split_op_infer_output_meta(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _split_op_infer_output_meta( + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: half_size = x.shape[-1] // 2 - return (x.new_empty((*x.shape[:-1], half_size)), x.new_empty((*x.shape[:-1], half_size))) + return ( + x.new_empty((*x.shape[:-1], half_size)), + x.new_empty((*x.shape[:-1], half_size)), + ) - @magi_register_custom_op(name="test::split_op", mutates_args=(), infer_output_meta_fn=_split_op_infer_output_meta) + @magi_register_custom_op( + name="test::split_op", + mutates_args=(), + infer_output_meta_fn=_split_op_infer_output_meta, + ) def _split_op(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: half_size = x.shape[-1] // 2 # NOTE: Output cannot share the same memory with input @@ -113,126 +152,6 @@ def _split_op(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert_close(out2, x[..., 4:]) -class TestAutograd: - """Tests for custom op with autograd support.""" - - def test_with_autograd(self): - """Test custom op with setup_context and backward functions.""" - - def _square_infer_output_meta(x: torch.Tensor) -> torch.Tensor: - return torch.empty_like(x) - - def _square_setup_context(ctx, inputs, output): - (x,) = inputs - ctx.save_for_backward(x) - - def _square_backward(ctx, grad_output): - (x,) = ctx.saved_tensors - return grad_output * 2 * x - - @magi_register_custom_op( - name="test::square_op", - mutates_args=(), - infer_output_meta_fn=_square_infer_output_meta, - setup_context_fn=_square_setup_context, - backward_fn=_square_backward, - ) - def _square_op(x: torch.Tensor) -> torch.Tensor: - return x * x - - x = torch.randn(4, 8, requires_grad=True) - output = _square_op(x) - loss = output.sum() - loss.backward() - - # Gradient of x^2 is 2x - expected_grad = 2 * x - assert_close(x.grad, expected_grad) - - def test_autograd_multiple_inputs(self): - """Test autograd with multiple input tensors.""" - - def _weighted_sum_infer_output_meta(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor: - return torch.empty_like(a) - - def _weighted_sum_setup_context(ctx, inputs, output): - a, b, weight = inputs - ctx.save_for_backward(a, b) - ctx.weight = weight - - def _weighted_sum_backward(ctx, grad_output): - a, b = ctx.saved_tensors - weight = ctx.weight - grad_a = grad_output * weight - grad_b = grad_output * (1 - weight) - return grad_a, grad_b, None # None for non-tensor input - - @magi_register_custom_op( - name="test::weighted_sum_op", - mutates_args=(), - infer_output_meta_fn=_weighted_sum_infer_output_meta, - setup_context_fn=_weighted_sum_setup_context, - backward_fn=_weighted_sum_backward, - ) - def _weighted_sum_op(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor: - return a * weight + b * (1 - weight) - - a = torch.randn(4, 8, requires_grad=True) - b = torch.randn(4, 8, requires_grad=True) - weight = 0.7 - - output = _weighted_sum_op(a, b, weight) - loss = output.sum() - loss.backward() - - expected_grad_a = torch.ones_like(a) * weight - expected_grad_b = torch.ones_like(b) * (1 - weight) - - assert_close(a.grad, expected_grad_a) - assert_close(b.grad, expected_grad_b) - - def test_autograd_multiple_outputs(self): - """Test autograd with multiple output tensors.""" - - def _split_scale_infer_output_meta(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]: - half = x.shape[-1] // 2 - return (x.new_empty((*x.shape[:-1], half)), x.new_empty((*x.shape[:-1], half))) - - def _split_scale_setup_context(ctx, inputs, output): - x, scale = inputs - ctx.save_for_backward(x) - ctx.scale = scale - ctx.half = x.shape[-1] // 2 - - def _split_scale_backward(ctx, grad_out1, grad_out2): - (x,) = ctx.saved_tensors - scale = ctx.scale - # Reconstruct gradient for x - grad_x = torch.cat([grad_out1 * scale, grad_out2 * scale], dim=-1) - return grad_x, None - - @magi_register_custom_op( - name="test::split_scale_op", - mutates_args=(), - infer_output_meta_fn=_split_scale_infer_output_meta, - setup_context_fn=_split_scale_setup_context, - backward_fn=_split_scale_backward, - ) - def _split_scale_op(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]: - half = x.shape[-1] // 2 - return x[..., :half] * scale, x[..., half:] * scale - - x = torch.randn(4, 8, requires_grad=True) - scale = 2.0 - - out1, out2 = _split_scale_op(x, scale) - loss = out1.sum() + out2.sum() - loss.backward() - - expected_grad = torch.ones_like(x) * scale - assert_close(x.grad, expected_grad) - - class TestAutoGeneratedName: """Tests for auto-generated operator name when name is not provided.""" @@ -258,7 +177,9 @@ def test_auto_name_multiple_outputs(self): """Test auto-generated name with multiple tensor outputs.""" @magi_register_custom_op() - def _auto_name_multi_out_op(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _auto_name_multi_out_op( + a: torch.Tensor, b: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: return torch.clone(a + 1), torch.clone(b + 2) def fn(a, b): @@ -284,7 +205,9 @@ def _auto_grad_backward(ctx, grad_output): (x,) = ctx.saved_tensors return grad_output * 2 * x - @magi_register_custom_op(setup_context_fn=_auto_grad_setup_context, backward_fn=_auto_grad_backward) + @magi_register_custom_op( + setup_context_fn=_auto_grad_setup_context, backward_fn=_auto_grad_backward + ) def _auto_name_square_op(x: torch.Tensor) -> torch.Tensor: return x * x @@ -322,7 +245,9 @@ def test_single_output_multiple_inputs_default_meta(self): """Test default meta function with multiple inputs but single tensor output.""" @magi_register_custom_op(name="test::default_meta_multi_in") - def _default_meta_multi_in_op(a: torch.Tensor, b: torch.Tensor, scale: float) -> torch.Tensor: + def _default_meta_multi_in_op( + a: torch.Tensor, b: torch.Tensor, scale: float + ) -> torch.Tensor: return (a + b) * scale def fn(a, b, scale): @@ -342,7 +267,9 @@ def test_multiple_outputs_default_meta(self): """Test default meta function with multiple tensor outputs.""" @magi_register_custom_op(name="test::default_meta_multi_out") - def _default_meta_multi_out_op(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _default_meta_multi_out_op( + x: torch.Tensor, y: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: # Clone to avoid aliasing issues return torch.clone(x * 2), torch.clone(y * 3) @@ -405,6 +332,26 @@ def fn(scale, x, offset, y): assert_close(out2, y * scale + offset) +@pytest.fixture() +def magi_compile_config(): + """Fixture to set up a clean compile configuration for magi_compile tests.""" + compile_config = CompileConfig( + compile_mode=CompileMode.TORCH_COMPILE, cache_root_dir=tempfile.mkdtemp() + ) + + with ( + patch("magi_compiler.api.get_compile_config") as mock_get_config, + patch("torch.distributed.get_rank") as mock_rank, + ): + mock_get_config.return_value = compile_config + mock_rank.return_value = 0 + yield compile_config + + import shutil + + shutil.rmtree(compile_config.cache_root_dir, ignore_errors=True) + + class TestTorchCompileIntegration: """Tests for integration with torch.compile.""" @@ -414,7 +361,11 @@ def test_custom_op_in_compiled_function(self): def _double_infer_output_meta(x: torch.Tensor) -> torch.Tensor: return torch.empty_like(x) - @magi_register_custom_op(name="test::double_op", mutates_args=(), infer_output_meta_fn=_double_infer_output_meta) + @magi_register_custom_op( + name="test::double_op", + mutates_args=(), + infer_output_meta_fn=_double_infer_output_meta, + ) def _double_op(x: torch.Tensor) -> torch.Tensor: return x * 2 @@ -469,21 +420,6 @@ def fn(x): assert_close(x.grad, expected_grad) -@pytest.fixture() -def magi_compile_config(): - """Fixture to set up a clean compile configuration for magi_compile tests.""" - compile_config = CompileConfig(cache_root_dir=tempfile.mkdtemp()) - - with patch("magi_compiler.api.get_compile_config") as mock_get_config, patch("torch.distributed.get_rank") as mock_rank: - mock_get_config.return_value = compile_config - mock_rank.return_value = 0 - yield compile_config - - import shutil - - shutil.rmtree(compile_config.cache_root_dir, ignore_errors=True) - - class TestMagiCompileIntegration: """Tests for integration with magi_compile decorator.""" @@ -493,7 +429,11 @@ def test_custom_op_in_magi_compiled_module(self, magi_compile_config): def _triple_infer_output_meta(x: torch.Tensor) -> torch.Tensor: return torch.empty_like(x) - @magi_register_custom_op(name="test::triple_op", mutates_args=(), infer_output_meta_fn=_triple_infer_output_meta) + @magi_register_custom_op( + name="test::triple_op", + mutates_args=(), + infer_output_meta_fn=_triple_infer_output_meta, + ) def _triple_op(x: torch.Tensor) -> torch.Tensor: return x * 3 @@ -568,7 +508,9 @@ def _relu_custom_infer_output_meta(x: torch.Tensor) -> torch.Tensor: return torch.empty_like(x) @magi_register_custom_op( - name="test::relu_custom_op", mutates_args=(), infer_output_meta_fn=_relu_custom_infer_output_meta + name="test::relu_custom_op", + mutates_args=(), + infer_output_meta_fn=_relu_custom_infer_output_meta, ) def _relu_custom_op(x: torch.Tensor) -> torch.Tensor: return torch.relu(x) @@ -599,11 +541,19 @@ def _add_one_infer_output_meta(x: torch.Tensor) -> torch.Tensor: def _mul_two_infer_output_meta(x: torch.Tensor) -> torch.Tensor: return torch.empty_like(x) - @magi_register_custom_op(name="test::add_one_op", mutates_args=(), infer_output_meta_fn=_add_one_infer_output_meta) + @magi_register_custom_op( + name="test::add_one_op", + mutates_args=(), + infer_output_meta_fn=_add_one_infer_output_meta, + ) def _add_one_op(x: torch.Tensor) -> torch.Tensor: return x + 1 - @magi_register_custom_op(name="test::mul_two_op", mutates_args=(), infer_output_meta_fn=_mul_two_infer_output_meta) + @magi_register_custom_op( + name="test::mul_two_op", + mutates_args=(), + infer_output_meta_fn=_mul_two_infer_output_meta, + ) def _mul_two_op(x: torch.Tensor) -> torch.Tensor: return x * 2 @@ -623,14 +573,25 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: expected = (x + 1) * 2 assert_close(output, expected) - def test_custom_op_multiple_outputs_in_magi_compiled_module(self, magi_compile_config): + def test_custom_op_multiple_outputs_in_magi_compiled_module( + self, magi_compile_config + ): """Test custom op with multiple outputs inside a magi_compile'd module.""" - def _split_v2_infer_output_meta(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def _split_v2_infer_output_meta( + x: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: half = x.shape[-1] // 2 - return (x.new_empty((*x.shape[:-1], half)), x.new_empty((*x.shape[:-1], half))) + return ( + x.new_empty((*x.shape[:-1], half)), + x.new_empty((*x.shape[:-1], half)), + ) - @magi_register_custom_op(name="test::split_v2_op", mutates_args=(), infer_output_meta_fn=_split_v2_infer_output_meta) + @magi_register_custom_op( + name="test::split_v2_op", + mutates_args=(), + infer_output_meta_fn=_split_v2_infer_output_meta, + ) def _split_v2_op(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: half = x.shape[-1] // 2 return torch.clone(x[..., :half]), torch.clone(x[..., half:]) @@ -655,5 +616,1516 @@ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: assert_close(out2, x[..., 4:] * 2) +class TestDuplicateOpNameRejected: + """Re-registering the same ``namespace::op_name`` should raise a clear + error instead of letting ``torch.library`` complain about schema + fingerprints. + """ + + def test_duplicate_name_rejected(self): + @magi_register_custom_op(name="test::dup_name_first") + def _op_a(x: torch.Tensor) -> torch.Tensor: + return x + 1 + + with pytest.raises(RuntimeError, match="already"): + + @magi_register_custom_op(name="test::dup_name_first") + def _op_b(x: torch.Tensor) -> torch.Tensor: + return x + 2 + + +class TestOpNameNamespaceRequired: + """``torch.library`` requires ``namespace::op_name``; a bare name causes a + confusing low-level error. We surface a clear, actionable message + pointing at the convention. + """ + + def test_missing_namespace_rejected(self): + with pytest.raises(ValueError, match="namespace"): + + @magi_register_custom_op(name="missing_namespace_op") + def _op(x: torch.Tensor) -> torch.Tensor: + return x + + def test_namespaced_name_accepted(self): + @magi_register_custom_op(name="test::ns_ok") + def _op(x: torch.Tensor) -> torch.Tensor: + return x + 1 + + assert_close(_op(torch.zeros(2)), torch.ones(2)) + + +@dataclasses.dataclass(frozen=True) +class _CSCfg: + s: float + + +class TestDataclassWithComputeSensitiveSmoke: + """Dataclass input + ``is_compute_sensitive=True`` should register cleanly + and the op name should land in ``custom_compute_sensitive_ops``. + """ + + def test_dataclass_compute_sensitive_smoke(self): + from magi_compiler.config import get_compile_config + + @magi_register_custom_op( + name="test::dc_compute_sensitive", + is_compute_sensitive=True, + ) + def _op(x: torch.Tensor, cfg: _CSCfg) -> torch.Tensor: + return x * cfg.s + + out = _op(torch.ones(2), _CSCfg(s=2.0)) + assert_close(out, torch.full((2,), 2.0)) + assert ( + "test::dc_compute_sensitive" + in get_compile_config().recompute_config.custom_compute_sensitive_ops + ) + + +# ============================================================================ +# SECTION 2: Type Support & Schema Bridging +# ============================================================================ + + +class TestFrozenDataclassInput: + """Tests for frozen-dataclass input support via the auto pytree path.""" + + def test_forward_with_dataclass_only(self): + """Custom op whose only argument is a frozen dataclass.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _ScaleCfg: + scale: float + bias: float + + @magi_register_custom_op(name="test::dc_only_op", mutates_args=()) + def _dc_only_op(cfg: _ScaleCfg) -> torch.Tensor: + return torch.full((2, 3), cfg.scale + cfg.bias) + + cfg = _ScaleCfg(scale=2.0, bias=1.0) + out = _dc_only_op(cfg) + assert_close(out, torch.full((2, 3), 3.0)) + + def test_forward_with_mixed_args(self): + """Custom op with mixed dataclass and tensor inputs.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _AttnCfg: + scale: float + causal: bool + + @magi_register_custom_op(name="test::dc_mixed_op", mutates_args=()) + def _dc_mixed_op( + q: torch.Tensor, k: torch.Tensor, cfg: _AttnCfg + ) -> torch.Tensor: + out = (q @ k.transpose(-1, -2)) * cfg.scale + if cfg.causal: + mask = torch.tril(torch.ones_like(out, dtype=torch.bool)) + out = out.masked_fill(~mask, 0.0) + return out + + q = torch.randn(2, 4, 8) + k = torch.randn(2, 4, 8) + cfg = _AttnCfg(scale=0.25, causal=False) + out = _dc_mixed_op(q, k, cfg) + expected = (q @ k.transpose(-1, -2)) * 0.25 + assert_close(out, expected) + + # Try positional + keyword call form to ensure outer wrapper handles both. + out_kw = _dc_mixed_op(q, k=k, cfg=cfg) + assert_close(out_kw, expected) + + def test_forward_with_custom_meta_fn(self): + """Custom op with a user-provided meta function expressed in dataclass terms.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _ProjCfg: + out_dim: int + + def _proj_meta(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: + return x.new_empty((*x.shape[:-1], cfg.out_dim)) + + @magi_register_custom_op( + name="test::dc_meta_op", mutates_args=(), infer_output_meta_fn=_proj_meta + ) + def _dc_meta_op(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: + return x[..., : cfg.out_dim].clone() + + x = torch.randn(2, 8) + cfg = _ProjCfg(out_dim=3) + out = _dc_meta_op(x, cfg) + assert out.shape == (2, 3) + assert_close(out, x[..., :3]) + + +class TestNestedDataclassInput: + """Tests for *recursively* flattened nested-dataclass inputs.""" + + def test_nested_dataclass_only(self): + """A single dataclass argument that itself contains a dataclass field.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Inner: + scale: float + bias: float + + @dataclass(frozen=True) + class _Outer: + inner: _Inner + offset: float + + @magi_register_custom_op(name="test::nested_dc_only_op", mutates_args=()) + def _nested_dc_only_op(cfg: _Outer) -> torch.Tensor: + return torch.full((2, 3), cfg.inner.scale * cfg.inner.bias + cfg.offset) + + cfg = _Outer(inner=_Inner(scale=2.0, bias=3.0), offset=1.5) + out = _nested_dc_only_op(cfg) + assert_close(out, torch.full((2, 3), 7.5)) + + # The flat plan should fully expand the nested dataclass into leaves. + plan = _nested_dc_only_op._magi_flat_plan + assert plan[0][0] == "dataclass" + assert plan[0][1] == "cfg" + children = plan[0][3] + kinds = [c[0] for c in children] + assert ( + "dataclass" in kinds + ), "inner dataclass field must remain a dataclass node" + # Find the leaf flat names. + flat_names: list[str] = [] + + def _collect(node): + if node[0] == "primitive": + flat_names.append(node[2]) + else: + for child in node[3]: + _collect(child) + + _collect(plan[0]) + assert "cfg__inner__scale" in flat_names + assert "cfg__inner__bias" in flat_names + assert "cfg__offset" in flat_names + + def test_nested_dataclass_mixed_with_tensor(self): + """Tensor arg alongside a nested dataclass arg.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _NormCfg: + eps: float + affine: bool + + @dataclass(frozen=True) + class _LayerCfg: + norm: _NormCfg + scale: float + + @magi_register_custom_op(name="test::nested_dc_mixed_op", mutates_args=()) + def _nested_dc_mixed_op(x: torch.Tensor, cfg: _LayerCfg) -> torch.Tensor: + y = x / (x.std() + cfg.norm.eps) + if cfg.norm.affine: + y = y * cfg.scale + return y + + x = torch.randn(4, 8) + cfg = _LayerCfg(norm=_NormCfg(eps=1e-3, affine=True), scale=2.0) + out = _nested_dc_mixed_op(x, cfg) + expected = x / (x.std() + 1e-3) * 2.0 + assert_close(out, expected) + + # Keyword form too. + out_kw = _nested_dc_mixed_op(x=x, cfg=cfg) + assert_close(out_kw, expected) + + def test_deeply_nested_dataclass(self): + """Three levels of nesting plus a sibling primitive at the top level.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Leaf: + val: float + + @dataclass(frozen=True) + class _Mid: + leaf: _Leaf + extra: int + + @dataclass(frozen=True) + class _Root: + mid: _Mid + tag: float + + @magi_register_custom_op(name="test::deep_nested_dc_op", mutates_args=()) + def _deep_nested_dc_op( + x: torch.Tensor, cfg: _Root, alpha: float + ) -> torch.Tensor: + return x * cfg.mid.leaf.val + cfg.mid.extra + cfg.tag + alpha + + x = torch.randn(2, 3) + cfg = _Root(mid=_Mid(leaf=_Leaf(val=2.0), extra=3), tag=0.5) + out = _deep_nested_dc_op(x, cfg, alpha=1.0) + assert_close(out, x * 2.0 + 3 + 0.5 + 1.0) + + def test_nested_dataclass_with_meta_fn(self): + """User-supplied meta function expressed in nested-dataclass terms.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _ShapeCfg: + out_dim: int + + @dataclass(frozen=True) + class _ProjCfg: + shape: _ShapeCfg + scale: float + + def _proj_meta(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: + return x.new_empty((*x.shape[:-1], cfg.shape.out_dim)) + + @magi_register_custom_op( + name="test::nested_dc_meta_op", + mutates_args=(), + infer_output_meta_fn=_proj_meta, + ) + def _nested_dc_meta_op(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: + return x[..., : cfg.shape.out_dim].clone() * cfg.scale + + x = torch.randn(2, 8) + cfg = _ProjCfg(shape=_ShapeCfg(out_dim=3), scale=2.0) + out = _nested_dc_meta_op(x, cfg) + assert out.shape == (2, 3) + assert_close(out, x[..., :3] * 2.0) + + +class TestDataclassOptionalFields: + """Frozen dataclass fields annotated with ``Optional[...]`` should flow + through the recursive flattener and be accepted by + ``torch.library.infer_schema`` as long as the underlying type is one of + the schema's known optional types. + + Coverage matrix (verified against torch 2.9.1): + + - ``Optional[Tensor]`` ✓ + - ``Optional[int]`` ✓ + - ``Optional[float]`` ✓ + - ``Optional[bool]`` ✓ + - ``Optional[str]`` ✓ + - ``Optional[list[int]]`` ✓ + - ``Tensor | None`` (PEP 604) ✓ + - ``Optional[list[Tensor]]`` ✗ (PyTorch limitation, see negative + test below; users should use ``list[Optional[Tensor]]`` instead) + """ + + def test_optional_scalar_and_tensor_fields(self): + from dataclasses import dataclass + from typing import Optional + + @dataclass(frozen=True) + class _Cfg: + bias: Optional[torch.Tensor] + scale: Optional[float] + mode: Optional[str] + block_sizes: Optional[list[int]] + flag: Optional[bool] + count: Optional[int] + + @magi_register_custom_op(name="test::dc_optional_mix") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + out = x.clone() + if cfg.bias is not None: + out = out + cfg.bias + if cfg.scale is not None: + out = out * cfg.scale + return out + + x = torch.ones(3) + # All-None: just the identity transform. + out_all_none = _op(x, _Cfg(None, None, None, None, None, None)) + assert_close(out_all_none, x) + # Some-None: bias + scale active, others ignored by the body. + cfg = _Cfg( + bias=torch.tensor([1.0, 2.0, 3.0]), + scale=2.0, + mode="a", + block_sizes=[4, 8], + flag=True, + count=7, + ) + out_some = _op(x, cfg) + assert_close(out_some, torch.tensor([4.0, 6.0, 8.0])) + + def test_pep604_optional_tensor_field(self): + """The ``X | None`` PEP 604 syntax should be equivalent to + ``Optional[X]`` for dataclass fields.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + bias: torch.Tensor | None + + @magi_register_custom_op(name="test::dc_pep604_optional") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x.clone() if cfg.bias is None else x + cfg.bias + + x = torch.ones(2) + assert_close(_op(x, _Cfg(bias=None)), x) + assert_close( + _op(x, _Cfg(bias=torch.tensor([10.0, 20.0]))), + torch.tensor([11.0, 21.0]), + ) + + def test_optional_list_of_tensors_unsupported_by_torch_library(self): + """Sanity-pin a known-broken case so a future torch upgrade that + relaxes ``infer_schema`` flips this from xfail to pass. + + ``Optional[list[Tensor]]`` is NOT in the torch.library schema's + accepted-type set (although ``list[Optional[Tensor]]`` is). The + magi flattener doesn't try to rewrite the user's annotation, so we + let the underlying ValueError propagate.""" + from dataclasses import dataclass + from typing import Optional + + @dataclass(frozen=True) + class _Cfg: + biases: Optional[list[torch.Tensor]] + + with pytest.raises(ValueError, match="unsupported type"): + + @magi_register_custom_op(name="test::dc_optional_list_tensor") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x + + +class TestDataclassListFields: + """Frozen dataclass fields annotated with ``list[...]`` flow through the + flattener and are accepted by ``torch.library.infer_schema`` for the + schema's supported scalar/tensor element types. + + Coverage matrix (verified against torch 2.9.1): + + - ``list[Tensor]`` ✓ + - ``list[int]`` ✓ + - ``list[float]`` ✓ + - ``list[bool]`` ✓ + - ``list[Optional[Tensor]]`` ✓ + """ + + def test_list_of_tensors_and_scalars_in_dataclass(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + biases: list[torch.Tensor] + scales: list[float] + block_sizes: list[int] + flags: list[bool] + + @magi_register_custom_op(name="test::dc_list_mix") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + out = x.clone() + for b in cfg.biases: + out = out + b + out = out * float(sum(cfg.scales)) + # ``flags`` and ``block_sizes`` are accepted on the schema side + # even though the body here doesn't depend on them. + assert all(isinstance(s, int) for s in cfg.block_sizes) + assert all(isinstance(f, bool) for f in cfg.flags) + return out + + x = torch.ones(3) + cfg = _Cfg( + biases=[torch.tensor([1.0, 2.0, 3.0]), torch.tensor([10.0, 20.0, 30.0])], + scales=[2.0, 0.5], + block_sizes=[4, 8], + flags=[True, False], + ) + out = _op(x, cfg) + # (1 + 1 + 10) * 2.5 = 30, (1 + 2 + 20) * 2.5 = 57.5, ... + assert_close(out, torch.tensor([30.0, 57.5, 85.0])) + + def test_list_of_optional_tensors_in_dataclass(self): + """``list[Optional[Tensor]]`` is on the schema's supported list (in + contrast to the rejected ``Optional[list[Tensor]]``).""" + from dataclasses import dataclass + from typing import Optional + + @dataclass(frozen=True) + class _Cfg: + maybe_biases: list[Optional[torch.Tensor]] + + @magi_register_custom_op(name="test::dc_list_optional_tensor") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + out = x.clone() + for b in cfg.maybe_biases: + if b is not None: + out = out + b + return out + + x = torch.ones(2) + out = _op( + x, + _Cfg(maybe_biases=[None, torch.tensor([10.0, 20.0]), None]), + ) + assert_close(out, torch.tensor([11.0, 21.0])) + + def test_empty_list_field(self): + """An empty ``list[Tensor]`` value should round-trip through the + flattener and run the body normally.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + biases: list[torch.Tensor] + + @magi_register_custom_op(name="test::dc_empty_list") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + out = x.clone() + for b in cfg.biases: + out = out + b + return out + + x = torch.ones(3) + assert_close(_op(x, _Cfg(biases=[])), x) + + +class TestDataclassFieldDefaults: + """Dataclass field ``default`` / ``default_factory`` values should + propagate to the flat ``inspect.Signature`` so that calling the op + without those fields works (and ``torch.library.infer_schema`` records + them as optional).""" + + def test_dataclass_field_default_carried_into_signature(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + scale: float = 2.0 + + @magi_register_custom_op(name="test::dc_field_default_scale") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.scale + + x = torch.ones(3) + # User constructs the dataclass with the default value implicitly. + assert_close(_op(x, _Cfg()), torch.full((3,), 2.0)) + # And the override path still works. + assert_close(_op(x, _Cfg(scale=10.0)), torch.full((3,), 10.0)) + + def test_dataclass_field_default_factory_for_list(self): + from dataclasses import dataclass, field + + @dataclass(frozen=True) + class _Cfg: + block_sizes: list[int] = field(default_factory=lambda: [4, 8]) + + @magi_register_custom_op(name="test::dc_field_default_factory") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * float(sum(cfg.block_sizes)) + + x = torch.ones(2) + assert_close(_op(x, _Cfg()), torch.full((2,), 12.0)) + assert_close(_op(x, _Cfg(block_sizes=[1, 2, 3])), torch.full((2,), 6.0)) + + def test_nested_dataclass_with_default(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Inner: + val: float = 3.0 + + @dataclass(frozen=True) + class _Outer: + inner: _Inner = _Inner() + scale: float = 1.0 + + @magi_register_custom_op(name="test::dc_nested_default") + def _op(x: torch.Tensor, cfg: _Outer) -> torch.Tensor: + return x * cfg.inner.val * cfg.scale + + x = torch.ones(2) + assert_close(_op(x, _Outer()), torch.full((2,), 3.0)) + assert_close( + _op(x, _Outer(inner=_Inner(val=5.0), scale=2.0)), + torch.full((2,), 10.0), + ) + + +class TestDataclassUnsupportedContainerFields: + """``tuple[...]`` and ``dict[...]`` dataclass field annotations are not + accepted by ``torch.library.infer_schema`` (only ``list[...]`` is). We + reject them at registration time with an actionable hint pointing to the + field name and suggested fix.""" + + def test_tuple_field_rejected_with_hint(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + qkv: tuple[torch.Tensor, torch.Tensor, torch.Tensor] + + with pytest.raises(TypeError, match=r"list\[\.\.\.\]"): + + @magi_register_custom_op(name="test::dc_tuple_field") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x + + def test_dict_field_rejected_with_hint(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + weight_by_name: dict[str, torch.Tensor] + + with pytest.raises(TypeError, match="dict-typed"): + + @magi_register_custom_op(name="test::dc_dict_field") + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x + + +class TestMutableDataclassRejected: + """``magi_register_custom_op`` only supports ``@dataclass(frozen=True)``. + + A non-frozen dataclass passed as an op input would otherwise produce a + confusing ``ValueError: Unsupported type annotation X. It is not a type`` + deep inside ``torch.library.infer_schema``. We surface a clear error + pointing at the fix. + """ + + def test_top_level_mutable_dataclass_rejected(self): + from dataclasses import dataclass + + @dataclass # NOTE: missing frozen=True + class _MutableCfg: + scale: float + + with pytest.raises(TypeError, match="frozen=True"): + + @magi_register_custom_op(name="test::mutable_dc_top") + def _op(x: torch.Tensor, cfg: _MutableCfg) -> torch.Tensor: + return x * cfg.scale + + def test_nested_mutable_dataclass_rejected(self): + """Inner non-frozen dataclass nested inside an outer frozen one is + also detected; the error names the offending field.""" + from dataclasses import dataclass + + @dataclass # NOT frozen + class _MutableInner: + val: float + + @dataclass(frozen=True) + class _OuterCfg: + inner: _MutableInner + + with pytest.raises(TypeError, match="_MutableInner"): + + @magi_register_custom_op(name="test::mutable_dc_nested") + def _op(x: torch.Tensor, cfg: _OuterCfg) -> torch.Tensor: + return x * cfg.inner.val + + def test_frozen_dataclass_still_works(self): + """Sanity: the rejection path doesn't break the supported case.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _OkCfg: + scale: float + + @magi_register_custom_op(name="test::frozen_dc_ok") + def _op(x: torch.Tensor, cfg: _OkCfg) -> torch.Tensor: + return x * cfg.scale + + out = _op(torch.ones(4), _OkCfg(scale=3.0)) + assert_close(out, torch.full((4,), 3.0)) + + +class TestDataclassReturnRejected: + """``torch.library`` cannot model dataclass returns; we should refuse with + an actionable error instead of letting ``infer_schema`` raise an opaque + ``ValueError`` from the depths. + """ + + def test_dataclass_return_rejected(self): + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Out: + a: torch.Tensor + + with pytest.raises(TypeError, match="dataclass"): + + @magi_register_custom_op(name="test::dc_return") + def _op(x: torch.Tensor) -> _Out: + return _Out(a=x.clone()) + + def test_tensor_return_still_works(self): + @magi_register_custom_op(name="test::tensor_return_ok") + def _op(x: torch.Tensor) -> torch.Tensor: + return x + 1 + + out = _op(torch.zeros(3)) + assert_close(out, torch.ones(3)) + + +class TestDataclassUnresolvableLocalField: + """A dataclass field annotated with a class defined inside another + function (a "local class") under ``from __future__ import annotations`` + cannot be resolved by ``typing.get_type_hints``. We surface a clear + error pointing at the fix instead of a confusing schema error. + """ + + def test_unresolvable_local_field_rejected(self): + # We synthesize the failure by hand-constructing a frozen dataclass + # whose field type is a *string* that doesn't resolve in any visible + # namespace. This mirrors what ``from __future__ import annotations`` + # produces for a local class. + import dataclasses + + @dataclasses.dataclass(frozen=True) + class _OuterCfg: + inner: "_DefinitelyNotInScope" # noqa: F821 + + with pytest.raises(TypeError, match="unresolved string annotation"): + + @magi_register_custom_op(name="test::dc_unresolvable_local") + def _op(x: torch.Tensor, cfg: _OuterCfg) -> torch.Tensor: + return x + + +@dataclasses.dataclass(frozen=True) +class _OptScalarCfg: + softcap: float | None = None + causal: bool | None = None + + +class TestDataclassOptionalScalarFields: + """Common attention/cfg pattern: dataclass field is ``Optional[scalar]`` + with a ``None`` default. Verify forward across explicit-None and + explicit-value calls. + """ + + def test_optional_scalar_defaults_and_overrides(self): + @magi_register_custom_op(name="test::dc_opt_scalar") + def _op(x: torch.Tensor, cfg: _OptScalarCfg) -> torch.Tensor: + y = x.clone() + if cfg.softcap is not None: + y = torch.tanh(y / cfg.softcap) * cfg.softcap + if cfg.causal is True: + y = y * 0.5 + return y + + x = torch.randn(4) + assert_close(_op(x, _OptScalarCfg()), x) + out = _op(x, _OptScalarCfg(softcap=2.0, causal=True)) + expected = torch.tanh(x / 2.0) * 2.0 * 0.5 + assert_close(out, expected) + + +@dataclasses.dataclass(frozen=True) +class _DtypeCfg: + out_dtype: torch.dtype = torch.float16 + out_device: torch.device | None = None + + +class TestDataclassDtypeDeviceFields: + """``torch.dtype`` / ``torch.device`` are the canonical schema-supported + "structured scalar" types. Make sure they survive the dataclass-flatten + round-trip with sensible defaults. + """ + + def test_dtype_device_round_trip(self): + @magi_register_custom_op(name="test::dc_dtype_device") + def _op(x: torch.Tensor, cfg: _DtypeCfg) -> torch.Tensor: + return x.to(dtype=cfg.out_dtype) + + x = torch.randn(4) + out = _op(x, _DtypeCfg(out_dtype=torch.bfloat16)) + assert out.dtype == torch.bfloat16 + + +class TestTopLevelTupleDictRejected: + """Top-level parameters of type ``tuple[T, ...]`` / ``dict[K, V]`` are + rejected with the same actionable message as dataclass fields. + """ + + def test_top_level_tuple_rejected(self): + with pytest.raises(TypeError, match="tuple"): + + @magi_register_custom_op(name="test::top_tuple") + def _op(xs: tuple[torch.Tensor, ...]) -> torch.Tensor: + return xs[0] + + def test_top_level_dict_rejected(self): + with pytest.raises(TypeError, match="dict"): + + @magi_register_custom_op(name="test::top_dict") + def _op(xs: dict[str, torch.Tensor]) -> torch.Tensor: + return next(iter(xs.values())) + + +class TestLiteralAndEnumDowngrade: + """``Literal[str, ...]`` and string-valued ``Enum`` annotations are + auto-downgraded to ``str`` at the schema boundary; the op body still + receives the raw string. Numeric/heterogeneous Literals are rejected + with a clear error. + """ + + def test_literal_str_downgraded(self): + from typing import Literal + + @magi_register_custom_op(name="test::literal_str") + def _op(x: torch.Tensor, mode: Literal["a", "b"]) -> torch.Tensor: + return x * (2.0 if mode == "a" else 3.0) + + x = torch.ones(3) + assert_close(_op(x, "a"), torch.full((3,), 2.0)) + assert_close(_op(x, "b"), torch.full((3,), 3.0)) + + def test_string_enum_downgraded(self): + import enum + + class _Mode(enum.Enum): + A = "a" + B = "b" + + @magi_register_custom_op(name="test::enum_str") + def _op(x: torch.Tensor, mode: _Mode) -> torch.Tensor: + return x * (2.0 if mode == "a" else 3.0) + + # Note: caller passes the raw string value (the schema sees ``str``). + x = torch.ones(3) + assert_close(_op(x, "a"), torch.full((3,), 2.0)) + + def test_int_literal_rejected(self): + from typing import Literal + + with pytest.raises(TypeError, match="Literal"): + + @magi_register_custom_op(name="test::literal_int") + def _op(x: torch.Tensor, k: Literal[1, 2]) -> torch.Tensor: + return x + + def test_int_valued_enum_rejected(self): + import enum + + class _IntMode(enum.Enum): + A = 1 + B = 2 + + with pytest.raises(TypeError, match="Enum"): + + @magi_register_custom_op(name="test::enum_int") + def _op(x: torch.Tensor, mode: _IntMode) -> torch.Tensor: + return x + + +class TestTopLevelOptionalTensor: + """``Optional[Tensor]`` / ``Tensor | None`` is the most common attention / + bias / mask pattern. Verify both annotation flavours, with and without + a ``None`` default, and across forward + autograd paths. + """ + + def test_optional_tensor_default_none(self): + @magi_register_custom_op(name="test::opt_tensor_default") + def _op(x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: + return x.clone() if bias is None else x + bias + + x = torch.randn(3, 4) + b = torch.randn(3, 4) + assert_close(_op(x), x) + assert_close(_op(x, b), x + b) + + def test_optional_tensor_no_default_explicit_none(self): + @magi_register_custom_op(name="test::opt_tensor_explicit") + def _op(x: torch.Tensor, bias: torch.Tensor | None) -> torch.Tensor: + return x.clone() if bias is None else x + bias + + x = torch.randn(3, 4) + assert_close(_op(x, None), x) + + +class TestReturnListTensor: + """``list[Tensor]`` is the canonical return shape for split / chunk / + expert-output ops. Validate forward and that a user-supplied + ``infer_output_meta_fn`` is required for tracing. + """ + + def test_list_tensor_return_runtime(self): + def meta(x, n): + return [torch.empty_like(x.chunk(n, dim=0)[0]) for _ in range(n)] + + @magi_register_custom_op( + name="test::list_tensor_return", + infer_output_meta_fn=meta, + ) + def _op(x: torch.Tensor, n: int) -> list[torch.Tensor]: + # ``.chunk`` returns views that alias ``x``; clone each chunk so + # the op doesn't violate torch.library's no-aliasing invariant + # for outputs. + return [c.clone() for c in x.chunk(n, dim=0)] + + x = torch.randn(8, 4) + out = _op(x, 4) + assert isinstance(out, list) + assert len(out) == 4 + assert all(o.shape == (2, 4) for o in out) + + +class TestTopLevelListIntDefault: + """A top-level parameter typed ``list[int]`` with a non-None default + should not crash registration (the previous behavior was to forward the + list literal to ``infer_schema``, which rejects it). + """ + + def test_list_int_default_does_not_crash(self): + @magi_register_custom_op(name="test::list_int_default") + def _op(x: torch.Tensor, dims: list[int] = [0]) -> torch.Tensor: + return x.sum(dim=dims, keepdim=True) + + x = torch.randn(3, 4) + # Caller still has to supply the list explicitly because the schema + # default was scrubbed; that's the documented price of using a list + # default with torch.library. + out = _op(x, [0]) + assert out.shape == (1, 4) + + +# ============================================================================ +# SECTION 3: Autograd Bridge +# ============================================================================ + + +class TestAutograd: + """Tests for custom op with autograd support.""" + + def test_with_autograd(self): + """Test custom op with setup_context and backward functions.""" + + def _square_infer_output_meta(x: torch.Tensor) -> torch.Tensor: + return torch.empty_like(x) + + def _square_setup_context(ctx, inputs, output): + (x,) = inputs + ctx.save_for_backward(x) + + def _square_backward(ctx, grad_output): + (x,) = ctx.saved_tensors + return grad_output * 2 * x + + @magi_register_custom_op( + name="test::square_op", + mutates_args=(), + infer_output_meta_fn=_square_infer_output_meta, + setup_context_fn=_square_setup_context, + backward_fn=_square_backward, + ) + def _square_op(x: torch.Tensor) -> torch.Tensor: + return x * x + + x = torch.randn(4, 8, requires_grad=True) + output = _square_op(x) + loss = output.sum() + loss.backward() + + # Gradient of x^2 is 2x + expected_grad = 2 * x + assert_close(x.grad, expected_grad) + + def test_autograd_multiple_inputs(self): + """Test autograd with multiple input tensors.""" + + def _weighted_sum_infer_output_meta( + a: torch.Tensor, b: torch.Tensor, weight: float + ) -> torch.Tensor: + return torch.empty_like(a) + + def _weighted_sum_setup_context(ctx, inputs, output): + a, b, weight = inputs + ctx.save_for_backward(a, b) + ctx.weight = weight + + def _weighted_sum_backward(ctx, grad_output): + a, b = ctx.saved_tensors + weight = ctx.weight + grad_a = grad_output * weight + grad_b = grad_output * (1 - weight) + return grad_a, grad_b, None # None for non-tensor input + + @magi_register_custom_op( + name="test::weighted_sum_op", + mutates_args=(), + infer_output_meta_fn=_weighted_sum_infer_output_meta, + setup_context_fn=_weighted_sum_setup_context, + backward_fn=_weighted_sum_backward, + ) + def _weighted_sum_op( + a: torch.Tensor, b: torch.Tensor, weight: float + ) -> torch.Tensor: + return a * weight + b * (1 - weight) + + a = torch.randn(4, 8, requires_grad=True) + b = torch.randn(4, 8, requires_grad=True) + weight = 0.7 + + output = _weighted_sum_op(a, b, weight) + loss = output.sum() + loss.backward() + + expected_grad_a = torch.ones_like(a) * weight + expected_grad_b = torch.ones_like(b) * (1 - weight) + + assert_close(a.grad, expected_grad_a) + assert_close(b.grad, expected_grad_b) + + def test_autograd_multiple_outputs(self): + """Test autograd with multiple output tensors.""" + + def _split_scale_infer_output_meta( + x: torch.Tensor, scale: float + ) -> tuple[torch.Tensor, torch.Tensor]: + half = x.shape[-1] // 2 + return ( + x.new_empty((*x.shape[:-1], half)), + x.new_empty((*x.shape[:-1], half)), + ) + + def _split_scale_setup_context(ctx, inputs, output): + x, scale = inputs + ctx.save_for_backward(x) + ctx.scale = scale + ctx.half = x.shape[-1] // 2 + + def _split_scale_backward(ctx, grad_out1, grad_out2): + (x,) = ctx.saved_tensors + scale = ctx.scale + # Reconstruct gradient for x + grad_x = torch.cat([grad_out1 * scale, grad_out2 * scale], dim=-1) + return grad_x, None + + @magi_register_custom_op( + name="test::split_scale_op", + mutates_args=(), + infer_output_meta_fn=_split_scale_infer_output_meta, + setup_context_fn=_split_scale_setup_context, + backward_fn=_split_scale_backward, + ) + def _split_scale_op( + x: torch.Tensor, scale: float + ) -> tuple[torch.Tensor, torch.Tensor]: + half = x.shape[-1] // 2 + return x[..., :half] * scale, x[..., half:] * scale + + x = torch.randn(4, 8, requires_grad=True) + scale = 2.0 + + out1, out2 = _split_scale_op(x, scale) + loss = out1.sum() + out2.sum() + loss.backward() + + expected_grad = torch.ones_like(x) * scale + assert_close(x.grad, expected_grad) + + +class TestDataclassInputWithBackward: + """``backward_fn`` is bridged so users can write it in the ORIGINAL + (dataclass-bearing) signature even though the registered op underneath + is flat. These tests pin every shape of bridging we promise. + """ + + def test_backward_with_tensor_and_dataclass_input(self): + """``backward`` returns ``(grad_x, None)`` against the original sig.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _ScaleCfg: + scale: float + + def _setup(ctx, inputs, output): + x, cfg = inputs # NOTE: original-signature view, not flat + assert isinstance(cfg, _ScaleCfg) + ctx.save_for_backward(x) + ctx.scale = cfg.scale + + def _bwd(ctx, grad_out): + (_x,) = ctx.saved_tensors + return grad_out * ctx.scale, None + + @magi_register_custom_op( + name="test::dc_bwd_basic", + mutates_args=(), + setup_context_fn=_setup, + backward_fn=_bwd, + ) + def _op(x: torch.Tensor, cfg: _ScaleCfg) -> torch.Tensor: + return x * cfg.scale + + x = torch.randn(4, 8, requires_grad=True) + cfg = _ScaleCfg(scale=3.0) + y = _op(x, cfg) + assert_close(y, x * 3.0) + (gx,) = torch.autograd.grad(y.sum(), x) + assert_close(gx, torch.full_like(x, 3.0)) + + def test_backward_returning_bare_grad_is_allowed(self): + """Single-input convenience: returning a bare tensor instead of a + 1-tuple should still work, mirroring stock PyTorch behaviour.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + alpha: float + + def _setup(ctx, inputs, output): + (cfg,) = inputs + ctx.alpha = cfg.alpha + + def _bwd(ctx, grad_out): + # Op has zero tensor inputs; nothing to backprop into. Returning + # ``None`` (or even a bare value) for the lone non-tensor input + # must be accepted by the bridge. + return None + + @magi_register_custom_op( + name="test::dc_bwd_bare_grad", + mutates_args=(), + setup_context_fn=_setup, + backward_fn=_bwd, + ) + def _op(cfg: _Cfg) -> torch.Tensor: + return torch.full((2, 3), cfg.alpha) + + out = _op(_Cfg(alpha=1.5)) + assert_close(out, torch.full((2, 3), 1.5)) + + def test_backward_with_per_field_grads(self): + """If the user wants to express grads at field granularity, returning + a same-shape dataclass (or a dict) for the dataclass slot must be + spread to the underlying flat slots. The bridge accepts ``None`` + leaves to stand in for non-differentiable fields.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + scale: float + offset: float + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + ctx.scale = cfg.scale + + def _bwd(ctx, grad_out): + # Express the dataclass grad as a same-shape dataclass with + # ``None`` leaves; the bridge must spread these to flat slots. + return grad_out * ctx.scale, _Cfg(scale=None, offset=None) + + @magi_register_custom_op( + name="test::dc_bwd_per_field", + mutates_args=(), + setup_context_fn=_setup, + backward_fn=_bwd, + ) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.scale + cfg.offset + + x = torch.randn(3, requires_grad=True) + y = _op(x, _Cfg(scale=2.0, offset=0.5)) + (gx,) = torch.autograd.grad(y.sum(), x) + assert_close(gx, torch.full_like(x, 2.0)) + + def test_backward_with_dict_shaped_dc_grad(self): + """The bridge should also accept a plain ``dict`` for the dataclass + slot (handy when the user does not want to construct another instance).""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + scale: float + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + ctx.scale = cfg.scale + + def _bwd(ctx, grad_out): + return grad_out * ctx.scale, {"scale": None} + + @magi_register_custom_op( + name="test::dc_bwd_dict_grad", + mutates_args=(), + setup_context_fn=_setup, + backward_fn=_bwd, + ) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.scale + + x = torch.randn(5, requires_grad=True) + y = _op(x, _Cfg(scale=4.0)) + (gx,) = torch.autograd.grad(y.sum(), x) + assert_close(gx, torch.full_like(x, 4.0)) + + def test_backward_with_nested_dataclass(self): + """Backward bridging must descend through nested dataclasses (the + flat slot count for a nested dc with ``None`` grad equals the total + number of leaf fields).""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Inner: + scale: float + bias: float + + @dataclass(frozen=True) + class _Outer: + inner: _Inner + tag: float + + def _setup(ctx, inputs, output): + x, cfg = inputs + assert isinstance(cfg, _Outer) and isinstance(cfg.inner, _Inner) + ctx.save_for_backward(x) + ctx.scale = cfg.inner.scale + + def _bwd(ctx, grad_out): + return grad_out * ctx.scale, None # whole nested dc => None + + @magi_register_custom_op( + name="test::dc_bwd_nested", + mutates_args=(), + setup_context_fn=_setup, + backward_fn=_bwd, + ) + def _op(x: torch.Tensor, cfg: _Outer) -> torch.Tensor: + return x * cfg.inner.scale + cfg.inner.bias + cfg.tag + + x = torch.randn(2, 4, requires_grad=True) + cfg = _Outer(inner=_Inner(scale=1.5, bias=0.25), tag=0.0) + y = _op(x, cfg) + (gx,) = torch.autograd.grad(y.sum(), x) + assert_close(gx, torch.full_like(x, 1.5)) + + def test_backward_two_tensor_inputs_around_dataclass(self): + """Two tensor inputs sandwiching a dataclass: both gradients must + land in the right flat slots.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + alpha: float + beta: float + + def _setup(ctx, inputs, output): + a, cfg, b = inputs + ctx.save_for_backward(a, b) + ctx.alpha = cfg.alpha + ctx.beta = cfg.beta + + def _bwd(ctx, grad_out): + a, b = ctx.saved_tensors + return grad_out * ctx.alpha, None, grad_out * ctx.beta + + @magi_register_custom_op( + name="test::dc_bwd_sandwich", + mutates_args=(), + setup_context_fn=_setup, + backward_fn=_bwd, + ) + def _op(a: torch.Tensor, cfg: _Cfg, b: torch.Tensor) -> torch.Tensor: + return a * cfg.alpha + b * cfg.beta + + a = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=True) + cfg = _Cfg(alpha=2.0, beta=-3.0) + y = _op(a, cfg, b) + ga, gb = torch.autograd.grad(y.sum(), [a, b]) + assert_close(ga, torch.full_like(a, 2.0)) + assert_close(gb, torch.full_like(b, -3.0)) + + def test_backward_wrong_grad_count_raises(self): + """If ``backward_fn`` returns the wrong number of grads (relative to + the ORIGINAL signature), the user gets a clear error rather than an + opaque autograd shape mismatch.""" + from dataclasses import dataclass + + @dataclass(frozen=True) + class _Cfg: + scale: float + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + + def _bad_bwd(ctx, grad_out): + return (grad_out,) # missing the dataclass slot + + @magi_register_custom_op( + name="test::dc_bwd_wrong_count", + mutates_args=(), + setup_context_fn=_setup, + backward_fn=_bad_bwd, + ) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.scale + + x = torch.randn(3, requires_grad=True) + y = _op(x, _Cfg(scale=2.0)) + with pytest.raises(ValueError, match=r"backward_fn returned 1 grad"): + torch.autograd.grad(y.sum(), x) + + +@dataclasses.dataclass(frozen=True) +class _BwTupleCfg: + s: float + + +class TestTupleOutputBackward: + """``backward_fn`` with ``tuple`` outputs and a dataclass input.""" + + def test_two_output_backward_with_dataclass(self): + def setup(ctx, inputs, output): + x, _cfg = inputs + ctx.save_for_backward(x) + ctx.s = _cfg.s + + def bwd(ctx, gy0, gy1): + (x,) = ctx.saved_tensors + # d(out0)/dx = s, d(out1)/dx = 1 + return gy0 * ctx.s + gy1, None + + @magi_register_custom_op( + name="test::tuple_bwd_dc", + setup_context_fn=setup, + backward_fn=bwd, + ) + def _op(x: torch.Tensor, cfg: _BwTupleCfg) -> tuple[torch.Tensor, torch.Tensor]: + return x * cfg.s, x.clone() + + x = torch.randn(4, requires_grad=True) + y0, y1 = _op(x, _BwTupleCfg(s=3.0)) + (y0.sum() + 2.0 * y1.sum()).backward() + assert x.grad is not None + # gy0 = 1, gy1 = 2 -> grad = 1*3 + 2 = 5 + assert torch.allclose(x.grad, torch.full_like(x, 5.0)) + + +class TestPerFieldNoneGrad: + """The autograd bridge supports returning a partially-None grad for a + dataclass input (i.e. some Tensor fields differentiable, others not). + """ + + def test_partial_none_grad_for_dataclass(self): + @dataclasses.dataclass(frozen=True) + class _Cfg: + w: torch.Tensor + b: torch.Tensor # we'll mark this non-differentiable + + def setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x, cfg.w) + + def bwd(ctx, gy): + x, w = ctx.saved_tensors + # x: gx = gy * w; cfg.w: gw = gy * x; cfg.b: None. + return gy * w, {"w": gy * x, "b": None} + + @magi_register_custom_op( + name="test::partial_none_grad", + setup_context_fn=setup, + backward_fn=bwd, + ) + def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.w + cfg.b + + x = torch.randn(4, requires_grad=True) + w = torch.randn(4, requires_grad=True) + b = torch.randn(4, requires_grad=False) + cfg = _Cfg(w=w, b=b) + out = _op(x, cfg) + out.sum().backward() + assert torch.allclose(x.grad, w) + assert torch.allclose(w.grad, x) + # b did not require grad; nothing to assert beyond "no exception". + + +class TestBackwardCallsAnotherOp: + """``backward_fn`` is allowed to call other registered ops to compute the + gradient. This is the FlashAttention-style "forward op + backward op" + decomposition. Verify the gradient is correct end-to-end. + """ + + def test_backward_dispatches_to_another_op(self): + @magi_register_custom_op(name="test::matmul_grad_helper") + def matmul_grad_x(gy: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + return gy @ w.t() + + def setup(ctx, inputs, output): + x, w = inputs + ctx.save_for_backward(x, w) + + def bwd(ctx, gy): + x, w = ctx.saved_tensors + return matmul_grad_x(gy, w), x.t() @ gy + + @magi_register_custom_op( + name="test::matmul_with_op_bwd", + setup_context_fn=setup, + backward_fn=bwd, + ) + def matmul(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: + return x @ w + + x = torch.randn(3, 4, requires_grad=True) + w = torch.randn(4, 5, requires_grad=True) + out = matmul(x, w) + out.sum().backward() + gy = torch.ones_like(out) + assert torch.allclose(x.grad, gy @ w.t()) + assert torch.allclose(w.grad, x.t() @ gy) + + +# ============================================================================ +# SECTION 4: Compositions & Python Semantics +# ============================================================================ + + +@dataclasses.dataclass(frozen=True) +class _KwCfg: + s: float + + +class TestKwargOnlyDataclass: + """``cfg`` passed as a kwarg-only argument should still be flattened and + routed correctly through the dataclass-aware bridge. + """ + + def test_kwarg_only_dataclass_call(self): + @magi_register_custom_op(name="test::kwonly_dc") + def _op(x: torch.Tensor, *, cfg: _KwCfg) -> torch.Tensor: + return x * cfg.s + + x = torch.randn(2, 3) + out = _op(x, cfg=_KwCfg(s=4.0)) + assert_close(out, x * 4.0) + + def test_kwarg_only_dataclass_with_backward(self): + def setup(ctx, inputs, output): + x, cfg = inputs + ctx.s = cfg.s + + def bwd(ctx, gy): + return gy * ctx.s, None + + @magi_register_custom_op( + name="test::kwonly_dc_bwd", + setup_context_fn=setup, + backward_fn=bwd, + ) + def _op(x: torch.Tensor, *, cfg: _KwCfg) -> torch.Tensor: + return x * cfg.s + + x = torch.randn(3, requires_grad=True) + out = _op(x, cfg=_KwCfg(s=2.5)) + out.sum().backward() + assert torch.allclose(x.grad, torch.full_like(x, 2.5)) + + +@dataclasses.dataclass(frozen=True) +class _MutCfg: + a: torch.Tensor + b: torch.Tensor + + +class TestMutatesArgsOnDataclass: + """``mutates_args=('cfg',)`` on a dataclass parameter should expand to the + flat leaf names automatically; an unknown name should raise. + """ + + def test_mutates_args_dataclass_expands(self): + # We don't actually mutate (frozen dataclass) -- we just check the + # registration succeeds, which it would not if the name failed to + # expand to flat tensor leaves. + @magi_register_custom_op( + name="test::mutates_dc_expand", + mutates_args=("cfg",), + ) + def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: + cfg.a.add_(x) # in-place on a tensor field + cfg.b.add_(x) + return cfg.a + cfg.b + + x = torch.ones(3) + cfg = _MutCfg(a=torch.zeros(3), b=torch.zeros(3)) + out = _op(x, cfg) + assert_close(out, torch.full((3,), 2.0)) + + def test_mutates_args_unknown_name_rejected(self): + with pytest.raises(ValueError, match="does not match"): + + @magi_register_custom_op( + name="test::mutates_dc_unknown", + mutates_args=("does_not_exist",), + ) + def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: + return x + + def test_mutates_args_flat_name_passthrough(self): + """Users may also use the flat name (``cfg__a``) directly.""" + + @magi_register_custom_op( + name="test::mutates_dc_flat", + mutates_args=("cfg__a",), + ) + def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: + cfg.a.add_(x) + return cfg.a + cfg.b + + x = torch.ones(3) + cfg = _MutCfg(a=torch.zeros(3), b=torch.zeros(3)) + out = _op(x, cfg) + assert_close(out, torch.ones(3)) + + +class TestNestedMagiOpCall: + """Calling one ``@magi_register_custom_op``-registered op from inside + another is the bread-and-butter compositional pattern (e.g. a fused + attention op that internally calls a registered softmax). Verify both + plain-input and dataclass-input compositions. + """ + + def test_plain_op_inside_plain_op(self): + @magi_register_custom_op(name="test::nested_inner_plain") + def inner(x: torch.Tensor) -> torch.Tensor: + return x * 2 + + @magi_register_custom_op(name="test::nested_outer_plain") + def outer(x: torch.Tensor) -> torch.Tensor: + return inner(x) + 1 + + x = torch.randn(4) + assert_close(outer(x), x * 2 + 1) + + def test_dataclass_op_inside_plain_op(self): + @dataclasses.dataclass(frozen=True) + class _Cfg: + s: float + + @magi_register_custom_op(name="test::nested_inner_dc") + def inner(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: + return x * cfg.s + + @magi_register_custom_op(name="test::nested_outer_calls_dc") + def outer(x: torch.Tensor) -> torch.Tensor: + return inner(x, _Cfg(s=3.0)) + 1 + + x = torch.randn(4) + assert_close(outer(x), x * 3.0 + 1) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/api_tests/test_register_triton_op.py b/tests/api_tests/test_register_triton_op.py new file mode 100644 index 0000000..af3db6d --- /dev/null +++ b/tests/api_tests/test_register_triton_op.py @@ -0,0 +1,1482 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +""" +This test suite covers the ``triton_op`` auto-detection and registration path. + +Coverage Matrix & Sections: +--------------------------- +SECTION 1: Direct Kernel Launch Patterns + - Flat direct kernel call (``kernel[grid](...)``) + - Multiple kernels in sequence + - Kernel launched inside a closure + - Multilevel nesting & Helper functions launching kernels + +SECTION 2: Wrapped, Dynamic & Exotic Retrievals + - Helper launchers (local, cross-module, 3rd party wrappers) + - ``wrap_triton`` idempotency (Mixing wrapped and bare kernels safely) + - Explicit ``extra_triton_kernels`` override & deduplication + - Staticmethod / Classmethod kernels + - Dynamically fetched / runtime-imported kernels + +SECTION 3: Autotune, Heuristics & Autograd in Triton + - ``@triton.autotune`` kernels (single & multiple configs) + - ``@triton.heuristics`` rejection & graceful fallback + - Autograd combined with Triton kernels + +SECTION 4: Dataclass & End-to-End Tracing + - Dataclass inputs navigating through to Triton kernels + - Nested dataclass & backward combination + - Pure Inductor see-through proof (AOT graph verification) +""" + +import pytest +import torch +from torch.testing import assert_close + +triton = pytest.importorskip("triton") +tl = pytest.importorskip("triton.language") + +from magi_compiler.api import magi_register_custom_op # noqa: E402 + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="triton kernels require CUDA" +) + + +# --------------------------------------------------------------------------- +# Module-level kernels (so they live in fn.__globals__ for several scenarios) +# --------------------------------------------------------------------------- + + +@triton.jit +def _cos_kernel(in_ptr0, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = tl.cos(x) + tl.store(out_ptr + offsets, output, mask=mask) + + +@triton.jit +def _scale_kernel(in_ptr0, out_ptr, n_elements, scale, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + output = x * scale + tl.store(out_ptr + offsets, output, mask=mask) + + +@triton.jit +def _add_kernel(a_ptr, b_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + a = tl.load(a_ptr + offsets, mask=mask) + b = tl.load(b_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, a + b, mask=mask) + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 128}, num_warps=4), + triton.Config({"BLOCK_SIZE": 256}, num_warps=4), + ], + key=["n_elements"], +) +@triton.jit +def _autotuned_cos_kernel(in_ptr0, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr0 + offsets, mask=mask) + tl.store(out_ptr + offsets, tl.cos(x), mask=mask) + + +def _grid_1d(n: int): + return ((n + 127) // 128,) + + +# Module-level frozen dataclass used by the dataclass+triton test below. We +# declare it at module scope (not inside the test method) so that +# ``typing.get_type_hints`` / ``eval`` on the function's stringified +# annotations (PEP 563 / ``from __future__ import annotations``) can find it +# via ``fn.__globals__``. +from dataclasses import dataclass as _dc_dataclass # noqa: E402 + + +@_dc_dataclass(frozen=True) +class _DcCosCfg: + block_size: int + + +# Nested dataclass fixtures for the nested-dataclass + triton tests below. +# Same module-scope rationale as ``_DcCosCfg``. + + +@_dc_dataclass(frozen=True) +class _DcKernelCfg: + block_size: int + extra_offset: float + + +@_dc_dataclass(frozen=True) +class _DcOuterCfg: + kernel: _DcKernelCfg + scale: float + + +@_dc_dataclass(frozen=True) +class _DcShapeCfg: + out_dim: int + + +@_dc_dataclass(frozen=True) +class _DcProjCfg: + shape: _DcShapeCfg + block_size: int + + +# that by defining it at module scope but in its own helper that fn calls. + + +def _scale_launcher(x: torch.Tensor, factor: float) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _scale_kernel[_grid_1d(n)](x, out, n, factor, BLOCK_SIZE=128) + return out + + +def _add_launcher(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(a) + n = a.numel() + _add_kernel[_grid_1d(n)](a, b, out, n, BLOCK_SIZE=128) + return out + + +def _make_cos_kernel(): + @triton.jit + def _kernel(in_ptr, out_ptr, n_elements, BLOCK: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK + offsets = block_start + tl.arange(0, BLOCK) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, tl.cos(x), mask=mask) + + return _kernel + + +def _inner_launcher(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + +def _dispatch_launcher(x: torch.Tensor) -> torch.Tensor: + return _inner_launcher(x) + + +@triton.heuristics({"BLOCK_SIZE": lambda args: 128}) +@triton.jit +def _heuristics_top_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, x, mask=mask) + + +@triton.autotune( + configs=[triton.Config({}, num_warps=4)], + key=["n_elements"], +) +@triton.heuristics({"BLOCK_SIZE": lambda args: 128}) +@triton.jit +def _autotune_then_heuristics_kernel( + in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr +): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, x, mask=mask) + + +class _KernelHolder: + """Holder with kernel exposed via classmethod / staticmethod. The + introspector cannot statically follow ``Holder.get()`` to a kernel at + decoration time; users must use ``extra_triton_kernels=`` instead. + """ + + @staticmethod + def get_static(): + return _scale_kernel + + @classmethod + def get_class(cls): + return _scale_kernel + + +# ============================================================================ +# SECTION 1: Direct Kernel Launch Patterns +# ============================================================================ + + +class TestFlatDirectKernel: + def test_basic_cos(self): + @magi_register_custom_op(name="magi_test::flat_cos") + def mycos(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + out = mycos(x) + assert_close(out, torch.cos(x), atol=1e-5, rtol=1e-5) + + def test_op_is_triton_op(self): + """Sanity: the registered op should be a triton_op-style CustomOpDef + and torch.compile should be able to see through it.""" + + @magi_register_custom_op(name="magi_test::seethrough_cos") + def mycos(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + compiled = torch.compile(mycos, backend="inductor", fullgraph=True) + x = torch.randn(2048, device="cuda") + out = compiled(x) + assert_close(out, torch.cos(x), atol=1e-5, rtol=1e-5) + + +class TestMultiKernelSequence: + def test_chain(self): + @magi_register_custom_op(name="magi_test::cos_then_scale") + def fn(x: torch.Tensor, scale: float) -> torch.Tensor: + tmp = torch.empty_like(x) + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, tmp, n, BLOCK_SIZE=128) + _scale_kernel[_grid_1d(n)](tmp, out, n, scale, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + out = fn(x, 2.5) + assert_close(out, torch.cos(x) * 2.5, atol=1e-5, rtol=1e-5) + + +class TestKernelInsideClosure: + def test_closure_kernel(self): + def make_op(kernel): + @magi_register_custom_op(name=f"magi_test::closure_{id(kernel)}") + def op(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + kernel[_grid_1d(n)](x, out, n, BLOCK=128) + return out + + return op + + kernel = _make_cos_kernel() + op = make_op(kernel) + x = torch.randn(2048, device="cuda") + assert_close(op(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + +# extra_triton_kernels escape hatch (scenario 9-style: kernel hidden behind +# an attribute access the introspector cannot trace). + + +class TestMultiLevelNesting: + def test_fn_to_dispatch_to_launcher_to_kernel(self): + @magi_register_custom_op(name="magi_test::multi_level_cos") + def fn(x: torch.Tensor) -> torch.Tensor: + return _dispatch_launcher(x) + + x = torch.randn(2048, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + def test_introspection_walks_all_levels(self): + from torch._higher_order_ops.triton_kernel_wrap import ( + TraceableTritonKernelWrapper, + ) + + from magi_compiler._triton_introspect import ( + get_inner_triton_kernels, + rewrite_fn_with_wrap_triton, + ) + + def fn(x): + return _dispatch_launcher(x) + + kernels = get_inner_triton_kernels(fn) + assert _cos_kernel in kernels + + rewritten = rewrite_fn_with_wrap_triton(fn, kernels) + rebuilt_dispatch = rewritten.__globals__["_dispatch_launcher"] + rebuilt_inner = rebuilt_dispatch.__globals__["_inner_launcher"] + assert isinstance( + rebuilt_inner.__globals__["_cos_kernel"], TraceableTritonKernelWrapper + ) + + +# Third-party "thin wrapper" pattern: some libraries return objects with a +# ``.fn`` attribute pointing at the underlying triton kernel; the introspector +# already knows how to unwrap that, so kernels invoked via +# ``maybe_capture(kernel)[grid](...)`` should still register as a triton_op. + + +class TestFactoryInsideFn: + def test_factory_inside_fn_runtime(self): + @magi_register_custom_op(name="magi_test::factory_inside_fn") + def fn(x: torch.Tensor) -> torch.Tensor: + kernel = _make_cos_kernel() + out = torch.empty_like(x) + n = x.numel() + kernel[_grid_1d(n)](x, out, n, BLOCK=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + +# True cross-module launcher: helpers and kernels live in +# ``tests/api_tests/_triton_external_helpers.py``. The decorated function +# imports them, so ``_rebuild`` has to descend into a helper whose +# ``__globals__`` is a *different* module dict than ``fn.__globals__``. + + +class TestNnModuleSelfKernel: + def test_kernel_on_self(self): + from torch import nn + + class CosModule(nn.Module): + def __init__(self, kernel): + super().__init__() + self._kernel = kernel + self.fn = self._build_fn() + + def _build_fn(self): + kernel = self._kernel + + @magi_register_custom_op( + name=f"magi_test::module_self_kernel_{id(self)}", + extra_triton_kernels=[kernel], + ) + def op(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + return op + + def forward(self, x): + return self.fn(x) + + mod = CosModule(_cos_kernel).to("cuda") + x = torch.randn(1024, device="cuda") + assert_close(mod(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + +# Factory created *inside* fn (kernel is a local variable, not a closure +# captured from outside). The introspector detects the bare ``kernel[grid]`` +# call but the actual kernel object lives only in the runtime locals, so +# rewrite has nothing to shadow. This must still execute correctly because +# ``wrap_triton`` is optional for runtime correctness (only required for +# torch.compile traceability). + + +# ============================================================================ +# SECTION 2: Wrapped, Dynamic & Exotic Retrievals +# ============================================================================ + + +class TestHelperLauncher: + def test_helper_launcher(self): + @magi_register_custom_op(name="magi_test::add_via_launcher") + def add_op(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return _add_launcher(a, b) + + a = torch.randn(2048, device="cuda") + b = torch.randn(2048, device="cuda") + assert_close(add_op(a, b), a + b, atol=1e-5, rtol=1e-5) + + +class TestCrossModuleLauncher: + def test_scale_via_external_launcher(self): + @magi_register_custom_op(name="magi_test::scale_via_external") + def scale_op(x: torch.Tensor, factor: float) -> torch.Tensor: + return _scale_launcher(x, factor) + + x = torch.randn(2048, device="cuda") + assert_close(scale_op(x, 0.25), x * 0.25, atol=1e-5, rtol=1e-5) + + +class TestThirdPartyThinWrapper: + def test_thin_wrapper_kernel(self): + from tests.api_tests._triton_external_helpers import maybe_capture + + @magi_register_custom_op( + name="magi_test::cos_via_thin_wrapper", + # Even though the introspector handles ``.fn``-style wrappers, we + # also pass the raw kernel as ``extra_triton_kernels`` to confirm + # the deduplication path works with this style of call. + extra_triton_kernels=[_cos_kernel], + ) + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + wrapped = maybe_capture(_cos_kernel) + wrapped[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + +# wrap_triton idempotency: if the user already wrote ``wrap_triton(kernel)`` +# explicitly, we must not produce a wrap_triton(wrap_triton(kernel)). + + +class TestTrueCrossModuleLauncher: + def test_external_neg_launcher(self): + from tests.api_tests._triton_external_helpers import ( + external_neg_launcher, + ) + + @magi_register_custom_op(name="magi_test::true_cross_module_neg") + def fn(x: torch.Tensor) -> torch.Tensor: + return external_neg_launcher(x) + + x = torch.randn(2048, device="cuda") + assert_close(fn(x), -x, atol=1e-5, rtol=1e-5) + + def test_rewrite_descends_into_other_module(self): + from torch._higher_order_ops.triton_kernel_wrap import ( + TraceableTritonKernelWrapper, + ) + + from magi_compiler._triton_introspect import ( + get_inner_triton_kernels, + rewrite_fn_with_wrap_triton, + ) + from tests.api_tests._triton_external_helpers import ( + external_double_kernel, + external_double_launcher, + ) + + def fn(x): + # Bare Name call so the introspector can follow it across modules + # via ``called_functions``. + return external_double_launcher(x) + + kernels = get_inner_triton_kernels(fn) + assert external_double_kernel in kernels + + rewritten = rewrite_fn_with_wrap_triton(fn, kernels) + + # ``external_double_launcher`` was captured from the enclosing test + # method's locals, so it lives in ``fn``'s closure (NOT in + # ``__globals__``). The rewrite pass must still descend into it and + # produce a rebuilt copy whose globals reference the wrap_triton- + # aware kernel. + rebuilt_launcher = None + for cell in rewritten.__closure__ or (): + try: + contents = cell.cell_contents + except ValueError: + continue + if callable(contents) and getattr(contents, "__name__", None) == ( + "external_double_launcher" + ): + rebuilt_launcher = contents + break + assert rebuilt_launcher is not None, ( + "expected rewrite_fn_with_wrap_triton to keep the launcher in " + "the rewritten function's closure" + ) + assert isinstance( + rebuilt_launcher.__globals__["external_double_kernel"], + TraceableTritonKernelWrapper, + ), ( + "rewrite_fn_with_wrap_triton should rebuild cross-module helpers " + "so the kernel reference inside them is wrap_triton-aware." + ) + + # The ORIGINAL helper module's globals must NOT be mutated; only the + # rebuilt copy carries the wrapper. + from tests.api_tests import _triton_external_helpers as ext_mod + + assert not isinstance( + ext_mod.external_double_launcher.__globals__["external_double_kernel"], + TraceableTritonKernelWrapper, + ), ( + "rewrite_fn_with_wrap_triton must not mutate the helper's home " + "module globals (other unrelated callers would be affected)." + ) + + +# Nested dataclass + triton kernel: dataclass-of-dataclass arguments must be +# fully flattened so that ``infer_schema`` only sees primitive types, while +# the triton_op path still sees through the kernel. + + +class TestMixedWrappedAndBareKernels: + """When the user has manually wrapped some kernels with ``wrap_triton`` + but left others bare (a common state during incremental migration), the + decorator must wrap only the bare ones (no double-wrap) and the op must + still run. + """ + + def test_mixed_wrapped_and_bare(self): + from torch.library import wrap_triton + + @magi_register_custom_op(name="magi_test::mixed_wrap_state") + def myop(x: torch.Tensor) -> torch.Tensor: + n = x.numel() + mid = torch.empty_like(x) + wrap_triton(_cos_kernel)[_grid_1d(n)](x, mid, n, BLOCK_SIZE=128) + out = torch.empty_like(x) + _scale_kernel[_grid_1d(n)](mid, out, n, 2.0, BLOCK_SIZE=128) + return out + + x = torch.randn(512, device="cuda") + out = myop(x) + assert_close(out, torch.cos(x) * 2.0, atol=1e-5, rtol=1e-5) + + +class TestWrapTritonIdempotent: + def test_user_already_wrapped(self): + from torch.library import wrap_triton + + @magi_register_custom_op(name="magi_test::cos_user_wrapped") + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + wrap_triton(_cos_kernel)[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + def test_rewrite_does_not_double_wrap(self): + """Direct unit test: passing the already-wrapped kernel back through + ``rewrite_fn_with_wrap_triton`` must not produce a double wrapper.""" + from torch._higher_order_ops.triton_kernel_wrap import ( + TraceableTritonKernelWrapper, + ) + from torch.library import wrap_triton + + from magi_compiler._triton_introspect import rewrite_fn_with_wrap_triton + + wrapped_kernel = wrap_triton(_cos_kernel) + + def fn(x): + out = torch.empty_like(x) + n = x.numel() + wrapped_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + # Pass the wrapped kernel as the "kernels" argument; the rewrite path + # should pass it through ``_resolve_kernel`` and not re-wrap. + rewritten = rewrite_fn_with_wrap_triton(fn, [wrapped_kernel]) + # The closure cell for ``wrapped_kernel`` (or the rebuilt globals + # entry, depending on closure capture order) must still be a single + # TraceableTritonKernelWrapper, not nested. + seen = [] + if rewritten.__closure__ is not None: + for cell in rewritten.__closure__: + try: + seen.append(cell.cell_contents) + except ValueError: + pass + seen.extend(rewritten.__globals__.values()) + wrappers = [v for v in seen if isinstance(v, TraceableTritonKernelWrapper)] + assert wrappers, "expected at least one wrap_triton wrapper to be present" + for w in wrappers: + inner = getattr(w, "kernel", None) or getattr(w, "fn", None) + assert not isinstance( + inner, TraceableTritonKernelWrapper + ), "rewrite_fn_with_wrap_triton produced a double-wrapped kernel" + + +# infer_output_meta_fn override: both the ``list[str]`` shorthand and the +# explicit ``Callable`` form should be honoured even when we go down the +# triton_op path (because triton_op pre-registers ``fn`` itself as the fake). + + +class TestExtraTritonKernels: + def test_explicit_kernel_list(self): + kernels_holder = type("KH", (), {})() + kernels_holder.k = _cos_kernel + + @magi_register_custom_op( + name="magi_test::cos_via_extra", + extra_triton_kernels=[_cos_kernel], + ) + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + kernels_holder.k[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + +# Fallback: no triton kernels => still works (custom_op path). + + +class TestExtraTritonKernelsDedup: + def test_dedup_in_resolve_and_rewrite(self): + from magi_compiler._magi_register_custom_op import _resolve_triton_kernels + from magi_compiler._triton_introspect import rewrite_fn_with_wrap_triton + + def fn(x): + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + resolved_all, resolved_bare, _user_wrapped_ids = _resolve_triton_kernels( + fn, [_cos_kernel] + ) + # Should appear exactly once even though it's both passed explicitly + # and discovered by introspection. + assert resolved_all.count(_cos_kernel) == 1 + assert len(resolved_all) == 1 + assert resolved_bare.count(_cos_kernel) == 1 + assert len(resolved_bare) == 1 + + rewritten = rewrite_fn_with_wrap_triton(fn, resolved_bare) + from torch._higher_order_ops.triton_kernel_wrap import ( + TraceableTritonKernelWrapper, + ) + + wrapped = rewritten.__globals__["_cos_kernel"] + assert isinstance(wrapped, TraceableTritonKernelWrapper) + + def test_dedup_e2e(self): + @magi_register_custom_op( + name="magi_test::dedup_cos", + extra_triton_kernels=[_cos_kernel], + ) + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + # Confirm we still went down the triton_op path even though the kernel + # was specified twice (auto-detected + extra_triton_kernels). + assert TestTritonOpRegistryAssertion._was_registered_as_triton_op( + "magi_test::dedup_cos" + ), "expected the op to be registered as a triton_op" + + +# Dataclass input + triton kernel inside the body. Verifies the dataclass- +# aware path also routes ``inner_fn`` through ``rewrite_fn_with_wrap_triton`` +# when triton kernels are present. + + +class TestExtraTritonKernelsForStaticOrClassmethod: + """``staticmethod`` / ``classmethod`` selectors are opaque to source + introspection. ``extra_triton_kernels`` keeps the op on the triton_op + path even so. + """ + + def test_staticmethod_selected_kernel(self): + @magi_register_custom_op( + name="magi_test::sm_kernel", + extra_triton_kernels=[_scale_kernel], + ) + def myop(x: torch.Tensor) -> torch.Tensor: + kernel = _KernelHolder.get_static() + out = torch.empty_like(x) + n = x.numel() + kernel[_grid_1d(n)](x, out, n, 2.0, BLOCK_SIZE=128) + return out + + x = torch.randn(256, device="cuda") + out = myop(x) + assert_close(out, x * 2.0) + + def test_classmethod_selected_kernel(self): + @magi_register_custom_op( + name="magi_test::cm_kernel", + extra_triton_kernels=[_scale_kernel], + ) + def myop(x: torch.Tensor) -> torch.Tensor: + kernel = _KernelHolder.get_class() + out = torch.empty_like(x) + n = x.numel() + kernel[_grid_1d(n)](x, out, n, 3.0, BLOCK_SIZE=128) + return out + + x = torch.randn(256, device="cuda") + out = myop(x) + assert_close(out, x * 3.0) + + +class TestExtraTritonKernelsForRuntimeImport: + """A kernel imported inside the function body (runtime import) is invisible + to source introspection. ``extra_triton_kernels`` works around that. + """ + + def test_runtime_imported_kernel(self): + # The kernel object lives at module scope (we can't actually do a fresh + # ``import`` in a way that hides it from source scanning AND lets the + # function still call it). Simulate the runtime-import case by stuffing + # the kernel into a local ``import``-like alias derived from globals, + # so source introspection cannot statically resolve it. + @magi_register_custom_op( + name="magi_test::runtime_import_kernel", + extra_triton_kernels=[_cos_kernel], + ) + def myop(x: torch.Tensor) -> torch.Tensor: + module_globals = globals() + # Indirect lookup hides the kernel from static introspection of + # ``myop``'s globals/closure. + kernel = module_globals["_cos_kernel"] + out = torch.empty_like(x) + n = x.numel() + kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(256, device="cuda") + out = myop(x) + assert_close(out, torch.cos(x)) + + +# #21: dataclass input + is_compute_sensitive on triton path + + +class TestNoTritonFallback: + def test_no_kernel_uses_custom_op(self): + @magi_register_custom_op(name="magi_test::pure_python_op") + def fn(x: torch.Tensor) -> torch.Tensor: + return x * 2 + 1 + + x = torch.randn(8, 8) + assert_close(fn(x), x * 2 + 1) + + +# Triton path + autograd combination. + + +class TestIntrospection: + def test_get_inner_triton_kernels_flat(self): + from magi_compiler._triton_introspect import get_inner_triton_kernels + + def fn(x): + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + kernels = get_inner_triton_kernels(fn) + assert _cos_kernel in kernels + + def test_get_inner_triton_kernels_nested(self): + from magi_compiler._triton_introspect import get_inner_triton_kernels + + def fn(a, b): + return _add_launcher(a, b) + + kernels = get_inner_triton_kernels(fn) + assert _add_kernel in kernels + + def test_rewrite_replaces_kernel_with_wrap_triton(self): + from torch._higher_order_ops.triton_kernel_wrap import ( + TraceableTritonKernelWrapper, + ) + + from magi_compiler._triton_introspect import ( + get_inner_triton_kernels, + rewrite_fn_with_wrap_triton, + ) + + def fn(x): + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + kernels = get_inner_triton_kernels(fn) + rewritten = rewrite_fn_with_wrap_triton(fn, kernels) + + # _cos_kernel name in the rewritten globals should now point to a + # TraceableTritonKernelWrapper, not the bare JITFunction. + assert isinstance( + rewritten.__globals__["_cos_kernel"], TraceableTritonKernelWrapper + ) + # Originals untouched. + from triton.runtime.jit import JITFunction + + assert isinstance(_cos_kernel, JITFunction) + + def test_rewrite_propagates_through_helpers(self): + from torch._higher_order_ops.triton_kernel_wrap import ( + TraceableTritonKernelWrapper, + ) + + from magi_compiler._triton_introspect import ( + get_inner_triton_kernels, + rewrite_fn_with_wrap_triton, + ) + + def fn(a, b): + return _add_launcher(a, b) + + kernels = get_inner_triton_kernels(fn) + rewritten = rewrite_fn_with_wrap_triton(fn, kernels) + + rebuilt_launcher = rewritten.__globals__["_add_launcher"] + assert isinstance( + rebuilt_launcher.__globals__["_add_kernel"], TraceableTritonKernelWrapper + ) + + +# Multi-level nesting: fn -> dispatch -> launcher -> kernel. +# Verifies that kernels several call-graph hops away are still detected and +# that ``rewrite_fn_with_wrap_triton`` rebuilds every helper along the path. + + +class TestInferOutputMetaOverride: + def test_meta_list_form(self): + @magi_register_custom_op( + name="magi_test::triton_meta_list", + infer_output_meta_fn=["x"], + ) + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + # And inside torch.compile (forces the fake/meta path to be used). + compiled = torch.compile(fn, backend="inductor", fullgraph=True) + assert_close(compiled(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + def test_meta_callable_form(self): + called = {"count": 0} + + def custom_meta(x: torch.Tensor) -> torch.Tensor: + called["count"] += 1 + return torch.empty_like(x) + + @magi_register_custom_op( + name="magi_test::triton_meta_callable", + infer_output_meta_fn=custom_meta, + ) + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + compiled = torch.compile(fn, backend="inductor", fullgraph=True) + assert_close(compiled(x), torch.cos(x), atol=1e-5, rtol=1e-5) + # Tracing through torch.compile should have invoked the user-provided + # meta at least once. + assert called["count"] >= 1 + + +# Explicit registry-level assertion that we actually went down the +# ``torch.library.triton_op`` path (i.e. Inductor would be able to inline +# the kernel), distinguishing it from the silent custom_op fallback. + + +class TestTritonOpRegistryAssertion: + """Verify we actually take the ``torch.library.triton_op`` registration + path (so Inductor / make_fx can see through the op) instead of silently + falling back to plain ``custom_op`` (which would be opaque).""" + + @staticmethod + def _was_registered_as_triton_op(op_or_name) -> bool: + # ``triton_op`` installs a torch_dispatch on FunctionalTensorMode that + # decomposes the op into ``triton_kernel_wrapper_mutation`` calls. + # Plain ``custom_op`` does not. + from torch._library.custom_ops import OPDEFS + from torch._subclasses.functional_tensor import FunctionalTensorMode + + if isinstance(op_or_name, str): + opdef = OPDEFS.get(op_or_name) + if opdef is None: + return False + else: + opdef = op_or_name + dispatch_fns = getattr(opdef, "_torch_dispatch_fns", {}) or {} + return FunctionalTensorMode in dispatch_fns + + def test_registered_as_triton_op(self): + @magi_register_custom_op(name="magi_test::registry_cos") + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + assert self._was_registered_as_triton_op("magi_test::registry_cos"), ( + "magi_test::registry_cos should have been registered via " + "torch.library.triton_op (so make_fx decomposes it into " + "triton_kernel_wrapper_mutation), not via plain custom_op." + ) + + def test_pure_python_op_not_registered_as_triton(self): + @magi_register_custom_op(name="magi_test::registry_pure_python") + def fn(x: torch.Tensor) -> torch.Tensor: + return x * 2 + 1 + + assert not self._was_registered_as_triton_op( + "magi_test::registry_pure_python" + ), ( + "magi_test::registry_pure_python has no triton kernels; it should " + "have fallen back to the custom_op path and remain opaque to " + "make_fx." + ) + + +# extra_triton_kernels deduplication: a kernel that is *both* auto-detected +# and listed in ``extra_triton_kernels`` should appear exactly once after +# resolution and must not be wrap_triton-wrapped twice. + + +# ============================================================================ +# SECTION 3: Autotune, Heuristics & Autograd in Triton +# ============================================================================ + + +class TestAutotuneKernels: + def test_autotuned(self): + @magi_register_custom_op(name="magi_test::autotuned_cos") + def fn(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + # autotuner picks BLOCK_SIZE; grid uses meta lambda + _autotuned_cos_kernel[(triton.cdiv(n, 128),)](x, out, n) + return out + + x = torch.randn(2048, device="cuda") + assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) + + +class TestMultipleAutotuneKernelsSameOp: + """A single op may launch several differently-autotuned kernels (a common + FlashAttention / Mamba pattern). Verify both kernels are detected and the + op runs end-to-end through the triton_op path. + """ + + def test_two_autotune_kernels_in_same_op(self): + # Build a *second* autotuned kernel locally so we can be sure both + # kernel objects appear in the op's call graph. + @triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 128}, num_warps=4), + triton.Config({"BLOCK_SIZE": 256}, num_warps=4), + ], + key=["n_elements"], + ) + @triton.jit + def _autotuned_sin_kernel( + in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(in_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, tl.sin(x), mask=mask) + + @magi_register_custom_op( + name="magi_test::two_autotune_kernels", + extra_triton_kernels=[_autotuned_sin_kernel], + ) + def myop(x: torch.Tensor) -> torch.Tensor: + n = x.numel() + mid = torch.empty_like(x) + _autotuned_cos_kernel[_grid_1d(n)](x, mid, n) + out = torch.empty_like(x) + _autotuned_sin_kernel[_grid_1d(n)](mid, out, n) + return out + + x = torch.randn(2048, device="cuda") + out = myop(x) + assert_close(out, torch.sin(torch.cos(x)), atol=1e-4, rtol=1e-4) + + +class TestHeuristicsRejection: + """``torch.library.wrap_triton`` only accepts ``JITFunction`` and + ``Autotuner``. A top-level ``@triton.heuristics`` produces a + ``Heuristics`` instance that fails ``wrap_triton`` with a confusing + error. ``@magi_register_custom_op`` rejects this case up front with a + clearer message, while still accepting the recommended layering of + ``@triton.autotune -> @triton.heuristics -> @triton.jit``. + """ + + def test_top_level_heuristics_rejected_with_clear_message(self): + """Bare ``@triton.heuristics`` on a kernel referenced from the op + body must be rejected at registration time, not deep inside + ``wrap_triton``.""" + with pytest.raises(RuntimeError, match="triton.heuristics"): + + @magi_register_custom_op(name="magi_test::heuristics_top") + def myop(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _heuristics_top_kernel[_grid_1d(n)](x, out, n) + return out + + def test_top_level_heuristics_via_extra_triton_kernels_rejected(self): + """Same constraint applies when the user passes the offending kernel + through the ``extra_triton_kernels`` escape hatch (no auto-detection + involved).""" + with pytest.raises(RuntimeError, match="triton.heuristics"): + + @magi_register_custom_op( + name="magi_test::heuristics_extra", + extra_triton_kernels=[_heuristics_top_kernel], + ) + def myop(x: torch.Tensor) -> torch.Tensor: + # Body doesn't reference the kernel at all; rejection comes + # purely from the extra_triton_kernels list. + return x.clone() + + def test_autotune_outside_heuristics_is_accepted(self): + """The recommended layering ``@triton.autotune -> @triton.heuristics + -> @triton.jit`` produces an ``Autotuner`` at the top level and is + accepted (and end-to-end functional).""" + + @magi_register_custom_op(name="magi_test::autotune_over_heuristics") + def myop(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _autotune_then_heuristics_kernel[_grid_1d(n)](x, out, n) + return out + + x = torch.randn(512, device="cuda") + out = myop(x) + assert_close(out, x) + + +# #15 / #16: kernels not statically discoverable -> extra_triton_kernels= + + +class TestTritonWithAutograd: + def test_triton_with_backward(self): + def setup_ctx(ctx, inputs, output): + (x,) = inputs + ctx.save_for_backward(x) + + def backward(ctx, grad_out): + (x,) = ctx.saved_tensors + # d/dx cos(x) = -sin(x) + return grad_out * (-torch.sin(x)) + + @magi_register_custom_op( + name="magi_test::triton_cos_grad", + setup_context_fn=setup_ctx, + backward_fn=backward, + ) + def mycos(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + x = torch.randn(1024, device="cuda", requires_grad=True) + out = mycos(x) + loss = out.sum() + loss.backward() + assert_close(x.grad, -torch.sin(x.detach()), atol=1e-5, rtol=1e-5) + + +# Direct unit tests for the introspection / rewrite helpers. + + +# ============================================================================ +# SECTION 4: Dataclass & End-to-End Tracing +# ============================================================================ + + +class TestDataclassWithTritonKernel: + def test_dataclass_input_with_triton(self): + # Use the module-level _DcCosCfg dataclass declared near the top of + # this test module so the type is visible from ``fn.__globals__`` + # even with ``from __future__ import annotations`` in effect (which + # turns every annotation into a string at runtime). + + @magi_register_custom_op(name="magi_test::dc_cos") + def fn(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=cfg.block_size) + return out + + x = torch.randn(1024, device="cuda") + cfg = _DcCosCfg(block_size=128) + assert_close(fn(x, cfg), torch.cos(x), atol=1e-5, rtol=1e-5) + + # The dataclass-aware path registers an *inner* op under the requested + # name. That inner op should still be a triton_op. + assert TestTritonOpRegistryAssertion._was_registered_as_triton_op( + "magi_test::dc_cos" + ), ( + "dataclass+triton path should still register the inner op as a " + "triton_op so Inductor can see through it." + ) + + +# nn.Module owning the kernel as ``self.kernel``: the introspector cannot +# see through ``self.kernel[grid](...)`` (Subscripted Attribute, not a plain +# Name), so the user must use ``extra_triton_kernels``. We assert that the +# escape hatch makes this work end-to-end. + + +class TestNestedDataclassWithTritonKernel: + def test_two_level_nested_dc_with_triton(self): + """Outer dataclass containing an inner dataclass; both are unwrapped + into ``cfg__kernel__block_size``, ``cfg__kernel__extra_offset`` and + ``cfg__scale`` flat parameters.""" + + @magi_register_custom_op(name="magi_test::nested_dc_cos_scale") + def fn(x: torch.Tensor, cfg: _DcOuterCfg) -> torch.Tensor: + tmp = torch.empty_like(x) + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, tmp, n, BLOCK_SIZE=cfg.kernel.block_size) + _scale_kernel[_grid_1d(n)]( + tmp, out, n, cfg.scale, BLOCK_SIZE=cfg.kernel.block_size + ) + return out + cfg.kernel.extra_offset + + x = torch.randn(1024, device="cuda") + cfg = _DcOuterCfg( + kernel=_DcKernelCfg(block_size=128, extra_offset=0.5), scale=2.5 + ) + out = fn(x, cfg) + expected = torch.cos(x) * 2.5 + 0.5 + assert_close(out, expected, atol=1e-5, rtol=1e-5) + + # Sanity-check the flat plan. + plan = fn._magi_flat_plan + cfg_node = plan[1] + assert cfg_node[0] == "dataclass" and cfg_node[1] == "cfg" + flat_names: list[str] = [] + + def _collect(node): + if node[0] == "primitive": + flat_names.append(node[2]) + else: + for child in node[3]: + _collect(child) + + _collect(cfg_node) + assert { + "cfg__kernel__block_size", + "cfg__kernel__extra_offset", + "cfg__scale", + }.issubset(flat_names) + + # And: the registered op should still go through triton_op. + assert TestTritonOpRegistryAssertion._was_registered_as_triton_op( + "magi_test::nested_dc_cos_scale" + ), ( + "nested-dataclass + triton path should still register the inner " + "op as a triton_op so Inductor can see through it." + ) + + def test_nested_dc_with_triton_and_meta_fn(self): + """User-supplied meta function expressed in nested-dataclass terms, + combined with a triton kernel call.""" + + def _meta(x: torch.Tensor, cfg: _DcProjCfg) -> torch.Tensor: + return x.new_empty((*x.shape[:-1], cfg.shape.out_dim)) + + @magi_register_custom_op( + name="magi_test::nested_dc_cos_proj", + infer_output_meta_fn=_meta, + ) + def fn(x: torch.Tensor, cfg: _DcProjCfg) -> torch.Tensor: + sliced = x[..., : cfg.shape.out_dim].contiguous() + out = torch.empty_like(sliced) + n = sliced.numel() + _cos_kernel[_grid_1d(n)](sliced, out, n, BLOCK_SIZE=cfg.block_size) + return out + + x = torch.randn(2, 8, device="cuda") + cfg = _DcProjCfg(shape=_DcShapeCfg(out_dim=3), block_size=128) + out = fn(x, cfg) + expected = torch.cos(x[..., :3].contiguous()) + assert out.shape == (2, 3) + assert_close(out, expected, atol=1e-5, rtol=1e-5) + + +# Triton kernel + dataclass input + backward_fn: the dataclass-aware path +# must bridge ``setup_context`` and ``backward`` so that the user can keep +# writing them against the ORIGINAL signature (with the dataclass instance) +# while underneath the registered op is flat. + + +class TestDataclassWithTritonKernelAndBackward: + def test_triton_dc_backward_basic(self): + """End-to-end backward against a dc + triton op: use the cos kernel + (analytical grad: -sin(x) * cfg.factor) so we can verify exact grads.""" + + def _setup(ctx, inputs, output): + x, cfg = inputs + assert isinstance(cfg, _DcCosCfg) + ctx.save_for_backward(x) + ctx.block_size = cfg.block_size + + def _bwd(ctx, grad_out): + (x,) = ctx.saved_tensors + return grad_out * (-torch.sin(x)), None + + @magi_register_custom_op( + name="magi_test::dc_cos_grad", + setup_context_fn=_setup, + backward_fn=_bwd, + ) + def mycos(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=cfg.block_size) + return out + + x = torch.randn(1024, device="cuda", requires_grad=True) + cfg = _DcCosCfg(block_size=128) + out = mycos(x, cfg) + out.sum().backward() + assert_close(x.grad, -torch.sin(x.detach()), atol=1e-5, rtol=1e-5) + + # Sanity: this op should still have gone through the triton_op path + # (the dataclass-aware path registers an inner op under ``op_name``). + assert TestTritonOpRegistryAssertion._was_registered_as_triton_op( + "magi_test::dc_cos_grad" + ) + + def test_triton_nested_dc_backward(self): + """Nested dataclass + triton + backward. The bridge must spread the + whole-nested-dc ``None`` grad over every flat slot under that + dataclass.""" + + def _setup(ctx, inputs, output): + x, cfg = inputs + assert isinstance(cfg, _DcOuterCfg) + assert isinstance(cfg.kernel, _DcKernelCfg) + ctx.save_for_backward(x) + ctx.scale = cfg.scale + + def _bwd(ctx, grad_out): + (x,) = ctx.saved_tensors + # d/dx (cos(x) * scale + offset) = -sin(x) * scale + return grad_out * (-torch.sin(x)) * ctx.scale, None + + @magi_register_custom_op( + name="magi_test::nested_dc_cos_grad", + setup_context_fn=_setup, + backward_fn=_bwd, + ) + def fn(x: torch.Tensor, cfg: _DcOuterCfg) -> torch.Tensor: + tmp = torch.empty_like(x) + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, tmp, n, BLOCK_SIZE=cfg.kernel.block_size) + _scale_kernel[_grid_1d(n)]( + tmp, out, n, cfg.scale, BLOCK_SIZE=cfg.kernel.block_size + ) + return out + cfg.kernel.extra_offset + + x = torch.randn(1024, device="cuda", requires_grad=True) + cfg = _DcOuterCfg( + kernel=_DcKernelCfg(block_size=128, extra_offset=0.5), scale=2.5 + ) + out = fn(x, cfg) + out.sum().backward() + expected = -torch.sin(x.detach()) * 2.5 + assert_close(x.grad, expected, atol=1e-5, rtol=1e-5) + + def test_triton_dc_backward_with_per_field_grad(self): + """User returns per-field grads (as a same-shape dataclass with + ``None`` leaves) for the dc slot. The triton path must still work.""" + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + ctx.block_size = cfg.block_size + + def _bwd(ctx, grad_out): + (x,) = ctx.saved_tensors + return ( + grad_out * (-torch.sin(x)), + _DcCosCfg(block_size=None), + ) + + @magi_register_custom_op( + name="magi_test::dc_cos_per_field_grad", + setup_context_fn=_setup, + backward_fn=_bwd, + ) + def mycos(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=cfg.block_size) + return out + + x = torch.randn(512, device="cuda", requires_grad=True) + out = mycos(x, _DcCosCfg(block_size=128)) + out.sum().backward() + assert_close(x.grad, -torch.sin(x.detach()), atol=1e-5, rtol=1e-5) + + def test_triton_dc_backward_with_dict_grad(self): + """User returns the dataclass slot's grad as a plain ``dict`` (handy + when constructing another instance is awkward, e.g. in a generic + backward). The bridge must spread it through ``__getitem__``-style + access into the underlying flat slots.""" + + def _setup(ctx, inputs, output): + x, cfg = inputs + ctx.save_for_backward(x) + ctx.block_size = cfg.block_size + + def _bwd(ctx, grad_out): + (x,) = ctx.saved_tensors + # Use a dict for the dc slot rather than constructing a new + # _DcCosCfg(block_size=None) instance. + return grad_out * (-torch.sin(x)), {"block_size": None} + + @magi_register_custom_op( + name="magi_test::dc_cos_dict_grad", + setup_context_fn=_setup, + backward_fn=_bwd, + ) + def mycos(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=cfg.block_size) + return out + + x = torch.randn(512, device="cuda", requires_grad=True) + out = mycos(x, _DcCosCfg(block_size=128)) + out.sum().backward() + assert_close(x.grad, -torch.sin(x.detach()), atol=1e-5, rtol=1e-5) + + +# Module-level @triton.heuristics fixtures for the rejection tests below. +# Defined at module scope so triton.jit's source-resolution machinery (which +# needs a real .py file) succeeds. + + +class TestDataclassTritonComputeSensitiveSmoke: + """The dataclass-aware bridge composes cleanly with + ``is_compute_sensitive=True`` on the triton path: registration succeeds, + the op runs, and its name lands in the compute-sensitive registry. + """ + + def test_dataclass_triton_compute_sensitive(self): + from magi_compiler.config import get_compile_config + + @magi_register_custom_op( + name="magi_test::dc_triton_cs", + is_compute_sensitive=True, + ) + def myop(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=cfg.block_size) + return out + + x = torch.randn(256, device="cuda") + out = myop(x, _DcCosCfg(block_size=128)) + assert_close(out, torch.cos(x)) + assert ( + "magi_test::dc_triton_cs" + in get_compile_config().recompute_config.custom_compute_sensitive_ops + ) + + +class TestInductorSeesTritonKernel: + """The whole point of the triton_op auto-detection is that + ``torch.compile`` (Inductor) sees through the op to the underlying + triton kernel rather than treating it as opaque. Verify by inspecting + the FX graph captured by Inductor for the wrap_triton-functional HOP. + """ + + def test_triton_kernel_visible_in_aot_graph(self): + from torch._functorch.aot_autograd import aot_function + + @magi_register_custom_op(name="magi_test::inductor_visible_cos") + def mycos(x: torch.Tensor) -> torch.Tensor: + out = torch.empty_like(x) + n = x.numel() + _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) + return out + + # Run the op through AOTAutograd directly with a custom forward + # compiler that just records the post-functionalization graph. This + # is exactly the layer where ``triton_op``s decompose into the + # ``triton_kernel_wrapper_functional`` HOP; the presence of that + # node in the captured graph proves Inductor (which runs *after* + # AOTAutograd) sees the underlying triton kernel rather than an + # opaque ``torch.ops.magi_test.inductor_visible_cos`` call. + captured_graphs: list[str] = [] + + def _capture(gm, _example_inputs): + captured_graphs.append(gm.code) + return gm.forward + + x = torch.randn(1024, device="cuda") + torch._dynamo.reset() + compiled_aot = aot_function(mycos, fw_compiler=_capture, bw_compiler=_capture) + out = compiled_aot(x) + assert_close(out, torch.cos(x), atol=1e-5, rtol=1e-5) + + joined = "\n".join(captured_graphs) + assert ( + "triton_kernel_wrapper_functional" in joined + or "triton_kernel_wrapper_mutation" in joined + ), ( + "AOT graph did not decompose magi_test::inductor_visible_cos " + "into the triton_kernel_wrapper HOP; Inductor will treat it " + "as opaque. Captured AOT graph:\n" + joined + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 64a9a42da4477233ae4082024974d479848a7d51 Mon Sep 17 00:00:00 2001 From: Yunbo Zhang Date: Mon, 27 Apr 2026 19:45:50 +0800 Subject: [PATCH 2/2] [fix] fix code style --- magi_compiler/_magi_register_custom_op.py | 117 ++------- magi_compiler/_triton_introspect.py | 67 +---- tests/api_tests/_triton_external_helpers.py | 8 +- tests/api_tests/test_register_custom_op.py | 273 ++++---------------- tests/api_tests/test_register_triton_op.py | 257 ++++-------------- 5 files changed, 154 insertions(+), 568 deletions(-) diff --git a/magi_compiler/_magi_register_custom_op.py b/magi_compiler/_magi_register_custom_op.py index fb3ea98..b8d6b93 100644 --- a/magi_compiler/_magi_register_custom_op.py +++ b/magi_compiler/_magi_register_custom_op.py @@ -98,8 +98,7 @@ def _assert_op_name_namespaced(op_name: str) -> None: _LITERAL_STRING_DOWNGRADE_HINT = ( - "Use ``str`` and validate the value inside the op body, e.g. " - "``assert mode in ('a', 'b')``." + "Use ``str`` and validate the value inside the op body, e.g. " "``assert mode in ('a', 'b')``." ) @@ -439,14 +438,7 @@ def _resolve_dataclass_field_types(cls: type) -> dict[str, Any]: return {f.name: f.type for f in dataclasses.fields(cls)} -_SCHEMA_DEFAULT_TYPES: tuple[type, ...] = ( - int, - float, - bool, - str, - torch.device, - torch.dtype, -) +_SCHEMA_DEFAULT_TYPES: tuple[type, ...] = (int, float, bool, str, torch.device, torch.dtype) def _schema_compatible_param_default(default: Any) -> Any: @@ -491,9 +483,7 @@ def _schema_compatible_default(f: "dataclasses.Field") -> Any: return inspect.Parameter.empty -def _build_dataclass_subplan( - cls: type, attr_name: str, flat_prefix: str -) -> tuple[tuple, list[inspect.Parameter]]: +def _build_dataclass_subplan(cls: type, attr_name: str, flat_prefix: str) -> tuple[tuple, list[inspect.Parameter]]: """Recursively build a (sub-)plan and the corresponding flat parameters for one frozen-dataclass-typed value. @@ -530,18 +520,12 @@ def _build_dataclass_subplan( ) _assert_not_mutable_dataclass(f_type, where=f"field {cls.__name__}.{f.name}") if _is_frozen_dataclass(f_type): - sub_node, sub_params = _build_dataclass_subplan( - f_type, attr_name=f.name, flat_prefix=child_flat_name - ) + sub_node, sub_params = _build_dataclass_subplan(f_type, attr_name=f.name, flat_prefix=child_flat_name) children.append(sub_node) flat_params.extend(sub_params) else: - _assert_not_unsupported_container( - f_type, where=f"field {cls.__name__}.{f.name}" - ) - f_type = _maybe_downgrade_literal_or_enum( - f_type, where=f"field {cls.__name__}.{f.name}" - ) + _assert_not_unsupported_container(f_type, where=f"field {cls.__name__}.{f.name}") + f_type = _maybe_downgrade_literal_or_enum(f_type, where=f"field {cls.__name__}.{f.name}") children.append(("primitive", f.name, child_flat_name, None)) # Carry the dataclass field's default (or default_factory product) # over to the flat parameter so torch.library.infer_schema records @@ -602,16 +586,12 @@ def _build_flat_signature(fn: Callable): annotation = resolved.get(name, param.annotation) _assert_not_mutable_dataclass(annotation, where=f"parameter {name!r}") if _is_frozen_dataclass(annotation): - node, sub_params = _build_dataclass_subplan( - annotation, attr_name=name, flat_prefix=name - ) + node, sub_params = _build_dataclass_subplan(annotation, attr_name=name, flat_prefix=name) plan.append(node) flat_params.extend(sub_params) else: _assert_not_unsupported_container(annotation, where=f"parameter {name!r}") - annotation = _maybe_downgrade_literal_or_enum( - annotation, where=f"parameter {name!r}" - ) + annotation = _maybe_downgrade_literal_or_enum(annotation, where=f"parameter {name!r}") new_param = param.replace( kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation, @@ -655,9 +635,7 @@ def _wrapped(*args, **kwargs): _wrapped.__signature__ = flat_sig flat_annotations = { - p.name: p.annotation - for p in flat_sig.parameters.values() - if p.annotation is not inspect.Parameter.empty + p.name: p.annotation for p in flat_sig.parameters.values() if p.annotation is not inspect.Parameter.empty } if flat_sig.return_annotation is not inspect.Signature.empty: flat_annotations["return"] = flat_sig.return_annotation @@ -718,9 +696,7 @@ def _flatten_value_into(node: tuple, value: Any, out: list) -> None: _flatten_value_into(child, getattr(value, field_name), out) -def _flatten_call_args( - plan: list[tuple], user_sig: inspect.Signature, args: tuple, kwargs: dict -) -> list: +def _flatten_call_args(plan: list[tuple], user_sig: inspect.Signature, args: tuple, kwargs: dict) -> list: """ Flatten a user-side call (which may pass nested dataclass instances) into a positional list. The order matches the flat signature produced by @@ -798,10 +774,7 @@ def _collect_tensor_leaf_flat_names(node: tuple) -> list[str]: return out -def _expand_mutates_args( - mutates_args: tuple[str, ...] | list[str], - plan: list[tuple], -) -> tuple[str, ...]: +def _expand_mutates_args(mutates_args: tuple[str, ...] | list[str], plan: list[tuple]) -> tuple[str, ...]: """Translate ``mutates_args`` from the *original* parameter space to the *flat* parameter space. @@ -913,8 +886,7 @@ def _assert_wrap_triton_compatible(kernels: list[Any]) -> None: def _resolve_triton_kernels( - fn: Callable, - extra_triton_kernels: list[Any] | tuple[Any, ...] | None, + fn: Callable, extra_triton_kernels: list[Any] | tuple[Any, ...] | None ) -> tuple[list[Any], list[Any], set[int]]: """Best-effort: collect triton kernels referenced inside ``fn``. @@ -981,9 +953,7 @@ def _resolve_triton_kernels( except Exception: logger.debug("get_referenced_heuristics_kernels(%r) failed", fn, exc_info=True) referenced_heuristics = [] - _assert_wrap_triton_compatible( - list(extra_triton_kernels or ()) + list(referenced_heuristics) - ) + _assert_wrap_triton_compatible(list(extra_triton_kernels or ()) + list(referenced_heuristics)) try: user_wrapped = get_user_wrapped_triton_kernels(fn) except Exception: @@ -1064,16 +1034,12 @@ def _register_op( except ImportError: triton_op = None # type: ignore[assignment] logger.warning( - "torch.library.triton_op not available; falling back to " - "torch.library.custom_op for op %s", - op_name, + "torch.library.triton_op not available; falling back to " "torch.library.custom_op for op %s", op_name ) if triton_op is not None: try: - fn_for_register = rewrite_fn_with_wrap_triton( - fn, bare_triton_kernels, excluded_kernel_ids=excluded_kernel_ids - ) + fn_for_register = rewrite_fn_with_wrap_triton(fn, bare_triton_kernels, excluded_kernel_ids=excluded_kernel_ids) # ``rewrite_fn_with_wrap_triton`` builds a fresh # ``types.FunctionType`` from ``fn.__code__``; if ``fn`` is a # thin signature-rewriting wrapper (e.g. for Literal / @@ -1087,16 +1053,9 @@ def _register_op( for p in signature_override.parameters.values() if p.annotation is not inspect.Parameter.empty } - if ( - signature_override.return_annotation - is not inspect.Signature.empty - ): - fn_for_register.__annotations__["return"] = ( - signature_override.return_annotation - ) - registered_op = triton_op(op_name, mutates_args=mutates_args)( - fn_for_register - ) + if signature_override.return_annotation is not inspect.Signature.empty: + fn_for_register.__annotations__["return"] = signature_override.return_annotation + registered_op = triton_op(op_name, mutates_args=mutates_args)(fn_for_register) # ``triton_op`` already registers ``fn`` as the fake/meta # implementation. Only override when the user explicitly # supplied an ``infer_output_meta_fn``. @@ -1133,9 +1092,7 @@ def decorator(fn: Callable) -> Callable: _assert_op_name_namespaced(op_name) _assert_op_name_unused(op_name) if is_compute_sensitive: - get_compile_config().recompute_config.custom_compute_sensitive_ops.append( - op_name - ) + get_compile_config().recompute_config.custom_compute_sensitive_ops.append(op_name) if is_subgraph_boundary: get_compile_config().splitting_ops.append(op_name) @@ -1153,9 +1110,7 @@ def decorator(fn: Callable) -> Callable: # schema sees the cleaned-up version. Otherwise we register ``fn`` # directly to preserve the original zero-overhead path. sig_was_rewritten = _signatures_differ(flat_sig, user_sig) - fn_for_register = ( - _make_flat_signature_wrapper(fn, flat_sig) if sig_was_rewritten else fn - ) + fn_for_register = _make_flat_signature_wrapper(fn, flat_sig) if sig_was_rewritten else fn # Step 1: Build the meta/fake function (used either as a # register_fake override on the triton path, or as the regular @@ -1165,9 +1120,7 @@ def decorator(fn: Callable) -> Callable: meta_fn = _create_identity_meta_fn(meta_target) user_supplied_meta = False elif isinstance(infer_output_meta_fn, list): - meta_fn = _create_meta_fn_from_param_names( - meta_target, infer_output_meta_fn - ) + meta_fn = _create_meta_fn_from_param_names(meta_target, infer_output_meta_fn) user_supplied_meta = True else: meta_fn = infer_output_meta_fn @@ -1175,9 +1128,7 @@ def decorator(fn: Callable) -> Callable: # Step 2: Detect inner triton kernels and register the op via # triton_op (if any kernels are present) or custom_op (otherwise). - triton_kernels, bare_triton_kernels, user_wrapped_ids = ( - _resolve_triton_kernels(fn, extra_triton_kernels) - ) + triton_kernels, bare_triton_kernels, user_wrapped_ids = _resolve_triton_kernels(fn, extra_triton_kernels) registered_op = _register_op( op_name=op_name, fn=fn_for_register, @@ -1192,9 +1143,7 @@ def decorator(fn: Callable) -> Callable: # Step 3: Register autograd if backward_fn is provided if backward_fn is not None: - registered_op.register_autograd( - backward_fn, setup_context=setup_context_fn - ) + registered_op.register_autograd(backward_fn, setup_context=setup_context_fn) _REGISTERED_OP_NAMES.add(op_name) return registered_op @@ -1212,13 +1161,9 @@ def _bind_to_user_kwargs(args, kwargs): # Detect triton kernels referenced from the original (dataclass-typed) # fn. If any are present, route ``inner_fn`` through a wrap_triton-aware # copy of ``fn`` so the eventual triton_op registration captures them. - triton_kernels, bare_triton_kernels, user_wrapped_ids = _resolve_triton_kernels( - fn, extra_triton_kernels - ) + triton_kernels, bare_triton_kernels, user_wrapped_ids = _resolve_triton_kernels(fn, extra_triton_kernels) fn_for_inner = ( - rewrite_fn_with_wrap_triton( - fn, bare_triton_kernels, excluded_kernel_ids=user_wrapped_ids - ) + rewrite_fn_with_wrap_triton(fn, bare_triton_kernels, excluded_kernel_ids=user_wrapped_ids) if bare_triton_kernels else fn ) @@ -1243,9 +1188,7 @@ def inner_fn(*args, **kwargs): # tool reading ``__annotations__`` directly (e.g. ``get_type_hints``) # also sees the primitive types torch.library expects. flat_annotations = { - p.name: p.annotation - for p in flat_sig.parameters.values() - if p.annotation is not inspect.Parameter.empty + p.name: p.annotation for p in flat_sig.parameters.values() if p.annotation is not inspect.Parameter.empty } if flat_sig.return_annotation is not inspect.Signature.empty: flat_annotations["return"] = flat_sig.return_annotation @@ -1303,9 +1246,7 @@ def _bridged_setup_context(ctx, inputs, output): # ``inputs`` is the flat positional tuple in the order of # ``flat_sig``. Reassemble it into the user's original # (possibly nested-dataclass-bearing) shape. - flat_kwargs = { - p.name: v for p, v in zip(flat_sig.parameters.values(), inputs) - } + flat_kwargs = {p.name: v for p, v in zip(flat_sig.parameters.values(), inputs)} user_kwargs = _reassemble_user_kwargs(plan, flat_kwargs) # Preserve original positional order so users can do # ``x, cfg = inputs`` exactly like in the no-dataclass case. @@ -1320,9 +1261,7 @@ def _bridged_backward(ctx, *grads): user_grads = (user_grads,) return tuple(_flatten_user_grads(plan, user_grads)) - registered_op.register_autograd( - _bridged_backward, setup_context=_bridged_setup_context - ) + registered_op.register_autograd(_bridged_backward, setup_context=_bridged_setup_context) # Outer wrapper preserves the original (dataclass-aware) signature for # users while routing through the registered (flat) op underneath. diff --git a/magi_compiler/_triton_introspect.py b/magi_compiler/_triton_introspect.py index d5d10da..150e3d7 100644 --- a/magi_compiler/_triton_introspect.py +++ b/magi_compiler/_triton_introspect.py @@ -59,9 +59,7 @@ # ============================================================================== -def _find_triton_kernels_impl( - fn: Callable[..., Any], only_bare: bool = False -) -> list[object]: +def _find_triton_kernels_impl(fn: Callable[..., Any], only_bare: bool = False) -> list[object]: """Shared driver for :func:`get_inner_triton_kernels` and :func:`get_bare_triton_kernels`. @@ -75,11 +73,7 @@ def _find_triton_kernels_impl( # prevent infinite recursion MAX_RECURSION_DEPTH = 5 - def find_triton_kernels( - fn: Callable[..., Any], - visited_fns: set[int] | None = None, - depth: int = 0, - ) -> list[object]: + def find_triton_kernels(fn: Callable[..., Any], visited_fns: set[int] | None = None, depth: int = 0) -> list[object]: try: from triton.runtime.autotuner import Autotuner from triton.runtime.jit import JITFunction @@ -98,10 +92,7 @@ def find_triton_kernels( if fn_id in visited_fns: return [] if depth > MAX_RECURSION_DEPTH: - logger.debug( - "reached max recursion depth (%s) in find_triton_kernels", - MAX_RECURSION_DEPTH, - ) + logger.debug("reached max recursion depth (%s) in find_triton_kernels", MAX_RECURSION_DEPTH) return [] visited_fns.add(fn_id) @@ -166,9 +157,7 @@ def visit_Call(self, node: ast.Call) -> None: and attr.value.value.value.id == "torch" and attr.value.value.attr == "ops" ): - self.called_functions.append( - f"{attr.value.attr}::{attr.attr}" - ) + self.called_functions.append(f"{attr.value.attr}::{attr.attr}") # Catch capture_triton, wrap_triton that's been # imported directly elif isinstance(node.func, ast.Name): @@ -186,9 +175,7 @@ def visit_Call(self, node: ast.Call) -> None: # value is a plain Name (the most common pattern); subscripted # attributes (e.g. ``self.kernel[grid](...)``) need the # ``extra_triton_kernels`` escape hatch. - if isinstance(node.func, ast.Subscript) and isinstance( - node.func.value, ast.Name - ): + if isinstance(node.func, ast.Subscript) and isinstance(node.func.value, ast.Name): self.bare_kernel_names.append(node.func.value.id) self.generic_visit(node) @@ -279,11 +266,7 @@ def resolve_names_to_kernels( except ValueError: unwrapped = obj if hasattr(unwrapped, "__code__"): - nested = find_triton_kernels( - unwrapped, - visited_fns, - depth + 1, - ) + nested = find_triton_kernels(unwrapped, visited_fns, depth + 1) if nested: results.extend(nested) continue @@ -292,9 +275,7 @@ def resolve_names_to_kernels( # trace through local assignments for rhs_expr in assignments[name]: referenced = extract_names_from_expr(rhs_expr) - traced = resolve_names_to_kernels( - referenced, namespace, assignments, visited - ) + traced = resolve_names_to_kernels(referenced, namespace, assignments, visited) results.extend(traced) else: logger.debug("%s not found in namespace or assignments", name) @@ -308,16 +289,12 @@ def resolve_names_to_kernels( if only_bare: names_to_resolve: list[str] = list(collector.bare_kernel_names) else: - names_to_resolve = list(collector.bare_kernel_names) + list( - collector.wrapped_kernel_names - ) + names_to_resolve = list(collector.bare_kernel_names) + list(collector.wrapped_kernel_names) for expr in collector.return_exprs: names_to_resolve.extend(extract_names_from_expr(expr)) for name in names_to_resolve: - traced_objects = resolve_names_to_kernels( - [name], all_names, collector.assignments - ) + traced_objects = resolve_names_to_kernels([name], all_names, collector.assignments) for obj in traced_objects: obj_id = id(obj) if obj_id not in seen_ids: @@ -352,9 +329,7 @@ def resolve_names_to_kernels( seen_ids.add(kernel_id) resolved.append(kernel) except Exception: - logger.debug( - "failed to analyze called function %s", func_name, exc_info=True - ) + logger.debug("failed to analyze called function %s", func_name, exc_info=True) return resolved @@ -432,9 +407,7 @@ def visit_Call(self, node: ast.Call) -> None: obj = namespace.get(n) if obj is None: continue - resolved = ( - obj if isinstance(obj, kernel_types) else _resolve_kernel(obj, kernel_types) - ) + resolved = obj if isinstance(obj, kernel_types) else _resolve_kernel(obj, kernel_types) if resolved is None: continue if id(resolved) in seen: @@ -616,9 +589,7 @@ def _is_user_helper(obj: object) -> bool: def rewrite_fn_with_wrap_triton( - fn: Callable[..., Any], - kernels: list[object], - excluded_kernel_ids: Optional[set[int]] = None, + fn: Callable[..., Any], kernels: list[object], excluded_kernel_ids: Optional[set[int]] = None ) -> Callable[..., Any]: """ Return a copy of ``fn`` whose globals / closures are shadowed so that every @@ -772,21 +743,11 @@ def _rebuild(f: Callable[..., Any]) -> Callable[..., Any]: new_cells.append(types.CellType(_rebuild(contents))) continue except Exception: - logger.debug( - "failed to rebuild closure helper %s", - getattr(contents, "__name__", "?"), - exc_info=True, - ) + logger.debug("failed to rebuild closure helper %s", getattr(contents, "__name__", "?"), exc_info=True) new_cells.append(cell) new_closure = tuple(new_cells) - new_fn = types.FunctionType( - f.__code__, - new_globals, - f.__name__, - f.__defaults__, - new_closure, - ) + new_fn = types.FunctionType(f.__code__, new_globals, f.__name__, f.__defaults__, new_closure) # Preserve introspectable metadata so that downstream tooling # (infer_schema, register_fake, etc.) continues to work. try: diff --git a/tests/api_tests/_triton_external_helpers.py b/tests/api_tests/_triton_external_helpers.py index 37cc253..0547d9d 100644 --- a/tests/api_tests/_triton_external_helpers.py +++ b/tests/api_tests/_triton_external_helpers.py @@ -41,9 +41,7 @@ if HAS_TRITON: @triton.jit - def external_neg_kernel( - in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr - ): + def external_neg_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements @@ -51,9 +49,7 @@ def external_neg_kernel( tl.store(out_ptr + offsets, -x, mask=mask) @triton.jit - def external_double_kernel( - in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr - ): + def external_double_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements diff --git a/tests/api_tests/test_register_custom_op.py b/tests/api_tests/test_register_custom_op.py index 3b694f7..49dbdf2 100644 --- a/tests/api_tests/test_register_custom_op.py +++ b/tests/api_tests/test_register_custom_op.py @@ -78,9 +78,7 @@ def test_multiple_inputs(self): """Test custom op with multiple input tensors.""" @magi_register_custom_op(name="test::multi_input_op", mutates_args=()) - def _multi_input_op( - a: torch.Tensor, b: torch.Tensor, scale: float - ) -> torch.Tensor: + def _multi_input_op(a: torch.Tensor, b: torch.Tensor, scale: float) -> torch.Tensor: return (a + b) * scale a = torch.randn(4, 8) @@ -98,19 +96,13 @@ class TestInferOutputMeta: def test_with_infer_output_meta(self): """Test that infer_output_meta_fn is correctly registered for tracing.""" - def _scaled_add_infer_output_meta( - x: torch.Tensor, y: torch.Tensor, scale: float - ) -> torch.Tensor: + def _scaled_add_infer_output_meta(x: torch.Tensor, y: torch.Tensor, scale: float) -> torch.Tensor: return torch.empty_like(x) @magi_register_custom_op( - name="test::scaled_add_op", - mutates_args=(), - infer_output_meta_fn=_scaled_add_infer_output_meta, + name="test::scaled_add_op", mutates_args=(), infer_output_meta_fn=_scaled_add_infer_output_meta ) - def _scaled_add_op( - x: torch.Tensor, y: torch.Tensor, scale: float - ) -> torch.Tensor: + def _scaled_add_op(x: torch.Tensor, y: torch.Tensor, scale: float) -> torch.Tensor: return (x + y) * scale x = torch.randn(4, 8) @@ -124,20 +116,11 @@ def _scaled_add_op( def test_multiple_outputs_infer_meta(self): """Test infer_output_meta_fn with multiple outputs.""" - def _split_op_infer_output_meta( - x: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + def _split_op_infer_output_meta(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: half_size = x.shape[-1] // 2 - return ( - x.new_empty((*x.shape[:-1], half_size)), - x.new_empty((*x.shape[:-1], half_size)), - ) + return (x.new_empty((*x.shape[:-1], half_size)), x.new_empty((*x.shape[:-1], half_size))) - @magi_register_custom_op( - name="test::split_op", - mutates_args=(), - infer_output_meta_fn=_split_op_infer_output_meta, - ) + @magi_register_custom_op(name="test::split_op", mutates_args=(), infer_output_meta_fn=_split_op_infer_output_meta) def _split_op(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: half_size = x.shape[-1] // 2 # NOTE: Output cannot share the same memory with input @@ -177,9 +160,7 @@ def test_auto_name_multiple_outputs(self): """Test auto-generated name with multiple tensor outputs.""" @magi_register_custom_op() - def _auto_name_multi_out_op( - a: torch.Tensor, b: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def _auto_name_multi_out_op(a: torch.Tensor, b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: return torch.clone(a + 1), torch.clone(b + 2) def fn(a, b): @@ -205,9 +186,7 @@ def _auto_grad_backward(ctx, grad_output): (x,) = ctx.saved_tensors return grad_output * 2 * x - @magi_register_custom_op( - setup_context_fn=_auto_grad_setup_context, backward_fn=_auto_grad_backward - ) + @magi_register_custom_op(setup_context_fn=_auto_grad_setup_context, backward_fn=_auto_grad_backward) def _auto_name_square_op(x: torch.Tensor) -> torch.Tensor: return x * x @@ -245,9 +224,7 @@ def test_single_output_multiple_inputs_default_meta(self): """Test default meta function with multiple inputs but single tensor output.""" @magi_register_custom_op(name="test::default_meta_multi_in") - def _default_meta_multi_in_op( - a: torch.Tensor, b: torch.Tensor, scale: float - ) -> torch.Tensor: + def _default_meta_multi_in_op(a: torch.Tensor, b: torch.Tensor, scale: float) -> torch.Tensor: return (a + b) * scale def fn(a, b, scale): @@ -267,9 +244,7 @@ def test_multiple_outputs_default_meta(self): """Test default meta function with multiple tensor outputs.""" @magi_register_custom_op(name="test::default_meta_multi_out") - def _default_meta_multi_out_op( - x: torch.Tensor, y: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: + def _default_meta_multi_out_op(x: torch.Tensor, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # Clone to avoid aliasing issues return torch.clone(x * 2), torch.clone(y * 3) @@ -335,14 +310,9 @@ def fn(scale, x, offset, y): @pytest.fixture() def magi_compile_config(): """Fixture to set up a clean compile configuration for magi_compile tests.""" - compile_config = CompileConfig( - compile_mode=CompileMode.TORCH_COMPILE, cache_root_dir=tempfile.mkdtemp() - ) - - with ( - patch("magi_compiler.api.get_compile_config") as mock_get_config, - patch("torch.distributed.get_rank") as mock_rank, - ): + compile_config = CompileConfig(compile_mode=CompileMode.TORCH_COMPILE, cache_root_dir=tempfile.mkdtemp()) + + with patch("magi_compiler.api.get_compile_config") as mock_get_config, patch("torch.distributed.get_rank") as mock_rank: mock_get_config.return_value = compile_config mock_rank.return_value = 0 yield compile_config @@ -361,11 +331,7 @@ def test_custom_op_in_compiled_function(self): def _double_infer_output_meta(x: torch.Tensor) -> torch.Tensor: return torch.empty_like(x) - @magi_register_custom_op( - name="test::double_op", - mutates_args=(), - infer_output_meta_fn=_double_infer_output_meta, - ) + @magi_register_custom_op(name="test::double_op", mutates_args=(), infer_output_meta_fn=_double_infer_output_meta) def _double_op(x: torch.Tensor) -> torch.Tensor: return x * 2 @@ -429,11 +395,7 @@ def test_custom_op_in_magi_compiled_module(self, magi_compile_config): def _triple_infer_output_meta(x: torch.Tensor) -> torch.Tensor: return torch.empty_like(x) - @magi_register_custom_op( - name="test::triple_op", - mutates_args=(), - infer_output_meta_fn=_triple_infer_output_meta, - ) + @magi_register_custom_op(name="test::triple_op", mutates_args=(), infer_output_meta_fn=_triple_infer_output_meta) def _triple_op(x: torch.Tensor) -> torch.Tensor: return x * 3 @@ -508,9 +470,7 @@ def _relu_custom_infer_output_meta(x: torch.Tensor) -> torch.Tensor: return torch.empty_like(x) @magi_register_custom_op( - name="test::relu_custom_op", - mutates_args=(), - infer_output_meta_fn=_relu_custom_infer_output_meta, + name="test::relu_custom_op", mutates_args=(), infer_output_meta_fn=_relu_custom_infer_output_meta ) def _relu_custom_op(x: torch.Tensor) -> torch.Tensor: return torch.relu(x) @@ -541,19 +501,11 @@ def _add_one_infer_output_meta(x: torch.Tensor) -> torch.Tensor: def _mul_two_infer_output_meta(x: torch.Tensor) -> torch.Tensor: return torch.empty_like(x) - @magi_register_custom_op( - name="test::add_one_op", - mutates_args=(), - infer_output_meta_fn=_add_one_infer_output_meta, - ) + @magi_register_custom_op(name="test::add_one_op", mutates_args=(), infer_output_meta_fn=_add_one_infer_output_meta) def _add_one_op(x: torch.Tensor) -> torch.Tensor: return x + 1 - @magi_register_custom_op( - name="test::mul_two_op", - mutates_args=(), - infer_output_meta_fn=_mul_two_infer_output_meta, - ) + @magi_register_custom_op(name="test::mul_two_op", mutates_args=(), infer_output_meta_fn=_mul_two_infer_output_meta) def _mul_two_op(x: torch.Tensor) -> torch.Tensor: return x * 2 @@ -573,25 +525,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: expected = (x + 1) * 2 assert_close(output, expected) - def test_custom_op_multiple_outputs_in_magi_compiled_module( - self, magi_compile_config - ): + def test_custom_op_multiple_outputs_in_magi_compiled_module(self, magi_compile_config): """Test custom op with multiple outputs inside a magi_compile'd module.""" - def _split_v2_infer_output_meta( - x: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + def _split_v2_infer_output_meta(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: half = x.shape[-1] // 2 - return ( - x.new_empty((*x.shape[:-1], half)), - x.new_empty((*x.shape[:-1], half)), - ) + return (x.new_empty((*x.shape[:-1], half)), x.new_empty((*x.shape[:-1], half))) - @magi_register_custom_op( - name="test::split_v2_op", - mutates_args=(), - infer_output_meta_fn=_split_v2_infer_output_meta, - ) + @magi_register_custom_op(name="test::split_v2_op", mutates_args=(), infer_output_meta_fn=_split_v2_infer_output_meta) def _split_v2_op(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: half = x.shape[-1] // 2 return torch.clone(x[..., :half]), torch.clone(x[..., half:]) @@ -668,19 +609,13 @@ class TestDataclassWithComputeSensitiveSmoke: def test_dataclass_compute_sensitive_smoke(self): from magi_compiler.config import get_compile_config - @magi_register_custom_op( - name="test::dc_compute_sensitive", - is_compute_sensitive=True, - ) + @magi_register_custom_op(name="test::dc_compute_sensitive", is_compute_sensitive=True) def _op(x: torch.Tensor, cfg: _CSCfg) -> torch.Tensor: return x * cfg.s out = _op(torch.ones(2), _CSCfg(s=2.0)) assert_close(out, torch.full((2,), 2.0)) - assert ( - "test::dc_compute_sensitive" - in get_compile_config().recompute_config.custom_compute_sensitive_ops - ) + assert "test::dc_compute_sensitive" in get_compile_config().recompute_config.custom_compute_sensitive_ops # ============================================================================ @@ -718,9 +653,7 @@ class _AttnCfg: causal: bool @magi_register_custom_op(name="test::dc_mixed_op", mutates_args=()) - def _dc_mixed_op( - q: torch.Tensor, k: torch.Tensor, cfg: _AttnCfg - ) -> torch.Tensor: + def _dc_mixed_op(q: torch.Tensor, k: torch.Tensor, cfg: _AttnCfg) -> torch.Tensor: out = (q @ k.transpose(-1, -2)) * cfg.scale if cfg.causal: mask = torch.tril(torch.ones_like(out, dtype=torch.bool)) @@ -749,9 +682,7 @@ class _ProjCfg: def _proj_meta(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: return x.new_empty((*x.shape[:-1], cfg.out_dim)) - @magi_register_custom_op( - name="test::dc_meta_op", mutates_args=(), infer_output_meta_fn=_proj_meta - ) + @magi_register_custom_op(name="test::dc_meta_op", mutates_args=(), infer_output_meta_fn=_proj_meta) def _dc_meta_op(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: return x[..., : cfg.out_dim].clone() @@ -793,9 +724,7 @@ def _nested_dc_only_op(cfg: _Outer) -> torch.Tensor: assert plan[0][1] == "cfg" children = plan[0][3] kinds = [c[0] for c in children] - assert ( - "dataclass" in kinds - ), "inner dataclass field must remain a dataclass node" + assert "dataclass" in kinds, "inner dataclass field must remain a dataclass node" # Find the leaf flat names. flat_names: list[str] = [] @@ -861,9 +790,7 @@ class _Root: tag: float @magi_register_custom_op(name="test::deep_nested_dc_op", mutates_args=()) - def _deep_nested_dc_op( - x: torch.Tensor, cfg: _Root, alpha: float - ) -> torch.Tensor: + def _deep_nested_dc_op(x: torch.Tensor, cfg: _Root, alpha: float) -> torch.Tensor: return x * cfg.mid.leaf.val + cfg.mid.extra + cfg.tag + alpha x = torch.randn(2, 3) @@ -887,11 +814,7 @@ class _ProjCfg: def _proj_meta(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: return x.new_empty((*x.shape[:-1], cfg.shape.out_dim)) - @magi_register_custom_op( - name="test::nested_dc_meta_op", - mutates_args=(), - infer_output_meta_fn=_proj_meta, - ) + @magi_register_custom_op(name="test::nested_dc_meta_op", mutates_args=(), infer_output_meta_fn=_proj_meta) def _nested_dc_meta_op(x: torch.Tensor, cfg: _ProjCfg) -> torch.Tensor: return x[..., : cfg.shape.out_dim].clone() * cfg.scale @@ -948,14 +871,7 @@ def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: out_all_none = _op(x, _Cfg(None, None, None, None, None, None)) assert_close(out_all_none, x) # Some-None: bias + scale active, others ignored by the body. - cfg = _Cfg( - bias=torch.tensor([1.0, 2.0, 3.0]), - scale=2.0, - mode="a", - block_sizes=[4, 8], - flag=True, - count=7, - ) + cfg = _Cfg(bias=torch.tensor([1.0, 2.0, 3.0]), scale=2.0, mode="a", block_sizes=[4, 8], flag=True, count=7) out_some = _op(x, cfg) assert_close(out_some, torch.tensor([4.0, 6.0, 8.0])) @@ -974,10 +890,7 @@ def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: x = torch.ones(2) assert_close(_op(x, _Cfg(bias=None)), x) - assert_close( - _op(x, _Cfg(bias=torch.tensor([10.0, 20.0]))), - torch.tensor([11.0, 21.0]), - ) + assert_close(_op(x, _Cfg(bias=torch.tensor([10.0, 20.0]))), torch.tensor([11.0, 21.0])) def test_optional_list_of_tensors_unsupported_by_torch_library(self): """Sanity-pin a known-broken case so a future torch upgrade that @@ -1067,10 +980,7 @@ def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: return out x = torch.ones(2) - out = _op( - x, - _Cfg(maybe_biases=[None, torch.tensor([10.0, 20.0]), None]), - ) + out = _op(x, _Cfg(maybe_biases=[None, torch.tensor([10.0, 20.0]), None])) assert_close(out, torch.tensor([11.0, 21.0])) def test_empty_list_field(self): @@ -1149,10 +1059,7 @@ def _op(x: torch.Tensor, cfg: _Outer) -> torch.Tensor: x = torch.ones(2) assert_close(_op(x, _Outer()), torch.full((2,), 3.0)) - assert_close( - _op(x, _Outer(inner=_Inner(val=5.0), scale=2.0)), - torch.full((2,), 10.0), - ) + assert_close(_op(x, _Outer(inner=_Inner(val=5.0), scale=2.0)), torch.full((2,), 10.0)) class TestDataclassUnsupportedContainerFields: @@ -1460,10 +1367,7 @@ def test_list_tensor_return_runtime(self): def meta(x, n): return [torch.empty_like(x.chunk(n, dim=0)[0]) for _ in range(n)] - @magi_register_custom_op( - name="test::list_tensor_return", - infer_output_meta_fn=meta, - ) + @magi_register_custom_op(name="test::list_tensor_return", infer_output_meta_fn=meta) def _op(x: torch.Tensor, n: int) -> list[torch.Tensor]: # ``.chunk`` returns views that alias ``x``; clone each chunk so # the op doesn't violate torch.library's no-aliasing invariant @@ -1540,9 +1444,7 @@ def _square_op(x: torch.Tensor) -> torch.Tensor: def test_autograd_multiple_inputs(self): """Test autograd with multiple input tensors.""" - def _weighted_sum_infer_output_meta( - a: torch.Tensor, b: torch.Tensor, weight: float - ) -> torch.Tensor: + def _weighted_sum_infer_output_meta(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor: return torch.empty_like(a) def _weighted_sum_setup_context(ctx, inputs, output): @@ -1564,9 +1466,7 @@ def _weighted_sum_backward(ctx, grad_output): setup_context_fn=_weighted_sum_setup_context, backward_fn=_weighted_sum_backward, ) - def _weighted_sum_op( - a: torch.Tensor, b: torch.Tensor, weight: float - ) -> torch.Tensor: + def _weighted_sum_op(a: torch.Tensor, b: torch.Tensor, weight: float) -> torch.Tensor: return a * weight + b * (1 - weight) a = torch.randn(4, 8, requires_grad=True) @@ -1586,14 +1486,9 @@ def _weighted_sum_op( def test_autograd_multiple_outputs(self): """Test autograd with multiple output tensors.""" - def _split_scale_infer_output_meta( - x: torch.Tensor, scale: float - ) -> tuple[torch.Tensor, torch.Tensor]: + def _split_scale_infer_output_meta(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]: half = x.shape[-1] // 2 - return ( - x.new_empty((*x.shape[:-1], half)), - x.new_empty((*x.shape[:-1], half)), - ) + return (x.new_empty((*x.shape[:-1], half)), x.new_empty((*x.shape[:-1], half))) def _split_scale_setup_context(ctx, inputs, output): x, scale = inputs @@ -1615,9 +1510,7 @@ def _split_scale_backward(ctx, grad_out1, grad_out2): setup_context_fn=_split_scale_setup_context, backward_fn=_split_scale_backward, ) - def _split_scale_op( - x: torch.Tensor, scale: float - ) -> tuple[torch.Tensor, torch.Tensor]: + def _split_scale_op(x: torch.Tensor, scale: float) -> tuple[torch.Tensor, torch.Tensor]: half = x.shape[-1] // 2 return x[..., :half] * scale, x[..., half:] * scale @@ -1656,12 +1549,7 @@ def _bwd(ctx, grad_out): (_x,) = ctx.saved_tensors return grad_out * ctx.scale, None - @magi_register_custom_op( - name="test::dc_bwd_basic", - mutates_args=(), - setup_context_fn=_setup, - backward_fn=_bwd, - ) + @magi_register_custom_op(name="test::dc_bwd_basic", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) def _op(x: torch.Tensor, cfg: _ScaleCfg) -> torch.Tensor: return x * cfg.scale @@ -1691,12 +1579,7 @@ def _bwd(ctx, grad_out): # must be accepted by the bridge. return None - @magi_register_custom_op( - name="test::dc_bwd_bare_grad", - mutates_args=(), - setup_context_fn=_setup, - backward_fn=_bwd, - ) + @magi_register_custom_op(name="test::dc_bwd_bare_grad", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) def _op(cfg: _Cfg) -> torch.Tensor: return torch.full((2, 3), cfg.alpha) @@ -1725,12 +1608,7 @@ def _bwd(ctx, grad_out): # ``None`` leaves; the bridge must spread these to flat slots. return grad_out * ctx.scale, _Cfg(scale=None, offset=None) - @magi_register_custom_op( - name="test::dc_bwd_per_field", - mutates_args=(), - setup_context_fn=_setup, - backward_fn=_bwd, - ) + @magi_register_custom_op(name="test::dc_bwd_per_field", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: return x * cfg.scale + cfg.offset @@ -1756,12 +1634,7 @@ def _setup(ctx, inputs, output): def _bwd(ctx, grad_out): return grad_out * ctx.scale, {"scale": None} - @magi_register_custom_op( - name="test::dc_bwd_dict_grad", - mutates_args=(), - setup_context_fn=_setup, - backward_fn=_bwd, - ) + @magi_register_custom_op(name="test::dc_bwd_dict_grad", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: return x * cfg.scale @@ -1795,12 +1668,7 @@ def _setup(ctx, inputs, output): def _bwd(ctx, grad_out): return grad_out * ctx.scale, None # whole nested dc => None - @magi_register_custom_op( - name="test::dc_bwd_nested", - mutates_args=(), - setup_context_fn=_setup, - backward_fn=_bwd, - ) + @magi_register_custom_op(name="test::dc_bwd_nested", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) def _op(x: torch.Tensor, cfg: _Outer) -> torch.Tensor: return x * cfg.inner.scale + cfg.inner.bias + cfg.tag @@ -1830,12 +1698,7 @@ def _bwd(ctx, grad_out): a, b = ctx.saved_tensors return grad_out * ctx.alpha, None, grad_out * ctx.beta - @magi_register_custom_op( - name="test::dc_bwd_sandwich", - mutates_args=(), - setup_context_fn=_setup, - backward_fn=_bwd, - ) + @magi_register_custom_op(name="test::dc_bwd_sandwich", mutates_args=(), setup_context_fn=_setup, backward_fn=_bwd) def _op(a: torch.Tensor, cfg: _Cfg, b: torch.Tensor) -> torch.Tensor: return a * cfg.alpha + b * cfg.beta @@ -1865,10 +1728,7 @@ def _bad_bwd(ctx, grad_out): return (grad_out,) # missing the dataclass slot @magi_register_custom_op( - name="test::dc_bwd_wrong_count", - mutates_args=(), - setup_context_fn=_setup, - backward_fn=_bad_bwd, + name="test::dc_bwd_wrong_count", mutates_args=(), setup_context_fn=_setup, backward_fn=_bad_bwd ) def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: return x * cfg.scale @@ -1898,11 +1758,7 @@ def bwd(ctx, gy0, gy1): # d(out0)/dx = s, d(out1)/dx = 1 return gy0 * ctx.s + gy1, None - @magi_register_custom_op( - name="test::tuple_bwd_dc", - setup_context_fn=setup, - backward_fn=bwd, - ) + @magi_register_custom_op(name="test::tuple_bwd_dc", setup_context_fn=setup, backward_fn=bwd) def _op(x: torch.Tensor, cfg: _BwTupleCfg) -> tuple[torch.Tensor, torch.Tensor]: return x * cfg.s, x.clone() @@ -1934,11 +1790,7 @@ def bwd(ctx, gy): # x: gx = gy * w; cfg.w: gw = gy * x; cfg.b: None. return gy * w, {"w": gy * x, "b": None} - @magi_register_custom_op( - name="test::partial_none_grad", - setup_context_fn=setup, - backward_fn=bwd, - ) + @magi_register_custom_op(name="test::partial_none_grad", setup_context_fn=setup, backward_fn=bwd) def _op(x: torch.Tensor, cfg: _Cfg) -> torch.Tensor: return x * cfg.w + cfg.b @@ -1972,11 +1824,7 @@ def bwd(ctx, gy): x, w = ctx.saved_tensors return matmul_grad_x(gy, w), x.t() @ gy - @magi_register_custom_op( - name="test::matmul_with_op_bwd", - setup_context_fn=setup, - backward_fn=bwd, - ) + @magi_register_custom_op(name="test::matmul_with_op_bwd", setup_context_fn=setup, backward_fn=bwd) def matmul(x: torch.Tensor, w: torch.Tensor) -> torch.Tensor: return x @ w @@ -2021,11 +1869,7 @@ def setup(ctx, inputs, output): def bwd(ctx, gy): return gy * ctx.s, None - @magi_register_custom_op( - name="test::kwonly_dc_bwd", - setup_context_fn=setup, - backward_fn=bwd, - ) + @magi_register_custom_op(name="test::kwonly_dc_bwd", setup_context_fn=setup, backward_fn=bwd) def _op(x: torch.Tensor, *, cfg: _KwCfg) -> torch.Tensor: return x * cfg.s @@ -2050,10 +1894,7 @@ def test_mutates_args_dataclass_expands(self): # We don't actually mutate (frozen dataclass) -- we just check the # registration succeeds, which it would not if the name failed to # expand to flat tensor leaves. - @magi_register_custom_op( - name="test::mutates_dc_expand", - mutates_args=("cfg",), - ) + @magi_register_custom_op(name="test::mutates_dc_expand", mutates_args=("cfg",)) def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: cfg.a.add_(x) # in-place on a tensor field cfg.b.add_(x) @@ -2067,20 +1908,14 @@ def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: def test_mutates_args_unknown_name_rejected(self): with pytest.raises(ValueError, match="does not match"): - @magi_register_custom_op( - name="test::mutates_dc_unknown", - mutates_args=("does_not_exist",), - ) + @magi_register_custom_op(name="test::mutates_dc_unknown", mutates_args=("does_not_exist",)) def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: return x def test_mutates_args_flat_name_passthrough(self): """Users may also use the flat name (``cfg__a``) directly.""" - @magi_register_custom_op( - name="test::mutates_dc_flat", - mutates_args=("cfg__a",), - ) + @magi_register_custom_op(name="test::mutates_dc_flat", mutates_args=("cfg__a",)) def _op(x: torch.Tensor, cfg: _MutCfg) -> torch.Tensor: cfg.a.add_(x) return cfg.a + cfg.b diff --git a/tests/api_tests/test_register_triton_op.py b/tests/api_tests/test_register_triton_op.py index af3db6d..f893ae9 100644 --- a/tests/api_tests/test_register_triton_op.py +++ b/tests/api_tests/test_register_triton_op.py @@ -53,9 +53,7 @@ from magi_compiler.api import magi_register_custom_op # noqa: E402 -pytestmark = pytest.mark.skipif( - not torch.cuda.is_available(), reason="triton kernels require CUDA" -) +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="triton kernels require CUDA") # --------------------------------------------------------------------------- @@ -97,10 +95,7 @@ def _add_kernel(a_ptr, b_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 128}, num_warps=4), - triton.Config({"BLOCK_SIZE": 256}, num_warps=4), - ], + configs=[triton.Config({"BLOCK_SIZE": 128}, num_warps=4), triton.Config({"BLOCK_SIZE": 256}, num_warps=4)], key=["n_elements"], ) @triton.jit @@ -209,15 +204,10 @@ def _heuristics_top_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr tl.store(out_ptr + offsets, x, mask=mask) -@triton.autotune( - configs=[triton.Config({}, num_warps=4)], - key=["n_elements"], -) +@triton.autotune(configs=[triton.Config({}, num_warps=4)], key=["n_elements"]) @triton.heuristics({"BLOCK_SIZE": lambda args: 128}) @triton.jit -def _autotune_then_heuristics_kernel( - in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr -): +def _autotune_then_heuristics_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -324,14 +314,9 @@ def fn(x: torch.Tensor) -> torch.Tensor: assert_close(fn(x), torch.cos(x), atol=1e-5, rtol=1e-5) def test_introspection_walks_all_levels(self): - from torch._higher_order_ops.triton_kernel_wrap import ( - TraceableTritonKernelWrapper, - ) + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper - from magi_compiler._triton_introspect import ( - get_inner_triton_kernels, - rewrite_fn_with_wrap_triton, - ) + from magi_compiler._triton_introspect import get_inner_triton_kernels, rewrite_fn_with_wrap_triton def fn(x): return _dispatch_launcher(x) @@ -342,9 +327,7 @@ def fn(x): rewritten = rewrite_fn_with_wrap_triton(fn, kernels) rebuilt_dispatch = rewritten.__globals__["_dispatch_launcher"] rebuilt_inner = rebuilt_dispatch.__globals__["_inner_launcher"] - assert isinstance( - rebuilt_inner.__globals__["_cos_kernel"], TraceableTritonKernelWrapper - ) + assert isinstance(rebuilt_inner.__globals__["_cos_kernel"], TraceableTritonKernelWrapper) # Third-party "thin wrapper" pattern: some libraries return objects with a @@ -386,10 +369,7 @@ def __init__(self, kernel): def _build_fn(self): kernel = self._kernel - @magi_register_custom_op( - name=f"magi_test::module_self_kernel_{id(self)}", - extra_triton_kernels=[kernel], - ) + @magi_register_custom_op(name=f"magi_test::module_self_kernel_{id(self)}", extra_triton_kernels=[kernel]) def op(x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) n = x.numel() @@ -468,9 +448,7 @@ def fn(x: torch.Tensor) -> torch.Tensor: class TestTrueCrossModuleLauncher: def test_external_neg_launcher(self): - from tests.api_tests._triton_external_helpers import ( - external_neg_launcher, - ) + from tests.api_tests._triton_external_helpers import external_neg_launcher @magi_register_custom_op(name="magi_test::true_cross_module_neg") def fn(x: torch.Tensor) -> torch.Tensor: @@ -480,18 +458,10 @@ def fn(x: torch.Tensor) -> torch.Tensor: assert_close(fn(x), -x, atol=1e-5, rtol=1e-5) def test_rewrite_descends_into_other_module(self): - from torch._higher_order_ops.triton_kernel_wrap import ( - TraceableTritonKernelWrapper, - ) + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper - from magi_compiler._triton_introspect import ( - get_inner_triton_kernels, - rewrite_fn_with_wrap_triton, - ) - from tests.api_tests._triton_external_helpers import ( - external_double_kernel, - external_double_launcher, - ) + from magi_compiler._triton_introspect import get_inner_triton_kernels, rewrite_fn_with_wrap_triton + from tests.api_tests._triton_external_helpers import external_double_kernel, external_double_launcher def fn(x): # Bare Name call so the introspector can follow it across modules @@ -514,19 +484,13 @@ def fn(x): contents = cell.cell_contents except ValueError: continue - if callable(contents) and getattr(contents, "__name__", None) == ( - "external_double_launcher" - ): + if callable(contents) and getattr(contents, "__name__", None) == ("external_double_launcher"): rebuilt_launcher = contents break assert rebuilt_launcher is not None, ( - "expected rewrite_fn_with_wrap_triton to keep the launcher in " - "the rewritten function's closure" + "expected rewrite_fn_with_wrap_triton to keep the launcher in " "the rewritten function's closure" ) - assert isinstance( - rebuilt_launcher.__globals__["external_double_kernel"], - TraceableTritonKernelWrapper, - ), ( + assert isinstance(rebuilt_launcher.__globals__["external_double_kernel"], TraceableTritonKernelWrapper), ( "rewrite_fn_with_wrap_triton should rebuild cross-module helpers " "so the kernel reference inside them is wrap_triton-aware." ) @@ -536,8 +500,7 @@ def fn(x): from tests.api_tests import _triton_external_helpers as ext_mod assert not isinstance( - ext_mod.external_double_launcher.__globals__["external_double_kernel"], - TraceableTritonKernelWrapper, + ext_mod.external_double_launcher.__globals__["external_double_kernel"], TraceableTritonKernelWrapper ), ( "rewrite_fn_with_wrap_triton must not mutate the helper's home " "module globals (other unrelated callers would be affected)." @@ -590,9 +553,7 @@ def fn(x: torch.Tensor) -> torch.Tensor: def test_rewrite_does_not_double_wrap(self): """Direct unit test: passing the already-wrapped kernel back through ``rewrite_fn_with_wrap_triton`` must not produce a double wrapper.""" - from torch._higher_order_ops.triton_kernel_wrap import ( - TraceableTritonKernelWrapper, - ) + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper from torch.library import wrap_triton from magi_compiler._triton_introspect import rewrite_fn_with_wrap_triton @@ -638,10 +599,7 @@ def test_explicit_kernel_list(self): kernels_holder = type("KH", (), {})() kernels_holder.k = _cos_kernel - @magi_register_custom_op( - name="magi_test::cos_via_extra", - extra_triton_kernels=[_cos_kernel], - ) + @magi_register_custom_op(name="magi_test::cos_via_extra", extra_triton_kernels=[_cos_kernel]) def fn(x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) n = x.numel() @@ -666,9 +624,7 @@ def fn(x): _cos_kernel[_grid_1d(n)](x, out, n, BLOCK_SIZE=128) return out - resolved_all, resolved_bare, _user_wrapped_ids = _resolve_triton_kernels( - fn, [_cos_kernel] - ) + resolved_all, resolved_bare, _user_wrapped_ids = _resolve_triton_kernels(fn, [_cos_kernel]) # Should appear exactly once even though it's both passed explicitly # and discovered by introspection. assert resolved_all.count(_cos_kernel) == 1 @@ -677,18 +633,13 @@ def fn(x): assert len(resolved_bare) == 1 rewritten = rewrite_fn_with_wrap_triton(fn, resolved_bare) - from torch._higher_order_ops.triton_kernel_wrap import ( - TraceableTritonKernelWrapper, - ) + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper wrapped = rewritten.__globals__["_cos_kernel"] assert isinstance(wrapped, TraceableTritonKernelWrapper) def test_dedup_e2e(self): - @magi_register_custom_op( - name="magi_test::dedup_cos", - extra_triton_kernels=[_cos_kernel], - ) + @magi_register_custom_op(name="magi_test::dedup_cos", extra_triton_kernels=[_cos_kernel]) def fn(x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) n = x.numel() @@ -717,10 +668,7 @@ class TestExtraTritonKernelsForStaticOrClassmethod: """ def test_staticmethod_selected_kernel(self): - @magi_register_custom_op( - name="magi_test::sm_kernel", - extra_triton_kernels=[_scale_kernel], - ) + @magi_register_custom_op(name="magi_test::sm_kernel", extra_triton_kernels=[_scale_kernel]) def myop(x: torch.Tensor) -> torch.Tensor: kernel = _KernelHolder.get_static() out = torch.empty_like(x) @@ -733,10 +681,7 @@ def myop(x: torch.Tensor) -> torch.Tensor: assert_close(out, x * 2.0) def test_classmethod_selected_kernel(self): - @magi_register_custom_op( - name="magi_test::cm_kernel", - extra_triton_kernels=[_scale_kernel], - ) + @magi_register_custom_op(name="magi_test::cm_kernel", extra_triton_kernels=[_scale_kernel]) def myop(x: torch.Tensor) -> torch.Tensor: kernel = _KernelHolder.get_class() out = torch.empty_like(x) @@ -760,10 +705,7 @@ def test_runtime_imported_kernel(self): # function still call it). Simulate the runtime-import case by stuffing # the kernel into a local ``import``-like alias derived from globals, # so source introspection cannot statically resolve it. - @magi_register_custom_op( - name="magi_test::runtime_import_kernel", - extra_triton_kernels=[_cos_kernel], - ) + @magi_register_custom_op(name="magi_test::runtime_import_kernel", extra_triton_kernels=[_cos_kernel]) def myop(x: torch.Tensor) -> torch.Tensor: module_globals = globals() # Indirect lookup hides the kernel from static introspection of @@ -818,14 +760,9 @@ def fn(a, b): assert _add_kernel in kernels def test_rewrite_replaces_kernel_with_wrap_triton(self): - from torch._higher_order_ops.triton_kernel_wrap import ( - TraceableTritonKernelWrapper, - ) + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper - from magi_compiler._triton_introspect import ( - get_inner_triton_kernels, - rewrite_fn_with_wrap_triton, - ) + from magi_compiler._triton_introspect import get_inner_triton_kernels, rewrite_fn_with_wrap_triton def fn(x): out = torch.empty_like(x) @@ -838,23 +775,16 @@ def fn(x): # _cos_kernel name in the rewritten globals should now point to a # TraceableTritonKernelWrapper, not the bare JITFunction. - assert isinstance( - rewritten.__globals__["_cos_kernel"], TraceableTritonKernelWrapper - ) + assert isinstance(rewritten.__globals__["_cos_kernel"], TraceableTritonKernelWrapper) # Originals untouched. from triton.runtime.jit import JITFunction assert isinstance(_cos_kernel, JITFunction) def test_rewrite_propagates_through_helpers(self): - from torch._higher_order_ops.triton_kernel_wrap import ( - TraceableTritonKernelWrapper, - ) + from torch._higher_order_ops.triton_kernel_wrap import TraceableTritonKernelWrapper - from magi_compiler._triton_introspect import ( - get_inner_triton_kernels, - rewrite_fn_with_wrap_triton, - ) + from magi_compiler._triton_introspect import get_inner_triton_kernels, rewrite_fn_with_wrap_triton def fn(a, b): return _add_launcher(a, b) @@ -863,9 +793,7 @@ def fn(a, b): rewritten = rewrite_fn_with_wrap_triton(fn, kernels) rebuilt_launcher = rewritten.__globals__["_add_launcher"] - assert isinstance( - rebuilt_launcher.__globals__["_add_kernel"], TraceableTritonKernelWrapper - ) + assert isinstance(rebuilt_launcher.__globals__["_add_kernel"], TraceableTritonKernelWrapper) # Multi-level nesting: fn -> dispatch -> launcher -> kernel. @@ -875,10 +803,7 @@ def fn(a, b): class TestInferOutputMetaOverride: def test_meta_list_form(self): - @magi_register_custom_op( - name="magi_test::triton_meta_list", - infer_output_meta_fn=["x"], - ) + @magi_register_custom_op(name="magi_test::triton_meta_list", infer_output_meta_fn=["x"]) def fn(x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) n = x.numel() @@ -898,10 +823,7 @@ def custom_meta(x: torch.Tensor) -> torch.Tensor: called["count"] += 1 return torch.empty_like(x) - @magi_register_custom_op( - name="magi_test::triton_meta_callable", - infer_output_meta_fn=custom_meta, - ) + @magi_register_custom_op(name="magi_test::triton_meta_callable", infer_output_meta_fn=custom_meta) def fn(x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) n = x.numel() @@ -963,9 +885,7 @@ def test_pure_python_op_not_registered_as_triton(self): def fn(x: torch.Tensor) -> torch.Tensor: return x * 2 + 1 - assert not self._was_registered_as_triton_op( - "magi_test::registry_pure_python" - ), ( + assert not self._was_registered_as_triton_op("magi_test::registry_pure_python"), ( "magi_test::registry_pure_python has no triton kernels; it should " "have fallen back to the custom_op path and remain opaque to " "make_fx." @@ -1006,16 +926,11 @@ def test_two_autotune_kernels_in_same_op(self): # Build a *second* autotuned kernel locally so we can be sure both # kernel objects appear in the op's call graph. @triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 128}, num_warps=4), - triton.Config({"BLOCK_SIZE": 256}, num_warps=4), - ], + configs=[triton.Config({"BLOCK_SIZE": 128}, num_warps=4), triton.Config({"BLOCK_SIZE": 256}, num_warps=4)], key=["n_elements"], ) @triton.jit - def _autotuned_sin_kernel( - in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr - ): + def _autotuned_sin_kernel(in_ptr, out_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -1023,10 +938,7 @@ def _autotuned_sin_kernel( x = tl.load(in_ptr + offsets, mask=mask) tl.store(out_ptr + offsets, tl.sin(x), mask=mask) - @magi_register_custom_op( - name="magi_test::two_autotune_kernels", - extra_triton_kernels=[_autotuned_sin_kernel], - ) + @magi_register_custom_op(name="magi_test::two_autotune_kernels", extra_triton_kernels=[_autotuned_sin_kernel]) def myop(x: torch.Tensor) -> torch.Tensor: n = x.numel() mid = torch.empty_like(x) @@ -1068,10 +980,7 @@ def test_top_level_heuristics_via_extra_triton_kernels_rejected(self): involved).""" with pytest.raises(RuntimeError, match="triton.heuristics"): - @magi_register_custom_op( - name="magi_test::heuristics_extra", - extra_triton_kernels=[_heuristics_top_kernel], - ) + @magi_register_custom_op(name="magi_test::heuristics_extra", extra_triton_kernels=[_heuristics_top_kernel]) def myop(x: torch.Tensor) -> torch.Tensor: # Body doesn't reference the kernel at all; rejection comes # purely from the extra_triton_kernels list. @@ -1108,11 +1017,7 @@ def backward(ctx, grad_out): # d/dx cos(x) = -sin(x) return grad_out * (-torch.sin(x)) - @magi_register_custom_op( - name="magi_test::triton_cos_grad", - setup_context_fn=setup_ctx, - backward_fn=backward, - ) + @magi_register_custom_op(name="magi_test::triton_cos_grad", setup_context_fn=setup_ctx, backward_fn=backward) def mycos(x: torch.Tensor) -> torch.Tensor: out = torch.empty_like(x) n = x.numel() @@ -1154,11 +1059,8 @@ def fn(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: # The dataclass-aware path registers an *inner* op under the requested # name. That inner op should still be a triton_op. - assert TestTritonOpRegistryAssertion._was_registered_as_triton_op( - "magi_test::dc_cos" - ), ( - "dataclass+triton path should still register the inner op as a " - "triton_op so Inductor can see through it." + assert TestTritonOpRegistryAssertion._was_registered_as_triton_op("magi_test::dc_cos"), ( + "dataclass+triton path should still register the inner op as a " "triton_op so Inductor can see through it." ) @@ -1180,15 +1082,11 @@ def fn(x: torch.Tensor, cfg: _DcOuterCfg) -> torch.Tensor: out = torch.empty_like(x) n = x.numel() _cos_kernel[_grid_1d(n)](x, tmp, n, BLOCK_SIZE=cfg.kernel.block_size) - _scale_kernel[_grid_1d(n)]( - tmp, out, n, cfg.scale, BLOCK_SIZE=cfg.kernel.block_size - ) + _scale_kernel[_grid_1d(n)](tmp, out, n, cfg.scale, BLOCK_SIZE=cfg.kernel.block_size) return out + cfg.kernel.extra_offset x = torch.randn(1024, device="cuda") - cfg = _DcOuterCfg( - kernel=_DcKernelCfg(block_size=128, extra_offset=0.5), scale=2.5 - ) + cfg = _DcOuterCfg(kernel=_DcKernelCfg(block_size=128, extra_offset=0.5), scale=2.5) out = fn(x, cfg) expected = torch.cos(x) * 2.5 + 0.5 assert_close(out, expected, atol=1e-5, rtol=1e-5) @@ -1207,16 +1105,10 @@ def _collect(node): _collect(child) _collect(cfg_node) - assert { - "cfg__kernel__block_size", - "cfg__kernel__extra_offset", - "cfg__scale", - }.issubset(flat_names) + assert {"cfg__kernel__block_size", "cfg__kernel__extra_offset", "cfg__scale"}.issubset(flat_names) # And: the registered op should still go through triton_op. - assert TestTritonOpRegistryAssertion._was_registered_as_triton_op( - "magi_test::nested_dc_cos_scale" - ), ( + assert TestTritonOpRegistryAssertion._was_registered_as_triton_op("magi_test::nested_dc_cos_scale"), ( "nested-dataclass + triton path should still register the inner " "op as a triton_op so Inductor can see through it." ) @@ -1228,10 +1120,7 @@ def test_nested_dc_with_triton_and_meta_fn(self): def _meta(x: torch.Tensor, cfg: _DcProjCfg) -> torch.Tensor: return x.new_empty((*x.shape[:-1], cfg.shape.out_dim)) - @magi_register_custom_op( - name="magi_test::nested_dc_cos_proj", - infer_output_meta_fn=_meta, - ) + @magi_register_custom_op(name="magi_test::nested_dc_cos_proj", infer_output_meta_fn=_meta) def fn(x: torch.Tensor, cfg: _DcProjCfg) -> torch.Tensor: sliced = x[..., : cfg.shape.out_dim].contiguous() out = torch.empty_like(sliced) @@ -1268,11 +1157,7 @@ def _bwd(ctx, grad_out): (x,) = ctx.saved_tensors return grad_out * (-torch.sin(x)), None - @magi_register_custom_op( - name="magi_test::dc_cos_grad", - setup_context_fn=_setup, - backward_fn=_bwd, - ) + @magi_register_custom_op(name="magi_test::dc_cos_grad", setup_context_fn=_setup, backward_fn=_bwd) def mycos(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: out = torch.empty_like(x) n = x.numel() @@ -1287,9 +1172,7 @@ def mycos(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: # Sanity: this op should still have gone through the triton_op path # (the dataclass-aware path registers an inner op under ``op_name``). - assert TestTritonOpRegistryAssertion._was_registered_as_triton_op( - "magi_test::dc_cos_grad" - ) + assert TestTritonOpRegistryAssertion._was_registered_as_triton_op("magi_test::dc_cos_grad") def test_triton_nested_dc_backward(self): """Nested dataclass + triton + backward. The bridge must spread the @@ -1308,25 +1191,17 @@ def _bwd(ctx, grad_out): # d/dx (cos(x) * scale + offset) = -sin(x) * scale return grad_out * (-torch.sin(x)) * ctx.scale, None - @magi_register_custom_op( - name="magi_test::nested_dc_cos_grad", - setup_context_fn=_setup, - backward_fn=_bwd, - ) + @magi_register_custom_op(name="magi_test::nested_dc_cos_grad", setup_context_fn=_setup, backward_fn=_bwd) def fn(x: torch.Tensor, cfg: _DcOuterCfg) -> torch.Tensor: tmp = torch.empty_like(x) out = torch.empty_like(x) n = x.numel() _cos_kernel[_grid_1d(n)](x, tmp, n, BLOCK_SIZE=cfg.kernel.block_size) - _scale_kernel[_grid_1d(n)]( - tmp, out, n, cfg.scale, BLOCK_SIZE=cfg.kernel.block_size - ) + _scale_kernel[_grid_1d(n)](tmp, out, n, cfg.scale, BLOCK_SIZE=cfg.kernel.block_size) return out + cfg.kernel.extra_offset x = torch.randn(1024, device="cuda", requires_grad=True) - cfg = _DcOuterCfg( - kernel=_DcKernelCfg(block_size=128, extra_offset=0.5), scale=2.5 - ) + cfg = _DcOuterCfg(kernel=_DcKernelCfg(block_size=128, extra_offset=0.5), scale=2.5) out = fn(x, cfg) out.sum().backward() expected = -torch.sin(x.detach()) * 2.5 @@ -1343,16 +1218,9 @@ def _setup(ctx, inputs, output): def _bwd(ctx, grad_out): (x,) = ctx.saved_tensors - return ( - grad_out * (-torch.sin(x)), - _DcCosCfg(block_size=None), - ) + return (grad_out * (-torch.sin(x)), _DcCosCfg(block_size=None)) - @magi_register_custom_op( - name="magi_test::dc_cos_per_field_grad", - setup_context_fn=_setup, - backward_fn=_bwd, - ) + @magi_register_custom_op(name="magi_test::dc_cos_per_field_grad", setup_context_fn=_setup, backward_fn=_bwd) def mycos(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: out = torch.empty_like(x) n = x.numel() @@ -1381,11 +1249,7 @@ def _bwd(ctx, grad_out): # _DcCosCfg(block_size=None) instance. return grad_out * (-torch.sin(x)), {"block_size": None} - @magi_register_custom_op( - name="magi_test::dc_cos_dict_grad", - setup_context_fn=_setup, - backward_fn=_bwd, - ) + @magi_register_custom_op(name="magi_test::dc_cos_dict_grad", setup_context_fn=_setup, backward_fn=_bwd) def mycos(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: out = torch.empty_like(x) n = x.numel() @@ -1412,10 +1276,7 @@ class TestDataclassTritonComputeSensitiveSmoke: def test_dataclass_triton_compute_sensitive(self): from magi_compiler.config import get_compile_config - @magi_register_custom_op( - name="magi_test::dc_triton_cs", - is_compute_sensitive=True, - ) + @magi_register_custom_op(name="magi_test::dc_triton_cs", is_compute_sensitive=True) def myop(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: out = torch.empty_like(x) n = x.numel() @@ -1425,10 +1286,7 @@ def myop(x: torch.Tensor, cfg: _DcCosCfg) -> torch.Tensor: x = torch.randn(256, device="cuda") out = myop(x, _DcCosCfg(block_size=128)) assert_close(out, torch.cos(x)) - assert ( - "magi_test::dc_triton_cs" - in get_compile_config().recompute_config.custom_compute_sensitive_ops - ) + assert "magi_test::dc_triton_cs" in get_compile_config().recompute_config.custom_compute_sensitive_ops class TestInductorSeesTritonKernel: @@ -1468,10 +1326,7 @@ def _capture(gm, _example_inputs): assert_close(out, torch.cos(x), atol=1e-5, rtol=1e-5) joined = "\n".join(captured_graphs) - assert ( - "triton_kernel_wrapper_functional" in joined - or "triton_kernel_wrapper_mutation" in joined - ), ( + assert "triton_kernel_wrapper_functional" in joined or "triton_kernel_wrapper_mutation" in joined, ( "AOT graph did not decompose magi_test::inductor_visible_cos " "into the triton_kernel_wrapper HOP; Inductor will treat it " "as opaque. Captured AOT graph:\n" + joined