From 7b931366bd8f16289d8f6bcd162bd4f34920ef22 Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Thu, 9 Apr 2026 16:49:51 +0800 Subject: [PATCH 1/6] [Feat] CP-balance as flash_mask sub-module Co-authored-by: starcrown001 <148410714+starcrown001@users.noreply.github.com> --- flashmask/flash_mask/cp_balance/__init__.py | 2 + flashmask/flash_mask/cp_balance/cp_balance.py | 379 ++++++++++ .../cp_balance/cp_balance_cuda_kernels.py | 46 ++ .../cp_balance/csrc/cp_balance_utils.cu | 667 ++++++++++++++++++ flashmask/flash_mask/cp_balance/csrc/setup.py | 120 ++++ flashmask/setup.py | 25 +- 6 files changed, 1238 insertions(+), 1 deletion(-) create mode 100644 flashmask/flash_mask/cp_balance/__init__.py create mode 100644 flashmask/flash_mask/cp_balance/cp_balance.py create mode 100644 flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py create mode 100644 flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu create mode 100644 flashmask/flash_mask/cp_balance/csrc/setup.py diff --git a/flashmask/flash_mask/cp_balance/__init__.py b/flashmask/flash_mask/cp_balance/__init__.py new file mode 100644 index 00000000000..96570f1e679 --- /dev/null +++ b/flashmask/flash_mask/cp_balance/__init__.py @@ -0,0 +1,2 @@ +from .cp_balance import balance_flashmask_input +from .cp_balance_cuda_kernels import indices_rerank_cuda, indices_to_chunks_cuda diff --git a/flashmask/flash_mask/cp_balance/cp_balance.py b/flashmask/flash_mask/cp_balance/cp_balance.py new file mode 100644 index 00000000000..4175262b327 --- /dev/null +++ b/flashmask/flash_mask/cp_balance/cp_balance.py @@ -0,0 +1,379 @@ +import heapq +import paddle +import numpy as np +from .cp_balance_cuda_kernels import scanMaxMinChunkedKernel, reduce_workload, indices_to_chunks_cuda, indices_rerank_cuda +import paddle.distributed as dist +import hashlib +from typing import List, Tuple, Dict, Optional + +# --- 调试辅助函数 --- + +def save_tensor(x: paddle.Tensor, name: str): + """将 Paddle Tensor 保存为 txt 文件,用于调试。""" + x_np = x.numpy() + np.savetxt(f'{name}.txt', x_np.reshape(-1, x_np.shape[-1]), fmt='%d') + +def tensor_md5(tensor: paddle.Tensor) -> str: + """计算 Paddle Tensor 的 MD5 哈希值,用于验证数据一致性。""" + x_bytes = tensor.numpy().tobytes() + md5_hash = hashlib.md5(x_bytes).hexdigest() + print(f"Tensor MD5: {md5_hash}") + return md5_hash + +# --- 核心工作负载计算与分配 --- + +def get_q_workload( + start_row_indices: paddle.Tensor, + q_chunk_size: int, + m_block_size: int, + n_block_size: int +) -> paddle.Tensor: + """ + 根据稀疏attention的起止索引,估算每个query chunk的计算负载。 + 这是负载均衡的第一步,目的是量化每个数据块的计算成本。 + + Args: + start_row_indices (paddle.Tensor): 形状为 [B, H, S, 2] 或 [B, H, S, 4] 的张量, + 表示每个 query token 需要计算的 key token 的起止范围。 + 维度4的顺序为 [LTS, LTE, UTS, UTE]。 + 维度2的顺序为 [LTS, UTE]。 + q_chunk_size (int): Query 侧进行负载均衡分析的块大小。 + m_block_size (int): FlashAttention kernel 中 query 侧的块大小 (Br)。 + n_block_size (int): FlashAttention kernel 中 key 侧的块大小 (Bc)。 + + Returns: + paddle.Tensor: 形状为 [1, H, Tchunks, 2] 的张量, + 其中 Tchunks 是 chunk 的数量。 + 每个 chunk 的信息为 [workload, original_index], + 表示该 chunk 的估算工作量和其原始索引。 + """ + assert start_row_indices is not None, "start_row_indices cannot be None" + assert q_chunk_size % m_block_size == 0, "q_chunk_size must be divisible by m_block_size" + + # 1. 解析输入的起止索引 + # start_row_indices 可能包含下三角(LT)和上三角(UT)的起止(Start/End)信息 + LTS, LTE, UTS, UTE = None, None, None, None + if start_row_indices.shape[-1] == 4: + LTS, LTE, UTS, UTE = paddle.split(start_row_indices, 4, axis=-1) + LTS, LTE, UTS, UTE = [t.squeeze(-1) for t in (LTS, LTE, UTS, UTE)] + elif start_row_indices.shape[-1] == 2: + LTS, UTE = paddle.split(start_row_indices, 2, axis=-1) + LTS, UTE = LTS.squeeze(-1), UTE.squeeze(-1) + + # 2. 获取维度信息 + # 从任意一个非None的张量中获取 Batch, Head, Sequence Length + valid_tensor = next(t for t in [LTS, LTE, UTS, UTE] if t is not None) + B, H, S = valid_tensor.shape + + # 计算块的数量 + Tr = S // m_block_size # Query 侧块总数 + Tc = S // n_block_size # Key 侧块总数 + Tchunks = S // q_chunk_size # 用于负载均衡的 chunk 总数 + assert Tr % Tchunks == 0, "Total row blocks must be divisible by total chunks" + blocks_per_chunk = Tr // Tchunks + + # 3. 使用自定义CUDA核预计算每个 Key 块内的索引最大/最小值 + # 这一步是关键优化,它将 O(S) 的扫描操作降维到 O(S/Bc), + # 极大地加速了后续工作负载的估算。 + def scan_max_min(tensor): + if tensor is not None: + return scanMaxMinChunkedKernel(tensor, n_block_size, B, H, S) + return None, None + + LTStartMax_gpu, LTStartMin_gpu = scan_max_min(LTS) + LTEndMax_gpu, LTEndMin_gpu = scan_max_min(LTE) + UTStartMax_gpu, UTStartMin_gpu = scan_max_min(UTS) + UTEndMax_gpu, UTEndMin_gpu = scan_max_min(UTE) + + # 4. 使用自定义CUDA核计算每个 Query 块的工作负载 + # 这个核模拟了 FlashAttention 的块状计算过程,但只计算需要被激活的块的数量, + # 而不是执行实际的矩阵乘法,从而高效地估算出工作负载。 + all_indices_max_min = [ + LTStartMax_gpu, LTStartMin_gpu, LTEndMax_gpu, LTEndMin_gpu, + UTStartMax_gpu, UTStartMin_gpu, UTEndMax_gpu, UTEndMin_gpu + ] + workload_per_block = reduce_workload(all_indices_max_min, B, H, Tr, Tc, m_block_size, S) + + # 5. 将每个块的工作负载聚合到 chunk 级别 + workload_grouped = workload_per_block.reshape([B, H, Tchunks, blocks_per_chunk, 1]) + workload_per_chunk = paddle.sum(workload_grouped, axis=3).sum(axis=0).reshape([1, H, Tchunks]) + + # 6. 准备最终输出,包含工作负载和原始索引 + final_res = paddle.zeros([1, H, Tchunks, 2], dtype='int32', device=start_row_indices.place) + final_res[:, :, :, 0] = workload_per_chunk + final_res[:, :, :, 1] = paddle.arange(0, Tchunks, dtype="int32") + + return final_res + + +def assign_tasks_heap( + tasks: np.ndarray, + num_buckets: int +) -> Tuple[List[List[Tuple[int, int]]], List[int], int]: + """ + 使用小顶堆的贪心算法,将带有权重和索引的任务列表分配到 M 个桶中, + 以实现负载均衡。 + + Args: + tasks (np.ndarray): 形状为 (N, 2) 的任务数组,每行是 [weight, index]。 + num_buckets (int): 桶的数量(通常等于通信组的 world size)。 + + Returns: + Tuple: + - buckets (List[List[Tuple[int, int]]]): 分配结果,每个子列表是一个桶的任务。 + - bucket_weights (List[int]): 每个桶的总权重。 + - cuts (int): 数据切分次数,衡量数据重排后的连续性。 + """ + n = len(tasks) + if n == 0: + return [[] for _ in range(num_buckets)], [0] * num_buckets, 0 + + # 每个桶的期望任务数量 + batch_size = n // num_buckets + + # 按权重降序排序任务,优先分配最重的任务 + tasks_sorted = sorted(tasks, key=lambda x: -x[0]) + + # 初始化桶和记录每个桶当前状态的变量 + buckets = [[] for _ in range(num_buckets)] + bucket_weights = [0] * num_buckets + bucket_counts = [0] * num_buckets + + # 初始化小顶堆,用于快速找到当前总权重最小的桶 + # 堆中元素为 (current_weight, bucket_index) + heap = [(0, i) for i in range(num_buckets)] + + # 贪心分配:依次将最重的任务分配给当前总权重最小的、且未满的桶 + for weight, idx in tasks_sorted: + # 找到一个可以放入任务的桶 + temp_popped = [] + found_bucket = False + while heap: + bucket_sum, bucket_idx = heapq.heappop(heap) + if bucket_counts[bucket_idx] < batch_size: + # 找到桶,更新状态并放回堆中 + buckets[bucket_idx].append((weight, idx)) + bucket_weights[bucket_idx] += weight + bucket_counts[bucket_idx] += 1 + heapq.heappush(heap, (bucket_weights[bucket_idx], bucket_idx)) + found_bucket = True + break + else: + # 该桶已满,暂存起来,继续寻找下一个 + temp_popped.append((bucket_sum, bucket_idx)) + + # 将之前因为满了而弹出的桶重新放回堆中 + for item in temp_popped: + heapq.heappush(heap, item) + + if not found_bucket: + # 如果所有桶都满了(通常在 n % num_buckets != 0 时发生) + # 将剩余的任务分配给当前总权重最小的桶 + bucket_sum, bucket_idx = heapq.heappop(heap) + buckets[bucket_idx].append((weight, idx)) + bucket_weights[bucket_idx] += weight + bucket_counts[bucket_idx] += 1 + heapq.heappush(heap, (bucket_weights[bucket_idx], bucket_idx)) + + + # (可选)按任务原始序号对每个桶内部进行排序,方便调试 + for i in range(num_buckets): + buckets[i] = sorted(buckets[i], key=lambda x: x[1]) + + # 统计切分次数:衡量重排后数据块的连续性 + all_assigned_indices = sorted([idx for bucket in buckets for _, idx in bucket]) + cuts = sum(1 for i in range(1, len(all_assigned_indices)) if all_assigned_indices[i] != all_assigned_indices[i-1] + 1) + + return buckets, bucket_weights, cuts + + +# --- 数据通信与重排辅助函数 --- + +def get_send_dict(buckets: List[List[Tuple[int, int]]], cp_size: int, rank: int) -> Dict[int, List[int]]: + """ + 根据负载均衡分配结果,为当前 rank 生成 all-to-all 通信的发送字典。 + + Args: + buckets (List): 所有 rank 的任务分配结果。 + cp_size (int): 通信组大小。 + rank (int): 当前进程的 rank。 + + Returns: + Dict[int, List[int]]: 发送字典。key 是目标 rank,value 是要发送给该 rank 的本地 chunk 索引列表。 + """ + send_dict = {i: [] for i in range(cp_size)} + # 遍历所有桶(即所有目标 rank 的任务列表) + for target_rank, bucket in enumerate(buckets): + for _, chunk_idx in bucket: + # 如果某个 chunk 的原始属主是当前 rank,则需要将其发送 + if chunk_idx // cp_size == rank: + # chunk_idx % cp_size 得到的是在当前 rank 上的局部索引 + send_dict[target_rank].append(chunk_idx % cp_size) + return send_dict + +def get_recv_dict(bucket: List[Tuple[int, int]], cp_size: int) -> Dict[int, List[int]]: + """ + 根据当前 rank 的任务分配结果,生成 all-to-all 通信的接收字典。 + + Args: + bucket (List): 当前 rank 分配到的任务列表。 + cp_size (int): 通信组大小。 + + Returns: + Dict[int, List[int]]: 接收字典。key 是源 rank,value 是从该 rank 接收的数据块 + 应该被放置到的本地位置索引列表。 + """ + recv_dict = {i: [] for i in range(cp_size)} + # 遍历分配给我的所有任务 + for local_pos, (_, chunk_idx) in enumerate(bucket): + # chunk_idx.item() // cp_size 得到的是这个 chunk 原始所在的 rank + source_rank = chunk_idx.item() // cp_size + recv_dict[source_rank].append(local_pos) + return recv_dict + +def balance_alltoall( + input_tensor: paddle.Tensor, + cp_size: int, + cp_group, + chunk_size: int, + send_dict: Dict[int, List[int]], + recv_dict: Dict[int, List[int]] +) -> paddle.Tensor: + """ + 执行 all-to-all 通信,根据 send/recv 字典对 `input_tensor` 进行数据重排。 + 此函数已重构,可统一处理不同维度的张量。 + + Args: + input_tensor (paddle.Tensor): 待重排的张量,如 Q, K, V。 + cp_size (int): 通信组大小。 + cp_group (dist.Group): Paddle 分布式通信组。 + chunk_size (int): 数据块的大小。 + send_dict (Dict): 发送字典。 + recv_dict (Dict): 接收字典。 + + Returns: + paddle.Tensor: 重排后的张量。 + """ + original_shape = input_tensor.shape + B, S = original_shape[0], original_shape[1] + + # 将输入张量统一 reshape 为 3D (B, S, -1) 以便统一处理 + tensor_3d = input_tensor.reshape((B, S, -1)) + HD = tensor_3d.shape[-1] + + # 1. 准备发送数据 (Gather) + # 根据 send_dict,从本地张量中收集需要发送给其他 rank 的数据块 + send_list = [] + for target_rank in range(cp_size): + indices_to_send = send_dict[target_rank] + if indices_to_send: + # 将所有要发往同一个 rank 的数据块拼接在一起 + data_to_send = paddle.concat( + [tensor_3d[:, idx * chunk_size:(idx + 1) * chunk_size, :] for idx in indices_to_send], + axis=1 + ) + send_list.append(data_to_send) + else: + # 注意:NCCL alltoall 不支持大小为 0 的张量,因此发送一个虚拟的、 + # 非常小的张量作为占位符。接收方也需对应接收。 + send_list.append(paddle.zeros((B, 1, HD), dtype=input_tensor.dtype)) + + # 2. 准备接收缓冲区 (Scatter) + # 根据 recv_dict,为从其他 rank 接收的数据准备相应大小的空缓冲区 + recv_list = [] + for source_rank in range(cp_size): + num_chunks_to_recv = len(recv_dict[source_rank]) + if num_chunks_to_recv > 0: + recv_list.append( + paddle.empty((B, chunk_size * num_chunks_to_recv, HD), dtype=input_tensor.dtype) + ) + else: + # 对应发送方的虚拟张量,接收一个同样大小的虚拟缓冲区 + recv_list.append(paddle.empty((B, 1, HD), dtype=input_tensor.dtype)) + + # 3. 执行 All-to-All 通信 + dist.alltoall(out_tensor_list=recv_list, in_tensor_list=send_list, group=cp_group) + + # 4. 将接收到的数据重新组装成最终张量 + final_res_3d = paddle.empty_like(tensor_3d) + for source_rank in range(cp_size): + local_positions = recv_dict[source_rank] + if local_positions: + received_data = recv_list[source_rank] + # 将从 source_rank 接收到的数据块,放置到它们在本地应该在的位置 + for i, local_pos in enumerate(local_positions): + start_s = local_pos * chunk_size + end_s = (local_pos + 1) * chunk_size + data_start = i * chunk_size + data_end = (i + 1) * chunk_size + final_res_3d[:, start_s:end_s, :] = received_data[:, data_start:data_end, :] + + # 恢复原始形状 + return final_res_3d.reshape(original_shape) + + +# --- 主流程函数 --- + +def balance_flashmask_input( + startend_row_indices: paddle.Tensor, + cp_size: int, + cp_rank: int, + balance_chunk_size: int = 2048, + q_block_size: int = 128, + k_block_size: int = 128 +) -> Tuple[paddle.Tensor, List[List[Tuple[int, int]]]]: + """ + FlashMask 输入数据的负载均衡主流程。 + 该函数协调整个过程:估算工作负载 -> 任务分配 -> 生成通信计划 -> 数据重排。 + + Args: + startend_row_indices (paddle.Tensor): 稀疏 attention 的原始起止索引。 + cp_size (int): 通信组大小。 + cp_rank (int): 当前进程的 rank。 + balance_chunk_size (int): 用于负载均衡分析和数据移动的块大小。 + q_block_size (int): FlashAttention kernel 的 query 块大小。 + k_block_size (int): FlashAttention kernel 的 key 块大小。 + + Returns: + Tuple: + - local_startend_row_indices (paddle.Tensor): 经过负载均衡和重排后, + 当前 rank 需要处理的局部起止索引。 + - buckets (List[List[Tuple[int, int]]]): 全局的任务分配方案,用于后续 + 对 Q, K, V 等张量进行同样的重排。 + """ + # 步骤 1: 估算每个 chunk 的工作负载 + paddle.base.core.nvprof_nvtx_push("get_q_workload") + workload = get_q_workload(startend_row_indices, balance_chunk_size, q_block_size, k_block_size) + paddle.base.core.nvprof_nvtx_pop() + + # 步骤 2: 使用堆贪心算法在 CPU 上进行任务分配 + paddle.base.core.nvprof_nvtx_push("assign_tasks_heap") + # 将 workload tensor 转换成 numpy 数组以用于 heapq + tasks_np = workload.reshape([-1, 2]).cpu().numpy() + buckets, _, _ = assign_tasks_heap(tasks_np, cp_size) + paddle.base.core.nvprof_nvtx_pop() + + # 步骤 5: 根据全局分配方案 `buckets`,对原始索引张量进行重排 (Gather) + # 这一步创建了一个全局视角下、数据块被重新排列后的 `startend_row_indices`。 + paddle.base.core.nvprof_nvtx_push("startend_row_indices_rerank") + # 将 `buckets` 展平,得到一个新的 chunk 顺序 + rerank_indices = np.array([idx for bucket in buckets for _, idx in bucket], dtype=np.int32) + indices_tensor = paddle.to_tensor(rerank_indices, dtype='int32', place=startend_row_indices.place) + + # 使用 CUDA 核高效地执行 gather 操作 + startend_row_indices_rerank = indices_rerank_cuda(startend_row_indices, indices_tensor) + paddle.base.core.nvprof_nvtx_pop() + + # 步骤 6: 从重排后的全局索引中,计算出当前 rank 的局部索引 (Localize) + # 这一步将全局索引(可能跨越整个序列长度S)转换为相对于本地数据块的局部索引。 + paddle.base.core.nvprof_nvtx_push("indices_to_chunks") + local_bucket_indices = [x[1] for x in buckets[cp_rank]] + local_indices_tensor = paddle.to_tensor(local_bucket_indices, dtype='int32', place=startend_row_indices.place) + + # 使用 CUDA 核高效地执行索引的 clipping 和 offsetting + local_startend_row_indices = indices_to_chunks_cuda( + startend_row_indices_rerank, local_indices_tensor, balance_chunk_size + ) + paddle.base.core.nvprof_nvtx_pop() + + return local_startend_row_indices, buckets \ No newline at end of file diff --git a/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py b/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py new file mode 100644 index 00000000000..f0487a445db --- /dev/null +++ b/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py @@ -0,0 +1,46 @@ +import flashmask_cpbalance_cudaops as cp_balance_ops + +def scanMaxMinChunkedKernel(input_tensor, Bc, B, H, S): + maxo,mino = cp_balance_ops.scan_max_min( + input_tensor, + H, + S, + S, + Bc, + False, + 0.0, + 0, + 0 + ) + + return maxo, mino + + +def reduce_workload(start_row_maxmin_indice_list, B, H, Tr, Tc, Br, S): + ( + LTStartMax, + LTStartMin, + LTEndMax, + LTEndMin, + UTStartMax, + UTStartMin, + UTEndMax, + UTEndMin, + ) = start_row_maxmin_indice_list + + workload = cp_balance_ops.reduce_workload( + LTStartMax, LTStartMin, LTEndMax, LTEndMin, UTStartMax, UTStartMin, UTEndMax, UTEndMin, + B, H, Tr, Tc, S, Br, False, 128 + ) + + return workload + +def indices_to_chunks_cuda(startend_row_indices, bucket_idx, chunksize=2048): + result = cp_balance_ops.indices_to_chunks(startend_row_indices, bucket_idx, chunksize) + return result + +def indices_rerank_cuda(startend_row_indices, indices, balance_chunk_size=2048): + B, H, S, D = startend_row_indices.shape + num_chunks = (S + balance_chunk_size - 1) // balance_chunk_size + startend_row_indices_rerank = cp_balance_ops.indices_rerank(startend_row_indices, indices, B, H, S,D,num_chunks,balance_chunk_size) + return startend_row_indices_rerank diff --git a/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu b/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu new file mode 100644 index 00000000000..200787d0e58 --- /dev/null +++ b/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu @@ -0,0 +1,667 @@ +#include "paddle/extension.h" + +#define CHECK_CUDA_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") + +int get_kBlockN(int head_size_rounded, bool is_flashmask, bool is_causal, bool has_softcap, + bool is_local, int seqlen_q, int seqlen_k, bool has_lt_end, bool has_ut_start) { + if (head_size_rounded <= 64) { + if (is_flashmask && !is_causal) { + return 96; + } else if ((is_causal && has_softcap) || is_flashmask) { + return 128; + } else { + return 128; + } + } else if (head_size_rounded <= 128) { + if (is_causal || is_local || has_softcap) { + return 128; + } else { + if (seqlen_q >= 1024 || seqlen_k >= 1024) { + return 128; + } else { + return 64; + } + } + } else if (head_size_rounded <= 256) { + if (has_lt_end && has_ut_start) { + return 32; + } else { + return 64; + } + } else { + // 不支持的情况 + throw std::runtime_error("head_size_rounded not supported"); + } +} + +template +__global__ +void scanMaxMinChunkedKernel( + const int *input, int b, int n, int *maxo, int *mino) { + int bid = threadIdx.y + blockIdx.y * blockDim.y; + if (bid >= b) return; + int i_offset = bid * n; + input = input + i_offset; + const int nblock_seqlen = ((n + kBlockN - 1) / kBlockN + 3) & 0xfffffffc; + constexpr int nums = (kBlockN + 31) / 32; + int warpId = blockIdx.x; + int tid = threadIdx.x; + int lane_id = threadIdx.x % 32; + int maxv, minv; + int idx = warpId * kBlockN + tid; + if (warpId * kBlockN + kBlockN > n) { + maxv = 0; + minv = INT_MAX; + #pragma unroll + for (int i = 0; i < nums; i++) { + if (idx < n && lane_id + i * 32 < kBlockN) { + maxv = max(maxv, input[idx]); + minv = min(minv, input[idx]); + } + idx += 32; + } + } else { + maxv = 0; + minv = INT_MAX; + #pragma unroll + for (int i = 0; i < nums; i++) { + if(lane_id + i * 32 < kBlockN) { + maxv = max(maxv, input[idx]); + minv = min(minv, input[idx]); + idx += 32; + } + } + } + __syncwarp(); + maxv = __reduce_max_sync(0xffffffff, maxv); + minv = __reduce_min_sync(0xffffffff, minv); + if (tid == 0) { + maxo[bid * nblock_seqlen + warpId] = maxv; + mino[bid * nblock_seqlen + warpId] = minv; + } +} + +// Enum for pointer dispatching in reduce_workload_kernel +enum PtrDispatch { SINGLE_PTR = 1, DUAL_PTR = 2, FULL_PTR = 4 }; + +template +__global__ void reduce_workload_kernel( + const int* LTStartMax, const int* LTStartMin, + const int* LTEndMax, const int* LTEndMin, + const int* UTStartMax, const int* UTStartMin, + const int* UTEndMax, const int* UTEndMin, + int* workload, // [B, H, Tr, 1] + int BH, int Tr, int Tc, int S, + int Br // m_block_size +) { + int bh = blockIdx.y; + int tr = blockIdx.x; + int tc = threadIdx.x; + int warpId = threadIdx.x / 32; + int laneId = threadIdx.x % 32; + + if(tr >= Tr) return; + + int wl = 0; + bool fully_masked = true; + bool partially_masked = false; + int lt_start_max = INT_MAX; + int lt_start_min = INT_MAX; + int lt_end_max = INT_MAX; + int lt_end_min = INT_MAX; + int ut_start_max = INT_MIN; + int ut_start_min = INT_MIN; + int ut_end_max = INT_MIN; + int ut_end_min = INT_MIN; + + __shared__ int smem[32]; + + const int idx = bh * Tc + tc; + const int q_idx = bh * Tr + tr; + + // m_block_s/e: Q block boundaries within a single (batch, head) — use tr only, not q_idx. + // q_idx includes the bh offset for output indexing, but mask values are in [0, S) per (b,h). + const int m_block_s = tr * kBlockM; + const int m_block_e = m_block_s + kBlockM < S ? m_block_s + kBlockM : S; + + lt_start_max = tc < Tc ? LTStartMax[idx] : INT_MAX; + lt_start_min = tc < Tc ? LTStartMin[idx] : INT_MAX; + + // 分支展开 + if constexpr (PTR_DISPATCH_TAG == FULL_PTR) { + lt_end_max = tc < Tc ? LTEndMax[idx] : INT_MAX; + lt_end_min = tc < Tc ? LTEndMin[idx] : INT_MAX; + ut_start_max = tc < Tc ? UTStartMax[idx] : INT_MIN; + ut_start_min = tc < Tc ? UTStartMin[idx] : INT_MIN; + ut_end_max = tc < Tc ? UTEndMax[idx] : INT_MIN; + ut_end_min = tc < Tc ? UTEndMin[idx] : INT_MIN; + + fully_masked = (m_block_s >= lt_start_max && m_block_e <= lt_end_min) || + (m_block_s >= ut_start_max && m_block_e <= ut_end_min); + partially_masked = (m_block_s < lt_end_max && m_block_e > lt_start_min) || + (m_block_s < ut_end_max && m_block_e > ut_start_min); + } + else if constexpr (PTR_DISPATCH_TAG == DUAL_PTR) { + if constexpr (is_causal) { + lt_end_max = tc < Tc ? LTEndMax[idx] : INT_MAX; + lt_end_min = tc < Tc ? LTEndMin[idx] : INT_MAX; + fully_masked = m_block_s >= lt_start_max && m_block_e <= lt_end_min; + partially_masked = m_block_s < lt_end_max && m_block_e > lt_start_min; + } else { + ut_end_max = tc < Tc ? UTEndMax[idx] : INT_MIN; + ut_end_min = tc < Tc ? UTEndMin[idx] : INT_MIN; + fully_masked = (m_block_s >= lt_start_max) || (m_block_e <= ut_end_min); + partially_masked = (m_block_e > lt_start_min) || (m_block_s < ut_end_max); + } + } + else if constexpr (PTR_DISPATCH_TAG == SINGLE_PTR) { + fully_masked = m_block_s >= lt_start_max; + partially_masked = m_block_e > lt_start_min; + } + + if(tc >= Tc){ + fully_masked = true; + partially_masked = false; + } + wl = fully_masked ? 0 : 1; + + unsigned mask = 0xffffffff; + // warp reduce sum + int wl_sum = wl; + for (int offset = 16; offset > 0; offset >>= 1) { + wl_sum += __shfl_down_sync(mask, wl_sum, offset); + } + if (laneId == 0) { + smem[warpId] = wl_sum; + } + __syncthreads(); + + if (threadIdx.x < 32) { + int val = (threadIdx.x < (blockDim.x + 31)/32) ? smem[threadIdx.x] : 0; + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_down_sync(mask, val, offset); + } + if (threadIdx.x == 0) { + workload[q_idx] = val; + } + } +} + +__global__ void indices_to_chunks_kernel( + const int* startend_row_indices, + const int* chunk_bucket_indices, + int* chunked_result, + int num_rows, + int num_buckets, + int chunk_size) +{ + int row = blockIdx.x * blockDim.x + threadIdx.x; + if (row >= num_rows) return; + + int max_chunk_index = 0; + int row_val = startend_row_indices[row]; + + for (int bucket = 0; bucket < num_buckets; ++bucket) { + int bucket_idx = chunk_bucket_indices[bucket]; + int chunk_start = bucket_idx * chunk_size; + int local_index = row_val - chunk_start; + local_index = max(local_index, 0); + local_index = min(local_index, chunk_size); + + if (local_index > 0) { + local_index += bucket * chunk_size; + } + + if (bucket == 0 || local_index > max_chunk_index) { + max_chunk_index = local_index; + } + } + chunked_result[row] = max_chunk_index; +} + +__global__ void indices_rerank_kernel( + const int* startend_row_indices, + int* output_reranked_indices, + const int* chunk_indices, + int batch_size, + int num_heads, + int seq_len, + int feature_dim, + int num_chunks, + int chunk_size +) { + int output_seq_len = num_chunks * chunk_size; + int total_elements = batch_size * output_seq_len * num_heads * feature_dim; + int flat_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (flat_idx >= total_elements) return; + + int d = flat_idx % feature_dim; + int s_out = (flat_idx / feature_dim) % output_seq_len; + int h = (flat_idx / feature_dim / output_seq_len) % num_heads; + int b = (flat_idx / feature_dim / output_seq_len / num_heads) % batch_size; + + int chunk_id = s_out / chunk_size; + int chunk_offset = s_out % chunk_size; + int src_s = chunk_indices[chunk_id] * chunk_size + chunk_offset; + + if (src_s >= seq_len) return; + + int src_flat_idx = ((b * num_heads + h) * seq_len + src_s) * feature_dim + d; + int dst_flat_idx = flat_idx; + + output_reranked_indices[dst_flat_idx] = startend_row_indices[src_flat_idx]; +} + + + + +// ============================================================================ +// ScanMaxMin Operator +// ============================================================================ + +std::vector scan_max_min_cuda( + const paddle::Tensor& input, + const int head_size_rounded, + const int seq_len_q, + const int seq_len_k, + const int blocksize = -1, + const bool is_causal = false, + const float softcap = 0.0, + const int window_size_left = 0, + const int window_size_right = 0) { + CHECK_CUDA_INPUT(input); + + // The scanMaxMin kernel treats input as flat [batch, seqlen]. + // Input tensor is [B, H, S] from Python (H is always 1 in practice; after squeeze(-1) from [B,H,S,D]). + // We compute total_batch = product of all dims except the last, so it handles [B,S], [B,H,S] etc. + const auto dims = input.shape(); + const auto ndim = dims.size(); + int64_t total_batch = 1; + for (int i = 0; i < ndim - 1; i++) total_batch *= dims[i]; + const auto num_sequences = dims[ndim - 1]; + // head_dim only used by get_kBlockN heuristic; safe default when blocksize is explicit + const auto head_dim = (ndim >= 4) ? dims[3] : 1; + + PADDLE_ENFORCE_EQ( + num_sequences, + seq_len_k, + common::errors::InvalidArgument( + "Input tensor's third dimension (num_sequences) must be equal to seq_len_k.")); + + const bool is_local = (window_size_left >= 0 || window_size_right >= 0) && !is_causal; + const bool is_flashmask = true; + const bool has_softcap = softcap > 0.0; + const bool has_lt_end = !is_causal && head_dim >= 2; + const bool has_ut_start = head_dim == 4; + + const int kernel_block_size_n = + blocksize > 0 ? blocksize : get_kBlockN(head_size_rounded, + is_flashmask, + is_causal, + has_softcap, + is_local, + seq_len_q, + seq_len_k, + has_lt_end, + has_ut_start); + + // Pad the number of blocks to be a multiple of 4 for performance + const int num_blocks_seqlen = + ((num_sequences + kernel_block_size_n - 1) / kernel_block_size_n + 3) & 0xfffffffc; + + std::vector output_shape = {total_batch, num_blocks_seqlen}; + auto max_output = paddle::empty(output_shape, input.dtype(), input.place()); + auto min_output = paddle::empty(output_shape, input.dtype(), input.place()); + + // Launch kernel + dim3 block_dim(32, 4); + dim3 grid_dim((num_sequences + kernel_block_size_n - 1) / kernel_block_size_n, + (total_batch + 3) / 4); + + const cudaStream_t stream = input.stream(); + + switch (kernel_block_size_n) { + case 32: + scanMaxMinChunkedKernel<32><<>>( + input.data(), total_batch, num_sequences, + max_output.data(), min_output.data()); + break; + case 64: + scanMaxMinChunkedKernel<64><<>>( + input.data(), total_batch, num_sequences, + max_output.data(), min_output.data()); + break; + case 96: + scanMaxMinChunkedKernel<96><<>>( + input.data(), total_batch, num_sequences, + max_output.data(), min_output.data()); + break; + case 128: + scanMaxMinChunkedKernel<128><<>>( + input.data(), total_batch, num_sequences, + max_output.data(), min_output.data()); + break; + default: + PD_THROW("Unsupported kernel_block_size_n: %d", kernel_block_size_n); + } + return {max_output, min_output}; +} + +std::vector ScanMaxMin( + const paddle::Tensor& input, + int head_size_rounded, + int seq_len_q, + int seq_len_k, + int blocksize, + bool is_causal, + float softcap, + int window_size_left, + int window_size_right) { +#ifdef PADDLE_WITH_CUDA + if (input.is_gpu()) { + return scan_max_min_cuda(input, + head_size_rounded, + seq_len_q, + seq_len_k, + blocksize, + is_causal, + softcap, + window_size_left, + window_size_right); + } +#endif + PD_THROW("Unsupported device: ScanMaxMin operator is only available for CUDA."); +} + + +// ============================================================================ +// ReduceWorkload Operator +// ============================================================================ + +template +void launch_reduce_workload_kernel( + const paddle::Tensor& lt_start_max, + const paddle::Tensor& lt_start_min, + const paddle::optional& lt_end_max, + const paddle::optional& lt_end_min, + const paddle::optional& ut_start_max, + const paddle::optional& ut_start_min, + const paddle::optional& ut_end_max, + const paddle::optional& ut_end_min, + paddle::Tensor& workload, + int batch_times_heads, + int num_row_blocks, + int num_col_blocks, + int stride, + int row_block_size, + bool is_causal, + cudaStream_t stream) { + + dim3 block_dim(1024, 1); + dim3 grid_dim(num_row_blocks, batch_times_heads); + + int ptr_dispatch_tag = SINGLE_PTR; + if (lt_end_max || ut_end_max) { + ptr_dispatch_tag = DUAL_PTR; + if (ut_start_max) { + ptr_dispatch_tag = FULL_PTR; + } + } + + int* workload_ptr = workload.data(); + const int* lt_start_max_ptr = lt_start_max.data(); + const int* lt_start_min_ptr = lt_start_min.data(); + const int* lt_end_max_ptr = lt_end_max ? lt_end_max.get().data() : nullptr; + const int* lt_end_min_ptr = lt_end_min ? lt_end_min.get().data() : nullptr; + const int* ut_start_max_ptr = ut_start_max ? ut_start_max.get().data() : nullptr; + const int* ut_start_min_ptr = ut_start_min ? ut_start_min.get().data() : nullptr; + const int* ut_end_max_ptr = ut_end_max ? ut_end_max.get().data() : nullptr; + const int* ut_end_min_ptr = ut_end_min ? ut_end_min.get().data() : nullptr; + + if (ptr_dispatch_tag == FULL_PTR) { + reduce_workload_kernel<<>>( + lt_start_max_ptr, lt_start_min_ptr, lt_end_max_ptr, lt_end_min_ptr, + ut_start_max_ptr, ut_start_min_ptr, ut_end_max_ptr, ut_end_min_ptr, + workload_ptr, batch_times_heads, num_row_blocks, num_col_blocks, stride, row_block_size); + } else if (ptr_dispatch_tag == DUAL_PTR) { + if (is_causal) { + reduce_workload_kernel<<>>( + lt_start_max_ptr, lt_start_min_ptr, lt_end_max_ptr, lt_end_min_ptr, + ut_start_max_ptr, ut_start_min_ptr, ut_end_max_ptr, ut_end_min_ptr, + workload_ptr, batch_times_heads, num_row_blocks, num_col_blocks, stride, row_block_size); + } else { + reduce_workload_kernel<<>>( + lt_start_max_ptr, lt_start_min_ptr, lt_end_max_ptr, lt_end_min_ptr, + ut_start_max_ptr, ut_start_min_ptr, ut_end_max_ptr, ut_end_min_ptr, + workload_ptr, batch_times_heads, num_row_blocks, num_col_blocks, stride, row_block_size); + } + } else if (ptr_dispatch_tag == SINGLE_PTR) { + reduce_workload_kernel<<>>( + lt_start_max_ptr, lt_start_min_ptr, lt_end_max_ptr, lt_end_min_ptr, + ut_start_max_ptr, ut_start_min_ptr, ut_end_max_ptr, ut_end_min_ptr, + workload_ptr, batch_times_heads, num_row_blocks, num_col_blocks, stride, row_block_size); + } else { + PD_THROW("Unknown pointer dispatch tag."); + } +} + +std::vector reduce_workload_cuda( + const paddle::Tensor& lt_start_max, + const paddle::Tensor& lt_start_min, + const paddle::optional& lt_end_max, + const paddle::optional& lt_end_min, + const paddle::optional& ut_start_max, + const paddle::optional& ut_start_min, + const paddle::optional& ut_end_max, + const paddle::optional& ut_end_min, + int batch_size, + int num_heads, + int num_row_blocks, + int num_col_blocks, + int stride, + int row_block_size, + bool is_causal, + int m_block_size) { + + const int kBlockM = m_block_size; + const int batch_times_heads = batch_size * num_heads; + + // Use the actual padded stride from scanMaxMin output, not the caller's unpadded num_col_blocks. + // scanMaxMin pads nblock_seqlen to a multiple of 4 for performance; if num_col_blocks differs + // from the tensor's actual column count, the flat index bh*Tc+tc would be wrong. + const int Tc_stride = static_cast(lt_start_max.shape()[1]); + + // Allocate output tensor + std::vector output_shape = {batch_size, num_heads, num_row_blocks, 1}; + auto workload = paddle::empty(output_shape, lt_start_max.dtype(), lt_start_max.place()); + + cudaStream_t stream = lt_start_max.stream(); + + switch (kBlockM) { + case 64: + launch_reduce_workload_kernel<64>( + lt_start_max, lt_start_min, lt_end_max, lt_end_min, ut_start_max, + ut_start_min, ut_end_max, ut_end_min, workload, batch_times_heads, + num_row_blocks, Tc_stride, stride, row_block_size, is_causal, stream); + break; + case 96: + launch_reduce_workload_kernel<96>( + lt_start_max, lt_start_min, lt_end_max, lt_end_min, ut_start_max, + ut_start_min, ut_end_max, ut_end_min, workload, batch_times_heads, + num_row_blocks, Tc_stride, stride, row_block_size, is_causal, stream); + break; + case 128: + launch_reduce_workload_kernel<128>( + lt_start_max, lt_start_min, lt_end_max, lt_end_min, ut_start_max, + ut_start_min, ut_end_max, ut_end_min, workload, batch_times_heads, + num_row_blocks, Tc_stride, stride, row_block_size, is_causal, stream); + break; + default: + PD_THROW("Unsupported m_block_size: %d", kBlockM); + } + return {workload}; +} + +std::vector ReduceWorkloadOp( + const paddle::Tensor& lt_start_max, + const paddle::Tensor& lt_start_min, + const paddle::optional& lt_end_max, + const paddle::optional& lt_end_min, + const paddle::optional& ut_start_max, + const paddle::optional& ut_start_min, + const paddle::optional& ut_end_max, + const paddle::optional& ut_end_min, + int batch_size, + int num_heads, + int num_row_blocks, + int num_col_blocks, + int stride, + int row_block_size, + bool is_causal, + int m_block_size) { +#ifdef PADDLE_WITH_CUDA + if (lt_start_max.is_gpu()) { + return reduce_workload_cuda(lt_start_max, + lt_start_min, + lt_end_max, + lt_end_min, + ut_start_max, + ut_start_min, + ut_end_max, + ut_end_min, + batch_size, + num_heads, + num_row_blocks, + num_col_blocks, + stride, + row_block_size, + is_causal, + m_block_size); + } +#endif + PD_THROW("Unsupported device: ReduceWorkload operator is only available for CUDA."); +} + + +// ============================================================================ +// IndicesToChunks & IndicesRerank Operators +// ============================================================================ + +std::vector IndicesToChunksOp( + const paddle::Tensor& row_indices, + const paddle::Tensor& chunk_bucket_indices, + int chunk_size) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_EQ(row_indices.is_gpu(), true, + common::errors::InvalidArgument("Input 'row_indices' must be a CUDA tensor.")); + + auto chunked_result = paddle::empty_like(row_indices); + + const int num_rows = row_indices.numel(); + const int num_buckets = chunk_bucket_indices.numel(); + const int num_threads_per_block = 256; + const int num_blocks = (num_rows + num_threads_per_block - 1) / num_threads_per_block; + + indices_to_chunks_kernel<<>>( + row_indices.data(), + chunk_bucket_indices.data(), + chunked_result.data(), + num_rows, + num_buckets, + chunk_size); + + return {chunked_result}; +#else + PD_THROW("Unsupported device: IndicesToChunks operator is only available for CUDA."); +#endif +} + +std::vector IndicesRerankOp( + const paddle::Tensor& input_row_indices, + const paddle::Tensor& chunk_indices, + int batch_size, + int num_heads, + int seq_len, + int feature_dim, + int num_chunks, + int chunk_size) { +#ifdef PADDLE_WITH_CUDA + PADDLE_ENFORCE_EQ(input_row_indices.is_gpu(), true, + common::errors::InvalidArgument("Input 'input_row_indices' must be a CUDA tensor.")); + + const int output_seq_len = num_chunks * chunk_size; + auto reranked_indices = paddle::empty({batch_size, num_heads, output_seq_len, feature_dim}, + input_row_indices.dtype(), + input_row_indices.place()); + + const int total_elements = batch_size * output_seq_len * num_heads * feature_dim; + const int num_threads_per_block = 256; + const int num_blocks = (total_elements + num_threads_per_block - 1) / num_threads_per_block; + + indices_rerank_kernel<<>>( + input_row_indices.data(), + reranked_indices.data(), + chunk_indices.data(), + batch_size, + num_heads, + seq_len, + feature_dim, + num_chunks, + chunk_size); + + return {reranked_indices}; +#else + PD_THROW("Unsupported device: IndicesRerank operator is only available for CUDA."); +#endif +} + + +// ============================================================================ +// Operator Registrations +// ============================================================================ + +PD_BUILD_OP(scan_max_min) + .Inputs({"Input"}) + .Outputs({"MaxOut", "MinOut"}) + .Attrs({"head_size_rounded: int", + "seq_len_q: int", + "seq_len_k: int", + "blocksize: int", + "is_causal: bool", + "softcap: float", + "window_size_left: int", + "window_size_right: int"}) + .SetKernelFn(PD_KERNEL(ScanMaxMin)); + +PD_BUILD_OP(reduce_workload) + .Inputs({"LTStartMax", "LTStartMin", + paddle::Optional("LTEndMax"), paddle::Optional("LTEndMin"), + paddle::Optional("UTStartMax"), paddle::Optional("UTStartMin"), + paddle::Optional("UTEndMax"), paddle::Optional("UTEndMin")}) + .Outputs({"Workload"}) + .Attrs({"batch_size: int", + "num_heads: int", + "num_row_blocks: int", + "num_col_blocks: int", + "stride: int", + "row_block_size: int", + "is_causal: bool", + "m_block_size: int"}) + .SetKernelFn(PD_KERNEL(ReduceWorkloadOp)); + +PD_BUILD_OP(indices_to_chunks) + .Inputs({"RowIndices", "ChunkBucketIndices"}) + .Outputs({"ChunkedResult"}) + .Attrs({"chunk_size: int"}) + .SetKernelFn(PD_KERNEL(IndicesToChunksOp)); + +PD_BUILD_OP(indices_rerank) + .Inputs({"InputRowIndices", "ChunkIndices"}) + .Outputs({"RerankedIndices"}) + .Attrs({"batch_size: int", + "num_heads: int", + "seq_len: int", + "feature_dim: int", + "num_chunks: int", + "chunk_size: int"}) + .SetKernelFn(PD_KERNEL(IndicesRerankOp)); \ No newline at end of file diff --git a/flashmask/flash_mask/cp_balance/csrc/setup.py b/flashmask/flash_mask/cp_balance/csrc/setup.py new file mode 100644 index 00000000000..4297c58bb57 --- /dev/null +++ b/flashmask/flash_mask/cp_balance/csrc/setup.py @@ -0,0 +1,120 @@ +import os +import subprocess +import shutil +import re + + +def get_version_from_txt(): + version_file = os.path.join(os.path.dirname(__file__), "version.txt") + with open(version_file, "r") as f: + version = f.read().strip() + return version + + +def custom_version_scheme(version): + base_version = get_version_from_txt() + date_str = ( + subprocess.check_output( + ["git", "log", "-1", "--format=%cd", "--date=format:%Y%m%d"] + ) + .decode() + .strip() + ) + return f"{base_version}.dev{date_str}" + + +def no_local_scheme(version): + return "" + + +def change_pwd(): + """change_pwd""" + path = os.path.dirname(__file__) + if path: + os.chdir(path) + +def get_cuda_version(): + nvcc_path = shutil.which("nvcc") + if nvcc_path is None: + raise FileNotFoundError( + "nvcc command not found. Please make sure CUDA toolkit is installed and nvcc is in PATH." + ) + + result = subprocess.run( + ["nvcc", "--version"], + capture_output=True, + text=True, + check=True, + ) + version_output = result.stdout + + match = re.search(r"release (\d+)\.(\d+)", version_output) + if not match: + raise ValueError( + f"Cannot parse CUDA version from nvcc output:\n{version_output}" + ) + cuda_major = int(match.group(1)) + cuda_minor = int(match.group(2)) + return cuda_major, cuda_minor + + +def setup_ops_extension(): + from paddle.utils.cpp_extension import CUDAExtension, setup + + nvcc_args = [ + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "-maxrregcount=32", + "-lineinfo", + "-DCUTLASS_DEBUG_TRACE_LEVEL=0", + "-gencode=arch=compute_80,code=sm_80", + "-gencode=arch=compute_90a,code=sm_90a", + "-gencode=arch=compute_100,code=sm_100", + "-DNDEBUG", + ] + cuda_major, cuda_minor = get_cuda_version() + if cuda_major < 12: + raise ValueError( + f"CUDA version must be >= 12. Detected version: {cuda_major}.{cuda_minor}" + ) + if cuda_major == 12 and cuda_minor < 8: + nvcc_args = [arg for arg in nvcc_args if "compute_100" not in arg] + + ext_module = CUDAExtension( + sources=[ + # cpp files + # cuda files + "./cp_balance_utils.cu", + ], + include_dirs=[ + os.path.join(os.getcwd(), "./"), + ], + extra_compile_args={ + "cxx": [ + "-O3", + "-w", + "-Wno-abi", + "-fPIC", + "-std=c++17", + ], + "nvcc": nvcc_args, + }, + ) + + change_pwd() + setup( + name="flashmask_cpbalance_cudaops", + ext_modules=[ext_module], + version="0.0.1", + setup_requires=["setuptools_scm"], + ) + + +setup_ops_extension() \ No newline at end of file diff --git a/flashmask/setup.py b/flashmask/setup.py index b6940550a88..e074eaf5274 100644 --- a/flashmask/setup.py +++ b/flashmask/setup.py @@ -76,7 +76,8 @@ def _get_version(): # ============================================================ # Packages: exclude modules not being built # ============================================================ -exclude_packages = ['build', 'build.*', 'tests', 'tests.*'] +exclude_packages = ['build', 'build.*', 'tests', 'tests.*', + 'flash_mask.cp_balance.csrc', 'flash_mask.cp_balance.csrc.*'] if not BUILD_FA3: exclude_packages += [ 'flash_mask.flashmask_attention_v3', @@ -393,3 +394,25 @@ def _get_cuda_version(): paddle_setup(**setup_kwargs, ext_modules=ext_modules) else: setuptools_setup(**setup_kwargs) + +# ============================================================ +# CP Balance: CUDA extension (built via its own setup.py after main setup) +# Paddle's cpp_extension.setup only supports 1 Extension per call, +# so we invoke cp_balance's setup.py as a subprocess. +# ============================================================ +if BUILD_FA3: + cp_balance_csrc_dir = os.path.join(FLASH_MASK_DIR, 'cp_balance', 'csrc') + print("[flashmask] Building CP Balance CUDA extension...") + result = subprocess.run( + [sys.executable, 'setup.py', 'install'], + cwd=cp_balance_csrc_dir, + capture_output=True, + text=True, + ) + if result.returncode != 0: + print(f"[flashmask] CP Balance build STDERR:\n{result.stderr}") + raise RuntimeError( + f"Failed to build CP Balance CUDA extension.\n" + f"You can build it manually: cd {cp_balance_csrc_dir} && python setup.py install" + ) + print("[flashmask] CP Balance CUDA extension built successfully.") From f50215bcff9d300fa10ba03ad765b10352601fbf Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Thu, 9 Apr 2026 17:01:29 +0800 Subject: [PATCH 2/6] [Chore] Add copyright --- flashmask/flash_mask/cp_balance/__init__.py | 14 ++++++++++++++ flashmask/flash_mask/cp_balance/cp_balance.py | 14 ++++++++++++++ .../cp_balance/cp_balance_cuda_kernels.py | 14 ++++++++++++++ .../flash_mask/cp_balance/csrc/cp_balance_utils.cu | 14 ++++++++++++++ flashmask/flash_mask/cp_balance/csrc/setup.py | 14 ++++++++++++++ 5 files changed, 70 insertions(+) diff --git a/flashmask/flash_mask/cp_balance/__init__.py b/flashmask/flash_mask/cp_balance/__init__.py index 96570f1e679..6804de4266f 100644 --- a/flashmask/flash_mask/cp_balance/__init__.py +++ b/flashmask/flash_mask/cp_balance/__init__.py @@ -1,2 +1,16 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .cp_balance import balance_flashmask_input from .cp_balance_cuda_kernels import indices_rerank_cuda, indices_to_chunks_cuda diff --git a/flashmask/flash_mask/cp_balance/cp_balance.py b/flashmask/flash_mask/cp_balance/cp_balance.py index 4175262b327..504f0fc7266 100644 --- a/flashmask/flash_mask/cp_balance/cp_balance.py +++ b/flashmask/flash_mask/cp_balance/cp_balance.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import heapq import paddle import numpy as np diff --git a/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py b/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py index f0487a445db..ffc4a67e3b8 100644 --- a/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py +++ b/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import flashmask_cpbalance_cudaops as cp_balance_ops def scanMaxMinChunkedKernel(input_tensor, Bc, B, H, S): diff --git a/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu b/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu index 200787d0e58..231e00a4aaa 100644 --- a/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu +++ b/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu @@ -1,3 +1,17 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "paddle/extension.h" #define CHECK_CUDA_INPUT(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.") diff --git a/flashmask/flash_mask/cp_balance/csrc/setup.py b/flashmask/flash_mask/cp_balance/csrc/setup.py index 4297c58bb57..eac80bf937a 100644 --- a/flashmask/flash_mask/cp_balance/csrc/setup.py +++ b/flashmask/flash_mask/cp_balance/csrc/setup.py @@ -1,3 +1,17 @@ +# Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import subprocess import shutil From 2e31ccb9e18478a38953c834520acd0ca5afb6e0 Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Fri, 10 Apr 2026 11:18:16 +0800 Subject: [PATCH 3/6] [Major] Single package management with build_ext and .so 'hacking' --- flashmask/flash_mask/cp_balance/.gitignore | 1 + flashmask/flash_mask/cp_balance/cp_balance.py | 17 --- .../cp_balance/cp_balance_cuda_kernels.py | 8 +- flashmask/setup.py | 125 +++++++++++++++--- 4 files changed, 108 insertions(+), 43 deletions(-) create mode 100644 flashmask/flash_mask/cp_balance/.gitignore diff --git a/flashmask/flash_mask/cp_balance/.gitignore b/flashmask/flash_mask/cp_balance/.gitignore new file mode 100644 index 00000000000..b9d6fe4f85b --- /dev/null +++ b/flashmask/flash_mask/cp_balance/.gitignore @@ -0,0 +1 @@ +flashmask_cpbalance_cudaops.py \ No newline at end of file diff --git a/flashmask/flash_mask/cp_balance/cp_balance.py b/flashmask/flash_mask/cp_balance/cp_balance.py index 504f0fc7266..dbe133d8a99 100644 --- a/flashmask/flash_mask/cp_balance/cp_balance.py +++ b/flashmask/flash_mask/cp_balance/cp_balance.py @@ -17,25 +17,8 @@ import numpy as np from .cp_balance_cuda_kernels import scanMaxMinChunkedKernel, reduce_workload, indices_to_chunks_cuda, indices_rerank_cuda import paddle.distributed as dist -import hashlib from typing import List, Tuple, Dict, Optional -# --- 调试辅助函数 --- - -def save_tensor(x: paddle.Tensor, name: str): - """将 Paddle Tensor 保存为 txt 文件,用于调试。""" - x_np = x.numpy() - np.savetxt(f'{name}.txt', x_np.reshape(-1, x_np.shape[-1]), fmt='%d') - -def tensor_md5(tensor: paddle.Tensor) -> str: - """计算 Paddle Tensor 的 MD5 哈希值,用于验证数据一致性。""" - x_bytes = tensor.numpy().tobytes() - md5_hash = hashlib.md5(x_bytes).hexdigest() - print(f"Tensor MD5: {md5_hash}") - return md5_hash - -# --- 核心工作负载计算与分配 --- - def get_q_workload( start_row_indices: paddle.Tensor, q_chunk_size: int, diff --git a/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py b/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py index ffc4a67e3b8..1113164920d 100644 --- a/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py +++ b/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import flashmask_cpbalance_cudaops as cp_balance_ops - +from . import flashmask_cpbalance_cudaops as cp_balance_ops + def scanMaxMinChunkedKernel(input_tensor, Bc, B, H, S): maxo,mino = cp_balance_ops.scan_max_min( input_tensor, @@ -41,12 +41,12 @@ def reduce_workload(start_row_maxmin_indice_list, B, H, Tr, Tc, Br, S): UTEndMax, UTEndMin, ) = start_row_maxmin_indice_list - + workload = cp_balance_ops.reduce_workload( LTStartMax, LTStartMin, LTEndMax, LTEndMin, UTStartMax, UTStartMin, UTEndMax, UTEndMin, B, H, Tr, Tc, S, Br, False, 128 ) - + return workload def indices_to_chunks_cuda(startend_row_indices, bucket_idx, chunksize=2048): diff --git a/flashmask/setup.py b/flashmask/setup.py index e074eaf5274..8056e178fbf 100644 --- a/flashmask/setup.py +++ b/flashmask/setup.py @@ -30,6 +30,8 @@ import os import sys import subprocess +import shutil +import glob from setuptools import setup as setuptools_setup, find_packages @@ -376,6 +378,106 @@ def _get_cuda_version(): ) ) +# ============================================================ +# CUDA submodule builder +# ============================================================ +# Some submodules need different nvcc flags than the main FA3 extension +# (e.g., cp_balance needs sm_80/sm_90a/sm_100 while FA3 targets sm_90a only). +# Paddle's CUDAExtension applies the same flags to ALL sources, so these +# submodules must be compiled independently. This function handles: +# 1. Run the submodule's own setup.py build_ext +# 2. Copy the resulting .so into the Python package directory +# 3. Return the package name for package_data (so the .so ships in the wheel) +# +# To add a new submodule, just call _build_cuda_submodule() and append +# the returned package name to _submodule_package_data. + +def _build_cuda_submodule(name, csrc_dir, pkg_dir): + """Build a CUDA submodule and copy outputs into its package directory. + + Paddle's build_ext produces in build/: + - {module_name}.so — compiled CUDA binary (no _pd_ suffix) + - {module_name}.py — Python wrapper that loads {module_name}_pd_.so + The wrapper hardcodes the _pd_ filename, so we rename the .so when copying. + + Args: + name: Human-readable name for log messages. + csrc_dir: Directory containing the submodule's setup.py. + pkg_dir: Python package directory to copy outputs into. + + Returns: + Package name (dot-separated) for package_data, or None if skipped. + """ + if not os.path.isdir(csrc_dir): + print(f"[flashmask] {name}: csrc directory not found, skipping.") + return None + + print(f"[flashmask] Building {name} CUDA extension...") + result = subprocess.run( + [sys.executable, 'setup.py', 'build_ext'], + cwd=csrc_dir, capture_output=True, text=True, + ) + if result.returncode != 0: + print(f"[flashmask] {name} build STDERR:\n{result.stderr}") + raise RuntimeError( + f"Failed to build {name} CUDA extension.\n" + f"Build manually: cd {csrc_dir} && python setup.py build_ext" + ) + + # Find the .so and wrapper .py in build/ + so_files = glob.glob(os.path.join(csrc_dir, 'build', '**', '*.so'), recursive=True) + if not so_files: + raise RuntimeError( + f"{name} build_ext succeeded but no .so found under " + f"{os.path.join(csrc_dir, 'build')}/" + ) + so_path = so_files[0] + module_name = os.path.basename(so_path).replace('.so', '') + wrapper_path = os.path.join(os.path.dirname(so_path), f'{module_name}.py') + if not os.path.exists(wrapper_path): + raise RuntimeError( + f"{name}: Paddle-generated wrapper {module_name}.py not found " + f"alongside {so_path}" + ) + + # Copy to pkg_dir. Rename .so to add _pd_ suffix (wrapper hardcodes this name). + shutil.copy2(so_path, os.path.join(pkg_dir, f'{module_name}_pd_.so')) + shutil.copy2(wrapper_path, pkg_dir) + print(f"[flashmask] {name} built: {module_name}_pd_.so + {module_name}.py") + + # Clean up build artifacts from csrc_dir + for _d in glob.glob(os.path.join(csrc_dir, 'build')) + \ + glob.glob(os.path.join(csrc_dir, '*.egg-info')): + shutil.rmtree(_d, ignore_errors=True) + # Also clean any _pd_.so / wrapper .py that Paddle may leave in csrc_dir + for _f in glob.glob(os.path.join(csrc_dir, '*_pd_.so')) + \ + glob.glob(os.path.join(csrc_dir, f'{module_name}.py')): + os.remove(_f) + + # Derive package name from pkg_dir relative to ROOT_DIR + # e.g. flash_mask/cp_balance -> flash_mask.cp_balance + return os.path.relpath(pkg_dir, ROOT_DIR).replace(os.sep, '.') + + +# ============================================================ +# Build CUDA submodules +# ============================================================ +_submodule_package_data = {} + +# --- cp_balance: needs sm_80/sm_90a/sm_100 (multi-arch) --- +_pkg = _build_cuda_submodule( + 'CP Balance', + csrc_dir=os.path.join(FLASH_MASK_DIR, 'cp_balance', 'csrc'), + pkg_dir=os.path.join(FLASH_MASK_DIR, 'cp_balance'), +) +if _pkg: + _submodule_package_data[_pkg] = ['*.so'] + +# To add future submodules, just repeat: +# _pkg = _build_cuda_submodule('Name', csrc_dir=..., pkg_dir=...) +# if _pkg: +# _submodule_package_data[_pkg] = ['*.so'] + # ============================================================ # Build: use paddle's setup when building FA3, plain setuptools otherwise # ============================================================ @@ -383,6 +485,7 @@ def _get_cuda_version(): name='flash_mask', version=VERSION, packages=packages, + package_data=_submodule_package_data, author='PaddlePaddle', description='FlashMask: Efficient and Rich Mask Extension of FlashAttention', install_requires=install_requires, @@ -394,25 +497,3 @@ def _get_cuda_version(): paddle_setup(**setup_kwargs, ext_modules=ext_modules) else: setuptools_setup(**setup_kwargs) - -# ============================================================ -# CP Balance: CUDA extension (built via its own setup.py after main setup) -# Paddle's cpp_extension.setup only supports 1 Extension per call, -# so we invoke cp_balance's setup.py as a subprocess. -# ============================================================ -if BUILD_FA3: - cp_balance_csrc_dir = os.path.join(FLASH_MASK_DIR, 'cp_balance', 'csrc') - print("[flashmask] Building CP Balance CUDA extension...") - result = subprocess.run( - [sys.executable, 'setup.py', 'install'], - cwd=cp_balance_csrc_dir, - capture_output=True, - text=True, - ) - if result.returncode != 0: - print(f"[flashmask] CP Balance build STDERR:\n{result.stderr}") - raise RuntimeError( - f"Failed to build CP Balance CUDA extension.\n" - f"You can build it manually: cd {cp_balance_csrc_dir} && python setup.py install" - ) - print("[flashmask] CP Balance CUDA extension built successfully.") From dd653fdf88d55835bb04f9a9df96dd6d48decd39 Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Mon, 13 Apr 2026 00:04:47 +0800 Subject: [PATCH 4/6] [Trial] Try IPO solver for CP-balance --- flashmask/flash_mask/cp_balance/cp_balance.py | 68 +- .../cp_balance/cp_balance_cuda_kernels.py | 8 + .../cp_balance/csrc/cp_balance_fast.hpp | 579 ++++++++++++++++++ .../cp_balance/csrc/cp_balance_ipo_op.cpp | 38 ++ flashmask/flash_mask/cp_balance/csrc/setup.py | 3 +- 5 files changed, 689 insertions(+), 7 deletions(-) create mode 100644 flashmask/flash_mask/cp_balance/csrc/cp_balance_fast.hpp create mode 100644 flashmask/flash_mask/cp_balance/csrc/cp_balance_ipo_op.cpp diff --git a/flashmask/flash_mask/cp_balance/cp_balance.py b/flashmask/flash_mask/cp_balance/cp_balance.py index dbe133d8a99..04a84fe3525 100644 --- a/flashmask/flash_mask/cp_balance/cp_balance.py +++ b/flashmask/flash_mask/cp_balance/cp_balance.py @@ -15,7 +15,7 @@ import heapq import paddle import numpy as np -from .cp_balance_cuda_kernels import scanMaxMinChunkedKernel, reduce_workload, indices_to_chunks_cuda, indices_rerank_cuda +from .cp_balance_cuda_kernels import scanMaxMinChunkedKernel, reduce_workload, indices_to_chunks_cuda, indices_rerank_cuda, cp_balance_ipo_solve import paddle.distributed as dist from typing import List, Tuple, Dict, Optional @@ -184,6 +184,56 @@ def assign_tasks_heap( return buckets, bucket_weights, cuts +def assign_tasks_ipo( + tasks: np.ndarray, + num_buckets: int +) -> Tuple[List[List[Tuple[int, int]]], List[int], int]: + """ + 使用 IPO (Iterative Pairwise/Triple Optimal) 最优求解器分配任务。 + 接口与 assign_tasks_heap 完全一致。 + + 当 N > 512 或 N % num_buckets != 0 时自动 fallback 到 assign_tasks_heap。 + + Args: + tasks (np.ndarray): 形状为 (N, 2) 的任务数组,每行是 [weight, index]。 + num_buckets (int): 桶的数量。 + + Returns: + 与 assign_tasks_heap 相同的三元组 (buckets, bucket_weights, cuts)。 + """ + n = len(tasks) + if n == 0 or n > 512 or n % num_buckets != 0: + return assign_tasks_heap(tasks, num_buckets) + + K = n // num_buckets + weights = np.array([t[0] for t in tasks], dtype=np.int32) + + # 调用 C++ IPO solver + # assign_matrix: (num_buckets, K),每个元素是 item index (0..N-1) + assign_matrix, _ = cp_balance_ipo_solve(weights, num_buckets) + + buckets = [] + bucket_weights = [] + for j in range(num_buckets): + bucket = [] + bw = 0 + for t in range(K): + idx = int(assign_matrix[j, t]) + w = int(tasks[idx][0]) + chunk_idx = int(tasks[idx][1]) + bucket.append((w, chunk_idx)) + bw += w + bucket.sort(key=lambda x: x[1]) + buckets.append(bucket) + bucket_weights.append(bw) + + # 统计切分次数 + all_idx = sorted([idx for b in buckets for _, idx in b]) + cuts = sum(1 for i in range(1, len(all_idx)) if all_idx[i] != all_idx[i - 1] + 1) + + return buckets, bucket_weights, cuts + + # --- 数据通信与重排辅助函数 --- def get_send_dict(buckets: List[List[Tuple[int, int]]], cp_size: int, rank: int) -> Dict[int, List[int]]: @@ -317,7 +367,8 @@ def balance_flashmask_input( cp_rank: int, balance_chunk_size: int = 2048, q_block_size: int = 128, - k_block_size: int = 128 + k_block_size: int = 128, + use_ipo: bool = False ) -> Tuple[paddle.Tensor, List[List[Tuple[int, int]]]]: """ FlashMask 输入数据的负载均衡主流程。 @@ -330,6 +381,8 @@ def balance_flashmask_input( balance_chunk_size (int): 用于负载均衡分析和数据移动的块大小。 q_block_size (int): FlashAttention kernel 的 query 块大小。 k_block_size (int): FlashAttention kernel 的 key 块大小。 + use_ipo (bool): 是否使用 IPO 最优求解器替代 LPT 贪心。 + N > 512 或 N % cp_size != 0 时自动 fallback 到 LPT。 Returns: Tuple: @@ -343,11 +396,14 @@ def balance_flashmask_input( workload = get_q_workload(startend_row_indices, balance_chunk_size, q_block_size, k_block_size) paddle.base.core.nvprof_nvtx_pop() - # 步骤 2: 使用堆贪心算法在 CPU 上进行任务分配 - paddle.base.core.nvprof_nvtx_push("assign_tasks_heap") - # 将 workload tensor 转换成 numpy 数组以用于 heapq + # 步骤 2: 任务分配(IPO 最优求解 或 LPT 贪心) + paddle.base.core.nvprof_nvtx_push("assign_tasks") + # 将 workload tensor 转换成 numpy 数组 tasks_np = workload.reshape([-1, 2]).cpu().numpy() - buckets, _, _ = assign_tasks_heap(tasks_np, cp_size) + if use_ipo: + buckets, _, _ = assign_tasks_ipo(tasks_np, cp_size) + else: + buckets, _, _ = assign_tasks_heap(tasks_np, cp_size) paddle.base.core.nvprof_nvtx_pop() # 步骤 5: 根据全局分配方案 `buckets`,对原始索引张量进行重排 (Gather) diff --git a/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py b/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py index 1113164920d..d93645afbf2 100644 --- a/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py +++ b/flashmask/flash_mask/cp_balance/cp_balance_cuda_kernels.py @@ -58,3 +58,11 @@ def indices_rerank_cuda(startend_row_indices, indices, balance_chunk_size=2048): num_chunks = (S + balance_chunk_size - 1) // balance_chunk_size startend_row_indices_rerank = cp_balance_ops.indices_rerank(startend_row_indices, indices, B, H, S,D,num_chunks,balance_chunk_size) return startend_row_indices_rerank + + +def cp_balance_ipo_solve(weights_np, M): + """调用 IPO 最优求解器。weights_np: 1-D int32 numpy, M: int。返回 (assign_matrix, max_load)。""" + import paddle + weights_t = paddle.to_tensor(weights_np, dtype='int32', place=paddle.CPUPlace()) + assign_t, ml_t = cp_balance_ops.cp_balance_ipo(weights_t, M) + return assign_t.numpy(), ml_t.numpy().item() diff --git a/flashmask/flash_mask/cp_balance/csrc/cp_balance_fast.hpp b/flashmask/flash_mask/cp_balance/csrc/cp_balance_fast.hpp new file mode 100644 index 00000000000..d21849257ac --- /dev/null +++ b/flashmask/flash_mask/cp_balance/csrc/cp_balance_fast.hpp @@ -0,0 +1,579 @@ +#pragma once +// CpBalanceSolver — Production-ready CP load balancing solver +// +// Solves: N items with integer weights -> M workers (K = N/M items each), +// minimize the maximum worker load. +// +// Constraints: M divides N, N <= 512, M <= 32, K <= 16. +// +// Usage: +// // C++ with vector +// auto result = CpBalanceSolver::solve({100, 80, 60, 40, 30, 20, 15, 10}, 2); +// // result.max_load, result.assign[worker] = {item indices...} +// +// // Raw pointer (for Paddle/Python C extension, zero-copy to tensor) +// int out[M * K]; +// int max_load = CpBalanceSolver::solve_to(weights.data(), N, M, out); +// +// Paddle custom op example: +// +// #include "paddle/extension.h" +// #include "cp_balance_fast.hpp" +// +// std::vector CpBalanceOp(const paddle::Tensor& weights, int64_t M) { +// int N = weights.shape()[0], K = N / M; +// auto assign = paddle::empty({M, K}, paddle::DataType::INT32, weights.place()); +// int max_load = CpBalanceSolver::solve_to( +// weights.data(), N, M, assign.data()); +// return {assign, paddle::full({1}, max_load, paddle::DataType::INT32)}; +// } +// +// PD_BUILD_OP(cp_balance) +// .Inputs({"Weights"}).Attrs({"M: int"}) +// .Outputs({"Assign", "MaxLoad"}) +// .SetKernelFn(PD_KERNEL(CpBalanceOp)); + +#include +#include +#include +#include +#include +#include + +class CpBalanceSolver { +public: + // ===== Limits ===== + static constexpr int kMaxN = 512; + static constexpr int kMaxM = 32; + static constexpr int kMaxK = kMaxN / 2; + static constexpr int kMax2K = 34; // 2-way subproblem max items + static constexpr int kMax3K = 50; // 3-way subproblem max items + + // ===== Result (general C++ API) ===== + struct Result { + int max_load = 0; + std::vector> assign; // assign[worker] = item indices + }; + + // Solve and return Result with vector-based assignment. + [[nodiscard]] static Result solve(const std::vector& weights, int M) { + int N = static_cast(weights.size()); + Result r; + if (N <= 0 || M <= 0 || N % M != 0) return r; + int K = N / M; + + // Run core solver into stack arrays + std::array, kMaxM> assign{}; + std::array count{}, load{}; + int max_load = solve_core(weights.data(), N, M, K, assign, count, load); + + // Build result + r.max_load = max_load; + r.assign.resize(M); + for (int j = 0; j < M; j++) { + r.assign[j].assign(assign[j].begin(), assign[j].begin() + count[j]); + } + return r; + } + + // Solve and write assignment directly to a flat buffer. + // out_assign: pre-allocated buffer of size [M * K], row-major. + // out_assign[j * K + t] = item index for worker j, slot t. + // Returns max_load (0 if input invalid). + [[nodiscard]] static int solve_to(const int* weights, int N, int M, int* out_assign) { + if (N <= 0 || M <= 0 || N % M != 0) return 0; + int K = N / M; + + std::array, kMaxM> assign{}; + std::array count{}, load{}; + int max_load = solve_core(weights, N, M, K, assign, count, load); + + // Copy to flat output + for (int j = 0; j < M; j++) { + std::copy_n(assign[j].begin(), K, out_assign + j * K); + } + return max_load; + } + +private: + // ===== Internal sub-solver result types ===== + struct Part2 { + int max_load; + std::array in_group0{}; + }; + + struct Part3 { + int max_load; + std::array group{}; + }; + + // ===== Deterministic sort comparator: weight desc, index asc ===== + template + static auto desc_weight_asc_index(const W& w) { + return [&](int a, int b) { + return w[a] > w[b] || (w[a] == w[b] && a < b); + }; + } + + // ===== Meet-in-the-Middle 2-way solver (K <= 10) ===== + static Part2 mitm_2way(const int* items, int n_items) { + int K = n_items / 2; + int total = 0; + for (int i = 0; i < n_items; i++) total += items[i]; + + int size_a = K, size_b = n_items - K; + + // Enumerate B-half subsets, bucket by popcount, sort by sum + struct SubsetInfo { int sum; uint16_t mask; }; + static thread_local std::vector buckets[17]; + for (int i = 0; i <= size_b; i++) buckets[i].clear(); + + for (int mask = 0; mask < (1 << size_b); mask++) { + int sum = 0; + for (int i = 0; i < size_b; i++) { + if (mask & (1 << i)) sum += items[K + i]; + } + buckets[__builtin_popcount(mask)].push_back({sum, static_cast(mask)}); + } + for (int i = 0; i <= size_b; i++) { + std::sort(buckets[i].begin(), buckets[i].end(), + [](const auto& a, const auto& b) { return a.sum < b.sum; }); + } + + // Enumerate A-half subsets, binary-search B for best match + int best_max = total; + uint32_t best_mask_a = 0; + uint16_t best_mask_b = 0; + + for (int mask_a = 0; mask_a < (1 << size_a); mask_a++) { + int count_a = __builtin_popcount(mask_a); + int sum_a = 0; + for (int i = 0; i < size_a; i++) { + if (mask_a & (1 << i)) sum_a += items[i]; + } + + int need = K - count_a; + if (need < 0 || need > size_b) continue; + auto& bucket = buckets[need]; + if (bucket.empty()) continue; + + int want = (total + 1) / 2 - sum_a; + int lo = 0, hi = static_cast(bucket.size()); + while (lo < hi) { + int mid = (lo + hi) / 2; + if (bucket[mid].sum < want) lo = mid + 1; + else hi = mid; + } + + for (int delta : {-1, 0}) { + int idx = lo + delta; + if (idx < 0 || idx >= static_cast(bucket.size())) continue; + int g0_sum = sum_a + bucket[idx].sum; + int cur_max = std::max(g0_sum, total - g0_sum); + if (cur_max < best_max) { + best_max = cur_max; + best_mask_a = mask_a; + best_mask_b = bucket[idx].mask; + } + } + } + + Part2 result; + result.max_load = best_max; + for (int i = 0; i < size_a; i++) { + if (best_mask_a & (1 << i)) result.in_group0[i] = true; + } + for (int i = 0; i < size_b; i++) { + if (best_mask_b & (1 << i)) result.in_group0[K + i] = true; + } + return result; + } + + // ===== BnB 2-way solver (K > 10): binary search + DFS ===== + static Part2 bnb_2way(const int* items, int n_items) { + int K = n_items / 2; + int total = 0; + for (int i = 0; i < n_items; i++) total += items[i]; + int lower_bound = std::max((total + 1) / 2, items[0]); + + // Sort items descending (large first → prune earlier) + int ord[kMax2K], sorted_w[kMax2K], suffix_sum[kMax2K + 1]; + std::iota(ord, ord + n_items, 0); + std::sort(ord, ord + n_items, desc_weight_asc_index(items)); + for (int i = 0; i < n_items; i++) sorted_w[i] = items[ord[i]]; + suffix_sum[n_items] = 0; + for (int i = n_items - 1; i >= 0; i--) { + suffix_sum[i] = suffix_sum[i + 1] + sorted_w[i]; + } + + // LPT upper bound + int lpt_load[2] = {}, lpt_count[2] = {}, lpt_assign[kMax2K]; + for (int i = 0; i < n_items; i++) { + int g = (lpt_count[0] >= K) ? 1 + : (lpt_count[1] >= K) ? 0 + : (lpt_load[0] <= lpt_load[1]) ? 0 : 1; + lpt_load[g] += sorted_w[i]; + lpt_count[g]++; + lpt_assign[i] = g; + } + int upper_bound = std::max(lpt_load[0], lpt_load[1]); + + Part2 best; + best.max_load = upper_bound; + for (int i = 0; i < n_items; i++) { + if (lpt_assign[i] == 0) best.in_group0[ord[i]] = true; + } + if (lower_bound == upper_bound) return best; + + // BnB state + int load[2], count[2], assign[kMax2K]; + int64_t node_count; + bool timed_out; + static constexpr int64_t kNodeLimit = 500000; + int target; + + auto search = [&](auto& self, int pos) -> bool { + if (++node_count > kNodeLimit) { timed_out = true; return false; } + if (pos == n_items) return true; + + int cur_w = sorted_w[pos]; + int total_cap = 0, n_open = 0, last_open = -1; + for (int j = 0; j < 2; j++) { + if (count[j] >= K) continue; + total_cap += target - load[j]; + n_open++; + last_open = j; + } + + // Suffix-sum pruning + if (suffix_sum[pos] > total_cap) return false; + + // Single-group forced + if (n_open == 1) { + if (n_items - pos != K - count[last_open]) return false; + if (suffix_sum[pos] + load[last_open] > target) return false; + for (int p = pos; p < n_items; p++) assign[p] = last_open; + return true; + } + + // Equal-value pruning + int start_group = 0; + if (pos > 0 && sorted_w[pos] == sorted_w[pos - 1]) { + start_group = assign[pos - 1]; + } + + for (int j = start_group; j < 2; j++) { + if (count[j] >= K || load[j] + cur_w > target) continue; + // Symmetric pruning + if (j == 1 && load[0] == load[1] && count[0] == count[1]) continue; + + load[j] += cur_w; count[j]++; assign[pos] = j; + if (self(self, pos + 1)) return true; + count[j]--; load[j] -= cur_w; + } + return false; + }; + + // Binary search for minimum feasible target + for (int lo = lower_bound, hi = upper_bound - 1; lo <= hi;) { + int mid = (lo + hi) / 2; + target = mid; + std::memset(load, 0, sizeof(load)); + std::memset(count, 0, sizeof(count)); + node_count = 0; timed_out = false; + + if (search(search, 0)) { + best.max_load = mid; + best.in_group0 = {}; + for (int i = 0; i < n_items; i++) { + if (assign[i] == 0) best.in_group0[ord[i]] = true; + } + hi = mid - 1; + } else { + lo = mid + 1; + } + } + return best; + } + + // Dispatch: MITM for K <= 10, BnB for K > 10 + static Part2 partition_2way(const int* items, int n_items) { + return (n_items / 2 <= 10) ? mitm_2way(items, n_items) : bnb_2way(items, n_items); + } + + // ===== Insertion sort for small arrays (faster than std::sort for N <= 24) ===== + static void isort_desc(int* ord, int n, const int* w) { + for (int i = 1; i < n; i++) { + int key = ord[i], key_w = w[key]; + int j = i - 1; + while (j >= 0 && (w[ord[j]] < key_w || (w[ord[j]] == key_w && ord[j] > key))) { + ord[j + 1] = ord[j]; j--; + } + ord[j + 1] = key; + } + } + + // ===== BnB 3-way target solver: "can 3K items → 3 groups of K, each ≤ target?" ===== + static Part3 bnb_3way_target(const int* weights, int n_items, int target) { + int K = n_items / 3; + int total = 0; + for (int i = 0; i < n_items; i++) total += weights[i]; + + int ord[kMax3K], sorted_w[kMax3K], suffix_sum[kMax3K + 1]; + std::iota(ord, ord + n_items, 0); + isort_desc(ord, n_items, weights); + for (int i = 0; i < n_items; i++) sorted_w[i] = weights[ord[i]]; + suffix_sum[n_items] = 0; + for (int i = n_items - 1; i >= 0; i--) { + suffix_sum[i] = suffix_sum[i + 1] + sorted_w[i]; + } + + int lower_bound = std::max({(total + 2) / 3, sorted_w[0]}); + Part3 result; + result.max_load = target + 1; // infeasible by default + if (target < lower_bound) return result; + + int load[3] = {}, count[3] = {}, assign[kMax3K]; + int64_t node_count = 0; + bool timed_out = false; + static constexpr int64_t kNodeLimit = 2000000; + + auto search = [&](auto& self, int pos) -> bool { + if (++node_count > kNodeLimit) { timed_out = true; return false; } + if (pos == n_items) return true; + + int remaining = n_items - pos; + int cur_w = sorted_w[pos]; + int total_cap = 0, n_open = 0, last_open = -1; + + for (int j = 0; j < 3; j++) { + if (count[j] >= K) continue; + int need = K - count[j]; + int tail_start = n_items - need; + // Tail pruning + if (tail_start < pos) { + if (remaining < need) return false; + if (suffix_sum[pos] + load[j] > target) return false; + } else { + if (suffix_sum[tail_start] + load[j] > target) return false; + } + total_cap += target - load[j]; + n_open++; + last_open = j; + } + + // Suffix-sum pruning + if (suffix_sum[pos] > total_cap) return false; + + // Single-group forced + if (n_open == 1) { + if (remaining != K - count[last_open]) return false; + if (suffix_sum[pos] + load[last_open] > target) return false; + for (int p = pos; p < n_items; p++) assign[p] = last_open; + return true; + } + + // Equal-value pruning + int start_group = 0; + if (pos > 0 && sorted_w[pos] == sorted_w[pos - 1]) { + start_group = assign[pos - 1]; + } + + // Symmetric pruning: skip duplicate (load, count) states + int seen_load[3], seen_count[3], n_seen = 0; + for (int j = start_group; j < 3; j++) { + if (count[j] >= K || load[j] + cur_w > target) continue; + + bool dup = false; + for (int s = 0; s < n_seen; s++) { + if (seen_load[s] == load[j] && seen_count[s] == count[j]) { dup = true; break; } + } + if (dup) continue; + seen_load[n_seen] = load[j]; + seen_count[n_seen] = count[j]; + n_seen++; + + load[j] += cur_w; count[j]++; assign[pos] = j; + if (self(self, pos + 1)) return true; + count[j]--; load[j] -= cur_w; + } + return false; + }; + + if (search(search, 0)) { + int gl[3] = {}; + for (int p = 0; p < n_items; p++) gl[assign[p]] += sorted_w[p]; + result.max_load = std::max({gl[0], gl[1], gl[2]}); + for (int i = 0; i < n_items; i++) result.group[ord[i]] = assign[i]; + } + return result; + } + + // ===== Core solver: writes into caller-provided stack arrays ===== + static int solve_core( + const int* w, int N, int M, int K, + std::array, kMaxM>& assign, + std::array& count, + std::array& load) + { + // LPT initialization: sort descending, assign to lightest worker + int ord[kMaxN]; + std::iota(ord, ord + N, 0); + std::sort(ord, ord + N, desc_weight_asc_index(w)); + + for (int i = 0; i < N; i++) { + int idx = ord[i]; + int lightest = -1; + for (int j = 0; j < M; j++) { + if (count[j] < K && (lightest < 0 || load[j] < load[lightest])) { + lightest = j; + } + } + assign[lightest][count[lightest]++] = idx; + load[lightest] += w[idx]; + } + + // Skip optimization if K > 16 (subproblems too expensive) + if (K > 16) { + return *std::max_element(load.begin(), load.begin() + M); + } + + // Lower bound: no solution can beat this + int total = 0; + for (int i = 0; i < N; i++) total += w[i]; + int lower_bound = std::max((total + M - 1) / M, N > 0 ? w[ord[0]] : 0); + + int max_rounds = (K <= 10) ? 200 : ((K <= 14) ? 50 : 20); + + for (int round = 0; round < max_rounds; round++) { + bool any_improved = false; + + // ---- Phase 2: Pairwise exact repartition ---- + for (bool improved = true; improved;) { + improved = false; + + // Find heaviest worker + int cur_max = 0, heavy = 0; + for (int j = 0; j < M; j++) { + if (load[j] > cur_max) { cur_max = load[j]; heavy = j; } + } + if (cur_max <= lower_bound) break; + + // Precompute top-1/top-2 load among non-heavy workers + int omax1 = 0, omax1_idx = -1, omax2 = 0; + for (int j = 0; j < M; j++) { + if (j == heavy) continue; + if (load[j] > omax1) { omax2 = omax1; omax1 = load[j]; omax1_idx = j; } + else if (load[j] > omax2) omax2 = load[j]; + } + + // Try each partner, pick the best + int best_partner = -1, best_new_max = cur_max; + Part2 best_split; + for (int j = 0; j < M; j++) { + if (j == heavy) continue; + + int pw[kMax2K]; + int p = 0; + for (int t = 0; t < K; t++) pw[p++] = w[assign[heavy][t]]; + for (int t = 0; t < K; t++) pw[p++] = w[assign[j][t]]; + + auto split = partition_2way(pw, 2 * K); + int omax = (j == omax1_idx) ? omax2 : omax1; + int new_max = std::max(omax, split.max_load); + if (new_max < best_new_max) { + best_new_max = new_max; + best_partner = j; + best_split = split; + } + } + + // Apply best split + if (best_partner >= 0) { + int indices[kMax2K]; + int p = 0; + for (int t = 0; t < K; t++) indices[p++] = assign[heavy][t]; + for (int t = 0; t < K; t++) indices[p++] = assign[best_partner][t]; + + count[heavy] = count[best_partner] = 0; + load[heavy] = load[best_partner] = 0; + for (int t = 0; t < 2 * K; t++) { + int wk = best_split.in_group0[t] ? heavy : best_partner; + assign[wk][count[wk]++] = indices[t]; + load[wk] += w[indices[t]]; + } + improved = true; + any_improved = true; + } + } + + // ---- Phase 3: Triple exact repartition (K <= 8 only) ---- + if (M >= 3 && K <= 8) { + int cur_max = 0, heavy = 0; + for (int j = 0; j < M; j++) { + if (load[j] > cur_max) { cur_max = load[j]; heavy = j; } + } + if (cur_max <= lower_bound) break; + + // Sort partners by load ascending (try lightest first) + int partners[kMaxM], np = 0; + for (int j = 0; j < M; j++) { + if (j != heavy) partners[np++] = j; + } + std::sort(partners, partners + np, [&](int a, int b) { + return load[a] < load[b] || (load[a] == load[b] && a < b); + }); + + bool improved = false; + for (int pi = 0; pi < np && !improved; pi++) { + int p1 = partners[pi]; + for (int pk = pi + 1; pk < np && !improved; pk++) { + int p2 = partners[pk]; + + // Bystander pruning + int omax = 0; + for (int q = np - 1; q >= 0; q--) { + if (q != pi && q != pk) { omax = load[partners[q]]; break; } + } + if (omax >= cur_max) continue; + + // Lower-bound pruning + int tri_sum = load[heavy] + load[p1] + load[p2]; + int lw_max = 0; + for (int t = 0; t < K; t++) { + lw_max = std::max(lw_max, w[assign[heavy][t]]); + lw_max = std::max(lw_max, w[assign[p1][t]]); + lw_max = std::max(lw_max, w[assign[p2][t]]); + } + if (std::max(omax, std::max((tri_sum + 2) / 3, lw_max)) >= cur_max) continue; + + // Solve 3-way feasibility + int tw[kMax3K], ti[kMax3K]; + int p = 0; + for (int t = 0; t < K; t++) { ti[p] = assign[heavy][t]; tw[p] = w[ti[p]]; p++; } + for (int t = 0; t < K; t++) { ti[p] = assign[p1][t]; tw[p] = w[ti[p]]; p++; } + for (int t = 0; t < K; t++) { ti[p] = assign[p2][t]; tw[p] = w[ti[p]]; p++; } + + auto s3 = bnb_3way_target(tw, 3 * K, cur_max - 1); + if (std::max(omax, s3.max_load) < cur_max) { + int wmap[3] = {heavy, p1, p2}; + count[heavy] = count[p1] = count[p2] = 0; + load[heavy] = load[p1] = load[p2] = 0; + for (int t = 0; t < 3 * K; t++) { + int wk = wmap[s3.group[t]]; + assign[wk][count[wk]++] = ti[t]; + load[wk] += w[ti[t]]; + } + improved = true; + any_improved = true; + } + } + } + } + + if (!any_improved) break; + } + + return *std::max_element(load.begin(), load.begin() + M); + } +}; diff --git a/flashmask/flash_mask/cp_balance/csrc/cp_balance_ipo_op.cpp b/flashmask/flash_mask/cp_balance/csrc/cp_balance_ipo_op.cpp new file mode 100644 index 00000000000..f1b819166f5 --- /dev/null +++ b/flashmask/flash_mask/cp_balance/csrc/cp_balance_ipo_op.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2026 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/extension.h" +#include "cp_balance_fast.hpp" + +std::vector CpBalanceIpoKernel( + const paddle::Tensor& weights, int64_t M) { + int N = static_cast(weights.shape()[0]); + int K = N / static_cast(M); + + auto assign_out = paddle::empty( + {static_cast(M), static_cast(K)}, + paddle::DataType::INT32, paddle::CPUPlace()); + + int max_load = CpBalanceSolver::solve_to( + weights.data(), N, static_cast(M), assign_out.data()); + + auto ml_out = paddle::full({1}, max_load, paddle::DataType::INT32); + return {assign_out, ml_out}; +} + +PD_BUILD_OP(cp_balance_ipo) + .Inputs({"Weights"}) + .Attrs({"M: int"}) + .Outputs({"Assign", "MaxLoad"}) + .SetKernelFn(PD_KERNEL(CpBalanceIpoKernel)); diff --git a/flashmask/flash_mask/cp_balance/csrc/setup.py b/flashmask/flash_mask/cp_balance/csrc/setup.py index eac80bf937a..a220908e88a 100644 --- a/flashmask/flash_mask/cp_balance/csrc/setup.py +++ b/flashmask/flash_mask/cp_balance/csrc/setup.py @@ -103,9 +103,10 @@ def setup_ops_extension(): ext_module = CUDAExtension( sources=[ - # cpp files # cuda files "./cp_balance_utils.cu", + # cpp files (compiled by host compiler, not nvcc) + "./cp_balance_ipo_op.cpp", ], include_dirs=[ os.path.join(os.getcwd(), "./"), From 2708e39c46afb8123744bb77a128662618e46428 Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Mon, 13 Apr 2026 00:53:15 +0800 Subject: [PATCH 5/6] [Fix] Fix Python/CPP API discrepancy --- flashmask/flash_mask/cp_balance/cp_balance.py | 6 +++++- flashmask/flash_mask/cp_balance/csrc/cp_balance_ipo_op.cpp | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/flashmask/flash_mask/cp_balance/cp_balance.py b/flashmask/flash_mask/cp_balance/cp_balance.py index 04a84fe3525..306a7d6fba5 100644 --- a/flashmask/flash_mask/cp_balance/cp_balance.py +++ b/flashmask/flash_mask/cp_balance/cp_balance.py @@ -201,12 +201,16 @@ def assign_tasks_ipo( Returns: 与 assign_tasks_heap 相同的三元组 (buckets, bucket_weights, cuts)。 """ + # 兼容 Paddle tensor 输入(与 assign_tasks_heap 行为对齐) + if not isinstance(tasks, np.ndarray): + tasks = tasks.cpu().numpy() if hasattr(tasks, 'cpu') else np.asarray(tasks) + n = len(tasks) if n == 0 or n > 512 or n % num_buckets != 0: return assign_tasks_heap(tasks, num_buckets) K = n // num_buckets - weights = np.array([t[0] for t in tasks], dtype=np.int32) + weights = tasks[:, 0].astype(np.int32) # 调用 C++ IPO solver # assign_matrix: (num_buckets, K),每个元素是 item index (0..N-1) diff --git a/flashmask/flash_mask/cp_balance/csrc/cp_balance_ipo_op.cpp b/flashmask/flash_mask/cp_balance/csrc/cp_balance_ipo_op.cpp index f1b819166f5..7d8866e30bb 100644 --- a/flashmask/flash_mask/cp_balance/csrc/cp_balance_ipo_op.cpp +++ b/flashmask/flash_mask/cp_balance/csrc/cp_balance_ipo_op.cpp @@ -16,16 +16,16 @@ #include "cp_balance_fast.hpp" std::vector CpBalanceIpoKernel( - const paddle::Tensor& weights, int64_t M) { + const paddle::Tensor& weights, int M) { int N = static_cast(weights.shape()[0]); - int K = N / static_cast(M); + int K = N / M; auto assign_out = paddle::empty( {static_cast(M), static_cast(K)}, paddle::DataType::INT32, paddle::CPUPlace()); int max_load = CpBalanceSolver::solve_to( - weights.data(), N, static_cast(M), assign_out.data()); + weights.data(), N, M, assign_out.data()); auto ml_out = paddle::full({1}, max_load, paddle::DataType::INT32); return {assign_out, ml_out}; From a16f37df4ae329be19329bd4893bde2a4dddaab9 Mon Sep 17 00:00:00 2001 From: Enigmatisms Date: Mon, 18 May 2026 16:03:27 +0800 Subject: [PATCH 6/6] [Fix] CP balance 256K+ seqlen bug fix --- flashmask/flash_mask/cp_balance/__init__.py | 2 +- .../cp_balance/csrc/cp_balance_utils.cu | 108 +++++++----------- 2 files changed, 43 insertions(+), 67 deletions(-) diff --git a/flashmask/flash_mask/cp_balance/__init__.py b/flashmask/flash_mask/cp_balance/__init__.py index 6804de4266f..8def756fd1d 100644 --- a/flashmask/flash_mask/cp_balance/__init__.py +++ b/flashmask/flash_mask/cp_balance/__init__.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .cp_balance import balance_flashmask_input +from .cp_balance import balance_flashmask_input, assign_tasks_heap from .cp_balance_cuda_kernels import indices_rerank_cuda, indices_to_chunks_cuda diff --git a/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu b/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu index 231e00a4aaa..ced5d3a6b37 100644 --- a/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu +++ b/flashmask/flash_mask/cp_balance/csrc/cp_balance_utils.cu @@ -108,80 +108,55 @@ __global__ void reduce_workload_kernel( int BH, int Tr, int Tc, int S, int Br // m_block_size ) { - int bh = blockIdx.y; - int tr = blockIdx.x; - int tc = threadIdx.x; - int warpId = threadIdx.x / 32; - int laneId = threadIdx.x % 32; - - if(tr >= Tr) return; - - int wl = 0; - bool fully_masked = true; - bool partially_masked = false; - int lt_start_max = INT_MAX; - int lt_start_min = INT_MAX; - int lt_end_max = INT_MAX; - int lt_end_min = INT_MAX; - int ut_start_max = INT_MIN; - int ut_start_min = INT_MIN; - int ut_end_max = INT_MIN; - int ut_end_min = INT_MIN; + const int bh = blockIdx.y; + const int tr = blockIdx.x; + const int warpId = threadIdx.x / 32; + const int laneId = threadIdx.x % 32; - __shared__ int smem[32]; - - const int idx = bh * Tc + tc; - const int q_idx = bh * Tr + tr; + if (tr >= Tr) return; - // m_block_s/e: Q block boundaries within a single (batch, head) — use tr only, not q_idx. - // q_idx includes the bh offset for output indexing, but mask values are in [0, S) per (b,h). + // m_block_s/e: Q block boundaries within a single (batch, head). const int m_block_s = tr * kBlockM; const int m_block_e = m_block_s + kBlockM < S ? m_block_s + kBlockM : S; - lt_start_max = tc < Tc ? LTStartMax[idx] : INT_MAX; - lt_start_min = tc < Tc ? LTStartMin[idx] : INT_MAX; - - // 分支展开 - if constexpr (PTR_DISPATCH_TAG == FULL_PTR) { - lt_end_max = tc < Tc ? LTEndMax[idx] : INT_MAX; - lt_end_min = tc < Tc ? LTEndMin[idx] : INT_MAX; - ut_start_max = tc < Tc ? UTStartMax[idx] : INT_MIN; - ut_start_min = tc < Tc ? UTStartMin[idx] : INT_MIN; - ut_end_max = tc < Tc ? UTEndMax[idx] : INT_MIN; - ut_end_min = tc < Tc ? UTEndMin[idx] : INT_MIN; - - fully_masked = (m_block_s >= lt_start_max && m_block_e <= lt_end_min) || - (m_block_s >= ut_start_max && m_block_e <= ut_end_min); - partially_masked = (m_block_s < lt_end_max && m_block_e > lt_start_min) || - (m_block_s < ut_end_max && m_block_e > ut_start_min); - } - else if constexpr (PTR_DISPATCH_TAG == DUAL_PTR) { - if constexpr (is_causal) { - lt_end_max = tc < Tc ? LTEndMax[idx] : INT_MAX; - lt_end_min = tc < Tc ? LTEndMin[idx] : INT_MAX; - fully_masked = m_block_s >= lt_start_max && m_block_e <= lt_end_min; - partially_masked = m_block_s < lt_end_max && m_block_e > lt_start_min; - } else { - ut_end_max = tc < Tc ? UTEndMax[idx] : INT_MIN; - ut_end_min = tc < Tc ? UTEndMin[idx] : INT_MIN; - fully_masked = (m_block_s >= lt_start_max) || (m_block_e <= ut_end_min); - partially_masked = (m_block_e > lt_start_min) || (m_block_s < ut_end_max); + const int bh_offset = bh * Tc; + const int q_idx = bh * Tr + tr; + + // Stride loop: 每个 thread 处理 tc = threadIdx.x, threadIdx.x + blockDim.x, ... + int thread_wl = 0; + for (int tc = static_cast(threadIdx.x); tc < Tc; tc += static_cast(blockDim.x)) { + const int idx = bh_offset + tc; + + int lt_start_max_val = LTStartMax[idx]; + bool fully_masked = true; + + if constexpr (PTR_DISPATCH_TAG == FULL_PTR) { + int lt_end_min_val = LTEndMin[idx]; + int ut_start_max_val = UTStartMax[idx]; + int ut_end_min_val = UTEndMin[idx]; + fully_masked = (m_block_s >= lt_start_max_val && m_block_e <= lt_end_min_val) || + (m_block_s >= ut_start_max_val && m_block_e <= ut_end_min_val); + } + else if constexpr (PTR_DISPATCH_TAG == DUAL_PTR) { + if constexpr (is_causal) { + int lt_end_min_val = LTEndMin[idx]; + fully_masked = m_block_s >= lt_start_max_val && m_block_e <= lt_end_min_val; + } else { + int ut_end_min_val = UTEndMin[idx]; + fully_masked = (m_block_s >= lt_start_max_val) || (m_block_e <= ut_end_min_val); + } + } + else if constexpr (PTR_DISPATCH_TAG == SINGLE_PTR) { + fully_masked = m_block_s >= lt_start_max_val; } - } - else if constexpr (PTR_DISPATCH_TAG == SINGLE_PTR) { - fully_masked = m_block_s >= lt_start_max; - partially_masked = m_block_e > lt_start_min; - } - if(tc >= Tc){ - fully_masked = true; - partially_masked = false; + thread_wl += fully_masked ? 0 : 1; } - wl = fully_masked ? 0 : 1; - unsigned mask = 0xffffffff; - // warp reduce sum - int wl_sum = wl; + // Warp reduce sum + __shared__ int smem[32]; + const unsigned mask = 0xffffffff; + int wl_sum = thread_wl; for (int offset = 16; offset > 0; offset >>= 1) { wl_sum += __shfl_down_sync(mask, wl_sum, offset); } @@ -190,8 +165,9 @@ __global__ void reduce_workload_kernel( } __syncthreads(); + // Final reduce across warps (first warp collects) if (threadIdx.x < 32) { - int val = (threadIdx.x < (blockDim.x + 31)/32) ? smem[threadIdx.x] : 0; + int val = (threadIdx.x < (blockDim.x + 31) / 32) ? smem[threadIdx.x] : 0; for (int offset = 16; offset > 0; offset >>= 1) { val += __shfl_down_sync(mask, val, offset); }