-
Notifications
You must be signed in to change notification settings - Fork 431
[Example][BugFix] 1SM GEMM example on Blackwell and fix handling of mbar
#1774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughAdds a new SM100 GEMM example kernel and refactors mbar handling from region-based to BufferLoad-based across GEMM lowering and TileLang builtins; introduces MakeAccessPtrFromBufferLoad, changes TMA bulk-copy gating to use tl_shuffle_elect, and updates TileLang/Gemm APIs to accept/return BufferLoad or None. Changes
Sequence Diagram(s)mermaid Host->>Kernel: launch gemm kernel Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🤖 Fix all issues with AI agents
In `@src/op/gemm_py.cc`:
- Around line 83-86: When handling the optional mbar argument in the arg-parsing
block, don't silently ignore non-BufferLoad values: if args.size() > 16 then
attempt the BufferLoadNode cast as currently done (check for BufferLoadNode and
set node->mbar_ via Downcast<BufferLoad>), but add an else branch that fails
fast (throw or LOG(FATAL)/CHECK) reporting that arg 16 was expected to be a
BufferLoad and include the actual argument's type/name (use whatever runtime
type introspection is available on the Expr/Node to include in the message) so
callers get a clear error instead of silently skipping mbar_.
In `@src/op/gemm_py.h`:
- Line 32: mbar_ is declared as a non-optional tir::BufferLoad but is only
conditionally assigned (when args.size() > 16 and args[16] is a BufferLoadNode),
causing a type-contract mismatch; change the field declaration from
tir::BufferLoad mbar_ to std::optional<tir::BufferLoad> mbar_, update the
parser/initializer (where args is inspected) to emplace/assign mbar_ only in the
conditional branch, and adjust any uses of mbar_ (check has_value() or use
value_or) so code and the Python bindings safely handle the absent case;
alternatively, if you prefer non-optional, ensure mbar_ is unconditionally
initialized in the same constructor code path and remove the "optional" comment.
In `@src/op/utils.cc`:
- Around line 95-122: The function MakeAccessPtrFromBufferLoad uses hard-coded
DataType::Int(32) for offset, stride and extent which can overflow for large
buffers; change all occurrences of make_const(DataType::Int(32), ...) and the
IntImm for rw_mask to use the buffer's index dtype (buf->index_dtype) instead:
initialize offset and stride with make_const(buf->index_dtype, 0/1), compute
offset/stride arithmetic with that dtype, set extent using
make_const(buf->index_dtype, 1), and construct the rw_mask as
IntImm(buf->index_dtype, rw_mask) when building acc_args; update references
inside MakeAccessPtrFromBufferLoad (offset, stride, extent, acc_args)
accordingly.
🧹 Nitpick comments (1)
src/op/gemm.cc (1)
84-88: Consider hard-failing on unexpected mbar argument type.When
args.size() > 16, a non-BufferLoad value is silently dropped. That can mask frontend mismatches. Prefer anICHECKto fail early when the caller passes the wrong type.Suggested tightening
if (args.size() > 16) { - if (const auto *load = args[16].as<BufferLoadNode>()) { - node->mbar_ = Downcast<BufferLoad>(args[16]); - } else { - node->mbar_ = std::nullopt; - } + ICHECK(args[16].as<BufferLoadNode>()) + << "mbar must be provided as BufferLoad when present"; + node->mbar_ = Downcast<BufferLoad>(args[16]); }
| if (args.size() > 16) { | ||
| if (const auto *load = args[16].as<BufferLoadNode>()) { | ||
| node->mbarRegion_ = | ||
| NormalizeToBufferRegion(Downcast<BufferLoad>(args[16])); | ||
| node->mbar_ = node->mbarRegion_->buffer; | ||
| node->mbar_ = Downcast<BufferLoad>(args[16]); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fail fast when mbar is present but not a BufferLoad.
Silently ignoring non-BufferLoad inputs can mask call-site errors and lead to a missing barrier later; an explicit check makes failures clearer.
💡 Suggested fix
- if (args.size() > 16) {
- if (const auto *load = args[16].as<BufferLoadNode>()) {
- node->mbar_ = Downcast<BufferLoad>(args[16]);
- }
- }
+ if (args.size() > 16) {
+ ICHECK(args[16].as<BufferLoadNode>())
+ << "mbar must be provided as a BufferLoad when present";
+ node->mbar_ = Downcast<BufferLoad>(args[16]);
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if (args.size() > 16) { | |
| if (const auto *load = args[16].as<BufferLoadNode>()) { | |
| node->mbarRegion_ = | |
| NormalizeToBufferRegion(Downcast<BufferLoad>(args[16])); | |
| node->mbar_ = node->mbarRegion_->buffer; | |
| node->mbar_ = Downcast<BufferLoad>(args[16]); | |
| } | |
| if (args.size() > 16) { | |
| ICHECK(args[16].as<BufferLoadNode>()) | |
| << "mbar must be provided as a BufferLoad when present"; | |
| node->mbar_ = Downcast<BufferLoad>(args[16]); | |
| } |
🤖 Prompt for AI Agents
In `@src/op/gemm_py.cc` around lines 83 - 86, When handling the optional mbar
argument in the arg-parsing block, don't silently ignore non-BufferLoad values:
if args.size() > 16 then attempt the BufferLoadNode cast as currently done
(check for BufferLoadNode and set node->mbar_ via Downcast<BufferLoad>), but add
an else branch that fails fast (throw or LOG(FATAL)/CHECK) reporting that arg 16
was expected to be a BufferLoad and include the actual argument's type/name (use
whatever runtime type introspection is available on the Expr/Node to include in
the message) so callers get a clear error instead of silently skipping mbar_.
| PrimExpr clearAccum_ = const_false(); | ||
| BufferRegion mbarRegion_; | ||
| tir::Buffer mbar_; // mbar is optional, only used for TCGEN5MMA | ||
| tir::BufferLoad mbar_; // mbar is optional, only used for TCGEN5MMA |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify mbar_ usage for null-safety and initialization paths
rg -n -C3 'mbar_' src/op/gemm_py.cc src/op/gemm_py.hRepository: tile-ai/tilelang
Length of output: 1338
🏁 Script executed:
# Also check the GemmNode definition to see the original type
rg -n -C3 'struct GemmNode\|class GemmNode' src/op/gemm_py.h src/op/gemm_py.ccRepository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Look for mbar_ initialization in constructors or Create methods
rg -n 'mbar_\s*=' src/op/gemm_py.cc src/op/gemm_py.hRepository: tile-ai/tilelang
Length of output: 132
🏁 Script executed:
# Check if .defined() or other null checks are used with mbar_
rg -n 'mbar_.*defined|if.*mbar_' src/op/gemm_py.cc src/op/gemm_py.hRepository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for how mbar_ is used/read (not just assigned)
rg -n 'mbar_' src/op/gemm_py.cc src/op/gemm_py.h -A 5 -B 2Repository: tile-ai/tilelang
Length of output: 1511
🏁 Script executed:
# Look for uses of mbar in generated IR or kernel code
rg -n 'mbar' src/op/gemm_py.cc src/op/gemm_py.hRepository: tile-ai/tilelang
Length of output: 286
🏁 Script executed:
# Check if there's a TVM BufferLoad definition to understand default initialization
fd -t f '\.h$' -x grep -l 'class BufferLoad\|struct BufferLoad' {} \; | head -5Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for mbar usage in gemm_py.cc to find where it's read/accessed
rg -n 'mbar' src/op/gemm_py.cc | head -20Repository: tile-ai/tilelang
Length of output: 114
🏁 Script executed:
# Search for GemmPyNode usage in other files to see how mbar is accessed
rg -n 'GemmPyNode\|mbar' src/ --include='*.cc' --include='*.h' | grep -v 'gemm_py' | head -20Repository: tile-ai/tilelang
Length of output: 501
🏁 Script executed:
# Check the entire gemm_py.cc to see function context where mbar is used
wc -l src/op/gemm_py.ccRepository: tile-ai/tilelang
Length of output: 81
🏁 Script executed:
cat -n src/op/gemm_py.ccRepository: tile-ai/tilelang
Length of output: 15164
Resolve "optional" mismatch for mbar_ (type-contract inconsistency).
mbar_ is declared as non-optional tir::BufferLoad, but the comment marks it as optional and initialization is conditional (line 85: only assigned if args.size() > 16 and args[16] is a BufferLoadNode). This creates a type-contract mismatch: the field can remain uninitialized while the type signature suggests it's always present.
Suggestion: Make this std::optional<tir::BufferLoad> to match the semantic intent, or guarantee initialization and remove the optional comment. Verify Python-side code handles undefined mbar_ safely.
🤖 Prompt for AI Agents
In `@src/op/gemm_py.h` at line 32, mbar_ is declared as a non-optional
tir::BufferLoad but is only conditionally assigned (when args.size() > 16 and
args[16] is a BufferLoadNode), causing a type-contract mismatch; change the
field declaration from tir::BufferLoad mbar_ to std::optional<tir::BufferLoad>
mbar_, update the parser/initializer (where args is inspected) to emplace/assign
mbar_ only in the conditional branch, and adjust any uses of mbar_ (check
has_value() or use value_or) so code and the Python bindings safely handle the
absent case; alternatively, if you prefer non-optional, ensure mbar_ is
unconditionally initialized in the same constructor code path and remove the
"optional" comment.
| PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) { | ||
| Buffer buf = load->buffer; | ||
| int ndim = static_cast<int>(buf->shape.size()); | ||
|
|
||
| // Compute offset using row-major layout (iterate in reverse) | ||
| PrimExpr offset = make_const(DataType::Int(32), 0); | ||
| PrimExpr stride = make_const(DataType::Int(32), 1); | ||
|
|
||
| for (int i = ndim - 1; i >= 0; --i) { | ||
| const PrimExpr &index = load->indices[i]; | ||
| if (const auto *ramp = index.as<RampNode>()) { | ||
| // For Ramp, use the base | ||
| offset = offset + ramp->base * stride; | ||
| } else { | ||
| // For scalar index (IntImm or other PrimExpr) | ||
| offset = offset + index * stride; | ||
| } | ||
| stride = stride * buf->shape[i]; | ||
| } | ||
|
|
||
| // Extent is 1 element for a single BufferLoad access | ||
| PrimExpr extent = make_const(DataType::Int(32), 1); | ||
|
|
||
| // Build access_ptr | ||
| PrimExpr ptype = tir::TypeAnnotation(buf->dtype); | ||
| Array<PrimExpr> acc_args{ptype, buf->data, offset, extent, | ||
| IntImm(DataType::Int(32), rw_mask)}; | ||
| return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid int32 offset/stride to prevent overflow on large buffers.
DataType::Int(32) can overflow for large shapes and diverges from MakeAccessPtrFromRegion. Use the buffer index dtype for offset/stride/extent.
💡 Suggested fix
- // Compute offset using row-major layout (iterate in reverse)
- PrimExpr offset = make_const(DataType::Int(32), 0);
- PrimExpr stride = make_const(DataType::Int(32), 1);
+ // Compute offset using row-major layout (iterate in reverse)
+ DataType idx_dtype = buf->shape[0].dtype();
+ PrimExpr offset = make_const(idx_dtype, 0);
+ PrimExpr stride = make_const(idx_dtype, 1);
@@
- // Extent is 1 element for a single BufferLoad access
- PrimExpr extent = make_const(DataType::Int(32), 1);
+ // Extent is 1 element for a single BufferLoad access
+ PrimExpr extent = make_const(idx_dtype, 1);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) { | |
| Buffer buf = load->buffer; | |
| int ndim = static_cast<int>(buf->shape.size()); | |
| // Compute offset using row-major layout (iterate in reverse) | |
| PrimExpr offset = make_const(DataType::Int(32), 0); | |
| PrimExpr stride = make_const(DataType::Int(32), 1); | |
| for (int i = ndim - 1; i >= 0; --i) { | |
| const PrimExpr &index = load->indices[i]; | |
| if (const auto *ramp = index.as<RampNode>()) { | |
| // For Ramp, use the base | |
| offset = offset + ramp->base * stride; | |
| } else { | |
| // For scalar index (IntImm or other PrimExpr) | |
| offset = offset + index * stride; | |
| } | |
| stride = stride * buf->shape[i]; | |
| } | |
| // Extent is 1 element for a single BufferLoad access | |
| PrimExpr extent = make_const(DataType::Int(32), 1); | |
| // Build access_ptr | |
| PrimExpr ptype = tir::TypeAnnotation(buf->dtype); | |
| Array<PrimExpr> acc_args{ptype, buf->data, offset, extent, | |
| IntImm(DataType::Int(32), rw_mask)}; | |
| return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); | |
| PrimExpr MakeAccessPtrFromBufferLoad(const BufferLoad &load, int rw_mask) { | |
| Buffer buf = load->buffer; | |
| int ndim = static_cast<int>(buf->shape.size()); | |
| // Compute offset using row-major layout (iterate in reverse) | |
| DataType idx_dtype = buf->shape[0].dtype(); | |
| PrimExpr offset = make_const(idx_dtype, 0); | |
| PrimExpr stride = make_const(idx_dtype, 1); | |
| for (int i = ndim - 1; i >= 0; --i) { | |
| const PrimExpr &index = load->indices[i]; | |
| if (const auto *ramp = index.as<RampNode>()) { | |
| // For Ramp, use the base | |
| offset = offset + ramp->base * stride; | |
| } else { | |
| // For scalar index (IntImm or other PrimExpr) | |
| offset = offset + index * stride; | |
| } | |
| stride = stride * buf->shape[i]; | |
| } | |
| // Extent is 1 element for a single BufferLoad access | |
| PrimExpr extent = make_const(idx_dtype, 1); | |
| // Build access_ptr | |
| PrimExpr ptype = tir::TypeAnnotation(buf->dtype); | |
| Array<PrimExpr> acc_args{ptype, buf->data, offset, extent, | |
| IntImm(DataType::Int(32), rw_mask)}; | |
| return Call(DataType::Handle(), builtin::tvm_access_ptr(), acc_args); | |
| } |
🤖 Prompt for AI Agents
In `@src/op/utils.cc` around lines 95 - 122, The function
MakeAccessPtrFromBufferLoad uses hard-coded DataType::Int(32) for offset, stride
and extent which can overflow for large buffers; change all occurrences of
make_const(DataType::Int(32), ...) and the IntImm for rw_mask to use the
buffer's index dtype (buf->index_dtype) instead: initialize offset and stride
with make_const(buf->index_dtype, 0/1), compute offset/stride arithmetic with
that dtype, set extent using make_const(buf->index_dtype, 1), and construct the
rw_mask as IntImm(buf->index_dtype, rw_mask) when building acc_args; update
references inside MakeAccessPtrFromBufferLoad (offset, stride, extent, acc_args)
accordingly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py`:
- Around line 23-47: Add an explicit precondition that M, N, and K are exact
multiples of block_M, block_N, and block_K respectively to prevent out-of-bounds
writes when the kernel (which computes k_iters via T.ceildiv and writes C_shared
-> C at by*block_M, bx*block_N) runs; insert an assertion near the top of the
function before the T.Kernel block (where k_iters, A/B/C and shared buffers are
set up) that checks M % block_M == 0, N % block_N == 0, and K % block_K == 0 and
fail early if not, mirroring the guards used in other GEMM kernels.
In `@tilelang/language/builtin.py`:
- Around line 812-814: The code references mbar_ptr in the tir.call_intrin call
but only sets mbar_ptr inside the isinstance(mbar, (tir.Buffer, BufferLoad))
branch, causing UnboundLocalError when a raw PrimExpr is passed; fix by ensuring
mbar_ptr is always defined: after the existing if-block set mbar_ptr = mbar for
the PrimExpr case (or otherwise convert the PrimExpr to the expected pointer
form) so that mbar_ptr is available before calling tir.call_intrin("void",
tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr); modify the logic around the
mbar/mbar_ptr handling in the function in tilelang/language/builtin.py to cover
both tir.Buffer/BufferLoad and PrimExpr inputs.
| k_iters = T.ceildiv(K, block_K) | ||
|
|
||
| A: T.Tensor[[M, K], in_dtype] | ||
| B: T.Tensor[[K, N], in_dtype] | ||
| C = T.empty((M, N), out_dtype) | ||
|
|
||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): | ||
| A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) | ||
| B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype) | ||
| C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
| C_shared = T.alloc_shared((block_M, block_N), out_dtype) | ||
| loaded = T.alloc_barrier([32] * num_stages) | ||
| consumed = T.alloc_barrier([1] * num_stages) | ||
| tmem_full = T.alloc_barrier([1]) | ||
|
|
||
| tx = T.get_thread_binding() | ||
|
|
||
| T.use_swizzle(8) | ||
|
|
||
| if tx < 32: # warp 0: issue tma | ||
| for k in T.serial(k_iters): | ||
| T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) | ||
| T.copy(A[by * block_M : (by + 1) * block_M, k * block_K : (k + 1) * block_K], A_shared[k % num_stages, :, :]) | ||
| T.copy(B[k * block_K : (k + 1) * block_K, bx * block_N : (bx + 1) * block_N], B_shared[k % num_stages, :, :]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# Find the file mentioned in the review
fd "gemm_tcgen5mma_ws.py" --type fRepository: tile-ai/tilelang
Length of output: 100
🏁 Script executed:
# List the repository structure to understand the codebase layout
ls -la examples/ 2>/dev/null | head -20Repository: tile-ai/tilelang
Length of output: 1244
🏁 Script executed:
# Search for other GEMM examples to understand the pattern
fd "gemm.*\.py" examples/ --type f | head -20Repository: tile-ai/tilelang
Length of output: 1130
🏁 Script executed:
# Examine the file mentioned in the review
cat -n examples/gemm_sm100/gemm_tcgen5mma_ws.pyRepository: tile-ai/tilelang
Length of output: 4292
🏁 Script executed:
# Also search for other gemm_sm100 files to get context
ls -la examples/gemm_sm100/Repository: tile-ai/tilelang
Length of output: 434
🏁 Script executed:
# Search for other GEMM examples to understand how they handle non-divisible dimensions
grep -r "T.copy.*\[.*\*.*block_" examples/gemm* --include="*.py" | head -20Repository: tile-ai/tilelang
Length of output: 2122
🏁 Script executed:
# Look at other SM100 GEMM variants to see if they have guards
cat -n examples/gemm_sm100/gemm_tcgen5mma.pyRepository: tile-ai/tilelang
Length of output: 3381
🏁 Script executed:
# Search for any guard patterns (assertions, checks) in GEMM examples
rg "assert.*%|assert.*divisible|if.*%" examples/gemm* -A 2 -B 2 | head -40Repository: tile-ai/tilelang
Length of output: 3576
Add explicit precondition to guard against non‑divisible M/N/K dimensions.
The kernel uses T.ceildiv for grid dimensions, and line 68 copies C_shared (shape [block_M, block_N]) to C[by * block_M, bx * block_N] without bounds checking. If M/N/K are not exact multiples of their block sizes, the last threadblocks will write out of bounds. The test case passes only because M=N=K=16384 are exact multiples of their block sizes. Add an assertion to enforce this precondition, as done in other GEMM kernels in the codebase.
🛡️ Suggested guard (adjust to TileLang-supported assertion APIs)
def gemm(
A,
B,
@@
):
M, N, K = T.const("M, N, K")
+ # Kernel requires exact tile coverage to avoid OOB access at line 68.
+ assert (
+ M % block_M == 0
+ and N % block_N == 0
+ and K % block_K == 0
+ ), "M/N/K must be divisible by block_M/block_N/block_K"🤖 Prompt for AI Agents
In `@examples/gemm_sm100/gemm_tcgen5mma_ws.py` around lines 23 - 47, Add an
explicit precondition that M, N, and K are exact multiples of block_M, block_N,
and block_K respectively to prevent out-of-bounds writes when the kernel (which
computes k_iters via T.ceildiv and writes C_shared -> C at by*block_M,
bx*block_N) runs; insert an assertion near the top of the function before the
T.Kernel block (where k_iters, A/B/C and shared buffers are set up) that checks
M % block_M == 0, N % block_N == 0, and K % block_K == 0 and fail early if not,
mirroring the guards used in other GEMM kernels.
| if isinstance(mbar, (tir.Buffer, BufferLoad)): | ||
| mbar_ptr = retrieve_ptr(mbar, access_type="rw") | ||
| return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
UnboundLocalError when mbar is a PrimExpr.
The function signature accepts PrimExpr, but mbar_ptr is only assigned when mbar is a Buffer or BufferLoad. When a raw PrimExpr (pointer address) is passed, the if-block is skipped and line 814 references an undefined mbar_ptr.
🐛 Proposed fix
if isinstance(mbar, (tir.Buffer, BufferLoad)):
mbar_ptr = retrieve_ptr(mbar, access_type="rw")
+ else:
+ mbar_ptr = mbar
return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if isinstance(mbar, (tir.Buffer, BufferLoad)): | |
| mbar_ptr = retrieve_ptr(mbar, access_type="rw") | |
| return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr) | |
| if isinstance(mbar, (tir.Buffer, BufferLoad)): | |
| mbar_ptr = retrieve_ptr(mbar, access_type="rw") | |
| else: | |
| mbar_ptr = mbar | |
| return tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"), mbar_ptr) |
🤖 Prompt for AI Agents
In `@tilelang/language/builtin.py` around lines 812 - 814, The code references
mbar_ptr in the tir.call_intrin call but only sets mbar_ptr inside the
isinstance(mbar, (tir.Buffer, BufferLoad)) branch, causing UnboundLocalError
when a raw PrimExpr is passed; fix by ensuring mbar_ptr is always defined: after
the existing if-block set mbar_ptr = mbar for the PrimExpr case (or otherwise
convert the PrimExpr to the expected pointer form) so that mbar_ptr is available
before calling tir.call_intrin("void", tir.op.Op.get("tl.tcgen05_mma_arrive"),
mbar_ptr); modify the logic around the mbar/mbar_ptr handling in the function in
tilelang/language/builtin.py to cover both tir.Buffer/BufferLoad and PrimExpr
inputs.
mbarasBufferLoadto avoid missing indexThanks @Hamerlate for providing the dev machine.
Summary by CodeRabbit
New Features
Refactoring
API