diff --git a/generate_startend_row_indices.py b/generate_startend_row_indices.py index 5209376..07d1e60 100644 --- a/generate_startend_row_indices.py +++ b/generate_startend_row_indices.py @@ -1,62 +1,84 @@ -import paddle +# import paddle +import torch import numpy as np def startend_row_indices_to_attn_bias(startend_row_indices, seqlen_q, nheads, dtype, causal=True): if startend_row_indices is None: return None + if not isinstance(startend_row_indices, torch.Tensor): + startend_row_indices = torch.tensor(startend_row_indices) + bz, num_head, seqlen_k, bound_num = startend_row_indices.shape assert nheads % num_head == 0 - m = paddle.zeros((bz, num_head, seqlen_q, seqlen_k), dtype=dtype) + device = startend_row_indices.device + + # 1. 将索引转到 CPU Numpy 以加速循环 + startend_cpu = startend_row_indices.detach().cpu().numpy() + + # 2. 初始化 CPU Numpy 数组 (关键修改: 使用 m_cpu 而不是 m) + # float32 兼容后续的 bfloat16/float16 + m_cpu = np.zeros((bz, num_head, seqlen_q, seqlen_k), dtype=np.float32) + has_end = (causal and bound_num == 2) or ((not causal) and bound_num == 4) + inf_value = float("-inf") + + # 3. 在 CPU 上进行循环赋值 (速度快) for bi in range(bz): for hi in range(num_head): for j in range(seqlen_k): - downstart = startend_row_indices[bi, hi, j, 0] + downstart = startend_cpu[bi, hi, j, 0] + + # 注意:这里全部操作 m_cpu if has_end: - downend = startend_row_indices[bi, hi, j, 1] - m[bi, hi, downstart:downend, j] = -np.inf + downend = startend_cpu[bi, hi, j, 1] + m_cpu[bi, hi, downstart:downend, j] = inf_value else: - m[bi, hi, downstart:, j] = -np.inf + m_cpu[bi, hi, downstart:, j] = inf_value + if causal: - # from flash-attention 2.1 and in flash-attention 3, If seqlen_q != seqlen_k and causal=True, - # the causal mask is aligned to the bottom right corner of the attention matrix, - # instead of the top-left corner. - # See: https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag - m[bi, hi, :max(0, j - (seqlen_k - seqlen_q)), j] = -np.inf + # FlashAttention Causal 逻辑 + row_limit = max(0, j - (seqlen_k - seqlen_q)) + m_cpu[bi, hi, :row_limit, j] = inf_value else: if has_end: - upstart = startend_row_indices[bi, hi, j, 2] - upend = startend_row_indices[bi, hi, j, 3] - m[bi, hi, upstart:upend, j] = -np.inf + upstart = startend_cpu[bi, hi, j, 2] + upend = startend_cpu[bi, hi, j, 3] + m_cpu[bi, hi, upstart:upend, j] = inf_value else: - upend = startend_row_indices[bi, hi, j, 1] - m[bi, hi, :upend, j] = -np.inf - m = paddle.repeat_interleave(x=m, repeats=nheads // num_head, axis=1) + upend = startend_cpu[bi, hi, j, 1] + m_cpu[bi, hi, :upend, j] = inf_value + + # 4. 最后一次性转回 GPU Tensor + m = torch.tensor(m_cpu, dtype=dtype, device=device) + m = torch.repeat_interleave(m, repeats=nheads // num_head, dim=1) return m def generate_none_mask(batch_size, seqlen_q, seqlen_k, h, causal=True): return None, causal def generate_sliding_window_mask(batch_size, seqlen_q, seqlen_k, h, window_size=None): - if window_size == None: + if window_size is None: window_size = 1024 if seqlen_k != 8192: window_size = int(window_size * (seqlen_k / 8192)) print(f"{seqlen_k=}, auto setting window_size to {window_size}") - startend_row_indices = paddle.arange( - window_size, seqlen_k + window_size, dtype="int32" - ).reshape((1, 1, seqlen_k, 1)) - startend_row_indices = paddle.clip( + # 生成从 window_size 开始的索引 + startend_row_indices = torch.arange( + window_size, seqlen_k + window_size, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1) + + # paddle.clip -> torch.clamp + startend_row_indices = torch.clamp( startend_row_indices, max=seqlen_q - ).repeat_interleave(batch_size, 0) + ).repeat_interleave(batch_size, dim=0) - causal=True + causal = True return startend_row_indices, causal def generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): # TODO: this seems buggy, to be fixed - if doc_seqlens == None: + if doc_seqlens is None: doc_seqlens = [2538, 1742, 3213] if seqlen_k != 8192: doc_seqlens = [int(doc_seqlen * (seqlen_k / 8192)) for doc_seqlen in doc_seqlens] @@ -69,15 +91,21 @@ def generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens seq_cusums = np.cumsum(doc_seqlens) startend_row_indices = np.repeat(seq_cusums, doc_seqlens) - startend_row_indices = paddle.to_tensor(startend_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) - startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + # Paddle: paddle.to_tensor(...).reshape(...).repeat_interleave(...) + # PyTorch: torch.tensor(...).reshape(...).repeat_interleave(...) + startend_row_indices = torch.tensor( + startend_row_indices, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) + + # paddle.clip -> torch.clamp + startend_row_indices = torch.clamp(startend_row_indices, max=seqlen_q) causal = True return startend_row_indices, causal def generate_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): # TODO: this seems buggy, to be fixed - if doc_seqlens == None: + if doc_seqlens is None : doc_seqlens = [2538, 1742, 3213] if seqlen_k != 8192: doc_seqlens = [int(doc_seqlen * (seqlen_k / 8192)) for doc_seqlen in doc_seqlens] @@ -93,7 +121,7 @@ def generate_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): cur_len_so_far = doc_seqlens[0] for i in range(len(doc_seqlens)): down_left_row_indices.extend([cur_len_so_far] * doc_seqlens[i]) - if i < len(doc_seqlens) -1: + if i < len(doc_seqlens) - 1: cur_len_so_far += doc_seqlens[i+1] if padding > 0: down_left_row_indices.extend([cur_len_so_far] * padding) @@ -101,21 +129,33 @@ def generate_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): cur_len_so_far = 0 for i in range(len(doc_seqlens)): up_right_row_indices.extend([cur_len_so_far] * doc_seqlens[i]) - if i < len(doc_seqlens) -1: + if i < len(doc_seqlens) - 1: cur_len_so_far += doc_seqlens[i+1] if padding > 0: up_right_row_indices.extend([cur_len_so_far] * padding) - down_left_row_indices = paddle.to_tensor(down_left_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) - up_right_row_indices = paddle.to_tensor(up_right_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) - startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) - startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + # Paddle -> PyTorch 转换 + down_left_row_indices = torch.tensor( + down_left_row_indices, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) + + up_right_row_indices = torch.tensor( + up_right_row_indices, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) + + # paddle.concat -> torch.cat + startend_row_indices = torch.cat( + [down_left_row_indices, up_right_row_indices], dim=-1 + ) + + # paddle.clip -> torch.clamp + startend_row_indices = torch.clamp(startend_row_indices, max=seqlen_q) causal = False return startend_row_indices, causal def generate_share_question_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): - if doc_seqlens == None: + if doc_seqlens is None: doc_seqlens = [2538, 1742, 3213] if seqlen_k != 8192: doc_seqlens = [int(doc_seqlen * (seqlen_k / 8192)) for doc_seqlen in doc_seqlens] @@ -129,7 +169,7 @@ def generate_share_question_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens= assert len(doc_seqlens) >= 3 padding = seqlen_k - total_seqlen - #startend_row_indices = [S] * doc_seq_lens[0] + # startend_row_indices = [S] * doc_seq_lens[0] startend_row_indices = [total_seqlen] * doc_seqlens[0] cur_len_so_far = doc_seqlens[0] @@ -140,14 +180,19 @@ def generate_share_question_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens= if padding > 0: startend_row_indices.extend([cur_len_so_far] * padding) - startend_row_indices = paddle.to_tensor(startend_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) - startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + # Paddle -> PyTorch 转换 + startend_row_indices = torch.tensor( + startend_row_indices, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) + + # paddle.clip -> torch.clamp + startend_row_indices = torch.clamp(startend_row_indices, max=seqlen_q) causal = True return startend_row_indices, causal def generate_global_sliding_window_mask(batch_size, seqlen_q, seqlen_k, h, global_token=16, window_size=None): - if window_size == None: + if window_size is None: window_size = (512, 512) if seqlen_k != 8192: window_size = tuple(int(ws * (seqlen_k / 8192)) for ws in window_size) @@ -155,38 +200,46 @@ def generate_global_sliding_window_mask(batch_size, seqlen_q, seqlen_k, h, globa assert len(window_size) == 2 left_window_size, right_window_size = window_size - down_left_start_row_indices = [] - down_left_end_row_indices = [] - up_right_start_row_indices = [] - up_right_end_row_indices = [] - - down_left_start_row_indices = paddle.arange( - left_window_size + 1, seqlen_k + left_window_size + 1, dtype="int32" - ).clip(max=seqlen_q) + # 1. Down Left Start + down_left_start_row_indices = torch.arange( + left_window_size + 1, seqlen_k + left_window_size + 1, dtype=torch.int32 + ).clamp(max=seqlen_q) + down_left_start_row_indices[:global_token] = seqlen_q - down_left_start_row_indices = down_left_start_row_indices.reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + down_left_start_row_indices = down_left_start_row_indices.reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) - down_left_end_row_indices = paddle.full([seqlen_k], seqlen_q, dtype="int32").reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + # 2. Down Left End + down_left_end_row_indices = torch.full( + (seqlen_k,), seqlen_q, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) - up_right_start_row_indices = paddle.full([seqlen_k], global_token, dtype="int32") + # 3. Up Right Start + up_right_start_row_indices = torch.full( + (seqlen_k,), global_token, dtype=torch.int32 + ) up_right_start_row_indices[:global_token+right_window_size+1] = 0 - up_right_start_row_indices = up_right_start_row_indices.reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + up_right_start_row_indices = up_right_start_row_indices.reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) - up_right_end_row_indices = paddle.arange( - -right_window_size, seqlen_k - right_window_size, dtype="int32" + # 4. Up Right End + up_right_end_row_indices = torch.arange( + -right_window_size, seqlen_k - right_window_size, dtype=torch.int32 ) up_right_end_row_indices[:global_token+right_window_size+1] = 0 - up_right_end_row_indices = up_right_end_row_indices.reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + up_right_end_row_indices = up_right_end_row_indices.reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) - startend_row_indices = paddle.concat([down_left_start_row_indices, down_left_end_row_indices, up_right_start_row_indices, up_right_end_row_indices], axis=-1) - startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + # Concatenate & Clamp + startend_row_indices = torch.cat( + [down_left_start_row_indices, down_left_end_row_indices, up_right_start_row_indices, up_right_end_row_indices], + dim=-1 + ) + startend_row_indices = torch.clamp(startend_row_indices, max=seqlen_q) causal = False return startend_row_indices, causal def generate_causal_blockwise_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): # TODO: this seems buggy, to be fixed - if doc_seqlens == None: + if doc_seqlens is None: doc_seqlens = [2538, 1742, 3213] if seqlen_k != 8192: doc_seqlens = [int(doc_seqlen * (seqlen_k / 8192)) for doc_seqlen in doc_seqlens] @@ -204,14 +257,22 @@ def generate_causal_blockwise_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlen cur_len_so_far += doc_seqlens[i+1] if padding > 0: start_row_indices.extend([cur_len_so_far] * padding) - start_row_indices = paddle.to_tensor(start_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + + start_row_indices = torch.tensor( + start_row_indices, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) seq_cusums = np.cumsum(doc_seqlens) end_row_indices = [seq_cusums[-2]] * seq_cusums[-2] + [seq_cusums[-1]] * doc_seqlens[-1] + [seqlen_k] * padding - end_row_indices = paddle.to_tensor(end_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + + end_row_indices = torch.tensor( + end_row_indices, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) - startend_row_indices = paddle.concat([start_row_indices, end_row_indices], axis=-1) - startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + startend_row_indices = torch.cat( + [start_row_indices, end_row_indices], dim=-1 + ) + startend_row_indices = torch.clamp(startend_row_indices, max=seqlen_q) causal = True return startend_row_indices, causal @@ -220,8 +281,8 @@ def generate_prefix_lm_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seql """ tuple(prefix_length, seq_length) """ - if doc_seqlens == None: - doc_seqlens=[(1024, 2538), (1742, 1742), (512, 3213)] + if doc_seqlens is None: + doc_seqlens = [(1024, 2538), (1742, 1742), (512, 3213)] if seqlen_k != 8192: scale = seqlen_k / 8192 doc_seqlens = [tuple(int(v * scale) for v in pair) for pair in doc_seqlens] @@ -234,6 +295,7 @@ def generate_prefix_lm_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seql assert total_seqlen <= seqlen_k padding = seqlen_k - total_seqlen + # 1. Down Left Logic down_left_row_indices = [] cur_len_so_far = doc_seqlens[0][1] for i in range(len(doc_seqlens)): @@ -242,8 +304,12 @@ def generate_prefix_lm_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seql cur_len_so_far += doc_seqlens[i+1][1] if padding > 0: down_left_row_indices.extend([cur_len_so_far] * padding) - down_left_row_indices = paddle.to_tensor(down_left_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + + down_left_row_indices = torch.tensor( + down_left_row_indices, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) + # 2. Up Right Logic up_right_row_indices = [] cur_len_so_far = 0 for prefix_length, seq_length in doc_seqlens: @@ -251,11 +317,17 @@ def generate_prefix_lm_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seql cur_len_so_far += seq_length if padding > 0: up_right_row_indices.extend([total_seqlen] * padding) - up_right_row_indices = paddle.to_tensor(up_right_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + + up_right_row_indices = torch.tensor( + up_right_row_indices, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) - startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) + # 3. Concat & Clamp + startend_row_indices = torch.cat( + [down_left_row_indices, up_right_row_indices], dim=-1 + ) - startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + startend_row_indices = torch.clamp(startend_row_indices, max=seqlen_q) causal = False return startend_row_indices, causal @@ -264,16 +336,34 @@ def generate_prefix_lm_causal_mask(batch_size, seqlen_q, seqlen_k, h, prefix_len """ tuple(prefix_length, seq_length) """ - if prefix_length == None: + if prefix_length is None: prefix_length = 1024 if seqlen_k != 8192: prefix_length = int(prefix_length * (seqlen_k / 8192)) print(f"{seqlen_k=}, auto setting doc_seqlens to {prefix_length}") assert prefix_length <= seqlen_k - down_left_row_indices = paddle.full([seqlen_k], seqlen_k, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) - up_right_row_indices = paddle.to_tensor([0] * prefix_length + list(range(prefix_length, seqlen_k)), dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) - startend_row_indices = paddle.concat([down_left_row_indices, up_right_row_indices], axis=-1) - startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + + # 1. Down Left Logic + # paddle.full -> torch.full + down_left_row_indices = torch.full( + (seqlen_k,), seqlen_k, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) + + # 2. Up Right Logic + # paddle.to_tensor -> torch.tensor + up_right_row_indices = torch.tensor( + [0] * prefix_length + list(range(prefix_length, seqlen_k)), + dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) + + # 3. Concat & Clamp + # paddle.concat -> torch.cat + startend_row_indices = torch.cat( + [down_left_row_indices, up_right_row_indices], dim=-1 + ) + + # paddle.clip -> torch.clamp + startend_row_indices = torch.clamp(startend_row_indices, max=seqlen_q) causal = False return startend_row_indices, causal @@ -282,15 +372,18 @@ def generate_qk_sparse_mask(batch_size, seqlen_q, seqlen_k, h, maskout_pair=None """ tuple(offset, maskout_len) """ - if maskout_pair == None: - maskout_pair=[(1024, 538), (2358, 1700)] + if maskout_pair is None: + maskout_pair = [(1024, 538), (2358, 1700)] if seqlen_k != 8192: scale = seqlen_k / 8192 maskout_pair = [tuple(int(v * scale) for v in pair) for pair in maskout_pair] print(f"{seqlen_k=}, auto setting maskout_pair to {maskout_pair}") + start_row_indices = [] end_row_indices = [] last_offset = 0 + + # 纯 Python 逻辑保持不变 for offset, maskout_len in maskout_pair: assert offset > last_offset start_row_indices.extend([seqlen_k]*(offset-last_offset)) @@ -301,21 +394,37 @@ def generate_qk_sparse_mask(batch_size, seqlen_q, seqlen_k, h, maskout_pair=None last_offset = offset + maskout_len - last_offset <= seqlen_k + # 注意:原代码这里只是一个比较表达式,没有 assert,也没有赋值,可能是个笔误或者无效语句 + # last_offset <= seqlen_k + # 如果原意是断言,建议加上 assert: + # assert last_offset <= seqlen_k + start_row_indices.extend([seqlen_k]*(seqlen_k-last_offset)) end_row_indices.extend([seqlen_k]*(seqlen_k-last_offset)) - start_row_indices = paddle.to_tensor(start_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) - end_row_indices = paddle.to_tensor(end_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) - startend_row_indices = paddle.concat([start_row_indices, end_row_indices], axis=-1) - startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + # Tensor 转换 + start_row_indices = torch.tensor( + start_row_indices, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) + + end_row_indices = torch.tensor( + end_row_indices, dtype=torch.int32 + ).reshape(1, 1, seqlen_k, 1).repeat_interleave(batch_size, dim=0) + + # paddle.concat -> torch.cat + startend_row_indices = torch.cat( + [start_row_indices, end_row_indices], dim=-1 + ) + + # paddle.clip -> torch.clamp + startend_row_indices = torch.clamp(startend_row_indices, max=seqlen_q) causal = True return startend_row_indices, causal def generate_random_eviction_mask(batch_size, seqlen_q, seqlen_k, h, start_row=None): # np.random.seed(0) - if start_row == None: + if start_row is None: start_row = 4096 if seqlen_k != 8192: start_row = int(start_row * (seqlen_k / 8192)) @@ -329,7 +438,15 @@ def generate_random_eviction_mask(batch_size, seqlen_q, seqlen_k, h, start_row=N mask_pos = np.concatenate([mask_pos[mask_pos < index - 1], mask_pos[mask_pos >= index - 1]]) start_rows[mask_pos] = index start_rows_list.append(start_rows) - startend_row_indices = paddle.to_tensor(start_rows_list, dtype=paddle.int32).reshape((batch_size, h, seqlen_k, 1)) - startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + + # Paddle: paddle.to_tensor(list).reshape(...) + # PyTorch: 建议先转 np.array 再转 tensor 以避免潜在的效率问题 + startend_row_indices = torch.tensor( + np.array(start_rows_list), dtype=torch.int32 + ).reshape(batch_size, h, seqlen_k, 1) + + # paddle.clip -> torch.clamp + startend_row_indices = torch.clamp(startend_row_indices, max=seqlen_q) + causal = True return startend_row_indices, causal diff --git a/test_flashmask.py b/test_flashmask.py index ab98115..8fa1bb3 100644 --- a/test_flashmask.py +++ b/test_flashmask.py @@ -3,8 +3,21 @@ import itertools import pytest from einops import rearrange, repeat -import paddle -from paddle.nn.functional.flash_attention import flashmask_attention +# import paddle + +import sys +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.dirname(current_dir) +if parent_dir not in sys.path: + sys.path.append(parent_dir) +try: + from flashmask_interface import flashmask_attention +except Exception as e: + import traceback + traceback.print_exc() + raise + +import torch from generate_startend_row_indices import ( startend_row_indices_to_attn_bias, generate_none_mask, @@ -63,7 +76,8 @@ def generate_shapes(): batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices ) -@pytest.mark.parametrize("dtype", [paddle.bfloat16]) +# @pytest.mark.parametrize("dtype", [paddle.bfloat16]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) # 使用 torch.bfloat16 @pytest.mark.parametrize("fa_version", [3]) @pytest.mark.parametrize("d, dv", [ @@ -80,8 +94,8 @@ def generate_shapes(): @pytest.mark.parametrize( "gen_startend_row_indices", [ - partial(generate_none_mask, causal=False), # full - partial(generate_none_mask, causal=True), # causal + # partial(generate_none_mask, causal=False), # full + # partial(generate_none_mask, causal=True), # causal partial(generate_sliding_window_mask), # sliding window partial(generate_causal_document_mask), # causal document mask partial(generate_document_mask), # document mask @@ -97,30 +111,38 @@ def generate_shapes(): def test_flashmask( batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, d, dv, nheads_startend_row_indices, fa_version, dtype, gen_startend_row_indices, softcap=0.0 ): - paddle.seed(2024) + torch.manual_seed(2024) + # paddle.seed(2024) assert nheads % nheads_kv == 0 - q_ref = paddle.randn(shape=[batch_size, seqlen_q, nheads, d], dtype=dtype) - k_ref = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, d], dtype=dtype) - v_ref = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, dv], dtype=dtype) - q_ref.stop_gradient = False - k_ref.stop_gradient = False - v_ref.stop_gradient = False + q_ref = torch.randn(batch_size, seqlen_q, nheads, d, dtype=dtype, device='cuda', requires_grad=True) + k_ref = torch.randn(batch_size, seqlen_k, nheads_kv, d, dtype=dtype, device='cuda', requires_grad=True) + v_ref = torch.randn(batch_size, seqlen_k, nheads_kv, dv, dtype=dtype, device='cuda', requires_grad=True) - q_bf16, k_bf16, v_bf16 = [x.detach().clone() for x in (q_ref, k_ref, v_ref)] - q_bf16.stop_gradient = False - k_bf16.stop_gradient = False - v_bf16.stop_gradient = False + q_bf16 = q_ref.detach().clone().requires_grad_(True) + k_bf16 = k_ref.detach().clone().requires_grad_(True) + v_bf16 = v_ref.detach().clone().requires_grad_(True) - q, k, v = [x.detach().clone() for x in (q_ref, k_ref, v_ref)] - - q.stop_gradient = False - k.stop_gradient = False - v.stop_gradient = False + q = q_ref.detach().clone().requires_grad_(True) + k = k_ref.detach().clone().requires_grad_(True) + v = v_ref.detach().clone().requires_grad_(True) startend_row_indices, causal = gen_startend_row_indices(batch_size, seqlen_q, seqlen_k, nheads_startend_row_indices) + if startend_row_indices is None: + pytest.skip("Skipping because startend_row_indices is None") + + if startend_row_indices is not None: + if not isinstance(startend_row_indices, torch.Tensor): + # 如果是 numpy 或 paddle tensor (先转numpy) + if hasattr(startend_row_indices, 'numpy'): + startend_row_indices = torch.tensor(startend_row_indices.numpy(), device='cuda', dtype=torch.int32) + else: + startend_row_indices = torch.tensor(startend_row_indices, device='cuda', dtype=torch.int32) + else: + startend_row_indices = startend_row_indices.to('cuda', dtype=torch.int32) + if startend_row_indices is None and causal and d in (80, 192): pytest.skip(f"Skipping because running headdim {d} with flash_attn in causal mask") @@ -144,22 +166,16 @@ def test_flashmask( reorder_ops=True ) - # # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() assert softcap == 0.0 rtol = 2 if softcap == 0.0 else 3 - print(f"Paddle naive bf16 Output max diff: {(out_bf16 - out_ref).abs().max().item()}") - print(f"Paddle naive bf16 Output mean diff: {(out_bf16 - out_ref).abs().mean().item()}") + print(f"Torch naive bf16 Output max diff: {(out_bf16 - out_ref).abs().max().item()}") + print(f"Torch naive bf16 Output mean diff: {(out_bf16 - out_ref).abs().mean().item()}") - if fa_version == 2: - paddle.set_flags({'FLAGS_flash_attn_version': 2}) - elif fa_version == 3: - paddle.set_flags({'FLAGS_flash_attn_version': 3}) - else: - raise ValueError( - f"Invalid flash attention version: {fa_version}" - ) + # 确保 startend_row_indices 在 CUDA 上且为 int32 + if isinstance(startend_row_indices, torch.Tensor): + startend_row_indices = startend_row_indices.to('cuda', dtype=torch.int32) out, lse = flashmask_attention( q, @@ -171,16 +187,11 @@ def test_flashmask( ) print(f"flashmask Output max diff: {(out - out_ref).abs().max().item()}") print(f"flashmask Output mean diff: {(out - out_ref).abs().mean().item()}") - # if not causal: - # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") - # breakpoint() - - # Check that FlashAttention's numerical error is at most twice the numerical error - # of a Pytorch implementation. assert (out - out_ref).abs().max().item() <= rtol * (out_bf16 - out_ref).abs().max().item() + fwd_atol - g = paddle.randn(shape=out.shape, dtype=out.dtype) + # #Backward Check + g = torch.randn_like(out) out.backward(g) out_ref.backward(g) out_bf16.backward(g) @@ -192,12 +203,12 @@ def test_flashmask( print(f"flashmask dK mean diff: {(k.grad - k_ref.grad).abs().mean().item()}") print(f"flashmask dV mean diff: {(v.grad - v_ref.grad).abs().mean().item()}") - print(f"Paddle naive bf16 dQ max diff: {(q_bf16.grad - q_ref.grad).abs().max().item()}") - print(f"Paddle naive bf16 dK max diff: {(k_bf16.grad - k_ref.grad).abs().max().item()}") - print(f"Paddle naive bf16 dV max diff: {(v_bf16.grad - v_ref.grad).abs().max().item()}") - print(f"Paddle naive bf16 dQ mean diff: {(q_bf16.grad - q_ref.grad).abs().mean().item()}") - print(f"Paddle naive bf16 dK mean diff: {(k_bf16.grad - k_ref.grad).abs().mean().item()}") - print(f"Paddle naive bf16 dV mean diff: {(v_bf16.grad - v_ref.grad).abs().mean().item()}") + print(f"Torch naive bf16 dQ max diff: {(q_bf16.grad - q_ref.grad).abs().max().item()}") + print(f"Torch naive bf16 dK max diff: {(k_bf16.grad - k_ref.grad).abs().max().item()}") + print(f"Torch naive bf16 dV max diff: {(v_bf16.grad - v_ref.grad).abs().max().item()}") + print(f"Torch naive bf16 dQ mean diff: {(q_bf16.grad - q_ref.grad).abs().mean().item()}") + print(f"Torch naive bf16 dK mean diff: {(k_bf16.grad - k_ref.grad).abs().mean().item()}") + print(f"Torch naive bf16 dV mean diff: {(v_bf16.grad - v_ref.grad).abs().mean().item()}") dq_atol = 2 * (q_ref.grad + 0.3 - 0.3 - q_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) assert (q.grad - q_ref.grad).abs().max().item() <= rtol * (q_bf16.grad - q_ref.grad).abs().max().item() + dq_atol diff --git a/test_util.py b/test_util.py index f4120bd..37e592e 100644 --- a/test_util.py +++ b/test_util.py @@ -1,6 +1,7 @@ import math from einops import repeat, rearrange -import paddle +# import paddle +import torch from einops import rearrange, repeat import numpy as np @@ -13,13 +14,21 @@ def construct_local_mask( query_padding_mask=None, key_padding_mask=None, key_leftpad=None, + device=None, ): - row_idx = rearrange(paddle.arange(seqlen_q, dtype=paddle.int64), "s -> s 1") - col_idx = paddle.arange(seqlen_k, dtype=paddle.int64) + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + + # row_idx = rearrange(paddle.arange(seqlen_q, dtype=paddle.int64), "s -> s 1") + # col_idx = paddle.arange(seqlen_k, dtype=paddle.int64) + + row_idx = rearrange(torch.arange(seqlen_q, dtype=torch.int64, device=device), "s -> s 1") + col_idx = torch.arange(seqlen_k, dtype=torch.int64, device=device) if key_leftpad is not None: key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) - col_idx = paddle.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + val_inf = torch.tensor(2**32, device=device, dtype=torch.int64) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) sk = ( seqlen_k if key_padding_mask is None @@ -33,10 +42,10 @@ def construct_local_mask( if window_size[0] < 0: return col_idx > row_idx + sk - sq + window_size[1] else: - sk = paddle.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk - return paddle.logical_or( - col_idx > paddle.minimum(row_idx + sk - sq + window_size[1], sk), - paddle.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length), + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length), ) def attention_ref( @@ -85,54 +94,63 @@ def attention_ref( window_size = (window_size[0], 0) dtype_og = q.dtype if upcast: - q = paddle.cast(q, paddle.float32) - k = paddle.cast(k, paddle.float32) - v = paddle.cast(v, paddle.float32) + q = q.to(torch.float32) + k = k.to(torch.float32) + v = v.to(torch.float32) if qv is not None: - qv = paddle.cast(qv, paddle.float32) + qv = qv.to(torch.float32) if q_descale is not None: - assert False - q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) - q = (q.cast(paddle.float32) * q_descale).cast(q.dtype) - qv = (qv.cast(paddle.float32) * q_descale).cast(qv.dtype) if qv is not None else None + assert False + # Paddle: repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + g = q.shape[2] // k.shape[2] + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=g) + q = (q.to(torch.float32) * q_descale).to(q.dtype) + if qv is not None: + qv = (qv.to(torch.float32) * q_descale).to(qv.dtype) if k_descale is not None: assert False - k = (k.cast(paddle.float32) * rearrange(k_descale, "b h -> b 1 h 1")).cast(k.dtype) + k_descale_r = rearrange(k_descale, "b h -> b 1 h 1") + k = (k.to(torch.float32) * k_descale_r).to(k.dtype) if v_descale is not None: assert False - v = (v.cast(paddle.float32) * rearrange(v_descale, "b h -> b 1 h 1")).cast(v.dtype) + v_descale_r = rearrange(v_descale, "b h -> b 1 h 1") + v = (v.to(torch.float32) * v_descale_r).to(v.dtype) seqlen_q, seqlen_k = q.shape[1], k.shape[1] # (batch_size, seqlen, nheads, head_dim) -> (batch_size, nheads, seqlen, head_dim) - q = paddle.transpose(q, [0, 2, 1, 3]) - k = paddle.transpose(k, [0, 2, 1, 3]) - v = paddle.transpose(v, [0, 2, 1, 3]) + # q = paddle.transpose(q, [0, 2, 1, 3]) + # k = paddle.transpose(k, [0, 2, 1, 3]) + # v = paddle.transpose(v, [0, 2, 1, 3]) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) k = repeat(k, "b h s d -> b (h g) s d", g=q.shape[1] // k.shape[1]) v = repeat(v, "b h s d -> b (h g) s d", g=q.shape[1] // v.shape[1]) if attn_bias is not None: - attn_bias = repeat(attn_bias, "b h s d -> b (h g) s d ", g=q.shape[1] // attn_bias.shape[1]) + attn_bias = attn_bias.to(q.device) + if attn_bias.ndim == 4 and attn_bias.shape[1] != q.shape[1]: + attn_bias = repeat(attn_bias, "b h s d -> b (h g) s d ", g=q.shape[1] // attn_bias.shape[1]) d = q.shape[-1] dv = v.shape[-1] softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) if not reorder_ops: - scores = paddle.matmul(q * softmax_scale, k, transpose_y=True) + scores = torch.matmul(q * softmax_scale, k.transpose(-2, -1)) else: - scores = paddle.matmul(q, k * softmax_scale, transpose_y=True) + scores = torch.matmul(q, k.transpose(-2, -1) * softmax_scale) if qv is not None: assert False - scores = scores + paddle.matmul(qv * softmax_scale, v, transpose_y=True) + scores = scores + torch.matmul(qv * softmax_scale, v.transpose(-2, -1)) if softcap > 0: - assert False - scores = paddle.tanh(scores / softcap) * softcap + scores = torch.tanh(scores / softcap) * softcap if key_padding_mask is not None: assert False @@ -148,51 +166,53 @@ def attention_ref( query_padding_mask, key_padding_mask, key_leftpad=key_leftpad, + device=q.device ) if attention_chunk > 0: assert False - chunk_mask = construct_chunk_mask( - seqlen_q, - seqlen_k, - attention_chunk, - query_padding_mask, - key_padding_mask, - key_leftpad=key_leftpad, - device=q.device, - ) - local_mask = paddle.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + # chunk_mask = construct_chunk_mask( + # seqlen_q, + # seqlen_k, + # attention_chunk, + # query_padding_mask, + # key_padding_mask, + # key_leftpad=key_leftpad, + # device=q.device, + # ) + # local_mask = paddle.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask if local_mask is not None: scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: - scores = scores + attn_bias.cast(paddle.float32) + scores = scores + attn_bias.to(torch.float32) # print("scores:", scores[0,0,0,:]) # when all values in a line of attn_bias are -inf, setting value in this line to a very small value # to prevend softmax giving nan output - all_inf_mask = (attn_bias == -np.inf).all(axis=-1, keepdim=True) - scores = paddle.where(all_inf_mask, paddle.full_like(scores, -1e9), scores) + all_inf_mask = (attn_bias == -float('inf')).all(dim=-1, keepdim=True) + scores = torch.where(all_inf_mask, torch.full_like(scores, -1e9), scores) - attention = paddle.nn.functional.softmax(scores, axis=-1).cast(v.dtype) + attention = torch.softmax(scores, dim=-1).to(v.dtype) if attn_bias is not None: # when all values in a line of attn_bias are -inf, we setting value in this line to a very small value # to prevend softmax giving nan output, however, after softmax, values in this line become 1/seqlen, # so setting them to 0 after softmax - attention = paddle.where(all_inf_mask, paddle.zeros_like(attention), attention) + attention = torch.where(all_inf_mask, torch.zeros_like(attention), attention) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: assert False - attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + # attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) # Without this we might get NaN in dv if key_padding_mask is not None: assert False - attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + # attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + # Some rows might be completely masked out so we fill them with zero instead of NaN if local_mask is not None: - attention = attention.masked_fill(paddle.all(local_mask, axis=-1, keepdim=True), 0.0) + attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) dropout_scaling = 1.0 / (1 - dropout_p) # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling # output = paddle.matmul(attention_drop, v, transpose_y=True) @@ -202,53 +222,60 @@ def attention_ref( else: attention_drop = attention if intermediate_dtype is not None: - attention_drop = attention_drop.cast(intermediate_dtype).cast(attention_drop.dtype) - output = paddle.matmul(attention_drop, v * dropout_scaling) - output = paddle.transpose(output, [0, 2, 1, 3]) + attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) + output = torch.matmul(attention_drop, v * dropout_scaling) + output = output.permute(0, 2, 1, 3) # Back to (b, s, h, d) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) - return output.cast(dtype=dtype_og), attention.cast(dtype=dtype_og) + return output.to(dtype_og), attention.to(dtype_og) #blockmask utils -def random_blockmask(shape, dtype='int32',is_causal=False, ref_q = None): +def random_blockmask(shape, dtype=torch.int32, is_causal=False, ref_q=None, device='cuda'): # 随机生成 0/1 mask - mask = paddle.randint(0, 2, shape, dtype=paddle.int32) + mask = torch.randint(0, 2, shape, dtype=dtype, device=device) B, S, Q, K = shape return mask def flashmask_to_densemask(startend_row_indices, seqlen_q, nheads, causal=True): if startend_row_indices is None: return None - bz, num_head, seqlen_k, bound_num = startend_row_indices.shape + if not isinstance(startend_row_indices, torch.Tensor): + startend_row_indices = torch.tensor(startend_row_indices, device='cuda') + startend_cpu = startend_row_indices.detach().cpu().numpy() + bz, num_head, seqlen_k, bound_num = startend_cpu.shape assert nheads % num_head == 0 - m = paddle.ones((bz, num_head, seqlen_q, seqlen_k), dtype=paddle.int32) + m_cpu = np.ones((bz, num_head, seqlen_q, seqlen_k), dtype=np.int32) has_end = (causal and bound_num == 2) or ((not causal) and bound_num == 4) for bi in range(bz): for hi in range(num_head): for j in range(seqlen_k): - downstart = startend_row_indices[bi, hi, j, 0] + downstart = startend_cpu[bi, hi, j, 0] if has_end: - downend = startend_row_indices[bi, hi, j, 1] - m[bi, hi, downstart:downend, j] = 0 + downend = startend_cpu[bi, hi, j, 1] + m_cpu[bi, hi, downstart:downend, j] = 0 else: - m[bi, hi, downstart:, j] = 0 + m_cpu[bi, hi, downstart:, j] = 0 if causal: # from flash-attention 2.1 and in flash-attention 3, If seqlen_q != seqlen_k and causal=True, # the causal mask is aligned to the bottom right corner of the attention matrix, # instead of the top-left corner. # See: https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag - m[bi, hi, :max(0, j - (seqlen_k - seqlen_q)), j] = 0 + row_limit = max(0, j - (seqlen_k - seqlen_q)) + #有疑问 + m_cpu[bi, hi, :row_limit, j] = 0 else: if has_end: - upstart = startend_row_indices[bi, hi, j, 2] - upend = startend_row_indices[bi, hi, j, 3] - m[bi, hi, upstart:upend, j] = 0 + upstart = startend_cpu[bi, hi, j, 2] + upend = startend_cpu[bi, hi, j, 3] + m_cpu[bi, hi, upstart:upend, j] = 0 else: - upend = startend_row_indices[bi, hi, j, 1] - m[bi, hi, :upend, j] = 0 - m = paddle.repeat_interleave(x=m, repeats=nheads // num_head, axis=1) - m = m.astype(paddle.bool) + upend = startend_cpu[bi, hi, j, 1] + m_cpu[bi, hi, :upend, j] = 0 + device = startend_row_indices.device if startend_row_indices.is_cuda else 'cuda' + m = torch.tensor(m_cpu, device=device, dtype=torch.int32) + m = torch.repeat_interleave(m, repeats=nheads // num_head, dim=1) + m = m.to(torch.bool) return m def blockmask_to_densemask(blockmask, q_len, k_len, dtype, causal=True): @@ -270,8 +297,8 @@ def blockmask_to_densemask(blockmask, q_len, k_len, dtype, causal=True): block_k = 128 # 1. 展开到[bs, s, q_len, k_len] - densemask = blockmask.astype(dtype).repeat_interleave(block_q, axis=2).repeat_interleave(block_k, axis=3) + densemask = blockmask.to(dtype).repeat_interleave(block_q, dim=2).repeat_interleave(block_k, dim=3) densemask = densemask[:, :, :q_len, :k_len] # print(densemask) - return densemask.astype(paddle.bool) + return densemask.to(torch.bool)