Fix CUDA 13.2 flash attention build compatibility#141
Conversation
Co-authored-by: Codex <codex@openai.com>
There was a problem hiding this comment.
Pull request overview
Fixes CUDA 13.2 toolchain build failures in the CUDA/CUTLASS FMHA codepaths by adjusting shared-memory address arithmetic patterns and by applying a targeted ptxas optimization workaround for SM90 forward instantiation translation units.
Changes:
- Add CUDA 13.x compatibility helpers in
smem_tile.h(limit unrolling + factor shared-memory address arithmetic into helper functions). - Update SM90 forward kernel instantiation selection to avoid
hdimall/hdimdiffbatching on CUDA 13+. - Apply
-Xptxas -O0to SM90 forward instantiation sources for FA3 and FlashMaskV2 on CUDA 13+.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h |
Adds CUDA 13.x ICE workaround via unroll control + helper functions for smem address arithmetic. |
csrc/CMakeLists.txt |
Splits SM90 forward instantiations by head-dim on CUDA 13+ and applies ptxas -O0 to SM90 forward sources to avoid ptxas ICEs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # CUDA 13.2 ptxas can ICE when all SM90 forward head dimensions are batched | ||
| # into the generated hdimall/hdimdiff translation units. Compile them as | ||
| # individual instantiation files on CUDA 13+ to keep each PTX module smaller. | ||
| set(FA3_SPLIT_SM90_FWD_BY_HDIM OFF) | ||
| if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") | ||
| set(FA3_SPLIT_SM90_FWD_BY_HDIM ON) | ||
| endif() |
There was a problem hiding this comment.
The comment says the workaround is for “CUDA 13.2 ptxas”, but the version gate enables it for all CUDA >= 13.0. If the issue is specific to 13.2, consider tightening the check to 13.2+ (or adjust the comment to reflect that all 13.x are affected) so we don’t unnecessarily change build characteristics for unaffected toolchains.
| if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") | ||
| # CUDA 13.2 ptxas can ICE on FlashMaskV2 SM90 forward instantiations | ||
| # even though these files are already split by head dimension. | ||
| set(flashmaskv2_sm90_fwd_ptxas_o0_sources ${flashmaskv2_sources_fwd_sm90}) | ||
| endif() |
There was a problem hiding this comment.
This section also describes a “CUDA 13.2 ptxas” ICE, but the workaround is applied whenever CMAKE_CUDA_COMPILER_VERSION >= 13.0. If only 13.2 is problematic, consider tightening the version predicate (or updating the comment) to avoid disabling ptxas optimizations on 13.0/13.1 unnecessarily.
| #if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 13 | ||
| // CUDA 13.x can ICE in the legacy FA1 dense-mask path when these loops are | ||
| // fully unrolled around shared-memory address arithmetic. | ||
| #define FMHA_COMPAT_UNROLL _Pragma("unroll 1") | ||
| #define FMHA_COMPAT_HELPER __device__ __noinline__ | ||
| using Smem_offset_type = uint32_t; | ||
| #else | ||
| #define FMHA_COMPAT_UNROLL _Pragma("unroll") | ||
| #define FMHA_COMPAT_HELPER inline __device__ | ||
| using Smem_offset_type = uint32_t; | ||
| #endif |
There was a problem hiding this comment.
FMHA_COMPAT_UNROLL / FMHA_COMPAT_HELPER are introduced as preprocessor macros in a public header, which will leak into any translation unit that includes this file. This repo’s fmha headers don’t appear to use/undef similar macros, so this increases risk of name collisions and makes debugging harder. Consider replacing these with #if blocks at the call sites (so no macro identifiers escape), or at least #undef FMHA_COMPAT_UNROLL / #undef FMHA_COMPAT_HELPER at the end of the header after the last use.
| #pragma once | ||
|
|
||
| #include "utils.h" | ||
| #include <cstddef> |
There was a problem hiding this comment.
#include <cstddef> is newly added but doesn’t appear to be used anywhere in this header (only size_t mentions are in comments). If it’s not needed for the CUDA 13.x workaround, consider removing it to keep includes minimal.
| #include <cstddef> |
|
|
||
| if(FA3_SPLIT_SM90_FWD_BY_HDIM AND fa3_sm90_fwd_ptxas_o0_sources) | ||
| set_property(SOURCE ${fa3_sm90_fwd_ptxas_o0_sources} APPEND PROPERTY | ||
| COMPILE_OPTIONS -Xptxas -O0) |
There was a problem hiding this comment.
-Xptxas -O0 disables the PTX assembler's optimizations. Could we explore alternatives that avoid the CUDA 13.2 ptxas ICE without sacrificing kernel performance?
|
|
||
| if(flashmaskv2_sm90_fwd_ptxas_o0_sources) | ||
| set_property(SOURCE ${flashmaskv2_sm90_fwd_ptxas_o0_sources} APPEND PROPERTY | ||
| COMPILE_OPTIONS -Xptxas -O0) |
There was a problem hiding this comment.
-Xptxas -O0 disables the PTX assembler's optimizations. Could we explore alternatives that avoid the CUDA 13.2 ptxas ICE without sacrificing kernel performance?
Co-authored-by: Codex <codex@openai.com>
|
@gouzil Why is the cutlass update needed here? |
跟 Dao-AILab#1860 的原因是一样的,在 cutlass v4.2.0 版本才修复, 但是我目前是直接更新到了 v4.3.5, 还在测试 paddle 那边的影响 |
@gouzil 加上Dao-AILab#1860 的修复之后,不升级cutlass还是会挂吗? |
会挂,编译的时候挂的 |
fix
修复 CUDA 13.2 下的编译错误
smem_tile.h地址计算重写, 替换为等价逻辑对 SM90 forward 实例化文件关闭ptxas优化,避免ptxas fatal: C7907params.num_consumers(例如,32)小于NumThreadsPerWarpGroup(128)时,整数除法会导致num_consumer_warpgroups_per_cluster为 0,从而在初始化期间导致编译器报错。改为向上取整除法,以确保最小值为 1。cc: @swgu98
最小复现案例
参考链接: