From 63e61877e1bab9c23484010ee0551d2fb9b86695 Mon Sep 17 00:00:00 2001 From: Siyu Lou Date: Fri, 13 Mar 2026 12:01:34 +0800 Subject: [PATCH] add rrattn estimate func and interface --- flashmask/flash_mask/utils/__init__.py | 22 + .../flash_mask/utils/block_mask_utils.py | 502 ++++++++ flashmask/flash_mask/utils/index_utils.py | 77 ++ .../utils/rr_attn_estimate_triton_op.py | 1122 +++++++++++++++++ .../flash_mask/utils/rr_attn_interface.py | 171 +++ 5 files changed, 1894 insertions(+) create mode 100644 flashmask/flash_mask/utils/__init__.py create mode 100644 flashmask/flash_mask/utils/block_mask_utils.py create mode 100644 flashmask/flash_mask/utils/index_utils.py create mode 100644 flashmask/flash_mask/utils/rr_attn_estimate_triton_op.py create mode 100644 flashmask/flash_mask/utils/rr_attn_interface.py diff --git a/flashmask/flash_mask/utils/__init__.py b/flashmask/flash_mask/utils/__init__.py new file mode 100644 index 00000000000..a85946e62d8 --- /dev/null +++ b/flashmask/flash_mask/utils/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility modules for Python-only FlashMask sparse attention wrappers.""" + +__all__ = [ + "block_mask_utils", + "index_utils", + "rr_attn_estimate_triton_op", + "rr_attn_interface", +] diff --git a/flashmask/flash_mask/utils/block_mask_utils.py b/flashmask/flash_mask/utils/block_mask_utils.py new file mode 100644 index 00000000000..581e9db6f16 --- /dev/null +++ b/flashmask/flash_mask/utils/block_mask_utils.py @@ -0,0 +1,502 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import paddle +import triton +import triton.language as tl + + +@triton.jit +def _load_bounds( + base_offset, + k_offsets, + load_mask, + ptr_start_lt, + ptr_end_lt, + ptr_start_ut, + ptr_end_ut, + causal: tl.constexpr, + mode: tl.constexpr, +): + INT_MAX: tl.constexpr = 2147483647 + INT_MIN: tl.constexpr = -2147483648 + + pad_lt = INT_MAX + pad_ut = INT_MIN + + b_lts = tl.load( + ptr_start_lt + base_offset + k_offsets, mask=load_mask, other=pad_lt + ) + + need_lte: tl.constexpr = (causal and mode == 2) or (not causal and mode == 4) + + if need_lte: + b_lte = tl.load( + ptr_end_lt + base_offset + k_offsets, mask=load_mask, other=pad_lt + ) + else: + b_lte = tl.full(b_lts.shape, pad_lt, dtype=tl.int32) + + if causal: + b_uts = tl.full(b_lts.shape, pad_ut, dtype=tl.int32) + else: + if mode == 4: + b_uts = tl.load( + ptr_start_ut + base_offset + k_offsets, + mask=load_mask, + other=pad_ut, + ) + else: + b_uts = tl.full(b_lts.shape, pad_ut, dtype=tl.int32) + + need_ute: tl.constexpr = (not causal) and (mode == 2 or mode == 4) + + if need_ute: + b_ute = tl.load( + ptr_end_ut + base_offset + k_offsets, mask=load_mask, other=pad_ut + ) + else: + b_ute = tl.full(b_lts.shape, pad_ut, dtype=tl.int32) + + return b_lts, b_lte, b_uts, b_ute + + +@triton.jit +def _is_block_fully_masked( + block_rows, + lts_max, + lte_min, + uts_max, + ute_min, +): + # since we pass exact row indices now, use "<" for end + in_lt = (block_rows[:, None] >= lts_max[None, :]) & ( + block_rows[:, None] < lte_min[None, :] + ) + in_ut = (block_rows[:, None] >= uts_max[None, :]) & ( + block_rows[:, None] < ute_min[None, :] + ) + + mask = in_lt | in_ut + return mask + + +@triton.jit +def check_fully_masked_state( + mask_ptr_base_offset, + k_offsets, + k_load_mask, + q_rows, + ptrs_strict_lt_start, + ptrs_strict_lt_end, + ptrs_strict_ut_start, + ptrs_strict_ut_end, + causal: tl.constexpr, + mode: tl.constexpr, +): + fm_lts, fm_lte, fm_uts, fm_ute = _load_bounds( + mask_ptr_base_offset, + k_offsets, + k_load_mask, + ptrs_strict_lt_start, + ptrs_strict_lt_end, + ptrs_strict_ut_start, + ptrs_strict_ut_end, + causal=causal, + mode=mode, + ) + + fm_geo = _is_block_fully_masked( + q_rows, + fm_lts, + fm_lte, + fm_uts, + fm_ute, + ) + fm_oob = ~k_load_mask[None, :] + + return fm_geo | fm_oob + + +@triton.jit +def _is_block_partially_masked( + block_rows, + lts_min, + lte_max, + uts_min, + ute_max, +): + # Logic: Overlap exists if Q is potentially inside [min_start, max_end) + overlap_lt = (block_rows[:, None] < lte_max[None, :]) & ( + block_rows[:, None] >= lts_min[None, :] + ) + overlap_ut = (block_rows[:, None] < ute_max[None, :]) & ( + block_rows[:, None] >= uts_min[None, :] + ) + + return overlap_lt | overlap_ut + + +@triton.jit +def check_partially_masked_state( + mask_ptr_base_offset, + k_offsets, + k_load_mask, + q_rows, + ptrs_perm_lt_start, + ptrs_perm_lt_end, + ptrs_perm_ut_start, + ptrs_perm_ut_end, + causal: tl.constexpr, + mode: tl.constexpr, +): + pm_lts, pm_lte, pm_uts, pm_ute = _load_bounds( + mask_ptr_base_offset, + k_offsets, + k_load_mask, + ptrs_perm_lt_start, + ptrs_perm_lt_end, + ptrs_perm_ut_start, + ptrs_perm_ut_end, + causal=causal, + mode=mode, + ) + + return _is_block_partially_masked(q_rows, pm_lts, pm_lte, pm_uts, pm_ute) + + +@triton.jit +def _compare_and_swap( + x, + ids, + flip, + i: tl.constexpr, + n_dims: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)] + y = tl.reshape(x, shape) + + # slice left/right with 'stride' 2**(n_dims - i - 1) + mask = tl.arange(0, 2)[None, :, None] + left = tl.broadcast_to(tl.sum(tl.where(mask == 0, y, 0), 1)[:, None, :], shape).to( + y.dtype + ) + right = tl.broadcast_to(tl.sum(tl.where(mask == 1, y, 0), 1)[:, None, :], shape).to( + y.dtype + ) + left = tl.reshape(left, x.shape) + right = tl.reshape(right, x.shape) + + # idx + y_idx = tl.reshape(ids, shape) + left_idx = tl.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape) + right_idx = tl.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape) + left_idx = tl.reshape(left_idx, x.shape).to(y_idx.dtype) + right_idx = tl.reshape(right_idx, x.shape).to(y_idx.dtype) + + # actual compare-and-swap + idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + ileft = left.to(idtype, bitcast=True) + iright = right.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + + cond = (left > right) != flip + ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix)) + new_ids = ids ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(ids)) + return ret.to(x.dtype, bitcast=True), new_ids + + +@triton.jit +def _bitonic_merge( + x, + ids, + stage: tl.constexpr, + order: tl.constexpr, + n_dims: tl.constexpr, +): + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + + # flip denotes whether to re-arrange sub-sequences of elements in ascending or + # descending order. + # if flip = 00000000... then all elements will be re-arranged ascendingly at this stage + # if flip = 00110011... then all the elements will be re-arranged alternatingly (with + # a stride of 2) at this stage + + if order == 2: + shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage] + flip = tl.reshape( + tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape + ) + else: + flip = order + + # perform `stage` rounds of `compare-and-swap` + for i in tl.static_range(stage): + x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims) + return x, ids + + +@triton.jit +def bitonic_argsort_device( + x, ids, n_dims: tl.constexpr, descending: tl.constexpr = tl.core.CONSTEXPR_0 +): + for i in tl.static_range(1, n_dims + 1): + x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims) + + return x, ids + + +@triton.jit +def top_p_kernel( + X_ptr, + Out_ptr, + stride_row, + threshold_p, + N_COLS, + BLOCK_SIZE: tl.constexpr, + NUM_DIMS: tl.constexpr, +): + pid = tl.program_id(0) + row_start_ptr = X_ptr + pid * stride_row + + offsets = tl.arange(0, BLOCK_SIZE) + mask_load = offsets < N_COLS + + # Load with 0.0 padding to calculate correct sum + x_raw = tl.load(row_start_ptr + offsets, mask=mask_load, other=0.0).to(tl.float32) + row_sum = tl.sum(x_raw, axis=0) + + out_row_ptr = Out_ptr + pid * stride_row + if row_sum == 0.0: + tl.store( + out_row_ptr + offsets, + tl.zeros([BLOCK_SIZE], dtype=tl.int8), + mask=mask_load, + ) + return + + # Calculate the actual threshold value based on the sum + actual_cutoff = row_sum * threshold_p + + padding_val = float("-inf") + x_for_sort = tl.where(mask_load, x_raw, padding_val) + ids = tl.arange(0, BLOCK_SIZE) # Initialize indices [0, 1, ... BLOCK_SIZE-1] + + # Perform Bitonic Sort (Descending) + # x_sorted: values from high to low + # ids_sorted: original indices corresponding to those values + x_sorted, ids_sorted = bitonic_argsort_device( + x_for_sort, ids, NUM_DIMS, descending=1 + ) + + cum_probs = tl.cumsum(x_sorted, axis=0) + mask_keep = (cum_probs - x_sorted) < actual_cutoff + + # Force padding elements to be False (just in case) + is_not_padding = x_sorted > padding_val + mask_keep = mask_keep & is_not_padding + + # Scatter Write (Restore original order) + mask_store = ids_sorted < N_COLS + tl.store(out_row_ptr + ids_sorted, mask_keep.to(tl.int8), mask=mask_store) + + +def find_blocks_topp(x: paddle.Tensor, p: float): + """ + Input: + x: [b, h, m, n] float tensor (probabilities, unnormalized) + p: float, threshold + Output: + mask: [b, h, m, n] bool tensor + """ + original_shape = x.shape + n = original_shape[-1] + + x_reshaped = x.reshape(-1, n).contiguous() + B = x_reshaped.shape[0] # Total number of rows + + block_size = triton.next_power_of_2(n) + if block_size < 1: + block_size = 1 + num_dims = int(math.log2(block_size)) + + output_mask = paddle.empty(x_reshaped.shape, dtype=paddle.bool, device=x.device) + + grid = (B,) + + top_p_kernel[grid]( + x_reshaped, + output_mask, + x_reshaped.strides[0], + p, + n, + BLOCK_SIZE=block_size, + NUM_DIMS=num_dims, + ) + + return output_mask.reshape(original_shape) + + +def find_blocks_chunked( + input_tensor, + current_index, + threshold, + num_to_choose, + decoding: bool, + mode: str = "both", + causal=True, +): + """ + Finds and selects relevant blocks of attention for transformer-based models based on a + threshold or a predefined number of blocks. + + Parameters: + - input_tensor (paddle.Tensor): The input tensor of shape (batch_size, head_num, chunk_num, block_num). + - current_index (int): The current index in the sequence processing. + - threshold (float or None): A threshold value used to determine the minimum attention weight sum. + - num_to_choose (int or None): The number of blocks to be selected, ensuring sufficient information retrieval. + - decoding (bool): If True, operates in decoding mode; otherwise, it's in encoding mode. + - mode (str): Defines the processing mode, either 'both', 'prefill', or 'decode'. + - causal (bool): If True, applies causal masking to prevent future information leakage. + + Returns: + - paddle.Tensor: A boolean mask of shape (batch_size, head_num, chunk_num, block_num), + indicating which blocks should be attended to. + """ + assert threshold is None or num_to_choose is None + batch_size, head_num, chunk_num, block_num = input_tensor.shape + # 0 -- -- -- -- current_index + # 0 -- -- -- -- -- current_index+1 + # 0 -- -- -- -- -- ----------- current_index + chunk_num - 1 + if mode == "prefill" and decoding: + return paddle.ones_like(input_tensor, dtype=paddle.bool) + if mode == "decode" and not decoding: + mask = paddle.ones_like(input_tensor, dtype=paddle.bool) + if causal: + mask[:, :, :, current_index : current_index + chunk_num] = paddle.tril( + paddle.ones(1, head_num, chunk_num, chunk_num) + ) + mask[:, :, current_index + chunk_num :, :] = 0 + return paddle.cat( + [ + paddle.ones_like(input_tensor, dtype=paddle.bool)[ + :, :, 0 : current_index + 1 + ], + paddle.zeros_like(input_tensor, dtype=paddle.bool)[ + :, :, current_index + 1 : + ], + ], + dim=-1, + ) + else: + return mask + input_tensor = input_tensor.astype("float32") + + if threshold is not None: + total_sum = input_tensor.sum(dim=-1, keepdim=True) + if isinstance(threshold, paddle.Tensor): + threshold = threshold.astype("float32") + required_sum = total_sum * threshold.unsqueeze(0).unsqueeze(-1).unsqueeze( + -1 + ).expand((batch_size, head_num, chunk_num, 1)) + else: + required_sum = total_sum * threshold + if causal: + mask = paddle.zeros_like(input_tensor, dtype=paddle.bool) + mask[:, :, :, 0] = 1 + mask[:, :, :, current_index : current_index + chunk_num] = ( + paddle.eye(chunk_num) + .unsqueeze(0) + .unsqueeze(0) + .expand(1, head_num, chunk_num, chunk_num) + ) + other_values = input_tensor.masked_fill(mask, 0.0) + sorted_values, _ = paddle.compat.sort(other_values, dim=-1, descending=True) + + sorted_values = paddle.cat( + [ + input_tensor.new_zeros((batch_size, head_num, chunk_num, 1)), + paddle.where(mask, input_tensor, 0.0).sum(dim=-1, keepdim=True), + sorted_values[:, :, :, :-2], + ], + dim=-1, + ) + + _, index = paddle.compat.sort( + paddle.where(mask, 100000 * (1 + input_tensor), input_tensor), + dim=-1, + descending=True, + ) + cumulative_sum_without_self = paddle.cat( + [ + sorted_values.new_zeros((batch_size, head_num, chunk_num, 1)), + sorted_values[:, :, :, 0:-1], + ], + dim=-1, + ).cumsum(dim=-1) + + index_mask = cumulative_sum_without_self < required_sum + index = paddle.where(index_mask, index, 0) + mask = mask.view(batch_size, head_num * chunk_num, block_num) + index = index.view(batch_size, head_num * chunk_num, block_num) + mask[:, paddle.arange(mask.shape[1]).unsqueeze(dim=-1), index] = True + mask = mask.view(batch_size, head_num, chunk_num, block_num) + # assert(bool((paddle.where(mask,input_tensor,0).sum(dim=-1,keepdim=True) >= required_sum*0.99).all())) + else: + mask = paddle.zeros_like(input_tensor, dtype=paddle.bool) + sorted_values, index = paddle.compat.sort( + input_tensor, dim=-1, descending=True + ) + cumulative_sum_without_self = paddle.cat( + [ + sorted_values.new_zeros((batch_size, head_num, chunk_num, 1)), + sorted_values[:, :, :, 0:-1], + ], + dim=-1, + ).cumsum(dim=-1) + index_mask = cumulative_sum_without_self < required_sum + index = paddle.where(index_mask, index, 0) + mask = mask.view(batch_size, head_num * chunk_num, block_num) + index = index.view(batch_size, head_num * chunk_num, block_num) + mask[ + :, + paddle.arange(mask.shape[1]).unsqueeze(dim=-1), + index, + ] = True + mask = mask.view(batch_size, head_num, chunk_num, block_num) + else: + raise NotImplementedError("block num chunk prefill not implemented") + + if causal and paddle.any(mask[:, :, :, current_index + chunk_num :]): + mask[:, :, :, current_index + chunk_num :] = False + + if causal: + if decoding: + assert mask[:, :, :, 0].all() and mask[:, :, :, -1].all() + else: + lambda_mask = paddle.zeros_like(input_tensor, dtype=bool) + lambda_mask[:, :, :, 0] = 1 + lambda_mask[:, :, :, current_index : current_index + chunk_num] = ( + paddle.eye(chunk_num) + .unsqueeze(0) + .unsqueeze(0) + .expand(1, head_num, chunk_num, chunk_num) + ) + assert paddle.where(lambda_mask, mask, True).all() + + return mask diff --git a/flashmask/flash_mask/utils/index_utils.py b/flashmask/flash_mask/utils/index_utils.py new file mode 100644 index 00000000000..6dbc4b639e8 --- /dev/null +++ b/flashmask/flash_mask/utils/index_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import triton +import triton.language as tl + + +@triton.jit +def scan_maxmin_chunked( + input_ptr, + output_max_ptr, + output_min_ptr, + seqlen, + num_chunks, + chunk_size: tl.constexpr, + BN: tl.constexpr, +): + INT_MAX: tl.constexpr = 2147483647 + INT_MIN: tl.constexpr = -2147483648 + + i_tile = tl.program_id(0) + i_bh = tl.program_id(1) + + p_tile = i_tile * BN + tl.arange(0, BN) + mask_tile = p_tile < seqlen + b_tile = tl.load(input_ptr + i_bh * seqlen + p_tile, mask=mask_tile) + + b_omax = tl.where(mask_tile, b_tile, INT_MIN).reshape( + (BN // chunk_size, chunk_size) + ) + b_omax = tl.max(b_omax, axis=1) + + b_omin = tl.where(mask_tile, b_tile, INT_MAX).reshape( + (BN // chunk_size, chunk_size) + ) + b_omin = tl.min(b_omin, axis=1) + + offs_out = tl.arange(0, BN // chunk_size) + i_tile * (BN // chunk_size) + mask_out = offs_out < num_chunks + tl.store(output_max_ptr + i_bh * num_chunks + offs_out, b_omax, mask=mask_out) + tl.store(output_min_ptr + i_bh * num_chunks + offs_out, b_omin, mask=mask_out) + + +def prepare_maxmin( + input: paddle.Tensor, chunk_size: int +) -> tuple[paddle.Tensor, paddle.Tensor]: + bsz, num_heads, seq_len = input.shape + num_chunks = (seq_len + chunk_size - 1) // chunk_size + + output_max = paddle.empty([bsz, num_heads, num_chunks], dtype=paddle.int32) + output_min = paddle.empty([bsz, num_heads, num_chunks], dtype=paddle.int32) + + BN = 512 + grid = ((seq_len + BN - 1) // BN, bsz * num_heads) + scan_maxmin_chunked[grid]( + input, + output_max, + output_min, + seq_len, + num_chunks, + chunk_size=chunk_size, + BN=BN, + ) + + return output_max, output_min diff --git a/flashmask/flash_mask/utils/rr_attn_estimate_triton_op.py b/flashmask/flash_mask/utils/rr_attn_estimate_triton_op.py new file mode 100644 index 00000000000..4ecf7155693 --- /dev/null +++ b/flashmask/flash_mask/utils/rr_attn_estimate_triton_op.py @@ -0,0 +1,1122 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass + +import paddle + +# paddle.compat.enable_torch_proxy() +import triton +import triton.language as tl + +from .block_mask_utils import ( + check_fully_masked_state, + check_partially_masked_state, + find_blocks_chunked, +) +from .index_utils import ( + prepare_maxmin, +) + +LOG2E = 1.4426950408889634 # 1 / ln(2) + + +@triton.jit +def flashmask_apply( + X, + q_rows, + base_offset, + k_offsets, + load_mask, + lt_start_ptr, + lt_end_ptr, + ut_start_ptr, + ut_end_ptr, + causal: tl.constexpr, + mode: tl.constexpr, +): + INT_MAX: tl.constexpr = 2147483647 + INT_MIN: tl.constexpr = -2147483648 + + pad_lt = INT_MAX + pad_ut = INT_MIN + + lts = tl.load(lt_start_ptr + base_offset + k_offsets, mask=load_mask, other=pad_lt) + if mode == 1: + dense_mask = q_rows[:, None] >= lts[None, :] + elif mode == 4: + lte = tl.load( + lt_end_ptr + base_offset + k_offsets, mask=load_mask, other=pad_lt + ) + uts = tl.load( + ut_start_ptr + base_offset + k_offsets, mask=load_mask, other=pad_ut + ) + ute = tl.load( + ut_end_ptr + base_offset + k_offsets, mask=load_mask, other=pad_ut + ) + dense_mask = ( + (q_rows[:, None] >= lts[None, :]) & (q_rows[:, None] < lte[None, :]) + ) | ((q_rows[:, None] >= uts[None, :]) & (q_rows[:, None] < ute[None, :])) + else: + if causal: + lte = tl.load( + lt_end_ptr + base_offset + k_offsets, + mask=load_mask, + other=pad_lt, + ) + dense_mask = (q_rows[:, None] >= lts[None, :]) & ( + q_rows[:, None] < lte[None, :] + ) + else: + ute = tl.load( + ut_end_ptr + base_offset + k_offsets, + mask=load_mask, + other=pad_ut, + ) + dense_mask = (q_rows[:, None] >= lts[None, :]) | ( + q_rows[:, None] < ute[None, :] + ) + + X = (1.0 - dense_mask) * X # set 0 for sum reduce + return X, dense_mask + + +@triton.jit +def check_dense_contains_partial_stride( + dense_flashmask, + q_token_mask, + k_token_mask, + BLOCK_SIZE: tl.constexpr, + STRIDE: tl.constexpr, +): + dense_flashmask = tl.where( + (q_token_mask[:, None] & k_token_mask[None, :]), + dense_flashmask.to(tl.int32), + tl.full([], 0, tl.int32), + ) + mask_stride_cnt = dense_flashmask.reshape( + BLOCK_SIZE // STRIDE, BLOCK_SIZE // STRIDE, STRIDE + ).sum(2) + mask_stride_valid_cnt = ( + k_token_mask.reshape(1, BLOCK_SIZE // STRIDE, STRIDE).to(tl.int32).sum(2) + ) + + mask_stride_is_partial = (mask_stride_cnt > 0) & ( + mask_stride_cnt < mask_stride_valid_cnt + ) + # return mask_stride_is_partial + return tl.sum(mask_stride_is_partial.to(tl.int32)) > 0 + + +@triton.jit +def gemm_fuse_softmax_causal( + q, + k, + out, + out_boundary_mask, + # --- Mask Pointers --- + lt_start_ptr, + lt_end_ptr, + ut_start_ptr, + ut_end_ptr, + lt_start_nstridemax, + lt_start_nstridemin, + lt_end_nstridemax, + lt_end_nstridemin, + ut_start_nstridemax, + ut_start_nstridemin, + ut_end_nstridemax, + ut_end_nstridemin, + # --- Params --- + scale: float, + seqlen_q: int, + seqlen_k: int, + num_q_blocks: int, + num_k_blocks: int, + N_STRIDES, + STRIDE: tl.constexpr, + HQ: tl.constexpr, + H: tl.constexpr, + HIDS: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + mode: tl.constexpr, +): + i_block = tl.program_id(0).to(tl.int64) + i_h = tl.program_id(1).to(tl.int64) + i_b = tl.program_id(2).to(tl.int64) + + ratio: tl.constexpr = BLOCK_SIZE // STRIDE + G: tl.constexpr = HQ // H + GIDS: tl.constexpr = HQ // HIDS + + i_hkv = i_h // G + i_hid = i_h // GIDS + + # ================= 1. Coordinates Setup ================= + q_stride_base = i_block * ratio + offs_q_stride = q_stride_base + tl.arange(0, ratio) + + mask_ptr_base_bh_stride = i_b * N_STRIDES * HIDS + i_hid * N_STRIDES + mask_ptr_base_bh_tokens = i_b * seqlen_k * HIDS + i_hid * seqlen_k + + # Load Q + p_q = q + i_b * seqlen_q * HQ * K + (i_block * BLOCK_SIZE) * HQ * K + i_h * K + p_q = ( + p_q + + tl.arange(0, ratio)[:, None] * (HQ * K * STRIDE) + + tl.arange(0, K)[None, :] + + HQ * K * (i_h % STRIDE) + ) + offs_tokens_q = ( + tl.arange(0, ratio) * STRIDE + i_block * BLOCK_SIZE + (i_h % STRIDE) + ) # round-robin offset + mask_q = offs_tokens_q < seqlen_q + # mask_q = offs_tokens_q[:, None] < seqlen_q + + b_q = tl.load(p_q, mask=mask_q[:, None], other=0.0) + b_q = (b_q * scale).to(b_q.dtype) + + # Softmax Accumulators + m_i = tl.full([ratio], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([ratio], dtype=tl.float32) + + # Causal / FA3 Setup + shift = seqlen_k - seqlen_q + # xattn v14 applies causal in reshaped (stride) space. + shift_stride = shift // STRIDE + + # k_safe_end: K blocks strictly to the left of the diagonal (Safe to compute fully) + # Condition: k_block_end <= q_block_start + shift + # (k + 1) * BLOCK <= i_block * BLOCK + shift + k_safe_end = (i_block * BLOCK_SIZE + shift) // BLOCK_SIZE + k_safe_end = min(num_k_blocks, max(0, k_safe_end)) + + # k_valid_end: The last K block that intersects with the diagonal or Q block + # Condition: k_block_start <= q_block_end_idx + shift + # k * BLOCK <= ((i_block + 1) * BLOCK - 1) + shift + k_valid_end = ((i_block + 1) * BLOCK_SIZE - 1 + shift) // BLOCK_SIZE + 1 + k_valid_end = min(num_k_blocks, max(k_safe_end, k_valid_end)) + + p_k_base = k + i_b * seqlen_k * H * K + i_hkv * K + offs_k_base = tl.arange(0, K)[:, None] + offs_stride_k = tl.arange(0, ratio) + offs_tokens_k = tl.arange(0, BLOCK_SIZE) + + # ================= 2. Loop 1: Statistics ================= + for iter in range(0, k_safe_end): + curr_stride_offset = mask_ptr_base_bh_stride + iter * ratio + curr_load_mask = (iter * ratio + offs_stride_k) < N_STRIDES + + fully_masked_stride_mask = check_fully_masked_state( + curr_stride_offset, + offs_stride_k, + curr_load_mask, + offs_tokens_q, + lt_start_nstridemax, + lt_end_nstridemin, + ut_start_nstridemax, + ut_end_nstridemin, + causal=True, + mode=mode, + ) + + if tl.sum(fully_masked_stride_mask.to(tl.int32)) < ratio * ratio: + # Load K & Compute Dot + p_k = ( + p_k_base + + iter * BLOCK_SIZE * H * K + + tl.arange(0, BLOCK_SIZE)[None, :] * H * K + + offs_k_base + ) + b_k = tl.load(p_k) + # CHANGE: NO REDUCE HERE + # logits = tl.dot(b_q, b_k) # [ratio, BLOCK_SIZE] + + partially_masked_stride_mask = check_partially_masked_state( + curr_stride_offset, + offs_stride_k, + curr_load_mask, + offs_tokens_q, + lt_start_nstridemin, + lt_end_nstridemax, + ut_start_nstridemin, + ut_end_nstridemax, + causal=True, + mode=mode, + ) + + real_partially_masked_stride_mask = ( + ~fully_masked_stride_mask + ) & partially_masked_stride_mask + if tl.sum(real_partially_masked_stride_mask) > 0: + logits = tl.dot(b_q, b_k) # [ratio, BLOCK_SIZE] + curr_token_offset = mask_ptr_base_bh_tokens + iter * BLOCK_SIZE + curr_token_load_mask = (iter * BLOCK_SIZE + offs_tokens_k) < seqlen_k + X, dense_flashmask = flashmask_apply( + logits, + offs_tokens_q, + curr_token_offset, + offs_tokens_k, + curr_token_load_mask, + lt_start_ptr, + lt_end_ptr, + ut_start_ptr, + ut_end_ptr, + causal=True, + mode=mode, + ) + + # Reduce token logits to get stride score + X = X.reshape(ratio, ratio, STRIDE).sum(axis=2) + fully_masked_by_fm = ( + dense_flashmask.reshape(ratio, ratio, STRIDE) + ).min(axis=2) == 1 + X = tl.where(fully_masked_by_fm, -1.0e6, X) + else: + # Reduce token logits to get stride score + X = tl.dot(b_q, b_k.reshape(K, ratio, STRIDE).sum(2)) + + X = tl.where(fully_masked_stride_mask, -1.0e6, X) + + # Update Stats + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + m_i = m_new + + for iter in range(k_safe_end, k_valid_end): + curr_stride_offset = mask_ptr_base_bh_stride + iter * ratio + curr_load_mask = (iter * ratio + offs_stride_k) < N_STRIDES + # k_col_min = iter * BLOCK_SIZE + + fully_masked_stride_mask = check_fully_masked_state( + curr_stride_offset, + offs_stride_k, + curr_load_mask, + offs_tokens_q, + lt_start_nstridemax, + lt_end_nstridemin, + ut_start_nstridemax, + ut_end_nstridemin, + causal=True, + mode=mode, + ) + + if tl.sum(fully_masked_stride_mask.to(tl.int32)) < ratio * ratio: + p_k = ( + p_k_base + + iter * BLOCK_SIZE * H * K + + tl.arange(0, BLOCK_SIZE)[None, :] * H * K + + offs_k_base + ) + mask_k = (tl.arange(0, BLOCK_SIZE)[None, :] + iter * BLOCK_SIZE) < seqlen_k + b_k = tl.load(p_k, mask=mask_k, other=0.0) + # b_k = b_k.reshape(K, ratio, STRIDE) + # b_k = tl.sum(b_k, axis=2) + logits = tl.dot(b_q, b_k) + + curr_token_offset = mask_ptr_base_bh_tokens + iter * BLOCK_SIZE + curr_token_load_mask = (iter * BLOCK_SIZE + offs_tokens_k) < seqlen_k + X, dense_flashmask = flashmask_apply( + logits, + offs_tokens_q, + curr_token_offset, + offs_tokens_k, + curr_token_load_mask, + lt_start_ptr, + lt_end_ptr, + ut_start_ptr, + ut_end_ptr, + causal=True, + mode=mode, + ) + # Reduce token logits to stride space first, then apply + # stride-level causal mask to align with xattn v14 behavior. + X = X.reshape(ratio, ratio, STRIDE).sum(axis=2) + global_offs_k_stride = iter * ratio + offs_stride_k + causal_mask_stride = global_offs_k_stride[None, :] > ( + offs_q_stride[:, None] + shift_stride + ) + fully_masked_by_fm = ( + dense_flashmask.reshape(ratio, ratio, STRIDE).min(axis=2) == 1 + ) + fully_masked_by_fm = fully_masked_by_fm | causal_mask_stride + X = tl.where(fully_masked_by_fm, -1.0e6, X) + + X = tl.where(fully_masked_stride_mask, -1.0e6, X) + + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + m_i = m_new + + # ================= 3. Output Preparation ================= + l_i_inv = 1.0 / l_i + + stride_out_b = (HQ * num_q_blocks * num_k_blocks).to(tl.int64) + stride_out_head = (num_q_blocks * num_k_blocks).to(tl.int64) + stride_out_q = num_k_blocks.to(tl.int64) + p_out = out + i_b * stride_out_b + i_h * stride_out_head + i_block * stride_out_q + p_out_mask = ( + out_boundary_mask + + i_b * stride_out_b + + i_h * stride_out_head + + i_block * stride_out_q + ) + + # ================= 4. Loop 2: Output (Exact Mirror) ================= + # 4.1 Non-Causal Blocks + for iter in range(0, k_safe_end): + curr_stride_offset = mask_ptr_base_bh_stride + iter * ratio + curr_load_mask = (iter * ratio + offs_stride_k) < N_STRIDES + + fully_masked_stride_mask = check_fully_masked_state( + curr_stride_offset, + offs_stride_k, + curr_load_mask, + offs_tokens_q, + lt_start_nstridemax, + lt_end_nstridemin, + ut_start_nstridemax, + ut_end_nstridemin, + causal=True, + mode=mode, + ) + + if tl.sum(fully_masked_stride_mask.to(tl.int32)) < ratio * ratio: + # Load K & Compute Dot + p_k = ( + p_k_base + + iter * BLOCK_SIZE * H * K + + tl.arange(0, BLOCK_SIZE)[None, :] * H * K + + offs_k_base + ) + b_k = tl.load(p_k) + + partially_masked_stride_mask = check_partially_masked_state( + curr_stride_offset, + offs_stride_k, + curr_load_mask, + offs_tokens_q, + lt_start_nstridemin, + lt_end_nstridemax, + ut_start_nstridemin, + ut_end_nstridemax, + causal=True, + mode=mode, + ) + + real_partially_masked_stride_mask = ( + ~fully_masked_stride_mask + ) & partially_masked_stride_mask + + if tl.sum(real_partially_masked_stride_mask) > 0: + logits = tl.dot(b_q, b_k) # [ratio, BLOCK_SIZE] + + curr_token_offset = mask_ptr_base_bh_tokens + iter * BLOCK_SIZE + + curr_token_load_mask = (iter * BLOCK_SIZE + offs_tokens_k) < seqlen_k + + X, dense_flashmask = flashmask_apply( + logits, + offs_tokens_q, + curr_token_offset, + offs_tokens_k, + curr_token_load_mask, + lt_start_ptr, + lt_end_ptr, + ut_start_ptr, + ut_end_ptr, + causal=True, + mode=mode, + ) + + # Reduce token logits to get stride score + + X = X.reshape(ratio, ratio, STRIDE).sum(axis=2) + fully_masked_by_fm = ( + dense_flashmask.reshape(ratio, ratio, STRIDE).min(axis=2) == 1 + ) + X = tl.where(fully_masked_by_fm, -1.0e6, X) + + has_partial = check_dense_contains_partial_stride( + dense_flashmask, + q_token_mask=mask_q, # [ratio] + k_token_mask=curr_token_load_mask, # [block_size] + BLOCK_SIZE=BLOCK_SIZE, + STRIDE=STRIDE, + ) + tl.store(p_out_mask + iter, has_partial.to(tl.int8)) + + else: + X = tl.dot(b_q, b_k.reshape(K, ratio, STRIDE).sum(2)) + tl.store(p_out_mask + iter, tl.zeros([], dtype=tl.int8)) + + X = tl.where(fully_masked_stride_mask, -1.0e6, X) + + # Normalization & Reduction + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(mask_q[:, None], X, 0) + X = tl.where(m_i[:, None] < -1.0e5, 0, X) + X = tl.sum(X, 1) # Sum K-strides + X = tl.sum(X, 0) # Sum Q-tokens + tl.store(p_out + iter, X.to(out.type.element_ty)) + + else: + tl.store(p_out + iter, tl.zeros([], dtype=out.type.element_ty)) + tl.store(p_out_mask + iter, tl.zeros([], dtype=tl.int8)) + + # 4.2 Causal Block + for iter in range(k_safe_end, k_valid_end): + curr_stride_offset = mask_ptr_base_bh_stride + iter * ratio + + curr_load_mask = (iter * ratio + offs_stride_k) < N_STRIDES + + fully_masked_stride_mask = check_fully_masked_state( + curr_stride_offset, + offs_stride_k, + curr_load_mask, + offs_tokens_q, + lt_start_nstridemax, + lt_end_nstridemin, + ut_start_nstridemax, + ut_end_nstridemin, + causal=True, + mode=mode, + ) + + if tl.sum(fully_masked_stride_mask.to(tl.int32)) < ratio * ratio: + p_k = ( + p_k_base + + iter * BLOCK_SIZE * H * K + + tl.arange(0, BLOCK_SIZE)[None, :] * H * K + + offs_k_base + ) + + mask_k = (tl.arange(0, BLOCK_SIZE)[None, :] + iter * BLOCK_SIZE) < seqlen_k + + b_k = tl.load(p_k, mask=mask_k, other=0.0) + + logits = tl.dot(b_q, b_k) + partially_masked_stride_mask = check_partially_masked_state( + curr_stride_offset, + offs_stride_k, + curr_load_mask, + offs_tokens_q, + lt_start_nstridemin, + lt_end_nstridemax, + ut_start_nstridemin, + ut_end_nstridemax, + causal=True, + mode=mode, + ) + + real_partially_masked_stride_mask = ( + ~fully_masked_stride_mask + ) & partially_masked_stride_mask + + curr_token_offset = mask_ptr_base_bh_tokens + iter * BLOCK_SIZE + + curr_token_load_mask = (iter * BLOCK_SIZE + offs_tokens_k) < seqlen_k + + X, dense_flashmask = flashmask_apply( + logits, + offs_tokens_q, + curr_token_offset, + offs_tokens_k, + curr_token_load_mask, + lt_start_ptr, + lt_end_ptr, + ut_start_ptr, + ut_end_ptr, + causal=True, + mode=mode, + ) + # Reduce token logits to stride space first, then apply + # stride-level causal mask to align with xattn v14 behavior. + X = X.reshape(ratio, ratio, STRIDE).sum(axis=2) + global_offs_k_stride = iter * ratio + offs_stride_k + causal_mask_stride = global_offs_k_stride[None, :] > ( + offs_q_stride[:, None] + shift_stride + ) + fully_masked_by_fm = ( + dense_flashmask.reshape(ratio, ratio, STRIDE).min(axis=2) == 1 + ) + fully_masked_by_fm = fully_masked_by_fm | causal_mask_stride + + X = tl.where(fully_masked_by_fm, -1.0e6, X) + has_partial = check_dense_contains_partial_stride( + dense_flashmask, + q_token_mask=mask_q, + k_token_mask=curr_token_load_mask, + BLOCK_SIZE=BLOCK_SIZE, + STRIDE=STRIDE, + ) + tl.store(p_out_mask + iter, has_partial.to(tl.int8)) + + # Explicitly mask out fully masked stride blocks + X = tl.where(fully_masked_stride_mask, -1.0e6, X) + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(m_i[:, None] < -1.0e5, 0, X) + X = tl.where(mask_q[:, None], X, 0) + X = tl.sum(X, 1) + X = tl.sum(X, 0) + tl.store(p_out + iter, X.to(out.type.element_ty)) + else: + tl.store(p_out + iter, tl.zeros([], dtype=out.type.element_ty)) + tl.store(p_out_mask + iter, tl.zeros([], dtype=tl.int8)) + + for iter in range(k_valid_end, num_k_blocks): + tl.store(p_out + iter, tl.zeros([], dtype=out.type.element_ty)) + tl.store(p_out_mask + iter, tl.zeros([], dtype=tl.int8)) + + +@triton.jit +def gemm_fuse_softmax_non_causal( + q, + k, + out, + out_boundary_mask, + # --- Mask Pointers --- + lt_start_ptr, + lt_end_ptr, + ut_start_ptr, + ut_end_ptr, + lt_start_nstridemax, + lt_start_nstridemin, + lt_end_nstridemax, + lt_end_nstridemin, + ut_start_nstridemax, + ut_start_nstridemin, + ut_end_nstridemax, + ut_end_nstridemin, + # --- Params --- + scale: float, + seqlen_q: int, + seqlen_k: int, + num_q_blocks: int, + num_k_blocks: int, + N_STRIDES, + STRIDE: tl.constexpr, + HQ: tl.constexpr, + H: tl.constexpr, + HIDS: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + mode: tl.constexpr, +): + """ + Non-Causal (Bidirectional) Version: + 1. Loop over ALL K blocks (0 to num_k_blocks). + 2. No "Diagonal/Causal" check logic. + 3. Block Mask logic remains active (controlled by causal=False). + """ + + i_block = tl.program_id(0).to(tl.int64) + i_h = tl.program_id(1).to(tl.int64) + i_b = tl.program_id(2).to(tl.int64) + + ratio: tl.constexpr = BLOCK_SIZE // STRIDE + G: tl.constexpr = HQ // H + GIDS: tl.constexpr = HQ // HIDS + + i_hkv = i_h // G + i_hid = i_h // GIDS + + # ================= 1. Coordinates Setup ================= + + mask_ptr_base_bh_stride = i_b * N_STRIDES * HIDS + i_hid * N_STRIDES + mask_ptr_base_bh_tokens = i_b * seqlen_k * HIDS + i_hid * seqlen_k + # Load Q (Round-Robin Sampling) + + p_q = q + i_b * seqlen_q * HQ * K + (i_block * BLOCK_SIZE) * HQ * K + i_h * K + + p_q = ( + p_q + + tl.arange(0, ratio)[:, None] * (HQ * K * STRIDE) + + tl.arange(0, K)[None, :] + + HQ * K * (i_h % STRIDE) + ) + + offs_tokens_q = tl.arange(0, ratio) * STRIDE + i_block * BLOCK_SIZE + (i_h % STRIDE) + + mask_q = offs_tokens_q < seqlen_q + b_q = tl.load(p_q, mask=mask_q[:, None], other=0.0) + b_q = (b_q * scale).to(b_q.dtype) + + # Softmax Accumulators + m_i = tl.full([ratio], float("-inf"), dtype=tl.float32) + l_i = tl.zeros([ratio], dtype=tl.float32) + + # K Pointers Setup + p_k_base = k + i_b * seqlen_k * H * K + i_hkv * K + offs_k_base = tl.arange(0, K)[:, None] + offs_stride_k = tl.arange(0, ratio) + offs_tokens_k = tl.arange(0, BLOCK_SIZE) + # ================= 2. Loop 1: Statistics ================= + # Iterate over ALL K blocks (No causal split) + + for iter in range(0, num_k_blocks): + curr_stride_offset = mask_ptr_base_bh_stride + iter * ratio + curr_load_mask = (iter * ratio + offs_stride_k) < N_STRIDES + + # [Check Fully Masked] + # causal=False affects logic inside check (e.g. loads UT bounds) + fully_masked_stride_mask = check_fully_masked_state( + curr_stride_offset, + offs_stride_k, + curr_load_mask, + offs_tokens_q, + lt_start_nstridemax, + lt_end_nstridemin, + ut_start_nstridemax, + ut_end_nstridemin, + causal=False, + mode=mode, + ) + + if tl.sum(fully_masked_stride_mask.to(tl.int32)) < ratio * ratio: + # Load K & Compute Dot + p_k = ( + p_k_base + + iter * BLOCK_SIZE * H * K + + tl.arange(0, BLOCK_SIZE)[None, :] * H * K + + offs_k_base + ) + + mask_k = (tl.arange(0, BLOCK_SIZE)[None, :] + iter * BLOCK_SIZE) < seqlen_k + + b_k = tl.load(p_k, mask=mask_k, other=0.0) + # Compute Scores: [ratio, K] @ [K, BLOCK_SIZE] -> [ratio, BLOCK_SIZE] + # logits = tl.dot(b_q, b_k) + # [Check Partial Mask] + + partially_masked_stride_mask = check_partially_masked_state( + curr_stride_offset, + offs_stride_k, + curr_load_mask, + offs_tokens_q, + lt_start_nstridemin, + lt_end_nstridemax, + ut_start_nstridemin, + ut_end_nstridemax, + causal=False, + mode=mode, + ) + + real_partially_masked_stride_mask = ( + ~fully_masked_stride_mask + ) & partially_masked_stride_mask + + if tl.sum(real_partially_masked_stride_mask) > 0: + logits = tl.dot(b_q, b_k) + curr_token_offset = mask_ptr_base_bh_tokens + iter * BLOCK_SIZE + curr_token_load_mask = (iter * BLOCK_SIZE + offs_tokens_k) < seqlen_k + + X, dense_flashmask = flashmask_apply( + logits, + offs_tokens_q, + curr_token_offset, + offs_tokens_k, + curr_token_load_mask, + lt_start_ptr, + lt_end_ptr, + ut_start_ptr, + ut_end_ptr, + causal=False, + mode=mode, + ) + + # Reduce token logits to get stride score + X = X.reshape(ratio, ratio, STRIDE).sum(axis=2) + + fully_masked_by_fm = ( + dense_flashmask.reshape(ratio, ratio, STRIDE).min(axis=2) == 1 + ) + X = tl.where(fully_masked_by_fm, -1.0e6, X) + + else: + X = tl.dot(b_q, b_k.reshape(K, ratio, STRIDE).sum(2)) + + # Explicitly mask out fully masked stride blocks + X = tl.where(fully_masked_stride_mask, -1.0e6, X) + + # Update Stats + m_local = tl.max(X, 1) + m_new = tl.maximum(m_i, m_local) + alpha = tl.math.exp2(m_i - m_new) + X = X - m_new[:, None] + l_local = tl.sum(tl.math.exp2(X), 1) + l_i = l_i * alpha + l_local + m_i = m_new + + # ================= 3. Output Preparation ================= + + l_i_inv = 1.0 / l_i + stride_out_b = (HQ * num_q_blocks * num_k_blocks).to(tl.int64) + stride_out_head = (num_q_blocks * num_k_blocks).to(tl.int64) + stride_out_q = num_k_blocks.to(tl.int64) + + p_out = out + i_b * stride_out_b + i_h * stride_out_head + i_block * stride_out_q + + p_out_mask = ( + out_boundary_mask + + i_b * stride_out_b + + i_h * stride_out_head + + i_block * stride_out_q + ) + + # ================= 4. Loop 2: Output (Exact Mirror) ================= + + for iter in range(0, num_k_blocks): + curr_stride_offset = mask_ptr_base_bh_stride + iter * ratio + curr_load_mask = (iter * ratio + offs_stride_k) < N_STRIDES + + fully_masked_stride_mask = check_fully_masked_state( + curr_stride_offset, + offs_stride_k, + curr_load_mask, + offs_tokens_q, + lt_start_nstridemax, + lt_end_nstridemin, + ut_start_nstridemax, + ut_end_nstridemin, + causal=False, + mode=mode, + ) + + if tl.sum(fully_masked_stride_mask.to(tl.int32)) < ratio * ratio: + p_k = ( + p_k_base + + iter * BLOCK_SIZE * H * K + + tl.arange(0, BLOCK_SIZE)[None, :] * H * K + + offs_k_base + ) + mask_k = (tl.arange(0, BLOCK_SIZE)[None, :] + iter * BLOCK_SIZE) < seqlen_k + + b_k = tl.load(p_k, mask=mask_k, other=0.0) + + partially_masked_stride_mask = check_partially_masked_state( + curr_stride_offset, + offs_stride_k, + curr_load_mask, + offs_tokens_q, + lt_start_nstridemin, + lt_end_nstridemax, + ut_start_nstridemin, + ut_end_nstridemax, + causal=False, + mode=mode, + ) + + real_partially_masked_stride_mask = ( + ~fully_masked_stride_mask + ) & partially_masked_stride_mask + + if tl.sum(real_partially_masked_stride_mask) > 0: + logits = tl.dot(b_q, b_k) + curr_token_offset = mask_ptr_base_bh_tokens + iter * BLOCK_SIZE + + curr_token_load_mask = (iter * BLOCK_SIZE + offs_tokens_k) < seqlen_k + + X, dense_flashmask = flashmask_apply( + logits, + offs_tokens_q, + curr_token_offset, + offs_tokens_k, + curr_token_load_mask, + lt_start_ptr, + lt_end_ptr, + ut_start_ptr, + ut_end_ptr, + causal=False, + mode=mode, + ) + # Reduce token logits to get stride score + X = X.reshape(ratio, ratio, STRIDE).sum(axis=2) + + fully_masked_by_fm = ( + dense_flashmask.reshape(ratio, ratio, STRIDE).min(axis=2) == 1 + ) + X = tl.where(fully_masked_by_fm, -1.0e6, X) + + has_partial = check_dense_contains_partial_stride( + dense_flashmask, + q_token_mask=mask_q, + k_token_mask=curr_token_load_mask, + BLOCK_SIZE=BLOCK_SIZE, + STRIDE=STRIDE, + ) + tl.store(p_out_mask + iter, has_partial.to(tl.int8)) + + else: + # Reduce token logits to get stride score + X = tl.dot(b_q, b_k.reshape(K, ratio, STRIDE).sum(2)) + tl.store(p_out_mask + iter, tl.zeros([], dtype=tl.int8)) + + X = tl.where(fully_masked_stride_mask, -1.0e6, X) + + # Normalization & Reduction + X = tl.exp2(X - m_i[:, None]) * l_i_inv[:, None] + X = tl.where(mask_q[:, None], X, 0) + X = tl.where(m_i[:, None] < -1.0e5, 0, X) + X = tl.sum(X, 1) # Sum K-strides + X = tl.sum(X, 0) # Sum Q-tokens + tl.store(p_out + iter, X.to(out.type.element_ty)) + + else: + tl.store(p_out + iter, tl.zeros([], dtype=out.type.element_ty)) + tl.store(p_out_mask + iter, tl.zeros([], dtype=tl.int8)) + + +@dataclass(frozen=True) +class RawPtrs: + # token-level: [B, HIDS, seqlen_q] + lt_start: paddle.Tensor + lt_end: paddle.Tensor + ut_start: paddle.Tensor + ut_end: paddle.Tensor + + +@dataclass(frozen=True) +class StrideMaxMinPtrs: + # stride-level: [B, HIDS, n_strides] + lt_start_max: paddle.Tensor + lt_start_min: paddle.Tensor + lt_end_max: paddle.Tensor + lt_end_min: paddle.Tensor + ut_start_max: paddle.Tensor + ut_start_min: paddle.Tensor + ut_end_max: paddle.Tensor + ut_end_min: paddle.Tensor + n_strides: int + + +def _require(cond: bool, msg: str): + if not cond: + raise ValueError(msg) + + +def _extract_raw_ptrs( + startend_row_indices: paddle.Tensor, + causal: bool, +) -> tuple[int, RawPtrs]: + """ + startend_row_indices: [B, HIDS, seqlen_q, mode], mode in {1,2,4} + - mode=1: only lt_start + - mode=2: + causal=True -> (lt_start, lt_end) + causal=False -> (lt_start, ut_end) + - mode=4: (lt_start, lt_end, ut_start, ut_end) + """ + mode = startend_row_indices.shape[-1] + _require(mode in (1, 2, 4), f"Unsupported mode={mode}, expected 1/2/4") + _require( + not (causal and mode == 4), + "mode=4 is only valid when causal=False in FlashMask semantics", + ) + + # 统一保证 contiguous + x = startend_row_indices.contiguous() + + lt_start = x[..., 0].contiguous() + + lt_end = lt_start + ut_start = lt_start + ut_end = lt_start + + if mode == 2: + if causal: + lt_end = x[..., 1].contiguous() + else: + ut_end = x[..., 1].contiguous() + elif mode == 4: + lt_end = x[..., 1].contiguous() + ut_start = x[..., 2].contiguous() + ut_end = x[..., 3].contiguous() + + return mode, RawPtrs( + lt_start=lt_start, lt_end=lt_end, ut_start=ut_start, ut_end=ut_end + ) + + +def _prepare_stride_maxmin_ptrs( + raw: RawPtrs, + mode: int, + causal: bool, + stride: int, +) -> StrideMaxMinPtrs: + _require(stride > 0, "stride must be positive") + + lt_start_max, lt_start_min = prepare_maxmin(raw.lt_start, stride) + n_strides = lt_start_max.shape[2] + + dummy_max = lt_start_max + + lt_end_max = lt_end_min = dummy_max + ut_start_max = ut_start_min = dummy_max + ut_end_max = ut_end_min = dummy_max + + if mode == 2: + if causal: + lt_end_max, lt_end_min = prepare_maxmin(raw.lt_end, stride) + else: + ut_end_max, ut_end_min = prepare_maxmin(raw.ut_end, stride) + elif mode == 4: + lt_end_max, lt_end_min = prepare_maxmin(raw.lt_end, stride) + ut_start_max, ut_start_min = prepare_maxmin(raw.ut_start, stride) + ut_end_max, ut_end_min = prepare_maxmin(raw.ut_end, stride) + + return StrideMaxMinPtrs( + lt_start_max=lt_start_max, + lt_start_min=lt_start_min, + lt_end_max=lt_end_max, + lt_end_min=lt_end_min, + ut_start_max=ut_start_max, + ut_start_min=ut_start_min, + ut_end_max=ut_end_max, + ut_end_min=ut_end_min, + n_strides=n_strides, + ) + + +@paddle.compat.use_torch_proxy_guard() +def rr_attn_estimate_triton_func( + q: paddle.Tensor, + k: paddle.Tensor, + startend_row_indices: paddle.Tensor, + stride: int = 8, + causal: bool = True, + threshold: float = 1.0, +) -> paddle.Tensor: + """ + Returns: + attn_sums: [B, HQ, ceil(seqlen_q/BS), ceil(seqlen_k/BS)] + boundary_protection_mask: same shape, bool + """ + _require( + startend_row_indices.ndim == 4, + "startend_row_indices must be [B, HIDS, seqlen_q, mode]", + ) + + bsz, q_len, num_q_heads, head_dim = q.shape + bsz2, kv_len, num_kv_heads, _ = k.shape + _require(bsz2 == bsz, "q/k batch size mismatch") + + _require( + startend_row_indices.shape[0] == bsz, + "startend_row_indices batch mismatch", + ) + _require( + startend_row_indices.shape[2] == kv_len, + "startend_row_indices seqlen_k mismatch", + ) + + num_indices_heads = startend_row_indices.shape[1] + _require( + num_q_heads % num_kv_heads == 0, + "MHA/GQA requires num_q_heads % num_kv_heads == 0", + ) + _require( + num_q_heads % num_indices_heads == 0, + "Require num_q_heads % num_indices_heads == 0 for head mapping", + ) + + _require( + startend_row_indices.place == q.place, + "startend_row_indices must be on the same device as q", + ) + _require(stride > 0, "stride must be positive") + + mode, raw = _extract_raw_ptrs(startend_row_indices, causal) + stride_mm = _prepare_stride_maxmin_ptrs(raw, mode, causal, stride) + + # --- 5. Kernel Launch Setup --- + BLOCK_SIZE = 128 + num_q_blocks = triton.cdiv(q_len, BLOCK_SIZE) + num_k_blocks = triton.cdiv(kv_len, BLOCK_SIZE) + + attn_sums = paddle.empty( + (bsz, num_q_heads, num_q_blocks, num_k_blocks), + dtype=q.dtype, + ) + + boundary_protection_mask = paddle.empty( + (bsz, num_q_heads, num_q_blocks, num_k_blocks), + dtype=paddle.bool, + ) + + grid = (num_q_blocks, num_q_heads, bsz) + + scale = LOG2E / math.sqrt(head_dim) / stride + + kernel = gemm_fuse_softmax_causal if causal else gemm_fuse_softmax_non_causal + + kernel[grid]( + q, + k, + attn_sums, + boundary_protection_mask, + # raw pointers (token-level) + raw.lt_start, + raw.lt_end, + raw.ut_start, + raw.ut_end, + # stride max/min pointers + stride_mm.lt_start_max, + stride_mm.lt_start_min, + stride_mm.lt_end_max, + stride_mm.lt_end_min, + stride_mm.ut_start_max, + stride_mm.ut_start_min, + stride_mm.ut_end_max, + stride_mm.ut_end_min, + # meta + scale=scale, + seqlen_q=q_len, + seqlen_k=kv_len, + num_q_blocks=num_q_blocks, + num_k_blocks=num_k_blocks, + N_STRIDES=stride_mm.n_strides, + STRIDE=stride, + HQ=num_q_heads, + H=num_kv_heads, + HIDS=num_indices_heads, + K=head_dim, + BLOCK_SIZE=BLOCK_SIZE, + mode=mode, + ) + + return ( + attn_sums, + boundary_protection_mask, + # find_blocks_topp(attn_sums, threshold), + find_blocks_chunked( + attn_sums, + 0, + threshold, + None, + decoding=False, + mode="prefill", + causal=causal, + ), + ) diff --git a/flashmask/flash_mask/utils/rr_attn_interface.py b/flashmask/flash_mask/utils/rr_attn_interface.py new file mode 100644 index 00000000000..c0b5d06416a --- /dev/null +++ b/flashmask/flash_mask/utils/rr_attn_interface.py @@ -0,0 +1,171 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn.functional as F + +from .rr_attn_estimate_triton_op import ( + rr_attn_estimate_triton_func, +) + + +def rr_attention( + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + startend_row_indices: paddle.Tensor, + *, + threshold: float = 1.0, + stride: int = 8, + causal: bool = False, + dropout: float = 0.0, + training: bool = True, + keep_sink: bool = True, + keep_last: bool = True, + return_softmax_lse: bool = False, + return_seed_offset: bool = False, +) -> paddle.Tensor: + """ + RRAttention using threshold-based sparse pattern estimation. + + This function implements an efficient attention mechanism that uses a + threshold-based strategy to reduce computation in attention blocks. It estimates + attention patterns using a custom triton kernel and applies sparse attention with + FlashMask. + + Args: + query (paddle.Tensor): + Query tensor with shape [batch_size, seq_len_q, num_heads, head_dim] + key (paddle.Tensor): + Key tensor with shape [batch_size, seq_len_k, num_heads, head_dim] + value (paddle.Tensor): + Value tensor with shape [batch_size, seq_len_k, num_heads, head_dim] + startend_row_indices (paddle.Tensor): + See flashmask_attention for details. + threshold (float, optional): + Sparsity threshold in range [0, 1]. Higher values produce sparser patterns. + Default: 1.0 (full attention) + stride (int, optional): + Stride for attention pattern estimation. Controls granularity of block processing. + Default: 8 + causal (bool, optional): + Whether to apply causal masking. Default: False + dropout (float, optional): + Dropout probability for attention weights. Default: 0.0 + training (bool, optional): + Whether in training mode. Default: True + return_softmax_lse (bool, optional): + Whether to return log-sum-exp values. Default: False + return_seed_offset (bool, optional): + Whether to return seed offset. Default: False + + Returns: + paddle.Tensor: + Attention output with shape [batch_size, seq_len_q, num_heads, head_dim] + If return_softmax_lse is True and return_seed_offset is True, returns tuple: + (output, softmax_lse, seed_offset) + + Raises: + ValueError: If input tensors have incompatible shapes or invalid parameters + RuntimeError: If triton kernel execution fails + + Example: + >>> import paddle + >>> from rr_attn_interface import rr_attention + >>> + >>> # Create sample tensors + >>> batch_size, seq_len_q, seq_len_k, num_heads, head_dim = 2, 512, 512, 8, 64 + >>> query = paddle.randn([batch_size, seq_len_q, num_heads, head_dim]) + >>> key = paddle.randn([batch_size, seq_len_k, num_heads, head_dim]) + >>> value = paddle.randn([batch_size, seq_len_k, num_heads, head_dim]) + >>> + >>> # Apply RR attention with threshold 0.8 + >>> output = rr_attention( + ... query, key, value, startend_row_indices, + ... threshold=0.8, causal=True, training=True + ... ) + >>> print(output.shape) # [2, 512, 8, 64] + + Note: + - When threshold=1.0, falls back to standard flashmask_attention + - The sparse pattern is constructed by boundary_mask and top-p mask + """ + + # Fast path: full attention when threshold is maximum + if startend_row_indices is None or threshold == 1.0: + return F.flashmask_attention( + query, + key, + value, + startend_row_indices=startend_row_indices, + dropout=dropout, + causal=causal, + training=training, + return_softmax_lse=return_softmax_lse, + return_seed_offset=return_seed_offset, + ) + + # Ensure startend_row_indices has same number of heads as query + _, num_heads_q = query.shape[0], query.shape[2] + num_heads_indices = startend_row_indices.shape[1] + if num_heads_indices != num_heads_q: + if num_heads_q % num_heads_indices != 0: + raise ValueError( + f"query heads ({num_heads_q}) must be divisible by " + f"startend_row_indices heads ({num_heads_indices})" + ) + repeat_factor = num_heads_q // num_heads_indices + startend_row_indices = startend_row_indices.repeat_interleave( + repeat_factor, axis=1 + ).contiguous() + + with paddle.no_grad(): + attn_sums, boundary_mask, topp_mask = rr_attn_estimate_triton_func( + q=query, + k=key, + startend_row_indices=startend_row_indices, + stride=stride, + threshold=threshold, + causal=causal, + ) + + # Combine masks: boundary protection + top-p sparsity + block_mask = paddle.logical_or(boundary_mask, topp_mask).astype(paddle.int32) + if keep_sink: + block_mask[:, :, :, 0] = 1 + if keep_last: + block_mask[:, :, -1, :] = 1 + + # Apply sparse attention with computed block mask + return F.flashmask_attention( + query, + key, + value, + startend_row_indices=startend_row_indices, + dropout=dropout, + causal=causal, + training=training, + return_softmax_lse=return_softmax_lse, + return_seed_offset=return_seed_offset, + block_mask=block_mask, + ) + + +# if __name__ == '__main__': +# query = paddle.randn((1, 128, 8, 128), dtype=paddle.bfloat16) +# key = paddle.randn((1, 128, 2, 128), dtype=paddle.bfloat16) +# value = paddle.randn((1, 128, 2, 128), dtype=paddle.bfloat16) + +# startend_row_indices = paddle.full([1, 1, 128, 1], 128, dtype=paddle.int32) +# print(rr_attention(query, key, value, startend_row_indices, threshold=0.5, causal=True))