Skip to content

[Feat] Support dataclass in magi_register_custom_op#32

Open
themistbeforedawn wants to merge 1 commit into
SandAI-org:mainfrom
themistbeforedawn:feat/magi-register-op-with-dataclass-input
Open

[Feat] Support dataclass in magi_register_custom_op#32
themistbeforedawn wants to merge 1 commit into
SandAI-org:mainfrom
themistbeforedawn:feat/magi-register-op-with-dataclass-input

Conversation

@themistbeforedawn
Copy link
Copy Markdown
Collaborator

🗂️ PR Category

  • ✨ New Feature
  • 🚀 Optimization (performance, memory, etc.)
  • 💥 Breaking Change
  • 🐛 Bug Fix
  • 🛠️ Development / Refactoring
  • 📚 Documentation
  • 🧹 Chore (Dependencies, CI/CD, Configuration, etc.)
  • 🧪 Testing

📝 Description

What's new

@magi_register_custom_op now accepts frozen-dataclass parameters (recursively nested), so users can group config / flags as a single @dataclass(frozen=True) while torch.library's schema continues to see only primitives:

@dataclasses.dataclass(frozen=True)
class AttnCfg:
    scale: float
    causal: bool = False

@magi_register_custom_op()
def attn(q: torch.Tensor, k: torch.Tensor, cfg: AttnCfg) -> torch.Tensor:
    ...

The same lower-signature pass also handles (transparent to users):

  • Literal[...] / string-Enum annotations → auto-downgraded to str
  • Unsupported defaults (mutable, dataclass instance, …) are scrubbed from the lowered signature only; user-facing defaults are preserved
  • mutates_args accepts either the dataclass-level name (expands to all Tensor leaves) or any lowered leaf name
  • backward_fn returns one grad per original parameter (not per leaf); a whole non-differentiable dataclass arg collapses to a single None
  • Strict signature validation: variadic args / missing annotations are rejected up-front with actionable errors

Architecture — 4-slot pipeline

Each registration owns up to 4 named objects:

Slot Object Created by Presence
0 fn user source Always
1 lowered_fn this PR only when the signature needs lowering
2 torch_registered_op torch.library Always
3 magi_exposed_op this PR only when dataclass flattening is needed

The naming is deliberately dual: torch_registered_op is registered into torch.library's dispatcher; magi_exposed_op is exposed out of Magi to the user.

Architecture — 3 runtime paths

The slot set produced at registration time selects one of three runtime paths:

  1. simple  fn → torch_registered_op — zero per-call overhead; returns the OpOverload directly
  2. sig-only-rewrite  fn → lowered_fn → torch_registered_op — e.g. Literal downgrade only
  3. dataclass-flatten  fn → lowered_fn → torch_registered_op → magi_exposed_op — the wrapper flattens / unflattens on every call; the underlying OpOverload is accessible via op._magi_torch_registered_op

magi_compiler/_magi_register_custom_op.py is laid out 1:1 against this model — 8 numbered sections grouped into registration-time helpers · runtime helpers · main pipeline.

Tests

83 new tests in tests/api_tests/test_register_custom_op.py cover all three runtime paths, autograd bridging through dataclass inputs, nested dataclasses, Optional / Literal / Enum / dtype / device fields, torch.compile integration, and full error-path coverage.

Comment on lines +58 to +95
Part B. Runtime paths -- the three pipelines
============================================

Three pipelines are possible; the decorator returns whichever object sits
at the end of the path:

1. simple fn -> torch_registered_op
Returned: ``torch._ops.OpOverload`` (slot 2).
Runtime: zero magi-level overhead -- straight into torch.library's
dispatcher.

2. sig-only-rewrite fn -> lowered_fn -> torch_registered_op
Returned: ``torch._ops.OpOverload`` (slot 2).
Runtime: same as simple -- ``lowered_fn`` is a transparent
forwarding shim (the rewrite is registration-time only).

