Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Feb 2, 2026

  • 1sm gemm of ~1430T on B200
  • Handle mbar as BufferLoad to avoid missing index

Thanks @Hamerlate for providing the dev machine.

Summary by CodeRabbit

  • New Features

    • Added a new high-performance GEMM example with validation and end-to-end benchmarking (prints kernel source, latency, TFLOPS).
  • Refactoring

    • Streamlined internal memory-barrier representation used by GEMM lowering.
    • TMA bulk-copy emission now uses a thread-election mechanism for more reliable predicates.
    • Added utility for constructing buffer access pointers.
  • API

    • Broadened intrinsic input handling to accept buffer or raw-pointer forms for barrier-related calls.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 2, 2026

📝 Walkthrough

Walkthrough

Adds a new SM100 GEMM example kernel and refactors mbar handling from region-based to BufferLoad-based across GEMM lowering and TileLang builtins; introduces MakeAccessPtrFromBufferLoad, changes TMA bulk-copy gating to use tl_shuffle_elect, and updates TileLang/Gemm APIs to accept/return BufferLoad or None.

Changes

Cohort / File(s) Summary
New GEMM example
examples/gemm_sm100/gemm_tcgen5mma_ws.py
Adds a 1‑SM, non-persistent GEMM kernel with two-warp staging (load vs compute), shared/tile/local memory orchestration, jit-annotated gemm() and main() with correctness and perf benchmarks.
GEMM node & lowering
src/op/gemm.h, src/op/gemm.cc, src/op/gemm_py.h, src/op/gemm_py.cc
Replaces region-based mbar representation with Optionaltir::BufferLoad mbar_; removes mbarRegion_; updates constructors, reflection (exposes cCoords_), and TCGEN5MMA lowering to build access ptrs from BufferLoad.
Access pointer utility
src/op/utils.h, src/op/utils.cc
Adds MakeAccessPtrFromBufferLoad(...) to compute linear offsets and build tvm_access_ptr calls from BufferLoad indices (handles Ramp/scalars), paralleling region-based API.
TMA bulk-copy gating
src/op/copy.cc
Changes final TMA bulk-copy predicate to use tl_shuffle_elect(...) instead of a per-thread equality check against thread_bounds->min.
TileLang builtin & GEMM python API updates
tilelang/language/builtin.py, tilelang/language/gemm_op.py, tilelang/tileop/gemm/gemm_base.py, tilelang/tileop/gemm/gemm_tcgen05.py
Broadened tcgen05_mma_arrive signature to accept Buffer/BufferLoad/PrimExpr; removed zero-valued mbar fallback; GemmBase.mbar now returns BufferLoad

Sequence Diagram(s)

mermaid
sequenceDiagram
participant Host as "Host"
participant Kernel as "Kernel"
participant WarpLoad as "Warp0 (load)"
participant WarpMMA as "Warp1 (compute)"
participant Shared as "SharedMemory"
participant C_tmem as "C_tmem (tile memory)"
participant GlobalC as "Global C"

Host->>Kernel: launch gemm kernel
Kernel->>WarpLoad: assign threads (tx<32)
Kernel->>WarpMMA: assign threads (32<=tx<64)
WarpLoad->>Shared: load A/B tiles (per-stage)
WarpLoad->>WarpMMA: signal data ready (parity/stage)
WarpMMA->>C_tmem: perform tcgen5 MMA accumulations
WarpMMA->>Shared: write partial results (per-stage)
WarpLoad-->>WarpMMA: barrier sync
WarpMMA->>GlobalC: move from C_tmem -> local -> shared -> global C
Kernel->>Host: return computed C

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 I hopped through buffers, tiny and bright,
From regions to loads in the soft SM light.
Warps loaded, warps counted, barriers sighed—
Tiles danced together, results verified.
A carrot of code — fast, neat, and right.

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 37.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately captures the main changes: adding a 1SM GEMM example for Blackwell and refactoring mbar handling from Buffer/BufferRegion to BufferLoad representation.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

github-actions bot commented Feb 2, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@Rachmanino Rachmanino marked this pull request as ready for review February 3, 2026 07:39
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🤖 Fix all issues with AI agents
In `@src/op/gemm_py.cc`:
- Around line 83-86: When handling the optional mbar argument in the arg-parsing
block, don't silently ignore non-BufferLoad values: if args.size() > 16 then
attempt the BufferLoadNode cast as currently done (check for BufferLoadNode and
set node->mbar_ via Downcast<BufferLoad>), but add an else branch that fails
fast (throw or LOG(FATAL)/CHECK) reporting that arg 16 was expected to be a
BufferLoad and include the actual argument's type/name (use whatever runtime
type introspection is available on the Expr/Node to include in the message) so
callers get a clear error instead of silently skipping mbar_.

In `@src/op/gemm_py.h`:
- Line 32: mbar_ is declared as a non-optional tir::BufferLoad but is only
conditionally assigned (when args.size() > 16 and args[16] is a BufferLoadNode),
causing a type-contract mismatch; change the field declaration from
tir::BufferLoad mbar_ to std::optional<tir::BufferLoad> mbar_, update the
parser/initializer (where args is inspected) to emplace/assign mbar_ only in the
conditional branch, and adjust any uses of mbar_ (check has_value() or use
value_or) so code and the Python bindings safely handle the absent case;
alternatively, if you prefer non-optional, ensure mbar_ is unconditionally
initialized in the same constructor code path and remove the "optional" comment.

In `@src/op/utils.cc`:
- Around line 95-122: The function MakeAccessPtrFromBufferLoad uses hard-coded
DataType::Int(32) for offset, stride and extent which can overflow for large
buffers; change all occurrences of make_const(DataType::Int(32), ...) and the
IntImm for rw_mask to use the buffer's index dtype (buf->index_dtype) instead:
initialize offset and stride with make_const(buf->index_dtype, 0/1), compute
offset/stride arithmetic with that dtype, set extent using
make_const(buf->index_dtype, 1), and construct the rw_mask as
IntImm(buf->index_dtype, rw_mask) when building acc_args; update references
inside MakeAccessPtrFromBufferLoad (offset, stride, extent, acc_args)
accordingly.
🧹 Nitpick comments (1)
src/op/gemm.cc (1)

84-88: Consider hard-failing on unexpected mbar argument type.

When args.size() > 16, a non-BufferLoad value is silently dropped. That can mask frontend mismatches. Prefer an ICHECK to fail early when the caller passes the wrong type.

Suggested tightening
 if (args.size() > 16) {
-    if (const auto *load = args[16].as<BufferLoadNode>()) {
-      node->mbar_ = Downcast<BufferLoad>(args[16]);
-    } else {
-      node->mbar_ = std::nullopt;
-    }
+    ICHECK(args[16].as<BufferLoadNode>())
+        << "mbar must be provided as BufferLoad when present";
+    node->mbar_ = Downcast<BufferLoad>(args[16]);
 }

Comment on lines 83 to 86
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
node->mbar_ = Downcast<BufferLoad>(args[16]);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fail fast when mbar is present but not a BufferLoad.

Silently ignoring non-BufferLoad inputs can mask call-site errors and lead to a missing barrier later; an explicit check makes failures clearer.

💡 Suggested fix
-  if (args.size() > 16) {
-    if (const auto *load = args[16].as<BufferLoadNode>()) {
-      node->mbar_ = Downcast<BufferLoad>(args[16]);
-    }
-  }
+  if (args.size() > 16) {
+    ICHECK(args[16].as<BufferLoadNode>())
+        << "mbar must be provided as a BufferLoad when present";
+    node->mbar_ = Downcast<BufferLoad>(args[16]);
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if (args.size() > 16) {
if (const auto *load = args[16].as<BufferLoadNode>()) {
node->mbarRegion_ =
NormalizeToBufferRegion(Downcast<BufferLoad>(args[16]));
node->mbar_ = node->mbarRegion_->buffer;
node->mbar_ = Downcast<BufferLoad>(args[16]);
}
if (args.size() > 16) {
ICHECK(args[16].as<BufferLoadNode>())
<< "mbar must be provided as a BufferLoad when present";
node->mbar_ = Downcast<BufferLoad>(args[16]);
}
🤖 Prompt for AI Agents
In `@src/op/gemm_py.cc` around lines 83 - 86, When handling the optional mbar
argument in the arg-parsing block, don't silently ignore non-BufferLoad values:
if args.size() > 16 then attempt the BufferLoadNode cast as currently done
(check for BufferLoadNode and set node->mbar_ via Downcast<BufferLoad>), but add
an else branch that fails fast (throw or LOG(FATAL)/CHECK) reporting that arg 16
was expected to be a BufferLoad and include the actual argument's type/name (use
whatever runtime type introspection is available on the Expr/Node to include in
the message) so callers get a clear error instead of silently skipping mbar_.

PrimExpr clearAccum_ = const_false();
BufferRegion mbarRegion_;
tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA
tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Verify mbar_ usage for null-safety and initialization paths
rg -n -C3 'mbar_' src/op/gemm_py.cc src/op/gemm_py.h

Repository: tile-ai/tilelang

Length of output: 1338


🏁 Script executed:

# Also check the GemmNode definition to see the original type
rg -n -C3 'struct GemmNode\|class GemmNode' src/op/gemm_py.h src/op/gemm_py.cc

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Look for mbar_ initialization in constructors or Create methods
rg -n 'mbar_\s*=' src/op/gemm_py.cc src/op/gemm_py.h

Repository: tile-ai/tilelang

Length of output: 132


🏁 Script executed:

# Check if .defined() or other null checks are used with mbar_
rg -n 'mbar_.*defined|if.*mbar_' src/op/gemm_py.cc src/op/gemm_py.h

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for how mbar_ is used/read (not just assigned)
rg -n 'mbar_' src/op/gemm_py.cc src/op/gemm_py.h -A 5 -B 2

Repository: tile-ai/tilelang

Length of output: 1511


🏁 Script executed:

# Look for uses of mbar in generated IR or kernel code
rg -n 'mbar' src/op/gemm_py.cc src/op/gemm_py.h

Repository: tile-ai/tilelang

Length of output: 286


🏁 Script executed:

# Check if there's a TVM BufferLoad definition to understand default initialization
fd -t f '\.h$' -x grep -l 'class BufferLoad\|struct BufferLoad' {} \; | head -5

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for mbar usage in gemm_py.cc to find where it's read/accessed
rg -n 'mbar' src/op/gemm_py.cc | head -20

Repository: tile-ai/tilelang

Length of output: 114


🏁 Script executed:

# Search for GemmPyNode usage in other files to see how mbar is accessed
rg -n 'GemmPyNode\|mbar' src/ --include='*.cc' --include='*.h' | grep -v 'gemm_py' | head -20

Repository: tile-ai/tilelang

Length of output: 501


🏁 Script executed:

# Check the entire gemm_py.cc to see function context where mbar is used
wc -l src/op/gemm_py.cc

Repository: tile-ai/tilelang

Length of output: 81


🏁 Script executed:

cat -n src/op/gemm_py.cc

Repository: tile-ai/tilelang

Length of output: 15164


Resolve "optional" mismatch for mbar_ (type-contract inconsistency).

mbar_ is declared as non-optional tir::BufferLoad, but the comment marks it as optional and initialization is conditional (line 85: only assigned if args.size() > 16 and args[16] is a BufferLoadNode). This creates a type-contract mismatch: the field can remain uninitialized while the type signature suggests it's always present.

Suggestion: Make this std::optional<tir::BufferLoad> to match the semantic intent, or guarantee initialization and remove the optional comment. Verify Python-side code handles undefined mbar_ safely.

🤖 Prompt for AI Agents
In `@src/op/gemm_py.h` at line 32, mbar_ is declared as a non-optional
tir::BufferLoad but is only conditionally assigned (when args.size() > 16 and
args[16] is a BufferLoadNode), causing a type-contract mismatch; change the
field declaration from tir::BufferLoad mbar_ to std::optional<tir::BufferLoad>
mbar_, update the parser/initializer (where args is inspected) to emplace/assign
mbar_ only in the conditional branch, and adjust any uses of mbar_ (check
has_value() or use value_or) so code and the Python bindings safely handle the
absent case; alternatively, if you prefer non-optional, ensure mbar_ is
unconditionally initialized in the same constructor code path and remove the
"optional" comment.

Comment on lines +95 to +122
PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) {
Buffer buf = load->buffer;
int ndim = static_cast<int>(buf->shape.size());

// Compute offset using row-major layout (iterate in reverse)
PrimExpr offset = make_const(DataType::Int(32), 0);
PrimExpr stride = make_const(DataType::Int(32), 1);

for (int i = ndim - 1; i >= 0; --i) {
const PrimExpr &index = load->indices[i];
if (const auto *ramp = index.as<RampNode>()) {
// For Ramp, use the base
offset = offset + ramp->base * stride;
} else {
// For scalar index (IntImm or other PrimExpr)
offset = offset + index * stride;
}
stride = stride * buf->shape[i];
}

// Extent is 1 element for a single BufferLoad access
PrimExpr extent = make_const(DataType::Int(32), 1);

// Build access_ptr
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid int32 offset/stride to prevent overflow on large buffers.

DataType::Int(32) can overflow for large shapes and diverges from MakeAccessPtrFromRegion. Use the buffer index dtype for offset/stride/extent.

💡 Suggested fix
-  // Compute offset using row-major layout (iterate in reverse)
-  PrimExpr offset = make_const(DataType::Int(32), 0);
-  PrimExpr stride = make_const(DataType::Int(32), 1);
+  // Compute offset using row-major layout (iterate in reverse)
+  DataType idx_dtype = buf->shape[0].dtype();
+  PrimExpr offset = make_const(idx_dtype, 0);
+  PrimExpr stride = make_const(idx_dtype, 1);
@@
-  // Extent is 1 element for a single BufferLoad access
-  PrimExpr extent = make_const(DataType::Int(32), 1);
+  // Extent is 1 element for a single BufferLoad access
+  PrimExpr extent = make_const(idx_dtype, 1);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) {
Buffer buf = load->buffer;
int ndim = static_cast<int>(buf->shape.size());
// Compute offset using row-major layout (iterate in reverse)
PrimExpr offset = make_const(DataType::Int(32), 0);
PrimExpr stride = make_const(DataType::Int(32), 1);
for (int i = ndim - 1; i >= 0; --i) {
const PrimExpr &index = load->indices[i];
if (const auto *ramp = index.as<RampNode>()) {
// For Ramp, use the base
offset = offset + ramp->base * stride;
} else {
// For scalar index (IntImm or other PrimExpr)
offset = offset + index * stride;
}
stride = stride * buf->shape[i];
}
// Extent is 1 element for a single BufferLoad access
PrimExpr extent = make_const(DataType::Int(32), 1);
// Build access_ptr
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) {
Buffer buf = load->buffer;
int ndim = static_cast<int>(buf->shape.size());
// Compute offset using row-major layout (iterate in reverse)
DataType idx_dtype = buf->shape[0].dtype();
PrimExpr offset = make_const(idx_dtype, 0);
PrimExpr stride = make_const(idx_dtype, 1);
for (int i = ndim - 1; i >= 0; --i) {
const PrimExpr &index = load->indices[i];
if (const auto *ramp = index.as<RampNode>()) {
// For Ramp, use the base
offset = offset + ramp->base * stride;
} else {
// For scalar index (IntImm or other PrimExpr)
offset = offset + index * stride;
}
stride = stride * buf->shape[i];
}
// Extent is 1 element for a single BufferLoad access
PrimExpr extent = make_const(idx_dtype, 1);
// Build access_ptr
PrimExpr ptype = tir::TypeAnnotation(buf->dtype);
Array<PrimExpr> acc_args{ptype, buf->data, offset, extent,
IntImm(DataType::Int(32), rw_mask)};
return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args);
}
🤖 Prompt for AI Agents
In `@src/op/utils.cc` around lines 95 - 122, The function
MakeAccessPtrFromBufferLoad uses hard-coded DataType::Int(32) for offset, stride
and extent which can overflow for large buffers; change all occurrences of
make_const(DataType::Int(32), ...) and the IntImm for rw_mask to use the
buffer's index dtype (buf->index_dtype) instead: initialize offset and stride
with make_const(buf->index_dtype, 0/1), compute offset/stride arithmetic with
that dtype, set extent using make_const(buf->index_dtype, 1), and construct the
rw_mask as IntImm(buf->index_dtype, rw_mask) when building acc_args; update
references inside MakeAccessPtrFromBufferLoad (offset, stride, extent, acc_args)
accordingly.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py`:
- Around line 23-47: Add an explicit precondition that M, N, and K are exact
multiples of block_M, block_N, and block_K respectively to prevent out-of-bounds
writes when the kernel (which computes k_iters via T.ceildiv and writes C_shared
-> C at by*block_M, bx*block_N) runs; insert an assertion near the top of the
function before the T.Kernel block (where k_iters, A/B/C and shared buffers are
set up) that checks M % block_M == 0, N % block_N == 0, and K % block_K == 0 and
fail early if not, mirroring the guards used in other GEMM kernels.

In `@tilelang/language/builtin.py`:
- Around line 812-814: The code references mbar_ptr in the tir.call_intrin call
but only sets mbar_ptr inside the isinstance(mbar, (tir.Buffer, BufferLoad))
branch, causing UnboundLocalError when a raw PrimExpr is passed; fix by ensuring
mbar_ptr is always defined: after the existing if-block set mbar_ptr = mbar for
the PrimExpr case (or otherwise convert the PrimExpr to the expected pointer
form) so that mbar_ptr is available before calling tir.call_intrin("void",
tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr); modify the logic around the
mbar/mbar_ptr handling in the function in tilelang/language/builtin.py to cover
both tir.Buffer/BufferLoad and PrimExpr inputs.

Comment on lines +23 to +47
k_iters = T.ceildiv(K, block_K)

A: T.Tensor[[M, K], in_dtype]
B: T.Tensor[[K, N], in_dtype]
C = T.empty((M, N), out_dtype)

with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype)
B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype)
C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)
loaded = T.alloc_barrier([32] * num_stages)
consumed = T.alloc_barrier([1] * num_stages)
tmem_full = T.alloc_barrier([1])

tx = T.get_thread_binding()

T.use_swizzle(8)

if tx < 32: # warp 0: issue tma
for k in T.serial(k_iters):
T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1)
T.copy(A[by * block_M : (by + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :])
T.copy(B[k * block_K : (k + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[k % num_stages, :, :])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# Find the file mentioned in the review
fd "gemm_tcgen5mma_ws.py" --type f

Repository: tile-ai/tilelang

Length of output: 100


🏁 Script executed:

# List the repository structure to understand the codebase layout
ls -la examples/ 2>/dev/null | head -20

Repository: tile-ai/tilelang

Length of output: 1244


🏁 Script executed:

# Search for other GEMM examples to understand the pattern
fd "gemm.*\.py" examples/ --type f | head -20

Repository: tile-ai/tilelang

Length of output: 1130


🏁 Script executed:

# Examine the file mentioned in the review
cat -n examples/gemm_sm100/gemm_tcgen5mma_ws.py

Repository: tile-ai/tilelang

Length of output: 4292


🏁 Script executed:

# Also search for other gemm_sm100 files to get context
ls -la examples/gemm_sm100/

Repository: tile-ai/tilelang

Length of output: 434


🏁 Script executed:

# Search for other GEMM examples to understand how they handle non-divisible dimensions
grep -r "T.copy.*\[.*\*.*block_" examples/gemm* --include="*.py" | head -20

Repository: tile-ai/tilelang

Length of output: 2122


🏁 Script executed:

# Look at other SM100 GEMM variants to see if they have guards
cat -n examples/gemm_sm100/gemm_tcgen5mma.py

Repository: tile-ai/tilelang

Length of output: 3381


🏁 Script executed:

# Search for any guard patterns (assertions, checks) in GEMM examples
rg "assert.*%|assert.*divisible|if.*%" examples/gemm* -A 2 -B 2 | head -40

Repository: tile-ai/tilelang

Length of output: 3576


Add explicit precondition to guard against non‑divisible M/N/K dimensions.
The kernel uses T.ceildiv for grid dimensions, and line 68 copies C_shared (shape [block_M, block_N]) to C[by * block_M, bx * block_N] without bounds checking. If M/N/K are not exact multiples of their block sizes, the last threadblocks will write out of bounds. The test case passes only because M=N=K=16384 are exact multiples of their block sizes. Add an assertion to enforce this precondition, as done in other GEMM kernels in the codebase.

🛡️ Suggested guard (adjust to TileLang-supported assertion APIs)
 def gemm(
     A,
     B,
@@
 ):
     M, N, K = T.const("M, N, K")
+    # Kernel requires exact tile coverage to avoid OOB access at line 68.
+    assert (
+        M % block_M == 0
+        and N % block_N == 0
+        and K % block_K == 0
+    ), "M/N/K must be divisible by block_M/block_N/block_K"
🤖 Prompt for AI Agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py` around lines 23 - 47, Add an
explicit precondition that M, N, and K are exact multiples of block_M, block_N,
and block_K respectively to prevent out-of-bounds writes when the kernel (which
computes k_iters via T.ceildiv and writes C_shared -> C at by*block_M,
bx*block_N) runs; insert an assertion near the top of the function before the
T.Kernel block (where k_iters, A/B/C and shared buffers are set up) that checks
M % block_M == 0, N % block_N == 0, and K % block_K == 0 and fail early if not,
mirroring the guards used in other GEMM kernels.

