Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/flash_attn_v3/cutlass
Submodule cutlass updated 2870 files
3 changes: 2 additions & 1 deletion csrc/flash_attn_v3/epilogue_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ struct CollectiveEpilogueFwd {
// ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;
using EpilogueTileMN = decltype(select<0, 1>(TileShape_MNK_PV{}));

using CopyOpR2S = std::conditional_t<
ArchTag::kMinComputeCapability >= 90,
// cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>()),
decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element, EpilogueTileMN>()),
AutoVectorizingCopyWithAssumedAlignment<128>
>;
using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
Expand Down
6 changes: 4 additions & 2 deletions csrc/flash_attn_v3/sm90_pipeline_no_cluster.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ using namespace cute;
// forward pass (especially hdim 128 causal). We instead reimplement the version of
// PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier.
//
// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0
// Count consumers in whole warpgroups. A single consumer warp still needs one
// mbarrier arrival count.
template <int Stages_, class Base=cutlass::PipelineTmaAsync<Stages_>>
class PipelineTmaAsyncNoCluster: public Base {
public:
Expand All @@ -39,7 +40,8 @@ class PipelineTmaAsyncNoCluster: public Base {
if (is_initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup;
uint32_t const num_consumer_warpgroups_per_cluster =
(params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup;
uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster;

cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
Expand Down
66 changes: 37 additions & 29 deletions csrc/flash_attn_with_bias_and_mask/src/fmha/smem_tile.h
Original file line number Diff line number Diff line change
Expand Up @@ -1270,7 +1270,9 @@ struct Smem_tile_mma {
// fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].y);
// fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].w);
// size_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset =
smem_write_ + mi * WARPS_M * 16 * BYTES_PER_ROW +
ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * BYTES_PER_STS;
Expand Down Expand Up @@ -1333,7 +1335,9 @@ struct Smem_tile_mma_transposed : public Base {
uint4 dst;
// fmha::ldsmt(dst, this->smem_ + offset);
// size_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset =
smem_read_ + mi * WARPS_M * 16 * BYTES_PER_ROW +
ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::ldsmt(dst, offset);
frag[mi][ni].reg(0) = dst.x;
frag[mi][ni].reg(1) = dst.z; // Fragment A regs col major!
Expand Down Expand Up @@ -1413,7 +1417,8 @@ struct Smem_tile_mma_epilogue : public Base {
// fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, y);
// fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, w);
// size_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset = (this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset =
(this->smem_write_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
// if ((threadIdx.x == 0) && (blockIdx.x == 0) && (blockIdx.y == 0)) {
// printf("mi = %d, ni = %d, offset - smem_write_ = %d\n", mi, ni, offset - this->smem_write_);
// }
Expand All @@ -1431,7 +1436,8 @@ struct Smem_tile_mma_epilogue : public Base {
for( int mi = 0; mi < M; mi++ ) {
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset = (this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
uint32_t offset =
(this->write_offset_ ^ (ni * 32)) + mi * WARPS_M * 16 * BYTES_PER_ROW;
fmha::sts(this->smem_ + offset + 0 * BYTES_PER_ROW, regs[mi][ni].x);
fmha::sts(this->smem_ + offset + 8 * BYTES_PER_ROW, regs[mi][ni].z);
offset ^= 4 * Base::BYTES_PER_STS;
Expand Down Expand Up @@ -1485,39 +1491,50 @@ struct Smem_tile_transpose {
write_col ^= (write_row & 0x07) * 4;

write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;
// smem_write_ = smem_ + write_row * BYTES_PER_ROW + write_col * BYTES_PER_STS;

int read_row, read_col;
read_row = (tidx & 0x0f);
read_col = (tidx & 0xe0) / 16 + (tidx & 0x1c) / 16;

read_col ^= (read_row & 0x07);
read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
// smem_read_ = smem_ + read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS;
}

inline __device__ void store_fragment(uint32_t base_offset, const Fragment_write &frag) {
uint32_t offset = smem_ + base_offset;
const uint32_t reg0 = frag.reg(0);
const uint32_t reg1 = frag.reg(1);
const uint32_t reg2 = frag.reg(2);
const uint32_t reg3 = frag.reg(3);
fmha::sts(offset + 0 * BYTES_PER_ROW, reg0);
fmha::sts(offset + 8 * BYTES_PER_ROW, reg2);
offset ^= 4 * BYTES_PER_STS;
fmha::sts(offset + 0 * BYTES_PER_ROW, reg1);
fmha::sts(offset + 8 * BYTES_PER_ROW, reg3);
}

inline __device__ uint4 load_fragment(uint32_t offset) {
uint4 dst;
fmha::ldsmt(dst, smem_ + offset);
return dst;
}

template<int M, int N>
inline __device__ void store(const Fragment_write (&frag_w)[M][N], int mi) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
const uint32_t base =
write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
store_fragment(base, frag_w[ni][mi]);
}
}

template<int N>
inline __device__ void load(Fragment_read (&frag_r)[N]) {
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
const uint4 dst = load_fragment(offset);
frag_r[ni].reg(0) = dst.x;
frag_r[ni].reg(1) = dst.y; // Fragment B regs col major!
frag_r[ni].reg(2) = dst.z;
Expand All @@ -1530,21 +1547,14 @@ struct Smem_tile_transpose {
static_assert(COLS == Cta_tile::N);
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(0));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(2));
offset ^= 4 * BYTES_PER_STS;
fmha::sts(smem_ + offset + 0 * BYTES_PER_ROW, frag_w[ni][mi].reg(1));
fmha::sts(smem_ + offset + 8 * BYTES_PER_ROW, frag_w[ni][mi].reg(3));
const uint32_t base =
write_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
store_fragment(base, frag_w[ni][mi]);
}
#pragma unroll
for( int ni = 0; ni < N; ni++ ) {
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
// size_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint32_t offset = read_offset_ + ni * WARPS_N * 16 * BYTES_PER_ELT;
uint4 dst;
fmha::ldsmt(dst, this->smem_ + offset);
const uint4 dst = load_fragment(offset);
frag_r[ni].reg(0) = dst.x;
frag_r[ni].reg(1) = dst.y; // Fragment B regs col major!
frag_r[ni].reg(2) = dst.z;
Expand All @@ -1555,8 +1565,6 @@ struct Smem_tile_transpose {
uint32_t smem_;
uint32_t write_offset_;
uint32_t read_offset_;
// uint32_t smem_write_;
// uint32_t smem_read_;
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 1 addition & 1 deletion csrc/flashmask_v2/cutlass
Submodule cutlass updated 2870 files
3 changes: 2 additions & 1 deletion csrc/flashmask_v2/epilogue_fwd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,12 @@ struct CollectiveEpilogueFwd {
// ((qhead_per_khead, seqlen_q), nheads_kv, batch, num_splits)
using ShapeLSEPacked = std::conditional_t<!PackGQA, cute::Shape<int32_t, int32_t, int32_t, int32_t>, cute::Shape<cute::Shape<int32_t, int32_t>, int32_t, int32_t, int32_t>>;
using StrideLSEPacked = std::conditional_t<!PackGQA, StrideLSE, cute::Stride<cute::Stride<int64_t, _1>, int64_t, int64_t, int64_t>>;
using EpilogueTileMN = decltype(select<0, 1>(TileShape_MNK_PV{}));

using CopyOpR2S = std::conditional_t<
ArchTag::kMinComputeCapability >= 90,
// cute::SM90_U32x4_STSM_N if Element size is 2 bytes (fp16, bf16)
decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element>()),
decltype(cutlass::epilogue::collective::detail::sm90_get_smem_store_op_for_accumulator<StrideO, Element, EpilogueTileMN>()),
AutoVectorizingCopyWithAssumedAlignment<128>
>;
using SmemCopyAtomO = Copy_Atom<CopyOpR2S, Element>;
Expand Down
6 changes: 4 additions & 2 deletions csrc/flashmask_v2/sm90_pipeline_no_cluster.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ using namespace cute;
// forward pass (especially hdim 128 causal). We instead reimplement the version of
// PipelineTmaAsync before v3.6.0 where only 1 out of 128 threads signals the barrier.
//
// Assumption: params.num_consumers % NumThreadsPerWarpGroup == 0
// Count consumers in whole warpgroups. A single consumer warp still needs one
// mbarrier arrival count.
template <int Stages_, class Base=cutlass::PipelineTmaAsync<Stages_>>
class PipelineTmaAsyncNoCluster: public Base {
public:
Expand All @@ -39,7 +40,8 @@ class PipelineTmaAsyncNoCluster: public Base {
if (is_initializing_warp) {
// Barrier FULL and EMPTY init
constexpr int producer_arv_cnt = 1;
uint32_t const num_consumer_warpgroups_per_cluster = params.num_consumers / NumThreadsPerWarpGroup;
uint32_t const num_consumer_warpgroups_per_cluster =
(params.num_consumers + NumThreadsPerWarpGroup - 1) / NumThreadsPerWarpGroup;
uint32_t const multicast_consumer_arrival_count = num_consumer_warpgroups_per_cluster;

cutlass::arch::detail::initialize_barrier_array_pair_aligned<decltype(storage.full_barrier_), decltype(storage.empty_barrier_), Stages>(
Expand Down