3. dataclass-flatten fn -> lowered_fn -> torch_registered_op
-> magi_exposed_op
Returned: a Python callable carrying the
``_magi_torch_registered_op`` attribute (slot 3).
Runtime forward (per call):
user code calls magi_exposed_op(x, cfg=...)
-> _flatten_call_args (original kwargs -> flat tuple)
-> _flatten_value_into (DFS over param_mapping_tree)
-> torch_registered_op(*flat) (slot 2 -- enters dispatcher)
-> lowered_fn(*flat) (slot 1 -- still in lowered shape)
-> _reassemble_kwargs (flat tuple -> original kwargs)
-> _build_value_from_node (rebuilds dataclass instances)
-> fn(**original_kwargs) (slot 0 -- user code finally sees
its original dataclass-bearing
signature)
Runtime backward (when backward_fn is supplied):
autograd calls _bridged_backward(ctx, *grads)
-> user_backward(ctx, *grads) (returns one grad per ORIGINAL
input, possibly a dataclass-shaped
grad object)
-> _flatten_grads (original grads -> flat grads)
-> _flatten_grad_into (DFS over param_mapping_tree)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make these comments self-explanatory by code instead of writing docs.

Or move these comments into a markdown document that clarifies the design of _register_custom_op.

Both acceptable for me haha~

Comment on lines +546 to 579
"""Lower ``fn``'s signature into a form ``torch.library.infer_schema`` accepts.

"Lower" is used in the compiler sense (high-level -> low-level): we walk
``fn``'s parameters once and do six things at the same time -- they all
need the same resolved annotations and the same iteration:

1. VALIDATE -- reject variadics, missing annotations, mutable dataclasses,
unsupported containers, dataclass returns (sec 1).
2. RESOLVE -- turn stringified annotations into real types via
``_resolve_annotations``, so dataclass detection works.
3. NORMALIZE -- collapse parameter kinds to POSITIONAL_OR_KEYWORD,
downgrade Literal/Enum to ``str``, scrub unsupported defaults.
4. FLATTEN -- expand each frozen-dataclass parameter (recursively) into
its primitive leaves via ``_build_dataclass_sub_mapping_tree``.
5. PYTREE -- side effect of step 4: register every dataclass as a pytree
node so Dynamo / AOTAutograd can trace through it.
6. EMIT -- assemble ``(original_sig, lowered_sig, param_mapping_tree)``.

A single pass is intentional: splitting concerns would force re-resolving
annotations and threading accumulator state. When the input is already
schema-compatible the lowered signature is bit-identical to the original,
and the caller's ``_signatures_differ`` check restores the zero-overhead path.

Returns:
- 1 if the return type is a single Tensor
- N if the return type is tuple[Tensor, Tensor, ...] with N elements
- 1 if no annotation or unrecognized annotation (default to single output)
original_sig (inspect.Signature): the user's un-flattened signature.
lowered_sig (inspect.Signature): what ``infer_schema`` will see.
param_mapping_tree (list[tuple]): the bridge between the two; a list
of root nodes (one per original parameter), each of which is:
* ``("primitive", attr_name, lowered_name, None)``, or
* ``("dataclass", attr_name, cls, [child_nodes...])``.
``attr_name`` is the parameter name at top level / field name
deeper down. The same tree drives both runtime translation
directions (sec 7).
"""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These comments confuse me. Use a human-readable sentence instead of AI-explained comments.

It takes me a long time to understand what these comments say. Provide simplified comments and code examples. I guess you are trying to say the return follows such rules, just paste the examples below:

original_sig: (q: Tensor, cfg: AttnCfg, mode: Literal['a','b']='a') -> Tensor
lowered_sig: (q: Tensor, cfg__scale: float, cfg__causal: bool, mode: str='a') -> Tensor
param_mapping_tree ≈
[
    ("primitive", "q",    "q",    None),
    ("dataclass", "cfg",  AttnCfg, [
        ("primitive", "scale",  "cfg__scale",  None),
        ("primitive", "causal", "cfg__causal", None),
    ]),
    ("primitive", "mode", "mode", None),
]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants