Skip to content
124 changes: 100 additions & 24 deletions benchmark_flashmask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from flash_mask.cute.interface import flashmask_attention
except (ImportError, ModuleNotFoundError):
from paddle.nn.functional.flash_attention import flashmask_attention

from flash_mask import flashmask_attention as flashmask_attention_interface
import random
import os
from datetime import datetime
Expand Down Expand Up @@ -117,7 +119,10 @@ def test_mask(
HKV,
D,
dtype = 'bf16',
DV = None,
):
if DV is None:
DV = D

if dtype == 'bf16':
data_type = paddle.bfloat16
Expand All @@ -126,8 +131,8 @@ def test_mask(

query = paddle.randn([B, S, H, D], dtype=data_type)
key = paddle.randn([B, SKV, HKV, D], dtype=data_type)
value = paddle.randn([B, SKV, HKV, D], dtype=data_type)
gradOut = paddle.randn([B, S, H, D], dtype=data_type)
value = paddle.randn([B, SKV, HKV, DV], dtype=data_type)
gradOut = paddle.randn([B, S, H, DV], dtype=data_type)

query.stop_gradient = False
key.stop_gradient = False
Expand Down Expand Up @@ -207,6 +212,72 @@ def flashmask_bwd():

return fwd_time_ms, bwd_time_ms, total_time_ms, fwd_flops, bwd_flops, total_flops, fwd_tflops, bwd_tflops, total_tflops, sparsity

def test_mask_varlen(
generate_mask_fn,
B,
S,
SKV,
H,
HKV,
D,
dtype = 'bf16',
DV = None,
):
if DV is None:
DV = D

if dtype == 'bf16':
data_type = paddle.bfloat16
else:
data_type = paddle.float16

query = paddle.randn([B, S, H, D], dtype=data_type)
key = paddle.randn([B, SKV, HKV, D], dtype=data_type)
value = paddle.randn([B, SKV, HKV, DV], dtype=data_type)
gradOut = paddle.randn([B, S, H, DV], dtype=data_type)

query.stop_gradient = False
key.stop_gradient = False
value.stop_gradient = False

startend_row_indices, causal = None, True
if generate_mask_fn is not None:
startend_row_indices, causal = generate_mask_fn(B, SKV, HKV, D)

sparsity = flashmask_block_sparsity(causal, startend_row_indices, B, H, HKV, S, SKV)
density = 1.0 - sparsity

def fwd_bwd():
q = query.detach().clone()
k = key.detach().clone()
v = value.detach().clone()
q.stop_gradient = False
k.stop_gradient = False
v.stop_gradient = False
out, lse = flashmask_attention_interface(
q, k, v,
startend_row_indices=startend_row_indices,
causal=causal,
return_softmax_lse=True,
use_varlen=True,
)
out.backward(gradOut, retain_graph=True)

total_time_ms = do_bench(fwd_bwd)

total_flops = density * cal_flops(B, H, S, SKV, D, mode='fwd_bwd')
total_tflops = cal_tflops(total_flops, total_time_ms)

# Return format consistent with test_mask; fwd/bwd individual values are N/A
fwd_time_ms = float('nan')
bwd_time_ms = float('nan')
fwd_flops = float('nan')
bwd_flops = float('nan')
fwd_tflops = float('nan')
bwd_tflops = float('nan')

return fwd_time_ms, bwd_time_ms, total_time_ms, fwd_flops, bwd_flops, total_flops, fwd_tflops, bwd_tflops, total_tflops, sparsity

def flashmask_block_sparsity(
causal,
flashmask,
Expand Down Expand Up @@ -754,8 +825,9 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas
qksparse_mask = eval(line.split(":")[-1].split("#")[1].strip())
doc_seq_lens_list.append((total_length, doc_list, qksparse_mask))
#doc_seq_lens_list = doc_seq_lens_list[::-1]
for D in [128] if fm_version == 4 else [64, 128, 256]:
H = 4096 // D
# (D, DV, H): standard configs + MLA (d=192, dv=128)
head_configs = [(64, 64, 64), (128, 128, 32), (192, 128, 16), (256, 256, 16)]
for D, DV, H in head_configs:
HKV = H
for idx, (S, prefix_doc_seq_lens, qksparse_mask) in enumerate(doc_seq_lens_list):
B = 128 * 1024 // S
Expand All @@ -765,10 +837,10 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas
doc_seq_lens = [x[1] for x in prefix_doc_seq_lens]
maskout_pair = []
offset = 0
print(f"{B}_{S}_{H}_{HKV}_{D}_{idx}_{dtype}")
print(f"{B}_{S}_{H}_{HKV}_{D}_{DV}_{idx}_{dtype}")
if not overwrite:
if os.path.exists(f"{dtype}{suffix}/flashmaskv{fm_version}_{B}_{S}_{H}_{D}_{idx}.csv"):
print(f"{dtype}{suffix}/flashmaskv{fm_version}_{B}_{S}_{H}_{D}_{idx}.csv already exists, skipping. To enable overwrite, use: --overwrite (True by default).")
if os.path.exists(f"{dtype}{suffix}/flashmaskv{fm_version}_{B}_{S}_{H}_{D}_{DV}_{idx}.csv"):
print(f"{dtype}{suffix}/flashmaskv{fm_version}_{B}_{S}_{H}_{D}_{DV}_{idx}.csv already exists, skipping. To enable overwrite, use: --overwrite (True by default).")
continue
if sum(qksparse_mask) == 0:
maskout_pair = [(1024, 538), (2358, 1700)]
Expand All @@ -780,23 +852,27 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas

share_qa_docs = [split_sequence(doc_seq) for doc_seq in doc_seq_lens]
available_examples = {
"Full": lambda: test_mask(generate_mask_fn=partial(generate_none_mask, causal=False), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
"Causal": lambda: test_mask(generate_mask_fn=partial(generate_none_mask, causal=True), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
"Sliding Window": lambda: test_mask(generate_mask_fn=partial(generate_sliding_window_mask, window_size=int(S*0.0625)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
"Causal Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
"Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
"Share Question Mask": lambda: test_mask(generate_mask_fn=partial(generate_share_question_mask, doc_seq_lens=share_qa_docs), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
# "Global Sliding Window": lambda: test_mask(generate_mask_fn=partial(generate_global_sliding_window_mask, global_token=16, window_size=(int(S*0.0625), int(S*0.0625))), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
"Causal Blockwise Mask": lambda: test_mask(generate_mask_fn=partial(generate_causal_blockwise_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
"Prefix LM Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
"Prefix LM Causal Mask": lambda: test_mask(generate_mask_fn=partial(generate_prefix_lm_causal_mask, prefix_length=int(S*0.5)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
"QK-sparse Mask": lambda: test_mask(generate_mask_fn=partial(generate_qk_sparse_mask, maskout_pair=maskout_pair), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
"Random Eviction Mask": lambda: test_mask(generate_mask_fn=partial(generate_random_eviction_mask, start_row=S//2), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
# "Hybrid SWA Prefix LM Doc": lambda: test_mask(generate_mask_fn=partial(generate_hybrid_swa_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
"Full": lambda: test_mask(generate_mask_fn=partial(generate_none_mask, causal=False), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
"Causal": lambda: test_mask(generate_mask_fn=partial(generate_none_mask, causal=True), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
"Sliding Window": lambda: test_mask(generate_mask_fn=partial(generate_sliding_window_mask, window_size=int(S*0.0625)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
"Causal Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
"Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
"Share Question Mask": lambda: test_mask(generate_mask_fn=partial(generate_share_question_mask, doc_seq_lens=share_qa_docs), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
# "Global Sliding Window": lambda: test_mask(generate_mask_fn=partial(generate_global_sliding_window_mask, global_token=16, window_size=(int(S*0.0625), int(S*0.0625))), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
"Causal Blockwise Mask": lambda: test_mask(generate_mask_fn=partial(generate_causal_blockwise_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
"Prefix LM Document Mask": lambda: test_mask(generate_mask_fn=partial(generate_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
"Prefix LM Causal Mask": lambda: test_mask(generate_mask_fn=partial(generate_prefix_lm_causal_mask, prefix_length=int(S*0.5)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
"QK-sparse Mask": lambda: test_mask(generate_mask_fn=partial(generate_qk_sparse_mask, maskout_pair=maskout_pair), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
"Random Eviction Mask": lambda: test_mask(generate_mask_fn=partial(generate_random_eviction_mask, start_row=S//2), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
# "Hybrid SWA Prefix LM Doc": lambda: test_mask(generate_mask_fn=partial(generate_hybrid_swa_prefix_lm_document_mask, doc_seq_lens=prefix_doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),

# Note(umiswing): support load mask and hybrid mask like this, and also, support simulate cp benchmark
# "Dumped Mask": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
# "Hybrid SWA": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank, hybrid_mask_fn=partial(hybrid_swa, window_size=512, swa_ratio=0.75)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype),
# "Dumped Mask": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
# "Hybrid SWA": lambda: test_mask(generate_mask_fn=partial(load_mask, path=mask_path, causal=False, cp_size=cp_size, cp_rank=cp_rank, hybrid_mask_fn=partial(hybrid_swa, window_size=512, swa_ratio=0.75)), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),

# Varlen benchmarks (fwd+bwd combined, use_varlen=True)
"Varlen Causal Document Mask": lambda: test_mask_varlen(generate_mask_fn=partial(generate_causal_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
"Varlen Document Mask": lambda: test_mask_varlen(generate_mask_fn=partial(generate_document_mask, doc_seq_lens=doc_seq_lens), B=B, S=SQ, SKV=SKV, H=H, HKV=HKV, D=D, dtype=dtype, DV=DV),
}

if "all" in examples:
Expand Down Expand Up @@ -837,8 +913,8 @@ def main(examples: List[str] = ["all"], dtype='bf16', fm_version=1, suffix="_bas
content2=tabulate(results, headers=headers, tablefmt="tsv")
os.makedirs(f"{dtype}{suffix}", exist_ok=True)
# Note(umiswing): this file name is better, but i need to keep the old name for fig plotting
# text_file = open(f"{dtype}{suffix}/flashmaskv{fm_version}_{current_time}_{B}_{SQ}_{SKV}_{H}_{HKV}_{D}_{idx}.csv","w")
text_file = open(f"{dtype}{suffix}/flashmaskv{fm_version}_{current_time}_{B}_{S}_{H}_{HKV}_{D}_{idx}.csv","w")
# text_file = open(f"{dtype}{suffix}/flashmaskv{fm_version}_{current_time}_{B}_{SQ}_{SKV}_{H}_{HKV}_{D}_{DV}_{idx}.csv","w")
text_file = open(f"{dtype}{suffix}/flashmaskv{fm_version}_{current_time}_{B}_{S}_{H}_{HKV}_{D}_{DV}_{idx}.csv","w")
text_file.write(content2)
text_file.close()

Expand Down
137 changes: 136 additions & 1 deletion generate_startend_row_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,11 @@ def generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens
assert total_seqlen <= seqlen_k
assert len(doc_seqlens) >= 3
padding = seqlen_k - np.sum(doc_seqlens)
doc_seqlens[-1] += padding
seq_cusums = np.cumsum(doc_seqlens)

startend_row_indices = np.repeat(seq_cusums, doc_seqlens)
padding_mask = np.repeat(seq_cusums[-1], padding)
startend_row_indices = np.concatenate([startend_row_indices, padding_mask])
startend_row_indices = paddle.to_tensor(startend_row_indices, dtype=paddle.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0)
startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q)

Expand Down Expand Up @@ -344,3 +345,137 @@ def generate_random_eviction_mask(batch_size, seqlen_q, seqlen_k, h, start_row=N
startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q)
causal = True
return startend_row_indices, causal


def _scale_doc_seqlens(doc_seqlens, seqlen_k):
"""Scale doc_seqlens from base 8192 to target seqlen_k."""
if seqlen_k != 8192:
doc_seqlens = [int(d * (seqlen_k / 8192)) for d in doc_seqlens]
return doc_seqlens


# Pre-defined document-length distributions for per-batch variation.
_DIFF_BATCH_DOC_SEQLENS = [
[2538, 1742, 3213],
[1500, 3500, 2493],
[3000, 1000, 3493],
[800, 2200, 4493],
]


def generate_causal_document_mask_diff_batch(batch_size, seqlen_q, seqlen_k, h, doc_seqlens_list=None):
"""Causal document mask where each batch item has DIFFERENT document boundaries."""
if doc_seqlens_list is None:
doc_seqlens_list = [
_scale_doc_seqlens(_DIFF_BATCH_DOC_SEQLENS[i % len(_DIFF_BATCH_DOC_SEQLENS)], seqlen_k)
for i in range(batch_size)
]
if seqlen_k != 8192:
print(f"{seqlen_k=}, auto setting per-batch doc_seqlens to {doc_seqlens_list}")

batch_indices = []
for bi in range(batch_size):
doc_seqlens = list(doc_seqlens_list[bi])
total_seqlen = np.sum(doc_seqlens)
assert total_seqlen <= seqlen_k
assert len(doc_seqlens) >= 3
padding = seqlen_k - np.sum(doc_seqlens)
seq_cusums = np.cumsum(doc_seqlens)

sri = np.repeat(seq_cusums, doc_seqlens)
padding_mask = np.repeat(seq_cusums[-1], padding)
sri = np.concatenate([sri, padding_mask])
batch_indices.append(sri)

stacked = np.stack(batch_indices, axis=0) # (batch_size, seqlen_k)
startend_row_indices = paddle.to_tensor(stacked, dtype=paddle.int32).reshape(
(batch_size, 1, seqlen_k, 1)
)
startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q)

causal = True
return startend_row_indices, causal


def generate_document_mask_diff_batch(batch_size, seqlen_q, seqlen_k, h, doc_seqlens_list=None):
"""Non-causal document mask where each batch item has DIFFERENT document boundaries."""
if doc_seqlens_list is None:
doc_seqlens_list = [
_scale_doc_seqlens(_DIFF_BATCH_DOC_SEQLENS[i % len(_DIFF_BATCH_DOC_SEQLENS)], seqlen_k)
for i in range(batch_size)
]
if seqlen_k != 8192:
print(f"{seqlen_k=}, auto setting per-batch doc_seqlens to {doc_seqlens_list}")

batch_down_left = []
batch_up_right = []
for bi in range(batch_size):
doc_seqlens = list(doc_seqlens_list[bi])
total_seqlen = np.sum(doc_seqlens)
assert total_seqlen <= seqlen_k
assert len(doc_seqlens) >= 3
padding = seqlen_k - np.sum(doc_seqlens)

down_left_row_indices = []
up_right_row_indices = []

cur_len_so_far = doc_seqlens[0]
for i in range(len(doc_seqlens)):
down_left_row_indices.extend([cur_len_so_far] * doc_seqlens[i])
if i < len(doc_seqlens) - 1:
cur_len_so_far += doc_seqlens[i + 1]
if padding > 0:
down_left_row_indices.extend([cur_len_so_far] * padding)

cur_len_so_far = 0
for i in range(len(doc_seqlens)):
up_right_row_indices.extend([cur_len_so_far] * doc_seqlens[i])
if i < len(doc_seqlens) - 1:
cur_len_so_far += doc_seqlens[i]
if padding > 0:
up_right_row_indices.extend([cur_len_so_far] * padding)

batch_down_left.append(down_left_row_indices)
batch_up_right.append(up_right_row_indices)

down_left = np.array(batch_down_left) # (batch_size, seqlen_k)
up_right = np.array(batch_up_right) # (batch_size, seqlen_k)

down_left = paddle.to_tensor(down_left, dtype=paddle.int32).reshape((batch_size, 1, seqlen_k, 1))
up_right = paddle.to_tensor(up_right, dtype=paddle.int32).reshape((batch_size, 1, seqlen_k, 1))
startend_row_indices = paddle.concat([down_left, up_right], axis=-1)
startend_row_indices = paddle.clip(startend_row_indices, max=seqlen_q)

causal = False
return startend_row_indices, causal

def generate_document_mask_simu(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None):
assert seqlen_q == seqlen_k
lts, causal = generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens)
causal = False

b, h, s, _ = lts.shape
ute = paddle.arange(
0, s, 1, dtype="int32"
).reshape((1, 1, s, 1)).repeat_interleave(b, 0).repeat_interleave(h, 1)

ute = paddle.where(ute <= lts, ute, lts)
startend_row_indices = paddle.concat([lts, ute], axis=-1)

return startend_row_indices, causal

def generate_document_mask_diff_batch_simu(batch_size, seqlen_q, seqlen_k, h, doc_seqlens_list=None):
assert seqlen_q == seqlen_k
lts, causal = generate_causal_document_mask_diff_batch(batch_size, seqlen_q, seqlen_k, h, doc_seqlens_list)
causal = False

b, h, s, _ = lts.shape
ute = paddle.arange(
0, s, 1, dtype="int32"
).reshape((1, 1, s, 1)).repeat_interleave(b, 0).repeat_interleave(h, 1)

ute = paddle.where(ute <= lts, ute, lts)
startend_row_indices = paddle.concat([lts, ute], axis=-1)

return startend_row_indices, causal

1 change: 1 addition & 0 deletions run_varlen.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
../wsm_varlen_env/bin/python -m pytest -v test_flashmask_to_varlen.py
Loading