Skip to content
189 changes: 189 additions & 0 deletions examples/gemm/example_gemm_intrinsics_dcu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
from tilelang import tvm as tvm
from tvm import DataType
import tilelang
import tilelang.language as T
from tilelang.intrinsics import get_swizzle_layout
from tilelang.intrinsics.mmac_macro_generator import (
MatrixCoreIntrinEmitter,
)
from tilelang.transform import simplify_prim_func
from tilelang import disable_cache

disable_cache()


def make_swizzle_layout(shared_buf):
dtype = shared_buf.dtype
shape = shared_buf.shape

can_swizzle = shape[-1] * DataType(dtype).bits == 512
if not can_swizzle:
return T.Layout(shape, lambda *args: args)

def transform_func(i, j):
new_warp_i, new_warp_j = get_swizzle_layout(i, j, shape[-1], dtype)
return [new_warp_i, new_warp_j]

return T.Layout(shape, transform_func)


@tilelang.jit(out_idx=[2])
@simplify_prim_func
def tl_matmul(
M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
):
assert in_dtype in [
"float16",
"int8",
], "Currently only float16 and int8 are supported"
assert out_dtype in [
"float16",
"float32",
"int32",
], "Currently only float16, float32 and int32 are supported"

micro_size_x = micro_size_y = micro_size_k = 16

if out_dtype == "int32":
micro_size_k = 32

Comment on lines +52 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

micro_size_k should depend on in_dtype, not out_dtype

Int8 paths require K-fragment=32 regardless of the final out dtype. Gate on in_dtype to avoid incorrect tiling when out_dtype is int32 but inputs aren’t int8.

Apply this diff:

-    if out_dtype == "int32":
-        micro_size_k = 32
+    if in_dtype == "int8":
+        micro_size_k = 32
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if out_dtype == "int32":
micro_size_k = 32
if in_dtype == "int8":
micro_size_k = 32
🤖 Prompt for AI Agents
In examples/gemm/example_gemm_intrinsics_dcu.py around lines 50 to 52, the code
sets micro_size_k based on out_dtype but it should be gated on in_dtype: change
the condition to check if in_dtype == "int8" (or equivalent int8 type sentinel
used elsewhere) and set micro_size_k = 32 for int8 input paths; leave other
dtype branches unchanged so int8 inputs always use K-fragment=32 even if
out_dtype is int32.

# This is a debug config
block_row_warps = 2
block_col_warps = 2
warp_row_tiles = 64
warp_col_tiles = 64
# chunk = 32 if in_dtype == "float16" else 64
chunk = 32
shared_scope = "shared.dyn"

# Pipeline Stage
stage = 2

block_M = block_row_warps * warp_row_tiles
block_N = block_col_warps * warp_col_tiles
block_K = chunk

Comment on lines +67 to +70
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Tile edge handling missing; kernel reads/writes out of bounds for non-multiple sizes

Grid uses ceildiv, but loads/stores assume perfect multiples of block_M/N/K. Either add guards in load/store loops or assert divisibility.

Minimal safety precondition (quick fix):

     block_K = chunk
@@
-    A_shape = (M, K)
+    # Precondition for this example kernel (simplifies control flow)
+    assert M % block_M == 0 and N % block_N == 0 and K % block_K == 0, \
+        "Example kernel currently requires M,N,K to be multiples of block sizes"
+    A_shape = (M, K)

Recommended follow-up (optional): add bounds checks in the A/B loads and C stores so arbitrary M,N,K work. I can draft that if you want.

Also applies to: 108-159


I need to examine the actual code to verify the concerns raised in the review comment.


🏁 Script executed:

#!/bin/bash
# Retrieve the file and examine the kernel implementation
if [ -f "examples/gemm/example_gemm_intrinsics_dcu.py" ]; then
    echo "=== File Content ===" 
    cat examples/gemm/example_gemm_intrinsics_dcu.py
else
    echo "File not found at examples/gemm/example_gemm_intrinsics_dcu.py"
    find . -name "example_gemm_intrinsics_dcu.py" -type f 2>/dev/null
fi

Length of output: 5708


Add bounds-checking guards or assertion for non-multiple tensor dimensions

The kernel grid uses ceildiv (line 99) but load/store loops assume perfect multiples. With M, N, or K not divisible by block_M, block_N, or block_K, the loops read/write out of bounds. The current test (M=N=K=16384) masks this because all are divisible by their block sizes.

Either add bounds checks in the A/B loads (lines 117–121) and C store (lines 138–145), or add an early assertion:

     block_K = chunk
@@
+    assert M % block_M == 0 and N % block_N == 0 and K % block_K == 0, \
+        "Example kernel requires M, N, K to be multiples of block sizes (M:%d, N:%d, K:%d; block_M:%d, block_N:%d, block_K:%d)" % (M, N, K, block_M, block_N, block_K)
     A_shape = (M, K)
🤖 Prompt for AI Agents
In examples/gemm/example_gemm_intrinsics_dcu.py around lines 65–68, the computed
block sizes (block_M, block_N, block_K) are used later by loads/stores that
assume tensor dimensions are exact multiples; add either runtime assertions
early (assert M % block_M == 0 and N % block_N == 0 and K % block_K == 0) or,
preferably, add bounds-check guards around A/B loads (lines ~117–121) and the C
store (lines ~138–145): before reading A or B elements check the computed global
row/col indices against M/N/K and substitute zero (or a safe value) for
out-of-bounds loads; before writing C check indices and skip stores outside M/N,
ensuring no out-of-bounds memory access.

A_shape = (M, K)
B_shape = (N, K)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, block_K)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
)

warp_size = 64
threads = warp_size * (block_row_warps * block_col_warps)
local_size_a = (micro_size_x * micro_size_k) // warp_size
local_size_b = (micro_size_y * micro_size_k) // warp_size
local_size_c = (micro_size_x * micro_size_y) // warp_size
warp_rows = warp_row_tiles // micro_size_x
warp_cols = warp_col_tiles // micro_size_y

# MMAC Wrapper to Auto Generate Code for MMAC
mmac_emitter = MatrixCoreIntrinEmitter(
a_dtype=in_dtype,
b_dtype=in_dtype,
accum_dtype=accum_dtype,
a_transposed=False,
b_transposed=True,
block_row_warps=block_row_warps,
block_col_warps=block_col_warps,
warp_row_tiles=warp_row_tiles,
warp_col_tiles=warp_col_tiles,
chunk=chunk,
)

@T.prim_func
def gemm_intrinsics(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, in_dtype),
C: T.Tensor((M, N), out_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)

T.annotate_layout(
{
A_shared: make_swizzle_layout(A_shared),
B_shared: make_swizzle_layout(B_shared),
}
)

# Improve L2 Cache
T.use_swizzle(panel_size=10)

T.clear(C_local)

for ko in T.Pipelined((K // block_K), num_stages=stage):
# Load A into shared memory
for i, k in T.Parallel(block_M, block_K):
A_shared[i, k] = A[by * block_M + i, ko * block_K + k]

# Load B into shared memory
for j, k in T.Parallel(block_N, block_K):
B_shared[j, k] = B[bx * block_N + j, ko * block_K + k]

for ki in T.serial(0, (block_K // micro_size_k)):
# Load A into fragment
mmac_emitter.ldmatrix_a(A_local, A_shared, ki)

# Load B into fragment
mmac_emitter.ldmatrix_b(B_local, B_shared, ki)

# Perform Matrix Multiplication
mmac_emitter.mmac(A_local, B_local, C_local)

# Perform STMatrix
mmac_emitter.stmatrix(C_local, C_shared)

# Store shared into global
for i, j in T.Parallel(block_M, block_N):
C[by * block_M + i, bx * block_N + j] = C_shared[
j // micro_size_y,
i // micro_size_x,
i % micro_size_x,
j % micro_size_y,
]

return gemm_intrinsics


def ref_program(A, B):
return A @ B.T


def main():
M, N, K = 16384, 16384, 16384
in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32"
kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype)
src_code = kernel.get_kernel_source()
# src_code is the generated cuda source
assert src_code is not None

profiler = kernel.get_profiler()

latency = profiler.do_bench(profiler.func, warmup=25)

print(latency)
print(kernel.get_kernel_source())
# Ensure that the latency is not None
assert latency is not None

profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2)


if __name__ == "__main__":
main()
161 changes: 161 additions & 0 deletions examples/minference/ops/vertical_slash_index.hip
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
// !!! This is a file automatically generated by hipify!!!
#include <ATen/dtk_macros.h>
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#include <assert.h>

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/numpy.h>
#include <torch/extension.h>

#include <hip/hip_runtime.h>

__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
for (int idx = range_start; idx < range_end; idx += block_size) {
block_offset[block_count++] = idx;
}
}

__global__ void convert_vertical_slash_indexes_kernel(
const int* seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int N_HEADS,
int N_ROWS,
int BLOCK_SIZE_M,
int BLOCK_SIZE_N,
int NNZ_V,
int NNZ_S
) {
const int batch_idx = blockIdx.y;
const int head_idx = blockIdx.x;
const int group_idx = blockIdx.z;

int seqlen = seqlens[batch_idx];
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
int start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= seqlen) {
return;
}
int end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;

Comment on lines +40 to +54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Guard against row index overflow.

Threads with block_idx_m ≥ N_ROWS can compute a valid start_m < seqlen (if seqlen > context), causing OOB on row_offset. Add an explicit guard.

Apply:

   int seqlen = seqlens[batch_idx];
   int block_idx_m = group_idx * blockDim.x + threadIdx.x;
+  if (block_idx_m >= N_ROWS) {
+      return;
+  }
   int start_m = block_idx_m * BLOCK_SIZE_M;
   if (start_m >= seqlen) {
       return;
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int seqlen = seqlens[batch_idx];
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
int start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= seqlen) {
return;
}
int end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
int seqlen = seqlens[batch_idx];
int block_idx_m = group_idx * blockDim.x + threadIdx.x;
if (block_idx_m >= N_ROWS) {
return;
}
int start_m = block_idx_m * BLOCK_SIZE_M;
if (start_m >= seqlen) {
return;
}
int end_m = start_m + BLOCK_SIZE_M;
vertical_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_V;
slash_indexes += (batch_idx * N_HEADS + head_idx) * NNZ_S;
int row_offset = (batch_idx * N_HEADS + head_idx) * N_ROWS + block_idx_m;
block_count += row_offset;
block_offset += row_offset * NNZ_S;
column_count += row_offset;
column_index += row_offset * NNZ_V;
🤖 Prompt for AI Agents
In examples/minference/ops/vertical_slash_index.hip around lines 40 to 54,
threads where block_idx_m >= N_ROWS can still have start_m < seqlen and will
compute row_offset and use it causing out-of-bounds accesses; add an explicit
guard right after computing block_idx_m (before computing row_offset and any
row-dependent offsets) that returns when block_idx_m >= N_ROWS so subsequent
uses of row_offset, block_count/offset, column_count/index are safe.

int tmp_col_cnt = 0, tmp_blk_cnt = 0;
int s = 0, v = 0;
int v_idx = vertical_indexes[v++];
int s_idx = slash_indexes[s++];
while (s_idx >= end_m) {
s_idx = slash_indexes[s++];
}
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
Comment on lines +55 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix OOB reads when NNZ_S/NNZ_V are zero and bound the pre-scan.

Accessing vertical_indexes[v++] and slash_indexes[s++] without checking NNZ_* risks OOB. The pre-loop while also lacks a bound on s.

Apply:

-    int tmp_col_cnt = 0, tmp_blk_cnt = 0;
-    int s = 0, v = 0;
-    int v_idx = vertical_indexes[v++];
-    int s_idx = slash_indexes[s++];
-    while (s_idx >= end_m) {
-        s_idx = slash_indexes[s++];
-    }
-    s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
+    int tmp_col_cnt = 0, tmp_blk_cnt = 0;
+    int s = 0, v = 0;
+    // Safe init of v_idx
+    int v_idx = (NNZ_V > 0) ? vertical_indexes[v++] : (end_m + BLOCK_SIZE_M);
+    // Handle NNZ_S == 0 early
+    if (NNZ_S == 0) {
+        block_count[0] = 0;
+        column_count[0] = 0;
+        return;
+    }
+    int s_idx = slash_indexes[s++];
+    while (s < NNZ_S && s_idx >= end_m) {
+        s_idx = slash_indexes[s++];
+    }
+    if (s_idx >= end_m) {
+        // No slash indices relevant for this row
+        block_count[0] = 0;
+        column_count[0] = 0;
+        return;
+    }
+    s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int tmp_col_cnt = 0, tmp_blk_cnt = 0;
int s = 0, v = 0;
int v_idx = vertical_indexes[v++];
int s_idx = slash_indexes[s++];
while (s_idx >= end_m) {
s_idx = slash_indexes[s++];
}
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);
int tmp_col_cnt = 0, tmp_blk_cnt = 0;
int s = 0, v = 0;
// Safe init of v_idx
int v_idx = (NNZ_V > 0) ? vertical_indexes[v++] : (end_m + BLOCK_SIZE_M);
// Handle NNZ_S == 0 early
if (NNZ_S == 0) {
block_count[0] = 0;
column_count[0] = 0;
return;
}
int s_idx = slash_indexes[s++];
while (s < NNZ_S && s_idx >= end_m) {
s_idx = slash_indexes[s++];
}
if (s_idx >= end_m) {
// No slash indices relevant for this row
block_count[0] = 0;
column_count[0] = 0;
return;
}
s_idx = max(end_m - s_idx, BLOCK_SIZE_M);

int range_start = s_idx - BLOCK_SIZE_M, range_end = s_idx;
while (1) {
if (v_idx < range_end) {
if (v_idx < range_start) {
column_index[tmp_col_cnt++] = v_idx;
}
if (v < NNZ_V) {
v_idx = vertical_indexes[v++];
} else {
v_idx = end_m + BLOCK_SIZE_M;
}
} else {
if (s < NNZ_S) {
s_idx = max(end_m - slash_indexes[s++], BLOCK_SIZE_M);
} else {
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
Comment on lines +67 to +78
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Cap per-row writes to avoid overflow of column_index and block_offset.

tmp_col_cnt and tmp_blk_cnt can exceed NNZ_V/NNZ_S; cap writes and bound save_blocks.

Apply:

-                column_index[tmp_col_cnt++] = v_idx;
+                if (tmp_col_cnt < NNZ_V) {
+                    column_index[tmp_col_cnt++] = v_idx;
+                }
@@
-                save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
+                save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, NNZ_S);
@@
-                save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
+                save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt, NNZ_S);

And update save_blocks to accept a max:

-__device__ void save_blocks(int* block_offset, int range_start, int range_end, int block_size, int& block_count) {
-    for (int idx = range_start; idx < range_end; idx += block_size) {
-        block_offset[block_count++] = idx;
-    }
-}
+__device__ __forceinline__ void save_blocks(int* block_offset,
+                                            int range_start,
+                                            int range_end,
+                                            int block_size,
+                                            int& blk_cnt,
+                                            int max_blocks) {
+    for (int idx = range_start; idx < range_end && blk_cnt < max_blocks; idx += block_size) {
+        block_offset[blk_cnt++] = idx;
+    }
+}

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/minference/ops/vertical_slash_index.hip around lines 67 to 78,
tmp_col_cnt and tmp_blk_cnt can grow past their backing limits causing
out-of-bounds writes; cap increments and writes so they never exceed NNZ_V and
NNZ_S respectively and pass a max limit into save_blocks. Change the code paths
that write into column_index and block_offset to check (tmp_col_cnt < NNZ_V) and
(tmp_blk_cnt < NNZ_S) before assigning/incrementing and, when calling
save_blocks, pass a new max parameter (e.g., max_tmp_blk_cnt) instead of the raw
tmp_blk_cnt; then update save_blocks signature and implementation to accept that
max and ensure it only processes up to that bounded count and validates indices
before accessing arrays.

break;
}
if (s_idx > range_end + BLOCK_SIZE_M) {
save_blocks(block_offset, range_start, range_end, BLOCK_SIZE_N, tmp_blk_cnt);
range_start = s_idx - BLOCK_SIZE_M;
range_end = s_idx;
} else if (s_idx > range_end) {
range_end += BLOCK_SIZE_M;
}
}
}

block_count[0] = tmp_blk_cnt;
column_count[0] = tmp_col_cnt;
}

void convert_vertical_slash_indexes_64x64(
const int* seqlens, // [BATCH, ]
const int* vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
const int* slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int* block_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* block_offset, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]
int* column_count, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]
int* column_index, // [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]
int BATCH_SIZE,
int N_HEADS,
int N_ROWS,
int NNZ_V,
int NNZ_S
) {
const int BLOCK_SIZE_M = 64;
const int BLOCK_SIZE_N = 64;
const int N_THREADS = 64;
const dim3 dimBlock(N_THREADS);
const dim3 dimGrid(N_HEADS, BATCH_SIZE, (N_ROWS + N_THREADS - 1) / N_THREADS);
hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, 0,
seqlens, vertical_indexes, slash_indexes,
block_count, block_offset, column_count, column_index,
N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S
);
Comment on lines +114 to +118
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Use PyTorch’s current stream, not stream 0; add launch error check.

Launching on stream 0 breaks PyTorch stream semantics and can race other ops. Also add a kernel launch check.

Apply:

+#include <ATen/cuda/CUDAContext.h>  // works for CUDA and ROCm builds
@@
-   hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, 0, 
+   auto stream = at::cuda::getCurrentCUDAStream();
+   hipLaunchKernelGGL(( convert_vertical_slash_indexes_kernel), dim3(dimGrid), dim3(dimBlock), 0, stream.stream(),
         seqlens, vertical_indexes, slash_indexes,
         block_count, block_offset, column_count, column_index,
         N_HEADS, N_ROWS, BLOCK_SIZE_M, BLOCK_SIZE_N, NNZ_V, NNZ_S
     );
+   AT_CUDA_CHECK(hipGetLastError());

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In examples/minference/ops/vertical_slash_index.hip around lines 114-118, the
kernel is launched on stream 0 and lacks a launch error check; change the launch
to use PyTorch's current HIP stream (retrieve the current stream from the
ATen/C10 API rather than hardcoding 0) and after hipLaunchKernelGGL add a kernel
launch error check (call hipGetLastError() and handle/report the error or
throw/log if non-zero) so the launch respects PyTorch stream semantics and
failures are detected.

}

std::vector<at::Tensor> convert_vertical_slash_indexes(
torch::Tensor seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int context_size,
int block_size_M,
int block_size_N
) {
assert(block_size_M == 64);
assert(block_size_N == 64);

hipSetDevice(seqlens.get_device());

int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;

torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());

convert_vertical_slash_indexes_64x64(
seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
nnz_vertical,
nnz_slash
);
Comment on lines +121 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Validate inputs, use DeviceGuard, enforce dtypes/contiguity; replace assert with TORCH_CHECK.

Ensure tensors live on the same device/stream, are int32, and contiguous. Avoid hipSetDevice; use DeviceGuard. Assert() is compiled out in Release.

Apply:

