Skip to content

Conversation

@SiriusNEO
Copy link
Collaborator

@SiriusNEO SiriusNEO commented Feb 3, 2026

This PR majorly fixes many type annotation & lint issues in frontend language APIs (tilelang/language/...). This issue is not an easy one and can not be solved just in one PR. We need to continuously improve our type annotations and linting. Anyway, it solves #1768.

Summary by CodeRabbit

  • New Features

    • Broader acceptance of buffer- and barrier-like inputs across memory, GEMM, copy, fill, reduce, and other APIs.
    • New public type aliases for dtypes/shapes and explicit reducer kinds; dtype-aware allocation and descriptors.
    • Tensor creation now requires an active builder; const may return a single value or a tuple.
  • Behavior

    • print() and related helpers are side-effect-only (no return values).
    • reshape/view now validate shape–dtype compatibility at runtime.
  • Documentation

    • Many public signatures gained explicit types and return annotations for better IDE hints.
  • Refactor

    • Centralized typing for consistent public APIs (shape, dtype, buffer/barrier aliases).

@github-actions
Copy link

github-actions bot commented Feb 3, 2026

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Feb 3, 2026

📝 Walkthrough

Walkthrough

Central typing modernization: introduces DType, ShapeType, BufferLikeType, BarrierType and updates many public signatures and return annotations across allocation, IR, builtin/memory, intrinsics, proxy/parser, loop, reduce, and utility modules to use these aliases and Optional types.

Changes

Cohort / File(s) Summary
Type System Infrastructure
tilelang/_typing.py
Add DType, ShapeType, BufferLikeType, BarrierType, and BufferLikeTypeTuple aliases for consistent typing.
Allocation & IR
tilelang/language/allocate.py, tilelang/language/ast/ir.py, tilelang/language/tir/ir.pyi
Replace generic dtype/shape hints with DType/ShapeType; many buffer/alloc APIs updated to Optional parameters and typed returns; add ReducerOp.
Builtin & Memory ops
tilelang/language/builtin.py, tilelang/utils/language.py
Broadened memory/barrier APIs to accept BufferLikeType/BarrierType; centralized mbarrier→buffer_load conversion; updated ldg*/stg* and fence helpers to validate BufferLikeType.
Copy / Fill / Print / RNG
tilelang/language/copy_op.py, tilelang/language/fill_op.py, tilelang/language/print_op.py, tilelang/language/random.py
copy/im2col/fill accept BufferLikeType; fill/clear normalize let-bound buffers; print helpers now side-effect-only (-> None); RNG functions annotated to return PrimExpr.
GEMM & Experimental GEMM_SP
tilelang/language/gemm_op.py, tilelang/language/experimental/gemm_sp.py
GEMM APIs accept BufferLikeType and BarrierType; legalization updated to handle let/Var and buffer-like inputs; experimental GEMM_SP updated similarly.
Reduce & Warp ops
tilelang/language/reduce_op.py
Add ReduceKind alias and adjust reduction API typings/return annotations; cumsum/cumsum_fragment updated to use BufferLikeType; many reductions now typed as -> None.
Proxy, Parser & Builder
tilelang/language/proxy.py, tilelang/language/parser/entry.py, tilelang/language/eager/builder.py
Proxies and parser buffer factories now use ShapeType/DType and return tir.Buffer; make_tensor enforces Builder context; const return widened to `Var
Customize / Kernel / Loop
tilelang/language/customize.py, tilelang/language/kernel.py, tilelang/language/loop.py
reshape/view accept ShapeType/DType and validate bit-size; Kernel variadic blocks accept `int
Math & Intrinsics
tilelang/language/fastmath.py, tilelang/language/math_intrinsics.py
Annotate math/IEEE intrinsics to accept and return tir.PrimExpr; add explicit PrimExpr imports.
Symbolics / Logical / PDL / Utils / Warpgroup
tilelang/language/symbolics.py, tilelang/language/logical.py, tilelang/language/pdl.py, tilelang/language/utils.py, tilelang/language/warpgroup.py
dynamic/symbolic use DType; logical helpers accept BufferLikeType; several small helpers gain explicit return types.

Sequence Diagram(s)

(omitted)

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested reviewers

  • kurisu6912
  • LeiWang1999

Poem

🐇 In the burrow where type-flowers spring,

DType and ShapeType start to sing.
Buffers, barriers, all set in line,
Typed and tidy — a rabbit’s fine design.
Hops of joy for every signed type!

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 76.04% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main objective of the PR: improving type annotations and reducing lint errors in the frontend language module.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tilelang/language/loop.py (1)

13-24: ⚠️ Potential issue | 🟡 Minor

Docstring does not reflect updated parameter type.

The function signature now accepts int | tir.PrimExpr for extents, but the docstring on line 23 still states extents : PrimExpr. Consider updating the docstring to match the signature.

📝 Suggested docstring fix
     Parameters
     ----------
-    extents : PrimExpr
+    extents : int | PrimExpr
         The extents of the iteration.
tilelang/language/kernel.py (1)

229-241: ⚠️ Potential issue | 🟡 Minor

Docstring could clarify accepted types for blocks parameter.

The signature now accepts int | tir.PrimExpr, but the docstring on lines 239-240 only mentions int. Consider updating to reflect both accepted types for clarity.

📝 Suggested docstring improvement
     Parameters
     ----------
-    blocks : int
-        A list of extent, can be 1-3 dimension, representing gridDim.(x|y|z)
+    blocks : int | PrimExpr
+        The block extents (1-3 dimensions), representing gridDim.(x|y|z)
tilelang/language/gemm_op.py (1)

40-56: ⚠️ Potential issue | 🟠 Major

Don’t drop region/offset info when unwrapping let-bound buffers.
Line 49-50 returns .buffer, which loses BufferRegion/BufferLoad slicing; to_buffer_region then computes shapes/offsets on the full buffer, which can mis-handle sliced A/B/C (and mbar arrays). Return the let value directly and let to_buffer_region normalize it.

🔧 Suggested fix
-        if isinstance(arg, tir.Var) and T.has_let_value(arg):
-            return T.get_let_value(arg).buffer
+        if isinstance(arg, tir.Var) and T.has_let_value(arg):
+            return T.get_let_value(arg)
         return arg
tilelang/language/ast/ir.py (1)

983-1013: ⚠️ Potential issue | 🟡 Minor

Handle None condition in allocate() to prevent runtime errors.

