Skip to content

[BUG] TMA fails #1648

@bucket-xv

Description

@bucket-xv

Required prerequisites

What version of TileLang are you using?

0.1.7.post2+cuda.gita56212de

System information

3.12.9 | (main, Feb 6 2025, 18:56:27) [GCC 11.2.0] linux

0.1.7.post2+cuda.gita56212de
2.9.0+cu130

Problem description

Failure once remove disable tma lower

Reproducible example code

The Python snippets:

import tilelang
from tilelang import language as T
import torch

@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: False,
    },
)
def get_sample_kernel(k: int):
    m = T.symbolic('m')
    block_m = 4
    block_k = 128
    num_threads = 32
    in_dtype = T.float8_e4m3fn
    @T.prim_func
    def sample_kernel(x: T.Tensor[(m, k), in_dtype]):
        with T.Kernel(T.ceildiv(m, block_m), T.ceildiv(k, block_k), threads=num_threads) as (
            pid_m,
            pid_k,
        ):
            # Local buffers
            x_in_shared = T.alloc_shared((block_m, block_k), dtype=in_dtype)

            # Load x to fragment
            T.fill(x_in_shared, 0)
            T.copy(
                x[pid_m * block_m: (pid_m + 1) * block_m, pid_k * block_k: (pid_k + 1) * block_k],
                x_in_shared
            )

    return sample_kernel

kernel = get_sample_kernel(128)
print(kernel.get_kernel_source())

Traceback

/tmp/tmpmgb2e0ij/tvm_kernels.cu(23): error: identifier "mbarrier" is undefined
      tl::tma_load(x_desc, mbarrier[0], (&(x_in_shared[0])), 0, (((int)blockIdx.x) * 4));

Expected behavior

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions