-
Notifications
You must be signed in to change notification settings - Fork 429
feat(commonir): add commonir abstract backend #1754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
350ed98
aa21242
a7b215a
cc1057f
e6fe539
4741081
e0505cf
17fb2ad
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,54 @@ | ||
| # Copyright (c) Tile-AI Corporation. | ||
| # Licensed under the MIT License. | ||
| import os | ||
|
|
||
| import tilelang | ||
| import tilelang.language as T | ||
|
|
||
| import torch | ||
|
|
||
| dtype = "float32" | ||
| seq_len = 1024 | ||
|
|
||
|
|
||
| def vec_add(N, block_N, dtype="float32"): | ||
| n_num = N // block_N | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((N), dtype), | ||
| B: T.Tensor((N), dtype), | ||
| C: T.Tensor((N), dtype), | ||
| ): | ||
| with T.Kernel(n_num, 1) as (by, bx): | ||
| start_y1 = by * block_N | ||
| start_y = start_y1 + bx | ||
| for (local_y) in T.Parallel(block_N): | ||
| y = start_y + local_y | ||
| C[y] = A[y] + B[y] | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| def test_vec_add(): | ||
| func = vec_add(seq_len, seq_len // 4) | ||
| compiled_kernel = tilelang.compile(func) | ||
|
|
||
| v1 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu() | ||
| v2 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu() | ||
| v3 = torch.zeros(size=[seq_len], dtype=eval("torch." + dtype)).npu() | ||
|
|
||
| y_ref = v1 + v2 | ||
| compiled_kernel(v1, v2, v3) | ||
|
|
||
| # print(y_ref) | ||
| # print(v3) | ||
|
|
||
| print(f'The maximum difference between torch and Tilellang is ' | ||
| f'{torch.max(torch.abs(y_ref - v3))}') | ||
|
|
||
| torch.testing.assert_close(v3, y_ref, atol=1e-2, rtol=0) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_vec_add() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,52 @@ | ||
| # Copyright (c) Tile-AI Corporation. | ||
| # Licensed under the MIT License. | ||
| import os | ||
|
|
||
| import tilelang | ||
| import tilelang.language as T | ||
| from functools import partial | ||
|
|
||
| import torch | ||
| import torch_npu | ||
| import time | ||
| import numpy as np | ||
| from typing import Callable, Optional, Union, List | ||
|
|
||
|
|
||
| dtype = "float32" | ||
| seq_len = 1024 | ||
|
|
||
| def vec_add(N, block_N, dtype="float32"): | ||
| n_num = N // block_N | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((N), dtype), | ||
| B: T.Tensor((N), dtype), | ||
| C: T.Tensor((N), dtype), | ||
| ): | ||
| with T.Kernel(n_num, 1) as (by, bx): | ||
| start_y1 = by * block_N | ||
| start_y = start_y1 + bx | ||
| for (local_y) in T.Parallel(block_N): | ||
| y = start_y + local_y | ||
| C[y] = A[y] + B[y] | ||
|
|
||
| return main | ||
|
|
||
| def ref_program(v1, v2): | ||
| return v1 + v2 | ||
|
|
||
| def test_vec_add(): | ||
| func = vec_add(seq_len, seq_len // 4) | ||
| compiled_kernel = tilelang.compile(func, out_idx=[2]) | ||
|
|
||
| profiler = compiled_kernel.get_profiler() | ||
| profiler.assert_allclose(ref_program, rtol=0.01, atol=0.01) | ||
| latency = profiler.do_bench(ref_program, warmup=500) | ||
| latency2 = profiler.do_bench(warmup=500) | ||
| print(f"⏱ latency base is {latency}") | ||
| print(f"⏱ latency is {latency2}") | ||
|
|
||
| if __name__ == "__main__": | ||
| test_vec_add() |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,52 @@ | ||||||||||||||||||||||||||
| # Copyright (c) Tile-AI Corporation. | ||||||||||||||||||||||||||
| # Licensed under the MIT License. | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import tilelang | ||||||||||||||||||||||||||
| import tilelang.language as T | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| import torch | ||||||||||||||||||||||||||
| import torch_npu | ||||||||||||||||||||||||||
| device = torch.npu.current_device() | ||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Module-level NPU device access may cause import failure.
🛡️ Proposed fix import torch
import torch_npu
-device = torch.npu.current_device()
dtype = torch.float16
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
...
def main():
+ device = torch.npu.current_device()
func = matmul(1024, 1024, 1024, 128, 128, 32)📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||
| dtype = torch.float16 | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| @T.prim_func | ||||||||||||||||||||||||||
| def gemm( | ||||||||||||||||||||||||||
| A: T.Tensor((M, K), dtype), | ||||||||||||||||||||||||||
| B: T.Tensor((K, N), dtype), | ||||||||||||||||||||||||||
| C: T.Tensor((M, N), dtype), | ||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): | ||||||||||||||||||||||||||
| A_shared = T.alloc_shared((block_M, block_K), dtype) | ||||||||||||||||||||||||||
| B_shared = T.alloc_shared((block_K, block_N), dtype) | ||||||||||||||||||||||||||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| T.clear(C_local) | ||||||||||||||||||||||||||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): | ||||||||||||||||||||||||||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||||||||||||||||||||||||||
| T.copy(B[k * block_K, bx * block_N], B_shared) | ||||||||||||||||||||||||||
| T.gemm(A_shared, B_shared, C_local) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| T.copy(C_local, C[by * block_M, bx * block_N]) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| return gemm | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| def main(): | ||||||||||||||||||||||||||
| func = matmul(1024, 1024, 1024, 128, 128, 32) | ||||||||||||||||||||||||||
| kernel = tilelang.compile(func, target='commonir') | ||||||||||||||||||||||||||
| SIZEALL = 1024 | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| torch.manual_seed(0) | ||||||||||||||||||||||||||
| a = torch.rand((SIZEALL, SIZEALL), dtype=dtype, device=device) - 0.5 | ||||||||||||||||||||||||||
| b = torch.rand((SIZEALL, SIZEALL), dtype=dtype, device=device) - 0.5 | ||||||||||||||||||||||||||
| result = torch.zeros((SIZEALL, SIZEALL), dtype=dtype, device=device) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| kernel(a, b, result) | ||||||||||||||||||||||||||
| golden = a @ b | ||||||||||||||||||||||||||
| # print(f"result is {result}, golden is {golden}") | ||||||||||||||||||||||||||
| torch.testing.assert_close(result, golden, atol=1e-2, rtol=1e-2) | ||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||
| main() | ||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Date year appears incorrect.
The entry is dated
01/29/2025, but given the current date (January 2026) and the reverse-chronological ordering of the news section, this should likely be01/29/2026. With the current date, this entry would be out of order—it should appear much lower in the list, after entries from later in 2025.📝 Suggested fix
📝 Committable suggestion
🤖 Prompt for AI Agents