The function signature allows condition=None (line 987), but the code only converts bool to IntImm and passes condition directly to _ffi_api.Allocate(). When allocate() is called without the condition parameter (as shown in all test usages), it defaults to None and will cause a runtime error. Add a guard to default None to IntImm("bool", True):

🛠️ Suggested fix
     if isinstance(condition, bool):
         condition = IntImm("bool", condition)
+    elif condition is None:
+        condition = IntImm("bool", True)
🤖 Fix all issues with AI agents
In `@tilelang/language/fastmath.py`:
- Line 127: Update the docstring for the __exp function in
tilelang/language/fastmath.py to state that it computes the natural exponential
(e**x or exp(x)) rather than 2**x; locate the docstring above the __exp
definition and replace "Calculate 2**x with fast math" with text like "Calculate
e**x (exp(x)) with fast math" so it matches the CUDA backend mapping and naming
conventions.
🧹 Nitpick comments (5)
tilelang/language/customize.py (1)

57-67: Consider enhancing the assertion message in view for consistency.

The view function's assertion message (line 66) is less informative than reshape's. For debugging ease, consider including the actual values.

📝 Suggested improvement
-    assert prim_expr_equal(bits_product(shape, dtype), bits_product(src.shape, src.dtype)), "T.reshape/view shape check failed."
+    assert prim_expr_equal(bits_product(shape, dtype), bits_product(src.shape, src.dtype)), (
+        f"T.reshape/view shape check failed. src {src} src.shape: {src.shape}, src.dtype: {src.dtype}, target shape: {shape}, target dtype: {dtype}"
+    )
tilelang/language/copy_op.py (1)

28-29: Consider updating docstrings to reference BufferLikeType.

The docstrings still reference the explicit union types (Union[tir.Buffer, tir.BufferLoad, tir.BufferRegion] on lines 28-29 and tir.Buffer on lines 134-135) while the signatures now use BufferLikeType. For consistency with the new typing system, consider updating these docstrings.

Also applies to: 134-135

tilelang/language/symbolics.py (1)

19-19: Minor: Docstring mentions str but type hint is DType.

Line 19 states dtype (str): Data type string but the parameter type is now DType. Consider updating to dtype (DType): Data type for the variable.

tilelang/language/allocate.py (1)

262-276: Inconsistent typing: These functions still use dtype: str instead of DType.

For consistency with the updated functions above, consider updating the following functions to use DType instead of str:

  • alloc_wgmma_desc (line 262)
  • alloc_tcgen05_smem_desc (line 266)
  • alloc_tcgen05_instruction_desc (line 270)
  • alloc_tcgen05_instr_desc (line 275)
  • empty overload and implementation (lines 280, 283)
♻️ Proposed fix for consistency
-def alloc_wgmma_desc(dtype: str = _dtypes.uint64):
+def alloc_wgmma_desc(dtype: DType = _dtypes.uint64):
     return alloc_descriptor("wgmma", dtype=dtype)


-def alloc_tcgen05_smem_desc(dtype: str = _dtypes.uint64):
+def alloc_tcgen05_smem_desc(dtype: DType = _dtypes.uint64):
     return alloc_descriptor("tcgen05_smem", dtype=dtype)


-def alloc_tcgen05_instruction_desc(dtype: str = _dtypes.uint32):
+def alloc_tcgen05_instruction_desc(dtype: DType = _dtypes.uint32):
     return alloc_descriptor("tcgen05_instr", dtype=dtype)


 # Alias: short name consistent with imports
-def alloc_tcgen05_instr_desc(dtype: str = _dtypes.uint32):
+def alloc_tcgen05_instr_desc(dtype: DType = _dtypes.uint32):
     return alloc_tcgen05_instruction_desc(dtype)


 `@overload`
-def empty(shape, dtype: str = _dtypes.float32) -> Tensor: ...
+def empty(shape, dtype: DType = _dtypes.float32) -> Tensor: ...


-def empty(*shape, dtype: str = _dtypes.float32) -> Tensor:
+def empty(*shape, dtype: DType = _dtypes.float32) -> Tensor:

Also applies to: 280-283

tilelang/language/math_intrinsics.py (1)

15-148: Consolidate duplicated fast math functions to a single source.

The 8 fast math functions (__log, __log2, __log10, __tan, __cos, __sin, __exp10, __exp) are duplicated between tilelang/language/fastmath.py and tilelang/language/math_intrinsics.py with identical implementations. Consider having math_intrinsics.py import and re-export these functions from fastmath.py rather than duplicating them, or consolidate to a single module to maintain one source of truth.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tilelang/language/logical.py (2)

68-72: ⚠️ Potential issue | 🟡 Minor

Fix all_of error message to reference the correct API.

The message says “T.any” inside all_of, which is confusing for users.

✏️ Proposed fix
-                        "Only support the last dimension to be for T.any currently, please contact us if you need this feature"
+                        "Only support the last dimension to be for T.all currently, please contact us if you need this feature"

12-43: ⚠️ Potential issue | 🔴 Critical

Type annotation does not match runtime handling; BufferLoad is accepted by type hint but rejected at runtime.

BufferLikeType includes BufferLoad, but both any_of and all_of only handle Buffer and BufferRegion, raising ValueError for BufferLoad. This creates a type contract violation that static type checkers will miss.

Additionally, fix the all_of docstring (line 53) which incorrectly says "performs the any operation" and the error message (line 71) which says "T.any currently"—both should reference "all" instead of "any".

🧹 Nitpick comments (2)
tilelang/language/experimental/gemm_sp.py (2)

38-41: Minor style inconsistency: docstrings use Union[] while type hints use | syntax.

The function signature uses the modern BufferLikeType | tir.Var syntax (PEP 604), but the docstrings use the older Union[BufferLikeType, tir.Var] format. Consider aligning docstrings with the code style for consistency.

✏️ Suggested docstring update
     Args:
-        A_sparse (Union[BufferLikeType, tir.Var]): First input matrix dense values
-        E (Union[BufferLikeType, tir.Var]): First input matrix sparse metadata
-        B (Union[BufferLikeType, tir.Var]): Second input matrix
-        C (Union[BufferLikeType, tir.Var]): Output matrix for results
+        A_sparse (BufferLikeType | tir.Var): First input matrix dense values
+        E (BufferLikeType | tir.Var): First input matrix sparse metadata
+        B (BufferLikeType | tir.Var): Second input matrix
+        C (BufferLikeType | tir.Var): Output matrix for results

56-67: Consider extracting duplicate legalize_arguments helper to module level.

This inner function is duplicated identically in gemm_sp_v2 (lines 139-150). Extracting it to a module-level helper would reduce duplication and improve maintainability.

♻️ Suggested refactor

Add a module-level helper before the function definitions:

def _legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType | tir.Var:
    """Convert let-bound variables to their corresponding buffers.

    Args:
        arg (BufferLikeType | tir.Var): Input argument to legalize

    Returns:
        BufferLikeType | tir.Var: The legalized argument
    """
    if isinstance(arg, tir.Var) and T.has_let_value(arg):
        return T.get_let_value(arg).buffer
    return arg

Then use _legalize_arguments in both gemm_sp and gemm_sp_v2 instead of defining the inner function twice.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tilelang/language/allocate.py (2)

33-48: ⚠️ Potential issue | 🟠 Major

Handle non-string bool dtypes in the shared-memory workaround.
With the switch to DType, callers can pass dtype objects (e.g., _dtypes.bool). The current dtype == "bool" check won’t catch those, leaving scope at "shared.dyn" and reintroducing the bool-merge issue. Consider normalizing to a string or explicitly checking dtype objects.

🔧 Proposed fix
-    if dtype == "bool":
+    if str(dtype) == "bool":
         # lei: This is a hack to handle bool type.
         # Because tilelang's merge smem pass cannot merge bool type currently.
         scope = "shared"

279-291: ⚠️ Potential issue | 🟠 Major

Allow positional dtype objects in empty.
The branch for empty(shape_tuple, dtype) only accepts str. With DType, callers can reasonably pass dtype objects positionally (e.g., _dtypes.float32), which currently raises TypeError. Extend the check to accept dtype objects.

🔧 Proposed fix
-    elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str):
+    elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], (str, tl_dtype)):
         return OutTensor(shape[0], shape[1])
🤖 Fix all issues with AI agents
In `@tilelang/language/allocate.py`:
- Around line 79-87: The implementation of alloc_var currently narrows the init
parameter to PrimExpr | None, which conflicts with the overloads that allow int
| float as well; update the implementation signature of alloc_var to accept
init: PrimExpr | int | float | None (matching the overloads) so type checking
passes, and ensure any internal handling of init (inside alloc_var) still
correctly accepts and normalizes int/float to PrimExpr as expected.

In `@tilelang/language/builtin.py`:
- Around line 431-435: Normalize the dtype parameter to a Python string as soon
as it's validated in the function signature handling (where buffer_or_ptr,
offset, num_regs, dtype are processed) so downstream comparisons and uses
succeed; specifically, if dtype is not None convert it via dtype = str(dtype)
(or use dtype.name/str-conversion consistent with DataType expectations)
immediately after validation so comparisons like dtype == "int32" work,
DataType(dtype).bits receives a string, and the value passed into
tir.call_intrin() becomes the expected StringImm. Ensure this normalization
happens before any `if dtype == ...`, `DataType(dtype)`, or
`tir.call_intrin(...)` calls.

@SiriusNEO SiriusNEO changed the title [Refactor] Improve type annotations and reduce some lint errors [Refactor] Improve type annotations and reduce some lint errors in frontend Feb 3, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tilelang/language/print_op.py (3)

74-78: ⚠️ Potential issue | 🟠 Major

Return inside loop exits after first iteration.

Placing return inside the for loop at line 78 causes the function to exit after processing only the first element, contradicting the documented behavior of iterating through all buffer elements.

🐛 Proposed fix - remove return from inside loop
     if condition:
         # Iterate through the buffer elements and print each one.
         for i in serial(elems):
             coords = index_to_coordinates(i, buffer.shape)
-            return tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
+            tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])

96-100: ⚠️ Potential issue | 🟠 Major

Return inside loop exits after first iteration.

Same issue as print_shared_buffer_with_condition: the return at line 100 causes the function to exit after printing only the first element.

🐛 Proposed fix
     if condition:
         # Iterate through the buffer elements and print each one.
         for i in serial(elems):
             coords = index_to_coordinates(i, buffer.shape)
-            return tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords])
+            tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords])

126-130: ⚠️ Potential issue | 🟠 Major

Return inside loop exits after first iteration.

Same issue as other buffer print functions: the return at line 130 causes early exit after the first element.

🐛 Proposed fix
     if condition:
         # Iterate through the buffer elements and print each one.
         for i in serial(elems):
             coords = index_to_coordinates(i, buffer.shape)
-            return tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
+            tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
tilelang/language/reduce_op.py (1)

24-104: ⚠️ Potential issue | 🔴 Critical

Fix early returns that skip writing to out in shared→shared / fragment→shared reductions.
return precedes copy(red_frag_out, out), making the copy unreachable and leaving out stale for those paths. Please return after the copy.

🐛 Suggested fix
@@
-            return tir.call_intrin(
+            call = tir.call_intrin(
                 "handle",
                 tir.op.Op.get(_REDUCE_OP_KEY),
                 to_buffer_region(red_frag_in, access_type="r"),
                 to_buffer_region(red_frag_out, access_type="w"),
                 reduce_type,
                 dim,
                 clear,
             )
             copy(red_frag_out, out)
+            return call
@@
-            return tir.call_intrin(
+            call = tir.call_intrin(
                 "handle",
                 tir.op.Op.get(_REDUCE_OP_KEY),
                 to_buffer_region(buffer, access_type="r"),
                 to_buffer_region(red_frag_out, access_type="w"),
                 reduce_type,
                 dim,
                 clear,
             )
             copy(red_frag_out, out)
+            return call
🤖 Fix all issues with AI agents
In `@tilelang/language/copy_op.py`:
- Line 24: The return type annotation that currently reads "-> tir.PrimExpr" is
incorrect because the function sometimes returns a tir.BufferStore (a tir.Stmt)
and other times returns tir.call_intrin(...) (a tir.PrimExpr); change the
annotation to a union such as "-> Union[tir.PrimExpr, tir.Stmt]" (add the
necessary typing import), and update the docstring (formerly "Returns:
tir.Call") to describe both possible return types and the conditions under which
each (tir.BufferStore vs tir.PrimExpr from tir.call_intrin) is returned.

In `@tilelang/language/print_op.py`:
- Around line 47-58: The else branch in print_global_buffer_with_condition
references i and coords which are only defined inside the for loop and must be
removed; also the current return inside the for loop causes early exit so change
the implementation so that when condition is True you iterate for i in
serial(elems), compute coords = index_to_coordinates(i, buffer.shape) and call
tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i,
buffer[coords]) for each element (no return inside the loop), and when condition
is False simply return None (or do nothing) so no undefined variables are
referenced.
🧹 Nitpick comments (2)
tilelang/language/copy_op.py (1)

120-130: Docstring parameters don't match new type annotations.

The docstring still documents img and col as tir.Buffer, but the signature now accepts BufferLikeType. Consider updating the docstring for consistency:

     """Perform im2col transformation for 2D convolution.

     Args:
-        img (tir.Buffer): Input image buffer
-        col (tir.Buffer): Output column buffer
+        img (BufferLikeType): Input image buffer (Buffer, BufferLoad, or BufferRegion)
+        col (BufferLikeType): Output column buffer (Buffer, BufferLoad, or BufferRegion)

Also, a return type annotation could be added for consistency with copy.

tilelang/language/reduce_op.py (1)

111-245: Docstrings still mention tir.Call; update to tir.PrimExpr.
A few reducer docstrings lag behind the new return type.

✍️ Example update (apply similarly to other reducers)
-    Returns:
-        tir.Call: Handle to the reduction operation
+    Returns:
+        tir.PrimExpr: Handle to the reduction operation

@SiriusNEO
Copy link
Collaborator Author

@codex review

@chatgpt-codex-connector
Copy link

To use Codex here, create a Codex account and connect to github.

@SiriusNEO SiriusNEO requested a review from Copilot February 3, 2026 07:24
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tilelang/language/allocate.py (3)

172-198: ⚠️ Potential issue | 🟡 Minor

Docstring args no longer match the signature.
The docstring still describes num_cols, but the function now takes shape. This will mislead users.

📝 Suggested doc update
-    Args:
-        num_cols (int): Number of columns to allocate in TMEM. Must be a power of 2 and >= 32 but less than or equal to 512.
+    Args:
+        shape (tuple): 2D TMEM shape; must be a 2D tensor and conform to TMEM column requirements.

33-48: ⚠️ Potential issue | 🟠 Major

Handle non-string DType representations in the bool scope workaround.

The type signature allows dtype to be a dtype object, ir.Type, or type in addition to strings. The current check at line 44 only handles the string case dtype == "bool", which means non-string dtype representations (e.g., ir.Type or dtype objects) would bypass the workaround and reintroduce the shared memory merge limitation.

🛠️ Suggested fix
-    if dtype == "bool":
+    dtype_str = dtype if isinstance(dtype, str) else str(dtype)
+    if dtype_str == "bool":
         # lei: This is a hack to handle bool type.
         # Because tilelang's merge smem pass cannot merge bool type currently.
         scope = "shared"

279-291: ⚠️ Potential issue | 🟠 Major

Branch 2 restricts positional dtype to strings only, inconsistent with DType specification.

The empty() function's second branch (line 287) only accepts isinstance(shape[1], str), but DType is defined as dtype | ir.Type | str | type. When called as empty(shape_tuple, _dtypes.float32) (positional), the dtype class passes through to line 287's check and raises TypeError, even though branches 1 and 3 already accept dtype classes via the keyword parameter. This breaks the reasonable call style.

🛠️ Suggested fix
-    elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and isinstance(shape[1], str):
+    elif len(shape) == 2 and isinstance(shape[0], (tuple, list)) and (
+        isinstance(shape[1], str)
+        or isinstance(shape[1], tl_dtype)
+        or (isinstance(shape[1], type) and issubclass(shape[1], tl_dtype))
+    ):
         return OutTensor(shape[0], shape[1])
🧹 Nitpick comments (1)
tilelang/language/allocate.py (1)

201-233: Tighten replication typing to match allowed values.
Since only "all" and "none" are accepted, a precise type hint improves linting.

♻️ Suggested type refinement
-def alloc_reducer(shape: ShapeType, dtype: DType, op: ReducerOp = "sum", replication=None) -> Buffer:
+def alloc_reducer(
+    shape: ShapeType,
+    dtype: DType,
+    op: ReducerOp = "sum",
+    replication: Literal["all", "none"] | None = None,
+) -> Buffer:

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
tilelang/language/copy_op.py (1)

121-135: ⚠️ Potential issue | 🟡 Minor

Align c2d_im2col() docstring arg types with BufferLikeType.

img and col are now BufferLikeType in the signature, but the docstring still says tir.Buffer. Please update for consistency.

Proposed docstring update
-        img (tir.Buffer): Input image buffer
-        col (tir.Buffer): Output column buffer
+        img (BufferLikeType): Input image buffer
+        col (BufferLikeType): Output column buffer
tilelang/language/print_op.py (4)

74-78: ⚠️ Potential issue | 🔴 Critical

Critical: Same early-exit bug—return inside loop prints only the first element.

Same issue as print_global_buffer_with_condition: the return on line 78 exits the loop after the first iteration, printing only one element instead of all elems.

🐛 Proposed fix
     if condition:
         # Iterate through the buffer elements and print each one.
         for i in serial(elems):
             coords = index_to_coordinates(i, buffer.shape)
-            return tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
+            tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])

96-100: ⚠️ Potential issue | 🔴 Critical

Critical: Same early-exit bug in fragment buffer printing.

The return on line 100 exits after the first iteration. After allocating shared memory and copying the buffer (lines 94-95), only the first element gets printed.

🐛 Proposed fix
     if condition:
         # Iterate through the buffer elements and print each one.
         for i in serial(elems):
             coords = index_to_coordinates(i, buffer.shape)
-            return tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords])
+            tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, smem[coords])

126-130: ⚠️ Potential issue | 🔴 Critical

Critical: Same early-exit bug in local buffer printing.

The return on line 130 exits after the first iteration, printing only the first element instead of all elems elements.

🐛 Proposed fix
     if condition:
         # Iterate through the buffer elements and print each one.
         for i in serial(elems):
             coords = index_to_coordinates(i, buffer.shape)
-            return tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])
+            tir.call_extern("handle", "debug_print_buffer_value", msg, buffer.name, i, buffer[coords])

165-180: ⚠️ Potential issue | 🟡 Minor

Docstring does not reflect optional return type.

The return type annotation is tir.PrimExpr | None, but the docstring at lines 179-180 states only tir.PrimExpr. Update the docstring for consistency.

📝 Proposed fix
     Returns:
-        tir.PrimExpr: The TIR expression for the debug print operation.
+        tir.PrimExpr | None: The TIR expression for the debug print operation, or None if no print is executed.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request improves type annotations and reduces lint errors in frontend language APIs for TileLang. It introduces a new tilelang/typing.py module that defines unified public type aliases (ShapeType, DType, BufferLikeType, BarrierType) to standardize type annotations across the codebase and address issue #1768 regarding inappropriate type annotations for T.view and T.reshape.

Changes:

  • Introduced a new tilelang/typing.py module with standardized type aliases for Python 3.9+ compatibility
  • Updated function signatures across multiple language API modules to use the new type aliases
  • Added missing return type annotations to numerous functions
  • Introduced ReduceKind and ReducerOp type literals for improved type safety in reduction operations

Reviewed changes

Copilot reviewed 24 out of 24 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
tilelang/typing.py New module defining BufferLikeType, BarrierType, DType, and ShapeType type aliases with Python 3.9 compatibility
tilelang/utils/language.py Updated buffer utility functions to use BufferLikeType instead of explicit unions
tilelang/language/tir/ir.pyi Added DType annotations to cast/reinterpret/infinity/max_value/min_value functions
tilelang/language/symbolics.py Added DType annotations to dynamic and symbolic functions
tilelang/language/reduce_op.py Added ReduceKind literal, return type annotations, and BufferLikeType usage
tilelang/language/random.py Added return type annotations to RNG functions
tilelang/language/proxy.py Updated Tensor/Buffer proxy functions with ShapeType and DType annotations
tilelang/language/print_op.py Added return type annotations (contains critical bugs)
tilelang/language/pdl.py Added return type annotations to PDL functions
tilelang/language/parser/entry.py Updated BufferProxy with ShapeType and DType annotations
tilelang/language/math_intrinsics.py Added PrimExpr type annotations to all math intrinsics
tilelang/language/loop.py Added return type annotations and parameter type improvements
tilelang/language/logical.py Updated to use BufferLikeType
tilelang/language/kernel.py Updated Kernel function with int/PrimExpr parameter types
tilelang/language/gemm_op.py Updated to use BufferLikeType and BarrierType throughout
tilelang/language/fill_op.py Updated to use BufferLikeType (missing return type annotations)
tilelang/language/fastmath.py Added PrimExpr type annotations
tilelang/language/experimental/gemm_sp.py Updated to use BufferLikeType
tilelang/language/eager/builder.py Fixed const function return type to allow single Var
tilelang/language/customize.py Updated reshape and view with ShapeType and DType
tilelang/language/copy_op.py Updated to use BufferLikeType and added return type (missing for c2d_im2col)
tilelang/language/builtin.py Comprehensive type improvements with BufferLikeType and BarrierType
tilelang/language/ast/ir.py Extensive type annotation updates with ShapeType, DType, and Optional types
tilelang/language/allocate.py Updated with ShapeType, DType, and ReducerOp literal (missing some return types)
Comments suppressed due to low confidence (1)

tilelang/language/copy_op.py:130

  • The function signature is missing a return type annotation. Based on the function implementation which returns tir.call_intrin result, the return type should be tir.PrimExpr.
def c2d_im2col(
    img: BufferLikeType,
    col: BufferLikeType,
    nhw_step: tir.PrimExpr,
    c_step: tir.PrimExpr,
    kernel: int,
    stride: int,
    dilation: int,
    pad: int,
    eviction_policy: Literal["evict_normal", "evict_first", "evict_last"] | None = None,
):

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
tilelang/language/print_op.py (2)

16-25: ⚠️ Potential issue | 🟡 Minor

Update docstrings to match None return types.

These helpers now return None, but their docstrings still say they return tir.PrimExpr, which will mislead users and type checkers.

✅ Suggested docstring fix
@@
-    Returns:
-        tir.PrimExpr: The TIR expression for the debug print operation.
+    Returns:
+        None
@@
-    Returns:
-        tir.PrimExpr: The TIR expression for the debug print operation, if the condition is True.
+    Returns:
+        None
@@
-    Returns:
-        tir.PrimExpr: The TIR expression for the debug print operation.
+    Returns:
+        None
@@
-    Returns:
-        tir.PrimExpr: The TIR expression for the debug print operation.
+    Returns:
+        None

Also applies to: 30-40, 58-69, 78-89


100-106: ⚠️ Potential issue | 🟡 Minor

Avoid asserting when print() is called with defaults.

print_msg requires a non-empty string, but print() allows obj=None with the default msg="", which triggers an assertion. Consider guarding the msg-only path (or relax the assert).

🔧 Suggested guard in print()
@@
-    elif obj is None:
-        print_msg(msg)
+    elif obj is None:
+        if msg:
+            print_msg(msg)

Also applies to: 158-233

tilelang/language/fill_op.py (1)

52-60: ⚠️ Potential issue | 🟠 Major

Handle let-bound BufferLoad without ramp indices.

get_buffer_region_from_load returns None for non-ramp indices; in that case clear() currently raises ValueError even though fill() can handle a BufferLoad by using unit extents. This makes clear(var) fail for valid let-bound loads while clear(buffer_load) succeeds.

🔧 Suggested fix
         elif isinstance(buffer_region, tir.BufferLoad):
             region = get_buffer_region_from_load(buffer_region)
             if region is None:
-                raise ValueError(f"Invalid buffer region: {buffer_region}, type: {type(buffer_region)}")
-            return fill(region, 0)
+                return fill(buffer_region, 0)
+            return fill(region, 0)
tilelang/language/allocate.py (2)

172-195: ⚠️ Potential issue | 🟡 Minor

Update alloc_tmem docstring to match the shape-based signature.

The docstring still references num_cols, but the API now accepts a 2D shape. This is misleading for callers.

📝 Suggested docstring update
-    Args:
-        num_cols (int): Number of columns to allocate in TMEM. Must be a power of 2 and >= 32 but less than or equal to 512.
+    Args:
+        shape (tuple[int | PrimExpr, int | PrimExpr]): 2D shape for TMEM allocation.
+            The number of columns should be a power of 2 and >= 32 (<= 512).

279-291: ⚠️ Potential issue | 🟠 Major

Handle all DType variants in empty() positional argument branch.

The function accepts dtype: DType which includes tvm.DataType (defined as dtype | ir.Type | str | type), but line 286-287 only checks for str when empty() is called with 2 positional arguments. This causes valid dtype inputs like tvm.DataType objects to raise TypeError.

When called as empty((shape), dtype_object), the dtype object becomes shape[1], and the isinstance(shape[1], str) check fails for non-string DType variants. Expand the isinstance check to accept all valid DType types, or validate the dtype argument before the shape pattern matching.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 6 comments.

Comments suppressed due to low confidence (3)

tilelang/language/print_op.py:39

  • The docstring incorrectly states that the function returns tir.PrimExpr, but the actual return type is None. The docstring should be updated to match the new return type annotation.
def print_var_with_condition(condition: tir.PrimExpr, var: tir.PrimExpr, msg: str = "") -> None:
    """
    Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True.

    Parameters:
        condition (tir.PrimExpr): A TIR expression representing the condition to check.
        var (tir.PrimExpr): The variable or expression to be printed.

    Returns:
        tir.PrimExpr: The TIR expression for the debug print operation, if the condition is True.

tilelang/language/print_op.py:24

  • The docstring incorrectly states that the function returns tir.PrimExpr, but the actual return type is None. The docstring should be updated to match the new return type annotation.
    """
    Prints the value of a TIR primitive expression (PrimExpr) for debugging purposes.

    Parameters:
        var (tir.PrimExpr): The variable or expression to be printed.

    Returns:
        tir.PrimExpr: The TIR expression for the debug print operation.

tilelang/language/print_op.py:68

  • The docstring incorrectly states that the function returns tir.PrimExpr, but the actual return type is None. The docstring should be updated to match the new return type annotation.
def print_shared_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> None:
    """
    Conditionally prints the values of a flattened TIR buffer if the condition is True.

    Parameters:
        condition (tir.PrimExpr): A TIR expression representing the condition to check.
        buffer (tir.Buffer): The buffer whose values need to be printed.
        elems (int): The number of elements in the buffer to print.

    Returns:
        tir.PrimExpr: The TIR expression for the debug print operation.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (3)
tilelang/language/allocate.py (1)

33-48: ⚠️ Potential issue | 🟠 Major

Normalize DType before comparing to "bool".

With DType now accepting dtype classes, dtype == "bool" won’t catch _dtypes.bool and the special‑case scope hack won’t trigger.

✅ Suggested normalization
-    if dtype == "bool":
+    dtype_name = str(dtype)
+    if dtype_name == "bool":
         # lei: This is a hack to handle bool type.
         # Because tilelang's merge smem pass cannot merge bool type currently.
         scope = "shared"
tilelang/language/reduce_op.py (2)

46-95: ⚠️ Potential issue | 🟠 Major

Return-before-copy drops fragment→shared results.

return tir.call_intrin(...) makes the subsequent copy(...) unreachable in the shared→shared and fragment→shared branches, so results never reach out.

✅ Preserve copy and still return the intrinsic handle
         if is_shared(buffer) and is_shared(out):
@@
-            return tir.call_intrin(
+            call = tir.call_intrin(
                 "handle",
                 tir.op.Op.get(_REDUCE_OP_KEY),
                 to_buffer_region(red_frag_in, access_type="r"),
                 to_buffer_region(red_frag_out, access_type="w"),
                 reduce_type,
                 dim,
                 clear,
             )
             copy(red_frag_out, out)
+            return call
@@
         elif is_fragment(buffer) and is_shared(out):
@@
-            return tir.call_intrin(
+            call = tir.call_intrin(
                 "handle",
                 tir.op.Op.get(_REDUCE_OP_KEY),
                 to_buffer_region(buffer, access_type="r"),
                 to_buffer_region(red_frag_out, access_type="w"),
                 reduce_type,
                 dim,
                 clear,
             )
             copy(red_frag_out, out)
+            return call

248-285: ⚠️ Potential issue | 🟡 Minor

Fragment cumsum path returns None despite tir.PrimExpr return type.

cumsum() returns cumsum_fragment(...) for fragment inputs, but the macro doesn’t return the intrinsic handle.

✅ Return the intrinsic handle after emitting copies
     copy(src, cumsum_smem)
-    tir.call_intrin(
+    call = tir.call_intrin(
         "handle",
         tir.op.Op.get("tl.tileop.cumsum"),
         to_buffer_region(cumsum_smem, access_type="r"),
         to_buffer_region(cumsum_smem, access_type="w"),
         dim,
         reverse,
     )
     copy(cumsum_smem, dst)
+    return call
🤖 Fix all issues with AI agents
In `@tilelang/_typing.py`:
- Around line 5-33: Add "from __future__ import annotations" as the very first
line of the module so the PEP 604 union types used in BarrierType,
BufferLikeType, DType, and ShapeType are deferred and won't be evaluated at
import time on Python 3.9; ensure the future import is placed before any other
imports or code in tilelang/_typing.py.
🧹 Nitpick comments (2)
tilelang/language/gemm_op.py (1)

40-51: Return type annotation may be inaccurate.

The function signature declares -> BufferLikeType, but if arg is a tir.Var without a let value, it's returned unchanged as a Var. The return type should be BufferLikeType | tir.Var to accurately reflect all code paths.

♻️ Suggested fix
-    def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType:
+    def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType | tir.Var:
tilelang/language/proxy.py (1)

58-71: Consider unifying default dtype values.

BufferProxy.__call__ uses _dtypes.float32 as the default (line 26), while BufferProxy.from_ptr, BaseTensorProxy.__call__, and BaseTensorProxy.from_ptr use the string "float32". Both should work, but consistency would improve maintainability.

Also applies to: 86-136

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tilelang/language/reduce_op.py (4)

24-108: ⚠️ Potential issue | 🟡 Minor

Align reduce() docstring with ReduceKind + None return.

The signature now uses ReduceKind and returns None, but the docstring still documents str and tir.Call, which is now misleading.

📘 Suggested docstring update
-        reduce_type (str): Type of reduction ('max', 'min', 'sum', 'abssum')
+        reduce_type (ReduceKind): Type of reduction ("sum", "abssum", "max", "absmax", "min", "bitand", "bitor", "bitxor")
 ...
-    Returns:
-        tir.Call: Handle to the reduction operation
+    Returns:
+        None

111-245: ⚠️ Potential issue | 🟡 Minor

Update reduce_ wrapper docstrings to reflect None return.*

Each wrapper now returns None, but the “Returns” sections still mention tir.Call, which can confuse API users.


248-285: ⚠️ Potential issue | 🟡 Minor

Broaden cumsum_fragment dst type to BufferLikeType.

cumsum() accepts buffer regions/loads, but cumsum_fragment() still narrows dst to tir.Buffer, which will trip type checkers and contradicts the docstring.

🧩 Suggested typing fix
 def cumsum_fragment(
     src: BufferLikeType,
-    dst: tir.Buffer,
+    dst: BufferLikeType,
     dim: int,
     reverse: bool,
 ) -> None:

287-365: ⚠️ Potential issue | 🟡 Minor

cumsum() return type doesn’t match fragment path.

When src is a fragment, cumsum() returns the result of cumsum_fragment() (currently None), but the signature/docstring promise tir.PrimExpr.

✅ Option: make the return type accurate
-def cumsum(
-    src: BufferLikeType,
-    dst: BufferLikeType | None = None,
-    dim: int = 0,
-    reverse: bool = False,
-) -> tir.PrimExpr:
+def cumsum(
+    src: BufferLikeType,
+    dst: BufferLikeType | None = None,
+    dim: int = 0,
+    reverse: bool = False,
+) -> tir.PrimExpr | None:
@@
-    Returns:
-        tir.Call: A handle to the emitted cumulative-sum operation.
+    Returns:
+        tir.Call | None: A handle to the emitted cumulative-sum operation (None for fragment path).

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (5)
tilelang/language/warpgroup.py (1)

17-30: ⚠️ Potential issue | 🟡 Minor

Docstring return type is now stale.

The function now annotates -> WarpSpecializeFrame but the docstring still says Tuple[frame.LaunchThreadFrame]. Align the docstring with the updated public API.

✏️ Suggested docstring fix
-    Returns
-    -------
-    res : Tuple[frame.LaunchThreadFrame]
-        The result LaunchThreadFrame.
+    Returns
+    -------
+    res : WarpSpecializeFrame
+        The WarpSpecializeFrame produced by the FFI call.
tilelang/language/print_op.py (1)

45-54: ⚠️ Potential issue | 🟡 Minor

Inconsistent return type annotation.

All other similar buffer print functions (print_shared_buffer_with_condition, print_fragment_buffer_with_condition, print_local_buffer_with_condition) were updated to return -> None, but this function still has -> tir.PrimExpr. Since the function body has no return statement (correctly, as it's a side-effect-only macro), the type annotation should also be updated for consistency.

🔧 Proposed fix
 `@macro`
-def print_global_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> tir.PrimExpr:
+def print_global_buffer_with_condition(condition: tir.PrimExpr, buffer: tir.Buffer, elems: int, msg: str = "") -> None:
     """
     Conditionally prints the values of a flattened TIR buffer if the condition is True.
     """
tilelang/language/reduce_op.py (1)

109-127: ⚠️ Potential issue | 🟡 Minor

Docstrings are inconsistent with the new None return type.

The docstrings for reduce_max (lines 122-125), reduce_min (lines 139-141), reduce_sum (lines 164-166), reduce_abssum (lines 179-181), reduce_absmax (lines 194-196), reduce_bitand (lines 209-211), reduce_bitor (lines 224-226), and reduce_bitxor (lines 239-241) all state:

Returns:
    tir.Call: Handle to the reduction operation

However, the function signatures now correctly indicate -> None. Consider updating or removing the Returns section from these docstrings to match the actual behavior.

📝 Proposed fix for reduce_max (apply similar changes to other functions)
 def reduce_max(buffer: tir.Buffer, out: tir.Buffer, dim: int = -1, clear: bool = True) -> None:
     """Perform reduce max on input buffer, store the result to output buffer

     Parameters
     ----------
     buffer : Buffer
         The input buffer.
     out : Buffer
         The output buffer.
     dim : int
         The dimension to perform reduce on
     clear : bool
         If set to True, the output buffer will first be initialized to -inf.
-    Returns
-    -------
-    handle : PrimExpr
     """
tilelang/language/gemm_op.py (1)

40-51: ⚠️ Potential issue | 🟡 Minor

Return type annotation doesn't match actual behavior.

The function can return a tir.Var when the input is a non-let-bound variable (line 51 returns arg unchanged), but the return type annotation specifies BufferLikeType. Additionally, the docstring references old types (Union[tir.Buffer, tir.Var]) instead of the new type aliases.

🔧 Proposed fix
-    def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType:
+    def legalize_arguments(arg: BufferLikeType | tir.Var) -> BufferLikeType | tir.Var:
         """Convert let-bound variables to their corresponding buffers.

         Args:
-            arg (Union[tir.Buffer, tir.Var]): Input argument to legalize
+            arg (BufferLikeType | tir.Var): Input argument to legalize

         Returns:
-            Union[tir.Buffer, tir.Var]: The legalized argument
+            BufferLikeType | tir.Var: The legalized argument (buffer if let-bound, otherwise unchanged)
         """
tilelang/language/customize.py (1)

70-76: ⚠️ Potential issue | 🟡 Minor

Fix docstring to match type annotation, and clarify the intentional duplication.

The function duplicates loop_break() in tilelang/language/builtin.py:797, but the customize.py version is intentionally the public API (explicitly imported in __init__.py before the builtin wildcard import). Consider whether this duplication can be eliminated or documented.

The docstring inconsistency should be resolved: it states the return type is tir.Call (line 74), but the annotation is PrimExpr. Update the docstring to match the annotation or be more specific (e.g., "Returns: PrimExpr (specifically a tir.Call to the intrinsic)").

🤖 Fix all issues with AI agents
In `@tilelang/language/reduce_op.py`:
- Around line 285-290: The cumsum function signature declares it returns
tir.PrimExpr but when src is a fragment it returns the result of cumsum_fragment
which is currently typed to return None, causing a type mismatch; fix by making
the return types consistent — either change cumsum's signature to "->
tir.PrimExpr | None" if fragments legitimately cause no PrimExpr, or update
cumsum_fragment to return a tir.PrimExpr (propagate and return the intrinsic
result) and keep cumsum as "-> tir.PrimExpr"; locate the functions cumsum and
cumsum_fragment and apply the chosen change so callers and type checks align.
🧹 Nitpick comments (7)
tilelang/language/copy_op.py (1)

43-44: Docstring still states tir.Call but return type is now tir.PrimExpr | tir.Stmt.

The docstring on line 44 says "Returns: tir.Call" but doesn't reflect the tir.BufferStore path (line 87). Consider updating to describe both return scenarios.

📝 Suggested docstring update
     Returns:
-        tir.Call: A handle to the copy operation
+        Union[tir.PrimExpr, tir.Stmt]: Either a tir.BufferStore (when both src and dst
+            are scalar BufferLoads without region extents) or a tir.Call handle to the
+            copy operation.
tilelang/language/print_op.py (5)

15-26: Docstring "Returns" section is now stale.

The return type annotation was correctly updated to None, but the docstring at lines 23-24 still documents Returns: tir.PrimExpr. Since this is a side-effect-only macro, the Returns section should either be removed or updated to indicate None.

📝 Suggested docstring update
     """
     Prints the value of a TIR primitive expression (PrimExpr) for debugging purposes.

     Parameters:
         var (tir.PrimExpr): The variable or expression to be printed.
-
-    Returns:
-        tir.PrimExpr: The TIR expression for the debug print operation.
     """

29-42: Docstring "Returns" section is stale.

Same issue as print_var - the docstring still documents a tir.PrimExpr return value at lines 38-39, but the signature now correctly indicates -> None.


57-74: Docstring "Returns" section is stale.

The return type annotation was correctly updated to None, but the docstring at lines 67-68 still documents Returns: tir.PrimExpr.


77-96: Docstring "Returns" section is stale.

Same pattern - the docstring at lines 87-88 still documents a tir.PrimExpr return value, but the signature now correctly indicates -> None.


168-172: Malformed docstring parameter description.

Line 171 appears to be a continuation of the warp_id parameter description or a separate note, but it's formatted as a standalone line without proper indentation or labeling. This makes the docstring harder to parse.

📝 Suggested fix
     Parameters:
         obj (Any): The object to print. It can be either a tir.Buffer, tir.PrimExpr, or None (for msg-only print).
         msg (str): An optional message to include in the print statement.
         warp_group_id (int): The warp group id to print.
-        warp_id (int): The warp id to print.
-        print thread will be warp_group_id * warp_group_size + warp_id
+        warp_id (int): The warp id to print. The print thread will be
+            warp_group_id * warp_group_size + warp_id * warp_size.
tilelang/language/gemm_op.py (1)

199-222: Type annotations don't reflect that tir.Var is accepted.

The docstrings correctly document that Var is accepted for A, B, C, and mbar (lines 213, 222), but the type annotations only specify BufferLikeType and BarrierType. Since legalize_arguments internally handles tir.Var conversion, consider updating the public signatures to explicitly accept Var for better IDE support and API clarity.

♻️ Optional: Make type annotations match documented behavior
 def gemm(
-    A: BufferLikeType,
-    B: BufferLikeType,
-    C: BufferLikeType,
+    A: BufferLikeType | tir.Var,
+    B: BufferLikeType | tir.Var,
+    C: BufferLikeType | tir.Var,
     transpose_A: bool = False,
     transpose_B: bool = False,
     policy: GemmWarpPolicy = GemmWarpPolicy.Square,
     clear_accum: bool = False,
     k_pack: int = 1,
     wg_wait: int = 0,
-    mbar: BarrierType | None = None,
+    mbar: BarrierType | tir.Var | None = None,
 ) -> tir.PrimExpr:

The same change could be applied to gemm_v1 and gemm_v2 for consistency.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 26 out of 26 changed files in this pull request and generated 14 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +30 to 33
def print_var_with_condition(condition: tir.PrimExpr, var: tir.PrimExpr, msg: str = "") -> None:
"""
Conditionally prints a TIR primitive expression (PrimExpr) if a given condition is True.
Copy link

Copilot AI Feb 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docstring states "Returns: tir.PrimExpr: The TIR expression for the debug print operation, if the condition is True" but the function signature has been updated to return None. The docstring should be updated or the Returns section should be removed to reflect that this macro function doesn't return a value.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tilelang/language/reduce_op.py (4)

122-125: ⚠️ Potential issue | 🟡 Minor

Docstring incorrectly claims a return value.

The function signature is -> None, but the docstring states it returns a PrimExpr handle. The docstring should be updated or the "Returns" section removed to match the actual behavior.

📝 Suggested fix
     dim : int
         The dimension to perform reduce on
     clear : bool
         If set to True, the output buffer will first be initialized to -inf.
-    Returns
-    -------
-    handle : PrimExpr
     """

139-140: ⚠️ Potential issue | 🟡 Minor

Docstrings across reduce helpers claim returns that don't exist.

Multiple reduce_* functions (reduce_min, reduce_sum, reduce_abssum, reduce_absmax, reduce_bitand, reduce_bitor, reduce_bitxor) have docstrings claiming they return a tir.Call handle, but all these functions now return None. Consider removing or updating the "Returns" sections in these docstrings for consistency.

Also applies to: 164-165, 179-180, 194-195, 209-210, 224-225, 239-240


331-332: ⚠️ Potential issue | 🟡 Minor

Docstring should mention the None return case.

The return type is tir.PrimExpr | None, but the docstring only mentions returning a tir.Call handle. It should clarify that None is returned when src is a fragment buffer (delegating to the macro implementation).

📝 Suggested fix
     Returns:
-        tir.Call: A handle to the emitted cumulative-sum operation.
+        tir.PrimExpr | None: A handle to the emitted cumulative-sum operation,
+            or None when operating on fragment buffers (macro implementation).

437-466: ⚠️ Potential issue | 🟡 Minor

Add warp_reduce_bitxor function or clarify why it's intentionally omitted.

The file defines reduce_bitxor (line 231) and includes "bitxor" in ReduceKind (line 21), but there's no corresponding warp_reduce_bitxor function. Both bitand and bitor have warp equivalents (lines 437 and 453), making this an inconsistency. Either add warp_reduce_bitxor for symmetry or document why warp-level bitwise-xor reduction is not supported.

🧹 Nitpick comments (1)
tilelang/language/reduce_op.py (1)

44-45: Consider using ReduceKind instead of str for type consistency.

The outer reduce function uses ReduceKind for the reduce_type parameter, but the nested reduce_macro uses str. For full type consistency and better IDE support, consider using ReduceKind here as well.

🔧 Suggested fix
     `@macro`
-    def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: str, dim: int, clear: bool) -> None:
+    def reduce_macro(buffer: tir.Buffer, out: tir.Buffer, reduce_type: ReduceKind, dim: int, clear: bool) -> None:

@SiriusNEO SiriusNEO merged commit 3b3369e into tile-ai:main Feb 3, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant