-
Notifications
You must be signed in to change notification settings - Fork 23
[Feat] Add CUTLASS matmul-epilogue fusion path for sm_120 #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
wtr0504
wants to merge
7
commits into
SandAI-org:main
Choose a base branch
from
wtr0504:feat/matmul_epilogue
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
dcc02f0
add triton matmul fusion
wtr0504 afb9399
add cute kernel
wtr0504 ea5cc68
[Feat] Add CUTLASS matmul-epilogue fusion path for sm_120
wtr0504 4e42fcf
add cutlass install in Dockerfile & update
wtr0504 0cfa820
add enable_mm_epilogue_fusion & chore
wtr0504 f62bd8c
chore
wtr0504 36f7fbf
update .github/codestyle/copyright.hook
wtr0504 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move this file to the |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| # Copyright (c) 2026 SandAI. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """GPU device introspection helpers. | ||
|
|
||
| Centralised so that pass-manager / FX passes / runtime modules don't all | ||
| re-implement the same try/except dance around ``torch.cuda``. | ||
| """ | ||
|
|
||
| from typing import Tuple | ||
|
|
||
|
|
||
| def device_capability(device: int = 0) -> Tuple[int, int]: | ||
| """Return ``(major, minor)`` for the given CUDA device. | ||
|
|
||
| Falls back to ``(0, 0)`` when CUDA is unavailable / not initialised / | ||
| raises any error during introspection β callers compare against a | ||
| minimum cap so a zero pair always means "feature unsupported", which | ||
| is the safe behaviour on CPU-only hosts and during static analysis. | ||
| """ | ||
| try: | ||
| import torch as _torch | ||
|
|
||
| if _torch.cuda.is_available(): | ||
| return _torch.cuda.get_device_capability(device) | ||
| except Exception: | ||
| pass | ||
| return (0, 0) | ||
|
|
||
|
|
||
| def device_capability_major(device: int = 0) -> int: | ||
| """Convenience wrapper: just the major-capability int (0 if no CUDA).""" | ||
| return device_capability(device)[0] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,117 @@ | ||
| # Copyright (c) 2026 SandAI. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import torch | ||
| import torch._inductor.fx_passes.pre_grad | ||
|
|
||
| from ...magi_depyf.timeline import emit_pass_lifecycle | ||
| from ..pass_base import MagiInductorPass | ||
|
|
||
|
|
||
| class RemoveUselessOpsPass(MagiInductorPass): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| """ | ||
| Remove useless convert, view, reshape operations. | ||
| When their input already has the target type and shape, these operations are redundant. | ||
| """ | ||
|
|
||
| TARGET_METHODS = { | ||
| "view", | ||
| "reshape", | ||
| "to", | ||
| "type", | ||
| "contiguous", | ||
| "clone", | ||
| "flatten", | ||
| "permute", | ||
| "transpose", | ||
| "t", | ||
| "unsqueeze", | ||
| "squeeze", | ||
| "expand", | ||
| "repeat", | ||
| "bfloat16", | ||
| "float", | ||
| "half", | ||
| "int", | ||
| "long", | ||
| "short", | ||
| "double", | ||
| "bool", | ||
| "byte", | ||
| } | ||
|
|
||
| @staticmethod | ||
| def _get_tensor_info(node: torch.fx.Node): | ||
| # Get tensor info from example_value | ||
| if "example_value" in node.meta: | ||
| val = node.meta["example_value"] | ||
| if isinstance(val, torch.Tensor): | ||
| return val.shape, val.dtype, val.stride() | ||
| elif isinstance(val, (list, tuple)) and len(val) > 0 and isinstance(val[0], torch.Tensor): | ||
| return val[0].shape, val[0].dtype, val[0].stride() | ||
|
|
||
| return None, None, None | ||
|
|
||
| def is_applicable(self, graph: torch.fx.Graph, shape: int | None = None) -> bool: | ||
| for node in graph.nodes: | ||
| if node.op == "call_method" and node.target in self.TARGET_METHODS: | ||
| return True | ||
| return False | ||
|
|
||
| @emit_pass_lifecycle | ||
| def __call__(self, graph: torch.fx.Graph): | ||
| nodes_to_remove = [] | ||
|
|
||
| for node in graph.nodes: | ||
| is_target_method = node.op == "call_method" and node.target in self.TARGET_METHODS | ||
| if not is_target_method: | ||
| continue | ||
|
|
||
| # Need at least one argument (the input tensor) | ||
| if not node.args or not isinstance(node.args[0], torch.fx.Node): | ||
| continue | ||
|
|
||
| input_node = node.args[0] | ||
|
|
||
| node_shape, node_dtype, node_stride = self._get_tensor_info(node) | ||
| input_shape, input_dtype, input_stride = self._get_tensor_info(input_node) | ||
| if node_shape is None or input_shape is None: | ||
| continue | ||
| if node_dtype is None or input_dtype is None: | ||
| continue | ||
| # Some ops or metadata might not have stride properly captured, | ||
| # but if they do, we should require them to match to be totally safe against contiguous-forcing ops. | ||
| if node_stride is not None and input_stride is not None and node_stride != input_stride: | ||
| continue | ||
|
|
||
| # Check if shape and dtype match exactly | ||
| if node_shape == input_shape and node_dtype == input_dtype: | ||
| # For _to_copy, ensure we are not changing memory format or device or other properties implicitly, | ||
| # but typically in full graph if shape and dtype match, and it's on the same device, it's safe. | ||
| # Let's also check device just in case if it's available. | ||
| def get_device(n): | ||
| if "example_value" in n.meta and isinstance(n.meta["example_value"], torch.Tensor): | ||
| return n.meta["example_value"].device | ||
|
|
||
| node_device = get_device(node) | ||
| input_device = get_device(input_node) | ||
| if node_device is not None and input_device is not None and node_device != input_device: | ||
| continue | ||
|
|
||
| # Replace uses | ||
| node.replace_all_uses_with(input_node) | ||
| nodes_to_remove.append(node) | ||
|
|
||
| for node in nodes_to_remove: | ||
| graph.erase_node(node) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright (c) 2026 SandAI. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
13 changes: 13 additions & 0 deletions
13
magi_compiler/passes/piecewise_graph/fusion/blackwell_geforce/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Copyright (c) 2026 SandAI. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Treat cutlass as third-party and provide install cmd cause users may install magi_compiler without docker.
Update commands in readme.md pls~