Skip to content

[PRE-TASK] Add Pass to Convert Triton IR Dialect to VIR Dialect #655

@xlinsist

Description

@xlinsist

Deliverable

Develop a pass named triton-to-dynamic-vector to lower add.ttir into VIR dialect examples.

Task Description

add.ttir("ttir" is shorted for Triton Tensor IR) is generated by triton-cpu compiler from the original Triton kernel implementation add.py. Currently, add.ttir is further lowered into add.ttcir("ttcir" is shorted for Triton Tensor CPU IR). This PR is supposed to replace this lowering process to finally get the same level IR as add.ttcir, but with dynamic vector type.

Before developing the automated pass, please start by manually adapting the add.ttcir file into a version that uses a dynamic vector representation. Save this new file as add_dynamic.ttcir and verify it can be successfully compiled to RVV assembly code. Please post the generated RVV assembly instructions for add_dynamic.ttcir in the comment for further discussion.

The operation definitions of relative dialects are listed as followed:

Consider mapping the load and store operation first.

Tips:

  1. Triton-specific semantics like "!tt.ptr" can be temporarily replaced with equivalent, generic MLIR pointer semantics. This ensures the RVV assembly can be successfully generated.

  2. All "loc (location)" information is optional. It is used only for debugging and code location tracking.

  3. The appendix shows key code segments from add.py, add.ttir, and add.ttcir.

Timeline

Coding phase: 2025-12-31 to 2026-01-07
Code review: begins on 2025-01-08

Appendix

add.py

import os
import torch

import triton
import triton.language as tl

# 让 Triton 在 CPU 后端运行
triton.runtime.driver.set_active_to_cpu()
USE_GPU = False

def get_add_kernel_autotune_config():
    configs = []
    # 这里按“分块”思路:每个 program 处理一段长度为 BLOCK_SIZE 的向量
    for BLOCK_SIZE in [64, 128, 256, 512, 1024]:
        configs.append(triton.Config({"BLOCK_SIZE": BLOCK_SIZE}))
    if os.getenv("ENABLE_AUTOTUNING") == "add_kernel":
        assert len(configs) > 1
        return configs
    # 默认给一个配置
    return [triton.Config({"BLOCK_SIZE": 256})]


@triton.autotune(
    configs=get_add_kernel_autotune_config(),
    key=[],
)
@triton.jit
def add_kernel(
    a_ptr, b_ptr, c_ptr,
    n_elements,
    stride_a, stride_b, stride_c,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)

    # 本 program 负责的 element 索引范围:pid*BLOCK_SIZE ... pid*BLOCK_SIZE+BLOCK_SIZE-1
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements

    # 支持任意 stride(但要求 contiguous 的话 stride 就是 1)
    a = tl.load(a_ptr + offsets * stride_a, mask=mask, other=0.0)
    b = tl.load(b_ptr + offsets * stride_b, mask=mask, other=0.0)
    c = a + b
    tl.store(c_ptr + offsets * stride_c, c, mask=mask)


def add(a: torch.Tensor, b: torch.Tensor):
    assert a.shape == b.shape, "a and b must have the same shape"
    assert a.is_contiguous(), "Tensor a must be contiguous"
    assert b.is_contiguous(), "Tensor b must be contiguous"
    assert a.device == b.device, "a and b must be on the same device"
    assert a.dtype == b.dtype, "a and b must have the same dtype"

    # 拉平成 1D 做元素加法(分块处理)
    a1 = a.reshape(-1)
    b1 = b.reshape(-1)
    n = a1.numel()

    c1 = torch.empty_like(a1)

    grid = lambda META: (triton.cdiv(n, META["BLOCK_SIZE"]),)
    add_kernel[grid](
        a1, b1, c1,
        n,
        a1.stride(0), b1.stride(0), c1.stride(0),
    )
    return c1.reshape(a.shape)


def test_add():
    torch.manual_seed(0)
    # 用一些不整齐的长度来测试 mask 是否正确
    a = torch.randn((179, 167), device="cpu", dtype=torch.float32)
    b = torch.randn((179, 167), device="cpu", dtype=torch.float32)

    triton_output = add(a, b)
    torch_output = a + b

    assert torch.allclose(triton_output, torch_output, atol=0, rtol=0), "❌ Triton and Torch differ"
    print("✅ Triton and Torch match")


test_add()

add.ttir(generated from add.py)

module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0), %arg3: i32 loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0)) attributes {noinline = false} {
    %cst = arith.constant dense<0.000000e+00> : tensor<256xf32> loc(#loc1)
    %c256_i32 = arith.constant 256 : i32 loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = arith.muli %0, %c256_i32 : i32 loc(#loc3)
    %2 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32> loc(#loc4)
    %3 = tt.splat %1 : i32 -> tensor<256xi32> loc(#loc5)
    %4 = arith.addi %3, %2 : tensor<256xi32> loc(#loc5)
    %5 = tt.splat %arg3 : i32 -> tensor<256xi32> loc(#loc6)
    %6 = arith.cmpi slt, %4, %5 : tensor<256xi32> loc(#loc6)
    %7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>> loc(#loc7)
    %8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32> loc(#loc7)
    %9 = tt.load %8, %6, %cst : tensor<256x!tt.ptr<f32>> loc(#loc8)
    %10 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>> loc(#loc9)
    %11 = tt.addptr %10, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32> loc(#loc9)
    %12 = tt.load %11, %6, %cst : tensor<256x!tt.ptr<f32>> loc(#loc10)
    %13 = arith.addf %9, %12 : tensor<256xf32> loc(#loc11)
    %14 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>> loc(#loc12)
    %15 = tt.addptr %14, %4 : tensor<256x!tt.ptr<f32>>, tensor<256xi32> loc(#loc12)
    tt.store %15, %13, %6 : tensor<256x!tt.ptr<f32>> loc(#loc13)
    tt.return loc(#loc14)
  } loc(#loc)
} loc(#loc)
#loc1 = loc(unknown)
#loc2 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":36:24)
#loc3 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":39:20)
#loc4 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":39:46)
#loc5 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":39:33)
#loc6 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":40:21)
#loc7 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":43:24)
#loc8 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":43:16)
#loc9 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":44:24)
#loc10 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":44:16)
#loc11 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":45:12)
#loc12 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":46:21)
#loc13 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":46:41)
#loc14 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":46:4)

add.ttcir(lowered from add.ttir)

module {
  tt.func public @add_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0), %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0), %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32} loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0), %arg3: i32 loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":30:0)) attributes {noinline = false} {
    %c0 = arith.constant 0 : index loc(#loc1)
    %cst = arith.constant dense<"0x000000000100000002000000030000000400000005000000060000000700000008000000090000000A0000000B0000000C0000000D0000000E0000000F000000100000001100000012000000130000001400000015000000160000001700000018000000190000001A0000001B0000001C0000001D0000001E0000001F000000200000002100000022000000230000002400000025000000260000002700000028000000290000002A0000002B0000002C0000002D0000002E0000002F000000300000003100000032000000330000003400000035000000360000003700000038000000390000003A0000003B0000003C0000003D0000003E0000003F000000400000004100000042000000430000004400000045000000460000004700000048000000490000004A0000004B0000004C0000004D0000004E0000004F000000500000005100000052000000530000005400000055000000560000005700000058000000590000005A0000005B0000005C0000005D0000005E0000005F000000600000006100000062000000630000006400000065000000660000006700000068000000690000006A0000006B0000006C0000006D0000006E0000006F000000700000007100000072000000730000007400000075000000760000007700000078000000790000007A0000007B0000007C0000007D0000007E0000007F000000800000008100000082000000830000008400000085000000860000008700000088000000890000008A0000008B0000008C0000008D0000008E0000008F000000900000009100000092000000930000009400000095000000960000009700000098000000990000009A0000009B0000009C0000009D0000009E0000009F000000A0000000A1000000A2000000A3000000A4000000A5000000A6000000A7000000A8000000A9000000AA000000AB000000AC000000AD000000AE000000AF000000B0000000B1000000B2000000B3000000B4000000B5000000B6000000B7000000B8000000B9000000BA000000BB000000BC000000BD000000BE000000BF000000C0000000C1000000C2000000C3000000C4000000C5000000C6000000C7000000C8000000C9000000CA000000CB000000CC000000CD000000CE000000CF000000D0000000D1000000D2000000D3000000D4000000D5000000D6000000D7000000D8000000D9000000DA000000DB000000DC000000DD000000DE000000DF000000E0000000E1000000E2000000E3000000E4000000E5000000E6000000E7000000E8000000E9000000EA000000EB000000EC000000ED000000EE000000EF000000F0000000F1000000F2000000F3000000F4000000F5000000F6000000F7000000F8000000F9000000FA000000FB000000FC000000FD000000FE000000FF000000"> : vector<256xi32> loc(#loc1)
    %c256_i32 = arith.constant 256 : i32 loc(#loc1)
    %cst_0 = arith.constant dense<0.000000e+00> : vector<256xf32> loc(#loc1)
    %0 = tt.get_program_id x : i32 loc(#loc2)
    %1 = arith.muli %0, %c256_i32 : i32 loc(#loc3)
    %2 = vector.splat %1 : vector<256xi32> loc(#loc4)
    %3 = arith.addi %2, %cst : vector<256xi32> loc(#loc4)
    %4 = vector.splat %arg3 : vector<256xi32> loc(#loc5)
    %5 = arith.cmpi slt, %3, %4 : vector<256xi32> loc(#loc5)
    %6 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32 loc(#loc6)
    %7 = triton_cpu.ptr_to_memref %6 : <f32> -> memref<256xf32> loc(#loc7)
    %8 = vector.maskedload %7[%c0], %5, %cst_0 : memref<256xf32>, vector<256xi1>, vector<256xf32> into vector<256xf32> loc(#loc7)
    %9 = tt.addptr %arg1, %1 : !tt.ptr<f32>, i32 loc(#loc8)
    %10 = triton_cpu.ptr_to_memref %9 : <f32> -> memref<256xf32> loc(#loc9)
    %11 = vector.maskedload %10[%c0], %5, %cst_0 : memref<256xf32>, vector<256xi1>, vector<256xf32> into vector<256xf32> loc(#loc9)
    %12 = arith.addf %8, %11 : vector<256xf32> loc(#loc10)
    %13 = tt.addptr %arg2, %1 : !tt.ptr<f32>, i32 loc(#loc11)
    %14 = triton_cpu.ptr_to_memref %13 : <f32> -> memref<256xf32> loc(#loc12)
    vector.maskedstore %14[%c0], %5, %12 : memref<256xf32>, vector<256xi1>, vector<256xf32> loc(#loc12)
    tt.return loc(#loc13)
  } loc(#loc)
} loc(#loc)

#loc1 = loc(unknown)
#loc2 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":36:24)
#loc3 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":39:20)
#loc4 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":39:33)
#loc5 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":40:21)
#loc6 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":43:24)
#loc7 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":43:16)
#loc8 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":44:24)
#loc9 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":44:16)
#loc10 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":45:12)
#loc11 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":46:21)
#loc12 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":46:41)
#loc13 = loc("/home/zhouxulin/intern/AI-Benchmark/test_add.py":46:4)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions