Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/codestyle/copyright.hook
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _get_comment_mark(path):
if lang_type.search(path) is not None:
return "#"

lang_type=re.compile(r"\.(h|c|hpp|cc|cpp|cu|go|cuh|proto)$")
lang_type=re.compile(r"\.(h|c|hpp|hxx|cc|cpp|cxx|cu|go|cuh|proto)$")
if lang_type.search(path) is not None:
return "//"

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ repos:
name: copyright_checker
entry: python3 ./.github/codestyle/copyright.hook
language: system
files: \.(c|cc|cxx|cpp|cu|h|hpp|hxx|proto|py|sh)$
files: \.(c|cc|cxx|cpp|cu|cuh|h|hpp|hxx|proto|py|sh)$
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
Expand Down
52 changes: 52 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,21 @@ FROM nvcr.io/nvidia/pytorch:25.10-py3

ARG FLASH_ATTENTION_COMMIT_ID="b613d9e2c8475945baff3fd68f2030af1b890acf"

# CUTLASS β€” source is always cloned (the magi_compiler EVT-fusion path
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Treat cutlass as third-party and provide install cmd cause users may install magi_compiler without docker.
Update commands in readme.md pls~

# JIT-includes its headers and our /opt/cutlass tree is the readable
# reference checkout). The CMake-driven profiler/library is compiled
# *only* when the build host is an RTX 5090 (sm_120, Blackwell consumer);
# every other arch gets the source tree but no built artefacts.
#
# Override behaviour with a build arg:
# --build-arg CUTLASS_BUILD=yes force compile (e.g. on a build farm
# without a GPU but targeting sm_120)
# --build-arg CUTLASS_BUILD=no force skip even if 5090 detected
# --build-arg CUTLASS_BUILD=auto (default) compile iff nvidia-smi
# reports compute_cap == 12.x
ARG CUTLASS_COMMIT_ID="f74fea9ce35868d3ae9f8d1dce1969d7250d3f90"
ARG CUTLASS_BUILD="auto"

ENV PIP_NO_CACHE_DIR=1 \
PIP_DISABLE_PIP_VERSION_CHECK=1 \
PYTHONDONTWRITEBYTECODE=1
Expand All @@ -18,6 +33,7 @@ RUN --mount=type=secret,id=http_proxy,required=false \
ca-certificates \
git \
build-essential \
cmake \
ninja-build && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
Expand All @@ -42,6 +58,42 @@ RUN --mount=type=secret,id=http_proxy,required=false \
cp /tmp/flash-attention/hopper/flash_attn_interface.py ${python_path}/flash_attn_3/ && \
rm -rf /tmp/flash-attention


RUN --mount=type=secret,id=http_proxy,required=false \
--mount=type=secret,id=https_proxy,required=false \
export http_proxy="$(cat /run/secrets/http_proxy 2>/dev/null || true)" && \
export https_proxy="$(cat /run/secrets/https_proxy 2>/dev/null || true)" && \
mkdir -p /opt/cutlass && \
cd /opt/cutlass && \
git init -q && \
git remote add origin https://github.com/NVIDIA/cutlass.git && \
git fetch origin ${CUTLASS_COMMIT_ID} --depth 1 && \
git checkout ${CUTLASS_COMMIT_ID} && \
(git submodule update --init --recursive --depth 1 --jobs 8 || \
git submodule update --init --recursive --depth 1 --jobs 1)


RUN set -eu; \
case "${CUTLASS_BUILD}" in \
no) echo "[CUTLASS] CUTLASS_BUILD=no β€” skipping cmake configure."; exit 0 ;; \
yes) DO_BUILD=1 ;; \
auto) \
if command -v nvidia-smi >/dev/null 2>&1 && \
nvidia-smi --query-gpu=compute_cap --format=csv,noheader 2>/dev/null \
| head -n1 | grep -Eq '^12\.'; then \
echo "[CUTLASS] nvidia-smi reports sm_120 β€” running cmake configure."; \
DO_BUILD=1; \
else \
echo "[CUTLASS] No sm_120 detected at build time β€” skipping cmake (headers still available)."; \
exit 0; \
fi ;; \
*) echo "[CUTLASS] Unknown CUTLASS_BUILD=${CUTLASS_BUILD}"; exit 1 ;; \
esac; \
[ -n "${DO_BUILD:-}" ] && cd /opt/cutlass && \
export CUDACXX="${CUDA_INSTALL_PATH:-${CUDA_HOME:-/usr/local/cuda}}/bin/nvcc" && \
mkdir -p build && cd build && \
cmake .. -DCUTLASS_NVCC_ARCHS=120a

RUN --mount=type=secret,id=http_proxy,required=false \
--mount=type=secret,id=https_proxy,required=false \
export http_proxy="$(cat /run/secrets/http_proxy 2>/dev/null || true)" && \
Expand Down
10 changes: 10 additions & 0 deletions magi_compiler/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ class PassConfig(BaseModel):
# TODO: Add sequence parallelism pass and async TP pass.
# TODO: Add Ulysses overlap pass.
enable_sage_attn: bool = Field(False, description="Whether to replace flash attention with sage attention.")
enable_mm_epilogue_fusion: bool = Field(
True,
description=(
"Whether to enable the matmul + elementwise epilogue fusion pass. "
"On RTX 5090 (sm_120) this lowers fused chains to a CUTLASS Sm80EVT "
"kernel via the blackwell_geforce.MatmulEvtEpilogueFusionPass. The "
"pass is a no-op on older architectures regardless of this flag, "
"but the flag still controls whether it is registered at all."
),
)

