diff --git a/flashmask/flash_mask/cute/block_sparsity.py b/flashmask/flash_mask/cute/block_sparsity.py index 2022dee9a97..8b11765d4f1 100644 --- a/flashmask/flash_mask/cute/block_sparsity.py +++ b/flashmask/flash_mask/cute/block_sparsity.py @@ -9,7 +9,7 @@ """ from typing import Tuple, Optional, Callable, List, NamedTuple -import paddle +import torch import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack @@ -30,18 +30,18 @@ def __new_from_mlir_values__(self, values): return BlockSparseTensors(*values) -class BlockSparseTensorsPaddle(NamedTuple): - mask_block_cnt: paddle.Tensor - mask_block_idx: paddle.Tensor - full_block_cnt: Optional[paddle.Tensor] = None - full_block_idx: Optional[paddle.Tensor] = None +class BlockSparseTensorsTorch(NamedTuple): + mask_block_cnt: torch.Tensor + mask_block_idx: torch.Tensor + full_block_cnt: Optional[torch.Tensor] = None + full_block_idx: Optional[torch.Tensor] = None def _expand_sparsity_tensor( - tensor: paddle.Tensor, + tensor: torch.Tensor, expected_shape: Tuple[int, ...], tensor_name: str, -) -> paddle.Tensor: +) -> torch.Tensor: """Check if we need to expand the tensor to expected shape, and do so if possible.""" needs_expand = tensor.shape != expected_shape if not needs_expand: @@ -56,20 +56,20 @@ def _expand_sparsity_tensor( def _check_and_expand_block( name: str, - cnt: Optional[paddle.Tensor], - idx: Optional[paddle.Tensor], + cnt: Optional[torch.Tensor], + idx: Optional[torch.Tensor], expected_count_shape: Tuple[int, int, int], expected_index_shape: Tuple[int, int, int, int], -) -> Tuple[Optional[paddle.Tensor], Optional[paddle.Tensor]]: +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: if (cnt is None) != (idx is None): raise ValueError( f"{name}_block_cnt and {name}_block_idx must both be provided or both be None" ) if cnt is None or idx is None: return None, None - if cnt.dtype != paddle.int32 or idx.dtype != paddle.int32: - raise ValueError(f"{name}_block tensors must have dtype paddle.int32") - if cnt.place != idx.place: + if cnt.dtype != torch.int32 or idx.dtype != torch.int32: + raise ValueError(f"{name}_block tensors must have dtype torch.int32") + if cnt.device != idx.device: raise ValueError(f"{name}_block_cnt and {name}_block_idx must be on the same device") if not cnt.is_cuda or not idx.is_cuda: raise ValueError(f"{name}_block tensors must live on CUDA") @@ -79,11 +79,11 @@ def _check_and_expand_block( def normalize_block_sparse_tensors( - tensors: BlockSparseTensorsPaddle, + tensors: BlockSparseTensorsTorch, *, expected_count_shape: Tuple[int, int, int], expected_index_shape: Tuple[int, int, int, int], -) -> BlockSparseTensorsPaddle: +) -> BlockSparseTensorsTorch: if tensors.mask_block_cnt is None or tensors.mask_block_idx is None: raise ValueError("mask_block_cnt and mask_block_idx must be provided for block sparsity.") @@ -104,45 +104,33 @@ def normalize_block_sparse_tensors( expected_count_shape, expected_index_shape, ) - if full_cnt is not None and mask_cnt.place != full_cnt.place: + if full_cnt is not None and mask_cnt.device != full_cnt.device: raise ValueError("All block sparse tensors must be on the same device") - return BlockSparseTensorsPaddle( + return BlockSparseTensorsTorch( mask_block_cnt=mask_cnt, mask_block_idx=mask_idx, full_block_cnt=full_cnt, full_block_idx=full_idx, ) - -def is_block_sparsity_enabled(tensors: BlockSparseTensorsPaddle) -> bool: +def is_block_sparsity_enabled(tensors: BlockSparseTensorsTorch) -> bool: return any(t is not None for t in (tensors.full_block_cnt, tensors.mask_block_cnt)) - -def to_cute_block_sparse_tensors(tensors: BlockSparseTensorsPaddle) -> Optional[BlockSparseTensors]: +def to_cute_block_sparse_tensors(tensors: BlockSparseTensorsTorch) -> Optional[BlockSparseTensors]: if not is_block_sparsity_enabled(tensors): return None - mask_block_cnt_tensor = from_dlpack( - tensors.mask_block_cnt.detach(), assumed_align=4 - ).mark_layout_dynamic(leading_dim=2) - mask_block_idx_tensor = from_dlpack( - tensors.mask_block_idx.detach(), assumed_align=4 - ).mark_layout_dynamic(leading_dim=3) - full_block_cnt_tensor = ( - from_dlpack(tensors.full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=2 - ) - if tensors.full_block_cnt is not None - else None - ) - full_block_idx_tensor = ( - from_dlpack(tensors.full_block_idx.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=3 - ) - if tensors.full_block_idx is not None - else None - ) + def _wrap(t, dim): + if t is None: + return None + return from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=dim) + + mask_block_cnt_tensor = _wrap(tensors.mask_block_cnt, 2) + mask_block_idx_tensor = _wrap(tensors.mask_block_idx, 3) + + full_block_cnt_tensor = _wrap(tensors.full_block_cnt, 2) + full_block_idx_tensor = _wrap(tensors.full_block_idx, 3) return BlockSparseTensors( mask_block_cnt_tensor, @@ -156,14 +144,14 @@ def compute_block_sparsity( config: Config, mask_mod_flex: Optional[Callable], device: str, - cu_seqlens_q: Optional[paddle.Tensor] = None, - cu_seqlens_k: Optional[paddle.Tensor] = None, - aux_tensors: Optional[List[paddle.Tensor]] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + aux_tensors: Optional[List[torch.Tensor]] = None, ) -> Tuple[ - Optional[paddle.Tensor], - Optional[paddle.Tensor], - Optional[paddle.Tensor], - Optional[paddle.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], ]: """ Computes block sparsity tensors from a given masking function. @@ -205,30 +193,30 @@ def compute_block_sparsity( def _compute_sparsity( - config: Config, device: str, aux_tensors: Optional[List[paddle.Tensor]] -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + config: Config, device: str, aux_tensors: Optional[List[torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Computes block sparsity for fixed-length sequences.""" n_blocks_q = (config.seqlen_q + config.tile_m - 1) // config.tile_m n_blocks_k = (config.seqlen_k + config.tile_n - 1) // config.tile_n # Pre-allocate output tensors - full_block_cnt = paddle.zeros( - (config.batch_size, config.nheads, n_blocks_q), dtype=paddle.int32 + full_block_cnt = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q), dtype=torch.int32, device=device ) - mask_block_cnt = paddle.zeros( - (config.batch_size, config.nheads, n_blocks_q), dtype=paddle.int32 + mask_block_cnt = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q), dtype=torch.int32, device=device ) - full_block_idx = paddle.zeros( - (config.batch_size, config.nheads, n_blocks_q, n_blocks_k), dtype=paddle.int32 + full_block_idx = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q, n_blocks_k), dtype=torch.int32, device=device ) - mask_block_idx = paddle.zeros( - (config.batch_size, config.nheads, n_blocks_q, n_blocks_k), dtype=paddle.int32 + mask_block_idx = torch.zeros( + (config.batch_size, config.nheads, n_blocks_q, n_blocks_k), dtype=torch.int32, device=device ) # --- Identity Mask --- # All blocks are fully computed. if config.mask_mod_name == "identity": - k_blocks = paddle.arange(n_blocks_k) + k_blocks = torch.arange(n_blocks_k, dtype=torch.int32, device=device) for q_block_idx in range(n_blocks_q): full_block_cnt[:, :, q_block_idx] = n_blocks_k full_block_idx[:, :, q_block_idx, :n_blocks_k] = k_blocks @@ -236,14 +224,14 @@ def _compute_sparsity( # --- Identity Partial Mask --- # All blocks are partially computed (masked). elif config.mask_mod_name == "identity_partial": - k_blocks = paddle.arange(n_blocks_k) + k_blocks = torch.arange(n_blocks_k, dtype=torch.int32, device=device) for q_block_idx in range(n_blocks_q): mask_block_cnt[:, :, q_block_idx] = n_blocks_k mask_block_idx[:, :, q_block_idx, :n_blocks_k] = k_blocks # --- Block Causal Mask --- elif config.mask_mod_name == "block_causal": - k_blocks = paddle.arange(n_blocks_k) + k_blocks = torch.arange(n_blocks_k, dtype=torch.int32, device=device) for q_block_idx in range(n_blocks_q): causal_indices = k_blocks[k_blocks <= q_block_idx] num_causal_indices = len(causal_indices) @@ -253,16 +241,18 @@ def _compute_sparsity( # --- Causal and Sliding Window Masks --- elif config.mask_mod_name in ["causal", "sliding_window"]: - q_block_indices = paddle.arange(n_blocks_q) - k_block_indices = paddle.arange(n_blocks_k) + q_block_indices = torch.arange(n_blocks_q, dtype=torch.int32, device=device) + k_block_indices = torch.arange(n_blocks_k, dtype=torch.int32, device=device) q_starts = q_block_indices * config.tile_m - q_ends = paddle.minimum( - (q_block_indices + 1) * config.tile_m, paddle.to_tensor(config.seqlen_q) + q_ends = torch.minimum( + (q_block_indices + 1) * config.tile_m, + torch.tensor(config.seqlen_q, dtype=torch.int32, device=device) ) k_starts = k_block_indices * config.tile_n - k_ends = paddle.minimum( - (k_block_indices + 1) * config.tile_n, paddle.to_tensor(config.seqlen_k) + k_ends = torch.minimum( + (k_block_indices + 1) * config.tile_n, + torch.tensor(config.seqlen_k, dtype=torch.int32, device=device) ) # Expand dims for broadcasting: (n_blocks_q, 1) and (1, n_blocks_k) @@ -315,9 +305,9 @@ def _compute_varlen_sparsity( config: Config, mask_mod_flex: Callable, device: str, - cu_seqlens_q: paddle.Tensor, - cu_seqlens_k: paddle.Tensor, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]: + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Computes block sparsity for variable-length sequences.""" assert cu_seqlens_k is not None, "cu_seqlens_k is required for varlen attention" assert cu_seqlens_q.shape[0] == config.batch_size + 1 @@ -336,19 +326,19 @@ def _compute_varlen_sparsity( max_n_blocks = (total_k_len + config.tile_n - 1) // config.tile_n # Pre-allocate padded output tensors - full_block_cnt = paddle.zeros( - (config.batch_size, config.nheads, max_m_blocks), dtype=paddle.int32 + full_block_cnt = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks), dtype=torch.int32, device=device ) - mask_block_cnt = paddle.zeros( - (config.batch_size, config.nheads, max_m_blocks), dtype=paddle.int32 + mask_block_cnt = torch.zeros( + (config.batch_size, config.nheads, max_m_blocks), dtype=torch.int32, device=device ) - full_block_idx = paddle.zeros( + full_block_idx = torch.zeros( (config.batch_size, config.nheads, max_m_blocks, max_n_blocks), - dtype=paddle.int32, + dtype=torch.int32, device=device ) - mask_block_idx = paddle.zeros( + mask_block_idx = torch.zeros( (config.batch_size, config.nheads, max_m_blocks, max_n_blocks), - dtype=paddle.int32, + dtype=torch.int32, device=device ) # Process each sequence in the batch individually @@ -495,13 +485,13 @@ def _compute_causal_varlen_blocks( if full_blocks: full_block_cnt[seq_idx, :, m_local] = len(full_blocks) - full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = paddle.to_tensor( - full_blocks, + full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = torch.tensor( + full_blocks, dtype=torch.int32, device=device ) if partial_blocks: mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks) - mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = paddle.tensor( - partial_blocks, + mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = torch.tensor( + partial_blocks, dtype=torch.int32, device=device ) @@ -555,13 +545,13 @@ def _compute_sliding_window_varlen_blocks( if full_blocks: full_block_cnt[seq_idx, :, m_local] = len(full_blocks) - full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = paddle.to_tensor( - full_blocks, + full_block_idx[seq_idx, :, m_local, : len(full_blocks)] = torch.tensor( + full_blocks, dtype=torch.int32, device=device ) if partial_blocks: mask_block_cnt[seq_idx, :, m_local] = len(partial_blocks) - mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = paddle.to_tensor( - partial_blocks, + mask_block_idx[seq_idx, :, m_local, : len(partial_blocks)] = torch.tensor( + partial_blocks, dtype=torch.int32, device=device ) @@ -576,8 +566,8 @@ def _compute_identity_varlen_blocks( **kwargs, ): """Computes identity (all-attend) block sparsity for a single varlen sequence.""" - n_blocks_global = paddle.arange( - first_n_block_global, first_n_block_global + n_blocks_k, dtype=paddle.int32 + n_blocks_global = torch.arange( + first_n_block_global, first_n_block_global + n_blocks_k, dtype=torch.int32, device=device ) for m_local in range(n_blocks_q): full_block_cnt[seq_idx, :, m_local] = n_blocks_k @@ -641,11 +631,11 @@ def _compute_generic_varlen_blocks( if full_blocks: full_block_cnt[seq_idx, h_q, m_local] = len(full_blocks) - full_block_idx[seq_idx, h_q, m_local, : len(full_blocks)] = paddle.to_tensor( - full_blocks, + full_block_idx[seq_idx, h_q, m_local, : len(full_blocks)] = torch.tensor( + full_blocks, dtype=torch.int32, device=device ) if partial_blocks: mask_block_cnt[seq_idx, h_q, m_local] = len(partial_blocks) - mask_block_idx[seq_idx, h_q, m_local, : len(partial_blocks)] = paddle.to_tensor( - partial_blocks, + mask_block_idx[seq_idx, h_q, m_local, : len(partial_blocks)] = torch.tensor( + partial_blocks, dtype=torch.int32, device=device ) diff --git a/flashmask/flash_mask/cute/compute_block_sparsity.py b/flashmask/flash_mask/cute/compute_block_sparsity.py index 33775b58baa..b4d1f9f1d4b 100644 --- a/flashmask/flash_mask/cute/compute_block_sparsity.py +++ b/flashmask/flash_mask/cute/compute_block_sparsity.py @@ -5,7 +5,7 @@ from cutlass import Boolean, Int32, Int8, const_expr import cutlass.cute as cute from cutlass.cute.runtime import from_dlpack -import paddle +import torch from flash_mask.cute.block_sparsity import BlockSparseTensors from flash_mask.cute.utils import hash_callable, scalar_to_ssa, ssa_to_scalar @@ -263,7 +263,7 @@ def compute_block_sparsity( device, compute_full_blocks: bool = True, use_fast_sampling: bool = False, -) -> Tuple[BlockSparseTensors, Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor, paddle.Tensor]]: +) -> Tuple[BlockSparseTensors, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]: """ Computes block sparsity for a given `mask_mod`. @@ -281,33 +281,28 @@ def compute_block_sparsity( use_fast_sampling: Whether to use 5-point sampling (4 corners + center). This is much faster, but only suitable for masks where this check is sufficient. Returns: - A tuple of `BlockSparseTensors` and the underlying paddle tensors. + A tuple of `BlockSparseTensors` and the underlying torch tensors. """ num_m_blocks = (seqlen_q + tile_m - 1) // tile_m num_n_blocks = (seqlen_k + tile_n - 1) // tile_n - mask_block_cnt = paddle.zeros((batch_size, num_heads, num_m_blocks), dtype=paddle.int32) - mask_block_idx = paddle.zeros( - (batch_size, num_heads, num_m_blocks, num_n_blocks), dtype=paddle.int32 + mask_block_cnt = torch.zeros((batch_size, num_heads, num_m_blocks), dtype=torch.int32, device=device) + mask_block_idx = torch.zeros( + (batch_size, num_heads, num_m_blocks, num_n_blocks), dtype=torch.int32, device=device ) - full_block_cnt = paddle.zeros((batch_size, num_heads, num_m_blocks), dtype=paddle.int32) - full_block_idx = paddle.zeros( - (batch_size, num_heads, num_m_blocks, num_n_blocks), dtype=paddle.int32 + full_block_cnt = torch.zeros((batch_size, num_heads, num_m_blocks), dtype=torch.int32, device=device) + full_block_idx = torch.zeros( + (batch_size, num_heads, num_m_blocks, num_n_blocks), dtype=torch.int32, device=device ) + def _wrap(t, dim): + # (Capsule) -> from_dlpack (Cute Tensor) + return from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=dim) # Convert to cute tensors - mask_cnt_cute = from_dlpack(mask_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=2 - ) - mask_idx_cute = from_dlpack(mask_block_idx.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=3 - ) - full_cnt_cute = from_dlpack(full_block_cnt.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=2 - ) - full_idx_cute = from_dlpack(full_block_idx.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=3 - ) + mask_cnt_cute = _wrap(mask_block_cnt, 2) + mask_idx_cute = _wrap(mask_block_idx, 3) + full_cnt_cute = _wrap(full_block_cnt, 2) + full_idx_cute = _wrap(full_block_idx, 3) blocksparse_tensors = BlockSparseTensors( mask_block_cnt=mask_cnt_cute, @@ -350,7 +345,7 @@ def compute_block_sparsity( aux_tensors, ) - # Return both the BlockSparseTensors (cute) and the underlying paddle tensors + # Return both the BlockSparseTensors (cute) and the underlying torch tensors return blocksparse_tensors, (full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx) @@ -359,6 +354,9 @@ def compute_block_sparsity( def run(): """Test the BlockSparsityKernel with a simple causal mask.""" + if not torch.cuda.is_available(): + print("Skipping test: CUDA not available.") + return print("Testing BlockSparsityKernel...") diff --git a/flashmask/flash_mask/cute/cute_dsl_utils.py b/flashmask/flash_mask/cute/cute_dsl_utils.py index 4024f0559ed..9498b8f98c6 100644 --- a/flashmask/flash_mask/cute/cute_dsl_utils.py +++ b/flashmask/flash_mask/cute/cute_dsl_utils.py @@ -2,11 +2,11 @@ import os import pathlib -from typing import Tuple +from typing import Union, Tuple from functools import partial, lru_cache from dataclasses import dataclass, fields -import paddle +import torch try: from triton.tools.disasm import extract @@ -26,10 +26,10 @@ cute_compile_og = cute.compile -paddle2cute_dtype_map = { - paddle.float16: cutlass.Float16, - paddle.bfloat16: cutlass.BFloat16, - paddle.float32: cutlass.Float32, +torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, } @@ -39,8 +39,8 @@ def get_max_active_clusters(cluster_size): @lru_cache -def get_device_capacity(device: paddle.device = None) -> Tuple[int, int]: - return paddle.cuda.get_device_capability(device) +def get_device_capacity(device: Union[int, str, torch.device, None] = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) @dataclass diff --git a/flashmask/flash_mask/cute/flashmask_utils.py b/flashmask/flash_mask/cute/flashmask_utils.py index 94b944ebc77..e5362625944 100644 --- a/flashmask/flash_mask/cute/flashmask_utils.py +++ b/flashmask/flash_mask/cute/flashmask_utils.py @@ -20,7 +20,7 @@ from typing import Optional, NamedTuple from dataclasses import dataclass -import paddle +import torch import cutlass import cutlass.cute as cute import cuda.bindings.driver as cuda @@ -30,7 +30,7 @@ __all__ = [ "prepare_block_maxmin", - "FlashMaskInfoPaddle" + "FlashMaskInfoTorch", ] @@ -72,18 +72,18 @@ def __new_from_mlir_values__(self, values): @dataclass -class FlashMaskInfoPaddle: +class FlashMaskInfoTorch: is_causal: bool - startend_row_indices: paddle.Tensor - LTS_nblock_max: Optional[paddle.Tensor] = None - LTS_nblock_min: Optional[paddle.Tensor] = None - LTE_nblock_max: Optional[paddle.Tensor] = None - LTE_nblock_min: Optional[paddle.Tensor] = None - UTS_nblock_max: Optional[paddle.Tensor] = None - UTS_nblock_min: Optional[paddle.Tensor] = None - UTE_nblock_max: Optional[paddle.Tensor] = None - UTE_nblock_min: Optional[paddle.Tensor] = None - valid_block_count: Optional[paddle.Tensor] = None + startend_row_indices: torch.Tensor + LTS_nblock_max: Optional[torch.Tensor] = None + LTS_nblock_min: Optional[torch.Tensor] = None + LTE_nblock_max: Optional[torch.Tensor] = None + LTE_nblock_min: Optional[torch.Tensor] = None + UTS_nblock_max: Optional[torch.Tensor] = None + UTS_nblock_min: Optional[torch.Tensor] = None + UTE_nblock_max: Optional[torch.Tensor] = None + UTE_nblock_min: Optional[torch.Tensor] = None + valid_block_count: Optional[torch.Tensor] = None def _compute_nblock_seqlen(seqlen_k: int, kBlockN: int) -> int: @@ -184,18 +184,19 @@ def scan_max_min_cute( ) def _scan_max_min( - mInput: paddle.Tensor, + mInput: torch.Tensor, b: int, n: int, - mMaxO: paddle.Tensor, - mMinO: paddle.Tensor, + mMaxO: torch.Tensor, + mMinO: torch.Tensor, kBlockN: int, ): - input_tensor = from_dlpack(mInput.contiguous(), assumed_align=4).mark_layout_dynamic(leading_dim=2) + mInput_contig = mInput.contiguous() + input_tensor = from_dlpack(mInput_contig, assumed_align=4).mark_layout_dynamic(leading_dim=2) max_tensor = from_dlpack(mMaxO, assumed_align=4).mark_layout_dynamic(leading_dim=2) min_tensor = from_dlpack(mMinO, assumed_align=4).mark_layout_dynamic(leading_dim=2) - current_stream = cuda.CUstream(paddle.device.current_stream().stream_base.cuda_stream) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) compile_key = (b, kBlockN,) if compile_key not in _scan_max_min.compile_cache: @@ -217,7 +218,7 @@ def _scan_max_min( _scan_max_min.compile_cache = {} -def prepare_block_maxmin(flashmask_info: FlashMaskInfoPaddle, kBlockN: int = 128): +def prepare_block_maxmin(flashmask_info: FlashMaskInfoTorch, kBlockN: int = 128): """Prepare block-sparse max/min tensors for flashmask. The function will compute derived pointers/offsets and call scanMaxMinGpu @@ -227,28 +228,33 @@ def prepare_block_maxmin(flashmask_info: FlashMaskInfoPaddle, kBlockN: int = 128 batch, heads, seqlen_k, num_vecs = flashmask_info.startend_row_indices.shape nblocks = _compute_nblock_seqlen(seqlen_k, kBlockN) + device = flashmask_info.startend_row_indices.device + + def create_buffer(): + return torch.zeros(batch, heads, nblocks, dtype=torch.int32, device=device) + if num_vecs == 1 and flashmask_info.LTS_nblock_max is None and flashmask_info.LTS_nblock_min is None: - flashmask_info.LTS_nblock_max = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.LTS_nblock_min = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) + flashmask_info.LTS_nblock_max = create_buffer() + flashmask_info.LTS_nblock_min = create_buffer() _scan_max_min(flashmask_info.startend_row_indices[..., 0], batch * heads, seqlen_k, flashmask_info.LTS_nblock_max, flashmask_info.LTS_nblock_min, kBlockN) elif num_vecs == 2 and flashmask_info.is_causal and ( flashmask_info.LTS_nblock_max is None and flashmask_info.LTS_nblock_min is None and flashmask_info.LTE_nblock_max is None and flashmask_info.LTE_nblock_min is None ): - flashmask_info.LTS_nblock_max = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.LTS_nblock_min = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.LTE_nblock_max = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.LTE_nblock_min = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) + flashmask_info.LTS_nblock_max = create_buffer() + flashmask_info.LTS_nblock_min = create_buffer() + flashmask_info.LTE_nblock_max = create_buffer() + flashmask_info.LTE_nblock_min = create_buffer() _scan_max_min(flashmask_info.startend_row_indices[..., 0], batch * heads, seqlen_k, flashmask_info.LTS_nblock_max, flashmask_info.LTS_nblock_min, kBlockN) _scan_max_min(flashmask_info.startend_row_indices[..., 1], batch * heads, seqlen_k, flashmask_info.LTE_nblock_max, flashmask_info.LTE_nblock_min, kBlockN) elif num_vecs == 2 and not flashmask_info.is_causal and ( flashmask_info.LTS_nblock_max is None and flashmask_info.LTS_nblock_min is None and flashmask_info.UTE_nblock_max is None and flashmask_info.UTE_nblock_min is None ): - flashmask_info.LTS_nblock_max = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.LTS_nblock_min = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.UTE_nblock_max = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.UTE_nblock_min = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) + flashmask_info.LTS_nblock_max = create_buffer() + flashmask_info.LTS_nblock_min = create_buffer() + flashmask_info.UTE_nblock_max = create_buffer() + flashmask_info.UTE_nblock_min = create_buffer() _scan_max_min(flashmask_info.startend_row_indices[..., 0], batch * heads, seqlen_k, flashmask_info.LTS_nblock_max, flashmask_info.LTS_nblock_min, kBlockN) _scan_max_min(flashmask_info.startend_row_indices[..., 1], batch * heads, seqlen_k, flashmask_info.UTE_nblock_max, flashmask_info.UTE_nblock_min, kBlockN) elif num_vecs == 4 and ( @@ -258,14 +264,14 @@ def prepare_block_maxmin(flashmask_info: FlashMaskInfoPaddle, kBlockN: int = 128 flashmask_info.UTE_nblock_max is None and flashmask_info.UTE_nblock_min is None ): - flashmask_info.LTS_nblock_max = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.LTS_nblock_min = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.LTE_nblock_max = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.LTE_nblock_min = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.UTS_nblock_max = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.UTS_nblock_min = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.UTE_nblock_max = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) - flashmask_info.UTE_nblock_min = paddle.zeros([batch, heads, nblocks], dtype=paddle.int32) + flashmask_info.LTS_nblock_max = create_buffer() + flashmask_info.LTS_nblock_min = create_buffer() + flashmask_info.LTE_nblock_max = create_buffer() + flashmask_info.LTE_nblock_min = create_buffer() + flashmask_info.UTS_nblock_max = create_buffer() + flashmask_info.UTS_nblock_min = create_buffer() + flashmask_info.UTE_nblock_max = create_buffer() + flashmask_info.UTE_nblock_min = create_buffer() _scan_max_min(flashmask_info.startend_row_indices[..., 0], batch * heads, seqlen_k, flashmask_info.LTS_nblock_max, flashmask_info.LTS_nblock_min, kBlockN) _scan_max_min(flashmask_info.startend_row_indices[..., 1], batch * heads, seqlen_k, flashmask_info.LTE_nblock_max, flashmask_info.LTE_nblock_min, kBlockN) _scan_max_min(flashmask_info.startend_row_indices[..., 2], batch * heads, seqlen_k, flashmask_info.UTS_nblock_max, flashmask_info.UTS_nblock_min, kBlockN) @@ -274,7 +280,7 @@ def prepare_block_maxmin(flashmask_info: FlashMaskInfoPaddle, kBlockN: int = 128 raise ValueError(f"Unsupported num_vecs={num_vecs} in flashmask_info") -def is_flashmask_enabled(flashmask_info: FlashMaskInfoPaddle) -> bool: +def is_flashmask_enabled(flashmask_info: FlashMaskInfoTorch) -> bool: return any(t is not None for t in ( flashmask_info.LTS_nblock_max, flashmask_info.LTS_nblock_min, @@ -286,12 +292,17 @@ def is_flashmask_enabled(flashmask_info: FlashMaskInfoPaddle) -> bool: flashmask_info.UTE_nblock_min, )) -def to_cute_flashmask_info(flashmask_info: FlashMaskInfoPaddle) -> Optional[FlashMaskInfo]: +def to_cute_flashmask_info(flashmask_info: FlashMaskInfoTorch) -> Optional[FlashMaskInfo]: if not is_flashmask_enabled(flashmask_info): return None batch, heads, seqlen_k, num_vecs = flashmask_info.startend_row_indices.shape + def _wrap(t): + if t is None: + return None + return from_dlpack(t, assumed_align=4).mark_layout_dynamic(leading_dim=2) + startend_row_indices_tensor = from_dlpack(flashmask_info.startend_row_indices, assumed_align=4).mark_layout_dynamic(leading_dim=3) LTS_nblock_max_tensor = None LTS_nblock_min_tensor = None @@ -303,32 +314,32 @@ def to_cute_flashmask_info(flashmask_info: FlashMaskInfoPaddle) -> Optional[Flas UTE_nblock_min_tensor = None if num_vecs == 1: - LTS_nblock_max_tensor = from_dlpack(flashmask_info.LTS_nblock_max, assumed_align=4).mark_layout_dynamic(leading_dim=2) - LTS_nblock_min_tensor = from_dlpack(flashmask_info.LTS_nblock_min, assumed_align=4).mark_layout_dynamic(leading_dim=2) + LTS_nblock_max_tensor = _wrap(flashmask_info.LTS_nblock_max) + LTS_nblock_min_tensor = _wrap(flashmask_info.LTS_nblock_min) elif num_vecs == 2 and flashmask_info.is_causal: - LTS_nblock_max_tensor = from_dlpack(flashmask_info.LTS_nblock_max, assumed_align=4).mark_layout_dynamic(leading_dim=2) - LTS_nblock_min_tensor = from_dlpack(flashmask_info.LTS_nblock_min, assumed_align=4).mark_layout_dynamic(leading_dim=2) - LTE_nblock_max_tensor = from_dlpack(flashmask_info.LTE_nblock_max, assumed_align=4).mark_layout_dynamic(leading_dim=2) - LTE_nblock_min_tensor = from_dlpack(flashmask_info.LTE_nblock_min, assumed_align=4).mark_layout_dynamic(leading_dim=2) + LTS_nblock_max_tensor = _wrap(flashmask_info.LTS_nblock_max) + LTS_nblock_min_tensor = _wrap(flashmask_info.LTS_nblock_min) + LTE_nblock_max_tensor = _wrap(flashmask_info.LTE_nblock_max) + LTE_nblock_min_tensor = _wrap(flashmask_info.LTE_nblock_min) elif num_vecs == 2 and not flashmask_info.is_causal: - LTS_nblock_max_tensor = from_dlpack(flashmask_info.LTS_nblock_max, assumed_align=4).mark_layout_dynamic(leading_dim=2) - LTS_nblock_min_tensor = from_dlpack(flashmask_info.LTS_nblock_min, assumed_align=4).mark_layout_dynamic(leading_dim=2) - UTE_nblock_max_tensor = from_dlpack(flashmask_info.UTE_nblock_max, assumed_align=4).mark_layout_dynamic(leading_dim=2) - UTE_nblock_min_tensor = from_dlpack(flashmask_info.UTE_nblock_min, assumed_align=4).mark_layout_dynamic(leading_dim=2) + LTS_nblock_max_tensor = _wrap(flashmask_info.LTS_nblock_max) + LTS_nblock_min_tensor = _wrap(flashmask_info.LTS_nblock_min) + UTE_nblock_max_tensor = _wrap(flashmask_info.UTE_nblock_max) + UTE_nblock_min_tensor = _wrap(flashmask_info.UTE_nblock_min) elif num_vecs == 4: - LTS_nblock_max_tensor = from_dlpack(flashmask_info.LTS_nblock_max, assumed_align=4).mark_layout_dynamic(leading_dim=2) - LTS_nblock_min_tensor = from_dlpack(flashmask_info.LTS_nblock_min, assumed_align=4).mark_layout_dynamic(leading_dim=2) - LTE_nblock_max_tensor = from_dlpack(flashmask_info.LTE_nblock_max, assumed_align=4).mark_layout_dynamic(leading_dim=2) - LTE_nblock_min_tensor = from_dlpack(flashmask_info.LTE_nblock_min, assumed_align=4).mark_layout_dynamic(leading_dim=2) - UTS_nblock_max_tensor = from_dlpack(flashmask_info.UTS_nblock_max, assumed_align=4).mark_layout_dynamic(leading_dim=2) - UTS_nblock_min_tensor = from_dlpack(flashmask_info.UTS_nblock_min, assumed_align=4).mark_layout_dynamic(leading_dim=2) - UTE_nblock_max_tensor = from_dlpack(flashmask_info.UTE_nblock_max, assumed_align=4).mark_layout_dynamic(leading_dim=2) - UTE_nblock_min_tensor = from_dlpack(flashmask_info.UTE_nblock_min, assumed_align=4).mark_layout_dynamic(leading_dim=2) + LTS_nblock_max_tensor = _wrap(flashmask_info.LTS_nblock_max) + LTS_nblock_min_tensor = _wrap(flashmask_info.LTS_nblock_min) + LTE_nblock_max_tensor = _wrap(flashmask_info.LTE_nblock_max) + LTE_nblock_min_tensor = _wrap(flashmask_info.LTE_nblock_min) + UTS_nblock_max_tensor = _wrap(flashmask_info.UTS_nblock_max) + UTS_nblock_min_tensor = _wrap(flashmask_info.UTS_nblock_min) + UTE_nblock_max_tensor = _wrap(flashmask_info.UTE_nblock_max) + UTE_nblock_min_tensor = _wrap(flashmask_info.UTE_nblock_min) else: raise ValueError(f"Unsupported num_vecs={num_vecs} in flashmask_info") if flashmask_info.valid_block_count is not None: - valid_block_count = from_dlpack(flashmask_info.valid_block_count, assumed_align=4).mark_layout_dynamic(leading_dim=2) + valid_block_count = _wrap(flashmask_info.valid_block_count) else: valid_block_count = None @@ -462,7 +473,7 @@ def reduce_block_count_cute( # Note(wusiming): make sure call reduce_block_count after scan_max_min def reduce_block_count( - flashmask_info: FlashMaskInfo, + flashmask_info: FlashMaskInfoTorch, is_causal: bool, kBlockM: int, kBlockN: int, @@ -489,7 +500,7 @@ def reduce_block_count( has_uts = False has_ute = False - current_stream = cuda.CUstream(paddle.device.current_stream().stream_base.cuda_stream) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # TODO(wusiming): Are all of these compile keys necessary? compile_key = (is_causal, kBlockM, kBlockN, batch, heads, has_lte, has_uts, has_ute) diff --git a/flashmask/flash_mask/cute/interface.py b/flashmask/flash_mask/cute/interface.py index ec126b28866..f003dd6dbff 100644 --- a/flashmask/flash_mask/cute/interface.py +++ b/flashmask/flash_mask/cute/interface.py @@ -36,7 +36,7 @@ import math from typing import Optional, Tuple, Callable, Union -import paddle +import torch import cuda.bindings.driver as cuda @@ -54,27 +54,29 @@ from flash_mask.cute.flash_bwd_postprocess import FlashAttentionBackwardPostprocess from flash_mask.cute.flash_fwd_combine import FlashAttentionForwardCombine from flash_mask.cute.flashmask_utils import ( - FlashMaskInfoPaddle, + FlashMaskInfoTorch, prepare_block_maxmin, to_cute_flashmask_info, reduce_block_count, ) from flash_mask.cute.block_sparsity import ( - BlockSparseTensorsPaddle, + BlockSparseTensorsTorch, to_cute_block_sparse_tensors, normalize_block_sparse_tensors, ) def maybe_contiguous(x): - return x.contiguous() if x is not None and x.strides[-1] != 1 else x + # PyTorch uses .stride() method. + # Check if the last dimension has a stride of 1 (RowMajor layout for the last dim) + return x.contiguous() if x is not None and x.stride(-1) != 1 else x -paddle2cute_dtype_map = { - paddle.float16: cutlass.Float16, - paddle.bfloat16: cutlass.BFloat16, - paddle.float32: cutlass.Float32, +torch2cute_dtype_map = { + torch.float16: cutlass.Float16, + torch.bfloat16: cutlass.BFloat16, + torch.float32: cutlass.Float32, } @@ -89,20 +91,20 @@ def num_splits_heuristic(total_mblocks, num_SMs, num_n_blocks, max_splits): def _flash_attn_fwd( - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - cu_seqlens_q: Optional[paddle.Tensor] = None, - cu_seqlens_k: Optional[paddle.Tensor] = None, - seqused_q: Optional[paddle.Tensor] = None, - seqused_k: Optional[paddle.Tensor] = None, - page_table: Optional[paddle.Tensor] = None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, softcap: Optional[float] = None, window_size_left: Optional[int] = None, window_size_right: Optional[int] = None, - learnable_sink: Optional[paddle.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, # m_block_size: int = 128, # n_block_size: int = 64, # num_threads: int = 128, @@ -114,13 +116,13 @@ def _flash_attn_fwd( _compute_capability: Optional[int] = None, score_mod: Optional[Callable] = None, mask_mod: Optional[Callable] = None, - block_sparse_tensors: Optional[BlockSparseTensorsPaddle] = None, + block_sparse_tensors: Optional[BlockSparseTensorsTorch] = None, return_lse: bool = False, - out: Optional[paddle.Tensor] = None, - lse: Optional[paddle.Tensor] = None, - aux_tensors: Optional[list[paddle.Tensor]] = None, - startend_row_indices: Optional[paddle.Tensor] = None, -) -> Tuple[paddle.Tensor, paddle.Tensor]: + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + aux_tensors: Optional[list[torch.Tensor]] = None, + startend_row_indices: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: """Forward pass for FlashAttention. Args: @@ -150,21 +152,26 @@ def _flash_attn_fwd( # Note(wusiming): FA4 is so weird, but each cta process q_stage * m_block_size rows q_stage = 2 num_m_blocks = (seqlen_q + (q_stage * m_block_size) - 1) // (q_stage * m_block_size) - flashmask_info = FlashMaskInfoPaddle( + flashmask_info = FlashMaskInfoTorch( is_causal=causal, startend_row_indices=startend_row_indices, ) - flashmask_info.valid_block_count = paddle.empty([fm_batch_size, fm_heads, num_m_blocks], dtype=paddle.int32) + flashmask_info.valid_block_count = torch.empty( + [fm_batch_size, fm_heads, num_m_blocks], + dtype=torch.int32, + device=startend_row_indices.device + ) prepare_block_maxmin(flashmask_info) cute_flashmask_info = to_cute_flashmask_info(flashmask_info) reduce_block_count(cute_flashmask_info, causal, q_stage * m_block_size, n_block_size, seqlen_q) if page_table is not None: assert cu_seqlens_k is None, "page_table is not supported with cu_seqlens_k" - assert page_table.dtype == paddle.int32, "page_table must be int32" - assert page_table.strides[-1] == 1, "page_table must be contiguous in the last dimension" + assert page_table.dtype == torch.int32, "page_table must be int32" + # PyTorch uses .stride() method + assert page_table.stride(-1) == 1, "page_table must be contiguous in the last dimension" max_num_pages_per_seq = page_table.shape[1] - assert page_table.shape == [batch_size, max_num_pages_per_seq] + assert page_table.shape == (batch_size, max_num_pages_per_seq) num_pages, page_size = k.shape[:2] seqlen_k = num_pages * page_size else: @@ -174,48 +181,49 @@ def _flash_attn_fwd( head_dim_v = v.shape[-1] if cu_seqlens_k is None: if page_table is None: - assert k.shape == [batch_size, seqlen_k, num_head_kv, head_dim], ( - f"expect k with shape {[batch_size, seqlen_k, num_head_kv, head_dim]}, received {k.shape=}" + assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim), ( + f"expect k with shape {(batch_size, seqlen_k, num_head_kv, head_dim)}, received {k.shape=}" ) - assert v.shape == [batch_size, seqlen_k, num_head_kv, head_dim_v] + assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) else: - assert k.shape == [num_pages, page_size, num_head_kv, head_dim] - assert v.shape == [num_pages, page_size, num_head_kv, head_dim_v] + assert k.shape == (num_pages, page_size, num_head_kv, head_dim) + assert v.shape == (num_pages, page_size, num_head_kv, head_dim_v) else: - assert k.shape == [seqlen_k, num_head_kv, head_dim] - assert v.shape == [seqlen_k, num_head_kv, head_dim_v] - assert cu_seqlens_k.shape == [ + assert k.shape == (seqlen_k, num_head_kv, head_dim) + assert v.shape == (seqlen_k, num_head_kv, head_dim_v) + assert cu_seqlens_k.shape == ( batch_size + 1, - ], "cu_seqlens_k must have shape (batch_size + 1,)" + ), "cu_seqlens_k must have shape (batch_size + 1,)" if cu_seqlens_q is not None: - assert cu_seqlens_q.shape == [ + assert cu_seqlens_q.shape == ( batch_size + 1, - ], "cu_seqlens_q must have shape (batch_size + 1,)" - assert seqused_q is None or seqused_q.shape == [ + ), "cu_seqlens_q must have shape (batch_size + 1,)" + assert seqused_q is None or seqused_q.shape == ( batch_size, - ], "seqused_q must have shape (batch_size,)" - assert seqused_k is None or seqused_k.shape == [ + ), "seqused_q must have shape (batch_size,)" + assert seqused_k is None or seqused_k.shape == ( batch_size, - ], "seqused_k must have shape (batch_size,)" - assert q.dtype in [paddle.float16, paddle.bfloat16], "inputs must be float16 or bfloat16" + ), "seqused_k must have shape (batch_size,)" + assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" assert q.dtype == k.dtype == v.dtype, "inputs must have the same dtype" for t in [cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k]: if t is not None: - assert t.dtype == paddle.int32, ( + assert t.dtype == torch.int32, ( "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be int32" ) - assert t.strides[0] == 1, ( + # PyTorch uses .stride() method + assert t.stride(0) == 1, ( "cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k must be contiguous" ) if learnable_sink is not None: - assert learnable_sink.shape == [ + assert learnable_sink.shape == ( num_head, - ] - assert learnable_sink.dtype == paddle.bfloat16, "learnable_sink must be bfloat16" + ) + assert learnable_sink.dtype == torch.bfloat16, "learnable_sink must be bfloat16" assert all( - t is None or t.place.is_gpu_place() + t is None or t.is_cuda for t in ( q, k, @@ -241,37 +249,44 @@ def _flash_attn_fwd( if pack_gqa is None: pack_gqa = qhead_per_kvhead > 1 - out_paddle_dtype = q.dtype - place = q.place + out_dtype = q.dtype + device = q.device q_batch_seqlen_shape = ( - [batch_size, seqlen_q] + (batch_size, seqlen_q) if cu_seqlens_q is None - else [ + else ( total_q, - ] + ) ) - lse_shape = [batch_size, num_head, seqlen_q] if cu_seqlens_q is None else [num_head, total_q] - requires_grad = not (q.stop_gradient and k.stop_gradient and v.stop_gradient) - + lse_shape = (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) + # Paddle's stop_gradient is the inverse of PyTorch's requires_grad + # Original: requires_grad = not (q.stop_gradient and k.stop_gradient and v.stop_gradient) + # Meaning: requires_grad if AT LEAST ONE input needs gradient. + requires_grad = q.requires_grad or k.requires_grad or v.requires_grad + + # Assuming 'q' is available to determine the correct device + if out is None: - out = paddle.zeros( - shape=[*q_batch_seqlen_shape, num_head, head_dim_v], dtype=out_paddle_dtype + out = torch.zeros( + *q_batch_seqlen_shape, num_head, head_dim_v, + dtype=out_dtype, + device=device ) else: - expected_out_shape = [*q_batch_seqlen_shape, num_head, head_dim_v] + expected_out_shape = (*q_batch_seqlen_shape, num_head, head_dim_v) assert out.shape == expected_out_shape, ( f"out tensor shape {out.shape} does not match expected shape {expected_out_shape}" ) - assert out.dtype == out_paddle_dtype, ( - f"out tensor dtype {out.dtype} does not match expected dtype {out_paddle_dtype}" + assert out.dtype == out_dtype, ( + f"out tensor dtype {out.dtype} does not match expected dtype {out_dtype}" ) - assert out.place.is_gpu_place(), ( - f"out tensor device {out.place} does not match input device" + assert out.is_cuda, ( + f"out tensor device {out.device} does not match input device" ) if lse is None: lse = ( - paddle.full(shape=lse_shape, fill_value=float('-inf'), dtype=paddle.float32) + torch.full(lse_shape, float('-inf'), dtype=torch.float32, device=device) if requires_grad or return_lse else None ) @@ -279,12 +294,12 @@ def _flash_attn_fwd( assert lse.shape == lse_shape, ( f"lse tensor shape {lse.shape} does not match expected shape {lse_shape}" ) - assert lse.dtype == paddle.float32, ( - f"lse tensor dtype {lse.dtype} does not match expected dtype paddle.float32" + assert lse.dtype == torch.float32, ( + f"lse tensor dtype {lse.dtype} does not match expected dtype torch.float32" ) - assert lse.place.is_gpu_place(), "lse tensor must be on CUDA device" + assert lse.is_cuda, "lse tensor must be on CUDA device" - dtype = paddle2cute_dtype_map[q.dtype] + dtype = torch2cute_dtype_map[q.dtype] ( cu_seqlens_q_tensor, cu_seqlens_k_tensor, @@ -303,7 +318,7 @@ def _flash_attn_fwd( else None ) compute_capability = ( - paddle.device.cuda.get_device_capability()[0] + torch.cuda.get_device_capability(q.device)[0] if _compute_capability is None else _compute_capability ) @@ -344,7 +359,7 @@ def _flash_attn_fwd( else: causal, local = False, False - current_stream = cuda.CUstream(paddle.device.current_stream().stream_base.cuda_stream) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) if compute_capability == 9: # TODO: tune block size according to hdim. if head_dim == head_dim_v == 128 and not causal and not local and not use_block_sparsity: @@ -383,18 +398,23 @@ def _flash_attn_fwd( total_mblocks = batch_size * num_head_kv * num_m_blocks num_splits = num_splits_heuristic( total_mblocks, - paddle.device.cuda.get_device_properties(place.gpu_device_id()).multi_processor_count, + torch.cuda.get_device_properties(q.device).multi_processor_count, num_n_blocks, 128, ) is_split_kv = num_splits > 1 if is_split_kv: - out_partial = paddle.empty( - shape=[num_splits, *q_batch_seqlen_shape, num_head, head_dim_v], dtype=paddle.float32 + out_partial = torch.empty( + num_splits, *q_batch_seqlen_shape, num_head, head_dim_v, + dtype=torch.float32, + device=q.device + ) + lse_partial = torch.empty( + num_splits, *lse_shape, + dtype=torch.float32, + device=q.device ) - lse_partial = paddle.empty(shape=[num_splits, *lse_shape], dtype=paddle.float32) - q_tensor, k_tensor, v_tensor, o_tensor = [ from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) for t in (q, k, v, out if not is_split_kv else out_partial) @@ -598,13 +618,13 @@ def _flash_attn_fwd( def _flash_attn_bwd( - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - out: paddle.Tensor, - dout: paddle.Tensor, - lse: paddle.Tensor, - flashmask_info: Optional[Union[FlashMaskInfoPaddle, paddle.Tensor]] = None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + dout: torch.Tensor, + lse: torch.Tensor, + flashmask_info: Optional[Union[FlashMaskInfoTorch, torch.Tensor]] = None, softmax_scale: Optional[float] = None, causal: bool = False, softcap: float = 0.0, @@ -621,24 +641,25 @@ def _flash_attn_bwd( AtomLayoutNdKV: int = 2, AtomLayoutMdQ: int = 2, V_in_regs: bool = False, - cu_seqlens_q: Optional[paddle.Tensor] = None, - cu_seqlens_k: Optional[paddle.Tensor] = None, - seqused_q: Optional[paddle.Tensor] = None, - seqused_k: Optional[paddle.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, deterministic: bool = False, -) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - compute_capability = paddle.device.cuda.get_device_capability()[0] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + compute_capability = torch.cuda.get_device_capability(q.device)[0] assert compute_capability in [10], "Unsupported compute capability. Supported: 10.x" cute_flashmask_info = None num_flashmask_tensors = 0 - if flashmask_info is not None and isinstance(flashmask_info, paddle.Tensor): - flashmask_info = FlashMaskInfoPaddle( + + if flashmask_info is not None and isinstance(flashmask_info, torch.Tensor): + flashmask_info = FlashMaskInfoTorch( startend_row_indices=flashmask_info, is_causal=causal, ) if flashmask_info is not None: - assert isinstance(flashmask_info, FlashMaskInfoPaddle) + assert isinstance(flashmask_info, FlashMaskInfoTorch) prepare_block_maxmin(flashmask_info) cute_flashmask_info = to_cute_flashmask_info(flashmask_info) num_flashmask_tensors = 2 * flashmask_info.startend_row_indices.shape[-1] @@ -690,40 +711,41 @@ def _flash_attn_bwd( head_dim_v = v.shape[-1] if cu_seqlens_k is None: - assert k.shape == [batch_size, seqlen_k, num_head_kv, head_dim] - assert v.shape == [batch_size, seqlen_k, num_head_kv, head_dim_v] + assert k.shape == (batch_size, seqlen_k, num_head_kv, head_dim) + assert v.shape == (batch_size, seqlen_k, num_head_kv, head_dim_v) else: - assert k.shape == [total_k, num_head_kv, head_dim] - assert v.shape == [total_k, num_head_kv, head_dim_v] - assert cu_seqlens_k.shape == [ + assert k.shape == (total_k, num_head_kv, head_dim) + assert v.shape == (total_k, num_head_kv, head_dim_v) + assert cu_seqlens_k.shape == ( batch_size + 1, - ], "cu_seqlens_k must have shape (batch_size + 1,)" + ), "cu_seqlens_k must have shape (batch_size + 1,)" if cu_seqlens_q is not None: - assert cu_seqlens_q.shape == [ + assert cu_seqlens_q.shape == ( batch_size + 1, - ], "cu_seqlens_q must have shape (batch_size + 1,)" + ), "cu_seqlens_q must have shape (batch_size + 1,)" - assert out.shape == [total_q, num_head, head_dim_v] - assert dout.shape == [total_q, num_head, head_dim_v] - assert lse.shape == [num_head, total_q], "lse must have shape (num_head, total_q)" + assert out.shape == (total_q, num_head, head_dim_v) + assert dout.shape == (total_q, num_head, head_dim_v) + assert lse.shape == (num_head, total_q), "lse must have shape (num_head, total_q)" else: - assert out.shape == [batch_size, seqlen_q, num_head, head_dim_v] - assert dout.shape == [batch_size, seqlen_q, num_head, head_dim_v] - assert lse.shape == [batch_size, num_head, seqlen_q], ( + assert out.shape == (batch_size, seqlen_q, num_head, head_dim_v) + assert dout.shape == (batch_size, seqlen_q, num_head, head_dim_v) + assert lse.shape == (batch_size, num_head, seqlen_q), ( "lse must have shape (batch_size, num_head, seqlen_q)" ) - assert q.dtype in [paddle.float16, paddle.bfloat16], "inputs must be float16 or bfloat16" + assert q.dtype in [torch.float16, torch.bfloat16], "inputs must be float16 or bfloat16" assert q.dtype == k.dtype == v.dtype == out.dtype == dout.dtype, ( "inputs must have the same dtype" ) for t in [cu_seqlens_q, cu_seqlens_k]: if t is not None: - assert t.dtype == paddle.int32, "cu_seqlens_q, cu_seqlens_k must be int32" - assert lse.dtype == paddle.float32, "lse must be float32" + assert t.dtype == torch.int32, "cu_seqlens_q, cu_seqlens_k must be int32" + assert lse.dtype == torch.float32, "lse must be float32" + assert all( - t is None or t.place.is_gpu_place() + t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k) ), "inputs must be on CUDA device" assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv" @@ -741,32 +763,50 @@ def _flash_attn_bwd( if compute_capability != 10: assert deterministic is False, "bwd deterministic only supported for sm100 for now" - place = q.place + device = q.device # TODO: check if this is the right rounding - dq = paddle.zeros_like(q) - dk = paddle.zeros_like(k) - dv = paddle.zeros_like(v) + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) head_dim_rounded = (head_dim + 32 - 1) // 32 * 32 if cu_seqlens_q is None: seqlen_q_rounded = (seqlen_q + m_block_size - 1) // m_block_size * m_block_size - dq_accum = paddle.empty( - shape=[batch_size, num_head, seqlen_q_rounded * head_dim_rounded], dtype=paddle.float32 + dq_accum = torch.empty( + batch_size, num_head, seqlen_q_rounded * head_dim_rounded, + dtype=torch.float32, + device=device + ) + dpsum = torch.empty( + batch_size, num_head, seqlen_q_rounded, + dtype=torch.float32, + device=device ) - dpsum = paddle.empty(shape=[batch_size, num_head, seqlen_q_rounded], dtype=paddle.float32) - lse_log2 = paddle.empty( - shape=[batch_size, num_head, seqlen_q_rounded], dtype=paddle.float32 + lse_log2 = torch.empty( + batch_size, num_head, seqlen_q_rounded, + dtype=torch.float32, + device=device ) else: total_q_rounded_padded = ( (total_q + cu_seqlens_q.shape[0] * m_block_size - 1) // m_block_size * m_block_size ) - dq_accum = paddle.empty( - shape=[num_head, total_q_rounded_padded * head_dim_rounded], dtype=paddle.float32 + dq_accum = torch.empty( + num_head, total_q_rounded_padded * head_dim_rounded, + dtype=torch.float32, + device=device + ) + dpsum = torch.empty( + num_head, total_q_rounded_padded, + dtype=torch.float32, + device=device + ) + lse_log2 = torch.empty( + num_head, total_q_rounded_padded, + dtype=torch.float32, + device=device ) - dpsum = paddle.empty(shape=[num_head, total_q_rounded_padded], dtype=paddle.float32) - lse_log2 = paddle.empty(shape=[num_head, total_q_rounded_padded], dtype=paddle.float32) if qhead_per_kvhead > 1: head_dim_v_rounded = (head_dim_v + 32 - 1) // 32 * 32 @@ -775,13 +815,15 @@ def _flash_attn_bwd( num_n_blocks = seqlen_k_rounded // n_block_size if cluster_size == 2 and num_n_blocks % cluster_size != 0: seqlen_k_rounded = seqlen_k_rounded + n_block_size - dk_accum = paddle.zeros( - shape=[batch_size, num_head_kv, seqlen_k_rounded * head_dim_rounded], - dtype=paddle.float32, + dk_accum = torch.zeros( + batch_size, num_head_kv, seqlen_k_rounded * head_dim_rounded, + dtype=torch.float32, + device=device ) - dv_accum = paddle.zeros( - shape=[batch_size, num_head_kv, seqlen_k_rounded * head_dim_v_rounded], - dtype=paddle.float32, + dv_accum = torch.zeros( + batch_size, num_head_kv, seqlen_k_rounded * head_dim_v_rounded, + dtype=torch.float32, + device=device ) else: total_k_rounded_padded = ( @@ -790,56 +832,66 @@ def _flash_attn_bwd( num_n_blocks = total_k_rounded_padded // n_block_size if cluster_size == 2 and num_n_blocks % cluster_size != 0: total_k_rounded_padded = total_k_rounded_padded + n_block_size - dk_accum = paddle.zeros( - shape=[num_head_kv, total_k_rounded_padded * head_dim_rounded], dtype=paddle.float32 + dk_accum = torch.zeros( + num_head_kv, total_k_rounded_padded * head_dim_rounded, + dtype=torch.float32, + device=device ) - dv_accum = paddle.zeros( - shape=[num_head_kv, total_k_rounded_padded * head_dim_v_rounded], - dtype=paddle.float32, + dv_accum = torch.zeros( + num_head_kv, total_k_rounded_padded * head_dim_v_rounded, + dtype=torch.float32, + device=device ) - dtype = paddle2cute_dtype_map[q.dtype] + dtype = torch2cute_dtype_map[q.dtype] + def _wrap(t, align): + return from_dlpack(t.detach(), assumed_align=align).mark_layout_dynamic(leading_dim=t.ndim - 1) + q_tensor, k_tensor, v_tensor, o_tensor, do_tensor, dq_tensor, dk_tensor, dv_tensor = [ - from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) + _wrap(t, 16) for t in (q, k, v, out, dout, dq, dk, dv) ] - lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic( - leading_dim=lse.ndim - 1 - ) + lse_tensor = _wrap(lse, 4) dq_accum_tensor, dpsum_tensor, lse_log2_tensor = [ - from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) + _wrap(t, 16) for t in (dq_accum, dpsum, lse_log2) ] if qhead_per_kvhead > 1: dk_accum_tensor, dv_accum_tensor = [ - from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) + _wrap(t, 16) for t in (dk_accum, dv_accum) ] cu_seqlens_q_tensor, cu_seqlens_k_tensor, seqused_q_tensor, seqused_k_tensor = [ - from_dlpack(t.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=t.ndim - 1) - if t is not None - else None + _wrap(t, 4) if t is not None else None for t in (cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k) ] if deterministic: - dQ_semaphore = paddle.zeros( - shape=[batch_size, num_head, seqlen_q_rounded // m_block_size, 1], dtype=paddle.int32 + # Paddle: paddle.zeros(shape=[...], dtype=paddle.int32) + # PyTorch: torch.zeros(..., dtype=torch.int32, device=device) + dQ_semaphore = torch.zeros( + batch_size, num_head, seqlen_q_rounded // m_block_size, 1, + dtype=torch.int32, + device=device # 必须指定设备! ) else: dQ_semaphore = None if deterministic and qhead_per_kvhead > 1: - dK_semaphore = paddle.zeros( - shape=[batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2], dtype=paddle.int32 + dK_semaphore = torch.zeros( + batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, + dtype=torch.int32, + device=device ) - dV_semaphore = paddle.zeros( - shape=[batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2], dtype=paddle.int32 + dV_semaphore = torch.zeros( + batch_size, num_head_kv, seqlen_k_rounded // n_block_size, 2, + dtype=torch.int32, + device=device ) else: dK_semaphore = None dV_semaphore = None - # Note(wusiming): paddle doesn’t expose the physics layout, so assert that the tensor is contiguous here + # Note: PyTorch 也有 .is_contiguous(),保留这些检查 if dQ_semaphore is not None: assert dQ_semaphore.is_contiguous() if dK_semaphore is not None: @@ -848,13 +900,16 @@ def _flash_attn_bwd( assert dV_semaphore.is_contiguous() dQ_semaphore_tensor, dK_semaphore_tensor, dV_semaphore_tensor = [ utils.convert_from_dlpack_leading_static( - t.detach(), leading_dim=3, alignment=4, stride_order=tuple(range(t.ndim)) + t.detach(), + leading_dim=3, + alignment=4, + stride_order=tuple(range(t.ndim)) ) if t is not None else None for t in (dQ_semaphore, dK_semaphore, dV_semaphore) ] - current_stream = cuda.CUstream(paddle.device.current_stream().stream_base.cuda_stream) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # Preprocess kernel: compute (o * dout).sum(dim=-1), lse * log2_e, and zero out dq_accum. compile_key_pre = (compute_capability, dtype, head_dim_v, m_block_size, num_threads) @@ -1123,33 +1178,33 @@ def _flash_attn_bwd( _flash_attn_bwd.compile_cache_post = {} -class FlashAttnFunc(paddle.autograd.PyLayer): +class FlashAttnFunc(torch.autograd.Function): @staticmethod def forward( ctx, - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), - learnable_sink: Optional[paddle.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, mask_mod: Optional[Callable] = None, - full_block_cnt: Optional[paddle.Tensor] = None, - full_block_idx: Optional[paddle.Tensor] = None, - mask_block_cnt: Optional[paddle.Tensor] = None, - mask_block_idx: Optional[paddle.Tensor] = None, + full_block_cnt: Optional[torch.Tensor] = None, + full_block_idx: Optional[torch.Tensor] = None, + mask_block_cnt: Optional[torch.Tensor] = None, + mask_block_idx: Optional[torch.Tensor] = None, ): # Only create block sparse tensors if at least one block sparse parameter is provided block_sparse_tensors = None if any( t is not None for t in [full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx] ): - block_sparse_tensors = BlockSparseTensorsPaddle( + block_sparse_tensors = BlockSparseTensorsTorch( full_block_cnt=full_block_cnt, full_block_idx=full_block_idx, mask_block_cnt=mask_block_cnt, @@ -1180,7 +1235,7 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, lse = ctx.saved_tensor() + q, k, v, out, lse = ctx.saved_tensors dq, dk, dv = _flash_attn_bwd( q, k, @@ -1188,31 +1243,37 @@ def backward(ctx, dout, *args): out, dout, lse, - ctx.softmax_scale, - ctx.causal, - ctx.softcap, + # [无需修改] 这些都是在 forward 里手动存到 ctx 上的属性 + softmax_scale=ctx.softmax_scale, + causal=ctx.causal, + softcap=ctx.softcap, deterministic=ctx.deterministic, ) - # TODO(wusiming): do we need to return None for other fwd inputs? - return dq, dk, dv + # 必须补齐 None! + # forward 共有 16 个参数: + # q, k, v, softmax_scale, causal, window_size, learnable_sink, softcap, + # num_splits, pack_gqa, deterministic, mask_mod, + # full_block_cnt, full_block_idx, mask_block_cnt, mask_block_idx + # 因此这里要返回 3 个 Tensor 梯度 + 13 个 None + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None -class FlashAttnVarlenFunc(paddle.autograd.PyLayer): +class FlashAttnVarlenFunc(torch.autograd.Function): @staticmethod def forward( ctx, - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - cu_seqlens_q: Optional[paddle.Tensor], - cu_seqlens_k: Optional[paddle.Tensor], - seqused_q: Optional[paddle.Tensor] = None, - seqused_k: Optional[paddle.Tensor] = None, - page_table: Optional[paddle.Tensor] = None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor], + cu_seqlens_k: Optional[torch.Tensor], + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), - learnable_sink: Optional[paddle.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, @@ -1246,7 +1307,8 @@ def forward( @staticmethod def backward(ctx, dout, *args): - q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensor() + q, k, v, out, lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k = ctx.saved_tensors + # 逻辑检查保留 assert seqused_q is None assert seqused_k is None assert ctx.softcap == 0.0 @@ -1257,9 +1319,9 @@ def backward(ctx, dout, *args): out, dout, lse, - ctx.softmax_scale, - ctx.causal, - ctx.softcap, + softmax_scale=ctx.softmax_scale, + causal=ctx.causal, + softcap=ctx.softcap, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, seqused_q=seqused_q, @@ -1267,27 +1329,32 @@ def backward(ctx, dout, *args): deterministic=ctx.deterministic, ) - # TODO(wusiming): do we need to return None for other fwd inputs? - return dq, dk, dv + # [修改点 2] 必须补齐 None! + # forward 共有 16 个参数: + # 1-3: q, k, v (有梯度) + # 4-16: cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k, page_table, + # softmax_scale, causal, window_size, learnable_sink, softcap, + # num_splits, pack_gqa, deterministic (无梯度) + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None def flash_attn_func( - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), - learnable_sink: Optional[paddle.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, deterministic: bool = False, mask_mod: Optional[Callable] = None, - full_block_cnt: Optional[paddle.Tensor] = None, - full_block_idx: Optional[paddle.Tensor] = None, - mask_block_cnt: Optional[paddle.Tensor] = None, - mask_block_idx: Optional[paddle.Tensor] = None, + full_block_cnt: Optional[torch.Tensor] = None, + full_block_idx: Optional[torch.Tensor] = None, + mask_block_cnt: Optional[torch.Tensor] = None, + mask_block_idx: Optional[torch.Tensor] = None, ): return FlashAttnFunc.apply( q, @@ -1310,18 +1377,18 @@ def flash_attn_func( def flash_attn_varlen_func( - q: paddle.Tensor, - k: paddle.Tensor, - v: paddle.Tensor, - cu_seqlens_q: Optional[paddle.Tensor] = None, - cu_seqlens_k: Optional[paddle.Tensor] = None, - seqused_q: Optional[paddle.Tensor] = None, - seqused_k: Optional[paddle.Tensor] = None, - page_table: Optional[paddle.Tensor] = None, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, softmax_scale: Optional[float] = None, causal: bool = False, window_size: Tuple[Optional[int], Optional[int]] = (None, None), - learnable_sink: Optional[paddle.Tensor] = None, + learnable_sink: Optional[torch.Tensor] = None, softcap: float = 0.0, num_splits: int = 1, pack_gqa: Optional[bool] = None, @@ -1348,14 +1415,14 @@ def flash_attn_varlen_func( def _flash_attn_fwd_combine( - out_partial: paddle.Tensor, - lse_partial: paddle.Tensor, - out: paddle.Tensor, - lse: Optional[paddle.Tensor] = None, - cu_seqlens: Optional[paddle.Tensor] = None, - seqused: Optional[paddle.Tensor] = None, - num_splits_dynamic_ptr: Optional[paddle.Tensor] = None, - semaphore_to_reset: Optional[paddle.Tensor] = None, + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: torch.Tensor, + lse: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.Tensor] = None, + seqused: Optional[torch.Tensor] = None, + num_splits_dynamic_ptr: Optional[torch.Tensor] = None, + semaphore_to_reset: Optional[torch.Tensor] = None, ) -> None: """Forward combine kernel for split attention computation. @@ -1373,23 +1440,21 @@ def _flash_attn_fwd_combine( seqused: Used sequence lengths for each batch num_splits_dynamic_ptr: Dynamic number of splits per batch semaphore_to_reset: Semaphore for synchronization - k_block_size: Block size for head dimension - Returns: None """ # Input validation assert out_partial.ndim in [4, 5], "out_partial must have 4 or 5 dimensions" assert lse_partial.ndim in [3, 4], "lse_partial must have 3 or 4 dimensions" - assert out_partial.dtype in [paddle.float16, paddle.bfloat16, paddle.float32], ( + assert out_partial.dtype in [torch.float16, torch.bfloat16, torch.float32], ( "out_partial must be fp16, bf16, or fp32" ) - assert lse_partial.dtype == paddle.float32, "lse_partial must be fp32" - assert out_partial.place.is_gpu_place() and lse_partial.place.is_gpu_place(), ( + assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" + assert out_partial.is_cuda and lse_partial.is_cuda, ( "tensors must be on CUDA device" ) - assert out_partial.strides[-1] == 1, "out_partial must be contiguous in the last dimension" - assert lse_partial.strides[-2] == 1, "lse_partial must be contiguous in the seqlen dimension" + assert out_partial.stride(-1) == 1, "out_partial must be contiguous in the last dimension" + assert lse_partial.stride(-2) == 1, "lse_partial must be contiguous in the seqlen dimension" assert lse_partial.shape == out_partial.shape[:-1] # Determine if this is variable length based on dimensions @@ -1399,7 +1464,7 @@ def _flash_attn_fwd_combine( assert out.shape == out_partial.shape[1:], "out shape mismatch" if lse is not None: assert lse.shape == lse_partial.shape[1:], "lse shape mismatch" - assert lse.dtype == paddle.float32, "lse must be fp32" + assert lse.dtype == torch.float32, "lse must be fp32" # Validate optional tensors for t, name in [ @@ -1408,8 +1473,8 @@ def _flash_attn_fwd_combine( (num_splits_dynamic_ptr, "num_splits_dynamic_ptr"), ]: if t is not None: - assert t.dtype == paddle.int32, f"{name} must be int32" - assert t.place.is_gpu_place(), f"{name} must be on CUDA device" + assert t.dtype == torch.int32, f"{name} must be int32" + assert t.is_cuda, f"{name} must be on CUDA device" assert t.is_contiguous(), f"{name} must be contiguous" head_dim = out_partial.shape[-1] @@ -1453,11 +1518,11 @@ def _flash_attn_fwd_combine( optional_tensors ) - current_stream = cuda.CUstream(paddle.device.current_stream().stream_base.cuda_stream) + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) # Create combine kernel configuration - dtype = paddle2cute_dtype_map[out.dtype] - dtype_partial = paddle2cute_dtype_map[out_partial.dtype] + dtype = torch2cute_dtype_map[out.dtype] + dtype_partial = torch2cute_dtype_map[out_partial.dtype] compile_key = ( dtype, @@ -1525,14 +1590,14 @@ def _flash_attn_fwd_combine( def flash_attn_combine( - out_partial: paddle.Tensor, - lse_partial: paddle.Tensor, - out: Optional[paddle.Tensor] = None, - out_dtype: Optional[paddle.dtype] = None, - cu_seqlens: Optional[paddle.Tensor] = None, - seqused: Optional[paddle.Tensor] = None, + out_partial: torch.Tensor, + lse_partial: torch.Tensor, + out: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + cu_seqlens: Optional[torch.Tensor] = None, + seqused: Optional[torch.Tensor] = None, return_lse: bool = True, -) -> Tuple[paddle.Tensor, Optional[paddle.Tensor]]: +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """Flash Attention combine function for split attention computation. Combines partial outputs and log-sum-exp values from multiple splits @@ -1567,8 +1632,8 @@ def flash_attn_combine( # Input validation assert out_partial.ndim in [4, 5], "out_partial must have 4 or 5 dimensions" assert lse_partial.ndim in [3, 4], "lse_partial must have 3 or 4 dimensions" - assert out_partial.dtype == paddle.float32, "out_partial must be fp32 (from accumulation)" - assert lse_partial.dtype == paddle.float32, "lse_partial must be fp32" + assert out_partial.dtype == torch.float32, "out_partial must be fp32 (from accumulation)" + assert lse_partial.dtype == torch.float32, "lse_partial must be fp32" # Determine if this is variable length based on dimensions is_varlen = out_partial.ndim == 4 @@ -1576,7 +1641,7 @@ def flash_attn_combine( if is_varlen: # Variable length: (num_splits, total_q, num_heads, head_size) num_splits, total_q, num_heads, head_size = out_partial.shape - assert lse_partial.shape == [num_splits, total_q, num_heads], ( + assert lse_partial.shape == (num_splits, total_q, num_heads), ( "lse_partial shape mismatch for varlen" ) batch_size = 1 # Treat as single batch for varlen @@ -1584,7 +1649,7 @@ def flash_attn_combine( else: # Regular batched: (num_splits, batch_size, seqlen, num_heads, head_size) num_splits, batch_size, seqlen, num_heads, head_size = out_partial.shape - assert lse_partial.shape == [num_splits, batch_size, seqlen, num_heads], ( + assert lse_partial.shape == (num_splits, batch_size, seqlen, num_heads), ( "lse_partial shape mismatch" ) @@ -1593,20 +1658,28 @@ def flash_attn_combine( out_dtype = out_partial.dtype # Create output if not provided - place = out_partial.place + device = out_partial.device if out is None: if is_varlen: - out = paddle.zeros(shape=[total_q, num_heads, head_size], dtype=out_dtype) + out = torch.zeros(total_q, num_heads, head_size, dtype=out_dtype, device=device) else: - out = paddle.zeros(shape=[batch_size, seqlen, num_heads, head_size], dtype=out_dtype) + out = torch.zeros(batch_size, seqlen, num_heads, head_size, dtype=out_dtype, device=device) # Create lse output only if requested if return_lse: if is_varlen: - lse = paddle.full(shape=[num_heads, total_q], fill_value=float('-inf'), dtype=paddle.float32).transpose(0, 1) + lse = torch.full( + (num_heads, total_q), + fill_value=float('-inf'), + dtype=torch.float32, + device=device + ).transpose(0, 1) else: - lse = paddle.full( - shape=[batch_size, num_heads, seqlen], fill_value=float('-inf'), dtype=paddle.float32 + lse = torch.full( + (batch_size, num_heads, seqlen), + fill_value=float('-inf'), + dtype=torch.float32, + device=device ).transpose(1, 2) else: lse = None @@ -1621,18 +1694,18 @@ def flash_attn_combine( ) return out, lse -class FlashMaskFunc(paddle.autograd.PyLayer): +class FlashMaskFunc(torch.autograd.Function): @staticmethod def forward( ctx, - query: paddle.Tensor, - key: paddle.Tensor, - value: paddle.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, causal: bool = False, softmax_scale: float | None = None, - startend_row_indices: paddle.Tensor | None = None, - block_mask: paddle.Tensor | None = None, - ) -> paddle.Tensor | Tuple[paddle.Tensor, paddle.Tensor]: + startend_row_indices: torch.Tensor | None = None, + block_mask: torch.Tensor | None = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: out, lse = _flash_attn_fwd( query, key, @@ -1646,13 +1719,15 @@ def forward( ctx.save_for_backward(query, key, value, startend_row_indices, out, lse) ctx.softmax_scale = softmax_scale ctx.causal = causal - return [out, lse] + + return out, lse @staticmethod - def backward(ctx, dout, *args) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: - query, key, value, startend_row_indices, out, lse = ctx.saved_tensor() + def backward(ctx, dout, *args) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, None, None, None, None]: + query, key, value, startend_row_indices, out, lse = ctx.saved_tensors if startend_row_indices is not None: - flashmask_info = FlashMaskInfoPaddle( + # [修改点] 类名替换 + flashmask_info = FlashMaskInfoTorch( startend_row_indices=startend_row_indices, is_causal=ctx.causal, ) @@ -1665,35 +1740,41 @@ def backward(ctx, dout, *args) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Ten out, dout, lse, - flashmask_info, + flashmask_info=flashmask_info, causal=ctx.causal, - deterministic=paddle.get_flags(["FLAGS_cudnn_deterministic"])["FLAGS_cudnn_deterministic"], - ) - return dq, dk, dv + deterministic=torch.backends.cudnn.deterministic, + softmax_scale=ctx.softmax_scale, + ) + + # Forward inputs: + # 1. query (dq) + # 2. key (dk) + # 3. value (dv) + # 4. causal (None) + # 5. softmax_scale (None) + # 6. startend_row_indices (None) + # 7. block_mask (None) + return dq, dk, dv, None, None, None, None -# TODO(wusiming): should we align the parameters with those of paddle.nn.functional.flashmask_attention? def flashmask_attention( - query: paddle.Tensor, - key: paddle.Tensor, - value: paddle.Tensor, - startend_row_indices: paddle.Tensor | None = None, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + startend_row_indices: torch.Tensor | None = None, *, dropout: float = 0.0, causal: bool = False, window_size: int | tuple | None = None, return_softmax_lse: bool = False, return_seed_offset: bool = False, - fixed_seed_offset: paddle.Tensor | None = None, + fixed_seed_offset: torch.Tensor | None = None, rng_name: str = "", training: bool = True, name: str | None = None, softmax_scale: float | None = None, - block_mask: paddle.Tensor | None = None, -): - if ( - paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] == 4 - and (query.shape[-1] == 64 or query.shape[-1] == 128) - ): + block_mask: torch.Tensor | None = None, +): + if query.shape[-1] in [64, 128]: # and FLASH_ATTN_VERSION == 4: assert dropout == 0.0, ( "flashmask v4 does not support dropout" ) @@ -1719,13 +1800,13 @@ def flashmask_attention( assert block_mask is None, ( "flashmask v4 does not support block mask" ) - assert paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] == 4, ( - f"FLAGS_flash_attn_version:{paddle.base.framework.get_flags(['FLAGS_flash_attn_version'])['FLAGS_flash_attn_version']}, but running flashmask v4" - ) + # assert paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] == 4, ( + # f"FLAGS_flash_attn_version:{paddle.base.framework.get_flags(['FLAGS_flash_attn_version'])['FLAGS_flash_attn_version']}, but running flashmask v4" + # ) if startend_row_indices is not None: - assert startend_row_indices.dtype == paddle.int32, ( - f"startend_row_indices.dtype must be paddle.int32, but got {startend_row_indices.dtype}" + assert startend_row_indices.dtype == torch.int32, ( + f"startend_row_indices.dtype must be torch.int32, but got {startend_row_indices.dtype}" ) assert len(startend_row_indices.shape) == 4, ( f"startend_row_indices rank must be 4,but got {startend_row_indices.shape}" @@ -1768,27 +1849,33 @@ def flashmask_attention( query, key, value, - causal=causal, - softmax_scale=softmax_scale, - startend_row_indices=startend_row_indices, + causal, + softmax_scale, + startend_row_indices, + None, # block_mask, 对应 forward 的最后一个参数 ) if return_softmax_lse: - return [out, lse] + return out, lse # PyTorch 习惯返回 Tuple 而不是 List else: return out else: - original_flash_attn_version = paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] - if original_flash_attn_version == 4: - paddle.set_flags({"FLAGS_flash_attn_version": 2}) - assert ( - not causal or (query.shape[1] == key.shape[1]) - ), ( - f"Fallback to flashmask v1 is not supported when using causal mask " - f"and query/key sequence lengths differ (seqlen_q={query.shape[1]}, seqlen_k={key.shape[1]}). " - "Please ensure seqlen_q equals seqlen_k or disable causal." + + assert ( + not causal or (query.shape[1] == key.shape[1]) + ), ( + f"Fallback to flashmask v1 is not supported when using causal mask " + f"and query/key sequence lengths differ (seqlen_q={query.shape[1]}, seqlen_k={key.shape[1]}). " + "Please ensure seqlen_q equals seqlen_k or disable causal." + ) + + try: + from flashmask import flashmask_attention as flashmask_attention_impl + except ImportError: + raise ImportError( + "请安装 'flashmask3.0.0b1' 包。" ) try: - outputs = paddle.nn.functional.flashmask_attention( + outputs = flashmask_attention_impl( query=query, key=key, value=value, @@ -1805,86 +1892,81 @@ def flashmask_attention( softmax_scale=softmax_scale, block_mask=block_mask, ) - finally: - if original_flash_attn_version == 4: - paddle.set_flags({"FLAGS_flash_attn_version": 4}) + except ImportError: + raise ImportError("Could not import 'flashmask' package. Please ensure it is installed.") + return outputs -# Note(wusiming): do we need to align api to tridao? def flash_attention( - query: paddle.Tensor, - key: paddle.Tensor, - value: paddle.Tensor, - dropout=0.0, - causal=False, - return_softmax=False, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout: float = 0.0, + causal: bool = False, + return_softmax: bool = False, *, - fixed_seed_offset=None, - rng_name="", - training=True, - name=None, - softmax_scale=None, + fixed_seed_offset: Optional[torch.Tensor] = None, + rng_name: str = "", + training: bool = True, + name: Optional[str] = None, + softmax_scale: Optional[float] = None, ): - if ( - paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] == 4 - and (query.shape[-1] == 64 or query.shape[-1] == 128) - ): + if query.shape[-1] in [64, 128]: assert dropout == 0.0, ( - "flash attention 4 does not support dropout" + "flash attention v4 port does not support dropout" ) # Note(wusiming): return_softmax means return attn score, not lse assert not return_softmax, ( - "flash attention 4 does not support return_softmax" + "flash attention v4 port does not support return_softmax" ) - assert fixed_seed_offset is None , ( - "flash attention 4 does not support setting seed_offset" + assert fixed_seed_offset is None, ( + "flash attention v4 port does not support setting seed_offset" ) assert rng_name == "", ( - "flash attention 4 does not support setting rng_name" + "flash attention v4 port does not support setting rng_name" ) assert training, ( - "flash attention 4 does not support setting training to False" + "flash attention v4 port does not support setting training to False" ) assert name is None, ( - "flash attention 4 does not support setting name" + "flash attention v4 port does not support setting name" ) - # Note(wusiming): i dont think it is necessary to add a pylayer for flash_attention, just reuse flashmask + out, lse = FlashMaskFunc.apply( query, key, value, - causal=causal, - softmax_scale=softmax_scale, - startend_row_indices=None, + causal, + softmax_scale, + None, # startend_row_indices + None, # block_mask ) return out, None else: - original_flash_attn_version = paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] - if original_flash_attn_version == 4: - paddle.set_flags({"FLAGS_flash_attn_version": 2}) - assert ( - not causal or (query.shape[1] == key.shape[1]) - ), ( + if causal and query.shape[1] != key.shape[1]: + raise ValueError( f"Fallback to flash attention version 2 is not supported when using causal mask " f"and query/key sequence lengths differ (seqlen_q={query.shape[1]}, seqlen_k={key.shape[1]}). " "Please ensure seqlen_q equals seqlen_k or disable causal." ) try: - out, lse = paddle.nn.functional.flash_attention.flash_attention( - query=query, - key=key, - value=value, - dropout=dropout, - causal=causal, - return_softmax=return_softmax, - fixed_seed_offset=fixed_seed_offset, - rng_name=rng_name, - training=training, - name=name, - softmax_scale=softmax_scale, + from flash_attn import flash_attn_func + except ImportError: + raise ImportError( + "Fallback path requires 'flash_attn' library. " + "Please install it via `pip install flash-attn --no-build-isolation`." ) - finally: - if original_flash_attn_version == 4: - paddle.set_flags({"FLAGS_flash_attn_version": 4}) - return out, lse + result = flash_attn_func( + query, + key, + value, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + return_attn_probs=return_softmax, + ) + if return_softmax: + return result[0], result[1] + else: + return result, None