diff --git a/flashmask/flash_mask/cute/flash_bwd_sm100.py b/flashmask/flash_mask/cute/flash_bwd_sm100.py index 9ebfdaab421..fbdd8d37bd6 100644 --- a/flashmask/flash_mask/cute/flash_bwd_sm100.py +++ b/flashmask/flash_mask/cute/flash_bwd_sm100.py @@ -355,7 +355,8 @@ def _setup_smem_layout(self): else: self.sdKV_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2)) - # TODO(GuoxiaWang): 2 means only support flashmask startend_row_indices.shape[-1] <= 2 + # 2 columns: [LTS, LTE or UTE] loaded into smem. + # For num_vec==4, UTS and UTE are read directly from gmem (smem budget is full). self.sStartEndRowIndices_layout = cute.make_layout( shape=(self.tile_n, 2), stride=(1, self.tile_n), @@ -402,6 +403,9 @@ def __call__( self.ds_dtype = self.q_dtype self.enable_flashmask = cutlass.const_expr(flashmask_info is not None) + self.has_lte = const_expr(flashmask_info is not None and flashmask_info.LTE_nblock_max is not None) + self.has_uts = const_expr(flashmask_info is not None and flashmask_info.UTS_nblock_max is not None) + self.has_ute = const_expr(flashmask_info is not None and flashmask_info.UTE_nblock_max is not None) if const_expr(self.qhead_per_kvhead > 1): assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA" @@ -714,9 +718,9 @@ class SharedStorage: cute.struct.MemRange[self.startend_row_indices_dtype, cute.cosize(self.sStartEndRowIndices_layout)], 64, ] - # sStartEndRowIndices_layout (128,4):(1,128) - # 128 * 4 = 512 * 4 = 2048 - # 234496 + # sStartEndRowIndices_layout (128,2):(1,128) + # 128 * 2 = 256 * 4 = 1024 + # 233472 self.shared_storage = SharedStorage #print("self.shared_storage.size_in_bytes()", self.shared_storage.size_in_bytes()) @@ -1216,6 +1220,7 @@ def kernel( sStartEndRowIndices, sFM_max_min, flashmask_loaded_mbar_ptr, + mQ.shape[2], ) cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) @@ -1461,7 +1466,8 @@ def load( if tidx == 0 and self.debug_print: cute.printf('n_block: %d, after load_step 0 ~ UTS: %d', n_block, m_block) # Subtract 1 beforehand to use loop_start + 1 uniformly in the for loop. - loop_start = sFM_max_min[7] - 1 + # Use max to advance past UTS region, avoiding double-loading when UTE_min <= UTS_max + loop_start = max(sFM_max_min[4], sFM_max_min[7] - 1) else: loop_start = sFM_max_min[7] @@ -1667,22 +1673,38 @@ def load_fm( sFM_max_min[6] = (UTE_nblock_max[n_block] - 1) // self.tile_m sFM_max_min[7] = UTE_nblock_min[n_block] // self.tile_m + # sStartEndRowIndices layout is (tile_n, 2). + # For num_vec==4 (has_uts), UTS and UTE are read from gmem by the compute warp + # because smem budget is at the 228KB limit and cannot fit 4 columns. + # + # Possible masking conditions depending on num_vec: + # num_vec==1, causal: row >= LTS + # num_vec==2, causal: row >= LTS AND row < LTE + # num_vec==2, non-causal: row >= LTS OR row < UTE + # num_vec==4: (row >= LTS AND row < LTE) OR (row >= UTS AND row < UTE) + # + # smem column mapping: + # [:, 0] = LTS (all cases) + # [:, 1] = LTE (num_vec==2 causal or num_vec==4) or UTE (num_vec==2 non-causal) + # + # Default values ensure no masking when the bound is not supplied or column is OOB: + # [:, 0] = Int32.max (LTS default): "row >= Int32.max" is always false. + # [:, 1] = 0 (LTE/UTE default): "row < 0" is always false. for i in cutlass.range_constexpr(ntimes_copy): copy_offset = i * num_load_threads + tidx - sStartEndRowIndices[copy_offset, 0] = 2147483647 - sStartEndRowIndices[copy_offset, 1] = 2147483647 + sStartEndRowIndices[copy_offset, 0] = Int32.max # LTS default + sStartEndRowIndices[copy_offset, 1] = 0 # LTE/UTE default if (copy_offset < self.tile_n and n_block * self.tile_n + copy_offset < seqlen_k): LTS = flashmask_info.startend_row_indices[fm_batch_idx, fm_head_idx, None, 0] sStartEndRowIndices[copy_offset, 0] = LTS[n_block * self.tile_n + copy_offset] - #assert const_expr(num_vec <= 2), "only support num_vec == 2 now" if const_expr(flashmask_info.LTE_nblock_max is not None): + # num_vec==2 causal or num_vec==4: LTE at source index 1 LTE = flashmask_info.startend_row_indices[fm_batch_idx, fm_head_idx, None, 1] sStartEndRowIndices[copy_offset, 1] = LTE[n_block * self.tile_n + copy_offset] - if const_expr(flashmask_info.UTE_nblock_max is not None): + elif const_expr(flashmask_info.UTE_nblock_max is not None): + # num_vec==2, non-causal: UTE at source index 1 UTE = flashmask_info.startend_row_indices[fm_batch_idx, fm_head_idx, None, 1] sStartEndRowIndices[copy_offset, 1] = UTE[n_block * self.tile_n + copy_offset] - #cute.printf("%d, %d", copy_offset, sStartEndRowIndices[copy_offset, 0]) - #cute.print_tensor(LTS) cute.arch.sync_warp() @cute.jit @@ -1832,7 +1854,10 @@ def mma( num_blocks = num_blocks + max(0, (loop_end - loop_start)) if tidx == 0 and self.debug_print: cute.printf('after uts mma: n_block: %d, %d', n_block, num_blocks) - loop_start = sFM_max_min[7] + # Advance past UTS region to avoid double-counting when UTE_min <= UTS_max + loop_start = max(sFM_max_min[4] + 1, sFM_max_min[7]) + else: + loop_start = sFM_max_min[7] # UTE ~ LTS #loop_end = m_block_max if m_block_max < sFM_max_min[0] + 1 else sFM_max_min[0] + 1 @@ -2089,6 +2114,7 @@ def compute_loop( sStartEndRowIndices: cute.Tensor, sFM_max_min: cute.Tensor, flashmask_loaded_mbar_ptr: cute.Pointer, + num_heads: Int32, ): sLSE_2D = cute.make_tensor( sLSE.iterator, @@ -2204,6 +2230,14 @@ def compute_loop( seqlen, n_block // self.cluster_shape_mnk[0] ) mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k) + # For num_vec==4, compute fm_batch_idx/fm_head_idx for direct gmem access + # to UTS/UTE (not loaded into smem because 228KB smem budget is full). + fm_batch_idx_compute = Int32(0) + fm_head_idx_compute = Int32(0) + if const_expr(self.has_uts): + bsz, fm_heads, seqlen_k_fm, num_vec = flashmask_info.startend_row_indices.shape + fm_batch_idx_compute = batch_idx if bsz > 1 else 0 + fm_head_idx_compute = head_idx // (num_heads // fm_heads) # TODO: condition mask_seqlen mask_fn = partial( mask.apply_mask_sm100_transposed, @@ -2214,6 +2248,12 @@ def compute_loop( mask_causal=self.is_causal, mask_local=self.is_local, sStartEndRowIndices=sStartEndRowIndices, + startend_row_indices=flashmask_info.startend_row_indices if const_expr(self.has_uts) else None, + fm_batch_idx=fm_batch_idx_compute, + fm_head_idx=fm_head_idx_compute, + has_lte=self.has_lte, + has_uts=self.has_uts, + has_ute=self.has_ute, ) # prefetch_LSE = not self.is_causal @@ -2292,6 +2332,9 @@ def compute_loop( if tidx == 0 and self.debug_print: cute.printf('n_block: %d, after compute_step UTS_min ~ UTS_max: %d', n_block, m_block) + # Advance past UTS region to avoid double-processing when UTE_min <= UTS_max + loop_start = sFM_max_min[4] + 1 # UTS_max + 1 + loop_start = max(loop_start, sFM_max_min[7]) # UTE_min loop_end = min(sFM_max_min[6] + 1, m_block_max) # UTE_max for m_block in cutlass.range(loop_start, loop_end, unroll=1): diff --git a/flashmask/flash_mask/cute/interface.py b/flashmask/flash_mask/cute/interface.py index 8fdb636d7e9..ee8c6cc023a 100644 --- a/flashmask/flash_mask/cute/interface.py +++ b/flashmask/flash_mask/cute/interface.py @@ -1696,7 +1696,6 @@ def flashmask_attention( if ( paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] == 4 and query.shape[-1] <= 128 and key.shape[-1] <= 128 and value.shape[-1] <= 128 - and (startend_row_indices is None or startend_row_indices.shape[-1] != 4) ): assert dropout == 0.0, ( "flashmask v4 does not support dropout" diff --git a/flashmask/flash_mask/cute/mask.py b/flashmask/flash_mask/cute/mask.py index 8f2880b868b..c1dd648ad7f 100644 --- a/flashmask/flash_mask/cute/mask.py +++ b/flashmask/flash_mask/cute/mask.py @@ -512,17 +512,31 @@ def apply_mask_sm100_transposed( mask_local: cutlass.Constexpr, sStartEndRowIndices: cute.Tensor, partially_masked: bool, + startend_row_indices: Optional[cute.Tensor] = None, + fm_batch_idx: cutlass.Int32 = 0, + fm_head_idx: cutlass.Int32 = 0, + has_lte: cutlass.Constexpr[bool] = False, + has_uts: cutlass.Constexpr[bool] = False, + has_ute: cutlass.Constexpr[bool] = False, ) -> None: """ Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q. + sStartEndRowIndices layout: (tile_n, 2). + [:, 0] = LTS (all cases) + [:, 1] = LTE (num_vec==2 causal or num_vec==4) or UTE (num_vec==2 non-causal) + For num_vec==4 (has_uts), UTS and UTE are read from gmem via startend_row_indices + because smem cannot fit 4 columns at the 228KB limit. + FlashMask masking conditions: + num_vec==1, causal: row >= LTS + num_vec==2, causal: row >= LTS AND row < LTE + num_vec==2, non-causal: row >= LTS OR row < UTE + num_vec==4: (row >= LTS AND row < LTE) OR (row >= UTS AND row < UTE) """ assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True" ROW = 0 if const_expr(not self.swap_AB) else 1 COL = 1 if const_expr(not self.swap_AB) else 0 thr_col_offset = tScS_t2r[0][COL] seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset - #cute.printf('seqlenk_col_limit: %d, thr_col_offset: %d, t0ScS_t2r[0][COL]: %d, %d', seqlenk_col_limit, thr_col_offset, t0ScS_t2r[0][COL], t0ScS_t2r[32][COL]) - #cute.print_tensor(t0ScS_t2r) if const_expr(not mask_causal and not mask_local): if const_expr(mask_seqlen): if t0ScS_t2r[0][COL] >= seqlenk_col_limit: @@ -530,15 +544,44 @@ def apply_mask_sm100_transposed( acc_S[i] = -cutlass.Float32.inf # FlashMask if partially_masked: - for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): - lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m - ute = sStartEndRowIndices[tScS_t2r[i][COL], 1] - m_block * self.tile_m - acc_S[i] = ( - -cutlass.Float32.inf if tScS_t2r[i][ROW] >= lts else acc_S[i] - ) - acc_S[i] = ( - -cutlass.Float32.inf if tScS_t2r[i][ROW] < ute else acc_S[i] - ) + if const_expr(has_uts): + # num_vec==4: (row >= LTS AND row < LTE) OR (row >= UTS AND row < UTE) + # LTS, LTE from smem; UTS, UTE from gmem via startend_row_indices + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m + lte = sStartEndRowIndices[tScS_t2r[i][COL], 1] - m_block * self.tile_m + # Guard gmem access: when seqlen_k is not divisible by tile_n, + # the last n_block may have out-of-bound columns. + col_idx = n_block * self.tile_n + tScS_t2r[i][COL] + uts = 0 + ute = 0 + if col_idx < self.seqlen_k: + uts = startend_row_indices[fm_batch_idx, fm_head_idx, col_idx, 2] - m_block * self.tile_m + ute = startend_row_indices[fm_batch_idx, fm_head_idx, col_idx, 3] - m_block * self.tile_m + acc_S[i] = ( + -cutlass.Float32.inf + if (tScS_t2r[i][ROW] >= lts and tScS_t2r[i][ROW] < lte) or (tScS_t2r[i][ROW] >= uts and tScS_t2r[i][ROW] < ute) + else acc_S[i] + ) + elif const_expr(has_ute): + # num_vec==2, non-causal: row >= LTS OR row < UTE + # UTE stored in smem [:, 1] + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m + ute = sStartEndRowIndices[tScS_t2r[i][COL], 1] - m_block * self.tile_m + acc_S[i] = ( + -cutlass.Float32.inf if tScS_t2r[i][ROW] >= lts else acc_S[i] + ) + acc_S[i] = ( + -cutlass.Float32.inf if tScS_t2r[i][ROW] < ute else acc_S[i] + ) + else: + # num_vec==1: row >= LTS + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m + acc_S[i] = ( + -cutlass.Float32.inf if tScS_t2r[i][ROW] >= lts else acc_S[i] + ) else: # Causal or local thr_row_offset = tScS_t2r[0][ROW] @@ -548,9 +591,6 @@ def apply_mask_sm100_transposed( if const_expr(mask_causal): col0 = t0ScS_t2r[0][COL] row_limit_top = col0 - causal_row_offset - # tidx = cute.arch.thread_idx()[0] % 256 - # if tidx < 32: - # cute.printf("tidx = {}, {} {}, {} {}, col0 = {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1], col0) if const_expr(mask_seqlen): # If col is beyond the column limit, we want to mask out the entire # column, by setting row limit to be self.tile_m. @@ -568,12 +608,21 @@ def apply_mask_sm100_transposed( mask_r2p_transposed(acc_S, row_limit_top, num_rep) if partially_masked: - # FlashMask - for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): - lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m - lte = sStartEndRowIndices[tScS_t2r[i][COL], 1] - m_block * self.tile_m - acc_S[i] = ( - -cutlass.Float32.inf if tScS_t2r[i][ROW] >= lts and tScS_t2r[i][ROW] < lte else acc_S[i] - ) + # FlashMask (causal: has_uts is never True since num_vec==4 is always non-causal) + if const_expr(has_lte): + # num_vec==2, causal: row >= LTS AND row < LTE + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m + lte = sStartEndRowIndices[tScS_t2r[i][COL], 1] - m_block * self.tile_m + acc_S[i] = ( + -cutlass.Float32.inf if tScS_t2r[i][ROW] >= lts and tScS_t2r[i][ROW] < lte else acc_S[i] + ) + else: + # num_vec==1: row >= LTS + for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True): + lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m + acc_S[i] = ( + -cutlass.Float32.inf if tScS_t2r[i][ROW] >= lts else acc_S[i] + ) else: assert False, "Local masking isn't supported yet"