diff --git a/.github/codestyle/copyright.hook b/.github/codestyle/copyright.hook index 484ada0..3479940 100644 --- a/.github/codestyle/copyright.hook +++ b/.github/codestyle/copyright.hook @@ -43,7 +43,7 @@ def _get_comment_mark(path): if lang_type.search(path) is not None: return "#" - lang_type=re.compile(r"\.(h|c|hpp|cc|cpp|cu|go|cuh|proto)$") + lang_type=re.compile(r"\.(h|c|hpp|hxx|cc|cpp|cxx|cu|go|cuh|proto)$") if lang_type.search(path) is not None: return "//" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2c16f79..a460928 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,7 +6,7 @@ repos: name: copyright_checker entry: python3 ./.github/codestyle/copyright.hook language: system - files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$ + files: \.(c|cc|cxx|cpp|cu|cuh|h|hpp|hxx|proto|py|sh)$ - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.4.0 hooks: diff --git a/Dockerfile b/Dockerfile index 476ad3f..e9ef25a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,21 @@ FROM nvcr.io/nvidia/pytorch:25.10-py3 ARG FLASH_ATTENTION_COMMIT_ID="b613d9e2c8475945baff3fd68f2030af1b890acf" +# CUTLASS — source is always cloned (the magi_compiler EVT-fusion path +# JIT-includes its headers and our /opt/cutlass tree is the readable +# reference checkout). The CMake-driven profiler/library is compiled +# *only* when the build host is an RTX 5090 (sm_120, Blackwell consumer); +# every other arch gets the source tree but no built artefacts. +# +# Override behaviour with a build arg: +# --build-arg CUTLASS_BUILD=yes force compile (e.g. on a build farm +# without a GPU but targeting sm_120) +# --build-arg CUTLASS_BUILD=no force skip even if 5090 detected +# --build-arg CUTLASS_BUILD=auto (default) compile iff nvidia-smi +# reports compute_cap == 12.x +ARG CUTLASS_COMMIT_ID="f74fea9ce35868d3ae9f8d1dce1969d7250d3f90" +ARG CUTLASS_BUILD="auto" + ENV PIP_NO_CACHE_DIR=1 \ PIP_DISABLE_PIP_VERSION_CHECK=1 \ PYTHONDONTWRITEBYTECODE=1 @@ -18,6 +33,7 @@ RUN --mount=type=secret,id=http_proxy,required=false \ ca-certificates \ git \ build-essential \ + cmake \ ninja-build && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean @@ -42,6 +58,42 @@ RUN --mount=type=secret,id=http_proxy,required=false \ cp /tmp/flash-attention/hopper/flash_attn_interface.py ${python_path}/flash_attn_3/ && \ rm -rf /tmp/flash-attention + +RUN --mount=type=secret,id=http_proxy,required=false \ + --mount=type=secret,id=https_proxy,required=false \ + export http_proxy="$(cat /run/secrets/http_proxy 2>/dev/null || true)" && \ + export https_proxy="$(cat /run/secrets/https_proxy 2>/dev/null || true)" && \ + mkdir -p /opt/cutlass && \ + cd /opt/cutlass && \ + git init -q && \ + git remote add origin https://github.com/NVIDIA/cutlass.git && \ + git fetch origin ${CUTLASS_COMMIT_ID} --depth 1 && \ + git checkout ${CUTLASS_COMMIT_ID} && \ + (git submodule update --init --recursive --depth 1 --jobs 8 || \ + git submodule update --init --recursive --depth 1 --jobs 1) + + +RUN set -eu; \ + case "${CUTLASS_BUILD}" in \ + no) echo "[CUTLASS] CUTLASS_BUILD=no — skipping cmake configure."; exit 0 ;; \ + yes) DO_BUILD=1 ;; \ + auto) \ + if command -v nvidia-smi >/dev/null 2>&1 && \ + nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \ + | head -n1 | grep -Eq '^12\.'; then \ + echo "[CUTLASS] nvidia-smi reports sm_120 — running cmake configure."; \ + DO_BUILD=1; \ + else \ + echo "[CUTLASS] No sm_120 detected at build time — skipping cmake (headers still available)."; \ + exit 0; \ + fi ;; \ + *) echo "[CUTLASS] Unknown CUTLASS_BUILD=${CUTLASS_BUILD}"; exit 1 ;; \ + esac; \ + [ -n "${DO_BUILD:-}" ] && cd /opt/cutlass && \ + export CUDACXX="${CUDA_INSTALL_PATH:-${CUDA_HOME:-/usr/local/cuda}}/bin/nvcc" && \ + mkdir -p build && cd build && \ + cmake .. -DCUTLASS_NVCC_ARCHS=120a + RUN --mount=type=secret,id=http_proxy,required=false \ --mount=type=secret,id=https_proxy,required=false \ export http_proxy="$(cat /run/secrets/http_proxy 2>/dev/null || true)" && \ diff --git a/magi_compiler/config.py b/magi_compiler/config.py index c5edf38..7eb6468 100644 --- a/magi_compiler/config.py +++ b/magi_compiler/config.py @@ -64,6 +64,16 @@ class PassConfig(BaseModel): # TODO: Add sequence parallelism pass and async TP pass. # TODO: Add Ulysses overlap pass. enable_sage_attn: bool = Field(False, description="Whether to replace flash attention with sage attention.") + enable_mm_epilogue_fusion: bool = Field( + True, + description=( + "Whether to enable the matmul + elementwise epilogue fusion pass. " + "On RTX 5090 (sm_120) this lowers fused chains to a CUTLASS Sm80EVT " + "kernel via the blackwell_geforce.MatmulEvtEpilogueFusionPass. The " + "pass is a no-op on older architectures regardless of this flag, " + "but the flag still controls whether it is registered at all." + ), + ) @property def hash(self) -> str: diff --git a/magi_compiler/cuda/device.py b/magi_compiler/cuda/device.py new file mode 100644 index 0000000..ebcd246 --- /dev/null +++ b/magi_compiler/cuda/device.py @@ -0,0 +1,44 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPU device introspection helpers. + +Centralised so that pass-manager / FX passes / runtime modules don't all +re-implement the same try/except dance around ``torch.cuda``. +""" + +from typing import Tuple + + +def device_capability(device: int = 0) -> Tuple[int, int]: + """Return ``(major, minor)`` for the given CUDA device. + + Falls back to ``(0, 0)`` when CUDA is unavailable / not initialised / + raises any error during introspection — callers compare against a + minimum cap so a zero pair always means "feature unsupported", which + is the safe behaviour on CPU-only hosts and during static analysis. + """ + try: + import torch as _torch + + if _torch.cuda.is_available(): + return _torch.cuda.get_device_capability(device) + except Exception: + pass + return (0, 0) + + +def device_capability_major(device: int = 0) -> int: + """Convenience wrapper: just the major-capability int (0 if no CUDA).""" + return device_capability(device)[0] diff --git a/magi_compiler/passes/full_graph/full_graph_pass_mgr.py b/magi_compiler/passes/full_graph/full_graph_pass_mgr.py index 502d190..0626350 100644 --- a/magi_compiler/passes/full_graph/full_graph_pass_mgr.py +++ b/magi_compiler/passes/full_graph/full_graph_pass_mgr.py @@ -16,6 +16,7 @@ from ...magi_depyf.timeline import observe_lifecycle from .remove_item import RemoveItemPass +from .remove_useless_ops import RemoveUselessOpsPass from .replace_sage_atten import ReplaceSageAttentionPass @@ -30,6 +31,7 @@ def __init__(self, pass_config): if self.pass_config.enable_sage_attn: self.passes.append(ReplaceSageAttentionPass()) self.passes.append(RemoveItemPass()) + self.passes.append(RemoveUselessOpsPass()) @observe_lifecycle("full_graph_manager") def __call__(self, gm: torch.fx.GraphModule): diff --git a/magi_compiler/passes/full_graph/remove_useless_ops.py b/magi_compiler/passes/full_graph/remove_useless_ops.py new file mode 100644 index 0000000..a31acc5 --- /dev/null +++ b/magi_compiler/passes/full_graph/remove_useless_ops.py @@ -0,0 +1,117 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch._inductor.fx_passes.pre_grad + +from ...magi_depyf.timeline import emit_pass_lifecycle +from ..pass_base import MagiInductorPass + + +class RemoveUselessOpsPass(MagiInductorPass): + """ + Remove useless convert, view, reshape operations. + When their input already has the target type and shape, these operations are redundant. + """ + + TARGET_METHODS = { + "view", + "reshape", + "to", + "type", + "contiguous", + "clone", + "flatten", + "permute", + "transpose", + "t", + "unsqueeze", + "squeeze", + "expand", + "repeat", + "bfloat16", + "float", + "half", + "int", + "long", + "short", + "double", + "bool", + "byte", + } + + @staticmethod + def _get_tensor_info(node: torch.fx.Node): + # Get tensor info from example_value + if "example_value" in node.meta: + val = node.meta["example_value"] + if isinstance(val, torch.Tensor): + return val.shape, val.dtype, val.stride() + elif isinstance(val, (list, tuple)) and len(val) > 0 and isinstance(val[0], torch.Tensor): + return val[0].shape, val[0].dtype, val[0].stride() + + return None, None, None + + def is_applicable(self, graph: torch.fx.Graph, shape: int | None = None) -> bool: + for node in graph.nodes: + if node.op == "call_method" and node.target in self.TARGET_METHODS: + return True + return False + + @emit_pass_lifecycle + def __call__(self, graph: torch.fx.Graph): + nodes_to_remove = [] + + for node in graph.nodes: + is_target_method = node.op == "call_method" and node.target in self.TARGET_METHODS + if not is_target_method: + continue + + # Need at least one argument (the input tensor) + if not node.args or not isinstance(node.args[0], torch.fx.Node): + continue + + input_node = node.args[0] + + node_shape, node_dtype, node_stride = self._get_tensor_info(node) + input_shape, input_dtype, input_stride = self._get_tensor_info(input_node) + if node_shape is None or input_shape is None: + continue + if node_dtype is None or input_dtype is None: + continue + # Some ops or metadata might not have stride properly captured, + # but if they do, we should require them to match to be totally safe against contiguous-forcing ops. + if node_stride is not None and input_stride is not None and node_stride != input_stride: + continue + + # Check if shape and dtype match exactly + if node_shape == input_shape and node_dtype == input_dtype: + # For _to_copy, ensure we are not changing memory format or device or other properties implicitly, + # but typically in full graph if shape and dtype match, and it's on the same device, it's safe. + # Let's also check device just in case if it's available. + def get_device(n): + if "example_value" in n.meta and isinstance(n.meta["example_value"], torch.Tensor): + return n.meta["example_value"].device + + node_device = get_device(node) + input_device = get_device(input_node) + if node_device is not None and input_device is not None and node_device != input_device: + continue + + # Replace uses + node.replace_all_uses_with(input_node) + nodes_to_remove.append(node) + + for node in nodes_to_remove: + graph.erase_node(node) diff --git a/magi_compiler/passes/piecewise_graph/fusion/__init__.py b/magi_compiler/passes/piecewise_graph/fusion/__init__.py new file mode 100644 index 0000000..3eaa44a --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/__init__.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/__init__.py new file mode 100644 index 0000000..3eaa44a --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_combine.h b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_combine.h new file mode 100644 index 0000000..220549f --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_combine.h @@ -0,0 +1,141 @@ +// Copyright (c) 2026 SandAI. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Binary epilogue combine functor for the swiglu7 DualGemm fusion. +// +// D = silu_alpha( clamp(lhs, max=limit) ) * ( clamp(rhs, -limit, limit) + 1 ) +// +// silu_alpha(x) = x * sigmoid(alpha * x) alpha = 1.702, limit = 7.0 +// +// `lhs` is the gate-path output fragment (Op0 applied to A @ W_gate.T), +// `rhs` is the linear-path output fragment (Op1 applied to A @ W_linear.T). +// Both arrive as ElementOutput (bf16) fragments — this is dictated by the +// dual-epilogue call site (examples/45_dual_gemm/threadblock/dual_epilogue.h:413 +// passes `output_frag_ptr[0][i]` and `[1][i]`, which are post-conversion +// output-type fragments, not raw accumulator fragments). The combine upcasts +// to ElementCompute (fp32) internally, evaluates the swiglu7 expression, and +// converts back to bf16. +// +// Note on precision: the gate/linear matmuls accumulate in fp32 inside the +// MMAs. Op0/Op1 (LinearCombination, ScaleType::Nothing) downcast those fp32 +// accumulators to bf16 before this combine runs. The swiglu7 math itself +// stays in fp32 here, so the only extra precision loss vs the two-stage EVT +// version is the single fp32→bf16 round-trip on each accumulator at the +// epilogue boundary. Empirically this is well within the bf16 noise floor. +// +// Modelled on cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h — +// same interface contract: ElementOutput / ElementAccumulator / ElementCompute +// typedefs, kCount fragment width, empty Params, two operator() overloads +// (fragment + scalar), is_source_needed() returning true. + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +namespace cutlass { +namespace epilogue { +namespace thread { + +template < + typename ElementOutput_, + int Count, + typename ElementAccumulator_ = ElementOutput_, + typename ElementCompute_ = ElementOutput_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class Swiglu7Combine { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + struct Params {}; + +public: + + CUTLASS_HOST_DEVICE + Swiglu7Combine(Params const& /*params*/) {} + + CUTLASS_HOST_DEVICE + bool is_source_needed() const { return true; } + + CUTLASS_HOST_DEVICE + void set_k_partition(int /*k_partition*/, int /*k_partition_count*/) { + // swiglu7 cannot be split-K-reduced (non-linear epilogue). + assert(false); + } + + // Fragment-level. lhs = gate output fragment (bf16, post Op0), + // rhs = linear output fragment (bf16, post Op1). + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentOutput const& lhs, + FragmentOutput const& rhs) const { + NumericArrayConverter in2c; + NumericArrayConverter c2o; + + ComputeFragment gate = in2c(lhs); + ComputeFragment lin = in2c(rhs); + ComputeFragment out; + + Sigmoid sig; + ElementCompute const limit(7.0f); + ElementCompute const nlimit(-7.0f); + ElementCompute const alpha(1.702f); + ElementCompute const one(1.0f); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + ElementCompute g = gate[i] < limit ? gate[i] : limit; + ElementCompute r = lin[i] < nlimit ? nlimit + : (lin[i] > limit ? limit : lin[i]); + ElementCompute silu_g = g * sig(alpha * g); + out[i] = silu_g * (r + one); + } + return c2o(out); + } + + // Scalar overload — required by the DualGemm epilogue boilerplate. + CUTLASS_HOST_DEVICE + ElementOutput operator()(ElementOutput const& lhs, + ElementOutput const& rhs) const { + ElementCompute g(lhs), r(rhs); + ElementCompute const limit(7.0f); + ElementCompute const nlimit(-7.0f); + ElementCompute const alpha(1.702f); + ElementCompute const one(1.0f); + + Sigmoid sig; + + g = g < limit ? g : limit; + r = r < nlimit ? nlimit : (r > limit ? limit : r); + ElementCompute silu_g = g * sig(alpha * g); + return ElementOutput(silu_g * (r + one)); + } +}; + +} // namespace thread +} // namespace epilogue +} // namespace cutlass diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_epi_one_stage.cu b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_epi_one_stage.cu new file mode 100644 index 0000000..4000654 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/cutlass_kernels/swiglu7_epi_one_stage.cu @@ -0,0 +1,382 @@ +// Copyright (c) 2026 SandAI. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Single-kernel fully-fused swiglu7: +// +// D = swiglu7(A @ B.T) +// +// A : (M, K) bf16 row-major +// B : (N, K) bf16 row-major (torch.nn.Linear weight convention; N even) +// D : (M, N/2) bf16 row-major +// +// Implementation uses cutlass::gemm::device::DualGemm — the two GEMMs +// A @ W_gate.T and A @ W_linear.T run in the same threadblock sharing A's +// smem stages; their accumulators stay in registers and a custom +// Swiglu7Combine epilogue functor combines them and writes only D. +// +// AUTOTUNE: at first call per (M, N, K) tuple the runner times every +// registered (TileShape, WarpShape, Stages) candidate and caches the +// fastest one. The candidate set is hand-tuned for RTX 5090 (sm_120) +// — see register_candidates() for the rationale and SMEM budget math. + +#include +#include + +#include +#include + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/util/host_tensor.h" + +#include "45_dual_gemm/device/dual_gemm.h" +#include "swiglu7_combine.h" + +//////////////////////////////////////////////////////////////////////////////// +// Data types +//////////////////////////////////////////////////////////////////////////////// + +using ElementA = cutlass::bfloat16_t; +using ElementB = cutlass::bfloat16_t; +using ElementC = cutlass::bfloat16_t; +using ElementAcc = float; +using ElementCompute = float; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB0 = cutlass::layout::ColumnMajor; // strided ldB = 2K view +using LayoutB1 = cutlass::layout::ColumnMajor; // strided ldB = 2K view +using LayoutC = cutlass::layout::RowMajor; + +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // = 8 +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // = 8 +// Output vector width = 4 (bf16, 8 bytes) so any N_out divisible by 4 is OK +// — N=27304 → N_out=13652 is 4-aligned but not 8-aligned. +constexpr int EpilogueVecCount = 4; + +using ArchTag = cutlass::arch::Sm80; +using OperatorClass = cutlass::arch::OpClassTensorOp; +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + +constexpr auto kScaleType = cutlass::epilogue::thread::ScaleType::Nothing; +constexpr bool kSplitKSerial = false; +constexpr bool kStoreD0 = false; +constexpr bool kStoreD1 = false; + +//////////////////////////////////////////////////////////////////////////////// +// Per-tile DualGemm wrapper. The DualGemm device type is templated on +// (TileShape, WarpShape, Stages) — every autotune candidate instantiates the +// full kernel for its tuple. Compile time grows linearly with candidate count +// but DualGemm Sm80 is much cheaper to compile than the EVT path (no visitor +// tree), so we can afford 8–10 candidates. +//////////////////////////////////////////////////////////////////////////////// + +template +struct DualGemmConfig { + using EpilogueOp0 = cutlass::epilogue::thread::LinearCombination< + ElementC, EpilogueVecCount, ElementAcc, ElementCompute, kScaleType>; + using EpilogueOp1 = cutlass::epilogue::thread::LinearCombination< + ElementC, EpilogueVecCount, ElementAcc, ElementCompute, kScaleType>; + using EpilogueOp2 = cutlass::epilogue::thread::Swiglu7Combine< + ElementC, EpilogueVecCount, ElementAcc, ElementCompute>; + + using Gemm = cutlass::gemm::device::DualGemm< + ElementA, LayoutA, + ElementB, LayoutB0, LayoutB1, + ElementC, LayoutC, + ElementAcc, + OperatorClass, ArchTag, + TbShape, WaShape, InstructionShape, + EpilogueOp0, EpilogueOp1, EpilogueOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + Stages, + kStoreD0, kStoreD1, kSplitKSerial, + AlignmentA, AlignmentB>; +}; + +//////////////////////////////////////////////////////////////////////////////// +// Type-erased runner concept; one instance per autotune candidate. +//////////////////////////////////////////////////////////////////////////////// + +struct Sw7Args { + int M; // activations rows + int N_out; // = N/2 (output cols) + int K; + void* ptr_A; + void* ptr_B; // (N, K) row-major weight; gate/linear interleaved + void* ptr_D; // (M, N_out) +}; + +class Sw7Concept { + public: + virtual ~Sw7Concept() = default; + virtual size_t get_workspace_size(const Sw7Args&) = 0; + virtual cutlass::Status initialize(const Sw7Args&, void* ws, cudaStream_t) = 0; + virtual cutlass::Status run(cudaStream_t stream) = 0; + virtual const char* name() const = 0; +}; + +template +class Sw7Impl : public Sw7Concept { + public: + using GemmType = typename Cfg::Gemm; + using EpilogueOp0 = typename Cfg::EpilogueOp0; + using EpilogueOp1 = typename Cfg::EpilogueOp1; + using EpilogueOp2 = typename Cfg::EpilogueOp2; + + explicit Sw7Impl(const char* name) : name_(name) {} + + typename GemmType::Arguments make_args(const Sw7Args& a) { + auto ptrA = reinterpret_cast(a.ptr_A); + auto ptrB = reinterpret_cast(a.ptr_B); + auto ptrD = reinterpret_cast(a.ptr_D); + int const M = a.M, N_out = a.N_out, K = a.K; + + int64_t const ldB_strided = static_cast(2) * K; + LayoutB0 layoutB_gate(ldB_strided); + LayoutB1 layoutB_linear(ldB_strided); + LayoutC layoutC(static_cast(N_out)); + + using TensorRefA = cutlass::TensorRef; + using TensorRefB0 = cutlass::TensorRef; + using TensorRefB1 = cutlass::TensorRef; + using TensorRefCi = cutlass::TensorRef; + using TensorRefDo = cutlass::TensorRef; + + TensorRefA ref_A0(ptrA, LayoutA(static_cast(K))); + TensorRefB0 ref_B0(ptrB, layoutB_gate); + TensorRefCi ref_C0(nullptr, LayoutC(0)); + TensorRefDo ref_D0(nullptr, LayoutC(0)); + TensorRefB1 ref_B1(ptrB + K, layoutB_linear); + TensorRefCi ref_C1(nullptr, LayoutC(0)); + TensorRefDo ref_D1(nullptr, LayoutC(0)); + TensorRefDo ref_D2(ptrD, layoutC); + + typename EpilogueOp0::Params epi0{ElementCompute(1.0f), ElementCompute(0.0f)}; + typename EpilogueOp1::Params epi1{ElementCompute(1.0f), ElementCompute(0.0f)}; + typename EpilogueOp2::Params epi2{}; + + cutlass::gemm::GemmCoord problem{M, N_out, K}; + typename GemmType::Arguments args( + cutlass::gemm::DualGemmMode::kGemm, + problem, + ref_A0, + ref_B0, ref_C0, ref_D0, + ref_B1, ref_C1, ref_D1, + ref_D2, + epi0, epi1, epi2, + /*split_k_slices=*/1, + /*batch_count=*/1, + /*batch_stride_A=*/0, + /*batch_stride_B0=*/0, + /*batch_stride_B1=*/0, + /*batch_stride_C=*/0, + /*batch_stride_D=*/0); + return args; + } + + size_t get_workspace_size(const Sw7Args& a) override { + return GemmType::get_workspace_size(make_args(a)); + } + cutlass::Status initialize(const Sw7Args& a, void* ws, cudaStream_t s) override { + return gemm_.initialize(make_args(a), ws, s); + } + cutlass::Status run(cudaStream_t stream) override { + return gemm_.run(stream); + } + const char* name() const override { return name_; } + + private: + GemmType gemm_; + const char* name_; +}; + +//////////////////////////////////////////////////////////////////////////////// +// AutoTune runner — first call per (M, N_out, K) shape times all candidates. +//////////////////////////////////////////////////////////////////////////////// + +#define SW7_TILE(tb_m, tb_n, tb_k, wa_m, wa_n, wa_k, stages, label) \ + configs_.push_back(std::make_unique< \ + Sw7Impl, \ + cutlass::gemm::GemmShape, \ + stages>>>(label)) + +class Sw7AutoTuneRunner { + public: + Sw7AutoTuneRunner() { + // Tile candidates for RTX 5090 (sm_120, 100 KB SMEM/SM, 170 SMs). + // + // SMEM cost for DualGemm = (BM + 2*BN) * BK * 2B * stages because both + // B operands live in smem simultaneously. Budget cap ~96 KB. + // + // Bucket of M doesn't drive a separate .cu here — DualGemm compiles + // fast enough that one runner with all candidates handles every M, and + // the per-shape cache picks the best for whatever M it sees. + + // ── Small / decode-friendly tiles ──────────────────────────────────────── + SW7_TILE(64, 64, 32, 32, 32, 32, 4, "T<64,64,32>_S4"); // 36 KB + SW7_TILE(64, 64, 64, 32, 32, 64, 3, "T<64,64,64>_S3"); // 72 KB + SW7_TILE(64, 128, 32, 32, 64, 32, 3, "T<64,128,32>_S3"); // 60 KB + SW7_TILE(64, 128, 32, 32, 64, 32, 4, "T<64,128,32>_S4"); // 80 KB + + // ── Medium tiles (CUTLASS bf16 reference defaults) ────────────────────── + SW7_TILE(128, 64, 32, 64, 32, 32, 3, "T<128,64,32>_S3"); // 48 KB (original default) + SW7_TILE(128, 64, 32, 64, 32, 32, 4, "T<128,64,32>_S4"); // 64 KB + SW7_TILE(128, 64, 64, 64, 32, 64, 3, "T<128,64,64>_S3"); // 96 KB + SW7_TILE(128, 128, 32, 64, 64, 32, 3, "T<128,128,32>_S3"); // 72 KB + SW7_TILE(128, 128, 32, 64, 64, 32, 4, "T<128,128,32>_S4"); // 96 KB + + // ── Large prefill tiles ───────────────────────────────────────────────── + SW7_TILE(256, 64, 32, 64, 32, 32, 3, "T<256,64,32>_S3"); // 72 KB + // (256, 128, 32) needs stages>=3 (DualGemm requires multistage). With + // stages=3 SMEM = (256 + 256) * 32 * 2 * 3 = 96 KB — exactly at budget, + // tends to fail with SMEM allocation errors at runtime. Omitted. + + // (128, 256, 32)*3 = 120 KB > 96 — omitted. + // (64, 256, 32)*3 = 108 KB > 96 — omitted. + } + + void operator()(at::Tensor A, at::Tensor B, at::Tensor D) { + TORCH_CHECK(A.is_cuda() && B.is_cuda() && D.is_cuda(), + "all inputs must be CUDA tensors"); + TORCH_CHECK(A.scalar_type() == at::kBFloat16 && B.scalar_type() == at::kBFloat16 + && D.scalar_type() == at::kBFloat16, + "all inputs must be bf16"); + TORCH_CHECK(A.dim() == 2 && B.dim() == 2 && D.dim() == 2, "A, B, D must be 2D"); + TORCH_CHECK(A.size(1) == B.size(1), "K mismatch (A.size(1) vs B.size(1))"); + TORCH_CHECK(A.is_contiguous() && B.is_contiguous() && D.is_contiguous(), + "A, B, D must be contiguous"); + + int const M = static_cast(A.size(0)); + int const K = static_cast(A.size(1)); + int const N = static_cast(B.size(0)); + TORCH_CHECK((N % 2) == 0, "N must be even, got ", N); + int const N_out = N / 2; + TORCH_CHECK(D.size(0) == M && D.size(1) == N_out, + "D must be (M, N/2) = (", M, ",", N_out, ")"); + + Sw7Args ea; + ea.M = M; ea.N_out = N_out; ea.K = K; + ea.ptr_A = A.data_ptr(); + ea.ptr_B = B.data_ptr(); + ea.ptr_D = D.data_ptr(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.device().index()).stream(); + + // Single autotune per module. The .cu is compiled per (M-bucket, N, K) + // on the Python side — every distinct weight (N, K) gets its own .cu, + // so this runner instance hosts exactly one (N, K) and one bucket. The + // first call autotunes; all subsequent calls (any M in the bucket) + // reuse `best_idx_`. + if (best_idx_ < 0) { + best_idx_ = autotune(ea, stream); + } + int idx = best_idx_; + + auto& gemm = configs_[idx]; + size_t ws_sz = gemm->get_workspace_size(ea); + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) { + ws_ = at::empty({(int64_t)ws_sz + 1}, + at::TensorOptions().dtype(at::kByte).device(A.device())); + } + auto st = gemm->initialize(ea, ws_sz > 0 ? ws_.data_ptr() : nullptr, stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "DualGemm init failed (", gemm->name(), "): ", + cutlassGetStatusString(st)); + st = gemm->run(stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "DualGemm run failed (", gemm->name(), "): ", + cutlassGetStatusString(st)); + } + + int num_configs() const { return (int)configs_.size(); } + + private: + int autotune(const Sw7Args& ea, cudaStream_t stream) { + int best_idx = -1; + float best_time = 1e30f; + cudaEvent_t s, e; + cudaEventCreate(&s); cudaEventCreate(&e); + + for (size_t i = 0; i < configs_.size(); ++i) { + auto& g = configs_[i]; + size_t ws_sz = 0; + try { ws_sz = g->get_workspace_size(ea); } + catch (...) { continue; } + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) { + ws_ = at::empty({(int64_t)ws_sz + 1}, + at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + } + void* ws_ptr = ws_sz > 0 ? ws_.data_ptr() : nullptr; + if (g->initialize(ea, ws_ptr, stream) != cutlass::Status::kSuccess) { + continue; + } + + // Warmup — 10 iters so the L2 / instruction cache settle. With only + // 3 warmups (the original count) the first timed iter sees a cold L2 + // and inflates the average, sometimes flipping the best-config choice. + for (int w = 0; w < 10; ++w) g->run(stream); + cudaStreamSynchronize(stream); + + // Time — 50 iters keeps timing noise to <1% so 2–3 % perf gaps + // between candidates are distinguishable. + cudaEventRecord(s, stream); + int iters = 50; + for (int p = 0; p < iters; ++p) g->run(stream); + cudaEventRecord(e, stream); + cudaEventSynchronize(e); + float ms = 0; + cudaEventElapsedTime(&ms, s, e); + float avg = ms / iters; + if (avg < best_time) { best_time = avg; best_idx = (int)i; } + } + cudaEventDestroy(s); cudaEventDestroy(e); + TORCH_CHECK(best_idx >= 0, + "swiglu7 AutoTune: no candidate succeeded for (M,N_out,K)=(", + ea.M, ",", ea.N_out, ",", ea.K, ")"); + return best_idx; + } + + std::vector> configs_; + int best_idx_ = -1; // -1 = not yet autotuned; sticky after first call. + at::Tensor ws_; +}; + +static Sw7AutoTuneRunner& runner() { + static Sw7AutoTuneRunner R; + return R; +} + +void swiglu7_dual_matmul_out(at::Tensor A, at::Tensor B, at::Tensor D) { + runner()(std::move(A), std::move(B), std::move(D)); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.doc() = "CUTLASS DualGemm fully-fused swiglu7 (bf16) on sm_120 — autotune"; + m.def("swiglu7_dual_matmul_out", + &swiglu7_dual_matmul_out, + "D = swiglu7(A @ B.T) in a single fused kernel; " + "A:(M,K) bf16, B:(N,K) bf16 (N even), D:(M,N/2) bf16", + pybind11::arg("A"), + pybind11::arg("B"), + pybind11::arg("D")); + m.def("num_configs", []() { return runner().num_configs(); }); +} diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_codegen.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_codegen.py new file mode 100644 index 0000000..72f7984 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_codegen.py @@ -0,0 +1,853 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Render a CUTLASS .cu source from an EVT IR tree. + +The output is a single self-contained file that: + 1. Declares any custom functor templates required by scalar-baked ops + (ClampMaxC, ScaledSiLuAlpha, GeluErf, …) — each baked with its constant. + 2. Declares the bottom-up Sm80EVT typedef chain. + 3. Declares the GemmKernel + DeviceGemm + entry point. + 4. Exposes ``evt_matmul_out`` via PYBIND11. + +We use CUTLASS 2.x ``Sm80EVT`` running backward-compat on sm_120; this matches +``$MAGI_CUTLASS_ROOT/examples/99_evt_demo/heavy_epi_torch_ext.cu`` (default +``/opt/cutlass/...``) which has been verified to deliver +5..+12 % vs the +Triton TMA path on RTX 5090 bf16. +""" + +from __future__ import annotations + +import textwrap +from typing import Dict, List, Tuple + +from .evt_ir import Accum, AuxLoad, ColBroadcast, Compute, RowBroadcast, Store, walk_leaves + +# ── PyTorch dtype string → CUTLASS type ────────────────────────────────────── +_DTYPE_TO_CUTLASS = {"bfloat16": "cutlass::bfloat16_t", "float16": "cutlass::half_t", "float32": "float"} + +# PyTorch dtype string → at::ScalarType / pybind dtype string used in TORCH_CHECK. +_DTYPE_TO_AT = {"bfloat16": "at::kBFloat16", "float16": "at::kHalf", "float32": "at::kFloat"} + + +# ── Per-M-bucket tile candidate sets, hand-tuned for RTX 5090 (sm_120) ────── +# Hardware constraints driving these choices: +# * 170 SMs — the optimal grid size is some multiple of 170; small tiles +# keep more CTAs in flight when M is short. +# * 100 KB SMEM / SM — per-stage SMEM = (BM + BN) * BK * 2 (bf16). With +# stages=4 and (128,128,32) we land at 128 KB which exceeds budget; we +# prefer stages=3 in that case. (128,128,32)*4 = 128KB, (128,256,32)*3=144KB, +# (256,128,32)*3=144KB are still over budget but CUTLASS auto-shrinks +# stages on Sm80 if SMEM doesn't fit. We rely on can_implement / init to +# reject illegal combos at autotune time. +# * Decode-style M (≤256) loses parallelism on big tiles — 1 wave covers +# just a handful of N tiles. Need small BM. +# * Prefill-style M (>2048) has plenty of parallelism — bigger tiles win +# because they amortise loads better. +# +# Each tuple is (BM, BN, BK, WM, WN, WK, NumStages, label). +# WarpShape is conventionally TileShape / (2, 2) along (M, N), keeping 4 warps. +# We include WK == BK to match Sm80 TensorOp's default warp tiling. +_TILE_CANDIDATES_5090: dict = { + # ── small (decode / single-token) ──────────────────────────────────────── + # M ≤ 256: low parallelism along M. Use small BM to launch more CTAs along N. + # All candidates have BM*BN ≤ 16384 to keep occupancy high on 170 SMs. + "small": [ + (64, 64, 32, 32, 32, 32, 4, "T<64,64,32>_S4"), + (64, 64, 64, 32, 32, 64, 3, "T<64,64,64>_S3"), + (64, 128, 32, 32, 64, 32, 3, "T<64,128,32>_S3"), + (64, 128, 32, 32, 64, 32, 4, "T<64,128,32>_S4"), + (64, 128, 64, 32, 64, 64, 3, "T<64,128,64>_S3"), + (64, 256, 32, 32, 64, 32, 3, "T<64,256,32>_S3"), + (128, 64, 32, 64, 32, 32, 3, "T<128,64,32>_S3"), + (128, 64, 32, 64, 32, 32, 4, "T<128,64,32>_S4"), + ], + # ── medium (256 < M ≤ 2048) ────────────────────────────────────────────── + # Standard CUTLASS bf16 sweet spot. Mix BM=128/256 with BN=64/128/256. + "medium": [ + (128, 128, 32, 64, 64, 32, 3, "T<128,128,32>_S3"), + (128, 128, 32, 64, 64, 32, 4, "T<128,128,32>_S4"), + (128, 128, 64, 64, 64, 64, 3, "T<128,128,64>_S3"), + (128, 256, 32, 64, 64, 32, 3, "T<128,256,32>_S3"), + (256, 128, 32, 64, 64, 32, 3, "T<256,128,32>_S3"), + (128, 64, 64, 64, 32, 64, 4, "T<128,64,64>_S4"), + (64, 128, 64, 32, 64, 64, 4, "T<64,128,64>_S4"), + ], + # ── large (M > 2048) ───────────────────────────────────────────────────── + # Plenty of parallelism — bigger tiles for better arith density. SMEM + # budget on 5090 (100 KB) restricts (256,128) and (128,256) to stages=3. + "large": [ + (128, 256, 32, 64, 64, 32, 3, "T<128,256,32>_S3"), + (256, 128, 32, 64, 64, 32, 3, "T<256,128,32>_S3"), + (128, 128, 32, 64, 64, 32, 4, "T<128,128,32>_S4"), + (128, 128, 64, 64, 64, 64, 3, "T<128,128,64>_S3"), + (256, 128, 64, 64, 64, 64, 3, "T<256,128,64>_S3"), + (128, 256, 64, 64, 64, 64, 3, "T<128,256,64>_S3"), + ], +} + + +def _emit_tile_candidates(m_bucket: str) -> str: + """Emit C++ EVT_TILE_CANDIDATE(...) statements for a given M bucket.""" + candidates = _TILE_CANDIDATES_5090.get(m_bucket, _TILE_CANDIDATES_5090["medium"]) + lines = [] + for bm, bn, bk, wm, wn, wk, stages, label in candidates: + lines.append(f' EVT_TILE_CANDIDATE({bm}, {bn}, {bk}, {wm}, {wn}, {wk}, ' f'{stages}, "{label}");') + return "\n".join(lines) + + +# For data_ptr() casts at the C++ layer. +_DTYPE_TO_AT_CPP = {"bfloat16": "at::BFloat16", "float16": "at::Half", "float32": "float"} + + +# ── Built-in CUTLASS op names for the visitor template-template parameter ──── +# Maps IR op name → (CUTLASS template name, is_class_template_with_T_only) +# Each value must be a `template class` accepting a single type arg. +_BUILTIN_FN_TEMPLATE = { + # binary + "add": "cutlass::plus", + "sub": "cutlass::minus", + "mul": "cutlass::multiplies", + "div": "cutlass::divides", + "max": "cutlass::maximum", + "min": "cutlass::minimum", + # unary + "neg": "cutlass::negate", + "sigmoid": "cutlass::epilogue::thread::Sigmoid", + "silu": "cutlass::epilogue::thread::SiLu", + "tanh": "cutlass::epilogue::thread::Tanh", + "relu": "cutlass::epilogue::thread::ReLu", + "abs": "cutlass::absolute_value_op", +} + +# Unary ops that need a custom emitted functor (CUTLASS has no built-in). +# Each maps to a body template; the body uses ``T`` as the element type and +# operates on a single ``T`` value named ``x``. +_CUSTOM_UNARY_BODY = { + "square": "return x * x;", + "exp": "return cutlass::fast_exp(x);", + "log": "return cutlass::fast_log(x);", + "sqrt": "return cutlass::fast_sqrt(x);", + "rsqrt": "return cutlass::fast_rsqrt(x);", + "erf": "return T(erff(float(x)));", + "gelu_erf": "return T(0.5f) * x * (T(1.0f) + T(erff(float(x) * 0.70710678118654752f)));", + "gelu_tanh": ( + "float v = float(x);" " return T(0.5f * v * (1.0f + tanhf(" "0.7978845608028654f * (v + 0.044715f * v * v * v))));" + ), +} + +# Scalar-baked unary ops. The body template uses ``x`` and ``c`` (the baked +# constant, emitted as a ``T`` literal — never a runtime value). +_CUSTOM_SCALAR_BODY = { + "add_scalar": "return x + c;", + "sub_scalar": "return x - c;", + "mul_scalar": "return x * c;", + "div_scalar": "return x / c;", + "rsub_scalar": "return c - x;", + "clamp_min_c": "return x < c ? c : x;", + "clamp_max_c": "return x < c ? x : c;", + # scaled_silu_alpha(x, alpha) = x * sigmoid(alpha * x). Used by GELU7. + "scaled_silu_alpha": ( + "T t = c * x;" " T one = T(1.0f);" " T sig = one / (one + cutlass::fast_exp(-t));" " return x * sig;" + ), + # pow_scalar(x, c) – emit as repeated multiplies for small int c. + # Otherwise fall back to powf. + "pow_scalar": "return T(powf(float(x), float(c)));", +} + + +def _scalar_literal_T(value: float) -> str: + """Emit a constant as a ``T(...)`` cast that survives bf16 / fp16 / fp32.""" + # repr keeps round-trip precision; "f" suffix forces float in C++. + return f"T({float(value)!r}f)" + + +def _emit_custom_functor(name: str, op: str, scalar=None) -> str: + """Emit a unary CUTLASS-compatible functor (scalar + Array spec).""" + if op in _CUSTOM_UNARY_BODY: + body = _CUSTOM_UNARY_BODY[op] + scalar_decl = "" + elif op in _CUSTOM_SCALAR_BODY: + if scalar is None: + raise ValueError(f"Scalar op {op!r} needs a baked constant") + body = _CUSTOM_SCALAR_BODY[op] + scalar_decl = f" const T c = {_scalar_literal_T(scalar)};\n" + else: + raise ValueError(f"No custom functor body for op {op!r}") + return textwrap.dedent( + f"""\ + template + struct {name} {{ + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE + T operator()(T const& x) const {{ + {scalar_decl} {body} + }} + }}; + + template + struct {name}> {{ + static const bool kIsHeavy = true; + CUTLASS_HOST_DEVICE + cutlass::Array operator()(cutlass::Array const& v) const {{ + {name} op; + cutlass::Array out; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) out[i] = op(v[i]); + return out; + }} + }}; + """ + ) + + +# ── EVT typedef + leaf args walker ──────────────────────────────────────────── + + +class _EvtEmitter: + """Bottom-up walker that emits typedef chains + leaf placeholders.""" + + def __init__(self, root: Store): + self.root = root + self.typedef_lines: List[str] = [] + self.functor_decls: List[str] = [] + self._emitted_functors: Dict[Tuple[str, str], str] = {} + self._tmp_counter = 0 + # Per-leaf metadata captured during walk: leaf identity (object id) → + # (typedef_name, leaf_kind, input_idx_or_None, dtype_str) + self.leaf_typedefs: List[Tuple[str, str, "int | None", str]] = [] + self.scalar_functor_counter = 0 + + def _new_name(self, prefix: str) -> str: + self._tmp_counter += 1 + return f"{prefix}_{self._tmp_counter}" + + def _functor_name_for(self, op: str, scalar) -> str: + """Unique struct name for a custom functor, deduped by (op, scalar).""" + key = (op, repr(scalar) if scalar is not None else "") + if key in self._emitted_functors: + return self._emitted_functors[key] + # Strip dots from the scalar so the name stays a valid C++ identifier. + scalar_tag = "" + if scalar is not None: + self.scalar_functor_counter += 1 + scalar_tag = f"_v{self.scalar_functor_counter}" + name = f"Magi_{op}{scalar_tag}" + self._emitted_functors[key] = name + self.functor_decls.append(_emit_custom_functor(name, op, scalar)) + return name + + def _compute_op_template(self, node: Compute) -> str: + """Return the C++ template-name passed as ComputeFn to VisitorCompute.""" + if node.op in _BUILTIN_FN_TEMPLATE and node.scalar is None: + return _BUILTIN_FN_TEMPLATE[node.op] + # Custom functor — either scalar-baked or unary-no-builtin (e.g. erf). + return self._functor_name_for(node.op, node.scalar) + + def emit(self) -> str: + """Walk the IR; return the typedef name of the root EVT type (EVT_D).""" + # Recurse from Store.child first to build up subtrees. + body_root = self._emit_node(self.root.child) + # The store leaf itself is the StoreD typedef wrapping body_root. + store_name = self._new_name("StoreD") + self.typedef_lines.append( + "using {name} = cutlass::epilogue::threadblock::VisitorAuxStore<\n" + " OutputTileThreadMap, ElementC,\n" + " cutlass::FloatRoundStyle::round_to_nearest,\n" + " cute::Stride>;".format(name=store_name) + ) + evt_d = self._new_name("EVT_D") + self.typedef_lines.append( + f"using {evt_d} = cutlass::epilogue::threadblock::Sm80EVT<\n" f" {store_name}, {body_root}>;" + ) + # Track the StoreD leaf metadata so the launcher knows where to bind D. + self.leaf_typedefs.append((store_name, "store", None, self.root.out_dtype)) + return evt_d + + def _emit_node(self, node) -> str: + if isinstance(node, Accum): + name = self._new_name("Accum") + self.typedef_lines.append(f"using {name} = cutlass::epilogue::threadblock::VisitorAccFetch;") + return name + if isinstance(node, RowBroadcast): + name = self._new_name("RowBcast") + elem = _DTYPE_TO_CUTLASS[node.dtype] + self.typedef_lines.append( + f"using {name} = cutlass::epilogue::threadblock::VisitorRowBroadcast<\n" + f" OutputTileThreadMap, {elem},\n" + f" cute::Stride<_0, _1, int32_t>>;" + ) + self.leaf_typedefs.append((name, "row_bcast", node.input_idx, node.dtype)) + return name + if isinstance(node, ColBroadcast): + name = self._new_name("ColBcast") + elem = _DTYPE_TO_CUTLASS[node.dtype] + self.typedef_lines.append( + f"using {name} = cutlass::epilogue::threadblock::VisitorColBroadcast<\n" + f" OutputTileThreadMap, {elem},\n" + f" cute::Stride<_1, _0, int32_t>>;" + ) + self.leaf_typedefs.append((name, "col_bcast", node.input_idx, node.dtype)) + return name + if isinstance(node, AuxLoad): + name = self._new_name("Aux") + elem = _DTYPE_TO_CUTLASS[node.dtype] + self.typedef_lines.append( + f"using {name} = cutlass::epilogue::threadblock::VisitorAuxLoad<\n" + f" OutputTileThreadMap, {elem},\n" + f" cute::Stride>;" + ) + self.leaf_typedefs.append((name, "aux_load", node.input_idx, node.dtype)) + return name + if isinstance(node, Compute): + child_names = [self._emit_node(c) for c in node.children] + compute_name = self._new_name(f"Cmp_{node.op}") + fn_template = self._compute_op_template(node) + self.typedef_lines.append( + f"using {compute_name} = cutlass::epilogue::threadblock::VisitorCompute<\n" + f" {fn_template}, ElementCompute, ElementCompute,\n" + f" cutlass::FloatRoundStyle::round_to_nearest>;" + ) + evt_name = self._new_name(f"EVT_{node.op}") + child_typedef_list = ", ".join(child_names) + self.typedef_lines.append( + f"using {evt_name} = cutlass::epilogue::threadblock::Sm80EVT<\n" f" {compute_name}, {child_typedef_list}>;" + ) + return evt_name + raise TypeError(f"Unknown IR node type: {type(node).__name__}") + + +# ── Argument-tree emitter (matches EVT typedef tree) ────────────────────────── + + +def _emit_args_tree(node, leaf_args: Dict[int, str], indent: int = 4) -> str: + """Emit the nested-brace runtime callback-args literal matching the IR. + + ``leaf_args[input_idx]`` for non-Accum leaves is a small C++ snippet like + ``{ptrBias, ElementC(0), {_0{}, _1{}, int32_t(N)}}``. Accum / Compute / + Store args are empty braces ``{}``. The Store arg is ``{ptrD, {N, _1{}, + MN}}`` and is handled by the caller — this emitter only renders the body + inside StoreD. + """ + pad = " " * indent + if isinstance(node, Accum): + return f"{pad}{{}}" + if isinstance(node, (RowBroadcast, ColBroadcast, AuxLoad)): + return f"{pad}{leaf_args[node.input_idx]}" + if isinstance(node, Compute): + children_str = ",\n".join(_emit_args_tree(c, leaf_args, indent + 2) for c in node.children) + return f"{pad}{{\n" f"{children_str},\n" f"{pad} {{}}\n" f"{pad}}}" + raise TypeError(f"Unknown IR node type: {type(node).__name__}") + + +# ── Public API: render a complete .cu source string ────────────────────────── + + +_KERNEL_PREAMBLE = """\ +// AUTO-GENERATED by magi_compiler/passes/piecewise_graph/fusion/evt_codegen.py +// Do not edit by hand. Regenerate by re-running the FX pass. +// +// IR cache key: {cache_key} + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +using cute::_0; +using cute::_1; + +//////////////////////////////////////////////////////////////////////////////// +// Custom functors (one per unique scalar-baked op or non-builtin unary). +//////////////////////////////////////////////////////////////////////////////// +{functor_decls} + +//////////////////////////////////////////////////////////////////////////////// +// Data types and layouts +//////////////////////////////////////////////////////////////////////////////// + +using ElementA = {a_elem}; +using ElementB = {b_elem}; +using ElementC = {c_elem}; +using ElementAcc = float; +using ElementCompute = float; + +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::{b_layout}; +using LayoutC = cutlass::layout::RowMajor; + +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; +// AlignmentC = 4 instead of 8 so any N-divisible-by-4 output works (e.g. odd +// half-N values like 13652 from N=27304). Aligned tails still vectorise. +constexpr int AlignmentC = 4; + +using ArchTag = cutlass::arch::Sm80; +using OperatorClass = cutlass::arch::OpClassTensorOp; +using InstructionShape = cutlass::gemm::GemmShape< 16, 8, 16>; +constexpr int EVTEpilogueStages = 1; + +//////////////////////////////////////////////////////////////////////////////// +// Per-tile-config GEMM type. The OutputTileThreadMap depends on +// ThreadblockShape/WarpShape, which forces every EVT typedef to be re-built +// per tile. We package the whole tree inside a template struct keyed on the +// tile/warp/stages parameters so each autotune candidate is a distinct type. +//////////////////////////////////////////////////////////////////////////////// + +template +struct EvtConfig {{ + using TheTbShape = TbShape; + using TheWarpShape = WarpShape; + + using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + TbShape, WarpShape, ElementC, AlignmentC, EVTEpilogueStages>; + + //////////////////////////////////////////////////////////////////////////// + // EVT (Epilogue Visitor Tree) typedefs — generated from the IR tree. + //////////////////////////////////////////////////////////////////////////// +{typedef_block} + + //////////////////////////////////////////////////////////////////////////// + // GemmKernel / DeviceGemm + //////////////////////////////////////////////////////////////////////////// + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, + ElementC, LayoutC, AlignmentC, + ElementAcc, + ElementCompute, + OperatorClass, + ArchTag, + TbShape, + WarpShape, + InstructionShape, + {evt_root_name}, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + NumStages, + cutlass::arch::OpMultiplyAdd, + EVTEpilogueStages>::GemmKernel; + + using DeviceGemm = cutlass::gemm::device::GemmUniversalAdapter; +}}; + +//////////////////////////////////////////////////////////////////////////////// +// Autotune runner — one candidate per tile/warp/stages combination; first call +// at a new (M, N, K) tuple times every candidate and caches the winner. +//////////////////////////////////////////////////////////////////////////////// + +struct EvtArgs {{ + int M; + int N; + int K; + void* ptr_A; + void* ptr_B; + void* ptr_D; + // Extras pointers, in IR-leaf order. + std::vector ptr_extras; +}}; + +class EvtConcept {{ + public: + virtual ~EvtConcept() = default; + virtual size_t get_workspace_size(const EvtArgs&) = 0; + virtual cutlass::Status initialize(const EvtArgs&, void* ws, cudaStream_t s) = 0; + virtual cutlass::Status run(cudaStream_t stream) = 0; + virtual const char* name() const = 0; +}}; + +template +class EvtImpl : public EvtConcept {{ + public: + using GemmType = typename Cfg::DeviceGemm; + using EvtRoot = typename Cfg::{evt_root_name}; + + explicit EvtImpl(const char* name) : name_(name) {{}} + + typename GemmType::Arguments make_args(const EvtArgs& a) {{ + auto ptrA = reinterpret_cast(a.ptr_A); + auto ptrB = reinterpret_cast(a.ptr_B); + auto ptrD = reinterpret_cast(a.ptr_D); + int const M = a.M; + int const N = a.N; + int const K = a.K; + int64_t const MN = static_cast(M) * static_cast(N); + + typename EvtRoot::Arguments callback_args{{ +{args_tree} + , + {{ptrD, {{int64_t(N), _1{{}}, MN}}}} + }}; + + cutlass::gemm::GemmCoord problem{{M, N, K}}; + typename GemmType::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, + problem, + /*batch_count=*/1, + callback_args, + ptrA, ptrB, + /*ptr_C=*/nullptr, /*ptr_D=*/nullptr, + /*batch_stride_A=*/static_cast(M) * K, + /*batch_stride_B=*/static_cast(N) * K, + /*batch_stride_C=*/0, /*batch_stride_D=*/0, + /*stride_a=*/static_cast(K), + /*stride_b=*/static_cast({stride_b_expr}), + /*stride_c=*/0, /*stride_d=*/0); + return args; + }} + + size_t get_workspace_size(const EvtArgs& a) override {{ + auto args = make_args(a); + return GemmType::get_workspace_size(args); + }} + cutlass::Status initialize(const EvtArgs& a, void* ws, cudaStream_t s) override {{ + auto args = make_args(a); + return gemm_.initialize(args, ws, s); + }} + cutlass::Status run(cudaStream_t stream) override {{ + return gemm_.run(stream); + }} + const char* name() const override {{ return name_; }} + + private: + GemmType gemm_; + const char* name_; +}}; + +//////////////////////////////////////////////////////////////////////////////// +// Python-facing launcher +//////////////////////////////////////////////////////////////////////////////// +""" + + +_LAUNCHER_TEMPLATE = """\ +//////////////////////////////////////////////////////////////////////////////// +// Tile candidate registration. Each AutoConfigBuilder invocation instantiates +// the full EVT typedef tree + GemmKernel for that (TileShape, WarpShape, +// NumStages) tuple. Compile time grows linearly with the candidate count, so +// keep the list small and shape-relevant. +//////////////////////////////////////////////////////////////////////////////// + +#define EVT_TILE_CANDIDATE(tb_m, tb_n, tb_k, wa_m, wa_n, wa_k, stages, label) \\ + configs_.push_back(std::make_unique, \\ + cutlass::gemm::GemmShape, \\ + stages>>>(label)) + +class EvtAutoTuneRunner {{ + public: + EvtAutoTuneRunner() {{ +{tile_candidate_block} + }} + + void operator()(at::Tensor A, at::Tensor B, + std::vector extras, at::Tensor D) {{ + TORCH_CHECK(A.is_cuda() && B.is_cuda() && D.is_cuda(), + "evt_matmul_out: A/B/D must be CUDA tensors"); + TORCH_CHECK(A.scalar_type() == {a_at_dtype}, "A must be {a_dtype}"); + TORCH_CHECK(B.scalar_type() == {b_at_dtype}, "B must be {b_dtype}"); + TORCH_CHECK(D.scalar_type() == {c_at_dtype}, "D must be {c_dtype}"); + TORCH_CHECK(A.dim() == 2 && B.dim() == 2, "A, B must be 2D"); + TORCH_CHECK(A.is_contiguous() && B.is_contiguous() && D.is_contiguous(), + "A, B, D must be contiguous (row-major)"); + + int const M = static_cast(A.size(0)); + int const K = static_cast(A.size(1)); + int const N = static_cast({n_dim_expr}); + + TORCH_CHECK(D.size(0) == M && D.size(1) == N, + "D must be (M, N); got ", D.sizes()); + TORCH_CHECK(extras.size() == {n_extras}, "expected {n_extras} extra tensors, got ", extras.size()); + +{extras_validation} + + EvtArgs ea; + ea.M = M; ea.N = N; ea.K = K; + ea.ptr_A = A.data_ptr<{a_at_cpp}>(); + ea.ptr_B = B.data_ptr<{b_at_cpp}>(); + ea.ptr_D = D.data_ptr<{c_at_cpp}>(); + ea.ptr_extras.reserve({n_extras}); +{extras_ptrs} + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.device().index()).stream(); + + // Single autotune per module. The .cu is compiled per (IR, M-bucket, + // b_layout, N, K) on the Python side — every distinct weight (N, K) + // gets its own .cu, so this runner instance hosts exactly one (N, K) + // and one bucket of M values. Autotune once on the first call; all + // subsequent calls (any M inside the bucket) reuse `best_idx_`. + if (best_idx_ < 0) {{ + best_idx_ = autotune(ea, stream); + }} + int idx = best_idx_; + + auto& gemm = configs_[idx]; + size_t ws_sz = gemm->get_workspace_size(ea); + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) {{ + ws_ = at::empty({{(int64_t)ws_sz + 1}}, + at::TensorOptions().dtype(at::kByte).device(A.device())); + }} + auto st = gemm->initialize(ea, ws_sz > 0 ? ws_.data_ptr() : nullptr, stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "CUTLASS init failed (", gemm->name(), "): ", cutlassGetStatusString(st)); + st = gemm->run(stream); + TORCH_CHECK(st == cutlass::Status::kSuccess, + "CUTLASS run failed (", gemm->name(), "): ", cutlassGetStatusString(st)); + }} + + int num_configs() const {{ return (int)configs_.size(); }} + + private: + int autotune(const EvtArgs& ea, cudaStream_t stream) {{ + int best_idx = -1; + float best_time = 1e30f; + cudaEvent_t s, e; + cudaEventCreate(&s); cudaEventCreate(&e); + + for (size_t i = 0; i < configs_.size(); ++i) {{ + auto& g = configs_[i]; + size_t ws_sz = 0; + try {{ ws_sz = g->get_workspace_size(ea); }} + catch (...) {{ continue; }} + if (!ws_.defined() || ws_.numel() < (int64_t)ws_sz) {{ + ws_ = at::empty({{(int64_t)ws_sz + 1}}, + at::TensorOptions().dtype(at::kByte).device(at::kCUDA)); + }} + void* ws_ptr = ws_sz > 0 ? ws_.data_ptr() : nullptr; + if (g->initialize(ea, ws_ptr, stream) != cutlass::Status::kSuccess) {{ + continue; + }} + + // Warmup — 10 iters so L2 / inst caches settle (3 was too few — first + // timed iter saw a cold L2 and biased the choice towards smaller tiles). + for (int w = 0; w < 10; ++w) g->run(stream); + cudaStreamSynchronize(stream); + + // Time — 20 iters for ~1% timing noise, matching torch.compile defaults. + cudaEventRecord(s, stream); + int iters = 20; + for (int p = 0; p < iters; ++p) g->run(stream); + cudaEventRecord(e, stream); + cudaEventSynchronize(e); + float ms = 0; + cudaEventElapsedTime(&ms, s, e); + float avg = ms / iters; + if (avg < best_time) {{ best_time = avg; best_idx = (int)i; }} + }} + cudaEventDestroy(s); cudaEventDestroy(e); + TORCH_CHECK(best_idx >= 0, + "EVT AutoTune: no candidate succeeded for (M,N,K)=(", + ea.M, ",", ea.N, ",", ea.K, ")"); + return best_idx; + }} + + std::vector> configs_; + int best_idx_ = -1; // -1 = not yet autotuned; sticky after first call. + at::Tensor ws_; +}}; + +static EvtAutoTuneRunner& runner() {{ + static EvtAutoTuneRunner R; + return R; +}} + +void evt_matmul_out(at::Tensor A, at::Tensor B, + std::vector extras, + at::Tensor D) {{ + runner()(std::move(A), std::move(B), std::move(extras), std::move(D)); +}} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {{ + m.doc() = "Magi compiler EVT-fused matmul (auto-generated, autotune)"; + m.def("evt_matmul_out", &evt_matmul_out, + "Fused EVT matmul: D = epilogue(A @ B, extras...)", + pybind11::arg("A"), pybind11::arg("B"), + pybind11::arg("extras"), pybind11::arg("D")); + m.def("num_configs", []() {{ return runner().num_configs(); }}); +}} +""" + + +def render_evt_cu( + ir: Store, a_dtype: str, b_dtype: str, cache_key_str: str = "", b_layout: str = "row", m_bucket: str = "medium" +) -> str: + """Render a complete .cu source for the given EVT IR. + + Parameters + ---------- + ir : Store + Root of the EVT IR tree. + a_dtype, b_dtype : str + Element types for A and B (typically ``"bfloat16"``). Output dtype is + taken from ``ir.out_dtype``. + cache_key_str : str + Optional hash echoed in a top-level comment, useful for debugging. + b_layout : "row" | "col" + ``"row"`` (default): B is contiguous (K, N) row-major; LayoutB = + RowMajor; ldB = N. ``"col"``: B is the underlying (N, K) row-major + weight (== column-major (K, N)); LayoutB = ColumnMajor; ldB = K. Use + ``"col"`` when the FX graph passes ``permute([1,0])(weight)`` as B. + m_bucket : "small" | "medium" | "large" + Picks a tile-candidate set tuned for RTX 5090 (sm_120) at the given M + regime. The runner inside the rendered .cu autotunes across all + candidates in that bucket on the first call per (M, N, K) shape and + caches the winner. + """ + if b_layout not in ("row", "col"): + raise ValueError(f"b_layout must be 'row' or 'col', got {b_layout!r}") + if m_bucket not in _TILE_CANDIDATES_5090: + raise ValueError(f"unknown m_bucket {m_bucket!r}; " f"expected one of {list(_TILE_CANDIDATES_5090)}") + if not isinstance(ir, Store): + raise TypeError("render_evt_cu expects a Store node as root") + tile_candidate_block = _emit_tile_candidates(m_bucket) + + a_elem = _DTYPE_TO_CUTLASS[a_dtype] + b_elem = _DTYPE_TO_CUTLASS[b_dtype] + c_elem = _DTYPE_TO_CUTLASS[ir.out_dtype] + + emitter = _EvtEmitter(ir) + evt_root = emitter.emit() + + # Build per-leaf runtime arg fragments. These get inlined into + # ``EvtImpl::make_args`` (a method on a different class than the launcher + # that fills ea.ptr_extras). The only shared state between the two scopes + # is the EvtArgs struct ``a``, so we read pointers from a.ptr_extras[i] + # and cast back to the leaf's element type. + leaves = walk_leaves(ir) + leaf_args: Dict[int, str] = {} + for leaf in leaves: + # Accum has no extras pointer / dtype — skip; it consumes the GEMM + # accumulator directly via VisitorAccFetch. + if not isinstance(leaf, (RowBroadcast, ColBroadcast, AuxLoad)): + continue + elem = _DTYPE_TO_CUTLASS[leaf.dtype] + ptr_expr = f"reinterpret_cast<{elem}*>(a.ptr_extras[{leaf.input_idx}])" + if isinstance(leaf, RowBroadcast): + leaf_args[leaf.input_idx] = f"{{{ptr_expr}, {elem}(0), {{_0{{}}, _1{{}}, int32_t(N)}}}}" + elif isinstance(leaf, ColBroadcast): + leaf_args[leaf.input_idx] = f"{{{ptr_expr}, {elem}(0), {{_1{{}}, _0{{}}, int32_t(M)}}}}" + else: # AuxLoad + leaf_args[leaf.input_idx] = f"{{{ptr_expr}, {elem}(0), {{int64_t(N), _1{{}}, MN}}}}" + # Accum has no explicit args entry. + + args_tree = _emit_args_tree(ir.child, leaf_args, indent=8) + + # Extras-validation + pointer-extraction blocks. The same external tensor + # (same input_idx) may appear at multiple leaves in the IR tree — e.g. an + # ``add(mm, bias)`` value flowing into both ``sigmoid`` and ``mul`` creates + # two RowBroadcast(0) leaves. We must declare ``ptr_extra_0`` exactly once + # in the launcher; the runtime args tree still references the same ptr + # name from each leaf-arg fragment so this dedup is purely a C++ scope fix. + extras_validation_lines = [] + extras_ptr_lines = [] + seen_extras: set = set() + extra_leaves = [n for n in leaves if not isinstance(n, Accum)] + n_extras = max((leaf.input_idx for leaf in extra_leaves), default=-1) + 1 + for leaf in extra_leaves: + i = leaf.input_idx + if i in seen_extras: + continue + seen_extras.add(i) + at_dtype = _DTYPE_TO_AT[leaf.dtype] + at_cpp = _DTYPE_TO_AT_CPP[leaf.dtype] + _DTYPE_TO_CUTLASS[leaf.dtype] + if isinstance(leaf, RowBroadcast): + extras_validation_lines.append(f' TORCH_CHECK(extras[{i}].numel() == N, "extras[{i}] must have N elements");') + elif isinstance(leaf, ColBroadcast): + extras_validation_lines.append(f' TORCH_CHECK(extras[{i}].numel() == M, "extras[{i}] must have M elements");') + elif isinstance(leaf, AuxLoad): + extras_validation_lines.append( + f' TORCH_CHECK(extras[{i}].size(0) == M && extras[{i}].size(1) == N,' f' "extras[{i}] must be (M,N)");' + ) + extras_validation_lines.append( + f' TORCH_CHECK(extras[{i}].scalar_type() == {at_dtype},' f' "extras[{i}] must be {leaf.dtype}");' + ) + extras_validation_lines.append(f' TORCH_CHECK(extras[{i}].is_cuda(), "extras[{i}] must be CUDA");') + # Push raw pointer into ea.ptr_extras for the make_args() side to + # read (it lives in a different scope than this launcher fn). + extras_ptr_lines.append(f" ea.ptr_extras.push_back(static_cast(" f"extras[{i}].data_ptr<{at_cpp}>()));") + + extras_validation = "\n".join(extras_validation_lines) if extras_validation_lines else " // no extras" + extras_ptrs = "\n".join(extras_ptr_lines) if extras_ptr_lines else "" + + # Emit. The functor decls already end with a trailing newline each. + functor_decls = "\n".join(emitter.functor_decls) if emitter.functor_decls else "// (no custom functors)" + # typedef_block lives inside ``struct EvtConfig`` — indent each line by 2 + # spaces so member typedefs read consistently with the surrounding struct. + typedef_block = "\n".join(" " + l if l.strip() else l for l in "\n".join(emitter.typedef_lines).split("\n")) + + cutlass_b_layout = "RowMajor" if b_layout == "row" else "ColumnMajor" + if b_layout == "row": + # B is (K, N) row-major contiguous: K from B.size(0), N from B.size(1), ldB = N. + n_dim_expr = "B.size(1)" + stride_b_expr = "N" + else: + # B is the underlying (N, K) row-major weight (we read the same + # bytes via ColumnMajor (K, N)): N from B.size(0), K from B.size(1), ldB = K. + n_dim_expr = "B.size(0)" + stride_b_expr = "K" + + preamble = _KERNEL_PREAMBLE.format( + cache_key=cache_key_str, + functor_decls=functor_decls, + a_elem=a_elem, + b_elem=b_elem, + c_elem=c_elem, + typedef_block=typedef_block, + evt_root_name=evt_root, + b_layout=cutlass_b_layout, + # EvtImpl::make_args uses args_tree + stride_b_expr; same values as + # the launcher (per-IR / per-layout, not per-tile-config). + args_tree=args_tree, + stride_b_expr=stride_b_expr, + ) + launcher = _LAUNCHER_TEMPLATE.format( + evt_root_name=evt_root, + args_tree=args_tree, + a_dtype=a_dtype, + b_dtype=b_dtype, + c_dtype=ir.out_dtype, + a_at_dtype=_DTYPE_TO_AT[a_dtype], + b_at_dtype=_DTYPE_TO_AT[b_dtype], + c_at_dtype=_DTYPE_TO_AT[ir.out_dtype], + a_at_cpp=_DTYPE_TO_AT_CPP[a_dtype], + b_at_cpp=_DTYPE_TO_AT_CPP[b_dtype], + c_at_cpp=_DTYPE_TO_AT_CPP[ir.out_dtype], + n_extras=n_extras, + extras_validation=extras_validation, + extras_ptrs=extras_ptrs, + n_dim_expr=n_dim_expr, + stride_b_expr=stride_b_expr, + tile_candidate_block=tile_candidate_block, + ) + return preamble + launcher diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_ir.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_ir.py new file mode 100644 index 0000000..ae6bc1e --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_ir.py @@ -0,0 +1,242 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""EVT (Epilogue Visitor Tree) intermediate representation. + +A small dataclass IR that the FX pass builds while walking the consumers of an +``aten.mm`` node, and that ``evt_codegen.py`` consumes to render a CUTLASS .cu +source. The IR is canonicalised to a deterministic JSON string used as the +cache key for the JIT'd kernel module. + +The IR is rooted at a single ``Store`` node and forms a DAG of compute nodes +over leaves (``Accum``, ``RowBroadcast``, ``ColBroadcast``, ``AuxLoad``). + +Op naming: every name in ``UNARY_OPS`` / ``BINARY_OPS`` corresponds to a +CUTLASS visitor template that ``evt_codegen.py`` knows how to emit. Adding a +new op requires updating both this file and the codegen. +""" + +from __future__ import annotations + +import hashlib +import json +from dataclasses import dataclass +from typing import List, Optional, Union + +# Ops that take a single child tensor and produce a tensor of the same shape. +# All run in fp32 inside the EVT epilogue. +UNARY_OPS = frozenset( + {"neg", "sigmoid", "silu", "gelu_erf", "gelu_tanh", "tanh", "relu", "square", "erf", "exp", "log", "sqrt", "rsqrt", "abs"} +) + +# Ops that take two child tensors. Both children must be EVT subtrees. +BINARY_OPS = frozenset({"add", "sub", "mul", "div", "max", "min"}) + +# Unary ops that bake a single fp32 scalar into the functor at codegen time. +# Used to fold scalar literals out of the IR so they don't bloat the cache key. +SCALAR_UNARY_OPS = frozenset( + { + "add_scalar", # x + c + "sub_scalar", # x - c + "mul_scalar", # x * c + "div_scalar", # x / c + "rsub_scalar", # c - x + "clamp_min_c", # max(x, c) + "clamp_max_c", # min(x, c) + "scaled_silu_alpha", # x * sigmoid(alpha * x), used by gelu7 + "pow_scalar", # x ** c (only sensible for small integer c) + } +) + +ALL_OPS = UNARY_OPS | BINARY_OPS | SCALAR_UNARY_OPS + +# Output dtype tags propagated from FakeTensor metadata into Store and leaves. +# Kept as strings (not torch.dtype) so the IR is JSON-serialisable. +DTYPES = frozenset({"bfloat16", "float16", "float32"}) + + +# ── Leaf nodes ──────────────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class Accum: + """The fp32 GEMM accumulator. Always the unique starting leaf of the IR.""" + + kind: str = "accum" + + +@dataclass(frozen=True) +class RowBroadcast: + """1-D (N,) tensor broadcast along the M axis. Maps to VisitorRowBroadcast. + + ``input_idx`` is the position of this tensor in the runtime ``extras`` list. + ``dtype`` is the storage dtype; the visitor casts to fp32 internally. + """ + + input_idx: int + dtype: str + kind: str = "row_bcast" + + +@dataclass(frozen=True) +class ColBroadcast: + """1-D (M,) tensor broadcast along the N axis. Maps to VisitorColBroadcast.""" + + input_idx: int + dtype: str + kind: str = "col_bcast" + + +@dataclass(frozen=True) +class AuxLoad: + """2-D (M, N) row-major aux tensor. Maps to VisitorAuxLoad. + + Caller must guarantee ``stride[1] == 1`` and that ``stride[0]`` is 16-byte + aligned (cp.async requirement). + """ + + input_idx: int + dtype: str + kind: str = "aux_load" + + +# ── Compute nodes ───────────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class Compute: + """An interior fp32 elementwise op. + + Children are EVT subtrees (any of the leaf or compute types). + For SCALAR_UNARY_OPS, ``children`` has length 1 and ``scalar`` carries the + baked constant. + For UNARY_OPS, ``children`` has length 1, ``scalar`` is None. + For BINARY_OPS, ``children`` has length 2, ``scalar`` is None. + """ + + op: str + children: tuple + scalar: Optional[float] = None + kind: str = "compute" + + def __post_init__(self): + # Validate at construction time so codegen never sees a malformed IR. + if self.op not in ALL_OPS: + raise ValueError(f"Unknown EVT op: {self.op!r}") + if self.op in UNARY_OPS: + if len(self.children) != 1 or self.scalar is not None: + raise ValueError(f"UNARY op {self.op!r} requires 1 child, no scalar") + elif self.op in BINARY_OPS: + if len(self.children) != 2 or self.scalar is not None: + raise ValueError(f"BINARY op {self.op!r} requires 2 children, no scalar") + elif self.op in SCALAR_UNARY_OPS: + if len(self.children) != 1 or self.scalar is None: + raise ValueError(f"SCALAR_UNARY op {self.op!r} requires 1 child + scalar") + + +@dataclass(frozen=True) +class Store: + """Root of the IR. Casts the fp32 result to ``out_dtype`` and writes D.""" + + child: object # any IR node + out_dtype: str + kind: str = "store" + + def __post_init__(self): + if self.out_dtype not in DTYPES: + raise ValueError(f"Unknown out_dtype {self.out_dtype!r}") + + +# Union type alias for type hints. +IRNode = Union[Accum, RowBroadcast, ColBroadcast, AuxLoad, Compute, Store] + + +# ── Canonicalisation + serialisation ────────────────────────────────────────── + + +def to_dict(node) -> dict: + """Recursively convert an IR node tree into a JSON-friendly dict. + + The dict layout is designed for stable hashing: keys appear in a fixed + order and floats are formatted with ``repr`` so 1.702 vs 1.7020000001 + never collide. + """ + if isinstance(node, Accum): + return {"kind": "accum"} + if isinstance(node, RowBroadcast): + return {"kind": "row_bcast", "input_idx": node.input_idx, "dtype": node.dtype} + if isinstance(node, ColBroadcast): + return {"kind": "col_bcast", "input_idx": node.input_idx, "dtype": node.dtype} + if isinstance(node, AuxLoad): + return {"kind": "aux_load", "input_idx": node.input_idx, "dtype": node.dtype} + if isinstance(node, Compute): + d = {"kind": "compute", "op": node.op, "children": [to_dict(c) for c in node.children]} + if node.scalar is not None: + # repr of a float is round-trip-safe; explicitly stringify so JSON + # never serialises 1.7000000000000002. + d["scalar"] = repr(float(node.scalar)) + return d + if isinstance(node, Store): + return {"kind": "store", "out_dtype": node.out_dtype, "child": to_dict(node.child)} + raise TypeError(f"Unknown IR node type: {type(node).__name__}") + + +def to_canonical_json(node) -> str: + """Deterministic JSON string for an IR tree. Same IR ⇒ same string.""" + return json.dumps(to_dict(node), sort_keys=True, separators=(",", ":")) + + +def cache_key(node, a_dtype: str, b_dtype: str) -> str: + """SHA-256 hash of (IR JSON, A dtype, B dtype). Used as the JIT module key.""" + payload = {"ir": to_dict(node), "a": a_dtype, "b": b_dtype, "version": 1} + blob = json.dumps(payload, sort_keys=True, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(blob).hexdigest() + + +# ── Tree walkers ────────────────────────────────────────────────────────────── + + +def walk_leaves(node) -> List: + """Return all leaf nodes (Accum / RowBroadcast / ColBroadcast / AuxLoad) + in left-to-right pre-order. Used by codegen to enumerate kernel inputs.""" + out: list = [] + + def _go(n): + if isinstance(n, (Accum, RowBroadcast, ColBroadcast, AuxLoad)): + out.append(n) + elif isinstance(n, Compute): + for c in n.children: + _go(c) + elif isinstance(n, Store): + _go(n.child) + else: + raise TypeError(f"Unknown IR node type: {type(n).__name__}") + + _go(node) + return out + + +def is_trivial(node) -> bool: + """An IR is trivial when ``Store(Accum)`` — no compute on the accumulator. + + Trivial IRs would replace cuBLAS with a more expensive kernel for no + benefit, so the FX pass should refuse to emit them. + """ + return isinstance(node, Store) and isinstance(node.child, Accum) + + +def num_extras(node) -> int: + """Maximum input_idx + 1 across all non-Accum leaves, or 0 if none.""" + indices: list = [leaf.input_idx for leaf in walk_leaves(node) if not isinstance(leaf, Accum)] + return max(indices) + 1 if indices else 0 diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_runtime.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_runtime.py new file mode 100644 index 0000000..41d034a --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/evt_runtime.py @@ -0,0 +1,457 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Runtime side of the EVT fusion: torch.library op + JIT loader + dispatch. + +This file owns: + * The ``magi_epilogue::matmul_custom_evt`` torch.library op + fake impl. + * A process-level cache mapping IR JSON → compiled cpp_extension module. + * Dispatch to one of two backends: + - ``kind == "evt"`` → JIT-compiled CUTLASS Sm80EVT kernel. + - ``kind == "swiglu7_dual"`` → vendored DualGemm one-stage kernel. + +The kernel build directory uses the IR cache key as its name so re-runs and +multi-process Inductor compile workers all hit the same on-disk cache. +""" + +from __future__ import annotations + +import hashlib +import json +import os +import threading +from typing import Optional + +import torch + +from magi_compiler.config import get_compile_config + +from .evt_codegen import render_evt_cu +from .evt_ir import Accum, AuxLoad, ColBroadcast, Compute, RowBroadcast, Store + +# ── torch.library op definition ─────────────────────────────────────────────── +# Reuse the existing ``magi_epilogue`` library so all our custom matmul ops +# live under one namespace. Defining a fresh op here is harmless even if +# ``matmul_epilogue_fusion.py`` has already initialised the library. +_LIB = torch.library.Library("magi_epilogue", "FRAGMENT") +_LIB.define( + "matmul_custom_evt(Tensor A, Tensor B, Tensor[] extras, str ir_json," " str kind, int n_out, int out_dtype_id) -> Tensor" +) + + +# ── Output-dtype encoding (must round-trip through torch.library int args) ──── +_OUT_DTYPE_ID = {torch.bfloat16: 0, torch.float16: 1, torch.float32: 2} +_ID_TO_DTYPE = {v: k for k, v in _OUT_DTYPE_ID.items()} +_DTYPE_TO_STR = {torch.bfloat16: "bfloat16", torch.float16: "float16", torch.float32: "float32"} + + +def out_dtype_id(dt: torch.dtype) -> int: + """Encode a torch.dtype as a small int for inclusion in op args.""" + if dt not in _OUT_DTYPE_ID: + raise ValueError(f"Unsupported EVT output dtype {dt}") + return _OUT_DTYPE_ID[dt] + + +def out_dtype_from_id(i: int) -> torch.dtype: + return _ID_TO_DTYPE[i] + + +# ── Compile cache + per-key build lock ──────────────────────────────────────── +_MODULE_CACHE: dict = {} # cache_key (sha256 str) → loaded cpp_extension module +# Hot-path fast cache — avoids ``json.dumps + sha256`` (~10–30 μs/call) when +# the module has already been compiled. Keyed by the 4-tuple of (Python-) +# hashable inputs that uniquely determine the rendered .cu, since equality on +# the tuple is sufficient (no need to canonicalise twice). Populated on the +# slow path inside ``_compile_evt_module``. +_MODULE_FAST_CACHE: dict = {} # (ir_json, a_dtype, b_dtype, b_layout) → module +_MODULE_LOCKS: dict = {} # cache_key → threading.Lock +_MODULE_LOCKS_GLOBAL = threading.Lock() +_SWIGLU7_LOCK = threading.Lock() # serialises insertions into _SWIGLU7_FAST_CACHE + + +# ── D output-buffer cache ──────────────────────────────────────────────────── +# Single-entry greedy cache, keyed by (M, n_out, dtype, device_idx). The hot +# path in ``_matmul_custom_evt_cuda`` reads/writes this dict directly (the +# resolver was inlined for ~1 μs/call savings), so this module only owns the +# storage and a disable switch. +# +# FX-pass guards (K % 8 == 0; generic N % 4 == 0; swiglu7 N % 8 == 0) ensure +# n_out is always a multiple of CUTLASS's AlignmentC = 4 elements, so D is +# always allocated as a true-contiguous ``torch.empty((M, n_out), dtype)`` — +# no padded stride / scratch buffer route exists. Anything that violates the +# guards is rejected upstream and falls back to torch.compile's default mm. +# +# To opt out (e.g. when bench-scripting with overlapping streams), set the +# env var ``MAGI_EVT_DISABLE_D_CACHE=1``. +_D_BUF_CACHE: dict = {} +_D_CACHE_DISABLED: bool = os.environ.get("MAGI_EVT_DISABLE_D_CACHE", "0") not in ("0", "", "false", "False") + + +def _cutlass_root() -> str: + # Default install location is /opt/cutlass (Dockerfile clones the source + # tree there). Override with MAGI_CUTLASS_ROOT for ad-hoc dev checkouts. + return os.environ.get("MAGI_CUTLASS_ROOT", "/opt/cutlass") + + +def _evt_build_dir(key: str) -> str: + cache_root = get_compile_config().cache_root_dir + return os.path.join(cache_root, "evt_kernels", key) + + +def _per_key_lock(key: str) -> threading.Lock: + """Return the per-key build lock; coalesces concurrent compile requests.""" + with _MODULE_LOCKS_GLOBAL: + lock = _MODULE_LOCKS.get(key) + if lock is None: + lock = threading.Lock() + _MODULE_LOCKS[key] = lock + return lock + + +def _compile_evt_module( + ir_json: str, + a_dtype: torch.dtype, + b_dtype: torch.dtype, + b_layout: str = "row", + m_bucket: str = "medium", + N: int = 0, + K: int = 0, +): + """Render + JIT-compile the EVT kernel for ``ir_json``. Process-level cached. + + Cache key: (IR, A dtype, B dtype, b_layout, m_bucket, N, K). Each distinct + weight (N, K) lowers to its own .cu — even though the .cu source is + identical (N/K stay runtime variables), splitting the modules gives every + (N, K) its own runner instance with isolated `best_idx_`. This avoids + cross-(N, K) autotune contamination and matches the user's per-(N, K) + cache layout: e.g. two distinct (N, K) × two M-buckets ⇒ 4 .cu modules. + """ + # Hot-path fast cache: skip ``json.dumps + sha256`` (~10–30 μs each) on + # subsequent calls with the same inputs. + fast_key = (ir_json, a_dtype, b_dtype, b_layout, m_bucket, N, K) + cached = _MODULE_FAST_CACHE.get(fast_key) + if cached is not None: + return cached + + if b_layout not in ("row", "col"): + raise ValueError(f"b_layout must be 'row' or 'col', got {b_layout!r}") + a_str = _DTYPE_TO_STR[a_dtype] + b_str = _DTYPE_TO_STR[b_dtype] + extended = json.dumps( + { + "ir": ir_json, + "a": a_str, + "b": b_str, + "b_layout": b_layout, + "m_bucket": m_bucket, + "N": int(N), + "K": int(K), + "version": 3, + }, + sort_keys=True, + ).encode("utf-8") + key = hashlib.sha256(extended).hexdigest() + + cached = _MODULE_CACHE.get(key) + if cached is not None: + _MODULE_FAST_CACHE[fast_key] = cached + return cached + + lock = _per_key_lock(key) + with lock: + cached = _MODULE_CACHE.get(key) + if cached is not None: + _MODULE_FAST_CACHE[fast_key] = cached + return cached + + # Re-hydrate the IR tree from JSON for codegen. + ir = _ir_from_json(ir_json) + src = render_evt_cu(ir, a_str, b_str, cache_key_str=key, b_layout=b_layout, m_bucket=m_bucket) + + build_dir = _evt_build_dir(key) + os.makedirs(build_dir, exist_ok=True) + src_path = os.path.join(build_dir, "evt.cu") + # Write atomically (tmp + rename) so concurrent processes don't see a + # half-written file. Use a process-specific tmp name to avoid races + # across multiple rank processes generating the same kernel. + tmp_path = f"{src_path}.{os.getpid()}.tmp" + with open(tmp_path, "w") as f: + f.write(src) + os.replace(tmp_path, src_path) + + cutlass_root = _cutlass_root() + from torch.utils.cpp_extension import load + + # cpp_extension.load uses its own file lock under build_directory, so + # multi-process races resolve to a single nvcc invocation. + module = load( + name=f"magi_evt_{key[:12]}", + sources=[src_path], + extra_include_paths=[ + os.path.join(cutlass_root, "include"), + os.path.join(cutlass_root, "tools", "util", "include"), + ], + extra_cflags=["-O3", "-std=c++17"], + extra_cuda_cflags=["-std=c++17", "-O3", "--expt-relaxed-constexpr", "-gencode=arch=compute_120,code=sm_120"], + build_directory=build_dir, + verbose=False, + ) + _MODULE_CACHE[key] = module + _MODULE_FAST_CACHE[fast_key] = module + return module + + +# ── IR (de)serialisation ───────────────────────────────────────────────────── + + +def to_ir_json(node) -> str: + from .evt_ir import to_canonical_json + + return to_canonical_json(node) + + +def _ir_from_json(s: str): + """Inverse of ``to_canonical_json``. Used only to drive codegen at compile + time — the FX pass holds the original Python objects and never round-trips + its own IR through JSON in a hot loop.""" + d = json.loads(s) + return _node_from_dict(d) + + +def _node_from_dict(d): + kind = d["kind"] + if kind == "accum": + return Accum() + if kind == "row_bcast": + return RowBroadcast(input_idx=d["input_idx"], dtype=d["dtype"]) + if kind == "col_bcast": + return ColBroadcast(input_idx=d["input_idx"], dtype=d["dtype"]) + if kind == "aux_load": + return AuxLoad(input_idx=d["input_idx"], dtype=d["dtype"]) + if kind == "compute": + scalar = d.get("scalar") + scalar_val: Optional[float] = float(scalar) if scalar is not None else None + return Compute(op=d["op"], children=tuple(_node_from_dict(c) for c in d["children"]), scalar=scalar_val) + if kind == "store": + return Store(child=_node_from_dict(d["child"]), out_dtype=d["out_dtype"]) + raise ValueError(f"Unknown IR kind {kind!r}") + + +# ── swiglu7 dual-gemm extension loader ──────────────────────────────────────── +# Per-(m_bucket, N, K) cache. The .cu source is identical across keys (N/K stay +# runtime variables); we still build separate modules so each runner instance +# hosts exactly one (N, K), giving every weight shape its own isolated +# best_idx_. Two distinct (N, K) × two M-buckets ⇒ 4 modules. +_SWIGLU7_FAST_CACHE: dict = {} # (m_bucket, N, K) → loaded module +_SWIGLU7_BUILD_LOCKS: dict = {} # (m_bucket, N, K) → threading.Lock + + +def _compile_swiglu7_dual(m_bucket: str, N: int, K: int): + """Lazy-load a per-(bucket, N, K) instance of the vendored DualGemm kernel. + + Parameters + ---------- + m_bucket : "small" | "medium" | "large" + Bucket of the activation M dim — included in the cache key so e.g. + small-M (decode) can autotune to a different best tile than large-M + (prefill) for the same (N, K). + N, K : int + Static weight shape from B (the underlying (N, K) row-major tensor). + Distinct (N, K) get distinct modules so their autotune state is + independent. + """ + fast_key = (m_bucket, int(N), int(K)) + cached = _SWIGLU7_FAST_CACHE.get(fast_key) + if cached is not None: + return cached + + with _SWIGLU7_LOCK: + lock = _SWIGLU7_BUILD_LOCKS.get(fast_key) + if lock is None: + lock = threading.Lock() + _SWIGLU7_BUILD_LOCKS[fast_key] = lock + with lock: + cached = _SWIGLU7_FAST_CACHE.get(fast_key) + if cached is not None: + return cached + + cutlass_root = _cutlass_root() + here = os.path.dirname(os.path.abspath(__file__)) + src = os.path.join(here, "cutlass_kernels", "swiglu7_epi_one_stage.cu") + if not os.path.exists(src): + raise FileNotFoundError(f"vendored swiglu7 source not found: {src}") + cache_root = get_compile_config().cache_root_dir + # Build dir embeds (bucket, N, K) so distinct keys get their own + # build artefacts. cpp_extension uses the dir as the cache identity. + build_tag = f"{m_bucket}_N{N}_K{K}" + build_dir = os.path.join(cache_root, "evt_kernels", f"swiglu7_dual_{build_tag}") + os.makedirs(build_dir, exist_ok=True) + from torch.utils.cpp_extension import load + + module = load( + name=f"magi_swiglu7_dual_{build_tag}", + sources=[src], + extra_include_paths=[ + os.path.join(cutlass_root, "include"), + os.path.join(cutlass_root, "tools", "util", "include"), + os.path.join(cutlass_root, "examples"), + os.path.join(here, "cutlass_kernels"), + ], + extra_cflags=["-O3", "-std=c++17"], + extra_cuda_cflags=["-std=c++17", "-O3", "--expt-relaxed-constexpr", "-gencode=arch=compute_120,code=sm_120"], + build_directory=build_dir, + verbose=False, + ) + _SWIGLU7_FAST_CACHE[fast_key] = module + return module + + +# ── torch.library backend impls ─────────────────────────────────────────────── + + +# ── Dispatch fast-cache ────────────────────────────────────────────────────── +# Hot-path bottleneck reduction: collapse the four-step +# out_dtype_from_id → _m_bucket → _compile_* → mod.attr-lookup +# chain into a single dict.get() returning a pre-bound callable plus the +# small amount of immutable metadata the kernel-launch site needs. +# +# Key shape: (kind, ir_json, A.dtype, B.dtype, N, K, m_bucket, out_dtype). +# Most of these are static per FX-emit site (kind / ir_json / dtypes / N / K) +# — only m_bucket varies with M. So the cache reaches steady state after the +# first time each (site, bucket) is seen. +# +# Each entry holds: +# * kernel_call : pre-bound mod.evt_matmul_out / swiglu7_dual_matmul_out +# * is_evt : True for evt_row/evt_col (need extras list), False for swiglu7 +# * out_dtype : torch.dtype to pass to D allocation +class _DispatchEntry: + __slots__ = ("kernel_call", "is_evt", "out_dtype") + + def __init__(self, kernel_call, is_evt, out_dtype): + self.kernel_call = kernel_call + self.is_evt = is_evt + self.out_dtype = out_dtype + + +_DISPATCH_CACHE: dict = {} + + +def _resolve_dispatch(kind, ir_json, a_dtype, b_dtype, N_w, K_w, m_bucket, out_dtype): + """Slow-path resolver — compiles the .cu module (cache miss) and binds + the kernel callable. Cached by (kind, ir_json, A_dt, B_dt, N, K, bucket, + out_dtype) so each FX site × bucket only pays this once.""" + if kind == "swiglu7_dual": + mod = _compile_swiglu7_dual(m_bucket, N_w, K_w) + return _DispatchEntry(mod.swiglu7_dual_matmul_out, False, out_dtype) + if kind == "evt_row" or kind == "evt": + b_layout = "row" + elif kind == "evt_col": + b_layout = "col" + else: + raise ValueError(f"Unknown EVT kind {kind!r}") + mod = _compile_evt_module(ir_json, a_dtype, b_dtype, b_layout=b_layout, m_bucket=m_bucket, N=N_w, K=K_w) + return _DispatchEntry(mod.evt_matmul_out, True, out_dtype) + + +@torch.library.impl(_LIB, "matmul_custom_evt", "CUDA") +def _matmul_custom_evt_cuda(A, B, extras, ir_json, kind, n_out, out_dtype_id_): + """Runtime entry point for the EVT-fused matmul op. + + Hot path is heavily inlined to keep per-call Python overhead under ~2 μs: + one dict.get() resolves the kernel callable + metadata, then we allocate D + (with a single-entry greedy cache) and call straight into the C++ kernel. + + Layout contract — the FX pass owns this; do not rewrite operands here: + * ``kind == "evt_row"`` : B is contiguous (K, N) row-major. + * ``kind == "evt_col"`` : B is the underlying (N, K) row-major weight; the + kernel was rendered with ``LayoutB = ColumnMajor`` so it reads (K, N) + from the same bytes via stride (1, K). + * ``kind == "swiglu7_dual"`` : B is the underlying (N, K) row-major weight + (the FX pass already replaced the ``permute([1,0])`` view with its + operand). The DualGemm kernel reads it as ColumnMajor + ldB=2K. + + Calling ``.contiguous()`` on B here would silently break the col / swiglu7 + paths by materialising a (K, N) row-major copy that no longer matches the + LayoutB the kernel was compiled with — every B value would be wrong. + """ + # ── Step 1: resolve dispatch entry (one dict lookup on the fast path) ── + # B.size(0)/size(1) are slightly faster than .shape[0]/[1] (avoid Python + # tuple construction). For all 3 kinds B's leading dim ≠ K — the launcher + # / runner derives N internally from b_layout, but for the dispatch cache + # key we just need a stable per-site discriminator, so passing the raw + # B.size pair is enough. + B_size0 = B.size(0) + B_size1 = B.size(1) + M = A.size(0) + # Inline _m_bucket: avoid the ~300 ns function call. + if M <= 256: + m_bucket = "small" + elif M <= 2048: + m_bucket = "medium" + else: + m_bucket = "large" + # Inline out_dtype_from_id: skip the function call frame. + out_dtype = _ID_TO_DTYPE[out_dtype_id_] + # B's (N, K) interpretation depends on kind. For evt_row B is (K, N), + # for evt_col / swiglu7_dual B is the underlying (N, K). Either way we + # only need (B_size0, B_size1) to disambiguate distinct weights — the + # resolver re-computes N/K correctly for compilation. + a_dtype = A.dtype + b_dtype_ = B.dtype + fast_key = (kind, ir_json, a_dtype, b_dtype_, B_size0, B_size1, m_bucket, out_dtype) + entry = _DISPATCH_CACHE.get(fast_key) + if entry is None: + # Map B sizes to (N_w, K_w) in the layout the compile path expects. + if kind == "evt_row": + K_w, N_w = B_size0, B_size1 + else: + # evt_col / swiglu7_dual: B is (N, K) underlying weight. + N_w, K_w = B_size0, B_size1 + entry = _resolve_dispatch(kind, ir_json, a_dtype, b_dtype_, N_w, K_w, m_bucket, out_dtype) + _DISPATCH_CACHE[fast_key] = entry + + # ── Step 2: alloc / fetch D (greedy single-entry cache, inlined) ── + # FX pass guards (K % 8 == 0; generic N % 4 == 0; swiglu7 N % 8 == 0) + # ensure n_out is a multiple of CUTLASS AlignmentC = 4 for every dtype, + # so a plain ``torch.empty((M, n_out), dtype)`` is already CUTLASS- + # contiguous — no padded stride / scratch buffer route is required. + # Anything that violates the guards is rejected upstream and falls back + # to torch.compile's default mm. + if _D_CACHE_DISABLED: + D = torch.empty((M, n_out), device=A.device, dtype=out_dtype) + else: + dev_idx = A.device.index or 0 + d_key = (M, n_out, out_dtype, dev_idx) + D = _D_BUF_CACHE.get(d_key) + if D is None: + D = torch.empty((M, n_out), device=A.device, dtype=out_dtype) + _D_BUF_CACHE.clear() + _D_BUF_CACHE[d_key] = D + + # ── Step 3: dispatch — pre-bound callable, single C++ trampoline ── + kernel_call = entry.kernel_call + if entry.is_evt: + kernel_call(A, B, extras, D) + else: + # swiglu7_dual: extras is always [] here (FX pass guarantees). + kernel_call(A, B, D) + return D + + +@torch.library.register_fake("magi_epilogue::matmul_custom_evt") +def _matmul_custom_evt_fake(A, B, extras, ir_json, kind, n_out, out_dtype_id_): + out_dtype = out_dtype_from_id(out_dtype_id_) + # Contiguous (M, n_out) — see _D_BUF_CACHE comment for why padding is + # never needed under the FX-pass alignment guards. + return A.new_empty((A.shape[0], n_out), dtype=out_dtype) diff --git a/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py new file mode 100644 index 0000000..e88b386 --- /dev/null +++ b/magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/matmul_epilogue_fusion.py @@ -0,0 +1,719 @@ +# Copyright (c) 2026 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FX pass that fuses aten.mm + elementwise epilogue into a CUTLASS EVT call. + +Two backends: + * Generic EVT — for the 6 non-swiglu activations and 1-D bias/scale variants. + Builds an IR tree (see ``evt_ir.py``), serialises to JSON, replaces the + matched chain with a single ``torch.ops.magi_epilogue.matmul_custom_evt`` + call. The runtime renders + JIT-compiles a CUTLASS Sm80EVT kernel keyed by + the IR hash (see ``evt_runtime.py``). + * swiglu7 — pattern-matches the canonical recipe (slice-stride-2 + dual + clamps + scaled SiLU) and dispatches to a vendored DualGemm one-stage + kernel that writes (M, N/2) directly. + +Eligibility gates (alignment, B layout, dtype) are checked up-front. Anything +not eligible stays as ``aten.mm`` for cuBLAS to handle. We do NOT fall back to +the Triton fusion path on sm120; per user decision, EVT replaces it entirely. +""" + +from __future__ import annotations + +import operator +from typing import List, Optional, Tuple + +import torch +import torch.fx as fx + +from magi_compiler.cuda.device import device_capability_major +from magi_compiler.passes.pass_base import MagiInductorPass + +from . import evt_runtime # ensures torch.library op + fake impl are registered +from .evt_ir import Accum, AuxLoad, ColBroadcast, Compute, RowBroadcast, Store, is_trivial, num_extras, to_canonical_json + +# ── Op tables ──────────────────────────────────────────────────────────────── +# Pure passthrough — no value or dtype change; alias the same IR node. +_PASSTHROUGH_OPS = frozenset({torch.ops.aten.clone.default, torch.ops.aten.contiguous.default, torch.ops.aten.alias.default}) + +# Dtype-conversion ops; the EVT compute is always fp32 internally so these are +# absorbed as no-ops as long as the start/end of the chain reach the same final +# precision (we capture that via the Store node's out_dtype). +_TYPE_CONV_OPS = frozenset({torch.ops.prims.convert_element_type.default, torch.ops.aten._to_copy.default}) + +# Unary ops with a direct EVT IR equivalent. +_UNARY_OPS = { + torch.ops.aten.neg.default: "neg", + torch.ops.aten.sigmoid.default: "sigmoid", + torch.ops.aten.tanh.default: "tanh", + torch.ops.aten.silu.default: "silu", + torch.ops.aten.relu.default: "relu", + torch.ops.aten.square.default: "square", + torch.ops.aten.erf.default: "erf", + torch.ops.aten.exp.default: "exp", + torch.ops.aten.log.default: "log", + torch.ops.aten.sqrt.default: "sqrt", + torch.ops.aten.rsqrt.default: "rsqrt", + torch.ops.aten.abs.default: "abs", +} + +# Binary tensor ops. +_BINARY_OPS = { + torch.ops.aten.add.Tensor: "add", + torch.ops.aten.sub.Tensor: "sub", + torch.ops.aten.mul.Tensor: "mul", + torch.ops.aten.div.Tensor: "div", + torch.ops.aten.maximum.default: "max", + torch.ops.aten.minimum.default: "min", + operator.add: "add", + operator.sub: "sub", + operator.mul: "mul", + operator.truediv: "div", +} + +# Scalar binary ops → SCALAR_UNARY_OPS in IR. +_SCALAR_BINARY_TO_SCALAR_UNARY = { + torch.ops.aten.add.Scalar: "add_scalar", + torch.ops.aten.sub.Scalar: "sub_scalar", + torch.ops.aten.mul.Scalar: "mul_scalar", + torch.ops.aten.div.Scalar: "div_scalar", +} + + +# Output-dtype encode helper (mirrors evt_runtime). +_DTYPE_TO_STR = {torch.bfloat16: "bfloat16", torch.float16: "float16", torch.float32: "float32"} + + +def _val_dtype(node) -> Optional[torch.dtype]: + val = node.meta.get("val") if isinstance(node, fx.Node) else None + return val.dtype if val is not None else None + + +def _val_shape(node) -> Optional[Tuple]: + val = node.meta.get("val") if isinstance(node, fx.Node) else None + return tuple(val.shape) if val is not None else None + + +def _val_stride(node) -> Optional[Tuple]: + val = node.meta.get("val") if isinstance(node, fx.Node) else None + try: + return tuple(val.stride()) if val is not None else None + except Exception: + return None + + +def _is_static_int(x) -> bool: + return type(x) is int + + +def _is_transpose_node(n) -> bool: + """True iff ``n`` is a 2-D transpose (aten.t / transpose(0,1) / permute([1,0])).""" + if not isinstance(n, fx.Node) or n.op != "call_function": + return False + if n.target is torch.ops.aten.t.default: + return True + if n.target is torch.ops.aten.transpose.int: + # transpose(x, dim0, dim1) — accept (0, 1) on a 2D tensor. + if len(n.args) >= 3: + d0, d1 = n.args[1], n.args[2] + return {d0, d1} == {0, 1} + return False + if n.target is torch.ops.aten.permute.default: + # permute(x, [1, 0]) on a 2D tensor. + if len(n.args) >= 2: + perm = n.args[1] + return list(perm) == [1, 0] + return False + return False + + +def _b_layout_kind(B_node): + """Classify B for the EVT generic path. + + Returns (b_layout, underlying_b_node, n_dim) where: + * b_layout = "row" : B is (K, N) row-major contiguous; pass B as-is. + * b_layout = "col" : B is a stride-transpose of a contiguous (N, K) + tensor; pass the underlying tensor; kernel uses + LayoutB=ColumnMajor. + * (None, None, None) : B is not in a supported layout. + """ + shape = _val_shape(B_node) + stride = _val_stride(B_node) + if shape is None or stride is None or len(shape) != 2: + return None, None, None + K_or_N0, N_or_K1 = shape[0], shape[1] + # Contiguous (K, N): row layout. N = shape[1]. + if stride == (N_or_K1, 1): + return "row", B_node, N_or_K1 + # Stride-transposed (K, N) view of a contig (N, K) weight: stride == (1, K). + # The underlying tensor is the transpose-producer's input when the FX + # graph models the view explicitly via t/transpose/permute([1,0]); fall + # back to using B itself (its data_ptr is the same). + if _is_transpose_node(B_node): + weight = B_node.args[0] + w_shape = _val_shape(weight) if isinstance(weight, fx.Node) else None + w_stride = _val_stride(weight) if isinstance(weight, fx.Node) else None + if w_shape is not None and len(w_shape) == 2 and w_stride == (w_shape[1], 1): + # weight is (N, K) row-major contig; N = w_shape[0]. + return "col", weight, w_shape[0] + # Generic stride-transposed view (no explicit transpose node) — also OK: + # we read the same memory bytes as a (N, K) row-major buffer at B itself. + if stride == (1, K_or_N0): + # B is (K, N) col-major == underlying (N, K) row-major. We don't have + # an explicit weight node so we pass B directly; the kernel reads + # (N, K) with N = shape[1], K = shape[0]. Detection via stride alone. + return "col", B_node, N_or_K1 + return None, None, None + + +# ── Pass ───────────────────────────────────────────────────────────────────── + + +# Sentinel returned by _try_fuse_evt to communicate "abort, leave mm intact". +_ABORT = object() + + +class MatmulEvtEpilogueFusionPass(MagiInductorPass): + """Fuse aten.mm + elementwise chain into a CUTLASS EVT call (sm_120).""" + + def __init__(self, allow_extras: bool = True) -> None: + # On non-sm120 we degrade to a no-op; the manager wires us only on + # sm120 anyway, but defending against misuse is cheap. + self._enabled = device_capability_major() >= 12 + self.allow_extras = allow_extras + + def __call__(self, graph: fx.Graph) -> bool: + if not self._enabled: + return False + fused = 0 + for node in list(graph.nodes): + if node.op != "call_function": + continue + if node.target not in (torch.ops.aten.mm.default, torch.ops.aten.mm): + continue + r = self._try_fuse_evt(graph, node) + if r: + fused += 1 + if fused: + graph.eliminate_dead_code() + return fused > 0 + + # ── Generic EVT chain walker ────────────────────────────────────────────── + + def _try_fuse_evt(self, graph: fx.Graph, mm_node: fx.Node) -> bool: + A, B = mm_node.args[0], mm_node.args[1] + if not isinstance(A, fx.Node) or not isinstance(B, fx.Node): + return False + a_dtype = _val_dtype(A) + b_dtype = _val_dtype(B) + if a_dtype not in (torch.bfloat16, torch.float16) or a_dtype != b_dtype: + return False + # Alignment gates — bf16/fp16 require K % 8. + a_shape = _val_shape(A) + b_shape = _val_shape(B) + if a_shape is None or b_shape is None or len(a_shape) != 2 or len(b_shape) != 2: + return False + K = a_shape[1] + N = b_shape[1] + if _is_static_int(K) and (K % 8 != 0): + return False + if _is_static_int(N) and (N % 4 != 0): + return False + + # node_to_ir: each fused fx.Node → its IR subtree. mm_node maps to Accum. + node_to_ir: dict = {mm_node: Accum()} + # In-order list of fused fx nodes (for erase + escape detection). + fused_nodes: List[fx.Node] = [mm_node] + # Walked-and-removed nodes including type-conv/passthrough that don't + # appear in node_to_ir as new IR nodes (they alias their input). + walk_seen: List[fx.Node] = [mm_node] + # External tensors injected as RowBroadcast/ColBroadcast/AuxLoad leaves. + # extras_nodes[i] is the fx.Node passed at runtime as extras[i]. + extras_nodes: List[fx.Node] = [] + # Tracks whether the IR has any swiglu7-style slice. If so we abort + # generic EVT and try the swiglu7 matcher instead. + saw_slice = False + + last_node = mm_node + last_ir = node_to_ir[mm_node] + + # Walk consumers in source order, greedily absorbing supported ops. + curr = mm_node.next + while curr is not None and curr.op != "output": + uses_fused = any(isinstance(a, fx.Node) and a in node_to_ir for a in curr.args) + if not uses_fused: + curr = curr.next + continue + + target = curr.target + + # ── Pass-through (clone / contiguous / alias) ───────────────────── + if target in _PASSTHROUGH_OPS: + node_to_ir[curr] = node_to_ir[curr.args[0]] + walk_seen.append(curr) + last_node = curr + last_ir = node_to_ir[curr] + curr = curr.next + continue + + # ── Type conversion (no-op in fp32 EVT) ─────────────────────────── + if target in _TYPE_CONV_OPS: + node_to_ir[curr] = node_to_ir[curr.args[0]] + walk_seen.append(curr) + last_node = curr + last_ir = node_to_ir[curr] + curr = curr.next + continue + + # ── Pure view ops (only if shape unchanged) ─────────────────────── + if target in (torch.ops.aten.view.default, torch.ops.aten.reshape.default, torch.ops.aten._unsafe_view.default): + in_shape = _val_shape(curr.args[0]) + out_shape = _val_shape(curr) + if in_shape == out_shape: + node_to_ir[curr] = node_to_ir[curr.args[0]] + walk_seen.append(curr) + last_node = curr + last_ir = node_to_ir[curr] + curr = curr.next + continue + break + + # ── Slice stride-2 (swiglu marker) ──────────────────────────────── + if target is torch.ops.aten.slice.Tensor: + step = curr.args[4] if len(curr.args) > 4 else curr.kwargs.get("step", 1) + if step == 2: + saw_slice = True + break + + # ── Unary ops ───────────────────────────────────────────────────── + if target in _UNARY_OPS: + op_name = _UNARY_OPS[target] + child_ir = node_to_ir[curr.args[0]] + ir = Compute(op_name, (child_ir,)) + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + + # ── GELU (default = erf, alternative = tanh) ────────────────────── + if target is torch.ops.aten.gelu.default: + approx = curr.kwargs.get("approximate", "none") + op_name = "gelu_tanh" if approx == "tanh" else "gelu_erf" + child_ir = node_to_ir[curr.args[0]] + ir = Compute(op_name, (child_ir,)) + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + + # ── Scalar variants of add/sub/mul/div ──────────────────────────── + if target in _SCALAR_BINARY_TO_SCALAR_UNARY: + op_name = _SCALAR_BINARY_TO_SCALAR_UNARY[target] + child_ir = node_to_ir[curr.args[0]] + if not isinstance(curr.args[1], (int, float)): + break + scalar = float(curr.args[1]) + ir = Compute(op_name, (child_ir,), scalar=scalar) + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + + # ── Clamp family ────────────────────────────────────────────────── + if target in (torch.ops.aten.clamp.default, torch.ops.aten.clamp_min.default, torch.ops.aten.clamp_max.default): + child_ir = node_to_ir[curr.args[0]] + if target is torch.ops.aten.clamp_min.default: + lo = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("min") + hi = None + elif target is torch.ops.aten.clamp_max.default: + lo = None + hi = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("max") + else: + lo = curr.args[1] if len(curr.args) > 1 else curr.kwargs.get("min") + hi = curr.args[2] if len(curr.args) > 2 else curr.kwargs.get("max") + if (lo is not None and not isinstance(lo, (int, float))) or ( + hi is not None and not isinstance(hi, (int, float)) + ): + break + ir_now = child_ir + if lo is not None: + ir_now = Compute("clamp_min_c", (ir_now,), scalar=float(lo)) + if hi is not None: + ir_now = Compute("clamp_max_c", (ir_now,), scalar=float(hi)) + node_to_ir[curr] = ir_now + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir_now + curr = curr.next + continue + + # ── pow.Tensor_Scalar — only the small-int special-cases ────────── + if target is torch.ops.aten.pow.Tensor_Scalar: + exp = curr.args[1] if len(curr.args) > 1 else None + child_ir = node_to_ir[curr.args[0]] + if exp == 2 or exp == 2.0: + ir = Compute("square", (child_ir,)) + elif isinstance(exp, (int, float)): + ir = Compute("pow_scalar", (child_ir,), scalar=float(exp)) + else: + break + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + + # ── Binary tensor ops ───────────────────────────────────────────── + if target in _BINARY_OPS: + op_name = _BINARY_OPS[target] + lhs_raw = curr.args[0] + rhs_raw = curr.args[1] + # Fold int/float scalars on the RHS to scalar variants. + if isinstance(rhs_raw, (int, float)) and isinstance(lhs_raw, fx.Node) and lhs_raw in node_to_ir: + scalar_op = {"add": "add_scalar", "sub": "sub_scalar", "mul": "mul_scalar", "div": "div_scalar"}.get( + op_name + ) + if scalar_op is None: + break + ir = Compute(scalar_op, (node_to_ir[lhs_raw],), scalar=float(rhs_raw)) + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + # Fold scalar-on-LHS for commutative ops; for sub/div we need rsub/rdiv. + if isinstance(lhs_raw, (int, float)) and isinstance(rhs_raw, fx.Node) and rhs_raw in node_to_ir: + if op_name in ("add", "mul"): + scalar_op = "add_scalar" if op_name == "add" else "mul_scalar" + ir = Compute(scalar_op, (node_to_ir[rhs_raw],), scalar=float(lhs_raw)) + elif op_name == "sub": + ir = Compute("rsub_scalar", (node_to_ir[rhs_raw],), scalar=float(lhs_raw)) + else: + break + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + # Both tensor — either internal (already in IR) or external. + lhs_ir = self._ir_for_arg(lhs_raw, node_to_ir, extras_nodes, A, B) + rhs_ir = self._ir_for_arg(rhs_raw, node_to_ir, extras_nodes, A, B) + if lhs_ir is None or rhs_ir is None: + break + ir = Compute(op_name, (lhs_ir, rhs_ir)) + node_to_ir[curr] = ir + fused_nodes.append(curr) + walk_seen.append(curr) + last_node = curr + last_ir = ir + curr = curr.next + continue + + # Unsupported op — stop greedy walk. + break + + # If we saw a stride-2 slice and the chain is plausibly swiglu7, try + # the dedicated matcher. It rebuilds independently from mm_node. + if saw_slice: + return self._try_fuse_swiglu7(graph, mm_node) + + # Verify we made progress. + if last_ir is node_to_ir[mm_node]: + return False # only Accum — replacing cuBLAS with EVT is no win + + # Refuse if any escape: an intermediate fused node is consumed outside + # the fused region. (EVT has no "extra outputs"; the user explicitly + # opted out of cross-domain fan-out.) + # + # The exclusion ``n is not last_node`` is intentional — the last node + # in the fused chain becomes the EVT op's output and is allowed to + # have downstream consumers (that's the whole point of fusion). + # Earlier writes ([:-1] explicitly skips the last position) must not + # have any external user, otherwise the fused chain would silently + # drop their value. This previously read ``walk_seen[:-0]`` which is + # ``walk_seen[:0]`` (an empty slice!) so escape detection was a no-op + # and trivially-fusable chains like ``mm → add(residual) → square`` + # were emitted even when ``add(residual)`` was reused downstream. + fused_set = set(fused_nodes) | set(walk_seen) + for n in walk_seen[:-1]: + for u in n.users: + if u not in fused_set: + return False + + # Final eligibility check: A contiguous, B in a supported layout. + a_stride = _val_stride(A) + if a_stride is None: + return False + a_shape_now = _val_shape(A) + if a_stride != (a_shape_now[1], 1): + return False + b_layout, b_underlying, n_dim = _b_layout_kind(B) + if b_layout is None: + return False + + # Determine output dtype from the last fused node's FakeTensor metadata. + out_dt = _val_dtype(last_node) or torch.bfloat16 + if out_dt not in _DTYPE_TO_STR: + return False + + ir_root = Store(child=last_ir, out_dtype=_DTYPE_TO_STR[out_dt]) + if is_trivial(ir_root): + return False + # If extras are disabled, refuse any IR that needs them. + if not self.allow_extras and num_extras(ir_root) > 0: + return False + + ir_json = to_canonical_json(ir_root) + n_out = n_dim + out_dt_id = evt_runtime.out_dtype_id(out_dt) + kind = "evt_row" if b_layout == "row" else "evt_col" + + with graph.inserting_after(last_node): + new_node = graph.call_function( + torch.ops.magi_epilogue.matmul_custom_evt.default, + args=(A, b_underlying, extras_nodes, ir_json, kind, n_out, out_dt_id), + ) + # Propagate FakeTensor meta so downstream Inductor checks pass. + try: + val_last = last_node.meta.get("val") + if val_last is not None: + # Propagate but with 128B-aligned stride matching what the + # CUDA impl actually returns. + new_val = val_last.new_empty_strided( + val_last.shape, (evt_runtime._aligned_n_stride(int(val_last.shape[-1]), val_last.dtype), 1) + ) + new_node.meta["val"] = new_val + except Exception: + pass + + last_node.replace_all_uses_with(new_node) + for n in reversed(walk_seen): + if len(n.users) == 0 and n is not new_node: + graph.erase_node(n) + return True + + def _ir_for_arg(self, arg, node_to_ir, extras_nodes, A_node, B_node): + """Return an IR subtree for a binary-op operand. Internal → IR; external + → leaf (RowBroadcast / ColBroadcast / AuxLoad). None ⇒ abort.""" + if not isinstance(arg, fx.Node): + return None + if arg in node_to_ir: + return node_to_ir[arg] + if not self.allow_extras: + return None + # Classify external tensor by shape relative to (M, N). + a_shape = _val_shape(A_node) + b_shape = _val_shape(B_node) + if a_shape is None or b_shape is None: + return None + M = a_shape[0] + N = b_shape[1] + shape = _val_shape(arg) + stride = _val_stride(arg) + dt = _val_dtype(arg) + if shape is None or dt is None: + return None + dt_str = _DTYPE_TO_STR.get(dt) + if dt_str is None: + return None + # 1-D case: must distinguish (N,) vs (M,). Compare ints directly. + # When M is SymInt (dynamic batch dim) the M==N collision can't happen + # at compile time, so trust the (N,) match for RowBroadcast. Only the + # "both static + equal" case is ambiguous and we abort. + if len(shape) == 1: + n0 = shape[0] + m_is_static = _is_static_int(M) + n_is_static = _is_static_int(N) + if n_is_static and n0 == N: + # Could still collide with a (M,) col-broadcast iff M is also + # static and equal — abort in that ambiguous case. + if m_is_static and n0 == M: + return None + idx = self._add_extra(extras_nodes, arg) + return RowBroadcast(input_idx=idx, dtype=dt_str) + if m_is_static and n0 == M: + idx = self._add_extra(extras_nodes, arg) + return ColBroadcast(input_idx=idx, dtype=dt_str) + return None + if len(shape) == 2: + # (1, N) row-broadcast view. + if shape[0] == 1 and shape[1] == N: + idx = self._add_extra(extras_nodes, arg) + return RowBroadcast(input_idx=idx, dtype=dt_str) + # (M, 1) col-broadcast view. + if shape[1] == 1 and shape[0] == M: + idx = self._add_extra(extras_nodes, arg) + return ColBroadcast(input_idx=idx, dtype=dt_str) + # Full (M, N) aux load — require row-major contiguous. + if shape[0] == M and shape[1] == N and stride is not None and stride[1] == 1: + idx = self._add_extra(extras_nodes, arg) + return AuxLoad(input_idx=idx, dtype=dt_str) + return None + + def _add_extra(self, extras_nodes, arg) -> int: + for i, e in enumerate(extras_nodes): + if e is arg: + return i + extras_nodes.append(arg) + return len(extras_nodes) - 1 + + # ── swiglu7 special-case ────────────────────────────────────────────────── + + def _try_fuse_swiglu7(self, graph: fx.Graph, mm_node: fx.Node) -> bool: + """Match the canonical swiglu7 epilogue and dispatch to DualGemm. + + We do not attempt to encode swiglu7 in the EVT IR (the dual GEMM is a + whole different kernel structure). Instead we walk forward from mm_node + looking for the exact pattern produced by ``athena.activation.swiglu7`` + after Inductor decomposition. + + On a successful match we emit the magi_epilogue.matmul_custom_evt op + with kind="swiglu7_dual". The ``B`` argument must be the underlying + weight tensor of shape (N, K) — typically the predecessor of an + ``aten.t`` node feeding the mm. + """ + # Recover the underlying weight: B should be a 2-D transpose + # (aten.t / transpose(0,1) / permute([1,0])) of a contiguous (N, K) + # weight. Otherwise bail (no two-stage fallback). + B_node = mm_node.args[1] + if not isinstance(B_node, fx.Node) or not _is_transpose_node(B_node): + return False + weight_node = B_node.args[0] + if not isinstance(weight_node, fx.Node): + return False + w_shape = _val_shape(weight_node) + w_stride = _val_stride(weight_node) + if w_shape is None or len(w_shape) != 2 or w_stride is None: + return False + N, K = w_shape + # N % 8 ensures (a) the gate/linear interleaved split is valid (N + # even) AND (b) n_out = N // 2 satisfies CUTLASS AlignmentC = 4 + # for bf16. This lets the runtime allocate D as a true-contiguous + # (M, n_out) tensor with no padded stride / scratch path. Real + # GAGA2 has N=27304 (% 8 == 0). Smaller misaligned N falls back + # to torch.compile's default mm + python silu chain. + if not (_is_static_int(N) and N % 8 == 0): + return False + if w_stride != (K, 1): + return False # not contiguous (N, K) — abort + a_dtype = _val_dtype(mm_node.args[0]) + if a_dtype != torch.bfloat16 or _val_dtype(weight_node) != torch.bfloat16: + return False + + # We walk the chain in source order and collect every node belonging to + # the swiglu7 epilogue — anything else aborts. We don't need to verify + # the exact structure (the kernel does that intrinsically); we just need + # to find the final tensor that becomes the chain's only output, plus + # the set of nodes to erase. + chain_nodes: List[fx.Node] = [] + chain_set: set = {mm_node} + last_chain_node: Optional[fx.Node] = None + curr = mm_node.next + while curr is not None and curr.op != "output": + uses_chain = any(isinstance(a, fx.Node) and a in chain_set for a in curr.args) + if not uses_chain: + curr = curr.next + continue + if curr.target not in ( + torch.ops.aten.slice.Tensor, + torch.ops.aten.clamp.default, + torch.ops.aten.clamp_min.default, + torch.ops.aten.clamp_max.default, + torch.ops.aten.sigmoid.default, + torch.ops.aten.mul.Tensor, + torch.ops.aten.add.Tensor, + torch.ops.aten.add.Scalar, + torch.ops.aten.mul.Scalar, + torch.ops.prims.convert_element_type.default, + torch.ops.aten._to_copy.default, + torch.ops.aten.clone.default, + torch.ops.aten.contiguous.default, + torch.ops.aten.alias.default, + torch.ops.aten.view.default, + torch.ops.aten.reshape.default, + torch.ops.aten._unsafe_view.default, + ): + # Non-whitelist op consuming the chain → it's the boundary. + # Finalise last_chain_node as the previous node and stop. + # The output-shape check below verifies we actually saw the + # swiglu7 pattern (chain output's last dim must equal N//2). + break + chain_nodes.append(curr) + chain_set.add(curr) + last_chain_node = curr + curr = curr.next + + if last_chain_node is None: + return False + # Output dtype from the final node. + out_dt = _val_dtype(last_chain_node) or torch.bfloat16 + out_shape = _val_shape(last_chain_node) + if out_shape is None or len(out_shape) != 2: + return False + if not _is_static_int(out_shape[1]) or out_shape[1] != N // 2: + # The swiglu7 output's last dim must be N/2. + return False + + # No escape: every chain node's external uses must funnel through the + # final node (otherwise the DualGemm kernel produces only D and we'd + # lose the intermediate consumer). + for n in chain_nodes[:-1]: + for u in n.users: + if u not in chain_set: + return False + + # Emit the call. We do NOT pass IR JSON — the swiglu7 path ignores it. + out_dt_id = evt_runtime.out_dtype_id(out_dt) + n_out = N // 2 + with graph.inserting_after(last_chain_node): + new_node = graph.call_function( + torch.ops.magi_epilogue.matmul_custom_evt.default, + args=(mm_node.args[0], weight_node, [], "", "swiglu7_dual", n_out, out_dt_id), + ) + try: + val_last = last_chain_node.meta.get("val") + if val_last is not None: + new_val = val_last.new_empty_strided( + val_last.shape, (evt_runtime._aligned_n_stride(int(val_last.shape[-1]), val_last.dtype), 1) + ) + new_node.meta["val"] = new_val + except Exception: + pass + + last_chain_node.replace_all_uses_with(new_node) + for n in reversed(chain_nodes): + if len(n.users) == 0 and n is not new_node: + graph.erase_node(n) + # Erase mm and the t() node if no longer used. + if len(mm_node.users) == 0: + graph.erase_node(mm_node) + if isinstance(B_node, fx.Node) and len(B_node.users) == 0: + graph.erase_node(B_node) + return True diff --git a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py index f6441e0..2672cef 100644 --- a/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py +++ b/magi_compiler/passes/piecewise_graph/post_grad_pass_manager.py @@ -18,10 +18,12 @@ from torch._inductor.custom_graph_pass import CustomGraphPass from ...config import PassConfig +from ...cuda.device import device_capability_major from ...utils import magi_logger, set_env_var from ...utils.envs import MAGI_PATTERN_MATCH_DEBUG from ..pass_base import InductorPass, get_pass_context from .fix_functionalization import FixFunctionalizationPass +from .fusion.blackwell_geforce.matmul_epilogue_fusion import MatmulEvtEpilogueFusionPass from .post_cleanup import PostCleanupPass @@ -80,7 +82,13 @@ def __call__(self, graph: fx.Graph): def configure(self, pass_config: PassConfig): self.pass_config = pass_config - # TODO: Register custom passes here (fusion, noop elimination, sequence parallelism, async TP, Ulysses overlap). + # Matmul + epilogue fusion. On sm_120 (Blackwell consumer / RTX 5090) + # we lower fused chains to a CUTLASS Sm80EVT kernel. Toggled via + # PassConfig.enable_mm_epilogue_fusion (default True). The device + # check is independent — even with the flag on, non-sm_120 hosts + # don't register the pass since its FX walker would just no-op. + if pass_config.enable_mm_epilogue_fusion and device_capability_major() >= 12: + self.add(MatmulEvtEpilogueFusionPass()) # needs a functional graph self.post_cleanup = PostCleanupPass() diff --git a/tests/feature_tests/test_matmul_epilogue_fusion.py b/tests/feature_tests/test_matmul_epilogue_fusion.py new file mode 100644 index 0000000..f6d7cfd --- /dev/null +++ b/tests/feature_tests/test_matmul_epilogue_fusion.py @@ -0,0 +1,788 @@ +# Copyright (c) 2025 SandAI. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the CUTLASS Sm80EVT matmul-epilogue fusion path on RTX 5090. + +Three families of checks: + + 1. Positive numerical equivalence: every supported epilogue (the 7 athena + activations + binary ops + 1-D bias) must match eager within bf16 tol. + 2. Fusion-actually-fired: the emitted graph must contain a + ``magi_epilogue.matmul_custom_evt`` node — a green numerical test alone + would silently pass even if fusion was skipped (eager == "compiled"). + 3. Negative fallback: shapes / dtypes / chains the EVT pass does NOT + support must keep the original ``aten.mm`` and run through cuBLAS. + Catches over-eager fusion that would corrupt downstream consumers. +""" + +from typing import Optional + +import pytest +import torch +import torch.fx as fx +import torch.nn as nn +import torch.nn.functional as F + +from magi_compiler.api import magi_compile +from magi_compiler.config import get_compile_config + +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") + +_SM120_ONLY = pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 12, + reason="CUTLASS EVT path targets sm_120 (Blackwell consumer)", +) + + +# ── Activations from athena/performer_v16/activation.py (verbatim) ──────────── + + +def high_precision_silu(x, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + return F.silu(x.to(torch.float32)).to(out_dtype) + + +def high_precision_sigmoid(x, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + return F.sigmoid(x.to(torch.float32)).to(out_dtype) + + +def high_precision_gelu(x, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + return F.gelu(x.to(torch.float32)).to(out_dtype) + + +def swiglu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + x = x.to(torch.float32) + x_glu, x_linear = x[..., ::2], x[..., 1::2] + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return (out_glu * (x_linear + 1)).to(out_dtype) + + +def gelu7(x, alpha: float = 1.702, limit: float = 7.0, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + x = x.to(torch.float32) + x_glu = x.clamp(min=None, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + return out_glu.to(out_dtype) + + +def relu_square(x, out_dtype: Optional[torch.dtype] = None): + out_dtype = x.dtype if out_dtype is None else out_dtype + return torch.square(F.relu(x.to(torch.float32))).to(out_dtype) + + +# ── Compile + fusion-side instrumentation ──────────────────────────────────── + + +class _FusionStats: + """Records what the EVT pass did to the graph during one ``magi_compile``. + + Captured by patching ``MatmulEvtEpilogueFusionPass.__call__`` for the scope + of a test. We track: + * mm_before — count of ``aten.mm`` nodes seen on entry + * mm_after — same after the pass + * fused_count — number of ``magi_epilogue.matmul_custom_evt`` nodes + inserted (i.e. how many mm sites the pass actually + replaced; ``mm_before - mm_after`` only matches when + fusion never aborts mid-walk). + * kinds — the ``kind`` arg of each emitted op, e.g. + ["evt_row", "swiglu7_dual"]. + + Tests assert against these to prove the pass made the right choice — a + purely numerical comparison against eager would silently pass even when + fusion was skipped (because both paths fall back to cuBLAS). + """ + + def __init__(self) -> None: + self.mm_before = 0 + self.mm_after = 0 + self.fused_count = 0 + self.kinds: list = [] + # out_dtype_id of each emitted op (args[6]). Encoded as + # bf16 → 0, fp16 → 1, fp32 → 2 (see evt_runtime._OUT_DTYPE_ID). + # Tests assert against this to catch silent dtype regressions in the + # FX pass's last-node meta lookup or codegen's ElementC typedef. + self.out_dtype_ids: list = [] + + +def _install_pass_instrument(): + """Returns (stats, restore_fn). Wraps the FX pass to record per-call deltas.""" + from magi_compiler.passes.piecewise_graph.fusion.blackwell_geforce import matmul_epilogue_fusion as P + + stats = _FusionStats() + original = P.MatmulEvtEpilogueFusionPass.__call__ + evt_op = torch.ops.magi_epilogue.matmul_custom_evt.default + mm_targets = (torch.ops.aten.mm.default, torch.ops.aten.mm) + + def _instrumented(self, graph: fx.Graph): + before = sum(1 for n in graph.nodes if n.op == "call_function" and n.target in mm_targets) + result = original(self, graph) + after = sum(1 for n in graph.nodes if n.op == "call_function" and n.target in mm_targets) + emitted_kinds = [] + emitted_out_dtype_ids = [] + for n in graph.nodes: + if n.op == "call_function" and n.target is evt_op: + # signature: (A, B, extras, ir_json, kind, n_out, out_dtype_id) + if len(n.args) >= 5: + emitted_kinds.append(n.args[4]) + if len(n.args) >= 7: + emitted_out_dtype_ids.append(n.args[6]) + stats.mm_before += before + stats.mm_after += after + stats.fused_count += len(emitted_kinds) + stats.kinds.extend(emitted_kinds) + stats.out_dtype_ids.extend(emitted_out_dtype_ids) + return result + + P.MatmulEvtEpilogueFusionPass.__call__ = _instrumented + + def restore(): + P.MatmulEvtEpilogueFusionPass.__call__ = original + + return stats, restore + + +def _compile_and_check( + model: nn.Module, + inputs, + *, + atol: float = 0.5, + rtol: float = 0.0, + expect_fused: int = -1, + expect_kinds: Optional[list] = None, + expect_out_dtype: Optional[torch.dtype] = None, + expect_actual_dtype: Optional[torch.dtype] = None, + dynamic_arg_dims=None, + cast_model_to_bf16: bool = True, +): + """Compile ``model``, run it on ``inputs``, compare against eager. + + Parameters + ---------- + model, inputs + ``inputs`` is a tuple/list passed positionally to forward. + atol, rtol + Numerical tolerance: ``|actual - expected| <= atol + rtol*|expected|``. + expect_fused + Number of mm sites the pass MUST have replaced. Use 0 for negative + tests (fusion must NOT fire). -1 disables the check. + expect_kinds + If set, the multiset of emitted op ``kind`` args must equal this list. + E.g. ``["swiglu7_dual"]`` for the swiglu7 special-case path. + expect_out_dtype + If set, every emitted op's ``out_dtype_id`` (args[6]) MUST decode to + this dtype. Catches silent regressions where the FX pass picks the + wrong terminal-node dtype, or where Inductor inserts an extra cast + that the IR walker wasn't expecting. + expect_actual_dtype + If set, the runtime result tensor MUST have this dtype. Independent + check from ``expect_out_dtype`` — they should agree but a mismatch + between them would mean the codegen's StoreD typedef diverged from + the op's declared out_dtype_id. + dynamic_arg_dims + Forwarded to magi_compile. Defaults to making the first arg's M + dynamic (matches our fusion guards). + cast_model_to_bf16 + Default True (mirrors the standard test setup). Pass False when the + model already has the dtype mix you want (e.g. fp16-only or mixed + bf16 / fp16 weights). + """ + if dynamic_arg_dims is None: + # Use the model's forward signature to pick the first arg name. + import inspect + + params = list(inspect.signature(model.forward).parameters) + if not params: + dynamic_arg_dims = {} + else: + dynamic_arg_dims = {params[0]: 0} + + model = model.cuda() + # Use bfloat16 by default so the EVT pass actually fires (the pass + # requires bf16/fp16). Skip the auto-cast for tests that explicitly + # set up a different dtype mix. + if cast_model_to_bf16 and any(p.dtype.is_floating_point for p in model.parameters()): + model = model.bfloat16() + # Disable gradients on parameters; otherwise magi_compile / aot_autograd + # produces a forward+backward joint graph and the mm node has an extra + # user (the saved tensor for backward), which the EVT escape detector + # correctly refuses to fuse. + for p in model.parameters(): + p.requires_grad_(False) + + with torch.no_grad(): + expected = model(*inputs) + + get_compile_config().disable_cache = True + stats, restore = _install_pass_instrument() + try: + compiled_model = magi_compile(model, dynamic_arg_dims=dynamic_arg_dims) + with torch.no_grad(): + actual = compiled_model(*inputs) + finally: + restore() + + # Numerical check. + abs_diff = (actual - expected).abs() + tol = atol + rtol * expected.abs() + max_violation = (abs_diff - tol).max().item() + assert max_violation <= 0, ( + f"Fused result outside tolerance: " + f"max(|diff| - tol) = {max_violation:.4f}, " + f"max |diff| = {abs_diff.max().item():.4f}, " + f"fusion stats: fused={stats.fused_count} kinds={stats.kinds}" + ) + + # Fusion-actually-fired check. + if expect_fused >= 0: + assert stats.fused_count == expect_fused, ( + f"Expected {expect_fused} fused mm sites, got {stats.fused_count}. " + f"mm_before={stats.mm_before} mm_after={stats.mm_after} " + f"emitted kinds={stats.kinds}" + ) + if expect_kinds is not None: + assert sorted(stats.kinds) == sorted(expect_kinds), ( + f"Expected emitted kinds {sorted(expect_kinds)}, " f"got {sorted(stats.kinds)}" + ) + if expect_out_dtype is not None: + from magi_compiler.passes.piecewise_graph.fusion.blackwell_geforce.evt_runtime import out_dtype_from_id + + assert stats.out_dtype_ids, ( + f"expect_out_dtype={expect_out_dtype} but no fusion fired " f"(out_dtype_ids list is empty)" + ) + decoded = [out_dtype_from_id(i) for i in stats.out_dtype_ids] + for got in decoded: + assert got == expect_out_dtype, ( + f"Emitted out_dtype mismatch: expected {expect_out_dtype}, " f"got {got} (full list: {decoded})" + ) + if expect_actual_dtype is not None: + assert actual.dtype == expect_actual_dtype, ( + f"Runtime result dtype mismatch: expected {expect_actual_dtype}, " f"got {actual.dtype}" + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# Positive tests — every athena activation must fuse and stay numerically OK +# ───────────────────────────────────────────────────────────────────────────── + + +class _Bf16MmModel(nn.Module): + """All positive activation models share this skeleton: bf16 mm followed + by an epilogue fn that returns bf16. Weight is held in (N, K) row-major + form and accessed via ``permute([1, 0])`` to mirror the real GAGA2 graph.""" + + def __init__(self, k: int, n: int, epilogue): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + self._epi = epilogue + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return self._epi(y, out_dtype=torch.bfloat16) + + +_M, _K, _N = 1024, 1024, 1024 + + +def _input_a(): + return torch.randn(_M, _K, device="cuda", dtype=torch.bfloat16) + + +@_SM120_ONLY +@pytest.mark.parametrize( + "epi_name,epi_fn,atol,rtol", + [ + ("silu", high_precision_silu, 0.5, 0.0), + ("sigmoid", high_precision_sigmoid, 0.5, 0.0), + ("gelu", high_precision_gelu, 0.5, 0.0), + ("gelu7", gelu7, 0.5, 0.0), + ("relu_square", relu_square, 0.0, 0.2), + ], +) +def test_evt_unary_activations_fuse(epi_name, epi_fn, atol, rtol): + """All unary activations must fuse to a single ``evt_col`` op.""" + model = _Bf16MmModel(_K, _N, epi_fn) + _compile_and_check(model, (_input_a(),), atol=atol, rtol=rtol, expect_fused=1, expect_kinds=["evt_col"]) + + +@_SM120_ONLY +def test_evt_relu_native(): + """Plain ``aten.relu`` (no fp32 cast) — exercises the built-in CUTLASS + ReLu functor mapping in the IR.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return torch.relu(torch.mm(a, self.weight.permute(1, 0))).to(torch.bfloat16) + + _compile_and_check(M(), (_input_a(),), expect_fused=1, expect_kinds=["evt_col"]) + + +@_SM120_ONLY +def test_evt_swiglu7_dispatches_to_dualgemm(): + """SwiGLU7 must take the dedicated DualGemm one-stage path, not generic EVT.""" + model = _Bf16MmModel(_K, _N, swiglu7) + _compile_and_check(model, (_input_a(),), atol=0.5, rtol=0.05, expect_fused=1, expect_kinds=["swiglu7_dual"]) + + +# ───────────────────────────────────────────────────────────────────────────── +# Binary-op positive tests — chains containing add/sub/mul/div on the mm output +# ───────────────────────────────────────────────────────────────────────────── + + +@_SM120_ONLY +def test_evt_mm_plus_scalar(): + """``mm + 0.5`` — scalar add absorbs into ``add_scalar`` IR node. + + Tolerance: eager runs the add in bf16 (lossy ulp at ±0.5); CUTLASS runs + the add in fp32 then casts. The ~1.0 absolute diff observed is bf16 + rounding noise on the eager side, not a CUTLASS bug. + """ + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return (torch.mm(a, self.weight.permute(1, 0)) + 0.5).to(torch.bfloat16) + + _compile_and_check(M(), (_input_a(),), atol=1.5, expect_fused=1, expect_kinds=["evt_col"]) + + +@_SM120_ONLY +def test_evt_mm_times_scalar(): + """``mm * 0.25`` — scalar mul (mul_scalar IR).""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return (torch.mm(a, self.weight.permute(1, 0)) * 0.25).to(torch.bfloat16) + + _compile_and_check(M(), (_input_a(),), expect_fused=1, expect_kinds=["evt_col"]) + + +@_SM120_ONLY +def test_evt_mm_div_scalar_then_silu(): + """``silu(mm / 8)`` — scalar div + activation chain.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) / 8.0 + return high_precision_silu(y, out_dtype=torch.bfloat16) + + _compile_and_check(M(), (_input_a(),), expect_fused=1, expect_kinds=["evt_col"]) + + +@_SM120_ONLY +def test_evt_mm_minus_scalar_then_relu(): + """``relu(mm - 2.0)``.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) - 2.0 + return torch.relu(y).to(torch.bfloat16) + + _compile_and_check(M(), (_input_a(),), expect_fused=1, expect_kinds=["evt_col"]) + + +@_SM120_ONLY +def test_evt_mm_plus_1d_bias(): + """``silu(mm + bias_N)`` — 1-D bias as RowBroadcast extras.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + self.bias = nn.Parameter(torch.randn(_N)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + self.bias + return high_precision_silu(y, out_dtype=torch.bfloat16) + + # atol=1.5: eager does the bias-add in bf16 (lossy), CUTLASS in fp32 — + # the ~1.0 abs diff is bf16 ulp noise on the eager side. + _compile_and_check(M(), (_input_a(),), atol=1.5, expect_fused=1, expect_kinds=["evt_col"]) + + +@_SM120_ONLY +def test_evt_mm_times_aux_load(): + """``(mm * gate_MxN)`` — full (M, N) auxiliary tensor multiply. + + The gate must be supplied as a regular forward arg (not a model parameter) + because magi_compile doesn't trace through Parameters of dynamic shape. + """ + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a, gate): + y = torch.mm(a, self.weight.permute(1, 0)) * gate + return y.to(torch.bfloat16) + + a = _input_a() + gate = torch.randn(_M, _N, device="cuda", dtype=torch.bfloat16) + _compile_and_check( + M(), (a, gate), atol=0.0, rtol=0.1, expect_fused=1, expect_kinds=["evt_col"], dynamic_arg_dims={"a": 0, "gate": 0} + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# Negative tests — fusion must NOT fire and the chain must fall back to cuBLAS +# ───────────────────────────────────────────────────────────────────────────── + + +@_SM120_ONLY +def test_evt_no_fuse_intermediate_escapes(): + """Attention → residual → RMSNorm pattern: ``add(residual, mm)`` is + consumed both by ``square(...)`` (would-be-fused) AND by ``mul(_, rsqrt)`` + later. The pass MUST refuse — fusing would silently drop the value the + rest of RMSNorm needs.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(5120, _K)) + self.gamma = nn.Parameter(torch.randn(5120)) + + def forward(self, a, residual): + y = torch.mm(a, self.weight.permute(1, 0)).float() + x = residual + y + var = x.pow(2).mean(-1, keepdim=True) + rsqrt = torch.rsqrt(var + 1e-6) + return (x * rsqrt * (self.gamma + 1)).to(torch.bfloat16) + + a = _input_a() + residual = torch.randn(_M, 5120, device="cuda", dtype=torch.float32) + # `residual + y` couples a's M to residual's M; mark both dynamic so + # Dynamo doesn't specialize a's declared dynamic dim → ConstraintViolation. + _compile_and_check(M(), (a, residual), atol=2.0, rtol=0.1, expect_fused=0, dynamic_arg_dims={"a": 0, "residual": 0}) + + +@_SM120_ONLY +def test_evt_no_fuse_bare_mm(): + """A bare ``mm`` with no epilogue at all — Store(Accum) is trivial. + Replacing cuBLAS with a CUTLASS GEMM that does identical work is strictly + slower, so the pass must skip.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return torch.mm(a, self.weight.permute(1, 0)) + + _compile_and_check(M(), (_input_a(),), atol=0.5, expect_fused=0) + + +@_SM120_ONLY +def test_evt_no_fuse_k_misaligned(): + """K not divisible by 8 fails the bf16 alignment guard — cuBLAS path.""" + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + K = 1023 # 1023 % 8 = 7 → should NOT fuse + N = 1024 + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=0) + + +@_SM120_ONLY +def test_evt_no_fuse_evt_n_misaligned(): + """N not divisible by 4 fails the generic-EVT N-alignment guard + (CUTLASS AlignmentC = 4) — must fall back to torch.compile / cuBLAS.""" + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + K = 1024 + N = 1026 # 1026 % 4 = 2 → should NOT fuse + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=0) + + +@_SM120_ONLY +def test_evt_no_fuse_swiglu7_n_not_mult_of_8(): + """swiglu7 needs N % 8 == 0 so that n_out = N // 2 is 4-aligned for + bf16 (CUTLASS AlignmentC = 4). N = 12 (% 8 != 0) must fall back.""" + + class M(nn.Module): + def __init__(self, k, n): + super().__init__() + self.weight = nn.Parameter(torch.randn(n, k)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return swiglu7(y, out_dtype=torch.bfloat16) + + K = 1024 + N = 12 # 12 % 2 == 0 (split OK) but 12 % 8 != 0 → NOT fused + a = torch.randn(_M, K, device="cuda", dtype=torch.bfloat16) + _compile_and_check(M(K, N), (a,), expect_fused=0) + + +# ───────────────────────────────────────────────────────────────────────────── +# IR / cache key invariants +# ───────────────────────────────────────────────────────────────────────────── + + +@_SM120_ONLY +def test_evt_ir_canonical_determinism(): + """Same IR built twice → identical canonical JSON. If this regresses, the + .cu module disk cache silently misses and recompiles every run.""" + from magi_compiler.passes.piecewise_graph.fusion.blackwell_geforce.evt_ir import ( + Accum, + Compute, + Store, + cache_key, + to_canonical_json, + ) + + a = Store(Compute("silu", (Compute("add", (Accum(), Accum())),)), "bfloat16") + b = Store(Compute("silu", (Compute("add", (Accum(), Accum())),)), "bfloat16") + assert to_canonical_json(a) == to_canonical_json(b) + assert cache_key(a, "bfloat16", "bfloat16") == cache_key(b, "bfloat16", "bfloat16") + + +# ───────────────────────────────────────────────────────────────────────────── +# out_dtype correctness — verify the EVT pass picks the right Store dtype + +# the codegen's ElementC matches + the runtime returns a tensor of that dtype. +# +# Matrix: +# input dtype | epilogue compute | output dtype | expected out_dtype_id +# ───────────────────────────────────────────────────────────────────── +# bf16 | bf16 | bf16 | 0 (default) +# bf16 | fp32 | bf16 | 0 (high_precision_silu) +# bf16 | fp32 | fp32 | 2 (no final cast) +# bf16 | bf16 | fp16 | 1 (cross-precision) +# fp16 | fp16 | fp16 | 1 (fp16-only path) +# fp32 input | — | — | not fused (negative) +# ───────────────────────────────────────────────────────────────────────────── + + +@_SM120_ONLY +def test_evt_out_dtype_bf16_native(): + """bf16 mm → bf16 silu → bf16 output (no fp32 promotion). Pure-bf16 chain. + out_dtype_id MUST be 0 (bf16) and the runtime tensor MUST be bf16.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return F.silu(torch.mm(a, self.weight.permute(1, 0))) # bf16 → bf16 + + _compile_and_check( + M(), + (_input_a(),), + expect_fused=1, + expect_kinds=["evt_col"], + expect_out_dtype=torch.bfloat16, + expect_actual_dtype=torch.bfloat16, + ) + + +@_SM120_ONLY +def test_evt_out_dtype_bf16_via_high_precision(): + """The athena ``high_precision_silu`` pattern: bf16 → cast(fp32) → silu → + cast(bf16). The IR walker absorbs both casts; final output is bf16 even + though the compute went through fp32 internally. + + This is the most common athena pattern — a regression here means the + inner-cast handling broke and out_dtype is silently wrong.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return high_precision_silu(y, out_dtype=torch.bfloat16) + + _compile_and_check( + M(), + (_input_a(),), + expect_fused=1, + expect_kinds=["evt_col"], + expect_out_dtype=torch.bfloat16, + expect_actual_dtype=torch.bfloat16, + ) + + +@_SM120_ONLY +def test_evt_out_dtype_fp32_no_final_cast(): + """bf16 mm → fp32 cast → silu → keep fp32 (no final cast back). + + out_dtype_id MUST be 2 (fp32). Exercises codegen's ``ElementC = float`` + path + the runtime D allocator with fp32 row-stride alignment (4 elements + = 16 bytes — different vector size than bf16's 8 bytes). + """ + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)).float() + return F.silu(y) # stays fp32 + + _compile_and_check( + M(), + (_input_a(),), + expect_fused=1, + expect_kinds=["evt_col"], + expect_out_dtype=torch.float32, + expect_actual_dtype=torch.float32, + ) + + +@_SM120_ONLY +def test_evt_out_dtype_bf16_to_fp16(): + """bf16 mm → silu → cast(fp16). Cross-precision: bf16 inputs but fp16 + output. out_dtype_id MUST be 1 (fp16). Exercises the codegen's + ``ElementA = bfloat16_t`` + ``ElementC = half_t`` mixed instantiation.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return F.silu(torch.mm(a, self.weight.permute(1, 0))).half() + + _compile_and_check( + M(), + (_input_a(),), + atol=0.5, + expect_fused=1, + expect_kinds=["evt_col"], + expect_out_dtype=torch.float16, + expect_actual_dtype=torch.float16, + ) + + +@_SM120_ONLY +def test_evt_out_dtype_fp16_native(): + """fp16 mm + fp16 silu → fp16 output. Pure-fp16 path — exercises the + pass's bf16/fp16 branch in the input-dtype check, plus the codegen's + ``cutlass::half_t`` ElementA/B/C path end-to-end.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + return F.silu(torch.mm(a, self.weight.permute(1, 0))) # fp16 → fp16 + + a = torch.randn(_M, _K, device="cuda", dtype=torch.float16) + # Cast model to fp16 (not bf16) so all parameters match A's dtype. + model = M().cuda().half() + for p in model.parameters(): + p.requires_grad_(False) + + with torch.no_grad(): + expected = model(a) + + get_compile_config().disable_cache = True + stats, restore = _install_pass_instrument() + try: + compiled = magi_compile(model, dynamic_arg_dims={"a": 0}) + with torch.no_grad(): + actual = compiled(a) + finally: + restore() + + diff = (actual.float() - expected.float()).abs().max().item() + assert diff <= 0.5, f"fp16 silu max|diff|={diff}" + assert stats.fused_count == 1, f"fp16 path should fuse but got fused_count={stats.fused_count}" + assert stats.kinds == ["evt_col"], stats.kinds + assert stats.out_dtype_ids == [1], f"Expected out_dtype_id=[1] (fp16), got {stats.out_dtype_ids}" + assert actual.dtype == torch.float16, actual.dtype + + +@_SM120_ONLY +def test_evt_no_fuse_fp32_mm(): + """fp32 mm — pass requires bf16 (or fp16); fp32 must skip.""" + + class M(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.randn(_N, _K)) + + def forward(self, a): + y = torch.mm(a, self.weight.permute(1, 0)) + return F.silu(y) + + a = torch.randn(_M, _K, device="cuda", dtype=torch.float32) + + model = M().cuda() # fp32 — do NOT bfloat16() the model + with torch.no_grad(): + expected = model(a) + + get_compile_config().disable_cache = True + stats, restore = _install_pass_instrument() + try: + compiled_model = magi_compile(model, dynamic_arg_dims={"a": 0}) + with torch.no_grad(): + actual = compiled_model(a) + finally: + restore() + + diff = (actual - expected).abs().max().item() + assert diff <= 1.0, f"fp32 mm result diverged: {diff}" + assert stats.fused_count == 0, ( + f"fp32 mm should NOT fuse, but pass emitted {stats.fused_count} ops " f"(kinds={stats.kinds})" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])