diff --git a/.gitignore b/.gitignore index 97991419fdb..2fd004dc706 100644 --- a/.gitignore +++ b/.gitignore @@ -51,4 +51,4 @@ training/data # ck modules csrc/composable_kernel csrc/cutlass -.analysis \ No newline at end of file +.amd \ No newline at end of file diff --git a/flash_attn/flash_attn_triton_amd/fwd_prefill.py b/flash_attn/flash_attn_triton_amd/fwd_prefill.py index 71b2b40458a..3a2bd56fda4 100644 --- a/flash_attn/flash_attn_triton_amd/fwd_prefill.py +++ b/flash_attn/flash_attn_triton_amd/fwd_prefill.py @@ -382,26 +382,97 @@ def _attn_fwd_mask(acc, l_i, m_i, @triton.jit -def compute_masking(seqlen_k, seqlen_q, start_m, - IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, - WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - """ - Classify K blocks for attention computation with sliding window support. +def compute_window_bounds(q_start, q_end, diag, seqlen_k, + WINDOW_SIZE_LEFT: tl.constexpr, + WINDOW_SIZE_RIGHT: tl.constexpr, + IS_CAUSAL: tl.constexpr): + """Calculate the window boundaries for a query block.""" + # Left boundary + if WINDOW_SIZE_LEFT < 0: + left_min = 0 + left_max = 0 + else: + left_min = tl.maximum(0, q_start + diag - WINDOW_SIZE_LEFT) + left_max = tl.maximum(0, q_end + diag - WINDOW_SIZE_LEFT) - Returns: - - n_front_skip_blocks: Blocks completely before the window - - n_front_masked_blocks: Blocks partially overlapping window front - - n_full_blocks: Blocks completely inside the window - - n_back_masked_blocks: Blocks partially overlapping window back - - n_extra_tokens: Padding tokens in last K block + # Right boundary + if IS_CAUSAL: + # Causal cap: col ≤ row + diag + right_min = tl.minimum(seqlen_k - 1, q_start + diag) + right_max = tl.minimum(seqlen_k - 1, q_end + diag) + else: + if WINDOW_SIZE_RIGHT < 0: + right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) + else: + # Non-causal doesn't have the diagonal constraint + right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT) + right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT) + + return left_min, left_max, right_min, right_max + +@triton.jit +def classify_window_blocks(left_min, left_max, right_min, right_max, + BLOCK_N: tl.constexpr): + """Classify blocks based on window boundaries.""" + # First and last blocks that have ANY overlap with window + first_block = left_min // BLOCK_N + last_block = right_max // BLOCK_N + + # First block that is FULLY visible for all rows in Q block + full_left_block = left_max // BLOCK_N + (left_max % BLOCK_N != 0) + clipped_left = tl.minimum(full_left_block, last_block + 1) + + # Last block that is FULLY visible for all rows in Q block + last_full_block_candidate = right_min // BLOCK_N + if (last_full_block_candidate + 1) * BLOCK_N - 1 > right_min: + last_full_block_candidate -= 1 + full_right_block = tl.maximum(last_full_block_candidate, clipped_left - 1) + + # Calculate counts + n_front_skip_blocks = first_block + n_front_masked_blocks = tl.maximum(0, clipped_left - first_block) + n_full_blocks = tl.maximum(0, full_right_block - clipped_left + 1) + n_back_masked_blocks = tl.maximum(0, last_block - full_right_block) + + return (n_front_skip_blocks, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks, + clipped_left) # Return clipped_left for padded block handling + +@triton.jit +def handle_padded_last_block(n_extra_tokens, last_block, total_k_blocks, + clipped_left, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks): + """Ensure a padded last K-block is never classified as 'full'. + + We move the padded last block (if visible) into the back-masked bucket. + If it's already back-masked, we do nothing. If it was counted in the + front-masked range, we decrement front-masked; if it was counted as full, + we decrement full. Then we increment back-masked. """ - # Example case - # BLOCK_M = 4, BLOCK_N = 4, seqlen_q = 8, seqlen_k = 10 + padded_last_k = (n_extra_tokens != 0) & (last_block == total_k_blocks - 1) + + if padded_last_k: + # current 'full' range right edge + full_right_block = clipped_left + n_full_blocks - 1 + + # If last_block is already beyond full_right_block, it's already in back-masked → nothing to do + last_already_back_masked = last_block > full_right_block + if not last_already_back_masked: + # If the window starts past last_block, it was counted in front-masked + if clipped_left > last_block: + n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1) + else: + # Otherwise it was counted 'full' → move it out of full + n_full_blocks = tl.maximum(0, n_full_blocks - 1) + # In both cases we need one more back-masked block + n_back_masked_blocks = n_back_masked_blocks + 1 - # Total K blocks in the key sequence - total_k_blocks = tl.cdiv(seqlen_k, BLOCK_N) + return n_front_masked_blocks, n_full_blocks, n_back_masked_blocks +@triton.jit +def compute_padding_info(seqlen_k, BLOCK_N: tl.constexpr): + """Calculate padding information for the last K block.""" # check if we will need to do masking due either BLOCK_N being bigger than seqlen_k or seqlen_k not being a factor of BLOCK_N # n_extra_tokens = 10 % 4 = 2 # This means the last K block has 2 valid tokens and 2 padding positions @@ -415,15 +486,60 @@ def compute_masking(seqlen_k, seqlen_q, start_m, elif seqlen_k % BLOCK_N: n_extra_tokens = seqlen_k % BLOCK_N else: - n_extra_tokens = 0 + n_extra_tokens = 0 + return n_extra_tokens + +@triton.jit +def compute_block_masking(seqlen_k, seqlen_q, start_m, + IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr, + WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + """ + Classify K blocks for attention computation with sliding window support. + + Returns: + - n_front_skip_blocks: Blocks completely before the window + - n_front_masked_blocks: Blocks partially overlapping window front + - n_full_blocks: Blocks completely inside the window + - n_back_masked_blocks: Blocks partially overlapping window back + - n_extra_tokens: Padding tokens in last K block + """ + + # common + q_start = start_m * BLOCK_M + q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) + diag = seqlen_k - seqlen_q + total_k_blocks = tl.cdiv(seqlen_k, BLOCK_N) + n_extra_tokens = compute_padding_info(seqlen_k, BLOCK_N) if USE_SLIDING_WINDOW: - # TODO: Optimize by computing which blocks can be fully skipped - # For now, process all blocks with the mask function - if IS_CAUSAL: - return 0, 0, 0, total_k_blocks, n_extra_tokens - else: - return 0, 0, 0, total_k_blocks, n_extra_tokens + # get window bounds + left_min, left_max, right_min, right_max = compute_window_bounds( + q_start, q_end, diag, seqlen_k, + WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, IS_CAUSAL + ) + + # window vanishes → early exit + if right_max < left_min: + return 0, 0, 0, 0, n_extra_tokens + + # classify blocks + (n_front_skip_blocks, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks, + clipped_left) = classify_window_blocks( + left_min, left_max, right_min, right_max, BLOCK_N + ) + + # handle padded last block if needed + if n_extra_tokens != 0: + last_block = right_max // BLOCK_N + n_front_masked_blocks, n_full_blocks, n_back_masked_blocks = handle_padded_last_block( + n_extra_tokens, last_block, total_k_blocks, + clipped_left, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks + ) + return (n_front_skip_blocks, n_front_masked_blocks, + n_full_blocks, n_back_masked_blocks, n_extra_tokens) else: if IS_CAUSAL: # ========== CAUSAL MODE: Classify K Blocks ========== @@ -444,11 +560,6 @@ def compute_masking(seqlen_k, seqlen_q, start_m, # 1. figure out, in tokens, the right-most K position # this Q-block may attend to # ------------------------------------------------------------ - q_start = start_m * BLOCK_M - q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1) - - # causal diagonal offset between the two streams - diag = seqlen_k - seqlen_q # 0 when |Q| == |K| k_max_token = q_end + diag # last visible K index # this Q-block is entirely above the diagonal ⇒ nothing to do @@ -575,7 +686,7 @@ def attn_fwd(Q, K, V, bias, # figure out masking pattern - n_front_skip_blocks, n_front_masked_blocks, n_full_blocks, n_back_masked_blocks, n_extra_tokens = compute_masking( + n_front_skip_blocks, n_front_masked_blocks, n_full_blocks, n_back_masked_blocks, n_extra_tokens = compute_block_masking( seqlen_k, seqlen_q, start_m, IS_CAUSAL, USE_SLIDING_WINDOW, WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, BLOCK_M, BLOCK_N )