-
Notifications
You must be signed in to change notification settings - Fork 429
feat(commonir): add commonir abstract backend #1754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds a CommonIR backend: a DLCompiler git submodule, a new C++ COMMONIR codegen and runtime integration, Python JIT/profiler for CommonIR, examples, and pipeline changes to select and short-circuit to the CommonIR path. Changes
Sequence DiagramsequenceDiagram
participant User
participant TileLang as TileLang Compiler
participant Lower as Lower/Optimize
participant CodeGen as CommonIR CodeGen
participant JIT as CommonIR JIT
participant NPU as NPU Device
User->>TileLang: submit prim_func + target
TileLang->>Lower: lower(mod)
Note over Lower: if target == "commonir" or COMMONIR_enabled\napply limited passes and return MLIR
Lower->>CodeGen: emit CommonIR/C source
CodeGen->>JIT: provide C source + metadata
JIT->>JIT: parse grid & signature, compile via CommonIRCompiler
JIT-->>User: return JitKernel_NPU
User->>NPU: execute kernel
NPU-->>User: results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 18
🤖 Fix all issues with AI agents
In `@examples/commonir/gemm.py`:
- Line 9: The module-level call to torch.npu.current_device() (stored in the
device variable) can fail at import if NPU isn't available; move the device
detection inside main() (or another runtime entrypoint) and replace the
module-level device assignment with deferred lookup so imports succeed without
an active NPU context; update any references that used the module-level device
variable to retrieve the device inside main() or pass it down to functions that
need it.
In `@README.md`:
- Line 16: The changelog entry "01/22/2025 🚀: Added CommonIR support, enabling
compilation on wider domestic GPU through
[DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration." is out of
chronological order; either correct its date if it should be the latest (e.g.,
change year to 2026) or move this exact entry string to its proper position in
the reverse-chronological list (after the "02/10/2025" entry and before the
"01/20/2025" entry) so the README's news section remains sorted newest-first.
In `@src/target/codegen_commonir.cc`:
- Around line 185-219: GetCastOp can fall through without returning for
unsupported src/dst combinations; add a clear fallback at the end of GetCastOp
(referencing GetCastOp, src_type, dst_type, DataType::bits) that either returns
a safe sentinel (e.g., empty string) or better, raises a deterministic error
(throw std::runtime_error or call your project's fatal logging) including the
src/dst type info so undefined behavior is avoided and debugging is easier.
Ensure the fallback is reached only when none of the existing branches match and
include descriptive text with src_type/dst_type in the error.
- Around line 1301-1314: The is_arg branch currently emits a single "?x" and
collapses multidimensional argument memrefs; update the branch in GetMemrefInfo
(the block using memrefObj->is_arg, memrefObj->shape and building memref_type)
to iterate over memrefObj->shape and append a "?x" for each dimension (similar
to the non-arg loop), preserving the original rank, then call
PrintType(memrefObj->dtype, memref_type) as before.
- Around line 1148-1149: The VisitExpr_ implementation for DivNode currently
emits a placeholder "<<<divf>>>"; replace this by inspecting op->dtype (use
op->dtype() or op->dtype().base_type()) and emit the correct division IR token
for floats vs integers (e.g., a floating-point divide string for float types and
an integer divide string or signed/unsigned variant for integer types) via the
same PrintBinary call (e.g., PrintBinary(op, "<correct-div-op>", os, this)); if
the dtype is unsupported, hard-fail (throw or LOG(FATAL)) with a clear message
referencing DivNode so callers see the unsupported dtype instead of invalid IR.
- Around line 92-147: The helpers GetStrideFromShape, GetBufferStrides, and
getBroadcastDim dereference as_const_int() results without null checks, which
breaks on non-constant (symbolic) shapes/strides; update them to validate
as_const_int(...) before dereferencing and handle non-constant cases explicitly
(e.g., return an error, throw/log an explanatory message, or implement dynamic
stride handling) — in GetStrideFromShape skip or error on non-constant shape
elements and avoid building shape_int from nulls; in GetBufferStrides check each
buffer->strides element via as_const_int and handle the null case; in
getBroadcastDim guard every *as_const_int(...) access, ensure shapes are
constant or emit a clear error/exception when symbolic dimensions are
encountered.
In `@src/target/codegen_commonir.h`:
- Around line 85-86: The header declares PrintExtraAttrs(const PrimFunc &f) but
there's no implementation causing linker/vtable errors; add a definition in
codegen_commonir.cc that either implements the intended behavior or a minimal
stub (e.g., empty body) to satisfy the linker. Locate the declaration of
PrintExtraAttrs in src/target/codegen_commonir.h and add a corresponding
function definition in src/target/codegen_commonir.cc matching the exact
signature (void PrintExtraAttrs(const PrimFunc &f)) and ensure it uses/compiles
with the PrimFunc type and existing namespace so symbol linkage matches.
- Around line 14-17: The header codegen_commonir.h uses std::map and std::vector
but only includes <unordered_map> and <string>; add the missing includes for
<map> and <vector> at the top of the header so symbols like std::map and
std::vector used anywhere in codegen_commonir.h are always available regardless
of include order.
In `@tilelang/engine/lower.py`:
- Around line 261-266: Remove the leftover debug print by deleting the
commented-out "print(codegen_mod.inspect_source())" line and then address the
host_mod being passed as None to CompiledArtifact: either update this call in
lower.py (where target.kind.name / COMMONIR_enabled lead to device_codegen(mod,
target)) to supply a valid tvm.IRModule host_mod, or confirm and update the
CompiledArtifact type/signature and all downstream consumers to accept
Optional[tvm.IRModule] (search for CompiledArtifact usages to ensure None is
handled safely), keeping the call site and symbol names device_codegen and
CompiledArtifact consistent.
In `@tilelang/jit/__init__.py`:
- Around line 101-104: The current routing uses COMMONIR_enabled with lower
precedence than an explicit target and returns the wrong type; change the
condition so an explicit target argument takes priority over the global flag
(i.e., check target == 'commonir' first or require COMMONIR_enabled only when
target is None/default), and ensure the returned object matches the declared
return type by either returning a JITKernel instance or updating the
function/generic signature to allow JitKernel_NPU; specifically update the
branch using COMMONIR_enabled, the call to
compiler_common(out_idx=out_idx)/compiler_commonir.compile(func), and the
function's return annotation (JITKernel[_KP,_T] vs JitKernel_NPU) so the runtime
branch and type signature agree.
- Around line 40-41: The unconditional top-level import of compiler_common from
.jit_commonir causes ImportError on systems without the NPU stack because
jit_commonir imports torch_npu at module scope; move the import into the
conditional branch where compiler_common is actually needed (perform a local
import of .jit_commonir and access compiler_common inside that block) so
jit_commonir’s torch_npu dependency is only resolved when CommonIR functionality
is used.
In `@tilelang/jit/jit_commonir.py`:
- Line 88: The call uses redundant double conversion str(str(self.mod)) in the
re.search invocation; update the expression in jit_commonir.py (the line that
sets match via re.search(pattern, ...)) to pass a single string conversion—e.g.,
replace str(str(self.mod)) with str(self.mod)—so the regex runs against the same
string value without the unnecessary nested str() call.
- Around line 20-24: Remove the duplicate typing imports by editing the import
block in jit_commonir.py: keep only the new symbol(s) needed here (e.g., List)
and delete Any and Union which are already imported earlier (line 8), ensuring
there are no redundant imports of Any or Union in the file.
- Line 13: The unconditional top-level import of torch_npu in
tilelang.jit.jit_commonir breaks imports on systems without the NPU stack;
change it to a lazy or guarded import (e.g., move the import into the
function(s) or method(s) that actually use torch_npu or wrap the import in a
try/except ImportError and surface a clear runtime error only when CommonIR
functionality is invoked). Locate the module-level import "import torch_npu" in
jit_commonir.py and replace it with an on-demand import inside the specific
functions/classes that require torch_npu (or a guarded try/except that sets a
flag and raises a descriptive error when CommonIR APIs are called on non-NPU
hosts).
- Around line 1-14: Remove the seven unused imports from the top of
tilelang/jit/jit_commonir.py: subprocess, dataclass (from dataclasses), Path
(from pathlib), ModuleType (from types), sysconfig, pybind11, and functools;
delete their import tokens or whole import statements referencing them and keep
the remaining imports intact (os, re, tempfile, typing items, shutil, pybind11
only if used elsewhere, torch, torch_npu) and run lint/tests to confirm nothing
else references these symbols (search for subprocess, dataclass, Path,
ModuleType, sysconfig, pybind11, functools to ensure safe removal).
In `@tilelang/jit/kernel_npu.py`:
- Around line 39-56: The timing estimate computed from
start_event.elapsed_time(end_event) can be zero; clamp estimate_ms to a small
positive floor (e.g., 1e-3 ms) before using it to compute n_warmup and n_repeat
to avoid divide-by-zero and extreme counts. Locate the block that creates
start_event/end_event and sets estimate_ms, then set estimate_ms =
max(estimate_ms, <small_positive>) prior to computing n_warmup/n_repeat (keeping
later overrides for _n_warmup/_n_repeat intact); reference symbols: start_event,
end_event, estimate_ms, n_warmup, n_repeat, _n_warmup, _n_repeat, warmup, rep.
- Around line 94-113: The get_tensor function currently treats empty
KernelParam.shape as an error, but KernelParam.is_scalar() uses an empty shape
for 0‑dim tensors; remove the early ValueError and allow empty shapes as valid
scalars. Update the guard in get_tensor (the block that checks "if
hasattr(param, 'shape') and not param.shape:") to not raise for empty shapes,
and ensure the subsequent loop that checks for tir.Var only runs when
param.shape is non-empty (e.g., iterate only if param.shape). This preserves the
tir.Var static-shape check while allowing scalar KernelParam instances to be
handled as 0‑dim tensors.
- Around line 178-192: In _legalize_result_idx, fix the int branch bounds check
to forbid result_idx >= len(self.params) (not just >) and normalize negative
indices as currently done; additionally, when result_idx is a list (the else
branch), validate each element is an int, normalize any negative indices by
adding len(self.params), and raise ValueError for any index < 0 or >=
len(self.params); update references to params via self.params and return the
normalized list of ints.
🧹 Nitpick comments (15)
3rdparty/DLCompiler (2)
1-1: Document the rationale for this specific commit.Consider adding documentation (e.g., in a CHANGELOG, commit message, or README) explaining why this specific commit (
51ef7ccd953557fc1e3e7f46efcd0c3fdb88c53d) was chosen. This helps with:
- Future maintenance and upgrades
- Understanding compatibility requirements
- Debugging issues related to this version
1-1: Establish a maintenance and update strategy for the submodule.Third-party submodules require ongoing maintenance. Consider:
- Setting up automated checks for updates and security patches
- Documenting the update process and testing requirements
- Establishing ownership for monitoring and updating this dependency
- Adding submodule update guidelines to your contribution docs
examples/commonir/add_vector.py (3)
3-3: Remove unused import.The
osmodule is imported but never used in this file.🧹 Proposed fix
-import os - import tilelang
35-35: Consider adding explicittarget='commonir'for consistency.The
gemm.pyexample in the same directory explicitly specifiestarget='commonir'(seeexamples/commonir/gemm.pyline 38), but this example omits it. For clarity and consistency across CommonIR examples, consider making the target explicit.♻️ Proposed fix
- compiled_kernel = tilelang.compile(func) + compiled_kernel = tilelang.compile(func, target='commonir')
37-39: Replaceeval()withgetattr()for safer dtype resolution.Using
eval()to construct torch dtypes is flagged as potentially insecure (S307). While thedtypevariable is controlled in this example,getattr()is the idiomatic and safer approach.🛡️ Proposed fix
- v1 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu() - v2 = torch.randn(size=[seq_len], dtype=eval("torch." + dtype)).npu() - v3 = torch.zeros(size=[seq_len], dtype=eval("torch." + dtype)).npu() + torch_dtype = getattr(torch, dtype) + v1 = torch.randn(size=[seq_len], dtype=torch_dtype).npu() + v2 = torch.randn(size=[seq_len], dtype=torch_dtype).npu() + v3 = torch.zeros(size=[seq_len], dtype=torch_dtype).npu()examples/commonir/add_vector_profiler.py (2)
3-13: Remove unused imports.Several imports are not used in this file:
os,partialfrom functools,time,numpy, and typing imports (Callable,Optional,Union,List).🧹 Proposed fix
-import os - import tilelang import tilelang.language as T -from functools import partial import torch import torch_npu -import time -import numpy as np -from typing import Callable, Optional, Union, List
19-35: Consider extracting sharedvec_addkernel to reduce duplication.This
vec_addfunction is identical to the one inexamples/commonir/add_vector.py. Consider extracting the shared kernel definition to a common module to avoid code duplication across examples.examples/commonir/gemm.py (1)
36-49: Consider using consistent dimension constants.
SIZEALL = 1024duplicates the matrix dimensions already passed tomatmul(1024, 1024, 1024, ...). Consider defining the dimensions once to avoid inconsistency if one is changed without the other.♻️ Proposed fix
def main(): - func = matmul(1024, 1024, 1024, 128, 128, 32) - kernel = tilelang.compile(func, target='commonir') - SIZEALL = 1024 + M, N, K = 1024, 1024, 1024 + func = matmul(M, N, K, 128, 128, 32) + kernel = tilelang.compile(func, target='commonir') torch.manual_seed(0) - a = torch.rand((SIZEALL, SIZEALL), dtype=dtype, device=device) - 0.5 - b = torch.rand((SIZEALL, SIZEALL), dtype=dtype, device=device) - 0.5 - result = torch.zeros((SIZEALL, SIZEALL), dtype=dtype, device=device) + a = torch.rand((M, K), dtype=dtype, device=device) - 0.5 + b = torch.rand((K, N), dtype=dtype, device=device) - 0.5 + result = torch.zeros((M, N), dtype=dtype, device=device)CMakeLists.txt (2)
188-194: Simplify source file addition.Using
file(GLOB ...)for exactly two explicitly named files is unnecessary. A directlist(APPEND)is clearer and avoids the minor overhead of globbing.♻️ Proposed fix
if(USE_COMMONIR) - file(GLOB TILE_LANG_COMMONIR_SRCS + list(APPEND TILE_LANG_SRCS src/target/codegen_commonir.cc src/target/rt_mod_commonir.cc ) - list(APPEND TILE_LANG_SRCS ${TILE_LANG_COMMONIR_SRCS}) endif()
173-174: Consider adding COMMONIR toTILELANG_BACKENDSfor consistency.COMMONIR is not part of the
TILELANG_BACKENDSlist (line 68) and is not configured viatilelang_define_backend_option. Unlike CUDA, ROCM, and METAL, it doesn't benefit from the backend option caching logic and lacks documentation. If CommonIR is a first-class backend, add it toTILELANG_BACKENDSalongside the other backends to follow the established pattern.Additionally, the source file inclusion at lines 189-192 uses
file(GLOB)to match just two hardcoded filenames; consider usinglist(APPEND)directly instead for clarity.tilelang/utils/target.py (2)
27-28: Missing blank line between module constant and function definition.PEP 8 recommends two blank lines before top-level function definitions. Also consider adding a brief comment explaining the purpose of this flag.
COMMONIR_enabled = os.environ.get('USE_COMMONIR', '0') in ('1', 'true', 'on') + + def describe_supported_targets() -> dict[str, str]:
109-113:COMMONIR_enabledoverrides explicit user-specified target without documentation.When
USE_COMMONIR=1is set, this function returns"commonir"regardless of thetargetargument passed by the user. This behavior isn't documented in the docstring.Consider either:
- Documenting this behavior in the docstring
- Only applying the override when
target == "auto"+ # Note: When USE_COMMONIR environment variable is set, CommonIR target + # takes precedence over any user-specified target. if COMMONIR_enabled: return_var = "commonir"tilelang/jit/jit_commonir.py (3)
36-36: ReplaceMultiple
loggingmodule for proper log level control in production.+import logging +logger = logging.getLogger(__name__) ... - print(f"mod is in compiler_common: \n {mod}") + logger.debug(f"mod is in compiler_common: \n {mod}") ... - print(self.mlir_content) + logger.debug(self.mlir_content)
28-33: Fix type hints to use explicitOptional.PEP 484 prohibits implicit
Optional. Parameters with= Nonedefault should be explicitly typed asOptional.+from typing import Optional + class compiler_common: def __init__(self, - out_idx: Union[List[int], int] = None, + out_idx: Optional[Union[List[int], int]] = None, ): self.out_idx = out_idx - def compile(self, mod: PrimFunc = None) -> JitKernel_NPU: + def compile(self, mod: Optional[PrimFunc] = None) -> JitKernel_NPU:
97-119: Improve exception handling - avoid catching broadException.The file I/O methods catch
Exceptionbroadly and silently returnFalse/None. This can mask unexpected errors. Consider catching specific exceptions or re-raising with context.def _write_mlir_file(self, file_path): try: with open(file_path, "w", encoding="utf-8") as file: file.write(self.mlir_content) - return True except FileNotFoundError: - print(f"Error: Directory for '{file_path}' does not exist") - return False - except Exception as e: - print(f"Error occurred while writing to the file: {e}") - return False + logger.error(f"Directory for '{file_path}' does not exist") + raise + except OSError as e: + logger.error(f"Error writing to file: {e}") + raise + return True
|
|
||
| import torch | ||
| import torch_npu | ||
| device = torch.npu.current_device() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Module-level NPU device access may cause import failure.
torch.npu.current_device() is called at module load time, which will raise an error if NPU is not available or torch_npu is not properly initialized. Consider moving this inside main() to allow the module to be imported without an active NPU context.
🛡️ Proposed fix
import torch
import torch_npu
-device = torch.npu.current_device()
dtype = torch.float16
def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"):
...
def main():
+ device = torch.npu.current_device()
func = matmul(1024, 1024, 1024, 128, 128, 32)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| device = torch.npu.current_device() | |
| import torch | |
| import torch_npu | |
| dtype = torch.float16 | |
| def matmul(M, N, K, block_M, block_N, block_K, dtype="float16", accum_dtype="float"): | |
| ... | |
| def main(): | |
| device = torch.npu.current_device() | |
| func = matmul(1024, 1024, 1024, 128, 128, 32) |
🤖 Prompt for AI Agents
In `@examples/commonir/gemm.py` at line 9, The module-level call to
torch.npu.current_device() (stored in the device variable) can fail at import if
NPU isn't available; move the device detection inside main() (or another runtime
entrypoint) and replace the module-level device assignment with deferred lookup
so imports succeed without an active NPU context; update any references that
used the module-level device variable to retrieve the device inside main() or
pass it down to functions that need it.
README.md
Outdated
| <img src=./images/MatmulExample.png /> | ||
|
|
||
| ## Latest News | ||
| - 01/22/2025 🚀: Added CommonIR support, enabling compilation on wider domestic GPU through [DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Date placement appears out of chronological order.
The new entry dated 01/22/2025 is placed at the top, but chronologically it falls between 01/20/2025 and 02/10/2025. If news entries are meant to be in reverse chronological order (newest first), this entry should be placed lower in the list, after 02/10/2025 and before 01/20/2025.
Alternatively, if this is the most recent addition (perhaps the year should be 2026?), please verify the date is correct.
🤖 Prompt for AI Agents
In `@README.md` at line 16, The changelog entry "01/22/2025 🚀: Added CommonIR
support, enabling compilation on wider domestic GPU through
[DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration." is out of
chronological order; either correct its date if it should be the latest (e.g.,
change year to 2026) or move this exact entry string to its proper position in
the reverse-chronological list (after the "02/10/2025" entry and before the
"01/20/2025" entry) so the README's news section remains sorted newest-first.
| std::vector<unsigned long> GetStrideFromShape(Array<tvm::PrimExpr> shape) { | ||
| std::vector<unsigned long> strides; | ||
| unsigned long total_size = 1; | ||
| std::vector<int> shape_int; | ||
| for (PrimExpr s : shape) { | ||
| if (auto s_int = as_const_int(s)) { | ||
| total_size *= *s_int; | ||
| shape_int.push_back(*s_int); | ||
| } | ||
| } | ||
| for (int i = 0; i < shape.size(); i++) { | ||
| total_size /= shape_int[i]; | ||
| strides.push_back(total_size); | ||
| } | ||
| return strides; | ||
| } | ||
|
|
||
| String GetBufferStrides(Buffer buffer) { | ||
| Array<PrimExpr> shape = buffer->shape; | ||
| std::vector<unsigned long> strides; | ||
| int dim = buffer->shape.size(); | ||
| if (buffer->strides.empty()) { | ||
| strides = GetStrideFromShape(shape); | ||
| } else { | ||
| for (PrimExpr stride : buffer->strides) { | ||
| if (auto stride_int = as_const_int(stride)) { | ||
| strides.push_back(*stride_int); | ||
| } | ||
| } | ||
| } | ||
| String res = "["; | ||
| for (int i = 0; i < dim; i++) { | ||
| if (i > 0) | ||
| res = res + ", "; | ||
| res = res + std::to_string(strides[i]); | ||
| } | ||
| res = res + "]"; | ||
| return res; | ||
| } | ||
|
|
||
| static std::vector<int> getBroadcastDim(Array<PrimExpr> &buffer_shape0, | ||
| Array<PrimExpr> &buffer_shape1) { | ||
| assert(buffer_shape0.size() == buffer_shape1.size()); | ||
| std::vector<int> dims; | ||
| for (int i = 0; i < buffer_shape0.size(); i++) { | ||
| if (*as_const_int(buffer_shape0[i]) == 1 && | ||
| *as_const_int(buffer_shape1[i]) != 1) { | ||
| dims.emplace_back(i); | ||
| } | ||
| if (*as_const_int(buffer_shape0[i]) != 1 && | ||
| *as_const_int(buffer_shape1[i]) == 1) { | ||
| dims.emplace_back(i); | ||
| } | ||
| assert(*as_const_int(buffer_shape0[i]) == *as_const_int(buffer_shape1[i])); | ||
| } | ||
| return dims; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against non-constant shapes/strides in stride & broadcast helpers.
as_const_int() is dereferenced without null checks; if any shape/stride is symbolic, this triggers UB or out‑of‑bounds access. Add validation or explicit dynamic handling.
✅ Suggested fix
std::vector<unsigned long> GetStrideFromShape(Array<tvm::PrimExpr> shape) {
std::vector<unsigned long> strides;
unsigned long total_size = 1;
- std::vector<int> shape_int;
+ std::vector<int64_t> shape_int;
for (PrimExpr s : shape) {
- if (auto s_int = as_const_int(s)) {
- total_size *= *s_int;
- shape_int.push_back(*s_int);
- }
+ auto s_int = as_const_int(s);
+ ICHECK(s_int) << "dynamic shapes are not supported in GetStrideFromShape";
+ total_size *= *s_int;
+ shape_int.push_back(*s_int);
}
for (int i = 0; i < shape.size(); i++) {
total_size /= shape_int[i];
strides.push_back(total_size);
}
return strides;
}
String GetBufferStrides(Buffer buffer) {
Array<PrimExpr> shape = buffer->shape;
std::vector<unsigned long> strides;
int dim = buffer->shape.size();
if (buffer->strides.empty()) {
strides = GetStrideFromShape(shape);
} else {
for (PrimExpr stride : buffer->strides) {
- if (auto stride_int = as_const_int(stride)) {
- strides.push_back(*stride_int);
- }
+ auto stride_int = as_const_int(stride);
+ ICHECK(stride_int) << "dynamic strides are not supported in GetBufferStrides";
+ strides.push_back(*stride_int);
}
}
+ ICHECK(strides.size() == static_cast<size_t>(dim)) << "stride rank mismatch";
String res = "[";
for (int i = 0; i < dim; i++) {
if (i > 0)
res = res + ", ";
res = res + std::to_string(strides[i]);
}
res = res + "]";
return res;
}
static std::vector<int> getBroadcastDim(Array<PrimExpr> &buffer_shape0,
Array<PrimExpr> &buffer_shape1) {
assert(buffer_shape0.size() == buffer_shape1.size());
std::vector<int> dims;
for (int i = 0; i < buffer_shape0.size(); i++) {
- if (*as_const_int(buffer_shape0[i]) == 1 &&
- *as_const_int(buffer_shape1[i]) != 1) {
+ auto s0 = as_const_int(buffer_shape0[i]);
+ auto s1 = as_const_int(buffer_shape1[i]);
+ ICHECK(s0 && s1) << "broadcast requires static shapes";
+ if (*s0 == 1 && *s1 != 1) {
dims.emplace_back(i);
}
- if (*as_const_int(buffer_shape0[i]) != 1 &&
- *as_const_int(buffer_shape1[i]) == 1) {
+ if (*s0 != 1 && *s1 == 1) {
dims.emplace_back(i);
}
- assert(*as_const_int(buffer_shape0[i]) == *as_const_int(buffer_shape1[i]));
+ ICHECK(*s0 == *s1) << "shape mismatch for broadcast";
}
return dims;
}🤖 Prompt for AI Agents
In `@src/target/codegen_commonir.cc` around lines 92 - 147, The helpers
GetStrideFromShape, GetBufferStrides, and getBroadcastDim dereference
as_const_int() results without null checks, which breaks on non-constant
(symbolic) shapes/strides; update them to validate as_const_int(...) before
dereferencing and handle non-constant cases explicitly (e.g., return an error,
throw/log an explanatory message, or implement dynamic stride handling) — in
GetStrideFromShape skip or error on non-constant shape elements and avoid
building shape_int from nulls; in GetBufferStrides check each buffer->strides
element via as_const_int and handle the null case; in getBroadcastDim guard
every *as_const_int(...) access, ensure shapes are constant or emit a clear
error/exception when symbolic dimensions are encountered.
| std::string GetCastOp(DataType src_type, DataType dst_type) { | ||
| bool srcIsFloat = src_type.is_float() || src_type.is_bfloat16(); | ||
| bool srcIsInt = src_type.is_int(); | ||
| bool srcIsUInt = src_type.is_uint(); | ||
| bool targetIsFloat = dst_type.is_float() || dst_type.is_bfloat16(); | ||
| bool targetIsInt = dst_type.is_int(); | ||
| bool targetIsUInt = dst_type.is_uint(); | ||
| if (srcIsFloat && targetIsInt) { | ||
| return "arith.fptosi"; | ||
| } else if (srcIsFloat && targetIsUInt) { | ||
| return "arith.fptoui"; | ||
| } else if (srcIsInt && targetIsFloat) { | ||
| return "arith.sitofp"; | ||
| } else if (srcIsUInt && targetIsFloat) { | ||
| return "arith.uitofp"; | ||
| } else if (targetIsInt) { | ||
| if (dst_type.bits() > src_type.bits()) { | ||
| return "arith.extsi"; | ||
| } else { | ||
| return "arith.trunci"; | ||
| } | ||
| } else if (targetIsUInt) { | ||
| if (dst_type.bits() > src_type.bits()) { | ||
| return "arith.extui"; | ||
| } else { | ||
| return "arith.trunci"; | ||
| } | ||
| } else if (targetIsFloat) { | ||
| if (dst_type.bits() > src_type.bits()) { | ||
| return "arith.extf"; | ||
| } else { | ||
| return "arith.truncf"; | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a fallback/error path in GetCastOp to avoid UB.
The function can fall through without returning a value for unsupported type combinations, which is undefined behavior.
✅ Suggested fix
} else if (targetIsFloat) {
if (dst_type.bits() > src_type.bits()) {
return "arith.extf";
} else {
return "arith.truncf";
}
}
+ LOG(FATAL) << "Unsupported cast from " << src_type << " to " << dst_type;
+ return "";
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| std::string GetCastOp(DataType src_type, DataType dst_type) { | |
| bool srcIsFloat = src_type.is_float() || src_type.is_bfloat16(); | |
| bool srcIsInt = src_type.is_int(); | |
| bool srcIsUInt = src_type.is_uint(); | |
| bool targetIsFloat = dst_type.is_float() || dst_type.is_bfloat16(); | |
| bool targetIsInt = dst_type.is_int(); | |
| bool targetIsUInt = dst_type.is_uint(); | |
| if (srcIsFloat && targetIsInt) { | |
| return "arith.fptosi"; | |
| } else if (srcIsFloat && targetIsUInt) { | |
| return "arith.fptoui"; | |
| } else if (srcIsInt && targetIsFloat) { | |
| return "arith.sitofp"; | |
| } else if (srcIsUInt && targetIsFloat) { | |
| return "arith.uitofp"; | |
| } else if (targetIsInt) { | |
| if (dst_type.bits() > src_type.bits()) { | |
| return "arith.extsi"; | |
| } else { | |
| return "arith.trunci"; | |
| } | |
| } else if (targetIsUInt) { | |
| if (dst_type.bits() > src_type.bits()) { | |
| return "arith.extui"; | |
| } else { | |
| return "arith.trunci"; | |
| } | |
| } else if (targetIsFloat) { | |
| if (dst_type.bits() > src_type.bits()) { | |
| return "arith.extf"; | |
| } else { | |
| return "arith.truncf"; | |
| } | |
| } | |
| } | |
| std::string GetCastOp(DataType src_type, DataType dst_type) { | |
| bool srcIsFloat = src_type.is_float() || src_type.is_bfloat16(); | |
| bool srcIsInt = src_type.is_int(); | |
| bool srcIsUInt = src_type.is_uint(); | |
| bool targetIsFloat = dst_type.is_float() || dst_type.is_bfloat16(); | |
| bool targetIsInt = dst_type.is_int(); | |
| bool targetIsUInt = dst_type.is_uint(); | |
| if (srcIsFloat && targetIsInt) { | |
| return "arith.fptosi"; | |
| } else if (srcIsFloat && targetIsUInt) { | |
| return "arith.fptoui"; | |
| } else if (srcIsInt && targetIsFloat) { | |
| return "arith.sitofp"; | |
| } else if (srcIsUInt && targetIsFloat) { | |
| return "arith.uitofp"; | |
| } else if (targetIsInt) { | |
| if (dst_type.bits() > src_type.bits()) { | |
| return "arith.extsi"; | |
| } else { | |
| return "arith.trunci"; | |
| } | |
| } else if (targetIsUInt) { | |
| if (dst_type.bits() > src_type.bits()) { | |
| return "arith.extui"; | |
| } else { | |
| return "arith.trunci"; | |
| } | |
| } else if (targetIsFloat) { | |
| if (dst_type.bits() > src_type.bits()) { | |
| return "arith.extf"; | |
| } else { | |
| return "arith.truncf"; | |
| } | |
| } | |
| LOG(FATAL) << "Unsupported cast from " << src_type << " to " << dst_type; | |
| return ""; | |
| } |
🤖 Prompt for AI Agents
In `@src/target/codegen_commonir.cc` around lines 185 - 219, GetCastOp can fall
through without returning for unsupported src/dst combinations; add a clear
fallback at the end of GetCastOp (referencing GetCastOp, src_type, dst_type,
DataType::bits) that either returns a safe sentinel (e.g., empty string) or
better, raises a deterministic error (throw std::runtime_error or call your
project's fatal logging) including the src/dst type info so undefined behavior
is avoided and debugging is easier. Ensure the fallback is reached only when
none of the existing branches match and include descriptive text with
src_type/dst_type in the error.
| void CodeGenTileLangCOMMONIR::VisitExpr_(const DivNode *op, std::ostream &os) { | ||
| PrintBinary(op, "<<<divf>>>", os, this); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implement DivNode emission (currently emits a placeholder).
<<<divf>>> yields invalid IR. Emit a real div op or hard‑fail for unsupported dtypes.
✅ Suggested fix
void CodeGenTileLangCOMMONIR::VisitExpr_(const DivNode *op, std::ostream &os) {
- PrintBinary(op, "<<<divf>>>", os, this);
+ if (op->dtype.is_int()) {
+ PrintBinary(op, "divsi", os, this);
+ } else if (op->dtype.is_uint()) {
+ PrintBinary(op, "divui", os, this);
+ } else if (op->dtype.is_float()) {
+ PrintBinary(op, "divf", os, this);
+ } else {
+ LOG(FATAL) << "Unsupported dtype for DivNode: " << op->dtype;
+ }
}📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| void CodeGenTileLangCOMMONIR::VisitExpr_(const DivNode *op, std::ostream &os) { | |
| PrintBinary(op, "<<<divf>>>", os, this); | |
| void CodeGenTileLangCOMMONIR::VisitExpr_(const DivNode *op, std::ostream &os) { | |
| if (op->dtype.is_int()) { | |
| PrintBinary(op, "divsi", os, this); | |
| } else if (op->dtype.is_uint()) { | |
| PrintBinary(op, "divui", os, this); | |
| } else if (op->dtype.is_float()) { | |
| PrintBinary(op, "divf", os, this); | |
| } else { | |
| LOG(FATAL) << "Unsupported dtype for DivNode: " << op->dtype; | |
| } | |
| } |
🤖 Prompt for AI Agents
In `@src/target/codegen_commonir.cc` around lines 1148 - 1149, The VisitExpr_
implementation for DivNode currently emits a placeholder "<<<divf>>>"; replace
this by inspecting op->dtype (use op->dtype() or op->dtype().base_type()) and
emit the correct division IR token for floats vs integers (e.g., a
floating-point divide string for float types and an integer divide string or
signed/unsigned variant for integer types) via the same PrintBinary call (e.g.,
PrintBinary(op, "<correct-div-op>", os, this)); if the dtype is unsupported,
hard-fail (throw or LOG(FATAL)) with a clear message referencing DivNode so
callers see the unsupported dtype instead of invalid IR.
| from typing import ( | ||
| Any, | ||
| List, | ||
| Union, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Remove duplicate imports flagged by static analysis.
Any and Union are already imported on line 8. Remove the redundant imports.
-from typing import (
- Any,
- List,
- Union,
-)
+from typing import List📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| from typing import ( | |
| Any, | |
| List, | |
| Union, | |
| ) | |
| from typing import List |
🧰 Tools
🪛 Ruff (0.14.14)
21-21: Redefinition of unused Any from line 8: Any redefined here
Remove definition: Any
(F811)
23-23: Redefinition of unused Union from line 8: Union redefined here
Remove definition: Union
(F811)
🤖 Prompt for AI Agents
In `@tilelang/jit/jit_commonir.py` around lines 20 - 24, Remove the duplicate
typing imports by editing the import block in jit_commonir.py: keep only the new
symbol(s) needed here (e.g., List) and delete Any and Union which are already
imported earlier (line 8), ensuring there are no redundant imports of Any or
Union in the file.
| } | ||
| block_indices = {"x": None, "y": None, "z": None} | ||
| for dim, pattern in patterns.items(): | ||
| match = re.search(pattern, str(str(self.mod))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Redundant str(str(...)) call.
The double str() conversion is unnecessary.
- match = re.search(pattern, str(str(self.mod)))
+ match = re.search(pattern, str(self.mod))📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| match = re.search(pattern, str(str(self.mod))) | |
| match = re.search(pattern, str(self.mod)) |
🤖 Prompt for AI Agents
In `@tilelang/jit/jit_commonir.py` at line 88, The call uses redundant double
conversion str(str(self.mod)) in the re.search invocation; update the expression
in jit_commonir.py (the line that sets match via re.search(pattern, ...)) to
pass a single string conversion—e.g., replace str(str(self.mod)) with
str(self.mod)—so the regex runs against the same string value without the
unnecessary nested str() call.
| # Estimate the runtime of the function | ||
| start_event = torch.npu.Event(enable_timing=True) | ||
| end_event = torch.npu.Event(enable_timing=True) | ||
| start_event.record() | ||
| for _ in range(5): | ||
| cache.zero_() | ||
| fn() | ||
| end_event.record() | ||
| torch.npu.synchronize() | ||
| estimate_ms = start_event.elapsed_time(end_event) / 5 | ||
|
|
||
| # compute number of warmup and repeat | ||
| n_warmup = max(1, int(warmup / estimate_ms)) | ||
| n_repeat = max(1, int(rep / estimate_ms)) | ||
| if _n_warmup > 0: | ||
| n_warmup = _n_warmup | ||
| if _n_repeat > 0: | ||
| n_repeat = _n_repeat |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Guard against zero/near-zero timing estimates in do_bench.
elapsed_time() can return 0 for very fast kernels or coarse timer resolution, which triggers a divide-by-zero when computing n_warmup/n_repeat. Clamp to a small positive value to keep benchmarking stable.
✅ Suggested fix
- estimate_ms = start_event.elapsed_time(end_event) / 5
+ estimate_ms = start_event.elapsed_time(end_event) / 5
+ if estimate_ms <= 0:
+ estimate_ms = 1e-6📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Estimate the runtime of the function | |
| start_event = torch.npu.Event(enable_timing=True) | |
| end_event = torch.npu.Event(enable_timing=True) | |
| start_event.record() | |
| for _ in range(5): | |
| cache.zero_() | |
| fn() | |
| end_event.record() | |
| torch.npu.synchronize() | |
| estimate_ms = start_event.elapsed_time(end_event) / 5 | |
| # compute number of warmup and repeat | |
| n_warmup = max(1, int(warmup / estimate_ms)) | |
| n_repeat = max(1, int(rep / estimate_ms)) | |
| if _n_warmup > 0: | |
| n_warmup = _n_warmup | |
| if _n_repeat > 0: | |
| n_repeat = _n_repeat | |
| # Estimate the runtime of the function | |
| start_event = torch.npu.Event(enable_timing=True) | |
| end_event = torch.npu.Event(enable_timing=True) | |
| start_event.record() | |
| for _ in range(5): | |
| cache.zero_() | |
| fn() | |
| end_event.record() | |
| torch.npu.synchronize() | |
| estimate_ms = start_event.elapsed_time(end_event) / 5 | |
| if estimate_ms <= 0: | |
| estimate_ms = 1e-6 | |
| # compute number of warmup and repeat | |
| n_warmup = max(1, int(warmup / estimate_ms)) | |
| n_repeat = max(1, int(rep / estimate_ms)) | |
| if _n_warmup > 0: | |
| n_warmup = _n_warmup | |
| if _n_repeat > 0: | |
| n_repeat = _n_repeat |
🤖 Prompt for AI Agents
In `@tilelang/jit/kernel_npu.py` around lines 39 - 56, The timing estimate
computed from start_event.elapsed_time(end_event) can be zero; clamp estimate_ms
to a small positive floor (e.g., 1e-3 ms) before using it to compute n_warmup
and n_repeat to avoid divide-by-zero and extreme counts. Locate the block that
creates start_event/end_event and sets estimate_ms, then set estimate_ms =
max(estimate_ms, <small_positive>) prior to computing n_warmup/n_repeat (keeping
later overrides for _n_warmup/_n_repeat intact); reference symbols: start_event,
end_event, estimate_ms, n_warmup, n_repeat, _n_warmup, _n_repeat, warmup, rep.
| def get_tensor(param: KernelParam) -> torch.Tensor: | ||
| dtype: torch.dtype = param.torch_dtype() | ||
| device = torch.npu.current_device() | ||
|
|
||
| if hasattr(param, "shape") and not param.shape: | ||
| raise ValueError( | ||
| f"TensorType must have a shape, but got {type(param)}, " | ||
| "likely you are trying to generate a random tensor with a dynamic symbolic shape." | ||
| ) | ||
|
|
||
| # Check if with dynamic symbolic shape | ||
| for shape in param.shape: | ||
| if isinstance(shape, tir.Var): | ||
| raise ValueError( | ||
| f"TensorType must have a static shape, but got {shape}, " | ||
| "likely you are trying to generate a random tensor with a dynamic symbolic shape." | ||
| ) | ||
|
|
||
| shape = list(map(int, param.shape)) | ||
| if supply_type == TensorSupplyType.Auto: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allow scalar KernelParam shapes instead of raising.
KernelParam.is_scalar() uses an empty shape; the current guard raises and breaks profiling for scalar parameters. Treat empty shapes as valid (0‑dim tensors).
✅ Suggested fix
- if hasattr(param, "shape") and not param.shape:
- raise ValueError(
- f"TensorType must have a shape, but got {type(param)}, "
- "likely you are trying to generate a random tensor with a dynamic symbolic shape."
- )
-
- # Check if with dynamic symbolic shape
- for shape in param.shape:
- if isinstance(shape, tir.Var):
- raise ValueError(
- f"TensorType must have a static shape, but got {shape}, "
- "likely you are trying to generate a random tensor with a dynamic symbolic shape."
- )
-
- shape = list(map(int, param.shape))
+ if hasattr(param, "shape") and not param.shape:
+ # scalar parameter is valid
+ shape = []
+ else:
+ # Check if with dynamic symbolic shape
+ for dim in param.shape:
+ if isinstance(dim, tir.Var):
+ raise ValueError(
+ f"TensorType must have a static shape, but got {dim}, "
+ "likely you are trying to generate a random tensor with a dynamic symbolic shape."
+ )
+ shape = list(map(int, param.shape))🧰 Tools
🪛 Ruff (0.14.14)
99-102: Avoid specifying long messages outside the exception class
(TRY003)
107-110: Prefer TypeError exception for invalid type
(TRY004)
107-110: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@tilelang/jit/kernel_npu.py` around lines 94 - 113, The get_tensor function
currently treats empty KernelParam.shape as an error, but
KernelParam.is_scalar() uses an empty shape for 0‑dim tensors; remove the early
ValueError and allow empty shapes as valid scalars. Update the guard in
get_tensor (the block that checks "if hasattr(param, 'shape') and not
param.shape:") to not raise for empty shapes, and ensure the subsequent loop
that checks for tir.Var only runs when param.shape is non-empty (e.g., iterate
only if param.shape). This preserves the tir.Var static-shape check while
allowing scalar KernelParam instances to be handled as 0‑dim tensors.
| def _legalize_result_idx(self, result_idx: Optional[List[int]] = None) -> List[int]: | ||
| params = self.params | ||
| # result_idx is a list of indices of the output tensors | ||
| if result_idx is None: | ||
| result_idx = [] | ||
| elif isinstance(result_idx, int): | ||
| if result_idx > len(params) or result_idx < -len(params): | ||
| raise ValueError( | ||
| f"result_idx should be an integer between {-len(params)} and {len(params) - 1}") | ||
| if result_idx < 0: | ||
| result_idx = len(params) + result_idx | ||
| result_idx = [result_idx] | ||
| elif not isinstance(result_idx, list): | ||
| raise ValueError("result_idx should be a list of integers") | ||
| return result_idx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Normalize list result_idx (negative/out‑of‑range) and fix the >= check.
The int branch allows result_idx == len(params) (out of range). The list branch doesn’t normalize negatives or validate bounds, which mislabels outputs.
✅ Suggested fix
- elif isinstance(result_idx, int):
- if result_idx > len(params) or result_idx < -len(params):
+ elif isinstance(result_idx, int):
+ if result_idx >= len(params) or result_idx < -len(params):
raise ValueError(
f"result_idx should be an integer between {-len(params)} and {len(params) - 1}")
if result_idx < 0:
result_idx = len(params) + result_idx
result_idx = [result_idx]
- elif not isinstance(result_idx, list):
- raise ValueError("result_idx should be a list of integers")
+ elif isinstance(result_idx, list):
+ normalized = []
+ for idx in result_idx:
+ if not isinstance(idx, int):
+ raise ValueError("result_idx should be a list of integers")
+ if idx >= len(params) or idx < -len(params):
+ raise ValueError(
+ f"result_idx should be an integer between {-len(params)} and {len(params) - 1}")
+ if idx < 0:
+ idx = len(params) + idx
+ normalized.append(idx)
+ result_idx = normalized
+ else:
+ raise ValueError("result_idx should be a list of integers")
return result_idx🧰 Tools
🪛 Ruff (0.14.14)
185-186: Avoid specifying long messages outside the exception class
(TRY003)
191-191: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@tilelang/jit/kernel_npu.py` around lines 178 - 192, In _legalize_result_idx,
fix the int branch bounds check to forbid result_idx >= len(self.params) (not
just >) and normalize negative indices as currently done; additionally, when
result_idx is a list (the else branch), validate each element is an int,
normalize any negative indices by adding len(self.params), and raise ValueError
for any index < 0 or >= len(self.params); update references to params via
self.params and return the normalized list of ints.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@README.md`:
- Line 16: Update the dated news entry that currently reads "01/29/2025 🚀:
Added CommonIR support..." to the correct year "01/29/2026" and ensure the
News/Changelog block remains in reverse-chronological order; specifically locate
the "01/29/2025" string in README.md and change it to "01/29/2026" (and if
needed reorder adjacent entries so later 2025/2026 items remain in proper
descending date order).
In `@tilelang/utils/target.py`:
- Around line 140-144: The current check if COMMONIR_enabled unconditionally
overrides any explicit target; change it so COMMONIR only hijacks when the
requested target is the default/auto value—i.e., update the conditional around
return_var/Target(return_var) (the block using COMMONIR_enabled, return_var and
return_object) to also test the function's target argument (e.g., target is None
or target == "auto") before returning "commonir" or Target("commonir"); keep the
existing return_object behavior intact so explicit targets like "cuda" are
preserved.
🧹 Nitpick comments (1)
tilelang/utils/target.py (1)
27-27: Missing blank line after module-level constant.Per PEP 8, there should be a blank line between module-level definitions and function definitions.
COMMONIR_enabled = os.environ.get('USE_COMMONIR', '0') in ('1', 'true', 'on') + def describe_supported_targets() -> dict[str, str]:
| <img src=./images/MatmulExample.png /> | ||
|
|
||
| ## Latest News | ||
| - 01/29/2025 🚀: Added CommonIR support, enabling compilation on wider domestic GPU through [DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Date year appears incorrect.
The entry is dated 01/29/2025, but given the current date (January 2026) and the reverse-chronological ordering of the news section, this should likely be 01/29/2026. With the current date, this entry would be out of order—it should appear much lower in the list, after entries from later in 2025.
📝 Suggested fix
-- 01/29/2025 🚀: Added CommonIR support, enabling compilation on wider domestic GPU through [DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration.
+- 01/29/2026 🚀: Added CommonIR support, enabling compilation on wider domestic GPU through [DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| - 01/29/2025 🚀: Added CommonIR support, enabling compilation on wider domestic GPU through [DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration. | |
| - 01/29/2026 🚀: Added CommonIR support, enabling compilation on wider domestic GPU through [DLCompiler](https://github.com/DeepLink-org/DLCompiler) integration. |
🤖 Prompt for AI Agents
In `@README.md` at line 16, Update the dated news entry that currently reads
"01/29/2025 🚀: Added CommonIR support..." to the correct year "01/29/2026" and
ensure the News/Changelog block remains in reverse-chronological order;
specifically locate the "01/29/2025" string in README.md and change it to
"01/29/2026" (and if needed reorder adjacent entries so later 2025/2026 items
remain in proper descending date order).
| if COMMONIR_enabled: | ||
| return_var = "commonir" | ||
| if return_object: | ||
| return Target(return_var) | ||
| return return_var |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
COMMONIR_enabled unconditionally overrides any explicit target argument.
When USE_COMMONIR=1, this early return bypasses all target detection logic, ignoring even explicit user-specified targets like target="cuda". This may be intentional for a global opt-in, but it creates surprising behavior where the user's explicit target choice is silently overridden.
Consider adjusting the logic to respect explicit target specifications:
return_var: str | Target = target
- if COMMONIR_enabled:
+ if target == "commonir" or (target == "auto" and COMMONIR_enabled):
return_var = "commonir"
if return_object:
return Target(return_var)
return return_varThis would allow USE_COMMONIR=1 to affect only target="auto" (the default), while explicit targets like target="cuda" remain honored.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if COMMONIR_enabled: | |
| return_var = "commonir" | |
| if return_object: | |
| return Target(return_var) | |
| return return_var | |
| if target == "commonir" or (target == "auto" and COMMONIR_enabled): | |
| return_var = "commonir" | |
| if return_object: | |
| return Target(return_var) | |
| return return_var |
🤖 Prompt for AI Agents
In `@tilelang/utils/target.py` around lines 140 - 144, The current check if
COMMONIR_enabled unconditionally overrides any explicit target; change it so
COMMONIR only hijacks when the requested target is the default/auto value—i.e.,
update the conditional around return_var/Target(return_var) (the block using
COMMONIR_enabled, return_var and return_object) to also test the function's
target argument (e.g., target is None or target == "auto") before returning
"commonir" or Target("commonir"); keep the existing return_object behavior
intact so explicit targets like "cuda" are preserved.
Summary
User-facing changes
USE_COMMONIRto enable or disable the CommonIR feature pluggably.USE_COMMONIR=0: all existing targets keep their previous defaults entirely unaffected.What’s included
tilelang.utils.target.determine_target()recognizes commonir target and injects the commonir key.tilelang.engine.lower.device_codegen*()dispatches totarget.build.tilelang_commonir*when the target has the commonir key.src/target/codegen_commonir.*) built on top of the tvm codegen infrastructure (CodeGenC).src/target/rt_mod_commonir.cc).examples/commonir/*.py(tested currently on Huawei Ascend NPU).Requirements / constraints
Backward compatibility
Follow-ups (not in this PR)
Summary by CodeRabbit
New Features
Documentation & Examples
Chores
✏️ Tip: You can customize this high-level summary in your review settings.