From 7549959824de555ed6c829fde5d8726609cb3bb5 Mon Sep 17 00:00:00 2001 From: greenhandhand <781740145@qq.com> Date: Sun, 14 Dec 2025 15:21:28 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=20zeros=5F=20=E5=88=9D?= =?UTF-8?q?=E5=A7=8B=E5=8C=96=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops/zeros.hpp | 15 +++++++ python/infinicore/nn/__init__.py | 3 +- python/infinicore/nn/init/__init__.py | 5 +++ python/infinicore/nn/init/zeros_.py | 9 ++++ src/infinicore/ops/zeros/zeros.cc | 26 +++++++++++ src/infinicore/ops/zeros/zeros_infiniop.cc | 52 ++++++++++++++++++++++ src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/zeros_.hpp | 18 ++++++++ 8 files changed, 129 insertions(+), 1 deletion(-) create mode 100644 include/infinicore/ops/zeros.hpp create mode 100644 python/infinicore/nn/init/__init__.py create mode 100644 python/infinicore/nn/init/zeros_.py create mode 100644 src/infinicore/ops/zeros/zeros.cc create mode 100644 src/infinicore/ops/zeros/zeros_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/zeros_.hpp diff --git a/include/infinicore/ops/zeros.hpp b/include/infinicore/ops/zeros.hpp new file mode 100644 index 000000000..709c41855 --- /dev/null +++ b/include/infinicore/ops/zeros.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "common/op.hpp" + +namespace infinicore::op { +class Zeros { + +public: + using schema = void (*)(Tensor); + static void execute(Tensor output); + static common::OpDispatcher &dispatcher(); +}; + +void zeros_(Tensor output); +} // namespace infinicore::op diff --git a/python/infinicore/nn/__init__.py b/python/infinicore/nn/__init__.py index 73c9f0aaa..e221fa958 100644 --- a/python/infinicore/nn/__init__.py +++ b/python/infinicore/nn/__init__.py @@ -1,5 +1,6 @@ from infinicore.nn import functional from infinicore.nn.modules import * # noqa: F403 from infinicore.nn.parameter import InfiniCoreParameter as Parameter +from infinicore.nn import init -__all__ = ["functional", "Parameter"] +__all__ = ["functional", "Parameter", "init"] diff --git a/python/infinicore/nn/init/__init__.py b/python/infinicore/nn/init/__init__.py new file mode 100644 index 000000000..6b37744f2 --- /dev/null +++ b/python/infinicore/nn/init/__init__.py @@ -0,0 +1,5 @@ +from .zeros_ import zeros_ + +__all__ = [ + "zeros_", +] diff --git a/python/infinicore/nn/init/zeros_.py b/python/infinicore/nn/init/zeros_.py new file mode 100644 index 000000000..718c84135 --- /dev/null +++ b/python/infinicore/nn/init/zeros_.py @@ -0,0 +1,9 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def zeros_(input: Tensor) -> Tensor: + r"""Fill the input tensor with the scalar value 0.""" + _infinicore.zeros_(input._underlying) + return input diff --git a/src/infinicore/ops/zeros/zeros.cc b/src/infinicore/ops/zeros/zeros.cc new file mode 100644 index 000000000..97102054d --- /dev/null +++ b/src/infinicore/ops/zeros/zeros.cc @@ -0,0 +1,26 @@ +#include "infinicore/ops/zeros.hpp" +#include + +namespace infinicore::op { + +common::OpDispatcher &Zeros::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Zeros::execute(Tensor output) { + context::setDevice(output->device()); + auto device_type = context::getDevice().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No Zeros implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(output); +} + +void zeros_(Tensor output) { + Zeros::execute(output); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/zeros/zeros_infiniop.cc b/src/infinicore/ops/zeros/zeros_infiniop.cc new file mode 100644 index 000000000..62762a4c7 --- /dev/null +++ b/src/infinicore/ops/zeros/zeros_infiniop.cc @@ -0,0 +1,52 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/zeros.hpp" +#include + +namespace infinicore::op::zeros_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopZerosDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyZerosDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor output) { + size_t seed = hash_combine(output); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopZerosDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateZerosDescriptor( + context::getInfiniopHandle(output->device()), &desc, + output->desc(), output->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetZerosWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopZeros( + desc, workspace->data(), workspace_size, + output->data(), output->data(), context::getStream())); +} + +static bool registered = []() { + Zeros::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::zeros_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 978defa17..531043dd2 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -15,6 +15,7 @@ #include "ops/rope.hpp" #include "ops/silu.hpp" #include "ops/swiglu.hpp" +#include "ops/zeros_.hpp" namespace py = pybind11; @@ -34,6 +35,7 @@ inline void bind(py::module &m) { bind_swiglu(m); bind_rope(m); bind_embedding(m); + bind_zeros_(m); } } // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/zeros_.hpp b/src/infinicore/pybind11/ops/zeros_.hpp new file mode 100644 index 000000000..660fadd30 --- /dev/null +++ b/src/infinicore/pybind11/ops/zeros_.hpp @@ -0,0 +1,18 @@ +#pragma once + +#include + +#include "infinicore/ops/zeros.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_zeros_(py::module &m) { + m.def("zeros_", + &op::zeros_, + py::arg("input"), + R"doc(Fills the input tensor with zeros.)doc"); +} + +} // namespace infinicore::ops From a5db730e55d190a898a4a0370d519eb11713b15a Mon Sep 17 00:00:00 2001 From: greenhandhand <781740145@qq.com> Date: Sun, 14 Dec 2025 15:39:51 +0800 Subject: [PATCH 2/2] =?UTF-8?q?Infinicore=20=E6=AF=94=E8=B5=9B=EF=BC=8C?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=20log10,=20avg=5Fpool3d,=20histc,=20dot,=20l?= =?UTF-8?q?og1p=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops/avg_pool3d.hpp | 17 +++ include/infinicore/ops/dot.hpp | 16 ++ include/infinicore/ops/histc.hpp | 16 ++ include/infinicore/ops/log10.hpp | 16 ++ include/infinicore/ops/log1p.hpp | 16 ++ python/infinicore/__init__.py | 8 + python/infinicore/nn/functional/__init__.py | 2 + python/infinicore/nn/functional/avg_pool3d.py | 68 +++++++++ python/infinicore/ops/dot.py | 16 ++ python/infinicore/ops/histc.py | 13 ++ python/infinicore/ops/log10.py | 16 ++ python/infinicore/ops/log1p.py | 16 ++ src/infinicore/ops/avg_pool3d/avg_pool3d.cc | 66 +++++++++ .../ops/avg_pool3d/avg_pool3d_cpu.cc | 123 ++++++++++++++++ src/infinicore/ops/dot/dot.cc | 39 +++++ src/infinicore/ops/dot/dot_cpu.cc | 83 +++++++++++ src/infinicore/ops/histc/histc.cc | 36 +++++ src/infinicore/ops/histc/histc_cpu.cc | 138 ++++++++++++++++++ src/infinicore/ops/log10/log10.cc | 33 +++++ src/infinicore/ops/log10/log10_cpu.cc | 69 +++++++++ src/infinicore/ops/log1p/log1p.cc | 33 +++++ src/infinicore/ops/log1p/log1p_cpu.cc | 69 +++++++++ src/infinicore/pybind11/ops.hpp | 10 ++ src/infinicore/pybind11/ops/avg_pool3d.hpp | 24 +++ src/infinicore/pybind11/ops/dot.hpp | 19 +++ src/infinicore/pybind11/ops/histc.hpp | 30 ++++ src/infinicore/pybind11/ops/log10.hpp | 24 +++ src/infinicore/pybind11/ops/log1p.hpp | 24 +++ test/infinicore/ops/avg_pool3d.py | 6 +- test/infinicore/ops/dot.py | 6 +- test/infinicore/ops/histc.py | 6 +- test/infinicore/ops/log10.py | 6 +- test/infinicore/ops/log1p.py | 6 +- 33 files changed, 1055 insertions(+), 15 deletions(-) create mode 100644 include/infinicore/ops/avg_pool3d.hpp create mode 100644 include/infinicore/ops/dot.hpp create mode 100644 include/infinicore/ops/histc.hpp create mode 100644 include/infinicore/ops/log10.hpp create mode 100644 include/infinicore/ops/log1p.hpp create mode 100644 python/infinicore/nn/functional/avg_pool3d.py create mode 100644 python/infinicore/ops/dot.py create mode 100644 python/infinicore/ops/histc.py create mode 100644 python/infinicore/ops/log10.py create mode 100644 python/infinicore/ops/log1p.py create mode 100644 src/infinicore/ops/avg_pool3d/avg_pool3d.cc create mode 100644 src/infinicore/ops/avg_pool3d/avg_pool3d_cpu.cc create mode 100644 src/infinicore/ops/dot/dot.cc create mode 100644 src/infinicore/ops/dot/dot_cpu.cc create mode 100644 src/infinicore/ops/histc/histc.cc create mode 100644 src/infinicore/ops/histc/histc_cpu.cc create mode 100644 src/infinicore/ops/log10/log10.cc create mode 100644 src/infinicore/ops/log10/log10_cpu.cc create mode 100644 src/infinicore/ops/log1p/log1p.cc create mode 100644 src/infinicore/ops/log1p/log1p_cpu.cc create mode 100644 src/infinicore/pybind11/ops/avg_pool3d.hpp create mode 100644 src/infinicore/pybind11/ops/dot.hpp create mode 100644 src/infinicore/pybind11/ops/histc.hpp create mode 100644 src/infinicore/pybind11/ops/log10.hpp create mode 100644 src/infinicore/pybind11/ops/log1p.hpp diff --git a/include/infinicore/ops/avg_pool3d.hpp b/include/infinicore/ops/avg_pool3d.hpp new file mode 100644 index 000000000..43d4c9e28 --- /dev/null +++ b/include/infinicore/ops/avg_pool3d.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { +class AvgPool3d { +public: + using schema = void (*)(Tensor, Tensor, std::tuple, std::tuple, std::tuple, bool); + static void execute(Tensor output, Tensor input, std::tuple kernel_size, std::tuple stride, std::tuple padding, bool ceil_mode); + static common::OpDispatcher &dispatcher(); +}; + +Tensor avg_pool3d(Tensor input, std::tuple kernel_size, std::tuple stride, std::tuple padding, bool ceil_mode); +void avg_pool3d_(Tensor output, Tensor input, std::tuple kernel_size, std::tuple stride, std::tuple padding, bool ceil_mode); +} // namespace infinicore::op diff --git a/include/infinicore/ops/dot.hpp b/include/infinicore/ops/dot.hpp new file mode 100644 index 000000000..d22dbc5a2 --- /dev/null +++ b/include/infinicore/ops/dot.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Dot { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor c, Tensor a, Tensor b); + static common::OpDispatcher &dispatcher(); +}; + +Tensor dot(Tensor a, Tensor b); +void dot_(Tensor c, Tensor a, Tensor b); +} // namespace infinicore::op diff --git a/include/infinicore/ops/histc.hpp b/include/infinicore/ops/histc.hpp new file mode 100644 index 000000000..4ea5d6c8c --- /dev/null +++ b/include/infinicore/ops/histc.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Histc { +public: + using schema = void (*)(Tensor, Tensor, size_t, double, double); + static void execute(Tensor input, Tensor output, size_t bins, double min, double max); + static common::OpDispatcher &dispatcher(); +}; + +Tensor histc(Tensor input, size_t bins, double min, double max); +void histc_(Tensor input, Tensor output, size_t bins, double min, double max); +} // namespace infinicore::op diff --git a/include/infinicore/ops/log10.hpp b/include/infinicore/ops/log10.hpp new file mode 100644 index 000000000..02ce3882d --- /dev/null +++ b/include/infinicore/ops/log10.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Log10 { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor output, Tensor input); + static common::OpDispatcher &dispatcher(); +}; + +Tensor log10(Tensor input); +void log10_(Tensor output, Tensor input); +} // namespace infinicore::op diff --git a/include/infinicore/ops/log1p.hpp b/include/infinicore/ops/log1p.hpp new file mode 100644 index 000000000..965e549db --- /dev/null +++ b/include/infinicore/ops/log1p.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Log1p { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor output, Tensor input); + static common::OpDispatcher &dispatcher(); +}; + +Tensor log1p(Tensor input); +void log1p_(Tensor output, Tensor input); +} // namespace infinicore::op diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 5c541ec3c..484a15773 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -45,6 +45,10 @@ from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow from infinicore.ops.rearrange import rearrange +from infinicore.ops.dot import dot +from infinicore.ops.histc import histc +from infinicore.ops.log10 import log10 +from infinicore.ops.log1p import log1p from infinicore.tensor import ( Tensor, empty, @@ -115,6 +119,10 @@ "strided_empty", "strided_from_blob", "zeros", + "dot", + "log10", + "log1p", + "histc", ] use_ntops = False diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 255079790..793204833 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -6,6 +6,7 @@ from .rope import RopeAlgo, rope from .silu import silu from .swiglu import swiglu +from .avg_pool3d import avg_pool3d __all__ = [ "causal_softmax", @@ -17,4 +18,5 @@ "embedding", "rope", "RopeAlgo", + "avg_pool3d", ] diff --git a/python/infinicore/nn/functional/avg_pool3d.py b/python/infinicore/nn/functional/avg_pool3d.py new file mode 100644 index 000000000..16a3a5a9c --- /dev/null +++ b/python/infinicore/nn/functional/avg_pool3d.py @@ -0,0 +1,68 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def _zeros_pad(input: Tensor, padding: tuple[int, ...]) -> Tensor: + r"""Pad a tensor. + + Args: + input (Tensor): The input tensor. + padding (tuple[int, ...]): The padding sizes. + + Returns: + Tensor: The padded tensor. + """ + output_shape = [] + for i in range(input.ndim): + output_shape.append(input.size(i) + 2 * padding[i]) + + output = infinicore.empty(output_shape, dtype=input.dtype, device=input.device) + output = infinicore.nn.init.zeros_(output) + + # 使用 narrow 函数获取对应的位置,然后复制数据 + # 需要逐维度进行 narrow 操作 + output_view = output + for dim in range(len(input.size())): + output_view = infinicore.narrow(output_view, dim, padding[dim], input.size(dim)) + + # 将输入数据复制到输出张量的对应位置 + infinicore.add(input, output_view, out=output_view) + + return output + + +def avg_pool3d( + input: Tensor, + kernel_size: tuple[int, int, int] | int, + stride: tuple[int, int, int] | int | None = None, + padding: tuple[int, int, int] | int = 0, + ceil_mode: bool = False, +): + r"""Applies a 3D average pooling over an input signal composed of several input + planes.""" + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + + if isinstance(stride, int): + stride = (stride, stride, stride) + + if stride is None: + stride = kernel_size + + if isinstance(padding, int): + padding = [padding, padding, padding] + + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + padding = [0, 0] + list(padding) + + if any(p > 0 for p in padding): + input = _zeros_pad(input, padding) + return infinicore.ntops.torch.avg_pool3d(input, kernel_size, stride, ceil_mode) + + # cpu infer + return Tensor( + _infinicore.avg_pool3d( + input._underlying, kernel_size, stride, padding, ceil_mode + ) + ) diff --git a/python/infinicore/ops/dot.py b/python/infinicore/ops/dot.py new file mode 100644 index 000000000..429fa564a --- /dev/null +++ b/python/infinicore/ops/dot.py @@ -0,0 +1,16 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def dot(input: Tensor, tensor: Tensor, *, out=None) -> Tensor: + r"""Compute the dot product of two 1-D tensors.""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.dot(input, tensor, out=out) + + if out is None: + return Tensor(_infinicore.dot(input._underlying, tensor._underlying)) + + _infinicore.dot_(out._underlying, input._underlying, tensor._underlying) + return out diff --git a/python/infinicore/ops/histc.py b/python/infinicore/ops/histc.py new file mode 100644 index 000000000..f7fb9af03 --- /dev/null +++ b/python/infinicore/ops/histc.py @@ -0,0 +1,13 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def histc(input: Tensor, bins: int = 100, min: float | None = None, max: float | None = None) -> Tensor: + r"""Apply the logsumexp function.""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + is_moore = input._underlying.device.type == _infinicore.Device.Type.MOORE + return infinicore.ntops.torch.histc(input, bins, min, max, is_moore) + + return Tensor(_infinicore.histc(input._underlying, bins, min, max)) diff --git a/python/infinicore/ops/log10.py b/python/infinicore/ops/log10.py new file mode 100644 index 000000000..6ce4e150d --- /dev/null +++ b/python/infinicore/ops/log10.py @@ -0,0 +1,16 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def log10(input: Tensor, *, out=None) -> Tensor: + r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.log10(input, out=out) + + if out is None: + return Tensor(_infinicore.log10(input._underlying)) + + _infinicore.log10_(out._underlying, input._underlying) + return out diff --git a/python/infinicore/ops/log1p.py b/python/infinicore/ops/log1p.py new file mode 100644 index 000000000..3db205bcf --- /dev/null +++ b/python/infinicore/ops/log1p.py @@ -0,0 +1,16 @@ +import infinicore +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def log1p(input: Tensor, *, out=None) -> Tensor: + r"""Compute the ln(x + 1).""" + + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.log1p(input, out=out) + + if out is None: + return Tensor(_infinicore.log1p(input._underlying)) + + _infinicore.log1p_(out._underlying, input._underlying) + return out diff --git a/src/infinicore/ops/avg_pool3d/avg_pool3d.cc b/src/infinicore/ops/avg_pool3d/avg_pool3d.cc new file mode 100644 index 000000000..3f0754a2d --- /dev/null +++ b/src/infinicore/ops/avg_pool3d/avg_pool3d.cc @@ -0,0 +1,66 @@ +#include "infinicore/ops/avg_pool3d.hpp" +#include +#include + +namespace infinicore::op { + +common::OpDispatcher &AvgPool3d::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void AvgPool3d::execute(Tensor output, Tensor input, std::tuple kernel_size, std::tuple stride, std::tuple padding, bool ceil_mode) { + infinicore::context::setDevice(input->device(), true); + auto device_type = context::getDevice().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No AvgPool3d implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(output, input, kernel_size, stride, padding, ceil_mode); +} + +Tensor avg_pool3d(Tensor input, std::tuple kernel_size, std::tuple stride, std::tuple padding, bool ceil_mode) { + const auto ndim = input->ndim(); + auto input_shape = input->shape(); + + if (ndim != 5 && ndim != 4) { + throw std::runtime_error("Input tensor must be 4-dimensional (N, C, D_in, H_in, W_in) or (C, D_in, H_in, W_in)"); + } + + if (ndim == 4) { + input = input->view({1, input_shape[0], input_shape[1], input_shape[2], input_shape[3]}); + input_shape = input->shape(); + } + + const auto [Kernel_D, Kernel_H, Kernel_W] = kernel_size; + const auto [Stride_D, Stride_H, Stride_W] = stride; + const auto [Padding_D, Padding_H, Padding_W] = padding; + const auto D_in = input_shape[2]; + const auto H_in = input_shape[3]; + const auto W_in = input_shape[4]; + size_t D_out = 0; + size_t H_out = 0; + size_t W_out = 0; + if (ceil_mode) { + D_out = static_cast(std::ceil(static_cast(D_in + 2 * Padding_D - Kernel_D) / Stride_D)) + 1; + H_out = static_cast(std::ceil(static_cast(H_in + 2 * Padding_H - Kernel_H) / Stride_H)) + 1; + W_out = static_cast(std::ceil(static_cast(W_in + 2 * Padding_W - Kernel_W) / Stride_W)) + 1; + } else { + D_out = static_cast(std::floor(static_cast(D_in + 2 * Padding_D - Kernel_D) / Stride_D)) + 1; + H_out = static_cast(std::floor(static_cast(H_in + 2 * Padding_H - Kernel_H) / Stride_H)) + 1; + W_out = static_cast(std::floor(static_cast(W_in + 2 * Padding_W - Kernel_W) / Stride_W)) + 1; + } + + auto output_shape = Shape{input_shape[0], input_shape[1], D_out, H_out, W_out}; + + auto output = Tensor::empty(output_shape, input->dtype(), input->device()); + avg_pool3d_(output, input, kernel_size, stride, padding, ceil_mode); + return output; +} + +void avg_pool3d_(Tensor output, Tensor input, std::tuple kernel_size, std::tuple stride, std::tuple padding, bool ceil_mode) { + AvgPool3d::execute(output, input, kernel_size, stride, padding, ceil_mode); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/avg_pool3d/avg_pool3d_cpu.cc b/src/infinicore/ops/avg_pool3d/avg_pool3d_cpu.cc new file mode 100644 index 000000000..fc5999eca --- /dev/null +++ b/src/infinicore/ops/avg_pool3d/avg_pool3d_cpu.cc @@ -0,0 +1,123 @@ +#include "../../../utils.h" +#include "infinicore/device.hpp" +#include "infinicore/ops/avg_pool3d.hpp" +#include "infinicore/tensor.hpp" +#include +#include +#include +#include +#include + +namespace infinicore::op::avg_pool3d_impl::cpu { + +void calculate(Tensor output, Tensor input, std::tuple kernel_size, std::tuple stride, std::tuple padding, bool ceil_mode) { + // input: [N, C, D_in, H_in, W_in], output: [N, C, D_out, H_out, W_out] + auto input_shapes = input->shape(); + auto input_strides = input->strides(); + auto output_shapes = output->shape(); + auto dtype = input->dtype(); + + const size_t N = input_shapes[0]; + const size_t C = input_shapes[1]; + const size_t D_in = input_shapes[2]; + const size_t H_in = input_shapes[3]; + const size_t W_in = input_shapes[4]; + + const size_t D_out = output_shapes[2]; + const size_t H_out = output_shapes[3]; + const size_t W_out = output_shapes[4]; + + const size_t stride_N = input_strides[0]; + const size_t stride_C = input_strides[1]; + const size_t stride_D = input_strides[2]; + const size_t stride_H = input_strides[3]; + const size_t stride_W = input_strides[4]; + + const size_t kernel_d = std::get<0>(kernel_size); + const size_t kernel_h = std::get<1>(kernel_size); + const size_t kernel_w = std::get<2>(kernel_size); + + const size_t stride_d = std::get<0>(stride); + const size_t stride_h = std::get<1>(stride); + const size_t stride_w = std::get<2>(stride); + + // 使用 padding 参数(count_include_pad=True) + + auto input_base = input->data(); + auto output_base = output->data(); + const auto element_size = input->element_size(); + + // 无调试输出 + + const double kernel_vol = static_cast(kernel_d * kernel_h * kernel_w); + + for (size_t n = 0; n < N; ++n) { + for (size_t c = 0; c < C; ++c) { + for (size_t od = 0; od < D_out; ++od) { + long long d_start_raw = (long long)od * (long long)stride_d - (long long)std::get<0>(padding); + long long d_end_raw = d_start_raw + (long long)kernel_d; + const size_t d_begin = (size_t)std::max(0, d_start_raw); + const size_t d_end = (size_t)std::min(d_end_raw, (long long)D_in); + + for (size_t oh = 0; oh < H_out; ++oh) { + long long h_start_raw = (long long)oh * (long long)stride_h - (long long)std::get<1>(padding); + long long h_end_raw = h_start_raw + (long long)kernel_h; + const size_t h_begin = (size_t)std::max(0, h_start_raw); + const size_t h_end = (size_t)std::min(h_end_raw, (long long)H_in); + + for (size_t ow = 0; ow < W_out; ++ow) { + long long w_start_raw = (long long)ow * (long long)stride_w - (long long)std::get<2>(padding); + long long w_end_raw = w_start_raw + (long long)kernel_w; + const size_t w_begin = (size_t)std::max(0, w_start_raw); + const size_t w_end = (size_t)std::min(w_end_raw, (long long)W_in); + + double sum = 0.0; + + // 累加有效元素(padding 视为 0) + for (size_t id = d_begin; id < d_end; ++id) { + for (size_t ih = h_begin; ih < h_end; ++ih) { + for (size_t iw = w_begin; iw < w_end; ++iw) { + const size_t offset = n * stride_N + c * stride_C + id * stride_D + ih * stride_H + iw * stride_W; + + if (dtype == DataType::F32) { + auto *ptr = reinterpret_cast(input_base + offset * element_size); + sum += static_cast(*ptr); + } else if (dtype == DataType::F64) { + auto *ptr = reinterpret_cast(input_base + offset * element_size); + sum += *ptr; + } else if (dtype == DataType::F16) { + auto *ptr = reinterpret_cast(input_base + offset * element_size); + sum += static_cast(utils::cast(*ptr)); + } else { + throw std::runtime_error("Unsupported data type for avg_pool3d operation."); + } + } + } + } + + const double avg = sum / kernel_vol; // count_include_pad=True + + const size_t out_offset = n * (C * D_out * H_out * W_out) + c * (D_out * H_out * W_out) + od * (H_out * W_out) + oh * W_out + ow; + if (dtype == DataType::F32) { + auto *ptr = reinterpret_cast(output_base + out_offset * element_size); + *ptr = static_cast(avg); + } else if (dtype == DataType::F64) { + auto *ptr = reinterpret_cast(output_base + out_offset * element_size); + *ptr = avg; + } else if (dtype == DataType::F16) { + auto *ptr = reinterpret_cast(output_base + out_offset * element_size); + *ptr = utils::cast(static_cast(avg)); + } + } + } + } + } + } +} + +static bool registered = []() { + AvgPool3d::dispatcher().registerDevice(Device::Type::CPU, &calculate); + return true; +}(); + +} // namespace infinicore::op::avg_pool3d_impl::cpu diff --git a/src/infinicore/ops/dot/dot.cc b/src/infinicore/ops/dot/dot.cc new file mode 100644 index 000000000..85b25af2d --- /dev/null +++ b/src/infinicore/ops/dot/dot.cc @@ -0,0 +1,39 @@ +#include "infinicore/ops/dot.hpp" +#include "../../utils.hpp" +#include +#include + +namespace infinicore::op { + +common::OpDispatcher &Dot::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Dot::execute(Tensor c, Tensor a, Tensor b) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); + infinicore::context::setDevice(a->device(), true); + auto device_type = context::getDevice().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No Dot implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + if (a->ndim() != 1 || b->ndim() != 1 || c->ndim() != 0) { + throw std::runtime_error("Dot operation only supports 1-D tensors for a and b, and 0-D tensor for c."); + } + + func(c, a, b); +} + +Tensor dot(Tensor a, Tensor b) { + auto c = Tensor::empty(Shape{}, a->dtype(), a->device()); + dot_(c, a, b); + return c; +} + +void dot_(Tensor c, Tensor a, Tensor b) { + Dot::execute(c, a, b); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/dot/dot_cpu.cc b/src/infinicore/ops/dot/dot_cpu.cc new file mode 100644 index 000000000..fd4d2fdde --- /dev/null +++ b/src/infinicore/ops/dot/dot_cpu.cc @@ -0,0 +1,83 @@ +#include "../../../utils.h" +#include "infinicore/device.hpp" +#include "infinicore/ops/dot.hpp" +#include "infinicore/tensor.hpp" + +namespace infinicore::op::dot_impl::cpu { + +void calculate(Tensor c, Tensor a, Tensor b) { + auto a_shapes = a->shape(); + auto b_shapes = b->shape(); + if (a->ndim() != 1 || b->ndim() != 1) { + throw std::runtime_error("Dot CPU only supports 1-D tensors for a and b."); + } + if (a_shapes[0] != b_shapes[0]) { + throw std::runtime_error("Dot CPU requires a and b to have the same length."); + } + + auto dtype = a->dtype(); + if (dtype != b->dtype()) { + throw std::runtime_error("Dot CPU requires a and b to have the same dtype."); + } + + const size_t len = a_shapes[0]; + const auto a_stride = a->strides()[0]; + const auto b_stride = b->strides()[0]; + const auto a_element_size = a->element_size(); + const auto b_element_size = b->element_size(); + const auto c_element_size = c->element_size(); + + auto a_base = a->data(); + auto b_base = b->data(); + auto c_base = c->data(); + + double acc = 0.0; + for (size_t i = 0; i < len; ++i) { + const size_t a_off = i * static_cast(a_stride); + const size_t b_off = i * static_cast(b_stride); + if (dtype == DataType::F32) { + auto *ap = reinterpret_cast(a_base + a_off * a_element_size); + auto *bp = reinterpret_cast(b_base + b_off * b_element_size); + acc += static_cast((*ap) * (*bp)); + } else if (dtype == DataType::F64) { + auto *ap = reinterpret_cast(a_base + a_off * a_element_size); + auto *bp = reinterpret_cast(b_base + b_off * b_element_size); + acc += (*ap) * (*bp); + } else if (dtype == DataType::F16) { + auto *ap = reinterpret_cast(a_base + a_off * a_element_size); + auto *bp = reinterpret_cast(b_base + b_off * b_element_size); + float av = utils::cast(*ap); + float bv = utils::cast(*bp); + acc += static_cast(av * bv); + } else if (dtype == DataType::BF16) { + auto *ap = reinterpret_cast(a_base + a_off * a_element_size); + auto *bp = reinterpret_cast(b_base + b_off * b_element_size); + float av = utils::cast(*ap); + float bv = utils::cast(*bp); + acc += static_cast(av * bv); + } else { + throw std::runtime_error("Unsupported dtype for dot CPU."); + } + } + + if (dtype == DataType::F32) { + auto *cp = reinterpret_cast(c_base); + *cp = static_cast(acc); + } else if (dtype == DataType::F64) { + auto *cp = reinterpret_cast(c_base); + *cp = acc; + } else if (dtype == DataType::F16) { + auto *cp = reinterpret_cast(c_base); + *cp = utils::cast(static_cast(acc)); + } else if (dtype == DataType::BF16) { + auto *cp = reinterpret_cast(c_base); + *cp = utils::cast(static_cast(acc)); + } +} + +static bool registered = []() { + Dot::dispatcher().registerDevice(Device::Type::CPU, &calculate); + return true; +}(); + +} // namespace infinicore::op::dot_impl::cpu diff --git a/src/infinicore/ops/histc/histc.cc b/src/infinicore/ops/histc/histc.cc new file mode 100644 index 000000000..6052883ee --- /dev/null +++ b/src/infinicore/ops/histc/histc.cc @@ -0,0 +1,36 @@ +#include "infinicore/ops/histc.hpp" +#include +#include + +namespace infinicore::op { + +common::OpDispatcher &Histc::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Histc::execute(Tensor input, Tensor output, size_t bins, double min, double max) { + infinicore::context::setDevice(input->device(), true); + auto device_type = context::getDevice().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No Histc implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(input, output, bins, min, max); +} + +Tensor histc(Tensor input, size_t bins, double min, double max) { + auto output = Tensor::empty(Shape{ + bins, + }, + input->dtype(), input->device()); + histc_(input, output, bins, min, max); + return output; +} + +void histc_(Tensor input, Tensor output, size_t bins, double min, double max) { + Histc::execute(input, output, bins, min, max); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/histc/histc_cpu.cc b/src/infinicore/ops/histc/histc_cpu.cc new file mode 100644 index 000000000..9d904d88c --- /dev/null +++ b/src/infinicore/ops/histc/histc_cpu.cc @@ -0,0 +1,138 @@ +#include "../../../utils.h" +#include "infinicore/device.hpp" +#include "infinicore/ops/histc.hpp" +#include "infinicore/tensor.hpp" +#include +#include +#include +#include +#include + +namespace infinicore::op::histc_impl::cpu { + +void calculate(Tensor input, Tensor output, size_t bins, double min, double max) { + if (bins == 0) { + throw std::runtime_error("histc CPU: bins must be > 0"); + } + if (!(max > min)) { + throw std::runtime_error("histc CPU: require max > min"); + } + + auto dtype = input->dtype(); + // 输出一维 [bins],按创建逻辑默认连续 + auto out_base = output->data(); + const auto out_esize = output->element_size(); + + // 初始化输出为 0 + for (size_t i = 0; i < bins; ++i) { + const size_t off = i; + if (dtype == DataType::F32) { + auto *ptr = reinterpret_cast(out_base + off * out_esize); + *ptr = 0.0f; + } else if (dtype == DataType::F64) { + auto *ptr = reinterpret_cast(out_base + off * out_esize); + *ptr = 0.0; + } else if (dtype == DataType::F16) { + auto *ptr = reinterpret_cast(out_base + off * out_esize); + *ptr = utils::cast(0.0f); + } else if (dtype == DataType::BF16) { + auto *ptr = reinterpret_cast(out_base + off * out_esize); + *ptr = utils::cast(0.0f); + } else { + throw std::runtime_error("histc CPU: unsupported dtype for output"); + } + } + + // 输入遍历(支持任意形状/步长),对每个元素进行分箱 + auto in_base = input->data(); + const auto in_esize = input->element_size(); + auto strides = input->strides(); + auto shapes = input->shape(); + const size_t ndim = input->ndim(); + const size_t numel = input->numel(); + + const double width = (max - min) / static_cast(bins); + if (!(width > 0.0)) { + // 极端情况:width==0(min==max),将等于 max 的值计入最后一箱 + // 其他值忽略 + } + + std::vector indices(ndim, 0); + for (size_t idx = 0; idx < numel; ++idx) { + size_t off = 0; + for (size_t d = 0; d < ndim; ++d) { + off += indices[d] * static_cast(strides[d]); + } + + double vald = 0.0; + if (dtype == DataType::F32) { + auto *p = reinterpret_cast(in_base + off * in_esize); + vald = static_cast(*p); + } else if (dtype == DataType::F64) { + auto *p = reinterpret_cast(in_base + off * in_esize); + vald = *p; + } else if (dtype == DataType::F16) { + auto *p = reinterpret_cast(in_base + off * in_esize); + vald = static_cast(utils::cast(*p)); + } else if (dtype == DataType::BF16) { + auto *p = reinterpret_cast(in_base + off * in_esize); + vald = static_cast(utils::cast(*p)); + } else { + throw std::runtime_error("histc CPU: unsupported dtype for input"); + } + + // 计算箱索引 + ssize_t bin = -1; + if (vald < min || vald > max) { + bin = -1; // 忽略越界 + } else if (vald == max) { + bin = static_cast(bins - 1); + } else if (width > 0.0) { + double pos = (vald - min) / width; + ssize_t ib = static_cast(std::floor(pos)); + if (ib < 0) { + ib = 0; + } + if (ib >= static_cast(bins)) { + ib = static_cast(bins - 1); + } + bin = ib; + } + + if (bin >= 0) { + const size_t out_off = static_cast(bin); + if (dtype == DataType::F32) { + auto *op = reinterpret_cast(out_base + out_off * out_esize); + *op = *op + 1.0f; + } else if (dtype == DataType::F64) { + auto *op = reinterpret_cast(out_base + out_off * out_esize); + *op = *op + 1.0; + } else if (dtype == DataType::F16) { + auto *op = reinterpret_cast(out_base + out_off * out_esize); + float cur = utils::cast(*op); + *op = utils::cast(cur + 1.0f); + } else if (dtype == DataType::BF16) { + auto *op = reinterpret_cast(out_base + out_off * out_esize); + float cur = utils::cast(*op); + *op = utils::cast(cur + 1.0f); + } + } + + // 更新多维索引 + for (ssize_t d = static_cast(ndim) - 1; d >= 0; --d) { + indices[static_cast(d)]++; + if (indices[static_cast(d)] < shapes[static_cast(d)]) { + break; + } else { + indices[static_cast(d)] = 0; + } + } + } +} + +static bool registered = []() { + Histc::dispatcher().registerDevice(Device::Type::CPU, &calculate); + return true; +}(); + +} // namespace infinicore::op::histc_impl::cpu diff --git a/src/infinicore/ops/log10/log10.cc b/src/infinicore/ops/log10/log10.cc new file mode 100644 index 000000000..6af2f0ebe --- /dev/null +++ b/src/infinicore/ops/log10/log10.cc @@ -0,0 +1,33 @@ +#include "infinicore/ops/log10.hpp" +#include + +namespace infinicore::op { + +common::OpDispatcher &Log10::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Log10::execute(Tensor output, Tensor input) { + infinicore::context::setDevice(input->device(), true); + auto device_type = context::getDevice().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No Log10 implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(output, input); +} + +Tensor log10(Tensor input) { + Shape shape = input->shape(); + auto output = Tensor::empty(shape, input->dtype(), input->device()); + log10_(output, input); + return output; +} + +void log10_(Tensor output, Tensor input) { + Log10::execute(output, input); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/log10/log10_cpu.cc b/src/infinicore/ops/log10/log10_cpu.cc new file mode 100644 index 000000000..9549f18d5 --- /dev/null +++ b/src/infinicore/ops/log10/log10_cpu.cc @@ -0,0 +1,69 @@ +#include "../../../utils.h" + +#include "infinicore/device.hpp" +#include "infinicore/ops/log10.hpp" +#include "infinicore/tensor.hpp" +#include +#include +#include + +namespace infinicore::op::log10_impl::cpu { + +void calculate(Tensor output, Tensor input) { + auto strides = input->strides(); // vector + auto shapes = input->shape(); // vector + auto ndim = input->ndim(); + auto dtype = input->dtype(); + auto dtype_size = input->element_size(); + auto numel = input->numel(); + + auto input_base = input->data(); + auto output_base = output->data(); + + std::vector indices(ndim, 0); + for (size_t idx = 0; idx < numel; ++idx) { + // Calculate the offset for the current index + size_t offset = 0; + for (size_t dim = 0; dim < ndim; ++dim) { + offset += indices[dim] * strides[dim]; + } + + // Compute log10 for the current element + if (dtype == DataType::F32) { + auto *input_ptr = reinterpret_cast(input_base + offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + offset * dtype_size); + *output_ptr = std::log10(*input_ptr); + } else if (dtype == DataType::F64) { + auto *input_ptr = reinterpret_cast(input_base + offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + offset * dtype_size); + *output_ptr = std::log10(*input_ptr); + } else if (dtype == DataType::F16) { + // F16: 转换为 F32 计算,再转回 F16 + auto *input_ptr = reinterpret_cast(input_base + offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + offset * dtype_size); + + float input_f32 = utils::cast(*input_ptr); + float output_f32 = std::log10(input_f32); + *output_ptr = utils::cast(output_f32); + } else { + throw std::runtime_error("Unsupported data type for log10 operation."); + } + + // Update indices + for (ssize_t dim = ndim - 1; dim >= 0; --dim) { + indices[dim]++; + if (indices[dim] < shapes[dim]) { + break; + } else { + indices[dim] = 0; + } + } + } +} + +static bool registered = []() { + Log10::dispatcher().registerDevice(Device::Type::CPU, &calculate); + return true; +}(); + +} // namespace infinicore::op::log10_impl::cpu diff --git a/src/infinicore/ops/log1p/log1p.cc b/src/infinicore/ops/log1p/log1p.cc new file mode 100644 index 000000000..9089d1aff --- /dev/null +++ b/src/infinicore/ops/log1p/log1p.cc @@ -0,0 +1,33 @@ +#include "infinicore/ops/log1p.hpp" +#include + +namespace infinicore::op { + +common::OpDispatcher &Log1p::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Log1p::execute(Tensor output, Tensor input) { + infinicore::context::setDevice(input->device(), true); + auto device_type = context::getDevice().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No Log1p implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(output, input); +} + +Tensor log1p(Tensor input) { + Shape shape = input->shape(); + auto output = Tensor::empty(shape, input->dtype(), input->device()); + log1p_(output, input); + return output; +} + +void log1p_(Tensor output, Tensor input) { + Log1p::execute(output, input); +} +} // namespace infinicore::op diff --git a/src/infinicore/ops/log1p/log1p_cpu.cc b/src/infinicore/ops/log1p/log1p_cpu.cc new file mode 100644 index 000000000..9625a14a6 --- /dev/null +++ b/src/infinicore/ops/log1p/log1p_cpu.cc @@ -0,0 +1,69 @@ +#include "../../../utils.h" + +#include "infinicore/device.hpp" +#include "infinicore/ops/log1p.hpp" +#include "infinicore/tensor.hpp" +#include +#include +#include + +namespace infinicore::op::log1p_impl::cpu { + +void calculate(Tensor output, Tensor input) { + auto strides = input->strides(); // vector + auto shapes = input->shape(); // vector + auto ndim = input->ndim(); + auto dtype = input->dtype(); + auto dtype_size = input->element_size(); + auto numel = input->numel(); + + auto input_base = input->data(); + auto output_base = output->data(); + + std::vector indices(ndim, 0); + for (size_t idx = 0; idx < numel; ++idx) { + // Calculate the offset for the current index + size_t offset = 0; + for (size_t dim = 0; dim < ndim; ++dim) { + offset += indices[dim] * strides[dim]; + } + + // Compute log1p for the current element + if (dtype == DataType::F32) { + auto *input_ptr = reinterpret_cast(input_base + offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + offset * dtype_size); + *output_ptr = std::log(*input_ptr + 1); + } else if (dtype == DataType::F64) { + auto *input_ptr = reinterpret_cast(input_base + offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + offset * dtype_size); + *output_ptr = std::log(*input_ptr + 1); + } else if (dtype == DataType::F16) { + // F16: 转换为 F32 计算,再转回 F16 + auto *input_ptr = reinterpret_cast(input_base + offset * dtype_size); + auto *output_ptr = reinterpret_cast(output_base + offset * dtype_size); + + float input_f32 = utils::cast(*input_ptr); + float output_f32 = std::log(input_f32 + 1); + *output_ptr = utils::cast(output_f32); + } else { + throw std::runtime_error("Unsupported data type for log1p operation."); + } + + // Update indices + for (ssize_t dim = ndim - 1; dim >= 0; --dim) { + indices[dim]++; + if (indices[dim] < shapes[dim]) { + break; + } else { + indices[dim] = 0; + } + } + } +} + +static bool registered = []() { + Log1p::dispatcher().registerDevice(Device::Type::CPU, &calculate); + return true; +}(); + +} // namespace infinicore::op::log1p_impl::cpu diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 531043dd2..249e37da3 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -4,9 +4,14 @@ #include "ops/add.hpp" #include "ops/attention.hpp" +#include "ops/avg_pool3d.hpp" #include "ops/causal_softmax.hpp" +#include "ops/dot.hpp" #include "ops/embedding.hpp" +#include "ops/histc.hpp" #include "ops/linear.hpp" +#include "ops/log10.hpp" +#include "ops/log1p.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" #include "ops/random_sample.hpp" @@ -35,7 +40,12 @@ inline void bind(py::module &m) { bind_swiglu(m); bind_rope(m); bind_embedding(m); + bind_histc(m); bind_zeros_(m); + bind_log10(m); + bind_avg_pool3d(m); + bind_dot(m); + bind_log1p(m); } } // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/avg_pool3d.hpp b/src/infinicore/pybind11/ops/avg_pool3d.hpp new file mode 100644 index 000000000..7386b068f --- /dev/null +++ b/src/infinicore/pybind11/ops/avg_pool3d.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "infinicore/ops/avg_pool3d.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_avg_pool3d(py::module &m) { + m.def("avg_pool3d", + &op::avg_pool3d, + py::arg("input"), + py::arg("kernel_size"), + py::arg("stride") = py::none(), + py::arg("padding") = 0, + py::arg("ceil_mode") = false, + R"doc(Applies 3D average-pooling operation in :math:`kD \ times kH \times kW` regions by step size + :math:`sD \times sH \times sW` steps. The number of output features is equal to the number of + input planes.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/dot.hpp b/src/infinicore/pybind11/ops/dot.hpp new file mode 100644 index 000000000..a5bc98012 --- /dev/null +++ b/src/infinicore/pybind11/ops/dot.hpp @@ -0,0 +1,19 @@ +#pragma once + +#include + +#include "infinicore/ops/dot.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_dot(py::module &m) { + m.def("dot", + &op::dot, + py::arg("input"), + py::arg("tensor"), + R"doc(Computes the dot product of two 1-D tensors.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/histc.hpp b/src/infinicore/pybind11/ops/histc.hpp new file mode 100644 index 000000000..44e83bb84 --- /dev/null +++ b/src/infinicore/pybind11/ops/histc.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include + +#include "infinicore/ops/histc.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_histc(py::module &m) { + m.def("histc", + &op::histc, + py::arg("input"), + py::arg("bins"), + py::arg("min"), + py::arg("max"), + R"doc(Computes the histogram of a tensor.)doc"); + + m.def("log10_", + &op::histc_, + py::arg("input"), + py::arg("output"), + py::arg("bins"), + py::arg("min"), + py::arg("max"), + R"doc(In-place histogram computation.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/log10.hpp b/src/infinicore/pybind11/ops/log10.hpp new file mode 100644 index 000000000..55b1f1909 --- /dev/null +++ b/src/infinicore/pybind11/ops/log10.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "infinicore/ops/log10.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_log10(py::module &m) { + m.def("log10", + &op::log10, + py::arg("input"), + R"doc(Logarithm base 10 of the tensor.)doc"); + + m.def("log10_", + &op::log10_, + py::arg("input"), + py::arg("output"), + R"doc(In-place logarithm base 10 computation.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/log1p.hpp b/src/infinicore/pybind11/ops/log1p.hpp new file mode 100644 index 000000000..a4218134f --- /dev/null +++ b/src/infinicore/pybind11/ops/log1p.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "infinicore/ops/log1p.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_log1p(py::module &m) { + m.def("log1p", + &op::log1p, + py::arg("input"), + R"doc(Returns a new tensor with the natural logarithm of (1 + input).)doc"); + + m.def("log1p_", + &op::log1p_, + py::arg("input"), + py::arg("output"), + R"doc(In-place computation of the natural logarithm of (1 + input).)doc"); +} + +} // namespace infinicore::ops diff --git a/test/infinicore/ops/avg_pool3d.py b/test/infinicore/ops/avg_pool3d.py index adb356227..e63e92a27 100644 --- a/test/infinicore/ops/avg_pool3d.py +++ b/test/infinicore/ops/avg_pool3d.py @@ -70,9 +70,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.nn.functional.avg_pool3d(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.nn.functional.avg_pool3d(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.nn.functional.avg_pool3d(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/dot.py b/test/infinicore/ops/dot.py index 5d2300d24..00c7ccb78 100644 --- a/test/infinicore/ops/dot.py +++ b/test/infinicore/ops/dot.py @@ -64,9 +64,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.dot(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.dot(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.dot(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/histc.py b/test/infinicore/ops/histc.py index 26ba35986..93b4fa286 100644 --- a/test/infinicore/ops/histc.py +++ b/test/infinicore/ops/histc.py @@ -58,9 +58,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.histc(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.histc(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.histc(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/log10.py b/test/infinicore/ops/log10.py index fbb0863e3..05115dca0 100644 --- a/test/infinicore/ops/log10.py +++ b/test/infinicore/ops/log10.py @@ -87,9 +87,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.log10(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.log10(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.log10(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/log1p.py b/test/infinicore/ops/log1p.py index 32996927f..51e390cc3 100644 --- a/test/infinicore/ops/log1p.py +++ b/test/infinicore/ops/log1p.py @@ -87,9 +87,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.log1p(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.log1p(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.log1p(*args, **kwargs) def main():