diff --git a/flashmask/flash_mask/cute/flash_bwd_sm100.py b/flashmask/flash_mask/cute/flash_bwd_sm100.py index c9d29e9ee30..6c104938d76 100644 --- a/flashmask/flash_mask/cute/flash_bwd_sm100.py +++ b/flashmask/flash_mask/cute/flash_bwd_sm100.py @@ -63,6 +63,7 @@ def __init__( deterministic: bool = False, cluster_size: int = 1, use_2cta_instrs: bool = False, + is_split_d: bool = False, ): # padding head_dim to a multiple of 64 to match head_dim_rounded in interface hdim_multiple_of = 64 @@ -76,28 +77,51 @@ def __init__( self.tile_m = tile_m self.tile_n = tile_n self.debug_print = False + self.is_split_d = is_split_d - assert self.tile_hdim <= 128 or (self.tile_hdim == 192 and self.tile_hdimv == 128) - assert self.tile_hdimv <= 128 + if is_split_d: + assert self.tile_hdim > 192 and self.tile_hdim == self.tile_hdimv, ( + "Split-D BWD requires head_dim > 192 and head_dim == head_dim_v" + ) + # Split-D sub-GEMM dimension (each 256-dim op split into two 128-dim ops) + self.tile_hdim_split = self.tile_hdim // 2 # 128 + else: + assert self.tile_hdim <= 128 or (self.tile_hdim == 192 and self.tile_hdimv == 128) + assert self.tile_hdimv <= 128 self.use_2cta_instrs = bool(use_2cta_instrs and cluster_size == 2) self.cta_group_size = 2 if self.use_2cta_instrs else 1 assert self.tile_hdim != 192 or self.use_2cta_instrs, "Must use 2CTA for hdim 192" + if is_split_d: + assert self.use_2cta_instrs, "Split-D BWD requires 2CTA" # CTA tiler self.cta_tiler = (tile_n, tile_m, self.tile_hdim) - # S = K @ Q.T - self.mma_tiler_kq = (self.cta_group_size * tile_n, tile_m, self.tile_hdim) - # dP = V @ dO.T - self.mma_tiler_vdo = (self.cta_group_size * tile_n, tile_m, self.tile_hdimv) - # dV = P.T @ dO - self.mma_tiler_pdo = (self.cta_group_size * tile_n, self.tile_hdimv, tile_m) - # dK = dS.T @ Q - self.mma_tiler_dsq = (self.cta_group_size * tile_n, self.tile_hdim, tile_m) - # dQ = dS @ K - # 2-CTA: reduction dim is cluster-wide (tile_n * cta_group_size). - self.mma_tiler_dsk = (tile_m, self.tile_hdim, tile_n * self.cta_group_size) + if self.is_split_d: + # Split-D: each sub-GEMM operates on tile_hdim_split=128 + # S_lo = K_lo @ Q_lo.T, S_hi = K_hi @ Q_hi.T (accumulated) + self.mma_tiler_kq = (self.cta_group_size * tile_n, tile_m, self.tile_hdim_split) + # dP_lo = V_lo @ dO_lo.T, dP_hi = V_hi @ dO_hi.T (accumulated) + self.mma_tiler_vdo = (self.cta_group_size * tile_n, tile_m, self.tile_hdim_split) + # dV_lo = P.T @ dO_lo, dV_hi = P.T @ dO_hi + self.mma_tiler_pdo = (self.cta_group_size * tile_n, self.tile_hdim_split, tile_m) + # dK_lo = dS.T @ Q_lo, dK_hi = dS.T @ Q_hi + self.mma_tiler_dsq = (self.cta_group_size * tile_n, self.tile_hdim_split, tile_m) + # dQ_lo = dS @ K_lo, dQ_hi = dS @ K_hi + self.mma_tiler_dsk = (tile_m, self.tile_hdim_split, tile_n * self.cta_group_size) + else: + # S = K @ Q.T + self.mma_tiler_kq = (self.cta_group_size * tile_n, tile_m, self.tile_hdim) + # dP = V @ dO.T + self.mma_tiler_vdo = (self.cta_group_size * tile_n, tile_m, self.tile_hdimv) + # dV = P.T @ dO + self.mma_tiler_pdo = (self.cta_group_size * tile_n, self.tile_hdimv, tile_m) + # dK = dS.T @ Q + self.mma_tiler_dsq = (self.cta_group_size * tile_n, self.tile_hdim, tile_m) + # dQ = dS @ K + # 2-CTA: reduction dim is cluster-wide (tile_n * cta_group_size). + self.mma_tiler_dsk = (tile_m, self.tile_hdim, tile_n * self.cta_group_size) self.acc_dtype = Float32 self.startend_row_indices_dtype = Int32 @@ -163,7 +187,26 @@ def __init__( # self.tmem_total = self.tmem_S_offset + self.tile_n # assert self.tmem_total <= self.tmem_alloc_cols - if self.use_2cta_instrs and self.tile_hdim == 192 and self.tile_hdimv == 128: + if self.is_split_d: + # Split-D (d=dv=256) TMEM layout: + # |--- S/P ---|--- dV_low ---|--- dV_high ---|--- dP/dS ---| + # [0, 128) [128, 256) [256, 384) [384, 512) + # TMEM[0,128) is time-division multiplexed: S/P, dK_partial, dQ_partial + assert self.tile_m == 128 + assert self.tile_n == 128 + self.tmem_S_offset = 0 + self.tmem_P_offset = 0 # overlap with S + self.tmem_dV_lo_offset = 128 + self.tmem_dV_hi_offset = 256 + self.tmem_dV_offset = self.tmem_dV_lo_offset # for compatibility + self.tmem_dP_offset = 384 + self.tmem_dS_offset = 384 # overlap with dP + # dK uses GMEM atomic reduce (no persistent TMEM accumulation) + self.tmem_dK_offset = 0 # time-multiplexed with S/P + # dQ partial accumulation at S/P slot + self.tmem_dQ_offset = 0 # time-multiplexed with S/P + self.dK_as_reduce = True # dK accumulated via GMEM atomic reduce + elif self.use_2cta_instrs and self.tile_hdim == 192 and self.tile_hdimv == 128: assert self.tile_m == 128 assert self.tile_n == 128 self.tmem_dV_offset = 0 @@ -186,6 +229,9 @@ def __init__( self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP + if not self.is_split_d: + self.dK_as_reduce = False + if (not is_causal and not is_local) or deterministic: self.num_regs_reduce = 136 if self.use_2cta_instrs else 152 self.num_regs_compute = 136 @@ -204,6 +250,13 @@ def __init__( self.num_regs_load = 128 - 24 self.num_regs_mma = self.num_regs_load + if const_expr(self.is_split_d): + # Split-D (d=256): similar to d=192 register pressure + self.num_regs_reduce = 136 + self.num_regs_compute = 136 + self.num_regs_load = 104 + self.num_regs_mma = 104 + assert ( self.num_regs_reduce + self.num_regs_compute * 2 @@ -213,6 +266,13 @@ def __init__( self.buffer_align_bytes = 1024 + # Split-D: byte offset from low to high half in SMEM buffers + if self.is_split_d: + # For K-major SMEM (e.g., sQ: tile_m × tile_hdim), high half starts at + # tile_hdim_split * element_size bytes from the start of each row + # Stored here; used in the MMA loop to compute sub-GEMM offsets + self.split_d_smem_k_offset_elems = self.tile_hdim_split # 128 elements + def _setup_attributes(self): self.Q_stage = 1 if self.use_2cta_instrs else 2 self.dO_stage = 1 @@ -221,7 +281,12 @@ def _setup_attributes(self): self.sdKVaccum_stage = 2 # number of tma reduce adds per dQacc mma # todo: try 32/1 or 48/2 for 2cta d=192 dv=128 - if self.use_2cta_instrs and self.tile_hdim == 192: + if self.is_split_d: + # Split-D: dQ reduce operates on tile_hdim_split=128 (each half independently) + self.dQ_reduce_ncol_t2r = 32 + self.dQ_reduce_ncol = 16 if self.deterministic else 8 + self.sdQaccum_stage = 2 if self.deterministic else 4 + elif self.use_2cta_instrs and self.tile_hdim == 192: self.dQ_reduce_ncol_t2r = 32 self.dQ_reduce_ncol = 24 if not self.is_causal else 32 self.sdQaccum_stage = 2 if not self.is_causal else 1 @@ -234,9 +299,11 @@ def _setup_attributes(self): self.dQ_reduce_ncol = 32 self.sdQaccum_stage = 64 // self.dQ_reduce_ncol self.dQ_reduce_ncol_t2r = self.dQ_reduce_ncol - assert (self.tile_hdim // self.cta_group_size) % self.dQ_reduce_ncol == 0 - self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol - self.dQaccum_reduce_stage_t2r = self.tile_hdim // self.dQ_reduce_ncol_t2r + # For Split-D, reduce operations work on half dimensions + tile_hdim_for_reduce = self.tile_hdim_split if self.is_split_d else self.tile_hdim + assert (tile_hdim_for_reduce // self.cta_group_size) % self.dQ_reduce_ncol == 0 + self.dQaccum_reduce_stage = tile_hdim_for_reduce // self.dQ_reduce_ncol + self.dQaccum_reduce_stage_t2r = tile_hdim_for_reduce // self.dQ_reduce_ncol_t2r self.cluster_reduce_dQ = False and cute.size(self.cluster_shape_mn) > 1 # number of tma reduce adds for dKacc and dVacc epilogue self.dK_reduce_ncol = 32 @@ -298,80 +365,96 @@ def _get_tiled_mma(self): return tiled_mma_S, tiled_mma_dP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ def _setup_smem_layout(self): + # For Split-D: SMEM must hold full 256-dim data, but MMA operates on 128-dim halves. + # We define "full" tilers for SMEM layout computation and use the split tilers for MMA. + if self.is_split_d: + # Full tilers for SMEM allocation (tile_hdim=256 in K-dimension) + mma_tiler_kq_full = (self.cta_group_size * self.tile_n, self.tile_m, self.tile_hdim) + mma_tiler_vdo_full = (self.cta_group_size * self.tile_n, self.tile_m, self.tile_hdimv) + mma_tiler_pdo_full = (self.cta_group_size * self.tile_n, self.tile_hdimv, self.tile_m) + mma_tiler_dsq_full = (self.cta_group_size * self.tile_n, self.tile_hdim, self.tile_m) + mma_tiler_dsk_full = (self.tile_m, self.tile_hdim, self.tile_n * self.cta_group_size) + else: + mma_tiler_kq_full = self.mma_tiler_kq + mma_tiler_vdo_full = self.mma_tiler_vdo + mma_tiler_pdo_full = self.mma_tiler_pdo + mma_tiler_dsq_full = self.mma_tiler_dsq + mma_tiler_dsk_full = self.mma_tiler_dsk + # S.T = K @ Q.T sK_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_S, - self.mma_tiler_kq, + mma_tiler_kq_full, self.k_dtype, 1, ) self.sK_layout = cute.slice_(sK_layout, (None, None, None, 0)) self.sQ_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_S, - self.mma_tiler_kq, + mma_tiler_kq_full, self.q_dtype, self.Q_stage, ) # dP.T = V @ dO.T sV_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dP, - self.mma_tiler_vdo, + mma_tiler_vdo_full, self.v_dtype, 1, ) self.sV_layout = cute.slice_(sV_layout, (None, None, None, 0)) self.sdOt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dP, - self.mma_tiler_vdo, + mma_tiler_vdo_full, self.do_dtype, self.dO_stage, ) # dV += P.T @ dO tP_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dV, - self.mma_tiler_pdo, + mma_tiler_pdo_full, self.do_dtype, 1, ) self.tP_layout = cute.slice_(tP_layout, (None, None, None, 0)) self.sdO_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dV, - self.mma_tiler_pdo, + mma_tiler_pdo_full, self.do_dtype, self.dO_stage, ) # dK += dS.T @ Q sdSt_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dK, - self.mma_tiler_dsq, + mma_tiler_dsq_full, self.ds_dtype, 1, ) self.sdSt_layout = cute.slice_(sdSt_layout, (None, None, None, 0)) tdS_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dK, - self.mma_tiler_dsq, + mma_tiler_dsq_full, self.ds_dtype, 1, ) self.tdS_layout = cute.slice_(tdS_layout, (None, None, None, 0)) self.sQt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dK, - self.mma_tiler_dsq, + mma_tiler_dsq_full, self.q_dtype, self.Q_stage, ) # dQ = dS @ K sdS_layout = sm100_utils_basic.make_smem_layout_a( self.tiled_mma_dQ, - self.mma_tiler_dsk, + mma_tiler_dsk_full, self.ds_dtype, 1, ) self.sdS_layout = cute.slice_(sdS_layout, (None, None, None, 0)) sKt_layout = sm100_utils_basic.make_smem_layout_b( self.tiled_mma_dQ, - self.mma_tiler_dsk, + mma_tiler_dsk_full, self.k_dtype, 1, ) diff --git a/flashmask/flash_mask/cute/flash_fwd_sm100.py b/flashmask/flash_mask/cute/flash_fwd_sm100.py index da24d0d1310..323606a85cc 100644 --- a/flashmask/flash_mask/cute/flash_fwd_sm100.py +++ b/flashmask/flash_mask/cute/flash_fwd_sm100.py @@ -102,6 +102,7 @@ def __init__( has_aux_tensors: cutlass.Constexpr = False, paged_kv_non_tma: bool = False, is_varlen_q: bool = False, + is_split_d: bool = False, ): self.use_tma_KV = not paged_kv_non_tma # self.dtype = dtype @@ -116,7 +117,9 @@ def __init__( self.check_hdim_v_oob = head_dim_v != self.head_dim_v_padded self.m_block_size = m_block_size self.n_block_size = n_block_size - self.q_stage = 2 + self.is_split_d = is_split_d + # Split-D: q_stage must be 1 to fit TMEM (S + O_full = 128 + 256 = 384, with gap total = 512) + self.q_stage = 1 if is_split_d else 2 assert self.q_stage in [1, 2] # 2 Q tile per CTA @@ -141,6 +144,12 @@ def __init__( assert not (self.is_split_kv and self.head_dim_v_padded >= 192), ( "SplitKV is not supported for hdim >= 192" ) + if is_split_d: + assert self.head_dim_padded > 192 and self.head_dim_padded == self.head_dim_v_padded, ( + "Split-D requires head_dim > 192 and head_dim == head_dim_v" + ) + assert not self.is_split_kv, "Split-D does not support SplitKV" + assert not self.pack_gqa, "Split-D does not support pack_gqa" self.score_mod = score_mod self.mask_mod = mask_mod if cutlass.const_expr(has_aux_tensors): @@ -190,6 +199,7 @@ def __init__( ] # e.g., 256, 384 self.tmem_total = self.tmem_o_offset[-1] + self.head_dim_v_padded assert self.tmem_total <= SM100_TMEM_CAPACITY_COLUMNS + # print(f"[DEBUG-tmem] s_offset={self.tmem_s_offset}, o_offset={self.tmem_o_offset}, total={self.tmem_total}, capacity={SM100_TMEM_CAPACITY_COLUMNS}") self.tmem_s_to_p_offset = self.n_block_size // 2 self.tmem_p_offset = [ self.tmem_s_offset[i] + self.tmem_s_to_p_offset for i in range(2) @@ -202,6 +212,13 @@ def __init__( self.num_regs_softmax = 200 self.num_regs_correction = 64 self.num_regs_other = 48 + elif self.is_split_d: + # Split-D (d=256): register budget must satisfy 8*S + 4*C + 4*O <= 65536 (SM100 has 65536 regs/SM) + # Equivalently: 2*S + C + O <= 512 (in units of 32-thread warps) + # Using S=200, C=80, O=32: 2*200 + 80 + 32 = 512 + self.num_regs_softmax = 200 + self.num_regs_correction = 80 + self.num_regs_other = 32 else: # self.num_regs_softmax = 192 if self.is_causal or self.is_local else 184 self.num_regs_softmax = 200 @@ -235,13 +252,17 @@ def _setup_attributes(self): if self.head_dim_padded == 192 and self.head_dim_v_padded == 128: self.kv_stage = 2 if self.enable_flashmask else 3 + elif self.is_split_d: + # Split-D (d=dv=256): SMEM budget is tight (Q=64KB + KV=128KB=192KB with kv_stage=2) + self.kv_stage = 2 elif self.q_dtype.width == 8 or self.q_stage == 1: self.kv_stage = 4 else: self.kv_stage = 3 self.acc_stage = 1 - self.epi_stage = 2 + # Split-D: reduce epi_stage to 1 to fit SMEM (O=128KB with epi_stage=2 is too large) + self.epi_stage = 1 if self.is_split_d else 2 self.generate_block_stage = 2 # For hdim 192,128, we don't have enough smem to store all 3 stages of KV: # 128 x 192 x 2 bytes x 3 stages = 144KB, and we need 96KB for Q. @@ -577,7 +598,8 @@ def __call__( self.overlap_sO_sQ = ( (self.head_dim_padded == 192 and self.head_dim_v_padded >= 64) or - (self.head_dim_v_padded >= 128 and self.is_split_kv) + (self.head_dim_v_padded >= 128 and self.is_split_kv) or + self.is_split_d ) if const_expr(self.enable_flashmask): self.overlap_sO_sQ = True @@ -630,8 +652,8 @@ def __call__( self.mbar_O_full_offset = self.mbar_S_full_offset + 2 self.mbar_softmax_corr_full_offset = self.mbar_O_full_offset + 2 self.mbar_softmax_corr_empty_offset = self.mbar_softmax_corr_full_offset + 2 - self.mbar_corr_epi_full_offset = self.mbar_softmax_corr_empty_offset + self.epi_stage - self.mbar_corr_epi_empty_offset = self.mbar_corr_epi_full_offset + self.epi_stage + self.mbar_corr_epi_full_offset = self.mbar_softmax_corr_empty_offset + 2 # softmax_corr_empty always has 2 barriers + self.mbar_corr_epi_empty_offset = self.mbar_corr_epi_full_offset + self.q_stage # corr_epi_full has q_stage barriers self.mbar_s0_s1_sequence_offset = self.mbar_corr_epi_empty_offset + 2 self.mbar_tmem_dealloc_offset = self.mbar_s0_s1_sequence_offset + 8 self.mbar_P_full_2_offset = self.mbar_tmem_dealloc_offset + 1 @@ -676,7 +698,7 @@ class SharedStorage: tmem_holding_buf: Int32 # Smem tensors # store row max and row sum - sScale: cute.struct.MemRange[Float32, self.q_stage * self.m_block_size * 2] + sScale: cute.struct.MemRange[Float32, 2 * self.m_block_size * 2] s_startend_row_indices_size = 0 s_startend_row_indices_block_max_min_size = 0 @@ -698,6 +720,15 @@ class SharedStorage: s_extra_flags: cute.struct.MemRange[Int32, s_extra_flags_size] # TODO(wusiming): would it be better to alloc more space to s_n_block? self.shared_storage = SharedStorage + # Debug: print SMEM component sizes + # print(f"[DEBUG-smem] sO_size={sO_size}, sQ_size={sQ_size}, sK_cosize={cute.cosize(sK_layout)}") + # print(f"[DEBUG-smem] mbar_total={self.mbar_total}, sScale_size={2 * self.m_block_size * 2}") + # if const_expr(self.enable_flashmask): + # print(f"[DEBUG-smem] flashmask: s_startend_row_indices_size={4 * self.n_block_size * self.kv_stage}") + # print(f"[DEBUG-smem] flashmask: s_startend_row_indices_block_max_min_size={8 * self.generate_block_buffer_block_count * self.generate_block_stage}") + # print(f"[DEBUG-smem] flashmask: s_n_block_size={self.generate_block_buffer_block_count * self.generate_block_stage}") + # print(f"[DEBUG-smem] flashmask: generate_block_buffer_block_count={self.generate_block_buffer_block_count}") + import sys; sys.stdout.flush() LOG2_E = math.log2(math.e) if const_expr(self.score_mod is None): @@ -730,6 +761,18 @@ class SharedStorage: raise NotImplementedError("Block sparsity + paged KV not supported on SM100") # Launch the kernel synchronously + # print(f"[DEBUG-launch] smem={self.shared_storage.size_in_bytes()}, threads={self.threads_per_cta}, grid={grid_dim}, num_regs_softmax={self.num_regs_softmax}, num_regs_correction={self.num_regs_correction}, is_split_d={self.is_split_d}") + # print(f"[DEBUG-mbar] q_stage={self.q_stage}, epi_stage={self.epi_stage}, kv_stage={self.kv_stage}") + # print(f"[DEBUG-mbar] load_q_full={self.mbar_load_q_full_offset}, load_q_empty={self.mbar_load_q_empty_offset}") + # print(f"[DEBUG-mbar] load_kv_full={self.mbar_load_kv_full_offset}, load_kv_empty={self.mbar_load_kv_empty_offset}") + # print(f"[DEBUG-mbar] P_full_O_rescaled={self.mbar_P_full_O_rescaled_offset}, S_full={self.mbar_S_full_offset}, O_full={self.mbar_O_full_offset}") + # print(f"[DEBUG-mbar] softmax_corr_full={self.mbar_softmax_corr_full_offset}, softmax_corr_empty={self.mbar_softmax_corr_empty_offset}") + # print(f"[DEBUG-mbar] corr_epi_full={self.mbar_corr_epi_full_offset}, corr_epi_empty={self.mbar_corr_epi_empty_offset}") + # print(f"[DEBUG-mbar] s0_s1_seq={self.mbar_s0_s1_sequence_offset}, tmem_dealloc={self.mbar_tmem_dealloc_offset}, P_full_2={self.mbar_P_full_2_offset}") + # print(f"[DEBUG-mbar] generate_block_full={self.mbar_generate_block_full_offset}, generate_block_empty={self.mbar_generate_block_empty_offset}") + # print(f"[DEBUG-mbar] load_startend_full={self.mbar_load_startend_row_indices_full_offset}, load_startend_empty={self.mbar_load_startend_row_indices_empty_offset}") + # print(f"[DEBUG-mbar] mbar_total={self.mbar_total}") + import sys; sys.stdout.flush() self.kernel( mQ, mK, @@ -826,6 +869,7 @@ def kernel( """ warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 0: cute.printf("[DEBUG-K] entry warp=%d\n", warp_idx) # Prefetch tma descriptor if warp_idx == 0: @@ -837,10 +881,12 @@ def kernel( if const_expr(tma_atom_O is not None): cpasync.prefetch_descriptor(tma_atom_O) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 0: cute.printf("[DEBUG-K] after prefetch warp=%d\n", warp_idx) # Alloc smem = cutlass.utils.SmemAllocator() storage = smem.allocate(self.shared_storage) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 0: cute.printf("[DEBUG-K] after smem alloc warp=%d\n", warp_idx) mbar_ptr = storage.mbar_ptr.data_ptr() # Use the first N warps to initialize barriers if warp_idx == 1: @@ -934,6 +980,7 @@ def kernel( # Relying on pipeline_kv constructor to call mbarrier_init_fence and sync pipeline_kv = self.make_and_init_load_kv_pipeline(mbar_ptr + self.mbar_load_kv_full_offset) + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 0: cute.printf("[DEBUG-K] after pipeline_kv init warp=%d\n", warp_idx) # Generate smem tensor Q/K/V/O sQ = storage.sQ.get_tensor(sQ_layout.outer, swizzle=sQ_layout.inner) # (MMA, MMA_K, MMA_D, PIPE) @@ -946,7 +993,7 @@ def kernel( else: sO = cute.make_tensor(cute.recast_ptr(sQ.iterator, sO_layout.inner, self.o_dtype), sO_layout.outer) - sScale = storage.sScale.get_tensor(cute.make_layout(self.q_stage * self.m_block_size * 2)) + sScale = storage.sScale.get_tensor(cute.make_layout(2 * self.m_block_size * 2)) thr_mma_qk = tiled_mma_qk.get_slice(0) # default 1SM thr_mma_pv = tiled_mma_pv.get_slice(0) # default 1SM @@ -1034,6 +1081,7 @@ def kernel( s_extra_flags = None s_startend_row_indices = None + # if cute.arch.thread_idx()[0] == 0 and cute.arch.block_idx()[0] == 0: cute.printf("[DEBUG-K] before warp branch warp=%d\n", warp_idx) # /////////////////////////////////////////////////////////////////////////////// # EMPTY # /////////////////////////////////////////////////////////////////////////////// @@ -1054,6 +1102,7 @@ def kernel( if warp_idx >= self.generate_block_warp_ids[0] and warp_idx <= self.generate_block_warp_ids[-1]: # TODO(wusiming): tune reg for generate block cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("[DEBUG] Generate block warp started, warp_idx=%d\\n", warp_idx) self.generate_block( s_startend_row_indices_block_max_min, s_n_block, @@ -1113,6 +1162,7 @@ def kernel( cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) cute.arch.sync_warp() + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("[DEBUG] MMA warp started, warp_idx=%d\\n", warp_idx) self.mma( tiled_mma_qk, tiled_mma_pv, @@ -1172,6 +1222,7 @@ def kernel( if warp_idx < self.correction_warp_ids[0]: # increase register after decreasing cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("[DEBUG] Softmax warp started, warp_idx=%d\\n", warp_idx) softmax_loop = partial( self.softmax_loop, softmax_scale_log2=softmax_scale_log2, @@ -1221,6 +1272,7 @@ def kernel( # /////////////////////////////////////////////////////////////////////////////// if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("[DEBUG] Correction warp started, warp_idx=%d\\n", warp_idx) self.correction_loop( thr_mma_qk, thr_mma_pv, @@ -2006,7 +2058,7 @@ def mma( qk_mma_op, self.tmem_s_offset[stage], tSrQs[stage], - sA=sQ[None, None, None, stage], + sA=sQ[None, None, None, stage if self.q_stage == 2 else 0], zero_init=True, ) for stage in range(2) @@ -2037,6 +2089,7 @@ def mma( while work_tile.is_valid_tile: m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("[DEBUG-MMA] work_tile: m_block=%d, head=%d, batch=%d\\n", m_block, head_idx, batch_idx) block_iter_count = Int32(0) process_tile = False @@ -2044,6 +2097,7 @@ def mma( if const_expr(self.enable_flashmask): block_iter_count = flashmask_info.valid_block_count[batch_idx, head_idx // h_h_flashmask_ratio, m_block] process_tile = block_iter_count > Int32(0) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("[DEBUG-MMA] flashmask block_iter_count=%d for m_block=%d\\n", block_iter_count, m_block) elif const_expr(self.use_block_sparsity): block_iter_count = get_total_block_count(blocksparse_tensors, batch_idx, head_idx, m_block) process_tile = block_iter_count > Int32(0) @@ -2056,22 +2110,20 @@ def mma( process_tile = n_block_min < n_block_max if process_tile: - for stage in cutlass.range_constexpr(self.q_stage): + for stage in cutlass.range_constexpr(2): # GEMM_QK00 (Q0 * K0 -> S0) or GEMM_QK01 (Q1 * K0 -> S1) + # In split-d (q_stage=1), both S[0] and S[1] are computed from the same Q + # and the same K (full head_dim=256 dot product). S[0]==S[1]. # 1. wait for Q0 / Q1 - cute.arch.mbarrier_wait( - mbar_ptr + self.mbar_load_q_full_offset + stage, mma_q_consumer_phase - ) + if const_expr(self.q_stage == 2) or const_expr(stage == 0): + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_load_q_full_offset + (stage if const_expr(self.q_stage == 2) else 0), mma_q_consumer_phase + ) # 2. wait for K0 if const_expr(stage == 0): pipeline_kv.consumer_wait(mma_kv_consumer_state) tSrKi = tSrK[None, None, None, mma_kv_consumer_state.index] - # We don't need to acquire empty S0 / S1. - # For the first iteration, we don't need to wait as we're guaranteed S0 / S1 - # are empty. For subsequent iterations, the wait happened at the end - # of the while loop. # 3. gemm - # tiled_mma_qk = sm100_utils.gemm(tiled_mma_qk, tStSs[stage], tSrQs[stage], tSrKi, zero_init=True) sK_cur = sK[None, None, None, mma_kv_consumer_state.index] if const_expr(self.uneven_kv_smem): sK_cur = self.offset_kv_smem( @@ -2092,6 +2144,7 @@ def mma( # O hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate block_loop_count = block_iter_count - 1 O_should_accumulate = False + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("[DEBUG-MMA] entering block_loop, block_loop_count=%d\\n", block_loop_count) for i in cutlass.range(block_loop_count, unroll=1): # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop @@ -2103,33 +2156,27 @@ def mma( tOrVi = tOrV[None, None, None, Vi_index] for stage in cutlass.range_constexpr(2): # 2. acquire corrected O0/O1_partial and P0 / P1 - # For the first iteration in this work tile, waiting for O0/O1_partial - # means that the correction warps has finished reading tO during - # the last iteration of the previous work tile has finished. cute.arch.mbarrier_wait( mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase, ) - # 3. gemm - # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) - # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) - sV_cur = sV[None, None, None, Vi_index] - if const_expr(self.uneven_kv_smem): - sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) - gemm_Pi[stage]( - tCrB=tOrVi, - sB=sV_cur, - zero_init=not O_should_accumulate, - mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, - mbar_phase=P_full_O_rescaled_phase, - ) - # 4. release accumulated O0_partial / O1_partial - # Don't need to signal O_full to the correction warps anymore since the - # correction warps wait for the softmax warps anyway. By the time the softmax - # warps finished, S_i for the next iteration must have been done, so O_i-1 - # must have been done as well. - # with cute.arch.elect_one(): - # tcgen05.commit(mbar_ptr + self.mbar_O_full_offset + stage) + # 3. gemm (skip stage 1 in split-d: S[0]==S[1] so P0==P1, only need one PV) + if const_expr(self.q_stage == 2 or stage == 0): + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage]( + tCrB=tOrVi, + sB=sV_cur, + zero_init=not O_should_accumulate, + mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_phase=P_full_O_rescaled_phase, + ) + else: + # split-d stage 1: consume P_full_2 barrier without doing gemm + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_P_full_2_offset + stage, P_full_O_rescaled_phase + ) # 5. release V(i-1) if const_expr(stage == 1): pipeline_kv.consumer_release(mma_kv_release_state) @@ -2181,19 +2228,23 @@ def mma( mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage, P_full_O_rescaled_phase ) - # 3. gemm - # sm100_utils.gemm(tiled_mma_pv, tOtO0, tOrP0, tOrVi, zero_init=True) - # gemm_Pi[stage](tCrB=tOrVi, sB=sV[None, None, None, Vi_index], zero_init=not O_should_accumulate) - sV_cur = sV[None, None, None, Vi_index] - if const_expr(self.uneven_kv_smem): - sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) - gemm_Pi[stage]( - tCrB=tOrVi, - sB=sV_cur, - zero_init=not O_should_accumulate, - mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, - mbar_phase=P_full_O_rescaled_phase, - ) + # 3. gemm (skip stage 1 in split-d) + if const_expr(self.q_stage == 2 or stage == 0): + sV_cur = sV[None, None, None, Vi_index] + if const_expr(self.uneven_kv_smem): + sV_cur = self.offset_kv_smem(sV_cur, Vi_index, Vi_phase) + gemm_Pi[stage]( + tCrB=tOrVi, + sB=sV_cur, + zero_init=not O_should_accumulate, + mbar_ptr=mbar_ptr + self.mbar_P_full_2_offset + stage, + mbar_phase=P_full_O_rescaled_phase, + ) + else: + # split-d stage 1: consume P_full_2 barrier without doing gemm + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_P_full_2_offset + stage, P_full_O_rescaled_phase + ) # 4. release accumulated O0_partial # We do need O_full here since for the last tile, by the time the softmax warp # has signaled to the correction warps, the softmax warp has just finished compute @@ -2317,9 +2368,13 @@ def softmax_loop( n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + # In split-d (q_stage=1), both softmax warp groups process the SAME Q block + # (stage is KV pipeline buffer index, not Q row index), so m_block should be identical. + # In non-split-d (q_stage=2), stage indexes different Q rows. + m_block_for_mask = self.q_stage * m_block + (stage if self.q_stage == 2 else 0) if const_expr(self.enable_flashmask): shared_mask_kwargs = dict( - m_block=self.q_stage * m_block + stage, + m_block=m_block_for_mask, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, @@ -2341,7 +2396,7 @@ def softmax_loop( ) else: shared_mask_kwargs = dict( - m_block=self.q_stage * m_block + stage, + m_block=m_block_for_mask, thr_mma=thr_mma_qk, thr_tmem_load=thr_tmem_load, mask_causal=self.is_causal, @@ -2399,7 +2454,7 @@ def softmax_loop( stage=stage, batch_idx=batch_idx, head_idx=head_idx, - m_block=self.q_stage * m_block + stage, + m_block=m_block_for_mask, seqlen=seqlen, aux_tensors=aux_tensors, fastdiv_mods=fastdiv_mods, @@ -2905,6 +2960,7 @@ def correction_loop( m_block, head_idx, batch_idx, split_idx = work_tile.tile_idx seqlen = SeqlenInfoCls(batch_idx) n_block_min, n_block_max = block_info.get_n_block_min_max(seqlen, m_block, split_idx, num_splits) + # if cute.arch.thread_idx()[0] % 32 == 0: cute.printf("[DEBUG-CORR] work_tile: m_block=%d, head=%d, batch=%d\\n", m_block, head_idx, batch_idx) if const_expr(self.is_split_kv): mO_cur = seqlen.offset_batch_Q(mO, batch_idx, dim=3)[None, None, head_idx, split_idx] @@ -2960,10 +3016,14 @@ def correction_loop( # Don't need O_full anymore, since by the time softmax has signaled the correction # warps, S_i must have been done, so O_i-1 must have been done as well. # cute.arch.mbarrier_wait(mbar_ptr + self.mbar_O_full_offset + stage, o_corr_consumer_phase) - if should_rescale: - self.correction_rescale( - thr_mma_pv, tOtOs[stage if self.q_stage == 2 else 0], tidx, scale - ) + # In split-d, skip stage 1 rescale: S[0]==S[1] so softmax0 and softmax1 + # track identical max sequences, causing double rescaling of the same O buffer. + # Only stage 0 correction is needed since only stage 0 PV gemm executes. + if const_expr(self.q_stage == 2 or stage == 0): + if should_rescale: + self.correction_rescale( + thr_mma_pv, tOtOs[stage if self.q_stage == 2 else 0], tidx, scale + ) cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + stage) cute.arch.mbarrier_arrive( @@ -3045,6 +3105,19 @@ def correction_loop( # if tidx == 0: cute.printf("Correction final scale for stage %d: %f\n", stage, scale) + # In split-d (q_stage=1), epilogue only processed stage 0. We must still consume + # stage 1's final softmax signal and O_full commit to prevent barrier deadlock. + if const_expr(self.q_stage == 1): + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_softmax_corr_full_offset + 1, + softmax_corr_consumer_phase, + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + 1) + cute.arch.mbarrier_wait( + mbar_ptr + self.mbar_O_full_offset + 1, o_corr_consumer_phase + ) + cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_P_full_O_rescaled_offset + 1) + o_corr_consumer_phase ^= 1 softmax_corr_consumer_phase ^= 1 corr_epi_producer_phase ^= 1 diff --git a/flashmask/flash_mask/cute/interface.py b/flashmask/flash_mask/cute/interface.py index fb213e4e5c4..3075bddf52c 100644 --- a/flashmask/flash_mask/cute/interface.py +++ b/flashmask/flash_mask/cute/interface.py @@ -152,8 +152,10 @@ def _flash_attn_fwd( fm_batch_size = startend_row_indices.shape[0] fm_heads = startend_row_indices.shape[1] # Note(wusiming): FA4 is so weird, but each cta process q_stage * m_block_size rows - q_stage = 2 + # Split-D (d>192, d==dv) uses q_stage=1 to fit TMEM budget + q_stage = 1 if (head_dim > 192 and head_dim == v.shape[-1]) else 2 num_m_blocks = (seqlen_q + (q_stage * m_block_size) - 1) // (q_stage * m_block_size) + print(f"[DEBUG-interface] flashmask: q_stage={q_stage}, num_m_blocks={num_m_blocks}, head_dim={head_dim}, m_block_size={m_block_size}, seqlen_q={seqlen_q}") flashmask_info = FlashMaskInfoPaddle( is_causal=causal, startend_row_indices=startend_row_indices, @@ -364,6 +366,8 @@ def _flash_attn_fwd( # TODO: fix GQA + SplitKV + non-varlen if pack_gqa and num_splits != 1 and cu_seqlens_q is None: pack_gqa = False + # Split-D for d=dv=256 (head_dim > 192 requires q_stage=1 to fit TMEM) + is_split_d = head_dim > 192 and head_dim == head_dim_v if num_splits < 1: max_seqlen_k = ( @@ -490,6 +494,7 @@ def _flash_attn_fwd( page_size not in [None, 128], # paged KV non-TMA # flashmask startend_row_indices.shape[3] if startend_row_indices is not None else None, + is_split_d if compute_capability == 10 else False, ) if compile_key not in _flash_attn_fwd.compile_cache: if compute_capability == 9: @@ -537,6 +542,7 @@ def _flash_attn_fwd( has_aux_tensors=aux_tensors is not None, paged_kv_non_tma=page_size not in [None, 128], is_varlen_q=cu_seqlens_q is not None or seqused_q is not None, + is_split_d=is_split_d, ) else: raise ValueError( @@ -648,6 +654,7 @@ def _flash_attn_bwd( num_flashmask_tensors = 2 * flashmask_info.startend_row_indices.shape[-1] num_head, head_dim = q.shape[-2:] + is_split_d_bwd = False if compute_capability == 9: m_block_size = 80 if not causal else 64 @@ -673,6 +680,8 @@ def _flash_attn_bwd( need_large_cluster = (head_dim > 128) or (head_dim == 128 and flashmask_info is None) cluster_size = 2 if need_large_cluster else 1 use_2cta_instrs = cluster_size == 2 + # Split-D BWD for d=dv=256 (requires 2CTA + sub-GEMM decomposition) + is_split_d_bwd = head_dim > 192 and head_dim == head_dim_v q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = [ maybe_contiguous(t) @@ -780,7 +789,7 @@ def _flash_attn_bwd( dpsum = paddle.empty(shape=[num_head, total_q_rounded_padded], dtype=paddle.float32) lse_log2 = paddle.empty(shape=[num_head, total_q_rounded_padded], dtype=paddle.float32) - if qhead_per_kvhead > 1: + if qhead_per_kvhead > 1 or is_split_d_bwd: head_dim_v_rounded = (head_dim_v + hdim_round_to - 1) // hdim_round_to * hdim_round_to if cu_seqlens_k is None: seqlen_k_rounded = (seqlen_k + n_block_size - 1) // n_block_size * n_block_size @@ -822,7 +831,7 @@ def _flash_attn_bwd( from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (dq_accum, dpsum, lse_log2) ] - if qhead_per_kvhead > 1: + if qhead_per_kvhead > 1 or is_split_d_bwd: dk_accum_tensor, dv_accum_tensor = [ from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (dk_accum, dv_accum) @@ -945,6 +954,7 @@ def _flash_attn_bwd( pack_gqa, cluster_size, deterministic, + is_split_d_bwd if compute_capability == 10 else False, ) num_threads = 384 if compile_key not in _flash_attn_bwd.compile_cache: @@ -1000,6 +1010,7 @@ def _flash_attn_bwd( cluster_size=cluster_size, use_2cta_instrs=use_2cta_instrs, deterministic=deterministic, + is_split_d=is_split_d_bwd, ) # TODO: check @can_implement _flash_attn_bwd.compile_cache[compile_key] = cute.compile( @@ -1011,8 +1022,8 @@ def _flash_attn_bwd( lse_log2_tensor, dpsum_tensor, dq_accum_tensor, - dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, - dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, + dk_tensor if (qhead_per_kvhead == 1 and not is_split_d_bwd) else dk_accum_tensor, + dv_tensor if (qhead_per_kvhead == 1 and not is_split_d_bwd) else dv_accum_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, @@ -1032,8 +1043,8 @@ def _flash_attn_bwd( lse_log2_tensor, dpsum_tensor, dq_accum_tensor, - dk_tensor if qhead_per_kvhead == 1 else dk_accum_tensor, - dv_tensor if qhead_per_kvhead == 1 else dv_accum_tensor, + dk_tensor if (qhead_per_kvhead == 1 and not is_split_d_bwd) else dk_accum_tensor, + dv_tensor if (qhead_per_kvhead == 1 and not is_split_d_bwd) else dv_accum_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, @@ -1074,7 +1085,7 @@ def _flash_attn_bwd( current_stream, ) - if qhead_per_kvhead > 1: + if qhead_per_kvhead > 1 or is_split_d_bwd: # Postprocess kernel: convert dk_accum & dv_accum from float32 to bf16/fp16 compile_key_post = (dtype, head_dim, arch, n_block_size, num_threads, AtomLayoutNdKV, dKV_swapAB) if compile_key_post not in _flash_attn_bwd.compile_cache_post: @@ -1715,6 +1726,8 @@ def flashmask_attention( (query.shape[-1] <= 128 and key.shape[-1] <= 128 and value.shape[-1] <= 128) or (query.shape[-1] == 192 and key.shape[-1] == 192 and value.shape[-1] == 128) + or + (query.shape[-1] == 256 and key.shape[-1] == 256 and value.shape[-1] == 256) ) and (startend_row_indices is None or startend_row_indices.shape[-1] != 4) ): @@ -1855,6 +1868,8 @@ def flash_attention( (query.shape[-1] <= 128 and key.shape[-1] <= 128 and value.shape[-1] <= 128) or (query.shape[-1] == 192 and key.shape[-1] == 192 and value.shape[-1] == 128) + or + (query.shape[-1] == 256 and key.shape[-1] == 256 and value.shape[-1] == 256) ) ): assert dropout == 0.0, (