@property
def hash(self) -> str:
Expand Down
44 changes: 44 additions & 0 deletions magi_compiler/cuda/device.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move this file to the utils directory and update __init__.py

Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""GPU device introspection helpers.

Centralised so that pass-manager / FX passes / runtime modules don't all
re-implement the same try/except dance around ``torch.cuda``.
"""

from typing import Tuple


def device_capability(device: int = 0) -> Tuple[int, int]:
"""Return ``(major, minor)`` for the given CUDA device.

Falls back to ``(0, 0)`` when CUDA is unavailable / not initialised /
raises any error during introspection β€” callers compare against a
minimum cap so a zero pair always means "feature unsupported", which
is the safe behaviour on CPU-only hosts and during static analysis.
"""
try:
import torch as _torch

if _torch.cuda.is_available():
return _torch.cuda.get_device_capability(device)
except Exception:
pass
return (0, 0)


def device_capability_major(device: int = 0) -> int:
"""Convenience wrapper: just the major-capability int (0 if no CUDA)."""
return device_capability(device)[0]
2 changes: 2 additions & 0 deletions magi_compiler/passes/full_graph/full_graph_pass_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from ...magi_depyf.timeline import observe_lifecycle
from .remove_item import RemoveItemPass
from .remove_useless_ops import RemoveUselessOpsPass
from .replace_sage_atten import ReplaceSageAttentionPass


Expand All @@ -30,6 +31,7 @@ def __init__(self, pass_config):
if self.pass_config.enable_sage_attn:
self.passes.append(ReplaceSageAttentionPass())
self.passes.append(RemoveItemPass())
self.passes.append(RemoveUselessOpsPass())

@observe_lifecycle("full_graph_manager")
def __call__(self, gm: torch.fx.GraphModule):
Expand Down
117 changes: 117 additions & 0 deletions magi_compiler/passes/full_graph/remove_useless_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch._inductor.fx_passes.pre_grad

from ...magi_depyf.timeline import emit_pass_lifecycle
from ..pass_base import MagiInductorPass


class RemoveUselessOpsPass(MagiInductorPass):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RemoveUselessOpsPass -> EliminateIdentityViewCastPass
Note: also change the file name

"""
Remove useless convert, view, reshape operations.
When their input already has the target type and shape, these operations are redundant.
"""

TARGET_METHODS = {
"view",
"reshape",
"to",
"type",
"contiguous",
"clone",
"flatten",
"permute",
"transpose",
"t",
"unsqueeze",
"squeeze",
"expand",
"repeat",
"bfloat16",
"float",
"half",
"int",
"long",
"short",
"double",
"bool",
"byte",
}

@staticmethod
def _get_tensor_info(node: torch.fx.Node):
# Get tensor info from example_value
if "example_value" in node.meta:
val = node.meta["example_value"]
if isinstance(val, torch.Tensor):
return val.shape, val.dtype, val.stride()
elif isinstance(val, (list, tuple)) and len(val) > 0 and isinstance(val[0], torch.Tensor):
return val[0].shape, val[0].dtype, val[0].stride()

return None, None, None

def is_applicable(self, graph: torch.fx.Graph, shape: int | None = None) -> bool:
for node in graph.nodes:
if node.op == "call_method" and node.target in self.TARGET_METHODS:
return True
return False

@emit_pass_lifecycle
def __call__(self, graph: torch.fx.Graph):
nodes_to_remove = []

for node in graph.nodes:
is_target_method = node.op == "call_method" and node.target in self.TARGET_METHODS
if not is_target_method:
continue

# Need at least one argument (the input tensor)
if not node.args or not isinstance(node.args[0], torch.fx.Node):
continue

input_node = node.args[0]

node_shape, node_dtype, node_stride = self._get_tensor_info(node)
input_shape, input_dtype, input_stride = self._get_tensor_info(input_node)
if node_shape is None or input_shape is None:
continue
if node_dtype is None or input_dtype is None:
continue
# Some ops or metadata might not have stride properly captured,
# but if they do, we should require them to match to be totally safe against contiguous-forcing ops.
if node_stride is not None and input_stride is not None and node_stride != input_stride:
continue

# Check if shape and dtype match exactly
if node_shape == input_shape and node_dtype == input_dtype:
# For _to_copy, ensure we are not changing memory format or device or other properties implicitly,
# but typically in full graph if shape and dtype match, and it's on the same device, it's safe.
# Let's also check device just in case if it's available.
def get_device(n):
if "example_value" in n.meta and isinstance(n.meta["example_value"], torch.Tensor):
return n.meta["example_value"].device

node_device = get_device(node)
input_device = get_device(input_node)
if node_device is not None and input_device is not None and node_device != input_device:
continue

# Replace uses
node.replace_all_uses_with(input_node)
nodes_to_remove.append(node)

for node in nodes_to_remove:
graph.erase_node(node)
13 changes: 13 additions & 0 deletions magi_compiler/passes/piecewise_graph/fusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2026 SandAI. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading