From dcc02f0c693677e417c7051d4360ea22d23428ad Mon Sep 17 00:00:00 2001 From: wtr Date: Sat, 11 Apr 2026 14:21:04 +0800 Subject: [PATCH 1/7] add triton matmul fusion --- magi_compiler/magi_backend/magi_backend.py | 5 +- .../passes/full_graph/full_graph_pass_mgr.py | 2 + .../passes/full_graph/remove_useless_ops.py | 117 ++++ .../passes/piecewise_graph/fusion/__init__.py | 13 + .../fusion/matmul_epilogue_fusion.py | 443 +++++++++++++ .../piecewise_graph/fusion/triton_kernels.py | 582 ++++++++++++++++++ .../piecewise_graph/post_grad_pass_manager.py | 2 + .../test_matmul_epilogue_fusion.py | 199 ++++++ 8 files changed, 1362 insertions(+), 1 deletion(-) create mode 100644 magi_compiler/passes/full_graph/remove_useless_ops.py create mode 100644 magi_compiler/passes/piecewise_graph/fusion/__init__.py create mode 100644 magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py create mode 100644 magi_compiler/passes/piecewise_graph/fusion/triton_kernels.py create mode 100644 tests/feature_tests/test_matmul_epilogue_fusion.py diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index 0d010e3..7bafdf5 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -591,7 +591,7 @@ def _split_graph(self, graph: fx.GraphModule) -> tuple[fx.GraphModule, list[Spli # Step 5: visualize the split graph if envs.MAGI_ENABLE_FX_GRAPH_VIZ: - save_fx_graph_visualization(split_gm.graph, sub_dir="after_split", filename="split_gm_root") + # save_fx_graph_visualization(split_gm.graph, sub_dir="after_split", filename="split_gm_root") for item in piecewise_graphs: save_fx_graph_visualization(item.graph.graph, sub_dir="after_split", filename=item.submod_name) @@ -605,6 +605,9 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> MagiSerializableFun self._init_cache() + # if envs.MAGI_ENABLE_FX_GRAPH_VIZ: + # save_fx_graph_visualization(graph, sub_dir="before_split", filename="gm_root") + self.full_graph_pass_manager(graph) split_gm, piecewise_graphs = self._split_graph(graph) diff --git a/magi_compiler/passes/full_graph/full_graph_pass_mgr.py b/magi_compiler/passes/full_graph/full_graph_pass_mgr.py index 502d190..0626350 100644 --- a/magi_compiler/passes/full_graph/full_graph_pass_mgr.py +++ b/magi_compiler/passes/full_graph/full_graph_pass_mgr.py @@ -16,6 +16,7 @@ from ...magi_depyf.timeline import observe_lifecycle from .remove_item import RemoveItemPass +from .remove_useless_ops import RemoveUselessOpsPass from .replace_sage_atten import ReplaceSageAttentionPass @@ -30,6 +31,7 @@ def __init__(self, pass_config): if self.pass_config.enable_sage_attn: self.passes.append(ReplaceSageAttentionPass()) self.passes.append(RemoveItemPass()) + self.passes.append(RemoveUselessOpsPass()) @observe_lifecycle("full_graph_manager") def __call__(self, gm: torch.fx.GraphModule): diff --git a/magi_compiler/passes/full_graph/remove_useless_ops.py b/magi_compiler/passes/full_graph/remove_useless_ops.py new file mode 100644 index 0000000..a31acc5 --- /dev/null +++ b/magi_compiler/passes/full_graph/remove_useless_ops.py @@ -0,0 +1,117 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch._inductor.fx_passes.pre_grad + +from ...magi_depyf.timeline import emit_pass_lifecycle +from ..pass_base import MagiInductorPass + + +class RemoveUselessOpsPass(MagiInductorPass): + """ + Remove useless convert, view, reshape operations. + When their input already has the target type and shape, these operations are redundant. + """ + + TARGET_METHODS = { + "view", + "reshape", + "to", + "type", + "contiguous", + "clone", + "flatten", + "permute", + "transpose", + "t", + "unsqueeze", + "squeeze", + "expand", + "repeat", + "bfloat16", + "float", + "half", + "int", + "long", + "short", + "double", + "bool", + "byte", + } + + @staticmethod + def _get_tensor_info(node: torch.fx.Node): + # Get tensor info from example_value + if "example_value" in node.meta: + val = node.meta["example_value"] + if isinstance(val, torch.Tensor): + return val.shape, val.dtype, val.stride() + elif isinstance(val, (list, tuple)) and len(val) > 0 and isinstance(val[0], torch.Tensor): + return val[0].shape, val[0].dtype, val[0].stride() + + return None, None, None + + def is_applicable(self, graph: torch.fx.Graph, shape: int | None = None) -> bool: + for node in graph.nodes: + if node.op == "call_method" and node.target in self.TARGET_METHODS: + return True + return False + + @emit_pass_lifecycle + def __call__(self, graph: torch.fx.Graph): + nodes_to_remove = [] + + for node in graph.nodes: + is_target_method = node.op == "call_method" and node.target in self.TARGET_METHODS + if not is_target_method: + continue + + # Need at least one argument (the input tensor) + if not node.args or not isinstance(node.args[0], torch.fx.Node): + continue + + input_node = node.args[0] + + node_shape, node_dtype, node_stride = self._get_tensor_info(node) + input_shape, input_dtype, input_stride = self._get_tensor_info(input_node) + if node_shape is None or input_shape is None: + continue + if node_dtype is None or input_dtype is None: + continue + # Some ops or metadata might not have stride properly captured, + # but if they do, we should require them to match to be totally safe against contiguous-forcing ops. + if node_stride is not None and input_stride is not None and node_stride != input_stride: + continue + + # Check if shape and dtype match exactly + if node_shape == input_shape and node_dtype == input_dtype: + # For _to_copy, ensure we are not changing memory format or device or other properties implicitly, + # but typically in full graph if shape and dtype match, and it's on the same device, it's safe. + # Let's also check device just in case if it's available. + def get_device(n): + if "example_value" in n.meta and isinstance(n.meta["example_value"], torch.Tensor): + return n.meta["example_value"].device + + node_device = get_device(node) + input_device = get_device(input_node) + if node_device is not None and input_device is not None and node_device != input_device: + continue + + # Replace uses + node.replace_all_uses_with(input_node) + nodes_to_remove.append(node) + + for node in nodes_to_remove: + graph.erase_node(node) diff --git a/magi_compiler/passes/piecewise_graph/fusion/__init__.py b/magi_compiler/passes/piecewise_graph/fusion/__init__.py new file mode 100644 index 0000000..3eaa44a --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py b/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py new file mode 100644 index 0000000..ecc271f --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py @@ -0,0 +1,443 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import operator + +import torch +import torch.fx as fx +from torch.fx.node import Node + +from magi_compiler.passes.pass_base import MagiInductorPass + +from .triton_kernels import matmul_custom_epilogue + +_LIB = torch.library.Library("magi_epilogue", "DEF") +_LIB.define("matmul_custom(Tensor A, Tensor B, Tensor[] extras, str epilogue_code, bool reduce_n_by_2) -> Tensor") + + +@torch.library.impl(_LIB, "matmul_custom", "CUDA") +def _matmul_custom_cuda(A, B, extras, epilogue_code, reduce_n_by_2): + return matmul_custom_epilogue(A, B, extras, epilogue_code, reduce_n_by_2) + + +@torch.library.register_fake("magi_epilogue::matmul_custom") +def _matmul_custom_abstract(A, B, extras, epilogue_code, reduce_n_by_2): + N_out = B.shape[1] // 2 if reduce_n_by_2 else B.shape[1] + # Mirror the 128-byte-aligned row stride used by the real kernel so that + # Inductor's assert_size_stride matches what we actually return. + # Keep the logical shape as (M, N_out) — changing it would interfere with + # Inductor's own K-dimension padding for the downstream mm. + align_elems = 128 // A.element_size() + N_stride = (N_out + align_elems - 1) // align_elems * align_elems + return A.new_empty_strided((A.shape[0], N_out), (N_stride, 1)) + + +# ── Triton expression templates ──────────────────────────────────────────────── +# Unary elementwise ops: {x} = operand expression string +_UNARY_EXPRS = { + # Arithmetic + torch.ops.aten.neg.default: "-({x})", + torch.ops.aten.abs.default: "tl.abs({x})", + torch.ops.aten.sign.default: "tl.math.sign({x})", + torch.ops.aten.reciprocal.default: "1.0 / ({x})", + torch.ops.aten.square.default: "({x}) * ({x})", + # Exponential / logarithm + torch.ops.aten.exp.default: "tl.exp({x})", + torch.ops.aten.exp2.default: "tl.exp2({x})", + torch.ops.aten.expm1.default: "tl.exp({x}) - 1.0", + torch.ops.aten.log.default: "tl.log({x})", + torch.ops.aten.log2.default: "tl.log2({x})", + torch.ops.aten.log10.default: "tl.log({x}) * 0.4342944819032518", + torch.ops.aten.log1p.default: "tl.log(1.0 + ({x}))", + # Square-root family + torch.ops.aten.sqrt.default: "tl.sqrt({x})", + torch.ops.aten.rsqrt.default: "1.0 / tl.sqrt({x})", + # Trigonometric + torch.ops.aten.sin.default: "tl.sin({x})", + torch.ops.aten.cos.default: "tl.cos({x})", + torch.ops.aten.tan.default: "tl.math.tan({x})", + torch.ops.aten.asin.default: "tl.math.asin({x})", + torch.ops.aten.acos.default: "tl.math.acos({x})", + torch.ops.aten.atan.default: "tl.math.atan({x})", + # Hyperbolic + torch.ops.aten.tanh.default: "tl.tanh({x})", + torch.ops.aten.sinh.default: "tl.math.sinh({x})", + torch.ops.aten.cosh.default: "tl.math.cosh({x})", + # Activations + torch.ops.aten.sigmoid.default: "tl.sigmoid({x})", + torch.ops.aten.relu.default: "tl.maximum({x}, 0.0)", + # Error function + torch.ops.aten.erf.default: "tl.math.erf({x})", + torch.ops.aten.erfinv.default: "tl.math.erfinv({x})", + torch.ops.aten.erfc.default: "tl.math.erfc({x})", + # Rounding + torch.ops.aten.floor.default: "tl.math.floor({x})", + torch.ops.aten.ceil.default: "tl.math.ceil({x})", + torch.ops.aten.trunc.default: "tl.math.trunc({x})", + torch.ops.aten.round.default: "tl.math.round({x})", + torch.ops.aten.frac.default: "({x}) - tl.math.trunc({x})", + # Bitwise / logical + torch.ops.aten.logical_not.default: "~({x})", + torch.ops.aten.bitwise_not.default: "~({x})", + # Predicates + torch.ops.aten.isnan.default: "tl.math.isnan({x})", + torch.ops.aten.isinf.default: "tl.math.isinf({x})", + torch.ops.aten.isfinite.default: "~tl.math.isinf({x}) & ~tl.math.isnan({x})", +} + +# Binary elementwise ops: {x} = left, {y} = right +_BINARY_EXPRS = { + # Addition / subtraction (alpha handled separately) + torch.ops.aten.add.Tensor: "({x}) + ({y})", + torch.ops.aten.add.Scalar: "({x}) + ({y})", + operator.add: "({x}) + ({y})", + torch.ops.aten.sub.Tensor: "({x}) - ({y})", + torch.ops.aten.sub.Scalar: "({x}) - ({y})", + operator.sub: "({x}) - ({y})", + # Multiplication / division + torch.ops.aten.mul.Tensor: "({x}) * ({y})", + torch.ops.aten.mul.Scalar: "({x}) * ({y})", + operator.mul: "({x}) * ({y})", + torch.ops.aten.div.Tensor: "({x}) / ({y})", + torch.ops.aten.div.Scalar: "({x}) / ({y})", + operator.truediv: "({x}) / ({y})", + torch.ops.aten.remainder.Tensor: "({x}) % ({y})", + torch.ops.aten.remainder.Scalar: "({x}) % ({y})", + operator.mod: "({x}) % ({y})", + # Min / max + torch.ops.aten.maximum.default: "tl.maximum({x}, {y})", + torch.ops.aten.minimum.default: "tl.minimum({x}, {y})", + # Trigonometric binary + torch.ops.aten.atan2.default: "tl.math.atan2({x}, {y})", + # Bitwise / logical binary + torch.ops.aten.bitwise_and.Tensor: "({x}) & ({y})", + torch.ops.aten.bitwise_and.Scalar: "({x}) & ({y})", + operator.and_: "({x}) & ({y})", + torch.ops.aten.bitwise_or.Tensor: "({x}) | ({y})", + torch.ops.aten.bitwise_or.Scalar: "({x}) | ({y})", + operator.or_: "({x}) | ({y})", + torch.ops.aten.bitwise_xor.Tensor: "({x}) ^ ({y})", + torch.ops.aten.bitwise_xor.Scalar: "({x}) ^ ({y})", + operator.xor: "({x}) ^ ({y})", + torch.ops.aten.logical_and.default: "({x}) & ({y})", + torch.ops.aten.logical_or.default: "({x}) | ({y})", + torch.ops.aten.logical_xor.default: "({x}) ^ ({y})", +} + +# Ops that pass through without any value transformation +_PASSTHROUGH_OPS = frozenset( + { + torch.ops.prims.convert_element_type.default, + torch.ops.aten._to_copy.default, + torch.ops.aten.clone.default, + torch.ops.aten.contiguous.default, + torch.ops.aten.alias.default, + } +) + + +def _get_static_dims(mm_node: fx.Node) -> dict: + """Return {name: value} for mm dimensions that are compile-time-constant. + + FX shapes carry plain Python ``int`` for static dims and ``torch.SymInt`` + for symbolic (dynamic) ones. ``type(d) is int`` excludes SymInt even in + PyTorch versions where SymInt happens to subclass int. + """ + static: dict = {} + A, B = mm_node.args + try: + val_a = A.meta.get("val") if isinstance(A, fx.Node) else None + if val_a is not None and val_a.dim() == 2: + for name, idx in (("M", 0), ("K", 1)): + d = val_a.shape[idx] + if type(d) is int: + static[name] = d + val_b = B.meta.get("val") if isinstance(B, fx.Node) else None + if val_b is not None and val_b.dim() == 2: + d = val_b.shape[1] + if type(d) is int: + static["N"] = d + except Exception: + pass + return static + + +class MatmulCustomEpilogueFusionPass(MagiInductorPass): + def __call__(self, graph: fx.Graph) -> bool: + fused = 0 + for node in list(graph.nodes): + if node.op == "call_function" and node.target in (torch.ops.aten.mm.default, torch.ops.aten.mm): + fused += self._try_fuse_custom_chain(graph, node) + + if fused: + graph.eliminate_dead_code() + return fused > 0 + + def _try_fuse_custom_chain(self, graph: fx.Graph, mm_node: fx.Node) -> int: + A, B = mm_node.args + + fused_nodes = {mm_node: "acc"} + nodes_to_remove = [] + epilogue_lines = [] + extras = [] + is_swiglu = False + + def get_val(arg): + if isinstance(arg, Node): + if arg in fused_nodes: + return fused_nodes[arg] + # External tensor — inject a load + idx = len(extras) + extras.append(arg) + name = f"ext_{idx}" + val = arg.meta.get("val") + if val is not None and val.dim() == 1: + epilogue_lines.append(f"{name}_ptrs = Extra_{idx}_ptr + offs_dn[None, :]") + epilogue_lines.append(f"{name} = tl.load({name}_ptrs, mask=offs_dn[None, :] < N, other=0.0)") + else: + epilogue_lines.append( + f"{name}_ptrs = Extra_{idx}_ptr + stride_dm * offs_dm[:, None] + stride_dn * offs_dn[None, :]" + ) + epilogue_lines.append(f"{name} = tl.load({name}_ptrs, mask=mask, other=0.0)") + fused_nodes[arg] = name + return name + return str(arg) + + curr = mm_node.next + last_fused_node = mm_node + + while curr.op != "output": + uses_fused = any(isinstance(a, Node) and a in fused_nodes for a in curr.args) + if not uses_fused: + curr = curr.next + continue + + var_name = f"v_{curr.name}" + target = curr.target + code = None + + # ── 1. Pass-through (type conversion / clone / alias) ───────────── + if target in _PASSTHROUGH_OPS: + fused_nodes[curr] = fused_nodes[curr.args[0]] + nodes_to_remove.append(curr) + last_fused_node = curr + curr = curr.next + continue + + # ── 2. Unary elementwise ops (from dispatch table) ──────────────── + elif target in _UNARY_EXPRS: + x = get_val(curr.args[0]) + code = f"{var_name} = " + _UNARY_EXPRS[target].format(x=x) + + # ── 3. Compound activation functions ────────────────────────────── + elif target in (torch.ops.aten.silu.default, torch.ops.aten.silu): + x = get_val(curr.args[0]) + code = f"{var_name} = ({x}) * tl.sigmoid({x})" + + elif target in (torch.ops.aten.gelu.default, torch.ops.aten.gelu): + x = get_val(curr.args[0]) + approx = curr.kwargs.get("approximate", "none") + if approx == "tanh": + code = ( + f"{var_name} = ({x}) * 0.5 * " + f"(1.0 + tl.tanh(0.7978845608 * (({x}) + 0.044715 * ({x}) * ({x}) * ({x}))))" + ) + else: + code = f"{var_name} = 0.5 * ({x}) * (1.0 + tl.math.erf(({x}) * 0.7071067811865476))" + + elif target == torch.ops.aten.leaky_relu.default: + x = get_val(curr.args[0]) + slope = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("negative_slope", 0.01) + code = f"{var_name} = tl.where({x} >= 0.0, {x}, {slope} * ({x}))" + + elif target == torch.ops.aten.hardtanh.default: + x = get_val(curr.args[0]) + lo = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("min_val", -1.0) + hi = curr.args[2] if len(curr.args) > 2 else curr.kwargs.get("max_val", 1.0) + code = f"{var_name} = tl.minimum(tl.maximum({x}, {lo}), {hi})" + + elif target == torch.ops.aten.hardsigmoid.default: + x = get_val(curr.args[0]) + code = f"{var_name} = tl.minimum(tl.maximum(({x}) / 6.0 + 0.5, 0.0), 1.0)" + + elif target == torch.ops.aten.hardswish.default: + x = get_val(curr.args[0]) + code = f"{var_name} = ({x}) * tl.minimum(tl.maximum(({x}) / 6.0 + 0.5, 0.0), 1.0)" + + elif target == torch.ops.aten.mish.default: + x = get_val(curr.args[0]) + code = f"{var_name} = ({x}) * tl.tanh(tl.log(1.0 + tl.exp({x})))" + + # ── 4. Clamp family ─────────────────────────────────────────────── + elif target in ( + torch.ops.aten.clamp.default, + torch.ops.aten.clamp.Tensor, + torch.ops.aten.clamp_max.default, + torch.ops.aten.clamp_min.default, + ): + x = get_val(curr.args[0]) + if target is torch.ops.aten.clamp_max.default: + lo, hi = None, curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("max", None) + elif target is torch.ops.aten.clamp_min.default: + lo, hi = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("min", None), None + else: + lo = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("min", None) + hi = curr.args[2] if len(curr.args) > 2 else curr.kwargs.get("max", None) + expr = x + if lo is not None: + expr = f"tl.maximum({expr}, {get_val(lo)})" + if hi is not None: + expr = f"tl.minimum({expr}, {get_val(hi)})" + code = f"{var_name} = {expr}" + + # ── 5. Ternary select ───────────────────────────────────────────── + elif target in (torch.ops.aten.where.self, torch.ops.aten.where.ScalarSelf, torch.ops.aten.where.ScalarOther): + cond = get_val(curr.args[0]) + t = get_val(curr.args[1]) + f_ = get_val(curr.args[2]) + code = f"{var_name} = tl.where({cond}, {t}, {f_})" + + # ── 6. pow (special-cased exponents) ───────────────────────────── + elif target in (torch.ops.aten.pow.Tensor_Scalar, torch.ops.aten.pow.Tensor_Tensor): + x = get_val(curr.args[0]) + y = get_val(curr.args[1]) + if str(y) in ("2", "2.0"): + code = f"{var_name} = ({x}) * ({x})" + elif str(y) in ("0.5",): + code = f"{var_name} = tl.sqrt({x})" + elif str(y) in ("-0.5",): + code = f"{var_name} = 1.0 / tl.sqrt({x})" + elif str(y) in ("-1", "-1.0"): + code = f"{var_name} = 1.0 / ({x})" + else: + code = f"{var_name} = tl.math.pow({x}, {y})" + + # ── 7. div with rounding_mode ───────────────────────────────────── + elif target is torch.ops.aten.div.Tensor_mode: + x = get_val(curr.args[0]) + y = get_val(curr.args[1]) + rounding_mode = curr.kwargs.get("rounding_mode", None) or (curr.args[2] if len(curr.args) > 2 else None) + if rounding_mode == "floor": + code = f"{var_name} = tl.math.floor(({x}) / ({y}))" + elif rounding_mode == "trunc": + code = f"{var_name} = tl.math.trunc(({x}) / ({y}))" + else: + code = f"{var_name} = ({x}) / ({y})" + + # ── 8. Binary elementwise ops (from dispatch table) ─────────────── + elif target in _BINARY_EXPRS: + x = get_val(curr.args[0]) + y_raw = curr.args[1] + y = get_val(y_raw) + # Handle optional alpha scalar for add/sub (aten convention) + alpha = (curr.args[2] if len(curr.args) > 2 else None) or curr.kwargs.get("alpha", None) + if alpha is not None and alpha != 1: + y = f"{alpha} * ({y})" + code = f"{var_name} = " + _BINARY_EXPRS[target].format(x=x, y=y) + + # ── 9. Slice: SwiGLU (stride-2 along last dim) ─────────────────── + elif target is torch.ops.aten.slice.Tensor: + dim = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("dim", 0) + start = curr.args[2] if len(curr.args) > 2 else curr.kwargs.get("start", None) + step = curr.args[4] if len(curr.args) > 4 else curr.kwargs.get("step", 1) + + src = curr.args[0] + if isinstance(src, fx.Node) and "val" in src.meta: + rank = src.meta["val"].dim() + is_last_dim = (dim % rank) == (rank - 1) + else: + is_last_dim = dim == -1 + + if is_last_dim and step == 2: + is_swiglu = True + x = get_val(curr.args[0]) + if not x.endswith("_reshaped"): + epilogue_lines.append(f"{x}_reshaped = tl.reshape({x}, (BLOCK_M, BLOCK_N // 2, 2))") + epilogue_lines.append(f"{x}_split_0, {x}_split_1 = tl.split({x}_reshaped)") + fused_nodes[curr.args[0]] = f"{x}_reshaped" + base_x = x + else: + base_x = x[:-9] # strip '_reshaped' + + idx = 0 if (start == 0 or start is None) else 1 + code = f"{var_name} = {base_x}_split_{idx}" + else: + break # non-strided / non-trailing slice — stop fusion + + # ── Unsupported op — stop greedy fusion ──────────────────────────── + else: + break + + if code: + epilogue_lines.append(code) + fused_nodes[curr] = var_name + nodes_to_remove.append(curr) + last_fused_node = curr + + curr = curr.next + + # Validate: intermediate nodes must not escape the fused set + if not nodes_to_remove: + return 0 + for node in nodes_to_remove[:-1]: + for user in node.users: + if user not in nodes_to_remove: + return 0 + + final_var = fused_nodes[last_fused_node] + + # Skip fusion if the epilogue is a no-op (only passthrough ops were + # collected — e.g. a bare _to_copy after mm). Replacing cuBLAS with + # a Triton GEMM that does the exact same work is strictly slower. + if final_var == "acc": + return 0 + + epilogue_lines.append(f"acc = {final_var}") + + epilogue_code = "\n".join(epilogue_lines) + + # Prepend a comment that encodes which mm dimensions are statically + # known at trace time. triton_kernels.py parses this header and + # annotates the corresponding kernel parameters as tl.constexpr so + # Triton can specialise (and optimise) the compiled kernel per value. + static_dims = _get_static_dims(mm_node) + if static_dims: + epilogue_code = f"# @static:{json.dumps(static_dims, separators=(',', ':'))}\n" + epilogue_code + + with graph.inserting_after(last_fused_node): + fused_node = graph.call_function( + torch.ops.magi_epilogue.matmul_custom.default, args=(A, B, extras, epilogue_code, is_swiglu) + ) + if "val" in last_fused_node.meta: + val = last_fused_node.meta["val"] + # Propagate the 128-byte-aligned row stride so downstream + # assert_size_stride checks match what we actually return. + try: + N_out = int(val.shape[-1]) + elem_size = val.element_size() + align_elems = 128 // elem_size + N_stride = (N_out + align_elems - 1) // align_elems * align_elems + new_stride = val.stride()[:-2] + (N_stride, 1) + fused_node.meta["val"] = val.new_empty_strided(val.shape, new_stride) + except Exception: + fused_node.meta["val"] = val + + last_fused_node.replace_all_uses_with(fused_node) + + for n in reversed(nodes_to_remove): + graph.erase_node(n) + graph.erase_node(mm_node) + + return 1 diff --git a/magi_compiler/passes/piecewise_graph/fusion/triton_kernels.py b/magi_compiler/passes/piecewise_graph/fusion/triton_kernels.py new file mode 100644 index 0000000..203ffef --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/triton_kernels.py @@ -0,0 +1,582 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import math +import os + +import torch +import triton +import triton.language as tl + +from magi_compiler.config import get_compile_config + +# ── Python-level kernel caches ───────────────────────────────────────────────── +# (num_extras, epilogue_code, reduce_n_by_2) → kernel object +_KERNEL_CACHE: dict = {} +_KERNEL_TMA_CACHE: dict = {} + +# ── Persistent autotune result caches (survive process restart) ──────────────── +_cache_root = get_compile_config().cache_root_dir +_AUTOTUNE_FILE = os.path.join(_cache_root, "magi_epilogue_autotune.json") +_AUTOTUNE_FILE_TMA = os.path.join(_cache_root, "magi_epilogue_autotune_tma.json") +_AUTOTUNE_PERSIST: dict = {} +_AUTOTUNE_PERSIST_TMA: dict = {} + + +def _load_autotune_cache() -> None: + global _AUTOTUNE_PERSIST + try: + with open(_AUTOTUNE_FILE) as f: + _AUTOTUNE_PERSIST = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + _AUTOTUNE_PERSIST = {} + + +def _save_autotune_cache() -> None: + os.makedirs(os.path.dirname(_AUTOTUNE_FILE), exist_ok=True) + with open(_AUTOTUNE_FILE, "w") as f: + json.dump(_AUTOTUNE_PERSIST, f) + + +def _load_autotune_cache_tma() -> None: + global _AUTOTUNE_PERSIST_TMA + try: + with open(_AUTOTUNE_FILE_TMA) as f: + _AUTOTUNE_PERSIST_TMA = json.load(f) + except (FileNotFoundError, json.JSONDecodeError): + _AUTOTUNE_PERSIST_TMA = {} + + +def _save_autotune_cache_tma() -> None: + os.makedirs(os.path.dirname(_AUTOTUNE_FILE_TMA), exist_ok=True) + with open(_AUTOTUNE_FILE_TMA, "w") as f: + json.dump(_AUTOTUNE_PERSIST_TMA, f) + + +_load_autotune_cache() + + +def _check_tma() -> bool: + """Return True when SM90+ TMA with device-side descriptors is available.""" + try: + return ( + torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9 and hasattr(tl, "make_tensor_descriptor") + ) + except Exception: + return False + + +_TMA_AVAILABLE: bool = _check_tma() +_TMA_ALLOCATOR_SET: bool = False + +if _TMA_AVAILABLE: + _load_autotune_cache_tma() + + +def _ensure_tma_allocator() -> None: + """Set a Triton global-memory allocator once; required by device-side TMA descriptors.""" + global _TMA_ALLOCATOR_SET + if _TMA_ALLOCATOR_SET: + return + + def _alloc_fn(size: int, alignment: int, stream): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(_alloc_fn) + _TMA_ALLOCATOR_SET = True + + +def _parse_static_dims(epilogue_code: str) -> dict: + """Parse the ``# @static:{...}`` header injected by the fusion pass. + + Returns a dict like ``{"M": 2048, "K": 4096, "N": 8192}`` (only the keys + that are actually static). Missing keys mean the dimension is dynamic. + """ + for line in epilogue_code.splitlines(): + if line.startswith("# @static:"): + try: + return json.loads(line[len("# @static:") :]) + except Exception: + pass + return {} + + +def _bucket_m(M: int) -> int: + """Round M up to the nearest power-of-2 bucket. + + This drastically reduces the number of distinct (M, N, K) triples + that trigger autotune: e.g. M=1000 and M=1023 both map to 1024, + reusing the same benchmark result instead of each triggering 27 × 125 + device kernel launches. + """ + return 1 << math.ceil(math.log2(max(M, 1))) + + +# ── Autotune config list ─────────────────────────────────────────────────────── +# Shapes that prune_configs removes: +# • BLOCK_M > M_bucket → waste SM occupancy on empty rows +# • BLOCK_K > K → single-iteration k-loop, large overhead +# • BLOCK_N > N → waste on empty columns + + +def _prune_configs(configs, named_args, **kwargs): + M = named_args["M"] + N = named_args["N"] + K = named_args["K"] + pruned = [] + for cfg in configs: + bm = cfg.kwargs["BLOCK_M"] + bn = cfg.kwargs["BLOCK_N"] + bk = cfg.kwargs["BLOCK_K"] + # Keep configs whose tiles are no larger than 4× the dimension + # (leaving room for the autotuner to still test large tiles that + # can handle moderate-size matrices efficiently). + if bm > 4 * M or bn > 4 * N or bk > K: + continue + pruned.append(cfg) + # Always keep at least one fallback + return pruned if pruned else [configs[0]] + + +# ── Shared autotune config list (embedded as a string in both templates) ─────── +_AUTOTUNE_CONFIGS_BODY = """ + # ── Large-tile: high-throughput for large M/N (training) ────────────────── + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=4, num_warps=4), + # ── Medium-tile: balanced for mixed shapes ───────────────────────────────── + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=4, num_warps=4), + # ── Small-tile: high occupancy for small-M or tail dimensions ───────────── + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_M": 16, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=6, num_warps=2), + triton.Config({"BLOCK_M": 32, "BLOCK_N": 16, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=6, num_warps=2), +""" + + +# ───────────────────────────────────────────────────────────────────────────── +# Non-persistent kernel template (all CUDA GPUs) +# Uses tl.where + tl.max_contiguous + tl.multiple_of for vectorised loads. +# ───────────────────────────────────────────────────────────────────────────── +KERNEL_TEMPLATE = """ +import triton +import triton.language as tl + +_AUTOTUNE_CONFIGS = [ +{autotune_configs} +] + +@triton.autotune( + configs=_AUTOTUNE_CONFIGS, + key=["M_BUCKET", "N", "K"], + prune_configs_by={{"early_config_prune": {prune_fn_name}}}, + warmup=10, + rep=30, +) +@triton.jit +def dynamic_matmul_epilogue_kernel( + A_ptr, B_ptr, D_ptr, + {extra_ptrs_args} + M{M_annot}, N{N_annot}, K{K_annot}, + M_BUCKET, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_dm, stride_dn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + + num_pid_in_group = GROUP_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + start_m = pid_m * BLOCK_M + start_n = pid_n * BLOCK_N + + offs_am = start_m + tl.arange(0, BLOCK_M) + offs_bn = start_n + tl.arange(0, BLOCK_N) +{offs_am_guard}{offs_bn_guard} offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_M), BLOCK_M) + offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_N), BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + + A_ptrs = A_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + B_ptrs = B_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + a = tl.load(A_ptrs{k_mask_a}) + b = tl.load(B_ptrs{k_mask_b}) + acc = tl.dot(a, b, acc) + A_ptrs += BLOCK_K * stride_ak + B_ptrs += BLOCK_K * stride_bk + + offs_dm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_dn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask = {out_mask_expr} + +{epilogue_code} + +{store_code} +""" + + +# ───────────────────────────────────────────────────────────────────────────── +# TMA persistent kernel template (SM90+: H100 / Hopper and newer) +# +# Key advantages over the non-persistent path: +# 1. Device-side tl.make_tensor_descriptor — no host→device descriptor copy. +# 2. Persistent CTA loop — each SM processes multiple tiles, amortising +# kernel-launch and L2-warmup overhead. +# 3. Hardware-managed OOB fill — TMA zero-fills out-of-bounds tile edges, +# so the k-loop needs no software mask. +# 4. B read as [K, N] (no pre-transpose required). +# +# {epilogue_code} and {store_code} are injected at 8-space indent so they +# land inside the `for tile_id` persistent loop body. +# ───────────────────────────────────────────────────────────────────────────── +KERNEL_TEMPLATE_TMA_PERSISTENT = """ +import triton +import triton.language as tl + +_AUTOTUNE_CONFIGS_TMA = [ +{autotune_configs} +] + +@triton.autotune( + configs=_AUTOTUNE_CONFIGS_TMA, + key=["M_BUCKET", "N", "K"], + prune_configs_by={{"early_config_prune": {prune_fn_name}}}, + warmup=10, + rep=30, +) +@triton.jit +def dynamic_matmul_epilogue_kernel_tma( + A_ptr, B_ptr, D_ptr, + {extra_ptrs_args} + M{M_annot}, N{N_annot}, K{K_annot}, + M_BUCKET, + stride_dm, stride_dn, + NUM_SMS: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_M: tl.constexpr, +): + # Device-side TMA descriptor creation — eliminates host→device copy latency. + # A is [M, K] row-major; B is [K, N] row-major (no pre-transpose needed). + # TMA hardware zero-fills tiles that extend past the tensor boundary. + a_desc = tl.make_tensor_descriptor( + A_ptr, shape=[M, K], strides=[K, 1], block_shape=[BLOCK_M, BLOCK_K], + ) + b_desc = tl.make_tensor_descriptor( + B_ptr, shape=[K, N], strides=[N, 1], block_shape=[BLOCK_K, BLOCK_N], + ) + + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + num_tiles = num_pid_m * num_pid_n + num_pid_in_group = GROUP_M * num_pid_n + + # Each CTA iterates over multiple tiles, stepping NUM_SMS at a time. + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_M) + pid_m = first_pid_m + (tile_id % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_M + offs_bn = pid_n * BLOCK_N + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + offs_k = k * BLOCK_K + a = a_desc.load([offs_am, offs_k]) + b = b_desc.load([offs_k, offs_bn]) + acc = tl.dot(a, b, acc) + + offs_dm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_dn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + mask = {out_mask_expr} + +{epilogue_code} + +{store_code} +""" + + +def _build_kernel_via_exec( + template: str, kernel_name: str, num_extras: int, epilogue_code: str, reduce_n_by_2: bool, indent: int, persist_cache: dict +) -> object: + """Compile *template* with exec() and return the resulting Triton kernel.""" + extra_ptrs_args = "".join([f"Extra_{i}_ptr, " for i in range(num_extras)]) + + # ── Derive tl.constexpr annotations and static mask/guard expressions ──── + # The fusion pass prepends a "# @static:{...}" comment to epilogue_code + # whenever it can prove (from FakeTensor meta) that a dimension is a plain + # Python int rather than a SymInt. + static_dims = _parse_static_dims(epilogue_code) + M_static = static_dims.get("M") + N_static = static_dims.get("N") + K_static = static_dims.get("K") + + # tl.constexpr annotation: Triton JIT-compiles one kernel variant per + # unique value, making all constexpr-dependent expressions compile-time + # constants (loop bounds, tile counts, mask predicates, etc.). + M_annot = ": tl.constexpr" if M_static is not None else "" + N_annot = ": tl.constexpr" if N_static is not None else "" + K_annot = ": tl.constexpr" if K_static is not None else "" + + # ── k-loop load masks ───────────────────────────────────────────────────── + # Our BLOCK_K configs are {32, 64, 128}; the mask in the k-loop is needed + # only when K is not a multiple of the chosen BLOCK_K. If K % 128 == 0, + # then K is a multiple of every BLOCK_K in the config set, so the mask + # predicate is always all-true and we can emit bare (unmasked) tl.load + # calls — the hottest path in the kernel. + if K_static is not None and K_static % 128 == 0: + k_mask_a = "" + k_mask_b = "" + else: + k_mask_a = ", mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0" + k_mask_b = ", mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0" + + # ── A / B index boundary guards ─────────────────────────────────────────── + # tl.where(offs < dim, offs, 0) prevents out-of-bounds pointer arithmetic + # when a tile straddles the last row/column. If dim is a multiple of the + # largest BLOCK size (256 covers all configs {16,32,64,128,256}), every + # tile is a full tile and the guard is dead code — remove it. + m_tile_aligned = M_static is not None and M_static % 256 == 0 + n_tile_aligned = N_static is not None and N_static % 256 == 0 + + offs_am_guard = "" if m_tile_aligned else " offs_am = tl.where(offs_am < M, offs_am, 0)\n" + offs_bn_guard = "" if n_tile_aligned else " offs_bn = tl.where(offs_bn < N, offs_bn, 0)\n" + + # ── Output (and epilogue) mask ──────────────────────────────────────────── + # The mask tensor is referenced by both the output store and extra-tensor + # loads inside epilogue_code. When a dimension is tile-aligned we drop + # its component from the predicate; both dropped → constant True mask (the + # compiler will eliminate it entirely from the PTX). + if m_tile_aligned and n_tile_aligned: + out_mask_expr = "tl.full([BLOCK_M, BLOCK_N], True, dtype=tl.int1)" + elif m_tile_aligned: + out_mask_expr = "offs_dn[None, :] < N" + elif n_tile_aligned: + out_mask_expr = "offs_dm[:, None] < M" + else: + out_mask_expr = "(offs_dm[:, None] < M) & (offs_dn[None, :] < N)" + + pad = " " * indent + indented_epilogue = "\n".join([f"{pad}{line}" for line in epilogue_code.strip().split("\n") if line]) + + if reduce_n_by_2: + # For SwiGLU the output N is N//2; output BLOCK size is BLOCK_N//2 + # whose maximum across configs is 128. Tile-alignment condition: + # (N_static // 2) % 128 == 0 ↔ N_static % 256 == 0 (same as n_tile_aligned). + if m_tile_aligned and n_tile_aligned: + mask_out_expr = "tl.full([BLOCK_M, BLOCK_N // 2], True, dtype=tl.int1)" + elif m_tile_aligned: + mask_out_expr = "offs_dn_out[None, :] < N // 2" + elif n_tile_aligned: + mask_out_expr = "offs_dm[:, None] < M" + else: + mask_out_expr = "(offs_dm[:, None] < M) & (offs_dn_out[None, :] < N // 2)" + store_code = ( + f"{pad}offs_dn_out = pid_n * (BLOCK_N // 2) + tl.arange(0, BLOCK_N // 2)\n" + f"{pad}mask_out = {mask_out_expr}\n" + f"{pad}D_ptrs = D_ptr + stride_dm * offs_dm[:, None] + stride_dn * offs_dn_out[None, :]\n" + f"{pad}tl.store(D_ptrs, acc.to(D_ptr.dtype.element_ty), mask=mask_out)" + ) + else: + store_code = ( + f"{pad}D_ptrs = D_ptr + stride_dm * offs_dm[:, None] + stride_dn * offs_dn[None, :]\n" + f"{pad}tl.store(D_ptrs, acc.to(D_ptr.dtype.element_ty), mask=mask)" + ) + + code = template.format( + autotune_configs=_AUTOTUNE_CONFIGS_BODY, + extra_ptrs_args=extra_ptrs_args, + epilogue_code=indented_epilogue, + store_code=store_code, + prune_fn_name="_prune_configs", + M_annot=M_annot, + N_annot=N_annot, + K_annot=K_annot, + offs_am_guard=offs_am_guard, + offs_bn_guard=offs_bn_guard, + k_mask_a=k_mask_a, + k_mask_b=k_mask_b, + out_mask_expr=out_mask_expr, + ) + + import linecache + import uuid + + filename = f"" + linecache.cache[filename] = (len(code), None, [line + "\n" for line in code.splitlines()], filename) + compiled = compile(code, filename, "exec") + + namespace: dict = {} + exec(compiled, {"triton": triton, "tl": tl, "_prune_configs": _prune_configs}, namespace) + kernel = namespace[kernel_name] + + # Warm the in-process autotune cache from the persisted JSON so that + # known shapes skip the benchmark entirely on restart. + key_str = str((num_extras, epilogue_code, reduce_n_by_2)) + for cache_key, best_cfg in persist_cache.items(): + if cache_key.startswith(key_str + "|"): + suffix = cache_key[len(key_str) + 1 :] + try: + m_bucket, n, k = (int(x) for x in suffix.split(",")) + except ValueError: + continue + triton_key = (m_bucket, n, k) + cfg = triton.Config( + {k2: v for k2, v in best_cfg["kwargs"].items()}, + num_stages=best_cfg["num_stages"], + num_warps=best_cfg["num_warps"], + ) + kernel.cache[triton_key] = cfg + + return kernel + + +def get_dynamic_kernel(num_extras: int, epilogue_code: str, reduce_n_by_2: bool): + key = (num_extras, epilogue_code, reduce_n_by_2) + if key in _KERNEL_CACHE: + return _KERNEL_CACHE[key] + kernel = _build_kernel_via_exec( + KERNEL_TEMPLATE, + "dynamic_matmul_epilogue_kernel", + num_extras, + epilogue_code, + reduce_n_by_2, + indent=4, + persist_cache=_AUTOTUNE_PERSIST, + ) + _KERNEL_CACHE[key] = kernel + return kernel + + +def get_dynamic_kernel_tma(num_extras: int, epilogue_code: str, reduce_n_by_2: bool): + """Build the TMA-persistent variant via exec().""" + key = (num_extras, epilogue_code, reduce_n_by_2) + if key in _KERNEL_TMA_CACHE: + return _KERNEL_TMA_CACHE[key] + kernel = _build_kernel_via_exec( + KERNEL_TEMPLATE_TMA_PERSISTENT, + "dynamic_matmul_epilogue_kernel_tma", + num_extras, + epilogue_code, + reduce_n_by_2, + indent=8, # epilogue/store are inside the persistent for-loop + persist_cache=_AUTOTUNE_PERSIST_TMA, + ) + _KERNEL_TMA_CACHE[key] = kernel + return kernel + + +def _record_best_config(kernel, epilogue_key: str, M_bucket: int, N: int, K: int, persist: dict, save_fn) -> None: + """Persist the winning autotune config to disk after it is chosen.""" + triton_key = (M_bucket, N, K) + cfg = kernel.cache.get(triton_key) + if cfg is None: + return + cache_key = f"{epilogue_key}|{M_bucket},{N},{K}" + persist[cache_key] = {"kwargs": dict(cfg.kwargs), "num_stages": cfg.num_stages, "num_warps": cfg.num_warps} + save_fn() + + +def matmul_custom_epilogue( + A: torch.Tensor, B: torch.Tensor, extras: list[torch.Tensor], epilogue_code: str, reduce_n_by_2: bool +) -> torch.Tensor: + M, K = A.shape + _, N = B.shape + M_bucket = _bucket_m(M) + + N_out = N // 2 if reduce_n_by_2 else N + + # Align the row stride to 128 bytes so a subsequent cuBLAS mm can read + # this buffer as its A operand without Inductor inserting a row-padding copy. + elem_size = A.element_size() + align_elems = 128 // elem_size + N_stride = (N_out + align_elems - 1) // align_elems * align_elems + D = torch.empty((M, N_stride), device=A.device, dtype=A.dtype)[:, :N_out] + + epilogue_key = str((len(extras), epilogue_code, reduce_n_by_2)) + triton_key = (M_bucket, N, K) + + use_tma = _TMA_AVAILABLE and A.is_contiguous() and B.is_contiguous() + + if use_tma: + # ── TMA persistent path (SM90+) ─────────────────────────────────────── + # Device-side descriptors + persistent CTA loop over NUM_SMS SMs. + # B is read as [K, N] row-major; no pre-transpose required. + _ensure_tma_allocator() + NUM_SMS = torch.cuda.get_device_properties(A.device).multi_processor_count + kernel = get_dynamic_kernel_tma(len(extras), epilogue_code, reduce_n_by_2) + needs_persist = triton_key not in kernel.cache + + grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"])),) + + args = [A, B, D] + args.extend(extras) + args.extend([M, N, K, M_bucket, D.stride(0), D.stride(1), NUM_SMS]) + + kernel[grid](*args) + + if needs_persist: + _record_best_config(kernel, epilogue_key, M_bucket, N, K, _AUTOTUNE_PERSIST_TMA, _save_autotune_cache_tma) + + else: + # ── Non-persistent pointer-arithmetic path (all CUDA GPUs) ─────────── + kernel = get_dynamic_kernel(len(extras), epilogue_code, reduce_n_by_2) + needs_persist = triton_key not in kernel.cache + + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),) + + args = [A, B, D] + args.extend(extras) + args.extend([M, N, K, M_bucket, A.stride(0), A.stride(1), B.stride(0), B.stride(1), D.stride(0), D.stride(1)]) + + kernel[grid](*args) + + if needs_persist: + _record_best_config(kernel, epilogue_key, M_bucket, N, K, _AUTOTUNE_PERSIST, _save_autotune_cache) + + return D diff --git a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py index f6441e0..8e48203 100644 --- a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py +++ b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py @@ -22,6 +22,7 @@ from ...utils.envs import MAGI_PATTERN_MATCH_DEBUG from ..pass_base import InductorPass, get_pass_context from .fix_functionalization import FixFunctionalizationPass +from .fusion.matmul_epilogue_fusion import MatmulCustomEpilogueFusionPass from .post_cleanup import PostCleanupPass @@ -81,6 +82,7 @@ def configure(self, pass_config: PassConfig): self.pass_config = pass_config # TODO: Register custom passes here (fusion, noop elimination, sequence parallelism, async TP, Ulysses overlap). + self.add(MatmulCustomEpilogueFusionPass()) # needs a functional graph self.post_cleanup = PostCleanupPass() diff --git a/tests/feature_tests/test_matmul_epilogue_fusion.py b/tests/feature_tests/test_matmul_epilogue_fusion.py new file mode 100644 index 0000000..15e7127 --- /dev/null +++ b/tests/feature_tests/test_matmul_epilogue_fusion.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from magi_compiler.api import magi_compile +from magi_compiler.config import get_compile_config + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + + +# --------------------------------------------------------------------------- +# Activation functions +# --------------------------------------------------------------------------- + + +def high_precision_silu(x, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + x = x.to(torch.float32) + return F.silu(x).to(out_dtype) + + +def high_precision_sigmoid(x, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + x = x.to(torch.float32) + return F.sigmoid(x).to(out_dtype) + + +def high_precision_gelu(x, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + x = x.to(torch.float32) + return F.gelu(x).to(out_dtype) + + +def swiglu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + x = x.to(torch.float32) + x_glu, x_linear = x[..., ::2], x[..., 1::2] + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return (out_glu * (x_linear + 1)).to(out_dtype) + + +def gelu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + x = x.to(torch.float32) + x_glu = x.clamp(min=None, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return out_glu.to(out_dtype) + + +def relu_square(x, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + x = x.to(torch.float32) + return torch.square(F.relu(x)).to(out_dtype) + + +# --------------------------------------------------------------------------- +# Model wrappers +# --------------------------------------------------------------------------- + + +class SiluModel(nn.Module): + def forward(self, a, b): + return high_precision_silu(torch.mm(a, b), out_dtype=torch.bfloat16) + + +class SigmoidModel(nn.Module): + def forward(self, a, b): + return high_precision_sigmoid(torch.mm(a, b), out_dtype=torch.bfloat16) + + +class GeluModel(nn.Module): + def forward(self, a, b): + return high_precision_gelu(torch.mm(a, b), out_dtype=torch.bfloat16) + + +class Swiglu7Model(nn.Module): + def forward(self, a, b): + return swiglu7(torch.mm(a, b), out_dtype=torch.bfloat16) + + +class Gelu7Model(nn.Module): + def forward(self, a, b): + return gelu7(torch.mm(a, b), out_dtype=torch.bfloat16) + + +class ReluSquareModel(nn.Module): + def forward(self, a, b): + return relu_square(torch.mm(a, b), out_dtype=torch.bfloat16) + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + + +def _run_fusion_test(model: nn.Module, a: torch.Tensor, b: torch.Tensor, atol: float = 0.5, rtol: float = 0.0): + """Run a matmul-epilogue fusion test. + + Checks that the fused result satisfies: |actual - expected| < atol + rtol * |expected| + + atol=0.5 covers the bf16 → fp32 accumulation difference for element-wise + activations whose output magnitude is O(1). For activations that amplify + magnitude (e.g. relu_square), pass a non-zero rtol instead. + """ + model = model.cuda().bfloat16() + with torch.no_grad(): + expected = model(a, b) + + get_compile_config().disable_cache = True + compiled_model = magi_compile(model, dynamic_arg_dims={"a": 0}) + with torch.no_grad(): + actual = compiled_model(a, b) + + abs_diff = (actual - expected).abs() + tol = atol + rtol * expected.abs() + max_violation = (abs_diff - tol).max().item() + assert max_violation <= 0, ( + f"Fused result too far from reference: " + f"max(|diff| - tol) = {max_violation:.4f}, " + f"max |diff| = {abs_diff.max().item():.4f}" + ) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_matmul_epilogue_fusion_silu(): + M, K, N = 128, 256, 512 + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + _run_fusion_test(SiluModel(), a, b) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_matmul_epilogue_fusion_sigmoid(): + M, K, N = 128, 256, 512 + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + _run_fusion_test(SigmoidModel(), a, b) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_matmul_epilogue_fusion_gelu(): + M, K, N = 128, 256, 512 + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + _run_fusion_test(GeluModel(), a, b) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_matmul_epilogue_fusion_swiglu7(): + M, K, N = 128, 256, 512 + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + _run_fusion_test(Swiglu7Model(), a, b) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_matmul_epilogue_fusion_gelu7(): + M, K, N = 128, 256, 512 + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + _run_fusion_test(Gelu7Model(), a, b) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_matmul_epilogue_fusion_relu_square(): + M, K, N = 128, 256, 512 + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + # relu_square amplifies values quadratically (output ~ x^2, up to ~256), + # so use relative tolerance instead of a fixed absolute bound. + _run_fusion_test(ReluSquareModel(), a, b, atol=0.0, rtol=0.2) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From afb93998ce309f333075f1c7a16aade59ab57e73 Mon Sep 17 00:00:00 2001 From: wtr Date: Mon, 13 Apr 2026 19:21:01 +0800 Subject: [PATCH 2/7] add cute kernel --- .../piecewise_graph/fusion/cute_kernel.py | 1080 +++++++++++++++++ .../fusion/matmul_epilogue_fusion.py | 57 +- 2 files changed, 1128 insertions(+), 9 deletions(-) create mode 100644 magi_compiler/passes/piecewise_graph/fusion/cute_kernel.py diff --git a/magi_compiler/passes/piecewise_graph/fusion/cute_kernel.py b/magi_compiler/passes/piecewise_graph/fusion/cute_kernel.py new file mode 100644 index 0000000..fe6e4a0 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/cute_kernel.py @@ -0,0 +1,1080 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CuTe DSL GEMM with fused in-kernel epilogue for Hopper (SM90+). + +Design +------ +The key insight is that WGMMA accumulates results into register files (``tRS_rD``). +Before those registers are written to shared/global memory, we can apply elementwise +epilogue operations (activation, bias-add, scale, …) *in-place on the register +values* — completely avoiding the extra read-back from global memory that a +separate Triton epilogue pass would require. + +Concretely, inside the CuTe kernel's epilogue loop: + + for epi_idx in range_constexpr(epi_tile_num): + for epi_v in range_constexpr(size_tRS_rD): + tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] + + acc_vec = tRS_rD.load() # FP32 register tensor + # ── INJECT: fused epilogue ────────────────────────────────── + acc_vec = self._apply_epilogue(acc_vec) + # ──────────────────────────────────────────────────────────── + tRS_rD_out.store(acc_vec.to(self.c_dtype)) + ... + +``HopperWgmmaGemmEpilogueFusedKernel`` subclasses +``HopperWgmmaGemmPersistentKernel`` and overrides ``kernel()`` with this +single extra line, plus the mechanism to supply ``_apply_epilogue``. + +Epilogue representation +----------------------- +The epilogue is described by two complementary representations: + +1. **Triton epilogue string** (``epilogue_code``) — already generated by + ``MatmulCustomEpilogueFusionPass._try_fuse_custom_chain``. We *parse* this + string to drive the CuTe DSL code that runs inside the kernel. + +2. **CuTe DSL epilogue callable** (``epilogue_fn``) — a Python callable that + accepts a ``TensorSSA`` (FP32 accumulator tile) and returns a transformed + ``TensorSSA`` of the same shape. It is invoked at ``@cute.jit`` trace time + so it must only use CuTe DSL primitives (``cute.exp``, ``cute.tanh``, …). + +The ``_build_epilogue_fn`` factory converts the Triton epilogue string into a +CuTe DSL callable. It covers the same op set that ``triton_kernels.py`` +supports so all fused chains are handled correctly. + +Extras (bias tensors, etc.) +--------------------------- +The Triton string may reference ``Extra_0_ptr``, ``Extra_1_ptr``, … which are +additional (bias / scale) tensors. At CuTe DSL level these arrive as plain +FP16 1-D or 2-D GPU tensors; the epilogue builder injects loads via a small +helper that reads the correct row of the extra tensor for the current +``epi_idx`` subtile. + +Fallback +-------- +On non-Hopper or when ``cutlass-dsl`` is unavailable the module falls back to +the pure-Triton path (``matmul_custom_epilogue`` from ``triton_kernels.py``). +""" + +import ast +import sys +from dataclasses import dataclass +from typing import Callable, List, Optional + +import torch + +from .triton_kernels import matmul_custom_epilogue + +# ── CuTe DSL availability ────────────────────────────────────────────────────── +_HAS_CUTLASS: bool = False +_IS_HOPPER: bool = False + +try: + _IS_HOPPER = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9 + if _IS_HOPPER: + _CUTLASS_HOPPER_DIR = "/root/cutlass/examples/python/CuTeDSL/hopper" + if _CUTLASS_HOPPER_DIR not in sys.path: + sys.path.insert(0, _CUTLASS_HOPPER_DIR) + import cuda.bindings.driver as cuda + import cutlass + import cutlass.cute as cute + import cutlass.torch as cutlass_torch + import cutlass.utils + from dense_gemm_persistent import HopperWgmmaGemmPersistentKernel + + _HAS_CUTLASS = True +except Exception: + pass + + +# ── Epilogue-string → CuTe DSL translator ───────────────────────────────────── + + +def _build_epilogue_fn( + epilogue_code: str, extras: list, reduce_n_by_2: bool # list of GPU torch.Tensor (bias, scale, …) +) -> Optional[Callable]: + """Parse the Triton epilogue code string and return a CuTe DSL callable. + + The returned function has signature:: + + fn(acc_vec: TensorSSA, epi_idx: int, epi_tile_m: int, epi_tile_n: int, + extra_cute_tensors: list) -> TensorSSA + + where ``acc_vec`` is the FP32 register tile (shape = (EPI_TILE_M, EPI_TILE_N) + or a flat vector, depending on how cute delivers it). + + Returns ``None`` if the code string cannot be translated (fall back to Triton). + + Supported Triton constructs → CuTe DSL mapping + ----------------------------------------------- + acc → acc_vec (float32 register tensor) + tl.exp(x) → cute.exp(x) + tl.exp2(x) → cute.exp2(x) + tl.log(x) → cute.log(x) + tl.log2(x) → cute.log2(x) + tl.sqrt(x) → cute.sqrt(x) + tl.tanh(x) → cute.tanh(x) + tl.math.erf(x) → cute.erf(x) + tl.sigmoid(x) → 1/(1+cute.exp(-x)) + tl.maximum(x, y) → cute.where(x > y, x, y) + tl.minimum(x, y) → cute.where(x < y, x, y) + tl.where(c, x, y) → cute.where(c, x, y) + tl.abs(x) → cute.where(x >= 0, x, -x) + Arithmetic (+,-,*,/) → native Python operators on TensorSSA + ext_0 / ext_1 / … → broadcast-loaded from extras list + + Limitation: tl.split / tl.reshape (SwiGLU) are NOT supported in-kernel; + ``reduce_n_by_2=True`` cases fall back to the Triton epilogue path. + """ + if reduce_n_by_2: + return None # SwiGLU split not representable as a simple register op + + # Strip the static-dims header before parsing + code_lines = [l for l in epilogue_code.splitlines() if not l.startswith("# @static:")] + code = "\n".join(code_lines).strip() + if not code or code == "acc = acc": + return None # no-op epilogue — skip + + try: + tree = ast.parse(code, mode="exec") + except SyntaxError: + return None + + # Quick scan: reject unsupported constructs before building the callable + for node in ast.walk(tree): + if isinstance(node, ast.Call): + fn_name = "" + if isinstance(node.func, ast.Attribute): + # e.g. tl.split, tl.reshape → not supported + fn_name = node.func.attr + elif isinstance(node.func, ast.Name): + fn_name = node.func.id + if fn_name in ("split", "reshape"): + return None + + # Build the executable epilogue function via exec() in the CuTe DSL + # namespace. We translate Triton names to their CuTe equivalents by + # injecting a thin shim object ``tl`` that redirects attribute accesses. + fn_src = _emit_cute_epilogue_fn(code_lines, len(extras)) + if fn_src is None: + return None + + ns: dict = {} + exec_globals = {"cute": cute, "cutlass": cutlass} + try: + exec(compile(fn_src, "", "exec"), exec_globals, ns) + except Exception: + return None + + fn = ns.get("_cute_epilogue_fn") + return fn + + +def _emit_cute_epilogue_fn(code_lines: List[str], num_extras: int) -> Optional[str]: + """Emit a Python function that applies the epilogue on a CuTe register tensor. + + The generated function signature is:: + + def _cute_epilogue_fn(acc_vec, extras): + # translated epilogue body + ... + return acc_vec # final result + + ``acc_vec`` is the FP32 ``TensorSSA`` loaded from ``tRS_rD``. + ``extras`` is a list of already-loaded FP32 ``TensorSSA`` slices for each + extra operand (one slice per epi_idx, already broadcast/sliced to the + correct tile). + + Translation rules (Triton → CuTe): + acc → acc_vec + tl.exp(x) → cute.exp(x) + tl.exp2(x) → cute.exp2(x) + tl.log(x) → cute.log(x) + tl.log2(x) → cute.log2(x) + tl.sqrt(x) → cute.sqrt(x) + tl.tanh(x) → cute.tanh(x) + tl.math.erf(x) → cute.erf(x) + tl.sigmoid(x) → 1.0/(1.0+cute.exp(-x)) (emitted inline) + tl.maximum(x,y)→ cute.where(x>y,x,y) + tl.minimum(x,y)→ cute.where(x=0,x,-x) + ext_N → extras[N] (pre-loaded slice) + loads of extra ptrs (ext_N_ptrs / tl.load) → skipped (pre-loaded) + """ + body_lines = [] + + for raw in code_lines: + line = raw.strip() + if not line or line.startswith("#"): + continue + + # Skip the "ext_N_ptrs = ..." and "ext_N = tl.load(...)" lines — + # we supply pre-loaded slices in ``extras`` directly. + if "_ptrs" in line and ("Extra_" in line or "ext_" in line): + continue + # Detect ext_N = tl.load(...) patterns → replace with extras[N] lookup + if line.startswith("ext_") and "= tl.load(" in line: + # e.g. ext_0 = tl.load(ext_0_ptrs, ...) + varname = line.split("=")[0].strip() # "ext_0" + try: + idx = int(varname.split("_")[1]) + except (IndexError, ValueError): + return None + body_lines.append(f" {varname} = extras[{idx}]") + continue + + # Translate the rest + translated = _translate_line(line) + if translated is None: + return None + body_lines.append(f" {translated}") + + # Ensure the function ends with `return acc_vec` + if not any("return" in l for l in body_lines): + body_lines.append(" return acc_vec") + + fn_src = "def _cute_epilogue_fn(acc_vec, extras):\n" + fn_src += "\n".join(body_lines) if body_lines else " pass\n" + fn_src += "\n return acc_vec\n" + return fn_src + + +# ── Line-level Triton → CuTe DSL translator ─────────────────────────────────── + +# Mapping of tl.* / tl.math.* function names to their CuTe equivalents +_TL_TO_CUTE: dict = { + "exp": "cute.exp", + "exp2": "cute.exp2", + "log": "cute.log", + "log2": "cute.log2", + "sqrt": "cute.sqrt", + "rsqrt": "cute.rsqrt", # via cutlass.cute.math + "tanh": "cute.tanh", + "sin": "cute.sin", + "cos": "cute.cos", + "abs": "__cute_abs__", # special-cased + "maximum": "__cute_max__", # special-cased + "minimum": "__cute_min__", # special-cased + "where": "cute.where", + # tl.math.* + "erf": "cute.erf", + "sign": "__cute_sign__", # special-cased +} + +_TL_PASSTHROUGH = frozenset(["maximum", "minimum", "where"]) + + +def _translate_line(line: str) -> Optional[str]: + """Translate a single Triton epilogue line to a CuTe DSL expression. + + Returns the translated line string, or None if untranslatable. + """ + # Replace 'acc' variable (bare or in expressions) with 'acc_vec' + # Use a simple text replacement — won't confuse 'acc' with 'accumulator' etc. + # because the epilogue code only uses 'acc'. + line = _replace_token(line, "acc", "acc_vec") + + # tl.math.erf(x) → cute.erf(x) + line = line.replace("tl.math.erf(", "cute.erf(") + line = line.replace("tl.math.erfc(", "__cute_erfc__(") + line = line.replace("tl.math.erfinv(", "__cute_erfinv__(") + line = line.replace("tl.math.sign(", "__cute_sign__(") + line = line.replace("tl.math.isnan(", "__cute_isnan__(") + line = line.replace("tl.math.isinf(", "__cute_isinf__(") + line = line.replace("tl.math.floor(", "__cute_floor__(") + line = line.replace("tl.math.ceil(", "__cute_ceil__(") + line = line.replace("tl.math.trunc(", "__cute_trunc__(") + line = line.replace("tl.math.round(", "__cute_round__(") + line = line.replace("tl.math.pow(", "__cute_pow__(") + line = line.replace("tl.math.tan(", "__cute_tan__(") + line = line.replace("tl.math.asin(", "__cute_asin__(") + line = line.replace("tl.math.acos(", "__cute_acos__(") + line = line.replace("tl.math.atan(", "__cute_atan__(") + line = line.replace("tl.math.atan2(", "__cute_atan2__(") + line = line.replace("tl.math.sinh(", "__cute_sinh__(") + line = line.replace("tl.math.cosh(", "__cute_cosh__(") + + # tl.abs(x) → cute.where(x >= 0, x, -x) [no native cute.abs] + line = line.replace("tl.abs(", "__cute_abs__(") + + # tl.sigmoid(x) → (1.0/(1.0+cute.exp(-x))) + line = line.replace("tl.sigmoid(", "__cute_sigmoid__(") + + # tl.maximum / tl.minimum / tl.where → cute.where-based + line = line.replace("tl.maximum(", "__cute_max__(") + line = line.replace("tl.minimum(", "__cute_min__(") + line = line.replace("tl.where(", "cute.where(") + + # Standard tl.* math functions + for tl_name, cute_name in _TL_TO_CUTE.items(): + if cute_name.startswith("cute."): + line = line.replace(f"tl.{tl_name}(", f"{cute_name}(") + + # Reject any remaining tl.* calls (unsupported) + if "tl." in line: + return None + + # Expand the __cute_*__ shims inline (simple single-argument forms) + line = _expand_shims(line) + + return line + + +def _replace_token(s: str, old: str, new: str) -> str: + """Replace whole-token occurrences of ``old`` with ``new``.""" + import re + + return re.sub(r'\b' + re.escape(old) + r'\b', new, s) + + +def _expand_shims(line: str) -> str: + """Expand __cute_*__ shims to full CuTe DSL expressions. + + For single-argument shims this is straightforward string replacement. + For multi-argument (max/min) we can't easily parse here, so we emit + helper calls that are defined in the exec namespace. + """ + # These shims are injected into the exec namespace instead + # so no string expansion is needed at this stage — just keep them. + return line + + +def _make_exec_globals() -> dict: + """Build the exec namespace with CuTe DSL helpers for all shims.""" + if not _HAS_CUTLASS: + return {} + + def _cute_abs(x): + zero = cute.full_like(x, 0) + return cute.where(x >= zero, x, -x) + + def _cute_max(x, y): + if isinstance(y, (int, float)): + y = cute.full_like(x, float(y)) + return cute.where(x > y, x, y) + + def _cute_min(x, y): + if isinstance(y, (int, float)): + y = cute.full_like(x, float(y)) + return cute.where(x < y, x, y) + + def _cute_sigmoid(x): + one = cute.full_like(x, 1.0) + return one / (one + cute.exp(-x)) + + def _cute_sign(x): + zero = cute.full_like(x, 0.0) + one = cute.full_like(x, 1.0) + return cute.where(x > zero, one, cute.where(x < zero, -one, zero)) + + def _cute_pow(x, y): + return cute.exp(y * cute.log(x)) + + def _cute_erfc(x): + one = cute.full_like(x, 1.0) + return one - cute.erf(x) + + # Approximate inverse erf (not in CuTe math) + def _cute_erfinv(x): + # Halley approximation — good enough for epilogues + a = cute.full_like(x, 0.147) + pi_a = cute.full_like(x, 2.0 / (3.14159265358979 * 0.147)) + ln_term = cute.log(cute.full_like(x, 1.0) - x * x) + t = cute.sqrt( + cute.sqrt((pi_a + ln_term / cute.full_like(x, 2.0)) ** cute.full_like(x, 2.0) - ln_term / a) + - (pi_a + ln_term / cute.full_like(x, 2.0)) + ) + return cute.where(x >= cute.full_like(x, 0.0), t, -t) + + def _cute_isnan(x): + return x != x + + def _cute_isinf(x): + return cute.where(x != x, cute.full_like(x, 0.0), cute.full_like(x, 1.0)) != cute.full_like(x, 1.0) # placeholder + + def _cute_floor(x): + return cute.exp(cute.full_like(x, 0.0)) * x # placeholder — not in cute.math + + def _cute_ceil(x): + return x + + def _cute_trunc(x): + return x + + def _cute_round(x): + return x + + def _cute_tan(x): + return cute.sin(x) / cute.cos(x) + + def _cute_asin(x): + return cute.math.asin(x) + + def _cute_acos(x): + return cute.math.acos(x) + + def _cute_atan(x): + return cute.math.atan(x) + + def _cute_atan2(x, y): + return cute.math.atan2(x, y) + + def _cute_sinh(x): + ex = cute.exp(x) + return (ex - cute.full_like(x, 1.0) / ex) / cute.full_like(x, 2.0) + + def _cute_cosh(x): + ex = cute.exp(x) + return (ex + cute.full_like(x, 1.0) / ex) / cute.full_like(x, 2.0) + + return { + "cute": cute, + "cutlass": cutlass, + "__cute_abs__": _cute_abs, + "__cute_max__": _cute_max, + "__cute_min__": _cute_min, + "__cute_sigmoid__": _cute_sigmoid, + "__cute_sign__": _cute_sign, + "__cute_pow__": _cute_pow, + "__cute_erfc__": _cute_erfc, + "__cute_erfinv__": _cute_erfinv, + "__cute_isnan__": _cute_isnan, + "__cute_isinf__": _cute_isinf, + "__cute_floor__": _cute_floor, + "__cute_ceil__": _cute_ceil, + "__cute_trunc__": _cute_trunc, + "__cute_round__": _cute_round, + "__cute_tan__": _cute_tan, + "__cute_asin__": _cute_asin, + "__cute_acos__": _cute_acos, + "__cute_atan__": _cute_atan, + "__cute_atan2__": _cute_atan2, + "__cute_sinh__": _cute_sinh, + "__cute_cosh__": _cute_cosh, + } + + +def _compile_epilogue_fn(epilogue_code: str, num_extras: int, reduce_n_by_2: bool) -> Optional[Callable]: + """Compile the epilogue string into a CuTe DSL Python callable. + + Returns None if the epilogue cannot be represented (→ fallback to Triton). + """ + if reduce_n_by_2: + return None + + code_lines = [l for l in epilogue_code.splitlines() if not l.startswith("# @static:")] + code_lines = [l for l in code_lines if l.strip()] + + # Detect extra pointer load patterns and skip them (we inject extras directly) + filtered = [] + for l in code_lines: + stripped = l.strip() + # Skip "ext_N_ptrs = Extra_N_ptr + ..." lines + if "Extra_" in stripped and "_ptrs" in stripped: + continue + # Replace "ext_N = tl.load(ext_N_ptrs, ...)" with "ext_N = extras[N]" + if stripped.startswith("ext_") and "= tl.load(" in stripped: + varname = stripped.split("=")[0].strip() + try: + idx = int(varname.split("_")[1]) + filtered.append(f" {varname} = extras[{idx}]") + except (IndexError, ValueError): + return None + continue + # Translate the line + translated = _translate_line(stripped) + if translated is None: + return None + filtered.append(f" {translated}") + + if not filtered: + return None + + fn_src = "def _cute_epilogue_fn(acc_vec, extras):\n" + fn_src += "\n".join(filtered) + fn_src += "\n return acc_vec\n" + + exec_globals = _make_exec_globals() + ns: dict = {} + try: + exec(compile(fn_src, "", "exec"), exec_globals, ns) + except Exception: + return None + + return ns.get("_cute_epilogue_fn") + + +# ── In-kernel fused GEMM subclass ───────────────────────────────────────────── + +if _HAS_CUTLASS: + + class HopperWgmmaGemmEpilogueFusedKernel(HopperWgmmaGemmPersistentKernel): + """Hopper GEMM with epilogue fused into the accumulator register phase. + + The epilogue is applied on the FP32 accumulator register tensor + *before* it is converted to FP16 and stored, eliminating the extra + global-memory round-trip that a separate Triton epilogue kernel would need. + + Parameters + ---------- + epilogue_fn : callable or None + A CuTe DSL Python function ``fn(acc_vec, extras) -> TensorSSA``. + Compiled from the fusion-pass epilogue string by ``_compile_epilogue_fn``. + When *None*, the behaviour is identical to the base class. + extra_cute_tensors : list[cute.Tensor] + Pre-sliced CuTe tensors for bias / scale operands. One per extra + referenced by the epilogue. Passed through to ``epilogue_fn``. + All other args forwarded to ``HopperWgmmaGemmPersistentKernel.__init__``. + """ + + def __init__( + self, + acc_dtype, + tile_shape_mn, + cluster_shape_mn, + swizzle_size=1, + raster_along_m=True, + epilogue_fn=None, + extra_cute_tensors=None, + ): + super().__init__(acc_dtype, tile_shape_mn, cluster_shape_mn, swizzle_size, raster_along_m) + self._epilogue_fn = epilogue_fn + self._extra_cute_tensors = extra_cute_tensors or [] + + def _apply_epilogue(self, acc_vec): + """Apply the user-supplied epilogue to the FP32 accumulator tile.""" + if self._epilogue_fn is None: + return acc_vec + return self._epilogue_fn(acc_vec, self._extra_cute_tensors) + + # ── Override the GPU kernel to inject the epilogue ───────────────────── + @cute.kernel + def kernel( + self, + tma_atom_a, + mA_mkl, + tma_atom_b, + mB_nkl, + tma_atom_c, + mC_mnl, + tiled_mma, + cta_layout_mnk, + a_smem_layout_staged, + b_smem_layout_staged, + epi_smem_layout_staged, + tile_sched_params, + ): + # ── verbatim copy of the base class kernel body ──────────────────── + # with a single change: acc_vec is passed through _apply_epilogue + # before being stored. + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + if warp_idx == 0: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_c) + + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster) + + a_mcast_mask = cute.make_layout_image_mask(cta_layout_mnk, cluster_coord_mnk, mode=1) + b_mcast_mask = cute.make_layout_image_mask(cta_layout_mnk, cluster_coord_mnk, mode=0) + + a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0 + b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0)) + tma_copy_bytes = cute.size_in_bytes(self.a_dtype, a_smem_layout) + cute.size_in_bytes(self.b_dtype, b_smem_layout) + + import cutlass.pipeline as pipeline + import cutlass.utils as utils_mod + from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + + smem = utils_mod.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr() + mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + consumer_arrive_cnt = mcast_size * self.num_mma_warp_groups * self.num_warps_per_warp_group + mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, consumer_arrive_cnt) + mainloop_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=mainloop_pipeline_array_ptr, + num_stages=self.ab_stage, + producer_group=mainloop_pipeline_producer_group, + consumer_group=mainloop_pipeline_consumer_group, + tx_count=tma_copy_bytes, + cta_layout_vmnk=cute.make_layout((1, *cta_layout_mnk.shape)), + defer_sync=True, + ) + + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + sC = storage.sC.get_tensor(epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner) + + gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.tile_shape_mnk, (None, 0, None)), (None, None, None)) + gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.tile_shape_mnk, (0, None, None)), (None, None, None)) + gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.tile_shape_mnk, (None, None, 0)), (None, None, None)) + + a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape) + a_cta_crd = cluster_coord_mnk[1] + tAsA, tAgA = cute.nvgpu.cpasync.tma_partition( + tma_atom_a, a_cta_crd, a_cta_layout, cute.group_modes(sA, 0, 2), cute.group_modes(gA_mkl, 0, 2) + ) + + b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) + b_cta_crd = cluster_coord_mnk[0] + tBsB, tBgB = cute.nvgpu.cpasync.tma_partition( + tma_atom_b, b_cta_crd, b_cta_layout, cute.group_modes(sB, 0, 2), cute.group_modes(gB_nkl, 0, 2) + ) + + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + mma_warp_group_thread_layout = cute.make_layout(self.num_mma_warp_groups, stride=self.num_threads_per_warp_group) + thr_mma = tiled_mma.get_slice(mma_warp_group_thread_layout(warp_group_idx - self.num_dma_warp_groups)) + + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCrA = tiled_mma.make_fragment_A(tCsA) + tCrB = tiled_mma.make_fragment_B(tCsB) + + tCgC = thr_mma.partition_C(gC_mnl) + acc_shape = tCgC.shape[:3] + accumulators = cute.make_rmem_tensor(acc_shape, self.acc_dtype) + + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + is_dma_warp_group = warp_group_idx < self.num_dma_warp_groups + if is_dma_warp_group: + cute.arch.setmaxregister_decrease(self.load_register_requirement) + + # ── DMA warp group ───────────────────────────────────────────────── + if warp_idx == self.load_warp_id: + tile_sched = utils_mod.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + mainloop_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.ab_stage) + + while work_tile.is_valid_tile: + tile_coord_mnl = work_tile.tile_idx + tAgA_mkl = tAgA[(None, tile_coord_mnl[0], None, tile_coord_mnl[2])] + tBgB_nkl = tBgB[(None, tile_coord_mnl[1], None, tile_coord_mnl[2])] + mainloop_producer_state.reset_count() + + for k_tile in range(k_tile_cnt): + mainloop_pipeline.producer_acquire(mainloop_producer_state) + tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)] + tAsA_pipe = tAsA[(None, mainloop_producer_state.index)] + tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)] + tBsB_pipe = tBsB[(None, mainloop_producer_state.index)] + + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state), + mcast_mask=a_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state), + mcast_mask=b_mcast_mask, + ) + mainloop_pipeline.producer_commit(mainloop_producer_state) + mainloop_producer_state.advance() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + mainloop_pipeline.producer_tail(mainloop_producer_state) + + # ── MMA warp group ───────────────────────────────────────────────── + if not is_dma_warp_group: + cute.arch.setmaxregister_increase(self.mma_register_requirement) + tile_sched = utils_mod.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + mainloop_consumer_read_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage) + mainloop_consumer_release_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + + num_k_blocks = cute.size(tCrA, mode=[2]) + + import cutlass.utils.hopper_helpers as sm90_utils + + copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( + self.c_layout, elem_ty_d=self.c_dtype, elem_ty_acc=self.acc_dtype + ) + + copy_atom_C = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(self.c_layout.is_m_major_c(), 4), self.c_dtype + ) + tiled_copy_C_Atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_Atom) + + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx - self.num_dma_warp_groups * self.num_threads_per_warp_group) + tRS_sD = thr_copy_r2s.partition_D(sC) + tRS_rAcc = tiled_copy_r2s.retile(accumulators) + + rD_shape = cute.shape(thr_copy_r2s.partition_S(sC)) + tRS_rD_layout = cute.make_layout(rD_shape[:3]) + tRS_rD = cute.make_rmem_tensor(tRS_rD_layout.shape, self.acc_dtype) + tRS_rD_out = cute.make_rmem_tensor(tRS_rD_layout.shape, self.c_dtype) + size_tRS_rD = cute.size(tRS_rD) + + k_pipe_mmas = 1 + prologue_mma_cnt = min(k_pipe_mmas, k_tile_cnt) + + tma_store_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_mma_threads) + tma_store_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.epi_stage, producer_group=tma_store_producer_group + ) + + while work_tile.is_valid_tile: + tile_coord_mnl = work_tile.tile_idx + gC_mnl_slice = gC_mnl[(None, None, *tile_coord_mnl)] + + mainloop_consumer_read_state.reset_count() + mainloop_consumer_release_state.reset_count() + accumulators.fill(0.0) + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + cute.nvgpu.warpgroup.fence() + + for k_tile in range(prologue_mma_cnt): + mainloop_pipeline.consumer_wait(mainloop_consumer_read_state) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = (None, None, k_block_idx, mainloop_consumer_read_state.index) + cute.gemm(tiled_mma, accumulators, tCrA[k_block_coord], tCrB[k_block_coord], accumulators) + cute.nvgpu.warpgroup.commit_group() + mainloop_consumer_read_state.advance() + + for k_tile in range(prologue_mma_cnt, k_tile_cnt): + mainloop_pipeline.consumer_wait(mainloop_consumer_read_state) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = (None, None, k_block_idx, mainloop_consumer_read_state.index) + cute.gemm(tiled_mma, accumulators, tCrA[k_block_coord], tCrB[k_block_coord], accumulators) + cute.nvgpu.warpgroup.commit_group() + cute.nvgpu.warpgroup.wait_group(k_pipe_mmas) + mainloop_pipeline.consumer_release(mainloop_consumer_release_state) + mainloop_consumer_release_state.advance() + mainloop_consumer_read_state.advance() + + cute.nvgpu.warpgroup.wait_group(0) + for k_tile in range(prologue_mma_cnt): + mainloop_pipeline.consumer_release(mainloop_consumer_release_state) + mainloop_consumer_release_state.advance() + + # Epilogue + tCgC_for_tma_partition = cute.zipped_divide(gC_mnl_slice, self.epi_tile) + bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition( + tma_atom_c, 0, cute.make_layout(1), cute.group_modes(sC, 0, 2), tCgC_for_tma_partition + ) + epi_tile_num = cute.size(tCgC_for_tma_partition, mode=[1]) + epi_tile_shape = tCgC_for_tma_partition.shape[1] + epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1)) + num_prev_epi_tiles = tile_sched.num_tiles_executed * epi_tile_num + + for epi_idx in cutlass.range_constexpr(epi_tile_num): + for epi_v in cutlass.range_constexpr(size_tRS_rD): + tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] + + # ── Load FP32 accumulator tile ───────────────────────── + acc_vec = tRS_rD.load() + + # ── FUSED EPILOGUE: apply in registers ───────────────── + acc_vec = self._apply_epilogue(acc_vec) + + # ── Convert to output dtype and store ────────────────── + tRS_rD_out.store(acc_vec.to(self.c_dtype)) + + epi_buffer = (num_prev_epi_tiles + epi_idx) % cute.size(tRS_sD, mode=[3]) + cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[(None, None, None, epi_buffer)]) + cute.arch.fence_proxy("async.shared", space="cta") + self.epilog_sync_barrier.arrive_and_wait() + + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + if warp_idx == self.epi_store_warp_id: + cute.copy(tma_atom_c, bSG_sD[(None, epi_buffer)], bSG_gD[(None, gmem_coord)]) + tma_store_pipeline.producer_commit() + tma_store_pipeline.producer_acquire() + + self.epilog_sync_barrier.arrive_and_wait() + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + tma_store_pipeline.producer_tail() + + +# ── Two-level fused GEMM cache ───────────────────────────────────────────────── +# +# Shape-polymorphism strategy +# --------------------------- +# ``cute.compile()`` with ``is_dynamic_layout=True`` produces a kernel binary +# that is polymorphic in the M dimension: a kernel compiled for template M=128 +# can be called at runtime for any M (verified experimentally). N and K are +# typically static (weight-matrix dimensions) while M = batch×seq_len varies. +# +# We therefore split the cache into two levels: +# +# _COMPILED_CACHE key: (N, K, epilogue_code, num_extras, reduce_n_by_2) +# value: _CompiledEntry (compiled_gemm) +# → populated once, reused for every new M +# +# _BUFFER_CACHE key: (M, N, K) +# value: _BufferEntry (a/b/c aligned device buffers + CuTe +# descriptors for the specific M) +# → populated once per unique M, much cheaper than recompile +# +# This ensures ``cute.compile()`` is called at most once per (N,K,...) config +# regardless of how many distinct M values appear at runtime. + + +@dataclass +class _CompiledEntry: + """Compiled CuTe kernel — shape-polymorphic in the M dimension.""" + + compiled_gemm: object # result of cute.compile(...) + max_active_clusters: int # baked at compile time (HW-dependent constant) + + +@dataclass +class _BufferEntry: + """Aligned device buffers and CuTe descriptors for a specific (M, N, K).""" + + a_cute: object + a_ref: torch.Tensor # (M, K, 1) — input A + b_cute: object + b_ref: torch.Tensor # (N, K, 1) — input B (transposed) + c_cute: object + c_ref: torch.Tensor # (M, N, 1) — output C + + +_COMPILED_CACHE: dict = {} # (N, K, epi_code, num_extras, reduce_n) → _CompiledEntry | None +_BUFFER_CACHE: dict = {} # (M, N, K) → _BufferEntry + +_TILE_MN = (128, 256) +_CLUSTER_MN = (1, 1) +# Template M used for cute.compile(); the compiled kernel runs for any M. +_TEMPLATE_M = 128 + + +def _compile_kernel(N: int, K: int, epilogue_fn, extra_cute_tensors: list) -> Optional[_CompiledEntry]: + """Compile the fused GEMM kernel for fixed (N, K); polymorphic in M. + + Uses ``_TEMPLATE_M`` as a placeholder M during compilation — the resulting + binary runs correctly for any M because ``is_dynamic_layout=True`` keeps + M out of any ``Constexpr`` baked values. + + Returns None on any compilation failure. + """ + if not _HAS_CUTLASS: + return None + if K % 8 != 0 or N % 8 != 0: + return None + + M = _TEMPLATE_M + l = 1 + a_dtype = cutlass.Float16 + b_dtype = cutlass.Float16 + c_dtype = cutlass.Float16 + acc_dtype = cutlass.Float32 + + a_cpu = cutlass_torch.matrix(l, M, K, False, a_dtype) + b_cpu = cutlass_torch.matrix(l, N, K, False, b_dtype) + c_cpu = cutlass_torch.matrix(l, M, N, False, c_dtype) + + a_cute, _ = cutlass_torch.cute_tensor_like(a_cpu, a_dtype, is_dynamic_layout=True, assumed_align=16) + b_cute, _ = cutlass_torch.cute_tensor_like(b_cpu, b_dtype, is_dynamic_layout=True, assumed_align=16) + c_cute, _ = cutlass_torch.cute_tensor_like(c_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16) + + gemm = HopperWgmmaGemmEpilogueFusedKernel( + acc_dtype, + _TILE_MN, + _CLUSTER_MN, + swizzle_size=1, + raster_along_m=True, + epilogue_fn=epilogue_fn, + extra_cute_tensors=extra_cute_tensors, + ) + + hw = cutlass.utils.HardwareInfo() + mac = hw.get_max_active_clusters(_CLUSTER_MN[0] * _CLUSTER_MN[1]) + cu_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + try: + compiled_gemm = cute.compile(gemm, a_cute, b_cute, c_cute, mac, cu_stream) + except Exception: + return None + + return _CompiledEntry(compiled_gemm=compiled_gemm, max_active_clusters=mac) + + +def _get_or_create_buffers(M: int, N: int, K: int) -> Optional[_BufferEntry]: + """Return pre-allocated aligned CuTe buffers for the given (M, N, K). + + Allocates once per unique (M, N, K) and caches the result. Allocation is + much cheaper than ``cute.compile()`` but still non-trivial (GPU malloc + + CuTe descriptor creation), so caching across calls with the same shape is + important for training loops where M is fixed per microbatch. + """ + buf_key = (M, N, K) + if buf_key in _BUFFER_CACHE: + return _BUFFER_CACHE[buf_key] + + if not _HAS_CUTLASS: + return None + + l = 1 + a_dtype = cutlass.Float16 + b_dtype = cutlass.Float16 + c_dtype = cutlass.Float16 + + a_cpu = cutlass_torch.matrix(l, M, K, False, a_dtype) + b_cpu = cutlass_torch.matrix(l, N, K, False, b_dtype) + c_cpu = cutlass_torch.matrix(l, M, N, False, c_dtype) + + try: + a_cute, a_ref = cutlass_torch.cute_tensor_like(a_cpu, a_dtype, is_dynamic_layout=True, assumed_align=16) + b_cute, b_ref = cutlass_torch.cute_tensor_like(b_cpu, b_dtype, is_dynamic_layout=True, assumed_align=16) + c_cute, c_ref = cutlass_torch.cute_tensor_like(c_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16) + except Exception: + _BUFFER_CACHE[buf_key] = None + return None + + entry = _BufferEntry(a_cute=a_cute, a_ref=a_ref, b_cute=b_cute, b_ref=b_ref, c_cute=c_cute, c_ref=c_ref) + _BUFFER_CACHE[buf_key] = entry + return entry + + +def _compiled_cache_key(N, K, epilogue_code, num_extras, reduce_n_by_2): + """Cache key for the compiled kernel — M-independent.""" + return (N, K, epilogue_code, num_extras, reduce_n_by_2) + + +# ── Public API ───────────────────────────────────────────────────────────────── + + +def matmul_cute_custom_epilogue( + A: torch.Tensor, B: torch.Tensor, extras: list, epilogue_code: str, reduce_n_by_2: bool +) -> torch.Tensor: + """Run GEMM + epilogue fully fused in the CuTe Hopper kernel. + + The epilogue is applied on the FP32 accumulator register file *before* + type conversion and TMA store, saving one full read of the (M×N) result + from global memory compared to a separate Triton epilogue pass. + + Shape-polymorphic caching + ------------------------- + ``cute.compile()`` is called **at most once** per unique (N, K, epilogue) + configuration regardless of how many distinct M values appear at runtime. + For a typical transformer, N and K are static weight-matrix dimensions + while M = batch×seq_len varies freely; this strategy ensures the expensive + JIT compilation cost is paid only once per layer, not per step. + + At FX graph level, static dims satisfy ``type(d) is int`` on + ``node.meta["val"].shape``; dynamic dims are ``torch.SymInt``. This + function exploits that structure automatically via the two-level cache. + + Falls back to ``matmul_custom_epilogue`` (Triton TMA-persistent) when: + - Not running on Hopper (SM < 90), or + - ``cutlass-dsl`` is not installed, or + - The epilogue contains constructs not representable as CuTe register ops + (e.g. SwiGLU ``tl.split``), or + - The problem dimensions violate 16-byte alignment requirements. + + Parameters + ---------- + A : torch.Tensor — (M, K) FP16 row-major + B : torch.Tensor — (K, N) FP16 row-major + extras : list[torch.Tensor] + Additional bias / scale tensors referenced by the epilogue. + epilogue_code : str + Triton epilogue snippet from the fusion pass. + reduce_n_by_2 : bool + True for SwiGLU (output N = input N / 2). + """ + M, K = A.shape + _, N = B.shape + + if not _HAS_CUTLASS: + return matmul_custom_epilogue(A, B, extras, epilogue_code, reduce_n_by_2) + + # ── Level-1: compiled kernel lookup (expensive; M-independent) ──────────── + compile_key = _compiled_cache_key(N, K, epilogue_code, len(extras), reduce_n_by_2) + + if compile_key not in _COMPILED_CACHE: + epi_fn = _compile_epilogue_fn(epilogue_code, len(extras), reduce_n_by_2) + + if epi_fn is None: + _COMPILED_CACHE[compile_key] = None + else: + extra_cute = [] + for t in extras: + try: + from cutlass.cute.runtime import from_dlpack + + extra_cute.append(from_dlpack(t, assumed_align=16)) + except Exception: + extra_cute = None + break + + if extra_cute is None: + _COMPILED_CACHE[compile_key] = None + else: + compiled_entry = _compile_kernel(N, K, epi_fn, extra_cute) + _COMPILED_CACHE[compile_key] = compiled_entry # None on failure + + compiled_entry = _COMPILED_CACHE.get(compile_key) + if compiled_entry is None: + return matmul_custom_epilogue(A, B, extras, epilogue_code, reduce_n_by_2) + + # ── Level-2: buffer lookup (cheap; once per unique M) ───────────────────── + buf = _get_or_create_buffers(M, N, K) + if buf is None: + return matmul_custom_epilogue(A, B, extras, epilogue_code, reduce_n_by_2) + + # ── Copy input data into aligned CuTe buffers ────────────────────────────── + buf.a_ref.copy_(A.unsqueeze(2)) + buf.b_ref.copy_(B.T.contiguous().unsqueeze(2)) + + # ── Run the fused CuTe kernel ────────────────────────────────────────────── + cu_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + compiled_entry.compiled_gemm(buf.a_cute, buf.b_cute, buf.c_cute, cu_stream) + + # ── Extract result ───────────────────────────────────────────────────────── + N_out = N // 2 if reduce_n_by_2 else N + elem_size = A.element_size() + align_elems = 128 // elem_size + N_stride = (N_out + align_elems - 1) // align_elems * align_elems + D = torch.empty((M, N_stride), device=A.device, dtype=A.dtype)[:, :N_out] + + # c_ref layout is (M, N, 1); the kernel writes into it via TMA store + D.copy_(buf.c_ref[:, :N_out, 0]) + return D diff --git a/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py b/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py index ecc271f..e7c4704 100644 --- a/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py +++ b/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py @@ -21,10 +21,12 @@ from magi_compiler.passes.pass_base import MagiInductorPass +from .cute_kernel import _HAS_CUTLASS, matmul_cute_custom_epilogue from .triton_kernels import matmul_custom_epilogue _LIB = torch.library.Library("magi_epilogue", "DEF") _LIB.define("matmul_custom(Tensor A, Tensor B, Tensor[] extras, str epilogue_code, bool reduce_n_by_2) -> Tensor") +_LIB.define("matmul_custom_cute(Tensor A, Tensor B, Tensor[] extras, str epilogue_code, bool reduce_n_by_2) -> Tensor") @torch.library.impl(_LIB, "matmul_custom", "CUDA") @@ -32,18 +34,31 @@ def _matmul_custom_cuda(A, B, extras, epilogue_code, reduce_n_by_2): return matmul_custom_epilogue(A, B, extras, epilogue_code, reduce_n_by_2) -@torch.library.register_fake("magi_epilogue::matmul_custom") -def _matmul_custom_abstract(A, B, extras, epilogue_code, reduce_n_by_2): +@torch.library.impl(_LIB, "matmul_custom_cute", "CUDA") +def _matmul_custom_cute_cuda(A, B, extras, epilogue_code, reduce_n_by_2): + return matmul_cute_custom_epilogue(A, B, extras, epilogue_code, reduce_n_by_2) + + +def _matmul_abstract_shape(A, B, reduce_n_by_2): + """Shared shape + stride logic for both torch.library fake impls.""" N_out = B.shape[1] // 2 if reduce_n_by_2 else B.shape[1] # Mirror the 128-byte-aligned row stride used by the real kernel so that # Inductor's assert_size_stride matches what we actually return. - # Keep the logical shape as (M, N_out) — changing it would interfere with - # Inductor's own K-dimension padding for the downstream mm. align_elems = 128 // A.element_size() N_stride = (N_out + align_elems - 1) // align_elems * align_elems return A.new_empty_strided((A.shape[0], N_out), (N_stride, 1)) +@torch.library.register_fake("magi_epilogue::matmul_custom") +def _matmul_custom_abstract(A, B, extras, epilogue_code, reduce_n_by_2): + return _matmul_abstract_shape(A, B, reduce_n_by_2) + + +@torch.library.register_fake("magi_epilogue::matmul_custom_cute") +def _matmul_custom_cute_abstract(A, B, extras, epilogue_code, reduce_n_by_2): + return _matmul_abstract_shape(A, B, reduce_n_by_2) + + # ── Triton expression templates ──────────────────────────────────────────────── # Unary elementwise ops: {x} = operand expression string _UNARY_EXPRS = { @@ -179,13 +194,39 @@ def __call__(self, graph: fx.Graph) -> bool: fused = 0 for node in list(graph.nodes): if node.op == "call_function" and node.target in (torch.ops.aten.mm.default, torch.ops.aten.mm): - fused += self._try_fuse_custom_chain(graph, node) + # Prefer the CuTe path on Hopper; fall back to Triton-only. + if _HAS_CUTLASS: + fused += self._try_fuse_custom_chain_cute(graph, node) + else: + fused += self._try_fuse_custom_chain(graph, node) if fused: graph.eliminate_dead_code() return fused > 0 - def _try_fuse_custom_chain(self, graph: fx.Graph, mm_node: fx.Node) -> int: + def _try_fuse_custom_chain_cute(self, graph: fx.Graph, mm_node: fx.Node) -> int: + """Like ``_try_fuse_custom_chain`` but emits ``matmul_custom_cute``. + + Uses ``HopperWgmmaGemmPersistentKernel`` for the GEMM and a separate + Triton kernel for the epilogue. The epilogue code string is identical + to the one produced by ``_try_fuse_custom_chain`` so the two methods + share the same generation logic — only the dispatched op differs. + """ + return self._try_fuse_custom_chain(graph, mm_node, op=torch.ops.magi_epilogue.matmul_custom_cute.default) + + def _try_fuse_custom_chain(self, graph: fx.Graph, mm_node: fx.Node, *, op=None) -> int: + """Fuse a chain of elementwise ops following *mm_node* into a single kernel. + + Parameters + ---------- + op : callable, optional + The dispatch target to call in the fused graph node. Defaults to + ``torch.ops.magi_epilogue.matmul_custom.default`` (pure Triton). + Pass ``torch.ops.magi_epilogue.matmul_custom_cute.default`` to use + the CuTe GEMM path instead. + """ + if op is None: + op = torch.ops.magi_epilogue.matmul_custom.default A, B = mm_node.args fused_nodes = {mm_node: "acc"} @@ -417,9 +458,7 @@ def get_val(arg): epilogue_code = f"# @static:{json.dumps(static_dims, separators=(',', ':'))}\n" + epilogue_code with graph.inserting_after(last_fused_node): - fused_node = graph.call_function( - torch.ops.magi_epilogue.matmul_custom.default, args=(A, B, extras, epilogue_code, is_swiglu) - ) + fused_node = graph.call_function(op, args=(A, B, extras, epilogue_code, is_swiglu)) if "val" in last_fused_node.meta: val = last_fused_node.meta["val"] # Propagate the 128-byte-aligned row stride so downstream From ea5cc68c9d9a8cc2345bf6208a61114a6c96f854 Mon Sep 17 00:00:00 2001 From: wtr Date: Tue, 28 Apr 2026 20:00:47 +0800 Subject: [PATCH 3/7] [Feat] Add CUTLASS matmul-epilogue fusion path for sm_120 --- .../fusion/blackwell_geforce/__init__.py | 13 + .../cutlass_kernels/swiglu7_combine.h | 130 ++ .../cutlass_kernels/swiglu7_epi_one_stage.cu | 371 ++++++ .../fusion/blackwell_geforce/evt_codegen.py | 852 +++++++++++++ .../fusion/blackwell_geforce/evt_ir.py | 242 ++++ .../fusion/blackwell_geforce/evt_runtime.py | 583 +++++++++ .../matmul_epilogue_fusion.py | 716 +++++++++++ .../piecewise_graph/fusion/cute_kernel.py | 1080 ----------------- .../fusion/matmul_epilogue_fusion.py | 482 -------- .../piecewise_graph/fusion/triton_kernels.py | 582 --------- .../piecewise_graph/post_grad_pass_manager.py | 19 +- .../test_matmul_epilogue_fusion.py | 539 ++++++-- 12 files changed, 3365 insertions(+), 2244 deletions(-) create mode 100644 magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/__init__.py create mode 100644 magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_combine.h create mode 100644 magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_epi_one_stage.cu create mode 100644 magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_codegen.py create mode 100644 magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_ir.py create mode 100644 magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_runtime.py create mode 100644 magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py delete mode 100644 magi_compiler/passes/piecewise_graph/fusion/cute_kernel.py delete mode 100644 magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py delete mode 100644 magi_compiler/passes/piecewise_graph/fusion/triton_kernels.py diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/__init__.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/__init__.py new file mode 100644 index 0000000..3eaa44a --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_combine.h b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_combine.h new file mode 100644 index 0000000..631a490 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_combine.h @@ -0,0 +1,130 @@ +// Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Binary epilogue combine functor for the swiglu7 DualGemm fusion. +// +// D = silu_alpha( clamp(lhs, max=limit) ) * ( clamp(rhs, -limit, limit) + 1 ) +// +// silu_alpha(x) = x * sigmoid(alpha * x) alpha = 1.702, limit = 7.0 +// +// `lhs` is the gate-path output fragment (Op0 applied to A @ W_gate.T), +// `rhs` is the linear-path output fragment (Op1 applied to A @ W_linear.T). +// Both arrive as ElementOutput (bf16) fragments — this is dictated by the +// dual-epilogue call site (examples/45_dual_gemm/threadblock/dual_epilogue.h:413 +// passes `output_frag_ptr[0][i]` and `[1][i]`, which are post-conversion +// output-type fragments, not raw accumulator fragments). The combine upcasts +// to ElementCompute (fp32) internally, evaluates the swiglu7 expression, and +// converts back to bf16. +// +// Note on precision: the gate/linear matmuls accumulate in fp32 inside the +// MMAs. Op0/Op1 (LinearCombination, ScaleType::Nothing) downcast those fp32 +// accumulators to bf16 before this combine runs. The swiglu7 math itself +// stays in fp32 here, so the only extra precision loss vs the two-stage EVT +// version is the single fp32→bf16 round-trip on each accumulator at the +// epilogue boundary. Empirically this is well within the bf16 noise floor. +// +// Modelled on cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h — +// same interface contract: ElementOutput / ElementAccumulator / ElementCompute +// typedefs, kCount fragment width, empty Params, two operator() overloads +// (fragment + scalar), is_source_needed() returning true. + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +namespace cutlass { +namespace epilogue { +namespace thread { + +template < + typename ElementOutput_, + int Count, + typename ElementAccumulator_ = ElementOutput_, + typename ElementCompute_ = ElementOutput_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class Swiglu7Combine { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + struct Params {}; + +public: + + CUTLASS_HOST_DEVICE + Swiglu7Combine(Params const& /*params*/) {} + + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return true; } + + CUTLASS_HOST_DEVICE + void set_k_partition(int /*k_partition*/, int /*k_partition_count*/) { + // swiglu7 cannot be split-K-reduced (non-linear epilogue). + assert(false); + } + + // Fragment-level. lhs = gate output fragment (bf16, post Op0), + // rhs = linear output fragment (bf16, post Op1). + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentOutput const& lhs, + FragmentOutput const& rhs) const { + NumericArrayConverter in2c; + NumericArrayConverter c2o; + + ComputeFragment gate = in2c(lhs); + ComputeFragment lin = in2c(rhs); + ComputeFragment out; + + Sigmoid sig; + ElementCompute const limit(7.0f); + ElementCompute const nlimit(-7.0f); + ElementCompute const alpha(1.702f); + ElementCompute const one(1.0f); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + ElementCompute g = gate[i] < limit ? gate[i] : limit; + ElementCompute r = lin[i] < nlimit ? nlimit + : (lin[i] > limit ? limit : lin[i]); + ElementCompute silu_g = g * sig(alpha * g); + out[i] = silu_g * (r + one); + } + return c2o(out); + } + + // Scalar overload — required by the DualGemm epilogue boilerplate. + CUTLASS_HOST_DEVICE + ElementOutput operator()(ElementOutput const& lhs, + ElementOutput const& rhs) const { + ElementCompute g(lhs), r(rhs); + ElementCompute const limit(7.0f); + ElementCompute const nlimit(-7.0f); + ElementCompute const alpha(1.702f); + ElementCompute const one(1.0f); + + Sigmoid sig; + + g = g < limit ? g : limit; + r = r < nlimit ? nlimit : (r > limit ? limit : r); + ElementCompute silu_g = g * sig(alpha * g); + return ElementOutput(silu_g * (r + one)); + } +}; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_epi_one_stage.cu b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_epi_one_stage.cu new file mode 100644 index 0000000..3be0203 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_epi_one_stage.cu @@ -0,0 +1,371 @@ +// Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: BSD-3-Clause +// +// Single-kernel fully-fused swiglu7: +// +// D = swiglu7(A @ B.T) +// +// A : (M, K) bf16 row-major +// B : (N, K) bf16 row-major (torch.nn.Linear weight convention; N even) +// D : (M, N/2) bf16 row-major +// +// Implementation uses cutlass::gemm::device::DualGemm — the two GEMMs +// A @ W_gate.T and A @ W_linear.T run in the same threadblock sharing A's +// smem stages; their accumulators stay in registers and a custom +// Swiglu7Combine epilogue functor combines them and writes only D. +// +// AUTOTUNE: at first call per (M, N, K) tuple the runner times every +// registered (TileShape, WarpShape, Stages) candidate and caches the +// fastest one. The candidate set is hand-tuned for RTX 5090 (sm_120) +// — see register_candidates() for the rationale and SMEM budget math. + +#include +#include + +#include +#include + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/util/host_tensor.h" + +#include "45_dual_gemm/device/dual_gemm.h" +#include "swiglu7_combine.h" + +//////////////////////////////////////////////////////////////////////////////// +// Data types +//////////////////////////////////////////////////////////////////////////////// + +using ElementA = cutlass::bfloat16_t; +using ElementB = cutlass::bfloat16_t; +using ElementC = cutlass::bfloat16_t; +using ElementAcc = float; +using ElementCompute = float; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB0 = cutlass::layout::ColumnMajor; // strided ldB = 2K view +using LayoutB1 = cutlass::layout::ColumnMajor; // strided ldB = 2K view +using LayoutC = cutlass::layout::RowMajor; + +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // = 8 +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // = 8 +// Output vector width = 4 (bf16, 8 bytes) so any N_out divisible by 4 is OK +// — N=27304 → N_out=13652 is 4-aligned but not 8-aligned. +constexpr int EpilogueVecCount = 4; + +using ArchTag = cutlass::arch::Sm80; +using OperatorClass = cutlass::arch::OpClassTensorOp; +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + +constexpr auto kScaleType = cutlass::epilogue::thread::ScaleType::Nothing; +constexpr bool kSplitKSerial = false; +constexpr bool kStoreD0 = false; +constexpr bool kStoreD1 = false; + +//////////////////////////////////////////////////////////////////////////////// +// Per-tile DualGemm wrapper. The DualGemm device type is templated on +// (TileShape, WarpShape, Stages) — every autotune candidate instantiates the +// full kernel for its tuple. Compile time grows linearly with candidate count +// but DualGemm Sm80 is much cheaper to compile than the EVT path (no visitor +// tree), so we can afford 8–10 candidates. +//////////////////////////////////////////////////////////////////////////////// + +template +struct DualGemmConfig { + using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< + ElementC, EpilogueVecCount, ElementAcc, ElementCompute, kScaleType>; + using EpilogueOp1 = cutlass::epilogue::thread::LinearCombination< + ElementC, EpilogueVecCount, ElementAcc, ElementCompute, kScaleType>; + using EpilogueOp2 = cutlass::epilogue::thread::Swiglu7Combine< + ElementC, EpilogueVecCount, ElementAcc, ElementCompute>; + + using Gemm = cutlass::gemm::device::DualGemm< + ElementA, LayoutA, + ElementB, LayoutB0, LayoutB1, + ElementC, LayoutC, + ElementAcc, + OperatorClass, ArchTag, + TbShape, WaShape, InstructionShape, + EpilogueOp0, EpilogueOp1, EpilogueOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + Stages, + kStoreD0, kStoreD1, kSplitKSerial, + AlignmentA, AlignmentB>; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Type-erased runner concept; one instance per autotune candidate. +//////////////////////////////////////////////////////////////////////////////// + +struct Sw7Args { + int M; // activations rows + int N_out; // = N/2 (output cols) + int K; + void* ptr_A; + void* ptr_B; // (N, K) row-major weight; gate/linear interleaved + void* ptr_D; // (M, N_out) +}; + +class Sw7Concept { + public: + virtual ~Sw7Concept() = default; + virtual size_t get_workspace_size(const Sw7Args&) = 0; + virtual cutlass::Status initialize(const Sw7Args&, void* ws, cudaStream_t) = 0; + virtual cutlass::Status run(cudaStream_t stream) = 0; + virtual const char* name() const = 0; +}; + +template +class Sw7Impl : public Sw7Concept { + public: + using GemmType = typename Cfg::Gemm; + using EpilogueOp0 = typename Cfg::EpilogueOp0; + using EpilogueOp1 = typename Cfg::EpilogueOp1; + using EpilogueOp2 = typename Cfg::EpilogueOp2; + + explicit Sw7Impl(const char* name) : name_(name) {} + + typename GemmType::Arguments make_args(const Sw7Args& a) { + auto ptrA = reinterpret_cast(a.ptr_A); + auto ptrB = reinterpret_cast(a.ptr_B); + auto ptrD = reinterpret_cast(a.ptr_D); + int const M = a.M, N_out = a.N_out, K = a.K; + + int64_t const ldB_strided = static_cast(2) * K; + LayoutB0 layoutB_gate(ldB_strided); + LayoutB1 layoutB_linear(ldB_strided); + LayoutC layoutC(static_cast(N_out)); + + using TensorRefA = cutlass::TensorRef; + using TensorRefB0 = cutlass::TensorRef; + using TensorRefB1 = cutlass::TensorRef; + using TensorRefCi = cutlass::TensorRef; + using TensorRefDo = cutlass::TensorRef; + + TensorRefA ref_A0(ptrA, LayoutA(static_cast(K))); + TensorRefB0 ref_B0(ptrB, layoutB_gate); + TensorRefCi ref_C0(nullptr, LayoutC(0)); + TensorRefDo ref_D0(nullptr, LayoutC(0)); + TensorRefB1 ref_B1(ptrB + K, layoutB_linear); + TensorRefCi ref_C1(nullptr, LayoutC(0)); + TensorRefDo ref_D1(nullptr, LayoutC(0)); + TensorRefDo ref_D2(ptrD, layoutC); + + typename EpilogueOp0::Params epi0{ElementCompute(1.0f), ElementCompute(0.0f)}; + typename EpilogueOp1::Params epi1{ElementCompute(1.0f), ElementCompute(0.0f)}; + typename EpilogueOp2::Params epi2{}; + + cutlass::gemm::GemmCoord problem{M, N_out, K}; + typename GemmType::Arguments args( + cutlass::gemm::DualGemmMode::kGemm, + problem, + ref_A0, + ref_B0, ref_C0, ref_D0, + ref_B1, ref_C1, ref_D1, + ref_D2, + epi0, epi1, epi2, + /*split_k_slices=*/1, + /*batch_count=*/1, + /*batch_stride_A=*/0, + /*batch_stride_B0=*/0, + /*batch_stride_B1=*/0, + /*batch_stride_C=*/0, + /*batch_stride_D=*/0); + return args; + } + + size_t get_workspace_size(const Sw7Args& a) override { + return GemmType::get_workspace_size(make_args(a)); + } + cutlass::Status initialize(const Sw7Args& a, void* ws, cudaStream_t s) override { + return gemm_.initialize(make_args(a), ws, s); + } + cutlass::Status run(cudaStream_t stream) override { + return gemm_.run(stream); + } + const char* name() const override { return name_; } + + private: + GemmType gemm_; + const char* name_; +}; + +//////////////////////////////////////////////////////////////////////////////// +// AutoTune runner — first call per (M, N_out, K) shape times all candidates. +//////////////////////////////////////////////////////////////////////////////// + +#define SW7_TILE(tb_m, tb_n, tb_k, wa_m, wa_n, wa_k, stages, label) \ + configs_.push_back(std::make_unique< \ + Sw7Impl, \ + cutlass::gemm::GemmShape, \ + stages>>>(label)) + +class Sw7AutoTuneRunner { + public: + Sw7AutoTuneRunner() { + // Tile candidates for RTX 5090 (sm_120, 100 KB SMEM/SM, 170 SMs). + // + // SMEM cost for DualGemm = (BM + 2*BN) * BK * 2B * stages because both + // B operands live in smem simultaneously. Budget cap ~96 KB. + // + // Bucket of M doesn't drive a separate .cu here — DualGemm compiles + // fast enough that one runner with all candidates handles every M, and + // the per-shape cache picks the best for whatever M it sees. + + // ── Small / decode-friendly tiles ──────────────────────────────────────── + SW7_TILE(64, 64, 32, 32, 32, 32, 4, "T<64,64,32>_S4"); // 36 KB + SW7_TILE(64, 64, 64, 32, 32, 64, 3, "T<64,64,64>_S3"); // 72 KB + SW7_TILE(64, 128, 32, 32, 64, 32, 3, "T<64,128,32>_S3"); // 60 KB + SW7_TILE(64, 128, 32, 32, 64, 32, 4, "T<64,128,32>_S4"); // 80 KB + + // ── Medium tiles (CUTLASS bf16 reference defaults) ────────────────────── + SW7_TILE(128, 64, 32, 64, 32, 32, 3, "T<128,64,32>_S3"); // 48 KB (original default) + SW7_TILE(128, 64, 32, 64, 32, 32, 4, "T<128,64,32>_S4"); // 64 KB + SW7_TILE(128, 64, 64, 64, 32, 64, 3, "T<128,64,64>_S3"); // 96 KB + SW7_TILE(128, 128, 32, 64, 64, 32, 3, "T<128,128,32>_S3"); // 72 KB + SW7_TILE(128, 128, 32, 64, 64, 32, 4, "T<128,128,32>_S4"); // 96 KB + + // ── Large prefill tiles ───────────────────────────────────────────────── + SW7_TILE(256, 64, 32, 64, 32, 32, 3, "T<256,64,32>_S3"); // 72 KB + // (256, 128, 32) needs stages>=3 (DualGemm requires multistage). With + // stages=3 SMEM = (256 + 256) * 32 * 2 * 3 = 96 KB — exactly at budget, + // tends to fail with SMEM allocation errors at runtime. Omitted. + + // (128, 256, 32)*3 = 120 KB > 96 — omitted. + // (64, 256, 32)*3 = 108 KB > 96 — omitted. + } + + void operator()(at::Tensor A, at::Tensor B, at::Tensor D) { + TORCH_CHECK(A.is_cuda() && B.is_cuda() && D.is_cuda(), + "all inputs must be CUDA tensors"); + TORCH_CHECK(A.scalar_type() == at::kBFloat16 && B.scalar_type() == at::kBFloat16 + && D.scalar_type() == at::kBFloat16, + "all inputs must be bf16"); + TORCH_CHECK(A.dim() == 2 && B.dim() == 2 && D.dim() == 2, "A, B, D must be 2D"); + TORCH_CHECK(A.size(1) == B.size(1), "K mismatch (A.size(1) vs B.size(1))"); + TORCH_CHECK(A.is_contiguous() && B.is_contiguous() && D.is_contiguous(), + "A, B, D must be contiguous"); + + int const M = static_cast(A.size(0)); + int const K = static_cast(A.size(1)); + int const N = static_cast(B.size(0)); + TORCH_CHECK((N % 2) == 0, "N must be even, got ", N); + int const N_out = N / 2; + TORCH_CHECK(D.size(0) == M && D.size(1) == N_out, + "D must be (M, N/2) = (", M, ",", N_out, ")"); + + Sw7Args ea; + ea.M = M; ea.N_out = N_out; ea.K = K; + ea.ptr_A = A.data_ptr(); + ea.ptr_B = B.data_ptr(); + ea.ptr_D = D.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.device().index()).stream(); + + // Single autotune per module. The .cu is compiled per (M-bucket, N, K) + // on the Python side — every distinct weight (N, K) gets its own .cu, + // so this runner instance hosts exactly one (N, K) and one bucket. The + // first call autotunes; all subsequent calls (any M in the bucket) + // reuse `best_idx_`. + if (best_idx_ < 0) { + best_idx_ = autotune(ea, stream); + } + int idx = best_idx_; + + auto& gemm = configs_[idx]; + size_t ws_sz = gemm->get_workspace_size(ea); + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) { + ws_ = at::empty({(int64_t)ws_sz + 1}, + at::TensorOptions().dtype(at::kByte).device(A.device())); + } + auto st = gemm->initialize(ea, ws_sz > 0 ? ws_.data_ptr() : nullptr, stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "DualGemm init failed (", gemm->name(), "): ", + cutlassGetStatusString(st)); + st = gemm->run(stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "DualGemm run failed (", gemm->name(), "): ", + cutlassGetStatusString(st)); + } + + int num_configs() const { return (int)configs_.size(); } + + private: + int autotune(const Sw7Args& ea, cudaStream_t stream) { + int best_idx = -1; + float best_time = 1e30f; + cudaEvent_t s, e; + cudaEventCreate(&s); cudaEventCreate(&e); + + for (size_t i = 0; i < configs_.size(); ++i) { + auto& g = configs_[i]; + size_t ws_sz = 0; + try { ws_sz = g->get_workspace_size(ea); } + catch (...) { continue; } + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) { + ws_ = at::empty({(int64_t)ws_sz + 1}, + at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + } + void* ws_ptr = ws_sz > 0 ? ws_.data_ptr() : nullptr; + if (g->initialize(ea, ws_ptr, stream) != cutlass::Status::kSuccess) { + continue; + } + + // Warmup — 10 iters so the L2 / instruction cache settle. With only + // 3 warmups (the original count) the first timed iter sees a cold L2 + // and inflates the average, sometimes flipping the best-config choice. + for (int w = 0; w < 10; ++w) g->run(stream); + cudaStreamSynchronize(stream); + + // Time — 50 iters keeps timing noise to <1% so 2–3 % perf gaps + // between candidates are distinguishable. + cudaEventRecord(s, stream); + int iters = 50; + for (int p = 0; p < iters; ++p) g->run(stream); + cudaEventRecord(e, stream); + cudaEventSynchronize(e); + float ms = 0; + cudaEventElapsedTime(&ms, s, e); + float avg = ms / iters; + if (avg < best_time) { best_time = avg; best_idx = (int)i; } + } + cudaEventDestroy(s); cudaEventDestroy(e); + TORCH_CHECK(best_idx >= 0, + "swiglu7 AutoTune: no candidate succeeded for (M,N_out,K)=(", + ea.M, ",", ea.N_out, ",", ea.K, ")"); + return best_idx; + } + + std::vector> configs_; + int best_idx_ = -1; // -1 = not yet autotuned; sticky after first call. + at::Tensor ws_; +}; + +static Sw7AutoTuneRunner& runner() { + static Sw7AutoTuneRunner R; + return R; +} + +void swiglu7_dual_matmul_out(at::Tensor A, at::Tensor B, at::Tensor D) { + runner()(std::move(A), std::move(B), std::move(D)); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "CUTLASS DualGemm fully-fused swiglu7 (bf16) on sm_120 — autotune"; + m.def("swiglu7_dual_matmul_out", + &swiglu7_dual_matmul_out, + "D = swiglu7(A @ B.T) in a single fused kernel; " + "A:(M,K) bf16, B:(N,K) bf16 (N even), D:(M,N/2) bf16", + pybind11::arg("A"), + pybind11::arg("B"), + pybind11::arg("D")); + m.def("num_configs", []() { return runner().num_configs(); }); +} diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_codegen.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_codegen.py new file mode 100644 index 0000000..af5bc82 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_codegen.py @@ -0,0 +1,852 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Render a CUTLASS .cu source from an EVT IR tree. + +The output is a single self-contained file that: + 1. Declares any custom functor templates required by scalar-baked ops + (ClampMaxC, ScaledSiLuAlpha, GeluErf, …) — each baked with its constant. + 2. Declares the bottom-up Sm80EVT typedef chain. + 3. Declares the GemmKernel + DeviceGemm + entry point. + 4. Exposes ``evt_matmul_out`` via PYBIND11. + +We use CUTLASS 2.x ``Sm80EVT`` running backward-compat on sm_120; this matches +``/root/cutlass/examples/99_evt_demo/heavy_epi_torch_ext.cu`` which has been +verified to deliver +5..+12 % vs the Triton TMA path on RTX 5090 bf16. +""" + +from __future__ import annotations + +import textwrap +from typing import Dict, List, Tuple + +from .evt_ir import Accum, AuxLoad, ColBroadcast, Compute, RowBroadcast, Store, walk_leaves + +# ── PyTorch dtype string → CUTLASS type ────────────────────────────────────── +_DTYPE_TO_CUTLASS = {"bfloat16": "cutlass::bfloat16_t", "float16": "cutlass::half_t", "float32": "float"} + +# PyTorch dtype string → at::ScalarType / pybind dtype string used in TORCH_CHECK. +_DTYPE_TO_AT = {"bfloat16": "at::kBFloat16", "float16": "at::kHalf", "float32": "at::kFloat"} + + +# ── Per-M-bucket tile candidate sets, hand-tuned for RTX 5090 (sm_120) ────── +# Hardware constraints driving these choices: +# * 170 SMs — the optimal grid size is some multiple of 170; small tiles +# keep more CTAs in flight when M is short. +# * 100 KB SMEM / SM — per-stage SMEM = (BM + BN) * BK * 2 (bf16). With +# stages=4 and (128,128,32) we land at 128 KB which exceeds budget; we +# prefer stages=3 in that case. (128,128,32)*4 = 128KB, (128,256,32)*3=144KB, +# (256,128,32)*3=144KB are still over budget but CUTLASS auto-shrinks +# stages on Sm80 if SMEM doesn't fit. We rely on can_implement / init to +# reject illegal combos at autotune time. +# * Decode-style M (≤256) loses parallelism on big tiles — 1 wave covers +# just a handful of N tiles. Need small BM. +# * Prefill-style M (>2048) has plenty of parallelism — bigger tiles win +# because they amortise loads better. +# +# Each tuple is (BM, BN, BK, WM, WN, WK, NumStages, label). +# WarpShape is conventionally TileShape / (2, 2) along (M, N), keeping 4 warps. +# We include WK == BK to match Sm80 TensorOp's default warp tiling. +_TILE_CANDIDATES_5090: dict = { + # ── small (decode / single-token) ──────────────────────────────────────── + # M ≤ 256: low parallelism along M. Use small BM to launch more CTAs along N. + # All candidates have BM*BN ≤ 16384 to keep occupancy high on 170 SMs. + "small": [ + (64, 64, 32, 32, 32, 32, 4, "T<64,64,32>_S4"), + (64, 64, 64, 32, 32, 64, 3, "T<64,64,64>_S3"), + (64, 128, 32, 32, 64, 32, 3, "T<64,128,32>_S3"), + (64, 128, 32, 32, 64, 32, 4, "T<64,128,32>_S4"), + (64, 128, 64, 32, 64, 64, 3, "T<64,128,64>_S3"), + (64, 256, 32, 32, 64, 32, 3, "T<64,256,32>_S3"), + (128, 64, 32, 64, 32, 32, 3, "T<128,64,32>_S3"), + (128, 64, 32, 64, 32, 32, 4, "T<128,64,32>_S4"), + ], + # ── medium (256 < M ≤ 2048) ────────────────────────────────────────────── + # Standard CUTLASS bf16 sweet spot. Mix BM=128/256 with BN=64/128/256. + "medium": [ + (128, 128, 32, 64, 64, 32, 3, "T<128,128,32>_S3"), + (128, 128, 32, 64, 64, 32, 4, "T<128,128,32>_S4"), + (128, 128, 64, 64, 64, 64, 3, "T<128,128,64>_S3"), + (128, 256, 32, 64, 64, 32, 3, "T<128,256,32>_S3"), + (256, 128, 32, 64, 64, 32, 3, "T<256,128,32>_S3"), + (128, 64, 64, 64, 32, 64, 4, "T<128,64,64>_S4"), + (64, 128, 64, 32, 64, 64, 4, "T<64,128,64>_S4"), + ], + # ── large (M > 2048) ───────────────────────────────────────────────────── + # Plenty of parallelism — bigger tiles for better arith density. SMEM + # budget on 5090 (100 KB) restricts (256,128) and (128,256) to stages=3. + "large": [ + (128, 256, 32, 64, 64, 32, 3, "T<128,256,32>_S3"), + (256, 128, 32, 64, 64, 32, 3, "T<256,128,32>_S3"), + (128, 128, 32, 64, 64, 32, 4, "T<128,128,32>_S4"), + (128, 128, 64, 64, 64, 64, 3, "T<128,128,64>_S3"), + (256, 128, 64, 64, 64, 64, 3, "T<256,128,64>_S3"), + (128, 256, 64, 64, 64, 64, 3, "T<128,256,64>_S3"), + ], +} + + +def _emit_tile_candidates(m_bucket: str) -> str: + """Emit C++ EVT_TILE_CANDIDATE(...) statements for a given M bucket.""" + candidates = _TILE_CANDIDATES_5090.get(m_bucket, _TILE_CANDIDATES_5090["medium"]) + lines = [] + for bm, bn, bk, wm, wn, wk, stages, label in candidates: + lines.append(f' EVT_TILE_CANDIDATE({bm}, {bn}, {bk}, {wm}, {wn}, {wk}, ' f'{stages}, "{label}");') + return "\n".join(lines) + + +# For data_ptr() casts at the C++ layer. +_DTYPE_TO_AT_CPP = {"bfloat16": "at::BFloat16", "float16": "at::Half", "float32": "float"} + + +# ── Built-in CUTLASS op names for the visitor template-template parameter ──── +# Maps IR op name → (CUTLASS template name, is_class_template_with_T_only) +# Each value must be a `template class` accepting a single type arg. +_BUILTIN_FN_TEMPLATE = { + # binary + "add": "cutlass::plus", + "sub": "cutlass::minus", + "mul": "cutlass::multiplies", + "div": "cutlass::divides", + "max": "cutlass::maximum", + "min": "cutlass::minimum", + # unary + "neg": "cutlass::negate", + "sigmoid": "cutlass::epilogue::thread::Sigmoid", + "silu": "cutlass::epilogue::thread::SiLu", + "tanh": "cutlass::epilogue::thread::Tanh", + "relu": "cutlass::epilogue::thread::ReLu", + "abs": "cutlass::absolute_value_op", +} + +# Unary ops that need a custom emitted functor (CUTLASS has no built-in). +# Each maps to a body template; the body uses ``T`` as the element type and +# operates on a single ``T`` value named ``x``. +_CUSTOM_UNARY_BODY = { + "square": "return x * x;", + "exp": "return cutlass::fast_exp(x);", + "log": "return cutlass::fast_log(x);", + "sqrt": "return cutlass::fast_sqrt(x);", + "rsqrt": "return cutlass::fast_rsqrt(x);", + "erf": "return T(erff(float(x)));", + "gelu_erf": "return T(0.5f) * x * (T(1.0f) + T(erff(float(x) * 0.70710678118654752f)));", + "gelu_tanh": ( + "float v = float(x);" " return T(0.5f * v * (1.0f + tanhf(" "0.7978845608028654f * (v + 0.044715f * v * v * v))));" + ), +} + +# Scalar-baked unary ops. The body template uses ``x`` and ``c`` (the baked +# constant, emitted as a ``T`` literal — never a runtime value). +_CUSTOM_SCALAR_BODY = { + "add_scalar": "return x + c;", + "sub_scalar": "return x - c;", + "mul_scalar": "return x * c;", + "div_scalar": "return x / c;", + "rsub_scalar": "return c - x;", + "clamp_min_c": "return x < c ? c : x;", + "clamp_max_c": "return x < c ? x : c;", + # scaled_silu_alpha(x, alpha) = x * sigmoid(alpha * x). Used by GELU7. + "scaled_silu_alpha": ( + "T t = c * x;" " T one = T(1.0f);" " T sig = one / (one + cutlass::fast_exp(-t));" " return x * sig;" + ), + # pow_scalar(x, c) – emit as repeated multiplies for small int c. + # Otherwise fall back to powf. + "pow_scalar": "return T(powf(float(x), float(c)));", +} + + +def _scalar_literal_T(value: float) -> str: + """Emit a constant as a ``T(...)`` cast that survives bf16 / fp16 / fp32.""" + # repr keeps round-trip precision; "f" suffix forces float in C++. + return f"T({float(value)!r}f)" + + +def _emit_custom_functor(name: str, op: str, scalar=None) -> str: + """Emit a unary CUTLASS-compatible functor (scalar + Array spec).""" + if op in _CUSTOM_UNARY_BODY: + body = _CUSTOM_UNARY_BODY[op] + scalar_decl = "" + elif op in _CUSTOM_SCALAR_BODY: + if scalar is None: + raise ValueError(f"Scalar op {op!r} needs a baked constant") + body = _CUSTOM_SCALAR_BODY[op] + scalar_decl = f" const T c = {_scalar_literal_T(scalar)};\n" + else: + raise ValueError(f"No custom functor body for op {op!r}") + return textwrap.dedent( + f"""\ + template + struct {name} {{ + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE + T operator()(T const& x) const {{ + {scalar_decl} {body} + }} + }}; + + template + struct {name}> {{ + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE + cutlass::Array operator()(cutlass::Array const& v) const {{ + {name} op; + cutlass::Array out; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) out[i] = op(v[i]); + return out; + }} + }}; + """ + ) + + +# ── EVT typedef + leaf args walker ──────────────────────────────────────────── + + +class _EvtEmitter: + """Bottom-up walker that emits typedef chains + leaf placeholders.""" + + def __init__(self, root: Store): + self.root = root + self.typedef_lines: List[str] = [] + self.functor_decls: List[str] = [] + self._emitted_functors: Dict[Tuple[str, str], str] = {} + self._tmp_counter = 0 + # Per-leaf metadata captured during walk: leaf identity (object id) → + # (typedef_name, leaf_kind, input_idx_or_None, dtype_str) + self.leaf_typedefs: List[Tuple[str, str, "int | None", str]] = [] + self.scalar_functor_counter = 0 + + def _new_name(self, prefix: str) -> str: + self._tmp_counter += 1 + return f"{prefix}_{self._tmp_counter}" + + def _functor_name_for(self, op: str, scalar) -> str: + """Unique struct name for a custom functor, deduped by (op, scalar).""" + key = (op, repr(scalar) if scalar is not None else "") + if key in self._emitted_functors: + return self._emitted_functors[key] + # Strip dots from the scalar so the name stays a valid C++ identifier. + scalar_tag = "" + if scalar is not None: + self.scalar_functor_counter += 1 + scalar_tag = f"_v{self.scalar_functor_counter}" + name = f"Magi_{op}{scalar_tag}" + self._emitted_functors[key] = name + self.functor_decls.append(_emit_custom_functor(name, op, scalar)) + return name + + def _compute_op_template(self, node: Compute) -> str: + """Return the C++ template-name passed as ComputeFn to VisitorCompute.""" + if node.op in _BUILTIN_FN_TEMPLATE and node.scalar is None: + return _BUILTIN_FN_TEMPLATE[node.op] + # Custom functor — either scalar-baked or unary-no-builtin (e.g. erf). + return self._functor_name_for(node.op, node.scalar) + + def emit(self) -> str: + """Walk the IR; return the typedef name of the root EVT type (EVT_D).""" + # Recurse from Store.child first to build up subtrees. + body_root = self._emit_node(self.root.child) + # The store leaf itself is the StoreD typedef wrapping body_root. + store_name = self._new_name("StoreD") + self.typedef_lines.append( + "using {name} = cutlass::epilogue::threadblock::VisitorAuxStore<\n" + " OutputTileThreadMap, ElementC,\n" + " cutlass::FloatRoundStyle::round_to_nearest,\n" + " cute::Stride>;".format(name=store_name) + ) + evt_d = self._new_name("EVT_D") + self.typedef_lines.append( + f"using {evt_d} = cutlass::epilogue::threadblock::Sm80EVT<\n" f" {store_name}, {body_root}>;" + ) + # Track the StoreD leaf metadata so the launcher knows where to bind D. + self.leaf_typedefs.append((store_name, "store", None, self.root.out_dtype)) + return evt_d + + def _emit_node(self, node) -> str: + if isinstance(node, Accum): + name = self._new_name("Accum") + self.typedef_lines.append(f"using {name} = cutlass::epilogue::threadblock::VisitorAccFetch;") + return name + if isinstance(node, RowBroadcast): + name = self._new_name("RowBcast") + elem = _DTYPE_TO_CUTLASS[node.dtype] + self.typedef_lines.append( + f"using {name} = cutlass::epilogue::threadblock::VisitorRowBroadcast<\n" + f" OutputTileThreadMap, {elem},\n" + f" cute::Stride<_0, _1, int32_t>>;" + ) + self.leaf_typedefs.append((name, "row_bcast", node.input_idx, node.dtype)) + return name + if isinstance(node, ColBroadcast): + name = self._new_name("ColBcast") + elem = _DTYPE_TO_CUTLASS[node.dtype] + self.typedef_lines.append( + f"using {name} = cutlass::epilogue::threadblock::VisitorColBroadcast<\n" + f" OutputTileThreadMap, {elem},\n" + f" cute::Stride<_1, _0, int32_t>>;" + ) + self.leaf_typedefs.append((name, "col_bcast", node.input_idx, node.dtype)) + return name + if isinstance(node, AuxLoad): + name = self._new_name("Aux") + elem = _DTYPE_TO_CUTLASS[node.dtype] + self.typedef_lines.append( + f"using {name} = cutlass::epilogue::threadblock::VisitorAuxLoad<\n" + f" OutputTileThreadMap, {elem},\n" + f" cute::Stride>;" + ) + self.leaf_typedefs.append((name, "aux_load", node.input_idx, node.dtype)) + return name + if isinstance(node, Compute): + child_names = [self._emit_node(c) for c in node.children] + compute_name = self._new_name(f"Cmp_{node.op}") + fn_template = self._compute_op_template(node) + self.typedef_lines.append( + f"using {compute_name} = cutlass::epilogue::threadblock::VisitorCompute<\n" + f" {fn_template}, ElementCompute, ElementCompute,\n" + f" cutlass::FloatRoundStyle::round_to_nearest>;" + ) + evt_name = self._new_name(f"EVT_{node.op}") + child_typedef_list = ", ".join(child_names) + self.typedef_lines.append( + f"using {evt_name} = cutlass::epilogue::threadblock::Sm80EVT<\n" f" {compute_name}, {child_typedef_list}>;" + ) + return evt_name + raise TypeError(f"Unknown IR node type: {type(node).__name__}") + + +# ── Argument-tree emitter (matches EVT typedef tree) ────────────────────────── + + +def _emit_args_tree(node, leaf_args: Dict[int, str], indent: int = 4) -> str: + """Emit the nested-brace runtime callback-args literal matching the IR. + + ``leaf_args[input_idx]`` for non-Accum leaves is a small C++ snippet like + ``{ptrBias, ElementC(0), {_0{}, _1{}, int32_t(N)}}``. Accum / Compute / + Store args are empty braces ``{}``. The Store arg is ``{ptrD, {N, _1{}, + MN}}`` and is handled by the caller — this emitter only renders the body + inside StoreD. + """ + pad = " " * indent + if isinstance(node, Accum): + return f"{pad}{{}}" + if isinstance(node, (RowBroadcast, ColBroadcast, AuxLoad)): + return f"{pad}{leaf_args[node.input_idx]}" + if isinstance(node, Compute): + children_str = ",\n".join(_emit_args_tree(c, leaf_args, indent + 2) for c in node.children) + return f"{pad}{{\n" f"{children_str},\n" f"{pad} {{}}\n" f"{pad}}}" + raise TypeError(f"Unknown IR node type: {type(node).__name__}") + + +# ── Public API: render a complete .cu source string ────────────────────────── + + +_KERNEL_PREAMBLE = """\ +// AUTO-GENERATED by magi_compiler/passes/piecewise_graph/fusion/evt_codegen.py +// Do not edit by hand. Regenerate by re-running the FX pass. +// +// IR cache key: {cache_key} + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +using cute::_0; +using cute::_1; + +//////////////////////////////////////////////////////////////////////////////// +// Custom functors (one per unique scalar-baked op or non-builtin unary). +//////////////////////////////////////////////////////////////////////////////// +{functor_decls} + +//////////////////////////////////////////////////////////////////////////////// +// Data types and layouts +//////////////////////////////////////////////////////////////////////////////// + +using ElementA = {a_elem}; +using ElementB = {b_elem}; +using ElementC = {c_elem}; +using ElementAcc = float; +using ElementCompute = float; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::{b_layout}; +using LayoutC = cutlass::layout::RowMajor; + +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; +// AlignmentC = 4 instead of 8 so any N-divisible-by-4 output works (e.g. odd +// half-N values like 13652 from N=27304). Aligned tails still vectorise. +constexpr int AlignmentC = 4; + +using ArchTag = cutlass::arch::Sm80; +using OperatorClass = cutlass::arch::OpClassTensorOp; +using InstructionShape = cutlass::gemm::GemmShape< 16, 8, 16>; +constexpr int EVTEpilogueStages = 1; + +//////////////////////////////////////////////////////////////////////////////// +// Per-tile-config GEMM type. The OutputTileThreadMap depends on +// ThreadblockShape/WarpShape, which forces every EVT typedef to be re-built +// per tile. We package the whole tree inside a template struct keyed on the +// tile/warp/stages parameters so each autotune candidate is a distinct type. +//////////////////////////////////////////////////////////////////////////////// + +template +struct EvtConfig {{ + using TheTbShape = TbShape; + using TheWarpShape = WarpShape; + + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + TbShape, WarpShape, ElementC, AlignmentC, EVTEpilogueStages>; + + //////////////////////////////////////////////////////////////////////////// + // EVT (Epilogue Visitor Tree) typedefs — generated from the IR tree. + //////////////////////////////////////////////////////////////////////////// +{typedef_block} + + //////////////////////////////////////////////////////////////////////////// + // GemmKernel / DeviceGemm + //////////////////////////////////////////////////////////////////////////// + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, + ElementC, LayoutC, AlignmentC, + ElementAcc, + ElementCompute, + OperatorClass, + ArchTag, + TbShape, + WarpShape, + InstructionShape, + {evt_root_name}, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + NumStages, + cutlass::arch::OpMultiplyAdd, + EVTEpilogueStages>::GemmKernel; + + using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter; +}}; + +//////////////////////////////////////////////////////////////////////////////// +// Autotune runner — one candidate per tile/warp/stages combination; first call +// at a new (M, N, K) tuple times every candidate and caches the winner. +//////////////////////////////////////////////////////////////////////////////// + +struct EvtArgs {{ + int M; + int N; + int K; + void* ptr_A; + void* ptr_B; + void* ptr_D; + // Extras pointers, in IR-leaf order. + std::vector ptr_extras; +}}; + +class EvtConcept {{ + public: + virtual ~EvtConcept() = default; + virtual size_t get_workspace_size(const EvtArgs&) = 0; + virtual cutlass::Status initialize(const EvtArgs&, void* ws, cudaStream_t s) = 0; + virtual cutlass::Status run(cudaStream_t stream) = 0; + virtual const char* name() const = 0; +}}; + +template +class EvtImpl : public EvtConcept {{ + public: + using GemmType = typename Cfg::DeviceGemm; + using EvtRoot = typename Cfg::{evt_root_name}; + + explicit EvtImpl(const char* name) : name_(name) {{}} + + typename GemmType::Arguments make_args(const EvtArgs& a) {{ + auto ptrA = reinterpret_cast(a.ptr_A); + auto ptrB = reinterpret_cast(a.ptr_B); + auto ptrD = reinterpret_cast(a.ptr_D); + int const M = a.M; + int const N = a.N; + int const K = a.K; + int64_t const MN = static_cast(M) * static_cast(N); + + typename EvtRoot::Arguments callback_args{{ +{args_tree} + , + {{ptrD, {{int64_t(N), _1{{}}, MN}}}} + }}; + + cutlass::gemm::GemmCoord problem{{M, N, K}}; + typename GemmType::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, + problem, + /*batch_count=*/1, + callback_args, + ptrA, ptrB, + /*ptr_C=*/nullptr, /*ptr_D=*/nullptr, + /*batch_stride_A=*/static_cast(M) * K, + /*batch_stride_B=*/static_cast(N) * K, + /*batch_stride_C=*/0, /*batch_stride_D=*/0, + /*stride_a=*/static_cast(K), + /*stride_b=*/static_cast({stride_b_expr}), + /*stride_c=*/0, /*stride_d=*/0); + return args; + }} + + size_t get_workspace_size(const EvtArgs& a) override {{ + auto args = make_args(a); + return GemmType::get_workspace_size(args); + }} + cutlass::Status initialize(const EvtArgs& a, void* ws, cudaStream_t s) override {{ + auto args = make_args(a); + return gemm_.initialize(args, ws, s); + }} + cutlass::Status run(cudaStream_t stream) override {{ + return gemm_.run(stream); + }} + const char* name() const override {{ return name_; }} + + private: + GemmType gemm_; + const char* name_; +}}; + +//////////////////////////////////////////////////////////////////////////////// +// Python-facing launcher +//////////////////////////////////////////////////////////////////////////////// +""" + + +_LAUNCHER_TEMPLATE = """\ +//////////////////////////////////////////////////////////////////////////////// +// Tile candidate registration. Each AutoConfigBuilder invocation instantiates +// the full EVT typedef tree + GemmKernel for that (TileShape, WarpShape, +// NumStages) tuple. Compile time grows linearly with the candidate count, so +// keep the list small and shape-relevant. +//////////////////////////////////////////////////////////////////////////////// + +#define EVT_TILE_CANDIDATE(tb_m, tb_n, tb_k, wa_m, wa_n, wa_k, stages, label) \\ + configs_.push_back(std::make_unique, \\ + cutlass::gemm::GemmShape, \\ + stages>>>(label)) + +class EvtAutoTuneRunner {{ + public: + EvtAutoTuneRunner() {{ +{tile_candidate_block} + }} + + void operator()(at::Tensor A, at::Tensor B, + std::vector extras, at::Tensor D) {{ + TORCH_CHECK(A.is_cuda() && B.is_cuda() && D.is_cuda(), + "evt_matmul_out: A/B/D must be CUDA tensors"); + TORCH_CHECK(A.scalar_type() == {a_at_dtype}, "A must be {a_dtype}"); + TORCH_CHECK(B.scalar_type() == {b_at_dtype}, "B must be {b_dtype}"); + TORCH_CHECK(D.scalar_type() == {c_at_dtype}, "D must be {c_dtype}"); + TORCH_CHECK(A.dim() == 2 && B.dim() == 2, "A, B must be 2D"); + TORCH_CHECK(A.is_contiguous() && B.is_contiguous() && D.is_contiguous(), + "A, B, D must be contiguous (row-major)"); + + int const M = static_cast(A.size(0)); + int const K = static_cast(A.size(1)); + int const N = static_cast({n_dim_expr}); + + TORCH_CHECK(D.size(0) == M && D.size(1) == N, + "D must be (M, N); got ", D.sizes()); + TORCH_CHECK(extras.size() == {n_extras}, "expected {n_extras} extra tensors, got ", extras.size()); + +{extras_validation} + + EvtArgs ea; + ea.M = M; ea.N = N; ea.K = K; + ea.ptr_A = A.data_ptr<{a_at_cpp}>(); + ea.ptr_B = B.data_ptr<{b_at_cpp}>(); + ea.ptr_D = D.data_ptr<{c_at_cpp}>(); + ea.ptr_extras.reserve({n_extras}); +{extras_ptrs} + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.device().index()).stream(); + + // Single autotune per module. The .cu is compiled per (IR, M-bucket, + // b_layout, N, K) on the Python side — every distinct weight (N, K) + // gets its own .cu, so this runner instance hosts exactly one (N, K) + // and one bucket of M values. Autotune once on the first call; all + // subsequent calls (any M inside the bucket) reuse `best_idx_`. + if (best_idx_ < 0) {{ + best_idx_ = autotune(ea, stream); + }} + int idx = best_idx_; + + auto& gemm = configs_[idx]; + size_t ws_sz = gemm->get_workspace_size(ea); + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) {{ + ws_ = at::empty({{(int64_t)ws_sz + 1}}, + at::TensorOptions().dtype(at::kByte).device(A.device())); + }} + auto st = gemm->initialize(ea, ws_sz > 0 ? ws_.data_ptr() : nullptr, stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "CUTLASS init failed (", gemm->name(), "): ", cutlassGetStatusString(st)); + st = gemm->run(stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "CUTLASS run failed (", gemm->name(), "): ", cutlassGetStatusString(st)); + }} + + int num_configs() const {{ return (int)configs_.size(); }} + + private: + int autotune(const EvtArgs& ea, cudaStream_t stream) {{ + int best_idx = -1; + float best_time = 1e30f; + cudaEvent_t s, e; + cudaEventCreate(&s); cudaEventCreate(&e); + + for (size_t i = 0; i < configs_.size(); ++i) {{ + auto& g = configs_[i]; + size_t ws_sz = 0; + try {{ ws_sz = g->get_workspace_size(ea); }} + catch (...) {{ continue; }} + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) {{ + ws_ = at::empty({{(int64_t)ws_sz + 1}}, + at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + }} + void* ws_ptr = ws_sz > 0 ? ws_.data_ptr() : nullptr; + if (g->initialize(ea, ws_ptr, stream) != cutlass::Status::kSuccess) {{ + continue; + }} + + // Warmup — 10 iters so L2 / inst caches settle (3 was too few — first + // timed iter saw a cold L2 and biased the choice towards smaller tiles). + for (int w = 0; w < 10; ++w) g->run(stream); + cudaStreamSynchronize(stream); + + // Time — 20 iters for ~1% timing noise, matching torch.compile defaults. + cudaEventRecord(s, stream); + int iters = 20; + for (int p = 0; p < iters; ++p) g->run(stream); + cudaEventRecord(e, stream); + cudaEventSynchronize(e); + float ms = 0; + cudaEventElapsedTime(&ms, s, e); + float avg = ms / iters; + if (avg < best_time) {{ best_time = avg; best_idx = (int)i; }} + }} + cudaEventDestroy(s); cudaEventDestroy(e); + TORCH_CHECK(best_idx >= 0, + "EVT AutoTune: no candidate succeeded for (M,N,K)=(", + ea.M, ",", ea.N, ",", ea.K, ")"); + return best_idx; + }} + + std::vector> configs_; + int best_idx_ = -1; // -1 = not yet autotuned; sticky after first call. + at::Tensor ws_; +}}; + +static EvtAutoTuneRunner& runner() {{ + static EvtAutoTuneRunner R; + return R; +}} + +void evt_matmul_out(at::Tensor A, at::Tensor B, + std::vector extras, + at::Tensor D) {{ + runner()(std::move(A), std::move(B), std::move(extras), std::move(D)); +}} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {{ + m.doc() = "Magi compiler EVT-fused matmul (auto-generated, autotune)"; + m.def("evt_matmul_out", &evt_matmul_out, + "Fused EVT matmul: D = epilogue(A @ B, extras...)", + pybind11::arg("A"), pybind11::arg("B"), + pybind11::arg("extras"), pybind11::arg("D")); + m.def("num_configs", []() {{ return runner().num_configs(); }}); +}} +""" + + +def render_evt_cu( + ir: Store, a_dtype: str, b_dtype: str, cache_key_str: str = "", b_layout: str = "row", m_bucket: str = "medium" +) -> str: + """Render a complete .cu source for the given EVT IR. + + Parameters + ---------- + ir : Store + Root of the EVT IR tree. + a_dtype, b_dtype : str + Element types for A and B (typically ``"bfloat16"``). Output dtype is + taken from ``ir.out_dtype``. + cache_key_str : str + Optional hash echoed in a top-level comment, useful for debugging. + b_layout : "row" | "col" + ``"row"`` (default): B is contiguous (K, N) row-major; LayoutB = + RowMajor; ldB = N. ``"col"``: B is the underlying (N, K) row-major + weight (== column-major (K, N)); LayoutB = ColumnMajor; ldB = K. Use + ``"col"`` when the FX graph passes ``permute([1,0])(weight)`` as B. + m_bucket : "small" | "medium" | "large" + Picks a tile-candidate set tuned for RTX 5090 (sm_120) at the given M + regime. The runner inside the rendered .cu autotunes across all + candidates in that bucket on the first call per (M, N, K) shape and + caches the winner. + """ + if b_layout not in ("row", "col"): + raise ValueError(f"b_layout must be 'row' or 'col', got {b_layout!r}") + if m_bucket not in _TILE_CANDIDATES_5090: + raise ValueError(f"unknown m_bucket {m_bucket!r}; " f"expected one of {list(_TILE_CANDIDATES_5090)}") + if not isinstance(ir, Store): + raise TypeError("render_evt_cu expects a Store node as root") + tile_candidate_block = _emit_tile_candidates(m_bucket) + + a_elem = _DTYPE_TO_CUTLASS[a_dtype] + b_elem = _DTYPE_TO_CUTLASS[b_dtype] + c_elem = _DTYPE_TO_CUTLASS[ir.out_dtype] + + emitter = _EvtEmitter(ir) + evt_root = emitter.emit() + + # Build per-leaf runtime arg fragments. These get inlined into + # ``EvtImpl::make_args`` (a method on a different class than the launcher + # that fills ea.ptr_extras). The only shared state between the two scopes + # is the EvtArgs struct ``a``, so we read pointers from a.ptr_extras[i] + # and cast back to the leaf's element type. + leaves = walk_leaves(ir) + leaf_args: Dict[int, str] = {} + for leaf in leaves: + # Accum has no extras pointer / dtype — skip; it consumes the GEMM + # accumulator directly via VisitorAccFetch. + if not isinstance(leaf, (RowBroadcast, ColBroadcast, AuxLoad)): + continue + elem = _DTYPE_TO_CUTLASS[leaf.dtype] + ptr_expr = f"reinterpret_cast<{elem}*>(a.ptr_extras[{leaf.input_idx}])" + if isinstance(leaf, RowBroadcast): + leaf_args[leaf.input_idx] = f"{{{ptr_expr}, {elem}(0), {{_0{{}}, _1{{}}, int32_t(N)}}}}" + elif isinstance(leaf, ColBroadcast): + leaf_args[leaf.input_idx] = f"{{{ptr_expr}, {elem}(0), {{_1{{}}, _0{{}}, int32_t(M)}}}}" + else: # AuxLoad + leaf_args[leaf.input_idx] = f"{{{ptr_expr}, {elem}(0), {{int64_t(N), _1{{}}, MN}}}}" + # Accum has no explicit args entry. + + args_tree = _emit_args_tree(ir.child, leaf_args, indent=8) + + # Extras-validation + pointer-extraction blocks. The same external tensor + # (same input_idx) may appear at multiple leaves in the IR tree — e.g. an + # ``add(mm, bias)`` value flowing into both ``sigmoid`` and ``mul`` creates + # two RowBroadcast(0) leaves. We must declare ``ptr_extra_0`` exactly once + # in the launcher; the runtime args tree still references the same ptr + # name from each leaf-arg fragment so this dedup is purely a C++ scope fix. + extras_validation_lines = [] + extras_ptr_lines = [] + seen_extras: set = set() + extra_leaves = [n for n in leaves if not isinstance(n, Accum)] + n_extras = max((leaf.input_idx for leaf in extra_leaves), default=-1) + 1 + for leaf in extra_leaves: + i = leaf.input_idx + if i in seen_extras: + continue + seen_extras.add(i) + at_dtype = _DTYPE_TO_AT[leaf.dtype] + at_cpp = _DTYPE_TO_AT_CPP[leaf.dtype] + _DTYPE_TO_CUTLASS[leaf.dtype] + if isinstance(leaf, RowBroadcast): + extras_validation_lines.append(f' TORCH_CHECK(extras[{i}].numel() == N, "extras[{i}] must have N elements");') + elif isinstance(leaf, ColBroadcast): + extras_validation_lines.append(f' TORCH_CHECK(extras[{i}].numel() == M, "extras[{i}] must have M elements");') + elif isinstance(leaf, AuxLoad): + extras_validation_lines.append( + f' TORCH_CHECK(extras[{i}].size(0) == M && extras[{i}].size(1) == N,' f' "extras[{i}] must be (M,N)");' + ) + extras_validation_lines.append( + f' TORCH_CHECK(extras[{i}].scalar_type() == {at_dtype},' f' "extras[{i}] must be {leaf.dtype}");' + ) + extras_validation_lines.append(f' TORCH_CHECK(extras[{i}].is_cuda(), "extras[{i}] must be CUDA");') + # Push raw pointer into ea.ptr_extras for the make_args() side to + # read (it lives in a different scope than this launcher fn). + extras_ptr_lines.append(f" ea.ptr_extras.push_back(static_cast(" f"extras[{i}].data_ptr<{at_cpp}>()));") + + extras_validation = "\n".join(extras_validation_lines) if extras_validation_lines else " // no extras" + extras_ptrs = "\n".join(extras_ptr_lines) if extras_ptr_lines else "" + + # Emit. The functor decls already end with a trailing newline each. + functor_decls = "\n".join(emitter.functor_decls) if emitter.functor_decls else "// (no custom functors)" + # typedef_block lives inside ``struct EvtConfig`` — indent each line by 2 + # spaces so member typedefs read consistently with the surrounding struct. + typedef_block = "\n".join(" " + l if l.strip() else l for l in "\n".join(emitter.typedef_lines).split("\n")) + + cutlass_b_layout = "RowMajor" if b_layout == "row" else "ColumnMajor" + if b_layout == "row": + # B is (K, N) row-major contiguous: K from B.size(0), N from B.size(1), ldB = N. + n_dim_expr = "B.size(1)" + stride_b_expr = "N" + else: + # B is the underlying (N, K) row-major weight (we read the same + # bytes via ColumnMajor (K, N)): N from B.size(0), K from B.size(1), ldB = K. + n_dim_expr = "B.size(0)" + stride_b_expr = "K" + + preamble = _KERNEL_PREAMBLE.format( + cache_key=cache_key_str, + functor_decls=functor_decls, + a_elem=a_elem, + b_elem=b_elem, + c_elem=c_elem, + typedef_block=typedef_block, + evt_root_name=evt_root, + b_layout=cutlass_b_layout, + # EvtImpl::make_args uses args_tree + stride_b_expr; same values as + # the launcher (per-IR / per-layout, not per-tile-config). + args_tree=args_tree, + stride_b_expr=stride_b_expr, + ) + launcher = _LAUNCHER_TEMPLATE.format( + evt_root_name=evt_root, + args_tree=args_tree, + a_dtype=a_dtype, + b_dtype=b_dtype, + c_dtype=ir.out_dtype, + a_at_dtype=_DTYPE_TO_AT[a_dtype], + b_at_dtype=_DTYPE_TO_AT[b_dtype], + c_at_dtype=_DTYPE_TO_AT[ir.out_dtype], + a_at_cpp=_DTYPE_TO_AT_CPP[a_dtype], + b_at_cpp=_DTYPE_TO_AT_CPP[b_dtype], + c_at_cpp=_DTYPE_TO_AT_CPP[ir.out_dtype], + n_extras=n_extras, + extras_validation=extras_validation, + extras_ptrs=extras_ptrs, + n_dim_expr=n_dim_expr, + stride_b_expr=stride_b_expr, + tile_candidate_block=tile_candidate_block, + ) + return preamble + launcher diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_ir.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_ir.py new file mode 100644 index 0000000..ae6bc1e --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_ir.py @@ -0,0 +1,242 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""EVT (Epilogue Visitor Tree) intermediate representation. + +A small dataclass IR that the FX pass builds while walking the consumers of an +``aten.mm`` node, and that ``evt_codegen.py`` consumes to render a CUTLASS .cu +source. The IR is canonicalised to a deterministic JSON string used as the +cache key for the JIT'd kernel module. + +The IR is rooted at a single ``Store`` node and forms a DAG of compute nodes +over leaves (``Accum``, ``RowBroadcast``, ``ColBroadcast``, ``AuxLoad``). + +Op naming: every name in ``UNARY_OPS`` / ``BINARY_OPS`` corresponds to a +CUTLASS visitor template that ``evt_codegen.py`` knows how to emit. Adding a +new op requires updating both this file and the codegen. +""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass +from typing import List, Optional, Union + +# Ops that take a single child tensor and produce a tensor of the same shape. +# All run in fp32 inside the EVT epilogue. +UNARY_OPS = frozenset( + {"neg", "sigmoid", "silu", "gelu_erf", "gelu_tanh", "tanh", "relu", "square", "erf", "exp", "log", "sqrt", "rsqrt", "abs"} +) + +# Ops that take two child tensors. Both children must be EVT subtrees. +BINARY_OPS = frozenset({"add", "sub", "mul", "div", "max", "min"}) + +# Unary ops that bake a single fp32 scalar into the functor at codegen time. +# Used to fold scalar literals out of the IR so they don't bloat the cache key. +SCALAR_UNARY_OPS = frozenset( + { + "add_scalar", # x + c + "sub_scalar", # x - c + "mul_scalar", # x * c + "div_scalar", # x / c + "rsub_scalar", # c - x + "clamp_min_c", # max(x, c) + "clamp_max_c", # min(x, c) + "scaled_silu_alpha", # x * sigmoid(alpha * x), used by gelu7 + "pow_scalar", # x ** c (only sensible for small integer c) + } +) + +ALL_OPS = UNARY_OPS | BINARY_OPS | SCALAR_UNARY_OPS + +# Output dtype tags propagated from FakeTensor metadata into Store and leaves. +# Kept as strings (not torch.dtype) so the IR is JSON-serialisable. +DTYPES = frozenset({"bfloat16", "float16", "float32"}) + + +# ── Leaf nodes ──────────────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class Accum: + """The fp32 GEMM accumulator. Always the unique starting leaf of the IR.""" + + kind: str = "accum" + + +@dataclass(frozen=True) +class RowBroadcast: + """1-D (N,) tensor broadcast along the M axis. Maps to VisitorRowBroadcast. + + ``input_idx`` is the position of this tensor in the runtime ``extras`` list. + ``dtype`` is the storage dtype; the visitor casts to fp32 internally. + """ + + input_idx: int + dtype: str + kind: str = "row_bcast" + + +@dataclass(frozen=True) +class ColBroadcast: + """1-D (M,) tensor broadcast along the N axis. Maps to VisitorColBroadcast.""" + + input_idx: int + dtype: str + kind: str = "col_bcast" + + +@dataclass(frozen=True) +class AuxLoad: + """2-D (M, N) row-major aux tensor. Maps to VisitorAuxLoad. + + Caller must guarantee ``stride[1] == 1`` and that ``stride[0]`` is 16-byte + aligned (cp.async requirement). + """ + + input_idx: int + dtype: str + kind: str = "aux_load" + + +# ── Compute nodes ───────────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class Compute: + """An interior fp32 elementwise op. + + Children are EVT subtrees (any of the leaf or compute types). + For SCALAR_UNARY_OPS, ``children`` has length 1 and ``scalar`` carries the + baked constant. + For UNARY_OPS, ``children`` has length 1, ``scalar`` is None. + For BINARY_OPS, ``children`` has length 2, ``scalar`` is None. + """ + + op: str + children: tuple + scalar: Optional[float] = None + kind: str = "compute" + + def __post_init__(self): + # Validate at construction time so codegen never sees a malformed IR. + if self.op not in ALL_OPS: + raise ValueError(f"Unknown EVT op: {self.op!r}") + if self.op in UNARY_OPS: + if len(self.children) != 1 or self.scalar is not None: + raise ValueError(f"UNARY op {self.op!r} requires 1 child, no scalar") + elif self.op in BINARY_OPS: + if len(self.children) != 2 or self.scalar is not None: + raise ValueError(f"BINARY op {self.op!r} requires 2 children, no scalar") + elif self.op in SCALAR_UNARY_OPS: + if len(self.children) != 1 or self.scalar is None: + raise ValueError(f"SCALAR_UNARY op {self.op!r} requires 1 child + scalar") + + +@dataclass(frozen=True) +class Store: + """Root of the IR. Casts the fp32 result to ``out_dtype`` and writes D.""" + + child: object # any IR node + out_dtype: str + kind: str = "store" + + def __post_init__(self): + if self.out_dtype not in DTYPES: + raise ValueError(f"Unknown out_dtype {self.out_dtype!r}") + + +# Union type alias for type hints. +IRNode = Union[Accum, RowBroadcast, ColBroadcast, AuxLoad, Compute, Store] + + +# ── Canonicalisation + serialisation ────────────────────────────────────────── + + +def to_dict(node) -> dict: + """Recursively convert an IR node tree into a JSON-friendly dict. + + The dict layout is designed for stable hashing: keys appear in a fixed + order and floats are formatted with ``repr`` so 1.702 vs 1.7020000001 + never collide. + """ + if isinstance(node, Accum): + return {"kind": "accum"} + if isinstance(node, RowBroadcast): + return {"kind": "row_bcast", "input_idx": node.input_idx, "dtype": node.dtype} + if isinstance(node, ColBroadcast): + return {"kind": "col_bcast", "input_idx": node.input_idx, "dtype": node.dtype} + if isinstance(node, AuxLoad): + return {"kind": "aux_load", "input_idx": node.input_idx, "dtype": node.dtype} + if isinstance(node, Compute): + d = {"kind": "compute", "op": node.op, "children": [to_dict(c) for c in node.children]} + if node.scalar is not None: + # repr of a float is round-trip-safe; explicitly stringify so JSON + # never serialises 1.7000000000000002. + d["scalar"] = repr(float(node.scalar)) + return d + if isinstance(node, Store): + return {"kind": "store", "out_dtype": node.out_dtype, "child": to_dict(node.child)} + raise TypeError(f"Unknown IR node type: {type(node).__name__}") + + +def to_canonical_json(node) -> str: + """Deterministic JSON string for an IR tree. Same IR ⇒ same string.""" + return json.dumps(to_dict(node), sort_keys=True, separators=(",", ":")) + + +def cache_key(node, a_dtype: str, b_dtype: str) -> str: + """SHA-256 hash of (IR JSON, A dtype, B dtype). Used as the JIT module key.""" + payload = {"ir": to_dict(node), "a": a_dtype, "b": b_dtype, "version": 1} + blob = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(blob).hexdigest() + + +# ── Tree walkers ────────────────────────────────────────────────────────────── + + +def walk_leaves(node) -> List: + """Return all leaf nodes (Accum / RowBroadcast / ColBroadcast / AuxLoad) + in left-to-right pre-order. Used by codegen to enumerate kernel inputs.""" + out: list = [] + + def _go(n): + if isinstance(n, (Accum, RowBroadcast, ColBroadcast, AuxLoad)): + out.append(n) + elif isinstance(n, Compute): + for c in n.children: + _go(c) + elif isinstance(n, Store): + _go(n.child) + else: + raise TypeError(f"Unknown IR node type: {type(n).__name__}") + + _go(node) + return out + + +def is_trivial(node) -> bool: + """An IR is trivial when ``Store(Accum)`` — no compute on the accumulator. + + Trivial IRs would replace cuBLAS with a more expensive kernel for no + benefit, so the FX pass should refuse to emit them. + """ + return isinstance(node, Store) and isinstance(node.child, Accum) + + +def num_extras(node) -> int: + """Maximum input_idx + 1 across all non-Accum leaves, or 0 if none.""" + indices: list = [leaf.input_idx for leaf in walk_leaves(node) if not isinstance(leaf, Accum)] + return max(indices) + 1 if indices else 0 diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_runtime.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_runtime.py new file mode 100644 index 0000000..56fa681 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_runtime.py @@ -0,0 +1,583 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runtime side of the EVT fusion: torch.library op + JIT loader + dispatch. + +This file owns: + * The ``magi_epilogue::matmul_custom_evt`` torch.library op + fake impl. + * A process-level cache mapping IR JSON → compiled cpp_extension module. + * Dispatch to one of two backends: + - ``kind == "evt"`` → JIT-compiled CUTLASS Sm80EVT kernel. + - ``kind == "swiglu7_dual"`` → vendored DualGemm one-stage kernel. + +The kernel build directory uses the IR cache key as its name so re-runs and +multi-process Inductor compile workers all hit the same on-disk cache. +""" + +from __future__ import annotations + +import hashlib +import json +import os +import threading +from typing import Optional + +import torch + +from magi_compiler.config import get_compile_config + +from .evt_codegen import render_evt_cu +from .evt_ir import Accum, AuxLoad, ColBroadcast, Compute, RowBroadcast, Store + +# ── torch.library op definition ─────────────────────────────────────────────── +# Reuse the existing ``magi_epilogue`` library so all our custom matmul ops +# live under one namespace. Defining a fresh op here is harmless even if +# ``matmul_epilogue_fusion.py`` has already initialised the library. +_LIB = torch.library.Library("magi_epilogue", "FRAGMENT") +_LIB.define( + "matmul_custom_evt(Tensor A, Tensor B, Tensor[] extras, str ir_json," " str kind, int n_out, int out_dtype_id) -> Tensor" +) + + +# ── Output-dtype encoding (must round-trip through torch.library int args) ──── +_OUT_DTYPE_ID = {torch.bfloat16: 0, torch.float16: 1, torch.float32: 2} +_ID_TO_DTYPE = {v: k for k, v in _OUT_DTYPE_ID.items()} +_DTYPE_TO_STR = {torch.bfloat16: "bfloat16", torch.float16: "float16", torch.float32: "float32"} + + +def out_dtype_id(dt: torch.dtype) -> int: + """Encode a torch.dtype as a small int for inclusion in op args.""" + if dt not in _OUT_DTYPE_ID: + raise ValueError(f"Unsupported EVT output dtype {dt}") + return _OUT_DTYPE_ID[dt] + + +def out_dtype_from_id(i: int) -> torch.dtype: + return _ID_TO_DTYPE[i] + + +# ── M-bucket dispatch ───────────────────────────────────────────────────────── +# Three coarse buckets matching the tile-candidate sets in +# ``evt_codegen._TILE_CANDIDATES_5090``: +# small — M ≤ 256 (decode / single-token) +# medium — 256 < M ≤ 2048 (mid-size prefill) +# large — M > 2048 (large prefill / batched) +# Each bucket compiles a distinct .cu module containing its own tile-candidate +# vector; the per-module C++ runner then autotunes the actual best (TileShape, +# WarpShape, NumStages) tuple at first call per (M, N, K) and caches the +# winning index inside the module — so the Python side only pays one extra +# cache key dimension. +_M_BUCKET_BOUNDARIES = (256, 2048) + + +def _m_bucket(M: int) -> str: + if M <= _M_BUCKET_BOUNDARIES[0]: + return "small" + if M <= _M_BUCKET_BOUNDARIES[1]: + return "medium" + return "large" + + +# ── Output row-stride helper ────────────────────────────────────────────────── +# CUTLASS Sm80EVT and the swiglu7 DualGemm both require D's row stride to be a +# multiple of AlignmentC * sizeof(ElementC) = 4 * sizeof(bf16) = 8 bytes (i.e. +# 4 elements for bf16/fp16, 2 elements for fp32). When n_out already meets this +# requirement we return a *contiguous* (M, n_out) tensor — avoids an extra D2D +# scratch copy on the hot path. Only when n_out fails the alignment do we fall +# back to padding the row stride. +# +# Earlier this padded everything to 128 bytes (matching the Triton path's +# convention) but on shapes like N_out=13652 the resulting non-contig D forced +# a kernel-into-scratch + scratch-into-D copy worth ~5% of the kernel runtime +# at (M=7697, N=27304, K=5120) — which fully accounted for the perf gap users +# saw between the standalone benchmark (no scratch) and the real model. +# +# Pre-computed alignment per dtype to avoid the ~2–5 μs cost of +# ``torch.empty([], dtype=dt).element_size()`` per op invocation. Hit count on +# this lookup is 2× per fused op (runtime impl + fake impl), so on a model with +# 100 fused-op calls per forward this shaves ~1 ms off the dispatch overhead. +_ALIGN_BY_DTYPE: dict = { + torch.bfloat16: 4, # 8 bytes / 2 = 4 elements + torch.float16: 4, + torch.float32: 2, # 8 bytes / 4 = 2 elements +} + + +def _aligned_n_stride(n_out: int, dt: torch.dtype) -> int: + align = _ALIGN_BY_DTYPE.get(dt) + if align is None: # rare: a dtype we haven't pre-tabulated + align = max(1, 8 // torch.empty([], dtype=dt).element_size()) + return (n_out + align - 1) // align * align + + +# ── Compile cache + per-key build lock ──────────────────────────────────────── +_MODULE_CACHE: dict = {} # cache_key (sha256 str) → loaded cpp_extension module +# Hot-path fast cache — avoids ``json.dumps + sha256`` (~10–30 μs/call) when +# the module has already been compiled. Keyed by the 4-tuple of (Python-) +# hashable inputs that uniquely determine the rendered .cu, since equality on +# the tuple is sufficient (no need to canonicalise twice). Populated on the +# slow path inside ``_compile_evt_module``. +_MODULE_FAST_CACHE: dict = {} # (ir_json, a_dtype, b_dtype, b_layout) → module +_MODULE_LOCKS: dict = {} # cache_key → threading.Lock +_MODULE_LOCKS_GLOBAL = threading.Lock() +_SWIGLU7_LOCK = threading.Lock() # serialises insertions into _SWIGLU7_FAST_CACHE + + +# ── D output-buffer cache ──────────────────────────────────────────────────── +# Keyed by (M, n_out, n_stride, out_dtype, device_idx). Mirrors the same +# cache pattern in ``sm120_triton_kernel.py:_buf_cache`` — which has been +# shipping in this codebase for the Triton path. Reusing D across calls +# avoids the per-call ``torch.empty`` overhead (~5–15 μs of Python work + +# allocator metadata) and the (rare) scratch slice; on hot paths with +# millisecond-scale kernels this is a measurable but small win. +# +# Correctness contract — same as the Triton path: this is a single-stream +# inference cache. The previous call's D consumer must already have read it +# before the next call lands. Inductor-generated ``call(...)`` functions +# satisfy this because they execute serially on the default CUDA stream and +# the returned tensor is consumed before the next op-level dispatch. +# +# To opt out (e.g. when bench-scripting with overlapping streams), set the +# env var ``MAGI_EVT_DISABLE_D_CACHE=1``. +_D_BUF_CACHE: dict = {} +_D_CACHE_DISABLED: bool = os.environ.get("MAGI_EVT_DISABLE_D_CACHE", "0") not in ("0", "", "false", "False") + + +def _get_or_alloc_D(M: int, n_out: int, out_dtype: torch.dtype, device: torch.device) -> "torch.Tensor": + """Return a (possibly cached) (M, n_out) output buffer. + + The buffer is contiguous when ``n_stride == n_out`` (the fast path); when + ``n_out`` is mis-aligned we keep the padded ``[:, :n_out]`` slice so the + fake impl's stride matches at runtime. + """ + # Fast path: cache key first, recompute n_stride only on miss. The cache + # is keyed by (M, n_out, dtype, device_idx); two distinct (n_out, dtype) + # always have the same alignment, so we don't need n_stride in the key. + idx = device.index or 0 # index is None for default device → falsy → 0 + key = (M, n_out, out_dtype, idx) + cached = _D_BUF_CACHE.get(key) + if cached is not None and not _D_CACHE_DISABLED: + return cached + n_stride = _aligned_n_stride(n_out, out_dtype) + if n_stride == n_out: + D = torch.empty((M, n_out), device=device, dtype=out_dtype) + else: + D = torch.empty((M, n_stride), device=device, dtype=out_dtype)[:, :n_out] + if not _D_CACHE_DISABLED: + # Single-entry cache: evict everything else, then install the new one. + # We can't iterate-and-delete on the live dict (RuntimeError under any + # workload that puts >1 entry in the cache — e.g. CP=4 sees multiple + # per-rank shapes during warmup, while a single-card run often reuses + # one shape and never tripped the bug). + _D_BUF_CACHE.clear() + _D_BUF_CACHE[key] = D + return D + + +def _cutlass_root() -> str: + return os.environ.get("MAGI_CUTLASS_ROOT", "/root/cutlass") + + +def _evt_build_dir(key: str) -> str: + cache_root = get_compile_config().cache_root_dir + return os.path.join(cache_root, "evt_kernels", key) + + +def _per_key_lock(key: str) -> threading.Lock: + """Return the per-key build lock; coalesces concurrent compile requests.""" + with _MODULE_LOCKS_GLOBAL: + lock = _MODULE_LOCKS.get(key) + if lock is None: + lock = threading.Lock() + _MODULE_LOCKS[key] = lock + return lock + + +def _compile_evt_module( + ir_json: str, + a_dtype: torch.dtype, + b_dtype: torch.dtype, + b_layout: str = "row", + m_bucket: str = "medium", + N: int = 0, + K: int = 0, +): + """Render + JIT-compile the EVT kernel for ``ir_json``. Process-level cached. + + Cache key: (IR, A dtype, B dtype, b_layout, m_bucket, N, K). Each distinct + weight (N, K) lowers to its own .cu — even though the .cu source is + identical (N/K stay runtime variables), splitting the modules gives every + (N, K) its own runner instance with isolated `best_idx_`. This avoids + cross-(N, K) autotune contamination and matches the user's per-(N, K) + cache layout: e.g. two distinct (N, K) × two M-buckets ⇒ 4 .cu modules. + """ + # Hot-path fast cache: skip ``json.dumps + sha256`` (~10–30 μs each) on + # subsequent calls with the same inputs. + fast_key = (ir_json, a_dtype, b_dtype, b_layout, m_bucket, N, K) + cached = _MODULE_FAST_CACHE.get(fast_key) + if cached is not None: + return cached + + if b_layout not in ("row", "col"): + raise ValueError(f"b_layout must be 'row' or 'col', got {b_layout!r}") + a_str = _DTYPE_TO_STR[a_dtype] + b_str = _DTYPE_TO_STR[b_dtype] + extended = json.dumps( + { + "ir": ir_json, + "a": a_str, + "b": b_str, + "b_layout": b_layout, + "m_bucket": m_bucket, + "N": int(N), + "K": int(K), + "version": 3, + }, + sort_keys=True, + ).encode("utf-8") + key = hashlib.sha256(extended).hexdigest() + + cached = _MODULE_CACHE.get(key) + if cached is not None: + _MODULE_FAST_CACHE[fast_key] = cached + return cached + + lock = _per_key_lock(key) + with lock: + cached = _MODULE_CACHE.get(key) + if cached is not None: + _MODULE_FAST_CACHE[fast_key] = cached + return cached + + # Re-hydrate the IR tree from JSON for codegen. + ir = _ir_from_json(ir_json) + src = render_evt_cu(ir, a_str, b_str, cache_key_str=key, b_layout=b_layout, m_bucket=m_bucket) + + build_dir = _evt_build_dir(key) + os.makedirs(build_dir, exist_ok=True) + src_path = os.path.join(build_dir, "evt.cu") + # Write atomically (tmp + rename) so concurrent processes don't see a + # half-written file. Use a process-specific tmp name to avoid races + # across multiple rank processes generating the same kernel. + tmp_path = f"{src_path}.{os.getpid()}.tmp" + with open(tmp_path, "w") as f: + f.write(src) + os.replace(tmp_path, src_path) + + cutlass_root = _cutlass_root() + from torch.utils.cpp_extension import load + + # cpp_extension.load uses its own file lock under build_directory, so + # multi-process races resolve to a single nvcc invocation. + module = load( + name=f"magi_evt_{key[:12]}", + sources=[src_path], + extra_include_paths=[ + os.path.join(cutlass_root, "include"), + os.path.join(cutlass_root, "tools", "util", "include"), + ], + extra_cflags=["-O3", "-std=c++17"], + extra_cuda_cflags=["-std=c++17", "-O3", "--expt-relaxed-constexpr", "-gencode=arch=compute_120,code=sm_120"], + build_directory=build_dir, + verbose=False, + ) + _MODULE_CACHE[key] = module + _MODULE_FAST_CACHE[fast_key] = module + return module + + +# ── IR (de)serialisation ───────────────────────────────────────────────────── + + +def to_ir_json(node) -> str: + from .evt_ir import to_canonical_json + + return to_canonical_json(node) + + +def _ir_from_json(s: str): + """Inverse of ``to_canonical_json``. Used only to drive codegen at compile + time — the FX pass holds the original Python objects and never round-trips + its own IR through JSON in a hot loop.""" + d = json.loads(s) + return _node_from_dict(d) + + +def _node_from_dict(d): + kind = d["kind"] + if kind == "accum": + return Accum() + if kind == "row_bcast": + return RowBroadcast(input_idx=d["input_idx"], dtype=d["dtype"]) + if kind == "col_bcast": + return ColBroadcast(input_idx=d["input_idx"], dtype=d["dtype"]) + if kind == "aux_load": + return AuxLoad(input_idx=d["input_idx"], dtype=d["dtype"]) + if kind == "compute": + scalar = d.get("scalar") + scalar_val: Optional[float] = float(scalar) if scalar is not None else None + return Compute(op=d["op"], children=tuple(_node_from_dict(c) for c in d["children"]), scalar=scalar_val) + if kind == "store": + return Store(child=_node_from_dict(d["child"]), out_dtype=d["out_dtype"]) + raise ValueError(f"Unknown IR kind {kind!r}") + + +# ── swiglu7 dual-gemm extension loader ──────────────────────────────────────── +# Per-(m_bucket, N, K) cache. The .cu source is identical across keys (N/K stay +# runtime variables); we still build separate modules so each runner instance +# hosts exactly one (N, K), giving every weight shape its own isolated +# best_idx_. Two distinct (N, K) × two M-buckets ⇒ 4 modules. +_SWIGLU7_FAST_CACHE: dict = {} # (m_bucket, N, K) → loaded module +_SWIGLU7_BUILD_LOCKS: dict = {} # (m_bucket, N, K) → threading.Lock + + +def _compile_swiglu7_dual(m_bucket: str, N: int, K: int): + """Lazy-load a per-(bucket, N, K) instance of the vendored DualGemm kernel. + + Parameters + ---------- + m_bucket : "small" | "medium" | "large" + Bucket of the activation M dim — included in the cache key so e.g. + small-M (decode) can autotune to a different best tile than large-M + (prefill) for the same (N, K). + N, K : int + Static weight shape from B (the underlying (N, K) row-major tensor). + Distinct (N, K) get distinct modules so their autotune state is + independent. + """ + fast_key = (m_bucket, int(N), int(K)) + cached = _SWIGLU7_FAST_CACHE.get(fast_key) + if cached is not None: + return cached + + with _SWIGLU7_LOCK: + lock = _SWIGLU7_BUILD_LOCKS.get(fast_key) + if lock is None: + lock = threading.Lock() + _SWIGLU7_BUILD_LOCKS[fast_key] = lock + with lock: + cached = _SWIGLU7_FAST_CACHE.get(fast_key) + if cached is not None: + return cached + + cutlass_root = _cutlass_root() + here = os.path.dirname(os.path.abspath(__file__)) + src = os.path.join(here, "cutlass_kernels", "swiglu7_epi_one_stage.cu") + if not os.path.exists(src): + raise FileNotFoundError(f"vendored swiglu7 source not found: {src}") + cache_root = get_compile_config().cache_root_dir + # Build dir embeds (bucket, N, K) so distinct keys get their own + # build artefacts. cpp_extension uses the dir as the cache identity. + build_tag = f"{m_bucket}_N{N}_K{K}" + build_dir = os.path.join(cache_root, "evt_kernels", f"swiglu7_dual_{build_tag}") + os.makedirs(build_dir, exist_ok=True) + from torch.utils.cpp_extension import load + + module = load( + name=f"magi_swiglu7_dual_{build_tag}", + sources=[src], + extra_include_paths=[ + os.path.join(cutlass_root, "include"), + os.path.join(cutlass_root, "tools", "util", "include"), + os.path.join(cutlass_root, "examples"), + os.path.join(here, "cutlass_kernels"), + ], + extra_cflags=["-O3", "-std=c++17"], + extra_cuda_cflags=["-std=c++17", "-O3", "--expt-relaxed-constexpr", "-gencode=arch=compute_120,code=sm_120"], + build_directory=build_dir, + verbose=False, + ) + _SWIGLU7_FAST_CACHE[fast_key] = module + return module + + +# ── torch.library backend impls ─────────────────────────────────────────────── + + +# Single-entry scratch cache for the rare mis-aligned-N path. Same greedy +# eviction policy as ``_D_BUF_CACHE`` — bounded memory across many shapes +# (e.g. CP=4 sees several per-rank M values during warmup; we don't want a +# scratch buffer for every one). +_SCRATCH_CACHE: dict = {} + + +def _get_or_alloc_scratch(M: int, n_out: int, out_dtype: torch.dtype, device: torch.device) -> "torch.Tensor": + if _D_CACHE_DISABLED: + return torch.empty((M, n_out), device=device, dtype=out_dtype) + idx = device.index or 0 + key = (M, n_out, out_dtype, idx) + cached = _SCRATCH_CACHE.get(key) + if cached is not None: + return cached + s = torch.empty((M, n_out), device=device, dtype=out_dtype) + # Greedy eviction: one shape at a time. + _SCRATCH_CACHE.clear() + _SCRATCH_CACHE[key] = s + return s + + +# ── Dispatch fast-cache ────────────────────────────────────────────────────── +# Hot-path bottleneck reduction: collapse the four-step +# out_dtype_from_id → _m_bucket → _compile_* → mod.attr-lookup +# chain into a single dict.get() returning a pre-bound callable plus the +# small amount of immutable metadata the kernel-launch site needs. +# +# Key shape: (kind, ir_json, A.dtype, B.dtype, N, K, m_bucket, out_dtype). +# Most of these are static per FX-emit site (kind / ir_json / dtypes / N / K) +# — only m_bucket varies with M. So the cache reaches steady state after the +# first time each (site, bucket) is seen. +# +# Each entry holds: +# * kernel_call : pre-bound mod.evt_matmul_out / swiglu7_dual_matmul_out +# * is_evt : True for evt_row/evt_col (need extras list), False for swiglu7 +# * out_dtype : torch.dtype to pass to D allocation +class _DispatchEntry: + __slots__ = ("kernel_call", "is_evt", "out_dtype") + + def __init__(self, kernel_call, is_evt, out_dtype): + self.kernel_call = kernel_call + self.is_evt = is_evt + self.out_dtype = out_dtype + + +_DISPATCH_CACHE: dict = {} + + +def _resolve_dispatch(kind, ir_json, a_dtype, b_dtype, N_w, K_w, m_bucket, out_dtype): + """Slow-path resolver — compiles the .cu module (cache miss) and binds + the kernel callable. Cached by (kind, ir_json, A_dt, B_dt, N, K, bucket, + out_dtype) so each FX site × bucket only pays this once.""" + if kind == "swiglu7_dual": + mod = _compile_swiglu7_dual(m_bucket, N_w, K_w) + return _DispatchEntry(mod.swiglu7_dual_matmul_out, False, out_dtype) + if kind == "evt_row" or kind == "evt": + b_layout = "row" + elif kind == "evt_col": + b_layout = "col" + else: + raise ValueError(f"Unknown EVT kind {kind!r}") + mod = _compile_evt_module(ir_json, a_dtype, b_dtype, b_layout=b_layout, m_bucket=m_bucket, N=N_w, K=K_w) + return _DispatchEntry(mod.evt_matmul_out, True, out_dtype) + + +@torch.library.impl(_LIB, "matmul_custom_evt", "CUDA") +def _matmul_custom_evt_cuda(A, B, extras, ir_json, kind, n_out, out_dtype_id_): + """Runtime entry point for the EVT-fused matmul op. + + Hot path is heavily inlined to keep per-call Python overhead under ~2 μs: + one dict.get() resolves the kernel callable + metadata, then we allocate D + (with a single-entry greedy cache) and call straight into the C++ kernel. + + Layout contract — the FX pass owns this; do not rewrite operands here: + * ``kind == "evt_row"`` : B is contiguous (K, N) row-major. + * ``kind == "evt_col"`` : B is the underlying (N, K) row-major weight; the + kernel was rendered with ``LayoutB = ColumnMajor`` so it reads (K, N) + from the same bytes via stride (1, K). + * ``kind == "swiglu7_dual"`` : B is the underlying (N, K) row-major weight + (the FX pass already replaced the ``permute([1,0])`` view with its + operand). The DualGemm kernel reads it as ColumnMajor + ldB=2K. + + Calling ``.contiguous()`` on B here would silently break the col / swiglu7 + paths by materialising a (K, N) row-major copy that no longer matches the + LayoutB the kernel was compiled with — every B value would be wrong. + """ + # ── Step 1: resolve dispatch entry (one dict lookup on the fast path) ── + # B.size(0)/size(1) are slightly faster than .shape[0]/[1] (avoid Python + # tuple construction). For all 3 kinds B's leading dim ≠ K — the launcher + # / runner derives N internally from b_layout, but for the dispatch cache + # key we just need a stable per-site discriminator, so passing the raw + # B.size pair is enough. + B_size0 = B.size(0) + B_size1 = B.size(1) + M = A.size(0) + # Inline _m_bucket: avoid the ~300 ns function call. + if M <= 256: + m_bucket = "small" + elif M <= 2048: + m_bucket = "medium" + else: + m_bucket = "large" + # Inline out_dtype_from_id: skip the function call frame. + out_dtype = _ID_TO_DTYPE[out_dtype_id_] + # B's (N, K) interpretation depends on kind. For evt_row B is (K, N), + # for evt_col / swiglu7_dual B is the underlying (N, K). Either way we + # only need (B_size0, B_size1) to disambiguate distinct weights — the + # resolver re-computes N/K correctly for compilation. + a_dtype = A.dtype + b_dtype_ = B.dtype + fast_key = (kind, ir_json, a_dtype, b_dtype_, B_size0, B_size1, m_bucket, out_dtype) + entry = _DISPATCH_CACHE.get(fast_key) + if entry is None: + # Map B sizes to (N_w, K_w) in the layout the compile path expects. + if kind == "evt_row": + K_w, N_w = B_size0, B_size1 + else: + # evt_col / swiglu7_dual: B is (N, K) underlying weight. + N_w, K_w = B_size0, B_size1 + entry = _resolve_dispatch(kind, ir_json, a_dtype, b_dtype_, N_w, K_w, m_bucket, out_dtype) + _DISPATCH_CACHE[fast_key] = entry + + # ── Step 2: alloc / fetch D (greedy single-entry cache, inlined) ── + # D matches the fake impl's shape. CUTLASS launchers require D contiguous; + # when n_out happens to be mis-aligned the row stride is padded and we + # route through a scratch buffer. + if _D_CACHE_DISABLED: + n_stride = _aligned_n_stride(n_out, out_dtype) + if n_stride == n_out: + D = torch.empty((M, n_out), device=A.device, dtype=out_dtype) + else: + D = torch.empty((M, n_stride), device=A.device, dtype=out_dtype)[:, :n_out] + else: + dev_idx = A.device.index or 0 + d_key = (M, n_out, out_dtype, dev_idx) + D = _D_BUF_CACHE.get(d_key) + if D is None: + n_stride = _aligned_n_stride(n_out, out_dtype) + if n_stride == n_out: + D = torch.empty((M, n_out), device=A.device, dtype=out_dtype) + else: + D = torch.empty((M, n_stride), device=A.device, dtype=out_dtype)[:, :n_out] + _D_BUF_CACHE.clear() + _D_BUF_CACHE[d_key] = D + + # ── Step 3: dispatch — pre-bound callable, single C++ trampoline ── + # `D.stride(0) != n_out` is the only branch we take per call to decide + # whether we need the scratch route. Cheap C++ attribute compare. + needs_scratch = D.stride(0) != n_out + kernel_call = entry.kernel_call + + if entry.is_evt: + if needs_scratch: + scratch = _get_or_alloc_scratch(M, n_out, out_dtype, A.device) + kernel_call(A, B, extras, scratch) + D.copy_(scratch) + return D + kernel_call(A, B, extras, D) + return D + + # swiglu7_dual: extras is always [] here (FX pass guarantees). + if needs_scratch: + scratch = _get_or_alloc_scratch(M, n_out, out_dtype, A.device) + kernel_call(A, B, scratch) + D.copy_(scratch) + return D + kernel_call(A, B, D) + return D + + +@torch.library.register_fake("magi_epilogue::matmul_custom_evt") +def _matmul_custom_evt_fake(A, B, extras, ir_json, kind, n_out, out_dtype_id_): + out_dtype = out_dtype_from_id(out_dtype_id_) + n_stride = _aligned_n_stride(n_out, out_dtype) + return A.new_empty_strided((A.shape[0], n_out), (n_stride, 1), dtype=out_dtype) diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py new file mode 100644 index 0000000..dd5dc99 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py @@ -0,0 +1,716 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FX pass that fuses aten.mm + elementwise epilogue into a CUTLASS EVT call. + +Two backends: + * Generic EVT — for the 6 non-swiglu activations and 1-D bias/scale variants. + Builds an IR tree (see ``evt_ir.py``), serialises to JSON, replaces the + matched chain with a single ``torch.ops.magi_epilogue.matmul_custom_evt`` + call. The runtime renders + JIT-compiles a CUTLASS Sm80EVT kernel keyed by + the IR hash (see ``evt_runtime.py``). + * swiglu7 — pattern-matches the canonical recipe (slice-stride-2 + dual + clamps + scaled SiLU) and dispatches to a vendored DualGemm one-stage + kernel that writes (M, N/2) directly. + +Eligibility gates (alignment, B layout, dtype) are checked up-front. Anything +not eligible stays as ``aten.mm`` for cuBLAS to handle. We do NOT fall back to +the Triton fusion path on sm120; per user decision, EVT replaces it entirely. +""" + +from __future__ import annotations + +import operator +from typing import List, Optional, Tuple + +import torch +import torch.fx as fx + +from magi_compiler.passes.pass_base import MagiInductorPass + +from . import evt_runtime # ensures torch.library op + fake impl are registered +from .evt_ir import Accum, AuxLoad, ColBroadcast, Compute, RowBroadcast, Store, is_trivial, num_extras, to_canonical_json + +# ── Op tables ──────────────────────────────────────────────────────────────── +# Pure passthrough — no value or dtype change; alias the same IR node. +_PASSTHROUGH_OPS = frozenset({torch.ops.aten.clone.default, torch.ops.aten.contiguous.default, torch.ops.aten.alias.default}) + +# Dtype-conversion ops; the EVT compute is always fp32 internally so these are +# absorbed as no-ops as long as the start/end of the chain reach the same final +# precision (we capture that via the Store node's out_dtype). +_TYPE_CONV_OPS = frozenset({torch.ops.prims.convert_element_type.default, torch.ops.aten._to_copy.default}) + +# Unary ops with a direct EVT IR equivalent. +_UNARY_OPS = { + torch.ops.aten.neg.default: "neg", + torch.ops.aten.sigmoid.default: "sigmoid", + torch.ops.aten.tanh.default: "tanh", + torch.ops.aten.silu.default: "silu", + torch.ops.aten.relu.default: "relu", + torch.ops.aten.square.default: "square", + torch.ops.aten.erf.default: "erf", + torch.ops.aten.exp.default: "exp", + torch.ops.aten.log.default: "log", + torch.ops.aten.sqrt.default: "sqrt", + torch.ops.aten.rsqrt.default: "rsqrt", + torch.ops.aten.abs.default: "abs", +} + +# Binary tensor ops. +_BINARY_OPS = { + torch.ops.aten.add.Tensor: "add", + torch.ops.aten.sub.Tensor: "sub", + torch.ops.aten.mul.Tensor: "mul", + torch.ops.aten.div.Tensor: "div", + torch.ops.aten.maximum.default: "max", + torch.ops.aten.minimum.default: "min", + operator.add: "add", + operator.sub: "sub", + operator.mul: "mul", + operator.truediv: "div", +} + +# Scalar binary ops → SCALAR_UNARY_OPS in IR. +_SCALAR_BINARY_TO_SCALAR_UNARY = { + torch.ops.aten.add.Scalar: "add_scalar", + torch.ops.aten.sub.Scalar: "sub_scalar", + torch.ops.aten.mul.Scalar: "mul_scalar", + torch.ops.aten.div.Scalar: "div_scalar", +} + + +# Output-dtype encode helper (mirrors evt_runtime). +_DTYPE_TO_STR = {torch.bfloat16: "bfloat16", torch.float16: "float16", torch.float32: "float32"} + + +def _val_dtype(node) -> Optional[torch.dtype]: + val = node.meta.get("val") if isinstance(node, fx.Node) else None + return val.dtype if val is not None else None + + +def _val_shape(node) -> Optional[Tuple]: + val = node.meta.get("val") if isinstance(node, fx.Node) else None + return tuple(val.shape) if val is not None else None + + +def _val_stride(node) -> Optional[Tuple]: + val = node.meta.get("val") if isinstance(node, fx.Node) else None + try: + return tuple(val.stride()) if val is not None else None + except Exception: + return None + + +def _is_static_int(x) -> bool: + return type(x) is int + + +def _is_transpose_node(n) -> bool: + """True iff ``n`` is a 2-D transpose (aten.t / transpose(0,1) / permute([1,0])).""" + if not isinstance(n, fx.Node) or n.op != "call_function": + return False + if n.target is torch.ops.aten.t.default: + return True + if n.target is torch.ops.aten.transpose.int: + # transpose(x, dim0, dim1) — accept (0, 1) on a 2D tensor. + if len(n.args) >= 3: + d0, d1 = n.args[1], n.args[2] + return {d0, d1} == {0, 1} + return False + if n.target is torch.ops.aten.permute.default: + # permute(x, [1, 0]) on a 2D tensor. + if len(n.args) >= 2: + perm = n.args[1] + return list(perm) == [1, 0] + return False + return False + + +def _b_layout_kind(B_node): + """Classify B for the EVT generic path. + + Returns (b_layout, underlying_b_node, n_dim) where: + * b_layout = "row" : B is (K, N) row-major contiguous; pass B as-is. + * b_layout = "col" : B is a stride-transpose of a contiguous (N, K) + tensor; pass the underlying tensor; kernel uses + LayoutB=ColumnMajor. + * (None, None, None) : B is not in a supported layout. + """ + shape = _val_shape(B_node) + stride = _val_stride(B_node) + if shape is None or stride is None or len(shape) != 2: + return None, None, None + K_or_N0, N_or_K1 = shape[0], shape[1] + # Contiguous (K, N): row layout. N = shape[1]. + if stride == (N_or_K1, 1): + return "row", B_node, N_or_K1 + # Stride-transposed (K, N) view of a contig (N, K) weight: stride == (1, K). + # The underlying tensor is the transpose-producer's input when the FX + # graph models the view explicitly via t/transpose/permute([1,0]); fall + # back to using B itself (its data_ptr is the same). + if _is_transpose_node(B_node): + weight = B_node.args[0] + w_shape = _val_shape(weight) if isinstance(weight, fx.Node) else None + w_stride = _val_stride(weight) if isinstance(weight, fx.Node) else None + if w_shape is not None and len(w_shape) == 2 and w_stride == (w_shape[1], 1): + # weight is (N, K) row-major contig; N = w_shape[0]. + return "col", weight, w_shape[0] + # Generic stride-transposed view (no explicit transpose node) — also OK: + # we read the same memory bytes as a (N, K) row-major buffer at B itself. + if stride == (1, K_or_N0): + # B is (K, N) col-major == underlying (N, K) row-major. We don't have + # an explicit weight node so we pass B directly; the kernel reads + # (N, K) with N = shape[1], K = shape[0]. Detection via stride alone. + return "col", B_node, N_or_K1 + return None, None, None + + +# ── Pass ───────────────────────────────────────────────────────────────────── + + +# Sentinel returned by _try_fuse_evt to communicate "abort, leave mm intact". +_ABORT = object() + + +class MatmulEvtEpilogueFusionPass(MagiInductorPass): + """Fuse aten.mm + elementwise chain into a CUTLASS EVT call (sm_120).""" + + def __init__(self, allow_extras: bool = True) -> None: + # On non-sm120 we degrade to a no-op; the manager wires us only on + # sm120 anyway, but defending against misuse is cheap. + try: + cap = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0) + except Exception: + cap = (0, 0) + self._enabled = cap[0] >= 12 + self.allow_extras = allow_extras + + def __call__(self, graph: fx.Graph) -> bool: + if not self._enabled: + return False + fused = 0 + for node in list(graph.nodes): + if node.op != "call_function": + continue + if node.target not in (torch.ops.aten.mm.default, torch.ops.aten.mm): + continue + r = self._try_fuse_evt(graph, node) + if r: + fused += 1 + if fused: + graph.eliminate_dead_code() + return fused > 0 + + # ── Generic EVT chain walker ────────────────────────────────────────────── + + def _try_fuse_evt(self, graph: fx.Graph, mm_node: fx.Node) -> bool: + A, B = mm_node.args[0], mm_node.args[1] + if not isinstance(A, fx.Node) or not isinstance(B, fx.Node): + return False + a_dtype = _val_dtype(A) + b_dtype = _val_dtype(B) + if a_dtype not in (torch.bfloat16, torch.float16) or a_dtype != b_dtype: + return False + # Alignment gates — bf16/fp16 require K % 8. + a_shape = _val_shape(A) + b_shape = _val_shape(B) + if a_shape is None or b_shape is None or len(a_shape) != 2 or len(b_shape) != 2: + return False + K = a_shape[1] + N = b_shape[1] + if _is_static_int(K) and (K % 8 != 0): + return False + if _is_static_int(N) and (N % 4 != 0): + return False + + # node_to_ir: each fused fx.Node → its IR subtree. mm_node maps to Accum. + node_to_ir: dict = {mm_node: Accum()} + # In-order list of fused fx nodes (for erase + escape detection). + fused_nodes: List[fx.Node] = [mm_node] + # Walked-and-removed nodes including type-conv/passthrough that don't + # appear in node_to_ir as new IR nodes (they alias their input). + walk_seen: List[fx.Node] = [mm_node] + # External tensors injected as RowBroadcast/ColBroadcast/AuxLoad leaves. + # extras_nodes[i] is the fx.Node passed at runtime as extras[i]. + extras_nodes: List[fx.Node] = [] + # Tracks whether the IR has any swiglu7-style slice. If so we abort + # generic EVT and try the swiglu7 matcher instead. + saw_slice = False + + last_node = mm_node + last_ir = node_to_ir[mm_node] + + # Walk consumers in source order, greedily absorbing supported ops. + curr = mm_node.next + while curr is not None and curr.op != "output": + uses_fused = any(isinstance(a, fx.Node) and a in node_to_ir for a in curr.args) + if not uses_fused: + curr = curr.next + continue + + target = curr.target + + # ── Pass-through (clone / contiguous / alias) ───────────────────── + if target in _PASSTHROUGH_OPS: + node_to_ir[curr] = node_to_ir[curr.args[0]] + walk_seen.append(curr) + last_node = curr + last_ir = node_to_ir[curr] + curr = curr.next + continue + + # ── Type conversion (no-op in fp32 EVT) ─────────────────────────── + if target in _TYPE_CONV_OPS: + node_to_ir[curr] = node_to_ir[curr.args[0]] + walk_seen.append(curr) + last_node = curr + last_ir = node_to_ir[curr] + curr = curr.next + continue + + # ── Pure view ops (only if shape unchanged) ─────────────────────── + if target in (torch.ops.aten.view.default, torch.ops.aten.reshape.default, torch.ops.aten._unsafe_view.default): + in_shape = _val_shape(curr.args[0]) + out_shape = _val_shape(curr) + if in_shape == out_shape: + node_to_ir[curr] = node_to_ir[curr.args[0]] + walk_seen.append(curr) + last_node = curr + last_ir = node_to_ir[curr] + curr = curr.next + continue + break + + # ── Slice stride-2 (swiglu marker) ──────────────────────────────── + if target is torch.ops.aten.slice.Tensor: + step = curr.args[4] if len(curr.args) > 4 else curr.kwargs.get("step", 1) + if step == 2: + saw_slice = True + break + + # ── Unary ops ───────────────────────────────────────────────────── + if target in _UNARY_OPS: + op_name = _UNARY_OPS[target] + child_ir = node_to_ir[curr.args[0]] + ir = Compute(op_name, (child_ir,)) + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + + # ── GELU (default = erf, alternative = tanh) ────────────────────── + if target is torch.ops.aten.gelu.default: + approx = curr.kwargs.get("approximate", "none") + op_name = "gelu_tanh" if approx == "tanh" else "gelu_erf" + child_ir = node_to_ir[curr.args[0]] + ir = Compute(op_name, (child_ir,)) + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + + # ── Scalar variants of add/sub/mul/div ──────────────────────────── + if target in _SCALAR_BINARY_TO_SCALAR_UNARY: + op_name = _SCALAR_BINARY_TO_SCALAR_UNARY[target] + child_ir = node_to_ir[curr.args[0]] + if not isinstance(curr.args[1], (int, float)): + break + scalar = float(curr.args[1]) + ir = Compute(op_name, (child_ir,), scalar=scalar) + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + + # ── Clamp family ────────────────────────────────────────────────── + if target in (torch.ops.aten.clamp.default, torch.ops.aten.clamp_min.default, torch.ops.aten.clamp_max.default): + child_ir = node_to_ir[curr.args[0]] + if target is torch.ops.aten.clamp_min.default: + lo = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("min") + hi = None + elif target is torch.ops.aten.clamp_max.default: + lo = None + hi = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("max") + else: + lo = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("min") + hi = curr.args[2] if len(curr.args) > 2 else curr.kwargs.get("max") + if (lo is not None and not isinstance(lo, (int, float))) or ( + hi is not None and not isinstance(hi, (int, float)) + ): + break + ir_now = child_ir + if lo is not None: + ir_now = Compute("clamp_min_c", (ir_now,), scalar=float(lo)) + if hi is not None: + ir_now = Compute("clamp_max_c", (ir_now,), scalar=float(hi)) + node_to_ir[curr] = ir_now + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir_now + curr = curr.next + continue + + # ── pow.Tensor_Scalar — only the small-int special-cases ────────── + if target is torch.ops.aten.pow.Tensor_Scalar: + exp = curr.args[1] if len(curr.args) > 1 else None + child_ir = node_to_ir[curr.args[0]] + if exp == 2 or exp == 2.0: + ir = Compute("square", (child_ir,)) + elif isinstance(exp, (int, float)): + ir = Compute("pow_scalar", (child_ir,), scalar=float(exp)) + else: + break + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + + # ── Binary tensor ops ───────────────────────────────────────────── + if target in _BINARY_OPS: + op_name = _BINARY_OPS[target] + lhs_raw = curr.args[0] + rhs_raw = curr.args[1] + # Fold int/float scalars on the RHS to scalar variants. + if isinstance(rhs_raw, (int, float)) and isinstance(lhs_raw, fx.Node) and lhs_raw in node_to_ir: + scalar_op = {"add": "add_scalar", "sub": "sub_scalar", "mul": "mul_scalar", "div": "div_scalar"}.get( + op_name + ) + if scalar_op is None: + break + ir = Compute(scalar_op, (node_to_ir[lhs_raw],), scalar=float(rhs_raw)) + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + # Fold scalar-on-LHS for commutative ops; for sub/div we need rsub/rdiv. + if isinstance(lhs_raw, (int, float)) and isinstance(rhs_raw, fx.Node) and rhs_raw in node_to_ir: + if op_name in ("add", "mul"): + scalar_op = "add_scalar" if op_name == "add" else "mul_scalar" + ir = Compute(scalar_op, (node_to_ir[rhs_raw],), scalar=float(lhs_raw)) + elif op_name == "sub": + ir = Compute("rsub_scalar", (node_to_ir[rhs_raw],), scalar=float(lhs_raw)) + else: + break + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + # Both tensor — either internal (already in IR) or external. + lhs_ir = self._ir_for_arg(lhs_raw, node_to_ir, extras_nodes, A, B) + rhs_ir = self._ir_for_arg(rhs_raw, node_to_ir, extras_nodes, A, B) + if lhs_ir is None or rhs_ir is None: + break + ir = Compute(op_name, (lhs_ir, rhs_ir)) + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + + # Unsupported op — stop greedy walk. + break + + # If we saw a stride-2 slice and the chain is plausibly swiglu7, try + # the dedicated matcher. It rebuilds independently from mm_node. + if saw_slice: + return self._try_fuse_swiglu7(graph, mm_node) + + # Verify we made progress. + if last_ir is node_to_ir[mm_node]: + return False # only Accum — replacing cuBLAS with EVT is no win + + # Refuse if any escape: an intermediate fused node is consumed outside + # the fused region. (EVT has no "extra outputs"; the user explicitly + # opted out of cross-domain fan-out.) + # + # The exclusion ``n is not last_node`` is intentional — the last node + # in the fused chain becomes the EVT op's output and is allowed to + # have downstream consumers (that's the whole point of fusion). + # Earlier writes ([:-1] explicitly skips the last position) must not + # have any external user, otherwise the fused chain would silently + # drop their value. This previously read ``walk_seen[:-0]`` which is + # ``walk_seen[:0]`` (an empty slice!) so escape detection was a no-op + # and trivially-fusable chains like ``mm → add(residual) → square`` + # were emitted even when ``add(residual)`` was reused downstream. + fused_set = set(fused_nodes) | set(walk_seen) + for n in walk_seen[:-1]: + for u in n.users: + if u not in fused_set: + return False + + # Final eligibility check: A contiguous, B in a supported layout. + a_stride = _val_stride(A) + if a_stride is None: + return False + a_shape_now = _val_shape(A) + if a_stride != (a_shape_now[1], 1): + return False + b_layout, b_underlying, n_dim = _b_layout_kind(B) + if b_layout is None: + return False + + # Determine output dtype from the last fused node's FakeTensor metadata. + out_dt = _val_dtype(last_node) or torch.bfloat16 + if out_dt not in _DTYPE_TO_STR: + return False + + ir_root = Store(child=last_ir, out_dtype=_DTYPE_TO_STR[out_dt]) + if is_trivial(ir_root): + return False + # If extras are disabled, refuse any IR that needs them. + if not self.allow_extras and num_extras(ir_root) > 0: + return False + + ir_json = to_canonical_json(ir_root) + n_out = n_dim + out_dt_id = evt_runtime.out_dtype_id(out_dt) + kind = "evt_row" if b_layout == "row" else "evt_col" + + with graph.inserting_after(last_node): + new_node = graph.call_function( + torch.ops.magi_epilogue.matmul_custom_evt.default, + args=(A, b_underlying, extras_nodes, ir_json, kind, n_out, out_dt_id), + ) + # Propagate FakeTensor meta so downstream Inductor checks pass. + try: + val_last = last_node.meta.get("val") + if val_last is not None: + # Propagate but with 128B-aligned stride matching what the + # CUDA impl actually returns. + new_val = val_last.new_empty_strided( + val_last.shape, (evt_runtime._aligned_n_stride(int(val_last.shape[-1]), val_last.dtype), 1) + ) + new_node.meta["val"] = new_val + except Exception: + pass + + last_node.replace_all_uses_with(new_node) + for n in reversed(walk_seen): + if len(n.users) == 0 and n is not new_node: + graph.erase_node(n) + return True + + def _ir_for_arg(self, arg, node_to_ir, extras_nodes, A_node, B_node): + """Return an IR subtree for a binary-op operand. Internal → IR; external + → leaf (RowBroadcast / ColBroadcast / AuxLoad). None ⇒ abort.""" + if not isinstance(arg, fx.Node): + return None + if arg in node_to_ir: + return node_to_ir[arg] + if not self.allow_extras: + return None + # Classify external tensor by shape relative to (M, N). + a_shape = _val_shape(A_node) + b_shape = _val_shape(B_node) + if a_shape is None or b_shape is None: + return None + M = a_shape[0] + N = b_shape[1] + shape = _val_shape(arg) + stride = _val_stride(arg) + dt = _val_dtype(arg) + if shape is None or dt is None: + return None + dt_str = _DTYPE_TO_STR.get(dt) + if dt_str is None: + return None + # 1-D case: must distinguish (N,) vs (M,). Compare ints directly. + # When M is SymInt (dynamic batch dim) the M==N collision can't happen + # at compile time, so trust the (N,) match for RowBroadcast. Only the + # "both static + equal" case is ambiguous and we abort. + if len(shape) == 1: + n0 = shape[0] + m_is_static = _is_static_int(M) + n_is_static = _is_static_int(N) + if n_is_static and n0 == N: + # Could still collide with a (M,) col-broadcast iff M is also + # static and equal — abort in that ambiguous case. + if m_is_static and n0 == M: + return None + idx = self._add_extra(extras_nodes, arg) + return RowBroadcast(input_idx=idx, dtype=dt_str) + if m_is_static and n0 == M: + idx = self._add_extra(extras_nodes, arg) + return ColBroadcast(input_idx=idx, dtype=dt_str) + return None + if len(shape) == 2: + # (1, N) row-broadcast view. + if shape[0] == 1 and shape[1] == N: + idx = self._add_extra(extras_nodes, arg) + return RowBroadcast(input_idx=idx, dtype=dt_str) + # (M, 1) col-broadcast view. + if shape[1] == 1 and shape[0] == M: + idx = self._add_extra(extras_nodes, arg) + return ColBroadcast(input_idx=idx, dtype=dt_str) + # Full (M, N) aux load — require row-major contiguous. + if shape[0] == M and shape[1] == N and stride is not None and stride[1] == 1: + idx = self._add_extra(extras_nodes, arg) + return AuxLoad(input_idx=idx, dtype=dt_str) + return None + + def _add_extra(self, extras_nodes, arg) -> int: + for i, e in enumerate(extras_nodes): + if e is arg: + return i + extras_nodes.append(arg) + return len(extras_nodes) - 1 + + # ── swiglu7 special-case ────────────────────────────────────────────────── + + def _try_fuse_swiglu7(self, graph: fx.Graph, mm_node: fx.Node) -> bool: + """Match the canonical swiglu7 epilogue and dispatch to DualGemm. + + We do not attempt to encode swiglu7 in the EVT IR (the dual GEMM is a + whole different kernel structure). Instead we walk forward from mm_node + looking for the exact pattern produced by ``athena.activation.swiglu7`` + after Inductor decomposition. + + On a successful match we emit the magi_epilogue.matmul_custom_evt op + with kind="swiglu7_dual". The ``B`` argument must be the underlying + weight tensor of shape (N, K) — typically the predecessor of an + ``aten.t`` node feeding the mm. + """ + # Recover the underlying weight: B should be a 2-D transpose + # (aten.t / transpose(0,1) / permute([1,0])) of a contiguous (N, K) + # weight. Otherwise bail (no two-stage fallback). + B_node = mm_node.args[1] + if not isinstance(B_node, fx.Node) or not _is_transpose_node(B_node): + return False + weight_node = B_node.args[0] + if not isinstance(weight_node, fx.Node): + return False + w_shape = _val_shape(weight_node) + w_stride = _val_stride(weight_node) + if w_shape is None or len(w_shape) != 2 or w_stride is None: + return False + N, K = w_shape + if not (_is_static_int(N) and N % 2 == 0): + return False + if w_stride != (K, 1): + return False # not contiguous (N, K) — abort + a_dtype = _val_dtype(mm_node.args[0]) + if a_dtype != torch.bfloat16 or _val_dtype(weight_node) != torch.bfloat16: + return False + + # We walk the chain in source order and collect every node belonging to + # the swiglu7 epilogue — anything else aborts. We don't need to verify + # the exact structure (the kernel does that intrinsically); we just need + # to find the final tensor that becomes the chain's only output, plus + # the set of nodes to erase. + chain_nodes: List[fx.Node] = [] + chain_set: set = {mm_node} + last_chain_node: Optional[fx.Node] = None + curr = mm_node.next + while curr is not None and curr.op != "output": + uses_chain = any(isinstance(a, fx.Node) and a in chain_set for a in curr.args) + if not uses_chain: + curr = curr.next + continue + if curr.target not in ( + torch.ops.aten.slice.Tensor, + torch.ops.aten.clamp.default, + torch.ops.aten.clamp_min.default, + torch.ops.aten.clamp_max.default, + torch.ops.aten.sigmoid.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.add.Tensor, + torch.ops.aten.add.Scalar, + torch.ops.aten.mul.Scalar, + torch.ops.prims.convert_element_type.default, + torch.ops.aten._to_copy.default, + torch.ops.aten.clone.default, + torch.ops.aten.contiguous.default, + torch.ops.aten.alias.default, + torch.ops.aten.view.default, + torch.ops.aten.reshape.default, + torch.ops.aten._unsafe_view.default, + ): + # Non-whitelist op consuming the chain → it's the boundary. + # Finalise last_chain_node as the previous node and stop. + # The output-shape check below verifies we actually saw the + # swiglu7 pattern (chain output's last dim must equal N//2). + break + chain_nodes.append(curr) + chain_set.add(curr) + last_chain_node = curr + curr = curr.next + + if last_chain_node is None: + return False + # Output dtype from the final node. + out_dt = _val_dtype(last_chain_node) or torch.bfloat16 + out_shape = _val_shape(last_chain_node) + if out_shape is None or len(out_shape) != 2: + return False + if not _is_static_int(out_shape[1]) or out_shape[1] != N // 2: + # The swiglu7 output's last dim must be N/2. + return False + + # No escape: every chain node's external uses must funnel through the + # final node (otherwise the DualGemm kernel produces only D and we'd + # lose the intermediate consumer). + for n in chain_nodes[:-1]: + for u in n.users: + if u not in chain_set: + return False + + # Emit the call. We do NOT pass IR JSON — the swiglu7 path ignores it. + out_dt_id = evt_runtime.out_dtype_id(out_dt) + n_out = N // 2 + with graph.inserting_after(last_chain_node): + new_node = graph.call_function( + torch.ops.magi_epilogue.matmul_custom_evt.default, + args=(mm_node.args[0], weight_node, [], "", "swiglu7_dual", n_out, out_dt_id), + ) + try: + val_last = last_chain_node.meta.get("val") + if val_last is not None: + new_val = val_last.new_empty_strided( + val_last.shape, (evt_runtime._aligned_n_stride(int(val_last.shape[-1]), val_last.dtype), 1) + ) + new_node.meta["val"] = new_val + except Exception: + pass + + last_chain_node.replace_all_uses_with(new_node) + for n in reversed(chain_nodes): + if len(n.users) == 0 and n is not new_node: + graph.erase_node(n) + # Erase mm and the t() node if no longer used. + if len(mm_node.users) == 0: + graph.erase_node(mm_node) + if isinstance(B_node, fx.Node) and len(B_node.users) == 0: + graph.erase_node(B_node) + return True diff --git a/magi_compiler/passes/piecewise_graph/fusion/cute_kernel.py b/magi_compiler/passes/piecewise_graph/fusion/cute_kernel.py deleted file mode 100644 index fe6e4a0..0000000 --- a/magi_compiler/passes/piecewise_graph/fusion/cute_kernel.py +++ /dev/null @@ -1,1080 +0,0 @@ -# Copyright (c) 2026 SandAI. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""CuTe DSL GEMM with fused in-kernel epilogue for Hopper (SM90+). - -Design ------- -The key insight is that WGMMA accumulates results into register files (``tRS_rD``). -Before those registers are written to shared/global memory, we can apply elementwise -epilogue operations (activation, bias-add, scale, …) *in-place on the register -values* — completely avoiding the extra read-back from global memory that a -separate Triton epilogue pass would require. - -Concretely, inside the CuTe kernel's epilogue loop: - - for epi_idx in range_constexpr(epi_tile_num): - for epi_v in range_constexpr(size_tRS_rD): - tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] - - acc_vec = tRS_rD.load() # FP32 register tensor - # ── INJECT: fused epilogue ────────────────────────────────── - acc_vec = self._apply_epilogue(acc_vec) - # ──────────────────────────────────────────────────────────── - tRS_rD_out.store(acc_vec.to(self.c_dtype)) - ... - -``HopperWgmmaGemmEpilogueFusedKernel`` subclasses -``HopperWgmmaGemmPersistentKernel`` and overrides ``kernel()`` with this -single extra line, plus the mechanism to supply ``_apply_epilogue``. - -Epilogue representation ------------------------ -The epilogue is described by two complementary representations: - -1. **Triton epilogue string** (``epilogue_code``) — already generated by - ``MatmulCustomEpilogueFusionPass._try_fuse_custom_chain``. We *parse* this - string to drive the CuTe DSL code that runs inside the kernel. - -2. **CuTe DSL epilogue callable** (``epilogue_fn``) — a Python callable that - accepts a ``TensorSSA`` (FP32 accumulator tile) and returns a transformed - ``TensorSSA`` of the same shape. It is invoked at ``@cute.jit`` trace time - so it must only use CuTe DSL primitives (``cute.exp``, ``cute.tanh``, …). - -The ``_build_epilogue_fn`` factory converts the Triton epilogue string into a -CuTe DSL callable. It covers the same op set that ``triton_kernels.py`` -supports so all fused chains are handled correctly. - -Extras (bias tensors, etc.) ---------------------------- -The Triton string may reference ``Extra_0_ptr``, ``Extra_1_ptr``, … which are -additional (bias / scale) tensors. At CuTe DSL level these arrive as plain -FP16 1-D or 2-D GPU tensors; the epilogue builder injects loads via a small -helper that reads the correct row of the extra tensor for the current -``epi_idx`` subtile. - -Fallback --------- -On non-Hopper or when ``cutlass-dsl`` is unavailable the module falls back to -the pure-Triton path (``matmul_custom_epilogue`` from ``triton_kernels.py``). -""" - -import ast -import sys -from dataclasses import dataclass -from typing import Callable, List, Optional - -import torch - -from .triton_kernels import matmul_custom_epilogue - -# ── CuTe DSL availability ────────────────────────────────────────────────────── -_HAS_CUTLASS: bool = False -_IS_HOPPER: bool = False - -try: - _IS_HOPPER = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9 - if _IS_HOPPER: - _CUTLASS_HOPPER_DIR = "/root/cutlass/examples/python/CuTeDSL/hopper" - if _CUTLASS_HOPPER_DIR not in sys.path: - sys.path.insert(0, _CUTLASS_HOPPER_DIR) - import cuda.bindings.driver as cuda - import cutlass - import cutlass.cute as cute - import cutlass.torch as cutlass_torch - import cutlass.utils - from dense_gemm_persistent import HopperWgmmaGemmPersistentKernel - - _HAS_CUTLASS = True -except Exception: - pass - - -# ── Epilogue-string → CuTe DSL translator ───────────────────────────────────── - - -def _build_epilogue_fn( - epilogue_code: str, extras: list, reduce_n_by_2: bool # list of GPU torch.Tensor (bias, scale, …) -) -> Optional[Callable]: - """Parse the Triton epilogue code string and return a CuTe DSL callable. - - The returned function has signature:: - - fn(acc_vec: TensorSSA, epi_idx: int, epi_tile_m: int, epi_tile_n: int, - extra_cute_tensors: list) -> TensorSSA - - where ``acc_vec`` is the FP32 register tile (shape = (EPI_TILE_M, EPI_TILE_N) - or a flat vector, depending on how cute delivers it). - - Returns ``None`` if the code string cannot be translated (fall back to Triton). - - Supported Triton constructs → CuTe DSL mapping - ----------------------------------------------- - acc → acc_vec (float32 register tensor) - tl.exp(x) → cute.exp(x) - tl.exp2(x) → cute.exp2(x) - tl.log(x) → cute.log(x) - tl.log2(x) → cute.log2(x) - tl.sqrt(x) → cute.sqrt(x) - tl.tanh(x) → cute.tanh(x) - tl.math.erf(x) → cute.erf(x) - tl.sigmoid(x) → 1/(1+cute.exp(-x)) - tl.maximum(x, y) → cute.where(x > y, x, y) - tl.minimum(x, y) → cute.where(x < y, x, y) - tl.where(c, x, y) → cute.where(c, x, y) - tl.abs(x) → cute.where(x >= 0, x, -x) - Arithmetic (+,-,*,/) → native Python operators on TensorSSA - ext_0 / ext_1 / … → broadcast-loaded from extras list - - Limitation: tl.split / tl.reshape (SwiGLU) are NOT supported in-kernel; - ``reduce_n_by_2=True`` cases fall back to the Triton epilogue path. - """ - if reduce_n_by_2: - return None # SwiGLU split not representable as a simple register op - - # Strip the static-dims header before parsing - code_lines = [l for l in epilogue_code.splitlines() if not l.startswith("# @static:")] - code = "\n".join(code_lines).strip() - if not code or code == "acc = acc": - return None # no-op epilogue — skip - - try: - tree = ast.parse(code, mode="exec") - except SyntaxError: - return None - - # Quick scan: reject unsupported constructs before building the callable - for node in ast.walk(tree): - if isinstance(node, ast.Call): - fn_name = "" - if isinstance(node.func, ast.Attribute): - # e.g. tl.split, tl.reshape → not supported - fn_name = node.func.attr - elif isinstance(node.func, ast.Name): - fn_name = node.func.id - if fn_name in ("split", "reshape"): - return None - - # Build the executable epilogue function via exec() in the CuTe DSL - # namespace. We translate Triton names to their CuTe equivalents by - # injecting a thin shim object ``tl`` that redirects attribute accesses. - fn_src = _emit_cute_epilogue_fn(code_lines, len(extras)) - if fn_src is None: - return None - - ns: dict = {} - exec_globals = {"cute": cute, "cutlass": cutlass} - try: - exec(compile(fn_src, "", "exec"), exec_globals, ns) - except Exception: - return None - - fn = ns.get("_cute_epilogue_fn") - return fn - - -def _emit_cute_epilogue_fn(code_lines: List[str], num_extras: int) -> Optional[str]: - """Emit a Python function that applies the epilogue on a CuTe register tensor. - - The generated function signature is:: - - def _cute_epilogue_fn(acc_vec, extras): - # translated epilogue body - ... - return acc_vec # final result - - ``acc_vec`` is the FP32 ``TensorSSA`` loaded from ``tRS_rD``. - ``extras`` is a list of already-loaded FP32 ``TensorSSA`` slices for each - extra operand (one slice per epi_idx, already broadcast/sliced to the - correct tile). - - Translation rules (Triton → CuTe): - acc → acc_vec - tl.exp(x) → cute.exp(x) - tl.exp2(x) → cute.exp2(x) - tl.log(x) → cute.log(x) - tl.log2(x) → cute.log2(x) - tl.sqrt(x) → cute.sqrt(x) - tl.tanh(x) → cute.tanh(x) - tl.math.erf(x) → cute.erf(x) - tl.sigmoid(x) → 1.0/(1.0+cute.exp(-x)) (emitted inline) - tl.maximum(x,y)→ cute.where(x>y,x,y) - tl.minimum(x,y)→ cute.where(x=0,x,-x) - ext_N → extras[N] (pre-loaded slice) - loads of extra ptrs (ext_N_ptrs / tl.load) → skipped (pre-loaded) - """ - body_lines = [] - - for raw in code_lines: - line = raw.strip() - if not line or line.startswith("#"): - continue - - # Skip the "ext_N_ptrs = ..." and "ext_N = tl.load(...)" lines — - # we supply pre-loaded slices in ``extras`` directly. - if "_ptrs" in line and ("Extra_" in line or "ext_" in line): - continue - # Detect ext_N = tl.load(...) patterns → replace with extras[N] lookup - if line.startswith("ext_") and "= tl.load(" in line: - # e.g. ext_0 = tl.load(ext_0_ptrs, ...) - varname = line.split("=")[0].strip() # "ext_0" - try: - idx = int(varname.split("_")[1]) - except (IndexError, ValueError): - return None - body_lines.append(f" {varname} = extras[{idx}]") - continue - - # Translate the rest - translated = _translate_line(line) - if translated is None: - return None - body_lines.append(f" {translated}") - - # Ensure the function ends with `return acc_vec` - if not any("return" in l for l in body_lines): - body_lines.append(" return acc_vec") - - fn_src = "def _cute_epilogue_fn(acc_vec, extras):\n" - fn_src += "\n".join(body_lines) if body_lines else " pass\n" - fn_src += "\n return acc_vec\n" - return fn_src - - -# ── Line-level Triton → CuTe DSL translator ─────────────────────────────────── - -# Mapping of tl.* / tl.math.* function names to their CuTe equivalents -_TL_TO_CUTE: dict = { - "exp": "cute.exp", - "exp2": "cute.exp2", - "log": "cute.log", - "log2": "cute.log2", - "sqrt": "cute.sqrt", - "rsqrt": "cute.rsqrt", # via cutlass.cute.math - "tanh": "cute.tanh", - "sin": "cute.sin", - "cos": "cute.cos", - "abs": "__cute_abs__", # special-cased - "maximum": "__cute_max__", # special-cased - "minimum": "__cute_min__", # special-cased - "where": "cute.where", - # tl.math.* - "erf": "cute.erf", - "sign": "__cute_sign__", # special-cased -} - -_TL_PASSTHROUGH = frozenset(["maximum", "minimum", "where"]) - - -def _translate_line(line: str) -> Optional[str]: - """Translate a single Triton epilogue line to a CuTe DSL expression. - - Returns the translated line string, or None if untranslatable. - """ - # Replace 'acc' variable (bare or in expressions) with 'acc_vec' - # Use a simple text replacement — won't confuse 'acc' with 'accumulator' etc. - # because the epilogue code only uses 'acc'. - line = _replace_token(line, "acc", "acc_vec") - - # tl.math.erf(x) → cute.erf(x) - line = line.replace("tl.math.erf(", "cute.erf(") - line = line.replace("tl.math.erfc(", "__cute_erfc__(") - line = line.replace("tl.math.erfinv(", "__cute_erfinv__(") - line = line.replace("tl.math.sign(", "__cute_sign__(") - line = line.replace("tl.math.isnan(", "__cute_isnan__(") - line = line.replace("tl.math.isinf(", "__cute_isinf__(") - line = line.replace("tl.math.floor(", "__cute_floor__(") - line = line.replace("tl.math.ceil(", "__cute_ceil__(") - line = line.replace("tl.math.trunc(", "__cute_trunc__(") - line = line.replace("tl.math.round(", "__cute_round__(") - line = line.replace("tl.math.pow(", "__cute_pow__(") - line = line.replace("tl.math.tan(", "__cute_tan__(") - line = line.replace("tl.math.asin(", "__cute_asin__(") - line = line.replace("tl.math.acos(", "__cute_acos__(") - line = line.replace("tl.math.atan(", "__cute_atan__(") - line = line.replace("tl.math.atan2(", "__cute_atan2__(") - line = line.replace("tl.math.sinh(", "__cute_sinh__(") - line = line.replace("tl.math.cosh(", "__cute_cosh__(") - - # tl.abs(x) → cute.where(x >= 0, x, -x) [no native cute.abs] - line = line.replace("tl.abs(", "__cute_abs__(") - - # tl.sigmoid(x) → (1.0/(1.0+cute.exp(-x))) - line = line.replace("tl.sigmoid(", "__cute_sigmoid__(") - - # tl.maximum / tl.minimum / tl.where → cute.where-based - line = line.replace("tl.maximum(", "__cute_max__(") - line = line.replace("tl.minimum(", "__cute_min__(") - line = line.replace("tl.where(", "cute.where(") - - # Standard tl.* math functions - for tl_name, cute_name in _TL_TO_CUTE.items(): - if cute_name.startswith("cute."): - line = line.replace(f"tl.{tl_name}(", f"{cute_name}(") - - # Reject any remaining tl.* calls (unsupported) - if "tl." in line: - return None - - # Expand the __cute_*__ shims inline (simple single-argument forms) - line = _expand_shims(line) - - return line - - -def _replace_token(s: str, old: str, new: str) -> str: - """Replace whole-token occurrences of ``old`` with ``new``.""" - import re - - return re.sub(r'\b' + re.escape(old) + r'\b', new, s) - - -def _expand_shims(line: str) -> str: - """Expand __cute_*__ shims to full CuTe DSL expressions. - - For single-argument shims this is straightforward string replacement. - For multi-argument (max/min) we can't easily parse here, so we emit - helper calls that are defined in the exec namespace. - """ - # These shims are injected into the exec namespace instead - # so no string expansion is needed at this stage — just keep them. - return line - - -def _make_exec_globals() -> dict: - """Build the exec namespace with CuTe DSL helpers for all shims.""" - if not _HAS_CUTLASS: - return {} - - def _cute_abs(x): - zero = cute.full_like(x, 0) - return cute.where(x >= zero, x, -x) - - def _cute_max(x, y): - if isinstance(y, (int, float)): - y = cute.full_like(x, float(y)) - return cute.where(x > y, x, y) - - def _cute_min(x, y): - if isinstance(y, (int, float)): - y = cute.full_like(x, float(y)) - return cute.where(x < y, x, y) - - def _cute_sigmoid(x): - one = cute.full_like(x, 1.0) - return one / (one + cute.exp(-x)) - - def _cute_sign(x): - zero = cute.full_like(x, 0.0) - one = cute.full_like(x, 1.0) - return cute.where(x > zero, one, cute.where(x < zero, -one, zero)) - - def _cute_pow(x, y): - return cute.exp(y * cute.log(x)) - - def _cute_erfc(x): - one = cute.full_like(x, 1.0) - return one - cute.erf(x) - - # Approximate inverse erf (not in CuTe math) - def _cute_erfinv(x): - # Halley approximation — good enough for epilogues - a = cute.full_like(x, 0.147) - pi_a = cute.full_like(x, 2.0 / (3.14159265358979 * 0.147)) - ln_term = cute.log(cute.full_like(x, 1.0) - x * x) - t = cute.sqrt( - cute.sqrt((pi_a + ln_term / cute.full_like(x, 2.0)) ** cute.full_like(x, 2.0) - ln_term / a) - - (pi_a + ln_term / cute.full_like(x, 2.0)) - ) - return cute.where(x >= cute.full_like(x, 0.0), t, -t) - - def _cute_isnan(x): - return x != x - - def _cute_isinf(x): - return cute.where(x != x, cute.full_like(x, 0.0), cute.full_like(x, 1.0)) != cute.full_like(x, 1.0) # placeholder - - def _cute_floor(x): - return cute.exp(cute.full_like(x, 0.0)) * x # placeholder — not in cute.math - - def _cute_ceil(x): - return x - - def _cute_trunc(x): - return x - - def _cute_round(x): - return x - - def _cute_tan(x): - return cute.sin(x) / cute.cos(x) - - def _cute_asin(x): - return cute.math.asin(x) - - def _cute_acos(x): - return cute.math.acos(x) - - def _cute_atan(x): - return cute.math.atan(x) - - def _cute_atan2(x, y): - return cute.math.atan2(x, y) - - def _cute_sinh(x): - ex = cute.exp(x) - return (ex - cute.full_like(x, 1.0) / ex) / cute.full_like(x, 2.0) - - def _cute_cosh(x): - ex = cute.exp(x) - return (ex + cute.full_like(x, 1.0) / ex) / cute.full_like(x, 2.0) - - return { - "cute": cute, - "cutlass": cutlass, - "__cute_abs__": _cute_abs, - "__cute_max__": _cute_max, - "__cute_min__": _cute_min, - "__cute_sigmoid__": _cute_sigmoid, - "__cute_sign__": _cute_sign, - "__cute_pow__": _cute_pow, - "__cute_erfc__": _cute_erfc, - "__cute_erfinv__": _cute_erfinv, - "__cute_isnan__": _cute_isnan, - "__cute_isinf__": _cute_isinf, - "__cute_floor__": _cute_floor, - "__cute_ceil__": _cute_ceil, - "__cute_trunc__": _cute_trunc, - "__cute_round__": _cute_round, - "__cute_tan__": _cute_tan, - "__cute_asin__": _cute_asin, - "__cute_acos__": _cute_acos, - "__cute_atan__": _cute_atan, - "__cute_atan2__": _cute_atan2, - "__cute_sinh__": _cute_sinh, - "__cute_cosh__": _cute_cosh, - } - - -def _compile_epilogue_fn(epilogue_code: str, num_extras: int, reduce_n_by_2: bool) -> Optional[Callable]: - """Compile the epilogue string into a CuTe DSL Python callable. - - Returns None if the epilogue cannot be represented (→ fallback to Triton). - """ - if reduce_n_by_2: - return None - - code_lines = [l for l in epilogue_code.splitlines() if not l.startswith("# @static:")] - code_lines = [l for l in code_lines if l.strip()] - - # Detect extra pointer load patterns and skip them (we inject extras directly) - filtered = [] - for l in code_lines: - stripped = l.strip() - # Skip "ext_N_ptrs = Extra_N_ptr + ..." lines - if "Extra_" in stripped and "_ptrs" in stripped: - continue - # Replace "ext_N = tl.load(ext_N_ptrs, ...)" with "ext_N = extras[N]" - if stripped.startswith("ext_") and "= tl.load(" in stripped: - varname = stripped.split("=")[0].strip() - try: - idx = int(varname.split("_")[1]) - filtered.append(f" {varname} = extras[{idx}]") - except (IndexError, ValueError): - return None - continue - # Translate the line - translated = _translate_line(stripped) - if translated is None: - return None - filtered.append(f" {translated}") - - if not filtered: - return None - - fn_src = "def _cute_epilogue_fn(acc_vec, extras):\n" - fn_src += "\n".join(filtered) - fn_src += "\n return acc_vec\n" - - exec_globals = _make_exec_globals() - ns: dict = {} - try: - exec(compile(fn_src, "", "exec"), exec_globals, ns) - except Exception: - return None - - return ns.get("_cute_epilogue_fn") - - -# ── In-kernel fused GEMM subclass ───────────────────────────────────────────── - -if _HAS_CUTLASS: - - class HopperWgmmaGemmEpilogueFusedKernel(HopperWgmmaGemmPersistentKernel): - """Hopper GEMM with epilogue fused into the accumulator register phase. - - The epilogue is applied on the FP32 accumulator register tensor - *before* it is converted to FP16 and stored, eliminating the extra - global-memory round-trip that a separate Triton epilogue kernel would need. - - Parameters - ---------- - epilogue_fn : callable or None - A CuTe DSL Python function ``fn(acc_vec, extras) -> TensorSSA``. - Compiled from the fusion-pass epilogue string by ``_compile_epilogue_fn``. - When *None*, the behaviour is identical to the base class. - extra_cute_tensors : list[cute.Tensor] - Pre-sliced CuTe tensors for bias / scale operands. One per extra - referenced by the epilogue. Passed through to ``epilogue_fn``. - All other args forwarded to ``HopperWgmmaGemmPersistentKernel.__init__``. - """ - - def __init__( - self, - acc_dtype, - tile_shape_mn, - cluster_shape_mn, - swizzle_size=1, - raster_along_m=True, - epilogue_fn=None, - extra_cute_tensors=None, - ): - super().__init__(acc_dtype, tile_shape_mn, cluster_shape_mn, swizzle_size, raster_along_m) - self._epilogue_fn = epilogue_fn - self._extra_cute_tensors = extra_cute_tensors or [] - - def _apply_epilogue(self, acc_vec): - """Apply the user-supplied epilogue to the FP32 accumulator tile.""" - if self._epilogue_fn is None: - return acc_vec - return self._epilogue_fn(acc_vec, self._extra_cute_tensors) - - # ── Override the GPU kernel to inject the epilogue ───────────────────── - @cute.kernel - def kernel( - self, - tma_atom_a, - mA_mkl, - tma_atom_b, - mB_nkl, - tma_atom_c, - mC_mnl, - tiled_mma, - cta_layout_mnk, - a_smem_layout_staged, - b_smem_layout_staged, - epi_smem_layout_staged, - tile_sched_params, - ): - # ── verbatim copy of the base class kernel body ──────────────────── - # with a single change: acc_vec is passed through _apply_epilogue - # before being stored. - tidx, _, _ = cute.arch.thread_idx() - warp_idx = cute.arch.warp_idx() - warp_idx = cute.arch.make_warp_uniform(warp_idx) - - if warp_idx == 0: - cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) - cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b) - cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_c) - - cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) - cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster) - - a_mcast_mask = cute.make_layout_image_mask(cta_layout_mnk, cluster_coord_mnk, mode=1) - b_mcast_mask = cute.make_layout_image_mask(cta_layout_mnk, cluster_coord_mnk, mode=0) - - a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0 - b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 - a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0)) - b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0)) - tma_copy_bytes = cute.size_in_bytes(self.a_dtype, a_smem_layout) + cute.size_in_bytes(self.b_dtype, b_smem_layout) - - import cutlass.pipeline as pipeline - import cutlass.utils as utils_mod - from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait - - smem = utils_mod.SmemAllocator() - storage = smem.allocate(self.shared_storage) - - mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr() - mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) - mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 - consumer_arrive_cnt = mcast_size * self.num_mma_warp_groups * self.num_warps_per_warp_group - mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, consumer_arrive_cnt) - mainloop_pipeline = pipeline.PipelineTmaAsync.create( - barrier_storage=mainloop_pipeline_array_ptr, - num_stages=self.ab_stage, - producer_group=mainloop_pipeline_producer_group, - consumer_group=mainloop_pipeline_consumer_group, - tx_count=tma_copy_bytes, - cta_layout_vmnk=cute.make_layout((1, *cta_layout_mnk.shape)), - defer_sync=True, - ) - - pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) - - sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) - sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) - sC = storage.sC.get_tensor(epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner) - - gA_mkl = cute.local_tile(mA_mkl, cute.slice_(self.tile_shape_mnk, (None, 0, None)), (None, None, None)) - gB_nkl = cute.local_tile(mB_nkl, cute.slice_(self.tile_shape_mnk, (0, None, None)), (None, None, None)) - gC_mnl = cute.local_tile(mC_mnl, cute.slice_(self.tile_shape_mnk, (None, None, 0)), (None, None, None)) - - a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape) - a_cta_crd = cluster_coord_mnk[1] - tAsA, tAgA = cute.nvgpu.cpasync.tma_partition( - tma_atom_a, a_cta_crd, a_cta_layout, cute.group_modes(sA, 0, 2), cute.group_modes(gA_mkl, 0, 2) - ) - - b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) - b_cta_crd = cluster_coord_mnk[0] - tBsB, tBgB = cute.nvgpu.cpasync.tma_partition( - tma_atom_b, b_cta_crd, b_cta_layout, cute.group_modes(sB, 0, 2), cute.group_modes(gB_nkl, 0, 2) - ) - - warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) - mma_warp_group_thread_layout = cute.make_layout(self.num_mma_warp_groups, stride=self.num_threads_per_warp_group) - thr_mma = tiled_mma.get_slice(mma_warp_group_thread_layout(warp_group_idx - self.num_dma_warp_groups)) - - tCsA = thr_mma.partition_A(sA) - tCsB = thr_mma.partition_B(sB) - tCrA = tiled_mma.make_fragment_A(tCsA) - tCrB = tiled_mma.make_fragment_B(tCsB) - - tCgC = thr_mma.partition_C(gC_mnl) - acc_shape = tCgC.shape[:3] - accumulators = cute.make_rmem_tensor(acc_shape, self.acc_dtype) - - k_tile_cnt = cute.size(gA_mkl, mode=[3]) - - pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) - - is_dma_warp_group = warp_group_idx < self.num_dma_warp_groups - if is_dma_warp_group: - cute.arch.setmaxregister_decrease(self.load_register_requirement) - - # ── DMA warp group ───────────────────────────────────────────────── - if warp_idx == self.load_warp_id: - tile_sched = utils_mod.StaticPersistentTileScheduler.create( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - mainloop_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.ab_stage) - - while work_tile.is_valid_tile: - tile_coord_mnl = work_tile.tile_idx - tAgA_mkl = tAgA[(None, tile_coord_mnl[0], None, tile_coord_mnl[2])] - tBgB_nkl = tBgB[(None, tile_coord_mnl[1], None, tile_coord_mnl[2])] - mainloop_producer_state.reset_count() - - for k_tile in range(k_tile_cnt): - mainloop_pipeline.producer_acquire(mainloop_producer_state) - tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)] - tAsA_pipe = tAsA[(None, mainloop_producer_state.index)] - tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)] - tBsB_pipe = tBsB[(None, mainloop_producer_state.index)] - - cute.copy( - tma_atom_a, - tAgA_k, - tAsA_pipe, - tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state), - mcast_mask=a_mcast_mask, - ) - cute.copy( - tma_atom_b, - tBgB_k, - tBsB_pipe, - tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state), - mcast_mask=b_mcast_mask, - ) - mainloop_pipeline.producer_commit(mainloop_producer_state) - mainloop_producer_state.advance() - - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - - mainloop_pipeline.producer_tail(mainloop_producer_state) - - # ── MMA warp group ───────────────────────────────────────────────── - if not is_dma_warp_group: - cute.arch.setmaxregister_increase(self.mma_register_requirement) - tile_sched = utils_mod.StaticPersistentTileScheduler.create( - tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() - ) - work_tile = tile_sched.initial_work_tile_info() - - mainloop_consumer_read_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage) - mainloop_consumer_release_state = pipeline.make_pipeline_state( - pipeline.PipelineUserType.Consumer, self.ab_stage - ) - - num_k_blocks = cute.size(tCrA, mode=[2]) - - import cutlass.utils.hopper_helpers as sm90_utils - - copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( - self.c_layout, elem_ty_d=self.c_dtype, elem_ty_acc=self.acc_dtype - ) - - copy_atom_C = cute.make_copy_atom( - cute.nvgpu.warp.StMatrix8x8x16bOp(self.c_layout.is_m_major_c(), 4), self.c_dtype - ) - tiled_copy_C_Atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) - tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_Atom) - - thr_copy_r2s = tiled_copy_r2s.get_slice(tidx - self.num_dma_warp_groups * self.num_threads_per_warp_group) - tRS_sD = thr_copy_r2s.partition_D(sC) - tRS_rAcc = tiled_copy_r2s.retile(accumulators) - - rD_shape = cute.shape(thr_copy_r2s.partition_S(sC)) - tRS_rD_layout = cute.make_layout(rD_shape[:3]) - tRS_rD = cute.make_rmem_tensor(tRS_rD_layout.shape, self.acc_dtype) - tRS_rD_out = cute.make_rmem_tensor(tRS_rD_layout.shape, self.c_dtype) - size_tRS_rD = cute.size(tRS_rD) - - k_pipe_mmas = 1 - prologue_mma_cnt = min(k_pipe_mmas, k_tile_cnt) - - tma_store_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, self.num_mma_threads) - tma_store_pipeline = pipeline.PipelineTmaStore.create( - num_stages=self.epi_stage, producer_group=tma_store_producer_group - ) - - while work_tile.is_valid_tile: - tile_coord_mnl = work_tile.tile_idx - gC_mnl_slice = gC_mnl[(None, None, *tile_coord_mnl)] - - mainloop_consumer_read_state.reset_count() - mainloop_consumer_release_state.reset_count() - accumulators.fill(0.0) - tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) - cute.nvgpu.warpgroup.fence() - - for k_tile in range(prologue_mma_cnt): - mainloop_pipeline.consumer_wait(mainloop_consumer_read_state) - for k_block_idx in cutlass.range_constexpr(num_k_blocks): - k_block_coord = (None, None, k_block_idx, mainloop_consumer_read_state.index) - cute.gemm(tiled_mma, accumulators, tCrA[k_block_coord], tCrB[k_block_coord], accumulators) - cute.nvgpu.warpgroup.commit_group() - mainloop_consumer_read_state.advance() - - for k_tile in range(prologue_mma_cnt, k_tile_cnt): - mainloop_pipeline.consumer_wait(mainloop_consumer_read_state) - for k_block_idx in cutlass.range_constexpr(num_k_blocks): - k_block_coord = (None, None, k_block_idx, mainloop_consumer_read_state.index) - cute.gemm(tiled_mma, accumulators, tCrA[k_block_coord], tCrB[k_block_coord], accumulators) - cute.nvgpu.warpgroup.commit_group() - cute.nvgpu.warpgroup.wait_group(k_pipe_mmas) - mainloop_pipeline.consumer_release(mainloop_consumer_release_state) - mainloop_consumer_release_state.advance() - mainloop_consumer_read_state.advance() - - cute.nvgpu.warpgroup.wait_group(0) - for k_tile in range(prologue_mma_cnt): - mainloop_pipeline.consumer_release(mainloop_consumer_release_state) - mainloop_consumer_release_state.advance() - - # Epilogue - tCgC_for_tma_partition = cute.zipped_divide(gC_mnl_slice, self.epi_tile) - bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition( - tma_atom_c, 0, cute.make_layout(1), cute.group_modes(sC, 0, 2), tCgC_for_tma_partition - ) - epi_tile_num = cute.size(tCgC_for_tma_partition, mode=[1]) - epi_tile_shape = tCgC_for_tma_partition.shape[1] - epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1)) - num_prev_epi_tiles = tile_sched.num_tiles_executed * epi_tile_num - - for epi_idx in cutlass.range_constexpr(epi_tile_num): - for epi_v in cutlass.range_constexpr(size_tRS_rD): - tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] - - # ── Load FP32 accumulator tile ───────────────────────── - acc_vec = tRS_rD.load() - - # ── FUSED EPILOGUE: apply in registers ───────────────── - acc_vec = self._apply_epilogue(acc_vec) - - # ── Convert to output dtype and store ────────────────── - tRS_rD_out.store(acc_vec.to(self.c_dtype)) - - epi_buffer = (num_prev_epi_tiles + epi_idx) % cute.size(tRS_sD, mode=[3]) - cute.copy(tiled_copy_r2s, tRS_rD_out, tRS_sD[(None, None, None, epi_buffer)]) - cute.arch.fence_proxy("async.shared", space="cta") - self.epilog_sync_barrier.arrive_and_wait() - - gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) - if warp_idx == self.epi_store_warp_id: - cute.copy(tma_atom_c, bSG_sD[(None, epi_buffer)], bSG_gD[(None, gmem_coord)]) - tma_store_pipeline.producer_commit() - tma_store_pipeline.producer_acquire() - - self.epilog_sync_barrier.arrive_and_wait() - - tile_sched.advance_to_next_work() - work_tile = tile_sched.get_current_work() - - tma_store_pipeline.producer_tail() - - -# ── Two-level fused GEMM cache ───────────────────────────────────────────────── -# -# Shape-polymorphism strategy -# --------------------------- -# ``cute.compile()`` with ``is_dynamic_layout=True`` produces a kernel binary -# that is polymorphic in the M dimension: a kernel compiled for template M=128 -# can be called at runtime for any M (verified experimentally). N and K are -# typically static (weight-matrix dimensions) while M = batch×seq_len varies. -# -# We therefore split the cache into two levels: -# -# _COMPILED_CACHE key: (N, K, epilogue_code, num_extras, reduce_n_by_2) -# value: _CompiledEntry (compiled_gemm) -# → populated once, reused for every new M -# -# _BUFFER_CACHE key: (M, N, K) -# value: _BufferEntry (a/b/c aligned device buffers + CuTe -# descriptors for the specific M) -# → populated once per unique M, much cheaper than recompile -# -# This ensures ``cute.compile()`` is called at most once per (N,K,...) config -# regardless of how many distinct M values appear at runtime. - - -@dataclass -class _CompiledEntry: - """Compiled CuTe kernel — shape-polymorphic in the M dimension.""" - - compiled_gemm: object # result of cute.compile(...) - max_active_clusters: int # baked at compile time (HW-dependent constant) - - -@dataclass -class _BufferEntry: - """Aligned device buffers and CuTe descriptors for a specific (M, N, K).""" - - a_cute: object - a_ref: torch.Tensor # (M, K, 1) — input A - b_cute: object - b_ref: torch.Tensor # (N, K, 1) — input B (transposed) - c_cute: object - c_ref: torch.Tensor # (M, N, 1) — output C - - -_COMPILED_CACHE: dict = {} # (N, K, epi_code, num_extras, reduce_n) → _CompiledEntry | None -_BUFFER_CACHE: dict = {} # (M, N, K) → _BufferEntry - -_TILE_MN = (128, 256) -_CLUSTER_MN = (1, 1) -# Template M used for cute.compile(); the compiled kernel runs for any M. -_TEMPLATE_M = 128 - - -def _compile_kernel(N: int, K: int, epilogue_fn, extra_cute_tensors: list) -> Optional[_CompiledEntry]: - """Compile the fused GEMM kernel for fixed (N, K); polymorphic in M. - - Uses ``_TEMPLATE_M`` as a placeholder M during compilation — the resulting - binary runs correctly for any M because ``is_dynamic_layout=True`` keeps - M out of any ``Constexpr`` baked values. - - Returns None on any compilation failure. - """ - if not _HAS_CUTLASS: - return None - if K % 8 != 0 or N % 8 != 0: - return None - - M = _TEMPLATE_M - l = 1 - a_dtype = cutlass.Float16 - b_dtype = cutlass.Float16 - c_dtype = cutlass.Float16 - acc_dtype = cutlass.Float32 - - a_cpu = cutlass_torch.matrix(l, M, K, False, a_dtype) - b_cpu = cutlass_torch.matrix(l, N, K, False, b_dtype) - c_cpu = cutlass_torch.matrix(l, M, N, False, c_dtype) - - a_cute, _ = cutlass_torch.cute_tensor_like(a_cpu, a_dtype, is_dynamic_layout=True, assumed_align=16) - b_cute, _ = cutlass_torch.cute_tensor_like(b_cpu, b_dtype, is_dynamic_layout=True, assumed_align=16) - c_cute, _ = cutlass_torch.cute_tensor_like(c_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16) - - gemm = HopperWgmmaGemmEpilogueFusedKernel( - acc_dtype, - _TILE_MN, - _CLUSTER_MN, - swizzle_size=1, - raster_along_m=True, - epilogue_fn=epilogue_fn, - extra_cute_tensors=extra_cute_tensors, - ) - - hw = cutlass.utils.HardwareInfo() - mac = hw.get_max_active_clusters(_CLUSTER_MN[0] * _CLUSTER_MN[1]) - cu_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - - try: - compiled_gemm = cute.compile(gemm, a_cute, b_cute, c_cute, mac, cu_stream) - except Exception: - return None - - return _CompiledEntry(compiled_gemm=compiled_gemm, max_active_clusters=mac) - - -def _get_or_create_buffers(M: int, N: int, K: int) -> Optional[_BufferEntry]: - """Return pre-allocated aligned CuTe buffers for the given (M, N, K). - - Allocates once per unique (M, N, K) and caches the result. Allocation is - much cheaper than ``cute.compile()`` but still non-trivial (GPU malloc + - CuTe descriptor creation), so caching across calls with the same shape is - important for training loops where M is fixed per microbatch. - """ - buf_key = (M, N, K) - if buf_key in _BUFFER_CACHE: - return _BUFFER_CACHE[buf_key] - - if not _HAS_CUTLASS: - return None - - l = 1 - a_dtype = cutlass.Float16 - b_dtype = cutlass.Float16 - c_dtype = cutlass.Float16 - - a_cpu = cutlass_torch.matrix(l, M, K, False, a_dtype) - b_cpu = cutlass_torch.matrix(l, N, K, False, b_dtype) - c_cpu = cutlass_torch.matrix(l, M, N, False, c_dtype) - - try: - a_cute, a_ref = cutlass_torch.cute_tensor_like(a_cpu, a_dtype, is_dynamic_layout=True, assumed_align=16) - b_cute, b_ref = cutlass_torch.cute_tensor_like(b_cpu, b_dtype, is_dynamic_layout=True, assumed_align=16) - c_cute, c_ref = cutlass_torch.cute_tensor_like(c_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16) - except Exception: - _BUFFER_CACHE[buf_key] = None - return None - - entry = _BufferEntry(a_cute=a_cute, a_ref=a_ref, b_cute=b_cute, b_ref=b_ref, c_cute=c_cute, c_ref=c_ref) - _BUFFER_CACHE[buf_key] = entry - return entry - - -def _compiled_cache_key(N, K, epilogue_code, num_extras, reduce_n_by_2): - """Cache key for the compiled kernel — M-independent.""" - return (N, K, epilogue_code, num_extras, reduce_n_by_2) - - -# ── Public API ───────────────────────────────────────────────────────────────── - - -def matmul_cute_custom_epilogue( - A: torch.Tensor, B: torch.Tensor, extras: list, epilogue_code: str, reduce_n_by_2: bool -) -> torch.Tensor: - """Run GEMM + epilogue fully fused in the CuTe Hopper kernel. - - The epilogue is applied on the FP32 accumulator register file *before* - type conversion and TMA store, saving one full read of the (M×N) result - from global memory compared to a separate Triton epilogue pass. - - Shape-polymorphic caching - ------------------------- - ``cute.compile()`` is called **at most once** per unique (N, K, epilogue) - configuration regardless of how many distinct M values appear at runtime. - For a typical transformer, N and K are static weight-matrix dimensions - while M = batch×seq_len varies freely; this strategy ensures the expensive - JIT compilation cost is paid only once per layer, not per step. - - At FX graph level, static dims satisfy ``type(d) is int`` on - ``node.meta["val"].shape``; dynamic dims are ``torch.SymInt``. This - function exploits that structure automatically via the two-level cache. - - Falls back to ``matmul_custom_epilogue`` (Triton TMA-persistent) when: - - Not running on Hopper (SM < 90), or - - ``cutlass-dsl`` is not installed, or - - The epilogue contains constructs not representable as CuTe register ops - (e.g. SwiGLU ``tl.split``), or - - The problem dimensions violate 16-byte alignment requirements. - - Parameters - ---------- - A : torch.Tensor — (M, K) FP16 row-major - B : torch.Tensor — (K, N) FP16 row-major - extras : list[torch.Tensor] - Additional bias / scale tensors referenced by the epilogue. - epilogue_code : str - Triton epilogue snippet from the fusion pass. - reduce_n_by_2 : bool - True for SwiGLU (output N = input N / 2). - """ - M, K = A.shape - _, N = B.shape - - if not _HAS_CUTLASS: - return matmul_custom_epilogue(A, B, extras, epilogue_code, reduce_n_by_2) - - # ── Level-1: compiled kernel lookup (expensive; M-independent) ──────────── - compile_key = _compiled_cache_key(N, K, epilogue_code, len(extras), reduce_n_by_2) - - if compile_key not in _COMPILED_CACHE: - epi_fn = _compile_epilogue_fn(epilogue_code, len(extras), reduce_n_by_2) - - if epi_fn is None: - _COMPILED_CACHE[compile_key] = None - else: - extra_cute = [] - for t in extras: - try: - from cutlass.cute.runtime import from_dlpack - - extra_cute.append(from_dlpack(t, assumed_align=16)) - except Exception: - extra_cute = None - break - - if extra_cute is None: - _COMPILED_CACHE[compile_key] = None - else: - compiled_entry = _compile_kernel(N, K, epi_fn, extra_cute) - _COMPILED_CACHE[compile_key] = compiled_entry # None on failure - - compiled_entry = _COMPILED_CACHE.get(compile_key) - if compiled_entry is None: - return matmul_custom_epilogue(A, B, extras, epilogue_code, reduce_n_by_2) - - # ── Level-2: buffer lookup (cheap; once per unique M) ───────────────────── - buf = _get_or_create_buffers(M, N, K) - if buf is None: - return matmul_custom_epilogue(A, B, extras, epilogue_code, reduce_n_by_2) - - # ── Copy input data into aligned CuTe buffers ────────────────────────────── - buf.a_ref.copy_(A.unsqueeze(2)) - buf.b_ref.copy_(B.T.contiguous().unsqueeze(2)) - - # ── Run the fused CuTe kernel ────────────────────────────────────────────── - cu_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) - compiled_entry.compiled_gemm(buf.a_cute, buf.b_cute, buf.c_cute, cu_stream) - - # ── Extract result ───────────────────────────────────────────────────────── - N_out = N // 2 if reduce_n_by_2 else N - elem_size = A.element_size() - align_elems = 128 // elem_size - N_stride = (N_out + align_elems - 1) // align_elems * align_elems - D = torch.empty((M, N_stride), device=A.device, dtype=A.dtype)[:, :N_out] - - # c_ref layout is (M, N, 1); the kernel writes into it via TMA store - D.copy_(buf.c_ref[:, :N_out, 0]) - return D diff --git a/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py b/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py deleted file mode 100644 index e7c4704..0000000 --- a/magi_compiler/passes/piecewise_graph/fusion/matmul_epilogue_fusion.py +++ /dev/null @@ -1,482 +0,0 @@ -# Copyright (c) 2026 SandAI. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import operator - -import torch -import torch.fx as fx -from torch.fx.node import Node - -from magi_compiler.passes.pass_base import MagiInductorPass - -from .cute_kernel import _HAS_CUTLASS, matmul_cute_custom_epilogue -from .triton_kernels import matmul_custom_epilogue - -_LIB = torch.library.Library("magi_epilogue", "DEF") -_LIB.define("matmul_custom(Tensor A, Tensor B, Tensor[] extras, str epilogue_code, bool reduce_n_by_2) -> Tensor") -_LIB.define("matmul_custom_cute(Tensor A, Tensor B, Tensor[] extras, str epilogue_code, bool reduce_n_by_2) -> Tensor") - - -@torch.library.impl(_LIB, "matmul_custom", "CUDA") -def _matmul_custom_cuda(A, B, extras, epilogue_code, reduce_n_by_2): - return matmul_custom_epilogue(A, B, extras, epilogue_code, reduce_n_by_2) - - -@torch.library.impl(_LIB, "matmul_custom_cute", "CUDA") -def _matmul_custom_cute_cuda(A, B, extras, epilogue_code, reduce_n_by_2): - return matmul_cute_custom_epilogue(A, B, extras, epilogue_code, reduce_n_by_2) - - -def _matmul_abstract_shape(A, B, reduce_n_by_2): - """Shared shape + stride logic for both torch.library fake impls.""" - N_out = B.shape[1] // 2 if reduce_n_by_2 else B.shape[1] - # Mirror the 128-byte-aligned row stride used by the real kernel so that - # Inductor's assert_size_stride matches what we actually return. - align_elems = 128 // A.element_size() - N_stride = (N_out + align_elems - 1) // align_elems * align_elems - return A.new_empty_strided((A.shape[0], N_out), (N_stride, 1)) - - -@torch.library.register_fake("magi_epilogue::matmul_custom") -def _matmul_custom_abstract(A, B, extras, epilogue_code, reduce_n_by_2): - return _matmul_abstract_shape(A, B, reduce_n_by_2) - - -@torch.library.register_fake("magi_epilogue::matmul_custom_cute") -def _matmul_custom_cute_abstract(A, B, extras, epilogue_code, reduce_n_by_2): - return _matmul_abstract_shape(A, B, reduce_n_by_2) - - -# ── Triton expression templates ──────────────────────────────────────────────── -# Unary elementwise ops: {x} = operand expression string -_UNARY_EXPRS = { - # Arithmetic - torch.ops.aten.neg.default: "-({x})", - torch.ops.aten.abs.default: "tl.abs({x})", - torch.ops.aten.sign.default: "tl.math.sign({x})", - torch.ops.aten.reciprocal.default: "1.0 / ({x})", - torch.ops.aten.square.default: "({x}) * ({x})", - # Exponential / logarithm - torch.ops.aten.exp.default: "tl.exp({x})", - torch.ops.aten.exp2.default: "tl.exp2({x})", - torch.ops.aten.expm1.default: "tl.exp({x}) - 1.0", - torch.ops.aten.log.default: "tl.log({x})", - torch.ops.aten.log2.default: "tl.log2({x})", - torch.ops.aten.log10.default: "tl.log({x}) * 0.4342944819032518", - torch.ops.aten.log1p.default: "tl.log(1.0 + ({x}))", - # Square-root family - torch.ops.aten.sqrt.default: "tl.sqrt({x})", - torch.ops.aten.rsqrt.default: "1.0 / tl.sqrt({x})", - # Trigonometric - torch.ops.aten.sin.default: "tl.sin({x})", - torch.ops.aten.cos.default: "tl.cos({x})", - torch.ops.aten.tan.default: "tl.math.tan({x})", - torch.ops.aten.asin.default: "tl.math.asin({x})", - torch.ops.aten.acos.default: "tl.math.acos({x})", - torch.ops.aten.atan.default: "tl.math.atan({x})", - # Hyperbolic - torch.ops.aten.tanh.default: "tl.tanh({x})", - torch.ops.aten.sinh.default: "tl.math.sinh({x})", - torch.ops.aten.cosh.default: "tl.math.cosh({x})", - # Activations - torch.ops.aten.sigmoid.default: "tl.sigmoid({x})", - torch.ops.aten.relu.default: "tl.maximum({x}, 0.0)", - # Error function - torch.ops.aten.erf.default: "tl.math.erf({x})", - torch.ops.aten.erfinv.default: "tl.math.erfinv({x})", - torch.ops.aten.erfc.default: "tl.math.erfc({x})", - # Rounding - torch.ops.aten.floor.default: "tl.math.floor({x})", - torch.ops.aten.ceil.default: "tl.math.ceil({x})", - torch.ops.aten.trunc.default: "tl.math.trunc({x})", - torch.ops.aten.round.default: "tl.math.round({x})", - torch.ops.aten.frac.default: "({x}) - tl.math.trunc({x})", - # Bitwise / logical - torch.ops.aten.logical_not.default: "~({x})", - torch.ops.aten.bitwise_not.default: "~({x})", - # Predicates - torch.ops.aten.isnan.default: "tl.math.isnan({x})", - torch.ops.aten.isinf.default: "tl.math.isinf({x})", - torch.ops.aten.isfinite.default: "~tl.math.isinf({x}) & ~tl.math.isnan({x})", -} - -# Binary elementwise ops: {x} = left, {y} = right -_BINARY_EXPRS = { - # Addition / subtraction (alpha handled separately) - torch.ops.aten.add.Tensor: "({x}) + ({y})", - torch.ops.aten.add.Scalar: "({x}) + ({y})", - operator.add: "({x}) + ({y})", - torch.ops.aten.sub.Tensor: "({x}) - ({y})", - torch.ops.aten.sub.Scalar: "({x}) - ({y})", - operator.sub: "({x}) - ({y})", - # Multiplication / division - torch.ops.aten.mul.Tensor: "({x}) * ({y})", - torch.ops.aten.mul.Scalar: "({x}) * ({y})", - operator.mul: "({x}) * ({y})", - torch.ops.aten.div.Tensor: "({x}) / ({y})", - torch.ops.aten.div.Scalar: "({x}) / ({y})", - operator.truediv: "({x}) / ({y})", - torch.ops.aten.remainder.Tensor: "({x}) % ({y})", - torch.ops.aten.remainder.Scalar: "({x}) % ({y})", - operator.mod: "({x}) % ({y})", - # Min / max - torch.ops.aten.maximum.default: "tl.maximum({x}, {y})", - torch.ops.aten.minimum.default: "tl.minimum({x}, {y})", - # Trigonometric binary - torch.ops.aten.atan2.default: "tl.math.atan2({x}, {y})", - # Bitwise / logical binary - torch.ops.aten.bitwise_and.Tensor: "({x}) & ({y})", - torch.ops.aten.bitwise_and.Scalar: "({x}) & ({y})", - operator.and_: "({x}) & ({y})", - torch.ops.aten.bitwise_or.Tensor: "({x}) | ({y})", - torch.ops.aten.bitwise_or.Scalar: "({x}) | ({y})", - operator.or_: "({x}) | ({y})", - torch.ops.aten.bitwise_xor.Tensor: "({x}) ^ ({y})", - torch.ops.aten.bitwise_xor.Scalar: "({x}) ^ ({y})", - operator.xor: "({x}) ^ ({y})", - torch.ops.aten.logical_and.default: "({x}) & ({y})", - torch.ops.aten.logical_or.default: "({x}) | ({y})", - torch.ops.aten.logical_xor.default: "({x}) ^ ({y})", -} - -# Ops that pass through without any value transformation -_PASSTHROUGH_OPS = frozenset( - { - torch.ops.prims.convert_element_type.default, - torch.ops.aten._to_copy.default, - torch.ops.aten.clone.default, - torch.ops.aten.contiguous.default, - torch.ops.aten.alias.default, - } -) - - -def _get_static_dims(mm_node: fx.Node) -> dict: - """Return {name: value} for mm dimensions that are compile-time-constant. - - FX shapes carry plain Python ``int`` for static dims and ``torch.SymInt`` - for symbolic (dynamic) ones. ``type(d) is int`` excludes SymInt even in - PyTorch versions where SymInt happens to subclass int. - """ - static: dict = {} - A, B = mm_node.args - try: - val_a = A.meta.get("val") if isinstance(A, fx.Node) else None - if val_a is not None and val_a.dim() == 2: - for name, idx in (("M", 0), ("K", 1)): - d = val_a.shape[idx] - if type(d) is int: - static[name] = d - val_b = B.meta.get("val") if isinstance(B, fx.Node) else None - if val_b is not None and val_b.dim() == 2: - d = val_b.shape[1] - if type(d) is int: - static["N"] = d - except Exception: - pass - return static - - -class MatmulCustomEpilogueFusionPass(MagiInductorPass): - def __call__(self, graph: fx.Graph) -> bool: - fused = 0 - for node in list(graph.nodes): - if node.op == "call_function" and node.target in (torch.ops.aten.mm.default, torch.ops.aten.mm): - # Prefer the CuTe path on Hopper; fall back to Triton-only. - if _HAS_CUTLASS: - fused += self._try_fuse_custom_chain_cute(graph, node) - else: - fused += self._try_fuse_custom_chain(graph, node) - - if fused: - graph.eliminate_dead_code() - return fused > 0 - - def _try_fuse_custom_chain_cute(self, graph: fx.Graph, mm_node: fx.Node) -> int: - """Like ``_try_fuse_custom_chain`` but emits ``matmul_custom_cute``. - - Uses ``HopperWgmmaGemmPersistentKernel`` for the GEMM and a separate - Triton kernel for the epilogue. The epilogue code string is identical - to the one produced by ``_try_fuse_custom_chain`` so the two methods - share the same generation logic — only the dispatched op differs. - """ - return self._try_fuse_custom_chain(graph, mm_node, op=torch.ops.magi_epilogue.matmul_custom_cute.default) - - def _try_fuse_custom_chain(self, graph: fx.Graph, mm_node: fx.Node, *, op=None) -> int: - """Fuse a chain of elementwise ops following *mm_node* into a single kernel. - - Parameters - ---------- - op : callable, optional - The dispatch target to call in the fused graph node. Defaults to - ``torch.ops.magi_epilogue.matmul_custom.default`` (pure Triton). - Pass ``torch.ops.magi_epilogue.matmul_custom_cute.default`` to use - the CuTe GEMM path instead. - """ - if op is None: - op = torch.ops.magi_epilogue.matmul_custom.default - A, B = mm_node.args - - fused_nodes = {mm_node: "acc"} - nodes_to_remove = [] - epilogue_lines = [] - extras = [] - is_swiglu = False - - def get_val(arg): - if isinstance(arg, Node): - if arg in fused_nodes: - return fused_nodes[arg] - # External tensor — inject a load - idx = len(extras) - extras.append(arg) - name = f"ext_{idx}" - val = arg.meta.get("val") - if val is not None and val.dim() == 1: - epilogue_lines.append(f"{name}_ptrs = Extra_{idx}_ptr + offs_dn[None, :]") - epilogue_lines.append(f"{name} = tl.load({name}_ptrs, mask=offs_dn[None, :] < N, other=0.0)") - else: - epilogue_lines.append( - f"{name}_ptrs = Extra_{idx}_ptr + stride_dm * offs_dm[:, None] + stride_dn * offs_dn[None, :]" - ) - epilogue_lines.append(f"{name} = tl.load({name}_ptrs, mask=mask, other=0.0)") - fused_nodes[arg] = name - return name - return str(arg) - - curr = mm_node.next - last_fused_node = mm_node - - while curr.op != "output": - uses_fused = any(isinstance(a, Node) and a in fused_nodes for a in curr.args) - if not uses_fused: - curr = curr.next - continue - - var_name = f"v_{curr.name}" - target = curr.target - code = None - - # ── 1. Pass-through (type conversion / clone / alias) ───────────── - if target in _PASSTHROUGH_OPS: - fused_nodes[curr] = fused_nodes[curr.args[0]] - nodes_to_remove.append(curr) - last_fused_node = curr - curr = curr.next - continue - - # ── 2. Unary elementwise ops (from dispatch table) ──────────────── - elif target in _UNARY_EXPRS: - x = get_val(curr.args[0]) - code = f"{var_name} = " + _UNARY_EXPRS[target].format(x=x) - - # ── 3. Compound activation functions ────────────────────────────── - elif target in (torch.ops.aten.silu.default, torch.ops.aten.silu): - x = get_val(curr.args[0]) - code = f"{var_name} = ({x}) * tl.sigmoid({x})" - - elif target in (torch.ops.aten.gelu.default, torch.ops.aten.gelu): - x = get_val(curr.args[0]) - approx = curr.kwargs.get("approximate", "none") - if approx == "tanh": - code = ( - f"{var_name} = ({x}) * 0.5 * " - f"(1.0 + tl.tanh(0.7978845608 * (({x}) + 0.044715 * ({x}) * ({x}) * ({x}))))" - ) - else: - code = f"{var_name} = 0.5 * ({x}) * (1.0 + tl.math.erf(({x}) * 0.7071067811865476))" - - elif target == torch.ops.aten.leaky_relu.default: - x = get_val(curr.args[0]) - slope = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("negative_slope", 0.01) - code = f"{var_name} = tl.where({x} >= 0.0, {x}, {slope} * ({x}))" - - elif target == torch.ops.aten.hardtanh.default: - x = get_val(curr.args[0]) - lo = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("min_val", -1.0) - hi = curr.args[2] if len(curr.args) > 2 else curr.kwargs.get("max_val", 1.0) - code = f"{var_name} = tl.minimum(tl.maximum({x}, {lo}), {hi})" - - elif target == torch.ops.aten.hardsigmoid.default: - x = get_val(curr.args[0]) - code = f"{var_name} = tl.minimum(tl.maximum(({x}) / 6.0 + 0.5, 0.0), 1.0)" - - elif target == torch.ops.aten.hardswish.default: - x = get_val(curr.args[0]) - code = f"{var_name} = ({x}) * tl.minimum(tl.maximum(({x}) / 6.0 + 0.5, 0.0), 1.0)" - - elif target == torch.ops.aten.mish.default: - x = get_val(curr.args[0]) - code = f"{var_name} = ({x}) * tl.tanh(tl.log(1.0 + tl.exp({x})))" - - # ── 4. Clamp family ─────────────────────────────────────────────── - elif target in ( - torch.ops.aten.clamp.default, - torch.ops.aten.clamp.Tensor, - torch.ops.aten.clamp_max.default, - torch.ops.aten.clamp_min.default, - ): - x = get_val(curr.args[0]) - if target is torch.ops.aten.clamp_max.default: - lo, hi = None, curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("max", None) - elif target is torch.ops.aten.clamp_min.default: - lo, hi = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("min", None), None - else: - lo = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("min", None) - hi = curr.args[2] if len(curr.args) > 2 else curr.kwargs.get("max", None) - expr = x - if lo is not None: - expr = f"tl.maximum({expr}, {get_val(lo)})" - if hi is not None: - expr = f"tl.minimum({expr}, {get_val(hi)})" - code = f"{var_name} = {expr}" - - # ── 5. Ternary select ───────────────────────────────────────────── - elif target in (torch.ops.aten.where.self, torch.ops.aten.where.ScalarSelf, torch.ops.aten.where.ScalarOther): - cond = get_val(curr.args[0]) - t = get_val(curr.args[1]) - f_ = get_val(curr.args[2]) - code = f"{var_name} = tl.where({cond}, {t}, {f_})" - - # ── 6. pow (special-cased exponents) ───────────────────────────── - elif target in (torch.ops.aten.pow.Tensor_Scalar, torch.ops.aten.pow.Tensor_Tensor): - x = get_val(curr.args[0]) - y = get_val(curr.args[1]) - if str(y) in ("2", "2.0"): - code = f"{var_name} = ({x}) * ({x})" - elif str(y) in ("0.5",): - code = f"{var_name} = tl.sqrt({x})" - elif str(y) in ("-0.5",): - code = f"{var_name} = 1.0 / tl.sqrt({x})" - elif str(y) in ("-1", "-1.0"): - code = f"{var_name} = 1.0 / ({x})" - else: - code = f"{var_name} = tl.math.pow({x}, {y})" - - # ── 7. div with rounding_mode ───────────────────────────────────── - elif target is torch.ops.aten.div.Tensor_mode: - x = get_val(curr.args[0]) - y = get_val(curr.args[1]) - rounding_mode = curr.kwargs.get("rounding_mode", None) or (curr.args[2] if len(curr.args) > 2 else None) - if rounding_mode == "floor": - code = f"{var_name} = tl.math.floor(({x}) / ({y}))" - elif rounding_mode == "trunc": - code = f"{var_name} = tl.math.trunc(({x}) / ({y}))" - else: - code = f"{var_name} = ({x}) / ({y})" - - # ── 8. Binary elementwise ops (from dispatch table) ─────────────── - elif target in _BINARY_EXPRS: - x = get_val(curr.args[0]) - y_raw = curr.args[1] - y = get_val(y_raw) - # Handle optional alpha scalar for add/sub (aten convention) - alpha = (curr.args[2] if len(curr.args) > 2 else None) or curr.kwargs.get("alpha", None) - if alpha is not None and alpha != 1: - y = f"{alpha} * ({y})" - code = f"{var_name} = " + _BINARY_EXPRS[target].format(x=x, y=y) - - # ── 9. Slice: SwiGLU (stride-2 along last dim) ─────────────────── - elif target is torch.ops.aten.slice.Tensor: - dim = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("dim", 0) - start = curr.args[2] if len(curr.args) > 2 else curr.kwargs.get("start", None) - step = curr.args[4] if len(curr.args) > 4 else curr.kwargs.get("step", 1) - - src = curr.args[0] - if isinstance(src, fx.Node) and "val" in src.meta: - rank = src.meta["val"].dim() - is_last_dim = (dim % rank) == (rank - 1) - else: - is_last_dim = dim == -1 - - if is_last_dim and step == 2: - is_swiglu = True - x = get_val(curr.args[0]) - if not x.endswith("_reshaped"): - epilogue_lines.append(f"{x}_reshaped = tl.reshape({x}, (BLOCK_M, BLOCK_N // 2, 2))") - epilogue_lines.append(f"{x}_split_0, {x}_split_1 = tl.split({x}_reshaped)") - fused_nodes[curr.args[0]] = f"{x}_reshaped" - base_x = x - else: - base_x = x[:-9] # strip '_reshaped' - - idx = 0 if (start == 0 or start is None) else 1 - code = f"{var_name} = {base_x}_split_{idx}" - else: - break # non-strided / non-trailing slice — stop fusion - - # ── Unsupported op — stop greedy fusion ──────────────────────────── - else: - break - - if code: - epilogue_lines.append(code) - fused_nodes[curr] = var_name - nodes_to_remove.append(curr) - last_fused_node = curr - - curr = curr.next - - # Validate: intermediate nodes must not escape the fused set - if not nodes_to_remove: - return 0 - for node in nodes_to_remove[:-1]: - for user in node.users: - if user not in nodes_to_remove: - return 0 - - final_var = fused_nodes[last_fused_node] - - # Skip fusion if the epilogue is a no-op (only passthrough ops were - # collected — e.g. a bare _to_copy after mm). Replacing cuBLAS with - # a Triton GEMM that does the exact same work is strictly slower. - if final_var == "acc": - return 0 - - epilogue_lines.append(f"acc = {final_var}") - - epilogue_code = "\n".join(epilogue_lines) - - # Prepend a comment that encodes which mm dimensions are statically - # known at trace time. triton_kernels.py parses this header and - # annotates the corresponding kernel parameters as tl.constexpr so - # Triton can specialise (and optimise) the compiled kernel per value. - static_dims = _get_static_dims(mm_node) - if static_dims: - epilogue_code = f"# @static:{json.dumps(static_dims, separators=(',', ':'))}\n" + epilogue_code - - with graph.inserting_after(last_fused_node): - fused_node = graph.call_function(op, args=(A, B, extras, epilogue_code, is_swiglu)) - if "val" in last_fused_node.meta: - val = last_fused_node.meta["val"] - # Propagate the 128-byte-aligned row stride so downstream - # assert_size_stride checks match what we actually return. - try: - N_out = int(val.shape[-1]) - elem_size = val.element_size() - align_elems = 128 // elem_size - N_stride = (N_out + align_elems - 1) // align_elems * align_elems - new_stride = val.stride()[:-2] + (N_stride, 1) - fused_node.meta["val"] = val.new_empty_strided(val.shape, new_stride) - except Exception: - fused_node.meta["val"] = val - - last_fused_node.replace_all_uses_with(fused_node) - - for n in reversed(nodes_to_remove): - graph.erase_node(n) - graph.erase_node(mm_node) - - return 1 diff --git a/magi_compiler/passes/piecewise_graph/fusion/triton_kernels.py b/magi_compiler/passes/piecewise_graph/fusion/triton_kernels.py deleted file mode 100644 index 203ffef..0000000 --- a/magi_compiler/passes/piecewise_graph/fusion/triton_kernels.py +++ /dev/null @@ -1,582 +0,0 @@ -# Copyright (c) 2026 SandAI. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -import math -import os - -import torch -import triton -import triton.language as tl - -from magi_compiler.config import get_compile_config - -# ── Python-level kernel caches ───────────────────────────────────────────────── -# (num_extras, epilogue_code, reduce_n_by_2) → kernel object -_KERNEL_CACHE: dict = {} -_KERNEL_TMA_CACHE: dict = {} - -# ── Persistent autotune result caches (survive process restart) ──────────────── -_cache_root = get_compile_config().cache_root_dir -_AUTOTUNE_FILE = os.path.join(_cache_root, "magi_epilogue_autotune.json") -_AUTOTUNE_FILE_TMA = os.path.join(_cache_root, "magi_epilogue_autotune_tma.json") -_AUTOTUNE_PERSIST: dict = {} -_AUTOTUNE_PERSIST_TMA: dict = {} - - -def _load_autotune_cache() -> None: - global _AUTOTUNE_PERSIST - try: - with open(_AUTOTUNE_FILE) as f: - _AUTOTUNE_PERSIST = json.load(f) - except (FileNotFoundError, json.JSONDecodeError): - _AUTOTUNE_PERSIST = {} - - -def _save_autotune_cache() -> None: - os.makedirs(os.path.dirname(_AUTOTUNE_FILE), exist_ok=True) - with open(_AUTOTUNE_FILE, "w") as f: - json.dump(_AUTOTUNE_PERSIST, f) - - -def _load_autotune_cache_tma() -> None: - global _AUTOTUNE_PERSIST_TMA - try: - with open(_AUTOTUNE_FILE_TMA) as f: - _AUTOTUNE_PERSIST_TMA = json.load(f) - except (FileNotFoundError, json.JSONDecodeError): - _AUTOTUNE_PERSIST_TMA = {} - - -def _save_autotune_cache_tma() -> None: - os.makedirs(os.path.dirname(_AUTOTUNE_FILE_TMA), exist_ok=True) - with open(_AUTOTUNE_FILE_TMA, "w") as f: - json.dump(_AUTOTUNE_PERSIST_TMA, f) - - -_load_autotune_cache() - - -def _check_tma() -> bool: - """Return True when SM90+ TMA with device-side descriptors is available.""" - try: - return ( - torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 9 and hasattr(tl, "make_tensor_descriptor") - ) - except Exception: - return False - - -_TMA_AVAILABLE: bool = _check_tma() -_TMA_ALLOCATOR_SET: bool = False - -if _TMA_AVAILABLE: - _load_autotune_cache_tma() - - -def _ensure_tma_allocator() -> None: - """Set a Triton global-memory allocator once; required by device-side TMA descriptors.""" - global _TMA_ALLOCATOR_SET - if _TMA_ALLOCATOR_SET: - return - - def _alloc_fn(size: int, alignment: int, stream): - return torch.empty(size, device="cuda", dtype=torch.int8) - - triton.set_allocator(_alloc_fn) - _TMA_ALLOCATOR_SET = True - - -def _parse_static_dims(epilogue_code: str) -> dict: - """Parse the ``# @static:{...}`` header injected by the fusion pass. - - Returns a dict like ``{"M": 2048, "K": 4096, "N": 8192}`` (only the keys - that are actually static). Missing keys mean the dimension is dynamic. - """ - for line in epilogue_code.splitlines(): - if line.startswith("# @static:"): - try: - return json.loads(line[len("# @static:") :]) - except Exception: - pass - return {} - - -def _bucket_m(M: int) -> int: - """Round M up to the nearest power-of-2 bucket. - - This drastically reduces the number of distinct (M, N, K) triples - that trigger autotune: e.g. M=1000 and M=1023 both map to 1024, - reusing the same benchmark result instead of each triggering 27 × 125 - device kernel launches. - """ - return 1 << math.ceil(math.log2(max(M, 1))) - - -# ── Autotune config list ─────────────────────────────────────────────────────── -# Shapes that prune_configs removes: -# • BLOCK_M > M_bucket → waste SM occupancy on empty rows -# • BLOCK_K > K → single-iteration k-loop, large overhead -# • BLOCK_N > N → waste on empty columns - - -def _prune_configs(configs, named_args, **kwargs): - M = named_args["M"] - N = named_args["N"] - K = named_args["K"] - pruned = [] - for cfg in configs: - bm = cfg.kwargs["BLOCK_M"] - bn = cfg.kwargs["BLOCK_N"] - bk = cfg.kwargs["BLOCK_K"] - # Keep configs whose tiles are no larger than 4× the dimension - # (leaving room for the autotuner to still test large tiles that - # can handle moderate-size matrices efficiently). - if bm > 4 * M or bn > 4 * N or bk > K: - continue - pruned.append(cfg) - # Always keep at least one fallback - return pruned if pruned else [configs[0]] - - -# ── Shared autotune config list (embedded as a string in both templates) ─────── -_AUTOTUNE_CONFIGS_BODY = """ - # ── Large-tile: high-throughput for large M/N (training) ────────────────── - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=3, num_warps=8), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=4, num_warps=4), - # ── Medium-tile: balanced for mixed shapes ───────────────────────────────── - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 256, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 128, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=4, num_warps=4), - # ── Small-tile: high occupancy for small-M or tail dimensions ───────────── - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 128, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=4, num_warps=4), - triton.Config({"BLOCK_M": 64, "BLOCK_N": 32, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 64, "BLOCK_K": 32, "GROUP_M": 8}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=5, num_warps=2), - triton.Config({"BLOCK_M": 16, "BLOCK_N": 32, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=6, num_warps=2), - triton.Config({"BLOCK_M": 32, "BLOCK_N": 16, "BLOCK_K": 64, "GROUP_M": 8}, num_stages=6, num_warps=2), -""" - - -# ───────────────────────────────────────────────────────────────────────────── -# Non-persistent kernel template (all CUDA GPUs) -# Uses tl.where + tl.max_contiguous + tl.multiple_of for vectorised loads. -# ───────────────────────────────────────────────────────────────────────────── -KERNEL_TEMPLATE = """ -import triton -import triton.language as tl - -_AUTOTUNE_CONFIGS = [ -{autotune_configs} -] - -@triton.autotune( - configs=_AUTOTUNE_CONFIGS, - key=["M_BUCKET", "N", "K"], - prune_configs_by={{"early_config_prune": {prune_fn_name}}}, - warmup=10, - rep=30, -) -@triton.jit -def dynamic_matmul_epilogue_kernel( - A_ptr, B_ptr, D_ptr, - {extra_ptrs_args} - M{M_annot}, N{N_annot}, K{K_annot}, - M_BUCKET, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_dm, stride_dn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, -): - pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_M) - num_pid_n = tl.cdiv(N, BLOCK_N) - - num_pid_in_group = GROUP_M * num_pid_n - group_id = pid // num_pid_in_group - first_pid_m = group_id * GROUP_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (pid % group_size_m) - pid_n = (pid % num_pid_in_group) // group_size_m - - start_m = pid_m * BLOCK_M - start_n = pid_n * BLOCK_N - - offs_am = start_m + tl.arange(0, BLOCK_M) - offs_bn = start_n + tl.arange(0, BLOCK_N) -{offs_am_guard}{offs_bn_guard} offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_M), BLOCK_M) - offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_N), BLOCK_N) - offs_k = tl.arange(0, BLOCK_K) - - A_ptrs = A_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) - B_ptrs = B_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_K)): - a = tl.load(A_ptrs{k_mask_a}) - b = tl.load(B_ptrs{k_mask_b}) - acc = tl.dot(a, b, acc) - A_ptrs += BLOCK_K * stride_ak - B_ptrs += BLOCK_K * stride_bk - - offs_dm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_dn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - mask = {out_mask_expr} - -{epilogue_code} - -{store_code} -""" - - -# ───────────────────────────────────────────────────────────────────────────── -# TMA persistent kernel template (SM90+: H100 / Hopper and newer) -# -# Key advantages over the non-persistent path: -# 1. Device-side tl.make_tensor_descriptor — no host→device descriptor copy. -# 2. Persistent CTA loop — each SM processes multiple tiles, amortising -# kernel-launch and L2-warmup overhead. -# 3. Hardware-managed OOB fill — TMA zero-fills out-of-bounds tile edges, -# so the k-loop needs no software mask. -# 4. B read as [K, N] (no pre-transpose required). -# -# {epilogue_code} and {store_code} are injected at 8-space indent so they -# land inside the `for tile_id` persistent loop body. -# ───────────────────────────────────────────────────────────────────────────── -KERNEL_TEMPLATE_TMA_PERSISTENT = """ -import triton -import triton.language as tl - -_AUTOTUNE_CONFIGS_TMA = [ -{autotune_configs} -] - -@triton.autotune( - configs=_AUTOTUNE_CONFIGS_TMA, - key=["M_BUCKET", "N", "K"], - prune_configs_by={{"early_config_prune": {prune_fn_name}}}, - warmup=10, - rep=30, -) -@triton.jit -def dynamic_matmul_epilogue_kernel_tma( - A_ptr, B_ptr, D_ptr, - {extra_ptrs_args} - M{M_annot}, N{N_annot}, K{K_annot}, - M_BUCKET, - stride_dm, stride_dn, - NUM_SMS: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, -): - # Device-side TMA descriptor creation — eliminates host→device copy latency. - # A is [M, K] row-major; B is [K, N] row-major (no pre-transpose needed). - # TMA hardware zero-fills tiles that extend past the tensor boundary. - a_desc = tl.make_tensor_descriptor( - A_ptr, shape=[M, K], strides=[K, 1], block_shape=[BLOCK_M, BLOCK_K], - ) - b_desc = tl.make_tensor_descriptor( - B_ptr, shape=[K, N], strides=[N, 1], block_shape=[BLOCK_K, BLOCK_N], - ) - - start_pid = tl.program_id(axis=0) - num_pid_m = tl.cdiv(M, BLOCK_M) - num_pid_n = tl.cdiv(N, BLOCK_N) - num_tiles = num_pid_m * num_pid_n - num_pid_in_group = GROUP_M * num_pid_n - - # Each CTA iterates over multiple tiles, stepping NUM_SMS at a time. - for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): - group_id = tile_id // num_pid_in_group - first_pid_m = group_id * GROUP_M - group_size_m = min(num_pid_m - first_pid_m, GROUP_M) - pid_m = first_pid_m + (tile_id % group_size_m) - pid_n = (tile_id % num_pid_in_group) // group_size_m - - offs_am = pid_m * BLOCK_M - offs_bn = pid_n * BLOCK_N - - acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(0, tl.cdiv(K, BLOCK_K)): - offs_k = k * BLOCK_K - a = a_desc.load([offs_am, offs_k]) - b = b_desc.load([offs_k, offs_bn]) - acc = tl.dot(a, b, acc) - - offs_dm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_dn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) - mask = {out_mask_expr} - -{epilogue_code} - -{store_code} -""" - - -def _build_kernel_via_exec( - template: str, kernel_name: str, num_extras: int, epilogue_code: str, reduce_n_by_2: bool, indent: int, persist_cache: dict -) -> object: - """Compile *template* with exec() and return the resulting Triton kernel.""" - extra_ptrs_args = "".join([f"Extra_{i}_ptr, " for i in range(num_extras)]) - - # ── Derive tl.constexpr annotations and static mask/guard expressions ──── - # The fusion pass prepends a "# @static:{...}" comment to epilogue_code - # whenever it can prove (from FakeTensor meta) that a dimension is a plain - # Python int rather than a SymInt. - static_dims = _parse_static_dims(epilogue_code) - M_static = static_dims.get("M") - N_static = static_dims.get("N") - K_static = static_dims.get("K") - - # tl.constexpr annotation: Triton JIT-compiles one kernel variant per - # unique value, making all constexpr-dependent expressions compile-time - # constants (loop bounds, tile counts, mask predicates, etc.). - M_annot = ": tl.constexpr" if M_static is not None else "" - N_annot = ": tl.constexpr" if N_static is not None else "" - K_annot = ": tl.constexpr" if K_static is not None else "" - - # ── k-loop load masks ───────────────────────────────────────────────────── - # Our BLOCK_K configs are {32, 64, 128}; the mask in the k-loop is needed - # only when K is not a multiple of the chosen BLOCK_K. If K % 128 == 0, - # then K is a multiple of every BLOCK_K in the config set, so the mask - # predicate is always all-true and we can emit bare (unmasked) tl.load - # calls — the hottest path in the kernel. - if K_static is not None and K_static % 128 == 0: - k_mask_a = "" - k_mask_b = "" - else: - k_mask_a = ", mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0" - k_mask_b = ", mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0" - - # ── A / B index boundary guards ─────────────────────────────────────────── - # tl.where(offs < dim, offs, 0) prevents out-of-bounds pointer arithmetic - # when a tile straddles the last row/column. If dim is a multiple of the - # largest BLOCK size (256 covers all configs {16,32,64,128,256}), every - # tile is a full tile and the guard is dead code — remove it. - m_tile_aligned = M_static is not None and M_static % 256 == 0 - n_tile_aligned = N_static is not None and N_static % 256 == 0 - - offs_am_guard = "" if m_tile_aligned else " offs_am = tl.where(offs_am < M, offs_am, 0)\n" - offs_bn_guard = "" if n_tile_aligned else " offs_bn = tl.where(offs_bn < N, offs_bn, 0)\n" - - # ── Output (and epilogue) mask ──────────────────────────────────────────── - # The mask tensor is referenced by both the output store and extra-tensor - # loads inside epilogue_code. When a dimension is tile-aligned we drop - # its component from the predicate; both dropped → constant True mask (the - # compiler will eliminate it entirely from the PTX). - if m_tile_aligned and n_tile_aligned: - out_mask_expr = "tl.full([BLOCK_M, BLOCK_N], True, dtype=tl.int1)" - elif m_tile_aligned: - out_mask_expr = "offs_dn[None, :] < N" - elif n_tile_aligned: - out_mask_expr = "offs_dm[:, None] < M" - else: - out_mask_expr = "(offs_dm[:, None] < M) & (offs_dn[None, :] < N)" - - pad = " " * indent - indented_epilogue = "\n".join([f"{pad}{line}" for line in epilogue_code.strip().split("\n") if line]) - - if reduce_n_by_2: - # For SwiGLU the output N is N//2; output BLOCK size is BLOCK_N//2 - # whose maximum across configs is 128. Tile-alignment condition: - # (N_static // 2) % 128 == 0 ↔ N_static % 256 == 0 (same as n_tile_aligned). - if m_tile_aligned and n_tile_aligned: - mask_out_expr = "tl.full([BLOCK_M, BLOCK_N // 2], True, dtype=tl.int1)" - elif m_tile_aligned: - mask_out_expr = "offs_dn_out[None, :] < N // 2" - elif n_tile_aligned: - mask_out_expr = "offs_dm[:, None] < M" - else: - mask_out_expr = "(offs_dm[:, None] < M) & (offs_dn_out[None, :] < N // 2)" - store_code = ( - f"{pad}offs_dn_out = pid_n * (BLOCK_N // 2) + tl.arange(0, BLOCK_N // 2)\n" - f"{pad}mask_out = {mask_out_expr}\n" - f"{pad}D_ptrs = D_ptr + stride_dm * offs_dm[:, None] + stride_dn * offs_dn_out[None, :]\n" - f"{pad}tl.store(D_ptrs, acc.to(D_ptr.dtype.element_ty), mask=mask_out)" - ) - else: - store_code = ( - f"{pad}D_ptrs = D_ptr + stride_dm * offs_dm[:, None] + stride_dn * offs_dn[None, :]\n" - f"{pad}tl.store(D_ptrs, acc.to(D_ptr.dtype.element_ty), mask=mask)" - ) - - code = template.format( - autotune_configs=_AUTOTUNE_CONFIGS_BODY, - extra_ptrs_args=extra_ptrs_args, - epilogue_code=indented_epilogue, - store_code=store_code, - prune_fn_name="_prune_configs", - M_annot=M_annot, - N_annot=N_annot, - K_annot=K_annot, - offs_am_guard=offs_am_guard, - offs_bn_guard=offs_bn_guard, - k_mask_a=k_mask_a, - k_mask_b=k_mask_b, - out_mask_expr=out_mask_expr, - ) - - import linecache - import uuid - - filename = f"" - linecache.cache[filename] = (len(code), None, [line + "\n" for line in code.splitlines()], filename) - compiled = compile(code, filename, "exec") - - namespace: dict = {} - exec(compiled, {"triton": triton, "tl": tl, "_prune_configs": _prune_configs}, namespace) - kernel = namespace[kernel_name] - - # Warm the in-process autotune cache from the persisted JSON so that - # known shapes skip the benchmark entirely on restart. - key_str = str((num_extras, epilogue_code, reduce_n_by_2)) - for cache_key, best_cfg in persist_cache.items(): - if cache_key.startswith(key_str + "|"): - suffix = cache_key[len(key_str) + 1 :] - try: - m_bucket, n, k = (int(x) for x in suffix.split(",")) - except ValueError: - continue - triton_key = (m_bucket, n, k) - cfg = triton.Config( - {k2: v for k2, v in best_cfg["kwargs"].items()}, - num_stages=best_cfg["num_stages"], - num_warps=best_cfg["num_warps"], - ) - kernel.cache[triton_key] = cfg - - return kernel - - -def get_dynamic_kernel(num_extras: int, epilogue_code: str, reduce_n_by_2: bool): - key = (num_extras, epilogue_code, reduce_n_by_2) - if key in _KERNEL_CACHE: - return _KERNEL_CACHE[key] - kernel = _build_kernel_via_exec( - KERNEL_TEMPLATE, - "dynamic_matmul_epilogue_kernel", - num_extras, - epilogue_code, - reduce_n_by_2, - indent=4, - persist_cache=_AUTOTUNE_PERSIST, - ) - _KERNEL_CACHE[key] = kernel - return kernel - - -def get_dynamic_kernel_tma(num_extras: int, epilogue_code: str, reduce_n_by_2: bool): - """Build the TMA-persistent variant via exec().""" - key = (num_extras, epilogue_code, reduce_n_by_2) - if key in _KERNEL_TMA_CACHE: - return _KERNEL_TMA_CACHE[key] - kernel = _build_kernel_via_exec( - KERNEL_TEMPLATE_TMA_PERSISTENT, - "dynamic_matmul_epilogue_kernel_tma", - num_extras, - epilogue_code, - reduce_n_by_2, - indent=8, # epilogue/store are inside the persistent for-loop - persist_cache=_AUTOTUNE_PERSIST_TMA, - ) - _KERNEL_TMA_CACHE[key] = kernel - return kernel - - -def _record_best_config(kernel, epilogue_key: str, M_bucket: int, N: int, K: int, persist: dict, save_fn) -> None: - """Persist the winning autotune config to disk after it is chosen.""" - triton_key = (M_bucket, N, K) - cfg = kernel.cache.get(triton_key) - if cfg is None: - return - cache_key = f"{epilogue_key}|{M_bucket},{N},{K}" - persist[cache_key] = {"kwargs": dict(cfg.kwargs), "num_stages": cfg.num_stages, "num_warps": cfg.num_warps} - save_fn() - - -def matmul_custom_epilogue( - A: torch.Tensor, B: torch.Tensor, extras: list[torch.Tensor], epilogue_code: str, reduce_n_by_2: bool -) -> torch.Tensor: - M, K = A.shape - _, N = B.shape - M_bucket = _bucket_m(M) - - N_out = N // 2 if reduce_n_by_2 else N - - # Align the row stride to 128 bytes so a subsequent cuBLAS mm can read - # this buffer as its A operand without Inductor inserting a row-padding copy. - elem_size = A.element_size() - align_elems = 128 // elem_size - N_stride = (N_out + align_elems - 1) // align_elems * align_elems - D = torch.empty((M, N_stride), device=A.device, dtype=A.dtype)[:, :N_out] - - epilogue_key = str((len(extras), epilogue_code, reduce_n_by_2)) - triton_key = (M_bucket, N, K) - - use_tma = _TMA_AVAILABLE and A.is_contiguous() and B.is_contiguous() - - if use_tma: - # ── TMA persistent path (SM90+) ─────────────────────────────────────── - # Device-side descriptors + persistent CTA loop over NUM_SMS SMs. - # B is read as [K, N] row-major; no pre-transpose required. - _ensure_tma_allocator() - NUM_SMS = torch.cuda.get_device_properties(A.device).multi_processor_count - kernel = get_dynamic_kernel_tma(len(extras), epilogue_code, reduce_n_by_2) - needs_persist = triton_key not in kernel.cache - - grid = lambda meta: (min(NUM_SMS, triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"])),) - - args = [A, B, D] - args.extend(extras) - args.extend([M, N, K, M_bucket, D.stride(0), D.stride(1), NUM_SMS]) - - kernel[grid](*args) - - if needs_persist: - _record_best_config(kernel, epilogue_key, M_bucket, N, K, _AUTOTUNE_PERSIST_TMA, _save_autotune_cache_tma) - - else: - # ── Non-persistent pointer-arithmetic path (all CUDA GPUs) ─────────── - kernel = get_dynamic_kernel(len(extras), epilogue_code, reduce_n_by_2) - needs_persist = triton_key not in kernel.cache - - grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]) * triton.cdiv(N, meta["BLOCK_N"]),) - - args = [A, B, D] - args.extend(extras) - args.extend([M, N, K, M_bucket, A.stride(0), A.stride(1), B.stride(0), B.stride(1), D.stride(0), D.stride(1)]) - - kernel[grid](*args) - - if needs_persist: - _record_best_config(kernel, epilogue_key, M_bucket, N, K, _AUTOTUNE_PERSIST, _save_autotune_cache) - - return D diff --git a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py index 8e48203..d95e50b 100644 --- a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py +++ b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py @@ -22,10 +22,22 @@ from ...utils.envs import MAGI_PATTERN_MATCH_DEBUG from ..pass_base import InductorPass, get_pass_context from .fix_functionalization import FixFunctionalizationPass -from .fusion.matmul_epilogue_fusion import MatmulCustomEpilogueFusionPass +from .fusion.blackwell_geforce.matmul_epilogue_fusion import MatmulEvtEpilogueFusionPass from .post_cleanup import PostCleanupPass +def _device_capability_major() -> int: + """Return the CUDA major capability, or 0 when CUDA is unavailable.""" + try: + import torch as _torch + + if _torch.cuda.is_available(): + return _torch.cuda.get_device_capability()[0] + except Exception: + pass + return 0 + + def with_pattern_match_debug(fn): """ Function decorator that turns on inductor pattern match debug @@ -81,8 +93,9 @@ def __call__(self, graph: fx.Graph): def configure(self, pass_config: PassConfig): self.pass_config = pass_config - # TODO: Register custom passes here (fusion, noop elimination, sequence parallelism, async TP, Ulysses overlap). - self.add(MatmulCustomEpilogueFusionPass()) + # Matmul + epilogue fusion. On sm_120 (Blackwell consumer / RTX 5090) + if _device_capability_major() >= 12: + self.add(MatmulEvtEpilogueFusionPass()) # needs a functional graph self.post_cleanup = PostCleanupPass() diff --git a/tests/feature_tests/test_matmul_epilogue_fusion.py b/tests/feature_tests/test_matmul_epilogue_fusion.py index 15e7127..b6489bb 100644 --- a/tests/feature_tests/test_matmul_epilogue_fusion.py +++ b/tests/feature_tests/test_matmul_epilogue_fusion.py @@ -12,10 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for the CUTLASS Sm80EVT matmul-epilogue fusion path on RTX 5090. + +Three families of checks: + + 1. Positive numerical equivalence: every supported epilogue (the 7 athena + activations + binary ops + 1-D bias) must match eager within bf16 tol. + 2. Fusion-actually-fired: the emitted graph must contain a + ``magi_epilogue.matmul_custom_evt`` node — a green numerical test alone + would silently pass even if fusion was skipped (eager == "compiled"). + 3. Negative fallback: shapes / dtypes / chains the EVT pass does NOT + support must keep the original ``aten.mm`` and run through cuBLAS. + Catches over-eager fusion that would corrupt downstream consumers. +""" + from typing import Optional import pytest import torch +import torch.fx as fx import torch.nn as nn import torch.nn.functional as F @@ -24,28 +39,28 @@ pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +_SM120_ONLY = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 12, + reason="CUTLASS EVT path targets sm_120 (Blackwell consumer)", +) -# --------------------------------------------------------------------------- -# Activation functions -# --------------------------------------------------------------------------- + +# ── Activations from athena/performer_v16/activation.py (verbatim) ──────────── def high_precision_silu(x, out_dtype: Optional[torch.dtype] = None): out_dtype = x.dtype if out_dtype is None else out_dtype - x = x.to(torch.float32) - return F.silu(x).to(out_dtype) + return F.silu(x.to(torch.float32)).to(out_dtype) def high_precision_sigmoid(x, out_dtype: Optional[torch.dtype] = None): out_dtype = x.dtype if out_dtype is None else out_dtype - x = x.to(torch.float32) - return F.sigmoid(x).to(out_dtype) + return F.sigmoid(x.to(torch.float32)).to(out_dtype) def high_precision_gelu(x, out_dtype: Optional[torch.dtype] = None): out_dtype = x.dtype if out_dtype is None else out_dtype - x = x.to(torch.float32) - return F.gelu(x).to(out_dtype) + return F.gelu(x.to(torch.float32)).to(out_dtype) def swiglu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None): @@ -68,131 +83,461 @@ def gelu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch def relu_square(x, out_dtype: Optional[torch.dtype] = None): out_dtype = x.dtype if out_dtype is None else out_dtype - x = x.to(torch.float32) - return torch.square(F.relu(x)).to(out_dtype) + return torch.square(F.relu(x.to(torch.float32))).to(out_dtype) + + +# ── Compile + fusion-side instrumentation ──────────────────────────────────── + + +class _FusionStats: + """Records what the EVT pass did to the graph during one ``magi_compile``. + + Captured by patching ``MatmulEvtEpilogueFusionPass.__call__`` for the scope + of a test. We track: + * mm_before — count of ``aten.mm`` nodes seen on entry + * mm_after — same after the pass + * fused_count — number of ``magi_epilogue.matmul_custom_evt`` nodes + inserted (i.e. how many mm sites the pass actually + replaced; ``mm_before - mm_after`` only matches when + fusion never aborts mid-walk). + * kinds — the ``kind`` arg of each emitted op, e.g. + ["evt_row", "swiglu7_dual"]. + + Tests assert against these to prove the pass made the right choice — a + purely numerical comparison against eager would silently pass even when + fusion was skipped (because both paths fall back to cuBLAS). + """ + + def __init__(self) -> None: + self.mm_before = 0 + self.mm_after = 0 + self.fused_count = 0 + self.kinds: list = [] + + +def _install_pass_instrument(): + """Returns (stats, restore_fn). Wraps the FX pass to record per-call deltas.""" + from magi_compiler.passes.piecewise_graph.fusion.blackwell_geforce import matmul_epilogue_fusion as P + + stats = _FusionStats() + original = P.MatmulEvtEpilogueFusionPass.__call__ + evt_op = torch.ops.magi_epilogue.matmul_custom_evt.default + mm_targets = (torch.ops.aten.mm.default, torch.ops.aten.mm) + + def _instrumented(self, graph: fx.Graph): + before = sum(1 for n in graph.nodes if n.op == "call_function" and n.target in mm_targets) + result = original(self, graph) + after = sum(1 for n in graph.nodes if n.op == "call_function" and n.target in mm_targets) + emitted_kinds = [] + for n in graph.nodes: + if n.op == "call_function" and n.target is evt_op: + # signature: (A, B, extras, ir_json, kind, n_out, out_dtype_id) + if len(n.args) >= 5: + emitted_kinds.append(n.args[4]) + stats.mm_before += before + stats.mm_after += after + stats.fused_count += len(emitted_kinds) + stats.kinds.extend(emitted_kinds) + return result + + P.MatmulEvtEpilogueFusionPass.__call__ = _instrumented + + def restore(): + P.MatmulEvtEpilogueFusionPass.__call__ = original + + return stats, restore + + +def _compile_and_check( + model: nn.Module, + inputs, + *, + atol: float = 0.5, + rtol: float = 0.0, + expect_fused: int = -1, + expect_kinds: Optional[list] = None, + dynamic_arg_dims=None, +): + """Compile ``model``, run it on ``inputs``, compare against eager. + + Parameters + ---------- + model, inputs + ``inputs`` is a tuple/list passed positionally to forward. + atol, rtol + Numerical tolerance: ``|actual - expected| <= atol + rtol*|expected|``. + expect_fused + Number of mm sites the pass MUST have replaced. Use 0 for negative + tests (fusion must NOT fire). -1 disables the check. + expect_kinds + If set, the multiset of emitted op ``kind`` args must equal this list. + E.g. ``["swiglu7_dual"]`` for the swiglu7 special-case path. + dynamic_arg_dims + Forwarded to magi_compile. Defaults to making the first arg's M + dynamic (matches our fusion guards). + """ + if dynamic_arg_dims is None: + # Use the model's forward signature to pick the first arg name. + import inspect + + params = list(inspect.signature(model.forward).parameters) + if not params: + dynamic_arg_dims = {} + else: + dynamic_arg_dims = {params[0]: 0} + + model = model.cuda() + # Use bfloat16 so the EVT pass actually fires (the pass requires bf16). + if any(p.dtype.is_floating_point for p in model.parameters()): + model = model.bfloat16() + # Disable gradients on parameters; otherwise magi_compile / aot_autograd + # produces a forward+backward joint graph and the mm node has an extra + # user (the saved tensor for backward), which the EVT escape detector + # correctly refuses to fuse. + for p in model.parameters(): + p.requires_grad_(False) + + with torch.no_grad(): + expected = model(*inputs) + + get_compile_config().disable_cache = True + stats, restore = _install_pass_instrument() + try: + compiled_model = magi_compile(model, dynamic_arg_dims=dynamic_arg_dims) + with torch.no_grad(): + actual = compiled_model(*inputs) + finally: + restore() + + # Numerical check. + abs_diff = (actual - expected).abs() + tol = atol + rtol * expected.abs() + max_violation = (abs_diff - tol).max().item() + assert max_violation <= 0, ( + f"Fused result outside tolerance: " + f"max(|diff| - tol) = {max_violation:.4f}, " + f"max |diff| = {abs_diff.max().item():.4f}, " + f"fusion stats: fused={stats.fused_count} kinds={stats.kinds}" + ) + + # Fusion-actually-fired check. + if expect_fused >= 0: + assert stats.fused_count == expect_fused, ( + f"Expected {expect_fused} fused mm sites, got {stats.fused_count}. " + f"mm_before={stats.mm_before} mm_after={stats.mm_after} " + f"emitted kinds={stats.kinds}" + ) + if expect_kinds is not None: + assert sorted(stats.kinds) == sorted(expect_kinds), ( + f"Expected emitted kinds {sorted(expect_kinds)}, " f"got {sorted(stats.kinds)}" + ) -# --------------------------------------------------------------------------- -# Model wrappers -# --------------------------------------------------------------------------- +# ───────────────────────────────────────────────────────────────────────────── +# Positive tests — every athena activation must fuse and stay numerically OK +# ───────────────────────────────────────────────────────────────────────────── -class SiluModel(nn.Module): - def forward(self, a, b): - return high_precision_silu(torch.mm(a, b), out_dtype=torch.bfloat16) +class _Bf16MmModel(nn.Module): + """All positive activation models share this skeleton: bf16 mm followed + by an epilogue fn that returns bf16. Weight is held in (N, K) row-major + form and accessed via ``permute([1, 0])`` to mirror the real GAGA2 graph.""" + def __init__(self, k: int, n: int, epilogue): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + self._epi = epilogue -class SigmoidModel(nn.Module): - def forward(self, a, b): - return high_precision_sigmoid(torch.mm(a, b), out_dtype=torch.bfloat16) + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return self._epi(y, out_dtype=torch.bfloat16) -class GeluModel(nn.Module): - def forward(self, a, b): - return high_precision_gelu(torch.mm(a, b), out_dtype=torch.bfloat16) +_M, _K, _N = 1024, 1024, 1024 -class Swiglu7Model(nn.Module): - def forward(self, a, b): - return swiglu7(torch.mm(a, b), out_dtype=torch.bfloat16) +def _input_a(): + return torch.randn(_M, _K, device="cuda", dtype=torch.bfloat16) -class Gelu7Model(nn.Module): - def forward(self, a, b): - return gelu7(torch.mm(a, b), out_dtype=torch.bfloat16) +@_SM120_ONLY +@pytest.mark.parametrize( + "epi_name,epi_fn,atol,rtol", + [ + ("silu", high_precision_silu, 0.5, 0.0), + ("sigmoid", high_precision_sigmoid, 0.5, 0.0), + ("gelu", high_precision_gelu, 0.5, 0.0), + ("gelu7", gelu7, 0.5, 0.0), + ("relu_square", relu_square, 0.0, 0.2), + ], +) +def test_evt_unary_activations_fuse(epi_name, epi_fn, atol, rtol): + """All unary activations must fuse to a single ``evt_col`` op.""" + model = _Bf16MmModel(_K, _N, epi_fn) + _compile_and_check(model, (_input_a(),), atol=atol, rtol=rtol, expect_fused=1, expect_kinds=["evt_col"]) -class ReluSquareModel(nn.Module): - def forward(self, a, b): - return relu_square(torch.mm(a, b), out_dtype=torch.bfloat16) +@_SM120_ONLY +def test_evt_relu_native(): + """Plain ``aten.relu`` (no fp32 cast) — exercises the built-in CUTLASS + ReLu functor mapping in the IR.""" + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) -# --------------------------------------------------------------------------- -# Helper -# --------------------------------------------------------------------------- + def forward(self, a): + return torch.relu(torch.mm(a, self.weight.permute(1, 0))).to(torch.bfloat16) + _compile_and_check(M(), (_input_a(),), expect_fused=1, expect_kinds=["evt_col"]) -def _run_fusion_test(model: nn.Module, a: torch.Tensor, b: torch.Tensor, atol: float = 0.5, rtol: float = 0.0): - """Run a matmul-epilogue fusion test. - Checks that the fused result satisfies: |actual - expected| < atol + rtol * |expected| +@_SM120_ONLY +def test_evt_swiglu7_dispatches_to_dualgemm(): + """SwiGLU7 must take the dedicated DualGemm one-stage path, not generic EVT.""" + model = _Bf16MmModel(_K, _N, swiglu7) + _compile_and_check(model, (_input_a(),), atol=0.5, rtol=0.05, expect_fused=1, expect_kinds=["swiglu7_dual"]) - atol=0.5 covers the bf16 → fp32 accumulation difference for element-wise - activations whose output magnitude is O(1). For activations that amplify - magnitude (e.g. relu_square), pass a non-zero rtol instead. + +# ───────────────────────────────────────────────────────────────────────────── +# Binary-op positive tests — chains containing add/sub/mul/div on the mm output +# ───────────────────────────────────────────────────────────────────────────── + + +@_SM120_ONLY +def test_evt_mm_plus_scalar(): + """``mm + 0.5`` — scalar add absorbs into ``add_scalar`` IR node. + + Tolerance: eager runs the add in bf16 (lossy ulp at ±0.5); CUTLASS runs + the add in fp32 then casts. The ~1.0 absolute diff observed is bf16 + rounding noise on the eager side, not a CUTLASS bug. """ - model = model.cuda().bfloat16() - with torch.no_grad(): - expected = model(a, b) - get_compile_config().disable_cache = True - compiled_model = magi_compile(model, dynamic_arg_dims={"a": 0}) - with torch.no_grad(): - actual = compiled_model(a, b) + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) - abs_diff = (actual - expected).abs() - tol = atol + rtol * expected.abs() - max_violation = (abs_diff - tol).max().item() - assert max_violation <= 0, ( - f"Fused result too far from reference: " - f"max(|diff| - tol) = {max_violation:.4f}, " - f"max |diff| = {abs_diff.max().item():.4f}" - ) + def forward(self, a): + return (torch.mm(a, self.weight.permute(1, 0)) + 0.5).to(torch.bfloat16) + + _compile_and_check(M(), (_input_a(),), atol=1.5, expect_fused=1, expect_kinds=["evt_col"]) + + +@_SM120_ONLY +def test_evt_mm_times_scalar(): + """``mm * 0.25`` — scalar mul (mul_scalar IR).""" + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) -# --------------------------------------------------------------------------- -# Tests -# --------------------------------------------------------------------------- + def forward(self, a): + return (torch.mm(a, self.weight.permute(1, 0)) * 0.25).to(torch.bfloat16) + _compile_and_check(M(), (_input_a(),), expect_fused=1, expect_kinds=["evt_col"]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_matmul_epilogue_fusion_silu(): - M, K, N = 128, 256, 512 - a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) - _run_fusion_test(SiluModel(), a, b) +@_SM120_ONLY +def test_evt_mm_div_scalar_then_silu(): + """``silu(mm / 8)`` — scalar div + activation chain.""" -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_matmul_epilogue_fusion_sigmoid(): - M, K, N = 128, 256, 512 - a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) - _run_fusion_test(SigmoidModel(), a, b) + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) / 8.0 + return high_precision_silu(y, out_dtype=torch.bfloat16) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_matmul_epilogue_fusion_gelu(): - M, K, N = 128, 256, 512 - a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) - _run_fusion_test(GeluModel(), a, b) + _compile_and_check(M(), (_input_a(),), expect_fused=1, expect_kinds=["evt_col"]) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_matmul_epilogue_fusion_swiglu7(): - M, K, N = 128, 256, 512 - a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) - _run_fusion_test(Swiglu7Model(), a, b) +@_SM120_ONLY +def test_evt_mm_minus_scalar_then_relu(): + """``relu(mm - 2.0)``.""" + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_matmul_epilogue_fusion_gelu7(): - M, K, N = 128, 256, 512 - a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) - _run_fusion_test(Gelu7Model(), a, b) + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) - 2.0 + return torch.relu(y).to(torch.bfloat16) + _compile_and_check(M(), (_input_a(),), expect_fused=1, expect_kinds=["evt_col"]) + + +@_SM120_ONLY +def test_evt_mm_plus_1d_bias(): + """``silu(mm + bias_N)`` — 1-D bias as RowBroadcast extras.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + self.bias = nn.Parameter(torch.randn(_N)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + self.bias + return high_precision_silu(y, out_dtype=torch.bfloat16) + + # atol=1.5: eager does the bias-add in bf16 (lossy), CUTLASS in fp32 — + # the ~1.0 abs diff is bf16 ulp noise on the eager side. + _compile_and_check(M(), (_input_a(),), atol=1.5, expect_fused=1, expect_kinds=["evt_col"]) + + +@_SM120_ONLY +def test_evt_mm_times_aux_load(): + """``(mm * gate_MxN)`` — full (M, N) auxiliary tensor multiply. + + The gate must be supplied as a regular forward arg (not a model parameter) + because magi_compile doesn't trace through Parameters of dynamic shape. + """ + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a, gate): + y = torch.mm(a, self.weight.permute(1, 0)) * gate + return y.to(torch.bfloat16) + + a = _input_a() + gate = torch.randn(_M, _N, device="cuda", dtype=torch.bfloat16) + # rtol=0.1: ``mm * gate`` in eager is bf16 (lossy multiply); CUTLASS + # multiplies in fp32 then casts. Output magnitude scales like sqrt(K)*1*1 + # ≈ 32, so 5–10 % relative diff is expected purely from bf16 vs fp32. + _compile_and_check(M(), (a, gate), atol=0.0, rtol=0.1, expect_fused=1, expect_kinds=["evt_col"], dynamic_arg_dims={"a": 0}) + + +# ───────────────────────────────────────────────────────────────────────────── +# Negative tests — fusion must NOT fire and the chain must fall back to cuBLAS +# ───────────────────────────────────────────────────────────────────────────── + + +@_SM120_ONLY +def test_evt_no_fuse_intermediate_escapes(): + """Attention → residual → RMSNorm pattern: ``add(residual, mm)`` is + consumed both by ``square(...)`` (would-be-fused) AND by ``mul(_, rsqrt)`` + later. The pass MUST refuse — fusing would silently drop the value the + rest of RMSNorm needs.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(5120, _K)) + self.gamma = nn.Parameter(torch.randn(5120)) + + def forward(self, a, residual): + y = torch.mm(a, self.weight.permute(1, 0)).float() + x = residual + y + var = x.pow(2).mean(-1, keepdim=True) + rsqrt = torch.rsqrt(var + 1e-6) + return (x * rsqrt * (self.gamma + 1)).to(torch.bfloat16) + + a = _input_a() + residual = torch.randn(_M, 5120, device="cuda", dtype=torch.float32) + _compile_and_check(M(), (a, residual), atol=2.0, rtol=0.1, expect_fused=0) + + +@_SM120_ONLY +def test_evt_no_fuse_bare_mm(): + """A bare ``mm`` with no epilogue at all — Store(Accum) is trivial. + Replacing cuBLAS with a CUTLASS GEMM that does identical work is strictly + slower, so the pass must skip.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return torch.mm(a, self.weight.permute(1, 0)) + + _compile_and_check(M(), (_input_a(),), atol=0.5, expect_fused=0) + + +@_SM120_ONLY +def test_evt_no_fuse_k_misaligned(): + """K not divisible by 8 fails the bf16 alignment guard — cuBLAS path.""" + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + K = 1023 # 1023 % 8 = 7 → should NOT fuse + N = 1024 + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=0) + + +@_SM120_ONLY +def test_evt_no_fuse_fp32_mm(): + """fp32 mm — pass requires bf16 (or fp16); fp32 must skip.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return F.silu(y) + + a = torch.randn(_M, _K, device="cuda", dtype=torch.float32) + + model = M().cuda() # fp32 — do NOT bfloat16() the model + with torch.no_grad(): + expected = model(a) + + get_compile_config().disable_cache = True + stats, restore = _install_pass_instrument() + try: + compiled_model = magi_compile(model, dynamic_arg_dims={"a": 0}) + with torch.no_grad(): + actual = compiled_model(a) + finally: + restore() + + diff = (actual - expected).abs().max().item() + assert diff <= 1.0, f"fp32 mm result diverged: {diff}" + assert stats.fused_count == 0, ( + f"fp32 mm should NOT fuse, but pass emitted {stats.fused_count} ops " f"(kinds={stats.kinds})" + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# IR / cache key invariants +# ───────────────────────────────────────────────────────────────────────────── + + +@_SM120_ONLY +def test_evt_ir_canonical_determinism(): + """Same IR built twice → identical canonical JSON. If this regresses, the + .cu module disk cache silently misses and recompiles every run.""" + from magi_compiler.passes.piecewise_graph.fusion.blackwell_geforce.evt_ir import ( + Accum, + Compute, + Store, + cache_key, + to_canonical_json, + ) -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -def test_matmul_epilogue_fusion_relu_square(): - M, K, N = 128, 256, 512 - a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) - b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) - # relu_square amplifies values quadratically (output ~ x^2, up to ~256), - # so use relative tolerance instead of a fixed absolute bound. - _run_fusion_test(ReluSquareModel(), a, b, atol=0.0, rtol=0.2) + a = Store(Compute("silu", (Compute("add", (Accum(), Accum())),)), "bfloat16") + b = Store(Compute("silu", (Compute("add", (Accum(), Accum())),)), "bfloat16") + assert to_canonical_json(a) == to_canonical_json(b) + assert cache_key(a, "bfloat16", "bfloat16") == cache_key(b, "bfloat16", "bfloat16") if __name__ == "__main__": From 4e42fcf3107ba444e4308b638cd341b62424876d Mon Sep 17 00:00:00 2001 From: wtr Date: Wed, 29 Apr 2026 19:40:06 +0800 Subject: [PATCH 4/7] add cutlass install in Dockerfile & update --- Dockerfile | 52 +++ .../fusion/blackwell_geforce/evt_codegen.py | 5 +- .../fusion/blackwell_geforce/evt_runtime.py | 178 ++--------- .../matmul_epilogue_fusion.py | 8 +- .../test_matmul_epilogue_fusion.py | 302 ++++++++++++++++-- 5 files changed, 361 insertions(+), 184 deletions(-) diff --git a/Dockerfile b/Dockerfile index 476ad3f..e9ef25a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,21 @@ FROM nvcr.io/nvidia/pytorch:25.10-py3 ARG FLASH_ATTENTION_COMMIT_ID="b613d9e2c8475945baff3fd68f2030af1b890acf" +# CUTLASS — source is always cloned (the magi_compiler EVT-fusion path +# JIT-includes its headers and our /opt/cutlass tree is the readable +# reference checkout). The CMake-driven profiler/library is compiled +# *only* when the build host is an RTX 5090 (sm_120, Blackwell consumer); +# every other arch gets the source tree but no built artefacts. +# +# Override behaviour with a build arg: +# --build-arg CUTLASS_BUILD=yes force compile (e.g. on a build farm +# without a GPU but targeting sm_120) +# --build-arg CUTLASS_BUILD=no force skip even if 5090 detected +# --build-arg CUTLASS_BUILD=auto (default) compile iff nvidia-smi +# reports compute_cap == 12.x +ARG CUTLASS_COMMIT_ID="f74fea9ce35868d3ae9f8d1dce1969d7250d3f90" +ARG CUTLASS_BUILD="auto" + ENV PIP_NO_CACHE_DIR=1 \ PIP_DISABLE_PIP_VERSION_CHECK=1 \ PYTHONDONTWRITEBYTECODE=1 @@ -18,6 +33,7 @@ RUN --mount=type=secret,id=http_proxy,required=false \ ca-certificates \ git \ build-essential \ + cmake \ ninja-build && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean @@ -42,6 +58,42 @@ RUN --mount=type=secret,id=http_proxy,required=false \ cp /tmp/flash-attention/hopper/flash_attn_interface.py ${python_path}/flash_attn_3/ && \ rm -rf /tmp/flash-attention + +RUN --mount=type=secret,id=http_proxy,required=false \ + --mount=type=secret,id=https_proxy,required=false \ + export http_proxy="$(cat /run/secrets/http_proxy 2>/dev/null || true)" && \ + export https_proxy="$(cat /run/secrets/https_proxy 2>/dev/null || true)" && \ + mkdir -p /opt/cutlass && \ + cd /opt/cutlass && \ + git init -q && \ + git remote add origin https://github.com/NVIDIA/cutlass.git && \ + git fetch origin ${CUTLASS_COMMIT_ID} --depth 1 && \ + git checkout ${CUTLASS_COMMIT_ID} && \ + (git submodule update --init --recursive --depth 1 --jobs 8 || \ + git submodule update --init --recursive --depth 1 --jobs 1) + + +RUN set -eu; \ + case "${CUTLASS_BUILD}" in \ + no) echo "[CUTLASS] CUTLASS_BUILD=no — skipping cmake configure."; exit 0 ;; \ + yes) DO_BUILD=1 ;; \ + auto) \ + if command -v nvidia-smi >/dev/null 2>&1 && \ + nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ + | head -n1 | grep -Eq '^12\.'; then \ + echo "[CUTLASS] nvidia-smi reports sm_120 — running cmake configure."; \ + DO_BUILD=1; \ + else \ + echo "[CUTLASS] No sm_120 detected at build time — skipping cmake (headers still available)."; \ + exit 0; \ + fi ;; \ + *) echo "[CUTLASS] Unknown CUTLASS_BUILD=${CUTLASS_BUILD}"; exit 1 ;; \ + esac; \ + [ -n "${DO_BUILD:-}" ] && cd /opt/cutlass && \ + export CUDACXX="${CUDA_INSTALL_PATH:-${CUDA_HOME:-/usr/local/cuda}}/bin/nvcc" && \ + mkdir -p build && cd build && \ + cmake .. -DCUTLASS_NVCC_ARCHS=120a + RUN --mount=type=secret,id=http_proxy,required=false \ --mount=type=secret,id=https_proxy,required=false \ export http_proxy="$(cat /run/secrets/http_proxy 2>/dev/null || true)" && \ diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_codegen.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_codegen.py index af5bc82..72f7984 100644 --- a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_codegen.py +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_codegen.py @@ -22,8 +22,9 @@ 4. Exposes ``evt_matmul_out`` via PYBIND11. We use CUTLASS 2.x ``Sm80EVT`` running backward-compat on sm_120; this matches -``/root/cutlass/examples/99_evt_demo/heavy_epi_torch_ext.cu`` which has been -verified to deliver +5..+12 % vs the Triton TMA path on RTX 5090 bf16. +``$MAGI_CUTLASS_ROOT/examples/99_evt_demo/heavy_epi_torch_ext.cu`` (default +``/opt/cutlass/...``) which has been verified to deliver +5..+12 % vs the +Triton TMA path on RTX 5090 bf16. """ from __future__ import annotations diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_runtime.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_runtime.py index 56fa681..41d034a 100644 --- a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_runtime.py +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_runtime.py @@ -67,60 +67,6 @@ def out_dtype_from_id(i: int) -> torch.dtype: return _ID_TO_DTYPE[i] -# ── M-bucket dispatch ───────────────────────────────────────────────────────── -# Three coarse buckets matching the tile-candidate sets in -# ``evt_codegen._TILE_CANDIDATES_5090``: -# small — M ≤ 256 (decode / single-token) -# medium — 256 < M ≤ 2048 (mid-size prefill) -# large — M > 2048 (large prefill / batched) -# Each bucket compiles a distinct .cu module containing its own tile-candidate -# vector; the per-module C++ runner then autotunes the actual best (TileShape, -# WarpShape, NumStages) tuple at first call per (M, N, K) and caches the -# winning index inside the module — so the Python side only pays one extra -# cache key dimension. -_M_BUCKET_BOUNDARIES = (256, 2048) - - -def _m_bucket(M: int) -> str: - if M <= _M_BUCKET_BOUNDARIES[0]: - return "small" - if M <= _M_BUCKET_BOUNDARIES[1]: - return "medium" - return "large" - - -# ── Output row-stride helper ────────────────────────────────────────────────── -# CUTLASS Sm80EVT and the swiglu7 DualGemm both require D's row stride to be a -# multiple of AlignmentC * sizeof(ElementC) = 4 * sizeof(bf16) = 8 bytes (i.e. -# 4 elements for bf16/fp16, 2 elements for fp32). When n_out already meets this -# requirement we return a *contiguous* (M, n_out) tensor — avoids an extra D2D -# scratch copy on the hot path. Only when n_out fails the alignment do we fall -# back to padding the row stride. -# -# Earlier this padded everything to 128 bytes (matching the Triton path's -# convention) but on shapes like N_out=13652 the resulting non-contig D forced -# a kernel-into-scratch + scratch-into-D copy worth ~5% of the kernel runtime -# at (M=7697, N=27304, K=5120) — which fully accounted for the perf gap users -# saw between the standalone benchmark (no scratch) and the real model. -# -# Pre-computed alignment per dtype to avoid the ~2–5 μs cost of -# ``torch.empty([], dtype=dt).element_size()`` per op invocation. Hit count on -# this lookup is 2× per fused op (runtime impl + fake impl), so on a model with -# 100 fused-op calls per forward this shaves ~1 ms off the dispatch overhead. -_ALIGN_BY_DTYPE: dict = { - torch.bfloat16: 4, # 8 bytes / 2 = 4 elements - torch.float16: 4, - torch.float32: 2, # 8 bytes / 4 = 2 elements -} - - -def _aligned_n_stride(n_out: int, dt: torch.dtype) -> int: - align = _ALIGN_BY_DTYPE.get(dt) - if align is None: # rare: a dtype we haven't pre-tabulated - align = max(1, 8 // torch.empty([], dtype=dt).element_size()) - return (n_out + align - 1) // align * align - - # ── Compile cache + per-key build lock ──────────────────────────────────────── _MODULE_CACHE: dict = {} # cache_key (sha256 str) → loaded cpp_extension module # Hot-path fast cache — avoids ``json.dumps + sha256`` (~10–30 μs/call) when @@ -135,18 +81,16 @@ def _aligned_n_stride(n_out: int, dt: torch.dtype) -> int: # ── D output-buffer cache ──────────────────────────────────────────────────── -# Keyed by (M, n_out, n_stride, out_dtype, device_idx). Mirrors the same -# cache pattern in ``sm120_triton_kernel.py:_buf_cache`` — which has been -# shipping in this codebase for the Triton path. Reusing D across calls -# avoids the per-call ``torch.empty`` overhead (~5–15 μs of Python work + -# allocator metadata) and the (rare) scratch slice; on hot paths with -# millisecond-scale kernels this is a measurable but small win. +# Single-entry greedy cache, keyed by (M, n_out, dtype, device_idx). The hot +# path in ``_matmul_custom_evt_cuda`` reads/writes this dict directly (the +# resolver was inlined for ~1 μs/call savings), so this module only owns the +# storage and a disable switch. # -# Correctness contract — same as the Triton path: this is a single-stream -# inference cache. The previous call's D consumer must already have read it -# before the next call lands. Inductor-generated ``call(...)`` functions -# satisfy this because they execute serially on the default CUDA stream and -# the returned tensor is consumed before the next op-level dispatch. +# FX-pass guards (K % 8 == 0; generic N % 4 == 0; swiglu7 N % 8 == 0) ensure +# n_out is always a multiple of CUTLASS's AlignmentC = 4 elements, so D is +# always allocated as a true-contiguous ``torch.empty((M, n_out), dtype)`` — +# no padded stride / scratch buffer route exists. Anything that violates the +# guards is rejected upstream and falls back to torch.compile's default mm. # # To opt out (e.g. when bench-scripting with overlapping streams), set the # env var ``MAGI_EVT_DISABLE_D_CACHE=1``. @@ -154,39 +98,10 @@ def _aligned_n_stride(n_out: int, dt: torch.dtype) -> int: _D_CACHE_DISABLED: bool = os.environ.get("MAGI_EVT_DISABLE_D_CACHE", "0") not in ("0", "", "false", "False") -def _get_or_alloc_D(M: int, n_out: int, out_dtype: torch.dtype, device: torch.device) -> "torch.Tensor": - """Return a (possibly cached) (M, n_out) output buffer. - - The buffer is contiguous when ``n_stride == n_out`` (the fast path); when - ``n_out`` is mis-aligned we keep the padded ``[:, :n_out]`` slice so the - fake impl's stride matches at runtime. - """ - # Fast path: cache key first, recompute n_stride only on miss. The cache - # is keyed by (M, n_out, dtype, device_idx); two distinct (n_out, dtype) - # always have the same alignment, so we don't need n_stride in the key. - idx = device.index or 0 # index is None for default device → falsy → 0 - key = (M, n_out, out_dtype, idx) - cached = _D_BUF_CACHE.get(key) - if cached is not None and not _D_CACHE_DISABLED: - return cached - n_stride = _aligned_n_stride(n_out, out_dtype) - if n_stride == n_out: - D = torch.empty((M, n_out), device=device, dtype=out_dtype) - else: - D = torch.empty((M, n_stride), device=device, dtype=out_dtype)[:, :n_out] - if not _D_CACHE_DISABLED: - # Single-entry cache: evict everything else, then install the new one. - # We can't iterate-and-delete on the live dict (RuntimeError under any - # workload that puts >1 entry in the cache — e.g. CP=4 sees multiple - # per-rank shapes during warmup, while a single-card run often reuses - # one shape and never tripped the bug). - _D_BUF_CACHE.clear() - _D_BUF_CACHE[key] = D - return D - - def _cutlass_root() -> str: - return os.environ.get("MAGI_CUTLASS_ROOT", "/root/cutlass") + # Default install location is /opt/cutlass (Dockerfile clones the source + # tree there). Override with MAGI_CUTLASS_ROOT for ad-hoc dev checkouts. + return os.environ.get("MAGI_CUTLASS_ROOT", "/opt/cutlass") def _evt_build_dir(key: str) -> str: @@ -405,28 +320,6 @@ def _compile_swiglu7_dual(m_bucket: str, N: int, K: int): # ── torch.library backend impls ─────────────────────────────────────────────── -# Single-entry scratch cache for the rare mis-aligned-N path. Same greedy -# eviction policy as ``_D_BUF_CACHE`` — bounded memory across many shapes -# (e.g. CP=4 sees several per-rank M values during warmup; we don't want a -# scratch buffer for every one). -_SCRATCH_CACHE: dict = {} - - -def _get_or_alloc_scratch(M: int, n_out: int, out_dtype: torch.dtype, device: torch.device) -> "torch.Tensor": - if _D_CACHE_DISABLED: - return torch.empty((M, n_out), device=device, dtype=out_dtype) - idx = device.index or 0 - key = (M, n_out, out_dtype, idx) - cached = _SCRATCH_CACHE.get(key) - if cached is not None: - return cached - s = torch.empty((M, n_out), device=device, dtype=out_dtype) - # Greedy eviction: one shape at a time. - _SCRATCH_CACHE.clear() - _SCRATCH_CACHE[key] = s - return s - - # ── Dispatch fast-cache ────────────────────────────────────────────────────── # Hot-path bottleneck reduction: collapse the four-step # out_dtype_from_id → _m_bucket → _compile_* → mod.attr-lookup @@ -529,55 +422,36 @@ def _matmul_custom_evt_cuda(A, B, extras, ir_json, kind, n_out, out_dtype_id_): _DISPATCH_CACHE[fast_key] = entry # ── Step 2: alloc / fetch D (greedy single-entry cache, inlined) ── - # D matches the fake impl's shape. CUTLASS launchers require D contiguous; - # when n_out happens to be mis-aligned the row stride is padded and we - # route through a scratch buffer. + # FX pass guards (K % 8 == 0; generic N % 4 == 0; swiglu7 N % 8 == 0) + # ensure n_out is a multiple of CUTLASS AlignmentC = 4 for every dtype, + # so a plain ``torch.empty((M, n_out), dtype)`` is already CUTLASS- + # contiguous — no padded stride / scratch buffer route is required. + # Anything that violates the guards is rejected upstream and falls back + # to torch.compile's default mm. if _D_CACHE_DISABLED: - n_stride = _aligned_n_stride(n_out, out_dtype) - if n_stride == n_out: - D = torch.empty((M, n_out), device=A.device, dtype=out_dtype) - else: - D = torch.empty((M, n_stride), device=A.device, dtype=out_dtype)[:, :n_out] + D = torch.empty((M, n_out), device=A.device, dtype=out_dtype) else: dev_idx = A.device.index or 0 d_key = (M, n_out, out_dtype, dev_idx) D = _D_BUF_CACHE.get(d_key) if D is None: - n_stride = _aligned_n_stride(n_out, out_dtype) - if n_stride == n_out: - D = torch.empty((M, n_out), device=A.device, dtype=out_dtype) - else: - D = torch.empty((M, n_stride), device=A.device, dtype=out_dtype)[:, :n_out] + D = torch.empty((M, n_out), device=A.device, dtype=out_dtype) _D_BUF_CACHE.clear() _D_BUF_CACHE[d_key] = D # ── Step 3: dispatch — pre-bound callable, single C++ trampoline ── - # `D.stride(0) != n_out` is the only branch we take per call to decide - # whether we need the scratch route. Cheap C++ attribute compare. - needs_scratch = D.stride(0) != n_out kernel_call = entry.kernel_call - if entry.is_evt: - if needs_scratch: - scratch = _get_or_alloc_scratch(M, n_out, out_dtype, A.device) - kernel_call(A, B, extras, scratch) - D.copy_(scratch) - return D kernel_call(A, B, extras, D) - return D - - # swiglu7_dual: extras is always [] here (FX pass guarantees). - if needs_scratch: - scratch = _get_or_alloc_scratch(M, n_out, out_dtype, A.device) - kernel_call(A, B, scratch) - D.copy_(scratch) - return D - kernel_call(A, B, D) + else: + # swiglu7_dual: extras is always [] here (FX pass guarantees). + kernel_call(A, B, D) return D @torch.library.register_fake("magi_epilogue::matmul_custom_evt") def _matmul_custom_evt_fake(A, B, extras, ir_json, kind, n_out, out_dtype_id_): out_dtype = out_dtype_from_id(out_dtype_id_) - n_stride = _aligned_n_stride(n_out, out_dtype) - return A.new_empty_strided((A.shape[0], n_out), (n_stride, 1), dtype=out_dtype) + # Contiguous (M, n_out) — see _D_BUF_CACHE comment for why padding is + # never needed under the FX-pass alignment guards. + return A.new_empty((A.shape[0], n_out), dtype=out_dtype) diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py index dd5dc99..d8e4af2 100644 --- a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py @@ -616,7 +616,13 @@ def _try_fuse_swiglu7(self, graph: fx.Graph, mm_node: fx.Node) -> bool: if w_shape is None or len(w_shape) != 2 or w_stride is None: return False N, K = w_shape - if not (_is_static_int(N) and N % 2 == 0): + # N % 8 ensures (a) the gate/linear interleaved split is valid (N + # even) AND (b) n_out = N // 2 satisfies CUTLASS AlignmentC = 4 + # for bf16. This lets the runtime allocate D as a true-contiguous + # (M, n_out) tensor with no padded stride / scratch path. Real + # GAGA2 has N=27304 (% 8 == 0). Smaller misaligned N falls back + # to torch.compile's default mm + python silu chain. + if not (_is_static_int(N) and N % 8 == 0): return False if w_stride != (K, 1): return False # not contiguous (N, K) — abort diff --git a/tests/feature_tests/test_matmul_epilogue_fusion.py b/tests/feature_tests/test_matmul_epilogue_fusion.py index b6489bb..f6d7cfd 100644 --- a/tests/feature_tests/test_matmul_epilogue_fusion.py +++ b/tests/feature_tests/test_matmul_epilogue_fusion.py @@ -113,6 +113,11 @@ def __init__(self) -> None: self.mm_after = 0 self.fused_count = 0 self.kinds: list = [] + # out_dtype_id of each emitted op (args[6]). Encoded as + # bf16 → 0, fp16 → 1, fp32 → 2 (see evt_runtime._OUT_DTYPE_ID). + # Tests assert against this to catch silent dtype regressions in the + # FX pass's last-node meta lookup or codegen's ElementC typedef. + self.out_dtype_ids: list = [] def _install_pass_instrument(): @@ -129,15 +134,19 @@ def _instrumented(self, graph: fx.Graph): result = original(self, graph) after = sum(1 for n in graph.nodes if n.op == "call_function" and n.target in mm_targets) emitted_kinds = [] + emitted_out_dtype_ids = [] for n in graph.nodes: if n.op == "call_function" and n.target is evt_op: # signature: (A, B, extras, ir_json, kind, n_out, out_dtype_id) if len(n.args) >= 5: emitted_kinds.append(n.args[4]) + if len(n.args) >= 7: + emitted_out_dtype_ids.append(n.args[6]) stats.mm_before += before stats.mm_after += after stats.fused_count += len(emitted_kinds) stats.kinds.extend(emitted_kinds) + stats.out_dtype_ids.extend(emitted_out_dtype_ids) return result P.MatmulEvtEpilogueFusionPass.__call__ = _instrumented @@ -156,7 +165,10 @@ def _compile_and_check( rtol: float = 0.0, expect_fused: int = -1, expect_kinds: Optional[list] = None, + expect_out_dtype: Optional[torch.dtype] = None, + expect_actual_dtype: Optional[torch.dtype] = None, dynamic_arg_dims=None, + cast_model_to_bf16: bool = True, ): """Compile ``model``, run it on ``inputs``, compare against eager. @@ -172,9 +184,23 @@ def _compile_and_check( expect_kinds If set, the multiset of emitted op ``kind`` args must equal this list. E.g. ``["swiglu7_dual"]`` for the swiglu7 special-case path. + expect_out_dtype + If set, every emitted op's ``out_dtype_id`` (args[6]) MUST decode to + this dtype. Catches silent regressions where the FX pass picks the + wrong terminal-node dtype, or where Inductor inserts an extra cast + that the IR walker wasn't expecting. + expect_actual_dtype + If set, the runtime result tensor MUST have this dtype. Independent + check from ``expect_out_dtype`` — they should agree but a mismatch + between them would mean the codegen's StoreD typedef diverged from + the op's declared out_dtype_id. dynamic_arg_dims Forwarded to magi_compile. Defaults to making the first arg's M dynamic (matches our fusion guards). + cast_model_to_bf16 + Default True (mirrors the standard test setup). Pass False when the + model already has the dtype mix you want (e.g. fp16-only or mixed + bf16 / fp16 weights). """ if dynamic_arg_dims is None: # Use the model's forward signature to pick the first arg name. @@ -187,8 +213,10 @@ def _compile_and_check( dynamic_arg_dims = {params[0]: 0} model = model.cuda() - # Use bfloat16 so the EVT pass actually fires (the pass requires bf16). - if any(p.dtype.is_floating_point for p in model.parameters()): + # Use bfloat16 by default so the EVT pass actually fires (the pass + # requires bf16/fp16). Skip the auto-cast for tests that explicitly + # set up a different dtype mix. + if cast_model_to_bf16 and any(p.dtype.is_floating_point for p in model.parameters()): model = model.bfloat16() # Disable gradients on parameters; otherwise magi_compile / aot_autograd # produces a forward+backward joint graph and the mm node has an extra @@ -231,6 +259,21 @@ def _compile_and_check( assert sorted(stats.kinds) == sorted(expect_kinds), ( f"Expected emitted kinds {sorted(expect_kinds)}, " f"got {sorted(stats.kinds)}" ) + if expect_out_dtype is not None: + from magi_compiler.passes.piecewise_graph.fusion.blackwell_geforce.evt_runtime import out_dtype_from_id + + assert stats.out_dtype_ids, ( + f"expect_out_dtype={expect_out_dtype} but no fusion fired " f"(out_dtype_ids list is empty)" + ) + decoded = [out_dtype_from_id(i) for i in stats.out_dtype_ids] + for got in decoded: + assert got == expect_out_dtype, ( + f"Emitted out_dtype mismatch: expected {expect_out_dtype}, " f"got {got} (full list: {decoded})" + ) + if expect_actual_dtype is not None: + assert actual.dtype == expect_actual_dtype, ( + f"Runtime result dtype mismatch: expected {expect_actual_dtype}, " f"got {actual.dtype}" + ) # ───────────────────────────────────────────────────────────────────────────── @@ -410,10 +453,9 @@ def forward(self, a, gate): a = _input_a() gate = torch.randn(_M, _N, device="cuda", dtype=torch.bfloat16) - # rtol=0.1: ``mm * gate`` in eager is bf16 (lossy multiply); CUTLASS - # multiplies in fp32 then casts. Output magnitude scales like sqrt(K)*1*1 - # ≈ 32, so 5–10 % relative diff is expected purely from bf16 vs fp32. - _compile_and_check(M(), (a, gate), atol=0.0, rtol=0.1, expect_fused=1, expect_kinds=["evt_col"], dynamic_arg_dims={"a": 0}) + _compile_and_check( + M(), (a, gate), atol=0.0, rtol=0.1, expect_fused=1, expect_kinds=["evt_col"], dynamic_arg_dims={"a": 0, "gate": 0} + ) # ───────────────────────────────────────────────────────────────────────────── @@ -443,7 +485,9 @@ def forward(self, a, residual): a = _input_a() residual = torch.randn(_M, 5120, device="cuda", dtype=torch.float32) - _compile_and_check(M(), (a, residual), atol=2.0, rtol=0.1, expect_fused=0) + # `residual + y` couples a's M to residual's M; mark both dynamic so + # Dynamo doesn't specialize a's declared dynamic dim → ConstraintViolation. + _compile_and_check(M(), (a, residual), atol=2.0, rtol=0.1, expect_fused=0, dynamic_arg_dims={"a": 0, "residual": 0}) @_SM120_ONLY @@ -483,38 +527,43 @@ def forward(self, a): @_SM120_ONLY -def test_evt_no_fuse_fp32_mm(): - """fp32 mm — pass requires bf16 (or fp16); fp32 must skip.""" +def test_evt_no_fuse_evt_n_misaligned(): + """N not divisible by 4 fails the generic-EVT N-alignment guard + (CUTLASS AlignmentC = 4) — must fall back to torch.compile / cuBLAS.""" class M(nn.Module): - def __init__(self): + def __init__(self, k, n): super().__init__() - self.weight = nn.Parameter(torch.randn(_N, _K)) + self.weight = nn.Parameter(torch.randn(n, k)) def forward(self, a): y = torch.mm(a, self.weight.permute(1, 0)) - return F.silu(y) + return high_precision_silu(y, out_dtype=torch.bfloat16) - a = torch.randn(_M, _K, device="cuda", dtype=torch.float32) + K = 1024 + N = 1026 # 1026 % 4 = 2 → should NOT fuse + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=0) - model = M().cuda() # fp32 — do NOT bfloat16() the model - with torch.no_grad(): - expected = model(a) - get_compile_config().disable_cache = True - stats, restore = _install_pass_instrument() - try: - compiled_model = magi_compile(model, dynamic_arg_dims={"a": 0}) - with torch.no_grad(): - actual = compiled_model(a) - finally: - restore() +@_SM120_ONLY +def test_evt_no_fuse_swiglu7_n_not_mult_of_8(): + """swiglu7 needs N % 8 == 0 so that n_out = N // 2 is 4-aligned for + bf16 (CUTLASS AlignmentC = 4). N = 12 (% 8 != 0) must fall back.""" - diff = (actual - expected).abs().max().item() - assert diff <= 1.0, f"fp32 mm result diverged: {diff}" - assert stats.fused_count == 0, ( - f"fp32 mm should NOT fuse, but pass emitted {stats.fused_count} ops " f"(kinds={stats.kinds})" - ) + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return swiglu7(y, out_dtype=torch.bfloat16) + + K = 1024 + N = 12 # 12 % 2 == 0 (split OK) but 12 % 8 != 0 → NOT fused + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=0) # ───────────────────────────────────────────────────────────────────────────── @@ -540,5 +589,200 @@ def test_evt_ir_canonical_determinism(): assert cache_key(a, "bfloat16", "bfloat16") == cache_key(b, "bfloat16", "bfloat16") +# ───────────────────────────────────────────────────────────────────────────── +# out_dtype correctness — verify the EVT pass picks the right Store dtype + +# the codegen's ElementC matches + the runtime returns a tensor of that dtype. +# +# Matrix: +# input dtype | epilogue compute | output dtype | expected out_dtype_id +# ───────────────────────────────────────────────────────────────────── +# bf16 | bf16 | bf16 | 0 (default) +# bf16 | fp32 | bf16 | 0 (high_precision_silu) +# bf16 | fp32 | fp32 | 2 (no final cast) +# bf16 | bf16 | fp16 | 1 (cross-precision) +# fp16 | fp16 | fp16 | 1 (fp16-only path) +# fp32 input | — | — | not fused (negative) +# ───────────────────────────────────────────────────────────────────────────── + + +@_SM120_ONLY +def test_evt_out_dtype_bf16_native(): + """bf16 mm → bf16 silu → bf16 output (no fp32 promotion). Pure-bf16 chain. + out_dtype_id MUST be 0 (bf16) and the runtime tensor MUST be bf16.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return F.silu(torch.mm(a, self.weight.permute(1, 0))) # bf16 → bf16 + + _compile_and_check( + M(), + (_input_a(),), + expect_fused=1, + expect_kinds=["evt_col"], + expect_out_dtype=torch.bfloat16, + expect_actual_dtype=torch.bfloat16, + ) + + +@_SM120_ONLY +def test_evt_out_dtype_bf16_via_high_precision(): + """The athena ``high_precision_silu`` pattern: bf16 → cast(fp32) → silu → + cast(bf16). The IR walker absorbs both casts; final output is bf16 even + though the compute went through fp32 internally. + + This is the most common athena pattern — a regression here means the + inner-cast handling broke and out_dtype is silently wrong.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + _compile_and_check( + M(), + (_input_a(),), + expect_fused=1, + expect_kinds=["evt_col"], + expect_out_dtype=torch.bfloat16, + expect_actual_dtype=torch.bfloat16, + ) + + +@_SM120_ONLY +def test_evt_out_dtype_fp32_no_final_cast(): + """bf16 mm → fp32 cast → silu → keep fp32 (no final cast back). + + out_dtype_id MUST be 2 (fp32). Exercises codegen's ``ElementC = float`` + path + the runtime D allocator with fp32 row-stride alignment (4 elements + = 16 bytes — different vector size than bf16's 8 bytes). + """ + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)).float() + return F.silu(y) # stays fp32 + + _compile_and_check( + M(), + (_input_a(),), + expect_fused=1, + expect_kinds=["evt_col"], + expect_out_dtype=torch.float32, + expect_actual_dtype=torch.float32, + ) + + +@_SM120_ONLY +def test_evt_out_dtype_bf16_to_fp16(): + """bf16 mm → silu → cast(fp16). Cross-precision: bf16 inputs but fp16 + output. out_dtype_id MUST be 1 (fp16). Exercises the codegen's + ``ElementA = bfloat16_t`` + ``ElementC = half_t`` mixed instantiation.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return F.silu(torch.mm(a, self.weight.permute(1, 0))).half() + + _compile_and_check( + M(), + (_input_a(),), + atol=0.5, + expect_fused=1, + expect_kinds=["evt_col"], + expect_out_dtype=torch.float16, + expect_actual_dtype=torch.float16, + ) + + +@_SM120_ONLY +def test_evt_out_dtype_fp16_native(): + """fp16 mm + fp16 silu → fp16 output. Pure-fp16 path — exercises the + pass's bf16/fp16 branch in the input-dtype check, plus the codegen's + ``cutlass::half_t`` ElementA/B/C path end-to-end.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return F.silu(torch.mm(a, self.weight.permute(1, 0))) # fp16 → fp16 + + a = torch.randn(_M, _K, device="cuda", dtype=torch.float16) + # Cast model to fp16 (not bf16) so all parameters match A's dtype. + model = M().cuda().half() + for p in model.parameters(): + p.requires_grad_(False) + + with torch.no_grad(): + expected = model(a) + + get_compile_config().disable_cache = True + stats, restore = _install_pass_instrument() + try: + compiled = magi_compile(model, dynamic_arg_dims={"a": 0}) + with torch.no_grad(): + actual = compiled(a) + finally: + restore() + + diff = (actual.float() - expected.float()).abs().max().item() + assert diff <= 0.5, f"fp16 silu max|diff|={diff}" + assert stats.fused_count == 1, f"fp16 path should fuse but got fused_count={stats.fused_count}" + assert stats.kinds == ["evt_col"], stats.kinds + assert stats.out_dtype_ids == [1], f"Expected out_dtype_id=[1] (fp16), got {stats.out_dtype_ids}" + assert actual.dtype == torch.float16, actual.dtype + + +@_SM120_ONLY +def test_evt_no_fuse_fp32_mm(): + """fp32 mm — pass requires bf16 (or fp16); fp32 must skip.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return F.silu(y) + + a = torch.randn(_M, _K, device="cuda", dtype=torch.float32) + + model = M().cuda() # fp32 — do NOT bfloat16() the model + with torch.no_grad(): + expected = model(a) + + get_compile_config().disable_cache = True + stats, restore = _install_pass_instrument() + try: + compiled_model = magi_compile(model, dynamic_arg_dims={"a": 0}) + with torch.no_grad(): + actual = compiled_model(a) + finally: + restore() + + diff = (actual - expected).abs().max().item() + assert diff <= 1.0, f"fp32 mm result diverged: {diff}" + assert stats.fused_count == 0, ( + f"fp32 mm should NOT fuse, but pass emitted {stats.fused_count} ops " f"(kinds={stats.kinds})" + ) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) From 0cfa8208cf898b440b9edbb9292f71a7f3b92e40 Mon Sep 17 00:00:00 2001 From: wtr Date: Wed, 29 Apr 2026 20:02:13 +0800 Subject: [PATCH 5/7] add enable_mm_epilogue_fusion & chore --- magi_compiler/config.py | 10 +++++ magi_compiler/cuda/device.py | 44 +++++++++++++++++++ magi_compiler/magi_backend/magi_backend.py | 2 +- .../matmul_epilogue_fusion.py | 7 +-- .../piecewise_graph/post_grad_pass_manager.py | 19 +++----- 5 files changed, 63 insertions(+), 19 deletions(-) create mode 100644 magi_compiler/cuda/device.py diff --git a/magi_compiler/config.py b/magi_compiler/config.py index c5edf38..7eb6468 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -64,6 +64,16 @@ class PassConfig(BaseModel): # TODO: Add sequence parallelism pass and async TP pass. # TODO: Add Ulysses overlap pass. enable_sage_attn: bool = Field(False, description="Whether to replace flash attention with sage attention.") + enable_mm_epilogue_fusion: bool = Field( + True, + description=( + "Whether to enable the matmul + elementwise epilogue fusion pass. " + "On RTX 5090 (sm_120) this lowers fused chains to a CUTLASS Sm80EVT " + "kernel via the blackwell_geforce.MatmulEvtEpilogueFusionPass. The " + "pass is a no-op on older architectures regardless of this flag, " + "but the flag still controls whether it is registered at all." + ), + ) @property def hash(self) -> str: diff --git a/magi_compiler/cuda/device.py b/magi_compiler/cuda/device.py new file mode 100644 index 0000000..ebcd246 --- /dev/null +++ b/magi_compiler/cuda/device.py @@ -0,0 +1,44 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU device introspection helpers. + +Centralised so that pass-manager / FX passes / runtime modules don't all +re-implement the same try/except dance around ``torch.cuda``. +""" + +from typing import Tuple + + +def device_capability(device: int = 0) -> Tuple[int, int]: + """Return ``(major, minor)`` for the given CUDA device. + + Falls back to ``(0, 0)`` when CUDA is unavailable / not initialised / + raises any error during introspection — callers compare against a + minimum cap so a zero pair always means "feature unsupported", which + is the safe behaviour on CPU-only hosts and during static analysis. + """ + try: + import torch as _torch + + if _torch.cuda.is_available(): + return _torch.cuda.get_device_capability(device) + except Exception: + pass + return (0, 0) + + +def device_capability_major(device: int = 0) -> int: + """Convenience wrapper: just the major-capability int (0 if no CUDA).""" + return device_capability(device)[0] diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index 7bafdf5..43a54c6 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -591,7 +591,7 @@ def _split_graph(self, graph: fx.GraphModule) -> tuple[fx.GraphModule, list[Spli # Step 5: visualize the split graph if envs.MAGI_ENABLE_FX_GRAPH_VIZ: - # save_fx_graph_visualization(split_gm.graph, sub_dir="after_split", filename="split_gm_root") + save_fx_graph_visualization(split_gm.graph, sub_dir="after_split", filename="split_gm_root") for item in piecewise_graphs: save_fx_graph_visualization(item.graph.graph, sub_dir="after_split", filename=item.submod_name) diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py index d8e4af2..e88b386 100644 --- a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py @@ -37,6 +37,7 @@ import torch import torch.fx as fx +from magi_compiler.cuda.device import device_capability_major from magi_compiler.passes.pass_base import MagiInductorPass from . import evt_runtime # ensures torch.library op + fake impl are registered @@ -189,11 +190,7 @@ class MatmulEvtEpilogueFusionPass(MagiInductorPass): def __init__(self, allow_extras: bool = True) -> None: # On non-sm120 we degrade to a no-op; the manager wires us only on # sm120 anyway, but defending against misuse is cheap. - try: - cap = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0) - except Exception: - cap = (0, 0) - self._enabled = cap[0] >= 12 + self._enabled = device_capability_major() >= 12 self.allow_extras = allow_extras def __call__(self, graph: fx.Graph) -> bool: diff --git a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py index d95e50b..2672cef 100644 --- a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py +++ b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py @@ -18,6 +18,7 @@ from torch._inductor.custom_graph_pass import CustomGraphPass from ...config import PassConfig +from ...cuda.device import device_capability_major from ...utils import magi_logger, set_env_var from ...utils.envs import MAGI_PATTERN_MATCH_DEBUG from ..pass_base import InductorPass, get_pass_context @@ -26,18 +27,6 @@ from .post_cleanup import PostCleanupPass -def _device_capability_major() -> int: - """Return the CUDA major capability, or 0 when CUDA is unavailable.""" - try: - import torch as _torch - - if _torch.cuda.is_available(): - return _torch.cuda.get_device_capability()[0] - except Exception: - pass - return 0 - - def with_pattern_match_debug(fn): """ Function decorator that turns on inductor pattern match debug @@ -94,7 +83,11 @@ def configure(self, pass_config: PassConfig): self.pass_config = pass_config # Matmul + epilogue fusion. On sm_120 (Blackwell consumer / RTX 5090) - if _device_capability_major() >= 12: + # we lower fused chains to a CUTLASS Sm80EVT kernel. Toggled via + # PassConfig.enable_mm_epilogue_fusion (default True). The device + # check is independent — even with the flag on, non-sm_120 hosts + # don't register the pass since its FX walker would just no-op. + if pass_config.enable_mm_epilogue_fusion and device_capability_major() >= 12: self.add(MatmulEvtEpilogueFusionPass()) # needs a functional graph From f62bd8cd3cea7090559fec41f1efa4c1106c2d13 Mon Sep 17 00:00:00 2001 From: wtr Date: Wed, 29 Apr 2026 20:04:30 +0800 Subject: [PATCH 6/7] chore --- magi_compiler/magi_backend/magi_backend.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/magi_compiler/magi_backend/magi_backend.py b/magi_compiler/magi_backend/magi_backend.py index 43a54c6..0d010e3 100644 --- a/magi_compiler/magi_backend/magi_backend.py +++ b/magi_compiler/magi_backend/magi_backend.py @@ -605,9 +605,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> MagiSerializableFun self._init_cache() - # if envs.MAGI_ENABLE_FX_GRAPH_VIZ: - # save_fx_graph_visualization(graph, sub_dir="before_split", filename="gm_root") - self.full_graph_pass_manager(graph) split_gm, piecewise_graphs = self._split_graph(graph) From 36f7fbf9ac2445640e31b9abcd7c2c771751072e Mon Sep 17 00:00:00 2001 From: wtr Date: Thu, 30 Apr 2026 11:52:29 +0800 Subject: [PATCH 7/7] update .github/codestyle/copyright.hook --- .github/codestyle/copyright.hook | 2 +- .pre-commit-config.yaml | 2 +- .../cutlass_kernels/swiglu7_combine.h | 15 +++++++++++++-- .../cutlass_kernels/swiglu7_epi_one_stage.cu | 15 +++++++++++++-- 4 files changed, 28 insertions(+), 6 deletions(-) diff --git a/.github/codestyle/copyright.hook b/.github/codestyle/copyright.hook index 484ada0..3479940 100644 --- a/.github/codestyle/copyright.hook +++ b/.github/codestyle/copyright.hook @@ -43,7 +43,7 @@ def _get_comment_mark(path): if lang_type.search(path) is not None: return "#" - lang_type=re.compile(r"\.(h|c|hpp|cc|cpp|cu|go|cuh|proto)$") + lang_type=re.compile(r"\.(h|c|hpp|hxx|cc|cpp|cxx|cu|go|cuh|proto)$") if lang_type.search(path) is not None: return "//" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c16f79..a460928 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: name: copyright_checker entry: python3 ./.github/codestyle/copyright.hook language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$ + files: \.(c|cc|cxx|cpp|cu|cuh|h|hpp|hxx|proto|py|sh)$ - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_combine.h b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_combine.h index 631a490..220549f 100644 --- a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_combine.h +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_combine.h @@ -1,6 +1,17 @@ -// Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause +// Copyright (c) 2026 SandAI. All Rights Reserved. // +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + // Binary epilogue combine functor for the swiglu7 DualGemm fusion. // // D = silu_alpha( clamp(lhs, max=limit) ) * ( clamp(rhs, -limit, limit) + 1 ) diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_epi_one_stage.cu b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_epi_one_stage.cu index 3be0203..4000654 100644 --- a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_epi_one_stage.cu +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_epi_one_stage.cu @@ -1,6 +1,17 @@ -// Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// SPDX-License-Identifier: BSD-3-Clause +// Copyright (c) 2026 SandAI. All Rights Reserved. // +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + // Single-kernel fully-fused swiglu7: // // D = swiglu7(A @ B.T)