diff --git a/flashmask_attention_torch.py b/flashmask_attention_torch.py new file mode 100644 index 0000000..c7ca23c --- /dev/null +++ b/flashmask_attention_torch.py @@ -0,0 +1,541 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +FlashMask Attention - PyTorch Implementation + +This module provides a PyTorch implementation of FlashMask attention algorithm. +The core equation is: + + result = softmax(Q @ K^T / sqrt(d) + M) @ V + +where M is the column-wise sparse mask introduced by FlashMask. + +Note: This is a reference implementation using standard PyTorch operations. +For optimal performance on NVIDIA GPUs, consider using the CUDA-optimized version. +""" + +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn.functional as F +from torch import Tensor + + +def flashmask_to_dense_mask( + startend_row_indices: Tensor, + seqlen_q: int, + seqlen_k: int, + dtype: torch.dtype, + causal: bool = True, +) -> Tensor: + """ + Convert FlashMask's startend_row_indices to dense attention mask. + + Args: + startend_row_indices: Column-wise sparse attention mask row indices tensor. + Shape: [batch_size, num_heads, seqlen_k, {1, 2, 4}] + seqlen_q: Query sequence length + seqlen_k: Key sequence length + dtype: Data type for the mask (should match attention scores dtype) + causal: Whether to use causal mask mode + + Returns: + Dense attention mask of shape [batch_size, num_heads, seqlen_q, seqlen_k] + """ + if startend_row_indices is None: + return None + + bz, num_head, seq_len_k, bound_num = startend_row_indices.shape + + # Create mask tensor initialized to 0 (will be filled with -inf for masked positions) + mask = torch.zeros((bz, num_head, seqlen_q, seqlen_k), dtype=dtype, device=startend_row_indices.device) + + has_end = (causal and bound_num == 2) or ((not causal) and bound_num == 4) + + # Create index tensors for vectorized operations + row_indices = torch.arange(seqlen_q, device=startend_row_indices.device, dtype=torch.int32) + col_indices = torch.arange(seqlen_k, device=startend_row_indices.device, dtype=torch.int32) + + for bi in range(bz): + for hi in range(num_head): + for j in range(seq_len_k): + # Lower triangular start + downstart = startend_row_indices[bi, hi, j, 0].item() + + if has_end: + downend = startend_row_indices[bi, hi, j, 1].item() + # Mask lower triangular region [downstart:downend, j] + if downstart < downend: + start_row = max(0, downstart) + end_row = min(seqlen_q, downend) + mask[bi, hi, start_row:end_row, j] = float('-inf') + + if causal: + # For causal mask, also mask upper triangle (future positions) + mask[bi, hi, :j, j] = float('-inf') + else: + # Upper triangular mask for bidirectional attention + upstart = startend_row_indices[bi, hi, j, 2].item() + upend = startend_row_indices[bi, hi, j, 3].item() + if upstart < upend: + start_row = max(0, upstart) + end_row = min(seqlen_q, upend) + mask[bi, hi, start_row:end_row, j] = float('-inf') + else: + # Mask from downstart to end + if downstart < seqlen_q: + mask[bi, hi, downstart:, j] = float('-inf') + + if causal: + mask[bi, hi, :j, j] = float('-inf') + else: + upend = startend_row_indices[bi, hi, j, 1].item() + if upend > 0: + mask[bi, hi, :upend, j] = float('-inf') + + return mask + + +def flashmask_to_dense_mask_vectorized( + startend_row_indices: Tensor, + seqlen_q: int, + seqlen_k: int, + dtype: torch.dtype, + causal: bool = True, +) -> Tensor: + """ + Vectorized version of flashmask_to_dense_mask for better performance. + + Args: + startend_row_indices: Column-wise sparse attention mask row indices tensor. + Shape: [batch_size, num_heads, seqlen_k, {1, 2, 4}] + seqlen_q: Query sequence length + seqlen_k: Key sequence length + dtype: Data type for the mask + causal: Whether to use causal mask mode + + Returns: + Dense attention mask of shape [batch_size, num_heads, seqlen_q, seqlen_k] + """ + if startend_row_indices is None: + return None + + bz, num_head, seq_len_k, bound_num = startend_row_indices.shape + device = startend_row_indices.device + + # Create row and column index tensors + # row_idx: [seqlen_q, 1], col_idx: [1, seqlen_k] + row_idx = torch.arange(seqlen_q, device=device, dtype=torch.int32).unsqueeze(1) + col_idx = torch.arange(seqlen_k, device=device, dtype=torch.int32).unsqueeze(0) + + # Initialize mask to zeros + mask = torch.zeros((bz, num_head, seqlen_q, seqlen_k), dtype=dtype, device=device) + + has_end = (causal and bound_num == 2) or ((not causal) and bound_num == 4) + + if causal: + # Causal mask: mask positions where row < col (upper triangle) + causal_mask = (row_idx < col_idx).unsqueeze(0).unsqueeze(0) # [1, 1, seqlen_q, seqlen_k] + mask.masked_fill_(causal_mask.expand(bz, num_head, -1, -1), float('-inf')) + + if bound_num == 1: + # Lower triangular start only + # startend_row_indices[..., 0] gives the start row for each column + # Mask rows >= downstart for each column + downstart = startend_row_indices[..., 0] # [bz, num_head, seq_len_k] + # Expand for broadcasting: [bz, num_head, 1, seq_len_k] + downstart_expanded = downstart.unsqueeze(2) + # Mask where row >= downstart + lower_mask = (row_idx.unsqueeze(0).unsqueeze(0) >= downstart_expanded) + mask.masked_fill_(lower_mask, float('-inf')) + + elif bound_num == 2: + # Lower triangular with start and end + downstart = startend_row_indices[..., 0] # [bz, num_head, seq_len_k] + downend = startend_row_indices[..., 1] # [bz, num_head, seq_len_k] + + downstart_expanded = downstart.unsqueeze(2) + downend_expanded = downend.unsqueeze(2) + + # Mask where downstart <= row < downend + lower_mask = (row_idx.unsqueeze(0).unsqueeze(0) >= downstart_expanded) & \ + (row_idx.unsqueeze(0).unsqueeze(0) < downend_expanded) + mask.masked_fill_(lower_mask, float('-inf')) + else: + # Bidirectional attention + if bound_num == 2: + # Lower triangular start + Upper triangular end + downstart = startend_row_indices[..., 0] # [bz, num_head, seq_len_k] + upend = startend_row_indices[..., 1] # [bz, num_head, seq_len_k] + + downstart_expanded = downstart.unsqueeze(2) + upend_expanded = upend.unsqueeze(2) + + # Lower mask: row >= downstart + lower_mask = (row_idx.unsqueeze(0).unsqueeze(0) >= downstart_expanded) + mask.masked_fill_(lower_mask, float('-inf')) + + # Upper mask: row < upend + upper_mask = (row_idx.unsqueeze(0).unsqueeze(0) < upend_expanded) + mask.masked_fill_(upper_mask, float('-inf')) + + elif bound_num == 4: + # Full bidirectional with start and end for both + downstart = startend_row_indices[..., 0] + downend = startend_row_indices[..., 1] + upstart = startend_row_indices[..., 2] + upend = startend_row_indices[..., 3] + + downstart_expanded = downstart.unsqueeze(2) + downend_expanded = downend.unsqueeze(2) + upstart_expanded = upstart.unsqueeze(2) + upend_expanded = upend.unsqueeze(2) + + # Lower mask: downstart <= row < downend + lower_mask = (row_idx.unsqueeze(0).unsqueeze(0) >= downstart_expanded) & \ + (row_idx.unsqueeze(0).unsqueeze(0) < downend_expanded) + mask.masked_fill_(lower_mask, float('-inf')) + + # Upper mask: upstart <= row < upend + upper_mask = (row_idx.unsqueeze(0).unsqueeze(0) >= upstart_expanded) & \ + (row_idx.unsqueeze(0).unsqueeze(0) < upend_expanded) + mask.masked_fill_(upper_mask, float('-inf')) + + return mask + + +def flashmask_attention( + query: Tensor, + key: Tensor, + value: Tensor, + startend_row_indices: Optional[Tensor] = None, + *, + dropout: float = 0.0, + causal: bool = False, + window_size: Optional[Union[int, tuple]] = None, + return_softmax_lse: bool = False, + return_seed_offset: bool = False, + fixed_seed_offset: Optional[Tensor] = None, + rng_name: str = "", + training: bool = True, + name: Optional[str] = None, + softmax_scale: Optional[float] = None, + block_mask: Optional[Tensor] = None, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """ + FlashMask: PyTorch Implementation + + This module provides the PyTorch implementation of the FlashMask algorithm. + The core equation is: + + result = softmax(Q @ K^T / sqrt(d) + M) @ V + + where M is the column-wise sparse mask introduced by FlashMask. + + Args: + query: The query tensor with shape [batch_size, q_seq_len, num_heads, head_dim]. + dtype can be float16 or bfloat16. + key: The key tensor with shape [batch_size, k_seq_len, k_num_heads, head_dim]. + dtype can be float16 or bfloat16. + value: The value tensor with shape [batch_size, k_seq_len, k_num_heads, head_dim]. + dtype can be float16 or bfloat16. + startend_row_indices: Column-wise sparse attention mask row indices tensor. + Shape: [batch_size, k_num_heads, k_seq_len, {1, 2, 4}]. dtype must be int32. + + - When `causal=True` and shape is [..., 1]: The value represents the starting + row index of the lower triangular mask. + - When `causal=True` and shape is [..., 2]: Values represent start and end + row indices of the lower triangular mask. + - When `causal=False` and shape is [..., 2]: Values represent lower triangular + start and upper triangular end. + - When `causal=False` and shape is [..., 4]: Values represent start/end for + both lower and upper triangular masks. + + dropout: Dropout ratio. Default is 0.0. + causal: Whether to enable causal mode. Default is False. + window_size: Sliding window size for local attention. Default is None. + return_softmax_lse: Whether to return log-sum-exp of softmax. Default is False. + return_seed_offset: Whether to return random seed offset. Default is False. + fixed_seed_offset: Fixed seed offset for dropout. Default is None. + rng_name: Random number generator name. Default is "". + training: Whether in training mode. Default is True. + name: Operation name. Default is None. + softmax_scale: Softmax scaling factor. Default is None (uses 1/sqrt(head_dim)). + block_mask: Block-level mask tensor. Currently not supported in this implementation. + + Returns: + Output tensor with shape [batch_size, q_seq_len, num_heads, head_dim]. + If return_softmax_lse is True, also returns the log-sum-exp tensor. + + Examples: + >>> import torch + >>> batch_size, seqlen, num_heads, head_dim = 1, 10, 2, 32 + >>> q = torch.rand(batch_size, seqlen, num_heads, head_dim, dtype=torch.bfloat16) + >>> k = torch.rand(batch_size, seqlen, num_heads, head_dim, dtype=torch.bfloat16) + >>> v = torch.rand(batch_size, seqlen, num_heads, head_dim, dtype=torch.bfloat16) + >>> startend_row_indices = torch.tensor([8]*10, dtype=torch.int32).reshape(1, 1, 10, 1) + >>> output = flashmask_attention(q, k, v, startend_row_indices, causal=True) + """ + # Input validation + assert query.dtype in [torch.float16, torch.bfloat16], \ + f"query dtype must be float16 or bfloat16, got {query.dtype}" + assert query.dtype == key.dtype == value.dtype, \ + "query, key, value must have the same dtype" + + batch_size, seqlen_q, num_heads, head_dim = query.shape + _, seqlen_k, num_heads_kv, _ = key.shape + + # Handle GQA (Grouped Query Attention) + if num_heads != num_heads_kv: + assert num_heads % num_heads_kv == 0, \ + f"num_heads ({num_heads}) must be divisible by num_heads_kv ({num_heads_kv})" + # Repeat key and value to match query heads + num_groups = num_heads // num_heads_kv + key = key.repeat_interleave(num_groups, dim=2) + value = value.repeat_interleave(num_groups, dim=2) + + # Handle window_size + if window_size is not None: + if isinstance(window_size, int): + window_size = (window_size, window_size) + assert startend_row_indices is None, \ + "Cannot use window_size with startend_row_indices" + # Generate sliding window mask + if causal: + startend_row_indices = torch.arange( + window_size[0] + 1, seqlen_q + window_size[0] + 1, + dtype=torch.int32, device=query.device + ).reshape(1, 1, seqlen_q, 1) + startend_row_indices = torch.clamp(startend_row_indices, max=seqlen_q) + startend_row_indices = startend_row_indices.repeat(batch_size, num_heads_kv, 1, 1) + else: + startend_row_indices = torch.empty((1, 1, seqlen_q, 2), dtype=torch.int32, device=query.device) + startend_row_indices[0, 0, :, 0] = torch.arange( + window_size[0] + 1, seqlen_q + window_size[0] + 1, dtype=torch.int32, device=query.device + ) + startend_row_indices[0, 0, :, 1] = torch.arange( + -window_size[1], seqlen_q - window_size[1], dtype=torch.int32, device=query.device + ) + startend_row_indices = torch.clamp(startend_row_indices, min=0, max=seqlen_q) + startend_row_indices = startend_row_indices.repeat(batch_size, num_heads_kv, 1, 1) + + # Validate startend_row_indices + if startend_row_indices is not None: + assert startend_row_indices.dtype == torch.int32, \ + f"startend_row_indices dtype must be int32, got {startend_row_indices.dtype}" + assert len(startend_row_indices.shape) == 4, \ + f"startend_row_indices must be 4D, got {startend_row_indices.shape}" + + assert startend_row_indices.shape[0] == batch_size, \ + f"startend_row_indices batch size mismatch" + assert startend_row_indices.shape[2] == seqlen_k, \ + f"startend_row_indices seq_len mismatch" + + # Handle head broadcasting + if startend_row_indices.shape[1] == 1: + startend_row_indices = startend_row_indices.expand(-1, num_heads, -1, -1) + elif startend_row_indices.shape[1] != num_heads: + if startend_row_indices.shape[1] == num_heads_kv: + startend_row_indices = startend_row_indices.repeat_interleave( + num_heads // num_heads_kv, dim=1 + ) + + # Set softmax scale + if softmax_scale is None: + softmax_scale = 1.0 / math.sqrt(head_dim) + + # Convert to attention format [batch, heads, seq, dim] for scaled_dot_product_attention + query_t = query.transpose(1, 2) # [batch, num_heads, seqlen_q, head_dim] + key_t = key.transpose(1, 2) # [batch, num_heads, seqlen_k, head_dim] + value_t = value.transpose(1, 2) # [batch, num_heads, seqlen_k, head_dim] + + # Generate attention mask + attn_mask = None + if startend_row_indices is not None: + attn_mask = flashmask_to_dense_mask_vectorized( + startend_row_indices, + seqlen_q, + seqlen_k, + dtype=query.dtype, + causal=causal, + ) + elif causal: + # Standard causal mask + attn_mask = torch.triu( + torch.ones(seqlen_q, seqlen_k, dtype=query.dtype, device=query.device) * float('-inf'), + diagonal=1 + ).unsqueeze(0).unsqueeze(0) + + # Compute attention using PyTorch's scaled_dot_product_attention + # This will use Flash Attention v2 if available + if attn_mask is not None: + # scaled_dot_product_attention expects attn_mask where masked positions have -inf + attn_output = F.scaled_dot_product_attention( + query_t, key_t, value_t, + attn_mask=attn_mask, + dropout_p=dropout if training else 0.0, + scale=softmax_scale, + ) + else: + attn_output = F.scaled_dot_product_attention( + query_t, key_t, value_t, + dropout_p=dropout if training else 0.0, + is_causal=causal, + scale=softmax_scale, + ) + + # Convert back to [batch, seq, heads, dim] + output = attn_output.transpose(1, 2) + + if return_softmax_lse: + # Compute log-sum-exp manually + # scores = (query_t @ key_t.transpose(-2, -1)) * softmax_scale + # if attn_mask is not None: + # scores = scores + attn_mask + # lse = torch.logsumexp(scores, dim=-1) + # lse = lse.transpose(1, 2) # [batch, seq, heads] + + # For efficiency, compute LSE in half precision + with torch.no_grad(): + scores = torch.matmul(query_t, key_t.transpose(-2, -1)) * softmax_scale + if attn_mask is not None: + scores = scores + attn_mask + lse = torch.logsumexp(scores.float(), dim=-1) + lse = lse.transpose(1, 2) # [batch, seq, heads] + + return output, lse + + if return_seed_offset: + # Not fully supported in PyTorch version + return output, None + + return output + + +class FlashMaskAttention(torch.nn.Module): + """ + FlashMask Attention Module for use in neural networks. + + This module wraps the flashmask_attention function for convenient use + in transformer architectures. + + Args: + causal: Whether to use causal attention. Default is False. + softmax_scale: Softmax scaling factor. Default is None (auto-computed). + dropout: Dropout probability. Default is 0.0. + + Examples: + >>> attn = FlashMaskAttention(causal=True) + >>> q = torch.rand(1, 10, 2, 32, dtype=torch.bfloat16) + >>> k = torch.rand(1, 10, 2, 32, dtype=torch.bfloat16) + >>> v = torch.rand(1, 10, 2, 32, dtype=torch.bfloat16) + >>> mask = torch.tensor([8]*10, dtype=torch.int32).reshape(1, 1, 10, 1) + >>> output = attn(q, k, v, mask) + """ + + def __init__( + self, + causal: bool = False, + softmax_scale: Optional[float] = None, + dropout: float = 0.0, + ): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout = dropout + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + startend_row_indices: Optional[Tensor] = None, + ) -> Tensor: + """ + Forward pass of FlashMask attention. + + Args: + query: Query tensor [batch, seq_q, heads, dim] + key: Key tensor [batch, seq_k, heads_kv, dim] + value: Value tensor [batch, seq_k, heads_kv, dim] + startend_row_indices: Optional mask indices [batch, heads_kv, seq_k, bounds] + + Returns: + Output tensor [batch, seq_q, heads, dim] + """ + return flashmask_attention( + query, key, value, + startend_row_indices=startend_row_indices, + causal=self.causal, + softmax_scale=self.softmax_scale, + dropout=self.dropout, + training=self.training, + ) + + +# Utility functions for creating common mask patterns + +def create_causal_mask(seqlen: int, batch_size: int = 1, num_heads: int = 1) -> Tensor: + """Create a standard causal mask's startend_row_indices.""" + startend_row_indices = torch.full( + (batch_size, num_heads, seqlen, 1), + seqlen, dtype=torch.int32 + ) + return startend_row_indices + + +def create_sliding_window_mask( + seqlen: int, + window_size: int, + batch_size: int = 1, + num_heads: int = 1, +) -> Tensor: + """Create a sliding window mask's startend_row_indices for causal attention.""" + startend_row_indices = torch.arange( + window_size + 1, seqlen + window_size + 1, dtype=torch.int32 + ).reshape(1, 1, seqlen, 1) + startend_row_indices = torch.clamp(startend_row_indices, max=seqlen) + return startend_row_indices.repeat(batch_size, num_heads, 1, 1) + + +def create_document_mask( + seqlen: int, + doc_boundaries: list, + batch_size: int = 1, + num_heads: int = 1, + causal: bool = True, +) -> Tensor: + """ + Create a document mask's startend_row_indices. + + Args: + seqlen: Total sequence length + doc_boundaries: List of document end positions (e.g., [4, 7, 10] for docs of lengths 4, 3, 3) + batch_size: Batch size + num_heads: Number of heads + causal: Whether to create causal document mask + + Returns: + startend_row_indices tensor + """ + startend = torch.zeros(seqlen, dtype=torch.int32) + doc_start = 0 + for doc_end in doc_boundaries: + startend[doc_start:doc_end] = doc_end + doc_start = doc_end + + startend_row_indices = startend.reshape(1, 1, seqlen, 1) + return startend_row_indices.repeat(batch_size, num_heads, 1, 1) diff --git a/generate_startend_row_indices_torch.py b/generate_startend_row_indices_torch.py new file mode 100644 index 0000000..0c751da --- /dev/null +++ b/generate_startend_row_indices_torch.py @@ -0,0 +1,347 @@ +import torch +import numpy as np + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def startend_row_indices_to_attn_bias(startend_row_indices, seqlen_q, nheads, dtype, causal=True): + if startend_row_indices is None: + return None + bz, num_head, seqlen_k, bound_num = startend_row_indices.shape + assert nheads % num_head == 0 + m = torch.zeros((bz, num_head, seqlen_q, seqlen_k), dtype=dtype, device=device) + has_end = (causal and bound_num == 2) or ((not causal) and bound_num == 4) + for bi in range(bz): + for hi in range(num_head): + for j in range(seqlen_k): + downstart = startend_row_indices[bi, hi, j, 0] + if has_end: + downend = startend_row_indices[bi, hi, j, 1] + m[bi, hi, downstart:downend, j] = -np.inf + else: + m[bi, hi, downstart:, j] = -np.inf + if causal: + # from flash-attention 2.1 and in flash-attention 3, If seqlen_q != seqlen_k and causal=True, + # the causal mask is aligned to the bottom right corner of the attention matrix, + # instead of the top-left corner. + # See: https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag + m[bi, hi, :max(0, j - (seqlen_k - seqlen_q)), j] = -np.inf + else: + if has_end: + upstart = startend_row_indices[bi, hi, j, 2] + upend = startend_row_indices[bi, hi, j, 3] + m[bi, hi, upstart:upend, j] = -np.inf + else: + upend = startend_row_indices[bi, hi, j, 1] + m[bi, hi, :upend, j] = -np.inf + m = torch.repeat_interleave(x=m, repeats=nheads // num_head, axis=1) + return m + +def generate_none_mask(batch_size, seqlen_q, seqlen_k, h, causal=True): + return None, causal + +def generate_sliding_window_mask(batch_size, seqlen_q, seqlen_k, h, window_size=None): + if window_size == None: + window_size = 1024 + if seqlen_k != 8192: + window_size = int(window_size * (seqlen_k / 8192)) + print(f"{seqlen_k=}, auto setting window_size to {window_size}") + + startend_row_indices = torch.arange( + window_size, seqlen_k + window_size, dtype=torch.int32, device=device + ).reshape((1, 1, seqlen_k, 1)) + startend_row_indices = torch.clip( + startend_row_indices, max=seqlen_q + ).repeat_interleave(batch_size, 0) + + causal=True + return startend_row_indices, causal + +def generate_causal_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): + # TODO: this seems buggy, to be fixed + if doc_seqlens == None: + doc_seqlens = [2538, 1742, 3213] + if seqlen_k != 8192: + doc_seqlens = [int(doc_seqlen * (seqlen_k / 8192)) for doc_seqlen in doc_seqlens] + print(f"{seqlen_k=}, auto setting doc_seqlens to {doc_seqlens}") + total_seqlen = np.sum(doc_seqlens) + assert total_seqlen <= seqlen_k + assert len(doc_seqlens) >= 3 + padding = seqlen_k - np.sum(doc_seqlens) + doc_seqlens[-1] += padding + seq_cusums = np.cumsum(doc_seqlens) + + startend_row_indices = np.repeat(seq_cusums, doc_seqlens) + # startend_row_indices = torch.as_tensor(startend_row_indices, dtype=torch.int32).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + startend_row_indices = ( + torch.as_tensor( + startend_row_indices, + dtype=torch.int32, + device=device + ) + .reshape(1, 1, seqlen_k, 1) + .repeat_interleave(batch_size, 0) + ) + + startend_row_indices = torch.clip(startend_row_indices, max=seqlen_q) + + causal = True + return startend_row_indices, causal + +def generate_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): + # TODO: this seems buggy, to be fixed + if doc_seqlens == None: + doc_seqlens = [2538, 1742, 3213] + if seqlen_k != 8192: + doc_seqlens = [int(doc_seqlen * (seqlen_k / 8192)) for doc_seqlen in doc_seqlens] + print(f"{seqlen_k=}, auto setting doc_seqlens to {doc_seqlens}") + total_seqlen = np.sum(doc_seqlens) + assert total_seqlen <= seqlen_k + assert len(doc_seqlens) >= 3 + padding = seqlen_k - np.sum(doc_seqlens) + + down_left_row_indices = [] + up_right_row_indices = [] + + cur_len_so_far = doc_seqlens[0] + for i in range(len(doc_seqlens)): + down_left_row_indices.extend([cur_len_so_far] * doc_seqlens[i]) + if i < len(doc_seqlens) -1: + cur_len_so_far += doc_seqlens[i+1] + if padding > 0: + down_left_row_indices.extend([cur_len_so_far] * padding) + + cur_len_so_far = 0 + for i in range(len(doc_seqlens)): + up_right_row_indices.extend([cur_len_so_far] * doc_seqlens[i]) + if i < len(doc_seqlens) -1: + cur_len_so_far += doc_seqlens[i+1] + if padding > 0: + up_right_row_indices.extend([cur_len_so_far] * padding) + + down_left_row_indices = torch.as_tensor(down_left_row_indices, dtype=torch.int32, device=device).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + up_right_row_indices = torch.as_tensor(up_right_row_indices, dtype=torch.int32, device=device).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + startend_row_indices = torch.concat([down_left_row_indices, up_right_row_indices], axis=-1) + startend_row_indices = torch.clip(startend_row_indices, max=seqlen_q) + + causal = False + return startend_row_indices, causal + +def generate_share_question_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): + if doc_seqlens == None: + doc_seqlens = [2538, 1742, 3213] + if seqlen_k != 8192: + doc_seqlens = [int(doc_seqlen * (seqlen_k / 8192)) for doc_seqlen in doc_seqlens] + print(f"{seqlen_k=}, auto setting doc_seqlens to {doc_seqlens}") + + seq_cusums = np.cumsum(doc_seqlens) + seq_cusums = np.append(seq_cusums, 128) + + total_seqlen = np.sum(doc_seqlens) + assert total_seqlen <= seqlen_k + assert len(doc_seqlens) >= 3 + padding = seqlen_k - total_seqlen + + #startend_row_indices = [S] * doc_seq_lens[0] + startend_row_indices = [total_seqlen] * doc_seqlens[0] + + cur_len_so_far = doc_seqlens[0] + for idx in range(1, len(doc_seqlens)): + cur_len_so_far += doc_seqlens[idx] + startend_row_indices.extend([cur_len_so_far] * doc_seqlens[idx]) + + if padding > 0: + startend_row_indices.extend([cur_len_so_far] * padding) + + startend_row_indices = torch.as_tensor(startend_row_indices, dtype=torch.int32, device=device).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + startend_row_indices = torch.clip(startend_row_indices, max=seqlen_q) + + causal = True + return startend_row_indices, causal + +def generate_global_sliding_window_mask(batch_size, seqlen_q, seqlen_k, h, global_token=16, window_size=None): + if window_size == None: + window_size = (512, 512) + if seqlen_k != 8192: + window_size = tuple(int(ws * (seqlen_k / 8192)) for ws in window_size) + print(f"{seqlen_k=}, auto setting window_size to {window_size}") + assert len(window_size) == 2 + left_window_size, right_window_size = window_size + + down_left_start_row_indices = [] + down_left_end_row_indices = [] + up_right_start_row_indices = [] + up_right_end_row_indices = [] + + down_left_start_row_indices = torch.arange( + left_window_size + 1, seqlen_k + left_window_size + 1, dtype=torch.int32, device=device + ).clip(max=seqlen_q) + down_left_start_row_indices[:global_token] = seqlen_q + down_left_start_row_indices = down_left_start_row_indices.reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + + down_left_end_row_indices = torch.full([seqlen_k], seqlen_q, dtype=torch.int32, device=device).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + + up_right_start_row_indices = torch.full([seqlen_k], global_token, dtype=torch.int32, device=device) + up_right_start_row_indices[:global_token+right_window_size+1] = 0 + up_right_start_row_indices = up_right_start_row_indices.reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + + up_right_end_row_indices = torch.arange( + -right_window_size, seqlen_k - right_window_size, dtype=torch.int32, device=device + ) + up_right_end_row_indices[:global_token+right_window_size+1] = 0 + up_right_end_row_indices = up_right_end_row_indices.reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + + startend_row_indices = torch.concat([down_left_start_row_indices, down_left_end_row_indices, up_right_start_row_indices, up_right_end_row_indices], axis=-1) + startend_row_indices = torch.clip(startend_row_indices, max=seqlen_q) + + causal = False + return startend_row_indices, causal + +def generate_causal_blockwise_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): + # TODO: this seems buggy, to be fixed + if doc_seqlens == None: + doc_seqlens = [2538, 1742, 3213] + if seqlen_k != 8192: + doc_seqlens = [int(doc_seqlen * (seqlen_k / 8192)) for doc_seqlen in doc_seqlens] + print(f"{seqlen_k=}, auto setting doc_seqlens to {doc_seqlens}") + total_seqlen = np.sum(doc_seqlens) + assert total_seqlen <= seqlen_k + assert len(doc_seqlens) >= 3 + padding = seqlen_k - np.sum(doc_seqlens) + + start_row_indices = [] + cur_len_so_far = doc_seqlens[0] + for i in range(len(doc_seqlens)): + start_row_indices.extend([cur_len_so_far] * doc_seqlens[i]) + if i < len(doc_seqlens) - 1: + cur_len_so_far += doc_seqlens[i+1] + if padding > 0: + start_row_indices.extend([cur_len_so_far] * padding) + start_row_indices = torch.as_tensor(start_row_indices, dtype=torch.int32, device=device).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + + seq_cusums = np.cumsum(doc_seqlens) + end_row_indices = [seq_cusums[-2]] * seq_cusums[-2] + [seq_cusums[-1]] * doc_seqlens[-1] + [seqlen_k] * padding + end_row_indices = torch.as_tensor(end_row_indices, dtype=torch.int32, device=device).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + + startend_row_indices = torch.concat([start_row_indices, end_row_indices], axis=-1) + startend_row_indices = torch.clip(startend_row_indices, max=seqlen_q) + + causal = True + return startend_row_indices, causal + +def generate_prefix_lm_document_mask(batch_size, seqlen_q, seqlen_k, h, doc_seqlens=None): + """ + tuple(prefix_length, seq_length) + """ + if doc_seqlens == None: + doc_seqlens=[(1024, 2538), (1742, 1742), (512, 3213)] + if seqlen_k != 8192: + scale = seqlen_k / 8192 + doc_seqlens = [tuple(int(v * scale) for v in pair) for pair in doc_seqlens] + print(f"{seqlen_k=}, auto setting doc_seqlens to {doc_seqlens}") + + assert len(doc_seqlens) >= 2 + total_seqlen = 0 + for prefix_length, seq_length in doc_seqlens: + total_seqlen += seq_length + assert total_seqlen <= seqlen_k + padding = seqlen_k - total_seqlen + + down_left_row_indices = [] + cur_len_so_far = doc_seqlens[0][1] + for i in range(len(doc_seqlens)): + down_left_row_indices.extend([cur_len_so_far] * doc_seqlens[i][1]) + if i < len(doc_seqlens) - 1: + cur_len_so_far += doc_seqlens[i+1][1] + if padding > 0: + down_left_row_indices.extend([cur_len_so_far] * padding) + down_left_row_indices = torch.as_tensor(down_left_row_indices, dtype=torch.int32, device=device).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + + up_right_row_indices = [] + cur_len_so_far = 0 + for prefix_length, seq_length in doc_seqlens: + up_right_row_indices.extend([cur_len_so_far] * prefix_length + list(range(cur_len_so_far+prefix_length, cur_len_so_far+seq_length))) + cur_len_so_far += seq_length + if padding > 0: + up_right_row_indices.extend([total_seqlen] * padding) + up_right_row_indices = torch.as_tensor(up_right_row_indices, dtype=torch.int32, device=device).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + + startend_row_indices = torch.concat([down_left_row_indices, up_right_row_indices], axis=-1) + + startend_row_indices = torch.clip(startend_row_indices, max=seqlen_q) + + causal = False + return startend_row_indices, causal + +def generate_prefix_lm_causal_mask(batch_size, seqlen_q, seqlen_k, h, prefix_length=None): + """ + tuple(prefix_length, seq_length) + """ + if prefix_length == None: + prefix_length = 1024 + if seqlen_k != 8192: + prefix_length = int(prefix_length * (seqlen_k / 8192)) + print(f"{seqlen_k=}, auto setting doc_seqlens to {prefix_length}") + assert prefix_length <= seqlen_k + down_left_row_indices = torch.full([seqlen_k], seqlen_k, dtype=torch.int32, device=device).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + up_right_row_indices = torch.as_tensor([0] * prefix_length + list(range(prefix_length, seqlen_k)), dtype=torch.int32, device=device).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + startend_row_indices = torch.concat([down_left_row_indices, up_right_row_indices], axis=-1) + startend_row_indices = torch.clip(startend_row_indices, max=seqlen_q) + + causal = False + return startend_row_indices, causal + +def generate_qk_sparse_mask(batch_size, seqlen_q, seqlen_k, h, maskout_pair=None): + """ + tuple(offset, maskout_len) + """ + if maskout_pair == None: + maskout_pair=[(1024, 538), (2358, 1700)] + if seqlen_k != 8192: + scale = seqlen_k / 8192 + maskout_pair = [tuple(int(v * scale) for v in pair) for pair in maskout_pair] + print(f"{seqlen_k=}, auto setting maskout_pair to {maskout_pair}") + start_row_indices = [] + end_row_indices = [] + last_offset = 0 + for offset, maskout_len in maskout_pair: + assert offset > last_offset + start_row_indices.extend([seqlen_k]*(offset-last_offset)) + end_row_indices.extend([seqlen_k]*(offset-last_offset)) + + start_row_indices.extend(list(range(offset, offset+maskout_len))) + end_row_indices.extend([offset+maskout_len]*(maskout_len)) + + last_offset = offset + maskout_len + + last_offset <= seqlen_k + start_row_indices.extend([seqlen_k]*(seqlen_k-last_offset)) + end_row_indices.extend([seqlen_k]*(seqlen_k-last_offset)) + + start_row_indices = torch.as_tensor(start_row_indices, dtype=torch.int32, device=device).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + end_row_indices = torch.as_tensor(end_row_indices, dtype=torch.int32, device=device).reshape((1, 1, seqlen_k, 1)).repeat_interleave(batch_size, 0) + startend_row_indices = torch.concat([start_row_indices, end_row_indices], axis=-1) + startend_row_indices = torch.clip(startend_row_indices, max=seqlen_q) + + causal = True + return startend_row_indices, causal + +def generate_random_eviction_mask(batch_size, seqlen_q, seqlen_k, h, start_row=None): + # np.random.seed(0) + if start_row == None: + start_row = 4096 + if seqlen_k != 8192: + start_row = int(start_row * (seqlen_k / 8192)) + print(f"{seqlen_k=}, auto setting start_row to {start_row}") + start_rows_list = [] + for bz_idx in range(batch_size): + for head_idx in range(h): + start_rows = np.array([seqlen_k+1] * seqlen_k) + mask_pos = np.random.choice(seqlen_k-1, seqlen_k - start_row, replace=False) + index = np.arange(start_row, seqlen_k) + mask_pos = np.concatenate([mask_pos[mask_pos < index - 1], mask_pos[mask_pos >= index - 1]]) + start_rows[mask_pos] = index + start_rows_list.append(start_rows) + startend_row_indices = torch.as_tensor(start_rows_list, dtype=torch.int32, device=device).reshape((batch_size, h, seqlen_k, 1)) + startend_row_indices = torch.clip(startend_row_indices, max=seqlen_q) + causal = True + return startend_row_indices, causal diff --git a/run_torch_8gpu.sh b/run_torch_8gpu.sh new file mode 100644 index 0000000..8816df6 --- /dev/null +++ b/run_torch_8gpu.sh @@ -0,0 +1,170 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ============================================================ +# Multi-GPU pytest runner with persistent split + resume +# +# 特性: +# 1) split 只生成一次 +# 2) 如果 flashmask_split_ 存在则直接复用 +# 3) resume 通过统计 log 已完成数量 +# ============================================================ + +NUM_GPUS=${NUM_GPUS:-8} +TEST_FILE="test_flashmask_torch.py" +LOG_DIR="logs" +SPLIT_DIR="./flashmask_split_" + +mkdir -p "${LOG_DIR}" +mkdir -p "${SPLIT_DIR}" + +echo "" +echo "========================================" +echo " Multi-GPU Pytest Runner" +echo " GPUs : ${NUM_GPUS}" +echo " Test file : ${TEST_FILE}" +echo " Split dir : ${SPLIT_DIR}" +echo "========================================" +echo "" + +# ────────────────────────────────────── +# Step 1: 如果没有 split 文件,则生成 +# ────────────────────────────────────── + +need_generate=false + +for (( gpu=0; gpu/dev/null \ + | grep '::' \ + > "${SPLIT_DIR}/all_tests.txt" || true + + TOTAL=$(wc -l < "${SPLIT_DIR}/all_tests.txt") + echo "[INFO] Total test cases: ${TOTAL}" + + if [ "${TOTAL}" -eq 0 ]; then + echo "[ERROR] No test cases collected." + exit 1 + fi + + # 初始化 GPU 文件 + for (( gpu=0; gpu "${SPLIT_DIR}/gpu_${gpu}.txt" + done + + idx=0 + while IFS= read -r line; do + gpu=$(( idx % NUM_GPUS )) + echo "${line}" >> "${SPLIT_DIR}/gpu_${gpu}.txt" + idx=$(( idx + 1 )) + done < "${SPLIT_DIR}/all_tests.txt" + + echo "[INFO] Split completed:" + for (( gpu=0; gpu "${remain_list}" + + test_args=$(tr '\n' ' ' < "${remain_list}") + + ( + export CUDA_VISIBLE_DEVICES="${gpu}" + + if [ -f "${log_file}" ]; then + echo "===== RESUME $(date) =====" >> "${log_file}" + else + echo "===== START $(date) =====" > "${log_file}" + fi + + python -m pytest -v --tb=short ${test_args} >> "${log_file}" 2>&1 || true + + echo "" >> "${log_file}" + echo "===== GPU ${gpu}: done =====" >> "${log_file}" + ) & + + pids+=($!) + echo "[INFO] GPU ${gpu}: started (PID $!), log -> ${log_file}" +done + +# ────────────────────────────────────── +# Step 3: 等待 +# ────────────────────────────────────── + +echo "" +echo "[INFO] Waiting for all GPU processes..." + +for pid in "${pids[@]}"; do + wait "${pid}" +done + +# ────────────────────────────────────── +# Step 4: 汇总 +# ────────────────────────────────────── + +echo "" +echo "========================================" +echo " All GPU processes finished." +echo "========================================" + +for (( gpu=0; gpu/dev/null || echo 0) + f=$(grep -c ' FAILED' "${log_file}" 2>/dev/null || echo 0) + s=$(grep -c ' SKIPPED' "${log_file}" 2>/dev/null || echo 0) + e=$(grep -c ' ERROR' "${log_file}" 2>/dev/null || echo 0) + printf " GPU %d: %4d passed, %4d failed, %4d skipped, %4d errors\n" \ + "${gpu}" "${p}" "${f}" "${s}" "${e}" + fi +done + +echo "" +echo "[INFO] Done." diff --git a/test_flashmask_torch.py b/test_flashmask_torch.py new file mode 100644 index 0000000..b6e1841 --- /dev/null +++ b/test_flashmask_torch.py @@ -0,0 +1,237 @@ +import os +import math +import itertools +import pytest +from einops import rearrange, repeat +import torch + +from interface_torch import flashmask_attention + +# from flash_mask.flashmask_attention_v3.flashmask_interface import flashmask_attention + +from generate_startend_row_indices_torch import ( + startend_row_indices_to_attn_bias, + generate_none_mask, + generate_sliding_window_mask, + generate_causal_document_mask, + generate_document_mask, + generate_share_question_mask, + generate_global_sliding_window_mask, + generate_causal_blockwise_mask, + generate_prefix_lm_document_mask, + generate_prefix_lm_causal_mask, + generate_qk_sparse_mask, + generate_random_eviction_mask +) + +from functools import partial +from test_util_torch import attention_ref + +# batch_size, seqlen_q, seqlen_k, nheads, nheads_kv +shape_cases = ( + [ + (2840, 32, 32, 16, 4), + (1, 300, 300, 16, 16), + # (2, 8192, 32768, 32, 4), # this will oom + # (2, 8192, 8192, 32, 4), # this will oom + (2, 8192, 8192, 14, 1), + (2, 16384, 16384, 4, 1), + (1, 1, 127, 1, 1), + (1, 128, 127, 1, 1), + (1, 127, 128, 1, 1), + (2, 16383, 16384, 4, 1), + (2, 16384, 16383, 4, 1), + (2, 1000, 1000, 4, 1), + (2, 2000, 2000, 4, 1), + (2, 3000, 3000, 4, 1), + (1, 4000, 4000, 1, 1), + (1, 8192, 32768+1024, 2, 1), + (1, 8192, 16384+1024, 2, 1) + # my case + ] + # tridao case + + list(itertools.product( + [9], # batch_size + [1, 64, 128, 256, 239, 799, 113, 113, 128, 113, 108, 256, 384, 640, 512, 1024, 1023, 1024,], # seqlen_q + [128, 192, 256, 203, 128, 217, 211, 256, 512, 256, 128, 256, 1024, 1024, 1023,], # seqlen_k + [6], # nheads + [6, 2, 1], # nheads_kv + )) + + list(itertools.product( + [2], # batch_size + [4096, 4224], # seqlen_q + [4096, 4224], # seqlen_k + [6], # nheads + [6, 2, 1], # nheads_kv + )) +) + +# Generate all combinations for second param +def generate_shapes(): + for batch_size, seqlen_q, seqlen_k, nheads, nheads_kv in shape_cases: + if nheads_kv == 1: + nheads_startend_row_indices_values = [1] + else: + nheads_startend_row_indices_values = [1, nheads_kv] + for nheads_startend_row_indices in nheads_startend_row_indices_values: + yield ( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices + ) + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("fa_version", [3]) +@pytest.mark.parametrize("d, dv", + [ + (64, 64), + (80, 80), + (128, 128), + (192, 192), + (256, 256), + ]) +@pytest.mark.parametrize( + "batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices", + list(generate_shapes()) +) +@pytest.mark.parametrize( + "gen_startend_row_indices", + [ + # partial(generate_none_mask, causal=False), # full + # partial(generate_none_mask, causal=True), # causal + partial(generate_sliding_window_mask), # sliding window + partial(generate_causal_document_mask), # causal document mask + partial(generate_document_mask), # document mask + partial(generate_share_question_mask), # share question mask + partial(generate_global_sliding_window_mask), # global sliding window + partial(generate_causal_blockwise_mask), # causal blockwise mask + partial(generate_prefix_lm_document_mask), # prefix lm document mask + partial(generate_prefix_lm_causal_mask), # prefix lm causal mask + partial(generate_qk_sparse_mask), # qk-sparse mask + partial(generate_random_eviction_mask), # random eviction mask + ], +) +def test_flashmask( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, d, dv, nheads_startend_row_indices, fa_version, dtype, gen_startend_row_indices, softcap=0.0 +): + torch.manual_seed(2024) + assert nheads % nheads_kv == 0 + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + q_ref = torch.randn([batch_size, seqlen_q, nheads, d], dtype=dtype, device=device) + k_ref = torch.randn([batch_size, seqlen_k, nheads_kv, d], dtype=dtype, device=device) + v_ref = torch.randn([batch_size, seqlen_k, nheads_kv, dv], dtype=dtype, device=device) + + q_ref.requires_grad_(True) + k_ref.requires_grad_(True) + v_ref.requires_grad_(True) + + q_bf16, k_bf16, v_bf16 = [x.detach().clone() for x in (q_ref, k_ref, v_ref)] + + q_bf16.requires_grad_(True) + k_bf16.requires_grad_(True) + v_bf16.requires_grad_(True) + + q, k, v = [x.detach().clone() for x in (q_ref, k_ref, v_ref)] + + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + + startend_row_indices, causal = gen_startend_row_indices(batch_size, seqlen_q, seqlen_k, nheads_startend_row_indices) + + if startend_row_indices is None and causal and d == 80: + pytest.skip(f"Skipping because running headdim 80 with flash_attn in causal mask") + + if fa_version == 2 and seqlen_q != seqlen_k and causal: + pytest.skip(f"Skipping because running fa2 in causal when seqlen_q != seqlen_k") + + if fa_version == 4 and d != 128 and d != 64 and seqlen_q != seqlen_k and causal: + pytest.skip(f"Skipping because running fa4 in causal when seqlen_q != seqlen_k and d not int [128, 64]") + + if fa_version == 4 and startend_row_indices is not None and startend_row_indices.shape[-1] == 4: + pytest.skip(f"Skipping because running fa4 when startend_row_indices.shape[-1] == 4") + + attn_bias = startend_row_indices_to_attn_bias(startend_row_indices, seqlen_q, nheads, dtype, causal) + + out_ref, attn_ref = attention_ref( + q_ref, + k_ref, + v_ref, + causal=causal, + attn_bias=attn_bias + ) + + out_bf16, attn_bf16 = attention_ref( + q_bf16, + k_bf16, + v_bf16, + causal=causal, + attn_bias=attn_bias, + upcast=False, + reorder_ops=True + ) + + # # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert softcap == 0.0 + rtol = 2 if softcap == 0.0 else 3 + + print(f"Torch naive bf16 Output max diff: {(out_bf16 - out_ref).abs().max().item()}") + print(f"Torch naive bf16 Output mean diff: {(out_bf16 - out_ref).abs().mean().item()}") + + if fa_version == 2: + os.environ["FLAGS_flash_attn_version"] = "2" + elif fa_version == 3: + os.environ["FLAGS_flash_attn_version"] = "3" + elif fa_version == 4: + os.environ["FLAGS_flash_attn_version"] = "4" + raise ValueError( + f"Invalid flash attention version: {fa_version}" + ) + + out, lse = flashmask_attention( + q, + k, + v, + startend_row_indices, + causal=causal, + return_softmax_lse=True + ) + + print(f"flashmask Output max diff: {(out - out_ref).abs().max().item()}") + print(f"flashmask Output mean diff: {(out - out_ref).abs().mean().item()}") + # if not causal: + # print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}") + # breakpoint() + + # Check that FlashAttention's numerical error is at most twice the numerical error + # of a Pytorch implementation. + + assert (out - out_ref).abs().max().item() <= rtol * (out_bf16 - out_ref).abs().max().item() + fwd_atol + + # g = torch.randn(out.shape, dtype=out.dtype) + g = torch.randn_like(out) + + out.backward(g) + out_ref.backward(g) + out_bf16.backward(g) + + print(f"flashmask dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") + print(f"flashmask dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") + print(f"flashmask dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") + print(f"flashmask dQ mean diff: {(q.grad - q_ref.grad).abs().mean().item()}") + print(f"flashmask dK mean diff: {(k.grad - k_ref.grad).abs().mean().item()}") + print(f"flashmask dV mean diff: {(v.grad - v_ref.grad).abs().mean().item()}") + + print(f"Torch naive bf16 dQ max diff: {(q_bf16.grad - q_ref.grad).abs().max().item()}") + print(f"Torch naive bf16 dK max diff: {(k_bf16.grad - k_ref.grad).abs().max().item()}") + print(f"Torch naive bf16 dV max diff: {(v_bf16.grad - v_ref.grad).abs().max().item()}") + print(f"Torch naive bf16 dQ mean diff: {(q_bf16.grad - q_ref.grad).abs().mean().item()}") + print(f"Torch naive bf16 dK mean diff: {(k_bf16.grad - k_ref.grad).abs().mean().item()}") + print(f"Torch naive bf16 dV mean diff: {(v_bf16.grad - v_ref.grad).abs().mean().item()}") + + dq_atol = 2 * (q_ref.grad + 0.3 - 0.3 - q_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (q.grad - q_ref.grad).abs().max().item() <= rtol * (q_bf16.grad - q_ref.grad).abs().max().item() + dq_atol + dk_atol = 2 * (k_ref.grad + 0.3 - 0.3 - k_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (k.grad - k_ref.grad).abs().max().item() <= rtol * (k_bf16.grad - k_ref.grad).abs().max().item() + dk_atol + dv_atol = 2 * (v_ref.grad + 0.3 - 0.3 - v_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (v.grad - v_ref.grad).abs().max().item() <= rtol * (v_bf16.grad - v_ref.grad).abs().max().item() + dv_atol diff --git a/test_util_torch.py b/test_util_torch.py new file mode 100644 index 0000000..a212aab --- /dev/null +++ b/test_util_torch.py @@ -0,0 +1,279 @@ +import math +from einops import repeat, rearrange +import torch +from einops import rearrange, repeat + +import numpy as np + +def construct_local_mask( + seqlen_q, + seqlen_k, + window_size=(-1, -1), # -1 means infinite window size + sink_token_length=0, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + device=None, +): + row_idx = rearrange(torch.arange(seqlen_q, dtype=torch.int64, device=device), "s -> s 1") + col_idx = torch.arange(seqlen_k, dtype=torch.int64, device=device) + if key_leftpad is not None: + key_leftpad = rearrange(key_leftpad, "b -> b 1 1 1") + col_idx = repeat(col_idx, "s -> b 1 1 s", b=key_leftpad.shape[0]) + col_idx = torch.where(col_idx >= key_leftpad, col_idx - key_leftpad, 2**32) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + sq = ( + seqlen_q + if query_padding_mask is None + else rearrange(query_padding_mask.sum(-1), "b -> b 1 1 1") + ) + if window_size[0] < 0: + return col_idx > row_idx + sk - sq + window_size[1] + else: + sk = torch.full_like(col_idx, seqlen_k) if key_padding_mask is None else sk + return torch.logical_or( + col_idx > torch.minimum(row_idx + sk - sq + window_size[1], sk), + torch.logical_and(col_idx < row_idx + sk - sq - window_size[0], col_idx >= sink_token_length), + ) + +def attention_ref( + q, + k, + v, + query_padding_mask=None, + key_padding_mask=None, + key_leftpad=None, + attn_bias=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + qv=None, + q_descale=None, k_descale=None, v_descale=None, + window_size=(-1, -1), # -1 means infinite window size + attention_chunk=0, + sink_token_length=0, + softcap=0.0, + upcast=True, + reorder_ops=False, + intermediate_dtype=None, +): + """ + Arguments: + q: (batch_size, seqlen_q, nheads, head_dim) + k: (batch_size, seqlen_k, nheads, head_dim) + v: (batch_size, seqlen_k, nheads, head_dim_v) + qv: (batch_size, seqlen_q, nheads, head_dim_v) + query_padding_mask: (batch_size, seqlen_q) + key_padding_mask: (batch_size, seqlen_k) + attn_bias: broadcastable to (batch_size, nheads, seqlen_q, seqlen_k) + dropout_p: float + dropout_mask: (batch_size, nheads, seqlen_q, seqlen_k) + causal: whether to apply causal masking + upcast: whether to to all inputs to fp32, do all computation in fp32, then to + output back to fp16/bf16. + reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) + without changing the math. This is to estimate the numerical error from operation + reordering. + Output: + output: (batch_size, seqlen_q, nheads, head_dim_v) + attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout + """ + if causal: + window_size = (window_size[0], 0) + dtype_og = q.dtype + if upcast: + q = q.to(torch.float32) + k = k.to(torch.float32) + v = v.to(torch.float32) + if qv is not None: + qv = qv.to(torch.float32) + + if q_descale is not None: + assert False + q_descale = repeat(q_descale, "b h -> b 1 (h g) 1", g=q.shape[2] // k.shape[2]) + q = (q.to(torch.float32) * q_descale).to(q.dtype) + qv = (qv.to(torch.float32) * q_descale).to(qv.dtype) if qv is not None else None + + if k_descale is not None: + assert False + k = (k.to(torch.float32) * rearrange(k_descale, "b h -> b 1 h 1")).to(k.dtype) + + if v_descale is not None: + assert False + v = (v.to(torch.float32) * rearrange(v_descale, "b h -> b 1 h 1")).to(v.dtype) + + seqlen_q, seqlen_k = q.shape[1], k.shape[1] + + # (batch_size, seqlen, nheads, head_dim) -> (batch_size, nheads, seqlen, head_dim) + q = torch.permute(q, [0, 2, 1, 3]) + k = torch.permute(k, [0, 2, 1, 3]) + v = torch.permute(v, [0, 2, 1, 3]) + + k = repeat(k, "b h s d -> b (h g) s d", g=q.shape[1] // k.shape[1]) + v = repeat(v, "b h s d -> b (h g) s d", g=q.shape[1] // v.shape[1]) + if attn_bias is not None: + attn_bias = repeat(attn_bias, "b h s d -> b (h g) s d ", g=q.shape[1] // attn_bias.shape[1]) + + d = q.shape[-1] + dv = v.shape[-1] + softmax_scale = 1.0 / math.sqrt(d if qv is None else d + dv) + + if not reorder_ops: + scores = torch.matmul(q * softmax_scale, k.transpose(-2, -1)) + else: + scores = torch.matmul(q, (k * softmax_scale).transpose(-2, -1)) + + if qv is not None: + assert False + scores = scores + torch.matmul(qv * softmax_scale, v.transpose(-2, -1)) + + if softcap > 0: + assert False + scores = torch.tanh(scores / softcap) * softcap + + if key_padding_mask is not None: + assert False + scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf")) + local_mask = None + + if window_size[0] >= 0 or window_size[1] >= 0: + local_mask = construct_local_mask( + seqlen_q, + seqlen_k, + window_size, + sink_token_length, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + if attention_chunk > 0: + assert False + chunk_mask = construct_chunk_mask( + seqlen_q, + seqlen_k, + attention_chunk, + query_padding_mask, + key_padding_mask, + key_leftpad=key_leftpad, + device=q.device, + ) + local_mask = torch.logical_or(local_mask, chunk_mask) if local_mask is not None else chunk_mask + + if local_mask is not None: + scores.masked_fill_(local_mask, float("-inf")) + if attn_bias is not None: + scores = scores + attn_bias.to(torch.float32) + # print("scores:", scores[0,0,0,:]) + # when all values in a line of attn_bias are -inf, setting value in this line to a very small value + # to prevend softmax giving nan output + all_inf_mask = (attn_bias == -np.inf).all(axis=-1, keepdim=True) + scores = torch.where(all_inf_mask, torch.full_like(scores, -1e9), scores) + + attention = torch.softmax(scores, axis=-1).to(v.dtype) + + if attn_bias is not None: + # when all values in a line of attn_bias are -inf, we setting value in this line to a very small value + # to prevend softmax giving nan output, however, after softmax, values in this line become 1/seqlen, + # so setting them to 0 after softmax + attention = torch.where(all_inf_mask, torch.zeros_like(attention), attention) + + # We want to mask here so that the attention matrix doesn't have any NaNs + # Otherwise we'll get NaN in dV + if query_padding_mask is not None: + assert False + attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0) + + # Without this we might get NaN in dv + if key_padding_mask is not None: + assert False + attention = attention.masked_fill(rearrange(~key_padding_mask, "b s -> b 1 1 s"), 0.0) + # Some rows might be completely masked out so we fill them with zero instead of NaN + if local_mask is not None: + attention = attention.masked_fill(torch.all(local_mask, axis=-1, keepdim=True), 0.0) + dropout_scaling = 1.0 / (1 - dropout_p) + # attention_drop = attention.masked_fill(~dropout_mask, 0.0) * dropout_scaling + # output = torch.matmul(attention_drop, v, transpose_y=True) + if dropout_mask is not None: + assert False + attention_drop = attention.masked_fill(~dropout_mask, 0.0) + else: + attention_drop = attention + if intermediate_dtype is not None: + attention_drop = attention_drop.to(intermediate_dtype).to(attention_drop.dtype) + output = torch.matmul(attention_drop, v * dropout_scaling) + output = torch.permute(output, [0, 2, 1, 3]) + if query_padding_mask is not None: + output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + return output.to(dtype=dtype_og), attention.to(dtype=dtype_og) + + +#blockmask utils +def random_blockmask(shape, dtype='int32',is_causal=False, ref_q = None): + # 随机生成 0/1 mask + mask = torch.randint(0, 2, shape, dtype=torch.int32) + B, S, Q, K = shape + return mask + +def flashmask_to_densemask(startend_row_indices, seqlen_q, nheads, causal=True): + if startend_row_indices is None: + return None + bz, num_head, seqlen_k, bound_num = startend_row_indices.shape + assert nheads % num_head == 0 + m = torch.ones((bz, num_head, seqlen_q, seqlen_k), dtype=torch.int32) + has_end = (causal and bound_num == 2) or ((not causal) and bound_num == 4) + for bi in range(bz): + for hi in range(num_head): + for j in range(seqlen_k): + downstart = startend_row_indices[bi, hi, j, 0] + if has_end: + downend = startend_row_indices[bi, hi, j, 1] + m[bi, hi, downstart:downend, j] = 0 + else: + m[bi, hi, downstart:, j] = 0 + if causal: + # from flash-attention 2.1 and in flash-attention 3, If seqlen_q != seqlen_k and causal=True, + # the causal mask is aligned to the bottom right corner of the attention matrix, + # instead of the top-left corner. + # See: https://github.com/Dao-AILab/flash-attention?tab=readme-ov-file#21-change-behavior-of-causal-flag + m[bi, hi, :max(0, j - (seqlen_k - seqlen_q)), j] = 0 + else: + if has_end: + upstart = startend_row_indices[bi, hi, j, 2] + upend = startend_row_indices[bi, hi, j, 3] + m[bi, hi, upstart:upend, j] = 0 + else: + upend = startend_row_indices[bi, hi, j, 1] + m[bi, hi, :upend, j] = 0 + m = torch.repeat_interleave(x=m, repeats=nheads // num_head, axis=1) + m = m.astype(torch.bool) + return m + +def blockmask_to_densemask(blockmask, q_len, k_len, dtype, causal=True): + """ + Args: + blockmask: [b, s, q_blocks, k_blocks] (0/1 mask, 1表示masked, 0表示可见) + q_len: int, query序列长度 + k_len: int, key序列长度 + dtype: torch.float32等 + causal: bool, 是否加自回归遮挡 + + Returns: + densemask: [b, s, q_len, k_len],可直接用于attention + """ + if blockmask is None: + return None + bz, num_head, q_blocks, k_blocks = blockmask.shape + block_q = 128 + block_k = 128 + + # 1. 展开到[bs, s, q_len, k_len] + densemask = blockmask.astype(dtype).repeat_interleave(block_q, axis=2).repeat_interleave(block_k, axis=3) + densemask = densemask[:, :, :q_len, :k_len] + # print(densemask) + + return densemask.astype(torch.bool)