Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions flashmask/flash_mask/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ def _setup_smem_layout(self):
else:
self.sdKV_layout = cute.make_layout((self.tile_n * self.dK_reduce_ncol, 2))

# TODO(GuoxiaWang): 2 means only support flashmask startend_row_indices.shape[-1] <= 2
# 2 columns: [LTS, LTE or UTE] loaded into smem.
# For num_vec==4, UTS and UTE are read directly from gmem (smem budget is full).
self.sStartEndRowIndices_layout = cute.make_layout(
shape=(self.tile_n, 2),
stride=(1, self.tile_n),
Expand Down Expand Up @@ -402,6 +403,9 @@ def __call__(
self.ds_dtype = self.q_dtype

self.enable_flashmask = cutlass.const_expr(flashmask_info is not None)
self.has_lte = const_expr(flashmask_info is not None and flashmask_info.LTE_nblock_max is not None)
self.has_uts = const_expr(flashmask_info is not None and flashmask_info.UTS_nblock_max is not None)
self.has_ute = const_expr(flashmask_info is not None and flashmask_info.UTE_nblock_max is not None)

if const_expr(self.qhead_per_kvhead > 1):
assert self.dk_dtype.width == 32, "Must accumulate dK in float precision for GQA"
Expand Down Expand Up @@ -714,9 +718,9 @@ class SharedStorage:
cute.struct.MemRange[self.startend_row_indices_dtype, cute.cosize(self.sStartEndRowIndices_layout)],
64,
]
# sStartEndRowIndices_layout (128,4):(1,128)
# 128 * 4 = 512 * 4 = 2048
# 234496
# sStartEndRowIndices_layout (128,2):(1,128)
# 128 * 2 = 256 * 4 = 1024
# 233472

self.shared_storage = SharedStorage
#print("self.shared_storage.size_in_bytes()", self.shared_storage.size_in_bytes())
Expand Down Expand Up @@ -1216,6 +1220,7 @@ def kernel(
sStartEndRowIndices,
sFM_max_min,
flashmask_loaded_mbar_ptr,
mQ.shape[2],
)
cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr)

Expand Down Expand Up @@ -1461,7 +1466,8 @@ def load(
if tidx == 0 and self.debug_print:
cute.printf('n_block: %d, after load_step 0 ~ UTS: %d', n_block, m_block)
# Subtract 1 beforehand to use loop_start + 1 uniformly in the for loop.
loop_start = sFM_max_min[7] - 1
# Use max to advance past UTS region, avoiding double-loading when UTE_min <= UTS_max
loop_start = max(sFM_max_min[4], sFM_max_min[7] - 1)
else:
loop_start = sFM_max_min[7]

Expand Down Expand Up @@ -1667,22 +1673,38 @@ def load_fm(
sFM_max_min[6] = (UTE_nblock_max[n_block] - 1) // self.tile_m
sFM_max_min[7] = UTE_nblock_min[n_block] // self.tile_m

# sStartEndRowIndices layout is (tile_n, 2).
# For num_vec==4 (has_uts), UTS and UTE are read from gmem by the compute warp
# because smem budget is at the 228KB limit and cannot fit 4 columns.
#
# Possible masking conditions depending on num_vec:
# num_vec==1, causal: row >= LTS
# num_vec==2, causal: row >= LTS AND row < LTE
# num_vec==2, non-causal: row >= LTS OR row < UTE
# num_vec==4: (row >= LTS AND row < LTE) OR (row >= UTS AND row < UTE)
#
# smem column mapping:
# [:, 0] = LTS (all cases)
# [:, 1] = LTE (num_vec==2 causal or num_vec==4) or UTE (num_vec==2 non-causal)
#
# Default values ensure no masking when the bound is not supplied or column is OOB:
# [:, 0] = Int32.max (LTS default): "row >= Int32.max" is always false.
# [:, 1] = 0 (LTE/UTE default): "row < 0" is always false.
for i in cutlass.range_constexpr(ntimes_copy):
copy_offset = i * num_load_threads + tidx
sStartEndRowIndices[copy_offset, 0] = 2147483647
sStartEndRowIndices[copy_offset, 1] = 2147483647
sStartEndRowIndices[copy_offset, 0] = Int32.max # LTS default
sStartEndRowIndices[copy_offset, 1] = 0 # LTE/UTE default
if (copy_offset < self.tile_n and n_block * self.tile_n + copy_offset < seqlen_k):
LTS = flashmask_info.startend_row_indices[fm_batch_idx, fm_head_idx, None, 0]
sStartEndRowIndices[copy_offset, 0] = LTS[n_block * self.tile_n + copy_offset]
#assert const_expr(num_vec <= 2), "only support num_vec == 2 now"
if const_expr(flashmask_info.LTE_nblock_max is not None):
# num_vec==2 causal or num_vec==4: LTE at source index 1
LTE = flashmask_info.startend_row_indices[fm_batch_idx, fm_head_idx, None, 1]
sStartEndRowIndices[copy_offset, 1] = LTE[n_block * self.tile_n + copy_offset]
if const_expr(flashmask_info.UTE_nblock_max is not None):
elif const_expr(flashmask_info.UTE_nblock_max is not None):
# num_vec==2, non-causal: UTE at source index 1
UTE = flashmask_info.startend_row_indices[fm_batch_idx, fm_head_idx, None, 1]
sStartEndRowIndices[copy_offset, 1] = UTE[n_block * self.tile_n + copy_offset]
#cute.printf("%d, %d", copy_offset, sStartEndRowIndices[copy_offset, 0])
#cute.print_tensor(LTS)
cute.arch.sync_warp()

@cute.jit
Expand Down Expand Up @@ -1832,7 +1854,10 @@ def mma(
num_blocks = num_blocks + max(0, (loop_end - loop_start))
if tidx == 0 and self.debug_print:
cute.printf('after uts mma: n_block: %d, %d', n_block, num_blocks)
loop_start = sFM_max_min[7]
# Advance past UTS region to avoid double-counting when UTE_min <= UTS_max
loop_start = max(sFM_max_min[4] + 1, sFM_max_min[7])
else:
loop_start = sFM_max_min[7]

# UTE ~ LTS
#loop_end = m_block_max if m_block_max < sFM_max_min[0] + 1 else sFM_max_min[0] + 1
Expand Down Expand Up @@ -2089,6 +2114,7 @@ def compute_loop(
sStartEndRowIndices: cute.Tensor,
sFM_max_min: cute.Tensor,
flashmask_loaded_mbar_ptr: cute.Pointer,
num_heads: Int32,
):
sLSE_2D = cute.make_tensor(
sLSE.iterator,
Expand Down Expand Up @@ -2204,6 +2230,14 @@ def compute_loop(
seqlen, n_block // self.cluster_shape_mnk[0]
)
mask = AttentionMaskCls(seqlen.seqlen_q, seqlen.seqlen_k)
# For num_vec==4, compute fm_batch_idx/fm_head_idx for direct gmem access
# to UTS/UTE (not loaded into smem because 228KB smem budget is full).
fm_batch_idx_compute = Int32(0)
fm_head_idx_compute = Int32(0)
if const_expr(self.has_uts):
bsz, fm_heads, seqlen_k_fm, num_vec = flashmask_info.startend_row_indices.shape
fm_batch_idx_compute = batch_idx if bsz > 1 else 0
fm_head_idx_compute = head_idx // (num_heads // fm_heads)
# TODO: condition mask_seqlen
mask_fn = partial(
mask.apply_mask_sm100_transposed,
Expand All @@ -2214,6 +2248,12 @@ def compute_loop(
mask_causal=self.is_causal,
mask_local=self.is_local,
sStartEndRowIndices=sStartEndRowIndices,
startend_row_indices=flashmask_info.startend_row_indices if const_expr(self.has_uts) else None,
fm_batch_idx=fm_batch_idx_compute,
fm_head_idx=fm_head_idx_compute,
has_lte=self.has_lte,
has_uts=self.has_uts,
has_ute=self.has_ute,
)

# prefetch_LSE = not self.is_causal
Expand Down Expand Up @@ -2292,6 +2332,9 @@ def compute_loop(
if tidx == 0 and self.debug_print:
cute.printf('n_block: %d, after compute_step UTS_min ~ UTS_max: %d', n_block, m_block)

# Advance past UTS region to avoid double-processing when UTE_min <= UTS_max
loop_start = sFM_max_min[4] + 1 # UTS_max + 1

loop_start = max(loop_start, sFM_max_min[7]) # UTE_min
loop_end = min(sFM_max_min[6] + 1, m_block_max) # UTE_max
for m_block in cutlass.range(loop_start, loop_end, unroll=1):
Expand Down
1 change: 0 additions & 1 deletion flashmask/flash_mask/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,6 @@ def flashmask_attention(
if (
paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] == 4
and query.shape[-1] <= 128 and key.shape[-1] <= 128 and value.shape[-1] <= 128
and (startend_row_indices is None or startend_row_indices.shape[-1] != 4)
):
assert dropout == 0.0, (
"flashmask v4 does not support dropout"
Expand Down
91 changes: 70 additions & 21 deletions flashmask/flash_mask/cute/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,33 +512,76 @@ def apply_mask_sm100_transposed(
mask_local: cutlass.Constexpr,
sStartEndRowIndices: cute.Tensor,
partially_masked: bool,
startend_row_indices: Optional[cute.Tensor] = None,
fm_batch_idx: cutlass.Int32 = 0,
fm_head_idx: cutlass.Int32 = 0,
has_lte: cutlass.Constexpr[bool] = False,
has_uts: cutlass.Constexpr[bool] = False,
has_ute: cutlass.Constexpr[bool] = False,
) -> None:
"""
Backward pass: mask S = K @ Q.T where n_block tiles seqlen_k and m_block tiles seqlen_q.
sStartEndRowIndices layout: (tile_n, 2).
[:, 0] = LTS (all cases)
[:, 1] = LTE (num_vec==2 causal or num_vec==4) or UTE (num_vec==2 non-causal)
For num_vec==4 (has_uts), UTS and UTE are read from gmem via startend_row_indices
because smem cannot fit 4 columns at the 228KB limit.
FlashMask masking conditions:
num_vec==1, causal: row >= LTS
num_vec==2, causal: row >= LTS AND row < LTE
num_vec==2, non-causal: row >= LTS OR row < UTE
num_vec==4: (row >= LTS AND row < LTE) OR (row >= UTS AND row < UTE)
"""
assert not (mask_causal and mask_local), "mask_causal and mask_local cannot be both True"
ROW = 0 if const_expr(not self.swap_AB) else 1
COL = 1 if const_expr(not self.swap_AB) else 0
thr_col_offset = tScS_t2r[0][COL]
seqlenk_col_limit = self.seqlen_k - n_block * self.tile_n - thr_col_offset
#cute.printf('seqlenk_col_limit: %d, thr_col_offset: %d, t0ScS_t2r[0][COL]: %d, %d', seqlenk_col_limit, thr_col_offset, t0ScS_t2r[0][COL], t0ScS_t2r[32][COL])
#cute.print_tensor(t0ScS_t2r)
if const_expr(not mask_causal and not mask_local):
if const_expr(mask_seqlen):
if t0ScS_t2r[0][COL] >= seqlenk_col_limit:
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
acc_S[i] = -cutlass.Float32.inf
# FlashMask
if partially_masked:
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m
ute = sStartEndRowIndices[tScS_t2r[i][COL], 1] - m_block * self.tile_m
acc_S[i] = (
-cutlass.Float32.inf if tScS_t2r[i][ROW] >= lts else acc_S[i]
)
acc_S[i] = (
-cutlass.Float32.inf if tScS_t2r[i][ROW] < ute else acc_S[i]
)
if const_expr(has_uts):
# num_vec==4: (row >= LTS AND row < LTE) OR (row >= UTS AND row < UTE)
# LTS, LTE from smem; UTS, UTE from gmem via startend_row_indices
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m
lte = sStartEndRowIndices[tScS_t2r[i][COL], 1] - m_block * self.tile_m
# Guard gmem access: when seqlen_k is not divisible by tile_n,
# the last n_block may have out-of-bound columns.
col_idx = n_block * self.tile_n + tScS_t2r[i][COL]
uts = 0
ute = 0
if col_idx < self.seqlen_k:
uts = startend_row_indices[fm_batch_idx, fm_head_idx, col_idx, 2] - m_block * self.tile_m
ute = startend_row_indices[fm_batch_idx, fm_head_idx, col_idx, 3] - m_block * self.tile_m
acc_S[i] = (
-cutlass.Float32.inf
if (tScS_t2r[i][ROW] >= lts and tScS_t2r[i][ROW] < lte) or (tScS_t2r[i][ROW] >= uts and tScS_t2r[i][ROW] < ute)
else acc_S[i]
)
elif const_expr(has_ute):
# num_vec==2, non-causal: row >= LTS OR row < UTE
# UTE stored in smem [:, 1]
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m
ute = sStartEndRowIndices[tScS_t2r[i][COL], 1] - m_block * self.tile_m
acc_S[i] = (
-cutlass.Float32.inf if tScS_t2r[i][ROW] >= lts else acc_S[i]
)
acc_S[i] = (
-cutlass.Float32.inf if tScS_t2r[i][ROW] < ute else acc_S[i]
)
else:
# num_vec==1: row >= LTS
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m
acc_S[i] = (
-cutlass.Float32.inf if tScS_t2r[i][ROW] >= lts else acc_S[i]
)

else: # Causal or local
thr_row_offset = tScS_t2r[0][ROW]
Expand All @@ -548,9 +591,6 @@ def apply_mask_sm100_transposed(
if const_expr(mask_causal):
col0 = t0ScS_t2r[0][COL]
row_limit_top = col0 - causal_row_offset
# tidx = cute.arch.thread_idx()[0] % 256
# if tidx < 32:
# cute.printf("tidx = {}, {} {}, {} {}, col0 = {}", tidx, tScS_t2r[0][0], tScS_t2r[0][1], tScS_t2r[1][0], tScS_t2r[1][1], col0)
if const_expr(mask_seqlen):
# If col is beyond the column limit, we want to mask out the entire
# column, by setting row limit to be self.tile_m.
Expand All @@ -568,12 +608,21 @@ def apply_mask_sm100_transposed(
mask_r2p_transposed(acc_S, row_limit_top, num_rep)

if partially_masked:
# FlashMask
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m
lte = sStartEndRowIndices[tScS_t2r[i][COL], 1] - m_block * self.tile_m
acc_S[i] = (
-cutlass.Float32.inf if tScS_t2r[i][ROW] >= lts and tScS_t2r[i][ROW] < lte else acc_S[i]
)
# FlashMask (causal: has_uts is never True since num_vec==4 is always non-causal)
if const_expr(has_lte):
# num_vec==2, causal: row >= LTS AND row < LTE
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m
lte = sStartEndRowIndices[tScS_t2r[i][COL], 1] - m_block * self.tile_m
acc_S[i] = (
-cutlass.Float32.inf if tScS_t2r[i][ROW] >= lts and tScS_t2r[i][ROW] < lte else acc_S[i]
)
else:
# num_vec==1: row >= LTS
for i in cutlass.range(cute.size(acc_S.shape), unroll_full=True):
lts = sStartEndRowIndices[tScS_t2r[i][COL], 0] - m_block * self.tile_m
acc_S[i] = (
-cutlass.Float32.inf if tScS_t2r[i][ROW] >= lts else acc_S[i]
)
else:
assert False, "Local masking isn't supported yet"