Skip to content

[Feature Request] Add CUDA graph backend support to autotuner #1633

@cscyuge

Description

@cscyuge

Required prerequisites

  • I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)

Problem

The AutoTuner currently uses the default "event" backend for benchmarking, which includes kernel launch overhead. For short-duration kernels, this overhead can significantly affect configuration selection. When kernels are executed with CUDA graphs (minimizing launch overhead), the selected configuration may not be optimal.

Request: Add support for CUDA graph backend in the profiler, allowing users to specify the profiling backend during auto-tuning.

Details

We compared different profiling methods and found significant discrepancies. But we don't known which measurement method is most reliable.

Benchmarking a small GEMM kernel (M=1, N=1024, K=1024) result:

Method Latency
tilelang profiler event backend 8.77us
tilelang profiler cupti backend 4.64us
triton.testing.do_bench 8.40us
triton.testing.do_bench_cudagraph 3.65us
ncu (direct kernel) 5.38us
ncu (CUDA graph) 5.18us

test script:

import tilelang
import tilelang.language as T
import triton


@tilelang.jit()
def matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32):
    @T.prim_func
    def gemm(
        A: T.Tensor((M, K), dtype),
        B: T.Tensor((K, N), dtype),
        C: T.Tensor((M, N), dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by):
            A_shared = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_K, block_N), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=4):
                T.copy(A[by * block_M, k * block_K], A_shared)
                T.copy(B[k * block_K, bx * block_N], B_shared)
                T.gemm(A_shared, B_shared, C_local)

            T.copy(C_local, C[by * block_M, bx * block_N])

    return gemm


def main():
    kernel = matmul(1, 1024, 1024, 16, 32, 128)

    import torch

    a = torch.randn(1, 1024).cuda().half()
    b = torch.randn(1024, 1024).cuda().half()
    c = torch.empty(1, 1024).cuda().half()

    kernel(a, b, c)
    
    # benchmark
    profiler = kernel.get_profiler()
    latency_cupti = profiler.do_bench(backend="cupti")
    latency_event = profiler.do_bench(backend="event")
    def func():
        kernel(a, b, c)
    latency_trion = triton.testing.do_bench(func)
    latency_triton_cudagraph = triton.testing.do_bench_cudagraph(func)
    print(f"cupti latency: {latency_cupti}ms")
    print(f"event latency: {latency_event}ms")
    print(f"triton latency: {latency_trion}ms")    
    print(f"triton cudagraph latency: {latency_triton_cudagraph}ms")
    
if __name__ == "__main__":
    main()

Solution

No response

Alternatives

No response

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions