Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,8 @@ __global__ void pagedAttentionPrefillKernel(
Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size;

// --- KV Cache 相关信息
const int64_t total_seq_len = cache_lens_[seq_idx];
const int64_t history_len = total_seq_len - cur_new_len;
const int64_t causal_limit = history_len + q_token_idx;
const int64_t cache_len = cache_lens_[seq_idx];
const int64_t causal_limit = cache_len + q_token_idx;

const size_t num_queries_per_kv = num_heads / num_kv_heads;
const size_t kv_head_idx = head_idx / num_queries_per_kv;
Expand Down
25 changes: 24 additions & 1 deletion src/infiniop/ops/paged_attention_prefill/info.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,32 @@ class PagedAttentionPrefillInfo {
return INFINI_STATUS_BAD_PARAM;
}

auto k_shape = k_cache_desc->shape();
auto v_shape = v_cache_desc->shape();
auto block_tables_shape = block_tables_desc->shape();
auto cache_lens_shape = cache_lens_desc->shape();
auto seq_lens_shape = seq_lens_desc->shape();
auto offset_shape = offset_desc->shape();

if (k_shape.size() != 4 || v_shape.size() != 4) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}

if (block_tables_shape.size() != 2) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}

if (cache_lens_shape.size() != 1 || offset_shape.size() != 1) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}

if ((offset_shape[0] - cache_lens_shape[0]) != 1) {
return INFINI_STATUS_BAD_PARAM;
}

// Q shape: [total_tokens, heads, dim] (3D)
auto q_shape = q_desc->shape();
if (q_shape.size() < 3) {
if (q_shape.size() != 3) {
return INFINI_STATUS_BAD_TENSOR_SHAPE;
}
size_t total_q_tokens = q_shape[0];
Expand Down
19 changes: 10 additions & 9 deletions test/infiniop/paged_attention_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,15 @@ def allocate_slots(self, request_id, num_new_tokens):


def ref_paged_attention_multi_turn(
query_new, k_cache, v_cache, block_tables, seq_lens, new_lens, offset, scale
query_new, k_cache, v_cache, block_tables, cache_lens, new_lens, offset, scale
):
block_size = k_cache.shape[2]
outputs = torch.zeros_like(query_new)
for i in range(len(offset) - 1):
total_len = seq_lens[i].item()
num_seqs = len(offset) - 1
for i in range(num_seqs):
num_new = new_lens[i].item()
history_len = total_len - num_new
cache_len = cache_lens[i].item()
total_len = cache_lens[i].item() + num_new

table = block_tables[i]
keys_all, values_all = [], []
Expand All @@ -99,7 +100,7 @@ def ref_paged_attention_multi_turn(

mask = torch.full((num_new, total_len), float("-inf"), device=Q.device)
for q_idx in range(num_new):
mask[q_idx, : history_len + q_idx + 1] = 0.0
mask[q_idx, : cache_len + q_idx + 1] = 0.0

scores = scores + mask.unsqueeze(0)
attn_weights = torch.softmax(scores, dim=-1).to(Q.dtype)
Expand Down Expand Up @@ -163,8 +164,9 @@ def test(
offset_list.append(cur_offset)

cur_new_len = seq_lens_cpu[i].item()
table, cache_len = manager.allocate_slots(i, cur_new_len)
cache_lens_list.append(cache_len)
table, total_len = manager.allocate_slots(i, cur_new_len)
cache_lens = total_len - cur_new_len
cache_lens_list.append(cache_lens)
all_block_tables.append(table)

# Simulated KV insertion
Expand All @@ -175,9 +177,8 @@ def test(

cur_offset = cur_offset + cur_new_len

history_len = cache_len - cur_new_len
for t in range(cur_new_len):
logical_pos = history_len + t
logical_pos = cache_lens + t
b_id = table[logical_pos // block_size]
off = logical_pos % block_size
k_cache.torch_tensor()[b_id, :, off, :] = k_new[t]
Expand Down