Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions include/infinicore/common/hash.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "../tensor.hpp"

#include <optional>
#include <type_traits>

namespace infinicore {
Expand All @@ -24,6 +25,15 @@ inline void hash_combine(size_t &seed, Tensor tensor) {
}
}

// Specialization for optional
template <typename T>
inline void hash_combine(size_t &seed, const std::optional<T> &opt) {
hash_combine(seed, opt.has_value());
if (opt) {
hash_combine(seed, *opt);
}
}

// Specialization for std::string
inline void hash_combine(size_t &seed, const std::string &str) {
hash_combine(seed, std::hash<std::string>{}(str));
Expand Down
1 change: 1 addition & 0 deletions include/infinicore/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "ops/matmul.hpp"
#include "ops/ones.hpp"
#include "ops/paged_attention.hpp"
#include "ops/paged_attention_prefill.hpp"
#include "ops/paged_caching.hpp"
#include "ops/random_sample.hpp"
#include "ops/rearrange.hpp"
Expand Down
6 changes: 3 additions & 3 deletions include/infinicore/ops/paged_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ namespace infinicore::op {
class PagedAttention {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale);
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale);
} // namespace infinicore::op
18 changes: 18 additions & 0 deletions include/infinicore/ops/paged_attention_prefill.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include "../device.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {

class PagedAttentionPrefill {
public:
using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, std::optional<Tensor>, float);
static void execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float);
static common::OpDispatcher<schema> &dispatcher();
};

Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale);
void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale);
} // namespace infinicore::op
2 changes: 2 additions & 0 deletions python/infinicore/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from infinicore.ops.mul import mul
from infinicore.ops.narrow import narrow
from infinicore.ops.paged_attention import paged_attention
from infinicore.ops.paged_attention_prefill import paged_attention_prefill
from infinicore.ops.paged_caching import paged_caching
from infinicore.ops.rearrange import rearrange
from infinicore.ops.squeeze import squeeze
Expand Down Expand Up @@ -119,6 +120,7 @@
"from_torch",
"paged_caching",
"paged_attention",
"paged_attention_prefill",
"ones",
"strided_empty",
"strided_from_blob",
Expand Down
6 changes: 3 additions & 3 deletions python/infinicore/ops/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def paged_attention(
k_cache: Tensor,
v_cache: Tensor,
block_tables: Tensor,
seq_lens: Tensor,
cache_lens: Tensor,
alibi_slopes: Tensor | None = None,
scale: float = 1.0,
*,
Expand All @@ -20,7 +20,7 @@ def paged_attention(
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
seq_lens._underlying,
cache_lens._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
Expand All @@ -32,7 +32,7 @@ def paged_attention(
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
seq_lens._underlying,
cache_lens._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
Expand Down
46 changes: 46 additions & 0 deletions python/infinicore/ops/paged_attention_prefill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from infinicore.lib import _infinicore
from infinicore.tensor import Tensor


def paged_attention_prefill(
q: Tensor,
k_cache: Tensor,
v_cache: Tensor,
block_tables: Tensor,
cache_lens: Tensor,
seq_lens: Tensor,
seq_offsets: Tensor,
alibi_slopes: Tensor | None = None,
scale: float = 1.0,
*,
out: Tensor | None = None,
):
if out is None:
return Tensor(
_infinicore.paged_attention_prefill(
q._underlying,
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
cache_lens._underlying,
seq_lens._underlying,
seq_offsets._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)
)

_infinicore.paged_attention_prefill_(
out._underlying,
q._underlying,
k_cache._underlying,
v_cache._underlying,
block_tables._underlying,
cache_lens._underlying,
seq_lens._underlying,
seq_offsets._underlying,
alibi_slopes._underlying if alibi_slopes is not None else None,
scale,
)

return out
14 changes: 7 additions & 7 deletions src/infinicore/ops/paged_attention/paged_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,20 @@ common::OpDispatcher<PagedAttention::schema> &PagedAttention::dispatcher() {
return dispatcher_;
};

void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, seq_lens);
void PagedAttention::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens);
infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
}

Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
Tensor paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
paged_attention_(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
return out;
}

void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale);
void paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttention::execute(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);
}

} // namespace infinicore::op
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ thread_local common::OpCache<size_t, infiniopPagedAttentionDescriptor_t> caches(
}
});

void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, seq_lens);
void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes, scale);

auto device = context::getDevice();
auto &cache = caches.getCache(device);
Expand All @@ -27,7 +27,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc
if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionDescriptor(
context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), seq_lens->desc(),
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(), cache_lens->desc(),
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale));
cache.put(seed, desc);
Expand All @@ -41,7 +41,7 @@ void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor bloc

INFINICORE_CHECK_ERROR(infiniopPagedAttention(
desc, workspace->data(), workspace_size,
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), seq_lens->data(),
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), cache_lens->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
context::getStream()));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "infinicore/ops/paged_attention_prefill.hpp"

#include "../../utils.hpp"

namespace infinicore::op {

common::OpDispatcher<PagedAttentionPrefill::schema> &PagedAttentionPrefill::dispatcher() {
static common::OpDispatcher<PagedAttentionPrefill::schema> dispatcher_;
return dispatcher_;
};

void PagedAttentionPrefill::execute(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, q, k_cache, v_cache, block_tables, cache_lens);
infinicore::context::setDevice(out->device());
dispatcher().lookup(out->device().getType())(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
}

Tensor paged_attention_prefill(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
auto out = Tensor::empty(q->shape(), q->dtype(), q->device());
paged_attention_prefill_(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
return out;
}

void paged_attention_prefill_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
PagedAttentionPrefill::execute(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);
}

} // namespace infinicore::op
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include "../../utils.hpp"
#include "infinicore/common/hash.hpp"
#include "infinicore/ops/common/cache.hpp"
#include "infinicore/ops/paged_attention_prefill.hpp"
#include <infiniop.h>

namespace infinicore::op::paged_attention_prefill_impl::infiniop {

thread_local common::OpCache<size_t, infiniopPagedAttentionPrefillDescriptor_t> caches(
100, // capacity
[](infiniopPagedAttentionPrefillDescriptor_t &desc) {
if (desc != nullptr) {
INFINICORE_CHECK_ERROR(infiniopDestroyPagedAttentionPrefillDescriptor(desc));
desc = nullptr;
}
});

void calculate(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, Tensor seq_lens, Tensor seq_offsets, std::optional<Tensor> alibi_slopes, float scale) {
size_t seed = hash_combine(out, q, k_cache, v_cache, block_tables, cache_lens, seq_lens, seq_offsets, alibi_slopes, scale);

auto device = context::getDevice();
auto &cache = caches.getCache(device);

auto desc_opt = cache.get(seed);
infiniopPagedAttentionPrefillDescriptor_t desc = nullptr;

if (!desc_opt) {
INFINICORE_CHECK_ERROR(infiniopCreatePagedAttentionPrefillDescriptor(
context::getInfiniopHandle(device), &desc,
out->desc(), q->desc(), k_cache->desc(), v_cache->desc(), block_tables->desc(),
cache_lens->desc(), seq_lens->desc(), seq_offsets->desc(),
alibi_slopes.has_value() ? alibi_slopes.value()->desc() : nullptr,
scale));
cache.put(seed, desc);
} else {
desc = *desc_opt;
}

size_t workspace_size = 0;
INFINICORE_CHECK_ERROR(infiniopGetPagedAttentionPrefillWorkspaceSize(desc, &workspace_size));
std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size);

INFINICORE_CHECK_ERROR(infiniopPagedAttentionPrefill(
desc, workspace->data(), workspace_size,
out->data(), q->data(), k_cache->data(), v_cache->data(), block_tables->data(), cache_lens->data(), seq_lens->data(), seq_offsets->data(),
alibi_slopes.has_value() ? alibi_slopes.value()->data() : nullptr,
context::getStream()));
}

static bool registered = []() {
PagedAttentionPrefill::dispatcher().registerAll(&calculate, false);
return true;
}();

} // namespace infinicore::op::paged_attention_prefill_impl::infiniop
12 changes: 6 additions & 6 deletions src/infinicore/pybind11/ops/paged_attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@ namespace py = pybind11;

namespace infinicore::ops {

Tensor py_paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, pybind11::object alibi_slopes, float scale) {
Tensor py_paged_attention(Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, pybind11::object alibi_slopes, float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}
return op::paged_attention(q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes_tensor, scale);
return op::paged_attention(q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes_tensor, scale);
}

void py_paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor seq_lens, pybind11::object alibi_slopes, float scale) {
void py_paged_attention_(Tensor out, Tensor q, Tensor k_cache, Tensor v_cache, Tensor block_tables, Tensor cache_lens, pybind11::object alibi_slopes, float scale) {
std::optional<Tensor> alibi_slopes_tensor = std::nullopt;
if (!alibi_slopes.is_none()) {
alibi_slopes_tensor = alibi_slopes.cast<Tensor>();
}

op::paged_attention_(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes_tensor, scale);
op::paged_attention_(out, q, k_cache, v_cache, block_tables, cache_lens, alibi_slopes_tensor, scale);
}

inline void bind_paged_attention(py::module &m) {
Expand All @@ -32,7 +32,7 @@ inline void bind_paged_attention(py::module &m) {
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("block_tables"),
py::arg("seq_lens"),
py::arg("cache_lens"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(Paged attention of query and key cache tensors.)doc");
Expand All @@ -44,7 +44,7 @@ inline void bind_paged_attention(py::module &m) {
py::arg("k_cache"),
py::arg("v_cache"),
py::arg("block_tables"),
py::arg("seq_lens"),
py::arg("cache_lens"),
py::arg("alibi_slopes"),
py::arg("scale"),
R"doc(In-place paged attention of query and key cache tensors.)doc");
Expand Down
Loading