Skip to content

support csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h cuda132 build#153

Merged
GuoxiaWang merged 3 commits into
PaddlePaddle:mainfrom
gouzil:test/flash_support_cuda132
May 26, 2026
Merged

support csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h cuda132 build#153
GuoxiaWang merged 3 commits into
PaddlePaddle:mainfrom
gouzil:test/flash_support_cuda132

Conversation

@gouzil
Copy link
Copy Markdown
Member

@gouzil gouzil commented May 25, 2026

feat

csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h 支持 cuda 13.2 编译,当前 FA3适配 cuda 13.2 还比较麻烦, 暂时先在 paddle 主仓库跳过编译。

完整改动可以查看 (仅限能编译):#141
paddle 适配 cuda 13.2 pr: PaddlePaddle/Paddle#78720

Copilot AI review requested due to automatic review settings May 25, 2026 03:09
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Note

Copilot was unable to run its full agentic suite in this review.

Updates SM90 pipeline barrier initialization to correctly handle non-warpgroup-multiple consumer counts, and refactors FMHA shared-memory tile fragment loads/stores for readability and reuse.

Changes:

  • Round up params.num_consumers to whole warp-groups when computing mbarrier arrival counts (flashmask v2 + flash-attn v3).
  • Refactor Smem_tile_transpose fragment store/load into helper methods and apply minor formatting cleanups.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp Adjusts consumer arrival count computation to use ceil(#consumers / warpgroup_threads).
csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h Extracts fragment load/store helpers; restructures offset computations and removes dead comments.
csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp Same barrier consumer arrival count rounding fix as flashmask v2.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +1504 to +1513
uint32_t offset = smem_ + base_offset;
const uint32_t reg0 = frag.reg(0);
const uint32_t reg1 = frag.reg(1);
const uint32_t reg2 = frag.reg(2);
const uint32_t reg3 = frag.reg(3);
fmha::sts(offset + 0 * BYTES_PER_ROW, reg0);
fmha::sts(offset + 8 * BYTES_PER_ROW, reg2);
offset ^= 4 * BYTES_PER_STS;
fmha::sts(offset + 0 * BYTES_PER_ROW, reg1);
fmha::sts(offset + 8 * BYTES_PER_ROW, reg3);
// PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier.
//
// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0
// Count consumers in whole warpgroups. A single consumer warp still needs one
// PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier.
//
// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0
// Count consumers in whole warpgroups. A single consumer warp still needs one
@gouzil gouzil changed the title [WIP][test] support cuda132 build support csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h cuda132 build May 25, 2026
const uint32_t reg3 = frag.reg(3);
fmha::sts(offset + 0 * BYTES_PER_ROW, reg0);
fmha::sts(offset + 8 * BYTES_PER_ROW, reg2);
offset ^= 4 * BYTES_PER_STS;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个offset的变换是不是和修改前不等价

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@umiswing
Copy link
Copy Markdown
Member

LGTM

@GuoxiaWang GuoxiaWang merged commit 1f3e4bb into PaddlePaddle:main May 26, 2026
1 check passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants