Skip to content

Fix CUDA 13.2 flash attention build compatibility#141

Open
gouzil wants to merge 4 commits into
PaddlePaddle:mainfrom
gouzil:codex/cuda-13-2-flashattn-compat
Open

Fix CUDA 13.2 flash attention build compatibility#141
gouzil wants to merge 4 commits into
PaddlePaddle:mainfrom
gouzil:codex/cuda-13-2-flashattn-compat

Conversation

@gouzil
Copy link
Copy Markdown
Member

@gouzil gouzil commented Apr 28, 2026

fix

修复 CUDA 13.2 下的编译错误

  • smem_tile.h 地址计算重写, 替换为等价逻辑
  • 对 SM90 forward 实例化文件关闭 ptxas 优化,避免 ptxas fatal: C7907
  • params.num_consumers(例如,32)小于 NumThreadsPerWarpGroup(128)时,整数除法会导致 num_consumer_warpgroups_per_cluster 为 0,从而在初始化期间导致编译器报错。改为向上取整除法,以确保最小值为 1。

cc: @swgu98

最小复现案例

// Standalone CUDA 13.2 front-end alloc_fe ICE repro for the smem_tile.h fix.
// Default compile fails; add -DREPRO_APPLY_SMEM_TILE_FIX=1 to validate the
// source workaround.

#include <cuda_runtime.h>
#include <cstddef>
#include <stdint.h>

extern "C" __device__ uint32_t __nvvm_get_smem_pointer(void* ptr);

#ifndef REPRO_APPLY_SMEM_TILE_FIX
#define REPRO_APPLY_SMEM_TILE_FIX 1
#endif

#if REPRO_APPLY_SMEM_TILE_FIX
#define REPRO_COMPAT_UNROLL _Pragma("unroll")
#define REPRO_COMPAT_HELPER __device__ inline
using SmemOffsetType = uint32_t;
#else
#define REPRO_COMPAT_UNROLL _Pragma("unroll")
using SmemOffsetType = uint32_t;
#endif

template <typename T>
__device__ inline uint32_t ToSmemPtr(T ptr) {
  return static_cast<uint32_t>(ptr);
}

struct Col {};

template <typename Layout>
struct FragmentB {
  __device__ inline uint32_t& reg(int ii) { return regs_[ii]; }
  __device__ inline const uint32_t& reg(int ii) const { return regs_[ii]; }
  uint32_t regs_[4];
};

__device__ inline void sts(uint32_t ptr, uint32_t val) {
  asm volatile("st.shared.b32 [%0], %1;\n" : : "r"(ptr), "r"(val));
}

template <int COLS_, int WARPS_N_>
struct SmemTileTranspose {
  using FragmentWrite = FragmentB<Col>;

  enum { COLS = COLS_ };
  enum { BYTES_PER_ELT = 2 };
  enum { BYTES_PER_STS = 4 };
  enum { BYTES_PER_ROW = COLS * BYTES_PER_ELT };
  enum { WARPS_N = WARPS_N_ };

  __device__ inline SmemTileTranspose(char* smem, int tidx) {
    smem_ = __nvvm_get_smem_pointer(smem);

    int write_row = (tidx & 0x1c) / 4;
    int write_col = (tidx & 0xe0) / 4 + (tidx & 0x03);
    write_col ^= (write_row & 0x07) * 4;
#if REPRO_APPLY_SMEM_TILE_FIX
    smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
#else
    write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
#endif

  }

#if REPRO_APPLY_SMEM_TILE_FIX
  REPRO_COMPAT_HELPER void store_fragment(SmemOffsetType base,
                                          const FragmentWrite& frag) {
    const SmemOffsetType swizzled = base ^ (4 * BYTES_PER_STS);
    sts(ToSmemPtr(base + 0 * BYTES_PER_ROW), frag.reg(0));
    sts(ToSmemPtr(base + 8 * BYTES_PER_ROW), frag.reg(2));
    sts(ToSmemPtr(swizzled + 0 * BYTES_PER_ROW), frag.reg(1));
    sts(ToSmemPtr(swizzled + 8 * BYTES_PER_ROW), frag.reg(3));
  }
#endif

  template <int M, int N>
  __device__ inline void store(const FragmentWrite (&frag_w)[M][N]) {
    REPRO_COMPAT_UNROLL
    for (int ni = 0; ni < N; ni++) {
#if REPRO_APPLY_SMEM_TILE_FIX
      const SmemOffsetType base =
          smem_write_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
      store_fragment(base, frag_w[ni][0]);
#else
      uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
      sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][0].reg(0));
      sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][0].reg(2));
      offset ^= 4 * BYTES_PER_STS;
      sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][0].reg(1));
      sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][0].reg(3));
#endif
    }
  }

  uint32_t smem_;
#if REPRO_APPLY_SMEM_TILE_FIX
  uint32_t smem_write_;
#else
  uint32_t write_offset_;
#endif
};

template <int COLS, int WARPS_N>
__device__ inline void device_1xN_with_mask_bias_fake() {
  extern __shared__ char smem[];
  const int tidx = threadIdx.x;
  const uint32_t seed = static_cast<uint32_t>(tidx);

  using SmemTile = SmemTileTranspose<COLS, WARPS_N>;
  using FragmentWrite = typename SmemTile::FragmentWrite;

  SmemTile tile(smem, tidx);

  FragmentWrite frag_w[1][1];

  frag_w[0][0].reg(0) = seed;
  frag_w[0][0].reg(1) = seed + 1;
  frag_w[0][0].reg(2) = seed + 2;
  frag_w[0][0].reg(3) = seed + 3;

  tile.store(frag_w);
}

__global__ void flashattn_hdim32_seq128_dropout0() {
  device_1xN_with_mask_bias_fake<128, 4>();
}
// Minimal standalone CUDA 13.2 ptxas C7907 reproducer.
// No Paddle/FlashAttention/CUTLASS headers required.
// Compile:
//   nvcc -O3 -std=c++17 -gencode arch=compute_90a,code=sm_90a \
//        -x cu -c cuda13_mbarrier_c7907_probe.cu
// Workaround validation:
//   add `-Xptxas -O0`

#include <cuda/barrier>

extern "C" __global__ void cuda13_mbarrier_probe() {
  using Barrier = cuda::barrier<cuda::thread_scope_block>;
  __shared__ alignas(Barrier) unsigned char storage[sizeof(Barrier)];
  auto *barrier = reinterpret_cast<Barrier *>(storage);
  if (threadIdx.x == 0) {
    init(barrier, 0);
  }
}

参考链接:

Co-authored-by: Codex <codex@openai.com>
Copilot AI review requested due to automatic review settings April 28, 2026 01:46
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes CUDA 13.2 toolchain build failures in the CUDA/CUTLASS FMHA codepaths by adjusting shared-memory address arithmetic patterns and by applying a targeted ptxas optimization workaround for SM90 forward instantiation translation units.

Changes:

  • Add CUDA 13.x compatibility helpers in smem_tile.h (limit unrolling + factor shared-memory address arithmetic into helper functions).
  • Update SM90 forward kernel instantiation selection to avoid hdimall/hdimdiff batching on CUDA 13+.
  • Apply -Xptxas -O0 to SM90 forward instantiation sources for FA3 and FlashMaskV2 on CUDA 13+.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.

File Description
csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h Adds CUDA 13.x ICE workaround via unroll control + helper functions for smem address arithmetic.
csrc/CMakeLists.txt Splits SM90 forward instantiations by head-dim on CUDA 13+ and applies ptxas -O0 to SM90 forward sources to avoid ptxas ICEs.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread csrc/CMakeLists.txt Outdated
Comment on lines +260 to +266
# CUDA 13.2 ptxas can ICE when all SM90 forward head dimensions are batched
# into the generated hdimall/hdimdiff translation units. Compile them as
# individual instantiation files on CUDA 13+ to keep each PTX module smaller.
set(FA3_SPLIT_SM90_FWD_BY_HDIM OFF)
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0")
set(FA3_SPLIT_SM90_FWD_BY_HDIM ON)
endif()
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment says the workaround is for “CUDA 13.2 ptxas”, but the version gate enables it for all CUDA >= 13.0. If the issue is specific to 13.2, consider tightening the check to 13.2+ (or adjust the comment to reflect that all 13.x are affected) so we don’t unnecessarily change build characteristics for unaffected toolchains.