+#include <c10/core/DeviceGuard.h>
@@
-    assert(block_size_M == 64);
-    assert(block_size_N == 64);
+    TORCH_CHECK(block_size_M == 64, "block_size_M must be 64");
+    TORCH_CHECK(block_size_N == 64, "block_size_N must be 64");
@@
-    hipSetDevice(seqlens.get_device());
+    c10::DeviceGuard guard(seqlens.device());
+    TORCH_CHECK(seqlens.is_cuda(), "seqlens must be on CUDA/HIP device");
+    TORCH_CHECK(vertical_indexes.is_cuda() && slash_indexes.is_cuda(), "Inputs must be CUDA/HIP tensors");
+    TORCH_CHECK(vertical_indexes.device() == seqlens.device() && slash_indexes.device() == seqlens.device(),
+                "All inputs must be on the same device");
+    TORCH_CHECK(seqlens.scalar_type() == at::kInt, "seqlens must be int32");
+    TORCH_CHECK(vertical_indexes.scalar_type() == at::kInt && slash_indexes.scalar_type() == at::kInt,
+                "vertical_indexes/slash_indexes must be int32");
+
+    seqlens = seqlens.contiguous();
+    vertical_indexes = vertical_indexes.contiguous();
+    slash_indexes = slash_indexes.contiguous();
@@
-    int batch_size = slash_indexes.size(0);
-    int num_heads = slash_indexes.size(1);
-    int nnz_slash = slash_indexes.size(2);
-    int nnz_vertical = vertical_indexes.size(2);
+    int batch_size = static_cast<int>(slash_indexes.size(0));
+    int num_heads = static_cast<int>(slash_indexes.size(1));
+    int nnz_slash = static_cast<int>(slash_indexes.size(2));
+    int nnz_vertical = static_cast<int>(vertical_indexes.size(2));
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
std::vector<at::Tensor> convert_vertical_slash_indexes(
torch::Tensor seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int context_size,
int block_size_M,
int block_size_N
) {
assert(block_size_M == 64);
assert(block_size_N == 64);
hipSetDevice(seqlens.get_device());
int batch_size = slash_indexes.size(0);
int num_heads = slash_indexes.size(1);
int nnz_slash = slash_indexes.size(2);
int nnz_vertical = vertical_indexes.size(2);
int num_rows = (context_size + block_size_M - 1) / block_size_M;
torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());
convert_vertical_slash_indexes_64x64(
seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
nnz_vertical,
nnz_slash
);
std::vector<at::Tensor> convert_vertical_slash_indexes(
torch::Tensor seqlens, // [BATCH, ]
torch::Tensor vertical_indexes, // [BATCH, N_HEADS, NNZ_V]
torch::Tensor slash_indexes, // [BATCH, N_HEADS, NNZ_S]
int context_size,
int block_size_M,
int block_size_N
) {
TORCH_CHECK(block_size_M == 64, "block_size_M must be 64");
TORCH_CHECK(block_size_N == 64, "block_size_N must be 64");
c10::DeviceGuard guard(seqlens.device());
TORCH_CHECK(seqlens.is_cuda(), "seqlens must be on CUDA/HIP device");
TORCH_CHECK(vertical_indexes.is_cuda() && slash_indexes.is_cuda(), "Inputs must be CUDA/HIP tensors");
TORCH_CHECK(vertical_indexes.device() == seqlens.device() && slash_indexes.device() == seqlens.device(),
"All inputs must be on the same device");
TORCH_CHECK(seqlens.scalar_type() == at::kInt, "seqlens must be int32");
TORCH_CHECK(vertical_indexes.scalar_type() == at::kInt && slash_indexes.scalar_type() == at::kInt,
"vertical_indexes/slash_indexes must be int32");
seqlens = seqlens.contiguous();
vertical_indexes = vertical_indexes.contiguous();
slash_indexes = slash_indexes.contiguous();
int batch_size = static_cast<int>(slash_indexes.size(0));
int num_heads = static_cast<int>(slash_indexes.size(1));
int nnz_slash = static_cast<int>(slash_indexes.size(2));
int nnz_vertical = static_cast<int>(vertical_indexes.size(2));
int num_rows = (context_size + block_size_M - 1) / block_size_M;
torch::Tensor block_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor block_offset = torch::zeros({batch_size, num_heads, num_rows, nnz_slash}, seqlens.options());
torch::Tensor column_count = torch::zeros({batch_size, num_heads, num_rows}, seqlens.options());
torch::Tensor column_index = torch::zeros({batch_size, num_heads, num_rows, nnz_vertical}, seqlens.options());
convert_vertical_slash_indexes_64x64(
seqlens.data_ptr<int>(),
vertical_indexes.data_ptr<int>(),
slash_indexes.data_ptr<int>(),
block_count.data_ptr<int>(),
block_offset.data_ptr<int>(),
column_count.data_ptr<int>(),
column_index.data_ptr<int>(),
batch_size,
num_heads,
num_rows,
nnz_vertical,
nnz_slash
);


return { block_count, block_offset, column_count, column_index };
}
19 changes: 19 additions & 0 deletions src/layout/gemm_layouts.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,23 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
return block_layout;
}

Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) {
if (element_size == 64)
LOG(FATAL) << "Not supported";
ICHECK(block_m % warp_m == 0);
ICHECK(block_n % warp_n == 0);
ICHECK(warp_m % 16 == 0) << "warp_m=" << warp_m;
ICHECK(warp_n % 16 == 0) << "warp_n=" << warp_n;
auto base_layout = makeGemmFragmentC16x16CDNA()->Repeat({1, 1}, false);
auto warp_layout =
base_layout->Repeat({warp_m / 16, warp_n / 16}, false, false);
auto block_layout =
warp_layout->Repeat({block_m / warp_m, block_n / warp_n}, true, false);
return block_layout;
}

Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size) {
Expand Down Expand Up @@ -730,6 +747,8 @@ Layout makeGemmABLayout(int mat_stride, int mat_continuous, int continuity,
if (!k_inner && element_size == 8) // int8 KxN
return makeGemmABLayoutPadded(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 8) == 0)
// return makeHalfBankSwizzleLayout(mat_stride, mat_continuous,
// element_size);
return makeFullBankSwizzleLayout(mat_stride, mat_continuous, element_size);
else if (mat_continuous % (vector_size * 4) == 0)
return makeHalfBankSwizzleLayout(mat_stride, mat_continuous, element_size);
Expand Down
3 changes: 3 additions & 0 deletions src/layout/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ Fragment makeGemmSparseFragmentC(const int block_m, const int block_n,
Fragment makeGemmFragmentCCDNA(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentCDCU(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Fragment makeGemmFragmentCHopper(const int block_m, const int block_n,
const int warp_m, const int warp_n,
const int element_size);
Expand Down
Loading