support csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h cuda132 build#153
Merged
Merged
Conversation
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Updates SM90 pipeline barrier initialization to correctly handle non-warpgroup-multiple consumer counts, and refactors FMHA shared-memory tile fragment loads/stores for readability and reuse.
Changes:
- Round up
params.num_consumersto whole warp-groups when computing mbarrier arrival counts (flashmask v2 + flash-attn v3). - Refactor
Smem_tile_transposefragment store/load into helper methods and apply minor formatting cleanups.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp | Adjusts consumer arrival count computation to use ceil(#consumers / warpgroup_threads). |
| csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h | Extracts fragment load/store helpers; restructures offset computations and removes dead comments. |
| csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp | Same barrier consumer arrival count rounding fix as flashmask v2. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+1504
to
+1513
| uint32_t offset = smem_ + base_offset; | ||
| const uint32_t reg0 = frag.reg(0); | ||
| const uint32_t reg1 = frag.reg(1); | ||
| const uint32_t reg2 = frag.reg(2); | ||
| const uint32_t reg3 = frag.reg(3); | ||
| fmha::sts(offset + 0 * BYTES_PER_ROW, reg0); | ||
| fmha::sts(offset + 8 * BYTES_PER_ROW, reg2); | ||
| offset ^= 4 * BYTES_PER_STS; | ||
| fmha::sts(offset + 0 * BYTES_PER_ROW, reg1); | ||
| fmha::sts(offset + 8 * BYTES_PER_ROW, reg3); |
| // PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. | ||
| // | ||
| // Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 | ||
| // Count consumers in whole warpgroups. A single consumer warp still needs one |
| // PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. | ||
| // | ||
| // Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 | ||
| // Count consumers in whole warpgroups. A single consumer warp still needs one |
csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h cuda132 build
umiswing
suggested changes
May 26, 2026
| const uint32_t reg3 = frag.reg(3); | ||
| fmha::sts(offset + 0 * BYTES_PER_ROW, reg0); | ||
| fmha::sts(offset + 8 * BYTES_PER_ROW, reg2); | ||
| offset ^= 4 * BYTES_PER_STS; |
Member
|
LGTM |
GuoxiaWang
approved these changes
May 26, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
feat
csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h支持 cuda 13.2 编译,当前 FA3适配 cuda 13.2 还比较麻烦, 暂时先在 paddle 主仓库跳过编译。完整改动可以查看 (仅限能编译):#141
paddle 适配 cuda 13.2 pr: PaddlePaddle/Paddle#78720