diff --git a/csrc/flash_attn_v3/cutlass b/csrc/flash_attn_v3/cutlass index afa17722036..4faf1a1568c 160000 --- a/csrc/flash_attn_v3/cutlass +++ b/csrc/flash_attn_v3/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit 4faf1a1568cf1e4ad8ff71846a13e16f2a6a6f6b diff --git a/csrc/flash_attn_v3/epilogue_fwd.hpp b/csrc/flash_attn_v3/epilogue_fwd.hpp index 69102e8c4e6..50a27463f0d 100644 --- a/csrc/flash_attn_v3/epilogue_fwd.hpp +++ b/csrc/flash_attn_v3/epilogue_fwd.hpp @@ -88,11 +88,12 @@ struct CollectiveEpilogueFwd { // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits) using ShapeLSEPacked = std::conditional_t, cute::Shape, int32_t, int32_t, int32_t>>; using StrideLSEPacked = std::conditional_t, int64_t, int64_t, int64_t>>; + using EpilogueTileMN = decltype(select<0, 1>(TileShape_MNK_PV{})); using CopyOpR2S = std::conditional_t< ArchTag::kMinComputeCapability >= 90, // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16) - decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), + decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), AutoVectorizingCopyWithAssumedAlignment<128> >; using SmemCopyAtomO = Copy_Atom; diff --git a/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp b/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp index 65a3d1554b3..c24f6150baa 100644 --- a/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp +++ b/csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp @@ -17,7 +17,8 @@ using namespace cute; // forward pass (especially hdim 128 causal). We instead reimplement the version of // PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. // -// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 +// Count consumers in whole warpgroups. A single consumer warp still needs one +// mbarrier arrival count. template > class PipelineTmaAsyncNoCluster: public Base { public: @@ -39,7 +40,8 @@ class PipelineTmaAsyncNoCluster: public Base { if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = + (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned( diff --git a/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h b/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h index 491253bb999..85900e5a2a0 100644 --- a/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h +++ b/csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h @@ -1270,7 +1270,9 @@ struct Smem_tile_mma { // fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y); // fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w); // size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = + smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + + ni * WARPS_N * 16 * BYTES_PER_ELT; fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); offset ^= 4 * BYTES_PER_STS; @@ -1333,7 +1335,9 @@ struct Smem_tile_mma_transposed : public Base { uint4 dst; // fmha::ldsmt(dst, this->smem_ + offset); // size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT; + uint32_t offset = + smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + + ni * WARPS_N * 16 * BYTES_PER_ELT; fmha::ldsmt(dst, offset); frag[mi][ni].reg(0) = dst.x; frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major! @@ -1413,7 +1417,8 @@ struct Smem_tile_mma_epilogue : public Base { // fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y); // fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w); // size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + uint32_t offset = + (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; // if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) { // printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_); // } @@ -1431,7 +1436,8 @@ struct Smem_tile_mma_epilogue : public Base { for( int mi = 0; mi < M; mi++ ) { for( int ni = 0; ni < N; ni++ ) { // size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; - uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; + uint32_t offset = + (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW; fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x); fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z); offset ^= 4 * Base::BYTES_PER_STS; @@ -1485,7 +1491,6 @@ struct Smem_tile_transpose { write_col ^= (write_row & 0x07) * 4; write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; - // smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS; int read_row, read_col; read_row = (tidx & 0x0f); @@ -1493,20 +1498,34 @@ struct Smem_tile_transpose { read_col ^= (read_row & 0x07); read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; - // smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; + } + + inline __device__ void store_fragment(uint32_t base_offset, const Fragment_write &frag) { + uint32_t offset = smem_ + base_offset; + const uint32_t reg0 = frag.reg(0); + const uint32_t reg1 = frag.reg(1); + const uint32_t reg2 = frag.reg(2); + const uint32_t reg3 = frag.reg(3); + fmha::sts(offset + 0 * BYTES_PER_ROW, reg0); + fmha::sts(offset + 8 * BYTES_PER_ROW, reg2); + offset ^= 4 * BYTES_PER_STS; + fmha::sts(offset + 0 * BYTES_PER_ROW, reg1); + fmha::sts(offset + 8 * BYTES_PER_ROW, reg3); + } + + inline __device__ uint4 load_fragment(uint32_t offset) { + uint4 dst; + fmha::ldsmt(dst, smem_ + offset); + return dst; } template inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) { #pragma unroll for( int ni = 0; ni < N; ni++ ) { - // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); + const uint32_t base = + write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + store_fragment(base, frag_w[ni][mi]); } } @@ -1514,10 +1533,8 @@ struct Smem_tile_transpose { inline __device__ void load(Fragment_read (&frag_r)[N]) { #pragma unroll for( int ni = 0; ni < N; ni++ ) { - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - fmha::ldsmt(dst, this->smem_ + offset); + const uint4 dst = load_fragment(offset); frag_r[ni].reg(0) = dst.x; frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! frag_r[ni].reg(2) = dst.z; @@ -1530,21 +1547,14 @@ struct Smem_tile_transpose { static_assert(COLS == Cta_tile::N); #pragma unroll for( int ni = 0; ni < N; ni++ ) { - // size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2)); - offset ^= 4 * BYTES_PER_STS; - fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1)); - fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3)); + const uint32_t base = + write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; + store_fragment(base, frag_w[ni][mi]); } #pragma unroll for( int ni = 0; ni < N; ni++ ) { - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - // size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT; - uint4 dst; - fmha::ldsmt(dst, this->smem_ + offset); + const uint4 dst = load_fragment(offset); frag_r[ni].reg(0) = dst.x; frag_r[ni].reg(1) = dst.y; // Fragment B regs col major! frag_r[ni].reg(2) = dst.z; @@ -1555,8 +1565,6 @@ struct Smem_tile_transpose { uint32_t smem_; uint32_t write_offset_; uint32_t read_offset_; - // uint32_t smem_write_; - // uint32_t smem_read_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/flashmask_v2/cutlass b/csrc/flashmask_v2/cutlass index afa17722036..4faf1a1568c 160000 --- a/csrc/flashmask_v2/cutlass +++ b/csrc/flashmask_v2/cutlass @@ -1 +1 @@ -Subproject commit afa1772203677c5118fcd82537a9c8fefbcc7008 +Subproject commit 4faf1a1568cf1e4ad8ff71846a13e16f2a6a6f6b diff --git a/csrc/flashmask_v2/epilogue_fwd.hpp b/csrc/flashmask_v2/epilogue_fwd.hpp index 69102e8c4e6..50a27463f0d 100644 --- a/csrc/flashmask_v2/epilogue_fwd.hpp +++ b/csrc/flashmask_v2/epilogue_fwd.hpp @@ -88,11 +88,12 @@ struct CollectiveEpilogueFwd { // ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits) using ShapeLSEPacked = std::conditional_t, cute::Shape, int32_t, int32_t, int32_t>>; using StrideLSEPacked = std::conditional_t, int64_t, int64_t, int64_t>>; + using EpilogueTileMN = decltype(select<0, 1>(TileShape_MNK_PV{})); using CopyOpR2S = std::conditional_t< ArchTag::kMinComputeCapability >= 90, // cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16) - decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), + decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator()), AutoVectorizingCopyWithAssumedAlignment<128> >; using SmemCopyAtomO = Copy_Atom; diff --git a/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp b/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp index 65a3d1554b3..c24f6150baa 100644 --- a/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp +++ b/csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp @@ -17,7 +17,8 @@ using namespace cute; // forward pass (especially hdim 128 causal). We instead reimplement the version of // PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier. // -// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0 +// Count consumers in whole warpgroups. A single consumer warp still needs one +// mbarrier arrival count. template > class PipelineTmaAsyncNoCluster: public Base { public: @@ -39,7 +40,8 @@ class PipelineTmaAsyncNoCluster: public Base { if (is_initializing_warp) { // Barrier FULL and EMPTY init constexpr int producer_arv_cnt = 1; - uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup; + uint32_t const num_consumer_warpgroups_per_cluster = + (params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup; uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster; cutlass::arch::detail::initialize_barrier_array_pair_aligned(