-
Notifications
You must be signed in to change notification settings - Fork 431
[AMD] Fix bugs about AMD FA kernel #1701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. 📝 WalkthroughWalkthroughA new alias attribute Changes
Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes two issues with the AMD Flash Attention kernel:
- Adds the missing
kernel_global_sourceattribute toCythonKernelAdapterwhich is required when saving/loading tuned kernels - Removes problematic configuration parameters (block_M=32, block_N=32, threads=512) that don't work with the current implementation
Changes:
- Added
kernel_global_sourceattribute as an alias fordevice_kernel_sourcein CythonKernelAdapter for compatibility - Updated
supply_tensors_gputo properly map TileLang dtypes to PyTorch dtypes - Removed problematic tuning configurations and updated from
whiletoT.Whileconstruct
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| tilelang/jit/adapter/cython/adapter.py | Added kernel_global_source field and properly initialized it in both __init__ and from_database methods as an alias to device_kernel_source |
| examples/amd/example_amd_flash_attn_fwd.py | Fixed dtype mapping, removed problematic configs, updated to use T.While construct, and improved comments |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # TileLang dtype to PyTorch dtype mapping | ||
| dtype_map = { | ||
| T.float16: torch.float16, | ||
| T.float32: torch.float32, | ||
| T.int32: torch.int32, | ||
| T.int64: torch.int64, | ||
| } | ||
|
|
||
| tensors = [] | ||
| for param in params: | ||
| if hasattr(param, "shape") and hasattr(param, "dtype"): | ||
| # Force creation on GPU device | ||
| shape = [int(s) for s in param.shape] | ||
| tensor = torch.randn(shape, dtype=param.dtype, device="cuda") | ||
| # Convert TileLang dtype to PyTorch dtype | ||
| torch_dtype = dtype_map.get(param.dtype, torch.float16) | ||
| tensor = torch.randn(shape, dtype=torch_dtype, device="cuda") |
Copilot
AI
Jan 20, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The manual dtype mapping could be simplified by using the built-in torch_dtype() method from KernelParam. Instead of maintaining a manual dtype_map dictionary, you can replace lines 14-19 and 27 with:
torch_dtype = param.torch_dtype()
This would eliminate the need for manual mapping and automatically handle all dtype conversions, including edge cases like float8 types that require special handling for HIP vs CUDA backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@examples/amd/example_amd_flash_attn_fwd.py`:
- Around line 13-28: Replace the manual dtype mapping in supply_tensors_gpu: the
current dtype_map and lookup using param.dtype should be removed and replaced by
calling param.torch_dtype() to obtain the correct PyTorch dtype; update the code
that sets torch_dtype (and remove the dtype_map variable) so tensor creation
uses torch_dtype = param.torch_dtype() ensuring all KernelParam dtypes
(bfloat16, float8, etc.) are handled and no silent fallback to torch.float16
occurs.
🧹 Nitpick comments (1)
examples/amd/example_amd_flash_attn_fwd.py (1)
53-57: Add a comment explaining why specific configurations were removed.The docstring mentions "avoiding problematic combinations" but doesn't explain the root cause. Per the PR description,
block_M=32andblock_N=32cause layout inference conflicts and divide-by-zero errors. Adding this context prevents future maintainers from re-introducing these values:📝 Suggested improvement
def get_configs(): - """Generates configurations for the autotuner, avoiding problematic combinations.""" - block_M = [64, 128, 256] - block_N = [64, 128, 256] - threads = [128, 256] + """Generates configurations for the autotuner, avoiding problematic combinations. + + Note: block_M/N=32 removed due to layout inference conflicts between acc_s and + acc_s_cast fragments. threads=512 removed to avoid divide-by-zero check failures. + """ + block_M = [64, 128, 256] # 32 causes layout conflict + block_N = [64, 128, 256] # 32 causes layout conflict + threads = [128, 256] # 512 causes divide-by-zero
| # TileLang dtype to PyTorch dtype mapping | ||
| dtype_map = { | ||
| T.float16: torch.float16, | ||
| T.float32: torch.float32, | ||
| T.int32: torch.int32, | ||
| T.int64: torch.int64, | ||
| } | ||
|
|
||
| tensors = [] | ||
| for param in params: | ||
| if hasattr(param, "shape") and hasattr(param, "dtype"): | ||
| # Force creation on GPU device | ||
| shape = [int(s) for s in param.shape] | ||
| tensor = torch.randn(shape, dtype=param.dtype, device="cuda") | ||
| # Convert TileLang dtype to PyTorch dtype | ||
| torch_dtype = dtype_map.get(param.dtype, torch.float16) | ||
| tensor = torch.randn(shape, dtype=torch_dtype, device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, read the actual file to understand context
echo "=== File content around lines 13-28 ==="
head -50 examples/amd/example_amd_flash_attn_fwd.py | cat -nRepository: tile-ai/tilelang
Length of output: 2301
🏁 Script executed:
#!/bin/bash
# Check the KernelParam class for torch_dtype method
echo "=== Searching for torch_dtype in param.py ==="
rg -n "torch_dtype" tilelang/engine/param.py -B 3 -A 10Repository: tile-ai/tilelang
Length of output: 1004
🏁 Script executed:
#!/bin/bash
# Find the supply_tensors_gpu function to see how params are created/used
echo "=== Searching for supply_tensors_gpu function ==="
rg -n "def supply_tensors_gpu" examples/amd/example_amd_flash_attn_fwd.py -B 2 -A 50 | head -80Repository: tile-ai/tilelang
Length of output: 2218
🏁 Script executed:
#!/bin/bash
# Look for where tensors are created from params to understand the full flow
echo "=== Checking how params are passed into this function ==="
rg -n "supply_tensors_gpu" examples/amd/example_amd_flash_attn_fwd.pyRepository: tile-ai/tilelang
Length of output: 255
🏁 Script executed:
#!/bin/bash
# Verify that params are indeed KernelParam objects by checking autotune decorator
echo "=== Searching for autotune decorator and how it passes params ==="
rg -n "def autotune" tilelang/ -A 20 | head -80Repository: tile-ai/tilelang
Length of output: 1518
🏁 Script executed:
#!/bin/bash
# Check if there are any other uses of dtype mapping in similar contexts
echo "=== Checking for other dtype mapping patterns in the codebase ==="
rg -n "dtype_map" --type py | head -20Repository: tile-ai/tilelang
Length of output: 2198
🏁 Script executed:
#!/bin/bash
# Verify the T.dtype().as_torch() handles all the expected types
echo "=== Checking what T.dtype().as_torch() implementation looks like ==="
rg -n "def as_torch" tilelang/ -B 2 -A 10 | head -60Repository: tile-ai/tilelang
Length of output: 865
🏁 Script executed:
#!/bin/bash
# Get the full implementation of as_torch()
echo "=== Full as_torch() implementation ==="
rg -n "def as_torch" tilelang/language/dtypes.py -A 20Repository: tile-ai/tilelang
Length of output: 632
🏁 Script executed:
#!/bin/bash
# Check the dtype class in TVM to understand what as_torch covers
echo "=== Checking what dtypes are defined in tilelang ==="
grep -n "float16\|float32\|int32\|int64\|bfloat16\|float8" tilelang/language/dtypes.py | head -30Repository: tile-ai/tilelang
Length of output: 957
🏁 Script executed:
#!/bin/bash
# Check the supply_prog flow - what type of params are passed
echo "=== Looking at how supply_prog is called in autotune ==="
rg -n "supply_prog" tilelang/autotuner/tuner.py -B 3 -A 3 | head -40Repository: tile-ai/tilelang
Length of output: 1970
🏁 Script executed:
#!/bin/bash
# Check if KernelParam is the standard param type or if there are other types
echo "=== Checking KernelParam usage in the autotune system ==="
rg -n "class KernelParam" tilelang/engine/param.py -A 5Repository: tile-ai/tilelang
Length of output: 331
🏁 Script executed:
#!/bin/bash
# Verify by looking at actual test or usage examples
echo "=== Looking for other examples of supply_prog usage ==="
rg -n "supply_prog" examples/ --type py -B 2 -A 8Repository: tile-ai/tilelang
Length of output: 4367
Simplify dtype conversion using the built-in torch_dtype() method.
The param objects passed to supply_tensors_gpu are KernelParam instances with a torch_dtype() method that handles dtype conversion comprehensively. Replace the manual dtype_map dictionary with:
torch_dtype = param.torch_dtype()This approach:
- Eliminates redundant code
- Handles all supported dtypes (including
bfloat16, float8 variants, etc.) instead of just the four in your manual map - Avoids the silent fallback to
torch.float16that could mask dtype mismatches
Before and after
- dtype_map = {
- T.float16: torch.float16,
- T.float32: torch.float32,
- T.int32: torch.int32,
- T.int64: torch.int64,
- }
-
tensors = []
for param in params:
if hasattr(param, "shape") and hasattr(param, "dtype"):
shape = [int(s) for s in param.shape]
- torch_dtype = dtype_map.get(param.dtype, torch.float16)
+ torch_dtype = param.torch_dtype()
tensor = torch.randn(shape, dtype=torch_dtype, device="cuda")🤖 Prompt for AI Agents
In `@examples/amd/example_amd_flash_attn_fwd.py` around lines 13 - 28, Replace the
manual dtype mapping in supply_tensors_gpu: the current dtype_map and lookup
using param.dtype should be removed and replaced by calling param.torch_dtype()
to obtain the correct PyTorch dtype; update the code that sets torch_dtype (and
remove the dtype_map variable) so tensor creation uses torch_dtype =
param.torch_dtype() ensuring all KernelParam dtypes (bfloat16, float8, etc.) are
handled and no silent fallback to torch.float16 occurs.
LeiWang1999
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Sorry that I forgot to submit review though I left some messages.
| shape = [int(s) for s in param.shape] | ||
| tensor = torch.randn(shape, dtype=param.dtype, device="cuda") | ||
| # Convert TileLang dtype to PyTorch dtype | ||
| torch_dtype = dtype_map.get(param.dtype, torch.float16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parm.dtype.as_torch()
| bx = b_split | ||
|
|
||
| while bx < num_q_blocks: | ||
| with T.While(bx < num_q_blocks): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why we need to change while into T.While?
As tilelang updates, I found some issues when I running AMD FA kernel, this pr is to propose the workaround for AMD FA example
and
are the two issues expected? thanks
Summary by CodeRabbit
Chores
Refactor
✏️ Tip: You can customize this high-level summary in your review settings.