Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/infiniop.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "infiniop/ops/lp_norm.h"
#include "infiniop/ops/mul.h"
#include "infiniop/ops/ones.h"
#include "infiniop/ops/paged_attention_prefill.h"
#include "infiniop/ops/random_sample.h"
#include "infiniop/ops/rearrange.h"
#include "infiniop/ops/relu.h"
Expand Down
96 changes: 96 additions & 0 deletions include/infiniop/ops/paged_attention_prefill.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
#ifndef __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
#define __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__

#include "../operator_descriptor.h"

// Define an opaque handle for the Paged Attention descriptor.
typedef struct InfiniopDescriptor *infiniopPagedAttentionPrefillDescriptor_t;

/**
* @brief Creates a descriptor for the Paged Attention v1 operation.
*
* This function initializes a descriptor that holds all the metadata needed
* for the paged attention computation.
*
* @param handle The handle to the InfiniOP library context.
* @param desc_ptr A pointer to store the created descriptor.
* @param out_desc Descriptor for the output tensor.
* @param q_desc Descriptor for the query tensor.
* @param k_cache_desc Descriptor for the key cache tensor.
* @param v_cache_desc Descriptor for the value cache tensor.
* @param block_tables_desc Descriptor for the block tables tensor.
* @param seq_lens_desc Descriptor for the sequence lengths tensor.
* @param seq_offsets_desc
* @param cache_lens_desc Descriptor for the cache lengths tensor.
* @param alibi_slopes_desc Optional descriptor for the ALiBi slopes tensor. Can be NULL.
* @param scale The attention scaling factor.
* @param max_num_blocks_per_seq The maximum number of batched blocks tables.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopCreatePagedAttentionPrefillDescriptor(
infiniopHandle_t handle,
infiniopPagedAttentionPrefillDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t seq_offsets_desc,
infiniopTensorDescriptor_t cache_lens_desc,
infiniopTensorDescriptor_t alibi_slopes_desc,
float scale);

/**
* @brief Retrieves the workspace size required for the Paged Attention operation.
*
* @param desc The Paged Attention descriptor.
* @param size A pointer to store the required workspace size in bytes.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopGetPagedAttentionPrefillWorkspaceSize(
infiniopPagedAttentionPrefillDescriptor_t desc, size_t *size);

/**
* @brief Executes the Paged Attention v1 operation.
*
* @param desc The Paged Attention descriptor.
* @param workspace Pointer to the workspace memory.
* @param workspace_size The size of the workspace.
* @param out Pointer to the output tensor data.
* @param q Pointer to the query tensor data.
* @param k_cache Pointer to the key cache data.
* @param v_cache Pointer to the value cache data.
* @param block_tables Pointer to the block tables data.
* @param seq_lens Pointer to the sequence lengths data.
* @param seq_offsets Pointer to the sequence offsets data.
* @param cache_lens Pointer to the sequence lengths data.
* @param alibi_slopes Pointer to the ALiBi slopes data. Can be NULL.
* @param stream The CUDA stream for the operation. Can be NULL.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopPagedAttentionPrefill(
infiniopPagedAttentionPrefillDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *q,
const void *k_cache,
const void *v_cache,
const void *block_tables,
const void *seq_lens,
const void *seq_offsets,
const void *cache_lens,
const void *alibi_slopes,
void *stream);

/**
* @brief Destroys a Paged Attention descriptor.
*
* @param desc The descriptor to be destroyed.
* @return infiniStatus_t Status code of the operation.
*/
__C __export infiniStatus_t infiniopDestroyPagedAttentionPrefillDescriptor(
infiniopPagedAttentionPrefillDescriptor_t desc);

#endif // __INFINIOP_PAGED_ATTENTION_PREFILL_API_H__
151 changes: 151 additions & 0 deletions src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#ifndef __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
#define __PAGED_ATTENTION_PREFILL_KERNEL_CUH__

