diff --git a/benchmark_flashmask.py b/benchmark_flashmask.py index 548a69d..5f2283d 100644 --- a/benchmark_flashmask.py +++ b/benchmark_flashmask.py @@ -9,6 +9,8 @@ from flash_mask.cute.interface import flashmask_attention except (ImportError, ModuleNotFoundError): from paddle.nn.functional.flash_attention import flashmask_attention + +from flash_mask import flashmask_attention as flashmask_attention_interface import random import os from datetime import datetime @@ -117,7 +119,10 @@ def test_mask( HKV, D, dtype = 'bf16', + DV = None, ): + if DV is None: + DV = D if dtype == 'bf16': data_type = paddle.bfloat16 @@ -126,8 +131,8 @@ def test_mask( query = paddle.randn([B, S, H, D], dtype=data_type) key = paddle.randn([B, SKV, HKV, D], dtype=data_type) - value = paddle.randn([B, SKV, HKV, D], dtype=data_type) - gradOut = paddle.randn([B, S, H, D], dtype=data_type) + value = paddle.randn([B, SKV, HKV, DV], dtype=data_type) + gradOut = paddle.randn([B, S, H, DV], dtype=data_type) query.stop_gradient = False key.stop_gradient = False @@ -207,6 +212,72 @@ def flashmask_bwd(): return fwd_time_ms, bwd_time_ms, total_time_ms, fwd_flops, bwd_flops, total_flops, fwd_tflops, bwd_tflops, total_tflops, sparsity +def test_mask_varlen( + generate_mask_fn, + B, + S, + SKV, + H, + HKV, + D, + dtype = 'bf16', + DV = None, +): + if DV is None: + DV = D + + if dtype == 'bf16': + data_type = paddle.bfloat16 + else: + data_type = paddle.float16 + + query = paddle.randn([B, S, H, D], dtype=data_type) + key = paddle.randn([B, SKV, HKV, D], dtype=data_type) + value = paddle.randn([B, SKV, HKV, DV], dtype=data_type) + gradOut = paddle.randn([B, S, H, DV], dtype=data_type) + + query.stop_gradient = False + key.stop_gradient = False + value.stop_gradient = False + + startend_row_indices, causal = None, True + if generate_mask_fn is not None: + startend_row_indices, causal = generate_mask_fn(B, SKV, HKV, D) + + sparsity = flashmask_block_sparsity(causal, startend_row_indices, B, H, HKV, S, SKV) + density = 1.0 - sparsity + + def fwd_bwd(): + q = query.detach().clone() + k = key.detach().clone() + v = value.detach().clone() + q.stop_gradient = False + k.stop_gradient = False + v.stop_gradient = False + out, lse = flashmask_attention_interface( + q, k, v, + startend_row_indices=startend_row_indices, + causal=causal, + return_softmax_lse=True, + use_varlen=True, + ) + out.backward(gradOut, retain_graph=True) + + total_time_ms = do_bench(fwd_bwd) + + total_flops = density * cal_flops(B, H, S, SKV, D, mode='fwd_bwd') + total_tflops = cal_tflops(total_flops, total_time_ms) + + # Return format consistent with test_mask; fwd/bwd individual values are N/A + fwd_time_ms = float('nan') + bwd_time_ms = float('nan') + fwd_flops = float('nan') + bwd_flops = float('nan') + fwd_tflops = float('nan') + bwd_tflops = float('nan') + + return fwd_time_ms, bwd_time_ms, total_time_ms, fwd_flops, bwd_flops, total_flops, fwd_tflops, bwd_tflops, total_tflops, sparsity + def flashmask_block_sparsity( causal, flashmask, @@ -754,8 +825,9 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas qksparse_mask = eval(line.split(":")[-1].split("#")[1].strip()) doc_seq_lens_list.append((total_length, doc_list, qksparse_mask)) #doc_seq_lens_list = doc_seq_lens_list[::-1] - for D in [128] if fm_version == 4 else [64, 128, 256]: - H = 4096 // D + # (D, DV, H): standard configs + MLA (d=192, dv=128) + head_configs = [(64, 64, 64), (128, 128, 32), (192, 128, 16), (256, 256, 16)] + for D, DV, H in head_configs: HKV = H for idx, (S, prefix_doc_seq_lens, qksparse_mask) in enumerate(doc_seq_lens_list): B = 128 * 1024 // S @@ -765,10 +837,10 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas doc_seq_lens = [x[1] for x in prefix_doc_seq_lens] maskout_pair = [] offset = 0 - print(f"{B}_{S}_{H}_{HKV}_{D}_{idx}_{dtype}") + print(f"{B}_{S}_{H}_{HKV}_{D}_{DV}_{idx}_{dtype}") if not overwrite: - if os.path.exists(f"{dtype}{suffix}/flashmaskv{fm_version}_{B}_{S}_{H}_{D}_{idx}.csv"): - print(f"{dtype}{suffix}/flashmaskv{fm_version}_{B}_{S}_{H}_{D}_{idx}.csv already exists, skipping. To enable overwrite, use: --overwrite (True by default).") + if os.path.exists(f"{dtype}{suffix}/flashmaskv{fm_version}_{B}_{S}_{H}_{D}_{DV}_{idx}.csv"): + print(f"{dtype}{suffix}/flashmaskv{fm_version}_{B}_{S}_{H}_{D}_{DV}_{idx}.csv already exists, skipping. To enable overwrite, use: --overwrite (True by default).") continue if sum(qksparse_mask) == 0: maskout_pair = [(1024, 538), (2358, 1700)] @@ -780,23 +852,27 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas share_qa_docs = [split_sequence(doc_seq) for doc_seq in doc_seq_lens] available_examples = { - "Full": lambda: test_mask(generate_mask_fn=partial(generate_none_mask, causal=False), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Causal": lambda: test_mask(generate_mask_fn=partial(generate_none_mask, causal=True), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Sliding Window": lambda: test_mask(generate_mask_fn=partial(generate_sliding_window_mask, window_size=int(S*0.0625)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Causal Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Share Question Mask": lambda: test_mask(generate_mask_fn=partial(generate_share_question_mask, doc_seq_lens=share_qa_docs), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - # "Global Sliding Window": lambda: test_mask(generate_mask_fn=partial(generate_global_sliding_window_mask, global_token=16, window_size=(int(S*0.0625), int(S*0.0625))), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Causal Blockwise Mask": lambda: test_mask(generate_mask_fn=partial(generate_causal_blockwise_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Prefix LM Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Prefix LM Causal Mask": lambda: test_mask(generate_mask_fn=partial(generate_prefix_lm_causal_mask, prefix_length=int(S*0.5)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "QK-sparse Mask": lambda: test_mask(generate_mask_fn=partial(generate_qk_sparse_mask, maskout_pair=maskout_pair), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Random Eviction Mask": lambda: test_mask(generate_mask_fn=partial(generate_random_eviction_mask, start_row=S//2), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - # "Hybrid SWA Prefix LM Doc": lambda: test_mask(generate_mask_fn=partial(generate_hybrid_swa_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), + "Full": lambda: test_mask(generate_mask_fn=partial(generate_none_mask, causal=False), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Causal": lambda: test_mask(generate_mask_fn=partial(generate_none_mask, causal=True), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Sliding Window": lambda: test_mask(generate_mask_fn=partial(generate_sliding_window_mask, window_size=int(S*0.0625)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Causal Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Share Question Mask": lambda: test_mask(generate_mask_fn=partial(generate_share_question_mask, doc_seq_lens=share_qa_docs), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + # "Global Sliding Window": lambda: test_mask(generate_mask_fn=partial(generate_global_sliding_window_mask, global_token=16, window_size=(int(S*0.0625), int(S*0.0625))), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Causal Blockwise Mask": lambda: test_mask(generate_mask_fn=partial(generate_causal_blockwise_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Prefix LM Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Prefix LM Causal Mask": lambda: test_mask(generate_mask_fn=partial(generate_prefix_lm_causal_mask, prefix_length=int(S*0.5)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "QK-sparse Mask": lambda: test_mask(generate_mask_fn=partial(generate_qk_sparse_mask, maskout_pair=maskout_pair), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Random Eviction Mask": lambda: test_mask(generate_mask_fn=partial(generate_random_eviction_mask, start_row=S//2), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + # "Hybrid SWA Prefix LM Doc": lambda: test_mask(generate_mask_fn=partial(generate_hybrid_swa_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), # Note(umiswing): support load mask and hybrid mask like this, and also, support simulate cp benchmark - # "Dumped Mask": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - # "Hybrid SWA": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank, hybrid_mask_fn=partial(hybrid_swa, window_size=512, swa_ratio=0.75)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), + # "Dumped Mask": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + # "Hybrid SWA": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank, hybrid_mask_fn=partial(hybrid_swa, window_size=512, swa_ratio=0.75)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + + # Varlen benchmarks (fwd+bwd combined, use_varlen=True) + "Varlen Causal Document Mask": lambda: test_mask_varlen(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Varlen Document Mask": lambda: test_mask_varlen(generate_mask_fn=partial(generate_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), } if "all" in examples: @@ -837,8 +913,8 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas content2=tabulate(results, headers=headers, tablefmt="tsv") os.makedirs(f"{dtype}{suffix}", exist_ok=True) # Note(umiswing): this file name is better, but i need to keep the old name for fig plotting - # text_file = open(f"{dtype}{suffix}/flashmaskv{fm_version}_{current_time}_{B}_{SQ}_{SKV}_{H}_{HKV}_{D}_{idx}.csv","w") - text_file = open(f"{dtype}{suffix}/flashmaskv{fm_version}_{current_time}_{B}_{S}_{H}_{HKV}_{D}_{idx}.csv","w") + # text_file = open(f"{dtype}{suffix}/flashmaskv{fm_version}_{current_time}_{B}_{SQ}_{SKV}_{H}_{HKV}_{D}_{DV}_{idx}.csv","w") + text_file = open(f"{dtype}{suffix}/flashmaskv{fm_version}_{current_time}_{B}_{S}_{H}_{HKV}_{D}_{DV}_{idx}.csv","w") text_file.write(content2) text_file.close() diff --git a/generate_startend_row_indices.py b/generate_startend_row_indices.py index fc33912..aacf682 100644 --- a/generate_startend_row_indices.py +++ b/generate_startend_row_indices.py @@ -76,10 +76,11 @@ def generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens assert total_seqlen <= seqlen_k assert len(doc_seqlens) >= 3 padding = seqlen_k - np.sum(doc_seqlens) - doc_seqlens[-1] += padding seq_cusums = np.cumsum(doc_seqlens) startend_row_indices = np.repeat(seq_cusums, doc_seqlens) + padding_mask = np.repeat(seq_cusums[-1], padding) + startend_row_indices = np.concatenate([startend_row_indices, padding_mask]) startend_row_indices = paddle.to_tensor(startend_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) @@ -344,3 +345,137 @@ def generate_random_eviction_mask(batch_size, seqlen_q, seqlen_k, h, start_row=N startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) causal = True return startend_row_indices, causal + + +def _scale_doc_seqlens(doc_seqlens, seqlen_k): + """Scale doc_seqlens from base 8192 to target seqlen_k.""" + if seqlen_k != 8192: + doc_seqlens = [int(d * (seqlen_k / 8192)) for d in doc_seqlens] + return doc_seqlens + + +# Pre-defined document-length distributions for per-batch variation. +_DIFF_BATCH_DOC_SEQLENS = [ + [2538, 1742, 3213], + [1500, 3500, 2493], + [3000, 1000, 3493], + [800, 2200, 4493], +] + + +def generate_causal_document_mask_diff_batch(batch_size, seqlen_q, seqlen_k, h, doc_seqlens_list=None): + """Causal document mask where each batch item has DIFFERENT document boundaries.""" + if doc_seqlens_list is None: + doc_seqlens_list = [ + _scale_doc_seqlens(_DIFF_BATCH_DOC_SEQLENS[i % len(_DIFF_BATCH_DOC_SEQLENS)], seqlen_k) + for i in range(batch_size) + ] + if seqlen_k != 8192: + print(f"{seqlen_k=}, auto setting per-batch doc_seqlens to {doc_seqlens_list}") + + batch_indices = [] + for bi in range(batch_size): + doc_seqlens = list(doc_seqlens_list[bi]) + total_seqlen = np.sum(doc_seqlens) + assert total_seqlen <= seqlen_k + assert len(doc_seqlens) >= 3 + padding = seqlen_k - np.sum(doc_seqlens) + seq_cusums = np.cumsum(doc_seqlens) + + sri = np.repeat(seq_cusums, doc_seqlens) + padding_mask = np.repeat(seq_cusums[-1], padding) + sri = np.concatenate([sri, padding_mask]) + batch_indices.append(sri) + + stacked = np.stack(batch_indices, axis=0) # (batch_size, seqlen_k) + startend_row_indices = paddle.to_tensor(stacked, dtype=paddle.int32).reshape( + (batch_size, 1, seqlen_k, 1) + ) + startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + + causal = True + return startend_row_indices, causal + + +def generate_document_mask_diff_batch(batch_size, seqlen_q, seqlen_k, h, doc_seqlens_list=None): + """Non-causal document mask where each batch item has DIFFERENT document boundaries.""" + if doc_seqlens_list is None: + doc_seqlens_list = [ + _scale_doc_seqlens(_DIFF_BATCH_DOC_SEQLENS[i % len(_DIFF_BATCH_DOC_SEQLENS)], seqlen_k) + for i in range(batch_size) + ] + if seqlen_k != 8192: + print(f"{seqlen_k=}, auto setting per-batch doc_seqlens to {doc_seqlens_list}") + + batch_down_left = [] + batch_up_right = [] + for bi in range(batch_size): + doc_seqlens = list(doc_seqlens_list[bi]) + total_seqlen = np.sum(doc_seqlens) + assert total_seqlen <= seqlen_k + assert len(doc_seqlens) >= 3 + padding = seqlen_k - np.sum(doc_seqlens) + + down_left_row_indices = [] + up_right_row_indices = [] + + cur_len_so_far = doc_seqlens[0] + for i in range(len(doc_seqlens)): + down_left_row_indices.extend([cur_len_so_far] * doc_seqlens[i]) + if i < len(doc_seqlens) - 1: + cur_len_so_far += doc_seqlens[i + 1] + if padding > 0: + down_left_row_indices.extend([cur_len_so_far] * padding) + + cur_len_so_far = 0 + for i in range(len(doc_seqlens)): + up_right_row_indices.extend([cur_len_so_far] * doc_seqlens[i]) + if i < len(doc_seqlens) - 1: + cur_len_so_far += doc_seqlens[i] + if padding > 0: + up_right_row_indices.extend([cur_len_so_far] * padding) + + batch_down_left.append(down_left_row_indices) + batch_up_right.append(up_right_row_indices) + + down_left = np.array(batch_down_left) # (batch_size, seqlen_k) + up_right = np.array(batch_up_right) # (batch_size, seqlen_k) + + down_left = paddle.to_tensor(down_left, dtype=paddle.int32).reshape((batch_size, 1, seqlen_k, 1)) + up_right = paddle.to_tensor(up_right, dtype=paddle.int32).reshape((batch_size, 1, seqlen_k, 1)) + startend_row_indices = paddle.concat([down_left, up_right], axis=-1) + startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + + causal = False + return startend_row_indices, causal + +def generate_document_mask_simu(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): + assert seqlen_q == seqlen_k + lts, causal = generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens) + causal = False + + b, h, s, _ = lts.shape + ute = paddle.arange( + 0, s, 1, dtype="int32" + ).reshape((1, 1, s, 1)).repeat_interleave(b, 0).repeat_interleave(h, 1) + + ute = paddle.where(ute <= lts, ute, lts) + startend_row_indices = paddle.concat([lts, ute], axis=-1) + + return startend_row_indices, causal + +def generate_document_mask_diff_batch_simu(batch_size, seqlen_q, seqlen_k, h, doc_seqlens_list=None): + assert seqlen_q == seqlen_k + lts, causal = generate_causal_document_mask_diff_batch(batch_size, seqlen_q, seqlen_k, h, doc_seqlens_list) + causal = False + + b, h, s, _ = lts.shape + ute = paddle.arange( + 0, s, 1, dtype="int32" + ).reshape((1, 1, s, 1)).repeat_interleave(b, 0).repeat_interleave(h, 1) + + ute = paddle.where(ute <= lts, ute, lts) + startend_row_indices = paddle.concat([lts, ute], axis=-1) + + return startend_row_indices, causal + diff --git a/run_varlen.sh b/run_varlen.sh new file mode 100644 index 0000000..3a8a491 --- /dev/null +++ b/run_varlen.sh @@ -0,0 +1 @@ +../wsm_varlen_env/bin/python -m pytest -v test_flashmask_to_varlen.py diff --git a/test_flashmask_to_varlen.py b/test_flashmask_to_varlen.py new file mode 100644 index 0000000..cf27b01 --- /dev/null +++ b/test_flashmask_to_varlen.py @@ -0,0 +1,247 @@ +""" +Test: FlashMask (Paddle) vs flash_attn_varlen_func (PyTorch) via convert_to_varlen. + +Workflow: + 1. Generate q, k, v, causal, startend_row_indices (Paddle tensors, padded layout). + 2. Call convert_to_varlen() to transform startend_row_indices into varlen format: + - q_varlen, k_varlen, v_varlen: concatenated Paddle tensors (total_q, nheads, d) + - cu_seqlens_q, cu_seqlens_k: cumulative sequence lengths (Paddle, int32) + - max_seqlen_q, max_seqlen_k: maximum sequence lengths (int) + 3. Call Paddle's flashmask_attention with the original padded input. + 4. Convert varlen tensors from Paddle to PyTorch, then call PyTorch's + flash_attn_varlen_func. + 5. Compare the two outputs via np.allclose. +""" + +import os +import math +import itertools +import pytest +import numpy as np +from functools import partial + +import paddle +import torch + +# ── Paddle: flashmask_attention ────────────────────────────────────────────── +from flash_mask import flashmask_attention +import flash_mask + +# ── Mask generators (Paddle) ──────────────────────────────────────────────── +from generate_startend_row_indices import ( + generate_causal_document_mask, + generate_document_mask, + generate_causal_document_mask_diff_batch, + generate_document_mask_diff_batch, + generate_document_mask_simu, + generate_document_mask_diff_batch_simu, +) + +from test_util import attention_ref + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────────────────────────────────────── + +def paddle_to_torch(t: paddle.Tensor) -> torch.Tensor: + """Convert a Paddle tensor to a PyTorch CUDA tensor. + + For bf16 tensors we view as int16 before going through numpy (which + doesn't support bf16), then reinterpret back to bfloat16 on the + PyTorch side. + """ + if t.dtype == paddle.bfloat16: + np_arr = t.view(paddle.int16).numpy() + return torch.from_numpy(np_arr).view(torch.bfloat16).cuda() + return torch.from_numpy(t.numpy()).cuda() + + +def torch_to_paddle(t: torch.Tensor) -> paddle.Tensor: + """Convert a PyTorch CUDA tensor to a Paddle tensor. + + For bf16 tensors we view as int16 before going through numpy, then + reinterpret back to bfloat16 on the Paddle side. + """ + if t.dtype == torch.bfloat16: + np_arr = t.cpu().view(torch.int16).numpy() + return paddle.to_tensor(np_arr).view(paddle.bfloat16) + return paddle.to_tensor(t.cpu().numpy()) + + +# ───────────────────────────────────────────────────────────────────────────── +# Test parameters +# ───────────────────────────────────────────────────────────────────────────── + +# (batch_size, seqlen_q, seqlen_k, nheads, nheads_kv) +shape_cases = [ + (1, 256, 256, 4, 4), + (2, 512, 512, 8, 2), + (1, 1024, 1024, 4, 1), + (2, 300, 300, 6, 2), + (1, 128, 128, 1, 1), + (2, 1000, 1000, 4, 1), +] + + +def generate_shapes(): + for batch_size, seqlen_q, seqlen_k, nheads, nheads_kv in shape_cases: + if nheads_kv == 1: + nheads_startend_row_indices_values = [1] + else: + nheads_startend_row_indices_values = [1, nheads_kv] + for nheads_sri in nheads_startend_row_indices_values: + yield (batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_sri) + + +# Only test mask types that are compatible with varlen (causal-style masks). +mask_generators = [ + partial(generate_document_mask), # document + partial(generate_causal_document_mask), # causal document + partial(generate_document_mask_diff_batch), # document + partial(generate_causal_document_mask_diff_batch), # causal document + partial(generate_document_mask_simu), # simu causal document + partial(generate_document_mask_diff_batch_simu), # simu causal document diff batch +] + + +# ───────────────────────────────────────────────────────────────────────────── +# The test +# ───────────────────────────────────────────────────────────────────────────── + +@pytest.mark.parametrize("dtype", [paddle.bfloat16]) +@pytest.mark.parametrize("d, dv", [(64, 64), (128, 128)]) +@pytest.mark.parametrize( + "batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices", + list(generate_shapes()), +) +@pytest.mark.parametrize("gen_startend_row_indices", mask_generators) +def test_flashmask_to_varlen( + batch_size, + seqlen_q, + seqlen_k, + nheads, + nheads_kv, + d, + dv, + nheads_startend_row_indices, + dtype, + gen_startend_row_indices, +): + """ + Compare Paddle flashmask_attention output with PyTorch flash_attn_varlen_func output + after converting startend_row_indices to varlen format via convert_to_varlen(). + """ + paddle.seed(2024) + torch.manual_seed(2024) + assert nheads % nheads_kv == 0 + + # ── 1. Generate padded Q, K, V (Paddle) ───────────────────────────────── + q_paddle = paddle.randn(shape=[batch_size, seqlen_q, nheads, d], dtype=dtype) + k_paddle = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, d], dtype=dtype) + v_paddle = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, dv], dtype=dtype) + + # Generate mask + startend_row_indices, causal = gen_startend_row_indices( + batch_size, seqlen_q, seqlen_k, nheads_startend_row_indices + ) + + # ── 3. Call Paddle's flashmask_attention ───────────────────────────────── + paddle.set_flags({"FLAGS_flash_attn_version": 4}) + + # Skip if FA4 doesn't support this configuration + if startend_row_indices is not None and startend_row_indices.shape[-1] == 4: + pytest.skip("FA4 does not support startend_row_indices with last dim == 4") + + q_fm = q_paddle.detach().clone() + k_fm = k_paddle.detach().clone() + v_fm = v_paddle.detach().clone() + q_fm.stop_gradient = False + k_fm.stop_gradient = False + v_fm.stop_gradient = False + + out_fm, lse_fm = flash_mask.cute.interface.flashmask_attention( + q_fm, + k_fm, + v_fm, + startend_row_indices=startend_row_indices, + causal=causal, + return_softmax_lse=True, + ) + + q_varlen = q_paddle.detach().clone() + k_varlen = k_paddle.detach().clone() + v_varlen = v_paddle.detach().clone() + q_varlen.stop_gradient = False + k_varlen.stop_gradient = False + v_varlen.stop_gradient = False + + # ── 4. Call PyTorch's flash_attn_varlen_func ───────────────────────────── + # Convert Paddle varlen tensors to PyTorch CUDA tensors + out_varlen, lse_varlen = flashmask_attention( + q_varlen, + k_varlen, + v_varlen, + startend_row_indices=startend_row_indices, + causal=causal, + return_softmax_lse=True, + use_varlen=True + ) + + # ── 5. Compare outputs ─────────────────────────────────────────────────── + + # Convert both outputs to float32 numpy for comparison + out_fm_np = paddle.cast(out_fm, paddle.float32).numpy() + out_vl_np = paddle.cast(out_varlen, paddle.float32).numpy() + + max_diff = np.max(np.abs(out_fm_np - out_vl_np)) + mean_diff = np.mean(np.abs(out_fm_np - out_vl_np)) + print(f"\n[fwd] max diff: {max_diff:.6e}, mean diff: {mean_diff:.6e}") + + assert np.allclose(out_fm_np, out_vl_np, rtol=1e-2, atol=1e-2), ( + f"Output mismatch: max diff {max_diff:.6e}, mean diff {mean_diff:.6e}" + ) + + # ── 6. Backward ────────────────────────────────────────────────────────── + # Generate the same random gradient for both paths. + g_fm = paddle.randn(shape=out_fm.shape, dtype=out_fm.dtype) + + # Flashmask backward (Paddle) + out_fm.backward(g_fm) + + g_vl = g_fm.detach().clone() + out_varlen.backward(g_vl) + + for name, grad_fm, grad_vl in [ + ("dQ", q_fm.grad, q_varlen.grad), + ("dK", k_fm.grad, k_varlen.grad), + ("dV", v_fm.grad, v_varlen.grad), + ]: + grad_fm_np = paddle.cast(grad_fm, paddle.float32).numpy() + grad_vl_np = paddle.cast(grad_vl, paddle.float32).numpy() + + max_diff = np.max(np.abs(grad_fm_np - grad_vl_np)) + mean_diff = np.mean(np.abs(grad_fm_np - grad_vl_np)) + print(f"[bwd {name}] max diff: {max_diff:.6e}, mean diff: {mean_diff:.6e}") + assert np.allclose(grad_fm_np, grad_vl_np, rtol=1e-2, atol=1e-2), ( + f"{name} mismatch: max diff {max_diff:.6e}, mean diff {mean_diff:.6e}" + ) + +# ───────────────────────────────────────────────────────────────────────────── +# Standalone runner +# ───────────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + # Quick smoke test: single config, causal document mask + test_flashmask_to_varlen( + batch_size=2, + seqlen_q=512, + seqlen_k=512, + nheads=4, + nheads_kv=2, + d=128, + dv=128, + nheads_startend_row_indices=1, + dtype=paddle.bfloat16, + gen_startend_row_indices=partial(generate_causal_document_mask), + ) + print("\nSmoke test passed!") diff --git a/varlen_utils.py b/varlen_utils.py new file mode 100644 index 0000000..c4ec41b --- /dev/null +++ b/varlen_utils.py @@ -0,0 +1,117 @@ +import paddle +import torch + + +def convert_to_varlen( + q, + k, + v, + causal, + startend_row_indices, +): + b, sq, hq, d = q.shape + _, skv, hkv, dv = v.shape + assert sq == skv + + # ── Extract document boundaries PER BATCH from startend_row_indices ── + # startend_row_indices shape: (batch, nheads_sri, seqlen_k, bound_num) + # Column 0 encodes the "end of document" boundary for each key position. + # Tokens within the same document share the same value; boundaries are + # where this value changes. Different batch items may have different + # document layouts, so we extract boundaries for each batch independently. + s = startend_row_indices[:, 0, :, 0] # (batch, seqlen_k) + + cu_seqlens_parts = [] + max_doc_len = 0 + needs_padding_fixup = False + real_ends = [] + + for bi in range(b): + s_bi = s[bi] # (seqlen_k,) + + # Find change positions -> document boundaries + diff_bi = paddle.not_equal(s_bi[1:], s_bi[:-1]) + change_idx_bi = paddle.nonzero(diff_bi).flatten().cast(paddle.int32) + 1 + + # Real end of documents (max value in column 0) + real_end_bi = int(s_bi.max().item()) + real_ends.append(real_end_bi) + if real_end_bi < skv: + needs_padding_fixup = True + + # Boundaries: [0, change_1, ..., seqlen_k] + # Always use seqlen_k as last boundary so padding KV (visible to + # the last doc's rows in flashmask) is included in the last doc. + boundaries_bi = paddle.concat([ + paddle.zeros([1], dtype=paddle.int32), + change_idx_bi, + paddle.to_tensor([skv], dtype=paddle.int32), + ]) + + # Track max document length across all batches + doc_lens_bi = boundaries_bi[1:] - boundaries_bi[:-1] + max_doc_len = max(max_doc_len, int(doc_lens_bi.max().item())) + + # Collect document start positions with batch offset + cu_seqlens_parts.append(boundaries_bi[:-1] + bi * skv) + + # Build cu_seqlens: concat per-batch starts + final endpoint + cu_seqlens = paddle.concat( + cu_seqlens_parts + [paddle.to_tensor([b * skv], dtype=paddle.int32)] + ) + + # ── Flatten q, k, v: (batch, seqlen, heads, dim) -> (total, heads, dim) + q_varlen = q.reshape([b * sq, hq, d]) + k_varlen = k.reshape([b * skv, hkv, d]) + v_varlen = v.reshape([b * skv, hkv, dv]) + + # ── Detect simulated causal masks ────────────────────────────────── + # A causal document mask can be encoded as causal=False with bound_num=2 + # (LTS + UTE) where UTE[j] = min(j, LTS[j]) reproduces the causal + # diagonal. For true non-causal masks, UTE is constant within each + # document (= document start), which differs from min(j, LTS[j]). + varlen_causal = causal + bound_num = startend_row_indices.shape[-1] + if not causal and bound_num == 2: + lts_all = startend_row_indices[:, 0, :, 0] # (b, skv) + ute_all = startend_row_indices[:, 0, :, 1] # (b, skv) + arange_ref = paddle.arange(skv, dtype=paddle.int32).unsqueeze(0) # (1, skv) + expected_causal_ute = paddle.minimum(arange_ref, lts_all) # (b, skv) + if paddle.equal_all(ute_all, expected_causal_ute).item(): + varlen_causal = True + + result = { + "q": q_varlen, + "k": k_varlen, + "v": v_varlen, + "cu_seqlens_q": cu_seqlens, + "cu_seqlens_k": cu_seqlens, + "max_seqlen_q": max_doc_len, + "max_seqlen_k": max_doc_len, + "causal": varlen_causal, + } + + # For non-causal masks with trailing padding: padding rows attend to + # nothing in flashmask (zero output), but varlen computes non-zero + # output for them. Zero out padding rows per batch to match flashmask. + # Note: real_end can differ across batch items. + if needs_padding_fixup: + _b, _sq = b, sq + _real_ends = real_ends + + def output_to_padded(out_varlen_pt): + nh = out_varlen_pt.shape[1] + dv_out = out_varlen_pt.shape[2] + out_padded = out_varlen_pt.reshape(_b, _sq, nh, dv_out) + # Vectorised per-batch zeroing + row_idx = torch.arange(_sq, device=out_padded.device) + real_end_t = torch.tensor( + _real_ends, device=out_padded.device, dtype=torch.int64, + ).unsqueeze(1) + padding_mask = row_idx.unsqueeze(0) >= real_end_t # (b, sq) + out_padded[padding_mask] = 0 + return out_padded + + result["output_to_padded"] = output_to_padded + + return result