-
Notifications
You must be signed in to change notification settings - Fork 239
Description
Deliverable
Develop a pass named triton-to-dynamic-vector to lower add.ttir into VIR dialect examples.
Task Description
add.ttir("ttir" is shorted for Triton Tensor IR) is generated by triton-cpu compiler from the original Triton kernel implementation add.py. Currently, add.ttir is further lowered into add.ttcir("ttcir" is shorted for Triton Tensor CPU IR). This PR is supposed to replace this lowering process to finally get the same level IR as add.ttcir, but with dynamic vector type.
Before developing the automated pass, please start by manually adapting the add.ttcir file into a version that uses a dynamic vector representation. Save this new file as add_dynamic.ttcir and verify it can be successfully compiled to RVV assembly code. Please post the generated RVV assembly instructions for add_dynamic.ttcir in the comment for further discussion.
The operation definitions of relative dialects are listed as followed:
ttirthe source dialectttcirthe reference dialectVIRthe target dialect
Consider mapping the load and store operation first.
Tips:
-
Triton-specific semantics like "!tt.ptr" can be temporarily replaced with equivalent, generic MLIR pointer semantics. This ensures the RVV assembly can be successfully generated.
-
All "loc (location)" information is optional. It is used only for debugging and code location tracking.
-
The appendix shows key code segments from
add.py,add.ttir, andadd.ttcir.
Timeline
Coding phase: 2025-12-31 to 2026-01-07
Code review: begins on 2025-01-08
Appendix
add.py
import os
import torch
import triton
import triton.language as tl
# 让 Triton 在 CPU 后端运行
triton.runtime.driver.set_active_to_cpu()
USE_GPU = False
def get_add_kernel_autotune_config():
configs = []
# 这里按“分块”思路:每个 program 处理一段长度为 BLOCK_SIZE 的向量
for BLOCK_SIZE in [64, 128, 256, 512, 1024]:
configs.append(triton.Config({"BLOCK_SIZE": BLOCK_SIZE}))
if os.getenv("ENABLE_AUTOTUNING") == "add_kernel":
assert len(configs) > 1
return configs
# 默认给一个配置
return [triton.Config({"BLOCK_SIZE": 256})]
@triton.autotune(
configs=get_add_kernel_autotune_config(),
key=[],
)
@triton.jit
def add_kernel(
a_ptr, b_ptr, c_ptr,
n_elements,
stride_a, stride_b, stride_c,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
# 本 program 负责的 element 索引范围:pid*BLOCK_SIZE ... pid*BLOCK_SIZE+BLOCK_SIZE-1
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# 支持任意 stride(但要求 contiguous 的话 stride 就是 1)
a = tl.load(a_ptr + offsets * stride_a, mask=mask, other=0.0)
b = tl.load(b_ptr + offsets * stride_b, mask=mask, other=0.0)
c = a + b
tl.store(c_ptr + offsets * stride_c, c, mask=mask)
def add(a: torch.Tensor, b: torch.Tensor):
assert a.shape == b.shape, "a and b must have the same shape"
assert a.is_contiguous(), "Tensor a must be contiguous"
assert b.is_contiguous(), "Tensor b must be contiguous"
assert a.device == b.device, "a and b must be on the same device"
assert a.dtype == b.dtype, "a and b must have the same dtype"
# 拉平成 1D 做元素加法(分块处理)
a1 = a.reshape(-1)
b1 = b.reshape(-1)
n = a1.numel()
c1 = torch.empty_like(a1)
grid = lambda META: (triton.cdiv(n, META["BLOCK_SIZE"]),)
add_kernel[grid](
a1, b1, c1,
n,
a1.stride(0), b1.stride(0), c1.stride(0),
)
return c1.reshape(a.shape)
def test_add():
torch.manual_seed(0)
# 用一些不整齐的长度来测试 mask 是否正确
a = torch.randn((179, 167), device="cpu", dtype=torch.float32)
b = torch.randn((179, 167), device="cpu", dtype=torch.float32)
triton_output = add(a, b)
torch_output = a + b
assert torch.allclose(triton_output, torch_output, atol=0, rtol=0), "❌ Triton and Torch differ"
print("✅ Triton and Torch match")
test_add()add.ttir(generated from add.py)
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0), %arg3: i32 loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0)) attributes {noinline = false} {
%cst = arith.constant dense<0.000000e+00> : tensor<256xf32> loc(#loc1)
%c256_i32 = arith.constant 256 : i32 loc(#loc1)
%0 = tt.get_program_id x : i32 loc(#loc2)
%1 = arith.muli %0, %c256_i32 : i32 loc(#loc3)
%2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc4)
%3 = tt.splat %1 : i32 -> tensor<256xi32> loc(#loc5)
%4 = arith.addi %3, %2 : tensor<256xi32> loc(#loc5)
%5 = tt.splat %arg3 : i32 -> tensor<256xi32> loc(#loc6)
%6 = arith.cmpi slt, %4, %5 : tensor<256xi32> loc(#loc6)
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>> loc(#loc7)
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32> loc(#loc7)
%9 = tt.load %8, %6, %cst : tensor<256x!tt.ptr<f32>> loc(#loc8)
%10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>> loc(#loc9)
%11 = tt.addptr %10, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32> loc(#loc9)
%12 = tt.load %11, %6, %cst : tensor<256x!tt.ptr<f32>> loc(#loc10)
%13 = arith.addf %9, %12 : tensor<256xf32> loc(#loc11)
%14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>> loc(#loc12)
%15 = tt.addptr %14, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32> loc(#loc12)
tt.store %15, %13, %6 : tensor<256x!tt.ptr<f32>> loc(#loc13)
tt.return loc(#loc14)
} loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":36:24)
#loc3 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":39:20)
#loc4 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":39:46)
#loc5 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":39:33)
#loc6 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":40:21)
#loc7 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":43:24)
#loc8 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":43:16)
#loc9 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":44:24)
#loc10 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":44:16)
#loc11 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":45:12)
#loc12 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":46:21)
#loc13 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":46:41)
#loc14 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":46:4)
add.ttcir(lowered from add.ttir)
module {
tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0), %arg3: i32 loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0)) attributes {noinline = false} {
%c0 = arith.constant 0 : index loc(#loc1)
%cst = arith.constant dense<"0x000000000100000002000000030000000400000005000000060000000700000008000000090000000A0000000B0000000C0000000D0000000E0000000F000000100000001100000012000000130000001400000015000000160000001700000018000000190000001A0000001B0000001C0000001D0000001E0000001F000000200000002100000022000000230000002400000025000000260000002700000028000000290000002A0000002B0000002C0000002D0000002E0000002F000000300000003100000032000000330000003400000035000000360000003700000038000000390000003A0000003B0000003C0000003D0000003E0000003F000000400000004100000042000000430000004400000045000000460000004700000048000000490000004A0000004B0000004C0000004D0000004E0000004F000000500000005100000052000000530000005400000055000000560000005700000058000000590000005A0000005B0000005C0000005D0000005E0000005F000000600000006100000062000000630000006400000065000000660000006700000068000000690000006A0000006B0000006C0000006D0000006E0000006F000000700000007100000072000000730000007400000075000000760000007700000078000000790000007A0000007B0000007C0000007D0000007E0000007F000000800000008100000082000000830000008400000085000000860000008700000088000000890000008A0000008B0000008C0000008D0000008E0000008F000000900000009100000092000000930000009400000095000000960000009700000098000000990000009A0000009B0000009C0000009D0000009E0000009F000000A0000000A1000000A2000000A3000000A4000000A5000000A6000000A7000000A8000000A9000000AA000000AB000000AC000000AD000000AE000000AF000000B0000000B1000000B2000000B3000000B4000000B5000000B6000000B7000000B8000000B9000000BA000000BB000000BC000000BD000000BE000000BF000000C0000000C1000000C2000000C3000000C4000000C5000000C6000000C7000000C8000000C9000000CA000000CB000000CC000000CD000000CE000000CF000000D0000000D1000000D2000000D3000000D4000000D5000000D6000000D7000000D8000000D9000000DA000000DB000000DC000000DD000000DE000000DF000000E0000000E1000000E2000000E3000000E4000000E5000000E6000000E7000000E8000000E9000000EA000000EB000000EC000000ED000000EE000000EF000000F0000000F1000000F2000000F3000000F4000000F5000000F6000000F7000000F8000000F9000000FA000000FB000000FC000000FD000000FE000000FF000000"> : vector<256xi32> loc(#loc1)
%c256_i32 = arith.constant 256 : i32 loc(#loc1)
%cst_0 = arith.constant dense<0.000000e+00> : vector<256xf32> loc(#loc1)
%0 = tt.get_program_id x : i32 loc(#loc2)
%1 = arith.muli %0, %c256_i32 : i32 loc(#loc3)
%2 = vector.splat %1 : vector<256xi32> loc(#loc4)
%3 = arith.addi %2, %cst : vector<256xi32> loc(#loc4)
%4 = vector.splat %arg3 : vector<256xi32> loc(#loc5)
%5 = arith.cmpi slt, %3, %4 : vector<256xi32> loc(#loc5)
%6 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32 loc(#loc6)
%7 = triton_cpu.ptr_to_memref %6 : <f32> -> memref<256xf32> loc(#loc7)
%8 = vector.maskedload %7[%c0], %5, %cst_0 : memref<256xf32>, vector<256xi1>, vector<256xf32> into vector<256xf32> loc(#loc7)
%9 = tt.addptr %arg1, %1 : !tt.ptr<f32>, i32 loc(#loc8)
%10 = triton_cpu.ptr_to_memref %9 : <f32> -> memref<256xf32> loc(#loc9)
%11 = vector.maskedload %10[%c0], %5, %cst_0 : memref<256xf32>, vector<256xi1>, vector<256xf32> into vector<256xf32> loc(#loc9)
%12 = arith.addf %8, %11 : vector<256xf32> loc(#loc10)
%13 = tt.addptr %arg2, %1 : !tt.ptr<f32>, i32 loc(#loc11)
%14 = triton_cpu.ptr_to_memref %13 : <f32> -> memref<256xf32> loc(#loc12)
vector.maskedstore %14[%c0], %5, %12 : memref<256xf32>, vector<256xi1>, vector<256xf32> loc(#loc12)
tt.return loc(#loc13)
} loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":36:24)
#loc3 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":39:20)
#loc4 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":39:33)
#loc5 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":40:21)
#loc6 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":43:24)
#loc7 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":43:16)
#loc8 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":44:24)
#loc9 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":44:16)
#loc10 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":45:12)
#loc11 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":46:21)
#loc12 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":46:41)
#loc13 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":46:4)