-
Notifications
You must be signed in to change notification settings - Fork 429
[Refactor] Improve type annotations and reduce some lint errors in frontend #1777
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughCentral typing modernization: introduces Changes
Sequence Diagram(s)(omitted) Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tilelang/language/loop.py (1)
13-24:⚠️ Potential issue | 🟡 MinorDocstring does not reflect updated parameter type.
The function signature now accepts
int | tir.PrimExprforextents, but the docstring on line 23 still statesextents : PrimExpr. Consider updating the docstring to match the signature.📝 Suggested docstring fix
Parameters ---------- - extents : PrimExpr + extents : int | PrimExpr The extents of the iteration.tilelang/language/kernel.py (1)
229-241:⚠️ Potential issue | 🟡 MinorDocstring could clarify accepted types for
blocksparameter.The signature now accepts
int | tir.PrimExpr, but the docstring on lines 239-240 only mentionsint. Consider updating to reflect both accepted types for clarity.📝 Suggested docstring improvement
Parameters ---------- - blocks : int - A list of extent, can be 1-3 dimension, representing gridDim.(x|y|z) + blocks : int | PrimExpr + The block extents (1-3 dimensions), representing gridDim.(x|y|z)tilelang/language/gemm_op.py (1)
40-56:⚠️ Potential issue | 🟠 MajorDon’t drop region/offset info when unwrapping let-bound buffers.
Line 49-50 returns.buffer, which loses BufferRegion/BufferLoad slicing;to_buffer_regionthen computes shapes/offsets on the full buffer, which can mis-handle sliced A/B/C (and mbar arrays). Return the let value directly and letto_buffer_regionnormalize it.🔧 Suggested fix
- if isinstance(arg, tir.Var) and T.has_let_value(arg): - return T.get_let_value(arg).buffer + if isinstance(arg, tir.Var) and T.has_let_value(arg): + return T.get_let_value(arg) return argtilelang/language/ast/ir.py (1)
983-1013:⚠️ Potential issue | 🟡 MinorHandle
Nonecondition inallocate()to prevent runtime errors.The function signature allows
condition=None(line 987), but the code only convertsbooltoIntImmand passesconditiondirectly to_ffi_api.Allocate(). Whenallocate()is called without theconditionparameter (as shown in all test usages), it defaults toNoneand will cause a runtime error. Add a guard to defaultNonetoIntImm("bool", True):🛠️ Suggested fix
if isinstance(condition, bool): condition = IntImm("bool", condition) + elif condition is None: + condition = IntImm("bool", True)
🤖 Fix all issues with AI agents
In `@tilelang/language/fastmath.py`:
- Line 127: Update the docstring for the __exp function in
tilelang/language/fastmath.py to state that it computes the natural exponential
(e**x or exp(x)) rather than 2**x; locate the docstring above the __exp
definition and replace "Calculate 2**x with fast math" with text like "Calculate
e**x (exp(x)) with fast math" so it matches the CUDA backend mapping and naming
conventions.
🧹 Nitpick comments (5)
tilelang/language/customize.py (1)
57-67: Consider enhancing the assertion message inviewfor consistency.The
viewfunction's assertion message (line 66) is less informative thanreshape's. For debugging ease, consider including the actual values.📝 Suggested improvement
- assert prim_expr_equal(bits_product(shape, dtype), bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed." + assert prim_expr_equal(bits_product(shape, dtype), bits_product(src.shape, src.dtype)), ( + f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {dtype}" + )tilelang/language/copy_op.py (1)
28-29: Consider updating docstrings to referenceBufferLikeType.The docstrings still reference the explicit union types (
Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion]on lines 28-29 andtir.Bufferon lines 134-135) while the signatures now useBufferLikeType. For consistency with the new typing system, consider updating these docstrings.Also applies to: 134-135
tilelang/language/symbolics.py (1)
19-19: Minor: Docstring mentionsstrbut type hint isDType.Line 19 states
dtype (str): Data type stringbut the parameter type is nowDType. Consider updating todtype (DType): Data type for the variable.tilelang/language/allocate.py (1)
262-276: Inconsistent typing: These functions still usedtype: strinstead ofDType.For consistency with the updated functions above, consider updating the following functions to use
DTypeinstead ofstr:
alloc_wgmma_desc(line 262)alloc_tcgen05_smem_desc(line 266)alloc_tcgen05_instruction_desc(line 270)alloc_tcgen05_instr_desc(line 275)emptyoverload and implementation (lines 280, 283)♻️ Proposed fix for consistency
-def alloc_wgmma_desc(dtype: str = _dtypes.uint64): +def alloc_wgmma_desc(dtype: DType = _dtypes.uint64): return alloc_descriptor("wgmma", dtype=dtype) -def alloc_tcgen05_smem_desc(dtype: str = _dtypes.uint64): +def alloc_tcgen05_smem_desc(dtype: DType = _dtypes.uint64): return alloc_descriptor("tcgen05_smem", dtype=dtype) -def alloc_tcgen05_instruction_desc(dtype: str = _dtypes.uint32): +def alloc_tcgen05_instruction_desc(dtype: DType = _dtypes.uint32): return alloc_descriptor("tcgen05_instr", dtype=dtype) # Alias: short name consistent with imports -def alloc_tcgen05_instr_desc(dtype: str = _dtypes.uint32): +def alloc_tcgen05_instr_desc(dtype: DType = _dtypes.uint32): return alloc_tcgen05_instruction_desc(dtype) `@overload` -def empty(shape, dtype: str = _dtypes.float32) -> Tensor: ... +def empty(shape, dtype: DType = _dtypes.float32) -> Tensor: ... -def empty(*shape, dtype: str = _dtypes.float32) -> Tensor: +def empty(*shape, dtype: DType = _dtypes.float32) -> Tensor:Also applies to: 280-283
tilelang/language/math_intrinsics.py (1)
15-148: Consolidate duplicated fast math functions to a single source.The 8 fast math functions (
__log,__log2,__log10,__tan,__cos,__sin,__exp10,__exp) are duplicated betweentilelang/language/fastmath.pyandtilelang/language/math_intrinsics.pywith identical implementations. Consider havingmath_intrinsics.pyimport and re-export these functions fromfastmath.pyrather than duplicating them, or consolidate to a single module to maintain one source of truth.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/language/logical.py (2)
68-72:⚠️ Potential issue | 🟡 MinorFix all_of error message to reference the correct API.
The message says “T.any” inside
all_of, which is confusing for users.✏️ Proposed fix
- "Only support the last dimension to be for T.any currently, please contact us if you need this feature" + "Only support the last dimension to be for T.all currently, please contact us if you need this feature"
12-43:⚠️ Potential issue | 🔴 CriticalType annotation does not match runtime handling; BufferLoad is accepted by type hint but rejected at runtime.
BufferLikeTypeincludesBufferLoad, but bothany_ofandall_ofonly handleBufferandBufferRegion, raisingValueErrorforBufferLoad. This creates a type contract violation that static type checkers will miss.Additionally, fix the
all_ofdocstring (line 53) which incorrectly says "performs the any operation" and the error message (line 71) which says "T.any currently"—both should reference "all" instead of "any".
🧹 Nitpick comments (2)
tilelang/language/experimental/gemm_sp.py (2)
38-41: Minor style inconsistency: docstrings useUnion[]while type hints use|syntax.The function signature uses the modern
BufferLikeType | tir.Varsyntax (PEP 604), but the docstrings use the olderUnion[BufferLikeType, tir.Var]format. Consider aligning docstrings with the code style for consistency.✏️ Suggested docstring update
Args: - A_sparse (Union[BufferLikeType, tir.Var]): First input matrix dense values - E (Union[BufferLikeType, tir.Var]): First input matrix sparse metadata - B (Union[BufferLikeType, tir.Var]): Second input matrix - C (Union[BufferLikeType, tir.Var]): Output matrix for results + A_sparse (BufferLikeType | tir.Var): First input matrix dense values + E (BufferLikeType | tir.Var): First input matrix sparse metadata + B (BufferLikeType | tir.Var): Second input matrix + C (BufferLikeType | tir.Var): Output matrix for results
56-67: Consider extracting duplicatelegalize_argumentshelper to module level.This inner function is duplicated identically in
gemm_sp_v2(lines 139-150). Extracting it to a module-level helper would reduce duplication and improve maintainability.♻️ Suggested refactor
Add a module-level helper before the function definitions:
def _legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType | tir.Var: """Convert let-bound variables to their corresponding buffers. Args: arg (BufferLikeType | tir.Var): Input argument to legalize Returns: BufferLikeType | tir.Var: The legalized argument """ if isinstance(arg, tir.Var) and T.has_let_value(arg): return T.get_let_value(arg).buffer return argThen use
_legalize_argumentsin bothgemm_spandgemm_sp_v2instead of defining the inner function twice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tilelang/language/allocate.py (2)
33-48:⚠️ Potential issue | 🟠 MajorHandle non-string bool dtypes in the shared-memory workaround.
With the switch toDType, callers can pass dtype objects (e.g.,_dtypes.bool). The currentdtype == "bool"check won’t catch those, leaving scope at"shared.dyn"and reintroducing the bool-merge issue. Consider normalizing to a string or explicitly checking dtype objects.🔧 Proposed fix
- if dtype == "bool": + if str(dtype) == "bool": # lei: This is a hack to handle bool type. # Because tilelang's merge smem pass cannot merge bool type currently. scope = "shared"
279-291:⚠️ Potential issue | 🟠 MajorAllow positional dtype objects in
empty.
The branch forempty(shape_tuple, dtype)only acceptsstr. WithDType, callers can reasonably pass dtype objects positionally (e.g.,_dtypes.float32), which currently raisesTypeError. Extend the check to accept dtype objects.🔧 Proposed fix
- elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str): + elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], (str, tl_dtype)): return OutTensor(shape[0], shape[1])
🤖 Fix all issues with AI agents
In `@tilelang/language/allocate.py`:
- Around line 79-87: The implementation of alloc_var currently narrows the init
parameter to PrimExpr | None, which conflicts with the overloads that allow int
| float as well; update the implementation signature of alloc_var to accept
init: PrimExpr | int | float | None (matching the overloads) so type checking
passes, and ensure any internal handling of init (inside alloc_var) still
correctly accepts and normalizes int/float to PrimExpr as expected.
In `@tilelang/language/builtin.py`:
- Around line 431-435: Normalize the dtype parameter to a Python string as soon
as it's validated in the function signature handling (where buffer_or_ptr,
offset, num_regs, dtype are processed) so downstream comparisons and uses
succeed; specifically, if dtype is not None convert it via dtype = str(dtype)
(or use dtype.name/str-conversion consistent with DataType expectations)
immediately after validation so comparisons like dtype == "int32" work,
DataType(dtype).bits receives a string, and the value passed into
tir.call_intrin() becomes the expected StringImm. Ensure this normalization
happens before any `if dtype == ...`, `DataType(dtype)`, or
`tir.call_intrin(...)` calls.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tilelang/language/print_op.py (3)
74-78:⚠️ Potential issue | 🟠 MajorReturn inside loop exits after first iteration.
Placing
returninside theforloop at line 78 causes the function to exit after processing only the first element, contradicting the documented behavior of iterating through all buffer elements.🐛 Proposed fix - remove return from inside loop
if condition: # Iterate through the buffer elements and print each one. for i in serial(elems): coords = index_to_coordinates(i, buffer.shape) - return tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
96-100:⚠️ Potential issue | 🟠 MajorReturn inside loop exits after first iteration.
Same issue as
print_shared_buffer_with_condition: thereturnat line 100 causes the function to exit after printing only the first element.🐛 Proposed fix
if condition: # Iterate through the buffer elements and print each one. for i in serial(elems): coords = index_to_coordinates(i, buffer.shape) - return tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords]) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords])
126-130:⚠️ Potential issue | 🟠 MajorReturn inside loop exits after first iteration.
Same issue as other buffer print functions: the
returnat line 130 causes early exit after the first element.🐛 Proposed fix
if condition: # Iterate through the buffer elements and print each one. for i in serial(elems): coords = index_to_coordinates(i, buffer.shape) - return tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])tilelang/language/reduce_op.py (1)
24-104:⚠️ Potential issue | 🔴 CriticalFix early returns that skip writing to
outin shared→shared / fragment→shared reductions.
returnprecedescopy(red_frag_out, out), making the copy unreachable and leavingoutstale for those paths. Please return after the copy.🐛 Suggested fix
@@ - return tir.call_intrin( + call = tir.call_intrin( "handle", tir.op.Op.get(_REDUCE_OP_KEY), to_buffer_region(red_frag_in, access_type="r"), to_buffer_region(red_frag_out, access_type="w"), reduce_type, dim, clear, ) copy(red_frag_out, out) + return call @@ - return tir.call_intrin( + call = tir.call_intrin( "handle", tir.op.Op.get(_REDUCE_OP_KEY), to_buffer_region(buffer, access_type="r"), to_buffer_region(red_frag_out, access_type="w"), reduce_type, dim, clear, ) copy(red_frag_out, out) + return call
🤖 Fix all issues with AI agents
In `@tilelang/language/copy_op.py`:
- Line 24: The return type annotation that currently reads "-> tir.PrimExpr" is
incorrect because the function sometimes returns a tir.BufferStore (a tir.Stmt)
and other times returns tir.call_intrin(...) (a tir.PrimExpr); change the
annotation to a union such as "-> Union[tir.PrimExpr, tir.Stmt]" (add the
necessary typing import), and update the docstring (formerly "Returns:
tir.Call") to describe both possible return types and the conditions under which
each (tir.BufferStore vs tir.PrimExpr from tir.call_intrin) is returned.
In `@tilelang/language/print_op.py`:
- Around line 47-58: The else branch in print_global_buffer_with_condition
references i and coords which are only defined inside the for loop and must be
removed; also the current return inside the for loop causes early exit so change
the implementation so that when condition is True you iterate for i in
serial(elems), compute coords = index_to_coordinates(i, buffer.shape) and call
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i,
buffer[coords]) for each element (no return inside the loop), and when condition
is False simply return None (or do nothing) so no undefined variables are
referenced.
🧹 Nitpick comments (2)
tilelang/language/copy_op.py (1)
120-130: Docstring parameters don't match new type annotations.The docstring still documents
imgandcolastir.Buffer, but the signature now acceptsBufferLikeType. Consider updating the docstring for consistency:"""Perform im2col transformation for 2D convolution. Args: - img (tir.Buffer): Input image buffer - col (tir.Buffer): Output column buffer + img (BufferLikeType): Input image buffer (Buffer, BufferLoad, or BufferRegion) + col (BufferLikeType): Output column buffer (Buffer, BufferLoad, or BufferRegion)Also, a return type annotation could be added for consistency with
copy.tilelang/language/reduce_op.py (1)
111-245: Docstrings still mentiontir.Call; update totir.PrimExpr.
A few reducer docstrings lag behind the new return type.✍️ Example update (apply similarly to other reducers)
- Returns: - tir.Call: Handle to the reduction operation + Returns: + tir.PrimExpr: Handle to the reduction operation
|
@codex review |
|
To use Codex here, create a Codex account and connect to github. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tilelang/language/allocate.py (3)
172-198:⚠️ Potential issue | 🟡 MinorDocstring args no longer match the signature.
The docstring still describesnum_cols, but the function now takesshape. This will mislead users.📝 Suggested doc update
- Args: - num_cols (int): Number of columns to allocate in TMEM. Must be a power of 2 and >= 32 but less than or equal to 512. + Args: + shape (tuple): 2D TMEM shape; must be a 2D tensor and conform to TMEM column requirements.
33-48:⚠️ Potential issue | 🟠 MajorHandle non-string
DTyperepresentations in the bool scope workaround.The type signature allows
dtypeto be adtypeobject,ir.Type, ortypein addition to strings. The current check at line 44 only handles the string casedtype == "bool", which means non-string dtype representations (e.g.,ir.Typeordtypeobjects) would bypass the workaround and reintroduce the shared memory merge limitation.🛠️ Suggested fix
- if dtype == "bool": + dtype_str = dtype if isinstance(dtype, str) else str(dtype) + if dtype_str == "bool": # lei: This is a hack to handle bool type. # Because tilelang's merge smem pass cannot merge bool type currently. scope = "shared"
279-291:⚠️ Potential issue | 🟠 MajorBranch 2 restricts positional dtype to strings only, inconsistent with DType specification.
The
empty()function's second branch (line 287) only acceptsisinstance(shape[1], str), butDTypeis defined asdtype | ir.Type | str | type. When called asempty(shape_tuple, _dtypes.float32)(positional), the dtype class passes through to line 287's check and raisesTypeError, even though branches 1 and 3 already accept dtype classes via the keyword parameter. This breaks the reasonable call style.🛠️ Suggested fix
- elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str): + elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and ( + isinstance(shape[1], str) + or isinstance(shape[1], tl_dtype) + or (isinstance(shape[1], type) and issubclass(shape[1], tl_dtype)) + ): return OutTensor(shape[0], shape[1])
🧹 Nitpick comments (1)
tilelang/language/allocate.py (1)
201-233: Tightenreplicationtyping to match allowed values.
Since only"all"and"none"are accepted, a precise type hint improves linting.♻️ Suggested type refinement
-def alloc_reducer(shape: ShapeType, dtype: DType, op: ReducerOp = "sum", replication=None) -> Buffer: +def alloc_reducer( + shape: ShapeType, + dtype: DType, + op: ReducerOp = "sum", + replication: Literal["all", "none"] | None = None, +) -> Buffer:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
tilelang/language/copy_op.py (1)
121-135:⚠️ Potential issue | 🟡 MinorAlign
c2d_im2col()docstring arg types withBufferLikeType.
imgandcolare nowBufferLikeTypein the signature, but the docstring still saystir.Buffer. Please update for consistency.Proposed docstring update
- img (tir.Buffer): Input image buffer - col (tir.Buffer): Output column buffer + img (BufferLikeType): Input image buffer + col (BufferLikeType): Output column buffertilelang/language/print_op.py (4)
74-78:⚠️ Potential issue | 🔴 CriticalCritical: Same early-exit bug—
returninside loop prints only the first element.Same issue as
print_global_buffer_with_condition: thereturnon line 78 exits the loop after the first iteration, printing only one element instead of allelems.🐛 Proposed fix
if condition: # Iterate through the buffer elements and print each one. for i in serial(elems): coords = index_to_coordinates(i, buffer.shape) - return tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
96-100:⚠️ Potential issue | 🔴 CriticalCritical: Same early-exit bug in fragment buffer printing.
The
returnon line 100 exits after the first iteration. After allocating shared memory and copying the buffer (lines 94-95), only the first element gets printed.🐛 Proposed fix
if condition: # Iterate through the buffer elements and print each one. for i in serial(elems): coords = index_to_coordinates(i, buffer.shape) - return tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords]) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords])
126-130:⚠️ Potential issue | 🔴 CriticalCritical: Same early-exit bug in local buffer printing.
The
returnon line 130 exits after the first iteration, printing only the first element instead of allelemselements.🐛 Proposed fix
if condition: # Iterate through the buffer elements and print each one. for i in serial(elems): coords = index_to_coordinates(i, buffer.shape) - return tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords]) + tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
165-180:⚠️ Potential issue | 🟡 MinorDocstring does not reflect optional return type.
The return type annotation is
tir.PrimExpr | None, but the docstring at lines 179-180 states onlytir.PrimExpr. Update the docstring for consistency.📝 Proposed fix
Returns: - tir.PrimExpr: The TIR expression for the debug print operation. + tir.PrimExpr | None: The TIR expression for the debug print operation, or None if no print is executed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request improves type annotations and reduces lint errors in frontend language APIs for TileLang. It introduces a new tilelang/typing.py module that defines unified public type aliases (ShapeType, DType, BufferLikeType, BarrierType) to standardize type annotations across the codebase and address issue #1768 regarding inappropriate type annotations for T.view and T.reshape.
Changes:
- Introduced a new
tilelang/typing.pymodule with standardized type aliases for Python 3.9+ compatibility - Updated function signatures across multiple language API modules to use the new type aliases
- Added missing return type annotations to numerous functions
- Introduced ReduceKind and ReducerOp type literals for improved type safety in reduction operations
Reviewed changes
Copilot reviewed 24 out of 24 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| tilelang/typing.py | New module defining BufferLikeType, BarrierType, DType, and ShapeType type aliases with Python 3.9 compatibility |
| tilelang/utils/language.py | Updated buffer utility functions to use BufferLikeType instead of explicit unions |
| tilelang/language/tir/ir.pyi | Added DType annotations to cast/reinterpret/infinity/max_value/min_value functions |
| tilelang/language/symbolics.py | Added DType annotations to dynamic and symbolic functions |
| tilelang/language/reduce_op.py | Added ReduceKind literal, return type annotations, and BufferLikeType usage |
| tilelang/language/random.py | Added return type annotations to RNG functions |
| tilelang/language/proxy.py | Updated Tensor/Buffer proxy functions with ShapeType and DType annotations |
| tilelang/language/print_op.py | Added return type annotations (contains critical bugs) |
| tilelang/language/pdl.py | Added return type annotations to PDL functions |
| tilelang/language/parser/entry.py | Updated BufferProxy with ShapeType and DType annotations |
| tilelang/language/math_intrinsics.py | Added PrimExpr type annotations to all math intrinsics |
| tilelang/language/loop.py | Added return type annotations and parameter type improvements |
| tilelang/language/logical.py | Updated to use BufferLikeType |
| tilelang/language/kernel.py | Updated Kernel function with int/PrimExpr parameter types |
| tilelang/language/gemm_op.py | Updated to use BufferLikeType and BarrierType throughout |
| tilelang/language/fill_op.py | Updated to use BufferLikeType (missing return type annotations) |
| tilelang/language/fastmath.py | Added PrimExpr type annotations |
| tilelang/language/experimental/gemm_sp.py | Updated to use BufferLikeType |
| tilelang/language/eager/builder.py | Fixed const function return type to allow single Var |
| tilelang/language/customize.py | Updated reshape and view with ShapeType and DType |
| tilelang/language/copy_op.py | Updated to use BufferLikeType and added return type (missing for c2d_im2col) |
| tilelang/language/builtin.py | Comprehensive type improvements with BufferLikeType and BarrierType |
| tilelang/language/ast/ir.py | Extensive type annotation updates with ShapeType, DType, and Optional types |
| tilelang/language/allocate.py | Updated with ShapeType, DType, and ReducerOp literal (missing some return types) |
Comments suppressed due to low confidence (1)
tilelang/language/copy_op.py:130
- The function signature is missing a return type annotation. Based on the function implementation which returns tir.call_intrin result, the return type should be tir.PrimExpr.
def c2d_im2col(
img: BufferLikeType,
col: BufferLikeType,
nhw_step: tir.PrimExpr,
c_step: tir.PrimExpr,
kernel: int,
stride: int,
dilation: int,
pad: int,
eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None,
):
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
tilelang/language/print_op.py (2)
16-25:⚠️ Potential issue | 🟡 MinorUpdate docstrings to match
Nonereturn types.These helpers now return
None, but their docstrings still say they returntir.PrimExpr, which will mislead users and type checkers.✅ Suggested docstring fix
@@ - Returns: - tir.PrimExpr: The TIR expression for the debug print operation. + Returns: + None @@ - Returns: - tir.PrimExpr: The TIR expression for the debug print operation, if the condition is True. + Returns: + None @@ - Returns: - tir.PrimExpr: The TIR expression for the debug print operation. + Returns: + None @@ - Returns: - tir.PrimExpr: The TIR expression for the debug print operation. + Returns: + NoneAlso applies to: 30-40, 58-69, 78-89
100-106:⚠️ Potential issue | 🟡 MinorAvoid asserting when
print()is called with defaults.
print_msgrequires a non-empty string, butprint()allowsobj=Nonewith the defaultmsg="", which triggers an assertion. Consider guarding the msg-only path (or relax the assert).🔧 Suggested guard in print()
@@ - elif obj is None: - print_msg(msg) + elif obj is None: + if msg: + print_msg(msg)Also applies to: 158-233
tilelang/language/fill_op.py (1)
52-60:⚠️ Potential issue | 🟠 MajorHandle let-bound
BufferLoadwithout ramp indices.
get_buffer_region_from_loadreturnsNonefor non-ramp indices; in that caseclear()currently raisesValueErroreven thoughfill()can handle aBufferLoadby using unit extents. This makesclear(var)fail for valid let-bound loads whileclear(buffer_load)succeeds.🔧 Suggested fix
elif isinstance(buffer_region, tir.BufferLoad): region = get_buffer_region_from_load(buffer_region) if region is None: - raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}") - return fill(region, 0) + return fill(buffer_region, 0) + return fill(region, 0)tilelang/language/allocate.py (2)
172-195:⚠️ Potential issue | 🟡 MinorUpdate alloc_tmem docstring to match the shape-based signature.
The docstring still references
num_cols, but the API now accepts a 2Dshape. This is misleading for callers.📝 Suggested docstring update
- Args: - num_cols (int): Number of columns to allocate in TMEM. Must be a power of 2 and >= 32 but less than or equal to 512. + Args: + shape (tuple[int | PrimExpr, int | PrimExpr]): 2D shape for TMEM allocation. + The number of columns should be a power of 2 and >= 32 (<= 512).
279-291:⚠️ Potential issue | 🟠 MajorHandle all DType variants in empty() positional argument branch.
The function accepts
dtype: DTypewhich includestvm.DataType(defined asdtype | ir.Type | str | type), but line 286-287 only checks forstrwhenempty()is called with 2 positional arguments. This causes valid dtype inputs liketvm.DataTypeobjects to raiseTypeError.When called as
empty((shape), dtype_object), the dtype object becomesshape[1], and theisinstance(shape[1], str)check fails for non-string DType variants. Expand the isinstance check to accept all valid DType types, or validate the dtype argument before the shape pattern matching.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 24 out of 24 changed files in this pull request and generated 6 comments.
Comments suppressed due to low confidence (3)
tilelang/language/print_op.py:39
- The docstring incorrectly states that the function returns
tir.PrimExpr, but the actual return type isNone. The docstring should be updated to match the new return type annotation.
def print_var_with_condition(condition: tir.PrimExpr, var: tir.PrimExpr, msg: str = "") -> None:
"""
Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True.
Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check.
var (tir.PrimExpr): The variable or expression to be printed.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation, if the condition is True.
tilelang/language/print_op.py:24
- The docstring incorrectly states that the function returns
tir.PrimExpr, but the actual return type isNone. The docstring should be updated to match the new return type annotation.
"""
Prints the value of a TIR primitive expression (PrimExpr) for debugging purposes.
Parameters:
var (tir.PrimExpr): The variable or expression to be printed.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
tilelang/language/print_op.py:68
- The docstring incorrectly states that the function returns
tir.PrimExpr, but the actual return type isNone. The docstring should be updated to match the new return type annotation.
def print_shared_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> None:
"""
Conditionally prints the values of a flattened TIR buffer if the condition is True.
Parameters:
condition (tir.PrimExpr): A TIR expression representing the condition to check.
buffer (tir.Buffer): The buffer whose values need to be printed.
elems (int): The number of elements in the buffer to print.
Returns:
tir.PrimExpr: The TIR expression for the debug print operation.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tilelang/language/allocate.py (1)
33-48:⚠️ Potential issue | 🟠 MajorNormalize
DTypebefore comparing to"bool".With
DTypenow accepting dtype classes,dtype == "bool"won’t catch_dtypes.booland the special‑case scope hack won’t trigger.✅ Suggested normalization
- if dtype == "bool": + dtype_name = str(dtype) + if dtype_name == "bool": # lei: This is a hack to handle bool type. # Because tilelang's merge smem pass cannot merge bool type currently. scope = "shared"tilelang/language/reduce_op.py (2)
46-95:⚠️ Potential issue | 🟠 MajorReturn-before-copy drops fragment→shared results.
return tir.call_intrin(...)makes the subsequentcopy(...)unreachable in the shared→shared and fragment→shared branches, so results never reachout.✅ Preserve copy and still return the intrinsic handle
if is_shared(buffer) and is_shared(out): @@ - return tir.call_intrin( + call = tir.call_intrin( "handle", tir.op.Op.get(_REDUCE_OP_KEY), to_buffer_region(red_frag_in, access_type="r"), to_buffer_region(red_frag_out, access_type="w"), reduce_type, dim, clear, ) copy(red_frag_out, out) + return call @@ elif is_fragment(buffer) and is_shared(out): @@ - return tir.call_intrin( + call = tir.call_intrin( "handle", tir.op.Op.get(_REDUCE_OP_KEY), to_buffer_region(buffer, access_type="r"), to_buffer_region(red_frag_out, access_type="w"), reduce_type, dim, clear, ) copy(red_frag_out, out) + return call
248-285:⚠️ Potential issue | 🟡 MinorFragment cumsum path returns None despite
tir.PrimExprreturn type.
cumsum()returnscumsum_fragment(...)for fragment inputs, but the macro doesn’t return the intrinsic handle.✅ Return the intrinsic handle after emitting copies
copy(src, cumsum_smem) - tir.call_intrin( + call = tir.call_intrin( "handle", tir.op.Op.get("tl.tileop.cumsum"), to_buffer_region(cumsum_smem, access_type="r"), to_buffer_region(cumsum_smem, access_type="w"), dim, reverse, ) copy(cumsum_smem, dst) + return call
🤖 Fix all issues with AI agents
In `@tilelang/_typing.py`:
- Around line 5-33: Add "from __future__ import annotations" as the very first
line of the module so the PEP 604 union types used in BarrierType,
BufferLikeType, DType, and ShapeType are deferred and won't be evaluated at
import time on Python 3.9; ensure the future import is placed before any other
imports or code in tilelang/_typing.py.
🧹 Nitpick comments (2)
tilelang/language/gemm_op.py (1)
40-51: Return type annotation may be inaccurate.The function signature declares
-> BufferLikeType, but ifargis atir.Varwithout a let value, it's returned unchanged as aVar. The return type should beBufferLikeType | tir.Varto accurately reflect all code paths.♻️ Suggested fix
- def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType: + def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType | tir.Var:tilelang/language/proxy.py (1)
58-71: Consider unifying default dtype values.
BufferProxy.__call__uses_dtypes.float32as the default (line 26), whileBufferProxy.from_ptr,BaseTensorProxy.__call__, andBaseTensorProxy.from_ptruse the string"float32". Both should work, but consistency would improve maintainability.Also applies to: 86-136
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tilelang/language/reduce_op.py (4)
24-108:⚠️ Potential issue | 🟡 MinorAlign
reduce()docstring with ReduceKind + None return.The signature now uses
ReduceKindand returnsNone, but the docstring still documentsstrandtir.Call, which is now misleading.📘 Suggested docstring update
- reduce_type (str): Type of reduction ('max', 'min', 'sum', 'abssum') + reduce_type (ReduceKind): Type of reduction ("sum", "abssum", "max", "absmax", "min", "bitand", "bitor", "bitxor") ... - Returns: - tir.Call: Handle to the reduction operation + Returns: + None
111-245:⚠️ Potential issue | 🟡 MinorUpdate reduce_ wrapper docstrings to reflect None return.*
Each wrapper now returns
None, but the “Returns” sections still mentiontir.Call, which can confuse API users.
248-285:⚠️ Potential issue | 🟡 MinorBroaden
cumsum_fragmentdst type to BufferLikeType.
cumsum()accepts buffer regions/loads, butcumsum_fragment()still narrowsdsttotir.Buffer, which will trip type checkers and contradicts the docstring.🧩 Suggested typing fix
def cumsum_fragment( src: BufferLikeType, - dst: tir.Buffer, + dst: BufferLikeType, dim: int, reverse: bool, ) -> None:
287-365:⚠️ Potential issue | 🟡 Minor
cumsum()return type doesn’t match fragment path.When
srcis a fragment,cumsum()returns the result ofcumsum_fragment()(currentlyNone), but the signature/docstring promisetir.PrimExpr.✅ Option: make the return type accurate
-def cumsum( - src: BufferLikeType, - dst: BufferLikeType | None = None, - dim: int = 0, - reverse: bool = False, -) -> tir.PrimExpr: +def cumsum( + src: BufferLikeType, + dst: BufferLikeType | None = None, + dim: int = 0, + reverse: bool = False, +) -> tir.PrimExpr | None: @@ - Returns: - tir.Call: A handle to the emitted cumulative-sum operation. + Returns: + tir.Call | None: A handle to the emitted cumulative-sum operation (None for fragment path).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (5)
tilelang/language/warpgroup.py (1)
17-30:⚠️ Potential issue | 🟡 MinorDocstring return type is now stale.
The function now annotates
-> WarpSpecializeFramebut the docstring still saysTuple[frame.LaunchThreadFrame]. Align the docstring with the updated public API.✏️ Suggested docstring fix
- Returns - ------- - res : Tuple[frame.LaunchThreadFrame] - The result LaunchThreadFrame. + Returns + ------- + res : WarpSpecializeFrame + The WarpSpecializeFrame produced by the FFI call.tilelang/language/print_op.py (1)
45-54:⚠️ Potential issue | 🟡 MinorInconsistent return type annotation.
All other similar buffer print functions (
print_shared_buffer_with_condition,print_fragment_buffer_with_condition,print_local_buffer_with_condition) were updated to return-> None, but this function still has-> tir.PrimExpr. Since the function body has no return statement (correctly, as it's a side-effect-only macro), the type annotation should also be updated for consistency.🔧 Proposed fix
`@macro` -def print_global_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr: +def print_global_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> None: """ Conditionally prints the values of a flattened TIR buffer if the condition is True. """tilelang/language/reduce_op.py (1)
109-127:⚠️ Potential issue | 🟡 MinorDocstrings are inconsistent with the new
Nonereturn type.The docstrings for
reduce_max(lines 122-125),reduce_min(lines 139-141),reduce_sum(lines 164-166),reduce_abssum(lines 179-181),reduce_absmax(lines 194-196),reduce_bitand(lines 209-211),reduce_bitor(lines 224-226), andreduce_bitxor(lines 239-241) all state:Returns: tir.Call: Handle to the reduction operationHowever, the function signatures now correctly indicate
-> None. Consider updating or removing theReturnssection from these docstrings to match the actual behavior.📝 Proposed fix for reduce_max (apply similar changes to other functions)
def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True) -> None: """Perform reduce max on input buffer, store the result to output buffer Parameters ---------- buffer : Buffer The input buffer. out : Buffer The output buffer. dim : int The dimension to perform reduce on clear : bool If set to True, the output buffer will first be initialized to -inf. - Returns - ------- - handle : PrimExpr """tilelang/language/gemm_op.py (1)
40-51:⚠️ Potential issue | 🟡 MinorReturn type annotation doesn't match actual behavior.
The function can return a
tir.Varwhen the input is a non-let-bound variable (line 51 returnsargunchanged), but the return type annotation specifiesBufferLikeType. Additionally, the docstring references old types (Union[tir.Buffer, tir.Var]) instead of the new type aliases.🔧 Proposed fix
- def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType: + def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType | tir.Var: """Convert let-bound variables to their corresponding buffers. Args: - arg (Union[tir.Buffer, tir.Var]): Input argument to legalize + arg (BufferLikeType | tir.Var): Input argument to legalize Returns: - Union[tir.Buffer, tir.Var]: The legalized argument + BufferLikeType | tir.Var: The legalized argument (buffer if let-bound, otherwise unchanged) """tilelang/language/customize.py (1)
70-76:⚠️ Potential issue | 🟡 MinorFix docstring to match type annotation, and clarify the intentional duplication.
The function duplicates
loop_break()intilelang/language/builtin.py:797, but thecustomize.pyversion is intentionally the public API (explicitly imported in__init__.pybefore thebuiltinwildcard import). Consider whether this duplication can be eliminated or documented.The docstring inconsistency should be resolved: it states the return type is
tir.Call(line 74), but the annotation isPrimExpr. Update the docstring to match the annotation or be more specific (e.g., "Returns: PrimExpr (specifically a tir.Call to the intrinsic)").
🤖 Fix all issues with AI agents
In `@tilelang/language/reduce_op.py`:
- Around line 285-290: The cumsum function signature declares it returns
tir.PrimExpr but when src is a fragment it returns the result of cumsum_fragment
which is currently typed to return None, causing a type mismatch; fix by making
the return types consistent — either change cumsum's signature to "->
tir.PrimExpr | None" if fragments legitimately cause no PrimExpr, or update
cumsum_fragment to return a tir.PrimExpr (propagate and return the intrinsic
result) and keep cumsum as "-> tir.PrimExpr"; locate the functions cumsum and
cumsum_fragment and apply the chosen change so callers and type checks align.
🧹 Nitpick comments (7)
tilelang/language/copy_op.py (1)
43-44: Docstring still statestir.Callbut return type is nowtir.PrimExpr | tir.Stmt.The docstring on line 44 says "Returns: tir.Call" but doesn't reflect the
tir.BufferStorepath (line 87). Consider updating to describe both return scenarios.📝 Suggested docstring update
Returns: - tir.Call: A handle to the copy operation + Union[tir.PrimExpr, tir.Stmt]: Either a tir.BufferStore (when both src and dst + are scalar BufferLoads without region extents) or a tir.Call handle to the + copy operation.tilelang/language/print_op.py (5)
15-26: Docstring "Returns" section is now stale.The return type annotation was correctly updated to
None, but the docstring at lines 23-24 still documentsReturns: tir.PrimExpr. Since this is a side-effect-only macro, the Returns section should either be removed or updated to indicateNone.📝 Suggested docstring update
""" Prints the value of a TIR primitive expression (PrimExpr) for debugging purposes. Parameters: var (tir.PrimExpr): The variable or expression to be printed. - - Returns: - tir.PrimExpr: The TIR expression for the debug print operation. """
29-42: Docstring "Returns" section is stale.Same issue as
print_var- the docstring still documents atir.PrimExprreturn value at lines 38-39, but the signature now correctly indicates-> None.
57-74: Docstring "Returns" section is stale.The return type annotation was correctly updated to
None, but the docstring at lines 67-68 still documentsReturns: tir.PrimExpr.
77-96: Docstring "Returns" section is stale.Same pattern - the docstring at lines 87-88 still documents a
tir.PrimExprreturn value, but the signature now correctly indicates-> None.
168-172: Malformed docstring parameter description.Line 171 appears to be a continuation of the
warp_idparameter description or a separate note, but it's formatted as a standalone line without proper indentation or labeling. This makes the docstring harder to parse.📝 Suggested fix
Parameters: obj (Any): The object to print. It can be either a tir.Buffer, tir.PrimExpr, or None (for msg-only print). msg (str): An optional message to include in the print statement. warp_group_id (int): The warp group id to print. - warp_id (int): The warp id to print. - print thread will be warp_group_id * warp_group_size + warp_id + warp_id (int): The warp id to print. The print thread will be + warp_group_id * warp_group_size + warp_id * warp_size.tilelang/language/gemm_op.py (1)
199-222: Type annotations don't reflect thattir.Varis accepted.The docstrings correctly document that
Varis accepted for A, B, C, and mbar (lines 213, 222), but the type annotations only specifyBufferLikeTypeandBarrierType. Sincelegalize_argumentsinternally handlestir.Varconversion, consider updating the public signatures to explicitly acceptVarfor better IDE support and API clarity.♻️ Optional: Make type annotations match documented behavior
def gemm( - A: BufferLikeType, - B: BufferLikeType, - C: BufferLikeType, + A: BufferLikeType | tir.Var, + B: BufferLikeType | tir.Var, + C: BufferLikeType | tir.Var, transpose_A: bool = False, transpose_B: bool = False, policy: GemmWarpPolicy = GemmWarpPolicy.Square, clear_accum: bool = False, k_pack: int = 1, wg_wait: int = 0, - mbar: BarrierType | None = None, + mbar: BarrierType | tir.Var | None = None, ) -> tir.PrimExpr:The same change could be applied to
gemm_v1andgemm_v2for consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 26 out of 26 changed files in this pull request and generated 14 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def print_var_with_condition(condition: tir.PrimExpr, var: tir.PrimExpr, msg: str = "") -> None: | ||
| """ | ||
| Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True. | ||
Copilot
AI
Feb 3, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring states "Returns: tir.PrimExpr: The TIR expression for the debug print operation, if the condition is True" but the function signature has been updated to return None. The docstring should be updated or the Returns section should be removed to reflect that this macro function doesn't return a value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tilelang/language/reduce_op.py (4)
122-125:⚠️ Potential issue | 🟡 MinorDocstring incorrectly claims a return value.
The function signature is
-> None, but the docstring states it returns aPrimExprhandle. The docstring should be updated or the "Returns" section removed to match the actual behavior.📝 Suggested fix
dim : int The dimension to perform reduce on clear : bool If set to True, the output buffer will first be initialized to -inf. - Returns - ------- - handle : PrimExpr """
139-140:⚠️ Potential issue | 🟡 MinorDocstrings across reduce helpers claim returns that don't exist.
Multiple
reduce_*functions (reduce_min,reduce_sum,reduce_abssum,reduce_absmax,reduce_bitand,reduce_bitor,reduce_bitxor) have docstrings claiming they return atir.Callhandle, but all these functions now returnNone. Consider removing or updating the "Returns" sections in these docstrings for consistency.Also applies to: 164-165, 179-180, 194-195, 209-210, 224-225, 239-240
331-332:⚠️ Potential issue | 🟡 MinorDocstring should mention the
Nonereturn case.The return type is
tir.PrimExpr | None, but the docstring only mentions returning atir.Callhandle. It should clarify thatNoneis returned whensrcis a fragment buffer (delegating to the macro implementation).📝 Suggested fix
Returns: - tir.Call: A handle to the emitted cumulative-sum operation. + tir.PrimExpr | None: A handle to the emitted cumulative-sum operation, + or None when operating on fragment buffers (macro implementation).
437-466:⚠️ Potential issue | 🟡 MinorAdd
warp_reduce_bitxorfunction or clarify why it's intentionally omitted.The file defines
reduce_bitxor(line 231) and includes"bitxor"inReduceKind(line 21), but there's no correspondingwarp_reduce_bitxorfunction. Bothbitandandbitorhave warp equivalents (lines 437 and 453), making this an inconsistency. Either addwarp_reduce_bitxorfor symmetry or document why warp-level bitwise-xor reduction is not supported.
🧹 Nitpick comments (1)
tilelang/language/reduce_op.py (1)
44-45: Consider usingReduceKindinstead ofstrfor type consistency.The outer
reducefunction usesReduceKindfor thereduce_typeparameter, but the nestedreduce_macrousesstr. For full type consistency and better IDE support, consider usingReduceKindhere as well.🔧 Suggested fix
`@macro` - def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool) -> None: + def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: ReduceKind, dim: int, clear: bool) -> None:
This PR majorly fixes many type annotation & lint issues in frontend language APIs (tilelang/language/...). This issue is not an easy one and can not be solved just in one PR. We need to continuously improve our type annotations and linting. Anyway, it solves #1768.
Summary by CodeRabbit
New Features
Behavior
Documentation
Refactor