namespace op::paged_attention_prefill::cuda {

template <typename Tdata, typename Tcompute,
size_t HEAD_DIM, size_t NUM_THREADS_PER_BLOCK>
__global__ void paged_attention_prefill(
Tdata *__restrict__ out, // [total_seq, H, D]
const Tdata *__restrict__ q, // [total_seq, H, D]
const Tdata *__restrict__ k_cache, // [num_blocks, nkvh, block_size, dh]
const Tdata *__restrict__ v_cache,
const int64_t *__restrict__ block_tables, // [B, max_blocks_per_seq]
const int64_t *__restrict__ seq_lens, // [B]
const int64_t *__restrict__ cache_lens, // [B]
const int64_t *__restrict__ seq_offsets, // [B]
size_t num_heads,
size_t num_kv_heads,
size_t max_blocks_per_seq,
size_t block_size,
ptrdiff_t q_stride, // = num_heads * HEAD_DIM
ptrdiff_t kv_block_stride, // stride between physical blocks
ptrdiff_t kv_head_stride, // stride between kv heads
ptrdiff_t out_stride, // = num_heads * HEAD_DIM
float scale) {
const int h = blockIdx.x;
const int b = blockIdx.y;
const int tid = threadIdx.x;

const int64_t seq_len = seq_lens[b];
const int64_t ctx_len = cache_lens[b];
const int64_t q_base = seq_offsets[b];

if (h >= num_heads) {
return;
}

const int kv_head = h / (num_heads / num_kv_heads);

extern __shared__ char smem_bytes[];
Tcompute *smem = reinterpret_cast<Tcompute *>(smem_bytes);
Tcompute *smem_k = smem; // [block_size, D]
Tcompute *smem_v = smem + block_size * HEAD_DIM; // [block_size, D]

// Loop over prefill tokens of this request
for (int64_t t = 0; t < seq_len; ++t) {
const int64_t q_index = q_base + t;

// -----------------------------
// Load Q
// -----------------------------
const Tdata *q_ptr = q + q_index * q_stride + h * HEAD_DIM;

Tcompute q_vec[HEAD_DIM];
#pragma unroll
for (size_t d = 0; d < HEAD_DIM; ++d) {
q_vec[d] = static_cast<Tcompute>(q_ptr[d]);
}

const int64_t q_logical = ctx_len + t;
const int64_t attn_len = q_logical + 1; // causal

// Online softmax state
Tcompute m = -INFINITY;
Tcompute s = 0;
Tcompute acc[HEAD_DIM] = {0};

// -----------------------------
// Iterate over KV blocks
// -----------------------------
for (int64_t blk = 0;
blk * block_size < attn_len;
++blk) {

const int64_t phys = block_tables[b * max_blocks_per_seq + blk];

const int64_t base = blk * block_size;

// -----------------------------
// Load K/V tile into shared memory
// -----------------------------
for (int i = tid;
i < block_size * HEAD_DIM;
i += NUM_THREADS_PER_BLOCK) {

const int tok = i / HEAD_DIM;
const int d = i % HEAD_DIM;
const int64_t logical = base + tok;

if (logical < attn_len) {
const Tdata *k_ptr = k_cache
+ phys * kv_block_stride
+ kv_head * kv_head_stride
+ tok * HEAD_DIM;

const Tdata *v_ptr = v_cache
+ phys * kv_block_stride
+ kv_head * kv_head_stride
+ tok * HEAD_DIM;

smem_k[i] = static_cast<Tcompute>(k_ptr[d]);
smem_v[i] = static_cast<Tcompute>(v_ptr[d]);
}
}
__syncthreads();

// -----------------------------
// Online softmax update
// -----------------------------
for (int tok = tid; tok < block_size; tok += NUM_THREADS_PER_BLOCK) {
const int64_t logical = base + tok;
if (logical < attn_len) {
Tcompute dot = 0;
#pragma unroll
for (size_t d = 0; d < HEAD_DIM; ++d) {
dot += q_vec[d] * smem_k[tok * HEAD_DIM + d];
}
dot *= scale;

const Tcompute m_new = max((float)m, (float)dot);
const Tcompute alpha = exp((float)(m - m_new));
const Tcompute beta = exp((float)(dot - m_new));

#pragma unroll
for (size_t d = 0; d < HEAD_DIM; ++d) {
acc[d] = acc[d] * alpha + beta * smem_v[tok * HEAD_DIM + d];
}

s = s * alpha + beta;
m = m_new;
}
}
__syncthreads();
}

const Tcompute inv_s = Tcompute(1) / (s + Tcompute(1e-6));

// -----------------------------
// Write output
// -----------------------------
Tdata *out_ptr = out + q_index * out_stride + h * HEAD_DIM;

#pragma unroll
for (size_t d = 0; d < HEAD_DIM; ++d) {
out_ptr[d] = static_cast<Tdata>(acc[d] * inv_s);
}
}
}
} // namespace op::paged_attention_prefill::cuda

#endif // __PAGED_ATTENTION_PREFILL_KERNEL_CUH__
107 changes: 107 additions & 0 deletions src/infiniop/ops/paged_attention_prefill/info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#ifndef __PAGED_ATTENTION_PREFILL_INFO_H__
#define __PAGED_ATTENTION_PREFILL_INFO_H__

#include "../../../utils.h"
#include "../../tensor.h"
#include <iostream>
#include <optional>
#include <vector>

namespace op::paged_attention_prefill {

class PagedAttentionPrefillInfo {
PagedAttentionPrefillInfo() = default;

public:
// --- Data Types and Scale ---
infiniDtype_t dtype;
float scale;

// --- Shape Dimensions ---
size_t num_seqs;
size_t num_heads;
size_t num_kv_heads;
size_t head_size;
size_t block_size;
size_t max_num_blocks_per_seq;

// --- Strides for Memory Layout ---
ptrdiff_t q_stride;
ptrdiff_t kv_block_stride;
ptrdiff_t kv_head_stride;
ptrdiff_t o_stride;

static utils::Result<PagedAttentionPrefillInfo> create(
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t q_desc,
infiniopTensorDescriptor_t k_cache_desc,
infiniopTensorDescriptor_t v_cache_desc,
infiniopTensorDescriptor_t block_tables_desc,
infiniopTensorDescriptor_t seq_lens_desc,
infiniopTensorDescriptor_t seq_offsets_desc,
infiniopTensorDescriptor_t cache_lens_desc,
const std::optional<infiniopTensorDescriptor_t> &alibi_slopes_desc,
float scale) {

auto dtype = q_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
if (out_desc->dtype() != dtype || k_cache_desc->dtype() != dtype || v_cache_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}

if (q_desc->ndim() != 3 || k_cache_desc->ndim() < 4 || v_cache_desc->ndim() < 4 || block_tables_desc->ndim() != 2 || seq_lens_desc->ndim() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}

CHECK_DTYPE(block_tables_desc->dtype(), INFINI_DTYPE_I64);
CHECK_DTYPE(seq_lens_desc->dtype(), INFINI_DTYPE_I64);
CHECK_DTYPE(cache_lens_desc->dtype(), INFINI_DTYPE_I64);

// --- Extract shape dimensions ---
auto q_shape = q_desc->shape();
auto k_cache_shape = k_cache_desc->shape();

size_t num_seqs = q_shape[0];
size_t num_heads = q_shape[1];
size_t head_size = q_shape[2];

if (!(head_size == 256 || head_size == 128 || head_size == 64 || head_size == 32 || head_size == 16)) {
// 输出具体的错误原因和当前的参数值
std::cerr << "Paged Attention Prefill supports head dim of 16, 32, 64, 128, or 256, but got "
<< head_size << "." << std::endl;
// 建议返回 SHAPE 相关的错误码
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}

size_t num_kv_heads = k_cache_shape[1];
size_t block_size = v_cache_desc->shape()[2]; // 使用V cache的block size维度更可靠
size_t max_num_blocks_per_seq = block_tables_desc->shape()[1];

// --- Calculate max_seq_len for shared memory allocation ---
// This is a safe upper bound.
// info.max_seq_len = info.max_num_blocks_per_seq * info.block_size;
// --- Extract strides for memory access ---
ptrdiff_t q_stride = q_desc->stride(0);
ptrdiff_t kv_block_stride = k_cache_desc->stride(0);
ptrdiff_t kv_head_stride = k_cache_desc->stride(1);
ptrdiff_t o_stride = out_desc->stride(0);

return utils::Result<PagedAttentionPrefillInfo>(PagedAttentionPrefillInfo{
dtype,
scale,
num_seqs,
num_heads,
num_kv_heads,
head_size,
block_size,
max_num_blocks_per_seq,
q_stride,
kv_block_stride,
kv_head_stride,
o_stride});
}
};

} // namespace op::paged_attention_prefill

#endif // __PAGED_ATTENTION_PREFILL_INFO_H__
Loading