Comment on lines +812 to 814
if isinstance(mbar, (tir.Buffer, BufferLoad)):
mbar_ptr = retrieve_ptr(mbar, access_type="rw")
return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

UnboundLocalError when mbar is a PrimExpr.

The function signature accepts PrimExpr, but mbar_ptr is only assigned when mbar is a Buffer or BufferLoad. When a raw PrimExpr (pointer address) is passed, the if-block is skipped and line 814 references an undefined mbar_ptr.

🐛 Proposed fix
     if isinstance(mbar, (tir.Buffer, BufferLoad)):
         mbar_ptr = retrieve_ptr(mbar, access_type="rw")
+    else:
+        mbar_ptr = mbar
     return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if isinstance(mbar, (tir.Buffer, BufferLoad)):
mbar_ptr = retrieve_ptr(mbar, access_type="rw")
return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr)
if isinstance(mbar, (tir.Buffer, BufferLoad)):
mbar_ptr = retrieve_ptr(mbar, access_type="rw")
else:
mbar_ptr = mbar
return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr)
🤖 Prompt for AI Agents
In `@tilelang/language/builtin.py` around lines 812 - 814, The code references
mbar_ptr in the tir.call_intrin call but only sets mbar_ptr inside the
isinstance(mbar, (tir.Buffer, BufferLoad)) branch, causing UnboundLocalError
when a raw PrimExpr is passed; fix by ensuring mbar_ptr is always defined: after
the existing if-block set mbar_ptr = mbar for the PrimExpr case (or otherwise
convert the PrimExpr to the expected pointer form) so that mbar_ptr is available
before calling tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"),
mbar_ptr); modify the logic around the mbar/mbar_ptr handling in the function in
tilelang/language/builtin.py to cover both tir.Buffer/BufferLoad and PrimExpr
inputs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant