From 832002919087ae7342e0826c3b4925753db3408f Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Mon, 15 Dec 2025 19:43:12 +0800 Subject: [PATCH 01/17] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81logical=5For\?= =?UTF-8?q?=20xor=20,=20logsigmoid=20cpu=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infiniop.h | 3 + python/infinicore/__init__.py | 6 ++ python/infinicore/nn/functional/__init__.py | 2 + src/infinicore/pybind11/ops.hpp | 6 ++ test/infinicore/framework/utils.py | 91 +++++++++++++++------ test/infinicore/ops/logical_or.py | 6 +- test/infinicore/ops/logical_xor.py | 6 +- test/infinicore/ops/logsigmoid.py | 6 +- 8 files changed, 91 insertions(+), 35 deletions(-) diff --git a/include/infiniop.h b/include/infiniop.h index 92e6f5963..8eaeaed6d 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -11,6 +11,9 @@ #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" #include "infiniop/ops/layer_norm.h" +#include "infiniop/ops/logical_or.h" +#include "infiniop/ops/logical_xor.h" +#include "infiniop/ops/logsigmoid.h" #include "infiniop/ops/logsoftmax.h" #include "infiniop/ops/lp_norm.h" #include "infiniop/ops/mul.h" diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 5c541ec3c..ebae5827a 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -41,6 +41,9 @@ ) from infinicore.ops.add import add from infinicore.ops.attention import attention +from infinicore.ops.logical_or import logical_or +from infinicore.ops.logical_xor import logical_xor +from infinicore.ops.logsigmoid import logsigmoid from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow @@ -101,6 +104,9 @@ # Operations. "add", "attention", + "logical_or", + "logical_xor", + "logsigmoid", "matmul", "mul", "narrow", diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 255079790..5819a19e6 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -1,6 +1,7 @@ from .causal_softmax import causal_softmax from .embedding import embedding from .linear import linear +from .logsigmoid import logsigmoid from .random_sample import random_sample from .rms_norm import rms_norm from .rope import RopeAlgo, rope @@ -17,4 +18,5 @@ "embedding", "rope", "RopeAlgo", + "logsigmoid", ] diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 978defa17..cda30f3c2 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -7,6 +7,9 @@ #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" #include "ops/linear.hpp" +#include "ops/logical_or.hpp" +#include "ops/logical_xor.hpp" +#include "ops/logsigmoid.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" #include "ops/random_sample.hpp" @@ -26,6 +29,9 @@ inline void bind(py::module &m) { bind_causal_softmax(m); bind_random_sample(m); bind_linear(m); + bind_logical_or(m); + bind_logical_xor(m); + bind_logsigmoid(m); bind_matmul(m); bind_mul(m); bind_rearrange(m); diff --git a/test/infinicore/framework/utils.py b/test/infinicore/framework/utils.py index dc16cee81..564557348 100644 --- a/test/infinicore/framework/utils.py +++ b/test/infinicore/framework/utils.py @@ -26,14 +26,21 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): elif actual.dtype == torch.bfloat16 or desired.dtype == torch.bfloat16: actual = actual.to(torch.float32) desired = desired.to(torch.float32) + # Note: bool tensors are handled inside print_discrepancy print_discrepancy(actual, desired, atol, rtol, equal_nan, verbose) import numpy as np - np.testing.assert_allclose( - actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True - ) + # For bool tensors, use assert_equal instead of assert_allclose + if actual.dtype == torch.bool or desired.dtype == torch.bool: + np.testing.assert_equal( + actual.cpu().numpy(), desired.cpu().numpy() + ) + else: + np.testing.assert_allclose( + actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True + ) def print_discrepancy( @@ -47,19 +54,40 @@ def print_discrepancy( import sys is_terminal = sys.stdout.isatty() + + # Handle bool tensors specially - PyTorch doesn't support subtraction for bool + is_bool = actual.dtype == torch.bool or expected.dtype == torch.bool + + if is_bool: + # For bool tensors, convert to int8 for comparison operations + actual_for_calc = actual.to(torch.int8) if actual.dtype == torch.bool else actual + expected_for_calc = expected.to(torch.int8) if expected.dtype == torch.bool else expected + else: + actual_for_calc = actual + expected_for_calc = expected - actual_isnan = torch.isnan(actual) - expected_isnan = torch.isnan(expected) + actual_isnan = torch.isnan(actual_for_calc) + expected_isnan = torch.isnan(expected_for_calc) # Calculate difference mask - nan_mismatch = ( - actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan - ) - diff_mask = nan_mismatch | ( - torch.abs(actual - expected) > (atol + rtol * torch.abs(expected)) - ) + if is_bool: + # For bool tensors, just check equality + diff_mask = actual != expected + else: + nan_mismatch = ( + actual_isnan ^ expected_isnan if equal_nan else actual_isnan | expected_isnan + ) + diff_mask = nan_mismatch | ( + torch.abs(actual_for_calc - expected_for_calc) > (atol + rtol * torch.abs(expected_for_calc)) + ) + diff_indices = torch.nonzero(diff_mask, as_tuple=False) - delta = actual - expected + + # Calculate delta (difference) - convert bool to int if needed + if is_bool: + delta = (actual.to(torch.int8) - expected.to(torch.int8)) + else: + delta = actual_for_calc - expected_for_calc # Display formatting col_width = [18, 20, 20, 20] @@ -75,11 +103,20 @@ def add_color(text, color_code): if verbose: for idx in diff_indices: index_tuple = tuple(idx.tolist()) - actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}" - expected_str = ( - f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}" - ) - delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}" + if is_bool: + # For bool, display as True/False + actual_val = actual[index_tuple].item() + expected_val = expected[index_tuple].item() + actual_str = f"{str(actual_val):<{col_width[1]}}" + expected_str = f"{str(expected_val):<{col_width[2]}}" + delta_val = delta[index_tuple].item() + delta_str = f"{delta_val:<{col_width[3]}}" + else: + actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}" + expected_str = ( + f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}" + ) + delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}" print( f" > Index: {str(index_tuple):<{col_width[0]}}" f"actual: {add_color(actual_str, 31)}" @@ -95,15 +132,17 @@ def add_color(text, color_code): print( f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)" ) - print( - f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}" - ) - print( - f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}" - ) - print( - f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}" - ) + + if not is_bool: + print( + f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}" + ) + print( + f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}" + ) + print( + f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}" + ) print("-" * total_width) return diff_indices diff --git a/test/infinicore/ops/logical_or.py b/test/infinicore/ops/logical_or.py index a51b7384f..5fc39cc77 100644 --- a/test/infinicore/ops/logical_or.py +++ b/test/infinicore/ops/logical_or.py @@ -107,9 +107,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.logical_or(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.logical_or(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.logical_or(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/logical_xor.py b/test/infinicore/ops/logical_xor.py index 4f82cc613..f6e2ac92d 100644 --- a/test/infinicore/ops/logical_xor.py +++ b/test/infinicore/ops/logical_xor.py @@ -107,9 +107,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.logical_xor(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.logical_xor(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.logical_xor(*args, **kwargs) def main(): diff --git a/test/infinicore/ops/logsigmoid.py b/test/infinicore/ops/logsigmoid.py index 0c9131323..44212ec80 100644 --- a/test/infinicore/ops/logsigmoid.py +++ b/test/infinicore/ops/logsigmoid.py @@ -68,9 +68,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.nn.functional.logsigmoid(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.nn.functional.logsigmoid(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.nn.functional.logsigmoid(*args, **kwargs) def main(): From 5756b4b502b285ba4c63a4d7156a22ca8256fbbe Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Mon, 15 Dec 2025 20:02:36 +0800 Subject: [PATCH 02/17] =?UTF-8?q?=E6=94=AF=E6=8C=81=20where=20=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops/logical_or.hpp | 17 +++ include/infinicore/ops/logical_xor.hpp | 17 +++ include/infinicore/ops/logsigmoid.hpp | 17 +++ include/infinicore/ops/where.hpp | 20 +++ include/infiniop.h | 1 + include/infiniop/ops/logical_or.h | 27 ++++ include/infiniop/ops/logical_xor.h | 27 ++++ include/infiniop/ops/logsigmoid.h | 25 ++++ include/infiniop/ops/where.h | 36 ++++++ python/infinicore/__init__.py | 2 + python/infinicore/nn/functional/logsigmoid.py | 13 ++ python/infinicore/ops/logical_or.py | 12 ++ python/infinicore/ops/logical_xor.py | 12 ++ python/infinicore/ops/logsigmoid.py | 12 ++ python/infinicore/ops/where.py | 52 ++++++++ python/infinicore/utils.py | 4 + src/infinicore/ops/logical_or/logical_or.cc | 28 +++++ .../ops/logical_or/logical_or_infiniop.cc | 53 ++++++++ src/infinicore/ops/logical_xor/logical_xor.cc | 28 +++++ .../ops/logical_xor/logical_xor_infiniop.cc | 53 ++++++++ src/infinicore/ops/logsigmoid/logsigmoid.cc | 35 ++++++ .../ops/logsigmoid/logsigmoid_infiniop.cc | 53 ++++++++ src/infinicore/ops/where/where.cc | 40 ++++++ src/infinicore/ops/where/where_infiniop.cc | 54 ++++++++ src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/logical_or.hpp | 27 ++++ src/infinicore/pybind11/ops/logical_xor.hpp | 27 ++++ src/infinicore/pybind11/ops/logsigmoid.hpp | 25 ++++ src/infinicore/pybind11/ops/where.hpp | 30 +++++ .../ops/logical_or/cpu/logical_or_cpu.cc | 116 ++++++++++++++++++ .../ops/logical_or/cpu/logical_or_cpu.h | 25 ++++ src/infiniop/ops/logical_or/operator.cc | 99 +++++++++++++++ .../ops/logical_xor/cpu/logical_xor_cpu.cc | 116 ++++++++++++++++++ .../ops/logical_xor/cpu/logical_xor_cpu.h | 25 ++++ src/infiniop/ops/logical_xor/operator.cc | 99 +++++++++++++++ .../ops/logsigmoid/cpu/logsigmoid_cpu.cc | 52 ++++++++ .../ops/logsigmoid/cpu/logsigmoid_cpu.h | 21 ++++ src/infiniop/ops/logsigmoid/operator.cc | 101 +++++++++++++++ src/infiniop/ops/where/cpu/where_cpu.cc | 92 ++++++++++++++ src/infiniop/ops/where/cpu/where_cpu.h | 33 +++++ src/infiniop/ops/where/operator.cc | 106 ++++++++++++++++ test/infinicore/ops/where.py | 6 +- 42 files changed, 1637 insertions(+), 3 deletions(-) create mode 100644 include/infinicore/ops/logical_or.hpp create mode 100644 include/infinicore/ops/logical_xor.hpp create mode 100644 include/infinicore/ops/logsigmoid.hpp create mode 100644 include/infinicore/ops/where.hpp create mode 100644 include/infiniop/ops/logical_or.h create mode 100644 include/infiniop/ops/logical_xor.h create mode 100644 include/infiniop/ops/logsigmoid.h create mode 100644 include/infiniop/ops/where.h create mode 100644 python/infinicore/nn/functional/logsigmoid.py create mode 100644 python/infinicore/ops/logical_or.py create mode 100644 python/infinicore/ops/logical_xor.py create mode 100644 python/infinicore/ops/logsigmoid.py create mode 100644 python/infinicore/ops/where.py create mode 100644 src/infinicore/ops/logical_or/logical_or.cc create mode 100644 src/infinicore/ops/logical_or/logical_or_infiniop.cc create mode 100644 src/infinicore/ops/logical_xor/logical_xor.cc create mode 100644 src/infinicore/ops/logical_xor/logical_xor_infiniop.cc create mode 100644 src/infinicore/ops/logsigmoid/logsigmoid.cc create mode 100644 src/infinicore/ops/logsigmoid/logsigmoid_infiniop.cc create mode 100644 src/infinicore/ops/where/where.cc create mode 100644 src/infinicore/ops/where/where_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/logical_or.hpp create mode 100644 src/infinicore/pybind11/ops/logical_xor.hpp create mode 100644 src/infinicore/pybind11/ops/logsigmoid.hpp create mode 100644 src/infinicore/pybind11/ops/where.hpp create mode 100644 src/infiniop/ops/logical_or/cpu/logical_or_cpu.cc create mode 100644 src/infiniop/ops/logical_or/cpu/logical_or_cpu.h create mode 100644 src/infiniop/ops/logical_or/operator.cc create mode 100644 src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.cc create mode 100644 src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.h create mode 100644 src/infiniop/ops/logical_xor/operator.cc create mode 100644 src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.cc create mode 100644 src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.h create mode 100644 src/infiniop/ops/logsigmoid/operator.cc create mode 100644 src/infiniop/ops/where/cpu/where_cpu.cc create mode 100644 src/infiniop/ops/where/cpu/where_cpu.h create mode 100644 src/infiniop/ops/where/operator.cc diff --git a/include/infinicore/ops/logical_or.hpp b/include/infinicore/ops/logical_or.hpp new file mode 100644 index 000000000..6342aeda3 --- /dev/null +++ b/include/infinicore/ops/logical_or.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class LogicalOr { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor c, Tensor a, Tensor b); + static common::OpDispatcher &dispatcher(); +}; + +Tensor logical_or(Tensor a, Tensor b); +void logical_or_(Tensor c, Tensor a, Tensor b); +} // namespace infinicore::op + diff --git a/include/infinicore/ops/logical_xor.hpp b/include/infinicore/ops/logical_xor.hpp new file mode 100644 index 000000000..e9624bdd6 --- /dev/null +++ b/include/infinicore/ops/logical_xor.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class LogicalXor { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor c, Tensor a, Tensor b); + static common::OpDispatcher &dispatcher(); +}; + +Tensor logical_xor(Tensor a, Tensor b); +void logical_xor_(Tensor c, Tensor a, Tensor b); +} // namespace infinicore::op + diff --git a/include/infinicore/ops/logsigmoid.hpp b/include/infinicore/ops/logsigmoid.hpp new file mode 100644 index 000000000..ba7af1699 --- /dev/null +++ b/include/infinicore/ops/logsigmoid.hpp @@ -0,0 +1,17 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class LogSigmoid { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor output, Tensor input); + static common::OpDispatcher &dispatcher(); +}; + +Tensor logsigmoid(Tensor input); +void logsigmoid_(Tensor output, Tensor input); +} // namespace infinicore::op + diff --git a/include/infinicore/ops/where.hpp b/include/infinicore/ops/where.hpp new file mode 100644 index 000000000..72f1b8506 --- /dev/null +++ b/include/infinicore/ops/where.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class Where { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor); + static void execute(Tensor out, Tensor cond, Tensor x, Tensor y); + static common::OpDispatcher &dispatcher(); +}; + +Tensor where(Tensor cond, Tensor x, Tensor y); +void where_(Tensor out, Tensor cond, Tensor x, Tensor y); + +} // namespace infinicore::op + + diff --git a/include/infiniop.h b/include/infiniop.h index 8eaeaed6d..66f15df70 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -32,6 +32,7 @@ #include "infiniop/ops/tanh.h" #include "infiniop/ops/topkrouter.h" #include "infiniop/ops/topksoftmax.h" +#include "infiniop/ops/where.h" #include "infiniop/ops/zeros.h" #include "infiniop/tensor_descriptor.h" diff --git a/include/infiniop/ops/logical_or.h b/include/infiniop/ops/logical_or.h new file mode 100644 index 000000000..0c5351f63 --- /dev/null +++ b/include/infiniop/ops/logical_or.h @@ -0,0 +1,27 @@ +#ifndef __INFINIOP_LOGICAL_OR_API_H__ +#define __INFINIOP_LOGICAL_OR_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopLogicalOrDescriptor_t; + +__C __export infiniStatus_t infiniopCreateLogicalOrDescriptor(infiniopHandle_t handle, + infiniopLogicalOrDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + +__C __export infiniStatus_t infiniopGetLogicalOrWorkspaceSize(infiniopLogicalOrDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopLogicalOr(infiniopLogicalOrDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream); + +__C __export infiniStatus_t infiniopDestroyLogicalOrDescriptor(infiniopLogicalOrDescriptor_t desc); + +#endif + diff --git a/include/infiniop/ops/logical_xor.h b/include/infiniop/ops/logical_xor.h new file mode 100644 index 000000000..fe6f9ba68 --- /dev/null +++ b/include/infiniop/ops/logical_xor.h @@ -0,0 +1,27 @@ +#ifndef __INFINIOP_LOGICAL_XOR_API_H__ +#define __INFINIOP_LOGICAL_XOR_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopLogicalXorDescriptor_t; + +__C __export infiniStatus_t infiniopCreateLogicalXorDescriptor(infiniopHandle_t handle, + infiniopLogicalXorDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + +__C __export infiniStatus_t infiniopGetLogicalXorWorkspaceSize(infiniopLogicalXorDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopLogicalXor(infiniopLogicalXorDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream); + +__C __export infiniStatus_t infiniopDestroyLogicalXorDescriptor(infiniopLogicalXorDescriptor_t desc); + +#endif + diff --git a/include/infiniop/ops/logsigmoid.h b/include/infiniop/ops/logsigmoid.h new file mode 100644 index 000000000..5037bc67e --- /dev/null +++ b/include/infiniop/ops/logsigmoid.h @@ -0,0 +1,25 @@ +#ifndef __INFINIOP_LOGSIGMOID_API_H__ +#define __INFINIOP_LOGSIGMOID_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopLogSigmoidDescriptor_t; + +__C __export infiniStatus_t infiniopCreateLogSigmoidDescriptor(infiniopHandle_t handle, + infiniopLogSigmoidDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +__C __export infiniStatus_t infiniopGetLogSigmoidWorkspaceSize(infiniopLogSigmoidDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopLogSigmoid(infiniopLogSigmoidDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream); + +__C __export infiniStatus_t infiniopDestroyLogSigmoidDescriptor(infiniopLogSigmoidDescriptor_t desc); + +#endif + diff --git a/include/infiniop/ops/where.h b/include/infiniop/ops/where.h new file mode 100644 index 000000000..95f8dd6f7 --- /dev/null +++ b/include/infiniop/ops/where.h @@ -0,0 +1,36 @@ +#ifndef __INFINIOP_WHERE_API_H__ +#define __INFINIOP_WHERE_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopWhereDescriptor_t; + +// y = where(cond, x, y) +__C __export infiniStatus_t infiniopCreateWhereDescriptor( + infiniopHandle_t handle, + infiniopWhereDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t cond_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc); + +__C __export infiniStatus_t infiniopGetWhereWorkspaceSize( + infiniopWhereDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopWhere( + infiniopWhereDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *cond, + const void *x, + const void *y, + void *stream); + +__C __export infiniStatus_t infiniopDestroyWhereDescriptor( + infiniopWhereDescriptor_t desc); + +#endif + + diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index ebae5827a..dcf8f4af3 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -44,6 +44,7 @@ from infinicore.ops.logical_or import logical_or from infinicore.ops.logical_xor import logical_xor from infinicore.ops.logsigmoid import logsigmoid +from infinicore.ops.where import where from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow @@ -107,6 +108,7 @@ "logical_or", "logical_xor", "logsigmoid", + "where", "matmul", "mul", "narrow", diff --git a/python/infinicore/nn/functional/logsigmoid.py b/python/infinicore/nn/functional/logsigmoid.py new file mode 100644 index 000000000..bb327f5b6 --- /dev/null +++ b/python/infinicore/nn/functional/logsigmoid.py @@ -0,0 +1,13 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def logsigmoid(input: Tensor, out=None) -> Tensor: + """Apply elementwise log-sigmoid.""" + if out is None: + return Tensor(_infinicore.logsigmoid(input._underlying)) + + _infinicore.logsigmoid_(out._underlying, input._underlying) + return out + + diff --git a/python/infinicore/ops/logical_or.py b/python/infinicore/ops/logical_or.py new file mode 100644 index 000000000..5016046a2 --- /dev/null +++ b/python/infinicore/ops/logical_or.py @@ -0,0 +1,12 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def logical_or(input, other, *, out=None): + if out is None: + return Tensor(_infinicore.logical_or(input._underlying, other._underlying)) + + _infinicore.logical_or_(out._underlying, input._underlying, other._underlying) + + return out + diff --git a/python/infinicore/ops/logical_xor.py b/python/infinicore/ops/logical_xor.py new file mode 100644 index 000000000..c4029f6bb --- /dev/null +++ b/python/infinicore/ops/logical_xor.py @@ -0,0 +1,12 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def logical_xor(input, other, *, out=None): + if out is None: + return Tensor(_infinicore.logical_xor(input._underlying, other._underlying)) + + _infinicore.logical_xor_(out._underlying, input._underlying, other._underlying) + + return out + diff --git a/python/infinicore/ops/logsigmoid.py b/python/infinicore/ops/logsigmoid.py new file mode 100644 index 000000000..128ced942 --- /dev/null +++ b/python/infinicore/ops/logsigmoid.py @@ -0,0 +1,12 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def logsigmoid(input, *, out=None): + if out is None: + return Tensor(_infinicore.logsigmoid(input._underlying)) + + _infinicore.logsigmoid_(out._underlying, input._underlying) + + return out + diff --git a/python/infinicore/ops/where.py b/python/infinicore/ops/where.py new file mode 100644 index 000000000..18650bd98 --- /dev/null +++ b/python/infinicore/ops/where.py @@ -0,0 +1,52 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor, from_torch + +import torch + + +def where(*args, out=None): + """Elementwise where(cond, x, y) selection. + + Two supported call patterns: + - where(cond, x, y) -> Tensor + - where(cond, x, y, out=...) -> out + + The condition-only variant where(cond) returning indices is implemented + by delegating to the underlying Torch tensor stored in cond._torch_ref. + """ + # condition-only mode: where(cond) -> indices tuple + if len(args) == 1: + cond = args[0] + + # Prefer using the original Torch tensor reference when available + cond_torch = getattr(cond, "_torch_ref", None) + if cond_torch is None: + # Fallback: create a Torch tensor, then copy data from infinicore tensor. + # Tests use CPU bool tensors for condition-only where. + cond_torch = torch.zeros( + cond.shape, + dtype=torch.bool, + device="cpu", + ) + # Share storage between Torch tensor and an infinicore view, then copy. + ic_view = from_torch(cond_torch) + ic_view.copy_(cond) + + idx_tensors = torch.where(cond_torch) + # torch.where(cond) returns a tuple of index tensors; mirror that with + # infinicore tensors sharing the same underlying storage. + return tuple(from_torch(t) for t in idx_tensors) + + if len(args) != 3: + raise TypeError("infinicore.where expects (cond, x, y)") + + cond, x, y = args + + if out is None: + return Tensor(_infinicore.where(cond._underlying, x._underlying, y._underlying)) + + _infinicore.where_(out._underlying, cond._underlying, x._underlying, y._underlying) + return out + + + diff --git a/python/infinicore/utils.py b/python/infinicore/utils.py index 094b2230e..4e6d3343c 100644 --- a/python/infinicore/utils.py +++ b/python/infinicore/utils.py @@ -23,6 +23,8 @@ def to_torch_dtype(infini_dtype): return torch.int64 elif infini_dtype == infinicore.uint8: return torch.uint8 + elif infini_dtype == infinicore.bool: + return torch.bool else: raise ValueError(f"Unsupported infinicore dtype: {infini_dtype}") @@ -45,6 +47,8 @@ def to_infinicore_dtype(torch_dtype): return infinicore.int64 elif torch_dtype == torch.uint8: return infinicore.uint8 + elif torch_dtype == torch.bool: + return infinicore.bool else: raise ValueError(f"Unsupported torch dtype: {torch_dtype}") diff --git a/src/infinicore/ops/logical_or/logical_or.cc b/src/infinicore/ops/logical_or/logical_or.cc new file mode 100644 index 000000000..96df42470 --- /dev/null +++ b/src/infinicore/ops/logical_or/logical_or.cc @@ -0,0 +1,28 @@ +#include "infinicore/ops/logical_or.hpp" +#include "../../utils.hpp" +#include "infinicore/dtype.hpp" + +namespace infinicore::op { + +common::OpDispatcher &LogicalOr::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void LogicalOr::execute(Tensor c, Tensor a, Tensor b) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); + infinicore::context::setDevice(c->device()); + dispatcher().lookup(c->device().getType())(c, a, b); +} + +Tensor logical_or(Tensor a, Tensor b) { + auto c = Tensor::empty(a->shape(), DataType::BOOL, a->device()); + logical_or_(c, a, b); + return c; +} + +void logical_or_(Tensor c, Tensor a, Tensor b) { + LogicalOr::execute(c, a, b); +} +} // namespace infinicore::op + diff --git a/src/infinicore/ops/logical_or/logical_or_infiniop.cc b/src/infinicore/ops/logical_or/logical_or_infiniop.cc new file mode 100644 index 000000000..66118968d --- /dev/null +++ b/src/infinicore/ops/logical_or/logical_or_infiniop.cc @@ -0,0 +1,53 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/logical_or.hpp" +#include + +namespace infinicore::op::logical_or_impl::infiniop { + +thread_local common::OpCache caches( + 100, + [](infiniopLogicalOrDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyLogicalOrDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor c, Tensor a, Tensor b) { + size_t seed = hash_combine(c, a, b); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopLogicalOrDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateLogicalOrDescriptor( + context::getInfiniopHandle(c->device()), &desc, + c->desc(), a->desc(), b->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetLogicalOrWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopLogicalOr( + desc, workspace->data(), workspace_size, + c->data(), a->data(), b->data(), context::getStream())); +} + +static bool registered = []() { + LogicalOr::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::logical_or_impl::infiniop + diff --git a/src/infinicore/ops/logical_xor/logical_xor.cc b/src/infinicore/ops/logical_xor/logical_xor.cc new file mode 100644 index 000000000..bd893f1f6 --- /dev/null +++ b/src/infinicore/ops/logical_xor/logical_xor.cc @@ -0,0 +1,28 @@ +#include "infinicore/ops/logical_xor.hpp" +#include "../../utils.hpp" +#include "infinicore/dtype.hpp" + +namespace infinicore::op { + +common::OpDispatcher &LogicalXor::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void LogicalXor::execute(Tensor c, Tensor a, Tensor b) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); + infinicore::context::setDevice(c->device()); + dispatcher().lookup(c->device().getType())(c, a, b); +} + +Tensor logical_xor(Tensor a, Tensor b) { + auto c = Tensor::empty(a->shape(), DataType::BOOL, a->device()); + logical_xor_(c, a, b); + return c; +} + +void logical_xor_(Tensor c, Tensor a, Tensor b) { + LogicalXor::execute(c, a, b); +} +} // namespace infinicore::op + diff --git a/src/infinicore/ops/logical_xor/logical_xor_infiniop.cc b/src/infinicore/ops/logical_xor/logical_xor_infiniop.cc new file mode 100644 index 000000000..76226cbbc --- /dev/null +++ b/src/infinicore/ops/logical_xor/logical_xor_infiniop.cc @@ -0,0 +1,53 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/logical_xor.hpp" +#include + +namespace infinicore::op::logical_xor_impl::infiniop { + +thread_local common::OpCache caches( + 100, + [](infiniopLogicalXorDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyLogicalXorDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor c, Tensor a, Tensor b) { + size_t seed = hash_combine(c, a, b); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopLogicalXorDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateLogicalXorDescriptor( + context::getInfiniopHandle(c->device()), &desc, + c->desc(), a->desc(), b->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetLogicalXorWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopLogicalXor( + desc, workspace->data(), workspace_size, + c->data(), a->data(), b->data(), context::getStream())); +} + +static bool registered = []() { + LogicalXor::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::logical_xor_impl::infiniop + diff --git a/src/infinicore/ops/logsigmoid/logsigmoid.cc b/src/infinicore/ops/logsigmoid/logsigmoid.cc new file mode 100644 index 000000000..1ed7da9f6 --- /dev/null +++ b/src/infinicore/ops/logsigmoid/logsigmoid.cc @@ -0,0 +1,35 @@ +#include "infinicore/ops/logsigmoid.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &LogSigmoid::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void LogSigmoid::execute(Tensor output, Tensor input) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(output, input); + infinicore::context::setDevice(output->device()); + auto device_type = output->device().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error("No LogSigmoid implementation found for device type: " + std::to_string(static_cast(device_type))); + } + + func(output, input); +} + +Tensor logsigmoid(Tensor input) { + Shape shape = input->shape(); + auto output = Tensor::empty(shape, input->dtype(), input->device()); + logsigmoid_(output, input); + return output; +} + +void logsigmoid_(Tensor output, Tensor input) { + LogSigmoid::execute(output, input); +} +} // namespace infinicore::op + diff --git a/src/infinicore/ops/logsigmoid/logsigmoid_infiniop.cc b/src/infinicore/ops/logsigmoid/logsigmoid_infiniop.cc new file mode 100644 index 000000000..e81a880a0 --- /dev/null +++ b/src/infinicore/ops/logsigmoid/logsigmoid_infiniop.cc @@ -0,0 +1,53 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/logsigmoid.hpp" +#include + +namespace infinicore::op::logsigmoid_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopLogSigmoidDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyLogSigmoidDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor output, Tensor input) { + size_t seed = hash_combine(output, input); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopLogSigmoidDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateLogSigmoidDescriptor( + context::getInfiniopHandle(output->device()), &desc, + output->desc(), input->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetLogSigmoidWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopLogSigmoid( + desc, workspace->data(), workspace_size, + output->data(), input->data(), context::getStream())); +} + +static bool registered = []() { + LogSigmoid::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::logsigmoid_impl::infiniop + diff --git a/src/infinicore/ops/where/where.cc b/src/infinicore/ops/where/where.cc new file mode 100644 index 000000000..6eabc9e1b --- /dev/null +++ b/src/infinicore/ops/where/where.cc @@ -0,0 +1,40 @@ +#include "infinicore/ops/where.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Where::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void Where::execute(Tensor out, Tensor cond, Tensor x, Tensor y) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, cond, x, y); + infinicore::context::setDevice(out->device()); + auto device_type = out->device().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error( + "No Where implementation found for device type: " + + std::to_string(static_cast(device_type))); + } + + func(out, cond, x, y); +} + +Tensor where(Tensor cond, Tensor x, Tensor y) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(cond, x, y); + // Output dtype follows x/y dtype + auto out = Tensor::empty(x->shape(), x->dtype(), x->device()); + where_(out, cond, x, y); + return out; +} + +void where_(Tensor out, Tensor cond, Tensor x, Tensor y) { + Where::execute(out, cond, x, y); +} + +} // namespace infinicore::op + + diff --git a/src/infinicore/ops/where/where_infiniop.cc b/src/infinicore/ops/where/where_infiniop.cc new file mode 100644 index 000000000..e8964acf4 --- /dev/null +++ b/src/infinicore/ops/where/where_infiniop.cc @@ -0,0 +1,54 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/where.hpp" +#include + +namespace infinicore::op::where_impl::infiniop { + +thread_local common::OpCache caches( + 100, + [](infiniopWhereDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyWhereDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor out, Tensor cond, Tensor x, Tensor y) { + size_t seed = hash_combine(out, cond, x, y); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopWhereDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateWhereDescriptor( + context::getInfiniopHandle(out->device()), &desc, + out->desc(), cond->desc(), x->desc(), y->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetWhereWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopWhere( + desc, workspace->data(), workspace_size, + out->data(), cond->data(), x->data(), y->data(), context::getStream())); +} + +static bool registered = []() { + Where::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::where_impl::infiniop + + diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index cda30f3c2..f8373b37f 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -17,6 +17,7 @@ #include "ops/rms_norm.hpp" #include "ops/rope.hpp" #include "ops/silu.hpp" +#include "ops/where.hpp" #include "ops/swiglu.hpp" namespace py = pybind11; @@ -40,6 +41,7 @@ inline void bind(py::module &m) { bind_swiglu(m); bind_rope(m); bind_embedding(m); + bind_where(m); } } // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/logical_or.hpp b/src/infinicore/pybind11/ops/logical_or.hpp new file mode 100644 index 000000000..fef213e30 --- /dev/null +++ b/src/infinicore/pybind11/ops/logical_or.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include + +#include "infinicore/ops/logical_or.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_logical_or(py::module &m) { + m.def("logical_or", + &op::logical_or, + py::arg("a"), + py::arg("b"), + R"doc(Logical OR of two tensors.)doc"); + + m.def("logical_or_", + &op::logical_or_, + py::arg("c"), + py::arg("a"), + py::arg("b"), + R"doc(In-place logical OR of two tensors.)doc"); +} + +} // namespace infinicore::ops + diff --git a/src/infinicore/pybind11/ops/logical_xor.hpp b/src/infinicore/pybind11/ops/logical_xor.hpp new file mode 100644 index 000000000..53ef5361b --- /dev/null +++ b/src/infinicore/pybind11/ops/logical_xor.hpp @@ -0,0 +1,27 @@ +#pragma once + +#include + +#include "infinicore/ops/logical_xor.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_logical_xor(py::module &m) { + m.def("logical_xor", + &op::logical_xor, + py::arg("a"), + py::arg("b"), + R"doc(Logical XOR of two tensors.)doc"); + + m.def("logical_xor_", + &op::logical_xor_, + py::arg("c"), + py::arg("a"), + py::arg("b"), + R"doc(In-place logical XOR of two tensors.)doc"); +} + +} // namespace infinicore::ops + diff --git a/src/infinicore/pybind11/ops/logsigmoid.hpp b/src/infinicore/pybind11/ops/logsigmoid.hpp new file mode 100644 index 000000000..7a56de01d --- /dev/null +++ b/src/infinicore/pybind11/ops/logsigmoid.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include + +#include "infinicore/ops/logsigmoid.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_logsigmoid(py::module &m) { + m.def("logsigmoid", + &op::logsigmoid, + py::arg("input"), + R"doc(LogSigmoid activation function.)doc"); + + m.def("logsigmoid_", + &op::logsigmoid_, + py::arg("output"), + py::arg("input"), + R"doc(In-place LogSigmoid activation function.)doc"); +} + +} // namespace infinicore::ops + diff --git a/src/infinicore/pybind11/ops/where.hpp b/src/infinicore/pybind11/ops/where.hpp new file mode 100644 index 000000000..9c067b4d6 --- /dev/null +++ b/src/infinicore/pybind11/ops/where.hpp @@ -0,0 +1,30 @@ +#pragma once + +#include + +#include "infinicore/ops/where.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_where(py::module &m) { + m.def("where", + &op::where, + py::arg("cond"), + py::arg("x"), + py::arg("y"), + R"doc(Elementwise where(cond, x, y) selection.)doc"); + + m.def("where_", + &op::where_, + py::arg("out"), + py::arg("cond"), + py::arg("x"), + py::arg("y"), + R"doc(In-place elementwise where(cond, x, y) selection into out tensor.)doc"); +} + +} // namespace infinicore::ops + + diff --git a/src/infiniop/ops/logical_or/cpu/logical_or_cpu.cc b/src/infiniop/ops/logical_or/cpu/logical_or_cpu.cc new file mode 100644 index 000000000..1225ca125 --- /dev/null +++ b/src/infiniop/ops/logical_or/cpu/logical_or_cpu.cc @@ -0,0 +1,116 @@ +#include "logical_or_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include +#include + +namespace op::logical_or::cpu { + +// Track input dtype for each Descriptor so we can support both +// - out-of-place / inplace(out) with bool output +// - inplace(a) / inplace(b) with output dtype == input dtype (int32/uint8) +static std::unordered_map g_input_dtype; +static std::mutex g_input_dtype_mutex; + +Descriptor::~Descriptor() { + std::lock_guard lock(g_input_dtype_mutex); + g_input_dtype.erase(this); +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + + // True output dtype (memory layout of output buffer) + auto out_dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + // Input dtype(s) + auto a_dtype = a_desc->dtype(); + auto b_dtype = b_desc->dtype(); + + // Inputs must have the same dtype and be one of the supported ones + CHECK_OR_RETURN(a_dtype == b_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(a_dtype, INFINI_DTYPE_BOOL, INFINI_DTYPE_I32, INFINI_DTYPE_U8); + + // Output must be either bool (standard case) or equal to input dtype + if (!(out_dtype == INFINI_DTYPE_BOOL || out_dtype == a_dtype)) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + + // Here we pass the *output* dtype into the descriptor, keeping the + // semantics of CREATE_ELEMENTWISE_CPU_DESCRIPTOR consistent with + // other elementwise ops. + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, out_dtype, out_desc, input_desc_vec); + + // Remember the common input dtype for this descriptor + { + std::lock_guard lock(g_input_dtype_mutex); + g_input_dtype[*desc_ptr] = a_dtype; + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + // _dtype now represents the output dtype (as in other elementwise ops) + auto out_dtype = _dtype; + + // Look up the input dtype for this descriptor + infiniDtype_t in_dtype; + { + std::lock_guard lock(g_input_dtype_mutex); + auto it = g_input_dtype.find(this); + if (it == g_input_dtype.end()) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + in_dtype = it->second; + } + + // Case 1: boolean inputs (and thus boolean output) + if (in_dtype == INFINI_DTYPE_BOOL) { + // All bool: we can use the homogeneous path + return _device_info->calculate(_info, output, inputs, stream); + } + + // Case 2: integer inputs (int32 / uint8) + if (in_dtype == INFINI_DTYPE_I32) { + if (out_dtype == INFINI_DTYPE_BOOL) { + // Inputs int32, output bool + return _device_info->calculate(_info, output, inputs, stream); + } else if (out_dtype == INFINI_DTYPE_I32) { + // Inplace(a/b): inputs and output are int32 + // Use homogeneous path; LogicalOrOp returns bool which is + // implicitly converted to int32 (0/1). + return _device_info->calculate(_info, output, inputs, stream); + } + } else if (in_dtype == INFINI_DTYPE_U8) { + if (out_dtype == INFINI_DTYPE_BOOL) { + // Inputs uint8, output bool + return _device_info->calculate(_info, output, inputs, stream); + } else if (out_dtype == INFINI_DTYPE_U8) { + // Inplace(a/b): inputs and output are uint8 + return _device_info->calculate(_info, output, inputs, stream); + } + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} +} // namespace op::logical_or::cpu + diff --git a/src/infiniop/ops/logical_or/cpu/logical_or_cpu.h b/src/infiniop/ops/logical_or/cpu/logical_or_cpu.h new file mode 100644 index 000000000..36c862570 --- /dev/null +++ b/src/infiniop/ops/logical_or/cpu/logical_or_cpu.h @@ -0,0 +1,25 @@ +#ifndef __LOGICAL_OR_CPU_H__ +#define __LOGICAL_OR_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(logical_or, cpu) + +namespace op::logical_or::cpu { +typedef struct LogicalOrOp { +public: + static constexpr size_t num_inputs = 2; + template + bool operator()(const T &a, const T &b) const { + return static_cast(a) || static_cast(b); + } + // Support heterogeneous input types for elementwise framework + template + Tout operator()(const Ta &a, const Tb &b) const { + return static_cast(static_cast(a) || static_cast(b)); + } +} LogicalOrOp; +} // namespace op::logical_or::cpu + +#endif // __LOGICAL_OR_CPU_H__ + diff --git a/src/infiniop/ops/logical_or/operator.cc b/src/infiniop/ops/logical_or/operator.cc new file mode 100644 index 000000000..ee8d4f18f --- /dev/null +++ b/src/infiniop/ops/logical_or/operator.cc @@ -0,0 +1,99 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/logical_or.h" + +#ifdef ENABLE_CPU_API +#include "cpu/logical_or_cpu.h" +#endif + +__C infiniStatus_t infiniopCreateLogicalOrDescriptor( + infiniopHandle_t handle, + infiniopLogicalOrDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::logical_or::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + c_desc, \ + {a_desc, b_desc}) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetLogicalOrWorkspaceSize(infiniopLogicalOrDescriptor_t desc, size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopLogicalOr( + infiniopLogicalOrDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, c, {a, b}, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyLogicalOrDescriptor(infiniopLogicalOrDescriptor_t desc) { +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + diff --git a/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.cc b/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.cc new file mode 100644 index 000000000..bcb8ce34f --- /dev/null +++ b/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.cc @@ -0,0 +1,116 @@ +#include "logical_xor_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include +#include + +namespace op::logical_xor::cpu { + +// Track input dtype for each Descriptor so we can support both +// - out-of-place / inplace(out) with bool output +// - inplace(a) / inplace(b) with output dtype == input dtype (int32/uint8) +static std::unordered_map g_input_dtype; +static std::mutex g_input_dtype_mutex; + +Descriptor::~Descriptor() { + std::lock_guard lock(g_input_dtype_mutex); + g_input_dtype.erase(this); +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + + // True output dtype (memory layout of output buffer) + auto out_dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + // Input dtype(s) + auto a_dtype = a_desc->dtype(); + auto b_dtype = b_desc->dtype(); + + // Inputs must have the same dtype and be one of the supported ones + CHECK_OR_RETURN(a_dtype == b_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(a_dtype, INFINI_DTYPE_BOOL, INFINI_DTYPE_I32, INFINI_DTYPE_U8); + + // Output must be either bool (standard case) or equal to input dtype + if (!(out_dtype == INFINI_DTYPE_BOOL || out_dtype == a_dtype)) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + + // Here we pass the *output* dtype into the descriptor, keeping the + // semantics of CREATE_ELEMENTWISE_CPU_DESCRIPTOR consistent with + // other elementwise ops. + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, out_dtype, out_desc, input_desc_vec); + + // Remember the common input dtype for this descriptor + { + std::lock_guard lock(g_input_dtype_mutex); + g_input_dtype[*desc_ptr] = a_dtype; + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + // _dtype now represents the output dtype (as in other elementwise ops) + auto out_dtype = _dtype; + + // Look up the input dtype for this descriptor + infiniDtype_t in_dtype; + { + std::lock_guard lock(g_input_dtype_mutex); + auto it = g_input_dtype.find(this); + if (it == g_input_dtype.end()) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + in_dtype = it->second; + } + + // Case 1: boolean inputs (and thus boolean output) + if (in_dtype == INFINI_DTYPE_BOOL) { + // All bool: we can use the homogeneous path + return _device_info->calculate(_info, output, inputs, stream); + } + + // Case 2: integer inputs (int32 / uint8) + if (in_dtype == INFINI_DTYPE_I32) { + if (out_dtype == INFINI_DTYPE_BOOL) { + // Inputs int32, output bool + return _device_info->calculate(_info, output, inputs, stream); + } else if (out_dtype == INFINI_DTYPE_I32) { + // Inplace(a/b): inputs and output are int32 + // Use homogeneous path; LogicalXorOp returns bool which is + // implicitly converted to int32 (0/1). + return _device_info->calculate(_info, output, inputs, stream); + } + } else if (in_dtype == INFINI_DTYPE_U8) { + if (out_dtype == INFINI_DTYPE_BOOL) { + // Inputs uint8, output bool + return _device_info->calculate(_info, output, inputs, stream); + } else if (out_dtype == INFINI_DTYPE_U8) { + // Inplace(a/b): inputs and output are uint8 + return _device_info->calculate(_info, output, inputs, stream); + } + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} +} // namespace op::logical_xor::cpu + diff --git a/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.h b/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.h new file mode 100644 index 000000000..3ca53f889 --- /dev/null +++ b/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.h @@ -0,0 +1,25 @@ +#ifndef __LOGICAL_XOR_CPU_H__ +#define __LOGICAL_XOR_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(logical_xor, cpu) + +namespace op::logical_xor::cpu { +typedef struct LogicalXorOp { +public: + static constexpr size_t num_inputs = 2; + template + bool operator()(const T &a, const T &b) const { + return static_cast(a) != static_cast(b); + } + // Support heterogeneous input types for elementwise framework + template + Tout operator()(const Ta &a, const Tb &b) const { + return static_cast(static_cast(a) != static_cast(b)); + } +} LogicalXorOp; +} // namespace op::logical_xor::cpu + +#endif // __LOGICAL_XOR_CPU_H__ + diff --git a/src/infiniop/ops/logical_xor/operator.cc b/src/infiniop/ops/logical_xor/operator.cc new file mode 100644 index 000000000..658e6a9fe --- /dev/null +++ b/src/infiniop/ops/logical_xor/operator.cc @@ -0,0 +1,99 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/logical_xor.h" + +#ifdef ENABLE_CPU_API +#include "cpu/logical_xor_cpu.h" +#endif + +__C infiniStatus_t infiniopCreateLogicalXorDescriptor( + infiniopHandle_t handle, + infiniopLogicalXorDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::logical_xor::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + c_desc, \ + {a_desc, b_desc}) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetLogicalXorWorkspaceSize(infiniopLogicalXorDescriptor_t desc, size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopLogicalXor( + infiniopLogicalXorDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, c, {a, b}, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyLogicalXorDescriptor(infiniopLogicalXorDescriptor_t desc) { +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + diff --git a/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.cc b/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.cc new file mode 100644 index 000000000..f765c7bb7 --- /dev/null +++ b/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.cc @@ -0,0 +1,52 @@ +#include "logsigmoid_cpu.h" + +namespace op::logsigmoid::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + CHECK_SAME_SHAPE(y_shape, x_shape); + + // create CPU elementwise descriptor + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate(_info, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::logsigmoid::cpu + diff --git a/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.h b/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.h new file mode 100644 index 000000000..8d6484d27 --- /dev/null +++ b/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.h @@ -0,0 +1,21 @@ +#ifndef __LOGSIGMOID_CPU_H__ +#define __LOGSIGMOID_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(logsigmoid, cpu) + +namespace op::logsigmoid::cpu { +typedef struct LogSigmoidOp { +public: + static constexpr size_t num_inputs = 1; + template + T operator()(const T &x) const { + // logsigmoid(x) = log(sigmoid(x)) = -log(1 + exp(-x)) + return -std::log(T(1) + std::exp(-x)); + } +} LogSigmoidOp; +} // namespace op::logsigmoid::cpu + +#endif // __LOGSIGMOID_CPU_H__ + diff --git a/src/infiniop/ops/logsigmoid/operator.cc b/src/infiniop/ops/logsigmoid/operator.cc new file mode 100644 index 000000000..4d0283994 --- /dev/null +++ b/src/infiniop/ops/logsigmoid/operator.cc @@ -0,0 +1,101 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/logsigmoid.h" + +#ifdef ENABLE_CPU_API +#include "cpu/logsigmoid_cpu.h" +#endif + +__C infiniStatus_t infiniopCreateLogSigmoidDescriptor( + infiniopHandle_t handle, + infiniopLogSigmoidDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::logsigmoid::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + {x_desc}) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetLogSigmoidWorkspaceSize(infiniopLogSigmoidDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopLogSigmoid( + infiniopLogSigmoidDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, y, {x}, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroyLogSigmoidDescriptor(infiniopLogSigmoidDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + diff --git a/src/infiniop/ops/where/cpu/where_cpu.cc b/src/infiniop/ops/where/cpu/where_cpu.cc new file mode 100644 index 000000000..26befdaf3 --- /dev/null +++ b/src/infiniop/ops/where/cpu/where_cpu.cc @@ -0,0 +1,92 @@ +#include "where_cpu.h" +#include "../../../devices/cpu/common_cpu.h" + +namespace op::where::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + + // Expected inputs: 0: cond, 1: x, 2: y + if (input_desc_vec.size() != 3) { + return INFINI_STATUS_BAD_PARAM; + } + + auto out_dtype = out_desc->dtype(); + + const auto &cond_desc = input_desc_vec.at(0); + const auto &x_desc = input_desc_vec.at(1); + const auto &y_desc = input_desc_vec.at(2); + + const auto &out_shape = out_desc->shape(); + const auto &cond_shape = cond_desc->shape(); + const auto &x_shape = x_desc->shape(); + const auto &y_shape = y_desc->shape(); + + // cond must be bool + CHECK_DTYPE(cond_desc->dtype(), INFINI_DTYPE_BOOL); + + // x, y and output must share the same dtype + auto x_dtype = x_desc->dtype(); + auto y_dtype = y_desc->dtype(); + CHECK_OR_RETURN(x_dtype == y_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(x_dtype == out_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + // Supported value dtypes (extend if needed) + CHECK_DTYPE( + out_dtype, + INFINI_DTYPE_F16, + INFINI_DTYPE_F32, + INFINI_DTYPE_F64, + INFINI_DTYPE_BF16, + INFINI_DTYPE_I32, + INFINI_DTYPE_I64, + INFINI_DTYPE_U8); + + // For now, require all shapes to match (no broadcasting) + CHECK_SAME_SHAPE(out_shape, cond_shape, x_shape, y_shape); + + // Create CPU elementwise descriptor with output dtype + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, out_dtype, out_desc, input_desc_vec); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_I32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_I64: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_U8: + return _device_info->calculate(_info, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::where::cpu + + diff --git a/src/infiniop/ops/where/cpu/where_cpu.h b/src/infiniop/ops/where/cpu/where_cpu.h new file mode 100644 index 000000000..6b4399a7c --- /dev/null +++ b/src/infiniop/ops/where/cpu/where_cpu.h @@ -0,0 +1,33 @@ +#ifndef __WHERE_CPU_H__ +#define __WHERE_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +// Define Descriptor: op::where::cpu::Descriptor +ELEMENTWISE_DESCRIPTOR(where, cpu) + +namespace op::where::cpu { + +struct WhereOp { +public: + // Three inputs: cond, x, y + static constexpr size_t num_inputs = 3; + + // Homogeneous version: cond is already bool, x/y have same type as output + template + T operator()(const bool &cond, const T &x, const T &y) const { + return cond ? x : y; + } + + // Heterogeneous version: support non-bool cond or explicit Tout + template + Tout operator()(const Tcond &cond, const Tx &x, const Ty &y) const { + return static_cast(cond) ? static_cast(x) : static_cast(y); + } +}; + +} // namespace op::where::cpu + +#endif // __WHERE_CPU_H__ + + diff --git a/src/infiniop/ops/where/operator.cc b/src/infiniop/ops/where/operator.cc new file mode 100644 index 000000000..c966af409 --- /dev/null +++ b/src/infiniop/ops/where/operator.cc @@ -0,0 +1,106 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/where.h" + +#ifdef ENABLE_CPU_API +#include "cpu/where_cpu.h" +#endif + +__C infiniStatus_t infiniopCreateWhereDescriptor( + infiniopHandle_t handle, + infiniopWhereDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t cond_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t y_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::where::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ + {cond_desc, x_desc, y_desc}) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetWhereWorkspaceSize( + infiniopWhereDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopWhere( + infiniopWhereDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *cond, + const void *x, + const void *y, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, out, {cond, x, y}, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + + diff --git a/test/infinicore/ops/where.py b/test/infinicore/ops/where.py index bc0013bd6..d1fc6fd50 100644 --- a/test/infinicore/ops/where.py +++ b/test/infinicore/ops/where.py @@ -71,9 +71,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.where(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.where(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.where(*args, **kwargs) def main(): From 40320697fb871e506c67295fd35533e4e4ef1ee3 Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Mon, 15 Dec 2025 20:29:59 +0800 Subject: [PATCH 03/17] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20vdot=20cpu?= =?UTF-8?q?=20=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops/vdot.hpp | 20 ++++ include/infiniop.h | 1 + include/infiniop/ops/vdot.h | 34 ++++++ python/infinicore/__init__.py | 2 + python/infinicore/ops/vdot.py | 12 ++ src/infinicore/ops/vdot/vdot.cc | 50 +++++++++ src/infinicore/ops/vdot/vdot_infiniop.cc | 54 +++++++++ src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/vdot.hpp | 21 ++++ src/infiniop/ops/vdot/cpu/vdot_cpu.cc | 137 +++++++++++++++++++++++ src/infiniop/ops/vdot/cpu/vdot_cpu.h | 56 +++++++++ src/infiniop/ops/vdot/operator.cc | 98 ++++++++++++++++ test/infinicore/ops/vdot.py | 6 +- 13 files changed, 490 insertions(+), 3 deletions(-) create mode 100644 include/infinicore/ops/vdot.hpp create mode 100644 include/infiniop/ops/vdot.h create mode 100644 python/infinicore/ops/vdot.py create mode 100644 src/infinicore/ops/vdot/vdot.cc create mode 100644 src/infinicore/ops/vdot/vdot_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/vdot.hpp create mode 100644 src/infiniop/ops/vdot/cpu/vdot_cpu.cc create mode 100644 src/infiniop/ops/vdot/cpu/vdot_cpu.h create mode 100644 src/infiniop/ops/vdot/operator.cc diff --git a/include/infinicore/ops/vdot.hpp b/include/infinicore/ops/vdot.hpp new file mode 100644 index 000000000..ba2b9550c --- /dev/null +++ b/include/infinicore/ops/vdot.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { + +class Vdot { +public: + using schema = void (*)(Tensor out, Tensor a, Tensor b); + static void execute(Tensor out, Tensor a, Tensor b); + static common::OpDispatcher &dispatcher(); +}; + +Tensor vdot(Tensor a, Tensor b); +void vdot_(Tensor out, Tensor a, Tensor b); + +} // namespace infinicore::op + + diff --git a/include/infiniop.h b/include/infiniop.h index 66f15df70..cc221682d 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -33,6 +33,7 @@ #include "infiniop/ops/topkrouter.h" #include "infiniop/ops/topksoftmax.h" #include "infiniop/ops/where.h" +#include "infiniop/ops/vdot.h" #include "infiniop/ops/zeros.h" #include "infiniop/tensor_descriptor.h" diff --git a/include/infiniop/ops/vdot.h b/include/infiniop/ops/vdot.h new file mode 100644 index 000000000..b486b8e72 --- /dev/null +++ b/include/infiniop/ops/vdot.h @@ -0,0 +1,34 @@ +#ifndef __INFINIOP_VDOT_API_H__ +#define __INFINIOP_VDOT_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopVdotDescriptor_t; + +// out = vdot(a, b) +__C __export infiniStatus_t infiniopCreateVdotDescriptor( + infiniopHandle_t handle, + infiniopVdotDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + +__C __export infiniStatus_t infiniopGetVdotWorkspaceSize( + infiniopVdotDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopVdot( + infiniopVdotDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + void *stream); + +__C __export infiniStatus_t infiniopDestroyVdotDescriptor( + infiniopVdotDescriptor_t desc); + +#endif // __INFINIOP_VDOT_API_H__ + + diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index dcf8f4af3..00aa7d066 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -45,6 +45,7 @@ from infinicore.ops.logical_xor import logical_xor from infinicore.ops.logsigmoid import logsigmoid from infinicore.ops.where import where +from infinicore.ops.vdot import vdot from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow @@ -109,6 +110,7 @@ "logical_xor", "logsigmoid", "where", + "vdot", "matmul", "mul", "narrow", diff --git a/python/infinicore/ops/vdot.py b/python/infinicore/ops/vdot.py new file mode 100644 index 000000000..9085b4ae5 --- /dev/null +++ b/python/infinicore/ops/vdot.py @@ -0,0 +1,12 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def vdot(a: Tensor, b: Tensor) -> Tensor: + """ + InfiniCore vdot: 1D vector dot product, aligned with torch.vdot + for real-valued tensors (no complex conjugation). + """ + return Tensor(_infinicore.vdot(a._underlying, b._underlying)) + + diff --git a/src/infinicore/ops/vdot/vdot.cc b/src/infinicore/ops/vdot/vdot.cc new file mode 100644 index 000000000..822b86c4a --- /dev/null +++ b/src/infinicore/ops/vdot/vdot.cc @@ -0,0 +1,50 @@ +#include "infinicore/ops/vdot.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Vdot::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void Vdot::execute(Tensor out, Tensor a, Tensor b) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, a, b); + + // Inputs must be 1D and same length + if (a->ndim() != 1 || b->ndim() != 1) { + throw std::runtime_error("vdot: input tensors must be 1D"); + } + if (a->shape()[0] != b->shape()[0]) { + throw std::runtime_error("vdot: input tensors must have the same length"); + } + + infinicore::context::setDevice(out->device()); + auto device_type = out->device().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error( + "No Vdot implementation found for device type: " + + std::to_string(static_cast(device_type))); + } + + func(out, a, b); +} + +Tensor vdot(Tensor a, Tensor b) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(a, b); + + // Output dtype equals input dtype for now + auto out = Tensor::empty({}, a->dtype(), a->device()); + vdot_(out, a, b); + return out; +} + +void vdot_(Tensor out, Tensor a, Tensor b) { + Vdot::execute(out, a, b); +} + +} // namespace infinicore::op + + diff --git a/src/infinicore/ops/vdot/vdot_infiniop.cc b/src/infinicore/ops/vdot/vdot_infiniop.cc new file mode 100644 index 000000000..866b1a044 --- /dev/null +++ b/src/infinicore/ops/vdot/vdot_infiniop.cc @@ -0,0 +1,54 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/vdot.hpp" +#include + +namespace infinicore::op::vdot_impl::infiniop { + +thread_local common::OpCache caches( + 100, + [](infiniopVdotDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyVdotDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor out, Tensor a, Tensor b) { + size_t seed = hash_combine(out, a, b); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopVdotDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateVdotDescriptor( + context::getInfiniopHandle(out->device()), &desc, + out->desc(), a->desc(), b->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetVdotWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopVdot( + desc, workspace->data(), workspace_size, + out->data(), a->data(), b->data(), context::getStream())); +} + +static bool registered = []() { + Vdot::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::vdot_impl::infiniop + + diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index f8373b37f..a015574a2 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -19,6 +19,7 @@ #include "ops/silu.hpp" #include "ops/where.hpp" #include "ops/swiglu.hpp" +#include "ops/vdot.hpp" namespace py = pybind11; @@ -42,6 +43,7 @@ inline void bind(py::module &m) { bind_rope(m); bind_embedding(m); bind_where(m); + bind_vdot(m); } } // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/vdot.hpp b/src/infinicore/pybind11/ops/vdot.hpp new file mode 100644 index 000000000..63a4f85c0 --- /dev/null +++ b/src/infinicore/pybind11/ops/vdot.hpp @@ -0,0 +1,21 @@ +#pragma once + +#include + +#include "infinicore/ops/vdot.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_vdot(py::module &m) { + m.def("vdot", + &op::vdot, + py::arg("a"), + py::arg("b"), + R"doc(Vector dot product for 1D tensors (real-valued).)doc"); +} + +} // namespace infinicore::ops + + diff --git a/src/infiniop/ops/vdot/cpu/vdot_cpu.cc b/src/infiniop/ops/vdot/cpu/vdot_cpu.cc new file mode 100644 index 000000000..13682ef48 --- /dev/null +++ b/src/infiniop/ops/vdot/cpu/vdot_cpu.cc @@ -0,0 +1,137 @@ +#include "vdot_cpu.h" +#include "../../../../utils.h" + +namespace op::vdot::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + + auto handle = reinterpret_cast(handle_); + auto in_dtype = a_desc->dtype(); + auto b_dtype = b_desc->dtype(); + auto out_dtype = out_desc->dtype(); + + // Inputs must be 1D vectors with same length + if (a_desc->ndim() != 1 || b_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (a_desc->numel() != b_desc->numel()) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // Input dtypes must match and be in supported set + CHECK_OR_RETURN(in_dtype == b_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE( + in_dtype, + INFINI_DTYPE_F16, + INFINI_DTYPE_F32, + INFINI_DTYPE_BF16); + + // Simplest: output dtype equals input dtype + CHECK_OR_RETURN(out_dtype == in_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + size_t length = a_desc->numel(); + ptrdiff_t a_stride = a_desc->stride(0); + ptrdiff_t b_stride = b_desc->stride(0); + + *desc_ptr = new Descriptor( + in_dtype, + out_dtype, + length, + a_stride, + b_stride, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +template +static void vdot_impl( + Tout *out, + const Tin *a_base, + const Tin *b_base, + size_t n, + ptrdiff_t a_stride, + ptrdiff_t b_stride) { + + if constexpr (std::is_same_v || std::is_same_v) { + // Accumulate in float for half/bfloat16, then cast back + float acc = 0.0f; + +#pragma omp parallel for reduction(+ : acc) + for (ptrdiff_t i = 0; i < static_cast(n); ++i) { + const Tin &av = a_base[i * a_stride]; + const Tin &bv = b_base[i * b_stride]; + float av_f = utils::cast(av); + float bv_f = utils::cast(bv); + acc += av_f * bv_f; + } + + *out = utils::cast(acc); + } else { + Tout acc{}; + +#pragma omp parallel for reduction(+ : acc) + for (ptrdiff_t i = 0; i < static_cast(n); ++i) { + const Tin &av = a_base[i * a_stride]; + const Tin &bv = b_base[i * b_stride]; + acc += static_cast(av) * static_cast(bv); + } + + *out = acc; + } +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + void *stream) const { + + switch (_in_dtype) { + case INFINI_DTYPE_F16: + vdot_impl( + reinterpret_cast(out), + reinterpret_cast(a), + reinterpret_cast(b), + _length, + _a_stride, + _b_stride); + break; + case INFINI_DTYPE_F32: + vdot_impl( + reinterpret_cast(out), + reinterpret_cast(a), + reinterpret_cast(b), + _length, + _a_stride, + _b_stride); + break; + case INFINI_DTYPE_BF16: + vdot_impl( + reinterpret_cast(out), + reinterpret_cast(a), + reinterpret_cast(b), + _length, + _a_stride, + _b_stride); + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::vdot::cpu + + diff --git a/src/infiniop/ops/vdot/cpu/vdot_cpu.h b/src/infiniop/ops/vdot/cpu/vdot_cpu.h new file mode 100644 index 000000000..528901ba0 --- /dev/null +++ b/src/infiniop/ops/vdot/cpu/vdot_cpu.h @@ -0,0 +1,56 @@ +#ifndef __VDOT_CPU_H__ +#define __VDOT_CPU_H__ + +#include "../../../devices/cpu/common_cpu.h" +#include "../../../tensor.h" +#include "../../../operator.h" + +namespace op::vdot::cpu { + +class Descriptor final : public InfiniopDescriptor { + infiniDtype_t _in_dtype; + infiniDtype_t _out_dtype; + size_t _length; + ptrdiff_t _a_stride; + ptrdiff_t _b_stride; + +public: + Descriptor(infiniDtype_t in_dtype, + infiniDtype_t out_dtype, + size_t length, + ptrdiff_t a_stride, + ptrdiff_t b_stride, + infiniDevice_t device_type, + int device_id) + : InfiniopDescriptor{device_type, device_id}, + _in_dtype(in_dtype), + _out_dtype(out_dtype), + _length(length), + _a_stride(a_stride), + _b_stride(b_stride) {} + + ~Descriptor(); + + size_t workspaceSize() const { return 0; } + + static infiniStatus_t create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + + infiniStatus_t calculate( + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + void *stream) const; +}; + +} // namespace op::vdot::cpu + +#endif // __VDOT_CPU_H__ + + diff --git a/src/infiniop/ops/vdot/operator.cc b/src/infiniop/ops/vdot/operator.cc new file mode 100644 index 000000000..2bc570ab8 --- /dev/null +++ b/src/infiniop/ops/vdot/operator.cc @@ -0,0 +1,98 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/vdot.h" + +#ifdef ENABLE_CPU_API +#include "cpu/vdot_cpu.h" +#endif + +__C infiniStatus_t infiniopCreateVdotDescriptor( + infiniopHandle_t handle, + infiniopVdotDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::vdot::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ + a_desc, b_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetVdotWorkspaceSize( + infiniopVdotDescriptor_t desc, + size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopVdot( + infiniopVdotDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + const void *a, + const void *b, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, out, a, b, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyVdotDescriptor(infiniopVdotDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE(INFINI_DEVICE_CPU, cpu); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + + diff --git a/test/infinicore/ops/vdot.py b/test/infinicore/ops/vdot.py index 2baf715f0..e294bc6dc 100644 --- a/test/infinicore/ops/vdot.py +++ b/test/infinicore/ops/vdot.py @@ -62,9 +62,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.vdot(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.vdot(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation.""" + return infinicore.vdot(*args, **kwargs) def main(): From 2672106ea916eb44c9075fac689d2e82d7bbae39 Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 16 Dec 2025 01:14:42 +0800 Subject: [PATCH 04/17] =?UTF-8?q?feat:=20=E9=80=82=E9=85=8Dlogical=5For\xo?= =?UTF-8?q?r,=20logsigmoid=20notp=20=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/infinicore/nn/functional/logsigmoid.py | 3 +++ python/infinicore/nn/functional/silu.py | 2 +- python/infinicore/ops/logical_or.py | 4 ++++ python/infinicore/ops/logical_xor.py | 5 +++++ python/infinicore/ops/logsigmoid.py | 4 ++++ 5 files changed, 17 insertions(+), 1 deletion(-) diff --git a/python/infinicore/nn/functional/logsigmoid.py b/python/infinicore/nn/functional/logsigmoid.py index bb327f5b6..c1bcf2b63 100644 --- a/python/infinicore/nn/functional/logsigmoid.py +++ b/python/infinicore/nn/functional/logsigmoid.py @@ -1,9 +1,12 @@ +import infinicore from infinicore.lib import _infinicore from infinicore.tensor import Tensor def logsigmoid(input: Tensor, out=None) -> Tensor: """Apply elementwise log-sigmoid.""" + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.logsigmoid(input, out = out) if out is None: return Tensor(_infinicore.logsigmoid(input._underlying)) diff --git a/python/infinicore/nn/functional/silu.py b/python/infinicore/nn/functional/silu.py index 36c7e4e3e..b2f97d0c0 100644 --- a/python/infinicore/nn/functional/silu.py +++ b/python/infinicore/nn/functional/silu.py @@ -6,7 +6,7 @@ def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor: r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.""" - if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None: + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): return infinicore.ntops.torch.silu(input, inplace=inplace) if inplace: diff --git a/python/infinicore/ops/logical_or.py b/python/infinicore/ops/logical_or.py index 5016046a2..01024413e 100644 --- a/python/infinicore/ops/logical_or.py +++ b/python/infinicore/ops/logical_or.py @@ -1,8 +1,12 @@ +import infinicore from infinicore.lib import _infinicore from infinicore.tensor import Tensor def logical_or(input, other, *, out=None): + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.logical_or(input, other, out = out) + if out is None: return Tensor(_infinicore.logical_or(input._underlying, other._underlying)) diff --git a/python/infinicore/ops/logical_xor.py b/python/infinicore/ops/logical_xor.py index c4029f6bb..b11380844 100644 --- a/python/infinicore/ops/logical_xor.py +++ b/python/infinicore/ops/logical_xor.py @@ -1,8 +1,13 @@ +import infinicore from infinicore.lib import _infinicore from infinicore.tensor import Tensor def logical_xor(input, other, *, out=None): + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.logical_xor(input, other, out = out) + + if out is None: return Tensor(_infinicore.logical_xor(input._underlying, other._underlying)) diff --git a/python/infinicore/ops/logsigmoid.py b/python/infinicore/ops/logsigmoid.py index 128ced942..cb22c9e70 100644 --- a/python/infinicore/ops/logsigmoid.py +++ b/python/infinicore/ops/logsigmoid.py @@ -1,8 +1,12 @@ +import infinicore from infinicore.lib import _infinicore from infinicore.tensor import Tensor def logsigmoid(input, *, out=None): + if infinicore.use_ntops and input.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.logsigmoid(input, out = out) + if out is None: return Tensor(_infinicore.logsigmoid(input._underlying)) From ef08d1e1ce5fcbfccdf5d09ce8b59bbdc1fb6300 Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 16 Dec 2025 15:18:38 +0800 Subject: [PATCH 05/17] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=20vdot=20?= =?UTF-8?q?=E5=90=84=E4=B8=AA=E5=B9=B3=E5=8F=B0=E7=9A=84GPU=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/infinicore/ops/vdot.py | 1 + src/infiniop/ops/vdot/cuda/kernel.cuh | 36 ++++ src/infiniop/ops/vdot/metax/vdot_metax.cu | 129 ++++++++++++++ src/infiniop/ops/vdot/metax/vdot_metax.h | 46 +++++ src/infiniop/ops/vdot/moore/vdot_moore.cu | 136 +++++++++++++++ src/infiniop/ops/vdot/moore/vdot_moore.h | 46 +++++ src/infiniop/ops/vdot/nvidia/vdot_nvidia.cu | 167 +++++++++++++++++++ src/infiniop/ops/vdot/nvidia/vdot_nvidia.cuh | 47 ++++++ src/infiniop/ops/vdot/operator.cc | 69 ++++++++ 9 files changed, 677 insertions(+) create mode 100644 src/infiniop/ops/vdot/cuda/kernel.cuh create mode 100644 src/infiniop/ops/vdot/metax/vdot_metax.cu create mode 100644 src/infiniop/ops/vdot/metax/vdot_metax.h create mode 100644 src/infiniop/ops/vdot/moore/vdot_moore.cu create mode 100644 src/infiniop/ops/vdot/moore/vdot_moore.h create mode 100644 src/infiniop/ops/vdot/nvidia/vdot_nvidia.cu create mode 100644 src/infiniop/ops/vdot/nvidia/vdot_nvidia.cuh diff --git a/python/infinicore/ops/vdot.py b/python/infinicore/ops/vdot.py index 9085b4ae5..2fd9fd7e6 100644 --- a/python/infinicore/ops/vdot.py +++ b/python/infinicore/ops/vdot.py @@ -1,3 +1,4 @@ +import infinicore from infinicore.lib import _infinicore from infinicore.tensor import Tensor diff --git a/src/infiniop/ops/vdot/cuda/kernel.cuh b/src/infiniop/ops/vdot/cuda/kernel.cuh new file mode 100644 index 000000000..6fe772406 --- /dev/null +++ b/src/infiniop/ops/vdot/cuda/kernel.cuh @@ -0,0 +1,36 @@ +#ifndef __VDOT_CUDA_KERNEL_CUH__ +#define __VDOT_CUDA_KERNEL_CUH__ + +#include + +namespace op::vdot::cuda { + +template +__global__ void vdotKernel(Tcompute *out, const Tdata *a, const Tdata *b, + size_t length, ptrdiff_t a_stride, + ptrdiff_t b_stride) { + + Tcompute dot = 0; + + // Each thread computes its partial dot product + for (size_t i = threadIdx.x; i < length; i += BLOCK_SIZE) { + Tcompute a_val = Tcompute(a[i * a_stride]); + Tcompute b_val = Tcompute(b[i * b_stride]); + dot += a_val * b_val; + } + + // Use CUB block-level reduction + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + Tcompute block_dot = BlockReduce(temp_storage).Sum(dot); + + // Thread 0 writes the result + if (threadIdx.x == 0) { + *out = block_dot; + } +} + +} // namespace op::vdot::cuda + +#endif // __VDOT_CUDA_KERNEL_CUH__ diff --git a/src/infiniop/ops/vdot/metax/vdot_metax.cu b/src/infiniop/ops/vdot/metax/vdot_metax.cu new file mode 100644 index 000000000..a258625e5 --- /dev/null +++ b/src/infiniop/ops/vdot/metax/vdot_metax.cu @@ -0,0 +1,129 @@ +#include "../../../devices/metax/metax_handle.h" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "vdot_metax.h" +#include +#include + +namespace op::vdot::metax { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create(infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + + auto handle = reinterpret_cast(handle_); + auto in_dtype = a_desc->dtype(); + auto b_dtype = b_desc->dtype(); + auto out_dtype = out_desc->dtype(); + + // Inputs must be 1D vectors with same length + if (a_desc->ndim() != 1 || b_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (a_desc->numel() != b_desc->numel()) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // Input dtypes must match and be in supported set + CHECK_OR_RETURN(in_dtype == b_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(in_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, + INFINI_DTYPE_BF16); + + // Output dtype equals input dtype + CHECK_OR_RETURN(out_dtype == in_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + size_t length = a_desc->numel(); + ptrdiff_t a_stride = a_desc->stride(0); + ptrdiff_t b_stride = b_desc->stride(0); + + *desc_ptr = new Descriptor(in_dtype, out_dtype, length, a_stride, b_stride, + handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, + void *out, const void *a, const void *b, + void *stream) const { + + auto cuda_stream = reinterpret_cast(stream); + constexpr unsigned int BLOCK_SIZE = 256; + + switch (_in_dtype) { + case INFINI_DTYPE_F32: { + float *out_f = reinterpret_cast(out); + const float *a_f = reinterpret_cast(a); + const float *b_f = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_f, a_f, b_f, _length, _a_stride, + _b_stride); + CHECK_CUDA(cudaGetLastError()); + break; + } + case INFINI_DTYPE_F64: { + double *out_d = reinterpret_cast(out); + const double *a_d = reinterpret_cast(a); + const double *b_d = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_d, a_d, b_d, _length, _a_stride, + _b_stride); + CHECK_CUDA(cudaGetLastError()); + break; + } + case INFINI_DTYPE_F16: { + // For FP16, accumulate in float, then cast back to half + if (workspace_size < sizeof(float)) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *tmp_out = reinterpret_cast(workspace); + { + const __half *a_h = reinterpret_cast(a); + const __half *b_h = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_h, b_h, _length, + _a_stride, _b_stride); + CHECK_CUDA(cudaGetLastError()); + float result_f; + CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + __half h_result = __float2half(result_f); + CHECK_CUDA(cudaMemcpyAsync(out, &h_result, sizeof(__half), + cudaMemcpyHostToDevice, cuda_stream)); + } + break; + } + case INFINI_DTYPE_BF16: { + // For BF16, accumulate in float, then cast back + if (workspace_size < sizeof(float)) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *tmp_out = reinterpret_cast(workspace); + { + const __nv_bfloat16 *a_bf = reinterpret_cast(a); + const __nv_bfloat16 *b_bf = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_bf, b_bf, _length, + _a_stride, _b_stride); + CHECK_CUDA(cudaGetLastError()); + float result_f; + CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + __nv_bfloat16 bf_result = __float2bfloat16(result_f); + CHECK_CUDA(cudaMemcpyAsync(out, &bf_result, sizeof(__nv_bfloat16), + cudaMemcpyHostToDevice, cuda_stream)); + } + break; + } + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::vdot::metax diff --git a/src/infiniop/ops/vdot/metax/vdot_metax.h b/src/infiniop/ops/vdot/metax/vdot_metax.h new file mode 100644 index 000000000..ffe85cff2 --- /dev/null +++ b/src/infiniop/ops/vdot/metax/vdot_metax.h @@ -0,0 +1,46 @@ +#ifndef __VDOT_METAX_API_H__ +#define __VDOT_METAX_API_H__ + +#include "../../../devices/metax/metax_handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" +#include "../cuda/kernel.cuh" + +namespace op::vdot::metax { + +class Descriptor final : public InfiniopDescriptor { + infiniDtype_t _in_dtype; + infiniDtype_t _out_dtype; + size_t _length; + ptrdiff_t _a_stride; + ptrdiff_t _b_stride; + +public: + Descriptor(infiniDtype_t in_dtype, infiniDtype_t out_dtype, size_t length, + ptrdiff_t a_stride, ptrdiff_t b_stride, infiniDevice_t device_type, + int device_id) + : InfiniopDescriptor{device_type, device_id}, _in_dtype(in_dtype), + _out_dtype(out_dtype), _length(length), _a_stride(a_stride), + _b_stride(b_stride) {} + + ~Descriptor(); + + size_t workspaceSize() const { + // Need workspace for FP16/BF16 to accumulate in float + return (_in_dtype == INFINI_DTYPE_F16 || _in_dtype == INFINI_DTYPE_BF16) + ? sizeof(float) + : 0; + } + + static infiniStatus_t create(infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + + infiniStatus_t calculate(void *workspace, size_t workspace_size, void *out, + const void *a, const void *b, void *stream) const; +}; + +} // namespace op::vdot::metax + +#endif // __VDOT_METAX_API_H__ diff --git a/src/infiniop/ops/vdot/moore/vdot_moore.cu b/src/infiniop/ops/vdot/moore/vdot_moore.cu new file mode 100644 index 000000000..331123170 --- /dev/null +++ b/src/infiniop/ops/vdot/moore/vdot_moore.cu @@ -0,0 +1,136 @@ +#include "../../../devices/moore/moore_handle.h" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "vdot_moore.h" +#include +#include + +namespace op::vdot::moore { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create(infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + + auto handle = reinterpret_cast(handle_); + auto in_dtype = a_desc->dtype(); + auto b_dtype = b_desc->dtype(); + auto out_dtype = out_desc->dtype(); + + // Inputs must be 1D vectors with same length + if (a_desc->ndim() != 1 || b_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (a_desc->numel() != b_desc->numel()) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // Input dtypes must match and be in supported set + CHECK_OR_RETURN(in_dtype == b_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(in_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, + INFINI_DTYPE_BF16); + + // Output dtype equals input dtype + CHECK_OR_RETURN(out_dtype == in_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + size_t length = a_desc->numel(); + ptrdiff_t a_stride = a_desc->stride(0); + ptrdiff_t b_stride = b_desc->stride(0); + + *desc_ptr = new Descriptor(in_dtype, out_dtype, length, a_stride, b_stride, + handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, + void *out, const void *a, const void *b, + void *stream) const { + + auto cuda_stream = reinterpret_cast(stream); + constexpr unsigned int BLOCK_SIZE = 256; + + switch (_in_dtype) { + case INFINI_DTYPE_F32: { + float *out_f = reinterpret_cast(out); + const float *a_f = reinterpret_cast(a); + const float *b_f = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_f, a_f, b_f, _length, _a_stride, + _b_stride); + CHECK_CUDA(cudaGetLastError()); + break; + } + case INFINI_DTYPE_F64: { + double *out_d = reinterpret_cast(out); + const double *a_d = reinterpret_cast(a); + const double *b_d = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_d, a_d, b_d, _length, _a_stride, + _b_stride); + CHECK_CUDA(cudaGetLastError()); + break; + } + case INFINI_DTYPE_F16: { + // For FP16, accumulate in float, then cast back to half + // Use workspace for temporary float buffer + if (workspace_size < sizeof(float)) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *tmp_out = reinterpret_cast(workspace); + { + // If workspace is too small, we need to allocate + // For simplicity, use a device-side kernel that writes directly to out + // But we need float accumulation, so use a temporary approach + const __half *a_h = reinterpret_cast(a); + const __half *b_h = reinterpret_cast(b); + // Launch kernel that accumulates in float and writes half result + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_h, b_h, _length, + _a_stride, _b_stride); + CHECK_CUDA(cudaGetLastError()); + // Use a simple device kernel to cast float to half + // For now, copy to host, cast, and copy back + float result_f; + CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + __half h_result = __float2half(result_f); + CHECK_CUDA(cudaMemcpyAsync(out, &h_result, sizeof(__half), + cudaMemcpyHostToDevice, cuda_stream)); + } + break; + } + case INFINI_DTYPE_BF16: { + // For BF16, accumulate in float, then cast back + if (workspace_size < sizeof(float)) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *tmp_out = reinterpret_cast(workspace); + { + const __nv_bfloat16 *a_bf = reinterpret_cast(a); + const __nv_bfloat16 *b_bf = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_bf, b_bf, _length, + _a_stride, _b_stride); + CHECK_CUDA(cudaGetLastError()); + float result_f; + CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + __nv_bfloat16 bf_result = __float2bfloat16(result_f); + CHECK_CUDA(cudaMemcpyAsync(out, &bf_result, sizeof(__nv_bfloat16), + cudaMemcpyHostToDevice, cuda_stream)); + } + break; + } + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::vdot::moore diff --git a/src/infiniop/ops/vdot/moore/vdot_moore.h b/src/infiniop/ops/vdot/moore/vdot_moore.h new file mode 100644 index 000000000..6bd36593d --- /dev/null +++ b/src/infiniop/ops/vdot/moore/vdot_moore.h @@ -0,0 +1,46 @@ +#ifndef __VDOT_MOORE_API_H__ +#define __VDOT_MOORE_API_H__ + +#include "../../../devices/moore/moore_handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" +#include "../cuda/kernel.cuh" + +namespace op::vdot::moore { + +class Descriptor final : public InfiniopDescriptor { + infiniDtype_t _in_dtype; + infiniDtype_t _out_dtype; + size_t _length; + ptrdiff_t _a_stride; + ptrdiff_t _b_stride; + +public: + Descriptor(infiniDtype_t in_dtype, infiniDtype_t out_dtype, size_t length, + ptrdiff_t a_stride, ptrdiff_t b_stride, infiniDevice_t device_type, + int device_id) + : InfiniopDescriptor{device_type, device_id}, _in_dtype(in_dtype), + _out_dtype(out_dtype), _length(length), _a_stride(a_stride), + _b_stride(b_stride) {} + + ~Descriptor(); + + size_t workspaceSize() const { + // Need workspace for FP16/BF16 to accumulate in float + return (_in_dtype == INFINI_DTYPE_F16 || _in_dtype == INFINI_DTYPE_BF16) + ? sizeof(float) + : 0; + } + + static infiniStatus_t create(infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + + infiniStatus_t calculate(void *workspace, size_t workspace_size, void *out, + const void *a, const void *b, void *stream) const; +}; + +} // namespace op::vdot::moore + +#endif // __VDOT_MOORE_API_H__ diff --git a/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cu b/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cu new file mode 100644 index 000000000..356e16901 --- /dev/null +++ b/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cu @@ -0,0 +1,167 @@ +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../cuda/kernel.cuh" +#include "vdot_nvidia.cuh" + +namespace op::vdot::nvidia { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create(infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + + auto handle = reinterpret_cast(handle_); + auto in_dtype = a_desc->dtype(); + auto b_dtype = b_desc->dtype(); + auto out_dtype = out_desc->dtype(); + + // Inputs must be 1D vectors with same length + if (a_desc->ndim() != 1 || b_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (a_desc->numel() != b_desc->numel()) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // Input dtypes must match and be in supported set + CHECK_OR_RETURN(in_dtype == b_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(in_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, + INFINI_DTYPE_BF16); + + // Output dtype equals input dtype + CHECK_OR_RETURN(out_dtype == in_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + size_t length = a_desc->numel(); + ptrdiff_t a_stride = a_desc->stride(0); + ptrdiff_t b_stride = b_desc->stride(0); + + *desc_ptr = + new Descriptor(in_dtype, out_dtype, length, a_stride, b_stride, + handle->internal(), handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, + void *out, const void *a, const void *b, + void *stream) const { + + auto cuda_stream = reinterpret_cast(stream); + + // For FP16/BF16, use CUDA kernel instead of cuBLAS + if (_in_dtype == INFINI_DTYPE_F16 || _in_dtype == INFINI_DTYPE_BF16) { + switch (_in_dtype) { + case INFINI_DTYPE_F16: { + if (workspace_size < sizeof(float)) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *tmp_out = reinterpret_cast(workspace); + const __half *a_h = reinterpret_cast(a); + const __half *b_h = reinterpret_cast(b); + constexpr unsigned int BLOCK_SIZE = 256; + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_h, b_h, _length, + _a_stride, _b_stride); + CHECK_CUDA(cudaGetLastError()); + float result_f; + CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + __half h_result = __float2half(result_f); + CHECK_CUDA(cudaMemcpyAsync(out, &h_result, sizeof(__half), + cudaMemcpyHostToDevice, cuda_stream)); + break; + } + case INFINI_DTYPE_BF16: { + if (workspace_size < sizeof(float)) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *tmp_out = reinterpret_cast(workspace); + const __nv_bfloat16 *a_bf = reinterpret_cast(a); + const __nv_bfloat16 *b_bf = reinterpret_cast(b); + constexpr unsigned int BLOCK_SIZE = 256; + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_bf, b_bf, _length, + _a_stride, _b_stride); + CHECK_CUDA(cudaGetLastError()); + float result_f; + CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + __nv_bfloat16 bf_result = __float2bfloat16(result_f); + CHECK_CUDA(cudaMemcpyAsync(out, &bf_result, sizeof(__nv_bfloat16), + cudaMemcpyHostToDevice, cuda_stream)); + break; + } + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; + } + + // Use cuBLAS for F32/F64 + CHECK_STATUS(_internal->useCublas(cuda_stream, [&](cublasHandle_t handle) { + switch (_in_dtype) { + case INFINI_DTYPE_F32: { + if (_a_stride == 1 && _b_stride == 1) { + // Contiguous case: use cublasSdot + float result; + CHECK_CUBLAS(cublasSdot(handle, static_cast(_length), + reinterpret_cast(a), + static_cast(_a_stride), + reinterpret_cast(b), + static_cast(_b_stride), &result)); + CHECK_CUDA(cudaMemcpyAsync(out, &result, sizeof(float), + cudaMemcpyHostToDevice, cuda_stream)); + } else { + // Strided case: use cublasDotEx + float result; + CHECK_CUBLAS(cublasDotEx(handle, static_cast(_length), + reinterpret_cast(a), CUDA_R_32F, + static_cast(_a_stride), + reinterpret_cast(b), CUDA_R_32F, + static_cast(_b_stride), &result, + CUDA_R_32F, CUDA_R_32F)); + CHECK_CUDA(cudaMemcpyAsync(out, &result, sizeof(float), + cudaMemcpyHostToDevice, cuda_stream)); + } + break; + } + case INFINI_DTYPE_F64: { + if (_a_stride == 1 && _b_stride == 1) { + // Contiguous case: use cublasDdot + double result; + CHECK_CUBLAS(cublasDdot(handle, static_cast(_length), + reinterpret_cast(a), + static_cast(_a_stride), + reinterpret_cast(b), + static_cast(_b_stride), &result)); + CHECK_CUDA(cudaMemcpyAsync(out, &result, sizeof(double), + cudaMemcpyHostToDevice, cuda_stream)); + } else { + // Strided case: use cublasDotEx + double result; + CHECK_CUBLAS(cublasDotEx(handle, static_cast(_length), + reinterpret_cast(a), + CUDA_R_64F, static_cast(_a_stride), + reinterpret_cast(b), + CUDA_R_64F, static_cast(_b_stride), + &result, CUDA_R_64F, CUDA_R_64F)); + CHECK_CUDA(cudaMemcpyAsync(out, &result, sizeof(double), + cudaMemcpyHostToDevice, cuda_stream)); + } + break; + } + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::vdot::nvidia diff --git a/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cuh b/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cuh new file mode 100644 index 000000000..9c25fc140 --- /dev/null +++ b/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cuh @@ -0,0 +1,47 @@ +#ifndef __VDOT_NVIDIA_CUH__ +#define __VDOT_NVIDIA_CUH__ + +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "../../../operator.h" +#include "../../../tensor.h" + +namespace op::vdot::nvidia { + +class Descriptor final : public InfiniopDescriptor { + infiniDtype_t _in_dtype; + infiniDtype_t _out_dtype; + size_t _length; + ptrdiff_t _a_stride; + ptrdiff_t _b_stride; + std::shared_ptr _internal; + +public: + Descriptor(infiniDtype_t in_dtype, infiniDtype_t out_dtype, size_t length, + ptrdiff_t a_stride, ptrdiff_t b_stride, + std::shared_ptr internal, + infiniDevice_t device_type, int device_id) + : InfiniopDescriptor{device_type, device_id}, _in_dtype(in_dtype), + _out_dtype(out_dtype), _length(length), _a_stride(a_stride), + _b_stride(b_stride), _internal(internal) {} + + ~Descriptor(); + + size_t workspaceSize() const { + // Need workspace for FP16/BF16 to accumulate in float + return (_in_dtype == INFINI_DTYPE_F16 || _in_dtype == INFINI_DTYPE_BF16) + ? sizeof(float) + : 0; + } + + static infiniStatus_t create(infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + + infiniStatus_t calculate(void *workspace, size_t workspace_size, void *out, + const void *a, const void *b, void *stream) const; +}; + +} // namespace op::vdot::nvidia + +#endif // __VDOT_NVIDIA_CUH__ diff --git a/src/infiniop/ops/vdot/operator.cc b/src/infiniop/ops/vdot/operator.cc index 2bc570ab8..bbf3b088c 100644 --- a/src/infiniop/ops/vdot/operator.cc +++ b/src/infiniop/ops/vdot/operator.cc @@ -5,6 +5,15 @@ #ifdef ENABLE_CPU_API #include "cpu/vdot_cpu.h" #endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/vdot_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/vdot_metax.h" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/vdot_moore.h" +#endif __C infiniStatus_t infiniopCreateVdotDescriptor( infiniopHandle_t handle, @@ -24,6 +33,21 @@ __C infiniStatus_t infiniopCreateVdotDescriptor( switch (handle->device) { #ifdef ENABLE_CPU_API CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + CREATE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -44,6 +68,21 @@ __C infiniStatus_t infiniopGetVdotWorkspaceSize( switch (desc->device_type) { #ifdef ENABLE_CPU_API GET(INFINI_DEVICE_CPU, cpu) +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_QY_API + GET(INFINI_DEVICE_QY, nvidia) +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore) #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -69,6 +108,21 @@ __C infiniStatus_t infiniopVdot( switch (desc->device_type) { #ifdef ENABLE_CPU_API CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + CALCULATE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -87,6 +141,21 @@ __C infiniStatus_t infiniopDestroyVdotDescriptor(infiniopVdotDescriptor_t desc) switch (desc->device_type) { #ifdef ENABLE_CPU_API DELETE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DELETE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + DELETE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + DELETE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + DELETE(INFINI_DEVICE_MOORE, moore); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; From fba66e8f819b892024a2fa38c05b15eed562c8d9 Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 16 Dec 2025 20:13:09 +0800 Subject: [PATCH 06/17] =?UTF-8?q?feat:=20=E9=80=82=E9=85=8Dwhere=E7=9A=84c?= =?UTF-8?q?uda=E7=AE=97=E5=AD=90=E3=80=81format=E6=A0=BC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops/where.hpp | 10 + include/infiniop/ops/where.h | 24 ++ python/infinicore/nn/functional/logsigmoid.py | 4 +- python/infinicore/ops/logical_or.py | 3 +- python/infinicore/ops/logical_xor.py | 4 +- python/infinicore/ops/logsigmoid.py | 3 +- python/infinicore/ops/vdot.py | 2 - python/infinicore/ops/where.py | 29 +- src/infinicore/ops/logical_or/logical_or.cc | 1 - .../ops/logical_or/logical_or_infiniop.cc | 1 - src/infinicore/ops/logical_xor/logical_xor.cc | 1 - .../ops/logical_xor/logical_xor_infiniop.cc | 1 - src/infinicore/ops/logsigmoid/logsigmoid.cc | 1 - .../ops/logsigmoid/logsigmoid_infiniop.cc | 1 - src/infinicore/ops/vdot/vdot.cc | 2 - src/infinicore/ops/vdot/vdot_infiniop.cc | 2 - src/infinicore/ops/where/where.cc | 23 +- .../ops/where/where_indices_infiniop.cc | 85 ++++++ src/infinicore/ops/where/where_infiniop.cc | 2 - src/infinicore/pybind11/ops/where.hpp | 5 + .../ops/logical_or/cpu/logical_or_cpu.cc | 1 - .../ops/logical_or/cpu/logical_or_cpu.h | 1 - src/infiniop/ops/logical_or/operator.cc | 27 +- .../ops/logical_xor/cpu/logical_xor_cpu.cc | 1 - .../ops/logical_xor/cpu/logical_xor_cpu.h | 1 - src/infiniop/ops/logical_xor/operator.cc | 7 +- .../ops/logsigmoid/cpu/logsigmoid_cpu.cc | 1 - .../ops/logsigmoid/cpu/logsigmoid_cpu.h | 1 - src/infiniop/ops/logsigmoid/operator.cc | 27 +- src/infiniop/ops/vdot/cpu/vdot_cpu.cc | 2 - src/infiniop/ops/vdot/cpu/vdot_cpu.h | 4 +- src/infiniop/ops/vdot/cuda/kernel.cuh | 30 +- src/infiniop/ops/vdot/metax/vdot_metax.cu | 184 ++++++------ src/infiniop/ops/vdot/metax/vdot_metax.h | 56 ++-- src/infiniop/ops/vdot/moore/vdot_moore.cu | 198 ++++++------ src/infiniop/ops/vdot/moore/vdot_moore.h | 56 ++-- src/infiniop/ops/vdot/nvidia/vdot_nvidia.cu | 281 +++++++++--------- src/infiniop/ops/vdot/nvidia/vdot_nvidia.cuh | 60 ++-- src/infiniop/ops/vdot/operator.cc | 24 +- src/infiniop/ops/where/cpu/where_cpu.cc | 2 - src/infiniop/ops/where/cpu/where_cpu.h | 2 - .../ops/where/cpu/where_indices_cpu.cc | 86 ++++++ .../ops/where/cpu/where_indices_cpu.h | 57 ++++ .../ops/where/cuda/where_indices_kernel.cuh | 74 +++++ .../ops/where/metax/where_indices_metax.cu | 139 +++++++++ .../ops/where/metax/where_indices_metax.h | 58 ++++ .../ops/where/moore/where_indices_moore.cu | 153 ++++++++++ .../ops/where/moore/where_indices_moore.h | 58 ++++ .../ops/where/nvidia/where_indices_nvidia.cu | 156 ++++++++++ .../ops/where/nvidia/where_indices_nvidia.h | 61 ++++ src/infiniop/ops/where/operator.cc | 183 +++++++++++- 51 files changed, 1642 insertions(+), 553 deletions(-) create mode 100644 src/infinicore/ops/where/where_indices_infiniop.cc create mode 100644 src/infiniop/ops/where/cpu/where_indices_cpu.cc create mode 100644 src/infiniop/ops/where/cpu/where_indices_cpu.h create mode 100644 src/infiniop/ops/where/cuda/where_indices_kernel.cuh create mode 100644 src/infiniop/ops/where/metax/where_indices_metax.cu create mode 100644 src/infiniop/ops/where/metax/where_indices_metax.h create mode 100644 src/infiniop/ops/where/moore/where_indices_moore.cu create mode 100644 src/infiniop/ops/where/moore/where_indices_moore.h create mode 100644 src/infiniop/ops/where/nvidia/where_indices_nvidia.cu create mode 100644 src/infiniop/ops/where/nvidia/where_indices_nvidia.h diff --git a/include/infinicore/ops/where.hpp b/include/infinicore/ops/where.hpp index 72f1b8506..85bfce417 100644 --- a/include/infinicore/ops/where.hpp +++ b/include/infinicore/ops/where.hpp @@ -15,6 +15,16 @@ class Where { Tensor where(Tensor cond, Tensor x, Tensor y); void where_(Tensor out, Tensor cond, Tensor x, Tensor y); +class WhereIndices { +public: + using schema = std::vector (*)(Tensor); + static std::vector execute(Tensor cond); + static common::OpDispatcher &dispatcher(); +}; + +// where(cond) -> tuple of index tensors +std::vector where_indices(Tensor cond); + } // namespace infinicore::op diff --git a/include/infiniop/ops/where.h b/include/infiniop/ops/where.h index 95f8dd6f7..196720cfb 100644 --- a/include/infiniop/ops/where.h +++ b/include/infiniop/ops/where.h @@ -31,6 +31,30 @@ __C __export infiniStatus_t infiniopWhere( __C __export infiniStatus_t infiniopDestroyWhereDescriptor( infiniopWhereDescriptor_t desc); +// where(cond) -> indices tuple +typedef struct InfiniopDescriptor *infiniopWhereIndicesDescriptor_t; + +__C __export infiniStatus_t infiniopCreateWhereIndicesDescriptor( + infiniopHandle_t handle, + infiniopWhereIndicesDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t cond_desc); + +__C __export infiniStatus_t infiniopGetWhereIndicesWorkspaceSize( + infiniopWhereIndicesDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopWhereIndices( + infiniopWhereIndicesDescriptor_t desc, + void *workspace, + size_t workspace_size, + void **outputs, // NDIM 个输出张量的指针数组 + const void *cond, + void *stream, + size_t *num_true); // 输出:True 元素的数量 + +__C __export infiniStatus_t infiniopDestroyWhereIndicesDescriptor( + infiniopWhereIndicesDescriptor_t desc); + #endif diff --git a/python/infinicore/nn/functional/logsigmoid.py b/python/infinicore/nn/functional/logsigmoid.py index c1bcf2b63..e79f16f65 100644 --- a/python/infinicore/nn/functional/logsigmoid.py +++ b/python/infinicore/nn/functional/logsigmoid.py @@ -6,11 +6,9 @@ def logsigmoid(input: Tensor, out=None) -> Tensor: """Apply elementwise log-sigmoid.""" if infinicore.use_ntops and input.device.type in ("cuda", "musa"): - return infinicore.ntops.torch.logsigmoid(input, out = out) + return infinicore.ntops.torch.logsigmoid(input, out=out) if out is None: return Tensor(_infinicore.logsigmoid(input._underlying)) _infinicore.logsigmoid_(out._underlying, input._underlying) return out - - diff --git a/python/infinicore/ops/logical_or.py b/python/infinicore/ops/logical_or.py index 01024413e..b83e23c88 100644 --- a/python/infinicore/ops/logical_or.py +++ b/python/infinicore/ops/logical_or.py @@ -5,7 +5,7 @@ def logical_or(input, other, *, out=None): if infinicore.use_ntops and input.device.type in ("cuda", "musa"): - return infinicore.ntops.torch.logical_or(input, other, out = out) + return infinicore.ntops.torch.logical_or(input, other, out=out) if out is None: return Tensor(_infinicore.logical_or(input._underlying, other._underlying)) @@ -13,4 +13,3 @@ def logical_or(input, other, *, out=None): _infinicore.logical_or_(out._underlying, input._underlying, other._underlying) return out - diff --git a/python/infinicore/ops/logical_xor.py b/python/infinicore/ops/logical_xor.py index b11380844..5065d6e8d 100644 --- a/python/infinicore/ops/logical_xor.py +++ b/python/infinicore/ops/logical_xor.py @@ -5,8 +5,7 @@ def logical_xor(input, other, *, out=None): if infinicore.use_ntops and input.device.type in ("cuda", "musa"): - return infinicore.ntops.torch.logical_xor(input, other, out = out) - + return infinicore.ntops.torch.logical_xor(input, other, out=out) if out is None: return Tensor(_infinicore.logical_xor(input._underlying, other._underlying)) @@ -14,4 +13,3 @@ def logical_xor(input, other, *, out=None): _infinicore.logical_xor_(out._underlying, input._underlying, other._underlying) return out - diff --git a/python/infinicore/ops/logsigmoid.py b/python/infinicore/ops/logsigmoid.py index cb22c9e70..3a0b85d39 100644 --- a/python/infinicore/ops/logsigmoid.py +++ b/python/infinicore/ops/logsigmoid.py @@ -5,7 +5,7 @@ def logsigmoid(input, *, out=None): if infinicore.use_ntops and input.device.type in ("cuda", "musa"): - return infinicore.ntops.torch.logsigmoid(input, out = out) + return infinicore.ntops.torch.logsigmoid(input, out=out) if out is None: return Tensor(_infinicore.logsigmoid(input._underlying)) @@ -13,4 +13,3 @@ def logsigmoid(input, *, out=None): _infinicore.logsigmoid_(out._underlying, input._underlying) return out - diff --git a/python/infinicore/ops/vdot.py b/python/infinicore/ops/vdot.py index 2fd9fd7e6..2928a1d20 100644 --- a/python/infinicore/ops/vdot.py +++ b/python/infinicore/ops/vdot.py @@ -9,5 +9,3 @@ def vdot(a: Tensor, b: Tensor) -> Tensor: for real-valued tensors (no complex conjugation). """ return Tensor(_infinicore.vdot(a._underlying, b._underlying)) - - diff --git a/python/infinicore/ops/where.py b/python/infinicore/ops/where.py index 18650bd98..552f3b1fb 100644 --- a/python/infinicore/ops/where.py +++ b/python/infinicore/ops/where.py @@ -1,5 +1,6 @@ from infinicore.lib import _infinicore from infinicore.tensor import Tensor, from_torch +import infinicore import torch @@ -18,35 +19,21 @@ def where(*args, out=None): if len(args) == 1: cond = args[0] - # Prefer using the original Torch tensor reference when available - cond_torch = getattr(cond, "_torch_ref", None) - if cond_torch is None: - # Fallback: create a Torch tensor, then copy data from infinicore tensor. - # Tests use CPU bool tensors for condition-only where. - cond_torch = torch.zeros( - cond.shape, - dtype=torch.bool, - device="cpu", - ) - # Share storage between Torch tensor and an infinicore view, then copy. - ic_view = from_torch(cond_torch) - ic_view.copy_(cond) - - idx_tensors = torch.where(cond_torch) - # torch.where(cond) returns a tuple of index tensors; mirror that with - # infinicore tensors sharing the same underlying storage. - return tuple(from_torch(t) for t in idx_tensors) + # Use native infiniop implementation + idx_tensors = _infinicore.where_indices(cond._underlying) + # Convert C++ Tensor objects to Python Tensor objects + return tuple(Tensor(t) for t in idx_tensors) if len(args) != 3: raise TypeError("infinicore.where expects (cond, x, y)") cond, x, y = args + if infinicore.use_ntops and x.device.type in ("cuda", "musa"): + return infinicore.ntops.torch.where(cond, x, y, out=out) + if out is None: return Tensor(_infinicore.where(cond._underlying, x._underlying, y._underlying)) _infinicore.where_(out._underlying, cond._underlying, x._underlying, y._underlying) return out - - - diff --git a/src/infinicore/ops/logical_or/logical_or.cc b/src/infinicore/ops/logical_or/logical_or.cc index 96df42470..597f38d88 100644 --- a/src/infinicore/ops/logical_or/logical_or.cc +++ b/src/infinicore/ops/logical_or/logical_or.cc @@ -25,4 +25,3 @@ void logical_or_(Tensor c, Tensor a, Tensor b) { LogicalOr::execute(c, a, b); } } // namespace infinicore::op - diff --git a/src/infinicore/ops/logical_or/logical_or_infiniop.cc b/src/infinicore/ops/logical_or/logical_or_infiniop.cc index 66118968d..2398cae10 100644 --- a/src/infinicore/ops/logical_or/logical_or_infiniop.cc +++ b/src/infinicore/ops/logical_or/logical_or_infiniop.cc @@ -50,4 +50,3 @@ static bool registered = []() { }(); } // namespace infinicore::op::logical_or_impl::infiniop - diff --git a/src/infinicore/ops/logical_xor/logical_xor.cc b/src/infinicore/ops/logical_xor/logical_xor.cc index bd893f1f6..13270a90f 100644 --- a/src/infinicore/ops/logical_xor/logical_xor.cc +++ b/src/infinicore/ops/logical_xor/logical_xor.cc @@ -25,4 +25,3 @@ void logical_xor_(Tensor c, Tensor a, Tensor b) { LogicalXor::execute(c, a, b); } } // namespace infinicore::op - diff --git a/src/infinicore/ops/logical_xor/logical_xor_infiniop.cc b/src/infinicore/ops/logical_xor/logical_xor_infiniop.cc index 76226cbbc..514b69ca4 100644 --- a/src/infinicore/ops/logical_xor/logical_xor_infiniop.cc +++ b/src/infinicore/ops/logical_xor/logical_xor_infiniop.cc @@ -50,4 +50,3 @@ static bool registered = []() { }(); } // namespace infinicore::op::logical_xor_impl::infiniop - diff --git a/src/infinicore/ops/logsigmoid/logsigmoid.cc b/src/infinicore/ops/logsigmoid/logsigmoid.cc index 1ed7da9f6..a1e5580a0 100644 --- a/src/infinicore/ops/logsigmoid/logsigmoid.cc +++ b/src/infinicore/ops/logsigmoid/logsigmoid.cc @@ -32,4 +32,3 @@ void logsigmoid_(Tensor output, Tensor input) { LogSigmoid::execute(output, input); } } // namespace infinicore::op - diff --git a/src/infinicore/ops/logsigmoid/logsigmoid_infiniop.cc b/src/infinicore/ops/logsigmoid/logsigmoid_infiniop.cc index e81a880a0..9142032bf 100644 --- a/src/infinicore/ops/logsigmoid/logsigmoid_infiniop.cc +++ b/src/infinicore/ops/logsigmoid/logsigmoid_infiniop.cc @@ -50,4 +50,3 @@ static bool registered = []() { }(); } // namespace infinicore::op::logsigmoid_impl::infiniop - diff --git a/src/infinicore/ops/vdot/vdot.cc b/src/infinicore/ops/vdot/vdot.cc index 822b86c4a..24af3c279 100644 --- a/src/infinicore/ops/vdot/vdot.cc +++ b/src/infinicore/ops/vdot/vdot.cc @@ -46,5 +46,3 @@ void vdot_(Tensor out, Tensor a, Tensor b) { } } // namespace infinicore::op - - diff --git a/src/infinicore/ops/vdot/vdot_infiniop.cc b/src/infinicore/ops/vdot/vdot_infiniop.cc index 866b1a044..ff799aaad 100644 --- a/src/infinicore/ops/vdot/vdot_infiniop.cc +++ b/src/infinicore/ops/vdot/vdot_infiniop.cc @@ -50,5 +50,3 @@ static bool registered = []() { }(); } // namespace infinicore::op::vdot_impl::infiniop - - diff --git a/src/infinicore/ops/where/where.cc b/src/infinicore/ops/where/where.cc index 6eabc9e1b..2e00025b9 100644 --- a/src/infinicore/ops/where/where.cc +++ b/src/infinicore/ops/where/where.cc @@ -35,6 +35,27 @@ void where_(Tensor out, Tensor cond, Tensor x, Tensor y) { Where::execute(out, cond, x, y); } -} // namespace infinicore::op +common::OpDispatcher &WhereIndices::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +std::vector WhereIndices::execute(Tensor cond) { + infinicore::context::setDevice(cond->device()); + auto device_type = cond->device().getType(); + auto func = dispatcher().lookup(device_type); + + if (func == nullptr) { + throw std::runtime_error( + "No WhereIndices implementation found for device type: " + + std::to_string(static_cast(device_type))); + } + return func(cond); +} +std::vector where_indices(Tensor cond) { + return WhereIndices::execute(cond); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/where/where_indices_infiniop.cc b/src/infinicore/ops/where/where_indices_infiniop.cc new file mode 100644 index 000000000..0930f8034 --- /dev/null +++ b/src/infinicore/ops/where/where_indices_infiniop.cc @@ -0,0 +1,85 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/dtype.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/where.hpp" +#include + +namespace infinicore::op::where_impl::infiniop { + +thread_local common::OpCache + indices_caches(100, [](infiniopWhereIndicesDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyWhereIndicesDescriptor(desc)); + desc = nullptr; + } + }); + +std::vector calculate(Tensor cond) { + size_t seed = hash_combine(cond); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = indices_caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopWhereIndicesDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateWhereIndicesDescriptor( + context::getInfiniopHandle(cond->device()), &desc, cond->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR( + infiniopGetWhereIndicesWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + size_t numel = cond->numel(); + int ndim = static_cast(cond->ndim()); + + // 先分配最大可能大小的输出(numel),实际大小会在 API 调用后确定 + std::vector outputs; + std::vector output_ptrs; + + for (int dim = 0; dim < ndim; ++dim) { + auto out = Tensor::empty({numel}, DataType::I64, cond->device()); + outputs.push_back(out); + output_ptrs.push_back(out->data()); + } + + // 调用 infiniop API,它会计算 num_true + size_t num_true = 0; + INFINICORE_CHECK_ERROR(infiniopWhereIndices( + desc, workspace->data(), workspace_size, output_ptrs.data(), cond->data(), + context::getStream(), &num_true)); + + // 同步流以确保计算完成 + context::syncStream(); + + // 如果实际 num_true 小于 numel,需要调整输出张量的大小 + // 但 Tensor 可能不支持调整大小,所以我们需要创建新的张量并复制数据 + if (num_true < numel) { + std::vector resized_outputs; + for (int dim = 0; dim < ndim; ++dim) { + auto resized = Tensor::empty({num_true}, DataType::I64, cond->device()); + // 复制前 num_true 个元素 + resized->copy_from(outputs[dim]->narrow({{0, 0, num_true}})); + resized_outputs.push_back(resized); + } + return resized_outputs; + } + + return outputs; +} + +static bool registered = []() { + WhereIndices::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::where_impl::infiniop diff --git a/src/infinicore/ops/where/where_infiniop.cc b/src/infinicore/ops/where/where_infiniop.cc index e8964acf4..1cd807d39 100644 --- a/src/infinicore/ops/where/where_infiniop.cc +++ b/src/infinicore/ops/where/where_infiniop.cc @@ -50,5 +50,3 @@ static bool registered = []() { }(); } // namespace infinicore::op::where_impl::infiniop - - diff --git a/src/infinicore/pybind11/ops/where.hpp b/src/infinicore/pybind11/ops/where.hpp index 9c067b4d6..0488ca662 100644 --- a/src/infinicore/pybind11/ops/where.hpp +++ b/src/infinicore/pybind11/ops/where.hpp @@ -23,6 +23,11 @@ inline void bind_where(py::module &m) { py::arg("x"), py::arg("y"), R"doc(In-place elementwise where(cond, x, y) selection into out tensor.)doc"); + + m.def("where_indices", + &op::where_indices, + py::arg("cond"), + R"doc(Return a tuple of index tensors where condition is True.)doc"); } } // namespace infinicore::ops diff --git a/src/infiniop/ops/logical_or/cpu/logical_or_cpu.cc b/src/infiniop/ops/logical_or/cpu/logical_or_cpu.cc index 1225ca125..892b3d967 100644 --- a/src/infiniop/ops/logical_or/cpu/logical_or_cpu.cc +++ b/src/infiniop/ops/logical_or/cpu/logical_or_cpu.cc @@ -113,4 +113,3 @@ infiniStatus_t Descriptor::calculate( return INFINI_STATUS_BAD_TENSOR_DTYPE; } } // namespace op::logical_or::cpu - diff --git a/src/infiniop/ops/logical_or/cpu/logical_or_cpu.h b/src/infiniop/ops/logical_or/cpu/logical_or_cpu.h index 36c862570..720837dbf 100644 --- a/src/infiniop/ops/logical_or/cpu/logical_or_cpu.h +++ b/src/infiniop/ops/logical_or/cpu/logical_or_cpu.h @@ -22,4 +22,3 @@ typedef struct LogicalOrOp { } // namespace op::logical_or::cpu #endif // __LOGICAL_OR_CPU_H__ - diff --git a/src/infiniop/ops/logical_or/operator.cc b/src/infiniop/ops/logical_or/operator.cc index ee8d4f18f..544090081 100644 --- a/src/infiniop/ops/logical_or/operator.cc +++ b/src/infiniop/ops/logical_or/operator.cc @@ -13,12 +13,12 @@ __C infiniStatus_t infiniopCreateLogicalOrDescriptor( infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) { -#define CREATE(CASE, NAMESPACE) \ - case CASE: \ - return op::logical_or::NAMESPACE::Descriptor::create( \ - handle, \ - reinterpret_cast(desc_ptr), \ - c_desc, \ +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::logical_or::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + c_desc, \ {a_desc, b_desc}) switch (handle->device) { @@ -35,8 +35,8 @@ __C infiniStatus_t infiniopCreateLogicalOrDescriptor( } __C infiniStatus_t infiniopGetLogicalOrWorkspaceSize(infiniopLogicalOrDescriptor_t desc, size_t *size) { -#define GET(CASE, NAMESPACE) \ - case CASE: \ +#define GET(CASE, NAMESPACE) \ + case CASE: \ *size = reinterpret_cast(desc)->workspaceSize(); \ return INFINI_STATUS_SUCCESS; @@ -60,8 +60,8 @@ __C infiniStatus_t infiniopLogicalOr( const void *b, void *stream) { -#define CALCULATE(CASE, NAMESPACE) \ - case CASE: \ +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ return reinterpret_cast(desc) \ ->calculate(workspace, workspace_size, c, {a, b}, stream) @@ -79,9 +79,9 @@ __C infiniStatus_t infiniopLogicalOr( } __C infiniStatus_t infiniopDestroyLogicalOrDescriptor(infiniopLogicalOrDescriptor_t desc) { -#define DELETE(CASE, NAMESPACE) \ - case CASE: \ - delete reinterpret_cast(desc); \ +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ return INFINI_STATUS_SUCCESS; switch (desc->device_type) { @@ -96,4 +96,3 @@ __C infiniStatus_t infiniopDestroyLogicalOrDescriptor(infiniopLogicalOrDescripto #undef DELETE } - diff --git a/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.cc b/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.cc index bcb8ce34f..384473b49 100644 --- a/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.cc +++ b/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.cc @@ -113,4 +113,3 @@ infiniStatus_t Descriptor::calculate( return INFINI_STATUS_BAD_TENSOR_DTYPE; } } // namespace op::logical_xor::cpu - diff --git a/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.h b/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.h index 3ca53f889..35747d742 100644 --- a/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.h +++ b/src/infiniop/ops/logical_xor/cpu/logical_xor_cpu.h @@ -22,4 +22,3 @@ typedef struct LogicalXorOp { } // namespace op::logical_xor::cpu #endif // __LOGICAL_XOR_CPU_H__ - diff --git a/src/infiniop/ops/logical_xor/operator.cc b/src/infiniop/ops/logical_xor/operator.cc index 658e6a9fe..1558f4925 100644 --- a/src/infiniop/ops/logical_xor/operator.cc +++ b/src/infiniop/ops/logical_xor/operator.cc @@ -15,9 +15,9 @@ __C infiniStatus_t infiniopCreateLogicalXorDescriptor( #define CREATE(CASE, NAMESPACE) \ case CASE: \ - return op::logical_xor::NAMESPACE::Descriptor::create( \ + return op::logical_xor::NAMESPACE::Descriptor::create( \ handle, \ - reinterpret_cast(desc_ptr), \ + reinterpret_cast(desc_ptr), \ c_desc, \ {a_desc, b_desc}) @@ -81,7 +81,7 @@ __C infiniStatus_t infiniopLogicalXor( __C infiniStatus_t infiniopDestroyLogicalXorDescriptor(infiniopLogicalXorDescriptor_t desc) { #define DELETE(CASE, NAMESPACE) \ case CASE: \ - delete reinterpret_cast(desc); \ + delete reinterpret_cast(desc); \ return INFINI_STATUS_SUCCESS; switch (desc->device_type) { @@ -96,4 +96,3 @@ __C infiniStatus_t infiniopDestroyLogicalXorDescriptor(infiniopLogicalXorDescrip #undef DELETE } - diff --git a/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.cc b/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.cc index f765c7bb7..1fdd2207b 100644 --- a/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.cc +++ b/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.cc @@ -49,4 +49,3 @@ infiniStatus_t Descriptor::calculate( return INFINI_STATUS_SUCCESS; } } // namespace op::logsigmoid::cpu - diff --git a/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.h b/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.h index 8d6484d27..bdfaca385 100644 --- a/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.h +++ b/src/infiniop/ops/logsigmoid/cpu/logsigmoid_cpu.h @@ -18,4 +18,3 @@ typedef struct LogSigmoidOp { } // namespace op::logsigmoid::cpu #endif // __LOGSIGMOID_CPU_H__ - diff --git a/src/infiniop/ops/logsigmoid/operator.cc b/src/infiniop/ops/logsigmoid/operator.cc index 4d0283994..5409a672a 100644 --- a/src/infiniop/ops/logsigmoid/operator.cc +++ b/src/infiniop/ops/logsigmoid/operator.cc @@ -12,12 +12,12 @@ __C infiniStatus_t infiniopCreateLogSigmoidDescriptor( infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc) { -#define CREATE(CASE, NAMESPACE) \ - case CASE: \ - return op::logsigmoid::NAMESPACE::Descriptor::create( \ - handle, \ +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::logsigmoid::NAMESPACE::Descriptor::create( \ + handle, \ reinterpret_cast(desc_ptr), \ - y_desc, \ + y_desc, \ {x_desc}) switch (handle->device) { @@ -35,9 +35,9 @@ __C infiniStatus_t infiniopCreateLogSigmoidDescriptor( __C infiniStatus_t infiniopGetLogSigmoidWorkspaceSize(infiniopLogSigmoidDescriptor_t desc, size_t *size) { -#define GET(CASE, NAMESPACE) \ - case CASE: \ - *size = reinterpret_cast(desc)->workspaceSize(); \ +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ return INFINI_STATUS_SUCCESS; switch (desc->device_type) { @@ -60,9 +60,9 @@ __C infiniStatus_t infiniopLogSigmoid( const void *x, void *stream) { -#define CALCULATE(CASE, NAMESPACE) \ - case CASE: \ - return reinterpret_cast(desc) \ +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ ->calculate(workspace, workspace_size, y, {x}, stream) switch (desc->device_type) { @@ -81,8 +81,8 @@ __C infiniStatus_t infiniopLogSigmoid( __C infiniStatus_t infiniopDestroyLogSigmoidDescriptor(infiniopLogSigmoidDescriptor_t desc) { -#define DELETE(CASE, NAMESPACE) \ - case CASE: \ +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ delete reinterpret_cast(desc); \ return INFINI_STATUS_SUCCESS; @@ -98,4 +98,3 @@ infiniopDestroyLogSigmoidDescriptor(infiniopLogSigmoidDescriptor_t desc) { #undef DELETE } - diff --git a/src/infiniop/ops/vdot/cpu/vdot_cpu.cc b/src/infiniop/ops/vdot/cpu/vdot_cpu.cc index 13682ef48..e3511ebba 100644 --- a/src/infiniop/ops/vdot/cpu/vdot_cpu.cc +++ b/src/infiniop/ops/vdot/cpu/vdot_cpu.cc @@ -133,5 +133,3 @@ infiniStatus_t Descriptor::calculate( } } // namespace op::vdot::cpu - - diff --git a/src/infiniop/ops/vdot/cpu/vdot_cpu.h b/src/infiniop/ops/vdot/cpu/vdot_cpu.h index 528901ba0..08a01ff6e 100644 --- a/src/infiniop/ops/vdot/cpu/vdot_cpu.h +++ b/src/infiniop/ops/vdot/cpu/vdot_cpu.h @@ -2,8 +2,8 @@ #define __VDOT_CPU_H__ #include "../../../devices/cpu/common_cpu.h" -#include "../../../tensor.h" #include "../../../operator.h" +#include "../../../tensor.h" namespace op::vdot::cpu { @@ -52,5 +52,3 @@ class Descriptor final : public InfiniopDescriptor { } // namespace op::vdot::cpu #endif // __VDOT_CPU_H__ - - diff --git a/src/infiniop/ops/vdot/cuda/kernel.cuh b/src/infiniop/ops/vdot/cuda/kernel.cuh index 6fe772406..9b5342e92 100644 --- a/src/infiniop/ops/vdot/cuda/kernel.cuh +++ b/src/infiniop/ops/vdot/cuda/kernel.cuh @@ -10,25 +10,25 @@ __global__ void vdotKernel(Tcompute *out, const Tdata *a, const Tdata *b, size_t length, ptrdiff_t a_stride, ptrdiff_t b_stride) { - Tcompute dot = 0; + Tcompute dot = 0; - // Each thread computes its partial dot product - for (size_t i = threadIdx.x; i < length; i += BLOCK_SIZE) { - Tcompute a_val = Tcompute(a[i * a_stride]); - Tcompute b_val = Tcompute(b[i * b_stride]); - dot += a_val * b_val; - } + // Each thread computes its partial dot product + for (size_t i = threadIdx.x; i < length; i += BLOCK_SIZE) { + Tcompute a_val = Tcompute(a[i * a_stride]); + Tcompute b_val = Tcompute(b[i * b_stride]); + dot += a_val * b_val; + } - // Use CUB block-level reduction - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; + // Use CUB block-level reduction + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; - Tcompute block_dot = BlockReduce(temp_storage).Sum(dot); + Tcompute block_dot = BlockReduce(temp_storage).Sum(dot); - // Thread 0 writes the result - if (threadIdx.x == 0) { - *out = block_dot; - } + // Thread 0 writes the result + if (threadIdx.x == 0) { + *out = block_dot; + } } } // namespace op::vdot::cuda diff --git a/src/infiniop/ops/vdot/metax/vdot_metax.cu b/src/infiniop/ops/vdot/metax/vdot_metax.cu index a258625e5..f75605a0a 100644 --- a/src/infiniop/ops/vdot/metax/vdot_metax.cu +++ b/src/infiniop/ops/vdot/metax/vdot_metax.cu @@ -14,116 +14,116 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle_, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) { - auto handle = reinterpret_cast(handle_); - auto in_dtype = a_desc->dtype(); - auto b_dtype = b_desc->dtype(); - auto out_dtype = out_desc->dtype(); + auto handle = reinterpret_cast(handle_); + auto in_dtype = a_desc->dtype(); + auto b_dtype = b_desc->dtype(); + auto out_dtype = out_desc->dtype(); - // Inputs must be 1D vectors with same length - if (a_desc->ndim() != 1 || b_desc->ndim() != 1) { - return INFINI_STATUS_BAD_TENSOR_SHAPE; - } - if (a_desc->numel() != b_desc->numel()) { - return INFINI_STATUS_BAD_TENSOR_SHAPE; - } + // Inputs must be 1D vectors with same length + if (a_desc->ndim() != 1 || b_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (a_desc->numel() != b_desc->numel()) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } - // Input dtypes must match and be in supported set - CHECK_OR_RETURN(in_dtype == b_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(in_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, - INFINI_DTYPE_BF16); + // Input dtypes must match and be in supported set + CHECK_OR_RETURN(in_dtype == b_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(in_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, + INFINI_DTYPE_BF16); - // Output dtype equals input dtype - CHECK_OR_RETURN(out_dtype == in_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + // Output dtype equals input dtype + CHECK_OR_RETURN(out_dtype == in_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - size_t length = a_desc->numel(); - ptrdiff_t a_stride = a_desc->stride(0); - ptrdiff_t b_stride = b_desc->stride(0); + size_t length = a_desc->numel(); + ptrdiff_t a_stride = a_desc->stride(0); + ptrdiff_t b_stride = b_desc->stride(0); - *desc_ptr = new Descriptor(in_dtype, out_dtype, length, a_stride, b_stride, - handle->device, handle->device_id); + *desc_ptr = new Descriptor(in_dtype, out_dtype, length, a_stride, b_stride, + handle->device, handle->device_id); - return INFINI_STATUS_SUCCESS; + return INFINI_STATUS_SUCCESS; } infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, void *out, const void *a, const void *b, void *stream) const { - auto cuda_stream = reinterpret_cast(stream); - constexpr unsigned int BLOCK_SIZE = 256; + auto cuda_stream = reinterpret_cast(stream); + constexpr unsigned int BLOCK_SIZE = 256; - switch (_in_dtype) { - case INFINI_DTYPE_F32: { - float *out_f = reinterpret_cast(out); - const float *a_f = reinterpret_cast(a); - const float *b_f = reinterpret_cast(b); - op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_f, a_f, b_f, _length, _a_stride, - _b_stride); - CHECK_CUDA(cudaGetLastError()); - break; - } - case INFINI_DTYPE_F64: { - double *out_d = reinterpret_cast(out); - const double *a_d = reinterpret_cast(a); - const double *b_d = reinterpret_cast(b); - op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_d, a_d, b_d, _length, _a_stride, - _b_stride); - CHECK_CUDA(cudaGetLastError()); - break; - } - case INFINI_DTYPE_F16: { - // For FP16, accumulate in float, then cast back to half - if (workspace_size < sizeof(float)) { - return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + switch (_in_dtype) { + case INFINI_DTYPE_F32: { + float *out_f = reinterpret_cast(out); + const float *a_f = reinterpret_cast(a); + const float *b_f = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_f, a_f, b_f, _length, _a_stride, + _b_stride); + CHECK_CUDA(cudaGetLastError()); + break; + } + case INFINI_DTYPE_F64: { + double *out_d = reinterpret_cast(out); + const double *a_d = reinterpret_cast(a); + const double *b_d = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_d, a_d, b_d, _length, _a_stride, + _b_stride); + CHECK_CUDA(cudaGetLastError()); + break; } - float *tmp_out = reinterpret_cast(workspace); - { - const __half *a_h = reinterpret_cast(a); - const __half *b_h = reinterpret_cast(b); - op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_h, b_h, _length, - _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); - float result_f; - CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), - cudaMemcpyDeviceToHost, cuda_stream)); - CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); - __half h_result = __float2half(result_f); - CHECK_CUDA(cudaMemcpyAsync(out, &h_result, sizeof(__half), - cudaMemcpyHostToDevice, cuda_stream)); + case INFINI_DTYPE_F16: { + // For FP16, accumulate in float, then cast back to half + if (workspace_size < sizeof(float)) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *tmp_out = reinterpret_cast(workspace); + { + const __half *a_h = reinterpret_cast(a); + const __half *b_h = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_h, b_h, _length, + _a_stride, _b_stride); + CHECK_CUDA(cudaGetLastError()); + float result_f; + CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + __half h_result = __float2half(result_f); + CHECK_CUDA(cudaMemcpyAsync(out, &h_result, sizeof(__half), + cudaMemcpyHostToDevice, cuda_stream)); + } + break; } - break; - } - case INFINI_DTYPE_BF16: { - // For BF16, accumulate in float, then cast back - if (workspace_size < sizeof(float)) { - return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + case INFINI_DTYPE_BF16: { + // For BF16, accumulate in float, then cast back + if (workspace_size < sizeof(float)) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *tmp_out = reinterpret_cast(workspace); + { + const __nv_bfloat16 *a_bf = reinterpret_cast(a); + const __nv_bfloat16 *b_bf = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_bf, b_bf, _length, + _a_stride, _b_stride); + CHECK_CUDA(cudaGetLastError()); + float result_f; + CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + __nv_bfloat16 bf_result = __float2bfloat16(result_f); + CHECK_CUDA(cudaMemcpyAsync(out, &bf_result, sizeof(__nv_bfloat16), + cudaMemcpyHostToDevice, cuda_stream)); + } + break; } - float *tmp_out = reinterpret_cast(workspace); - { - const __nv_bfloat16 *a_bf = reinterpret_cast(a); - const __nv_bfloat16 *b_bf = reinterpret_cast(b); - op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_bf, b_bf, _length, - _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); - float result_f; - CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), - cudaMemcpyDeviceToHost, cuda_stream)); - CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); - __nv_bfloat16 bf_result = __float2bfloat16(result_f); - CHECK_CUDA(cudaMemcpyAsync(out, &bf_result, sizeof(__nv_bfloat16), - cudaMemcpyHostToDevice, cuda_stream)); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; } - break; - } - default: - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - return INFINI_STATUS_SUCCESS; + return INFINI_STATUS_SUCCESS; } } // namespace op::vdot::metax diff --git a/src/infiniop/ops/vdot/metax/vdot_metax.h b/src/infiniop/ops/vdot/metax/vdot_metax.h index ffe85cff2..b5eed7093 100644 --- a/src/infiniop/ops/vdot/metax/vdot_metax.h +++ b/src/infiniop/ops/vdot/metax/vdot_metax.h @@ -9,36 +9,36 @@ namespace op::vdot::metax { class Descriptor final : public InfiniopDescriptor { - infiniDtype_t _in_dtype; - infiniDtype_t _out_dtype; - size_t _length; - ptrdiff_t _a_stride; - ptrdiff_t _b_stride; + infiniDtype_t _in_dtype; + infiniDtype_t _out_dtype; + size_t _length; + ptrdiff_t _a_stride; + ptrdiff_t _b_stride; public: - Descriptor(infiniDtype_t in_dtype, infiniDtype_t out_dtype, size_t length, - ptrdiff_t a_stride, ptrdiff_t b_stride, infiniDevice_t device_type, - int device_id) - : InfiniopDescriptor{device_type, device_id}, _in_dtype(in_dtype), - _out_dtype(out_dtype), _length(length), _a_stride(a_stride), - _b_stride(b_stride) {} - - ~Descriptor(); - - size_t workspaceSize() const { - // Need workspace for FP16/BF16 to accumulate in float - return (_in_dtype == INFINI_DTYPE_F16 || _in_dtype == INFINI_DTYPE_BF16) - ? sizeof(float) - : 0; - } - - static infiniStatus_t create(infiniopHandle_t handle, Descriptor **desc_ptr, - infiniopTensorDescriptor_t out_desc, - infiniopTensorDescriptor_t a_desc, - infiniopTensorDescriptor_t b_desc); - - infiniStatus_t calculate(void *workspace, size_t workspace_size, void *out, - const void *a, const void *b, void *stream) const; + Descriptor(infiniDtype_t in_dtype, infiniDtype_t out_dtype, size_t length, + ptrdiff_t a_stride, ptrdiff_t b_stride, infiniDevice_t device_type, + int device_id) + : InfiniopDescriptor{device_type, device_id}, _in_dtype(in_dtype), + _out_dtype(out_dtype), _length(length), _a_stride(a_stride), + _b_stride(b_stride) {} + + ~Descriptor(); + + size_t workspaceSize() const { + // Need workspace for FP16/BF16 to accumulate in float + return (_in_dtype == INFINI_DTYPE_F16 || _in_dtype == INFINI_DTYPE_BF16) + ? sizeof(float) + : 0; + } + + static infiniStatus_t create(infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + + infiniStatus_t calculate(void *workspace, size_t workspace_size, void *out, + const void *a, const void *b, void *stream) const; }; } // namespace op::vdot::metax diff --git a/src/infiniop/ops/vdot/moore/vdot_moore.cu b/src/infiniop/ops/vdot/moore/vdot_moore.cu index 331123170..2c9a6e474 100644 --- a/src/infiniop/ops/vdot/moore/vdot_moore.cu +++ b/src/infiniop/ops/vdot/moore/vdot_moore.cu @@ -14,123 +14,123 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle_, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) { - auto handle = reinterpret_cast(handle_); - auto in_dtype = a_desc->dtype(); - auto b_dtype = b_desc->dtype(); - auto out_dtype = out_desc->dtype(); + auto handle = reinterpret_cast(handle_); + auto in_dtype = a_desc->dtype(); + auto b_dtype = b_desc->dtype(); + auto out_dtype = out_desc->dtype(); - // Inputs must be 1D vectors with same length - if (a_desc->ndim() != 1 || b_desc->ndim() != 1) { - return INFINI_STATUS_BAD_TENSOR_SHAPE; - } - if (a_desc->numel() != b_desc->numel()) { - return INFINI_STATUS_BAD_TENSOR_SHAPE; - } + // Inputs must be 1D vectors with same length + if (a_desc->ndim() != 1 || b_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (a_desc->numel() != b_desc->numel()) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } - // Input dtypes must match and be in supported set - CHECK_OR_RETURN(in_dtype == b_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(in_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, - INFINI_DTYPE_BF16); + // Input dtypes must match and be in supported set + CHECK_OR_RETURN(in_dtype == b_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(in_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, + INFINI_DTYPE_BF16); - // Output dtype equals input dtype - CHECK_OR_RETURN(out_dtype == in_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + // Output dtype equals input dtype + CHECK_OR_RETURN(out_dtype == in_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - size_t length = a_desc->numel(); - ptrdiff_t a_stride = a_desc->stride(0); - ptrdiff_t b_stride = b_desc->stride(0); + size_t length = a_desc->numel(); + ptrdiff_t a_stride = a_desc->stride(0); + ptrdiff_t b_stride = b_desc->stride(0); - *desc_ptr = new Descriptor(in_dtype, out_dtype, length, a_stride, b_stride, - handle->device, handle->device_id); + *desc_ptr = new Descriptor(in_dtype, out_dtype, length, a_stride, b_stride, + handle->device, handle->device_id); - return INFINI_STATUS_SUCCESS; + return INFINI_STATUS_SUCCESS; } infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, void *out, const void *a, const void *b, void *stream) const { - auto cuda_stream = reinterpret_cast(stream); - constexpr unsigned int BLOCK_SIZE = 256; + auto cuda_stream = reinterpret_cast(stream); + constexpr unsigned int BLOCK_SIZE = 256; - switch (_in_dtype) { - case INFINI_DTYPE_F32: { - float *out_f = reinterpret_cast(out); - const float *a_f = reinterpret_cast(a); - const float *b_f = reinterpret_cast(b); - op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_f, a_f, b_f, _length, _a_stride, - _b_stride); - CHECK_CUDA(cudaGetLastError()); - break; - } - case INFINI_DTYPE_F64: { - double *out_d = reinterpret_cast(out); - const double *a_d = reinterpret_cast(a); - const double *b_d = reinterpret_cast(b); - op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_d, a_d, b_d, _length, _a_stride, - _b_stride); - CHECK_CUDA(cudaGetLastError()); - break; - } - case INFINI_DTYPE_F16: { - // For FP16, accumulate in float, then cast back to half - // Use workspace for temporary float buffer - if (workspace_size < sizeof(float)) { - return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + switch (_in_dtype) { + case INFINI_DTYPE_F32: { + float *out_f = reinterpret_cast(out); + const float *a_f = reinterpret_cast(a); + const float *b_f = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_f, a_f, b_f, _length, _a_stride, + _b_stride); + CHECK_CUDA(cudaGetLastError()); + break; + } + case INFINI_DTYPE_F64: { + double *out_d = reinterpret_cast(out); + const double *a_d = reinterpret_cast(a); + const double *b_d = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_d, a_d, b_d, _length, _a_stride, + _b_stride); + CHECK_CUDA(cudaGetLastError()); + break; } - float *tmp_out = reinterpret_cast(workspace); - { - // If workspace is too small, we need to allocate - // For simplicity, use a device-side kernel that writes directly to out - // But we need float accumulation, so use a temporary approach - const __half *a_h = reinterpret_cast(a); - const __half *b_h = reinterpret_cast(b); - // Launch kernel that accumulates in float and writes half result - op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_h, b_h, _length, - _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); - // Use a simple device kernel to cast float to half - // For now, copy to host, cast, and copy back - float result_f; - CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), - cudaMemcpyDeviceToHost, cuda_stream)); - CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); - __half h_result = __float2half(result_f); - CHECK_CUDA(cudaMemcpyAsync(out, &h_result, sizeof(__half), - cudaMemcpyHostToDevice, cuda_stream)); + case INFINI_DTYPE_F16: { + // For FP16, accumulate in float, then cast back to half + // Use workspace for temporary float buffer + if (workspace_size < sizeof(float)) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *tmp_out = reinterpret_cast(workspace); + { + // If workspace is too small, we need to allocate + // For simplicity, use a device-side kernel that writes directly to out + // But we need float accumulation, so use a temporary approach + const __half *a_h = reinterpret_cast(a); + const __half *b_h = reinterpret_cast(b); + // Launch kernel that accumulates in float and writes half result + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_h, b_h, _length, + _a_stride, _b_stride); + CHECK_CUDA(cudaGetLastError()); + // Use a simple device kernel to cast float to half + // For now, copy to host, cast, and copy back + float result_f; + CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + __half h_result = __float2half(result_f); + CHECK_CUDA(cudaMemcpyAsync(out, &h_result, sizeof(__half), + cudaMemcpyHostToDevice, cuda_stream)); + } + break; } - break; - } - case INFINI_DTYPE_BF16: { - // For BF16, accumulate in float, then cast back - if (workspace_size < sizeof(float)) { - return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + case INFINI_DTYPE_BF16: { + // For BF16, accumulate in float, then cast back + if (workspace_size < sizeof(float)) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *tmp_out = reinterpret_cast(workspace); + { + const __nv_bfloat16 *a_bf = reinterpret_cast(a); + const __nv_bfloat16 *b_bf = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_bf, b_bf, _length, + _a_stride, _b_stride); + CHECK_CUDA(cudaGetLastError()); + float result_f; + CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + __nv_bfloat16 bf_result = __float2bfloat16(result_f); + CHECK_CUDA(cudaMemcpyAsync(out, &bf_result, sizeof(__nv_bfloat16), + cudaMemcpyHostToDevice, cuda_stream)); + } + break; } - float *tmp_out = reinterpret_cast(workspace); - { - const __nv_bfloat16 *a_bf = reinterpret_cast(a); - const __nv_bfloat16 *b_bf = reinterpret_cast(b); - op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_bf, b_bf, _length, - _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); - float result_f; - CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), - cudaMemcpyDeviceToHost, cuda_stream)); - CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); - __nv_bfloat16 bf_result = __float2bfloat16(result_f); - CHECK_CUDA(cudaMemcpyAsync(out, &bf_result, sizeof(__nv_bfloat16), - cudaMemcpyHostToDevice, cuda_stream)); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; } - break; - } - default: - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - return INFINI_STATUS_SUCCESS; + return INFINI_STATUS_SUCCESS; } } // namespace op::vdot::moore diff --git a/src/infiniop/ops/vdot/moore/vdot_moore.h b/src/infiniop/ops/vdot/moore/vdot_moore.h index 6bd36593d..912326f22 100644 --- a/src/infiniop/ops/vdot/moore/vdot_moore.h +++ b/src/infiniop/ops/vdot/moore/vdot_moore.h @@ -9,36 +9,36 @@ namespace op::vdot::moore { class Descriptor final : public InfiniopDescriptor { - infiniDtype_t _in_dtype; - infiniDtype_t _out_dtype; - size_t _length; - ptrdiff_t _a_stride; - ptrdiff_t _b_stride; + infiniDtype_t _in_dtype; + infiniDtype_t _out_dtype; + size_t _length; + ptrdiff_t _a_stride; + ptrdiff_t _b_stride; public: - Descriptor(infiniDtype_t in_dtype, infiniDtype_t out_dtype, size_t length, - ptrdiff_t a_stride, ptrdiff_t b_stride, infiniDevice_t device_type, - int device_id) - : InfiniopDescriptor{device_type, device_id}, _in_dtype(in_dtype), - _out_dtype(out_dtype), _length(length), _a_stride(a_stride), - _b_stride(b_stride) {} - - ~Descriptor(); - - size_t workspaceSize() const { - // Need workspace for FP16/BF16 to accumulate in float - return (_in_dtype == INFINI_DTYPE_F16 || _in_dtype == INFINI_DTYPE_BF16) - ? sizeof(float) - : 0; - } - - static infiniStatus_t create(infiniopHandle_t handle, Descriptor **desc_ptr, - infiniopTensorDescriptor_t out_desc, - infiniopTensorDescriptor_t a_desc, - infiniopTensorDescriptor_t b_desc); - - infiniStatus_t calculate(void *workspace, size_t workspace_size, void *out, - const void *a, const void *b, void *stream) const; + Descriptor(infiniDtype_t in_dtype, infiniDtype_t out_dtype, size_t length, + ptrdiff_t a_stride, ptrdiff_t b_stride, infiniDevice_t device_type, + int device_id) + : InfiniopDescriptor{device_type, device_id}, _in_dtype(in_dtype), + _out_dtype(out_dtype), _length(length), _a_stride(a_stride), + _b_stride(b_stride) {} + + ~Descriptor(); + + size_t workspaceSize() const { + // Need workspace for FP16/BF16 to accumulate in float + return (_in_dtype == INFINI_DTYPE_F16 || _in_dtype == INFINI_DTYPE_BF16) + ? sizeof(float) + : 0; + } + + static infiniStatus_t create(infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + + infiniStatus_t calculate(void *workspace, size_t workspace_size, void *out, + const void *a, const void *b, void *stream) const; }; } // namespace op::vdot::moore diff --git a/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cu b/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cu index 356e16901..1b9061bac 100644 --- a/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cu +++ b/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cu @@ -13,155 +13,154 @@ infiniStatus_t Descriptor::create(infiniopHandle_t handle_, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) { - auto handle = reinterpret_cast(handle_); - auto in_dtype = a_desc->dtype(); - auto b_dtype = b_desc->dtype(); - auto out_dtype = out_desc->dtype(); - - // Inputs must be 1D vectors with same length - if (a_desc->ndim() != 1 || b_desc->ndim() != 1) { - return INFINI_STATUS_BAD_TENSOR_SHAPE; - } - if (a_desc->numel() != b_desc->numel()) { - return INFINI_STATUS_BAD_TENSOR_SHAPE; - } - - // Input dtypes must match and be in supported set - CHECK_OR_RETURN(in_dtype == b_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_DTYPE(in_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, - INFINI_DTYPE_BF16); - - // Output dtype equals input dtype - CHECK_OR_RETURN(out_dtype == in_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - - size_t length = a_desc->numel(); - ptrdiff_t a_stride = a_desc->stride(0); - ptrdiff_t b_stride = b_desc->stride(0); - - *desc_ptr = - new Descriptor(in_dtype, out_dtype, length, a_stride, b_stride, - handle->internal(), handle->device, handle->device_id); - - return INFINI_STATUS_SUCCESS; + auto handle = reinterpret_cast(handle_); + auto in_dtype = a_desc->dtype(); + auto b_dtype = b_desc->dtype(); + auto out_dtype = out_desc->dtype(); + + // Inputs must be 1D vectors with same length + if (a_desc->ndim() != 1 || b_desc->ndim() != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (a_desc->numel() != b_desc->numel()) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // Input dtypes must match and be in supported set + CHECK_OR_RETURN(in_dtype == b_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_DTYPE(in_dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, + INFINI_DTYPE_BF16); + + // Output dtype equals input dtype + CHECK_OR_RETURN(out_dtype == in_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + + size_t length = a_desc->numel(); + ptrdiff_t a_stride = a_desc->stride(0); + ptrdiff_t b_stride = b_desc->stride(0); + + *desc_ptr = new Descriptor(in_dtype, out_dtype, length, a_stride, b_stride, + handle->internal(), handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; } infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, void *out, const void *a, const void *b, void *stream) const { - auto cuda_stream = reinterpret_cast(stream); - - // For FP16/BF16, use CUDA kernel instead of cuBLAS - if (_in_dtype == INFINI_DTYPE_F16 || _in_dtype == INFINI_DTYPE_BF16) { - switch (_in_dtype) { - case INFINI_DTYPE_F16: { - if (workspace_size < sizeof(float)) { - return INFINI_STATUS_INSUFFICIENT_WORKSPACE; - } - float *tmp_out = reinterpret_cast(workspace); - const __half *a_h = reinterpret_cast(a); - const __half *b_h = reinterpret_cast(b); - constexpr unsigned int BLOCK_SIZE = 256; - op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_h, b_h, _length, - _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); - float result_f; - CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), - cudaMemcpyDeviceToHost, cuda_stream)); - CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); - __half h_result = __float2half(result_f); - CHECK_CUDA(cudaMemcpyAsync(out, &h_result, sizeof(__half), - cudaMemcpyHostToDevice, cuda_stream)); - break; - } - case INFINI_DTYPE_BF16: { - if (workspace_size < sizeof(float)) { - return INFINI_STATUS_INSUFFICIENT_WORKSPACE; - } - float *tmp_out = reinterpret_cast(workspace); - const __nv_bfloat16 *a_bf = reinterpret_cast(a); - const __nv_bfloat16 *b_bf = reinterpret_cast(b); - constexpr unsigned int BLOCK_SIZE = 256; - op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_bf, b_bf, _length, - _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); - float result_f; - CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), - cudaMemcpyDeviceToHost, cuda_stream)); - CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); - __nv_bfloat16 bf_result = __float2bfloat16(result_f); - CHECK_CUDA(cudaMemcpyAsync(out, &bf_result, sizeof(__nv_bfloat16), - cudaMemcpyHostToDevice, cuda_stream)); - break; + auto cuda_stream = reinterpret_cast(stream); + + // For FP16/BF16, use CUDA kernel instead of cuBLAS + if (_in_dtype == INFINI_DTYPE_F16 || _in_dtype == INFINI_DTYPE_BF16) { + switch (_in_dtype) { + case INFINI_DTYPE_F16: { + if (workspace_size < sizeof(float)) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *tmp_out = reinterpret_cast(workspace); + const __half *a_h = reinterpret_cast(a); + const __half *b_h = reinterpret_cast(b); + constexpr unsigned int BLOCK_SIZE = 256; + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_h, b_h, _length, + _a_stride, _b_stride); + CHECK_CUDA(cudaGetLastError()); + float result_f; + CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + __half h_result = __float2half(result_f); + CHECK_CUDA(cudaMemcpyAsync(out, &h_result, sizeof(__half), + cudaMemcpyHostToDevice, cuda_stream)); + break; + } + case INFINI_DTYPE_BF16: { + if (workspace_size < sizeof(float)) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + float *tmp_out = reinterpret_cast(workspace); + const __nv_bfloat16 *a_bf = reinterpret_cast(a); + const __nv_bfloat16 *b_bf = reinterpret_cast(b); + constexpr unsigned int BLOCK_SIZE = 256; + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_bf, b_bf, _length, + _a_stride, _b_stride); + CHECK_CUDA(cudaGetLastError()); + float result_f; + CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + __nv_bfloat16 bf_result = __float2bfloat16(result_f); + CHECK_CUDA(cudaMemcpyAsync(out, &bf_result, sizeof(__nv_bfloat16), + cudaMemcpyHostToDevice, cuda_stream)); + break; + } + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; } - default: - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - return INFINI_STATUS_SUCCESS; - } - - // Use cuBLAS for F32/F64 - CHECK_STATUS(_internal->useCublas(cuda_stream, [&](cublasHandle_t handle) { - switch (_in_dtype) { - case INFINI_DTYPE_F32: { - if (_a_stride == 1 && _b_stride == 1) { - // Contiguous case: use cublasSdot - float result; - CHECK_CUBLAS(cublasSdot(handle, static_cast(_length), - reinterpret_cast(a), - static_cast(_a_stride), - reinterpret_cast(b), - static_cast(_b_stride), &result)); - CHECK_CUDA(cudaMemcpyAsync(out, &result, sizeof(float), - cudaMemcpyHostToDevice, cuda_stream)); - } else { - // Strided case: use cublasDotEx - float result; - CHECK_CUBLAS(cublasDotEx(handle, static_cast(_length), - reinterpret_cast(a), CUDA_R_32F, - static_cast(_a_stride), - reinterpret_cast(b), CUDA_R_32F, - static_cast(_b_stride), &result, - CUDA_R_32F, CUDA_R_32F)); - CHECK_CUDA(cudaMemcpyAsync(out, &result, sizeof(float), - cudaMemcpyHostToDevice, cuda_stream)); - } - break; - } - case INFINI_DTYPE_F64: { - if (_a_stride == 1 && _b_stride == 1) { - // Contiguous case: use cublasDdot - double result; - CHECK_CUBLAS(cublasDdot(handle, static_cast(_length), - reinterpret_cast(a), - static_cast(_a_stride), - reinterpret_cast(b), - static_cast(_b_stride), &result)); - CHECK_CUDA(cudaMemcpyAsync(out, &result, sizeof(double), - cudaMemcpyHostToDevice, cuda_stream)); - } else { - // Strided case: use cublasDotEx - double result; - CHECK_CUBLAS(cublasDotEx(handle, static_cast(_length), - reinterpret_cast(a), - CUDA_R_64F, static_cast(_a_stride), - reinterpret_cast(b), - CUDA_R_64F, static_cast(_b_stride), - &result, CUDA_R_64F, CUDA_R_64F)); - CHECK_CUDA(cudaMemcpyAsync(out, &result, sizeof(double), - cudaMemcpyHostToDevice, cuda_stream)); - } - break; - } - default: - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - return INFINI_STATUS_SUCCESS; - })); - return INFINI_STATUS_SUCCESS; + // Use cuBLAS for F32/F64 + CHECK_STATUS(_internal->useCublas(cuda_stream, [&](cublasHandle_t handle) { + switch (_in_dtype) { + case INFINI_DTYPE_F32: { + if (_a_stride == 1 && _b_stride == 1) { + // Contiguous case: use cublasSdot + float result; + CHECK_CUBLAS(cublasSdot(handle, static_cast(_length), + reinterpret_cast(a), + static_cast(_a_stride), + reinterpret_cast(b), + static_cast(_b_stride), &result)); + CHECK_CUDA(cudaMemcpyAsync(out, &result, sizeof(float), + cudaMemcpyHostToDevice, cuda_stream)); + } else { + // Strided case: use cublasDotEx + float result; + CHECK_CUBLAS(cublasDotEx(handle, static_cast(_length), + reinterpret_cast(a), CUDA_R_32F, + static_cast(_a_stride), + reinterpret_cast(b), CUDA_R_32F, + static_cast(_b_stride), &result, + CUDA_R_32F, CUDA_R_32F)); + CHECK_CUDA(cudaMemcpyAsync(out, &result, sizeof(float), + cudaMemcpyHostToDevice, cuda_stream)); + } + break; + } + case INFINI_DTYPE_F64: { + if (_a_stride == 1 && _b_stride == 1) { + // Contiguous case: use cublasDdot + double result; + CHECK_CUBLAS(cublasDdot(handle, static_cast(_length), + reinterpret_cast(a), + static_cast(_a_stride), + reinterpret_cast(b), + static_cast(_b_stride), &result)); + CHECK_CUDA(cudaMemcpyAsync(out, &result, sizeof(double), + cudaMemcpyHostToDevice, cuda_stream)); + } else { + // Strided case: use cublasDotEx + double result; + CHECK_CUBLAS(cublasDotEx(handle, static_cast(_length), + reinterpret_cast(a), + CUDA_R_64F, static_cast(_a_stride), + reinterpret_cast(b), + CUDA_R_64F, static_cast(_b_stride), + &result, CUDA_R_64F, CUDA_R_64F)); + CHECK_CUDA(cudaMemcpyAsync(out, &result, sizeof(double), + cudaMemcpyHostToDevice, cuda_stream)); + } + break; + } + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; + })); + + return INFINI_STATUS_SUCCESS; } } // namespace op::vdot::nvidia diff --git a/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cuh b/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cuh index 9c25fc140..d4d082705 100644 --- a/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cuh +++ b/src/infiniop/ops/vdot/nvidia/vdot_nvidia.cuh @@ -8,38 +8,38 @@ namespace op::vdot::nvidia { class Descriptor final : public InfiniopDescriptor { - infiniDtype_t _in_dtype; - infiniDtype_t _out_dtype; - size_t _length; - ptrdiff_t _a_stride; - ptrdiff_t _b_stride; - std::shared_ptr _internal; + infiniDtype_t _in_dtype; + infiniDtype_t _out_dtype; + size_t _length; + ptrdiff_t _a_stride; + ptrdiff_t _b_stride; + std::shared_ptr _internal; public: - Descriptor(infiniDtype_t in_dtype, infiniDtype_t out_dtype, size_t length, - ptrdiff_t a_stride, ptrdiff_t b_stride, - std::shared_ptr internal, - infiniDevice_t device_type, int device_id) - : InfiniopDescriptor{device_type, device_id}, _in_dtype(in_dtype), - _out_dtype(out_dtype), _length(length), _a_stride(a_stride), - _b_stride(b_stride), _internal(internal) {} - - ~Descriptor(); - - size_t workspaceSize() const { - // Need workspace for FP16/BF16 to accumulate in float - return (_in_dtype == INFINI_DTYPE_F16 || _in_dtype == INFINI_DTYPE_BF16) - ? sizeof(float) - : 0; - } - - static infiniStatus_t create(infiniopHandle_t handle, Descriptor **desc_ptr, - infiniopTensorDescriptor_t out_desc, - infiniopTensorDescriptor_t a_desc, - infiniopTensorDescriptor_t b_desc); - - infiniStatus_t calculate(void *workspace, size_t workspace_size, void *out, - const void *a, const void *b, void *stream) const; + Descriptor(infiniDtype_t in_dtype, infiniDtype_t out_dtype, size_t length, + ptrdiff_t a_stride, ptrdiff_t b_stride, + std::shared_ptr internal, + infiniDevice_t device_type, int device_id) + : InfiniopDescriptor{device_type, device_id}, _in_dtype(in_dtype), + _out_dtype(out_dtype), _length(length), _a_stride(a_stride), + _b_stride(b_stride), _internal(internal) {} + + ~Descriptor(); + + size_t workspaceSize() const { + // Need workspace for FP16/BF16 to accumulate in float + return (_in_dtype == INFINI_DTYPE_F16 || _in_dtype == INFINI_DTYPE_BF16) + ? sizeof(float) + : 0; + } + + static infiniStatus_t create(infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc); + + infiniStatus_t calculate(void *workspace, size_t workspace_size, void *out, + const void *a, const void *b, void *stream) const; }; } // namespace op::vdot::nvidia diff --git a/src/infiniop/ops/vdot/operator.cc b/src/infiniop/ops/vdot/operator.cc index bbf3b088c..ca25c18a4 100644 --- a/src/infiniop/ops/vdot/operator.cc +++ b/src/infiniop/ops/vdot/operator.cc @@ -22,12 +22,12 @@ __C infiniStatus_t infiniopCreateVdotDescriptor( infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) { -#define CREATE(CASE, NAMESPACE) \ - case CASE: \ - return op::vdot::NAMESPACE::Descriptor::create( \ - handle, \ - reinterpret_cast(desc_ptr), \ - out_desc, \ +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::vdot::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ a_desc, b_desc) switch (handle->device) { @@ -60,8 +60,8 @@ __C infiniStatus_t infiniopGetVdotWorkspaceSize( infiniopVdotDescriptor_t desc, size_t *size) { -#define GET(CASE, NAMESPACE) \ - case CASE: \ +#define GET(CASE, NAMESPACE) \ + case CASE: \ *size = reinterpret_cast(desc)->workspaceSize(); \ return INFINI_STATUS_SUCCESS; @@ -133,9 +133,9 @@ __C infiniStatus_t infiniopVdot( __C infiniStatus_t infiniopDestroyVdotDescriptor(infiniopVdotDescriptor_t desc) { -#define DELETE(CASE, NAMESPACE) \ - case CASE: \ - delete reinterpret_cast(desc); \ +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ return INFINI_STATUS_SUCCESS; switch (desc->device_type) { @@ -163,5 +163,3 @@ __C infiniStatus_t infiniopDestroyVdotDescriptor(infiniopVdotDescriptor_t desc) #undef DELETE } - - diff --git a/src/infiniop/ops/where/cpu/where_cpu.cc b/src/infiniop/ops/where/cpu/where_cpu.cc index 26befdaf3..3bf58e55f 100644 --- a/src/infiniop/ops/where/cpu/where_cpu.cc +++ b/src/infiniop/ops/where/cpu/where_cpu.cc @@ -88,5 +88,3 @@ infiniStatus_t Descriptor::calculate( } } // namespace op::where::cpu - - diff --git a/src/infiniop/ops/where/cpu/where_cpu.h b/src/infiniop/ops/where/cpu/where_cpu.h index 6b4399a7c..c706e1edb 100644 --- a/src/infiniop/ops/where/cpu/where_cpu.h +++ b/src/infiniop/ops/where/cpu/where_cpu.h @@ -29,5 +29,3 @@ struct WhereOp { } // namespace op::where::cpu #endif // __WHERE_CPU_H__ - - diff --git a/src/infiniop/ops/where/cpu/where_indices_cpu.cc b/src/infiniop/ops/where/cpu/where_indices_cpu.cc new file mode 100644 index 000000000..864dab738 --- /dev/null +++ b/src/infiniop/ops/where/cpu/where_indices_cpu.cc @@ -0,0 +1,86 @@ +#include "where_indices_cpu.h" +#include +#include +#include + +namespace op::where::cpu { + +infiniStatus_t IndicesDescriptor::create( + infiniopHandle_t handle_, + IndicesDescriptor **desc_ptr, + infiniopTensorDescriptor_t cond_desc) { + + // 检查条件必须是 bool 类型 + if (cond_desc->dtype() != INFINI_DTYPE_BOOL) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + size_t numel = cond_desc->numel(); + int ndim = static_cast(cond_desc->ndim()); + + std::vector shape(ndim); + std::vector strides(ndim); + for (int i = 0; i < ndim; ++i) { + shape[i] = cond_desc->shape()[i]; + strides[i] = cond_desc->stride(i); + } + + *desc_ptr = new IndicesDescriptor( + numel, ndim, shape.data(), strides.data(), + INFINI_DEVICE_CPU, 0); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t IndicesDescriptor::calculate( + void *workspace, + size_t workspace_size, + void **outputs, + const void *cond, + void *stream, + size_t *num_true) const { + + const bool *cond_ptr = reinterpret_cast(cond); + int64_t **output_ptrs = new int64_t *[_ndim]; + for (int i = 0; i < _ndim; ++i) { + output_ptrs[i] = reinterpret_cast(outputs[i]); + } + + // 使用递归函数遍历所有多维索引,正确处理 strided tensor + std::vector indices(_ndim, 0); + size_t output_idx = 0; + + // 递归函数来遍历所有多维索引 + std::function traverse = [&](int dim) { + if (dim == _ndim) { + // 计算内存偏移(考虑 stride) + size_t offset = 0; + for (int i = 0; i < _ndim; ++i) { + offset += indices[i] * static_cast(_strides[i]); + } + + // 检查条件是否为 True + if (cond_ptr[offset]) { + // 记录多维索引 + for (int i = 0; i < _ndim; ++i) { + output_ptrs[i][output_idx] = static_cast(indices[i]); + } + output_idx++; + } + } else { + // 递归遍历当前维度的所有可能值 + for (size_t i = 0; i < _shape[dim]; ++i) { + indices[dim] = i; + traverse(dim + 1); + } + } + }; + + traverse(0); + + *num_true = output_idx; + delete[] output_ptrs; + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::where::cpu diff --git a/src/infiniop/ops/where/cpu/where_indices_cpu.h b/src/infiniop/ops/where/cpu/where_indices_cpu.h new file mode 100644 index 000000000..3712533f7 --- /dev/null +++ b/src/infiniop/ops/where/cpu/where_indices_cpu.h @@ -0,0 +1,57 @@ +#ifndef __WHERE_INDICES_CPU_H__ +#define __WHERE_INDICES_CPU_H__ + +#include "../../../operator.h" +#include "../../../tensor.h" + +namespace op::where::cpu { + +class IndicesDescriptor final : public InfiniopDescriptor { + size_t _numel; + int _ndim; + size_t *_shape; + ptrdiff_t *_strides; + +public: + IndicesDescriptor( + size_t numel, + int ndim, + const size_t *shape, + const ptrdiff_t *strides, + infiniDevice_t device_type, + int device_id) + : InfiniopDescriptor{device_type, device_id}, + _numel(numel), + _ndim(ndim) { + _shape = new size_t[ndim]; + _strides = new ptrdiff_t[ndim]; + for (int i = 0; i < ndim; ++i) { + _shape[i] = shape[i]; + _strides[i] = strides[i]; + } + } + + ~IndicesDescriptor() { + delete[] _shape; + delete[] _strides; + } + + size_t workspaceSize() const { return 0; } + + static infiniStatus_t create( + infiniopHandle_t handle, + IndicesDescriptor **desc_ptr, + infiniopTensorDescriptor_t cond_desc); + + infiniStatus_t calculate( + void *workspace, + size_t workspace_size, + void **outputs, + const void *cond, + void *stream, + size_t *num_true) const; +}; + +} // namespace op::where::cpu + +#endif // __WHERE_INDICES_CPU_H__ diff --git a/src/infiniop/ops/where/cuda/where_indices_kernel.cuh b/src/infiniop/ops/where/cuda/where_indices_kernel.cuh new file mode 100644 index 000000000..f00763153 --- /dev/null +++ b/src/infiniop/ops/where/cuda/where_indices_kernel.cuh @@ -0,0 +1,74 @@ +#ifndef __WHERE_INDICES_KERNEL_CUH__ +#define __WHERE_INDICES_KERNEL_CUH__ + +#include +#include + +namespace op::where::cuda { + +// 阶段1: 标记 True 元素 (将 bool 转换为 int64_t: 1 或 0) +// 支持 strided tensor:使用线性索引转换为多维索引,然后使用 stride 计算内存偏移 +template +__global__ void markTrueElements( + Tidx *flags, // 输出:每个元素是否为 True (1/0) + const bool *cond, // 输入:条件张量 + const size_t *shape, // 输入:张量形状 + const ptrdiff_t *strides, // 输入:张量 stride + size_t numel, + int ndim) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) { + // 将线性索引转换为多维索引 + size_t remaining = idx; + size_t offset = 0; + for (int dim = ndim - 1; dim >= 0; --dim) { + size_t dim_idx = remaining % shape[dim]; + offset += dim_idx * static_cast(strides[dim]); + remaining /= shape[dim]; + } + flags[idx] = cond[offset] ? static_cast(1) : static_cast(0); + } +} + +// 阶段2: 收集每个维度的索引 +// 对于 N 维张量,需要为每个维度收集索引 +template +__global__ void collectIndices( + Tidx **outputs, // 输出:NDIM 个索引张量的指针数组(在设备上) + const Tidx *flags, // 输入:标记数组(前缀和后) + const bool *cond, // 输入:条件张量(未使用,可为 nullptr) + const size_t *shape, // 输入:张量形状(在设备上) + const ptrdiff_t *strides, // 输入:张量 stride(在设备上,未在此 kernel 中使用) + size_t numel, + int ndim) { + (void)cond; + (void)strides; + + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < numel) { + // 通过 flags 判断该线性索引是否对应 True 元素: + // 对于 True 元素,flags[idx] 会比前一个位置的值大 1 + Tidx curr = flags[idx]; + Tidx prev = (idx == 0) ? static_cast(0) : flags[idx - 1]; + if (curr == prev) { + return; + } + + // 计算当前元素在输出中的位置(使用前缀和结果) + // flags[idx] 是 inclusive sum,所以 flags[idx] - 1 是当前元素在输出中的位置 + // 对于第一个元素(idx=0),如果它是 True,flags[0] = 1,所以 output_idx = 0 + Tidx output_idx = curr - 1; + + // 线性索引 -> 多维索引 + size_t remaining = idx; + for (int dim = ndim - 1; dim >= 0; --dim) { + size_t dim_idx = remaining % shape[dim]; + outputs[dim][output_idx] = static_cast(dim_idx); + remaining /= shape[dim]; + } + } +} + +} // namespace op::where::cuda + +#endif // __WHERE_INDICES_KERNEL_CUH__ diff --git a/src/infiniop/ops/where/metax/where_indices_metax.cu b/src/infiniop/ops/where/metax/where_indices_metax.cu new file mode 100644 index 000000000..d205ae5f9 --- /dev/null +++ b/src/infiniop/ops/where/metax/where_indices_metax.cu @@ -0,0 +1,139 @@ +#include "../../../devices/metax/metax_handle.h" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../cuda/where_indices_kernel.cuh" +#include "where_indices_metax.h" +#include +#include + +namespace op::where::metax { + +// 封装 CUB InclusiveSum +template +static hcError_t inclusiveSum(void *workspace_ptr, size_t &workspace_len, + T *data, int n, hcStream_t stream) { + return cub::DeviceScan::InclusiveSum(workspace_ptr, workspace_len, data, data, + n, stream); +} + +// 地址对齐到 256 +static constexpr size_t align256(size_t size) { return (size + 255) & (~255); } + +size_t IndicesDescriptor::workspaceSize() const { + const auto n = static_cast(_numel); + + // flags 数组大小 + size_t flags_size = align256(sizeof(int64_t) * _numel); + + // CUB scan workspace + size_t scan_workspace = 0; + int64_t *dummy = nullptr; + CHECK_CUDA(inclusiveSum(nullptr, scan_workspace, dummy, n, nullptr)); + + return flags_size + scan_workspace; +} + +infiniStatus_t IndicesDescriptor::create(infiniopHandle_t handle_, + IndicesDescriptor **desc_ptr, + infiniopTensorDescriptor_t cond_desc) { + + auto handle = reinterpret_cast(handle_); + + // 检查条件必须是 bool 类型 + if (cond_desc->dtype() != INFINI_DTYPE_BOOL) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + size_t numel = cond_desc->numel(); + int ndim = static_cast(cond_desc->ndim()); + + std::vector shape(ndim); + std::vector strides(ndim); + for (int i = 0; i < ndim; ++i) { + shape[i] = cond_desc->shape()[i]; + strides[i] = cond_desc->stride(i); + } + + *desc_ptr = new IndicesDescriptor(numel, ndim, shape.data(), strides.data(), + handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t IndicesDescriptor::calculate(void *workspace, + size_t workspace_size, + void **outputs, const void *cond, + void *stream, + size_t *num_true) const { + + auto cuda_stream = reinterpret_cast(stream); + const bool *cond_ptr = reinterpret_cast(cond); + + // 分配 workspace 中的内存 + int64_t *flags = reinterpret_cast(workspace); + size_t flags_size = align256(sizeof(int64_t) * _numel); + void *scan_workspace = static_cast(workspace) + flags_size; + size_t scan_workspace_size = workspace_size - flags_size; + + // 复制 shape 和 strides 到设备(用于 markTrueElements) + size_t *d_shape; + ptrdiff_t *d_strides; + CHECK_CUDA(cudaMallocAsync(&d_shape, sizeof(size_t) * _ndim, cuda_stream)); + CHECK_CUDA( + cudaMallocAsync(&d_strides, sizeof(ptrdiff_t) * _ndim, cuda_stream)); + CHECK_CUDA(cudaMemcpyAsync(d_shape, _shape, sizeof(size_t) * _ndim, + cudaMemcpyHostToDevice, cuda_stream)); + CHECK_CUDA(cudaMemcpyAsync(d_strides, _strides, sizeof(ptrdiff_t) * _ndim, + cudaMemcpyHostToDevice, cuda_stream)); + + // 阶段1: 标记 True 元素 + constexpr int BLOCK_SIZE = 256; + int grid_size = static_cast((_numel + BLOCK_SIZE - 1) / BLOCK_SIZE); + op::where::cuda::markTrueElements + <<>>(flags, cond_ptr, d_shape, + d_strides, _numel, _ndim); + CHECK_CUDA(cudaGetLastError()); + + // 阶段2: 计算前缀和(inclusive scan) + size_t temp_workspace_size = scan_workspace_size; + CHECK_CUDA(inclusiveSum(scan_workspace, temp_workspace_size, flags, + static_cast(_numel), cuda_stream)); + + // 获取 True 元素的总数 + int64_t num_true_val = 0; + CHECK_CUDA(cudaMemcpyAsync(&num_true_val, flags + _numel - 1, sizeof(int64_t), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + *num_true = static_cast(num_true_val); + + // 阶段3: 收集每个维度的索引 + int64_t **output_ptrs = new int64_t *[_ndim]; + for (int i = 0; i < _ndim; ++i) { + output_ptrs[i] = reinterpret_cast(outputs[i]); + } + + // d_shape / d_strides 已经在阶段1中分配和复制了,这里直接使用 + + // 复制 output_ptrs 到设备 + int64_t **d_output_ptrs; + CHECK_CUDA( + cudaMallocAsync(&d_output_ptrs, sizeof(int64_t *) * _ndim, cuda_stream)); + CHECK_CUDA(cudaMemcpyAsync(d_output_ptrs, output_ptrs, + sizeof(int64_t *) * _ndim, cudaMemcpyHostToDevice, + cuda_stream)); + + // 启动收集索引的 kernel + op::where::cuda::collectIndices + <<>>( + d_output_ptrs, flags, cond_ptr, d_shape, d_strides, _numel, _ndim); + CHECK_CUDA(cudaGetLastError()); + + // 清理 + CHECK_CUDA(cudaFreeAsync(d_shape, cuda_stream)); + CHECK_CUDA(cudaFreeAsync(d_strides, cuda_stream)); + CHECK_CUDA(cudaFreeAsync(d_output_ptrs, cuda_stream)); + delete[] output_ptrs; + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::where::metax diff --git a/src/infiniop/ops/where/metax/where_indices_metax.h b/src/infiniop/ops/where/metax/where_indices_metax.h new file mode 100644 index 000000000..42c606153 --- /dev/null +++ b/src/infiniop/ops/where/metax/where_indices_metax.h @@ -0,0 +1,58 @@ +#ifndef __WHERE_INDICES_METAX_H__ +#define __WHERE_INDICES_METAX_H__ + +#include "../../../devices/metax/metax_handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +namespace op::where::metax { + +class IndicesDescriptor final : public InfiniopDescriptor { + size_t _numel; + int _ndim; + size_t *_shape; + ptrdiff_t *_strides; + +public: + IndicesDescriptor( + size_t numel, + int ndim, + const size_t *shape, + const ptrdiff_t *strides, + infiniDevice_t device_type, + int device_id) + : InfiniopDescriptor{device_type, device_id}, + _numel(numel), + _ndim(ndim) { + _shape = new size_t[ndim]; + _strides = new ptrdiff_t[ndim]; + for (int i = 0; i < ndim; ++i) { + _shape[i] = shape[i]; + _strides[i] = strides[i]; + } + } + + ~IndicesDescriptor() { + delete[] _shape; + delete[] _strides; + } + + size_t workspaceSize() const; + + static infiniStatus_t create( + infiniopHandle_t handle, + IndicesDescriptor **desc_ptr, + infiniopTensorDescriptor_t cond_desc); + + infiniStatus_t calculate( + void *workspace, + size_t workspace_size, + void **outputs, + const void *cond, + void *stream, + size_t *num_true) const; +}; + +} // namespace op::where::metax + +#endif // __WHERE_INDICES_METAX_H__ diff --git a/src/infiniop/ops/where/moore/where_indices_moore.cu b/src/infiniop/ops/where/moore/where_indices_moore.cu new file mode 100644 index 000000000..b35bb1b9e --- /dev/null +++ b/src/infiniop/ops/where/moore/where_indices_moore.cu @@ -0,0 +1,153 @@ +#include "../../../devices/moore/moore_handle.h" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../cuda/where_indices_kernel.cuh" +#include "where_indices_moore.h" +#include +#include + +namespace op::where::moore { + +// 封装 CUB InclusiveSum +template +static musaError_t inclusiveSum( + void *workspace_ptr, size_t &workspace_len, + T *data, int n, + musaStream_t stream) { + return cub::DeviceScan::InclusiveSum( + workspace_ptr, workspace_len, + data, data, n, + stream); +} + +// 地址对齐到 256 +static constexpr size_t align256(size_t size) { + return (size + 255) & (~255); +} + +size_t IndicesDescriptor::workspaceSize() const { + const auto n = static_cast(_numel); + + // flags 数组大小 + size_t flags_size = align256(sizeof(int64_t) * _numel); + + // CUB scan workspace + size_t scan_workspace = 0; + int64_t *dummy = nullptr; + CHECK_CUDA(inclusiveSum( + nullptr, scan_workspace, + dummy, n, + nullptr)); + + return flags_size + scan_workspace; +} + +infiniStatus_t IndicesDescriptor::create( + infiniopHandle_t handle_, + IndicesDescriptor **desc_ptr, + infiniopTensorDescriptor_t cond_desc) { + + auto handle = reinterpret_cast(handle_); + + // 检查条件必须是 bool 类型 + if (cond_desc->dtype() != INFINI_DTYPE_BOOL) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + size_t numel = cond_desc->numel(); + int ndim = static_cast(cond_desc->ndim()); + + std::vector shape(ndim); + std::vector strides(ndim); + for (int i = 0; i < ndim; ++i) { + shape[i] = cond_desc->shape()[i]; + strides[i] = cond_desc->stride(i); + } + + *desc_ptr = new IndicesDescriptor( + numel, ndim, shape.data(), strides.data(), + handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t IndicesDescriptor::calculate( + void *workspace, + size_t workspace_size, + void **outputs, + const void *cond, + void *stream, + size_t *num_true) const { + + auto cuda_stream = reinterpret_cast(stream); + const bool *cond_ptr = reinterpret_cast(cond); + + // 分配 workspace 中的内存 + int64_t *flags = reinterpret_cast(workspace); + size_t flags_size = align256(sizeof(int64_t) * _numel); + void *scan_workspace = static_cast(workspace) + flags_size; + size_t scan_workspace_size = workspace_size - flags_size; + + // 复制 shape 和 strides 到设备(用于 markTrueElements) + size_t *d_shape; + ptrdiff_t *d_strides; + CHECK_CUDA(cudaMallocAsync(&d_shape, sizeof(size_t) * _ndim, cuda_stream)); + CHECK_CUDA(cudaMallocAsync(&d_strides, sizeof(ptrdiff_t) * _ndim, cuda_stream)); + CHECK_CUDA(cudaMemcpyAsync( + d_shape, _shape, sizeof(size_t) * _ndim, + cudaMemcpyHostToDevice, cuda_stream)); + CHECK_CUDA(cudaMemcpyAsync( + d_strides, _strides, sizeof(ptrdiff_t) * _ndim, + cudaMemcpyHostToDevice, cuda_stream)); + + // 阶段1: 标记 True 元素 + constexpr int BLOCK_SIZE = 256; + int grid_size = static_cast((_numel + BLOCK_SIZE - 1) / BLOCK_SIZE); + op::where::cuda::markTrueElements<<>>( + flags, cond_ptr, d_shape, d_strides, _numel, _ndim); + CHECK_CUDA(cudaGetLastError()); + + // 阶段2: 计算前缀和(inclusive scan) + size_t temp_workspace_size = scan_workspace_size; + CHECK_CUDA(inclusiveSum( + scan_workspace, temp_workspace_size, + flags, static_cast(_numel), + cuda_stream)); + + // 获取 True 元素的总数 + int64_t num_true_val = 0; + CHECK_CUDA(cudaMemcpyAsync( + &num_true_val, flags + _numel - 1, sizeof(int64_t), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + *num_true = static_cast(num_true_val); + + // 阶段3: 收集每个维度的索引 + int64_t **output_ptrs = new int64_t *[_ndim]; + for (int i = 0; i < _ndim; ++i) { + output_ptrs[i] = reinterpret_cast(outputs[i]); + } + + // d_shape 已经在阶段1中分配和复制了,这里直接使用 + + // 复制 output_ptrs 到设备 + int64_t **d_output_ptrs; + CHECK_CUDA(cudaMallocAsync(&d_output_ptrs, sizeof(int64_t *) * _ndim, cuda_stream)); + CHECK_CUDA(cudaMemcpyAsync( + d_output_ptrs, output_ptrs, sizeof(int64_t *) * _ndim, + cudaMemcpyHostToDevice, cuda_stream)); + + // 启动收集索引的 kernel + op::where::cuda::collectIndices<<>>( + d_output_ptrs, flags, cond_ptr, d_shape, d_strides, _numel, _ndim); + CHECK_CUDA(cudaGetLastError()); + + // 清理 + CHECK_CUDA(cudaFreeAsync(d_shape, cuda_stream)); + CHECK_CUDA(cudaFreeAsync(d_strides, cuda_stream)); + CHECK_CUDA(cudaFreeAsync(d_output_ptrs, cuda_stream)); + delete[] output_ptrs; + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::where::moore diff --git a/src/infiniop/ops/where/moore/where_indices_moore.h b/src/infiniop/ops/where/moore/where_indices_moore.h new file mode 100644 index 000000000..ccb126f31 --- /dev/null +++ b/src/infiniop/ops/where/moore/where_indices_moore.h @@ -0,0 +1,58 @@ +#ifndef __WHERE_INDICES_MOORE_H__ +#define __WHERE_INDICES_MOORE_H__ + +#include "../../../devices/moore/moore_handle.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +namespace op::where::moore { + +class IndicesDescriptor final : public InfiniopDescriptor { + size_t _numel; + int _ndim; + size_t *_shape; + ptrdiff_t *_strides; + +public: + IndicesDescriptor( + size_t numel, + int ndim, + const size_t *shape, + const ptrdiff_t *strides, + infiniDevice_t device_type, + int device_id) + : InfiniopDescriptor{device_type, device_id}, + _numel(numel), + _ndim(ndim) { + _shape = new size_t[ndim]; + _strides = new ptrdiff_t[ndim]; + for (int i = 0; i < ndim; ++i) { + _shape[i] = shape[i]; + _strides[i] = strides[i]; + } + } + + ~IndicesDescriptor() { + delete[] _shape; + delete[] _strides; + } + + size_t workspaceSize() const; + + static infiniStatus_t create( + infiniopHandle_t handle, + IndicesDescriptor **desc_ptr, + infiniopTensorDescriptor_t cond_desc); + + infiniStatus_t calculate( + void *workspace, + size_t workspace_size, + void **outputs, + const void *cond, + void *stream, + size_t *num_true) const; +}; + +} // namespace op::where::moore + +#endif // __WHERE_INDICES_MOORE_H__ diff --git a/src/infiniop/ops/where/nvidia/where_indices_nvidia.cu b/src/infiniop/ops/where/nvidia/where_indices_nvidia.cu new file mode 100644 index 000000000..eaa410f94 --- /dev/null +++ b/src/infiniop/ops/where/nvidia/where_indices_nvidia.cu @@ -0,0 +1,156 @@ +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../cuda/where_indices_kernel.cuh" +#include "where_indices_nvidia.h" +#include +#include +#include + +namespace op::where::nvidia { + +// 封装 CUB InclusiveSum +template +static cudaError inclusiveSum( + void *workspace_ptr, size_t &workspace_len, + T *data, int n, + cudaStream_t stream) { + return cub::DeviceScan::InclusiveSum( + workspace_ptr, workspace_len, + data, data, n, + stream); +} + +// 地址对齐到 256 +static constexpr size_t align256(size_t size) { + return (size + 255) & (~255); +} + +size_t IndicesDescriptor::workspaceSize() const { + const auto n = static_cast(_numel); + + // flags 数组大小 + size_t flags_size = align256(sizeof(int64_t) * _numel); + + // CUB scan workspace + size_t scan_workspace = 0; + int64_t *dummy = nullptr; + CHECK_CUDA(inclusiveSum( + nullptr, scan_workspace, + dummy, n, + nullptr)); + + return flags_size + scan_workspace; +} + +infiniStatus_t IndicesDescriptor::create( + infiniopHandle_t handle_, + IndicesDescriptor **desc_ptr, + infiniopTensorDescriptor_t cond_desc) { + + auto handle = reinterpret_cast(handle_); + + // 检查条件必须是 bool 类型 + if (cond_desc->dtype() != INFINI_DTYPE_BOOL) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + size_t numel = cond_desc->numel(); + int ndim = static_cast(cond_desc->ndim()); + + std::vector shape(ndim); + std::vector strides(ndim); + for (int i = 0; i < ndim; ++i) { + shape[i] = cond_desc->shape()[i]; + strides[i] = cond_desc->stride(i); + } + + *desc_ptr = new IndicesDescriptor( + numel, ndim, shape.data(), strides.data(), + handle->internal(), + handle->device, handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t IndicesDescriptor::calculate( + void *workspace, + size_t workspace_size, + void **outputs, // outputs[i] 是第 i 维的索引张量 + const void *cond, + void *stream, + size_t *num_true) const { + + auto cuda_stream = reinterpret_cast(stream); + const bool *cond_ptr = reinterpret_cast(cond); + + // 分配 workspace 中的内存 + int64_t *flags = reinterpret_cast(workspace); + size_t flags_size = align256(sizeof(int64_t) * _numel); + void *scan_workspace = static_cast(workspace) + flags_size; + size_t scan_workspace_size = workspace_size - flags_size; + + // 复制 shape 和 strides 到设备(用于 markTrueElements / collectIndices) + size_t *d_shape; + ptrdiff_t *d_strides; + CHECK_CUDA(cudaMallocAsync(&d_shape, sizeof(size_t) * _ndim, cuda_stream)); + CHECK_CUDA(cudaMallocAsync(&d_strides, sizeof(ptrdiff_t) * _ndim, cuda_stream)); + CHECK_CUDA(cudaMemcpyAsync( + d_shape, _shape, sizeof(size_t) * _ndim, + cudaMemcpyHostToDevice, cuda_stream)); + CHECK_CUDA(cudaMemcpyAsync( + d_strides, _strides, sizeof(ptrdiff_t) * _ndim, + cudaMemcpyHostToDevice, cuda_stream)); + + // 阶段1: 标记 True 元素 + constexpr int BLOCK_SIZE = 256; + int grid_size = static_cast((_numel + BLOCK_SIZE - 1) / BLOCK_SIZE); + op::where::cuda::markTrueElements<<>>( + flags, cond_ptr, d_shape, d_strides, _numel, _ndim); + CHECK_CUDA(cudaGetLastError()); + + // 阶段2: 计算前缀和(inclusive scan) + size_t temp_workspace_size = scan_workspace_size; + CHECK_CUDA(inclusiveSum( + scan_workspace, temp_workspace_size, + flags, static_cast(_numel), + cuda_stream)); + + // 获取 True 元素的总数(最后一个元素的前缀和值) + int64_t num_true_val = 0; + CHECK_CUDA(cudaMemcpyAsync( + &num_true_val, flags + _numel - 1, sizeof(int64_t), + cudaMemcpyDeviceToHost, cuda_stream)); + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + *num_true = static_cast(num_true_val); + + // 阶段3: 收集每个维度的索引 + // 将 outputs 转换为 int64_t* 数组并复制到设备 + int64_t **output_ptrs = new int64_t *[_ndim]; + for (int i = 0; i < _ndim; ++i) { + output_ptrs[i] = reinterpret_cast(outputs[i]); + } + + // d_shape / d_strides 已经在阶段1中分配和复制了,这里直接使用 + + // 复制 output_ptrs 到设备 + int64_t **d_output_ptrs; + CHECK_CUDA(cudaMallocAsync(&d_output_ptrs, sizeof(int64_t *) * _ndim, cuda_stream)); + CHECK_CUDA(cudaMemcpyAsync( + d_output_ptrs, output_ptrs, sizeof(int64_t *) * _ndim, + cudaMemcpyHostToDevice, cuda_stream)); + + // 启动收集索引的 kernel + op::where::cuda::collectIndices<<>>( + d_output_ptrs, flags, cond_ptr, d_shape, d_strides, _numel, _ndim); + CHECK_CUDA(cudaGetLastError()); + + // 清理 + CHECK_CUDA(cudaFreeAsync(d_shape, cuda_stream)); + CHECK_CUDA(cudaFreeAsync(d_strides, cuda_stream)); + CHECK_CUDA(cudaFreeAsync(d_output_ptrs, cuda_stream)); + delete[] output_ptrs; + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::where::nvidia diff --git a/src/infiniop/ops/where/nvidia/where_indices_nvidia.h b/src/infiniop/ops/where/nvidia/where_indices_nvidia.h new file mode 100644 index 000000000..a07f6e4b0 --- /dev/null +++ b/src/infiniop/ops/where/nvidia/where_indices_nvidia.h @@ -0,0 +1,61 @@ +#ifndef __WHERE_INDICES_NVIDIA_H__ +#define __WHERE_INDICES_NVIDIA_H__ + +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "../../../operator.h" +#include "../../../tensor.h" + +namespace op::where::nvidia { + +class IndicesDescriptor final : public InfiniopDescriptor { + size_t _numel; + int _ndim; + size_t *_shape; + ptrdiff_t *_strides; + std::shared_ptr _internal; + +public: + IndicesDescriptor( + size_t numel, + int ndim, + const size_t *shape, + const ptrdiff_t *strides, + std::shared_ptr internal, + infiniDevice_t device_type, + int device_id) + : InfiniopDescriptor{device_type, device_id}, + _numel(numel), + _ndim(ndim), + _internal(internal) { + _shape = new size_t[ndim]; + _strides = new ptrdiff_t[ndim]; + for (int i = 0; i < ndim; ++i) { + _shape[i] = shape[i]; + _strides[i] = strides[i]; + } + } + + ~IndicesDescriptor() { + delete[] _shape; + delete[] _strides; + } + + size_t workspaceSize() const; + + static infiniStatus_t create( + infiniopHandle_t handle, + IndicesDescriptor **desc_ptr, + infiniopTensorDescriptor_t cond_desc); + + infiniStatus_t calculate( + void *workspace, + size_t workspace_size, + void **outputs, // NDIM 个输出张量的指针数组 + const void *cond, + void *stream, + size_t *num_true) const; // 输出:True 元素的数量 +}; + +} // namespace op::where::nvidia + +#endif // __WHERE_INDICES_NVIDIA_H__ diff --git a/src/infiniop/ops/where/operator.cc b/src/infiniop/ops/where/operator.cc index c966af409..da309a9f8 100644 --- a/src/infiniop/ops/where/operator.cc +++ b/src/infiniop/ops/where/operator.cc @@ -4,6 +4,16 @@ #ifdef ENABLE_CPU_API #include "cpu/where_cpu.h" +#include "cpu/where_indices_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/where_indices_nvidia.h" +#endif +#ifdef ENABLE_METAX_API +#include "metax/where_indices_metax.h" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/where_indices_moore.h" #endif __C infiniStatus_t infiniopCreateWhereDescriptor( @@ -14,12 +24,12 @@ __C infiniStatus_t infiniopCreateWhereDescriptor( infiniopTensorDescriptor_t x_desc, infiniopTensorDescriptor_t y_desc) { -#define CREATE(CASE, NAMESPACE) \ - case CASE: \ - return op::where::NAMESPACE::Descriptor::create( \ - handle, \ - reinterpret_cast(desc_ptr), \ - out_desc, \ +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::where::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + out_desc, \ {cond_desc, x_desc, y_desc}) switch (handle->device) { @@ -39,8 +49,8 @@ __C infiniStatus_t infiniopGetWhereWorkspaceSize( infiniopWhereDescriptor_t desc, size_t *size) { -#define GET(CASE, NAMESPACE) \ - case CASE: \ +#define GET(CASE, NAMESPACE) \ + case CASE: \ *size = reinterpret_cast(desc)->workspaceSize(); \ return INFINI_STATUS_SUCCESS; @@ -65,9 +75,9 @@ __C infiniStatus_t infiniopWhere( const void *y, void *stream) { -#define CALCULATE(CASE, NAMESPACE) \ - case CASE: \ - return reinterpret_cast(desc) \ +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ ->calculate(workspace, workspace_size, out, {cond, x, y}, stream) switch (desc->device_type) { @@ -85,9 +95,9 @@ __C infiniStatus_t infiniopWhere( __C infiniStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc) { -#define DELETE(CASE, NAMESPACE) \ - case CASE: \ - delete reinterpret_cast(desc); \ +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ return INFINI_STATUS_SUCCESS; switch (desc->device_type) { @@ -103,4 +113,149 @@ __C infiniStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc #undef DELETE } +// where(cond) -> indices tuple +__C infiniStatus_t infiniopCreateWhereIndicesDescriptor( + infiniopHandle_t handle, + infiniopWhereIndicesDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t cond_desc) { + +#define CREATE_INDICES(CASE, NAMESPACE) \ + case CASE: \ + return op::where::NAMESPACE::IndicesDescriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + cond_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE_INDICES(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) + CREATE_INDICES(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE_INDICES(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + CREATE_INDICES(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE_INDICES(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CREATE_INDICES(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE_INDICES +} + +__C infiniStatus_t infiniopGetWhereIndicesWorkspaceSize( + infiniopWhereIndicesDescriptor_t desc, + size_t *size) { + +#define GET_INDICES(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET_INDICES(INFINI_DEVICE_CPU, cpu) +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) + GET_INDICES(INFINI_DEVICE_NVIDIA, nvidia) +#endif +#ifdef ENABLE_ILUVATAR_API + GET_INDICES(INFINI_DEVICE_ILUVATAR, nvidia) +#endif +#ifdef ENABLE_QY_API + GET_INDICES(INFINI_DEVICE_QY, nvidia) +#endif +#ifdef ENABLE_METAX_API + GET_INDICES(INFINI_DEVICE_METAX, metax) +#endif +#ifdef ENABLE_MOORE_API + GET_INDICES(INFINI_DEVICE_MOORE, moore) +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET_INDICES + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopWhereIndices( + infiniopWhereIndicesDescriptor_t desc, + void *workspace, + size_t workspace_size, + void **outputs, + const void *cond, + void *stream, + size_t *num_true) { + +#define CALCULATE_INDICES(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, outputs, cond, stream, num_true) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE_INDICES(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) + CALCULATE_INDICES(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE_INDICES(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + CALCULATE_INDICES(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE_INDICES(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE_INDICES(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE_INDICES +} + +__C infiniStatus_t infiniopDestroyWhereIndicesDescriptor(infiniopWhereIndicesDescriptor_t desc) { + +#define DELETE_INDICES(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DELETE_INDICES(INFINI_DEVICE_CPU, cpu); +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) + DELETE_INDICES(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DELETE_INDICES(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + DELETE_INDICES(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + DELETE_INDICES(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + DELETE_INDICES(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE_INDICES +} From 99eb9434890d17fc02ed0c50c172d1f578641b93 Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 16 Dec 2025 21:56:46 +0800 Subject: [PATCH 07/17] =?UTF-8?q?fix:=20=E5=8E=BB=E9=99=A4cub=E4=BE=9D?= =?UTF-8?q?=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ops/vdot/cuda/kernel.cuh | 35 ++++++++++++++--------- src/infiniop/ops/vdot/metax/vdot_metax.cu | 1 - src/infiniop/ops/vdot/moore/vdot_moore.cu | 1 - 3 files changed, 21 insertions(+), 16 deletions(-) diff --git a/src/infiniop/ops/vdot/cuda/kernel.cuh b/src/infiniop/ops/vdot/cuda/kernel.cuh index 9b5342e92..34c233e02 100644 --- a/src/infiniop/ops/vdot/cuda/kernel.cuh +++ b/src/infiniop/ops/vdot/cuda/kernel.cuh @@ -1,7 +1,8 @@ #ifndef __VDOT_CUDA_KERNEL_CUH__ #define __VDOT_CUDA_KERNEL_CUH__ -#include +#include +#include namespace op::vdot::cuda { @@ -10,24 +11,30 @@ __global__ void vdotKernel(Tcompute *out, const Tdata *a, const Tdata *b, size_t length, ptrdiff_t a_stride, ptrdiff_t b_stride) { - Tcompute dot = 0; - - // Each thread computes its partial dot product + // 每个线程计算部分点积 + Tcompute local_sum = 0; for (size_t i = threadIdx.x; i < length; i += BLOCK_SIZE) { - Tcompute a_val = Tcompute(a[i * a_stride]); - Tcompute b_val = Tcompute(b[i * b_stride]); - dot += a_val * b_val; + Tcompute a_val = static_cast(a[i * a_stride]); + Tcompute b_val = static_cast(b[i * b_stride]); + local_sum += a_val * b_val; } - // Use CUB block-level reduction - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage temp_storage; - - Tcompute block_dot = BlockReduce(temp_storage).Sum(dot); + // 使用共享内存进行 block 内归约(不依赖 CUB) + __shared__ Tcompute sdata[BLOCK_SIZE]; + sdata[threadIdx.x] = local_sum; + __syncthreads(); + + // 标准的二分归约算法 + for (unsigned int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (threadIdx.x < s) { + sdata[threadIdx.x] += sdata[threadIdx.x + s]; + } + __syncthreads(); + } - // Thread 0 writes the result + // Thread 0 写入结果 if (threadIdx.x == 0) { - *out = block_dot; + *out = sdata[0]; } } diff --git a/src/infiniop/ops/vdot/metax/vdot_metax.cu b/src/infiniop/ops/vdot/metax/vdot_metax.cu index f75605a0a..ed8eea47a 100644 --- a/src/infiniop/ops/vdot/metax/vdot_metax.cu +++ b/src/infiniop/ops/vdot/metax/vdot_metax.cu @@ -1,7 +1,6 @@ #include "../../../devices/metax/metax_handle.h" #include "../../../devices/nvidia/nvidia_kernel_common.cuh" #include "vdot_metax.h" -#include #include namespace op::vdot::metax { diff --git a/src/infiniop/ops/vdot/moore/vdot_moore.cu b/src/infiniop/ops/vdot/moore/vdot_moore.cu index 2c9a6e474..787be15f2 100644 --- a/src/infiniop/ops/vdot/moore/vdot_moore.cu +++ b/src/infiniop/ops/vdot/moore/vdot_moore.cu @@ -1,7 +1,6 @@ #include "../../../devices/moore/moore_handle.h" #include "../../../devices/nvidia/nvidia_kernel_common.cuh" #include "vdot_moore.h" -#include #include namespace op::vdot::moore { From 5785576d52dbc0270763fb2405979e462b3cb1b1 Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 16 Dec 2025 22:19:01 +0800 Subject: [PATCH 08/17] fix: metax moore --- src/infiniop/ops/vdot/cuda/kernel.cuh | 2 +- .../metax/{vdot_metax.cu => vdot_metax.maca} | 53 ++++++++-------- .../moore/{vdot_moore.cu => vdot_moore.mu} | 60 +++++++++---------- 3 files changed, 55 insertions(+), 60 deletions(-) rename src/infiniop/ops/vdot/metax/{vdot_metax.cu => vdot_metax.maca} (68%) rename src/infiniop/ops/vdot/moore/{vdot_moore.cu => vdot_moore.mu} (64%) diff --git a/src/infiniop/ops/vdot/cuda/kernel.cuh b/src/infiniop/ops/vdot/cuda/kernel.cuh index 34c233e02..7e12ec090 100644 --- a/src/infiniop/ops/vdot/cuda/kernel.cuh +++ b/src/infiniop/ops/vdot/cuda/kernel.cuh @@ -1,8 +1,8 @@ #ifndef __VDOT_CUDA_KERNEL_CUH__ #define __VDOT_CUDA_KERNEL_CUH__ -#include #include +// 不在这里包含 cuda_runtime.h,让各个平台的头文件提供必要的定义 namespace op::vdot::cuda { diff --git a/src/infiniop/ops/vdot/metax/vdot_metax.cu b/src/infiniop/ops/vdot/metax/vdot_metax.maca similarity index 68% rename from src/infiniop/ops/vdot/metax/vdot_metax.cu rename to src/infiniop/ops/vdot/metax/vdot_metax.maca index ed8eea47a..fb2087e5a 100644 --- a/src/infiniop/ops/vdot/metax/vdot_metax.cu +++ b/src/infiniop/ops/vdot/metax/vdot_metax.maca @@ -1,7 +1,7 @@ -#include "../../../devices/metax/metax_handle.h" -#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../../../devices/metax/metax_common.h" #include "vdot_metax.h" -#include +#include "../../../devices/metax/metax_kernel_common.h" +#include "../cuda/kernel.cuh" namespace op::vdot::metax { @@ -48,7 +48,7 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, void *out, const void *a, const void *b, void *stream) const { - auto cuda_stream = reinterpret_cast(stream); + auto hc_stream = reinterpret_cast(stream); constexpr unsigned int BLOCK_SIZE = 256; switch (_in_dtype) { @@ -57,9 +57,9 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, const float *a_f = reinterpret_cast(a); const float *b_f = reinterpret_cast(b); op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_f, a_f, b_f, _length, _a_stride, + <<<1, BLOCK_SIZE, 0, hc_stream>>>(out_f, a_f, b_f, _length, _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); + CHECK_METAX(hcGetLastError()); break; } case INFINI_DTYPE_F64: { @@ -67,9 +67,9 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, const double *a_d = reinterpret_cast(a); const double *b_d = reinterpret_cast(b); op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_d, a_d, b_d, _length, _a_stride, + <<<1, BLOCK_SIZE, 0, hc_stream>>>(out_d, a_d, b_d, _length, _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); + CHECK_METAX(hcGetLastError()); break; } case INFINI_DTYPE_F16: { @@ -82,16 +82,16 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, const __half *a_h = reinterpret_cast(a); const __half *b_h = reinterpret_cast(b); op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_h, b_h, _length, + <<<1, BLOCK_SIZE, 0, hc_stream>>>(tmp_out, a_h, b_h, _length, _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); + CHECK_METAX(hcGetLastError()); float result_f; - CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), - cudaMemcpyDeviceToHost, cuda_stream)); - CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + CHECK_METAX(hcMemcpyAsync(&result_f, tmp_out, sizeof(float), + hcMemcpyDeviceToHost, hc_stream)); + CHECK_METAX(hcStreamSynchronize(hc_stream)); __half h_result = __float2half(result_f); - CHECK_CUDA(cudaMemcpyAsync(out, &h_result, sizeof(__half), - cudaMemcpyHostToDevice, cuda_stream)); + CHECK_METAX(hcMemcpyAsync(out, &h_result, sizeof(__half), + hcMemcpyHostToDevice, hc_stream)); } break; } @@ -102,19 +102,19 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, } float *tmp_out = reinterpret_cast(workspace); { - const __nv_bfloat16 *a_bf = reinterpret_cast(a); - const __nv_bfloat16 *b_bf = reinterpret_cast(b); - op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_bf, b_bf, _length, + const __hpcc_bfloat16 *a_bf = reinterpret_cast(a); + const __hpcc_bfloat16 *b_bf = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, hc_stream>>>(tmp_out, a_bf, b_bf, _length, _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); + CHECK_METAX(hcGetLastError()); float result_f; - CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), - cudaMemcpyDeviceToHost, cuda_stream)); - CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); - __nv_bfloat16 bf_result = __float2bfloat16(result_f); - CHECK_CUDA(cudaMemcpyAsync(out, &bf_result, sizeof(__nv_bfloat16), - cudaMemcpyHostToDevice, cuda_stream)); + CHECK_METAX(hcMemcpyAsync(&result_f, tmp_out, sizeof(float), + hcMemcpyDeviceToHost, hc_stream)); + CHECK_METAX(hcStreamSynchronize(hc_stream)); + __hpcc_bfloat16 bf_result = __float2bfloat16(result_f); + CHECK_METAX(hcMemcpyAsync(out, &bf_result, sizeof(__hpcc_bfloat16), + hcMemcpyHostToDevice, hc_stream)); } break; } @@ -126,3 +126,4 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, } } // namespace op::vdot::metax + diff --git a/src/infiniop/ops/vdot/moore/vdot_moore.cu b/src/infiniop/ops/vdot/moore/vdot_moore.mu similarity index 64% rename from src/infiniop/ops/vdot/moore/vdot_moore.cu rename to src/infiniop/ops/vdot/moore/vdot_moore.mu index 787be15f2..74af7d31d 100644 --- a/src/infiniop/ops/vdot/moore/vdot_moore.cu +++ b/src/infiniop/ops/vdot/moore/vdot_moore.mu @@ -1,7 +1,7 @@ -#include "../../../devices/moore/moore_handle.h" -#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include "../../../devices/moore/moore_common.h" #include "vdot_moore.h" -#include +#include "../../../devices/moore/moore_kernel_common.h" +#include "../cuda/kernel.cuh" namespace op::vdot::moore { @@ -48,7 +48,7 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, void *out, const void *a, const void *b, void *stream) const { - auto cuda_stream = reinterpret_cast(stream); + auto musa_stream = reinterpret_cast(stream); constexpr unsigned int BLOCK_SIZE = 256; switch (_in_dtype) { @@ -57,9 +57,9 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, const float *a_f = reinterpret_cast(a); const float *b_f = reinterpret_cast(b); op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_f, a_f, b_f, _length, _a_stride, + <<<1, BLOCK_SIZE, 0, musa_stream>>>(out_f, a_f, b_f, _length, _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); + CHECK_MOORE(musaGetLastError()); break; } case INFINI_DTYPE_F64: { @@ -67,38 +67,31 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, const double *a_d = reinterpret_cast(a); const double *b_d = reinterpret_cast(b); op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(out_d, a_d, b_d, _length, _a_stride, + <<<1, BLOCK_SIZE, 0, musa_stream>>>(out_d, a_d, b_d, _length, _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); + CHECK_MOORE(musaGetLastError()); break; } case INFINI_DTYPE_F16: { // For FP16, accumulate in float, then cast back to half - // Use workspace for temporary float buffer if (workspace_size < sizeof(float)) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; } float *tmp_out = reinterpret_cast(workspace); { - // If workspace is too small, we need to allocate - // For simplicity, use a device-side kernel that writes directly to out - // But we need float accumulation, so use a temporary approach const __half *a_h = reinterpret_cast(a); const __half *b_h = reinterpret_cast(b); - // Launch kernel that accumulates in float and writes half result op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_h, b_h, _length, + <<<1, BLOCK_SIZE, 0, musa_stream>>>(tmp_out, a_h, b_h, _length, _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); - // Use a simple device kernel to cast float to half - // For now, copy to host, cast, and copy back + CHECK_MOORE(musaGetLastError()); float result_f; - CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), - cudaMemcpyDeviceToHost, cuda_stream)); - CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + CHECK_MOORE(musaMemcpyAsync(&result_f, tmp_out, sizeof(float), + musaMemcpyDeviceToHost, musa_stream)); + CHECK_MOORE(musaStreamSynchronize(musa_stream)); __half h_result = __float2half(result_f); - CHECK_CUDA(cudaMemcpyAsync(out, &h_result, sizeof(__half), - cudaMemcpyHostToDevice, cuda_stream)); + CHECK_MOORE(musaMemcpyAsync(out, &h_result, sizeof(__half), + musaMemcpyHostToDevice, musa_stream)); } break; } @@ -109,19 +102,19 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, } float *tmp_out = reinterpret_cast(workspace); { - const __nv_bfloat16 *a_bf = reinterpret_cast(a); - const __nv_bfloat16 *b_bf = reinterpret_cast(b); - op::vdot::cuda::vdotKernel - <<<1, BLOCK_SIZE, 0, cuda_stream>>>(tmp_out, a_bf, b_bf, _length, + const __mt_bfloat16 *a_bf = reinterpret_cast(a); + const __mt_bfloat16 *b_bf = reinterpret_cast(b); + op::vdot::cuda::vdotKernel + <<<1, BLOCK_SIZE, 0, musa_stream>>>(tmp_out, a_bf, b_bf, _length, _a_stride, _b_stride); - CHECK_CUDA(cudaGetLastError()); + CHECK_MOORE(musaGetLastError()); float result_f; - CHECK_CUDA(cudaMemcpyAsync(&result_f, tmp_out, sizeof(float), - cudaMemcpyDeviceToHost, cuda_stream)); - CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); - __nv_bfloat16 bf_result = __float2bfloat16(result_f); - CHECK_CUDA(cudaMemcpyAsync(out, &bf_result, sizeof(__nv_bfloat16), - cudaMemcpyHostToDevice, cuda_stream)); + CHECK_MOORE(musaMemcpyAsync(&result_f, tmp_out, sizeof(float), + musaMemcpyDeviceToHost, musa_stream)); + CHECK_MOORE(musaStreamSynchronize(musa_stream)); + __mt_bfloat16 bf_result = __float2bfloat16(result_f); + CHECK_MOORE(musaMemcpyAsync(out, &bf_result, sizeof(__mt_bfloat16), + musaMemcpyHostToDevice, musa_stream)); } break; } @@ -133,3 +126,4 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, } } // namespace op::vdot::moore + From dc4b340a6586a95c163d3a085d3ea8d81a10266b Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 16 Dec 2025 22:23:11 +0800 Subject: [PATCH 09/17] fix --- src/infiniop/ops/vdot/metax/vdot_metax.h | 1 - src/infiniop/ops/vdot/moore/vdot_moore.h | 1 - 2 files changed, 2 deletions(-) diff --git a/src/infiniop/ops/vdot/metax/vdot_metax.h b/src/infiniop/ops/vdot/metax/vdot_metax.h index b5eed7093..bd589e79c 100644 --- a/src/infiniop/ops/vdot/metax/vdot_metax.h +++ b/src/infiniop/ops/vdot/metax/vdot_metax.h @@ -4,7 +4,6 @@ #include "../../../devices/metax/metax_handle.h" #include "../../../operator.h" #include "../../../tensor.h" -#include "../cuda/kernel.cuh" namespace op::vdot::metax { diff --git a/src/infiniop/ops/vdot/moore/vdot_moore.h b/src/infiniop/ops/vdot/moore/vdot_moore.h index 912326f22..6d30b76f3 100644 --- a/src/infiniop/ops/vdot/moore/vdot_moore.h +++ b/src/infiniop/ops/vdot/moore/vdot_moore.h @@ -4,7 +4,6 @@ #include "../../../devices/moore/moore_handle.h" #include "../../../operator.h" #include "../../../tensor.h" -#include "../cuda/kernel.cuh" namespace op::vdot::moore { From 5cf4c87f384627314c08182f75a18e9a200f068a Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 16 Dec 2025 22:41:30 +0800 Subject: [PATCH 10/17] fix: metax check error --- src/infiniop/ops/vdot/metax/vdot_metax.maca | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/infiniop/ops/vdot/metax/vdot_metax.maca b/src/infiniop/ops/vdot/metax/vdot_metax.maca index fb2087e5a..c914d5856 100644 --- a/src/infiniop/ops/vdot/metax/vdot_metax.maca +++ b/src/infiniop/ops/vdot/metax/vdot_metax.maca @@ -59,7 +59,6 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, op::vdot::cuda::vdotKernel <<<1, BLOCK_SIZE, 0, hc_stream>>>(out_f, a_f, b_f, _length, _a_stride, _b_stride); - CHECK_METAX(hcGetLastError()); break; } case INFINI_DTYPE_F64: { @@ -69,7 +68,6 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, op::vdot::cuda::vdotKernel <<<1, BLOCK_SIZE, 0, hc_stream>>>(out_d, a_d, b_d, _length, _a_stride, _b_stride); - CHECK_METAX(hcGetLastError()); break; } case INFINI_DTYPE_F16: { @@ -84,7 +82,6 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, op::vdot::cuda::vdotKernel <<<1, BLOCK_SIZE, 0, hc_stream>>>(tmp_out, a_h, b_h, _length, _a_stride, _b_stride); - CHECK_METAX(hcGetLastError()); float result_f; CHECK_METAX(hcMemcpyAsync(&result_f, tmp_out, sizeof(float), hcMemcpyDeviceToHost, hc_stream)); @@ -107,7 +104,6 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, op::vdot::cuda::vdotKernel <<<1, BLOCK_SIZE, 0, hc_stream>>>(tmp_out, a_bf, b_bf, _length, _a_stride, _b_stride); - CHECK_METAX(hcGetLastError()); float result_f; CHECK_METAX(hcMemcpyAsync(&result_f, tmp_out, sizeof(float), hcMemcpyDeviceToHost, hc_stream)); From 8308a5ed4770e3cc84591869dc6edd35ebb52c5d Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 16 Dec 2025 22:54:53 +0800 Subject: [PATCH 11/17] fix: moore --- src/infiniop/ops/vdot/moore/vdot_moore.mu | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/infiniop/ops/vdot/moore/vdot_moore.mu b/src/infiniop/ops/vdot/moore/vdot_moore.mu index 74af7d31d..91b4dee5c 100644 --- a/src/infiniop/ops/vdot/moore/vdot_moore.mu +++ b/src/infiniop/ops/vdot/moore/vdot_moore.mu @@ -59,7 +59,6 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, op::vdot::cuda::vdotKernel <<<1, BLOCK_SIZE, 0, musa_stream>>>(out_f, a_f, b_f, _length, _a_stride, _b_stride); - CHECK_MOORE(musaGetLastError()); break; } case INFINI_DTYPE_F64: { @@ -69,7 +68,6 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, op::vdot::cuda::vdotKernel <<<1, BLOCK_SIZE, 0, musa_stream>>>(out_d, a_d, b_d, _length, _a_stride, _b_stride); - CHECK_MOORE(musaGetLastError()); break; } case INFINI_DTYPE_F16: { @@ -84,7 +82,6 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, op::vdot::cuda::vdotKernel <<<1, BLOCK_SIZE, 0, musa_stream>>>(tmp_out, a_h, b_h, _length, _a_stride, _b_stride); - CHECK_MOORE(musaGetLastError()); float result_f; CHECK_MOORE(musaMemcpyAsync(&result_f, tmp_out, sizeof(float), musaMemcpyDeviceToHost, musa_stream)); @@ -107,7 +104,6 @@ infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, op::vdot::cuda::vdotKernel <<<1, BLOCK_SIZE, 0, musa_stream>>>(tmp_out, a_bf, b_bf, _length, _a_stride, _b_stride); - CHECK_MOORE(musaGetLastError()); float result_f; CHECK_MOORE(musaMemcpyAsync(&result_f, tmp_out, sizeof(float), musaMemcpyDeviceToHost, musa_stream)); From 97be611c2b388ab33fa34d735781feab13339dcd Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 16 Dec 2025 23:11:03 +0800 Subject: [PATCH 12/17] fix: where --- ...ices_metax.cu => where_indices_metax.maca} | 58 +++++++++---------- ...ndices_moore.cu => where_indices_moore.mu} | 52 ++++++++--------- 2 files changed, 53 insertions(+), 57 deletions(-) rename src/infiniop/ops/where/metax/{where_indices_metax.cu => where_indices_metax.maca} (68%) rename src/infiniop/ops/where/moore/{where_indices_moore.cu => where_indices_moore.mu} (76%) diff --git a/src/infiniop/ops/where/metax/where_indices_metax.cu b/src/infiniop/ops/where/metax/where_indices_metax.maca similarity index 68% rename from src/infiniop/ops/where/metax/where_indices_metax.cu rename to src/infiniop/ops/where/metax/where_indices_metax.maca index d205ae5f9..977c453b6 100644 --- a/src/infiniop/ops/where/metax/where_indices_metax.cu +++ b/src/infiniop/ops/where/metax/where_indices_metax.maca @@ -1,9 +1,8 @@ -#include "../../../devices/metax/metax_handle.h" -#include "../../../devices/nvidia/nvidia_kernel_common.cuh" -#include "../cuda/where_indices_kernel.cuh" +#include "../../../devices/metax/metax_common.h" #include "where_indices_metax.h" +#include "../../../devices/metax/metax_kernel_common.h" +#include "../cuda/where_indices_kernel.cuh" #include -#include namespace op::where::metax { @@ -27,7 +26,7 @@ size_t IndicesDescriptor::workspaceSize() const { // CUB scan workspace size_t scan_workspace = 0; int64_t *dummy = nullptr; - CHECK_CUDA(inclusiveSum(nullptr, scan_workspace, dummy, n, nullptr)); + CHECK_METAX(inclusiveSum(nullptr, scan_workspace, dummy, n, nullptr)); return flags_size + scan_workspace; } @@ -65,7 +64,7 @@ infiniStatus_t IndicesDescriptor::calculate(void *workspace, void *stream, size_t *num_true) const { - auto cuda_stream = reinterpret_cast(stream); + auto hc_stream = reinterpret_cast(stream); const bool *cond_ptr = reinterpret_cast(cond); // 分配 workspace 中的内存 @@ -77,32 +76,31 @@ infiniStatus_t IndicesDescriptor::calculate(void *workspace, // 复制 shape 和 strides 到设备(用于 markTrueElements) size_t *d_shape; ptrdiff_t *d_strides; - CHECK_CUDA(cudaMallocAsync(&d_shape, sizeof(size_t) * _ndim, cuda_stream)); - CHECK_CUDA( - cudaMallocAsync(&d_strides, sizeof(ptrdiff_t) * _ndim, cuda_stream)); - CHECK_CUDA(cudaMemcpyAsync(d_shape, _shape, sizeof(size_t) * _ndim, - cudaMemcpyHostToDevice, cuda_stream)); - CHECK_CUDA(cudaMemcpyAsync(d_strides, _strides, sizeof(ptrdiff_t) * _ndim, - cudaMemcpyHostToDevice, cuda_stream)); + CHECK_METAX(hcMallocAsync(&d_shape, sizeof(size_t) * _ndim, hc_stream)); + CHECK_METAX( + hcMallocAsync(&d_strides, sizeof(ptrdiff_t) * _ndim, hc_stream)); + CHECK_METAX(hcMemcpyAsync(d_shape, _shape, sizeof(size_t) * _ndim, + hcMemcpyHostToDevice, hc_stream)); + CHECK_METAX(hcMemcpyAsync(d_strides, _strides, sizeof(ptrdiff_t) * _ndim, + hcMemcpyHostToDevice, hc_stream)); // 阶段1: 标记 True 元素 constexpr int BLOCK_SIZE = 256; int grid_size = static_cast((_numel + BLOCK_SIZE - 1) / BLOCK_SIZE); op::where::cuda::markTrueElements - <<>>(flags, cond_ptr, d_shape, + <<>>(flags, cond_ptr, d_shape, d_strides, _numel, _ndim); - CHECK_CUDA(cudaGetLastError()); // 阶段2: 计算前缀和(inclusive scan) size_t temp_workspace_size = scan_workspace_size; - CHECK_CUDA(inclusiveSum(scan_workspace, temp_workspace_size, flags, - static_cast(_numel), cuda_stream)); + CHECK_METAX(inclusiveSum(scan_workspace, temp_workspace_size, flags, + static_cast(_numel), hc_stream)); // 获取 True 元素的总数 int64_t num_true_val = 0; - CHECK_CUDA(cudaMemcpyAsync(&num_true_val, flags + _numel - 1, sizeof(int64_t), - cudaMemcpyDeviceToHost, cuda_stream)); - CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + CHECK_METAX(hcMemcpyAsync(&num_true_val, flags + _numel - 1, sizeof(int64_t), + hcMemcpyDeviceToHost, hc_stream)); + CHECK_METAX(hcStreamSynchronize(hc_stream)); *num_true = static_cast(num_true_val); // 阶段3: 收集每个维度的索引 @@ -115,25 +113,25 @@ infiniStatus_t IndicesDescriptor::calculate(void *workspace, // 复制 output_ptrs 到设备 int64_t **d_output_ptrs; - CHECK_CUDA( - cudaMallocAsync(&d_output_ptrs, sizeof(int64_t *) * _ndim, cuda_stream)); - CHECK_CUDA(cudaMemcpyAsync(d_output_ptrs, output_ptrs, - sizeof(int64_t *) * _ndim, cudaMemcpyHostToDevice, - cuda_stream)); + CHECK_METAX( + hcMallocAsync(&d_output_ptrs, sizeof(int64_t *) * _ndim, hc_stream)); + CHECK_METAX(hcMemcpyAsync(d_output_ptrs, output_ptrs, + sizeof(int64_t *) * _ndim, hcMemcpyHostToDevice, + hc_stream)); // 启动收集索引的 kernel op::where::cuda::collectIndices - <<>>( + <<>>( d_output_ptrs, flags, cond_ptr, d_shape, d_strides, _numel, _ndim); - CHECK_CUDA(cudaGetLastError()); // 清理 - CHECK_CUDA(cudaFreeAsync(d_shape, cuda_stream)); - CHECK_CUDA(cudaFreeAsync(d_strides, cuda_stream)); - CHECK_CUDA(cudaFreeAsync(d_output_ptrs, cuda_stream)); + CHECK_METAX(hcFreeAsync(d_shape, hc_stream)); + CHECK_METAX(hcFreeAsync(d_strides, hc_stream)); + CHECK_METAX(hcFreeAsync(d_output_ptrs, hc_stream)); delete[] output_ptrs; return INFINI_STATUS_SUCCESS; } } // namespace op::where::metax + diff --git a/src/infiniop/ops/where/moore/where_indices_moore.cu b/src/infiniop/ops/where/moore/where_indices_moore.mu similarity index 76% rename from src/infiniop/ops/where/moore/where_indices_moore.cu rename to src/infiniop/ops/where/moore/where_indices_moore.mu index b35bb1b9e..969350012 100644 --- a/src/infiniop/ops/where/moore/where_indices_moore.cu +++ b/src/infiniop/ops/where/moore/where_indices_moore.mu @@ -1,9 +1,8 @@ -#include "../../../devices/moore/moore_handle.h" -#include "../../../devices/nvidia/nvidia_kernel_common.cuh" -#include "../cuda/where_indices_kernel.cuh" +#include "../../../devices/moore/moore_common.h" #include "where_indices_moore.h" +#include "../../../devices/moore/moore_kernel_common.h" +#include "../cuda/where_indices_kernel.cuh" #include -#include namespace op::where::moore { @@ -33,7 +32,7 @@ size_t IndicesDescriptor::workspaceSize() const { // CUB scan workspace size_t scan_workspace = 0; int64_t *dummy = nullptr; - CHECK_CUDA(inclusiveSum( + CHECK_MOORE(inclusiveSum( nullptr, scan_workspace, dummy, n, nullptr)); @@ -78,7 +77,7 @@ infiniStatus_t IndicesDescriptor::calculate( void *stream, size_t *num_true) const { - auto cuda_stream = reinterpret_cast(stream); + auto musa_stream = reinterpret_cast(stream); const bool *cond_ptr = reinterpret_cast(cond); // 分配 workspace 中的内存 @@ -90,35 +89,34 @@ infiniStatus_t IndicesDescriptor::calculate( // 复制 shape 和 strides 到设备(用于 markTrueElements) size_t *d_shape; ptrdiff_t *d_strides; - CHECK_CUDA(cudaMallocAsync(&d_shape, sizeof(size_t) * _ndim, cuda_stream)); - CHECK_CUDA(cudaMallocAsync(&d_strides, sizeof(ptrdiff_t) * _ndim, cuda_stream)); - CHECK_CUDA(cudaMemcpyAsync( + CHECK_MOORE(musaMallocAsync(&d_shape, sizeof(size_t) * _ndim, musa_stream)); + CHECK_MOORE(musaMallocAsync(&d_strides, sizeof(ptrdiff_t) * _ndim, musa_stream)); + CHECK_MOORE(musaMemcpyAsync( d_shape, _shape, sizeof(size_t) * _ndim, - cudaMemcpyHostToDevice, cuda_stream)); - CHECK_CUDA(cudaMemcpyAsync( + musaMemcpyHostToDevice, musa_stream)); + CHECK_MOORE(musaMemcpyAsync( d_strides, _strides, sizeof(ptrdiff_t) * _ndim, - cudaMemcpyHostToDevice, cuda_stream)); + musaMemcpyHostToDevice, musa_stream)); // 阶段1: 标记 True 元素 constexpr int BLOCK_SIZE = 256; int grid_size = static_cast((_numel + BLOCK_SIZE - 1) / BLOCK_SIZE); - op::where::cuda::markTrueElements<<>>( + op::where::cuda::markTrueElements<<>>( flags, cond_ptr, d_shape, d_strides, _numel, _ndim); - CHECK_CUDA(cudaGetLastError()); // 阶段2: 计算前缀和(inclusive scan) size_t temp_workspace_size = scan_workspace_size; - CHECK_CUDA(inclusiveSum( + CHECK_MOORE(inclusiveSum( scan_workspace, temp_workspace_size, flags, static_cast(_numel), - cuda_stream)); + musa_stream)); // 获取 True 元素的总数 int64_t num_true_val = 0; - CHECK_CUDA(cudaMemcpyAsync( + CHECK_MOORE(musaMemcpyAsync( &num_true_val, flags + _numel - 1, sizeof(int64_t), - cudaMemcpyDeviceToHost, cuda_stream)); - CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + musaMemcpyDeviceToHost, musa_stream)); + CHECK_MOORE(musaStreamSynchronize(musa_stream)); *num_true = static_cast(num_true_val); // 阶段3: 收集每个维度的索引 @@ -131,23 +129,23 @@ infiniStatus_t IndicesDescriptor::calculate( // 复制 output_ptrs 到设备 int64_t **d_output_ptrs; - CHECK_CUDA(cudaMallocAsync(&d_output_ptrs, sizeof(int64_t *) * _ndim, cuda_stream)); - CHECK_CUDA(cudaMemcpyAsync( + CHECK_MOORE(musaMallocAsync(&d_output_ptrs, sizeof(int64_t *) * _ndim, musa_stream)); + CHECK_MOORE(musaMemcpyAsync( d_output_ptrs, output_ptrs, sizeof(int64_t *) * _ndim, - cudaMemcpyHostToDevice, cuda_stream)); + musaMemcpyHostToDevice, musa_stream)); // 启动收集索引的 kernel - op::where::cuda::collectIndices<<>>( + op::where::cuda::collectIndices<<>>( d_output_ptrs, flags, cond_ptr, d_shape, d_strides, _numel, _ndim); - CHECK_CUDA(cudaGetLastError()); // 清理 - CHECK_CUDA(cudaFreeAsync(d_shape, cuda_stream)); - CHECK_CUDA(cudaFreeAsync(d_strides, cuda_stream)); - CHECK_CUDA(cudaFreeAsync(d_output_ptrs, cuda_stream)); + CHECK_MOORE(musaFreeAsync(d_shape, musa_stream)); + CHECK_MOORE(musaFreeAsync(d_strides, musa_stream)); + CHECK_MOORE(musaFreeAsync(d_output_ptrs, musa_stream)); delete[] output_ptrs; return INFINI_STATUS_SUCCESS; } } // namespace op::where::moore + From 36bed3e7f7f8534a4281928e5082ac2ef8310fd4 Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 16 Dec 2025 23:15:00 +0800 Subject: [PATCH 13/17] refactor: remove cuda_runtime.h inclusion for platform-specific header definitions --- src/infiniop/ops/where/cuda/where_indices_kernel.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/infiniop/ops/where/cuda/where_indices_kernel.cuh b/src/infiniop/ops/where/cuda/where_indices_kernel.cuh index f00763153..834aeab80 100644 --- a/src/infiniop/ops/where/cuda/where_indices_kernel.cuh +++ b/src/infiniop/ops/where/cuda/where_indices_kernel.cuh @@ -2,7 +2,7 @@ #define __WHERE_INDICES_KERNEL_CUH__ #include -#include +// 不在这里包含 cuda_runtime.h,让各个平台的头文件提供必要的定义 namespace op::where::cuda { From b3eaba0799b39fcefc10867db091ee2aeb3b4af3 Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 16 Dec 2025 23:21:23 +0800 Subject: [PATCH 14/17] refactor: replace async memory allocation with synchronous calls in where_indices_moore --- src/infiniop/ops/where/moore/where_indices_moore.mu | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/infiniop/ops/where/moore/where_indices_moore.mu b/src/infiniop/ops/where/moore/where_indices_moore.mu index 969350012..3058c712b 100644 --- a/src/infiniop/ops/where/moore/where_indices_moore.mu +++ b/src/infiniop/ops/where/moore/where_indices_moore.mu @@ -89,8 +89,8 @@ infiniStatus_t IndicesDescriptor::calculate( // 复制 shape 和 strides 到设备(用于 markTrueElements) size_t *d_shape; ptrdiff_t *d_strides; - CHECK_MOORE(musaMallocAsync(&d_shape, sizeof(size_t) * _ndim, musa_stream)); - CHECK_MOORE(musaMallocAsync(&d_strides, sizeof(ptrdiff_t) * _ndim, musa_stream)); + CHECK_MOORE(musaMalloc(&d_shape, sizeof(size_t) * _ndim)); + CHECK_MOORE(musaMalloc(&d_strides, sizeof(ptrdiff_t) * _ndim)); CHECK_MOORE(musaMemcpyAsync( d_shape, _shape, sizeof(size_t) * _ndim, musaMemcpyHostToDevice, musa_stream)); @@ -129,7 +129,7 @@ infiniStatus_t IndicesDescriptor::calculate( // 复制 output_ptrs 到设备 int64_t **d_output_ptrs; - CHECK_MOORE(musaMallocAsync(&d_output_ptrs, sizeof(int64_t *) * _ndim, musa_stream)); + CHECK_MOORE(musaMalloc(&d_output_ptrs, sizeof(int64_t *) * _ndim)); CHECK_MOORE(musaMemcpyAsync( d_output_ptrs, output_ptrs, sizeof(int64_t *) * _ndim, musaMemcpyHostToDevice, musa_stream)); @@ -139,9 +139,9 @@ infiniStatus_t IndicesDescriptor::calculate( d_output_ptrs, flags, cond_ptr, d_shape, d_strides, _numel, _ndim); // 清理 - CHECK_MOORE(musaFreeAsync(d_shape, musa_stream)); - CHECK_MOORE(musaFreeAsync(d_strides, musa_stream)); - CHECK_MOORE(musaFreeAsync(d_output_ptrs, musa_stream)); + CHECK_MOORE(musaFree(d_shape)); + CHECK_MOORE(musaFree(d_strides)); + CHECK_MOORE(musaFree(d_output_ptrs)); delete[] output_ptrs; return INFINI_STATUS_SUCCESS; From 12c4197194671c6b4380249a8fb3684361ea7252 Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Tue, 16 Dec 2025 23:32:21 +0800 Subject: [PATCH 15/17] refactor: update memory allocation to use void pointers in where_indices_metax --- .../ops/where/metax/where_indices_metax.maca | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/infiniop/ops/where/metax/where_indices_metax.maca b/src/infiniop/ops/where/metax/where_indices_metax.maca index 977c453b6..0a22b7b25 100644 --- a/src/infiniop/ops/where/metax/where_indices_metax.maca +++ b/src/infiniop/ops/where/metax/where_indices_metax.maca @@ -76,9 +76,13 @@ infiniStatus_t IndicesDescriptor::calculate(void *workspace, // 复制 shape 和 strides 到设备(用于 markTrueElements) size_t *d_shape; ptrdiff_t *d_strides; - CHECK_METAX(hcMallocAsync(&d_shape, sizeof(size_t) * _ndim, hc_stream)); + void *d_shape_void; + void *d_strides_void; + CHECK_METAX(hcMallocAsync(&d_shape_void, sizeof(size_t) * _ndim, hc_stream)); CHECK_METAX( - hcMallocAsync(&d_strides, sizeof(ptrdiff_t) * _ndim, hc_stream)); + hcMallocAsync(&d_strides_void, sizeof(ptrdiff_t) * _ndim, hc_stream)); + d_shape = reinterpret_cast(d_shape_void); + d_strides = reinterpret_cast(d_strides_void); CHECK_METAX(hcMemcpyAsync(d_shape, _shape, sizeof(size_t) * _ndim, hcMemcpyHostToDevice, hc_stream)); CHECK_METAX(hcMemcpyAsync(d_strides, _strides, sizeof(ptrdiff_t) * _ndim, @@ -113,8 +117,10 @@ infiniStatus_t IndicesDescriptor::calculate(void *workspace, // 复制 output_ptrs 到设备 int64_t **d_output_ptrs; + void *d_output_ptrs_void; CHECK_METAX( - hcMallocAsync(&d_output_ptrs, sizeof(int64_t *) * _ndim, hc_stream)); + hcMallocAsync(&d_output_ptrs_void, sizeof(int64_t *) * _ndim, hc_stream)); + d_output_ptrs = reinterpret_cast(d_output_ptrs_void); CHECK_METAX(hcMemcpyAsync(d_output_ptrs, output_ptrs, sizeof(int64_t *) * _ndim, hcMemcpyHostToDevice, hc_stream)); @@ -125,9 +131,9 @@ infiniStatus_t IndicesDescriptor::calculate(void *workspace, d_output_ptrs, flags, cond_ptr, d_shape, d_strides, _numel, _ndim); // 清理 - CHECK_METAX(hcFreeAsync(d_shape, hc_stream)); - CHECK_METAX(hcFreeAsync(d_strides, hc_stream)); - CHECK_METAX(hcFreeAsync(d_output_ptrs, hc_stream)); + CHECK_METAX(hcFreeAsync(d_shape_void, hc_stream)); + CHECK_METAX(hcFreeAsync(d_strides_void, hc_stream)); + CHECK_METAX(hcFreeAsync(d_output_ptrs_void, hc_stream)); delete[] output_ptrs; return INFINI_STATUS_SUCCESS; From 25e56a5e3b9a3f657f0801fca40be1c116f8761d Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Wed, 17 Dec 2025 00:08:19 +0800 Subject: [PATCH 16/17] refactor: add stream synchronization after kernel execution in where_indices implementations --- src/infiniop/ops/where/metax/where_indices_metax.maca | 3 +++ src/infiniop/ops/where/moore/where_indices_moore.mu | 3 +++ src/infiniop/ops/where/nvidia/where_indices_nvidia.cu | 3 +++ src/infiniop/ops/where/operator.cc | 5 ++++- 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/src/infiniop/ops/where/metax/where_indices_metax.maca b/src/infiniop/ops/where/metax/where_indices_metax.maca index 0a22b7b25..01a43b10f 100644 --- a/src/infiniop/ops/where/metax/where_indices_metax.maca +++ b/src/infiniop/ops/where/metax/where_indices_metax.maca @@ -130,6 +130,9 @@ infiniStatus_t IndicesDescriptor::calculate(void *workspace, <<>>( d_output_ptrs, flags, cond_ptr, d_shape, d_strides, _numel, _ndim); + // 同步流,确保所有异步操作(包括 collectIndices kernel)完成 + CHECK_METAX(hcStreamSynchronize(hc_stream)); + // 清理 CHECK_METAX(hcFreeAsync(d_shape_void, hc_stream)); CHECK_METAX(hcFreeAsync(d_strides_void, hc_stream)); diff --git a/src/infiniop/ops/where/moore/where_indices_moore.mu b/src/infiniop/ops/where/moore/where_indices_moore.mu index 3058c712b..ab7d5b59d 100644 --- a/src/infiniop/ops/where/moore/where_indices_moore.mu +++ b/src/infiniop/ops/where/moore/where_indices_moore.mu @@ -138,6 +138,9 @@ infiniStatus_t IndicesDescriptor::calculate( op::where::cuda::collectIndices<<>>( d_output_ptrs, flags, cond_ptr, d_shape, d_strides, _numel, _ndim); + // 同步流,确保所有异步操作(包括 collectIndices kernel)完成 + CHECK_MOORE(musaStreamSynchronize(musa_stream)); + // 清理 CHECK_MOORE(musaFree(d_shape)); CHECK_MOORE(musaFree(d_strides)); diff --git a/src/infiniop/ops/where/nvidia/where_indices_nvidia.cu b/src/infiniop/ops/where/nvidia/where_indices_nvidia.cu index eaa410f94..ea4aeee09 100644 --- a/src/infiniop/ops/where/nvidia/where_indices_nvidia.cu +++ b/src/infiniop/ops/where/nvidia/where_indices_nvidia.cu @@ -144,6 +144,9 @@ infiniStatus_t IndicesDescriptor::calculate( d_output_ptrs, flags, cond_ptr, d_shape, d_strides, _numel, _ndim); CHECK_CUDA(cudaGetLastError()); + // 同步流,确保所有异步操作(包括 collectIndices kernel)完成 + CHECK_CUDA(cudaStreamSynchronize(cuda_stream)); + // 清理 CHECK_CUDA(cudaFreeAsync(d_shape, cuda_stream)); CHECK_CUDA(cudaFreeAsync(d_strides, cuda_stream)); diff --git a/src/infiniop/ops/where/operator.cc b/src/infiniop/ops/where/operator.cc index da309a9f8..e13dbd35d 100644 --- a/src/infiniop/ops/where/operator.cc +++ b/src/infiniop/ops/where/operator.cc @@ -231,7 +231,10 @@ __C infiniStatus_t infiniopDestroyWhereIndicesDescriptor(infiniopWhereIndicesDes #define DELETE_INDICES(CASE, NAMESPACE) \ case CASE: \ - delete reinterpret_cast(desc); \ + if (desc != nullptr) { \ + delete reinterpret_cast( \ + const_cast(reinterpret_cast(desc))); \ + } \ return INFINI_STATUS_SUCCESS; switch (desc->device_type) { From efbaabebc3679aa153d67296b4ba62a73fb9558d Mon Sep 17 00:00:00 2001 From: guozhihao-224 Date: Wed, 17 Dec 2025 00:12:46 +0800 Subject: [PATCH 17/17] refactor: comment out cache management in BenchmarkUtils for clarity --- test/infinicore/framework/benchmark.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/infinicore/framework/benchmark.py b/test/infinicore/framework/benchmark.py index 74466feea..c4b105b93 100644 --- a/test/infinicore/framework/benchmark.py +++ b/test/infinicore/framework/benchmark.py @@ -117,10 +117,10 @@ def _clear_cache(): if infinicore.use_ntops: import triton - cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() + # cache = triton.runtime.driver.active.get_empty_cache_for_benchmark() - def _clear_cache(): - triton.runtime.driver.active.clear_cache(cache) + # def _clear_cache(): + # triton.runtime.driver.active.clear_cache(cache) # Create pairs of DeviceEvents for each iteration start_events = [