Skip to content

Conversation

@CokeDong
Copy link

@CokeDong CokeDong commented Jan 29, 2026

  • Summary

    • This PR introduces CommonIR as a pluggable abstract TileLang target, and providing an end-to-end path to lower → generate CommonIR → compile via DLCompiler → run via DLCompiler on diverse domestic GPUs.
    • CommonIR aims to support seamless execution of tilelang DSL on domestic GPUs, meanwhile maintain performance through DLCompiler integration.
  • User-facing changes

    • use environment variable USE_COMMONIR to enable or disable the CommonIR feature pluggably.
    • USE_COMMONIR=0: all existing targets keep their previous defaults entirely unaffected.
  • What’s included

    • Target / lowering glue
      • tilelang.utils.target.determine_target() recognizes commonir target and injects the commonir key.
      • tilelang.engine.lower.device_codegen*() dispatches to target.build.tilelang_commonir* when the target has the commonir key.
    • Codegen + runtime
      • New CommonIR codegen (src/target/codegen_commonir.*) built on top of the tvm codegen infrastructure (CodeGenC).
      • New runtime module builder for CommonIR without compile (src/target/rt_mod_commonir.cc).
    • JIT / execution backend
      • New CommonIR adapter, execution backend are integrated into DLCompiler.
    • Tests
      • Add examples/commonir/*.py (tested currently on Huawei Ascend NPU).
  • Requirements / constraints

    • CommonIR backend requires DLCompiler. Please install DLCompiler accordingly.
  • Backward compatibility

    • Intended to be fully backward compatible.
  • Follow-ups (not in this PR)

    • Codegen: support more tilelang ops and enhancements.
    • Tests/CI: support more testcase and supplement CI on more domestic GPUs.
    • Kernel cache: extend kernel cache to support commonir artifacts (ideally on DLCompiler).
    • Compatibility: support profiler for commonir (ideally on DLCompiler).

Summary by CodeRabbit

  • New Features

    • Added CommonIR target support and runtime backend selectable via env var (USE_COMMONIR); integrates DLCompiler for CommonIR compilation.
    • New JIT/pipeline for NPU/CommonIR kernels with profiling and benchmarking support.
  • Documentation & Examples

    • Added vector-add examples (including profiler) and a tiled GEMM example demonstrating end-to-end compile/run.
  • Chores

    • Added external DLCompiler submodule.

✏️ Tip: You can customize this high-level summary in your review settings.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 29, 2026

📝 Walkthrough

Walkthrough

Adds a CommonIR backend: a DLCompiler git submodule, a new C++ COMMONIR codegen and runtime integration, Python JIT/profiler for CommonIR, examples, and pipeline changes to select and short-circuit to the CommonIR path.

Changes

Cohort / File(s) Summary
Submodule
\.gitmodules, 3rdparty/DLCompiler
Adds 3rdparty/DLCompiler git submodule and updates submodule pointer.
Build config
CMakeLists.txt
Adds USE_COMMONIR toggle (env var and auto-detect) and includes CommonIR sources into build when enabled.
C++ CommonIR Backend
src/target/codegen_commonir.h, src/target/codegen_commonir.cc, src/target/rt_mod_commonir.cc
Implements CodeGenTileLangCOMMONIR, SSA types (Scalar/Memref/Tensor), extensive IR visitors, memory/layout handling, GEMM/fill/copy codegen, function emission, and registers commonir target plus runtime builder.
Python pipeline integration
tilelang/utils/target.py, tilelang/engine/lower.py, tilelang/engine/phase.py, tilelang/jit/__init__.py
Adds COMMONIR_enabled flag and commonir target; short-circuits lowering/optimization and dispatches to CommonIR codegen/JIT when enabled or target is commonir.
CommonIR JIT & tooling
tilelang/jit/jit_commonir.py, tilelang/jit/kernel_npu.py
New CommonIR JIT compiler (compiler_common), grid/signature parsing, MLIR I/O, JitKernel_NPU, Profiler, NPU adapter, tensor supply factories, and benchmarking utilities.
API/data model change
tilelang/engine/param.py
Removed @dataclass decorator from KernelParam (changed to plain class).
Examples
examples/commonir/add_vector.py, examples/commonir/add_vector_profiler.py, examples/commonir/gemm.py
Adds vector add, profiler example, and tiled GEMM examples demonstrating CommonIR pipeline and validation.
Docs
README.md
Updates Latest News to mention CommonIR support and DLCompiler integration.

Sequence Diagram

sequenceDiagram
    participant User
    participant TileLang as TileLang Compiler
    participant Lower as Lower/Optimize
    participant CodeGen as CommonIR CodeGen
    participant JIT as CommonIR JIT
    participant NPU as NPU Device

    User->>TileLang: submit prim_func + target
    TileLang->>Lower: lower(mod)
    Note over Lower: if target == "commonir" or COMMONIR_enabled\napply limited passes and return MLIR
    Lower->>CodeGen: emit CommonIR/C source
    CodeGen->>JIT: provide C source + metadata
    JIT->>JIT: parse grid & signature, compile via CommonIRCompiler
    JIT-->>User: return JitKernel_NPU
    User->>NPU: execute kernel
    NPU-->>User: results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 I found a new submodule door,
CommonIR, a shiny core.
GEMM and vectors, tuned and spry,
DLCompiler waves hello — hi!
TileLang hops forward, wheels set to roar.

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 8.47% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The PR title clearly describes the main change: adding CommonIR as an abstract backend for TileLang, which is the primary objective of this comprehensive PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 18

🤖 Fix all issues with AI agents
In `@examples/commonir/gemm.py`:
- Line 9: The module-level call to torch.npu.current_device() (stored in the
device variable) can fail at import if NPU isn't available; move the device
detection inside main() (or another runtime entrypoint) and replace the
module-level device assignment with deferred lookup so imports succeed without
an active NPU context; update any references that used the module-level device
variable to retrieve the device inside main() or pass it down to functions that
need it.

In `@README.md`:
- Line 16: The changelog entry "01/22/2025 🚀: Added CommonIR support, enabling
compilation on wider domestic GPU through
[DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration." is out of
chronological order; either correct its date if it should be the latest (e.g.,
change year to 2026) or move this exact entry string to its proper position in
the reverse-chronological list (after the "02/10/2025" entry and before the
"01/20/2025" entry) so the README's news section remains sorted newest-first.

In `@src/target/codegen_commonir.cc`:
- Around line 185-219: GetCastOp can fall through without returning for
unsupported src/dst combinations; add a clear fallback at the end of GetCastOp
(referencing GetCastOp, src_type, dst_type, DataType::bits) that either returns
a safe sentinel (e.g., empty string) or better, raises a deterministic error
(throw std::runtime_error or call your project's fatal logging) including the
src/dst type info so undefined behavior is avoided and debugging is easier.
Ensure the fallback is reached only when none of the existing branches match and
include descriptive text with src_type/dst_type in the error.
- Around line 1301-1314: The is_arg branch currently emits a single "?x" and
collapses multidimensional argument memrefs; update the branch in GetMemrefInfo
(the block using memrefObj->is_arg, memrefObj->shape and building memref_type)
to iterate over memrefObj->shape and append a "?x" for each dimension (similar
to the non-arg loop), preserving the original rank, then call
PrintType(memrefObj->dtype, memref_type) as before.
- Around line 1148-1149: The VisitExpr_ implementation for DivNode currently
emits a placeholder "<<<divf>>>"; replace this by inspecting op->dtype (use
op->dtype() or op->dtype().base_type()) and emit the correct division IR token
for floats vs integers (e.g., a floating-point divide string for float types and
an integer divide string or signed/unsigned variant for integer types) via the
same PrintBinary call (e.g., PrintBinary(op, "<correct-div-op>", os, this)); if
the dtype is unsupported, hard-fail (throw or LOG(FATAL)) with a clear message
referencing DivNode so callers see the unsupported dtype instead of invalid IR.
- Around line 92-147: The helpers GetStrideFromShape, GetBufferStrides, and
getBroadcastDim dereference as_const_int() results without null checks, which
breaks on non-constant (symbolic) shapes/strides; update them to validate
as_const_int(...) before dereferencing and handle non-constant cases explicitly
(e.g., return an error, throw/log an explanatory message, or implement dynamic
stride handling) — in GetStrideFromShape skip or error on non-constant shape
elements and avoid building shape_int from nulls; in GetBufferStrides check each
buffer->strides element via as_const_int and handle the null case; in
getBroadcastDim guard every *as_const_int(...) access, ensure shapes are
constant or emit a clear error/exception when symbolic dimensions are
encountered.

In `@src/target/codegen_commonir.h`:
- Around line 85-86: The header declares PrintExtraAttrs(const PrimFunc &f) but
there's no implementation causing linker/vtable errors; add a definition in
codegen_commonir.cc that either implements the intended behavior or a minimal
stub (e.g., empty body) to satisfy the linker. Locate the declaration of
PrintExtraAttrs in src/target/codegen_commonir.h and add a corresponding
function definition in src/target/codegen_commonir.cc matching the exact
signature (void PrintExtraAttrs(const PrimFunc &f)) and ensure it uses/compiles
with the PrimFunc type and existing namespace so symbol linkage matches.
- Around line 14-17: The header codegen_commonir.h uses std::map and std::vector
but only includes <unordered_map> and <string>; add the missing includes for
<map> and <vector> at the top of the header so symbols like std::map and
std::vector used anywhere in codegen_commonir.h are always available regardless
of include order.

In `@tilelang/engine/lower.py`:
- Around line 261-266: Remove the leftover debug print by deleting the
commented-out "print(codegen_mod.inspect_source())" line and then address the
host_mod being passed as None to CompiledArtifact: either update this call in
lower.py (where target.kind.name / COMMONIR_enabled lead to device_codegen(mod,
target)) to supply a valid tvm.IRModule host_mod, or confirm and update the
CompiledArtifact type/signature and all downstream consumers to accept
Optional[tvm.IRModule] (search for CompiledArtifact usages to ensure None is
handled safely), keeping the call site and symbol names device_codegen and
CompiledArtifact consistent.

In `@tilelang/jit/__init__.py`:
- Around line 101-104: The current routing uses COMMONIR_enabled with lower
precedence than an explicit target and returns the wrong type; change the
condition so an explicit target argument takes priority over the global flag
(i.e., check target == 'commonir' first or require COMMONIR_enabled only when
target is None/default), and ensure the returned object matches the declared
return type by either returning a JITKernel instance or updating the
function/generic signature to allow JitKernel_NPU; specifically update the
branch using COMMONIR_enabled, the call to
compiler_common(out_idx=out_idx)/compiler_commonir.compile(func), and the
function's return annotation (JITKernel[_KP,_T] vs JitKernel_NPU) so the runtime
branch and type signature agree.
- Around line 40-41: The unconditional top-level import of compiler_common from
.jit_commonir causes ImportError on systems without the NPU stack because
jit_commonir imports torch_npu at module scope; move the import into the
conditional branch where compiler_common is actually needed (perform a local
import of .jit_commonir and access compiler_common inside that block) so
jit_commonir’s torch_npu dependency is only resolved when CommonIR functionality
is used.

In `@tilelang/jit/jit_commonir.py`:
- Line 88: The call uses redundant double conversion str(str(self.mod)) in the
re.search invocation; update the expression in jit_commonir.py (the line that
sets match via re.search(pattern, ...)) to pass a single string conversion—e.g.,
replace str(str(self.mod)) with str(self.mod)—so the regex runs against the same
string value without the unnecessary nested str() call.
- Around line 20-24: Remove the duplicate typing imports by editing the import
block in jit_commonir.py: keep only the new symbol(s) needed here (e.g., List)
and delete Any and Union which are already imported earlier (line 8), ensuring
there are no redundant imports of Any or Union in the file.
- Line 13: The unconditional top-level import of torch_npu in
tilelang.jit.jit_commonir breaks imports on systems without the NPU stack;
change it to a lazy or guarded import (e.g., move the import into the
function(s) or method(s) that actually use torch_npu or wrap the import in a
try/except ImportError and surface a clear runtime error only when CommonIR
functionality is invoked). Locate the module-level import "import torch_npu" in
jit_commonir.py and replace it with an on-demand import inside the specific
functions/classes that require torch_npu (or a guarded try/except that sets a
flag and raises a descriptive error when CommonIR APIs are called on non-NPU
hosts).
- Around line 1-14: Remove the seven unused imports from the top of
tilelang/jit/jit_commonir.py: subprocess, dataclass (from dataclasses), Path
(from pathlib), ModuleType (from types), sysconfig, pybind11, and functools;
delete their import tokens or whole import statements referencing them and keep
the remaining imports intact (os, re, tempfile, typing items, shutil, pybind11
only if used elsewhere, torch, torch_npu) and run lint/tests to confirm nothing
else references these symbols (search for subprocess, dataclass, Path,
ModuleType, sysconfig, pybind11, functools to ensure safe removal).

In `@tilelang/jit/kernel_npu.py`:
- Around line 39-56: The timing estimate computed from
start_event.elapsed_time(end_event) can be zero; clamp estimate_ms to a small
positive floor (e.g., 1e-3 ms) before using it to compute n_warmup and n_repeat
to avoid divide-by-zero and extreme counts. Locate the block that creates
start_event/end_event and sets estimate_ms, then set estimate_ms =
max(estimate_ms, <small_positive>) prior to computing n_warmup/n_repeat (keeping
later overrides for _n_warmup/_n_repeat intact); reference symbols: start_event,
end_event, estimate_ms, n_warmup, n_repeat, _n_warmup, _n_repeat, warmup, rep.
- Around line 94-113: The get_tensor function currently treats empty
KernelParam.shape as an error, but KernelParam.is_scalar() uses an empty shape
for 0‑dim tensors; remove the early ValueError and allow empty shapes as valid
scalars. Update the guard in get_tensor (the block that checks "if
hasattr(param, 'shape') and not param.shape:") to not raise for empty shapes,
and ensure the subsequent loop that checks for tir.Var only runs when
param.shape is non-empty (e.g., iterate only if param.shape). This preserves the
tir.Var static-shape check while allowing scalar KernelParam instances to be
handled as 0‑dim tensors.
- Around line 178-192: In _legalize_result_idx, fix the int branch bounds check
to forbid result_idx >= len(self.params) (not just >) and normalize negative
indices as currently done; additionally, when result_idx is a list (the else
branch), validate each element is an int, normalize any negative indices by
adding len(self.params), and raise ValueError for any index < 0 or >=
len(self.params); update references to params via self.params and return the
normalized list of ints.
🧹 Nitpick comments (15)
3rdparty/DLCompiler (2)

1-1: Document the rationale for this specific commit.

Consider adding documentation (e.g., in a CHANGELOG, commit message, or README) explaining why this specific commit (51ef7ccd953557fc1e3e7f46efcd0c3fdb88c53d) was chosen. This helps with:

  • Future maintenance and upgrades
  • Understanding compatibility requirements
  • Debugging issues related to this version

1-1: Establish a maintenance and update strategy for the submodule.

Third-party submodules require ongoing maintenance. Consider:

  • Setting up automated checks for updates and security patches
  • Documenting the update process and testing requirements
  • Establishing ownership for monitoring and updating this dependency
  • Adding submodule update guidelines to your contribution docs
examples/commonir/add_vector.py (3)

3-3: Remove unused import.

The os module is imported but never used in this file.

🧹 Proposed fix
-import os
-
 import tilelang

35-35: Consider adding explicit target='commonir' for consistency.

The gemm.py example in the same directory explicitly specifies target='commonir' (see examples/commonir/gemm.py line 38), but this example omits it. For clarity and consistency across CommonIR examples, consider making the target explicit.

♻️ Proposed fix
-    compiled_kernel = tilelang.compile(func)
+    compiled_kernel = tilelang.compile(func, target='commonir')

37-39: Replace eval() with getattr() for safer dtype resolution.

Using eval() to construct torch dtypes is flagged as potentially insecure (S307). While the dtype variable is controlled in this example, getattr() is the idiomatic and safer approach.

🛡️ Proposed fix
-    v1 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu()
-    v2 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu()
-    v3 = torch.zeros(size=[seq_len], dtype=eval("torch." + dtype)).npu()
+    torch_dtype = getattr(torch, dtype)
+    v1 = torch.randn(size=[seq_len], dtype=torch_dtype).npu()
+    v2 = torch.randn(size=[seq_len], dtype=torch_dtype).npu()
+    v3 = torch.zeros(size=[seq_len], dtype=torch_dtype).npu()
examples/commonir/add_vector_profiler.py (2)

3-13: Remove unused imports.

Several imports are not used in this file: os, partial from functools, time, numpy, and typing imports (Callable, Optional, Union, List).

🧹 Proposed fix
-import os
-
 import tilelang
 import tilelang.language as T
-from functools import partial

 import torch
 import torch_npu
-import time
-import numpy as np
-from typing import Callable, Optional, Union, List

19-35: Consider extracting shared vec_add kernel to reduce duplication.

This vec_add function is identical to the one in examples/commonir/add_vector.py. Consider extracting the shared kernel definition to a common module to avoid code duplication across examples.

examples/commonir/gemm.py (1)

36-49: Consider using consistent dimension constants.

SIZEALL = 1024 duplicates the matrix dimensions already passed to matmul(1024, 1024, 1024, ...). Consider defining the dimensions once to avoid inconsistency if one is changed without the other.

♻️ Proposed fix
 def main():
-    func = matmul(1024, 1024, 1024, 128, 128, 32)
-    kernel = tilelang.compile(func, target='commonir')
-    SIZEALL = 1024
+    M, N, K = 1024, 1024, 1024
+    func = matmul(M, N, K, 128, 128, 32)
+    kernel = tilelang.compile(func, target='commonir')

     torch.manual_seed(0)
-    a = torch.rand((SIZEALL, SIZEALL), dtype=dtype, device=device) - 0.5
-    b = torch.rand((SIZEALL, SIZEALL), dtype=dtype, device=device) - 0.5
-    result = torch.zeros((SIZEALL, SIZEALL), dtype=dtype, device=device)
+    a = torch.rand((M, K), dtype=dtype, device=device) - 0.5
+    b = torch.rand((K, N), dtype=dtype, device=device) - 0.5
+    result = torch.zeros((M, N), dtype=dtype, device=device)
CMakeLists.txt (2)

188-194: Simplify source file addition.

Using file(GLOB ...) for exactly two explicitly named files is unnecessary. A direct list(APPEND) is clearer and avoids the minor overhead of globbing.

♻️ Proposed fix
 if(USE_COMMONIR)
-  file(GLOB TILE_LANG_COMMONIR_SRCS
+  list(APPEND TILE_LANG_SRCS
     src/target/codegen_commonir.cc
     src/target/rt_mod_commonir.cc
   )
-  list(APPEND TILE_LANG_SRCS ${TILE_LANG_COMMONIR_SRCS})
 endif()

173-174: Consider adding COMMONIR to TILELANG_BACKENDS for consistency.

COMMONIR is not part of the TILELANG_BACKENDS list (line 68) and is not configured via tilelang_define_backend_option. Unlike CUDA, ROCM, and METAL, it doesn't benefit from the backend option caching logic and lacks documentation. If CommonIR is a first-class backend, add it to TILELANG_BACKENDS alongside the other backends to follow the established pattern.

Additionally, the source file inclusion at lines 189-192 uses file(GLOB) to match just two hardcoded filenames; consider using list(APPEND) directly instead for clarity.

tilelang/utils/target.py (2)

27-28: Missing blank line between module constant and function definition.

PEP 8 recommends two blank lines before top-level function definitions. Also consider adding a brief comment explaining the purpose of this flag.

 COMMONIR_enabled = os.environ.get('USE_COMMONIR', '0') in ('1', 'true', 'on')
+
+
 def describe_supported_targets() -> dict[str, str]:

109-113: COMMONIR_enabled overrides explicit user-specified target without documentation.

When USE_COMMONIR=1 is set, this function returns "commonir" regardless of the target argument passed by the user. This behavior isn't documented in the docstring.

Consider either:

  1. Documenting this behavior in the docstring
  2. Only applying the override when target == "auto"
+    # Note: When USE_COMMONIR environment variable is set, CommonIR target
+    # takes precedence over any user-specified target.
     if COMMONIR_enabled:
         return_var = "commonir"
tilelang/jit/jit_commonir.py (3)

36-36: Replace print statements with proper logging.

Multiple print statements (lines 36, 53, 59, 67) appear to be debug output. These should use Python's logging module for proper log level control in production.

+import logging
+logger = logging.getLogger(__name__)
...
-        print(f"mod is in compiler_common: \n {mod}")
+        logger.debug(f"mod is in compiler_common: \n {mod}")
...
-        print(self.mlir_content)
+        logger.debug(self.mlir_content)

28-33: Fix type hints to use explicit Optional.

PEP 484 prohibits implicit Optional. Parameters with = None default should be explicitly typed as Optional.

+from typing import Optional
+
 class compiler_common:
     def __init__(self,
-        out_idx: Union[List[int], int] = None,
+        out_idx: Optional[Union[List[int], int]] = None,
     ):
       self.out_idx = out_idx

-    def compile(self, mod: PrimFunc = None) -> JitKernel_NPU:
+    def compile(self, mod: Optional[PrimFunc] = None) -> JitKernel_NPU:

97-119: Improve exception handling - avoid catching broad Exception.

The file I/O methods catch Exception broadly and silently return False/None. This can mask unexpected errors. Consider catching specific exceptions or re-raising with context.

     def _write_mlir_file(self, file_path):
         try:
             with open(file_path, "w", encoding="utf-8") as file:
                 file.write(self.mlir_content)
-            return True
         except FileNotFoundError:
-            print(f"Error: Directory for '{file_path}' does not exist")
-            return False
-        except Exception as e:
-            print(f"Error occurred while writing to the file: {e}")
-            return False
+            logger.error(f"Directory for '{file_path}' does not exist")
+            raise
+        except OSError as e:
+            logger.error(f"Error writing to file: {e}")
+            raise
+        return True


import torch
import torch_npu
device = torch.npu.current_device()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Module-level NPU device access may cause import failure.

torch.npu.current_device() is called at module load time, which will raise an error if NPU is not available or torch_npu is not properly initialized. Consider moving this inside main() to allow the module to be imported without an active NPU context.

🛡️ Proposed fix
 import torch
 import torch_npu
-device = torch.npu.current_device()
 dtype = torch.float16

 def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
     ...

 def main():
+    device = torch.npu.current_device()
     func = matmul(1024, 1024, 1024, 128, 128, 32)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
device = torch.npu.current_device()
import torch
import torch_npu
dtype = torch.float16
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
...
def main():
device = torch.npu.current_device()
func = matmul(1024, 1024, 1024, 128, 128, 32)
🤖 Prompt for AI Agents
In `@examples/commonir/gemm.py` at line 9, The module-level call to
torch.npu.current_device() (stored in the device variable) can fail at import if
NPU isn't available; move the device detection inside main() (or another runtime
entrypoint) and replace the module-level device assignment with deferred lookup
so imports succeed without an active NPU context; update any references that
used the module-level device variable to retrieve the device inside main() or
pass it down to functions that need it.

README.md Outdated
<img src=./images/MatmulExample.png />

## Latest News
- 01/22/2025 🚀: Added CommonIR support, enabling compilation on wider domestic GPU through [DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Date placement appears out of chronological order.

The new entry dated 01/22/2025 is placed at the top, but chronologically it falls between 01/20/2025 and 02/10/2025. If news entries are meant to be in reverse chronological order (newest first), this entry should be placed lower in the list, after 02/10/2025 and before 01/20/2025.

Alternatively, if this is the most recent addition (perhaps the year should be 2026?), please verify the date is correct.

🤖 Prompt for AI Agents
In `@README.md` at line 16, The changelog entry "01/22/2025 🚀: Added CommonIR
support, enabling compilation on wider domestic GPU through
[DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration." is out of
chronological order; either correct its date if it should be the latest (e.g.,
change year to 2026) or move this exact entry string to its proper position in
the reverse-chronological list (after the "02/10/2025" entry and before the
"01/20/2025" entry) so the README's news section remains sorted newest-first.

Comment on lines +92 to +147
std::vector<unsigned long> GetStrideFromShape(Array<tvm::PrimExpr> shape) {
std::vector<unsigned long> strides;
unsigned long total_size = 1;
std::vector<int> shape_int;
for (PrimExpr s : shape) {
if (auto s_int = as_const_int(s)) {
total_size *= *s_int;
shape_int.push_back(*s_int);
}
}
for (int i = 0; i < shape.size(); i++) {
total_size /= shape_int[i];
strides.push_back(total_size);
}
return strides;
}

String GetBufferStrides(Buffer buffer) {
Array<PrimExpr> shape = buffer->shape;
std::vector<unsigned long> strides;
int dim = buffer->shape.size();
if (buffer->strides.empty()) {
strides = GetStrideFromShape(shape);
} else {
for (PrimExpr stride : buffer->strides) {
if (auto stride_int = as_const_int(stride)) {
strides.push_back(*stride_int);
}
}
}
String res = "[";
for (int i = 0; i < dim; i++) {
if (i > 0)
res = res + ", ";
res = res + std::to_string(strides[i]);
}
res = res + "]";
return res;
}

static std::vector<int> getBroadcastDim(Array<PrimExpr> &buffer_shape0,
Array<PrimExpr> &buffer_shape1) {
assert(buffer_shape0.size() == buffer_shape1.size());
std::vector<int> dims;
for (int i = 0; i < buffer_shape0.size(); i++) {
if (*as_const_int(buffer_shape0[i]) == 1 &&
*as_const_int(buffer_shape1[i]) != 1) {
dims.emplace_back(i);
}
if (*as_const_int(buffer_shape0[i]) != 1 &&
*as_const_int(buffer_shape1[i]) == 1) {
dims.emplace_back(i);
}
assert(*as_const_int(buffer_shape0[i]) == *as_const_int(buffer_shape1[i]));
}
return dims;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard against non-constant shapes/strides in stride & broadcast helpers.

as_const_int() is dereferenced without null checks; if any shape/stride is symbolic, this triggers UB or out‑of‑bounds access. Add validation or explicit dynamic handling.

✅ Suggested fix
 std::vector<unsigned long> GetStrideFromShape(Array<tvm::PrimExpr> shape) {
   std::vector<unsigned long> strides;
   unsigned long total_size = 1;
-  std::vector<int> shape_int;
+  std::vector<int64_t> shape_int;
   for (PrimExpr s : shape) {
-    if (auto s_int = as_const_int(s)) {
-      total_size *= *s_int;
-      shape_int.push_back(*s_int);
-    }
+    auto s_int = as_const_int(s);
+    ICHECK(s_int) << "dynamic shapes are not supported in GetStrideFromShape";
+    total_size *= *s_int;
+    shape_int.push_back(*s_int);
   }
   for (int i = 0; i < shape.size(); i++) {
     total_size /= shape_int[i];
     strides.push_back(total_size);
   }
   return strides;
 }
 
 String GetBufferStrides(Buffer buffer) {
   Array<PrimExpr> shape = buffer->shape;
   std::vector<unsigned long> strides;
   int dim = buffer->shape.size();
   if (buffer->strides.empty()) {
     strides = GetStrideFromShape(shape);
   } else {
     for (PrimExpr stride : buffer->strides) {
-      if (auto stride_int = as_const_int(stride)) {
-        strides.push_back(*stride_int);
-      }
+      auto stride_int = as_const_int(stride);
+      ICHECK(stride_int) << "dynamic strides are not supported in GetBufferStrides";
+      strides.push_back(*stride_int);
     }
   }
+  ICHECK(strides.size() == static_cast<size_t>(dim)) << "stride rank mismatch";
   String res = "[";
   for (int i = 0; i < dim; i++) {
     if (i > 0)
       res = res + ", ";
     res = res + std::to_string(strides[i]);
   }
   res = res + "]";
   return res;
 }
 
 static std::vector<int> getBroadcastDim(Array<PrimExpr> &buffer_shape0,
                                         Array<PrimExpr> &buffer_shape1) {
   assert(buffer_shape0.size() == buffer_shape1.size());
   std::vector<int> dims;
   for (int i = 0; i < buffer_shape0.size(); i++) {
-    if (*as_const_int(buffer_shape0[i]) == 1 &&
-        *as_const_int(buffer_shape1[i]) != 1) {
+    auto s0 = as_const_int(buffer_shape0[i]);
+    auto s1 = as_const_int(buffer_shape1[i]);
+    ICHECK(s0 && s1) << "broadcast requires static shapes";
+    if (*s0 == 1 && *s1 != 1) {
       dims.emplace_back(i);
     }
-    if (*as_const_int(buffer_shape0[i]) != 1 &&
-        *as_const_int(buffer_shape1[i]) == 1) {
+    if (*s0 != 1 && *s1 == 1) {
       dims.emplace_back(i);
     }
-    assert(*as_const_int(buffer_shape0[i]) == *as_const_int(buffer_shape1[i]));
+    ICHECK(*s0 == *s1) << "shape mismatch for broadcast";
   }
   return dims;
 }
🤖 Prompt for AI Agents
In `@src/target/codegen_commonir.cc` around lines 92 - 147, The helpers
GetStrideFromShape, GetBufferStrides, and getBroadcastDim dereference
as_const_int() results without null checks, which breaks on non-constant
(symbolic) shapes/strides; update them to validate as_const_int(...) before
dereferencing and handle non-constant cases explicitly (e.g., return an error,
throw/log an explanatory message, or implement dynamic stride handling) — in
GetStrideFromShape skip or error on non-constant shape elements and avoid
building shape_int from nulls; in GetBufferStrides check each buffer->strides
element via as_const_int and handle the null case; in getBroadcastDim guard
every *as_const_int(...) access, ensure shapes are constant or emit a clear
error/exception when symbolic dimensions are encountered.

Comment on lines +185 to +219
std::string GetCastOp(DataType src_type, DataType dst_type) {
bool srcIsFloat = src_type.is_float() || src_type.is_bfloat16();
bool srcIsInt = src_type.is_int();
bool srcIsUInt = src_type.is_uint();
bool targetIsFloat = dst_type.is_float() || dst_type.is_bfloat16();
bool targetIsInt = dst_type.is_int();
bool targetIsUInt = dst_type.is_uint();
if (srcIsFloat && targetIsInt) {
return "arith.fptosi";
} else if (srcIsFloat && targetIsUInt) {
return "arith.fptoui";
} else if (srcIsInt && targetIsFloat) {
return "arith.sitofp";
} else if (srcIsUInt && targetIsFloat) {
return "arith.uitofp";
} else if (targetIsInt) {
if (dst_type.bits() > src_type.bits()) {
return "arith.extsi";
} else {
return "arith.trunci";
}
} else if (targetIsUInt) {
if (dst_type.bits() > src_type.bits()) {
return "arith.extui";
} else {
return "arith.trunci";
}
} else if (targetIsFloat) {
if (dst_type.bits() > src_type.bits()) {
return "arith.extf";
} else {
return "arith.truncf";
}
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add a fallback/error path in GetCastOp to avoid UB.

The function can fall through without returning a value for unsupported type combinations, which is undefined behavior.

✅ Suggested fix
   } else if (targetIsFloat) {
     if (dst_type.bits() > src_type.bits()) {
       return "arith.extf";
     } else {
       return "arith.truncf";
     }
   }
+  LOG(FATAL) << "Unsupported cast from " << src_type << " to " << dst_type;
+  return "";
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
std::string GetCastOp(DataType src_type, DataType dst_type) {
bool srcIsFloat = src_type.is_float() || src_type.is_bfloat16();
bool srcIsInt = src_type.is_int();
bool srcIsUInt = src_type.is_uint();
bool targetIsFloat = dst_type.is_float() || dst_type.is_bfloat16();
bool targetIsInt = dst_type.is_int();
bool targetIsUInt = dst_type.is_uint();
if (srcIsFloat && targetIsInt) {
return "arith.fptosi";
} else if (srcIsFloat && targetIsUInt) {
return "arith.fptoui";
} else if (srcIsInt && targetIsFloat) {
return "arith.sitofp";
} else if (srcIsUInt && targetIsFloat) {
return "arith.uitofp";
} else if (targetIsInt) {
if (dst_type.bits() > src_type.bits()) {
return "arith.extsi";
} else {
return "arith.trunci";
}
} else if (targetIsUInt) {
if (dst_type.bits() > src_type.bits()) {
return "arith.extui";
} else {
return "arith.trunci";
}
} else if (targetIsFloat) {
if (dst_type.bits() > src_type.bits()) {
return "arith.extf";
} else {
return "arith.truncf";
}
}
}
std::string GetCastOp(DataType src_type, DataType dst_type) {
bool srcIsFloat = src_type.is_float() || src_type.is_bfloat16();
bool srcIsInt = src_type.is_int();
bool srcIsUInt = src_type.is_uint();
bool targetIsFloat = dst_type.is_float() || dst_type.is_bfloat16();
bool targetIsInt = dst_type.is_int();
bool targetIsUInt = dst_type.is_uint();
if (srcIsFloat && targetIsInt) {
return "arith.fptosi";
} else if (srcIsFloat && targetIsUInt) {
return "arith.fptoui";
} else if (srcIsInt && targetIsFloat) {
return "arith.sitofp";
} else if (srcIsUInt && targetIsFloat) {
return "arith.uitofp";
} else if (targetIsInt) {
if (dst_type.bits() > src_type.bits()) {
return "arith.extsi";
} else {
return "arith.trunci";
}
} else if (targetIsUInt) {
if (dst_type.bits() > src_type.bits()) {
return "arith.extui";
} else {
return "arith.trunci";
}
} else if (targetIsFloat) {
if (dst_type.bits() > src_type.bits()) {
return "arith.extf";
} else {
return "arith.truncf";
}
}
LOG(FATAL) << "Unsupported cast from " << src_type << " to " << dst_type;
return "";
}
🤖 Prompt for AI Agents
In `@src/target/codegen_commonir.cc` around lines 185 - 219, GetCastOp can fall
through without returning for unsupported src/dst combinations; add a clear
fallback at the end of GetCastOp (referencing GetCastOp, src_type, dst_type,
DataType::bits) that either returns a safe sentinel (e.g., empty string) or
better, raises a deterministic error (throw std::runtime_error or call your
project's fatal logging) including the src/dst type info so undefined behavior
is avoided and debugging is easier. Ensure the fallback is reached only when
none of the existing branches match and include descriptive text with
src_type/dst_type in the error.

Comment on lines +1148 to +1149
void CodeGenTileLangCOMMONIR::VisitExpr_(const DivNode *op, std::ostream &os) {
PrintBinary(op, "<<<divf>>>", os, this);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Implement DivNode emission (currently emits a placeholder).

<<<divf>>> yields invalid IR. Emit a real div op or hard‑fail for unsupported dtypes.

✅ Suggested fix
 void CodeGenTileLangCOMMONIR::VisitExpr_(const DivNode *op, std::ostream &os) {
-  PrintBinary(op, "<<<divf>>>", os, this);
+  if (op->dtype.is_int()) {
+    PrintBinary(op, "divsi", os, this);
+  } else if (op->dtype.is_uint()) {
+    PrintBinary(op, "divui", os, this);
+  } else if (op->dtype.is_float()) {
+    PrintBinary(op, "divf", os, this);
+  } else {
+    LOG(FATAL) << "Unsupported dtype for DivNode: " << op->dtype;
+  }
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void CodeGenTileLangCOMMONIR::VisitExpr_(const DivNode *op, std::ostream &os) {
PrintBinary(op, "<<<divf>>>", os, this);
void CodeGenTileLangCOMMONIR::VisitExpr_(const DivNode *op, std::ostream &os) {
if (op->dtype.is_int()) {
PrintBinary(op, "divsi", os, this);
} else if (op->dtype.is_uint()) {
PrintBinary(op, "divui", os, this);
} else if (op->dtype.is_float()) {
PrintBinary(op, "divf", os, this);
} else {
LOG(FATAL) << "Unsupported dtype for DivNode: " << op->dtype;
}
}
🤖 Prompt for AI Agents
In `@src/target/codegen_commonir.cc` around lines 1148 - 1149, The VisitExpr_
implementation for DivNode currently emits a placeholder "<<<divf>>>"; replace
this by inspecting op->dtype (use op->dtype() or op->dtype().base_type()) and
emit the correct division IR token for floats vs integers (e.g., a
floating-point divide string for float types and an integer divide string or
signed/unsigned variant for integer types) via the same PrintBinary call (e.g.,
PrintBinary(op, "<correct-div-op>", os, this)); if the dtype is unsupported,
hard-fail (throw or LOG(FATAL)) with a clear message referencing DivNode so
callers see the unsupported dtype instead of invalid IR.

Comment on lines +20 to +24
from typing import (
Any,
List,
Union,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Remove duplicate imports flagged by static analysis.

Any and Union are already imported on line 8. Remove the redundant imports.

-from typing import (
-    Any,
-    List,
-    Union,
-)
+from typing import List
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from typing import (
Any,
List,
Union,
)
from typing import List
🧰 Tools
🪛 Ruff (0.14.14)

21-21: Redefinition of unused Any from line 8: Any redefined here

Remove definition: Any

(F811)


23-23: Redefinition of unused Union from line 8: Union redefined here

Remove definition: Union

(F811)

🤖 Prompt for AI Agents
In `@tilelang/jit/jit_commonir.py` around lines 20 - 24, Remove the duplicate
typing imports by editing the import block in jit_commonir.py: keep only the new
symbol(s) needed here (e.g., List) and delete Any and Union which are already
imported earlier (line 8), ensuring there are no redundant imports of Any or
Union in the file.

}
block_indices = {"x": None, "y": None, "z": None}
for dim, pattern in patterns.items():
match = re.search(pattern, str(str(self.mod)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Redundant str(str(...)) call.

The double str() conversion is unnecessary.

-            match = re.search(pattern, str(str(self.mod)))
+            match = re.search(pattern, str(self.mod))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
match = re.search(pattern, str(str(self.mod)))
match = re.search(pattern, str(self.mod))
🤖 Prompt for AI Agents
In `@tilelang/jit/jit_commonir.py` at line 88, The call uses redundant double
conversion str(str(self.mod)) in the re.search invocation; update the expression
in jit_commonir.py (the line that sets match via re.search(pattern, ...)) to
pass a single string conversion—e.g., replace str(str(self.mod)) with
str(self.mod)—so the regex runs against the same string value without the
unnecessary nested str() call.

Comment on lines +39 to +56
# Estimate the runtime of the function
start_event = torch.npu.Event(enable_timing=True)
end_event = torch.npu.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.npu.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5

# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
if _n_warmup > 0:
n_warmup = _n_warmup
if _n_repeat > 0:
n_repeat = _n_repeat
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Guard against zero/near-zero timing estimates in do_bench.

elapsed_time() can return 0 for very fast kernels or coarse timer resolution, which triggers a divide-by-zero when computing n_warmup/n_repeat. Clamp to a small positive value to keep benchmarking stable.

✅ Suggested fix
-    estimate_ms = start_event.elapsed_time(end_event) / 5
+    estimate_ms = start_event.elapsed_time(end_event) / 5
+    if estimate_ms <= 0:
+        estimate_ms = 1e-6
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Estimate the runtime of the function
start_event = torch.npu.Event(enable_timing=True)
end_event = torch.npu.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.npu.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
if _n_warmup > 0:
n_warmup = _n_warmup
if _n_repeat > 0:
n_repeat = _n_repeat
# Estimate the runtime of the function
start_event = torch.npu.Event(enable_timing=True)
end_event = torch.npu.Event(enable_timing=True)
start_event.record()
for _ in range(5):
cache.zero_()
fn()
end_event.record()
torch.npu.synchronize()
estimate_ms = start_event.elapsed_time(end_event) / 5
if estimate_ms <= 0:
estimate_ms = 1e-6
# compute number of warmup and repeat
n_warmup = max(1, int(warmup / estimate_ms))
n_repeat = max(1, int(rep / estimate_ms))
if _n_warmup > 0:
n_warmup = _n_warmup
if _n_repeat > 0:
n_repeat = _n_repeat
🤖 Prompt for AI Agents
In `@tilelang/jit/kernel_npu.py` around lines 39 - 56, The timing estimate
computed from start_event.elapsed_time(end_event) can be zero; clamp estimate_ms
to a small positive floor (e.g., 1e-3 ms) before using it to compute n_warmup
and n_repeat to avoid divide-by-zero and extreme counts. Locate the block that
creates start_event/end_event and sets estimate_ms, then set estimate_ms =
max(estimate_ms, <small_positive>) prior to computing n_warmup/n_repeat (keeping
later overrides for _n_warmup/_n_repeat intact); reference symbols: start_event,
end_event, estimate_ms, n_warmup, n_repeat, _n_warmup, _n_repeat, warmup, rep.

Comment on lines +94 to +113
def get_tensor(param: KernelParam) -> torch.Tensor:
dtype: torch.dtype = param.torch_dtype()
device = torch.npu.current_device()

if hasattr(param, "shape") and not param.shape:
raise ValueError(
f"TensorType must have a shape, but got {type(param)}, "
"likely you are trying to generate a random tensor with a dynamic symbolic shape."
)

# Check if with dynamic symbolic shape
for shape in param.shape:
if isinstance(shape, tir.Var):
raise ValueError(
f"TensorType must have a static shape, but got {shape}, "
"likely you are trying to generate a random tensor with a dynamic symbolic shape."
)

shape = list(map(int, param.shape))
if supply_type == TensorSupplyType.Auto:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Allow scalar KernelParam shapes instead of raising.

KernelParam.is_scalar() uses an empty shape; the current guard raises and breaks profiling for scalar parameters. Treat empty shapes as valid (0‑dim tensors).

✅ Suggested fix
-        if hasattr(param, "shape") and not param.shape:
-            raise ValueError(
-                f"TensorType must have a shape, but got {type(param)}, "
-                "likely you are trying to generate a random tensor with a dynamic symbolic shape."
-            )
-
-        # Check if with dynamic symbolic shape
-        for shape in param.shape:
-            if isinstance(shape, tir.Var):
-                raise ValueError(
-                    f"TensorType must have a static shape, but got {shape}, "
-                    "likely you are trying to generate a random tensor with a dynamic symbolic shape."
-                )
-
-        shape = list(map(int, param.shape))
+        if hasattr(param, "shape") and not param.shape:
+            # scalar parameter is valid
+            shape = []
+        else:
+            # Check if with dynamic symbolic shape
+            for dim in param.shape:
+                if isinstance(dim, tir.Var):
+                    raise ValueError(
+                        f"TensorType must have a static shape, but got {dim}, "
+                        "likely you are trying to generate a random tensor with a dynamic symbolic shape."
+                    )
+            shape = list(map(int, param.shape))
🧰 Tools
🪛 Ruff (0.14.14)

99-102: Avoid specifying long messages outside the exception class

(TRY003)


107-110: Prefer TypeError exception for invalid type

(TRY004)


107-110: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@tilelang/jit/kernel_npu.py` around lines 94 - 113, The get_tensor function
currently treats empty KernelParam.shape as an error, but
KernelParam.is_scalar() uses an empty shape for 0‑dim tensors; remove the early
ValueError and allow empty shapes as valid scalars. Update the guard in
get_tensor (the block that checks "if hasattr(param, 'shape') and not
param.shape:") to not raise for empty shapes, and ensure the subsequent loop
that checks for tir.Var only runs when param.shape is non-empty (e.g., iterate
only if param.shape). This preserves the tir.Var static-shape check while
allowing scalar KernelParam instances to be handled as 0‑dim tensors.

Comment on lines +178 to +192
def _legalize_result_idx(self, result_idx: Optional[List[int]] = None) -> List[int]:
params = self.params
# result_idx is a list of indices of the output tensors
if result_idx is None:
result_idx = []
elif isinstance(result_idx, int):
if result_idx > len(params) or result_idx < -len(params):
raise ValueError(
f"result_idx should be an integer between {-len(params)} and {len(params) - 1}")
if result_idx < 0:
result_idx = len(params) + result_idx
result_idx = [result_idx]
elif not isinstance(result_idx, list):
raise ValueError("result_idx should be a list of integers")
return result_idx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Normalize list result_idx (negative/out‑of‑range) and fix the >= check.

The int branch allows result_idx == len(params) (out of range). The list branch doesn’t normalize negatives or validate bounds, which mislabels outputs.

✅ Suggested fix
-        elif isinstance(result_idx, int):
-            if result_idx > len(params) or result_idx < -len(params):
+        elif isinstance(result_idx, int):
+            if result_idx >= len(params) or result_idx < -len(params):
                 raise ValueError(
                     f"result_idx should be an integer between {-len(params)} and {len(params) - 1}")
             if result_idx < 0:
                 result_idx = len(params) + result_idx
             result_idx = [result_idx]
-        elif not isinstance(result_idx, list):
-            raise ValueError("result_idx should be a list of integers")
+        elif isinstance(result_idx, list):
+            normalized = []
+            for idx in result_idx:
+                if not isinstance(idx, int):
+                    raise ValueError("result_idx should be a list of integers")
+                if idx >= len(params) or idx < -len(params):
+                    raise ValueError(
+                        f"result_idx should be an integer between {-len(params)} and {len(params) - 1}")
+                if idx < 0:
+                    idx = len(params) + idx
+                normalized.append(idx)
+            result_idx = normalized
+        else:
+            raise ValueError("result_idx should be a list of integers")
         return result_idx
🧰 Tools
🪛 Ruff (0.14.14)

185-186: Avoid specifying long messages outside the exception class

(TRY003)


191-191: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In `@tilelang/jit/kernel_npu.py` around lines 178 - 192, In _legalize_result_idx,
fix the int branch bounds check to forbid result_idx >= len(self.params) (not
just >) and normalize negative indices as currently done; additionally, when
result_idx is a list (the else branch), validate each element is an int,
normalize any negative indices by adding len(self.params), and raise ValueError
for any index < 0 or >= len(self.params); update references to params via
self.params and return the normalized list of ints.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@README.md`:
- Line 16: Update the dated news entry that currently reads "01/29/2025 🚀:
Added CommonIR support..." to the correct year "01/29/2026" and ensure the
News/Changelog block remains in reverse-chronological order; specifically locate
the "01/29/2025" string in README.md and change it to "01/29/2026" (and if
needed reorder adjacent entries so later 2025/2026 items remain in proper
descending date order).

In `@tilelang/utils/target.py`:
- Around line 140-144: The current check if COMMONIR_enabled unconditionally
overrides any explicit target; change it so COMMONIR only hijacks when the
requested target is the default/auto value—i.e., update the conditional around
return_var/Target(return_var) (the block using COMMONIR_enabled, return_var and
return_object) to also test the function's target argument (e.g., target is None
or target == "auto") before returning "commonir" or Target("commonir"); keep the
existing return_object behavior intact so explicit targets like "cuda" are
preserved.
🧹 Nitpick comments (1)
tilelang/utils/target.py (1)

27-27: Missing blank line after module-level constant.

Per PEP 8, there should be a blank line between module-level definitions and function definitions.

 COMMONIR_enabled = os.environ.get('USE_COMMONIR', '0') in ('1', 'true', 'on')
+
 def describe_supported_targets() -> dict[str, str]:

<img src=./images/MatmulExample.png />

## Latest News
- 01/29/2025 🚀: Added CommonIR support, enabling compilation on wider domestic GPU through [DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Date year appears incorrect.

The entry is dated 01/29/2025, but given the current date (January 2026) and the reverse-chronological ordering of the news section, this should likely be 01/29/2026. With the current date, this entry would be out of order—it should appear much lower in the list, after entries from later in 2025.

📝 Suggested fix
-- 01/29/2025 🚀: Added CommonIR support, enabling compilation on wider domestic GPU through [DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration.
+- 01/29/2026 🚀: Added CommonIR support, enabling compilation on wider domestic GPU through [DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
- 01/29/2025 🚀: Added CommonIR support, enabling compilation on wider domestic GPU through [DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration.
- 01/29/2026 🚀: Added CommonIR support, enabling compilation on wider domestic GPU through [DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration.
🤖 Prompt for AI Agents
In `@README.md` at line 16, Update the dated news entry that currently reads
"01/29/2025 🚀: Added CommonIR support..." to the correct year "01/29/2026" and
ensure the News/Changelog block remains in reverse-chronological order;
specifically locate the "01/29/2025" string in README.md and change it to
"01/29/2026" (and if needed reorder adjacent entries so later 2025/2026 items
remain in proper descending date order).

Comment on lines +140 to +144
if COMMONIR_enabled:
return_var = "commonir"
if return_object:
return Target(return_var)
return return_var
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

COMMONIR_enabled unconditionally overrides any explicit target argument.

When USE_COMMONIR=1, this early return bypasses all target detection logic, ignoring even explicit user-specified targets like target="cuda". This may be intentional for a global opt-in, but it creates surprising behavior where the user's explicit target choice is silently overridden.

Consider adjusting the logic to respect explicit target specifications:

     return_var: str | Target = target
-    if COMMONIR_enabled:
+    if target == "commonir" or (target == "auto" and COMMONIR_enabled):
         return_var = "commonir"
         if return_object:
             return Target(return_var)
         return return_var

This would allow USE_COMMONIR=1 to affect only target="auto" (the default), while explicit targets like target="cuda" remain honored.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if COMMONIR_enabled:
return_var = "commonir"
if return_object:
return Target(return_var)
return return_var
if target == "commonir" or (target == "auto" and COMMONIR_enabled):
return_var = "commonir"
if return_object:
return Target(return_var)
return return_var
🤖 Prompt for AI Agents
In `@tilelang/utils/target.py` around lines 140 - 144, The current check if
COMMONIR_enabled unconditionally overrides any explicit target; change it so
COMMONIR only hijacks when the requested target is the default/auto value—i.e.,
update the conditional around return_var/Target(return_var) (the block using
COMMONIR_enabled, return_var and return_object) to also test the function's
target argument (e.g., target is None or target == "auto") before returning
"commonir" or Target("commonir"); keep the existing return_object behavior
intact so explicit targets like "cuda" are preserved.

@LeiWang1999 LeiWang1999 marked this pull request as draft January 30, 2026 08:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant