From 142070c284e26430b4e8e81e90c77b6da1370a06 Mon Sep 17 00:00:00 2001 From: umiswing Date: Mon, 13 Apr 2026 21:35:28 +0800 Subject: [PATCH 01/10] add fm convert to varlen test --- run_varlen.sh | 1 + test_flashmask_to_varlen.py | 256 ++++++++++++++++++++++++++++++++++++ varlen_utils.py | 31 +++++ 3 files changed, 288 insertions(+) create mode 100644 run_varlen.sh create mode 100644 test_flashmask_to_varlen.py create mode 100644 varlen_utils.py diff --git a/run_varlen.sh b/run_varlen.sh new file mode 100644 index 0000000..3a8a491 --- /dev/null +++ b/run_varlen.sh @@ -0,0 +1 @@ +../wsm_varlen_env/bin/python -m pytest -v test_flashmask_to_varlen.py diff --git a/test_flashmask_to_varlen.py b/test_flashmask_to_varlen.py new file mode 100644 index 0000000..d8d621f --- /dev/null +++ b/test_flashmask_to_varlen.py @@ -0,0 +1,256 @@ +""" +Test: FlashMask (Paddle) vs flash_attn_varlen_func (PyTorch) via convert_to_varlen. + +Workflow: + 1. Generate q, k, v, causal, startend_row_indices (Paddle tensors, padded layout). + 2. Call convert_to_varlen() to transform startend_row_indices into varlen format: + - q_varlen, k_varlen, v_varlen: concatenated Paddle tensors (total_q, nheads, d) + - cu_seqlens_q, cu_seqlens_k: cumulative sequence lengths (Paddle, int32) + - max_seqlen_q, max_seqlen_k: maximum sequence lengths (int) + 3. Call Paddle's flashmask_attention with the original padded input. + 4. Convert varlen tensors from Paddle to PyTorch, then call PyTorch's + flash_attn_varlen_func. + 5. Compare the two outputs via np.allclose. +""" + +import os +import math +import itertools +import pytest +import numpy as np +from functools import partial + +import paddle +import torch + +# ── Paddle: flashmask_attention ────────────────────────────────────────────── +try: + from flash_mask.cute.interface import flashmask_attention +except (ImportError, ModuleNotFoundError): + from paddle.nn.functional.flash_attention import flashmask_attention + +# ── PyTorch: flash_attn_varlen_func (FA4 cute) ────────────────────────────── +from flash_attn.cute.interface import flash_attn_varlen_func + +# ── convert_to_varlen (fake implementation in varlen_utils.py) ─────────────── +from varlen_utils import convert_to_varlen + +# ── Mask generators (Paddle) ──────────────────────────────────────────────── +from generate_startend_row_indices import ( + generate_causal_document_mask, + generate_document_mask, +) + +from test_util import attention_ref + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────────────────────────────────────── + +def paddle_to_torch(t: paddle.Tensor) -> torch.Tensor: + """Convert a Paddle tensor to a PyTorch CUDA tensor. + + For bf16 tensors we view as int16 before going through numpy (which + doesn't support bf16), then reinterpret back to bfloat16 on the + PyTorch side. + """ + if t.dtype == paddle.bfloat16: + np_arr = t.view(paddle.int16).numpy() + return torch.from_numpy(np_arr).view(torch.bfloat16).cuda() + return torch.from_numpy(t.numpy()).cuda() + + +def torch_to_paddle(t: torch.Tensor) -> paddle.Tensor: + """Convert a PyTorch CUDA tensor to a Paddle tensor. + + For bf16 tensors we view as int16 before going through numpy, then + reinterpret back to bfloat16 on the Paddle side. + """ + if t.dtype == torch.bfloat16: + np_arr = t.cpu().view(torch.int16).numpy() + return paddle.to_tensor(np_arr).view(paddle.bfloat16) + return paddle.to_tensor(t.cpu().numpy()) + + +# ───────────────────────────────────────────────────────────────────────────── +# Test parameters +# ───────────────────────────────────────────────────────────────────────────── + +# (batch_size, seqlen_q, seqlen_k, nheads, nheads_kv) +shape_cases = [ + (1, 256, 256, 4, 4), + (2, 512, 512, 8, 2), + (1, 1024, 1024, 4, 1), + (2, 300, 300, 6, 2), + (1, 128, 128, 1, 1), + (2, 1000, 1000, 4, 1), +] + + +def generate_shapes(): + for batch_size, seqlen_q, seqlen_k, nheads, nheads_kv in shape_cases: + if nheads_kv == 1: + nheads_startend_row_indices_values = [1] + else: + nheads_startend_row_indices_values = [1, nheads_kv] + for nheads_sri in nheads_startend_row_indices_values: + yield (batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_sri) + + +# Only test mask types that are compatible with varlen (causal-style masks). +mask_generators = [ + partial(generate_document_mask), # document + partial(generate_causal_document_mask), # causal document +] + + +# ───────────────────────────────────────────────────────────────────────────── +# The test +# ───────────────────────────────────────────────────────────────────────────── + +@pytest.mark.parametrize("dtype", [paddle.bfloat16]) +@pytest.mark.parametrize("d, dv", [(64, 64), (128, 128)]) +@pytest.mark.parametrize( + "batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices", + list(generate_shapes()), +) +@pytest.mark.parametrize("gen_startend_row_indices", mask_generators) +def test_flashmask_to_varlen( + batch_size, + seqlen_q, + seqlen_k, + nheads, + nheads_kv, + d, + dv, + nheads_startend_row_indices, + dtype, + gen_startend_row_indices, +): + """ + Compare Paddle flashmask_attention output with PyTorch flash_attn_varlen_func output + after converting startend_row_indices to varlen format via convert_to_varlen(). + """ + paddle.seed(2024) + torch.manual_seed(2024) + assert nheads % nheads_kv == 0 + + # ── 1. Generate padded Q, K, V (Paddle) ───────────────────────────────── + q_paddle = paddle.randn(shape=[batch_size, seqlen_q, nheads, d], dtype=dtype) + k_paddle = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, d], dtype=dtype) + v_paddle = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, dv], dtype=dtype) + + # Generate mask + startend_row_indices, causal = gen_startend_row_indices( + batch_size, seqlen_q, seqlen_k, nheads_startend_row_indices + ) + + # ── 2. Convert to varlen format ────────────────────────────────────────── + # convert_to_varlen returns Paddle tensors in a dict: + # "q": (total_q, nheads, d) Paddle + # "k": (total_k, nheads_kv, d) Paddle + # "v": (total_k, nheads_kv, dv) Paddle + # "cu_seqlens_q": (num_seqs + 1,) Paddle int32 + # "cu_seqlens_k": (num_seqs + 1,) Paddle int32 + # "max_seqlen_q": int + # "max_seqlen_k": int + # "causal": bool + + varlen = convert_to_varlen( + q_paddle, k_paddle, v_paddle, causal, startend_row_indices + ) + + q_vl_paddle = varlen["q"] + k_vl_paddle = varlen["k"] + v_vl_paddle = varlen["v"] + cu_seqlens_q_paddle = varlen["cu_seqlens_q"] + cu_seqlens_k_paddle = varlen["cu_seqlens_k"] + max_seqlen_q = varlen["max_seqlen_q"] + max_seqlen_k = varlen["max_seqlen_k"] + varlen_causal = varlen.get("causal", causal) + + # ── 3. Call Paddle's flashmask_attention ───────────────────────────────── + paddle.set_flags({"FLAGS_flash_attn_version": 4}) + + # Skip if FA4 doesn't support this configuration + if startend_row_indices is not None and startend_row_indices.shape[-1] == 4: + pytest.skip("FA4 does not support startend_row_indices with last dim == 4") + + q_fm = q_paddle.detach().clone() + k_fm = k_paddle.detach().clone() + v_fm = v_paddle.detach().clone() + q_fm.stop_gradient = False + k_fm.stop_gradient = False + v_fm.stop_gradient = False + + out_fm, lse_fm = flashmask_attention( + q_fm, + k_fm, + v_fm, + startend_row_indices=startend_row_indices, + causal=causal, + return_softmax_lse=True, + ) + + # ── 4. Call PyTorch's flash_attn_varlen_func ───────────────────────────── + # Convert Paddle varlen tensors to PyTorch CUDA tensors + q_varlen_pt = paddle_to_torch(q_vl_paddle).contiguous() + k_varlen_pt = paddle_to_torch(k_vl_paddle).contiguous() + v_varlen_pt = paddle_to_torch(v_vl_paddle).contiguous() + cu_seqlens_q_pt = paddle_to_torch(cu_seqlens_q_paddle).contiguous().to(torch.int32) + cu_seqlens_k_pt = paddle_to_torch(cu_seqlens_k_paddle).contiguous().to(torch.int32) + + out_varlen_pt, lse_varlen_pt = flash_attn_varlen_func( + q_varlen_pt, + k_varlen_pt, + v_varlen_pt, + cu_seqlens_q=cu_seqlens_q_pt, + cu_seqlens_k=cu_seqlens_k_pt, + max_seqlen_q=max_seqlen_q, + max_seqlen_k=max_seqlen_k, + causal=varlen_causal, + return_lse=True, + ) + + # ── 5. Compare outputs ─────────────────────────────────────────────────── + # Map varlen output back to padded layout for comparison. + # If convert_to_varlen provides an "output_to_padded" callable, use it; + # otherwise reshape the flat (total_q, nheads, dv) back to + # (batch_size, seqlen_q, nheads, dv). + if "output_to_padded" in varlen: + out_varlen_padded_pt = varlen["output_to_padded"](out_varlen_pt) + else: + out_varlen_padded_pt = out_varlen_pt.reshape(batch_size, seqlen_q, nheads, dv) + + # Convert both outputs to float32 numpy for comparison + out_fm_np = paddle.cast(out_fm, paddle.float32).numpy() + out_vl_np = torch_to_paddle(out_varlen_padded_pt).cast(paddle.float32).numpy() + + max_diff = np.max(np.abs(out_fm_np - out_vl_np)) + mean_diff = np.mean(np.abs(out_fm_np - out_vl_np)) + print(f"\n[flashmask vs varlen] max diff: {max_diff:.6e}, mean diff: {mean_diff:.6e}") + + assert np.allclose(out_fm_np, out_vl_np, rtol=1e-2, atol=1e-2), ( + f"Output mismatch: max diff {max_diff:.6e}, mean diff {mean_diff:.6e}" + ) + + +# ───────────────────────────────────────────────────────────────────────────── +# Standalone runner +# ───────────────────────────────────────────────────────────────────────────── + +if __name__ == "__main__": + # Quick smoke test: single config, causal document mask + test_flashmask_to_varlen( + batch_size=2, + seqlen_q=512, + seqlen_k=512, + nheads=4, + nheads_kv=2, + d=128, + dv=128, + nheads_startend_row_indices=1, + dtype=paddle.bfloat16, + gen_startend_row_indices=partial(generate_causal_document_mask), + ) + print("\nSmoke test passed!") diff --git a/varlen_utils.py b/varlen_utils.py new file mode 100644 index 0000000..f58f1eb --- /dev/null +++ b/varlen_utils.py @@ -0,0 +1,31 @@ +import paddle +def convert_to_varlen( + q, + k, + v, + causal, + startend_row_indices, +): + b, sq, hq, d = q.shape + _, skv, hkv, dv = v.shape + assert sq == skv + q_varlen = q.reshape([b * sq, hq, d]) + k_varlen = k.reshape([b * skv, hkv, d]) + v_varlen = v.reshape([b * skv, hkv, dv]) + + cu_seqlens_q = paddle.to_tensor([0, b * sq], dtype=paddle.int32) + cu_seqlens_k = paddle.to_tensor([0, b * skv], dtype=paddle.int32) + + max_seqlen_q = b * sq + max_seqlen_k = b * skv + + return { + "q": q_varlen, + "k": k_varlen, + "v": v_varlen, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + "max_seqlen_q": max_seqlen_q, + "max_seqlen_k": max_seqlen_k, + "causal": causal, + } From 29aec3cba599a80e320618c937184ece7de58249 Mon Sep 17 00:00:00 2001 From: umiswing Date: Tue, 14 Apr 2026 14:31:44 +0800 Subject: [PATCH 02/10] implement convert to varlen --- varlen_utils.py | 68 ++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 59 insertions(+), 9 deletions(-) diff --git a/varlen_utils.py b/varlen_utils.py index f58f1eb..4841dd7 100644 --- a/varlen_utils.py +++ b/varlen_utils.py @@ -1,4 +1,7 @@ import paddle +import torch + + def convert_to_varlen( q, k, @@ -9,23 +12,70 @@ def convert_to_varlen( b, sq, hq, d = q.shape _, skv, hkv, dv = v.shape assert sq == skv + + # ── Extract document boundaries from startend_row_indices ─────────── + # startend_row_indices shape: (batch, nheads_sri, seqlen_k, bound_num) + # Column 0 encodes the "end of document" boundary for each key position. + # Tokens within the same document share the same value. + # Document boundaries are where this value changes. + s = startend_row_indices[0, 0, :, 0] # (seqlen_k,) + + # Find positions where values change -> document start positions + diff = paddle.not_equal(s[1:], s[:-1]) # (seqlen_k - 1,) + change_idx = paddle.nonzero(diff).flatten().cast(paddle.int32) + 1 + + # The real end of documents = max value in column 0. + # For causal: equals seqlen_k (padding absorbed into last doc). + # For non-causal: may be < seqlen_k (padding rows attend to nothing). + real_end = int(s.max().item()) + + # Always use seqlen_k as last boundary so padding tokens are included + # in the last document (their KV is visible to the last doc's rows). + boundaries = paddle.concat([ + paddle.zeros([1], dtype=paddle.int32), + change_idx, + paddle.to_tensor([skv], dtype=paddle.int32), + ]) # (num_docs + 1,) + + # ── Flatten q, k, v: (batch, seqlen, heads, dim) -> (total, heads, dim) q_varlen = q.reshape([b * sq, hq, d]) k_varlen = k.reshape([b * skv, hkv, d]) v_varlen = v.reshape([b * skv, hkv, dv]) - cu_seqlens_q = paddle.to_tensor([0, b * sq], dtype=paddle.int32) - cu_seqlens_k = paddle.to_tensor([0, b * skv], dtype=paddle.int32) + # ── Build cu_seqlens for all batch items ──────────────────────────── + batch_offsets = (paddle.arange(b, dtype=paddle.int32) * skv).unsqueeze(1) + per_batch_starts = boundaries[:-1].unsqueeze(0) + batch_offsets # (b, num_docs) + cu_seqlens = paddle.concat([ + per_batch_starts.reshape([-1]), + paddle.to_tensor([b * skv], dtype=paddle.int32), + ]) - max_seqlen_q = b * sq - max_seqlen_k = b * skv + # ── Max sequence length (max document length) ─────────────────────── + doc_lengths = boundaries[1:] - boundaries[:-1] + max_seqlen = int(doc_lengths.max().item()) - return { + result = { "q": q_varlen, "k": k_varlen, "v": v_varlen, - "cu_seqlens_q": cu_seqlens_q, - "cu_seqlens_k": cu_seqlens_k, - "max_seqlen_q": max_seqlen_q, - "max_seqlen_k": max_seqlen_k, + "cu_seqlens_q": cu_seqlens, + "cu_seqlens_k": cu_seqlens, + "max_seqlen_q": max_seqlen, + "max_seqlen_k": max_seqlen, "causal": causal, } + + # For non-causal masks with trailing padding: padding rows attend to + # nothing in flashmask (zero output), but varlen computes non-zero + # output for them. Zero out padding rows to match flashmask. + if real_end < skv: + _b, _sq, _real_end = b, sq, real_end + + def output_to_padded(out_varlen_pt): + out_padded = out_varlen_pt.reshape(_b, _sq, -1, out_varlen_pt.shape[-1]) + out_padded[:, _real_end:] = 0 + return out_padded + + result["output_to_padded"] = output_to_padded + + return result From 8872ef0e0fa7eb3acaebd2d931f5896ac4f821c0 Mon Sep 17 00:00:00 2001 From: umiswing Date: Tue, 14 Apr 2026 16:26:22 +0800 Subject: [PATCH 03/10] add document mask diff batch test case --- generate_startend_row_indices.py | 102 +++++++++++++++++++++++++++++++ test_flashmask_to_varlen.py | 4 ++ 2 files changed, 106 insertions(+) diff --git a/generate_startend_row_indices.py b/generate_startend_row_indices.py index fc33912..ce74ad1 100644 --- a/generate_startend_row_indices.py +++ b/generate_startend_row_indices.py @@ -344,3 +344,105 @@ def generate_random_eviction_mask(batch_size, seqlen_q, seqlen_k, h, start_row=N startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) causal = True return startend_row_indices, causal + + +def _scale_doc_seqlens(doc_seqlens, seqlen_k): + """Scale doc_seqlens from base 8192 to target seqlen_k.""" + if seqlen_k != 8192: + doc_seqlens = [int(d * (seqlen_k / 8192)) for d in doc_seqlens] + return doc_seqlens + + +# Pre-defined document-length distributions for per-batch variation. +_DIFF_BATCH_DOC_SEQLENS = [ + [2538, 1742, 3213], + [1500, 3500, 2493], + [3000, 1000, 3493], + [800, 2200, 4493], +] + + +def generate_causal_document_mask_diff_batch(batch_size, seqlen_q, seqlen_k, h, doc_seqlens_list=None): + """Causal document mask where each batch item has DIFFERENT document boundaries.""" + if doc_seqlens_list is None: + doc_seqlens_list = [ + _scale_doc_seqlens(_DIFF_BATCH_DOC_SEQLENS[i % len(_DIFF_BATCH_DOC_SEQLENS)], seqlen_k) + for i in range(batch_size) + ] + if seqlen_k != 8192: + print(f"{seqlen_k=}, auto setting per-batch doc_seqlens to {doc_seqlens_list}") + + batch_indices = [] + for bi in range(batch_size): + doc_seqlens = list(doc_seqlens_list[bi]) + total_seqlen = np.sum(doc_seqlens) + assert total_seqlen <= seqlen_k + assert len(doc_seqlens) >= 3 + padding = seqlen_k - np.sum(doc_seqlens) + doc_seqlens[-1] += padding + seq_cusums = np.cumsum(doc_seqlens) + + sri = np.repeat(seq_cusums, doc_seqlens) + batch_indices.append(sri) + + stacked = np.stack(batch_indices, axis=0) # (batch_size, seqlen_k) + startend_row_indices = paddle.to_tensor(stacked, dtype=paddle.int32).reshape( + (batch_size, 1, seqlen_k, 1) + ) + startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + + causal = True + return startend_row_indices, causal + + +def generate_document_mask_diff_batch(batch_size, seqlen_q, seqlen_k, h, doc_seqlens_list=None): + """Non-causal document mask where each batch item has DIFFERENT document boundaries.""" + if doc_seqlens_list is None: + doc_seqlens_list = [ + _scale_doc_seqlens(_DIFF_BATCH_DOC_SEQLENS[i % len(_DIFF_BATCH_DOC_SEQLENS)], seqlen_k) + for i in range(batch_size) + ] + if seqlen_k != 8192: + print(f"{seqlen_k=}, auto setting per-batch doc_seqlens to {doc_seqlens_list}") + + batch_down_left = [] + batch_up_right = [] + for bi in range(batch_size): + doc_seqlens = list(doc_seqlens_list[bi]) + total_seqlen = np.sum(doc_seqlens) + assert total_seqlen <= seqlen_k + assert len(doc_seqlens) >= 3 + padding = seqlen_k - np.sum(doc_seqlens) + + down_left_row_indices = [] + up_right_row_indices = [] + + cur_len_so_far = doc_seqlens[0] + for i in range(len(doc_seqlens)): + down_left_row_indices.extend([cur_len_so_far] * doc_seqlens[i]) + if i < len(doc_seqlens) - 1: + cur_len_so_far += doc_seqlens[i + 1] + if padding > 0: + down_left_row_indices.extend([cur_len_so_far] * padding) + + cur_len_so_far = 0 + for i in range(len(doc_seqlens)): + up_right_row_indices.extend([cur_len_so_far] * doc_seqlens[i]) + if i < len(doc_seqlens) - 1: + cur_len_so_far += doc_seqlens[i] + if padding > 0: + up_right_row_indices.extend([cur_len_so_far] * padding) + + batch_down_left.append(down_left_row_indices) + batch_up_right.append(up_right_row_indices) + + down_left = np.array(batch_down_left) # (batch_size, seqlen_k) + up_right = np.array(batch_up_right) # (batch_size, seqlen_k) + + down_left = paddle.to_tensor(down_left, dtype=paddle.int32).reshape((batch_size, 1, seqlen_k, 1)) + up_right = paddle.to_tensor(up_right, dtype=paddle.int32).reshape((batch_size, 1, seqlen_k, 1)) + startend_row_indices = paddle.concat([down_left, up_right], axis=-1) + startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) + + causal = False + return startend_row_indices, causal diff --git a/test_flashmask_to_varlen.py b/test_flashmask_to_varlen.py index d8d621f..b64595c 100644 --- a/test_flashmask_to_varlen.py +++ b/test_flashmask_to_varlen.py @@ -39,6 +39,8 @@ from generate_startend_row_indices import ( generate_causal_document_mask, generate_document_mask, + generate_causal_document_mask_diff_batch, + generate_document_mask_diff_batch, ) from test_util import attention_ref @@ -101,6 +103,8 @@ def generate_shapes(): mask_generators = [ partial(generate_document_mask), # document partial(generate_causal_document_mask), # causal document + partial(generate_document_mask_diff_batch), # document + partial(generate_causal_document_mask_diff_batch), # causal document ] From 1db63bafc08338071663eec724b99c2d07a934e1 Mon Sep 17 00:00:00 2001 From: umiswing Date: Tue, 14 Apr 2026 18:19:25 +0800 Subject: [PATCH 04/10] support diff batch convert --- varlen_utils.py | 101 +++++++++++++++++++++++++++++------------------- 1 file changed, 61 insertions(+), 40 deletions(-) diff --git a/varlen_utils.py b/varlen_utils.py index 4841dd7..c9fb6d5 100644 --- a/varlen_utils.py +++ b/varlen_utils.py @@ -13,67 +13,88 @@ def convert_to_varlen( _, skv, hkv, dv = v.shape assert sq == skv - # ── Extract document boundaries from startend_row_indices ─────────── + # ── Extract document boundaries PER BATCH from startend_row_indices ── # startend_row_indices shape: (batch, nheads_sri, seqlen_k, bound_num) # Column 0 encodes the "end of document" boundary for each key position. - # Tokens within the same document share the same value. - # Document boundaries are where this value changes. - s = startend_row_indices[0, 0, :, 0] # (seqlen_k,) - - # Find positions where values change -> document start positions - diff = paddle.not_equal(s[1:], s[:-1]) # (seqlen_k - 1,) - change_idx = paddle.nonzero(diff).flatten().cast(paddle.int32) + 1 - - # The real end of documents = max value in column 0. - # For causal: equals seqlen_k (padding absorbed into last doc). - # For non-causal: may be < seqlen_k (padding rows attend to nothing). - real_end = int(s.max().item()) - - # Always use seqlen_k as last boundary so padding tokens are included - # in the last document (their KV is visible to the last doc's rows). - boundaries = paddle.concat([ - paddle.zeros([1], dtype=paddle.int32), - change_idx, - paddle.to_tensor([skv], dtype=paddle.int32), - ]) # (num_docs + 1,) + # Tokens within the same document share the same value; boundaries are + # where this value changes. Different batch items may have different + # document layouts, so we extract boundaries for each batch independently. + s = startend_row_indices[:, 0, :, 0] # (batch, seqlen_k) + + cu_seqlens_parts = [] + max_doc_len = 0 + needs_padding_fixup = False + real_ends = [] + + for bi in range(b): + s_bi = s[bi] # (seqlen_k,) + + # Find change positions -> document boundaries + diff_bi = paddle.not_equal(s_bi[1:], s_bi[:-1]) + change_idx_bi = paddle.nonzero(diff_bi).flatten().cast(paddle.int32) + 1 + + # Real end of documents (max value in column 0) + real_end_bi = int(s_bi.max().item()) + real_ends.append(real_end_bi) + if real_end_bi < skv: + needs_padding_fixup = True + + # Boundaries: [0, change_1, ..., seqlen_k] + # Always use seqlen_k as last boundary so padding KV (visible to + # the last doc's rows in flashmask) is included in the last doc. + boundaries_bi = paddle.concat([ + paddle.zeros([1], dtype=paddle.int32), + change_idx_bi, + paddle.to_tensor([skv], dtype=paddle.int32), + ]) + + # Track max document length across all batches + doc_lens_bi = boundaries_bi[1:] - boundaries_bi[:-1] + max_doc_len = max(max_doc_len, int(doc_lens_bi.max().item())) + + # Collect document start positions with batch offset + cu_seqlens_parts.append(boundaries_bi[:-1] + bi * skv) + + # Build cu_seqlens: concat per-batch starts + final endpoint + cu_seqlens = paddle.concat( + cu_seqlens_parts + [paddle.to_tensor([b * skv], dtype=paddle.int32)] + ) # ── Flatten q, k, v: (batch, seqlen, heads, dim) -> (total, heads, dim) q_varlen = q.reshape([b * sq, hq, d]) k_varlen = k.reshape([b * skv, hkv, d]) v_varlen = v.reshape([b * skv, hkv, dv]) - # ── Build cu_seqlens for all batch items ──────────────────────────── - batch_offsets = (paddle.arange(b, dtype=paddle.int32) * skv).unsqueeze(1) - per_batch_starts = boundaries[:-1].unsqueeze(0) + batch_offsets # (b, num_docs) - cu_seqlens = paddle.concat([ - per_batch_starts.reshape([-1]), - paddle.to_tensor([b * skv], dtype=paddle.int32), - ]) - - # ── Max sequence length (max document length) ─────────────────────── - doc_lengths = boundaries[1:] - boundaries[:-1] - max_seqlen = int(doc_lengths.max().item()) - result = { "q": q_varlen, "k": k_varlen, "v": v_varlen, "cu_seqlens_q": cu_seqlens, "cu_seqlens_k": cu_seqlens, - "max_seqlen_q": max_seqlen, - "max_seqlen_k": max_seqlen, + "max_seqlen_q": max_doc_len, + "max_seqlen_k": max_doc_len, "causal": causal, } # For non-causal masks with trailing padding: padding rows attend to # nothing in flashmask (zero output), but varlen computes non-zero - # output for them. Zero out padding rows to match flashmask. - if real_end < skv: - _b, _sq, _real_end = b, sq, real_end + # output for them. Zero out padding rows per batch to match flashmask. + # Note: real_end can differ across batch items. + if needs_padding_fixup: + _b, _sq = b, sq + _real_ends = real_ends def output_to_padded(out_varlen_pt): - out_padded = out_varlen_pt.reshape(_b, _sq, -1, out_varlen_pt.shape[-1]) - out_padded[:, _real_end:] = 0 + nh = out_varlen_pt.shape[1] + dv_out = out_varlen_pt.shape[2] + out_padded = out_varlen_pt.reshape(_b, _sq, nh, dv_out) + # Vectorised per-batch zeroing + row_idx = torch.arange(_sq, device=out_padded.device) + real_end_t = torch.tensor( + _real_ends, device=out_padded.device, dtype=torch.int64, + ).unsqueeze(1) + padding_mask = row_idx.unsqueeze(0) >= real_end_t # (b, sq) + out_padded[padding_mask] = 0 return out_padded result["output_to_padded"] = output_to_padded From ca43668006094c70ee9ee1066c548aafda06b1ad Mon Sep 17 00:00:00 2001 From: umiswing Date: Tue, 14 Apr 2026 20:59:30 +0800 Subject: [PATCH 05/10] fix causal document mask padding, add causal document mask simu --- generate_startend_row_indices.py | 37 ++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/generate_startend_row_indices.py b/generate_startend_row_indices.py index ce74ad1..aacf682 100644 --- a/generate_startend_row_indices.py +++ b/generate_startend_row_indices.py @@ -76,10 +76,11 @@ def generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens assert total_seqlen <= seqlen_k assert len(doc_seqlens) >= 3 padding = seqlen_k - np.sum(doc_seqlens) - doc_seqlens[-1] += padding seq_cusums = np.cumsum(doc_seqlens) startend_row_indices = np.repeat(seq_cusums, doc_seqlens) + padding_mask = np.repeat(seq_cusums[-1], padding) + startend_row_indices = np.concatenate([startend_row_indices, padding_mask]) startend_row_indices = paddle.to_tensor(startend_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q) @@ -379,10 +380,11 @@ def generate_causal_document_mask_diff_batch(batch_size, seqlen_q, seqlen_k, h, assert total_seqlen <= seqlen_k assert len(doc_seqlens) >= 3 padding = seqlen_k - np.sum(doc_seqlens) - doc_seqlens[-1] += padding seq_cusums = np.cumsum(doc_seqlens) sri = np.repeat(seq_cusums, doc_seqlens) + padding_mask = np.repeat(seq_cusums[-1], padding) + sri = np.concatenate([sri, padding_mask]) batch_indices.append(sri) stacked = np.stack(batch_indices, axis=0) # (batch_size, seqlen_k) @@ -446,3 +448,34 @@ def generate_document_mask_diff_batch(batch_size, seqlen_q, seqlen_k, h, doc_seq causal = False return startend_row_indices, causal + +def generate_document_mask_simu(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): + assert seqlen_q == seqlen_k + lts, causal = generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens) + causal = False + + b, h, s, _ = lts.shape + ute = paddle.arange( + 0, s, 1, dtype="int32" + ).reshape((1, 1, s, 1)).repeat_interleave(b, 0).repeat_interleave(h, 1) + + ute = paddle.where(ute <= lts, ute, lts) + startend_row_indices = paddle.concat([lts, ute], axis=-1) + + return startend_row_indices, causal + +def generate_document_mask_diff_batch_simu(batch_size, seqlen_q, seqlen_k, h, doc_seqlens_list=None): + assert seqlen_q == seqlen_k + lts, causal = generate_causal_document_mask_diff_batch(batch_size, seqlen_q, seqlen_k, h, doc_seqlens_list) + causal = False + + b, h, s, _ = lts.shape + ute = paddle.arange( + 0, s, 1, dtype="int32" + ).reshape((1, 1, s, 1)).repeat_interleave(b, 0).repeat_interleave(h, 1) + + ute = paddle.where(ute <= lts, ute, lts) + startend_row_indices = paddle.concat([lts, ute], axis=-1) + + return startend_row_indices, causal + From a489828d1a0d21a1ac5e6faf55c025ad7696c0f6 Mon Sep 17 00:00:00 2001 From: umiswing Date: Tue, 14 Apr 2026 22:02:01 +0800 Subject: [PATCH 06/10] fix varlen convert when simu causal with lts + ute --- varlen_utils.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/varlen_utils.py b/varlen_utils.py index c9fb6d5..c4ec41b 100644 --- a/varlen_utils.py +++ b/varlen_utils.py @@ -65,6 +65,21 @@ def convert_to_varlen( k_varlen = k.reshape([b * skv, hkv, d]) v_varlen = v.reshape([b * skv, hkv, dv]) + # ── Detect simulated causal masks ────────────────────────────────── + # A causal document mask can be encoded as causal=False with bound_num=2 + # (LTS + UTE) where UTE[j] = min(j, LTS[j]) reproduces the causal + # diagonal. For true non-causal masks, UTE is constant within each + # document (= document start), which differs from min(j, LTS[j]). + varlen_causal = causal + bound_num = startend_row_indices.shape[-1] + if not causal and bound_num == 2: + lts_all = startend_row_indices[:, 0, :, 0] # (b, skv) + ute_all = startend_row_indices[:, 0, :, 1] # (b, skv) + arange_ref = paddle.arange(skv, dtype=paddle.int32).unsqueeze(0) # (1, skv) + expected_causal_ute = paddle.minimum(arange_ref, lts_all) # (b, skv) + if paddle.equal_all(ute_all, expected_causal_ute).item(): + varlen_causal = True + result = { "q": q_varlen, "k": k_varlen, @@ -73,7 +88,7 @@ def convert_to_varlen( "cu_seqlens_k": cu_seqlens, "max_seqlen_q": max_doc_len, "max_seqlen_k": max_doc_len, - "causal": causal, + "causal": varlen_causal, } # For non-causal masks with trailing padding: padding rows attend to From 79e2498a2f3d96f4e9331977486f11166103b330 Mon Sep 17 00:00:00 2001 From: umiswing Date: Wed, 15 Apr 2026 12:09:02 +0800 Subject: [PATCH 07/10] add simu causal document mask case to test_flashmask_to_varlen.py --- test_flashmask_to_varlen.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test_flashmask_to_varlen.py b/test_flashmask_to_varlen.py index b64595c..db0d6f6 100644 --- a/test_flashmask_to_varlen.py +++ b/test_flashmask_to_varlen.py @@ -41,6 +41,8 @@ generate_document_mask, generate_causal_document_mask_diff_batch, generate_document_mask_diff_batch, + generate_document_mask_simu, + generate_document_mask_diff_batch_simu, ) from test_util import attention_ref @@ -105,6 +107,8 @@ def generate_shapes(): partial(generate_causal_document_mask), # causal document partial(generate_document_mask_diff_batch), # document partial(generate_causal_document_mask_diff_batch), # causal document + partial(generate_document_mask_simu), # simu causal document + partial(generate_document_mask_diff_batch_simu), # simu causal document diff batch ] From 5de34fb54d0d2457887b3b83e0bfcad1da58d770 Mon Sep 17 00:00:00 2001 From: umiswing Date: Wed, 15 Apr 2026 19:50:42 +0800 Subject: [PATCH 08/10] adapt test_flashmask_to_varlen.py to flashmask's use_varlen. add bwd test --- test_flashmask_to_varlen.py | 105 +++++++++++++++--------------------- 1 file changed, 44 insertions(+), 61 deletions(-) diff --git a/test_flashmask_to_varlen.py b/test_flashmask_to_varlen.py index db0d6f6..cf27b01 100644 --- a/test_flashmask_to_varlen.py +++ b/test_flashmask_to_varlen.py @@ -24,16 +24,8 @@ import torch # ── Paddle: flashmask_attention ────────────────────────────────────────────── -try: - from flash_mask.cute.interface import flashmask_attention -except (ImportError, ModuleNotFoundError): - from paddle.nn.functional.flash_attention import flashmask_attention - -# ── PyTorch: flash_attn_varlen_func (FA4 cute) ────────────────────────────── -from flash_attn.cute.interface import flash_attn_varlen_func - -# ── convert_to_varlen (fake implementation in varlen_utils.py) ─────────────── -from varlen_utils import convert_to_varlen +from flash_mask import flashmask_attention +import flash_mask # ── Mask generators (Paddle) ──────────────────────────────────────────────── from generate_startend_row_indices import ( @@ -153,30 +145,6 @@ def test_flashmask_to_varlen( batch_size, seqlen_q, seqlen_k, nheads_startend_row_indices ) - # ── 2. Convert to varlen format ────────────────────────────────────────── - # convert_to_varlen returns Paddle tensors in a dict: - # "q": (total_q, nheads, d) Paddle - # "k": (total_k, nheads_kv, d) Paddle - # "v": (total_k, nheads_kv, dv) Paddle - # "cu_seqlens_q": (num_seqs + 1,) Paddle int32 - # "cu_seqlens_k": (num_seqs + 1,) Paddle int32 - # "max_seqlen_q": int - # "max_seqlen_k": int - # "causal": bool - - varlen = convert_to_varlen( - q_paddle, k_paddle, v_paddle, causal, startend_row_indices - ) - - q_vl_paddle = varlen["q"] - k_vl_paddle = varlen["k"] - v_vl_paddle = varlen["v"] - cu_seqlens_q_paddle = varlen["cu_seqlens_q"] - cu_seqlens_k_paddle = varlen["cu_seqlens_k"] - max_seqlen_q = varlen["max_seqlen_q"] - max_seqlen_k = varlen["max_seqlen_k"] - varlen_causal = varlen.get("causal", causal) - # ── 3. Call Paddle's flashmask_attention ───────────────────────────────── paddle.set_flags({"FLAGS_flash_attn_version": 4}) @@ -191,7 +159,7 @@ def test_flashmask_to_varlen( k_fm.stop_gradient = False v_fm.stop_gradient = False - out_fm, lse_fm = flashmask_attention( + out_fm, lse_fm = flash_mask.cute.interface.flashmask_attention( q_fm, k_fm, v_fm, @@ -200,48 +168,63 @@ def test_flashmask_to_varlen( return_softmax_lse=True, ) + q_varlen = q_paddle.detach().clone() + k_varlen = k_paddle.detach().clone() + v_varlen = v_paddle.detach().clone() + q_varlen.stop_gradient = False + k_varlen.stop_gradient = False + v_varlen.stop_gradient = False + # ── 4. Call PyTorch's flash_attn_varlen_func ───────────────────────────── # Convert Paddle varlen tensors to PyTorch CUDA tensors - q_varlen_pt = paddle_to_torch(q_vl_paddle).contiguous() - k_varlen_pt = paddle_to_torch(k_vl_paddle).contiguous() - v_varlen_pt = paddle_to_torch(v_vl_paddle).contiguous() - cu_seqlens_q_pt = paddle_to_torch(cu_seqlens_q_paddle).contiguous().to(torch.int32) - cu_seqlens_k_pt = paddle_to_torch(cu_seqlens_k_paddle).contiguous().to(torch.int32) - - out_varlen_pt, lse_varlen_pt = flash_attn_varlen_func( - q_varlen_pt, - k_varlen_pt, - v_varlen_pt, - cu_seqlens_q=cu_seqlens_q_pt, - cu_seqlens_k=cu_seqlens_k_pt, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - causal=varlen_causal, - return_lse=True, + out_varlen, lse_varlen = flashmask_attention( + q_varlen, + k_varlen, + v_varlen, + startend_row_indices=startend_row_indices, + causal=causal, + return_softmax_lse=True, + use_varlen=True ) # ── 5. Compare outputs ─────────────────────────────────────────────────── - # Map varlen output back to padded layout for comparison. - # If convert_to_varlen provides an "output_to_padded" callable, use it; - # otherwise reshape the flat (total_q, nheads, dv) back to - # (batch_size, seqlen_q, nheads, dv). - if "output_to_padded" in varlen: - out_varlen_padded_pt = varlen["output_to_padded"](out_varlen_pt) - else: - out_varlen_padded_pt = out_varlen_pt.reshape(batch_size, seqlen_q, nheads, dv) # Convert both outputs to float32 numpy for comparison out_fm_np = paddle.cast(out_fm, paddle.float32).numpy() - out_vl_np = torch_to_paddle(out_varlen_padded_pt).cast(paddle.float32).numpy() + out_vl_np = paddle.cast(out_varlen, paddle.float32).numpy() max_diff = np.max(np.abs(out_fm_np - out_vl_np)) mean_diff = np.mean(np.abs(out_fm_np - out_vl_np)) - print(f"\n[flashmask vs varlen] max diff: {max_diff:.6e}, mean diff: {mean_diff:.6e}") + print(f"\n[fwd] max diff: {max_diff:.6e}, mean diff: {mean_diff:.6e}") assert np.allclose(out_fm_np, out_vl_np, rtol=1e-2, atol=1e-2), ( f"Output mismatch: max diff {max_diff:.6e}, mean diff {mean_diff:.6e}" ) + # ── 6. Backward ────────────────────────────────────────────────────────── + # Generate the same random gradient for both paths. + g_fm = paddle.randn(shape=out_fm.shape, dtype=out_fm.dtype) + + # Flashmask backward (Paddle) + out_fm.backward(g_fm) + + g_vl = g_fm.detach().clone() + out_varlen.backward(g_vl) + + for name, grad_fm, grad_vl in [ + ("dQ", q_fm.grad, q_varlen.grad), + ("dK", k_fm.grad, k_varlen.grad), + ("dV", v_fm.grad, v_varlen.grad), + ]: + grad_fm_np = paddle.cast(grad_fm, paddle.float32).numpy() + grad_vl_np = paddle.cast(grad_vl, paddle.float32).numpy() + + max_diff = np.max(np.abs(grad_fm_np - grad_vl_np)) + mean_diff = np.mean(np.abs(grad_fm_np - grad_vl_np)) + print(f"[bwd {name}] max diff: {max_diff:.6e}, mean diff: {mean_diff:.6e}") + assert np.allclose(grad_fm_np, grad_vl_np, rtol=1e-2, atol=1e-2), ( + f"{name} mismatch: max diff {max_diff:.6e}, mean diff {mean_diff:.6e}" + ) # ───────────────────────────────────────────────────────────────────────────── # Standalone runner From 1497a839545dc3ae4e9b67d9f6496e4056693295 Mon Sep 17 00:00:00 2001 From: umiswing Date: Thu, 16 Apr 2026 16:58:19 +0800 Subject: [PATCH 09/10] add fa4 varlen benchmark --- benchmark_flashmask.py | 68 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/benchmark_flashmask.py b/benchmark_flashmask.py index 548a69d..72b78d3 100644 --- a/benchmark_flashmask.py +++ b/benchmark_flashmask.py @@ -9,6 +9,8 @@ from flash_mask.cute.interface import flashmask_attention except (ImportError, ModuleNotFoundError): from paddle.nn.functional.flash_attention import flashmask_attention + +from flash_mask import flashmask_attention as flashmask_attention_interface import random import os from datetime import datetime @@ -207,6 +209,68 @@ def flashmask_bwd(): return fwd_time_ms, bwd_time_ms, total_time_ms, fwd_flops, bwd_flops, total_flops, fwd_tflops, bwd_tflops, total_tflops, sparsity +def test_mask_varlen( + generate_mask_fn, + B, + S, + SKV, + H, + HKV, + D, + dtype = 'bf16', +): + if dtype == 'bf16': + data_type = paddle.bfloat16 + else: + data_type = paddle.float16 + + query = paddle.randn([B, S, H, D], dtype=data_type) + key = paddle.randn([B, SKV, HKV, D], dtype=data_type) + value = paddle.randn([B, SKV, HKV, D], dtype=data_type) + gradOut = paddle.randn([B, S, H, D], dtype=data_type) + + query.stop_gradient = False + key.stop_gradient = False + value.stop_gradient = False + + startend_row_indices, causal = None, True + if generate_mask_fn is not None: + startend_row_indices, causal = generate_mask_fn(B, SKV, HKV, D) + + sparsity = flashmask_block_sparsity(causal, startend_row_indices, B, H, HKV, S, SKV) + density = 1.0 - sparsity + + def fwd_bwd(): + q = query.detach().clone() + k = key.detach().clone() + v = value.detach().clone() + q.stop_gradient = False + k.stop_gradient = False + v.stop_gradient = False + out, lse = flashmask_attention_interface( + q, k, v, + startend_row_indices=startend_row_indices, + causal=causal, + return_softmax_lse=True, + use_varlen=True, + ) + out.backward(gradOut, retain_graph=True) + + total_time_ms = do_bench(fwd_bwd) + + total_flops = density * cal_flops(B, H, S, SKV, D, mode='fwd_bwd') + total_tflops = cal_tflops(total_flops, total_time_ms) + + # Return format consistent with test_mask; fwd/bwd individual values are N/A + fwd_time_ms = float('nan') + bwd_time_ms = float('nan') + fwd_flops = float('nan') + bwd_flops = float('nan') + fwd_tflops = float('nan') + bwd_tflops = float('nan') + + return fwd_time_ms, bwd_time_ms, total_time_ms, fwd_flops, bwd_flops, total_flops, fwd_tflops, bwd_tflops, total_tflops, sparsity + def flashmask_block_sparsity( causal, flashmask, @@ -797,6 +861,10 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas # Note(umiswing): support load mask and hybrid mask like this, and also, support simulate cp benchmark # "Dumped Mask": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), # "Hybrid SWA": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank, hybrid_mask_fn=partial(hybrid_swa, window_size=512, swa_ratio=0.75)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), + + # Varlen benchmarks (fwd+bwd combined, use_varlen=True) + "Varlen Causal Document Mask": lambda: test_mask_varlen(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), + "Varlen Document Mask": lambda: test_mask_varlen(generate_mask_fn=partial(generate_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), } if "all" in examples: From 496a8e9e3141c12912b02d6732c55d42f3d04b15 Mon Sep 17 00:00:00 2001 From: umiswing Date: Thu, 16 Apr 2026 21:51:15 +0800 Subject: [PATCH 10/10] add mla benchmark (d != dv) --- benchmark_flashmask.py | 64 ++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 28 deletions(-) diff --git a/benchmark_flashmask.py b/benchmark_flashmask.py index 72b78d3..5f2283d 100644 --- a/benchmark_flashmask.py +++ b/benchmark_flashmask.py @@ -119,7 +119,10 @@ def test_mask( HKV, D, dtype = 'bf16', + DV = None, ): + if DV is None: + DV = D if dtype == 'bf16': data_type = paddle.bfloat16 @@ -128,8 +131,8 @@ def test_mask( query = paddle.randn([B, S, H, D], dtype=data_type) key = paddle.randn([B, SKV, HKV, D], dtype=data_type) - value = paddle.randn([B, SKV, HKV, D], dtype=data_type) - gradOut = paddle.randn([B, S, H, D], dtype=data_type) + value = paddle.randn([B, SKV, HKV, DV], dtype=data_type) + gradOut = paddle.randn([B, S, H, DV], dtype=data_type) query.stop_gradient = False key.stop_gradient = False @@ -218,7 +221,11 @@ def test_mask_varlen( HKV, D, dtype = 'bf16', + DV = None, ): + if DV is None: + DV = D + if dtype == 'bf16': data_type = paddle.bfloat16 else: @@ -226,8 +233,8 @@ def test_mask_varlen( query = paddle.randn([B, S, H, D], dtype=data_type) key = paddle.randn([B, SKV, HKV, D], dtype=data_type) - value = paddle.randn([B, SKV, HKV, D], dtype=data_type) - gradOut = paddle.randn([B, S, H, D], dtype=data_type) + value = paddle.randn([B, SKV, HKV, DV], dtype=data_type) + gradOut = paddle.randn([B, S, H, DV], dtype=data_type) query.stop_gradient = False key.stop_gradient = False @@ -818,8 +825,9 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas qksparse_mask = eval(line.split(":")[-1].split("#")[1].strip()) doc_seq_lens_list.append((total_length, doc_list, qksparse_mask)) #doc_seq_lens_list = doc_seq_lens_list[::-1] - for D in [128] if fm_version == 4 else [64, 128, 256]: - H = 4096 // D + # (D, DV, H): standard configs + MLA (d=192, dv=128) + head_configs = [(64, 64, 64), (128, 128, 32), (192, 128, 16), (256, 256, 16)] + for D, DV, H in head_configs: HKV = H for idx, (S, prefix_doc_seq_lens, qksparse_mask) in enumerate(doc_seq_lens_list): B = 128 * 1024 // S @@ -829,10 +837,10 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas doc_seq_lens = [x[1] for x in prefix_doc_seq_lens] maskout_pair = [] offset = 0 - print(f"{B}_{S}_{H}_{HKV}_{D}_{idx}_{dtype}") + print(f"{B}_{S}_{H}_{HKV}_{D}_{DV}_{idx}_{dtype}") if not overwrite: - if os.path.exists(f"{dtype}{suffix}/flashmaskv{fm_version}_{B}_{S}_{H}_{D}_{idx}.csv"): - print(f"{dtype}{suffix}/flashmaskv{fm_version}_{B}_{S}_{H}_{D}_{idx}.csv already exists, skipping. To enable overwrite, use: --overwrite (True by default).") + if os.path.exists(f"{dtype}{suffix}/flashmaskv{fm_version}_{B}_{S}_{H}_{D}_{DV}_{idx}.csv"): + print(f"{dtype}{suffix}/flashmaskv{fm_version}_{B}_{S}_{H}_{D}_{DV}_{idx}.csv already exists, skipping. To enable overwrite, use: --overwrite (True by default).") continue if sum(qksparse_mask) == 0: maskout_pair = [(1024, 538), (2358, 1700)] @@ -844,27 +852,27 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas share_qa_docs = [split_sequence(doc_seq) for doc_seq in doc_seq_lens] available_examples = { - "Full": lambda: test_mask(generate_mask_fn=partial(generate_none_mask, causal=False), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Causal": lambda: test_mask(generate_mask_fn=partial(generate_none_mask, causal=True), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Sliding Window": lambda: test_mask(generate_mask_fn=partial(generate_sliding_window_mask, window_size=int(S*0.0625)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Causal Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Share Question Mask": lambda: test_mask(generate_mask_fn=partial(generate_share_question_mask, doc_seq_lens=share_qa_docs), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - # "Global Sliding Window": lambda: test_mask(generate_mask_fn=partial(generate_global_sliding_window_mask, global_token=16, window_size=(int(S*0.0625), int(S*0.0625))), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Causal Blockwise Mask": lambda: test_mask(generate_mask_fn=partial(generate_causal_blockwise_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Prefix LM Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Prefix LM Causal Mask": lambda: test_mask(generate_mask_fn=partial(generate_prefix_lm_causal_mask, prefix_length=int(S*0.5)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "QK-sparse Mask": lambda: test_mask(generate_mask_fn=partial(generate_qk_sparse_mask, maskout_pair=maskout_pair), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Random Eviction Mask": lambda: test_mask(generate_mask_fn=partial(generate_random_eviction_mask, start_row=S//2), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - # "Hybrid SWA Prefix LM Doc": lambda: test_mask(generate_mask_fn=partial(generate_hybrid_swa_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), + "Full": lambda: test_mask(generate_mask_fn=partial(generate_none_mask, causal=False), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Causal": lambda: test_mask(generate_mask_fn=partial(generate_none_mask, causal=True), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Sliding Window": lambda: test_mask(generate_mask_fn=partial(generate_sliding_window_mask, window_size=int(S*0.0625)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Causal Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Share Question Mask": lambda: test_mask(generate_mask_fn=partial(generate_share_question_mask, doc_seq_lens=share_qa_docs), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + # "Global Sliding Window": lambda: test_mask(generate_mask_fn=partial(generate_global_sliding_window_mask, global_token=16, window_size=(int(S*0.0625), int(S*0.0625))), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Causal Blockwise Mask": lambda: test_mask(generate_mask_fn=partial(generate_causal_blockwise_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Prefix LM Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Prefix LM Causal Mask": lambda: test_mask(generate_mask_fn=partial(generate_prefix_lm_causal_mask, prefix_length=int(S*0.5)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "QK-sparse Mask": lambda: test_mask(generate_mask_fn=partial(generate_qk_sparse_mask, maskout_pair=maskout_pair), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Random Eviction Mask": lambda: test_mask(generate_mask_fn=partial(generate_random_eviction_mask, start_row=S//2), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + # "Hybrid SWA Prefix LM Doc": lambda: test_mask(generate_mask_fn=partial(generate_hybrid_swa_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), # Note(umiswing): support load mask and hybrid mask like this, and also, support simulate cp benchmark - # "Dumped Mask": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - # "Hybrid SWA": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank, hybrid_mask_fn=partial(hybrid_swa, window_size=512, swa_ratio=0.75)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), + # "Dumped Mask": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + # "Hybrid SWA": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank, hybrid_mask_fn=partial(hybrid_swa, window_size=512, swa_ratio=0.75)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), # Varlen benchmarks (fwd+bwd combined, use_varlen=True) - "Varlen Causal Document Mask": lambda: test_mask_varlen(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), - "Varlen Document Mask": lambda: test_mask_varlen(generate_mask_fn=partial(generate_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype), + "Varlen Causal Document Mask": lambda: test_mask_varlen(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), + "Varlen Document Mask": lambda: test_mask_varlen(generate_mask_fn=partial(generate_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV), } if "all" in examples: @@ -905,8 +913,8 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas content2=tabulate(results, headers=headers, tablefmt="tsv") os.makedirs(f"{dtype}{suffix}", exist_ok=True) # Note(umiswing): this file name is better, but i need to keep the old name for fig plotting - # text_file = open(f"{dtype}{suffix}/flashmaskv{fm_version}_{current_time}_{B}_{SQ}_{SKV}_{H}_{HKV}_{D}_{idx}.csv","w") - text_file = open(f"{dtype}{suffix}/flashmaskv{fm_version}_{current_time}_{B}_{S}_{H}_{HKV}_{D}_{idx}.csv","w") + # text_file = open(f"{dtype}{suffix}/flashmaskv{fm_version}_{current_time}_{B}_{SQ}_{SKV}_{H}_{HKV}_{D}_{DV}_{idx}.csv","w") + text_file = open(f"{dtype}{suffix}/flashmaskv{fm_version}_{current_time}_{B}_{S}_{H}_{HKV}_{D}_{DV}_{idx}.csv","w") text_file.write(content2) text_file.close()