From 3cd7c25ee18322e257fae26627b04bea8b73247f Mon Sep 17 00:00:00 2001 From: umiswing Date: Thu, 19 Jun 2025 17:01:41 +0800 Subject: [PATCH 1/4] fix dq_accum dv_accum numeric overflow on very long sequence --- csrc/flash_attn_v3/epilogue_bwd.hpp | 8 ++++---- csrc/flash_attn_v3/flash_bwd_launch_template.h | 10 +++++----- csrc/flash_attn_v3/flash_bwd_postprocess_kernel.h | 4 ++-- csrc/flash_attn_v3/flash_bwd_preprocess_kernel.h | 4 ++-- csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp | 6 +++--- 5 files changed, 16 insertions(+), 16 deletions(-) diff --git a/csrc/flash_attn_v3/epilogue_bwd.hpp b/csrc/flash_attn_v3/epilogue_bwd.hpp index 9362b040453..d8684541e10 100644 --- a/csrc/flash_attn_v3/epilogue_bwd.hpp +++ b/csrc/flash_attn_v3/epilogue_bwd.hpp @@ -87,7 +87,7 @@ struct CollectiveEpilogueBwd { cute::array_aligned, SmemAlignmentdKV> smem_dv; }; - using ShapedKV = cute::Shape; // (seqlen_k, d, head, batch) + using ShapedKV = cute::Shape; // (seqlen_k, d, head, batch) using StridedKV = cute::Stride; using TMA_dKV = std::conditional_t< @@ -350,7 +350,7 @@ struct CollectiveEpilogueBwdGQA { }; using TensorStorage = std::conditional_t; - using ShapedKV = cute::Shape; // (seqlen_k_rounded * d, head, batch) + using ShapedKV = cute::Shape; // (seqlen_k_rounded * d, head, batch) using StridedKV = cute::Stride<_1, int64_t, int64_t>; // Host side kernel arguments @@ -420,8 +420,8 @@ struct CollectiveEpilogueBwdGQA { bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0); Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dKaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); - Tensor gdKaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdKaccum), Shape>{}, make_coord(n_block)); // (M * K) - Tensor gdVaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) + Tensor gdKaccum = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdKaccum), Shape>{}, make_coord(n_block)); // (M * K) + Tensor gdVaccum = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) R2STiledCopydKVaccum r2s_tiled_copy_dKVaccum; auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx); diff --git a/csrc/flash_attn_v3/flash_bwd_launch_template.h b/csrc/flash_attn_v3/flash_bwd_launch_template.h index 76ded0407ec..a8564a78435 100644 --- a/csrc/flash_attn_v3/flash_bwd_launch_template.h +++ b/csrc/flash_attn_v3/flash_bwd_launch_template.h @@ -61,8 +61,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { static_cast(params.softmax_lse_log2_ptr), {_1{}, seqlen_q_rounded, !is_varlen_q ? params.h * params.seqlen_q_rounded : 0}, // stride_LSE_log2 static_cast(params.dq_accum_ptr), - {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum - {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * seqlen_q_rounded * params.h : 0}, // stride_dQaccum + {int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, params.h, batch_q}, // shape_dQaccum + {_1{}, int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, !is_varlen_q ? int64_t{params.d_rounded} * int64_t{seqlen_q_rounded} * int64_t{params.h} : 0}, // stride_dQaccum params.b, params.dq_semaphore, params.cu_seqlens_q, @@ -112,7 +112,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { static_cast(params.do_ptr), {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO static_cast(params.dq_accum_ptr), - {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum + {int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, params.h, batch_q}, // shape_dQaccum {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum static_cast(params.softmax_lse_log2_ptr), {seqlen_q_rounded, params.h, batch_q}, // shape_LSE @@ -134,7 +134,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { if constexpr (!GQA) { return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.d, params.h, batch_k}; // shape_dK } else { - return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}; // shape_dKaccum + return typename CollectiveEpilogue::ShapedKV {int64_t{seqlen_k_rounded} * int64_t{params.d_rounded}, params.h_k, batch_k}; // shape_dKaccum } }(), [&] { @@ -216,7 +216,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { >; typename PostprocessKernel::Arguments postprocess_args { static_cast(params.dq_accum_ptr), - {seqlen_q_rounded * params.d_rounded, params.h, batch_q}, // shape_dQaccum + {int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, params.h, batch_q}, // shape_dQaccum {_1{}, seqlen_q_rounded * params.d_rounded, !is_varlen_q ? params.d_rounded * params.seqlen_q_rounded * params.h : 0}, // stride_dQaccum static_cast(params.dq_ptr), {seqlen_q, params.d, params.h, batch_q}, // shape_dQ diff --git a/csrc/flash_attn_v3/flash_bwd_postprocess_kernel.h b/csrc/flash_attn_v3/flash_bwd_postprocess_kernel.h index c91e261507d..9fa19f04ed1 100644 --- a/csrc/flash_attn_v3/flash_bwd_postprocess_kernel.h +++ b/csrc/flash_attn_v3/flash_bwd_postprocess_kernel.h @@ -104,7 +104,7 @@ class FlashAttnBwdPostprocessConvertdQ { using ShapedQ = cute::Shape; // (seqlen_q, d, head, batch) using StridedQ = cute::Stride; - using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; // Device side arguments @@ -174,7 +174,7 @@ class FlashAttnBwdPostprocessConvertdQ { // Step 1: load dQaccum from gmem to smem Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdQaccum = local_tile(domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(m_block)); // (M * K) + Tensor gdQaccum = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdQaccum), Shape>{}, make_coord(m_block)); // (M * K) if constexpr (IsSm90) { // Use BulkCopy static constexpr uint32_t TmaTransactionBytesdQaccum = static_cast(size(SmemLayoutdQaccumFlat{}) * cute::sizeof_bits_v / 8); auto bulk_copy = Copy_Traits{}; diff --git a/csrc/flash_attn_v3/flash_bwd_preprocess_kernel.h b/csrc/flash_attn_v3/flash_bwd_preprocess_kernel.h index 85e877f9d4f..668357e7719 100644 --- a/csrc/flash_attn_v3/flash_bwd_preprocess_kernel.h +++ b/csrc/flash_attn_v3/flash_bwd_preprocess_kernel.h @@ -63,7 +63,7 @@ class FlashAttnBwdPreprocess { using StrideO = cute::Stride; using ShapedPsum = cute::Shape; // (seqlen_q, head, batch) using StridedPsum = cute::Stride<_1, int64_t, int64_t>; - using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; // Device side arguments @@ -230,7 +230,7 @@ class FlashAttnBwdPreprocess { if constexpr (Clear_dQaccum) { Tensor mdQaccum = make_tensor(make_gmem_ptr(params.ptr_dQaccum), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(seqlen_info.offset_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(m_block)); + Tensor gdQaccum = local_tile(cute::domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdQaccum), Shape>{}, make_coord(m_block)); GmemTiledCopyAccum gmem_tiled_copy_dQaccum; auto gmem_thr_copy_dQaccum = gmem_tiled_copy_dQaccum.get_thread_slice(thread_idx); Tensor tdQgdQaccum = gmem_thr_copy_dQaccum.partition_D(gdQaccum); diff --git a/csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp b/csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp index 71cfb020469..7e1904a3ba2 100644 --- a/csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -215,7 +215,7 @@ struct CollectiveMainloopBwdSm90 { using StrideQKV = cute::Stride; using ShapeLSE = cute::Shape; // (seqlen, head, batch) using StrideLSE = cute::Stride<_1, int64_t, int64_t>; // (seqlen, head, batch) - using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) + using ShapedQaccum = cute::Shape; // (seqlen_q * d, head, batch) using StridedQaccum = cute::Stride<_1, int64_t, int64_t>; using TMA_QdO = decltype(make_tma_copy_A_sm90( @@ -608,7 +608,7 @@ struct CollectiveMainloopBwdSm90 { bool const is_varlen = Varlen && params.cu_seqlens_q; Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) + Tensor gdQaccum_ = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_q_padded} * int64_t{kHeadDim}), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) int const num_batch = params.num_batch; @@ -785,7 +785,7 @@ struct CollectiveMainloopBwdSm90 { bool const is_varlen = Varlen && params.cu_seqlens_q; Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), params.shape_dQaccum, params.stride_dQaccum)(_, bidh, !is_varlen ? bidb : 0); - Tensor gdQaccum_ = local_tile(domain_offset(make_coord(seqlen_info.offset_q_padded * kHeadDim), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) + Tensor gdQaccum_ = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_q_padded} * int64_t{kHeadDim}), mdQaccum), Shape>{}, make_coord(_)); // (M * K, _) Tensor gdQaccum = cute::flat_divide(gdQaccum_, Int{}); // (M * K / WG, WG, _) // We can reuse r2s_thr_copy_dQaccum for this partitioning Tensor tdQgdQaccum = r2s_thr_copy_dQaccum.partition_D(gdQaccum); From e74f69e8b61857bcbde84f6a8abcddbb9a52c591 Mon Sep 17 00:00:00 2001 From: umiswing Date: Thu, 19 Jun 2025 17:15:55 +0800 Subject: [PATCH 2/4] revert unnessary int64_t --- csrc/flash_attn_v3/epilogue_bwd.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/flash_attn_v3/epilogue_bwd.hpp b/csrc/flash_attn_v3/epilogue_bwd.hpp index d8684541e10..8fa880f3671 100644 --- a/csrc/flash_attn_v3/epilogue_bwd.hpp +++ b/csrc/flash_attn_v3/epilogue_bwd.hpp @@ -87,7 +87,7 @@ struct CollectiveEpilogueBwd { cute::array_aligned, SmemAlignmentdKV> smem_dv; }; - using ShapedKV = cute::Shape; // (seqlen_k, d, head, batch) + using ShapedKV = cute::Shape; // (seqlen_k, d, head, batch) using StridedKV = cute::Stride; using TMA_dKV = std::conditional_t< From 1ac962eb1d58744d1948a960c7ca2c8ea5675fae Mon Sep 17 00:00:00 2001 From: Zichao <40557101+hxzd5568@users.noreply.github.com> Date: Wed, 4 Jun 2025 11:25:23 +0800 Subject: [PATCH 3/4] Support hdimQK != hdimV backward (#64) --- csrc/flash_attn_v3/epilogue_bwd.hpp | 45 ++++++++++++------- csrc/flash_attn_v3/flash_api.cu | 20 ++++----- .../flash_attn_v3/flash_bwd_launch_template.h | 19 +++++--- csrc/flash_attn_v3/mainloop_bwd_sm80.hpp | 28 ++++++++---- .../mainloop_bwd_sm90_tma_gmma_ws.hpp | 13 ++++-- flash_attn/flash_attn_interface.py | 12 ++--- 6 files changed, 86 insertions(+), 51 deletions(-) diff --git a/csrc/flash_attn_v3/epilogue_bwd.hpp b/csrc/flash_attn_v3/epilogue_bwd.hpp index 8fa880f3671..1ab95b8ca31 100644 --- a/csrc/flash_attn_v3/epilogue_bwd.hpp +++ b/csrc/flash_attn_v3/epilogue_bwd.hpp @@ -107,6 +107,7 @@ struct CollectiveEpilogueBwd { ShapedKV const shape_dK; StridedKV const stride_dK; Element* ptr_dV; + ShapedKV const shape_dV; StridedKV const stride_dV; int const num_heads_q; int* dk_semaphore; @@ -121,6 +122,7 @@ struct CollectiveEpilogueBwd { ShapedKV const shape_dK; StridedKV const stride_dK; Element* ptr_dV; + ShapedKV const shape_dV; StridedKV const stride_dV; TMA_dKV tma_store_dK, tma_store_dV; int const* cu_seqlens = nullptr; @@ -130,7 +132,7 @@ struct CollectiveEpilogueBwd { static Params to_underlying_arguments(Arguments const& args) { Tensor mdK = make_tensor(make_gmem_ptr(args.ptr_dK), args.shape_dK, args.stride_dK); - Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dK, args.stride_dV); + Tensor mdV = make_tensor(make_gmem_ptr(args.ptr_dV), args.shape_dV, args.stride_dV); TMA_dKV tma_store_dK = [&] { if constexpr (Use_TMA) { return make_tma_copy(GmemTiledCopydKVTMA{}, mdK, SmemLayoutdKVTMA{}, select<1, 2>(TileShape_MNK{}), _1{}); // no mcast for dKV @@ -145,7 +147,7 @@ struct CollectiveEpilogueBwd { return nullptr; } }(); - return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.stride_dV, + return {args.ptr_dK, args.shape_dK, args.stride_dK, args.ptr_dV, args.shape_dV, args.stride_dV, tma_store_dK, tma_store_dV, args.cu_seqlens, args.seqused}; } @@ -197,7 +199,7 @@ struct CollectiveEpilogueBwd { cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); Tensor mdK = params.tma_store_dK.get_tma_tensor(params.shape_dK); - Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dK); + Tensor mdV = params.tma_store_dV.get_tma_tensor(params.shape_dV); Tensor gdK = local_tile(mdK(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) Tensor gdV = local_tile(mdV(_, _, bidh, bidb), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) auto block_tma_dK = params.tma_store_dK.get_slice(_0{}); @@ -227,7 +229,7 @@ struct CollectiveEpilogueBwd { bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) GmemTiledCopydKV gmem_tiled_copy_dKV; @@ -241,25 +243,28 @@ struct CollectiveEpilogueBwd { Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_N,BLK_K) -> (blk_n,blk_k) // Repeat the partitioning with identity layouts Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdV))); + Tensor tdKVpdV = make_tensor(make_shape(size<2>(tdKVgdV))); + Tensor tdKVpdK = make_tensor(make_shape(size<2>(tdKVgdK))); #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } + for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); } + #pragma unroll + for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } // Need to check OOB when reading from smem if kBlockN isn't evenly tiled static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(size<0>(GmemLayoutAtom{})) == 0; flash::copy( - gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdKV, kBlockN); + gmem_tiled_copy_dKV, tdKVsdV, tdKVrdV, tdKVcdKV, tdKVpdV, kBlockN); flash::copy( - gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdKV, kBlockN); + gmem_tiled_copy_dKV, tdKVsdK, tdKVrdK, tdKVcdKV, tdKVpdK, kBlockN); // // Tell warp 0 that smem_k and smem_v are ready // cutlass::arch::fence_view_async_shared(); // ensure smem reads are done before next TMA to smem_k/v // flash::named_barrier_arrive(NumEpilogueThreads + cutlass::NumThreadsPerWarp, static_cast(BwdNamedBarriers::KVEmpty) /*id*/); // Construct identity layout for gdKV // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) + gmem_tiled_copy_dKV, tdKVrdV, tdKVgdV, tdKVcdKV, tdKVpdV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) ); flash::copy( - gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdKV, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) + gmem_tiled_copy_dKV, tdKVrdK, tdKVgdK, tdKVcdKV, tdKVpdK, std::min(seqlen_info.seqlen - n_block * kBlockN, kBlockN) ); } } @@ -282,7 +287,7 @@ struct CollectiveEpilogueBwd { bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdK = make_tensor(make_gmem_ptr(params.ptr_dK), params.shape_dK, params.stride_dK)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdK = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdK), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) - Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dK, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); + Tensor mdV = make_tensor(make_gmem_ptr(params.ptr_dV), params.shape_dV, params.stride_dV)(_, _, bidh, !is_varlen ? bidb : 0); Tensor gdV = local_tile(cute::domain_offset(make_coord(seqlen_info.offset, _0{}), mdV), select<1, 2>(TileShape_MNK{}), make_coord(n_block, _0{})); // (M, K) GmemTiledCopydKV gmem_tiled_copy_dKV; @@ -295,15 +300,18 @@ struct CollectiveEpilogueBwd { Tensor cdKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); // (BLK_M,BLK_K) -> (blk_m,blk_k) // Repeat the partitioning with identity layouts Tensor tdKVcdKV = gmem_thr_copy_dKV.partition_D(cdKV); - Tensor tdKVpdKV = make_tensor(make_shape(size<2>(tdKVgdK))); + Tensor tdKVpdK = make_tensor(make_shape(size<2>(tdKVgdK))); + Tensor tdKVpdV = make_tensor(make_shape(size<2>(tdKVgdV))); + #pragma unroll + for (int k = 0; k < size(tdKVpdK); ++k) { tdKVpdK(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } #pragma unroll - for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dK); } + for (int k = 0; k < size(tdKVpdV); ++k) { tdKVpdV(k) = get<1>(tdKVcdKV(_0{}, _0{}, k)) < get<1>(params.shape_dV); } // Clear_OOB_K must be false since we don't want to write zeros to gmem flash::copy( - gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN + gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdK, tdKVcdKV, tdKVpdK, seqlen_info.seqlen - n_block * kBlockN ); flash::copy( - gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdKV, seqlen_info.seqlen - n_block * kBlockN + gmem_tiled_copy_dKV, tdKVrdKV, tdKVgdV, tdKVcdKV, tdKVpdV, seqlen_info.seqlen - n_block * kBlockN ); } @@ -359,6 +367,7 @@ struct CollectiveEpilogueBwdGQA { ShapedKV const shape_dKaccum; StridedKV const stride_dKaccum; ElementAccum* ptr_dVaccum; + ShapedKV const shape_dVaccum; StridedKV const stride_dVaccum; int num_heads_q; int* dk_semaphore; @@ -373,6 +382,7 @@ struct CollectiveEpilogueBwdGQA { ShapedKV const shape_dKaccum; StridedKV const stride_dKaccum; ElementAccum* ptr_dVaccum; + ShapedKV const shape_dVaccum; StridedKV const stride_dVaccum; cutlass::FastDivmod qhead_per_khead_divmod; int* dk_semaphore; @@ -387,7 +397,7 @@ struct CollectiveEpilogueBwdGQA { assert(args.dk_semaphore != nullptr); assert(args.dv_semaphore != nullptr); } - return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.stride_dVaccum, + return {args.ptr_dKaccum, args.shape_dKaccum, args.stride_dKaccum, args.ptr_dVaccum, args.shape_dVaccum, args.stride_dVaccum, cutlass::FastDivmod(cute::ceil_div(args.num_heads_q, get<1>(args.shape_dKaccum))), args.dk_semaphore, args.dv_semaphore, args.cu_seqlens, args.seqused}; @@ -419,7 +429,8 @@ struct CollectiveEpilogueBwdGQA { flash::SeqlenInfo seqlen_info{bidb, size<0>(params.shape_dKaccum), params.cu_seqlens, params.seqused}; bool const is_varlen = Varlen && params.cu_seqlens; Tensor mdKaccum = make_tensor(make_gmem_ptr(params.ptr_dKaccum), params.shape_dKaccum, params.stride_dKaccum)(_, bidh_kv, !is_varlen ? bidb : 0); - Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dKaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); + + Tensor mdVaccum = make_tensor(make_gmem_ptr(params.ptr_dVaccum), params.shape_dVaccum, params.stride_dVaccum)(_, bidh_kv, !is_varlen ? bidb : 0); Tensor gdKaccum = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdKaccum), Shape>{}, make_coord(n_block)); // (M * K) Tensor gdVaccum = local_tile(domain_offset(make_coord(int64_t{seqlen_info.offset_padded} * int64_t{kHeadDim}), mdVaccum), Shape>{}, make_coord(n_block)); // (M * K) diff --git a/csrc/flash_attn_v3/flash_api.cu b/csrc/flash_attn_v3/flash_api.cu index 3983c8c5c4d..65cbad2e71d 100644 --- a/csrc/flash_attn_v3/flash_api.cu +++ b/csrc/flash_attn_v3/flash_api.cu @@ -236,38 +236,38 @@ void run_mha_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { if (!params.is_bf16) { #ifndef FLASHATTENTION_DISABLE_FP16 #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_bwd_(params, stream); } + if (params.d_rounded <= 64) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_bwd_(params, stream); } + if (params.d_rounded <= 96) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_bwd_(params, stream); } + if (params.d_rounded <= 128) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_bwd_(params, stream); } + if (params.d_rounded <= 192) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_bwd_(params, stream); } + if (params.d_rounded <= 256) { return run_mha_bwd_(params, stream); } #endif #else PADDLE_CHECK(false, "This flash attention build does not support FP16."); #endif } else { #ifndef FLASHATTENTION_DISABLE_HDIM64 - if (params.d <= 64) { return run_mha_bwd_(params, stream); } + if (params.d_rounded <= 64) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM96 - if (params.d <= 96) { return run_mha_bwd_(params, stream); } + if (params.d_rounded <= 96) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM128 - if (params.d <= 128) { return run_mha_bwd_(params, stream); } + if (params.d_rounded <= 128) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM192 - if (params.d <= 192) { return run_mha_bwd_(params, stream); } + if (params.d_rounded <= 192) { return run_mha_bwd_(params, stream); } #endif #ifndef FLASHATTENTION_DISABLE_HDIM256 - if (params.d <= 256) { return run_mha_bwd_(params, stream); } + if (params.d_rounded <= 256) { return run_mha_bwd_(params, stream); } #endif } }); diff --git a/csrc/flash_attn_v3/flash_bwd_launch_template.h b/csrc/flash_attn_v3/flash_bwd_launch_template.h index a8564a78435..607fb2e6c4d 100644 --- a/csrc/flash_attn_v3/flash_bwd_launch_template.h +++ b/csrc/flash_attn_v3/flash_bwd_launch_template.h @@ -49,7 +49,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { using PreprocessKernel = flash::FlashAttnBwdPreprocess; typename PreprocessKernel::Arguments preprocess_args { static_cast(params.o_ptr), - {seqlen_q, params.d, params.h, batch_q}, // shape_O + {seqlen_q, params.dv, params.h, batch_q}, // shape_O {params.o_row_stride, _1{}, params.o_head_stride, !is_varlen_q ? params.o_batch_stride : 0}, // stride_O static_cast(params.do_ptr), {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO @@ -108,8 +108,10 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { {seqlen_k, params.d, params.h_k, batch_k}, // shape_K {params.k_row_stride, _1{}, params.k_head_stride, !is_varlen_k ? params.k_batch_stride : 0}, // stride_K static_cast(params.v_ptr), + {seqlen_k, params.dv, params.h_k, batch_k}, // shape_V {params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0}, // stride_V static_cast(params.do_ptr), + {seqlen_q, params.dv, params.h, batch_q}, // shape_dO {params.do_row_stride, _1{}, params.do_head_stride, !is_varlen_q ? params.do_batch_stride : 0}, // stride_dO static_cast(params.dq_accum_ptr), {int64_t{seqlen_q_rounded} * int64_t{params.d_rounded}, params.h, batch_q}, // shape_dQaccum @@ -145,11 +147,18 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { } }(), static_cast(!GQA ? params.dv_ptr : params.dv_accum_ptr), + [&] { + if constexpr (!GQA) { + return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k}; // shape_dV + } else { + return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}; // shape_dVaccum + } + }(), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV } else { - return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum + return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum } }(), params.h, @@ -256,10 +265,10 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args); typename PostprocessKerneldKV::Arguments postprocess_dV_args { static_cast(params.dv_accum_ptr), - {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dVaccum - {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum + {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}, // shape_dVaccum + {_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum static_cast(params.dv_ptr), - {seqlen_k, params.d, params.h_k, batch_k}, // shape_dV + {seqlen_k, params.dv, params.h_k, batch_k}, // shape_dV {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV 1.f, params.cu_seqlens_k, diff --git a/csrc/flash_attn_v3/mainloop_bwd_sm80.hpp b/csrc/flash_attn_v3/mainloop_bwd_sm80.hpp index 0a79670f475..c5043e7037d 100644 --- a/csrc/flash_attn_v3/mainloop_bwd_sm80.hpp +++ b/csrc/flash_attn_v3/mainloop_bwd_sm80.hpp @@ -284,8 +284,10 @@ struct CollectiveMainloopBwdSm80 { ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; + ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; + ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; @@ -315,8 +317,10 @@ struct CollectiveMainloopBwdSm80 { ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; + ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; + ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; @@ -352,8 +356,8 @@ struct CollectiveMainloopBwdSm80 { // (the original softmax_scale) at the end. return {args.ptr_Q, args.shape_Q, args.stride_Q, args.ptr_K, args.shape_K, args.stride_K, - args.ptr_V, args.stride_V, - args.ptr_dO, args.stride_dO, + args.ptr_V, args.shape_V, args.stride_V, + args.ptr_dO, args.shape_dO, args.stride_dO, args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), args.ptr_LSE_log2, args.shape_LSE, args.stride_LSE_log2, args.ptr_dPsum, args.stride_dPsum, @@ -413,9 +417,9 @@ struct CollectiveMainloopBwdSm80 { bool const is_varlen_k = Varlen && params.cu_seqlens_k; int bidh_kv = params.qhead_per_khead_divmod.divide(bidh); Tensor mQ = make_tensor(make_gmem_ptr(params.ptr_Q), params.shape_Q, params.stride_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_Q, params.stride_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mdO = make_tensor(make_gmem_ptr(params.ptr_dO), params.shape_dO, params.stride_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK = make_tensor(make_gmem_ptr(params.ptr_K), params.shape_K, params.stride_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); - Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_K, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mV = make_tensor(make_gmem_ptr(params.ptr_V), params.shape_V, params.stride_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdQaccum = make_tensor(make_gmem_ptr(reinterpret_cast(params.ptr_dQaccum)), @@ -527,6 +531,9 @@ struct CollectiveMainloopBwdSm80 { for (int k = 0; k < size(tQpQ); ++k) { tQpQ(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_Q); } Tensor cLSE = cute::make_identity_tensor(select<0>(TileShape_MNK{})); Tensor tLSEcLSE = gmem_thr_copy_lse.partition_S(cLSE); + Tensor tdOpdO = make_tensor(make_shape(size<2>(tdOsdO))); + #pragma unroll + for (int k = 0; k < size(tdOpdO); ++k) { tdOpdO(k) = get<1>(tQcQ(_0{}, _0{}, k)) < get<1>(params.shape_dO); } int const seqlen_q = seqlen_info.seqlen_q; int const seqlen_k = seqlen_info.seqlen_k; @@ -545,9 +552,12 @@ struct CollectiveMainloopBwdSm80 { Tensor cKV = cute::make_identity_tensor(select<1, 2>(TileShape_MNK{})); Tensor tKVcKV = gmem_thr_copy_QKV.partition_S(cKV); Tensor t0KVcKV = gmem_thr0_copy_QKV.partition_S(cKV); - Tensor tKVpKV = make_tensor(make_shape(size<2>(tKsK))); + Tensor tKpK = make_tensor(make_shape(size<2>(tKsK))); + Tensor tVpV = make_tensor(make_shape(size<2>(tVsV))); + #pragma unroll + for (int k = 0; k < size(tKpK); ++k) { tKpK(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } #pragma unroll - for (int k = 0; k < size(tKVpKV); ++k) { tKVpKV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_K); } + for (int k = 0; k < size(tVpV); ++k) { tVpV(k) = get<1>(tKVcKV(_0{}, _0{}, k)) < get<1>(params.shape_V); } // Do we need bound check to make sure the row doesn't go above kBlockN static constexpr bool EvenN = kBlockN % CUTE_STATIC_V(shape<0>(GmemLayoutAtom{})) == 0; // static_assert(EvenN); // It simplifies the loading of K and V @@ -567,7 +577,7 @@ struct CollectiveMainloopBwdSm80 { bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; #pragma unroll for (int k = 0; k < size<2>(tVsV); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tVgV(_, m, k), tVsV(_, m, k)); + cute::copy(gmem_tiled_copy_QKV.with(tVpV(k) && predicate_n), tVgV(_, m, k), tVsV(_, m, k)); } } } @@ -580,7 +590,7 @@ struct CollectiveMainloopBwdSm80 { bool const predicate_n = get<0>(t0KVcKV(_0{}, m, _0{})) < seqlenk_row_limit; #pragma unroll for (int k = 0; k < size<2>(tKsK); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tKVpKV(k) && predicate_n), tKgK(_, m, k), tKsK(_, m, k)); + cute::copy(gmem_tiled_copy_QKV.with(tKpK(k) && predicate_n), tKgK(_, m, k), tKsK(_, m, k)); } } } @@ -653,7 +663,7 @@ struct CollectiveMainloopBwdSm80 { bool const predicate_m = get<0>(t0QcQ(_0{}, m, _0{})) < seqlenq_row_limit; #pragma unroll for (int k = 0; k < size<2>(tdOsdO); ++k) { - cute::copy(gmem_tiled_copy_QKV.with(tQpQ(k) && predicate_m), tdOgdO_cur(_, m, k), tdOsdO_cur(_, m, k)); + cute::copy(gmem_tiled_copy_QKV.with(tdOpdO(k) && predicate_m), tdOgdO_cur(_, m, k), tdOsdO_cur(_, m, k)); } } } diff --git a/csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp b/csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp index 7e1904a3ba2..5c6aea98abc 100644 --- a/csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp +++ b/csrc/flash_attn_v3/mainloop_bwd_sm90_tma_gmma_ws.hpp @@ -298,8 +298,10 @@ struct CollectiveMainloopBwdSm90 { ShapeQKV const shape_K; StrideQKV const stride_K; Element const* const ptr_V; + ShapeQKV const shape_V; StrideQKV const stride_V; Element const* const ptr_dO; + ShapeQKV const shape_dO; StrideQKV const stride_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; @@ -324,6 +326,8 @@ struct CollectiveMainloopBwdSm90 { struct Params { ShapeQKV const shape_Q; ShapeQKV const shape_K; + ShapeQKV const shape_V; + ShapeQKV const shape_dO; ElementAccum* const ptr_dQaccum; ShapedQaccum const shape_dQaccum; StridedQaccum stride_dQaccum; @@ -356,7 +360,7 @@ struct CollectiveMainloopBwdSm90 { SmemLayoutQ{}(_, _, _0{}), TileShape_MNK{}, ClusterShape{}); // mcast along N mode for this M load, if any - Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_Q, args.stride_dO); + Tensor mdO = make_tensor(make_gmem_ptr(args.ptr_dO), args.shape_dO, args.stride_dO); TMA_QdO tma_load_dO = make_tma_copy_A_sm90( GmemTiledCopyQdO{}, mdO, @@ -370,7 +374,7 @@ struct CollectiveMainloopBwdSm90 { SmemLayoutK{}, TileShape_MNK{}, ClusterShape{}); // no mcast for KV - Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_K, args.stride_V); + Tensor mV = make_tensor(make_gmem_ptr(args.ptr_V), args.shape_V, args.stride_V); TMA_V tma_load_V = make_tma_copy_B_sm90( GmemTiledCopyKV{}, mV, @@ -388,6 +392,7 @@ struct CollectiveMainloopBwdSm90 { // Instead we multiply by (1 - tanh^2) and multiply dK and dV by params.softmax_scale // (the original softmax_scale) at the end. return {args.shape_Q, args.shape_K, + args.shape_V, args.shape_dO, args.ptr_dQaccum, args.shape_dQaccum, args.stride_dQaccum, cutlass::FastDivmod(cute::ceil_div(get<2>(args.shape_Q), get<2>(args.shape_K))), tma_load_Q, tma_load_dO, tma_load_K, tma_load_V, @@ -453,9 +458,9 @@ struct CollectiveMainloopBwdSm90 { bool const is_varlen_q = Varlen && params.cu_seqlens_q; bool const is_varlen_k = Varlen && params.cu_seqlens_k; Tensor mQ = params.tma_load_Q.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); - Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_Q)(_, _, bidh, !is_varlen_q ? bidb : 0); + Tensor mdO = params.tma_load_dO.get_tma_tensor(params.shape_dO)(_, _, bidh, !is_varlen_q ? bidb : 0); Tensor mK = params.tma_load_K.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); - Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_K)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); + Tensor mV = params.tma_load_V.get_tma_tensor(params.shape_V)(_, _, bidh_kv, !is_varlen_k ? bidb : 0); Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE_log2), params.shape_LSE, params.stride_LSE_log2)(_, bidh, !is_varlen_q ? bidb : 0); Tensor mdPsum = make_tensor(make_gmem_ptr(params.ptr_dPsum), params.shape_LSE, params.stride_dPsum)(_, bidh, !is_varlen_q ? bidb : 0); diff --git a/flash_attn/flash_attn_interface.py b/flash_attn/flash_attn_interface.py index 68a7b9a0581..9e1e38dbe79 100644 --- a/flash_attn/flash_attn_interface.py +++ b/flash_attn/flash_attn_interface.py @@ -238,9 +238,9 @@ def backward(ctx, dout, *args): dq, dk, dv, ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state ) - dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., :dout.shape[-1]] - dv = dv[..., :dout.shape[-1]] + dq = dq[..., :q.shape[-1]] # We could have padded the head dimension + dk = dk[..., :k.shape[-1]] + dv = dv[..., :v.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None @@ -273,9 +273,9 @@ def backward(ctx, dout, *args): ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, rng_state=rng_state ) - dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension - dk = dk[..., :dout.shape[-1]] - dv = dv[..., :dout.shape[-1]] + dq = dq[..., :q.shape[-1]] # We could have padded the head dimension + dk = dk[..., :k.shape[-1]] + dv = dv[..., :v.shape[-1]] return dq, dk, dv, None, None, None, None, None, None, None, None From 3a830f65c475d8c57a8d3eeb59db32134705c117 Mon Sep 17 00:00:00 2001 From: umiswing Date: Thu, 19 Jun 2025 17:31:54 +0800 Subject: [PATCH 4/4] fix new int32_t overflow for dk_accum and dv_accum --- csrc/flash_attn_v3/flash_bwd_launch_template.h | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/csrc/flash_attn_v3/flash_bwd_launch_template.h b/csrc/flash_attn_v3/flash_bwd_launch_template.h index 607fb2e6c4d..2256e89a98a 100644 --- a/csrc/flash_attn_v3/flash_bwd_launch_template.h +++ b/csrc/flash_attn_v3/flash_bwd_launch_template.h @@ -143,7 +143,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { if constexpr (!GQA) { return typename CollectiveEpilogue::StridedKV {params.dk_row_stride, _1{}, params.dk_head_stride, !is_varlen_k ? params.dk_batch_stride : 0}; // stride_dK } else { - return typename CollectiveEpilogue::StridedKV {_1{}, params.d_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.d_rounded * params.seqlen_k_rounded : 0}; // stride_dKaccum + return typename CollectiveEpilogue::StridedKV {_1{}, int64_t{params.d_rounded} * int64_t{seqlen_k_rounded}, !is_varlen_k ? int64_t{params.h_k} * int64_t{params.d_rounded} * int64_t{params.seqlen_k_rounded} : 0}; // stride_dKaccum } }(), static_cast(!GQA ? params.dv_ptr : params.dv_accum_ptr), @@ -151,14 +151,14 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { if constexpr (!GQA) { return typename CollectiveEpilogue::ShapedKV {seqlen_k, params.dv, params.h, batch_k}; // shape_dV } else { - return typename CollectiveEpilogue::ShapedKV {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}; // shape_dVaccum + return typename CollectiveEpilogue::ShapedKV {int64_t{seqlen_k_rounded} * int64_t{params.dv_rounded}, params.h_k, batch_k}; // shape_dVaccum } }(), [&] { if constexpr (!GQA) { return typename CollectiveEpilogue::StridedKV {params.dv_row_stride, _1{}, params.dv_head_stride, !is_varlen_k ? params.dv_batch_stride : 0}; // stride_dV } else { - return typename CollectiveEpilogue::StridedKV {_1{}, params.dv_rounded * seqlen_k_rounded, !is_varlen_k ? params.h_k * params.dv_rounded * params.seqlen_k_rounded : 0}; // stride_dVaccum + return typename CollectiveEpilogue::StridedKV {_1{}, int64_t{params.dv_rounded} * int64_t{seqlen_k_rounded}, !is_varlen_k ? int64_t{params.h_k} * int64_t{params.dv_rounded} * int64_t{params.seqlen_k_rounded} : 0}; // stride_dVaccum } }(), params.h, @@ -254,7 +254,7 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { typename PostprocessKerneldKV::Arguments postprocess_dK_args { static_cast(params.dk_accum_ptr), {seqlen_k_rounded * params.d_rounded, params.h_k, batch_k}, // shape_dKaccum - {_1{}, seqlen_k_rounded * params.d_rounded, !is_varlen_k ? params.d_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dKaccum + {_1{}, int64_t{seqlen_k_rounded} * int64_t{params.d_rounded}, !is_varlen_k ? int64_t{params.d_rounded} * int64_t{params.seqlen_k_rounded} * int64_t{params.h_k} : 0}, // stride_dKaccum static_cast(params.dk_ptr), {seqlen_k, params.d, params.h_k, batch_k}, // shape_dK {params.dk_row_stride, _1{}, params.dk_head_stride, params.dk_batch_stride}, // stride_dK @@ -265,8 +265,8 @@ void run_flash_bwd(Flash_bwd_params ¶ms, cudaStream_t stream) { typename PostprocessKerneldKV::Params postprocess_dK_params = PostprocessKerneldKV::to_underlying_arguments(postprocess_dK_args); typename PostprocessKerneldKV::Arguments postprocess_dV_args { static_cast(params.dv_accum_ptr), - {seqlen_k_rounded * params.dv_rounded, params.h_k, batch_k}, // shape_dVaccum - {_1{}, seqlen_k_rounded * params.dv_rounded, !is_varlen_k ? params.dv_rounded * params.seqlen_k_rounded * params.h_k : 0}, // stride_dVaccum + {int64_t{seqlen_k_rounded} * int64_t{params.dv_rounded}, params.h_k, batch_k}, // shape_dVaccum + {_1{}, int64_t{seqlen_k_rounded} * int64_t{params.dv_rounded}, !is_varlen_k ? int64_t{params.dv_rounded} * int64_t{params.seqlen_k_rounded} * int64_t{params.h_k} : 0}, // stride_dVaccum static_cast(params.dv_ptr), {seqlen_k, params.dv, params.h_k, batch_k}, // shape_dV {params.dv_row_stride, _1{}, params.dv_head_stride, params.dv_batch_stride}, // stride_dV