From 7549959824de555ed6c829fde5d8726609cb3bb5 Mon Sep 17 00:00:00 2001 From: greenhandhand <781740145@qq.com> Date: Sun, 14 Dec 2025 15:21:28 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20zeros=5F=20=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops/zeros.hpp | 15 +++++++ python/infinicore/nn/__init__.py | 3 +- python/infinicore/nn/init/__init__.py | 5 +++ python/infinicore/nn/init/zeros_.py | 9 ++++ src/infinicore/ops/zeros/zeros.cc | 26 +++++++++++ src/infinicore/ops/zeros/zeros_infiniop.cc | 52 ++++++++++++++++++++++ src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/zeros_.hpp | 18 ++++++++ 8 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 include/infinicore/ops/zeros.hpp create mode 100644 python/infinicore/nn/init/__init__.py create mode 100644 python/infinicore/nn/init/zeros_.py create mode 100644 src/infinicore/ops/zeros/zeros.cc create mode 100644 src/infinicore/ops/zeros/zeros_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/zeros_.hpp diff --git a/include/infinicore/ops/zeros.hpp b/include/infinicore/ops/zeros.hpp new file mode 100644 index 000000000..709c41855 --- /dev/null +++ b/include/infinicore/ops/zeros.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "common/op.hpp" + +namespace infinicore::op { +class Zeros { + +public: + using schema = void (*)(Tensor); + static void execute(Tensor output); + static common::OpDispatcher &dispatcher(); +}; + +void zeros_(Tensor output); +} // namespace infinicore::op diff --git a/python/infinicore/nn/__init__.py b/python/infinicore/nn/__init__.py index 73c9f0aaa..e221fa958 100644 --- a/python/infinicore/nn/__init__.py +++ b/python/infinicore/nn/__init__.py @@ -1,5 +1,6 @@ from infinicore.nn import functional from infinicore.nn.modules import * # noqa: F403 from infinicore.nn.parameter import InfiniCoreParameter as Parameter +from infinicore.nn import init -__all__ = ["functional", "Parameter"] +__all__ = ["functional", "Parameter", "init"] diff --git a/python/infinicore/nn/init/__init__.py b/python/infinicore/nn/init/__init__.py new file mode 100644 index 000000000..6b37744f2 --- /dev/null +++ b/python/infinicore/nn/init/__init__.py @@ -0,0 +1,5 @@ +from .zeros_ import zeros_ + +__all__ = [ + "zeros_", +] diff --git a/python/infinicore/nn/init/zeros_.py b/python/infinicore/nn/init/zeros_.py new file mode 100644 index 000000000..718c84135 --- /dev/null +++ b/python/infinicore/nn/init/zeros_.py @@ -0,0 +1,9 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def zeros_(input: Tensor) -> Tensor: + r"""Fill the input tensor with the scalar value 0.""" + _infinicore.zeros_(input._underlying) + return input diff --git a/src/infinicore/ops/zeros/zeros.cc b/src/infinicore/ops/zeros/zeros.cc new file mode 100644 index 000000000..97102054d --- /dev/null +++ b/src/infinicore/ops/zeros/zeros.cc @@ -0,0 +1,26 @@ +#include "infinicore/ops/zeros.hpp" +#include + +namespace infinicore::op { + +common::OpDispatcher &Zeros::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Zeros::execute(Tensor output) { + context::setDevice(output->device()); + auto device_type = context::getDevice().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No Zeros implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(output); +} + +void zeros_(Tensor output) { + Zeros::execute(output); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/zeros/zeros_infiniop.cc b/src/infinicore/ops/zeros/zeros_infiniop.cc new file mode 100644 index 000000000..62762a4c7 --- /dev/null +++ b/src/infinicore/ops/zeros/zeros_infiniop.cc @@ -0,0 +1,52 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/zeros.hpp" +#include + +namespace infinicore::op::zeros_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopZerosDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyZerosDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor output) { + size_t seed = hash_combine(output); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopZerosDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateZerosDescriptor( + context::getInfiniopHandle(output->device()), &desc, + output->desc(), output->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetZerosWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopZeros( + desc, workspace->data(), workspace_size, + output->data(), output->data(), context::getStream())); +} + +static bool registered = []() { + Zeros::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::zeros_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 978defa17..531043dd2 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -15,6 +15,7 @@ #include "ops/rope.hpp" #include "ops/silu.hpp" #include "ops/swiglu.hpp" +#include "ops/zeros_.hpp" namespace py = pybind11; @@ -34,6 +35,7 @@ inline void bind(py::module &m) { bind_swiglu(m); bind_rope(m); bind_embedding(m); + bind_zeros_(m); } } // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/zeros_.hpp b/src/infinicore/pybind11/ops/zeros_.hpp new file mode 100644 index 000000000..660fadd30 --- /dev/null +++ b/src/infinicore/pybind11/ops/zeros_.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include + +#include "infinicore/ops/zeros.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_zeros_(py::module &m) { + m.def("zeros_", + &op::zeros_, + py::arg("input"), + R"doc(Fills the input tensor with zeros.)doc"); +} + +} // namespace infinicore::ops From d82a480d35502e748687a49ba4e77a40d106ee30 Mon Sep 17 00:00:00 2001 From: greenhandhand <781740145@qq.com> Date: Mon, 15 Dec 2025 15:38:49 +0800 Subject: [PATCH 2/2] =?UTF-8?q?Infinicore=20=E6=AF=94=E8=B5=9B=EF=BC=8C?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20bitwise=5Fleft=5Fshift,=20fold,=20index=5F?= =?UTF-8?q?select,=20log2,=20mish=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops/bitwise_left_shift.hpp | 16 ++ include/infinicore/ops/fold.hpp | 18 +++ include/infinicore/ops/index_select.hpp | 16 ++ include/infinicore/ops/log2.hpp | 16 ++ include/infinicore/ops/mish.hpp | 15 ++ python/infinicore/__init__.py | 5 + python/infinicore/nn/functional/__init__.py | 4 + python/infinicore/nn/functional/fold.py | 47 ++++++ python/infinicore/nn/functional/mish.py | 20 +++ python/infinicore/ops/bitwise_left_shift.py | 21 +++ python/infinicore/ops/index_select.py | 20 +++ python/infinicore/ops/log2.py | 21 +++ .../bitwise_left_shift/bitwise_left_shift.cc | 28 ++++ .../bitwise_left_shift_cpu.cc | 141 ++++++++++++++++ src/infinicore/ops/fold/fold.cc | 72 +++++++++ src/infinicore/ops/fold/fold_cpu.cc | 122 ++++++++++++++ .../ops/index_select/index_select.cc | 43 +++++ .../ops/index_select/index_select_cpu.cc | 153 ++++++++++++++++++ src/infinicore/ops/log2/log2.cc | 27 ++++ src/infinicore/ops/log2/log2_cpu.cc | 69 ++++++++ src/infinicore/ops/mish/mish.cc | 29 ++++ src/infinicore/ops/mish/mish_cpu.cc | 104 ++++++++++++ src/infinicore/pybind11/ops.hpp | 10 ++ .../pybind11/ops/bitwise_left_shift.hpp | 26 +++ src/infinicore/pybind11/ops/fold.hpp | 23 +++ src/infinicore/pybind11/ops/index_select.hpp | 28 ++++ src/infinicore/pybind11/ops/log2.hpp | 24 +++ src/infinicore/pybind11/ops/mish.hpp | 20 +++ test/infinicore/ops/bitwise_left_shift.py | 6 +- test/infinicore/ops/fold.py | 6 +- test/infinicore/ops/index_select.py | 6 +- test/infinicore/ops/log2.py | 6 +- test/infinicore/ops/mish.py | 6 +- 33 files changed, 1153 insertions(+), 15 deletions(-) create mode 100644 include/infinicore/ops/bitwise_left_shift.hpp create mode 100644 include/infinicore/ops/fold.hpp create mode 100644 include/infinicore/ops/index_select.hpp create mode 100644 include/infinicore/ops/log2.hpp create mode 100644 include/infinicore/ops/mish.hpp create mode 100644 python/infinicore/nn/functional/fold.py create mode 100644 python/infinicore/nn/functional/mish.py create mode 100644 python/infinicore/ops/bitwise_left_shift.py create mode 100644 python/infinicore/ops/index_select.py create mode 100644 python/infinicore/ops/log2.py create mode 100644 src/infinicore/ops/bitwise_left_shift/bitwise_left_shift.cc create mode 100644 src/infinicore/ops/bitwise_left_shift/bitwise_left_shift_cpu.cc create mode 100644 src/infinicore/ops/fold/fold.cc create mode 100644 src/infinicore/ops/fold/fold_cpu.cc create mode 100644 src/infinicore/ops/index_select/index_select.cc create mode 100644 src/infinicore/ops/index_select/index_select_cpu.cc create mode 100644 src/infinicore/ops/log2/log2.cc create mode 100644 src/infinicore/ops/log2/log2_cpu.cc create mode 100644 src/infinicore/ops/mish/mish.cc create mode 100644 src/infinicore/ops/mish/mish_cpu.cc create mode 100644 src/infinicore/pybind11/ops/bitwise_left_shift.hpp create mode 100644 src/infinicore/pybind11/ops/fold.hpp create mode 100644 src/infinicore/pybind11/ops/index_select.hpp create mode 100644 src/infinicore/pybind11/ops/log2.hpp create mode 100644 src/infinicore/pybind11/ops/mish.hpp diff --git a/include/infinicore/ops/bitwise_left_shift.hpp b/include/infinicore/ops/bitwise_left_shift.hpp new file mode 100644 index 000000000..a62b7bfa6 --- /dev/null +++ b/include/infinicore/ops/bitwise_left_shift.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class BitwiseLeftShift { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor c, Tensor a, Tensor b); + static common::OpDispatcher &dispatcher(); +}; + +Tensor bitwise_left_shift(Tensor a, Tensor b); +void bitwise_left_shift_(Tensor c, Tensor a, Tensor b); +} // namespace infinicore::op diff --git a/include/infinicore/ops/fold.hpp b/include/infinicore/ops/fold.hpp new file mode 100644 index 000000000..7aebaa334 --- /dev/null +++ b/include/infinicore/ops/fold.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { +class Fold { +public: + using schema = void (*)(Tensor, Tensor, std::tuple, std::tuple, std::tuple, std::tuple, std::tuple); + // Pytorch 文档目前说明了只支持 (N, C, H, W) 和 (C, H, W) 格式的输入输出 + static void execute(Tensor output, Tensor input, std::tuple output_size, std::tuple kernel_size, std::tuple dilation, std::tuple padding, std::tuple stride); + static common::OpDispatcher &dispatcher(); +}; + +Tensor fold(Tensor input, std::tuple output_size, std::tuple kernel_size, std::tuple dilation, std::tuple padding, std::tuple stride); +void fold_(Tensor output, Tensor input, std::tuple output_size, std::tuple kernel_size, std::tuple dilation, std::tuple padding, std::tuple stride); +} // namespace infinicore::op diff --git a/include/infinicore/ops/index_select.hpp b/include/infinicore/ops/index_select.hpp new file mode 100644 index 000000000..a1faccd81 --- /dev/null +++ b/include/infinicore/ops/index_select.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class IndexSelect { +public: + using schema = void (*)(Tensor, Tensor, int, Tensor); + static void execute(Tensor output, Tensor input, int dim, Tensor index); + static common::OpDispatcher &dispatcher(); +}; + +Tensor index_select(Tensor input, int dim, Tensor index); +void index_select_(Tensor output, Tensor input, int dim, Tensor index); +} // namespace infinicore::op diff --git a/include/infinicore/ops/log2.hpp b/include/infinicore/ops/log2.hpp new file mode 100644 index 000000000..bfc51fd4d --- /dev/null +++ b/include/infinicore/ops/log2.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Log2 { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor output, Tensor input); + static common::OpDispatcher &dispatcher(); +}; + +Tensor log2(Tensor input); +void log2_(Tensor output, Tensor input); +} // namespace infinicore::op \ No newline at end of file diff --git a/include/infinicore/ops/mish.hpp b/include/infinicore/ops/mish.hpp new file mode 100644 index 000000000..814ebc53e --- /dev/null +++ b/include/infinicore/ops/mish.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Mish { +public: + using schema = void (*)(Tensor, Tensor, bool); + static void execute(Tensor output, Tensor input, bool inplace); + static common::OpDispatcher &dispatcher(); +}; + +Tensor mish(Tensor input, bool inplace); +} // namespace infinicore::op diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 5c541ec3c..ed601a438 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -45,6 +45,9 @@ from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow from infinicore.ops.rearrange import rearrange +from infinicore.ops.bitwise_left_shift import bitwise_left_shift +from infinicore.ops.index_select import index_select +from infinicore.ops.log2 import log2 from infinicore.tensor import ( Tensor, empty, @@ -115,6 +118,8 @@ "strided_empty", "strided_from_blob", "zeros", + "bitwise_left_shift", + "index_select", ] use_ntops = False diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 255079790..b3bc38d89 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -6,6 +6,8 @@ from .rope import RopeAlgo, rope from .silu import silu from .swiglu import swiglu +from .fold import fold +from .mish import mish __all__ = [ "causal_softmax", @@ -17,4 +19,6 @@ "embedding", "rope", "RopeAlgo", + "fold", + "mish", ] diff --git a/python/infinicore/nn/functional/fold.py b/python/infinicore/nn/functional/fold.py new file mode 100644 index 000000000..94aa04c66 --- /dev/null +++ b/python/infinicore/nn/functional/fold.py @@ -0,0 +1,47 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def fold( + input: Tensor, + output_size: int | tuple[int, int], + kernel_size: int | tuple[int, int], + dilation: int | tuple[int, int] = 1, + padding: int | tuple[int, int] = 0, + stride: int | tuple[int, int] = 1, +) -> Tensor: + r"""Combines an array of sliding local blocks into a large containing tensor. Currently, only unbatched (3D) or batched (4D) image-like output tensors are supported.""" + + if isinstance(output_size, int): + output_size = (output_size, output_size) + + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + + if isinstance(dilation, int): + dilation = (dilation, dilation) + + if isinstance(padding, int): + padding = (padding, padding) + + if isinstance(stride, int): + stride = (stride, stride) + + assert input.ndim in (3, 4), "only 3D or 4D input tensors are supported" + assert len(output_size) == 2, "output_size must be a tuple of two integers (H, W)" + assert len(kernel_size) == 2, "kernel_size must be a tuple of two integers (kH, kW)" + assert len(dilation) == 2, "dilation must be a tuple of two integers (dH, dW)" + assert len(padding) == 2, "padding must be a tuple of two integers (pH, pW)" + assert len(stride) == 2, "stride must be a tuple of two integers (sH, sW)" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.fold( + input, output_size, kernel_size, dilation, padding, stride + ) + + return Tensor( + _infinicore.fold( + input._underlying, output_size, kernel_size, dilation, padding, stride + ) + ) diff --git a/python/infinicore/nn/functional/mish.py b/python/infinicore/nn/functional/mish.py new file mode 100644 index 000000000..87f55125f --- /dev/null +++ b/python/infinicore/nn/functional/mish.py @@ -0,0 +1,20 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def mish( + input: Tensor, inplace: bool = False +) -> Tensor: + r"""Applies the Mish activation function element-wise: mish(x) = x * tanh(softplus(x)).""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.mish(input, inplace) + + if inplace: + _infinicore.mish(input._underlying, inplace) + return input + else: + return Tensor( + _infinicore.mish(input._underlying, inplace) + ) \ No newline at end of file diff --git a/python/infinicore/ops/bitwise_left_shift.py b/python/infinicore/ops/bitwise_left_shift.py new file mode 100644 index 000000000..38aff1ef9 --- /dev/null +++ b/python/infinicore/ops/bitwise_left_shift.py @@ -0,0 +1,21 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def bitwise_left_shift(input: Tensor, other: Tensor, *, out=None) -> Tensor: + r"""Computes the left arithmetic shift of input by other bits. The input tensor must be of integral type.""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.bitwise_left_shift(input, other, out=out) + + if out is None: + return Tensor( + _infinicore.bitwise_left_shift(input._underlying, other._underlying) + ) + + _infinicore.bitwise_left_shift_( + out._underlying, input._underlying, other._underlying + ) + + return out diff --git a/python/infinicore/ops/index_select.py b/python/infinicore/ops/index_select.py new file mode 100644 index 000000000..c1b9580a4 --- /dev/null +++ b/python/infinicore/ops/index_select.py @@ -0,0 +1,20 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def index_select(input: Tensor, dim: int, index: Tensor, *, out=None) -> Tensor: + r"""Selects elements from input along a specific dimension.""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.index_select(input, dim, index, out=out) + if out is None: + return Tensor( + _infinicore.index_select(input._underlying, dim, index._underlying) + ) + + _infinicore.index_select_( + out._underlying, input._underlying, dim, index._underlying + ) + + return out diff --git a/python/infinicore/ops/log2.py b/python/infinicore/ops/log2.py new file mode 100644 index 000000000..70f1cedfd --- /dev/null +++ b/python/infinicore/ops/log2.py @@ -0,0 +1,21 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def log2(input: Tensor, *, out=None) -> Tensor: + r"""Computes the base-2 logarithm of the input tensor element-wise.""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.log2(input, out=out) + + if out is None: + return Tensor( + _infinicore.log2(input._underlying) + ) + + _infinicore.log2_( + out._underlying, input._underlying + ) + + return out diff --git a/src/infinicore/ops/bitwise_left_shift/bitwise_left_shift.cc b/src/infinicore/ops/bitwise_left_shift/bitwise_left_shift.cc new file mode 100644 index 000000000..edb7cee6a --- /dev/null +++ b/src/infinicore/ops/bitwise_left_shift/bitwise_left_shift.cc @@ -0,0 +1,28 @@ +#include "infinicore/ops/bitwise_left_shift.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &BitwiseLeftShift::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void BitwiseLeftShift::execute(Tensor c, Tensor a, Tensor b) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); + infinicore::context::setDevice(c->device(), true); + dispatcher().lookup(c->device().getType())(c, a, b); +} + +Tensor bitwise_left_shift(Tensor a, Tensor b) { + auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); + bitwise_left_shift_(c, a, b); + return c; +} + +void bitwise_left_shift_(Tensor c, Tensor a, Tensor b) { + BitwiseLeftShift::execute(c, a, b); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/bitwise_left_shift/bitwise_left_shift_cpu.cc b/src/infinicore/ops/bitwise_left_shift/bitwise_left_shift_cpu.cc new file mode 100644 index 000000000..c11bd8e3b --- /dev/null +++ b/src/infinicore/ops/bitwise_left_shift/bitwise_left_shift_cpu.cc @@ -0,0 +1,141 @@ +#include "../../../utils.h" + +#include "infinicore/device.hpp" +#include "infinicore/ops/bitwise_left_shift.hpp" +#include "infinicore/tensor.hpp" +#include +#include +#include + +namespace infinicore::op::bitwise_left_shift_impl::cpu { + +void calculate(Tensor c, Tensor a, Tensor b) { + auto a_shapes = a->shape(); + auto b_shapes = b->shape(); + auto c_shapes = c->shape(); + auto a_strides = a->strides(); + auto b_strides = b->strides(); + auto c_strides = c->strides(); + auto dtype = a->dtype(); + auto dtype_size = a->element_size(); + + auto a_base = a->data(); + auto b_base = b->data(); + auto c_base = c->data(); + + size_t c_numel = c->numel(); + + // 处理广播:逐元素操作 +#pragma omp parallel for + for (size_t c_idx = 0; c_idx < c_numel; ++c_idx) { + // 计算输出张量的多维索引 + std::vector c_indices(c_shapes.size()); + size_t temp_idx = c_idx; + for (int i = static_cast(c_shapes.size()) - 1; i >= 0; --i) { + c_indices[i] = temp_idx % c_shapes[i]; + temp_idx /= c_shapes[i]; + } + + // 计算输入张量 a 和 b 的偏移(考虑广播) + size_t a_offset = 0; + size_t b_offset = 0; + + // 处理维度差异(广播) + int a_dim_offset = c_shapes.size() - a_shapes.size(); + int b_dim_offset = c_shapes.size() - b_shapes.size(); + + for (int i = 0; i < static_cast(c_shapes.size()); ++i) { + // 计算 a 的偏移 + if (i >= a_dim_offset) { + int a_idx = i - a_dim_offset; + if (a_shapes[a_idx] > 1) { + a_offset += c_indices[i] * a_strides[a_idx]; + } + } + + // 计算 b 的偏移 + if (i >= b_dim_offset) { + int b_idx = i - b_dim_offset; + if (b_shapes[b_idx] > 1) { + b_offset += c_indices[i] * b_strides[b_idx]; + } + } + } + + // 获取位移量的数据类型和大小 + auto b_dtype = b->dtype(); + auto b_dtype_size = b->element_size(); + + // 读取位移量(转换为 int) + int shift_amount = 0; + if (b_dtype == DataType::I8) { + shift_amount = static_cast(*reinterpret_cast(b_base + b_offset * b_dtype_size)); + } else if (b_dtype == DataType::I16) { + shift_amount = static_cast(*reinterpret_cast(b_base + b_offset * b_dtype_size)); + } else if (b_dtype == DataType::I32) { + shift_amount = static_cast(*reinterpret_cast(b_base + b_offset * b_dtype_size)); + } else if (b_dtype == DataType::I64) { + shift_amount = static_cast(*reinterpret_cast(b_base + b_offset * b_dtype_size)); + } else if (b_dtype == DataType::U8) { + shift_amount = static_cast(*reinterpret_cast(b_base + b_offset * b_dtype_size)); + } else if (b_dtype == DataType::U16) { + shift_amount = static_cast(*reinterpret_cast(b_base + b_offset * b_dtype_size)); + } else if (b_dtype == DataType::U32) { + shift_amount = static_cast(*reinterpret_cast(b_base + b_offset * b_dtype_size)); + } else if (b_dtype == DataType::U64) { + shift_amount = static_cast(*reinterpret_cast(b_base + b_offset * b_dtype_size)); + } else { + throw std::runtime_error("Unsupported shift amount data type for bitwise_left_shift operation."); + } + + // 计算 c 的偏移(考虑非连续内存布局) + size_t c_offset = 0; + for (int i = 0; i < static_cast(c_shapes.size()); ++i) { + c_offset += c_indices[i] * c_strides[i]; + } + + // 根据数据类型执行按位左移 + if (dtype == DataType::I8) { + auto *a_ptr = reinterpret_cast(a_base + a_offset * dtype_size); + auto *c_ptr = reinterpret_cast(c_base + c_offset * dtype_size); + *c_ptr = (shift_amount >= 0 && shift_amount < 8) ? (*a_ptr << shift_amount) : 0; + } else if (dtype == DataType::I16) { + auto *a_ptr = reinterpret_cast(a_base + a_offset * dtype_size); + auto *c_ptr = reinterpret_cast(c_base + c_offset * dtype_size); + *c_ptr = (shift_amount >= 0 && shift_amount < 16) ? (*a_ptr << shift_amount) : 0; + } else if (dtype == DataType::I32) { + auto *a_ptr = reinterpret_cast(a_base + a_offset * dtype_size); + auto *c_ptr = reinterpret_cast(c_base + c_offset * dtype_size); + *c_ptr = (shift_amount >= 0 && shift_amount < 32) ? (*a_ptr << shift_amount) : 0; + } else if (dtype == DataType::I64) { + auto *a_ptr = reinterpret_cast(a_base + a_offset * dtype_size); + auto *c_ptr = reinterpret_cast(c_base + c_offset * dtype_size); + *c_ptr = (shift_amount >= 0 && shift_amount < 64) ? (*a_ptr << shift_amount) : 0; + } else if (dtype == DataType::U8) { + auto *a_ptr = reinterpret_cast(a_base + a_offset * dtype_size); + auto *c_ptr = reinterpret_cast(c_base + c_offset * dtype_size); + *c_ptr = (shift_amount >= 0 && shift_amount < 8) ? (*a_ptr << shift_amount) : 0; + } else if (dtype == DataType::U16) { + auto *a_ptr = reinterpret_cast(a_base + a_offset * dtype_size); + auto *c_ptr = reinterpret_cast(c_base + c_offset * dtype_size); + *c_ptr = (shift_amount >= 0 && shift_amount < 16) ? (*a_ptr << shift_amount) : 0; + } else if (dtype == DataType::U32) { + auto *a_ptr = reinterpret_cast(a_base + a_offset * dtype_size); + auto *c_ptr = reinterpret_cast(c_base + c_offset * dtype_size); + *c_ptr = (shift_amount >= 0 && shift_amount < 32) ? (*a_ptr << shift_amount) : 0; + } else if (dtype == DataType::U64) { + auto *a_ptr = reinterpret_cast(a_base + a_offset * dtype_size); + auto *c_ptr = reinterpret_cast(c_base + c_offset * dtype_size); + *c_ptr = (shift_amount >= 0 && shift_amount < 64) ? (*a_ptr << shift_amount) : 0; + } else { + throw std::runtime_error("Unsupported data type for bitwise_left_shift operation."); + } + } +} + +static bool registered = []() { + BitwiseLeftShift::dispatcher().registerDevice(Device::Type::CPU, &calculate); + return true; +}(); + +} // namespace infinicore::op::bitwise_left_shift_impl::cpu diff --git a/src/infinicore/ops/fold/fold.cc b/src/infinicore/ops/fold/fold.cc new file mode 100644 index 000000000..c2a5e7878 --- /dev/null +++ b/src/infinicore/ops/fold/fold.cc @@ -0,0 +1,72 @@ +#include "infinicore/ops/fold.hpp" +#include +#include + +namespace infinicore::op { + +common::OpDispatcher &Fold::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Fold::execute(Tensor output, Tensor input, std::tuple output_size, std::tuple kernel_size, std::tuple dilation, std::tuple padding, std::tuple stride) { + infinicore::context::setDevice(input->device(), true); + auto device_type = context::getDevice().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No Fold implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(output, input, output_size, kernel_size, dilation, padding, stride); +} + +Tensor fold(Tensor input, std::tuple output_size, std::tuple kernel_size, std::tuple dilation, std::tuple padding, std::tuple stride) { + const auto ndim = input->ndim(); + auto input_shape = input->shape(); + + if (ndim != 3 && ndim != 2) { + throw std::runtime_error("Input tensor must be 3-dimensional (N, C * K_h * K_w, L) or 2-dimensional (C * K_h * K_w, L)"); + } + + // Normalize input to shape [N, C*K_h*K_w, L] + if (ndim == 2) { + // input: [C*K_h*K_w, L] -> [1, C*K_h*K_w, L] + input = input->view({1, input_shape[0], input_shape[1]}); + input_shape = input->shape(); + } // if ndim==3, assume already [N, C*K*K, L] + + // input: [N, C * K_h * K_w, L] + const auto [Kernel_H, Kernel_W] = kernel_size; + const auto [Output_H, Output_W] = output_size; + const auto [Dilation_H, Dilation_W] = dilation; + const auto [Padding_H, Padding_W] = padding; + const auto [Stride_H, Stride_W] = stride; + const auto C = input_shape[1] / (Kernel_H * Kernel_W); + + if (C * Kernel_H * Kernel_W != input_shape[1]) { + throw std::runtime_error("Input channel dimension is not divisible by kernel size product"); + } + // Validate input L equals computed number of sliding positions + const auto L = input_shape[2]; + const auto L_h = (Output_H + 2 * Padding_H >= Dilation_H * (Kernel_H - 1) + 1) + ? (static_cast(std::floor((static_cast(Output_H) + 2.0 * Padding_H - static_cast(Dilation_H) * (Kernel_H - 1) - 1) / Stride_H)) + 1) + : 0; + const auto L_w = (Output_W + 2 * Padding_W >= Dilation_W * (Kernel_W - 1) + 1) + ? (static_cast(std::floor((static_cast(Output_W) + 2.0 * Padding_W - static_cast(Dilation_W) * (Kernel_W - 1) - 1) / Stride_W)) + 1) + : 0; + if (L != L_h * L_w) { + throw std::runtime_error("Input L does not match computed sliding window count"); + } + + auto output_shape = Shape{input_shape[0], C, Output_H, Output_W}; + + auto output = Tensor::empty(output_shape, input->dtype(), input->device()); + fold_(output, input, output_size, kernel_size, dilation, padding, stride); + return output; +} + +void fold_(Tensor output, Tensor input, std::tuple output_size, std::tuple kernel_size, std::tuple dilation, std::tuple padding, std::tuple stride) { + Fold::execute(output, input, output_size, kernel_size, dilation, padding, stride); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/fold/fold_cpu.cc b/src/infinicore/ops/fold/fold_cpu.cc new file mode 100644 index 000000000..1ad4add97 --- /dev/null +++ b/src/infinicore/ops/fold/fold_cpu.cc @@ -0,0 +1,122 @@ +#include "../../../utils.h" +#include "infinicore/device.hpp" +#include "infinicore/ops/fold.hpp" +#include "infinicore/tensor.hpp" +#include +#include +#include +#include +#include + +namespace infinicore::op::fold_impl::cpu { + +void calculate(Tensor output, Tensor input, std::tuple output_size, std::tuple kernel_size, std::tuple dilation, std::tuple padding, std::tuple stride) { + // input: [N, C * K_h * K_w, L], output: [N, C, H_out, W_out] + // L = floor((H_out + 2*pad_h - dil_h*(K_h-1) - 1)/stride_h + 1) * + // floor((W_out + 2*pad_w - dil_w*(K_w-1) - 1)/stride_w + 1) + + auto in_shape = input->shape(); + auto in_strides = input->strides(); + auto out_shape = output->shape(); + auto out_strides = output->strides(); + auto dtype = input->dtype(); + + const size_t N = in_shape[0]; + const size_t Ckk = in_shape[1]; // C * K_h * K_w + const size_t L = in_shape[2]; // number of sliding positions + + const size_t C = out_shape[1]; + const size_t H_out = out_shape[2]; + const size_t W_out = out_shape[3]; + + const size_t strideN_in = in_strides[0]; + const size_t strideC_in = in_strides[1]; + const size_t strideL_in = in_strides[2]; + + const size_t strideN_out = out_strides[0]; + const size_t strideC_out = out_strides[1]; + const size_t strideH_out = out_strides[2]; + const size_t strideW_out = out_strides[3]; + + const size_t K_h = std::get<0>(kernel_size); + const size_t K_w = std::get<1>(kernel_size); + const size_t S_h = std::get<0>(stride); + const size_t S_w = std::get<1>(stride); + const size_t D_h = std::get<0>(dilation); + const size_t D_w = std::get<1>(dilation); + const size_t P_h = std::get<0>(padding); + const size_t P_w = std::get<1>(padding); + + // Basic sanity check + if (C * K_h * K_w != Ckk) { + throw std::runtime_error("Input channel dimension is not divisible by kernel size product"); + } + + auto in_base = input->data(); + auto out_base = output->data(); + const auto elem_size = input->element_size(); + + // Compute L_h and L_w (number of sliding positions per dimension) + const size_t L_h = (H_out + 2 * P_h >= D_h * (K_h - 1) + 1) + ? (static_cast(std::floor((static_cast(H_out) + 2.0 * P_h - static_cast(D_h) * (K_h - 1) - 1) / S_h)) + 1) + : 0; + const size_t L_w = (W_out + 2 * P_w >= D_w * (K_w - 1) + 1) + ? (static_cast(std::floor((static_cast(W_out) + 2.0 * P_w - static_cast(D_w) * (K_w - 1) - 1) / S_w)) + 1) + : 0; + if (L != L_h * L_w) { + throw std::runtime_error("Input L does not match computed sliding window count"); + } + + // Zero-initialize output (accumulate) + std::memset(out_base, 0, out_shape[0] * out_shape[1] * out_shape[2] * out_shape[3] * elem_size); + + for (size_t n = 0; n < N; ++n) { + for (size_t c = 0; c < C; ++c) { + for (size_t oh = 0; oh < L_h; ++oh) { + for (size_t ow = 0; ow < L_w; ++ow) { + const size_t l_idx = oh * L_w + ow; + + for (size_t kh = 0; kh < K_h; ++kh) { + const long y = static_cast(oh) * static_cast(S_h) - static_cast(P_h) + static_cast(kh) * static_cast(D_h); + if (y < 0 || y >= static_cast(H_out)) continue; + + for (size_t kw = 0; kw < K_w; ++kw) { + const long x = static_cast(ow) * static_cast(S_w) - static_cast(P_w) + static_cast(kw) * static_cast(D_w); + if (x < 0 || x >= static_cast(W_out)) continue; + + const size_t ckk = c * (K_h * K_w) + kh * K_w + kw; + + const size_t in_offset = n * strideN_in + ckk * strideC_in + l_idx * strideL_in; + const size_t out_offset = n * strideN_out + c * strideC_out + static_cast(y) * strideH_out + static_cast(x) * strideW_out; + + if (dtype == DataType::F32) { + auto *in_ptr = reinterpret_cast(in_base + in_offset * elem_size); + auto *out_ptr = reinterpret_cast(out_base + out_offset * elem_size); + *out_ptr += *in_ptr; + } else if (dtype == DataType::F64) { + auto *in_ptr = reinterpret_cast(in_base + in_offset * elem_size); + auto *out_ptr = reinterpret_cast(out_base + out_offset * elem_size); + *out_ptr += *in_ptr; + } else if (dtype == DataType::F16) { + auto *in_ptr = reinterpret_cast(in_base + in_offset * elem_size); + auto *out_ptr = reinterpret_cast(out_base + out_offset * elem_size); + float acc = utils::cast(*out_ptr); + acc += utils::cast(*in_ptr); + *out_ptr = utils::cast(acc); + } else { + throw std::runtime_error("Unsupported dtype for fold CPU"); + } + } + } + } + } + } + } +} + +static bool registered = []() { + Fold::dispatcher().registerDevice(Device::Type::CPU, &calculate); + return true; +}(); + +} // namespace infinicore::op::fold_impl::cpu diff --git a/src/infinicore/ops/index_select/index_select.cc b/src/infinicore/ops/index_select/index_select.cc new file mode 100644 index 000000000..a24ca1f53 --- /dev/null +++ b/src/infinicore/ops/index_select/index_select.cc @@ -0,0 +1,43 @@ +#include "infinicore/ops/index_select.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &IndexSelect::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void IndexSelect::execute(Tensor output, Tensor input, int dim, Tensor index) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input, index); + infinicore::context::setDevice(output->device(), true); + dispatcher().lookup(output->device().getType())(output, input, dim, index); +} + +Tensor index_select(Tensor input, int dim, Tensor index) { + if (index->ndim() != 1) { + throw std::runtime_error("Index tensor must be 1-dimensional for index_select operation."); + } + + if (dim < 0) { + dim = dim + input->ndim(); + } + + if (dim < 0 || dim >= static_cast(input->ndim())) { + throw std::runtime_error("Dimension out of range for index_select operation."); + } + + auto output_shape = input->shape(); + output_shape[dim] = index->shape()[0]; + + auto output = Tensor::empty(output_shape, input->dtype(), input->device()); + index_select_(output, input, dim, index); + return output; +} + +void index_select_(Tensor output, Tensor input, int dim, Tensor index) { + IndexSelect::execute(output, input, dim, index); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/index_select/index_select_cpu.cc b/src/infinicore/ops/index_select/index_select_cpu.cc new file mode 100644 index 000000000..2891a3219 --- /dev/null +++ b/src/infinicore/ops/index_select/index_select_cpu.cc @@ -0,0 +1,153 @@ +#include "../../../utils.h" + +#include "infinicore/device.hpp" +#include "infinicore/ops/index_select.hpp" +#include "infinicore/tensor.hpp" +#include +#include +#include + +namespace infinicore::op::index_select_impl::cpu { + +void calculate(Tensor output, Tensor input, int dim, Tensor index) { + auto input_shapes = input->shape(); + auto index_shapes = index->shape(); + auto output_shapes = output->shape(); + auto input_strides = input->strides(); + auto index_strides = index->strides(); + auto output_strides = output->strides(); + auto dtype = input->dtype(); + auto dtype_size = input->element_size(); + + auto input_base = input->data(); + auto index_base = index->data(); + auto output_base = output->data(); + + size_t output_numel = output->numel(); + auto ndim = input->ndim(); + + // 规范化 dim 到 [0, ndim) + if (dim < 0) { + dim = ndim + dim; + } + + // 获取索引张量的数据类型 + auto index_dtype = index->dtype(); + size_t index_numel = index->numel(); + + // 并行遍历输出张量的每个元素 +#pragma omp parallel for + for (size_t output_idx = 0; output_idx < output_numel; ++output_idx) { + // 计算输出张量的多维索引 + std::vector output_indices(ndim); + size_t temp_idx = output_idx; + for (int i = static_cast(ndim) - 1; i >= 0; --i) { + output_indices[i] = temp_idx % output_shapes[i]; + temp_idx /= output_shapes[i]; + } + + // 构造输入张量的多维索引 + // 对于 dim 维度,需要从 index 张量中读取实际索引值 + std::vector input_indices(ndim); + for (int i = 0; i < static_cast(ndim); ++i) { + if (i == dim) { + // 从 index 张量读取索引值 + size_t index_offset = output_indices[i]; + int64_t selected_index = 0; + + if (index_dtype == DataType::I32) { + selected_index = static_cast(*reinterpret_cast(index_base + index_offset * sizeof(int32_t))); + } else if (index_dtype == DataType::I64) { + selected_index = *reinterpret_cast(index_base + index_offset * sizeof(int64_t)); + } else if (index_dtype == DataType::I8) { + selected_index = static_cast(*reinterpret_cast(index_base + index_offset * sizeof(int8_t))); + } else if (index_dtype == DataType::I16) { + selected_index = static_cast(*reinterpret_cast(index_base + index_offset * sizeof(int16_t))); + } else { + throw std::runtime_error("Unsupported index data type for index_select operation."); + } + + // 处理负索引 + if (selected_index < 0) { + selected_index += input_shapes[dim]; + } + + input_indices[i] = static_cast(selected_index); + } else { + input_indices[i] = output_indices[i]; + } + } + + // 计算输入张量的偏移 + size_t input_offset = 0; + for (int i = 0; i < static_cast(ndim); ++i) { + input_offset += input_indices[i] * input_strides[i]; + } + + // 计算输出张量的偏移 + size_t output_offset = 0; + for (int i = 0; i < static_cast(ndim); ++i) { + output_offset += output_indices[i] * output_strides[i]; + } + + // 根据数据类型复制数据 + if (dtype == DataType::F32) { + auto *input_ptr = reinterpret_cast(input_base + input_offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + output_offset * dtype_size); + *output_ptr = *input_ptr; + } else if (dtype == DataType::F16) { + auto *input_ptr = reinterpret_cast(input_base + input_offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + output_offset * dtype_size); + *output_ptr = *input_ptr; + } else if (dtype == DataType::I32) { + auto *input_ptr = reinterpret_cast(input_base + input_offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + output_offset * dtype_size); + *output_ptr = *input_ptr; + } else if (dtype == DataType::I64) { + auto *input_ptr = reinterpret_cast(input_base + input_offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + output_offset * dtype_size); + *output_ptr = *input_ptr; + } else if (dtype == DataType::I8) { + auto *input_ptr = reinterpret_cast(input_base + input_offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + output_offset * dtype_size); + *output_ptr = *input_ptr; + } else if (dtype == DataType::I16) { + auto *input_ptr = reinterpret_cast(input_base + input_offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + output_offset * dtype_size); + *output_ptr = *input_ptr; + } else if (dtype == DataType::U8) { + auto *input_ptr = reinterpret_cast(input_base + input_offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + output_offset * dtype_size); + *output_ptr = *input_ptr; + } else if (dtype == DataType::U16) { + auto *input_ptr = reinterpret_cast(input_base + input_offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + output_offset * dtype_size); + *output_ptr = *input_ptr; + } else if (dtype == DataType::U32) { + auto *input_ptr = reinterpret_cast(input_base + input_offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + output_offset * dtype_size); + *output_ptr = *input_ptr; + } else if (dtype == DataType::U64) { + auto *input_ptr = reinterpret_cast(input_base + input_offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + output_offset * dtype_size); + *output_ptr = *input_ptr; + } else if (dtype == DataType::BF16) { + auto *input_ptr = reinterpret_cast(input_base + input_offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + output_offset * dtype_size); + *output_ptr = *input_ptr; + } else if (dtype == DataType::BOOL) { + auto *input_ptr = reinterpret_cast(input_base + input_offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + output_offset * dtype_size); + *output_ptr = *input_ptr; + } else { + throw std::runtime_error("Unsupported data type for index_select operation."); + } + } +} + +static bool registered = []() { + IndexSelect::dispatcher().registerDevice(Device::Type::CPU, &calculate); + return true; +}(); + +} // namespace infinicore::op::index_select_impl::cpu \ No newline at end of file diff --git a/src/infinicore/ops/log2/log2.cc b/src/infinicore/ops/log2/log2.cc new file mode 100644 index 000000000..ca3491cd0 --- /dev/null +++ b/src/infinicore/ops/log2/log2.cc @@ -0,0 +1,27 @@ +#include "infinicore/ops/log2.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Log2::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Log2::execute(Tensor output, Tensor input) { + infinicore::context::setDevice(output->device(), true); + dispatcher().lookup(output->device().getType())(output, input); +} + +Tensor log2(Tensor input) { + auto output = Tensor::empty(input->shape(), input->dtype(), input->device()); + log2_(output, input); + return output; +} + +void log2_(Tensor output, Tensor input) { + Log2::execute(output, input); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/log2/log2_cpu.cc b/src/infinicore/ops/log2/log2_cpu.cc new file mode 100644 index 000000000..091af11b4 --- /dev/null +++ b/src/infinicore/ops/log2/log2_cpu.cc @@ -0,0 +1,69 @@ +#include "../../../utils.h" + +#include "infinicore/device.hpp" +#include "infinicore/ops/log2.hpp" +#include "infinicore/tensor.hpp" +#include +#include +#include + +namespace infinicore::op::log2_impl::infiniop { + +void calculate(Tensor output, Tensor input) { + auto strides = input->strides(); // vector + auto shapes = input->shape(); // vector + auto ndim = input->ndim(); + auto dtype = input->dtype(); + auto dtype_size = input->element_size(); + auto numel = input->numel(); + + auto input_base = input->data(); + auto output_base = output->data(); + + std::vector indices(ndim, 0); + for (size_t idx = 0; idx < numel; ++idx) { + // Calculate the offset for the current index + size_t offset = 0; + for (size_t dim = 0; dim < ndim; ++dim) { + offset += indices[dim] * strides[dim]; + } + + // Compute log2 for the current element + if (dtype == DataType::F32) { + auto *input_ptr = reinterpret_cast(input_base + offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + offset * dtype_size); + *output_ptr = std::log2(*input_ptr); + } else if (dtype == DataType::F64) { + auto *input_ptr = reinterpret_cast(input_base + offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + offset * dtype_size); + *output_ptr = std::log2(*input_ptr); + } else if (dtype == DataType::F16) { + // F16: 转换为 F32 计算,再转回 F16 + auto *input_ptr = reinterpret_cast(input_base + offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + offset * dtype_size); + + float input_f32 = utils::cast(*input_ptr); + float output_f32 = std::log2(input_f32); + *output_ptr = utils::cast(output_f32); + } else { + throw std::runtime_error("Unsupported data type for log2 operation."); + } + + // Update indices + for (ssize_t dim = ndim - 1; dim >= 0; --dim) { + indices[dim]++; + if (indices[dim] < shapes[dim]) { + break; + } else { + indices[dim] = 0; + } + } + } +} + +static bool registered = []() { + Log2::dispatcher().registerDevice(Device::Type::CPU, &calculate); + return true; +}(); + +} // namespace infinicore::op::log2_impl::infiniop \ No newline at end of file diff --git a/src/infinicore/ops/mish/mish.cc b/src/infinicore/ops/mish/mish.cc new file mode 100644 index 000000000..8eb132712 --- /dev/null +++ b/src/infinicore/ops/mish/mish.cc @@ -0,0 +1,29 @@ +#include "infinicore/ops/mish.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Mish::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Mish::execute(Tensor output, Tensor input, bool inplace) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input); + infinicore::context::setDevice(output->device()); + dispatcher().lookup(output->device().getType())(output, input, inplace); +} + +Tensor mish(Tensor input, bool inplace) { + if(inplace) { + Mish::execute(input, input, inplace); + return input; + } + + auto output = Tensor::empty(input->shape(), input->dtype(), input->device()); + Mish::execute(output, input, inplace); + return output; +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/mish/mish_cpu.cc b/src/infinicore/ops/mish/mish_cpu.cc new file mode 100644 index 000000000..5eae10f4f --- /dev/null +++ b/src/infinicore/ops/mish/mish_cpu.cc @@ -0,0 +1,104 @@ +#include "../../../utils.h" + +#include "infinicore/device.hpp" +#include "infinicore/ops/mish.hpp" +#include "infinicore/tensor.hpp" +#include +#include +#include + +namespace infinicore::op::mish_impl::cpu { + +void calculate(Tensor output, Tensor input, bool inplace) { + auto input_shapes = input->shape(); + auto output_shapes = output->shape(); + auto input_strides = input->strides(); + auto output_strides = output->strides(); + auto dtype = input->dtype(); + auto dtype_size = input->element_size(); + auto in_base = input->data(); + auto out_base = output->data(); + + size_t out_numel = output->numel(); + + // 支持 F16 / BF16 / F32 / F64 + if (!(dtype == DataType::F16 || dtype == DataType::BF16 || dtype == DataType::F32 || dtype == DataType::F64)) { + throw std::runtime_error("Mish CPU supports only F16/BF16/F32/F64."); + } + + auto softplus = [](double x) { + // 稳定的 softplus:log1p(exp(-|x|)) + max(x, 0) + double ax = std::abs(x); + return std::log1p(std::exp(-ax)) + std::max(x, 0.0); + }; + +#pragma omp parallel for + for (size_t o_idx = 0; o_idx < out_numel; ++o_idx) { + // 计算输出张量的多维索引 + std::vector o_indices(output_shapes.size()); + size_t tmp = o_idx; + for (int i = static_cast(output_shapes.size()) - 1; i >= 0; --i) { + o_indices[i] = tmp % output_shapes[i]; + tmp /= output_shapes[i]; + } + + // 计算输入偏移(处理广播) + size_t in_offset = 0; + int in_dim_offset = static_cast(output_shapes.size()) - static_cast(input_shapes.size()); + for (int i = 0; i < static_cast(output_shapes.size()); ++i) { + if (i >= in_dim_offset) { + int in_i = i - in_dim_offset; + size_t idx = (input_shapes[in_i] > 1) ? o_indices[i] : 0; + in_offset += idx * input_strides[in_i]; + } + } + + // 输出偏移(支持非连续) + size_t out_offset = 0; + for (int i = 0; i < static_cast(output_shapes.size()); ++i) { + out_offset += o_indices[i] * output_strides[i]; + } + + // 读取输入值为 double(根据 dtype 做转换) + auto in_ptr = in_base + in_offset * dtype_size; + double x; + if (dtype == DataType::F32) { + x = static_cast(*reinterpret_cast(in_ptr)); + } else if (dtype == DataType::F64) { + x = *reinterpret_cast(in_ptr); + } else if (dtype == DataType::F16) { + auto *hptr = reinterpret_cast(in_ptr); + float xf = utils::cast(*hptr); + x = static_cast(xf); + } else { // BF16 + auto *bptr = reinterpret_cast(in_ptr); + float xf = utils::cast(*bptr); + x = static_cast(xf); + } + + // mish: x * tanh(softplus(x)) + double sp = softplus(x); + double y = x * std::tanh(sp); + + // 写回输出,保持 dtype(使用 utils::cast 简化 F16/BF16) + auto out_ptr = out_base + out_offset * dtype_size; + if (dtype == DataType::F32) { + *reinterpret_cast(out_ptr) = static_cast(y); + } else if (dtype == DataType::F64) { + *reinterpret_cast(out_ptr) = y; + } else if (dtype == DataType::F16) { + auto *hptr = reinterpret_cast(out_ptr); + *hptr = utils::cast(static_cast(y)); + } else { // BF16 + auto *bptr = reinterpret_cast(out_ptr); + *bptr = utils::cast(static_cast(y)); + } + } +} + +static bool registered = []() { + Mish::dispatcher().registerDevice(Device::Type::CPU, &calculate); + return true; +}(); + +} // namespace infinicore::op::mish_impl::cpu \ No newline at end of file diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 531043dd2..5ecbdb0f1 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -4,6 +4,7 @@ #include "ops/add.hpp" #include "ops/attention.hpp" +#include "ops/bitwise_left_shift.hpp" #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" #include "ops/linear.hpp" @@ -16,6 +17,10 @@ #include "ops/silu.hpp" #include "ops/swiglu.hpp" #include "ops/zeros_.hpp" +#include "ops/index_select.hpp" +#include "ops/fold.hpp" +#include "ops/mish.hpp" +#include "ops/log2.hpp" namespace py = pybind11; @@ -36,6 +41,11 @@ inline void bind(py::module &m) { bind_rope(m); bind_embedding(m); bind_zeros_(m); + bind_bitwise_left_shift(m); + bind_index_select(m); + bind_fold(m); + bind_mish(m); + bind_log2(m); } } // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/bitwise_left_shift.hpp b/src/infinicore/pybind11/ops/bitwise_left_shift.hpp new file mode 100644 index 000000000..cbc6d11d4 --- /dev/null +++ b/src/infinicore/pybind11/ops/bitwise_left_shift.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "infinicore/ops/bitwise_left_shift.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_bitwise_left_shift(py::module &m) { + m.def("bitwise_left_shift", + &op::bitwise_left_shift, + py::arg("a"), + py::arg("b"), + R"doc(Element-wise bitwise left shift of two tensors.)doc"); + + m.def("bitwise_left_shift_", + &op::bitwise_left_shift_, + py::arg("c"), + py::arg("a"), + py::arg("b"), + R"doc(In-place element-wise tensor bitwise left shift.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/fold.hpp b/src/infinicore/pybind11/ops/fold.hpp new file mode 100644 index 000000000..0f0b956fc --- /dev/null +++ b/src/infinicore/pybind11/ops/fold.hpp @@ -0,0 +1,23 @@ +#pragma once + +#include + +#include "infinicore/ops/fold.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_fold(py::module &m) { + m.def("fold", + &op::fold, + py::arg("input"), + py::arg("output_size"), + py::arg("kernel_size"), + py::arg("dilation") = std::make_tuple(1, 1), + py::arg("padding") = std::make_tuple(0, 0), + py::arg("stride") = std::make_tuple(1, 1), + R"doc(Folds a tensor.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/index_select.hpp b/src/infinicore/pybind11/ops/index_select.hpp new file mode 100644 index 000000000..26e899995 --- /dev/null +++ b/src/infinicore/pybind11/ops/index_select.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include + +#include "infinicore/ops/index_select.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_index_select(py::module &m) { + m.def("index_select", + &op::index_select, + py::arg("input"), + py::arg("dim"), + py::arg("index"), + R"doc(Returns a new tensor which indexes the input tensor along dimension dim using the entries in index which is a LongTensor.)doc"); + + m.def("index_select_", + &op::index_select_, + py::arg("output"), + py::arg("input"), + py::arg("dim"), + py::arg("index"), + R"doc(In-place index select operation.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/log2.hpp b/src/infinicore/pybind11/ops/log2.hpp new file mode 100644 index 000000000..ff9a13055 --- /dev/null +++ b/src/infinicore/pybind11/ops/log2.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "infinicore/ops/log2.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_log2(py::module &m) { + m.def("log2", + &op::log2, + py::arg("input"), + R"doc(Logarithm base 2 of the tensor.)doc"); + + m.def("log2_", + &op::log2_, + py::arg("input"), + py::arg("output"), + R"doc(In-place logarithm base 2 computation.)doc"); +} + +} // namespace infinicore::ops \ No newline at end of file diff --git a/src/infinicore/pybind11/ops/mish.hpp b/src/infinicore/pybind11/ops/mish.hpp new file mode 100644 index 000000000..6d26bfc83 --- /dev/null +++ b/src/infinicore/pybind11/ops/mish.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include "infinicore/ops/mish.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_mish(py::module &m) { + m.def("mish", + &op::mish, + py::arg("input"), + py::arg("inplace") = false, + R"doc(Applies the Mish activation function: x * tanh(softplus(x)).)doc"); + +} + +} // namespace infinicore::ops diff --git a/test/infinicore/ops/bitwise_left_shift.py b/test/infinicore/ops/bitwise_left_shift.py index 3b0b503e3..98220bbf5 100644 --- a/test/infinicore/ops/bitwise_left_shift.py +++ b/test/infinicore/ops/bitwise_left_shift.py @@ -131,9 +131,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.bitwise_left_shift(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.bitwise_left_shift(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.bitwise_left_shift(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/fold.py b/test/infinicore/ops/fold.py index f9fb6f99c..ae4d80ed8 100644 --- a/test/infinicore/ops/fold.py +++ b/test/infinicore/ops/fold.py @@ -94,9 +94,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.nn.functional.fold(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.nn.functional.fold(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.nn.functional.fold(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/index_select.py b/test/infinicore/ops/index_select.py index a28bcc223..f55307cff 100644 --- a/test/infinicore/ops/index_select.py +++ b/test/infinicore/ops/index_select.py @@ -65,9 +65,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.index_select(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.index_select(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.index_select(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/log2.py b/test/infinicore/ops/log2.py index 6c4ebd740..a5d8457ca 100644 --- a/test/infinicore/ops/log2.py +++ b/test/infinicore/ops/log2.py @@ -87,9 +87,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.log2(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.log2(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.log2(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/mish.py b/test/infinicore/ops/mish.py index 087cabeac..c37957bd6 100644 --- a/test/infinicore/ops/mish.py +++ b/test/infinicore/ops/mish.py @@ -68,9 +68,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.nn.functional.mish(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.nn.functional.mish(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.nn.functional.mish(*args, **kwargs) def main():