diff --git a/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh b/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh index 75220da93..7c2d90197 100644 --- a/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh +++ b/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh @@ -53,9 +53,8 @@ __global__ void pagedAttentionPrefillKernel( Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size; // --- KV Cache 相关信息 - const int64_t total_seq_len = cache_lens_[seq_idx]; - const int64_t history_len = total_seq_len - cur_new_len; - const int64_t causal_limit = history_len + q_token_idx; + const int64_t cache_len = cache_lens_[seq_idx]; + const int64_t causal_limit = cache_len + q_token_idx; const size_t num_queries_per_kv = num_heads / num_kv_heads; const size_t kv_head_idx = head_idx / num_queries_per_kv; diff --git a/src/infiniop/ops/paged_attention_prefill/info.h b/src/infiniop/ops/paged_attention_prefill/info.h index 39c6b5715..2820f88b6 100644 --- a/src/infiniop/ops/paged_attention_prefill/info.h +++ b/src/infiniop/ops/paged_attention_prefill/info.h @@ -56,9 +56,32 @@ class PagedAttentionPrefillInfo { return INFINI_STATUS_BAD_PARAM; } + auto k_shape = k_cache_desc->shape(); + auto v_shape = v_cache_desc->shape(); + auto block_tables_shape = block_tables_desc->shape(); + auto cache_lens_shape = cache_lens_desc->shape(); + auto seq_lens_shape = seq_lens_desc->shape(); + auto offset_shape = offset_desc->shape(); + + if (k_shape.size() != 4 || v_shape.size() != 4) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if (block_tables_shape.size() != 2) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if (cache_lens_shape.size() != 1 || offset_shape.size() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if ((offset_shape[0] - cache_lens_shape[0]) != 1) { + return INFINI_STATUS_BAD_PARAM; + } + // Q shape: [total_tokens, heads, dim] (3D) auto q_shape = q_desc->shape(); - if (q_shape.size() < 3) { + if (q_shape.size() != 3) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } size_t total_q_tokens = q_shape[0]; diff --git a/test/infiniop/paged_attention_prefill.py b/test/infiniop/paged_attention_prefill.py index 948fd72d5..112ecabf3 100644 --- a/test/infiniop/paged_attention_prefill.py +++ b/test/infiniop/paged_attention_prefill.py @@ -74,14 +74,15 @@ def allocate_slots(self, request_id, num_new_tokens): def ref_paged_attention_multi_turn( - query_new, k_cache, v_cache, block_tables, seq_lens, new_lens, offset, scale + query_new, k_cache, v_cache, block_tables, cache_lens, new_lens, offset, scale ): block_size = k_cache.shape[2] outputs = torch.zeros_like(query_new) - for i in range(len(offset) - 1): - total_len = seq_lens[i].item() + num_seqs = len(offset) - 1 + for i in range(num_seqs): num_new = new_lens[i].item() - history_len = total_len - num_new + cache_len = cache_lens[i].item() + total_len = cache_lens[i].item() + num_new table = block_tables[i] keys_all, values_all = [], [] @@ -99,7 +100,7 @@ def ref_paged_attention_multi_turn( mask = torch.full((num_new, total_len), float("-inf"), device=Q.device) for q_idx in range(num_new): - mask[q_idx, : history_len + q_idx + 1] = 0.0 + mask[q_idx, : cache_len + q_idx + 1] = 0.0 scores = scores + mask.unsqueeze(0) attn_weights = torch.softmax(scores, dim=-1).to(Q.dtype) @@ -163,8 +164,9 @@ def test( offset_list.append(cur_offset) cur_new_len = seq_lens_cpu[i].item() - table, cache_len = manager.allocate_slots(i, cur_new_len) - cache_lens_list.append(cache_len) + table, total_len = manager.allocate_slots(i, cur_new_len) + cache_lens = total_len - cur_new_len + cache_lens_list.append(cache_lens) all_block_tables.append(table) # Simulated KV insertion @@ -175,9 +177,8 @@ def test( cur_offset = cur_offset + cur_new_len - history_len = cache_len - cur_new_len for t in range(cur_new_len): - logical_pos = history_len + t + logical_pos = cache_lens + t b_id = table[logical_pos // block_size] off = logical_pos % block_size k_cache.torch_tensor()[b_id, :, off, :] = k_new[t]