Copilot uses AI. Check for mistakes.
Comment thread csrc/CMakeLists.txt Outdated
Comment on lines +628 to +632
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "13.0")
# CUDA 13.2 ptxas can ICE on FlashMaskV2 SM90 forward instantiations
# even though these files are already split by head dimension.
set(flashmaskv2_sm90_fwd_ptxas_o0_sources ${flashmaskv2_sources_fwd_sm90})
endif()
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section also describes a “CUDA 13.2 ptxas” ICE, but the workaround is applied whenever CMAKE_CUDA_COMPILER_VERSION >= 13.0. If only 13.2 is problematic, consider tightening the version predicate (or updating the comment) to avoid disabling ptxas optimizations on 13.0/13.1 unnecessarily.

Copilot uses AI. Check for mistakes.
Comment on lines +37 to +47
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 13
// CUDA 13.x can ICE in the legacy FA1 dense-mask path when these loops are
// fully unrolled around shared-memory address arithmetic.
#define FMHA_COMPAT_UNROLL _Pragma("unroll 1")
#define FMHA_COMPAT_HELPER __device__ __noinline__
using Smem_offset_type = uint32_t;
#else
#define FMHA_COMPAT_UNROLL _Pragma("unroll")
#define FMHA_COMPAT_HELPER inline __device__
using Smem_offset_type = uint32_t;
#endif
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FMHA_COMPAT_UNROLL / FMHA_COMPAT_HELPER are introduced as preprocessor macros in a public header, which will leak into any translation unit that includes this file. This repo’s fmha headers don’t appear to use/undef similar macros, so this increases risk of name collisions and makes debugging harder. Consider replacing these with #if blocks at the call sites (so no macro identifiers escape), or at least #undef FMHA_COMPAT_UNROLL / #undef FMHA_COMPAT_HELPER at the end of the header after the last use.

Copilot uses AI. Check for mistakes.
#pragma once

#include "utils.h"
#include <cstddef>
Copy link

Copilot AI Apr 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include <cstddef> is newly added but doesn’t appear to be used anywhere in this header (only size_t mentions are in comments). If it’s not needed for the CUDA 13.x workaround, consider removing it to keep includes minimal.

Suggested change
#include <cstddef>

Copilot uses AI. Check for mistakes.
Comment thread csrc/CMakeLists.txt Outdated

if(FA3_SPLIT_SM90_FWD_BY_HDIM AND fa3_sm90_fwd_ptxas_o0_sources)
set_property(SOURCE ${fa3_sm90_fwd_ptxas_o0_sources} APPEND PROPERTY
COMPILE_OPTIONS -Xptxas -O0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-Xptxas -O0 disables the PTX assembler's optimizations. Could we explore alternatives that avoid the CUDA 13.2 ptxas ICE without sacrificing kernel performance?

Comment thread csrc/CMakeLists.txt Outdated

if(flashmaskv2_sm90_fwd_ptxas_o0_sources)
set_property(SOURCE ${flashmaskv2_sm90_fwd_ptxas_o0_sources} APPEND PROPERTY
COMPILE_OPTIONS -Xptxas -O0)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-Xptxas -O0 disables the PTX assembler's optimizations. Could we explore alternatives that avoid the CUDA 13.2 ptxas ICE without sacrificing kernel performance?

Co-authored-by: Codex <codex@openai.com>
@umiswing
Copy link
Copy Markdown
Member

@gouzil Why is the cutlass update needed here?

@gouzil
Copy link
Copy Markdown
Member Author

gouzil commented May 11, 2026

@gouzil Why is the cutlass update needed here?

https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/pipeline/sm90_pipeline.hpp#L311-L313

Dao-AILab#1860 的原因是一样的,在 cutlass v4.2.0 版本才修复, 但是我目前是直接更新到了 v4.3.5, 还在测试 paddle 那边的影响

@umiswing
Copy link
Copy Markdown
Member

@gouzil Why is the cutlass update needed here?

https://github.com/NVIDIA/cutlass/blob/main/include/cutlass/pipeline/sm90_pipeline.hpp#L311-L313

Dao-AILab#1860 的原因是一样的,在 cutlass v4.2.0 版本才修复, 但是我目前是直接更新到了 v4.3.5, 还在测试 paddle 那边的影响

@gouzil 加上Dao-AILab#1860 的修复之后,不升级cutlass还是会挂吗?

@gouzil
Copy link
Copy Markdown
Member Author

gouzil commented May 13, 2026

@gouzil 加上Dao-AILab#1860 的修复之后,不升级cutlass还是会挂吗?

会挂,编译的时候挂的

@gouzil gouzil changed the title [WIP] Fix CUDA 13.2 flash attention build compatibility Fix CUDA 13.2 flash attention build compatibility May 14, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants