From 3b2221b33b4c100fdca6eed9fe8535d7634afa31 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Tue, 28 Apr 2026 09:40:38 +0800 Subject: [PATCH 1/3] Fix CUDA 13.2 flash attention build compatibility Co-authored-by: Codex --- csrc/CMakeLists.txt | 62 +++++++- .../src/fmha/smem_tile.h | 144 +++++++++++------- 2 files changed, 149 insertions(+), 57 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 13635729d46..7a8c2c2fc40 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -257,7 +257,19 @@ if(NOT SKIP_BUILD_FA) list(APPEND HEAD_DIMENSIONS_BWD 256) endif() - set(HEAD_DIMENSIONS_FWD "all" "diff") + # CUDA 13.2 ptxas can ICE when all SM90 forward head dimensions are batched + # into the generated hdimall/hdimdiff translation units. Compile them as + # individual instantiation files on CUDA 13+ to keep each PTX module smaller. + set(FA3_SPLIT_SM90_FWD_BY_HDIM OFF) + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") + set(FA3_SPLIT_SM90_FWD_BY_HDIM ON) + endif() + + if(FA3_SPLIT_SM90_FWD_BY_HDIM) + set(HEAD_DIMENSIONS_FWD ${HEAD_DIMENSIONS_BWD}) + else() + set(HEAD_DIMENSIONS_FWD "all" "diff") + endif() set(HEAD_DIMENSIONS_FWD_SM80 ${HEAD_DIMENSIONS_BWD}) set(SPLIT "__EMPTY__") @@ -303,6 +315,7 @@ if(NOT SKIP_BUILD_FA) endforeach() set(sources_fwd_sm90) + set(fa3_sm90_fwd_ptxas_o0_sources) foreach(hdim ${HEAD_DIMENSIONS_FWD}) foreach(dtype ${DTYPE_FWD_SM90}) foreach(split ${SPLIT}) @@ -321,6 +334,36 @@ if(NOT SKIP_BUILD_FA) endforeach() endforeach() + if(FA3_SPLIT_SM90_FWD_BY_HDIM) + foreach(dtype ${DTYPE_FWD_SM90}) + foreach(split ${SPLIT}) + foreach(paged ${PAGEDKV}) + foreach(softcap ${SOFTCAP}) + foreach(packgqa ${PACKGQA}) + if(packgqa STREQUAL "__EMPTY__" OR (paged STREQUAL "__EMPTY__" AND split STREQUAL "__EMPTY__")) + if(NOT dtype STREQUAL "e4m3") + set(name "flash_attn_v3/instantiations/flash_fwd_hdim64_512_${dtype}${paged}${split}${softcap}${packgqa}_sm90.cu") + string(REPLACE "__EMPTY__" "" refine_name "${name}") + list(APPEND sources_fwd_sm90 "${refine_name}") + endif() + set(name "flash_attn_v3/instantiations/flash_fwd_hdim192_128_${dtype}${paged}${split}${softcap}${packgqa}_sm90.cu") + string(REPLACE "__EMPTY__" "" refine_name "${name}") + list(APPEND sources_fwd_sm90 "${refine_name}") + endif() + endforeach() + endforeach() + endforeach() + endforeach() + endforeach() + endif() + + if(FA3_SPLIT_SM90_FWD_BY_HDIM) + # CUDA 13.2 ptxas can ICE on multiple SM90 forward instantiations even + # after hdimall/hdimdiff are split, so keep the workaround scoped to + # SM90 forward sources instead of weakening backward or SM80 builds. + set(fa3_sm90_fwd_ptxas_o0_sources ${sources_fwd_sm90}) + endif() + set(sources_bwd_sm80) foreach(hdim ${HEAD_DIMENSIONS_BWD}) foreach(dtype ${DTYPE_BWD}) @@ -390,6 +433,11 @@ if(NOT SKIP_BUILD_FA) --expt-relaxed-constexpr >) + if(FA3_SPLIT_SM90_FWD_BY_HDIM AND fa3_sm90_fwd_ptxas_o0_sources) + set_property(SOURCE ${fa3_sm90_fwd_ptxas_o0_sources} APPEND PROPERTY + COMPILE_OPTIONS -Xptxas -O0) + endif() + INSTALL(TARGETS flashattnv3 LIBRARY DESTINATION "lib") INSTALL(FILES flash_attn_v3/flash_api.h DESTINATION "include" RENAME flashv3_api.h) @@ -576,6 +624,13 @@ if(NOT SKIP_BUILD_FA) endforeach() endforeach() + set(flashmaskv2_sm90_fwd_ptxas_o0_sources) + if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") + # CUDA 13.2 ptxas can ICE on FlashMaskV2 SM90 forward instantiations + # even though these files are already split by head dimension. + set(flashmaskv2_sm90_fwd_ptxas_o0_sources ${flashmaskv2_sources_fwd_sm90}) + endif() + set(flashmaskv2_sources_bwd_sm80) foreach(hdim ${FLASHMASKV2_HEAD_DIMENSIONS_BWD}) foreach(dtype ${FLASHMASKV2_DTYPE_BWD}) @@ -685,6 +740,11 @@ if(NOT SKIP_BUILD_FA) --expt-relaxed-constexpr >) + if(flashmaskv2_sm90_fwd_ptxas_o0_sources) + set_property(SOURCE ${flashmaskv2_sm90_fwd_ptxas_o0_sources} APPEND PROPERTY + COMPILE_OPTIONS -Xptxas -O0) + endif() + if(WITH_DISTRIBUTED_OVERLAP) target_include_directories(flashmaskv2 PRIVATE ${NVSHMEM_INCLUDE_DIR}) target_link_libraries(flashmaskv2 PRIVATE flashmaskv2_distributed) diff --git a/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h b/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h index 491253bb999..6576a0e9a81 100644 --- a/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h +++ b/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h @@ -28,11 +28,30 @@ #pragma once #include "utils.h" +#include #include #include namespace fmha { +#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 13 +// CUDA 13.x can ICE in the legacy FA1 dense-mask path when these loops are +// fully unrolled around shared-memory address arithmetic. +#define FMHA_COMPAT_UNROLL _Pragma("unroll 1") +#define FMHA_COMPAT_HELPER __device__ __noinline__ +using Smem_offset_type = uint32_t; +#else +#define FMHA_COMPAT_UNROLL _Pragma("unroll") +#define FMHA_COMPAT_HELPER inline __device__ +using Smem_offset_type = uint32_t; +#endif + +template +inline __device__ uint32_t ToSmemPtr(T ptr) { + return static_cast(ptr); +} + + //////////////////////////////////////////////////////////////////////////////////////////////////// template< @@ -1259,9 +1278,9 @@ struct Smem_tile_mma { template inline __device__ void store(const uint4 (®s)[M][N]) { static_assert(COLS == Cta_tile::N); - #pragma unroll + FMHA_COMPAT_UNROLL for( int mi = 0; mi < M; mi++ ) { - #pragma unroll + FMHA_COMPAT_UNROLL for( int ni = 0; ni < N; ni++ ) { // size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); @@ -1270,12 +1289,14 @@ struct Smem_tile_mma { // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); // size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); - fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); + Smem_offset_type offset = + smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + + ni * WARPS_N * 16 * BYTES_PER_ELT; + fmha::sts(ToSmemPtr(offset + 0 * BYTES_PER_ROW), regs[mi][ni].x); + fmha::sts(ToSmemPtr(offset + 8 * BYTES_PER_ROW), regs[mi][ni].z); offset ^= 4 * BYTES_PER_STS; - fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); - fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); + fmha::sts(ToSmemPtr(offset + 0 * BYTES_PER_ROW), regs[mi][ni].y); + fmha::sts(ToSmemPtr(offset + 8 * BYTES_PER_ROW), regs[mi][ni].w); } } } @@ -1284,9 +1305,9 @@ struct Smem_tile_mma { inline __device__ void store(const Fragment (&frag)[N][M]) { static_assert(COLS == Cta_tile::N); uint4 regs[M][N]; - #pragma unroll + FMHA_COMPAT_UNROLL for( int mi = 0; mi < M; mi++ ) { - #pragma unroll + FMHA_COMPAT_UNROLL for( int ni = 0; ni < N; ni++ ) { // Need to transpose ref(1) and reg(2) here since when we load it we transpose again. regs[mi][ni] = make_uint4(frag[ni][mi].reg(0), frag[ni][mi].reg(2), @@ -1333,8 +1354,10 @@ struct Smem_tile_mma_transposed : public Base { uint4 dst; // fmha::ldsmt(dst, this->smem_ + offset); // size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::ldsmt(dst, offset); + Smem_offset_type offset = + smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + + ni * WARPS_N * 16 * BYTES_PER_ELT; + fmha::ldsmt(dst, ToSmemPtr(offset)); frag[mi][ni].reg(0) = dst.x; frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major! frag[mi][ni].reg(2) = dst.y; @@ -1379,16 +1402,16 @@ struct Smem_tile_mma_epilogue : public Base { // size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; // fmha::lds(data[ii], this->smem_ + offset); // size_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; - uint32_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; - fmha::lds(data[ii], offset); + Smem_offset_type offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; + fmha::lds(data[ii], ToSmemPtr(offset)); } } template inline __device__ void store(const Acc (&acc)[M][N]){ - #pragma unroll + FMHA_COMPAT_UNROLL for( int mi = 0; mi < M; mi++ ) { - #pragma unroll + FMHA_COMPAT_UNROLL for( int ni = 0; ni < N; ni++ ) { // 1st row - 4 elements per row. float tmp00 = acc[mi][ni].elt(0); @@ -1413,15 +1436,16 @@ struct Smem_tile_mma_epilogue : public Base { // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y); // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w); // size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + Smem_offset_type offset = + (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_); // } - fmha::sts(offset + 0 * BYTES_PER_ROW, x); - fmha::sts(offset + 8 * BYTES_PER_ROW, z); + fmha::sts(ToSmemPtr(offset + 0 * BYTES_PER_ROW), x); + fmha::sts(ToSmemPtr(offset + 8 * BYTES_PER_ROW), z); offset ^= 4 * Base::BYTES_PER_STS; - fmha::sts(offset + 0 * BYTES_PER_ROW, y); - fmha::sts(offset + 8 * BYTES_PER_ROW, w); + fmha::sts(ToSmemPtr(offset + 0 * BYTES_PER_ROW), y); + fmha::sts(ToSmemPtr(offset + 8 * BYTES_PER_ROW), w); } } } @@ -1431,12 +1455,13 @@ struct Smem_tile_mma_epilogue : public Base { for( int mi = 0; mi < M; mi++ ) { for( int ni = 0; ni < N; ni++ ) { // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); - fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); + Smem_offset_type offset = + (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + fmha::sts(ToSmemPtr(this->smem_ + offset + 0 * BYTES_PER_ROW), regs[mi][ni].x); + fmha::sts(ToSmemPtr(this->smem_ + offset + 8 * BYTES_PER_ROW), regs[mi][ni].z); offset ^= 4 * Base::BYTES_PER_STS; - fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); - fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); + fmha::sts(ToSmemPtr(this->smem_ + offset + 0 * BYTES_PER_ROW), regs[mi][ni].y); + fmha::sts(ToSmemPtr(this->smem_ + offset + 8 * BYTES_PER_ROW), regs[mi][ni].w); } } } @@ -1485,7 +1510,6 @@ struct Smem_tile_transpose { write_col ^= (write_row & 0x07) * 4; write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; - // smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; int read_row, read_col; read_row = (tidx & 0x0f); @@ -1493,31 +1517,48 @@ struct Smem_tile_transpose { read_col ^= (read_row & 0x07); read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - // smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + FMHA_COMPAT_HELPER + void store_fragment(Smem_offset_type base_offset, const Fragment_write &frag) { + const uint32_t ptr0 = ToSmemPtr(smem_ + base_offset + 0 * BYTES_PER_ROW); + const uint32_t ptr1 = ToSmemPtr(smem_ + base_offset + 8 * BYTES_PER_ROW); + const Smem_offset_type swizzled_offset = base_offset ^ (4 * BYTES_PER_STS); + const uint32_t ptr2 = ToSmemPtr(smem_ + swizzled_offset + 0 * BYTES_PER_ROW); + const uint32_t ptr3 = ToSmemPtr(smem_ + swizzled_offset + 8 * BYTES_PER_ROW); + const uint32_t reg0 = frag.reg(0); + const uint32_t reg1 = frag.reg(1); + const uint32_t reg2 = frag.reg(2); + const uint32_t reg3 = frag.reg(3); + fmha::sts(ptr0, reg0); + fmha::sts(ptr1, reg2); + fmha::sts(ptr2, reg1); + fmha::sts(ptr3, reg3); + } + + FMHA_COMPAT_HELPER + uint4 load_fragment(Smem_offset_type offset) { + uint4 dst; + fmha::ldsmt(dst, ToSmemPtr(smem_ + offset)); + return dst; } template inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) { - #pragma unroll + FMHA_COMPAT_UNROLL for( int ni = 0; ni < N; ni++ ) { - // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); + const Smem_offset_type base = + write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + store_fragment(base, frag_w[ni][mi]); } } template inline __device__ void load(Fragment_read (&frag_r)[N]) { - #pragma unroll + FMHA_COMPAT_UNROLL for( int ni = 0; ni < N; ni++ ) { - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - fmha::ldsmt(dst, this->smem_ + offset); + Smem_offset_type offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + const uint4 dst = load_fragment(offset); frag_r[ni].reg(0) = dst.x; frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! frag_r[ni].reg(2) = dst.z; @@ -1528,23 +1569,16 @@ struct Smem_tile_transpose { template inline __device__ void transpose(const Fragment_write (&frag_w)[M][N], Fragment_read (&frag_r)[M], int mi) { static_assert(COLS == Cta_tile::N); - #pragma unroll + FMHA_COMPAT_UNROLL for( int ni = 0; ni < N; ni++ ) { - // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); + const Smem_offset_type base = + write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + store_fragment(base, frag_w[ni][mi]); } - #pragma unroll + FMHA_COMPAT_UNROLL for( int ni = 0; ni < N; ni++ ) { - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - fmha::ldsmt(dst, this->smem_ + offset); + Smem_offset_type offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + const uint4 dst = load_fragment(offset); frag_r[ni].reg(0) = dst.x; frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! frag_r[ni].reg(2) = dst.z; @@ -1555,8 +1589,6 @@ struct Smem_tile_transpose { uint32_t smem_; uint32_t write_offset_; uint32_t read_offset_; - // uint32_t smem_write_; - // uint32_t smem_read_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// From acac4d1e55ab48d84a747f4658ded97427db4825 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Sun, 10 May 2026 15:21:25 +0800 Subject: [PATCH 2/3] Update SM90 pipeline barrier handling Co-authored-by: Codex --- csrc/CMakeLists.txt | 62 +------------------ csrc/flash_attn_v3/cutlass | 2 +- .../sm90_pipeline_no_cluster.hpp | 6 +- csrc/flashmask_v2/cutlass | 2 +- .../flashmask_v2/sm90_pipeline_no_cluster.hpp | 6 +- 5 files changed, 11 insertions(+), 67 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 7a8c2c2fc40..13635729d46 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -257,19 +257,7 @@ if(NOT SKIP_BUILD_FA) list(APPEND HEAD_DIMENSIONS_BWD 256) endif() - # CUDA 13.2 ptxas can ICE when all SM90 forward head dimensions are batched - # into the generated hdimall/hdimdiff translation units. Compile them as - # individual instantiation files on CUDA 13+ to keep each PTX module smaller. - set(FA3_SPLIT_SM90_FWD_BY_HDIM OFF) - if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") - set(FA3_SPLIT_SM90_FWD_BY_HDIM ON) - endif() - - if(FA3_SPLIT_SM90_FWD_BY_HDIM) - set(HEAD_DIMENSIONS_FWD ${HEAD_DIMENSIONS_BWD}) - else() - set(HEAD_DIMENSIONS_FWD "all" "diff") - endif() + set(HEAD_DIMENSIONS_FWD "all" "diff") set(HEAD_DIMENSIONS_FWD_SM80 ${HEAD_DIMENSIONS_BWD}) set(SPLIT "__EMPTY__") @@ -315,7 +303,6 @@ if(NOT SKIP_BUILD_FA) endforeach() set(sources_fwd_sm90) - set(fa3_sm90_fwd_ptxas_o0_sources) foreach(hdim ${HEAD_DIMENSIONS_FWD}) foreach(dtype ${DTYPE_FWD_SM90}) foreach(split ${SPLIT}) @@ -334,36 +321,6 @@ if(NOT SKIP_BUILD_FA) endforeach() endforeach() - if(FA3_SPLIT_SM90_FWD_BY_HDIM) - foreach(dtype ${DTYPE_FWD_SM90}) - foreach(split ${SPLIT}) - foreach(paged ${PAGEDKV}) - foreach(softcap ${SOFTCAP}) - foreach(packgqa ${PACKGQA}) - if(packgqa STREQUAL "__EMPTY__" OR (paged STREQUAL "__EMPTY__" AND split STREQUAL "__EMPTY__")) - if(NOT dtype STREQUAL "e4m3") - set(name "flash_attn_v3/instantiations/flash_fwd_hdim64_512_${dtype}${paged}${split}${softcap}${packgqa}_sm90.cu") - string(REPLACE "__EMPTY__" "" refine_name "${name}") - list(APPEND sources_fwd_sm90 "${refine_name}") - endif() - set(name "flash_attn_v3/instantiations/flash_fwd_hdim192_128_${dtype}${paged}${split}${softcap}${packgqa}_sm90.cu") - string(REPLACE "__EMPTY__" "" refine_name "${name}") - list(APPEND sources_fwd_sm90 "${refine_name}") - endif() - endforeach() - endforeach() - endforeach() - endforeach() - endforeach() - endif() - - if(FA3_SPLIT_SM90_FWD_BY_HDIM) - # CUDA 13.2 ptxas can ICE on multiple SM90 forward instantiations even - # after hdimall/hdimdiff are split, so keep the workaround scoped to - # SM90 forward sources instead of weakening backward or SM80 builds. - set(fa3_sm90_fwd_ptxas_o0_sources ${sources_fwd_sm90}) - endif() - set(sources_bwd_sm80) foreach(hdim ${HEAD_DIMENSIONS_BWD}) foreach(dtype ${DTYPE_BWD}) @@ -433,11 +390,6 @@ if(NOT SKIP_BUILD_FA) --expt-relaxed-constexpr >) - if(FA3_SPLIT_SM90_FWD_BY_HDIM AND fa3_sm90_fwd_ptxas_o0_sources) - set_property(SOURCE ${fa3_sm90_fwd_ptxas_o0_sources} APPEND PROPERTY - COMPILE_OPTIONS -Xptxas -O0) - endif() - INSTALL(TARGETS flashattnv3 LIBRARY DESTINATION "lib") INSTALL(FILES flash_attn_v3/flash_api.h DESTINATION "include" RENAME flashv3_api.h) @@ -624,13 +576,6 @@ if(NOT SKIP_BUILD_FA) endforeach() endforeach() - set(flashmaskv2_sm90_fwd_ptxas_o0_sources) - if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0") - # CUDA 13.2 ptxas can ICE on FlashMaskV2 SM90 forward instantiations - # even though these files are already split by head dimension. - set(flashmaskv2_sm90_fwd_ptxas_o0_sources ${flashmaskv2_sources_fwd_sm90}) - endif() - set(flashmaskv2_sources_bwd_sm80) foreach(hdim ${FLASHMASKV2_HEAD_DIMENSIONS_BWD}) foreach(dtype ${FLASHMASKV2_DTYPE_BWD}) @@ -740,11 +685,6 @@ if(NOT SKIP_BUILD_FA) --expt-relaxed-constexpr >) - if(flashmaskv2_sm90_fwd_ptxas_o0_sources) - set_property(SOURCE ${flashmaskv2_sm90_fwd_ptxas_o0_sources} APPEND PROPERTY - COMPILE_OPTIONS -Xptxas -O0) - endif() - if(WITH_DISTRIBUTED_OVERLAP) target_include_directories(flashmaskv2 PRIVATE ${NVSHMEM_INCLUDE_DIR}) target_link_libraries(flashmaskv2 PRIVATE flashmaskv2_distributed) diff --git a/csrc/flash_attn_v3/cutlass b/csrc/flash_attn_v3/cutlass index afa17722036..4faf1a1568c 160000 --- a/csrc/flash_attn_v3/cutlass +++ b/csrc/flash_attn_v3/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit 4faf1a1568cf1e4ad8ff71846a13e16f2a6a6f6b diff --git a/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp b/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp index 65a3d1554b3..c24f6150baa 100644 --- a/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp +++ b/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp @@ -17,7 +17,8 @@ using namespace cute; // forward pass (especially hdim 128 causal). We instead reimplement the version of // PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. // -// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 +// Count consumers in whole warpgroups. A single consumer warp still needs one +// mbarrier arrival count. template > class PipelineTmaAsyncNoCluster: public Base { public: @@ -39,7 +40,8 @@ class PipelineTmaAsyncNoCluster: public Base { if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = + (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned( diff --git a/csrc/flashmask_v2/cutlass b/csrc/flashmask_v2/cutlass index afa17722036..4faf1a1568c 160000 --- a/csrc/flashmask_v2/cutlass +++ b/csrc/flashmask_v2/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit 4faf1a1568cf1e4ad8ff71846a13e16f2a6a6f6b diff --git a/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp b/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp index 65a3d1554b3..c24f6150baa 100644 --- a/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp +++ b/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp @@ -17,7 +17,8 @@ using namespace cute; // forward pass (especially hdim 128 causal). We instead reimplement the version of // PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. // -// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 +// Count consumers in whole warpgroups. A single consumer warp still needs one +// mbarrier arrival count. template > class PipelineTmaAsyncNoCluster: public Base { public: @@ -39,7 +40,8 @@ class PipelineTmaAsyncNoCluster: public Base { if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = + (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned( From fc8209e65f647ee3e09f12f8090b513b93767522 Mon Sep 17 00:00:00 2001 From: gouzi <530971494@qq.com> Date: Wed, 13 May 2026 10:25:07 +0800 Subject: [PATCH 3/3] Enhance the EpilogueTileMN type to support memory storage operations for CUDA 13.2 --- csrc/flash_attn_v3/epilogue_fwd.hpp | 3 +- .../src/fmha/smem_tile.h | 108 +++++++----------- csrc/flashmask_v2/epilogue_fwd.hpp | 3 +- 3 files changed, 46 insertions(+), 68 deletions(-) diff --git a/csrc/flash_attn_v3/epilogue_fwd.hpp b/csrc/flash_attn_v3/epilogue_fwd.hpp index 69102e8c4e6..50a27463f0d 100644 --- a/csrc/flash_attn_v3/epilogue_fwd.hpp +++ b/csrc/flash_attn_v3/epilogue_fwd.hpp @@ -88,11 +88,12 @@ struct CollectiveEpilogueFwd { // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits) using ShapeLSEPacked = std::conditional_t, cute::Shape, int32_t, int32_t, int32_t>>; using StrideLSEPacked = std::conditional_t, int64_t, int64_t, int64_t>>; + using EpilogueTileMN = decltype(select<0, 1>(TileShape_MNK_PV{})); using CopyOpR2S = std::conditional_t< ArchTag::kMinComputeCapability >= 90, // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16) - decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), + decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), AutoVectorizingCopyWithAssumedAlignment<128> >; using SmemCopyAtomO = Copy_Atom; diff --git a/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h b/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h index 6576a0e9a81..85900e5a2a0 100644 --- a/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h +++ b/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h @@ -28,30 +28,11 @@ #pragma once #include "utils.h" -#include #include #include namespace fmha { -#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 13 -// CUDA 13.x can ICE in the legacy FA1 dense-mask path when these loops are -// fully unrolled around shared-memory address arithmetic. -#define FMHA_COMPAT_UNROLL _Pragma("unroll 1") -#define FMHA_COMPAT_HELPER __device__ __noinline__ -using Smem_offset_type = uint32_t; -#else -#define FMHA_COMPAT_UNROLL _Pragma("unroll") -#define FMHA_COMPAT_HELPER inline __device__ -using Smem_offset_type = uint32_t; -#endif - -template -inline __device__ uint32_t ToSmemPtr(T ptr) { - return static_cast(ptr); -} - - //////////////////////////////////////////////////////////////////////////////////////////////////// template< @@ -1278,9 +1259,9 @@ struct Smem_tile_mma { template inline __device__ void store(const uint4 (®s)[M][N]) { static_assert(COLS == Cta_tile::N); - FMHA_COMPAT_UNROLL + #pragma unroll for( int mi = 0; mi < M; mi++ ) { - FMHA_COMPAT_UNROLL + #pragma unroll for( int ni = 0; ni < N; ni++ ) { // size_t offset = write_offset_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); @@ -1289,14 +1270,14 @@ struct Smem_tile_mma { // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); // size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - Smem_offset_type offset = + uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(ToSmemPtr(offset + 0 * BYTES_PER_ROW), regs[mi][ni].x); - fmha::sts(ToSmemPtr(offset + 8 * BYTES_PER_ROW), regs[mi][ni].z); + fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); + fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); offset ^= 4 * BYTES_PER_STS; - fmha::sts(ToSmemPtr(offset + 0 * BYTES_PER_ROW), regs[mi][ni].y); - fmha::sts(ToSmemPtr(offset + 8 * BYTES_PER_ROW), regs[mi][ni].w); + fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); + fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); } } } @@ -1305,9 +1286,9 @@ struct Smem_tile_mma { inline __device__ void store(const Fragment (&frag)[N][M]) { static_assert(COLS == Cta_tile::N); uint4 regs[M][N]; - FMHA_COMPAT_UNROLL + #pragma unroll for( int mi = 0; mi < M; mi++ ) { - FMHA_COMPAT_UNROLL + #pragma unroll for( int ni = 0; ni < N; ni++ ) { // Need to transpose ref(1) and reg(2) here since when we load it we transpose again. regs[mi][ni] = make_uint4(frag[ni][mi].reg(0), frag[ni][mi].reg(2), @@ -1354,10 +1335,10 @@ struct Smem_tile_mma_transposed : public Base { uint4 dst; // fmha::ldsmt(dst, this->smem_ + offset); // size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - Smem_offset_type offset = + uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::ldsmt(dst, ToSmemPtr(offset)); + fmha::ldsmt(dst, offset); frag[mi][ni].reg(0) = dst.x; frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major! frag[mi][ni].reg(2) = dst.y; @@ -1402,16 +1383,16 @@ struct Smem_tile_mma_epilogue : public Base { // size_t offset = read_offset_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; // fmha::lds(data[ii], this->smem_ + offset); // size_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; - Smem_offset_type offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; - fmha::lds(data[ii], ToSmemPtr(offset)); + uint32_t offset = smem_read_ + ii * ROWS_PER_LDS * BYTES_PER_ROW; + fmha::lds(data[ii], offset); } } template inline __device__ void store(const Acc (&acc)[M][N]){ - FMHA_COMPAT_UNROLL + #pragma unroll for( int mi = 0; mi < M; mi++ ) { - FMHA_COMPAT_UNROLL + #pragma unroll for( int ni = 0; ni < N; ni++ ) { // 1st row - 4 elements per row. float tmp00 = acc[mi][ni].elt(0); @@ -1436,16 +1417,16 @@ struct Smem_tile_mma_epilogue : public Base { // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y); // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w); // size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - Smem_offset_type offset = + uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_); // } - fmha::sts(ToSmemPtr(offset + 0 * BYTES_PER_ROW), x); - fmha::sts(ToSmemPtr(offset + 8 * BYTES_PER_ROW), z); + fmha::sts(offset + 0 * BYTES_PER_ROW, x); + fmha::sts(offset + 8 * BYTES_PER_ROW, z); offset ^= 4 * Base::BYTES_PER_STS; - fmha::sts(ToSmemPtr(offset + 0 * BYTES_PER_ROW), y); - fmha::sts(ToSmemPtr(offset + 8 * BYTES_PER_ROW), w); + fmha::sts(offset + 0 * BYTES_PER_ROW, y); + fmha::sts(offset + 8 * BYTES_PER_ROW, w); } } } @@ -1455,13 +1436,13 @@ struct Smem_tile_mma_epilogue : public Base { for( int mi = 0; mi < M; mi++ ) { for( int ni = 0; ni < N; ni++ ) { // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - Smem_offset_type offset = + uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - fmha::sts(ToSmemPtr(this->smem_ + offset + 0 * BYTES_PER_ROW), regs[mi][ni].x); - fmha::sts(ToSmemPtr(this->smem_ + offset + 8 * BYTES_PER_ROW), regs[mi][ni].z); + fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); + fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); offset ^= 4 * Base::BYTES_PER_STS; - fmha::sts(ToSmemPtr(this->smem_ + offset + 0 * BYTES_PER_ROW), regs[mi][ni].y); - fmha::sts(ToSmemPtr(this->smem_ + offset + 8 * BYTES_PER_ROW), regs[mi][ni].w); + fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); + fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); } } } @@ -1519,35 +1500,30 @@ struct Smem_tile_transpose { read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; } - FMHA_COMPAT_HELPER - void store_fragment(Smem_offset_type base_offset, const Fragment_write &frag) { - const uint32_t ptr0 = ToSmemPtr(smem_ + base_offset + 0 * BYTES_PER_ROW); - const uint32_t ptr1 = ToSmemPtr(smem_ + base_offset + 8 * BYTES_PER_ROW); - const Smem_offset_type swizzled_offset = base_offset ^ (4 * BYTES_PER_STS); - const uint32_t ptr2 = ToSmemPtr(smem_ + swizzled_offset + 0 * BYTES_PER_ROW); - const uint32_t ptr3 = ToSmemPtr(smem_ + swizzled_offset + 8 * BYTES_PER_ROW); + inline __device__ void store_fragment(uint32_t base_offset, const Fragment_write &frag) { + uint32_t offset = smem_ + base_offset; const uint32_t reg0 = frag.reg(0); const uint32_t reg1 = frag.reg(1); const uint32_t reg2 = frag.reg(2); const uint32_t reg3 = frag.reg(3); - fmha::sts(ptr0, reg0); - fmha::sts(ptr1, reg2); - fmha::sts(ptr2, reg1); - fmha::sts(ptr3, reg3); + fmha::sts(offset + 0 * BYTES_PER_ROW, reg0); + fmha::sts(offset + 8 * BYTES_PER_ROW, reg2); + offset ^= 4 * BYTES_PER_STS; + fmha::sts(offset + 0 * BYTES_PER_ROW, reg1); + fmha::sts(offset + 8 * BYTES_PER_ROW, reg3); } - FMHA_COMPAT_HELPER - uint4 load_fragment(Smem_offset_type offset) { + inline __device__ uint4 load_fragment(uint32_t offset) { uint4 dst; - fmha::ldsmt(dst, ToSmemPtr(smem_ + offset)); + fmha::ldsmt(dst, smem_ + offset); return dst; } template inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) { - FMHA_COMPAT_UNROLL + #pragma unroll for( int ni = 0; ni < N; ni++ ) { - const Smem_offset_type base = + const uint32_t base = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; store_fragment(base, frag_w[ni][mi]); } @@ -1555,9 +1531,9 @@ struct Smem_tile_transpose { template inline __device__ void load(Fragment_read (&frag_r)[N]) { - FMHA_COMPAT_UNROLL + #pragma unroll for( int ni = 0; ni < N; ni++ ) { - Smem_offset_type offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; const uint4 dst = load_fragment(offset); frag_r[ni].reg(0) = dst.x; frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! @@ -1569,15 +1545,15 @@ struct Smem_tile_transpose { template inline __device__ void transpose(const Fragment_write (&frag_w)[M][N], Fragment_read (&frag_r)[M], int mi) { static_assert(COLS == Cta_tile::N); - FMHA_COMPAT_UNROLL + #pragma unroll for( int ni = 0; ni < N; ni++ ) { - const Smem_offset_type base = + const uint32_t base = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; store_fragment(base, frag_w[ni][mi]); } - FMHA_COMPAT_UNROLL + #pragma unroll for( int ni = 0; ni < N; ni++ ) { - Smem_offset_type offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; const uint4 dst = load_fragment(offset); frag_r[ni].reg(0) = dst.x; frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! diff --git a/csrc/flashmask_v2/epilogue_fwd.hpp b/csrc/flashmask_v2/epilogue_fwd.hpp index 69102e8c4e6..50a27463f0d 100644 --- a/csrc/flashmask_v2/epilogue_fwd.hpp +++ b/csrc/flashmask_v2/epilogue_fwd.hpp @@ -88,11 +88,12 @@ struct CollectiveEpilogueFwd { // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits) using ShapeLSEPacked = std::conditional_t, cute::Shape, int32_t, int32_t, int32_t>>; using StrideLSEPacked = std::conditional_t, int64_t, int64_t, int64_t>>; + using EpilogueTileMN = decltype(select<0, 1>(TileShape_MNK_PV{})); using CopyOpR2S = std::conditional_t< ArchTag::kMinComputeCapability >= 90, // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16) - decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), + decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), AutoVectorizingCopyWithAssumedAlignment<128> >; using SmemCopyAtomO = Copy_Atom;