From d9f48d5c9af3b1252dddb476a33351cd0cee40bb Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Thu, 25 Dec 2025 11:11:30 +0000 Subject: [PATCH 1/2] issue/848 paged attention prefill --- include/infiniop.h | 1 + .../infiniop/ops/paged_attention_prefill.h | 95 +++++++ .../paged_attention_prefill/cuda/kernel.cuh | 151 ++++++++++++ .../ops/paged_attention_prefill/info.h | 107 ++++++++ .../nvidia/paged_attention_prefill_nvidia.cu | 233 ++++++++++++++++++ .../nvidia/paged_attention_prefill_nvidia.cuh | 8 + .../ops/paged_attention_prefill/operator.cc | 96 ++++++++ .../paged_attention_prefill.h | 55 +++++ 8 files changed, 746 insertions(+) create mode 100644 include/infiniop/ops/paged_attention_prefill.h create mode 100644 src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh create mode 100644 src/infiniop/ops/paged_attention_prefill/info.h create mode 100644 src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu create mode 100644 src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cuh create mode 100644 src/infiniop/ops/paged_attention_prefill/operator.cc create mode 100644 src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h diff --git a/include/infiniop.h b/include/infiniop.h index 92e6f5963..eec3c1389 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -15,6 +15,7 @@ #include "infiniop/ops/lp_norm.h" #include "infiniop/ops/mul.h" #include "infiniop/ops/ones.h" +#include "infiniop/ops/paged_attention_prefill.h" #include "infiniop/ops/random_sample.h" #include "infiniop/ops/rearrange.h" #include "infiniop/ops/relu.h" diff --git a/include/infiniop/ops/paged_attention_prefill.h b/include/infiniop/ops/paged_attention_prefill.h new file mode 100644 index 000000000..6ceb357e5 --- /dev/null +++ b/include/infiniop/ops/paged_attention_prefill.h @@ -0,0 +1,95 @@ +#ifndef __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__ +#define __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__ + +#include "../operator_descriptor.h" + +// Define an opaque handle for the Paged Attention descriptor. +typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t; + +/** + * @brief Creates a descriptor for the Paged Attention v1 operation. + * + * This function initializes a descriptor that holds all the metadata needed + * for the paged attention computation. + * + * @param handle The handle to the InfiniOP library context. + * @param desc_ptr A pointer to store the created descriptor. + * @param out_desc Descriptor for the output tensor. + * @param q_desc Descriptor for the query tensor. + * @param k_cache_desc Descriptor for the key cache tensor. + * @param v_cache_desc Descriptor for the value cache tensor. + * @param block_tables_desc Descriptor for the block tables tensor. + * @param seq_lens_desc Descriptor for the sequence lengths tensor. + * @param seq_offsets_desc + * @param cache_lens_desc Descriptor for the cache lengths tensor. + * @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL. + * @param scale The attention scaling factor. + * @param max_num_blocks_per_seq The maximum number of batched blocks tables. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( + infiniopHandle_t handle, + infiniopPagedAttentionPrefillDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t seq_offsets_desc, + infiniopTensorDescriptor_t cache_lens_desc, + infiniopTensorDescriptor_t alibi_slopes_desc, + float scale); + +/** + * @brief Retrieves the workspace size required for the Paged Attention operation. + * + * @param desc The Paged Attention descriptor. + * @param size A pointer to store the required workspace size in bytes. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( + infiniopPagedAttentionPrefillDescriptor_t desc, size_t *size); + +/** + * @brief Executes the Paged Attention v1 operation. + * + * @param desc The Paged Attention descriptor. + * @param workspace Pointer to the workspace memory. + * @param workspace_size The size of the workspace. + * @param out Pointer to the output tensor data. + * @param q Pointer to the query tensor data. + * @param k_cache Pointer to the key cache data. + * @param v_cache Pointer to the value cache data. + * @param block_tables Pointer to the block tables data. + * @param seq_lens Pointer to the sequence lengths data. + * @param seq_offsets Pointer to the sequence offsets data. + * @param cache_lens Pointer to the sequence lengths data. + * @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL. + * @param stream The CUDA stream for the operation. Can be NULL. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopPagedAttentionPrefill( + infiniopPagedAttentionPrefillDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *seq_lens, + const void *seq_offsets, + const void *alibi_slopes, + void *stream); + +/** + * @brief Destroys a Paged Attention descriptor. + * + * @param desc The descriptor to be destroyed. + * @return infiniStatus_t Status code of the operation. + */ +__C __export infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( + infiniopPagedAttentionPrefillDescriptor_t desc); + +#endif // __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__ diff --git a/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh b/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh new file mode 100644 index 000000000..e3a54d1ea --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh @@ -0,0 +1,151 @@ +#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__ +#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__ + +namespace op::paged_attention_prefill::cuda { + +template +__global__ void paged_attention_prefill( + Tdata *__restrict__ out, // [total_seq, H, D] + const Tdata *__restrict__ q, // [total_seq, H, D] + const Tdata *__restrict__ k_cache, // [num_blocks, nkvh, block_size, dh] + const Tdata *__restrict__ v_cache, + const int64_t *__restrict__ block_tables, // [B, max_blocks_per_seq] + const int64_t *__restrict__ seq_lens, // [B] + const int64_t *__restrict__ cache_lens, // [B] + const int64_t *__restrict__ seq_offsets, // [B] + size_t num_heads, + size_t num_kv_heads, + size_t max_blocks_per_seq, + size_t block_size, + ptrdiff_t q_stride, // = num_heads * HEAD_DIM + ptrdiff_t kv_block_stride, // stride between physical blocks + ptrdiff_t kv_head_stride, // stride between kv heads + ptrdiff_t out_stride, // = num_heads * HEAD_DIM + float scale) { + const int h = blockIdx.x; + const int b = blockIdx.y; + const int tid = threadIdx.x; + + const int64_t seq_len = seq_lens[b]; + const int64_t ctx_len = cache_lens[b]; + const int64_t q_base = seq_offsets[b]; + + if (h >= num_heads) { + return; + } + + const int kv_head = h / (num_heads / num_kv_heads); + + extern __shared__ char smem_bytes[]; + Tcompute *smem = reinterpret_cast(smem_bytes); + Tcompute *smem_k = smem; // [block_size, D] + Tcompute *smem_v = smem + block_size * HEAD_DIM; // [block_size, D] + + // Loop over prefill tokens of this request + for (int64_t t = 0; t < seq_len; ++t) { + const int64_t q_index = q_base + t; + + // ----------------------------- + // Load Q + // ----------------------------- + const Tdata *q_ptr = q + q_index * q_stride + h * HEAD_DIM; + + Tcompute q_vec[HEAD_DIM]; +#pragma unroll + for (size_t d = 0; d < HEAD_DIM; ++d) { + q_vec[d] = static_cast(q_ptr[d]); + } + + const int64_t q_logical = ctx_len + t; + const int64_t attn_len = q_logical + 1; // causal + + // Online softmax state + Tcompute m = -INFINITY; + Tcompute s = 0; + Tcompute acc[HEAD_DIM] = {0}; + + // ----------------------------- + // Iterate over KV blocks + // ----------------------------- + for (int64_t blk = 0; + blk * block_size < attn_len; + ++blk) { + + const int64_t phys = block_tables[b * max_blocks_per_seq + blk]; + + const int64_t base = blk * block_size; + + // ----------------------------- + // Load K/V tile into shared memory + // ----------------------------- + for (int i = tid; + i < block_size * HEAD_DIM; + i += NUM_THREADS_PER_BLOCK) { + + const int tok = i / HEAD_DIM; + const int d = i % HEAD_DIM; + const int64_t logical = base + tok; + + if (logical < attn_len) { + const Tdata *k_ptr = k_cache + + phys * kv_block_stride + + kv_head * kv_head_stride + + tok * HEAD_DIM; + + const Tdata *v_ptr = v_cache + + phys * kv_block_stride + + kv_head * kv_head_stride + + tok * HEAD_DIM; + + smem_k[i] = static_cast(k_ptr[d]); + smem_v[i] = static_cast(v_ptr[d]); + } + } + __syncthreads(); + + // ----------------------------- + // Online softmax update + // ----------------------------- + for (int tok = tid; tok < block_size; tok += NUM_THREADS_PER_BLOCK) { + const int64_t logical = base + tok; + if (logical < attn_len) { + Tcompute dot = 0; +#pragma unroll + for (size_t d = 0; d < HEAD_DIM; ++d) { + dot += q_vec[d] * smem_k[tok * HEAD_DIM + d]; + } + dot *= scale; + + const Tcompute m_new = max((float)m, (float)dot); + const Tcompute alpha = exp((float)(m - m_new)); + const Tcompute beta = exp((float)(dot - m_new)); + +#pragma unroll + for (size_t d = 0; d < HEAD_DIM; ++d) { + acc[d] = acc[d] * alpha + beta * smem_v[tok * HEAD_DIM + d]; + } + + s = s * alpha + beta; + m = m_new; + } + } + __syncthreads(); + } + + const Tcompute inv_s = Tcompute(1) / (s + Tcompute(1e-6)); + + // ----------------------------- + // Write output + // ----------------------------- + Tdata *out_ptr = out + q_index * out_stride + h * HEAD_DIM; + +#pragma unroll + for (size_t d = 0; d < HEAD_DIM; ++d) { + out_ptr[d] = static_cast(acc[d] * inv_s); + } + } +} +} // namespace op::paged_attention_prefill::cuda + +#endif // __PAGED_ATTENTION_PREFILL_KERNEL_CUH__ diff --git a/src/infiniop/ops/paged_attention_prefill/info.h b/src/infiniop/ops/paged_attention_prefill/info.h new file mode 100644 index 000000000..a4d92e41a --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/info.h @@ -0,0 +1,107 @@ +#ifndef __PAGED_ATTENTION_PREFILL_INFO_H__ +#define __PAGED_ATTENTION_PREFILL_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include +#include +#include + +namespace op::paged_attention_prefill { + +class PagedAttentionPrefillInfo { + PagedAttentionPrefillInfo() = default; + +public: + // --- Data Types and Scale --- + infiniDtype_t dtype; + float scale; + + // --- Shape Dimensions --- + size_t num_seqs; + size_t num_heads; + size_t num_kv_heads; + size_t head_size; + size_t block_size; + size_t max_num_blocks_per_seq; + + // --- Strides for Memory Layout --- + ptrdiff_t q_stride; + ptrdiff_t kv_block_stride; + ptrdiff_t kv_head_stride; + ptrdiff_t o_stride; + + static utils::Result create( + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t seq_offsets_desc, + infiniopTensorDescriptor_t cache_lens_desc, + const std::optional &alibi_slopes_desc, + float scale) { + + auto dtype = q_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + if (q_desc->ndim() != 3 || k_cache_desc->ndim() < 4 || v_cache_desc->ndim() < 4 || block_tables_desc->ndim() != 2 || seq_lens_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + CHECK_DTYPE(block_tables_desc->dtype(), INFINI_DTYPE_I32); + CHECK_DTYPE(seq_lens_desc->dtype(), INFINI_DTYPE_I32); + CHECK_DTYPE(cache_lens_desc->dtype(), INFINI_DTYPE_I32); + + // --- Extract shape dimensions --- + auto q_shape = q_desc->shape(); + auto k_cache_shape = k_cache_desc->shape(); + + size_t num_seqs = q_shape[0]; + size_t num_heads = q_shape[1]; + size_t head_size = q_shape[2]; + + if (head_size != 256 || head_size != 128 || head_size != 64 || head_size != 32 || head_size != 16) { + // 输出具体的错误原因和当前的参数值 + std::cerr << "[Error] Now only supports head_size = 128, but got " + << head_size << "." << std::endl; + // 建议返回 SHAPE 相关的错误码 + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + size_t num_kv_heads = k_cache_shape[1]; + size_t block_size = v_cache_desc->shape()[2]; // 使用V cache的block size维度更可靠 + size_t max_num_blocks_per_seq = block_tables_desc->shape()[1]; + + // --- Calculate max_seq_len for shared memory allocation --- + // This is a safe upper bound. + // info.max_seq_len = info.max_num_blocks_per_seq * info.block_size; + // --- Extract strides for memory access --- + ptrdiff_t q_stride = q_desc->stride(0); + ptrdiff_t kv_block_stride = k_cache_desc->stride(0); + ptrdiff_t kv_head_stride = k_cache_desc->stride(1); + ptrdiff_t o_stride = out_desc->stride(0); + + return utils::Result(PagedAttentionPrefillInfo{ + dtype, + scale, + num_seqs, + num_heads, + num_kv_heads, + head_size, + block_size, + max_num_blocks_per_seq, + q_stride, + kv_block_stride, + kv_head_stride, + o_stride}); + } +}; + +} // namespace op::paged_attention_prefill + +#endif // __PAGED_ATTENTION_PREFILL_INFO_H__ diff --git a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu new file mode 100644 index 000000000..cf724da05 --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cu @@ -0,0 +1,233 @@ + + +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include "../../../reduce/cuda/reduce.cuh" + +#include "../cuda/kernel.cuh" +#include "paged_attention_prefill_nvidia.cuh" + +template +static inline infiniStatus_t launch_prefill_kernel_impl( + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *seq_lens, + const void *context_lens, + const void *seq_offsets, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t max_blocks_per_seq, + size_t block_size, + ptrdiff_t q_stride, + ptrdiff_t kv_block_stride, + ptrdiff_t kv_head_stride, + ptrdiff_t out_stride, + float scale, + cudaStream_t stream) { + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS_PER_BLOCK); + + // Shared memory: + // - K block: block_size * HEAD_DIM + // - V block: block_size * HEAD_DIM + size_t smem_bytes = 2 * block_size * HEAD_DIM * sizeof(Tdata); + + op::paged_attention_prefill::cuda::paged_attention_prefill + <<>>( + (Tdata *)out, + (const Tdata *)q, + (const Tdata *)k_cache, + (const Tdata *)v_cache, + (const int64_t *)block_tables, + (const int64_t *)seq_lens, + (const int64_t *)context_lens, + (const int64_t *)seq_offsets, + num_heads, + num_kv_heads, + max_blocks_per_seq, + block_size, + q_stride, + kv_block_stride, + kv_head_stride, + out_stride, + scale); + + return INFINI_STATUS_SUCCESS; +} + +namespace op::paged_attention_prefill::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t seq_offsets_desc, + infiniopTensorDescriptor_t cache_lens_desc, + const std::optional &alibi_slopes_desc, + float scale) { + auto info = PagedAttentionPrefillInfo::create(out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, seq_offsets_desc, cache_lens_desc, alibi_slopes_desc, scale); + CHECK_RESULT(info); + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + info.take(), 0, handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +static infiniStatus_t dispatch_prefill_head_dim_and_dtype( + infiniDtype_t dtype, + size_t head_dim, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *seq_lens, + const void *context_lens, + const void *seq_offsets, + size_t num_heads, + size_t num_seqs, + size_t num_kv_heads, + size_t max_blocks_per_seq, + size_t block_size, + ptrdiff_t q_stride, + ptrdiff_t kv_block_stride, + ptrdiff_t kv_head_stride, + ptrdiff_t out_stride, + float scale, + cudaStream_t stream) { +#define DISPATCH_HEAD_DIM(HD) \ + case HD: \ + if (dtype == INFINI_DTYPE_F16) \ + return launch_prefill_kernel_impl( \ + out, q, k_cache, v_cache, block_tables, \ + seq_lens, context_lens, seq_offsets, \ + num_heads, num_seqs, num_kv_heads, \ + max_blocks_per_seq, block_size, \ + q_stride, kv_block_stride, kv_head_stride, out_stride, \ + scale, stream); \ + if (dtype == INFINI_DTYPE_BF16) \ + return launch_prefill_kernel_impl<__nv_bfloat16, HD, NUM_THREADS_PER_BLOCK>( \ + out, q, k_cache, v_cache, block_tables, \ + seq_lens, context_lens, seq_offsets, \ + num_heads, num_seqs, num_kv_heads, \ + max_blocks_per_seq, block_size, \ + q_stride, kv_block_stride, kv_head_stride, out_stride, \ + scale, stream); \ + if (dtype == INFINI_DTYPE_F32) \ + return launch_prefill_kernel_impl( \ + out, q, k_cache, v_cache, block_tables, \ + seq_lens, context_lens, seq_offsets, \ + num_heads, num_seqs, num_kv_heads, \ + max_blocks_per_seq, block_size, \ + q_stride, kv_block_stride, kv_head_stride, out_stride, \ + scale, stream); \ + return INFINI_STATUS_BAD_TENSOR_DTYPE; + + switch (head_dim) { + DISPATCH_HEAD_DIM(16) + DISPATCH_HEAD_DIM(32) + DISPATCH_HEAD_DIM(64) + DISPATCH_HEAD_DIM(128) + DISPATCH_HEAD_DIM(256) + default: + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + +#undef DISPATCH_HEAD_DIM +} + +infiniStatus_t Descriptor::calculate( + void *, + size_t, + void *out, + const void *q, + const void *k_cache, + const void *v_cache, + const void *block_tables, + const void *seq_lens, + const void *seq_offsets, + const void *context_lens, + const void *alibi_slopes, + void *stream_) const { + cudaStream_t stream = (cudaStream_t)stream_; + const size_t max_threads = _opaque->internal->maxThreadsPerBlock(); + + if (max_threads == CUDA_BLOCK_SIZE_4096) { + return dispatch_prefill_head_dim_and_dtype( + _info.dtype, + _info.head_size, + out, q, k_cache, v_cache, + block_tables, seq_lens, context_lens, seq_offsets, + _info.num_heads, + _info.num_seqs, + _info.num_kv_heads, + _info.max_num_blocks_per_seq, + _info.block_size, + _info.q_stride, + _info.kv_block_stride, + _info.kv_head_stride, + _info.o_stride, + _info.scale, + stream); + } else if (max_threads == CUDA_BLOCK_SIZE_1024) { + return dispatch_prefill_head_dim_and_dtype( + _info.dtype, + _info.head_size, + out, q, k_cache, v_cache, + block_tables, seq_lens, context_lens, seq_offsets, + _info.num_heads, + _info.num_seqs, + _info.num_kv_heads, + _info.max_num_blocks_per_seq, + _info.block_size, + _info.q_stride, + _info.kv_block_stride, + _info.kv_head_stride, + _info.o_stride, + _info.scale, + stream); + } else if (max_threads == CUDA_BLOCK_SIZE_512) { + return dispatch_prefill_head_dim_and_dtype( + _info.dtype, + _info.head_size, + out, q, k_cache, v_cache, + block_tables, seq_lens, context_lens, seq_offsets, + _info.num_heads, + _info.num_seqs, + _info.num_kv_heads, + _info.max_num_blocks_per_seq, + _info.block_size, + _info.q_stride, + _info.kv_block_stride, + _info.kv_head_stride, + _info.o_stride, + _info.scale, + stream); + } + + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; +} + +} // namespace op::paged_attention_prefill::nvidia diff --git a/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cuh b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cuh new file mode 100644 index 000000000..b9d3e97f1 --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/nvidia/paged_attention_prefill_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __PAGED_ATTENTION_PREFILL_NVIDIA_H__ +#define __PAGED_ATTENTION_PREFILL_NVIDIA_H__ + +#include "../paged_attention_prefill.h" + +DESCRIPTOR(nvidia) + +#endif // __PAGED_ATTENTION_PREFILL_NVIDIA_H__ diff --git a/src/infiniop/ops/paged_attention_prefill/operator.cc b/src/infiniop/ops/paged_attention_prefill/operator.cc new file mode 100644 index 000000000..99bddffc8 --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/operator.cc @@ -0,0 +1,96 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/paged_attention.h" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/paged_attention_prefill_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( + infiniopHandle_t handle, + infiniopPagedAttentionDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t q_desc, + infiniopTensorDescriptor_t k_cache_desc, + infiniopTensorDescriptor_t v_cache_desc, + infiniopTensorDescriptor_t block_tables_desc, + infiniopTensorDescriptor_t seq_lens_desc, + infiniopTensorDescriptor_t seq_offsets_desc, + infiniopTensorDescriptor_t cache_lens_desc, + infiniopTensorDescriptor_t alibi_slopes_desc, + float scale) { + + infiniopTensorDescriptor_t alibi_opt = (alibi_slopes_desc == nullptr) ? nullptr : alibi_slopes_desc; + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::paged_attention_prefill::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, q_desc, k_cache_desc, v_cache_desc, block_tables_desc, seq_lens_desc, seq_offsets_desc, cache_lens_desc, alibi_opt, scale); + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( + infiniopPagedAttentionDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__C infiniStatus_t infiniopPagedAttentionPrefill( + infiniopPagedAttentionDescriptor_t desc, + void *workspace, size_t workspace_size, + void *out, const void *q, const void *k_cache, const void *v_cache, + const void *block_tables, const void *seq_lens, const void *seq_offsets, const void *cache_lens, const void *alibi_slopes, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, out, q, k_cache, v_cache, block_tables, \ + seq_lens, seq_offsets, cache_lens, alibi_slopes, stream); + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} + +__C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( + infiniopPagedAttentionDescriptor_t desc) { + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +} diff --git a/src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h b/src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h new file mode 100644 index 000000000..1e4fd7ddf --- /dev/null +++ b/src/infiniop/ops/paged_attention_prefill/paged_attention_prefill.h @@ -0,0 +1,55 @@ +#ifndef PAGED_ATTENTION_PREFILL_H +#define PAGED_ATTENTION_PREFILL_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::paged_attention_prefill::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + PagedAttentionPrefillInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + PagedAttentionPrefillInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t q_desc, \ + infiniopTensorDescriptor_t k_cache_desc, \ + infiniopTensorDescriptor_t v_cache_desc, \ + infiniopTensorDescriptor_t block_tables_desc, \ + infiniopTensorDescriptor_t seq_lens_desc, \ + infiniopTensorDescriptor_t seq_offsets_desc, \ + infiniopTensorDescriptor_t cache_lens_desc, \ + const std::optional &alibi_slopes_desc, \ + float scale); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *out, const void *q, const void *k_cache, const void *v_cache, \ + const void *block_tables, const void *seq_lens, const void *seq_offsets, const void *cache_lens, \ + const void *alibi_slopes, \ + void *stream) const; \ + }; \ + } + +#endif // PAGED_ATTENTION_H From fd96c013e414e33eb97400d4f8b84e0883e89220 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Fri, 26 Dec 2025 01:29:56 +0000 Subject: [PATCH 2/2] issue/848 fix naming and checking --- include/infiniop/ops/paged_attention_prefill.h | 1 + src/infiniop/ops/paged_attention_prefill/info.h | 10 +++++----- src/infiniop/ops/paged_attention_prefill/operator.cc | 10 +++++----- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/include/infiniop/ops/paged_attention_prefill.h b/include/infiniop/ops/paged_attention_prefill.h index 6ceb357e5..64e669254 100644 --- a/include/infiniop/ops/paged_attention_prefill.h +++ b/include/infiniop/ops/paged_attention_prefill.h @@ -80,6 +80,7 @@ __C __export infiniStatus_t infiniopPagedAttentionPrefill( const void *block_tables, const void *seq_lens, const void *seq_offsets, + const void *cache_lens, const void *alibi_slopes, void *stream); diff --git a/src/infiniop/ops/paged_attention_prefill/info.h b/src/infiniop/ops/paged_attention_prefill/info.h index a4d92e41a..6f9aaf7ca 100644 --- a/src/infiniop/ops/paged_attention_prefill/info.h +++ b/src/infiniop/ops/paged_attention_prefill/info.h @@ -53,9 +53,9 @@ class PagedAttentionPrefillInfo { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - CHECK_DTYPE(block_tables_desc->dtype(), INFINI_DTYPE_I32); - CHECK_DTYPE(seq_lens_desc->dtype(), INFINI_DTYPE_I32); - CHECK_DTYPE(cache_lens_desc->dtype(), INFINI_DTYPE_I32); + CHECK_DTYPE(block_tables_desc->dtype(), INFINI_DTYPE_I64); + CHECK_DTYPE(seq_lens_desc->dtype(), INFINI_DTYPE_I64); + CHECK_DTYPE(cache_lens_desc->dtype(), INFINI_DTYPE_I64); // --- Extract shape dimensions --- auto q_shape = q_desc->shape(); @@ -65,9 +65,9 @@ class PagedAttentionPrefillInfo { size_t num_heads = q_shape[1]; size_t head_size = q_shape[2]; - if (head_size != 256 || head_size != 128 || head_size != 64 || head_size != 32 || head_size != 16) { + if (!(head_size == 256 || head_size == 128 || head_size == 64 || head_size == 32 || head_size == 16)) { // 输出具体的错误原因和当前的参数值 - std::cerr << "[Error] Now only supports head_size = 128, but got " + std::cerr << "Paged Attention Prefill supports head dim of 16, 32, 64, 128, or 256, but got " << head_size << "." << std::endl; // 建议返回 SHAPE 相关的错误码 return INFINI_STATUS_BAD_TENSOR_SHAPE; diff --git a/src/infiniop/ops/paged_attention_prefill/operator.cc b/src/infiniop/ops/paged_attention_prefill/operator.cc index 99bddffc8..93229eba9 100644 --- a/src/infiniop/ops/paged_attention_prefill/operator.cc +++ b/src/infiniop/ops/paged_attention_prefill/operator.cc @@ -1,6 +1,6 @@ #include "../../operator.h" #include "../../handle.h" -#include "infiniop/ops/paged_attention.h" +#include "infiniop/ops/paged_attention_prefill.h" #ifdef ENABLE_NVIDIA_API #include "nvidia/paged_attention_prefill_nvidia.cuh" @@ -8,7 +8,7 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( infiniopHandle_t handle, - infiniopPagedAttentionDescriptor_t *desc_ptr, + infiniopPagedAttentionPrefillDescriptor_t *desc_ptr, infiniopTensorDescriptor_t out_desc, infiniopTensorDescriptor_t q_desc, infiniopTensorDescriptor_t k_cache_desc, @@ -39,7 +39,7 @@ __C infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor( } __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( - infiniopPagedAttentionDescriptor_t desc, + infiniopPagedAttentionPrefillDescriptor_t desc, size_t *size) { #define GET(CASE, NAMESPACE) \ @@ -57,7 +57,7 @@ __C infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize( } __C infiniStatus_t infiniopPagedAttentionPrefill( - infiniopPagedAttentionDescriptor_t desc, + infiniopPagedAttentionPrefillDescriptor_t desc, void *workspace, size_t workspace_size, void *out, const void *q, const void *k_cache, const void *v_cache, const void *block_tables, const void *seq_lens, const void *seq_offsets, const void *cache_lens, const void *alibi_slopes, @@ -79,7 +79,7 @@ __C infiniStatus_t infiniopPagedAttentionPrefill( } __C infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor( - infiniopPagedAttentionDescriptor_t desc) { + infiniopPagedAttentionPrefillDescriptor_t desc) { #define DESTROY(CASE, NAMESPACE) \ case CASE: \