From e13ad8f93749462e41fb34637c08525426e30558 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Tue, 30 Dec 2025 07:49:15 +0000 Subject: [PATCH 1/2] issue/847 correct cache_lens naming --- include/infinicore/ops/paged_attention.hpp | 6 +++--- python/infinicore/ops/paged_attention.py | 6 +++--- .../ops/paged_attention/paged_attention.cc | 14 +++++++------- .../paged_attention/paged_attention_infiniop.cc | 8 ++++---- src/infinicore/pybind11/ops/paged_attention.hpp | 12 ++++++------ test/infinicore/ops/paged_attention.py | 16 ++++++++-------- 6 files changed, 31 insertions(+), 31 deletions(-) diff --git a/include/infinicore/ops/paged_attention.hpp b/include/infinicore/ops/paged_attention.hpp index 43d79214c..54d61fa89 100644 --- a/include/infinicore/ops/paged_attention.hpp +++ b/include/infinicore/ops/paged_attention.hpp @@ -9,10 +9,10 @@ namespace infinicore::op { class PagedAttention { public: using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional, float); - static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional alibi_slopes, float); + static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float); static common::OpDispatcher &dispatcher(); }; -Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional alibi_slopes, float scale); -void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional alibi_slopes, float scale); +Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float scale); +void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float scale); } // namespace infinicore::op diff --git a/python/infinicore/ops/paged_attention.py b/python/infinicore/ops/paged_attention.py index e5b448ff0..dfefa6a76 100644 --- a/python/infinicore/ops/paged_attention.py +++ b/python/infinicore/ops/paged_attention.py @@ -7,7 +7,7 @@ def paged_attention( k_cache: Tensor, v_cache: Tensor, block_tables: Tensor, - seq_lens: Tensor, + cache_lens: Tensor, alibi_slopes: Tensor | None = None, scale: float = 1.0, *, @@ -20,7 +20,7 @@ def paged_attention( k_cache._underlying, v_cache._underlying, block_tables._underlying, - seq_lens._underlying, + cache_lens._underlying, alibi_slopes._underlying if alibi_slopes is not None else None, scale, ) @@ -32,7 +32,7 @@ def paged_attention( k_cache._underlying, v_cache._underlying, block_tables._underlying, - seq_lens._underlying, + cache_lens._underlying, alibi_slopes._underlying if alibi_slopes is not None else None, scale, ) diff --git a/src/infinicore/ops/paged_attention/paged_attention.cc b/src/infinicore/ops/paged_attention/paged_attention.cc index 30a6bdf10..86ba68097 100644 --- a/src/infinicore/ops/paged_attention/paged_attention.cc +++ b/src/infinicore/ops/paged_attention/paged_attention.cc @@ -9,20 +9,20 @@ common::OpDispatcher &PagedAttention::dispatcher() { return dispatcher_; }; -void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional alibi_slopes, float scale) { - INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, seq_lens); +void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float scale) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens); infinicore::context::setDevice(out->device()); - dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale); + dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale); } -Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional alibi_slopes, float scale) { +Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float scale) { auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); - paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale); + paged_attention_(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale); return out; } -void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional alibi_slopes, float scale) { - PagedAttention::execute(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale); +void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float scale) { + PagedAttention::execute(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale); } } // namespace infinicore::op diff --git a/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc b/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc index 2a164ef62..64c659693 100644 --- a/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc +++ b/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc @@ -15,8 +15,8 @@ thread_local common::OpCache caches( } }); -void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional alibi_slopes, float scale) { - size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, seq_lens); +void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float scale) { + size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens); auto device = context::getDevice(); auto &cache = caches.getCache(device); @@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc if (!desc_opt) { INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor( context::getInfiniopHandle(device), &desc, - out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), seq_lens->desc(), + out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), cache_lens->desc(), alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr, scale)); cache.put(seed, desc); @@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc INFINICORE_CHECK_ERROR(infiniopPagedAttention( desc, workspace->data(), workspace_size, - out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), seq_lens->data(), + out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), cache_lens->data(), alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr, context::getStream())); } diff --git a/src/infinicore/pybind11/ops/paged_attention.hpp b/src/infinicore/pybind11/ops/paged_attention.hpp index dcf12925a..ab77c87a4 100644 --- a/src/infinicore/pybind11/ops/paged_attention.hpp +++ b/src/infinicore/pybind11/ops/paged_attention.hpp @@ -8,21 +8,21 @@ namespace py = pybind11; namespace infinicore::ops { -Tensor py_paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, pybind11::object alibi_slopes, float scale) { +Tensor py_paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, pybind11::object alibi_slopes, float scale) { std::optional alibi_slopes_tensor = std::nullopt; if (!alibi_slopes.is_none()) { alibi_slopes_tensor = alibi_slopes.cast(); } - return op::paged_attention(q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes_tensor, scale); + return op::paged_attention(q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes_tensor, scale); } -void py_paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, pybind11::object alibi_slopes, float scale) { +void py_paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, pybind11::object alibi_slopes, float scale) { std::optional alibi_slopes_tensor = std::nullopt; if (!alibi_slopes.is_none()) { alibi_slopes_tensor = alibi_slopes.cast(); } - op::paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes_tensor, scale); + op::paged_attention_(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes_tensor, scale); } inline void bind_paged_attention(py::module &m) { @@ -32,7 +32,7 @@ inline void bind_paged_attention(py::module &m) { py::arg("k_cache"), py::arg("v_cache"), py::arg("block_tables"), - py::arg("seq_lens"), + py::arg("cache_lens"), py::arg("alibi_slopes"), py::arg("scale"), R"doc(Paged attention of query and key cache tensors.)doc"); @@ -44,7 +44,7 @@ inline void bind_paged_attention(py::module &m) { py::arg("k_cache"), py::arg("v_cache"), py::arg("block_tables"), - py::arg("seq_lens"), + py::arg("cache_lens"), py::arg("alibi_slopes"), py::arg("scale"), R"doc(In-place paged attention of query and key cache tensors.)doc"); diff --git a/test/infinicore/ops/paged_attention.py b/test/infinicore/ops/paged_attention.py index 79ddc83ab..cc9f2977f 100644 --- a/test/infinicore/ops/paged_attention.py +++ b/test/infinicore/ops/paged_attention.py @@ -62,7 +62,7 @@ def parse_test_cases(): max_blocks_per_seq = (max_seq_len + block_size - 1) // block_size num_blocks = num_seqs * max_blocks_per_seq # A reasonable number for testing - seq_lens_torch = torch.randint(1, max_seq_len, (num_seqs,), dtype=torch.int64) + cache_lens_torch = torch.randint(1, max_seq_len, (num_seqs,), dtype=torch.int64) block_tables = torch.arange( 0, num_seqs * max_blocks_per_seq, dtype=torch.int64 @@ -75,7 +75,7 @@ def parse_test_cases(): v_cache_shape = (num_blocks, num_kv_heads, block_size, head_size) block_tables_shape = block_tables.shape - seq_lens_shape = seq_lens_torch.shape + cache_lens_shape = cache_lens_torch.shape # Generate test cases for all data types for dtype in _TENSOR_DTYPES: @@ -91,10 +91,10 @@ def parse_test_cases(): set_tensor=block_tables, dtype=infinicore.int64, ) - seq_lens_spec = TensorSpec.from_tensor( - seq_lens_shape, + cache_lens_spec = TensorSpec.from_tensor( + cache_lens_shape, init_mode=TensorInitializer.MANUAL, - set_tensor=seq_lens_torch, + set_tensor=cache_lens_torch, dtype=infinicore.int64, ) @@ -108,7 +108,7 @@ def parse_test_cases(): k_cache_spec, v_cache_spec, block_tables_spec, - seq_lens_spec, + cache_lens_spec, ], kwargs={"alibi_slopes": None, "scale": scale}, output_spec=None, @@ -132,7 +132,7 @@ def ref_masked_attention(query, key, value, scale, attn_mask=None): def ref_single_query_cached_kv_attention( - query, key_cache, value_cache, block_tables, seq_lens, alibi_slopes, scale + query, key_cache, value_cache, block_tables, cache_lens, alibi_slopes, scale ): # Reference implementation for paged attention, iterating through each sequence. output = torch.empty_like(query) @@ -143,7 +143,7 @@ def ref_single_query_cached_kv_attention( for i in range(num_seqs): q = query[i].unsqueeze(0) - seq_len = seq_lens[i].item() + seq_len = cache_lens[i].item() block_table = block_tables[i] keys_lst, values_lst = [], [] From 99b940b273832ea05a0531081d7866ee5b3bd198 Mon Sep 17 00:00:00 2001 From: PanZezhong Date: Tue, 30 Dec 2025 08:34:44 +0000 Subject: [PATCH 2/2] =?UTF-8?q?issue/847=20paged=20attention=20prefill?= =?UTF-8?q?=E4=B8=80=E6=AE=B5=E5=BC=8F=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/common/hash.hpp | 10 ++++ include/infinicore/ops.hpp | 1 + .../ops/paged_attention_prefill.hpp | 18 ++++++ python/infinicore/__init__.py | 2 + .../infinicore/ops/paged_attention_prefill.py | 46 ++++++++++++++++ .../paged_attention_infiniop.cc | 2 +- .../paged_attention_prefill.cc | 28 ++++++++++ .../paged_attention_prefill_infiniop.cc | 55 +++++++++++++++++++ .../paged_attention_prefill/cuda/kernel.cuh | 28 +++++----- xmake/nvidia.lua | 2 +- 10 files changed, 176 insertions(+), 16 deletions(-) create mode 100644 include/infinicore/ops/paged_attention_prefill.hpp create mode 100644 python/infinicore/ops/paged_attention_prefill.py create mode 100644 src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc create mode 100644 src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc diff --git a/include/infinicore/common/hash.hpp b/include/infinicore/common/hash.hpp index ec930a53f..a634e0150 100644 --- a/include/infinicore/common/hash.hpp +++ b/include/infinicore/common/hash.hpp @@ -2,6 +2,7 @@ #include "../tensor.hpp" +#include #include namespace infinicore { @@ -24,6 +25,15 @@ inline void hash_combine(size_t &seed, Tensor tensor) { } } +// Specialization for optional +template +inline void hash_combine(size_t &seed, const std::optional &opt) { + hash_combine(seed, opt.has_value()); + if (opt) { + hash_combine(seed, *opt); + } +} + // Specialization for std::string inline void hash_combine(size_t &seed, const std::string &str) { hash_combine(seed, std::hash{}(str)); diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 2e46fdf63..7d523472a 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -6,6 +6,7 @@ #include "ops/matmul.hpp" #include "ops/ones.hpp" #include "ops/paged_attention.hpp" +#include "ops/paged_attention_prefill.hpp" #include "ops/paged_caching.hpp" #include "ops/random_sample.hpp" #include "ops/rearrange.hpp" diff --git a/include/infinicore/ops/paged_attention_prefill.hpp b/include/infinicore/ops/paged_attention_prefill.hpp new file mode 100644 index 000000000..5f5bcdfe9 --- /dev/null +++ b/include/infinicore/ops/paged_attention_prefill.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +class PagedAttentionPrefill { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional, float); + static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional alibi_slopes, float); + static common::OpDispatcher &dispatcher(); +}; + +Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional alibi_slopes, float scale); +void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional alibi_slopes, float scale); +} // namespace infinicore::op diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 43b0cbc09..16115e753 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -45,6 +45,7 @@ from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow from infinicore.ops.paged_attention import paged_attention +from infinicore.ops.paged_attention_prefill import paged_attention_prefill from infinicore.ops.paged_caching import paged_caching from infinicore.ops.rearrange import rearrange from infinicore.ops.squeeze import squeeze @@ -119,6 +120,7 @@ "from_torch", "paged_caching", "paged_attention", + "paged_attention_prefill", "ones", "strided_empty", "strided_from_blob", diff --git a/python/infinicore/ops/paged_attention_prefill.py b/python/infinicore/ops/paged_attention_prefill.py new file mode 100644 index 000000000..2527ab23e --- /dev/null +++ b/python/infinicore/ops/paged_attention_prefill.py @@ -0,0 +1,46 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def paged_attention_prefill( + q: Tensor, + k_cache: Tensor, + v_cache: Tensor, + block_tables: Tensor, + cache_lens: Tensor, + seq_lens: Tensor, + seq_offsets: Tensor, + alibi_slopes: Tensor | None = None, + scale: float = 1.0, + *, + out: Tensor | None = None, +): + if out is None: + return Tensor( + _infinicore.paged_attention_prefill( + q._underlying, + k_cache._underlying, + v_cache._underlying, + block_tables._underlying, + cache_lens._underlying, + seq_lens._underlying, + seq_offsets._underlying, + alibi_slopes._underlying if alibi_slopes is not None else None, + scale, + ) + ) + + _infinicore.paged_attention_prefill_( + out._underlying, + q._underlying, + k_cache._underlying, + v_cache._underlying, + block_tables._underlying, + cache_lens._underlying, + seq_lens._underlying, + seq_offsets._underlying, + alibi_slopes._underlying if alibi_slopes is not None else None, + scale, + ) + + return out diff --git a/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc b/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc index 64c659693..bf72e9f8f 100644 --- a/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc +++ b/src/infinicore/ops/paged_attention/paged_attention_infiniop.cc @@ -16,7 +16,7 @@ thread_local common::OpCache caches( }); void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional alibi_slopes, float scale) { - size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens); + size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale); auto device = context::getDevice(); auto &cache = caches.getCache(device); diff --git a/src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc b/src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc new file mode 100644 index 000000000..c5f7f1a7f --- /dev/null +++ b/src/infinicore/ops/paged_attention_prefill/paged_attention_prefill.cc @@ -0,0 +1,28 @@ +#include "infinicore/ops/paged_attention_prefill.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &PagedAttentionPrefill::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void PagedAttentionPrefill::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional alibi_slopes, float scale) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens); + infinicore::context::setDevice(out->device()); + dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale); +} + +Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional alibi_slopes, float scale) { + auto out = Tensor::empty(q->shape(), q->dtype(), q->device()); + paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale); + return out; +} + +void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional alibi_slopes, float scale) { + PagedAttentionPrefill::execute(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc b/src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc new file mode 100644 index 000000000..a6db877dc --- /dev/null +++ b/src/infinicore/ops/paged_attention_prefill/paged_attention_prefill_infiniop.cc @@ -0,0 +1,55 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/paged_attention_prefill.hpp" +#include + +namespace infinicore::op::paged_attention_prefill_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopPagedAttentionPrefillDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyPagedAttentionPrefillDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional alibi_slopes, float scale) { + size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopPagedAttentionPrefillDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionPrefillDescriptor( + context::getInfiniopHandle(device), &desc, + out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), + cache_lens->desc(), seq_lens->desc(), seq_offsets->desc(), + alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr, + scale)); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetPagedAttentionPrefillWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopPagedAttentionPrefill( + desc, workspace->data(), workspace_size, + out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), cache_lens->data(), seq_lens->data(), seq_offsets->data(), + alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr, + context::getStream())); +} + +static bool registered = []() { + PagedAttentionPrefill::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::paged_attention_prefill_impl::infiniop diff --git a/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh b/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh index ec9aad40c..75220da93 100644 --- a/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh +++ b/src/infiniop/ops/paged_attention_prefill/cuda/kernel.cuh @@ -4,10 +4,10 @@ namespace op::paged_attention_prefill::cuda { // 辅助函数:二分查找确定当前 global_token_idx 属于哪个 sequence -__device__ __forceinline__ int find_seq_id(int token_idx, const int64_t *offset, int num_seqs) { - int low = 0, high = num_seqs - 1; +__device__ __forceinline__ size_t find_seq_id(size_t token_idx, const int64_t *offset, size_t num_seqs) { + size_t low = 0, high = num_seqs - 1; while (low <= high) { - int mid = (low + high) >> 1; + size_t mid = (low + high) >> 1; if (token_idx >= offset[mid] && token_idx < offset[mid + 1]) { return mid; } else if (token_idx < offset[mid]) { @@ -32,22 +32,22 @@ __global__ void pagedAttentionPrefillKernel( const size_t num_seqs) { // --- 使用 2D Grid 坐标 --- - const int global_token_idx = blockIdx.x; // 展平后的全局 token 索引 - const int head_idx = blockIdx.y; // Head 索引 - const int dim_idx = threadIdx.x; // Head 内部维度 + const size_t global_token_idx = blockIdx.x; // 展平后的全局 token 索引 + const size_t head_idx = blockIdx.y; // Head 索引 + const size_t dim_idx = threadIdx.x; // Head 内部维度 if (dim_idx >= head_size) { return; } // --- 通过二分查找 offset 找到所属的 seq_idx --- - int seq_idx = find_seq_id(global_token_idx, offset_, num_seqs); + size_t seq_idx = find_seq_id(global_token_idx, offset_, num_seqs); // --- 获取该 Sequence 本次 Prefill 的长度 const int64_t cur_new_len = seq_lens_[seq_idx]; // --- 该 token 在当前序列中的相对位置 - int q_token_idx = global_token_idx - offset_[seq_idx]; + size_t q_token_idx = global_token_idx - offset_[seq_idx]; const Tdata *q_ptr_base = q_ + global_token_idx * num_heads * head_size + head_idx * head_size; Tdata *out_ptr = out_ + global_token_idx * num_heads * head_size + head_idx * head_size; @@ -65,14 +65,14 @@ __global__ void pagedAttentionPrefillKernel( // Pass 1: 计算 Score 并找最大值 Tcompute max_score = -FLT_MAX; - for (int t = 0; t <= causal_limit; ++t) { + for (size_t t = 0; t <= causal_limit; ++t) { const int64_t b_idx = t / block_size; const int64_t t_off = t % block_size; const int64_t physical_block_id = block_table[b_idx]; const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; Tcompute score = 0.0f; - for (int d = 0; d < head_size; ++d) { + for (size_t d = 0; d < head_size; ++d) { score += static_cast(q_ptr_base[d]) * static_cast(k_vec[d]); } score *= static_cast(scale); @@ -86,14 +86,14 @@ __global__ void pagedAttentionPrefillKernel( // Pass 2: 计算 Sum of Exp Tcompute sum_exp = 0.0f; - for (int t = 0; t <= causal_limit; ++t) { + for (size_t t = 0; t <= causal_limit; ++t) { const int64_t b_idx = t / block_size; const int64_t t_off = t % block_size; const int64_t physical_block_id = block_table[b_idx]; const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; Tcompute score = 0.0f; - for (int d = 0; d < head_size; ++d) { + for (size_t d = 0; d < head_size; ++d) { score += static_cast(q_ptr_base[d]) * static_cast(k_vec[d]); } score *= static_cast(scale); @@ -106,14 +106,14 @@ __global__ void pagedAttentionPrefillKernel( // Pass 3: 加权求和得到输出 Tcompute acc = 0.0f; Tcompute inv_sum = 1.0f / (sum_exp + 1e-6f); - for (int t = 0; t <= causal_limit; ++t) { + for (size_t t = 0; t <= causal_limit; ++t) { const int64_t b_idx = t / block_size; const int64_t t_off = t % block_size; const int64_t physical_block_id = block_table[b_idx]; const Tdata *k_vec = k_cache_ + physical_block_id * kv_block_stride + kv_head_idx * kv_head_stride + t_off * head_size; Tcompute score = 0.0f; - for (int d = 0; d < head_size; ++d) { + for (size_t d = 0; d < head_size; ++d) { score += static_cast(q_ptr_base[d]) * static_cast(k_vec[d]); } score *= static_cast(scale); diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index db575969b..a86090776 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -55,7 +55,7 @@ target("infiniop-nvidia") end end - add_cuflags("-Xcompiler=-Wno-error=deprecated-declarations") + add_cuflags("-Xcompiler=-Wno-error=deprecated-declarations", "-Xcompiler=-Wno-error=unused-function") local arch_opt = get_config("cuda_arch") if arch_opt and type(arch_opt) == "string" then