Skip to content

Conversation

@danielhua23
Copy link
Contributor

@danielhua23 danielhua23 commented Jan 20, 2026

As tilelang updates, I found some issues when I running AMD FA kernel, this pr is to propose the workaround for AMD FA example

  • issue1: when saving tuned kernel, AMD FA will go through kernel_global_source of CythonKernelAdapter, but that attr is lost
  • issue2: original tuned config like block_M = 32 and block_N = 32 does not work now, so remove it, specifically, the issues are
Layout infer conflict between acc_s and acc_s_cast in T.Parallel loop:
    loop Fragment([32, 32] -> [8], replicate: 2, thread: 256, ...)
    fragment Fragment([32, 32] -> [4], replicate: 1, thread: 256, ...)

and

Check failed: pb->value != 0 (0 vs. 0) : Divide by zero

are the two issues expected? thanks

Summary by CodeRabbit

  • Chores

    • Optimized AMD Flash Attention kernel example with refined block/thread configurations and improved kernel execution flow.
  • Refactor

    • Added kernel compatibility attribute to the kernel adapter for improved code compatibility.

✏️ Tip: You can customize this high-level summary in your review settings.

Copilot AI review requested due to automatic review settings January 20, 2026 10:04
@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 20, 2026

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

📝 Walkthrough

Walkthrough

A new alias attribute kernel_global_source is added to CythonKernelAdapter for compatibility. The AMD Flash Attention forward example is refactored: dtype mapping for tensor creation, narrowed config search, a context-managed kernel loop, and explicit softmax/layout handling added.

Changes

Cohort / File(s) Summary
Cython Adapter Alias
tilelang/jit/adapter/cython/adapter.py
Adds public attribute kernel_global_source: str | None = None to CythonKernelAdapter and sets it to mirror device_kernel_source in __init__ and from_database.
AMD Flash Attention Example Refactoring
examples/amd/example_amd_flash_attn_fwd.py
Adds TileLang→PyTorch dtype mapping in supply_tensors_gpu, narrows block_M/block_N and thread options in get_configs, replaces a plain while loop with a context-managed T.While for the main bx loop, and introduces explicit softmax computation steps and layout-conflict handling (copying acc_sacc_s_cast before GEMM).
Project metadata
manifest_file, requirements.txt, pyproject.toml
Minor metadata/requirements adjustments (unchanged public APIs).

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 I hopped through kernels, small and keen,
I mirrored sources, neat and clean.
Softmax danced within the loop so bright,
Configs trimmed for a faster flight.
Frolic, compile, and run tonight! 🥕✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title '[AMD] Fix bugs about AMD FA kernel' directly addresses the main objective: fixing AMD FA kernel issues after TileLang updates.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes two issues with the AMD Flash Attention kernel:

  1. Adds the missing kernel_global_source attribute to CythonKernelAdapter which is required when saving/loading tuned kernels
  2. Removes problematic configuration parameters (block_M=32, block_N=32, threads=512) that don't work with the current implementation

Changes:

  • Added kernel_global_source attribute as an alias for device_kernel_source in CythonKernelAdapter for compatibility
  • Updated supply_tensors_gpu to properly map TileLang dtypes to PyTorch dtypes
  • Removed problematic tuning configurations and updated from while to T.While construct

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
tilelang/jit/adapter/cython/adapter.py Added kernel_global_source field and properly initialized it in both __init__ and from_database methods as an alias to device_kernel_source
examples/amd/example_amd_flash_attn_fwd.py Fixed dtype mapping, removed problematic configs, updated to use T.While construct, and improved comments

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 13 to 28
# TileLang dtype to PyTorch dtype mapping
dtype_map = {
T.float16: torch.float16,
T.float32: torch.float32,
T.int32: torch.int32,
T.int64: torch.int64,
}

tensors = []
for param in params:
if hasattr(param, "shape") and hasattr(param, "dtype"):
# Force creation on GPU device
shape = [int(s) for s in param.shape]
tensor = torch.randn(shape, dtype=param.dtype, device="cuda")
# Convert TileLang dtype to PyTorch dtype
torch_dtype = dtype_map.get(param.dtype, torch.float16)
tensor = torch.randn(shape, dtype=torch_dtype, device="cuda")
Copy link

Copilot AI Jan 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The manual dtype mapping could be simplified by using the built-in torch_dtype() method from KernelParam. Instead of maintaining a manual dtype_map dictionary, you can replace lines 14-19 and 27 with:

torch_dtype = param.torch_dtype()

This would eliminate the need for manual mapping and automatically handle all dtype conversions, including edge cases like float8 types that require special handling for HIP vs CUDA backends.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@examples/amd/example_amd_flash_attn_fwd.py`:
- Around line 13-28: Replace the manual dtype mapping in supply_tensors_gpu: the
current dtype_map and lookup using param.dtype should be removed and replaced by
calling param.torch_dtype() to obtain the correct PyTorch dtype; update the code
that sets torch_dtype (and remove the dtype_map variable) so tensor creation
uses torch_dtype = param.torch_dtype() ensuring all KernelParam dtypes
(bfloat16, float8, etc.) are handled and no silent fallback to torch.float16
occurs.
🧹 Nitpick comments (1)
examples/amd/example_amd_flash_attn_fwd.py (1)

53-57: Add a comment explaining why specific configurations were removed.

The docstring mentions "avoiding problematic combinations" but doesn't explain the root cause. Per the PR description, block_M=32 and block_N=32 cause layout inference conflicts and divide-by-zero errors. Adding this context prevents future maintainers from re-introducing these values:

📝 Suggested improvement
 def get_configs():
-    """Generates configurations for the autotuner, avoiding problematic combinations."""
-    block_M = [64, 128, 256]
-    block_N = [64, 128, 256]
-    threads = [128, 256]
+    """Generates configurations for the autotuner, avoiding problematic combinations.
+    
+    Note: block_M/N=32 removed due to layout inference conflicts between acc_s and
+    acc_s_cast fragments. threads=512 removed to avoid divide-by-zero check failures.
+    """
+    block_M = [64, 128, 256]  # 32 causes layout conflict
+    block_N = [64, 128, 256]  # 32 causes layout conflict
+    threads = [128, 256]  # 512 causes divide-by-zero

Comment on lines +13 to +28
# TileLang dtype to PyTorch dtype mapping
dtype_map = {
T.float16: torch.float16,
T.float32: torch.float32,
T.int32: torch.int32,
T.int64: torch.int64,
}

tensors = []
for param in params:
if hasattr(param, "shape") and hasattr(param, "dtype"):
# Force creation on GPU device
shape = [int(s) for s in param.shape]
tensor = torch.randn(shape, dtype=param.dtype, device="cuda")
# Convert TileLang dtype to PyTorch dtype
torch_dtype = dtype_map.get(param.dtype, torch.float16)
tensor = torch.randn(shape, dtype=torch_dtype, device="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, read the actual file to understand context
echo "=== File content around lines 13-28 ==="
head -50 examples/amd/example_amd_flash_attn_fwd.py | cat -n

Repository: tile-ai/tilelang

Length of output: 2301


🏁 Script executed:

#!/bin/bash
# Check the KernelParam class for torch_dtype method
echo "=== Searching for torch_dtype in param.py ==="
rg -n "torch_dtype" tilelang/engine/param.py -B 3 -A 10

Repository: tile-ai/tilelang

Length of output: 1004


🏁 Script executed:

#!/bin/bash
# Find the supply_tensors_gpu function to see how params are created/used
echo "=== Searching for supply_tensors_gpu function ==="
rg -n "def supply_tensors_gpu" examples/amd/example_amd_flash_attn_fwd.py -B 2 -A 50 | head -80

Repository: tile-ai/tilelang

Length of output: 2218


🏁 Script executed:

#!/bin/bash
# Look for where tensors are created from params to understand the full flow
echo "=== Checking how params are passed into this function ==="
rg -n "supply_tensors_gpu" examples/amd/example_amd_flash_attn_fwd.py

Repository: tile-ai/tilelang

Length of output: 255


🏁 Script executed:

#!/bin/bash
# Verify that params are indeed KernelParam objects by checking autotune decorator
echo "=== Searching for autotune decorator and how it passes params ==="
rg -n "def autotune" tilelang/ -A 20 | head -80

Repository: tile-ai/tilelang

Length of output: 1518


🏁 Script executed:

#!/bin/bash
# Check if there are any other uses of dtype mapping in similar contexts
echo "=== Checking for other dtype mapping patterns in the codebase ==="
rg -n "dtype_map" --type py | head -20

Repository: tile-ai/tilelang

Length of output: 2198


🏁 Script executed:

#!/bin/bash
# Verify the T.dtype().as_torch() handles all the expected types
echo "=== Checking what T.dtype().as_torch() implementation looks like ==="
rg -n "def as_torch" tilelang/ -B 2 -A 10 | head -60

Repository: tile-ai/tilelang

Length of output: 865


🏁 Script executed:

#!/bin/bash
# Get the full implementation of as_torch()
echo "=== Full as_torch() implementation ==="
rg -n "def as_torch" tilelang/language/dtypes.py -A 20

Repository: tile-ai/tilelang

Length of output: 632


🏁 Script executed:

#!/bin/bash
# Check the dtype class in TVM to understand what as_torch covers
echo "=== Checking what dtypes are defined in tilelang ==="
grep -n "float16\|float32\|int32\|int64\|bfloat16\|float8" tilelang/language/dtypes.py | head -30

Repository: tile-ai/tilelang

Length of output: 957


🏁 Script executed:

#!/bin/bash
# Check the supply_prog flow - what type of params are passed
echo "=== Looking at how supply_prog is called in autotune ==="
rg -n "supply_prog" tilelang/autotuner/tuner.py -B 3 -A 3 | head -40

Repository: tile-ai/tilelang

Length of output: 1970


🏁 Script executed:

#!/bin/bash
# Check if KernelParam is the standard param type or if there are other types
echo "=== Checking KernelParam usage in the autotune system ==="
rg -n "class KernelParam" tilelang/engine/param.py -A 5

Repository: tile-ai/tilelang

Length of output: 331


🏁 Script executed:

#!/bin/bash
# Verify by looking at actual test or usage examples
echo "=== Looking for other examples of supply_prog usage ==="
rg -n "supply_prog" examples/ --type py -B 2 -A 8

Repository: tile-ai/tilelang

Length of output: 4367


Simplify dtype conversion using the built-in torch_dtype() method.

The param objects passed to supply_tensors_gpu are KernelParam instances with a torch_dtype() method that handles dtype conversion comprehensively. Replace the manual dtype_map dictionary with:

torch_dtype = param.torch_dtype()

This approach:

  • Eliminates redundant code
  • Handles all supported dtypes (including bfloat16, float8 variants, etc.) instead of just the four in your manual map
  • Avoids the silent fallback to torch.float16 that could mask dtype mismatches
Before and after
-    dtype_map = {
-        T.float16: torch.float16,
-        T.float32: torch.float32,
-        T.int32: torch.int32,
-        T.int64: torch.int64,
-    }
-
     tensors = []
     for param in params:
         if hasattr(param, "shape") and hasattr(param, "dtype"):
             shape = [int(s) for s in param.shape]
-            torch_dtype = dtype_map.get(param.dtype, torch.float16)
+            torch_dtype = param.torch_dtype()
             tensor = torch.randn(shape, dtype=torch_dtype, device="cuda")
🤖 Prompt for AI Agents
In `@examples/amd/example_amd_flash_attn_fwd.py` around lines 13 - 28, Replace the
manual dtype mapping in supply_tensors_gpu: the current dtype_map and lookup
using param.dtype should be removed and replaced by calling param.torch_dtype()
to obtain the correct PyTorch dtype; update the code that sets torch_dtype (and
remove the dtype_map variable) so tensor creation uses torch_dtype =
param.torch_dtype() ensuring all KernelParam dtypes (bfloat16, float8, etc.) are
handled and no silent fallback to torch.float16 occurs.

Copy link
Member

@LeiWang1999 LeiWang1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Sorry that I forgot to submit review though I left some messages.

shape = [int(s) for s in param.shape]
tensor = torch.randn(shape, dtype=param.dtype, device="cuda")
# Convert TileLang dtype to PyTorch dtype
torch_dtype = dtype_map.get(param.dtype, torch.float16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parm.dtype.as_torch()

bx = b_split

while bx < num_q_blocks:
with T.While(bx < num_q_blocks):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we need to change while into T.While?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants