From 5848b408c4b11444694667efc221add4dc5729ab Mon Sep 17 00:00:00 2001 From: wooway777 Date: Wed, 24 Dec 2025 15:39:24 +0800 Subject: [PATCH] issue/838 - Cambricon Batched RoPE --- src/infiniop/ops/rope/bang/rope_bang.mlu | 13 +++---- .../ops/rope/bang/rope_bang_kernel.mlu | 35 ++++++++++++++++--- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/src/infiniop/ops/rope/bang/rope_bang.mlu b/src/infiniop/ops/rope/bang/rope_bang.mlu index b77e32d6c..9d493e67b 100644 --- a/src/infiniop/ops/rope/bang/rope_bang.mlu +++ b/src/infiniop/ops/rope/bang/rope_bang.mlu @@ -40,8 +40,9 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, const Tdata *sin_table, const Tdata *cos_table, cnrtQueue_t queue) { - auto dimx = uint32_t(info.seqlen); - auto dimy = uint32_t(info.nhead); + auto batch_size = uint32_t(info.batch); + auto seqlen = uint32_t(info.seqlen); + auto nhead = uint32_t(info.nhead); auto table_dim = uint32_t(info.table_dim); cnrtDim3_t k_dim; @@ -53,12 +54,12 @@ infiniStatus_t calculateRoPE(const RoPEInfo &info, k_dim.z = 1; k_type = CNRT_FUNC_TYPE_UNION1; - // Launch kernel + // Launch kernel with batch dimension ropeKernel<<>>( y, x, pos_ids, sin_table, cos_table, - dimx, dimy, table_dim, - info.y_stride_seqlen, info.y_stride_nhead, - info.x_stride_seqlen, info.x_stride_nhead, + batch_size, seqlen, nhead, table_dim, + info.y_stride_batch, info.y_stride_seqlen, info.y_stride_nhead, + info.x_stride_batch, info.x_stride_seqlen, info.x_stride_nhead, info.algo); cnrtQueueSync(queue); diff --git a/src/infiniop/ops/rope/bang/rope_bang_kernel.mlu b/src/infiniop/ops/rope/bang/rope_bang_kernel.mlu index fde035b4e..d88d1be1d 100644 --- a/src/infiniop/ops/rope/bang/rope_bang_kernel.mlu +++ b/src/infiniop/ops/rope/bang/rope_bang_kernel.mlu @@ -62,11 +62,14 @@ __mlu_global__ void ropeKernel( const Tindex *pos_ids, const Tdata *sin_table, const Tdata *cos_table, + uint32_t batch_size, uint32_t seqlen, uint32_t nhead, uint32_t table_dim, + ptrdiff_t y_stride_batch, ptrdiff_t y_stride_seqlen, ptrdiff_t y_stride_nhead, + ptrdiff_t x_stride_batch, ptrdiff_t x_stride_seqlen, ptrdiff_t x_stride_nhead, infiniopRoPEAlgo_t algo) { @@ -106,7 +109,7 @@ __mlu_global__ void ropeKernel( } // Task distribution - const int batch_volume = seqlen * nhead; + const int batch_volume = batch_size * seqlen * nhead; const int remaining_tasks = batch_volume % taskDim; const int base_tasks_per_core = batch_volume / taskDim; const int actual_tasks = base_tasks_per_core + (taskId < remaining_tasks ? 1 : 0); @@ -136,13 +139,35 @@ __mlu_global__ void ropeKernel( // Main processing loop for (int i = task_start_idx; i < task_start_idx + actual_tasks; i++) { - int seq_idx = i / nhead; + // Calculate 3D indices from flattened task index + int batch_idx = i / (seqlen * nhead); + int seq_idx = (i % (seqlen * nhead)) / nhead; int head_idx = i % nhead; - int out_offset = seq_idx * y_stride_seqlen + head_idx * y_stride_nhead; - int in_offset = seq_idx * x_stride_seqlen + head_idx * x_stride_nhead; + // Calculate offsets with batch dimension + // Note: For GPT-NeoX, the stride calculations might be different + int out_offset = batch_idx * y_stride_batch + seq_idx * y_stride_seqlen + head_idx * y_stride_nhead; + int in_offset = batch_idx * x_stride_batch + seq_idx * x_stride_seqlen + head_idx * x_stride_nhead; + + // Get position index for this sequence + // Position IDs are shared across batches or per batch depending on input + Tindex pos_idx; + if (use_pos_ids_buffer) { + // Position IDs loaded in NRAM + pos_idx = srcP[seq_idx]; + } else { + // Position IDs in global memory + // Handle both cases: position IDs shape could be [seqlen] or [batch_size, seqlen] + if (batch_size > 1) { + // Assume position IDs have shape [batch_size, seqlen] + int pos_flat_idx = batch_idx * seqlen + seq_idx; + pos_idx = pos_ids[pos_flat_idx]; + } else { + // Single batch case: position IDs shape is [seqlen] + pos_idx = pos_ids[seq_idx]; + } + } - Tindex pos_idx = use_pos_ids_buffer ? srcP[seq_idx] : pos_ids[seq_idx]; int rot_offset = pos_idx * table_dim; int processed = 0;