From 0b4e6d7978002f855bf1b8691c9d2d0d35c5f75c Mon Sep 17 00:00:00 2001 From: Kinorw Date: Wed, 26 Nov 2025 19:33:45 +0800 Subject: [PATCH 01/32] feat: bilinear,but sem out --- include/infiniop.h | 1 + include/infiniop/ops/bilinear.h | 35 ++++ src/infiniop/ops/bilinear/operator.cc | 261 ++++++++++++++++++++++++++ test/infiniop/bilinear.py | 185 ++++++++++++++++++ 4 files changed, 482 insertions(+) create mode 100644 include/infiniop/ops/bilinear.h create mode 100644 src/infiniop/ops/bilinear/operator.cc create mode 100644 test/infiniop/bilinear.py diff --git a/include/infiniop.h b/include/infiniop.h index 92e6f5963..5b6ff7c55 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -4,6 +4,7 @@ #include "infiniop/handle.h" #include "infiniop/ops/add.h" #include "infiniop/ops/attention.h" +#include "infiniop/ops/bilinear.h" #include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" diff --git a/include/infiniop/ops/bilinear.h b/include/infiniop/ops/bilinear.h new file mode 100644 index 000000000..ea4a0555f --- /dev/null +++ b/include/infiniop/ops/bilinear.h @@ -0,0 +1,35 @@ +#ifndef __INFINIOP_BILINEAR_API_H__ +#define __INFINIOP_BILINEAR_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopBilinearDescriptor_t; + +__C __export infiniStatus_t infiniopCreateBilinearDescriptor( + infiniopHandle_t handle, + infiniopBilinearDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t x1_desc, + infiniopTensorDescriptor_t x2_desc, + infiniopTensorDescriptor_t weight_desc, + infiniopTensorDescriptor_t bias_desc); // bias 可以为 nullptr + +__C __export infiniStatus_t infiniopGetBilinearWorkspaceSize( + infiniopBilinearDescriptor_t desc, + size_t *size); + +__C __export infiniStatus_t infiniopBilinear( + infiniopBilinearDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *out, + void const *x1, + void const *x2, + void const *weight, + void const *bias, + void *stream); + +__C __export infiniStatus_t infiniopDestroyBilinearDescriptor( + infiniopBilinearDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/bilinear/operator.cc b/src/infiniop/ops/bilinear/operator.cc new file mode 100644 index 000000000..0edacad25 --- /dev/null +++ b/src/infiniop/ops/bilinear/operator.cc @@ -0,0 +1,261 @@ +#include "../../operator.h" +#include "../../../utils.h" +#include "../../../utils/check.h" +#include "../../handle.h" +#include "../../tensor.h" +#include "infiniop/ops/bilinear.h" +#include "infiniop/ops/gemm.h" +#include "infiniop/ops/add.h" + +#include + +struct InfiniopBilinearDescriptor { + InfiniopDescriptor _super; + infiniopGemmDescriptor_t matmul1_desc; // x2 * W^T -> T + infiniopGemmDescriptor_t matmul2_desc; // x1 * T^T -> Out + infiniopAddDescriptor_t add_desc; // Out + Bias + infiniopTensorDescriptor_t bias_view_desc; + + size_t workspace_size; + size_t t_tensor_offset; // 中间变量 T 的偏移量 + size_t t_tensor_size; // 中间变量 T 的大小 + size_t op_workspace_offset; // 子算子 Workspace 偏移量 + size_t op_workspace_size; // 子算子 Workspace 大小 + + bool has_bias; + bool owns_bias_view_desc; +}; + +__C __export infiniStatus_t infiniopCreateBilinearDescriptor( + infiniopHandle_t handle, + infiniopBilinearDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t out_desc, + infiniopTensorDescriptor_t x1_desc, + infiniopTensorDescriptor_t x2_desc, + infiniopTensorDescriptor_t weight_desc, + infiniopTensorDescriptor_t bias_desc) { + + if (out_desc->ndim() != 2 || x1_desc->ndim() != 2 || x2_desc->ndim() != 2 || weight_desc->ndim() != 3) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if (!out_desc->isContiguous() || !x1_desc->isContiguous() || !x2_desc->isContiguous() || !weight_desc->isContiguous()) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + size_t N = x1_desc->shape()[0]; + size_t H_in1 = x1_desc->shape()[1]; + size_t H_in2 = x2_desc->shape()[1]; + size_t H_out = weight_desc->shape()[0]; + size_t alignment = 256; + + if (x2_desc->shape()[0] != N || weight_desc->shape()[1] != H_in1 || weight_desc->shape()[2] != H_in2) { + return INFINI_STATUS_BAD_PARAM; + } + + if (out_desc->shape()[0] != N || out_desc->shape()[1] != H_out) { + return INFINI_STATUS_BAD_PARAM; + } + + if (x1_desc->dtype() != x2_desc->dtype() || x1_desc->dtype() != out_desc->dtype() || x1_desc->dtype() != weight_desc->dtype()) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + bool has_bias = (bias_desc != nullptr && bias_desc->ndim() > 0); + if (has_bias) { + if (bias_desc->dtype() != out_desc->dtype()) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (bias_desc->shape().back() != H_out) { + return INFINI_STATUS_BAD_PARAM; + } + } + + // 2. 准备 MatMul1: T = x2 * W_flat^T + // x2: [N, H_in2] + // W: [H_out, H_in1, H_in2] -> W_flat: [H_out * H_in1, H_in2] + // 我们需要 W_flat^T,即 [H_in2, H_out * H_in1] + + infiniopTensorDescriptor_t w_view_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&w_view_desc, 3, weight_desc->shape().data(), weight_desc->strides().data(), weight_desc->dtype())); + // 合并前两维: [H_out, H_in1, H_in2] -> [H_out * H_in1, H_in2] + TRANSFORM_TENSOR_DESC(w_view_desc, dimMerge(0, 1)); + // 转置: [H_out * H_in1, H_in2] -> [H_in2, H_out * H_in1] + // 注意:这里的转置是逻辑上的 (stride swap),Gemm 会自动处理 + TRANSFORM_TENSOR_DESC(w_view_desc, dimPermute({1, 0})); + + // 构建中间变量 T 的描述符: [N, H_out * H_in1] + infiniopTensorDescriptor_t t_desc; + size_t t_shape[2] = {N, H_out * H_in1}; + CHECK_STATUS(infiniopCreateTensorDescriptor(&t_desc, 2, t_shape, nullptr, out_desc->dtype())); + + // 创建 MatMul1 描述符 + infiniopGemmDescriptor_t matmul1_desc; + CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul1_desc, t_desc, x2_desc, w_view_desc)); + + // 3. 准备 MatMul2: Out = x1 * T_view^T + // 这是一个 Batch Gemm + // x1: [N, H_in1] -> [N, 1, H_in1] + // T: [N, H_out * H_in1] -> [N, H_out, H_in1] -> 转置为 [N, H_in1, H_out] + // Out: [N, 1, H_out] (对应实际输出 [N, H_out]) + + // x1 视图: [N, 1, H_in1] + infiniopTensorDescriptor_t x1_view_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&x1_view_desc, 2, x1_desc->shape().data(), x1_desc->strides().data(), x1_desc->dtype())); + TRANSFORM_TENSOR_DESC(x1_view_desc, dimSplit(1, {1, H_in1})); + + // T 视图: [N, H_out, H_in1] -> 转置后作为 B 输入 + infiniopTensorDescriptor_t t_view_desc; + // 这里的 stride 需要根据 t_desc (即 workspace 中的连续内存) 来推导 + CHECK_STATUS(infiniopCreateTensorDescriptor(&t_view_desc, 2, t_desc->shape().data(), nullptr, t_desc->dtype())); + TRANSFORM_TENSOR_DESC(t_view_desc, dimSplit(1, {H_out, H_in1})); // [N, H_out, H_in1] + TRANSFORM_TENSOR_DESC(t_view_desc, dimPermute({0, 2, 1})); // [N, H_in1, H_out] + + // Out 视图: [N, 1, H_out] (Gemm 输出需要 3D 以匹配 Batch) + infiniopTensorDescriptor_t out_view_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&out_view_desc, 2, out_desc->shape().data(), out_desc->strides().data(), out_desc->dtype())); + TRANSFORM_TENSOR_DESC(out_view_desc, dimSplit(1, {1, H_out})); + + // 创建 MatMul2 描述符 + infiniopGemmDescriptor_t matmul2_desc; + CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul2_desc, out_view_desc, x1_view_desc, t_view_desc)); + + // 4. 准备 Bias Add (可选) + infiniopAddDescriptor_t add_desc = nullptr; + size_t add_ws = 0; + infiniopTensorDescriptor_t bias_view_desc = nullptr; + bool owns_bias_view_desc = false; + if (has_bias) { + if (bias_desc->ndim() == 1 && bias_desc->shape()[0] == H_out) { + size_t bias_shape[2] = {N, H_out}; + ssize_t bias_strides[2] = {0, bias_desc->strides()[0]}; + CHECK_STATUS(infiniopCreateTensorDescriptor(&bias_view_desc, 2, bias_shape, bias_strides, bias_desc->dtype())); + owns_bias_view_desc = true; + } else if (bias_desc->ndim() == 2 && bias_desc->shape()[0] == N && bias_desc->shape()[1] == H_out) { + bias_view_desc = bias_desc; + } else { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + // 使用 Add 算子: Out = Out + Bias (Broadcast) + CHECK_STATUS(infiniopCreateAddDescriptor(handle, &add_desc, out_desc, out_desc, bias_view_desc)); + CHECK_STATUS(infiniopGetAddWorkspaceSize(add_desc, &add_ws)); + } + + // 5. 计算 Workspace + size_t mm1_ws, mm2_ws; + CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul1_desc, &mm1_ws)); + CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul2_desc, &mm2_ws)); + + // 对齐 + mm1_ws = utils::align(mm1_ws, alignment); + mm2_ws = utils::align(mm2_ws, alignment); + add_ws = utils::align(add_ws, alignment); + + // 中间变量 T 的大小 + size_t t_size = utils::align(t_desc->numel() * infiniSizeOf(t_desc->dtype()), alignment); + + // 总 Workspace = T + Max(Op_Workspace) + size_t op_ws_size = std::max(mm1_ws, std::max(mm2_ws, add_ws)); + size_t op_ws_offset = 0; + size_t total_size = t_size; + if (op_ws_size > 0) { + op_ws_offset = utils::align(total_size, alignment); + total_size = op_ws_offset + op_ws_size; + } + + *(InfiniopBilinearDescriptor **)desc_ptr = new InfiniopBilinearDescriptor{ + {handle->device, handle->device_id}, + matmul1_desc, + matmul2_desc, + add_desc, + bias_view_desc, + total_size, + 0, + t_size, + op_ws_offset, + op_ws_size, + has_bias, + owns_bias_view_desc + }; + + // 清理临时描述符 (如果是新创建的对象) + // 注意:InfiniOP 的 Create 接口通常会拷贝描述符信息,所以这里释放是安全的 + // 具体依赖于底层实现,但通常 Descriptor 指针如果是 new 出来的且没被 take,则需要处理。 + // 在此简略,假设 TRANSFORM 宏处理了所有权转移。 + + return INFINI_STATUS_SUCCESS; +} + +__C __export infiniStatus_t infiniopGetBilinearWorkspaceSize( + infiniopBilinearDescriptor_t desc, + size_t *size) { + *size = ((InfiniopBilinearDescriptor *)desc)->workspace_size; + return INFINI_STATUS_SUCCESS; +} + +__C __export infiniStatus_t infiniopBilinear( + infiniopBilinearDescriptor_t desc_, + void *workspace, + size_t workspace_size, + void *out, + const void *x1, + const void *x2, + const void *weight, + const void *bias, + void *stream) { + + auto desc = (InfiniopBilinearDescriptor *)desc_; + printf("Executing Bilinear Op on device %d (id %d)\n", desc->_super.device_type, desc->_super.device_id); + if (workspace_size < desc->workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + if (desc->has_bias && bias == nullptr) { + return INFINI_STATUS_BAD_PARAM; + } + + // 指针计算 + void *t_ptr = (char *)workspace + desc->t_tensor_offset; + void *op_ws_ptr = desc->op_workspace_size > 0 ? (char *)workspace + desc->op_workspace_offset : nullptr; + size_t op_ws_size = desc->op_workspace_size; + printf("Bilinear Op workspace: total %zu bytes, T tensor at offset %zu, op workspace at offset %zu size %zu\n", + desc->workspace_size, desc->t_tensor_offset, desc->op_workspace_offset, desc->op_workspace_size); + // 1. 执行 MatMul1: T = x2 * W^T + // Gemm: C(T) = alpha * A(x2) * B(W^T) + CHECK_STATUS(infiniopGemm(desc->matmul1_desc, + op_ws_ptr, op_ws_size, + t_ptr, x2, weight, 1.0f, 0.0f, stream)); + + // 2. 执行 MatMul2: Out = x1 * T^T + // Gemm: C(Out) = alpha * A(x1) * B(T^T) + CHECK_STATUS(infiniopGemm(desc->matmul2_desc, + op_ws_ptr, op_ws_size, + out, x1, t_ptr, 1.0f, 0.0f, stream)); + + printf("Bilinear Op MatMul2 completed, Out tensor computed.\n"); + printf("Bilinear Op Bias Add starting.\n"); + // 3. 执行 Bias Add (可选) + if (desc->has_bias && desc->add_desc) { + CHECK_STATUS(infiniopAdd(desc->add_desc, + op_ws_ptr, op_ws_size, + out, out, bias, stream)); + } + + return INFINI_STATUS_SUCCESS; +} + +__C __export infiniStatus_t infiniopDestroyBilinearDescriptor( + infiniopBilinearDescriptor_t desc_) { + + auto desc = (InfiniopBilinearDescriptor *)desc_; + CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->matmul1_desc)); + CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->matmul2_desc)); + if (desc->add_desc) { + CHECK_STATUS(infiniopDestroyAddDescriptor(desc->add_desc)); + if (desc->owns_bias_view_desc && desc->bias_view_desc) { + CHECK_STATUS(infiniopDestroyTensorDescriptor(desc->bias_view_desc)); + } + } + delete desc; + return INFINI_STATUS_SUCCESS; +} \ No newline at end of file diff --git a/test/infiniop/bilinear.py b/test/infiniop/bilinear.py new file mode 100644 index 000000000..436298bc6 --- /dev/null +++ b/test/infiniop/bilinear.py @@ -0,0 +1,185 @@ +import torch +import ctypes +from ctypes import c_uint64 + +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) + + +_TEST_CASES = [ + # batch, in1, in2, out, use_bias + (4, 3, 5, 2, True), + (1, 6, 7, 3, True), + (8, 2, 4, 5, False), + (2, 3, 3, 4, True), + (6, 10, 12, 7, False), + (3, 1, 1, 2, True), +] + +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 5e-2}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-4}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def reference_bilinear(x1, x2, weight, bias): + """Compute bilinear output in FP32 for stability and cast back.""" + out = torch.einsum( + "ni,oij,nj->no", + x1.to(torch.float32), + weight.to(torch.float32), + x2.to(torch.float32), + ) + if bias is not None: + out = out + bias.to(torch.float32) + return out.to(x1.dtype) + + +def test( + handle, + device, + batch, + in1_features, + in2_features, + out_features, + use_bias, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing Bilinear on {InfiniDeviceNames[device]} with N:{batch} in1:{in1_features} in2:{in2_features} " + f"out:{out_features} bias:{use_bias} dtype:{InfiniDtypeNames[dtype]}" + ) + + out_tensor = TestTensor((batch, out_features), None, dtype, device, mode="zeros") + x1 = TestTensor((batch, in1_features), None, dtype, device, scale=0.1, bias=-0.05) + x2 = TestTensor((batch, in2_features), None, dtype, device, scale=0.1, bias=-0.05) + weight = TestTensor( + (out_features, in1_features, in2_features), + None, + dtype, + device, + scale=0.1, + bias=-0.05, + ) + bias_tensor = ( + TestTensor((out_features,), None, dtype, device, scale=0.1, bias=-0.05) + if use_bias + else None + ) + + ref = reference_bilinear( + x1.torch_tensor(), + x2.torch_tensor(), + weight.torch_tensor(), + bias_tensor.torch_tensor() if bias_tensor else None, + ) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateBilinearDescriptor( + handle, + ctypes.byref(descriptor), + out_tensor.descriptor, + x1.descriptor, + x2.descriptor, + weight.descriptor, + bias_tensor.descriptor if bias_tensor else None, + ) + ) + + tensors = [out_tensor, x1, x2, weight] + if bias_tensor: + tensors.append(bias_tensor) + for tensor in tensors: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetBilinearWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_bilinear(): + check_error( + LIBINFINIOP.infiniopBilinear( + descriptor, + workspace.data(), + workspace_size.value, + out_tensor.data(), + x1.data(), + x2.data(), + weight.data(), + bias_tensor.data() if bias_tensor else None, + None, + ) + ) + print(f"workspace size: {workspace_size.value} bytes") + print(f"tensor shapes: x1{tuple(x1.shape)} x2{tuple(x2.shape)} weight{tuple(weight.shape)} bias{tuple(bias_tensor.shape) if bias_tensor else None} out{tuple(out_tensor.shape)}") + lib_bilinear() + print(f"new workspace size: {workspace_size.value} bytes") + print(f"new tensor shapes: x1{tuple(x1.shape)} x2{tuple(x2.shape)} weight{tuple(weight.shape)} bias{tuple(bias_tensor.shape) if bias_tensor else None} out{tuple(out_tensor.shape)}") + print(f"--") + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(out_tensor.actual_tensor(), ref, atol=atol, rtol=rtol) + assert torch.allclose(out_tensor.actual_tensor(), ref, atol=atol, rtol=rtol) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: reference_bilinear( + x1.torch_tensor(), + x2.torch_tensor(), + weight.torch_tensor(), + bias_tensor.torch_tensor() if bias_tensor else None, + ), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation(" lib", lambda: lib_bilinear(), device, NUM_PRERUN, NUM_ITERATIONS) + + check_error(LIBINFINIOP.infiniopDestroyBilinearDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") From dfa35c1743b581d9982e1c4df526db7c4d0800f9 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Wed, 26 Nov 2025 20:11:04 +0800 Subject: [PATCH 02/32] fix: test/op_register.py --- src/infiniop/ops/bilinear/operator.cc | 378 +++++++++++------------ test/infiniop/libinfiniop/op_register.py | 38 +++ 2 files changed, 227 insertions(+), 189 deletions(-) diff --git a/src/infiniop/ops/bilinear/operator.cc b/src/infiniop/ops/bilinear/operator.cc index 0edacad25..5ec71b460 100644 --- a/src/infiniop/ops/bilinear/operator.cc +++ b/src/infiniop/ops/bilinear/operator.cc @@ -6,26 +6,33 @@ #include "infiniop/ops/bilinear.h" #include "infiniop/ops/gemm.h" #include "infiniop/ops/add.h" +#include "infiniop/ops/rearrange.h" #include struct InfiniopBilinearDescriptor { InfiniopDescriptor _super; - infiniopGemmDescriptor_t matmul1_desc; // x2 * W^T -> T - infiniopGemmDescriptor_t matmul2_desc; // x1 * T^T -> Out - infiniopAddDescriptor_t add_desc; // Out + Bias - infiniopTensorDescriptor_t bias_view_desc; - + infiniopGemmDescriptor_t imediate_desc; + infiniopGemmDescriptor_t result_desc; + infiniopRearrangeDescriptor_t weight_rearrange_desc; + infiniopAddDescriptor_t bias_add_desc; size_t workspace_size; - size_t t_tensor_offset; // 中间变量 T 的偏移量 - size_t t_tensor_size; // 中间变量 T 的大小 - size_t op_workspace_offset; // 子算子 Workspace 偏移量 - size_t op_workspace_size; // 子算子 Workspace 大小 - - bool has_bias; - bool owns_bias_view_desc; + size_t weight_offset; + size_t weight_size; + size_t imediate_offset; + size_t imediate_size; + size_t op_workspace_offset; + size_t op_workspace_size; }; +namespace { +constexpr size_t kWorkspaceAlignment = 256; + +size_t aligned_size(size_t value) { + return utils::align(value, kWorkspaceAlignment); +} +} // namespace + __C __export infiniStatus_t infiniopCreateBilinearDescriptor( infiniopHandle_t handle, infiniopBilinearDescriptor_t *desc_ptr, @@ -39,157 +46,133 @@ __C __export infiniStatus_t infiniopCreateBilinearDescriptor( return INFINI_STATUS_BAD_TENSOR_SHAPE; } - if (!out_desc->isContiguous() || !x1_desc->isContiguous() || !x2_desc->isContiguous() || !weight_desc->isContiguous()) { - return INFINI_STATUS_BAD_TENSOR_STRIDES; - } - - size_t N = x1_desc->shape()[0]; - size_t H_in1 = x1_desc->shape()[1]; - size_t H_in2 = x2_desc->shape()[1]; - size_t H_out = weight_desc->shape()[0]; - size_t alignment = 256; - - if (x2_desc->shape()[0] != N || weight_desc->shape()[1] != H_in1 || weight_desc->shape()[2] != H_in2) { - return INFINI_STATUS_BAD_PARAM; - } + size_t batch_size = x1_desc->shape()[0]; + size_t in1_features = x1_desc->shape()[1]; + size_t in2_features = x2_desc->shape()[1]; + size_t out_features = out_desc->shape()[1]; - if (out_desc->shape()[0] != N || out_desc->shape()[1] != H_out) { - return INFINI_STATUS_BAD_PARAM; + if (x2_desc->shape()[0] != batch_size || + weight_desc->shape()[0] != out_features || + weight_desc->shape()[1] != in1_features || + weight_desc->shape()[2] != in2_features || + out_desc->shape()[0] != batch_size) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; } - if (x1_desc->dtype() != x2_desc->dtype() || x1_desc->dtype() != out_desc->dtype() || x1_desc->dtype() != weight_desc->dtype()) { - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } + auto dtype = out_desc->dtype(); + CHECK_OR_RETURN(x1_desc->dtype() == dtype && x2_desc->dtype() == dtype && weight_desc->dtype() == dtype, + INFINI_STATUS_BAD_TENSOR_DTYPE); - bool has_bias = (bias_desc != nullptr && bias_desc->ndim() > 0); - if (has_bias) { - if (bias_desc->dtype() != out_desc->dtype()) { - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - if (bias_desc->shape().back() != H_out) { - return INFINI_STATUS_BAD_PARAM; - } + if (bias_desc) { + CHECK_OR_RETURN(bias_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(bias_desc->ndim() == 1 && bias_desc->dim(0) == out_features, + INFINI_STATUS_BAD_TENSOR_SHAPE); } - // 2. 准备 MatMul1: T = x2 * W_flat^T - // x2: [N, H_in2] - // W: [H_out, H_in1, H_in2] -> W_flat: [H_out * H_in1, H_in2] - // 我们需要 W_flat^T,即 [H_in2, H_out * H_in1] - - infiniopTensorDescriptor_t w_view_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&w_view_desc, 3, weight_desc->shape().data(), weight_desc->strides().data(), weight_desc->dtype())); - // 合并前两维: [H_out, H_in1, H_in2] -> [H_out * H_in1, H_in2] - TRANSFORM_TENSOR_DESC(w_view_desc, dimMerge(0, 1)); - // 转置: [H_out * H_in1, H_in2] -> [H_in2, H_out * H_in1] - // 注意:这里的转置是逻辑上的 (stride swap),Gemm 会自动处理 - TRANSFORM_TENSOR_DESC(w_view_desc, dimPermute({1, 0})); - - // 构建中间变量 T 的描述符: [N, H_out * H_in1] - infiniopTensorDescriptor_t t_desc; - size_t t_shape[2] = {N, H_out * H_in1}; - CHECK_STATUS(infiniopCreateTensorDescriptor(&t_desc, 2, t_shape, nullptr, out_desc->dtype())); - - // 创建 MatMul1 描述符 - infiniopGemmDescriptor_t matmul1_desc; - CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul1_desc, t_desc, x2_desc, w_view_desc)); - - // 3. 准备 MatMul2: Out = x1 * T_view^T - // 这是一个 Batch Gemm - // x1: [N, H_in1] -> [N, 1, H_in1] - // T: [N, H_out * H_in1] -> [N, H_out, H_in1] -> 转置为 [N, H_in1, H_out] - // Out: [N, 1, H_out] (对应实际输出 [N, H_out]) - - // x1 视图: [N, 1, H_in1] - infiniopTensorDescriptor_t x1_view_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&x1_view_desc, 2, x1_desc->shape().data(), x1_desc->strides().data(), x1_desc->dtype())); - TRANSFORM_TENSOR_DESC(x1_view_desc, dimSplit(1, {1, H_in1})); - - // T 视图: [N, H_out, H_in1] -> 转置后作为 B 输入 - infiniopTensorDescriptor_t t_view_desc; - // 这里的 stride 需要根据 t_desc (即 workspace 中的连续内存) 来推导 - CHECK_STATUS(infiniopCreateTensorDescriptor(&t_view_desc, 2, t_desc->shape().data(), nullptr, t_desc->dtype())); - TRANSFORM_TENSOR_DESC(t_view_desc, dimSplit(1, {H_out, H_in1})); // [N, H_out, H_in1] - TRANSFORM_TENSOR_DESC(t_view_desc, dimPermute({0, 2, 1})); // [N, H_in1, H_out] - - // Out 视图: [N, 1, H_out] (Gemm 输出需要 3D 以匹配 Batch) - infiniopTensorDescriptor_t out_view_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&out_view_desc, 2, out_desc->shape().data(), out_desc->strides().data(), out_desc->dtype())); - TRANSFORM_TENSOR_DESC(out_view_desc, dimSplit(1, {1, H_out})); - - // 创建 MatMul2 描述符 - infiniopGemmDescriptor_t matmul2_desc; - CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &matmul2_desc, out_view_desc, x1_view_desc, t_view_desc)); - - // 4. 准备 Bias Add (可选) - infiniopAddDescriptor_t add_desc = nullptr; - size_t add_ws = 0; - infiniopTensorDescriptor_t bias_view_desc = nullptr; - bool owns_bias_view_desc = false; - if (has_bias) { - if (bias_desc->ndim() == 1 && bias_desc->shape()[0] == H_out) { - size_t bias_shape[2] = {N, H_out}; - ssize_t bias_strides[2] = {0, bias_desc->strides()[0]}; - CHECK_STATUS(infiniopCreateTensorDescriptor(&bias_view_desc, 2, bias_shape, bias_strides, bias_desc->dtype())); - owns_bias_view_desc = true; - } else if (bias_desc->ndim() == 2 && bias_desc->shape()[0] == N && bias_desc->shape()[1] == H_out) { - bias_view_desc = bias_desc; - } else { - return INFINI_STATUS_BAD_TENSOR_SHAPE; - } - // 使用 Add 算子: Out = Out + Bias (Broadcast) - CHECK_STATUS(infiniopCreateAddDescriptor(handle, &add_desc, out_desc, out_desc, bias_view_desc)); - CHECK_STATUS(infiniopGetAddWorkspaceSize(add_desc, &add_ws)); - } + size_t dtype_size = infiniSizeOf(dtype); + + // Prepare packed weight layout to allow reshape into GEMM matrix + size_t weight_shape[3] = {out_features, in1_features, in2_features}; + ptrdiff_t packed_strides[3] = { + static_cast(in2_features), + static_cast(out_features * in2_features), + static_cast(1), + }; - // 5. 计算 Workspace - size_t mm1_ws, mm2_ws; - CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul1_desc, &mm1_ws)); - CHECK_STATUS(infiniopGetGemmWorkspaceSize(matmul2_desc, &mm2_ws)); - - // 对齐 - mm1_ws = utils::align(mm1_ws, alignment); - mm2_ws = utils::align(mm2_ws, alignment); - add_ws = utils::align(add_ws, alignment); - - // 中间变量 T 的大小 - size_t t_size = utils::align(t_desc->numel() * infiniSizeOf(t_desc->dtype()), alignment); - - // 总 Workspace = T + Max(Op_Workspace) - size_t op_ws_size = std::max(mm1_ws, std::max(mm2_ws, add_ws)); - size_t op_ws_offset = 0; - size_t total_size = t_size; - if (op_ws_size > 0) { - op_ws_offset = utils::align(total_size, alignment); - total_size = op_ws_offset + op_ws_size; + infiniopTensorDescriptor_t weight_packed_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&weight_packed_desc, 3, weight_shape, packed_strides, dtype)); + + infiniopRearrangeDescriptor_t weight_rearrange_desc; + CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &weight_rearrange_desc, weight_packed_desc, weight_desc)); + CHECK_STATUS(infiniopDestroyTensorDescriptor(weight_packed_desc)); + + infiniopTensorDescriptor_t weight_matrix_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&weight_matrix_desc, 3, weight_shape, packed_strides, dtype)); + TRANSFORM_TENSOR_DESC(weight_matrix_desc, dimPermute({1, 0, 2})); + TRANSFORM_TENSOR_DESC(weight_matrix_desc, dimMerge(1, 2)); + + size_t imediate_shape[2] = {batch_size, out_features * in2_features}; + infiniopTensorDescriptor_t imediate_flat_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&imediate_flat_desc, 2, imediate_shape, nullptr, dtype)); + + infiniopGemmDescriptor_t imediate_desc; + CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &imediate_desc, imediate_flat_desc, x1_desc, weight_matrix_desc)); + CHECK_STATUS(infiniopDestroyTensorDescriptor(weight_matrix_desc)); + + infiniopTensorDescriptor_t imediate_split_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&imediate_split_desc, 2, imediate_shape, nullptr, dtype)); + TRANSFORM_TENSOR_DESC(imediate_split_desc, dimSplit(1, {out_features, in2_features})); + + auto x2_shape = x2_desc->shape(); + auto x2_strides = x2_desc->strides(); + infiniopTensorDescriptor_t x2_col_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&x2_col_desc, 2, x2_shape.data(), x2_strides.data(), dtype)); + TRANSFORM_TENSOR_DESC(x2_col_desc, dimSplit(1, {in2_features, 1})); + + auto out_shape = out_desc->shape(); + auto out_strides = out_desc->strides(); + infiniopTensorDescriptor_t out_col_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&out_col_desc, 2, out_shape.data(), out_strides.data(), dtype)); + TRANSFORM_TENSOR_DESC(out_col_desc, dimSplit(1, {out_features, 1})); + + infiniopGemmDescriptor_t result_desc; + CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &result_desc, out_col_desc, imediate_split_desc, x2_col_desc)); + + size_t gemm1_workspace_size = 0; + size_t gemm2_workspace_size = 0; + CHECK_STATUS(infiniopGetGemmWorkspaceSize(imediate_desc, &gemm1_workspace_size)); + CHECK_STATUS(infiniopGetGemmWorkspaceSize(result_desc, &gemm2_workspace_size)); + + CHECK_STATUS(infiniopDestroyTensorDescriptor(imediate_flat_desc)); + CHECK_STATUS(infiniopDestroyTensorDescriptor(imediate_split_desc)); + CHECK_STATUS(infiniopDestroyTensorDescriptor(x2_col_desc)); + CHECK_STATUS(infiniopDestroyTensorDescriptor(out_col_desc)); + + infiniopAddDescriptor_t bias_add_desc = nullptr; + size_t add_workspace_size = 0; + if (bias_desc) { + size_t bias_shape[2] = {batch_size, out_features}; + ptrdiff_t bias_strides[2] = {0, bias_desc->stride(0)}; + infiniopTensorDescriptor_t bias_broadcast_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&bias_broadcast_desc, 2, bias_shape, bias_strides, dtype)); + CHECK_STATUS(infiniopCreateAddDescriptor(handle, &bias_add_desc, out_desc, out_desc, bias_broadcast_desc)); + CHECK_STATUS(infiniopGetAddWorkspaceSize(bias_add_desc, &add_workspace_size)); + CHECK_STATUS(infiniopDestroyTensorDescriptor(bias_broadcast_desc)); } + size_t weight_size = aligned_size(out_features * in1_features * in2_features * dtype_size); + size_t imediate_size = aligned_size(batch_size * out_features * in2_features * dtype_size); + + size_t op_workspace_size = std::max(gemm1_workspace_size, gemm2_workspace_size); + op_workspace_size = std::max(op_workspace_size, add_workspace_size); + op_workspace_size = aligned_size(op_workspace_size); + + size_t weight_offset = 0; + size_t imediate_offset = weight_offset + weight_size; + size_t op_workspace_offset = imediate_offset + imediate_size; + size_t workspace_size = op_workspace_offset + op_workspace_size; + *(InfiniopBilinearDescriptor **)desc_ptr = new InfiniopBilinearDescriptor{ {handle->device, handle->device_id}, - matmul1_desc, - matmul2_desc, - add_desc, - bias_view_desc, - total_size, - 0, - t_size, - op_ws_offset, - op_ws_size, - has_bias, - owns_bias_view_desc - }; - - // 清理临时描述符 (如果是新创建的对象) - // 注意:InfiniOP 的 Create 接口通常会拷贝描述符信息,所以这里释放是安全的 - // 具体依赖于底层实现,但通常 Descriptor 指针如果是 new 出来的且没被 take,则需要处理。 - // 在此简略,假设 TRANSFORM 宏处理了所有权转移。 + imediate_desc, + result_desc, + weight_rearrange_desc, + bias_add_desc, + workspace_size, + weight_offset, + weight_size, + imediate_offset, + imediate_size, + op_workspace_offset, + op_workspace_size}; return INFINI_STATUS_SUCCESS; } __C __export infiniStatus_t infiniopGetBilinearWorkspaceSize( - infiniopBilinearDescriptor_t desc, + infiniopBilinearDescriptor_t desc, size_t *size) { - *size = ((InfiniopBilinearDescriptor *)desc)->workspace_size; + *size = reinterpret_cast(desc)->workspace_size; return INFINI_STATUS_SUCCESS; } @@ -198,64 +181,81 @@ __C __export infiniStatus_t infiniopBilinear( void *workspace, size_t workspace_size, void *out, - const void *x1, - const void *x2, - const void *weight, - const void *bias, + void const *x1, + void const *x2, + void const *weight, + void const *bias, void *stream) { - - auto desc = (InfiniopBilinearDescriptor *)desc_; - printf("Executing Bilinear Op on device %d (id %d)\n", desc->_super.device_type, desc->_super.device_id); + + auto desc = reinterpret_cast(desc_); if (workspace_size < desc->workspace_size) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; } - if (desc->has_bias && bias == nullptr) { + if (desc->bias_add_desc && bias == nullptr) { return INFINI_STATUS_BAD_PARAM; } - // 指针计算 - void *t_ptr = (char *)workspace + desc->t_tensor_offset; - void *op_ws_ptr = desc->op_workspace_size > 0 ? (char *)workspace + desc->op_workspace_offset : nullptr; - size_t op_ws_size = desc->op_workspace_size; - printf("Bilinear Op workspace: total %zu bytes, T tensor at offset %zu, op workspace at offset %zu size %zu\n", - desc->workspace_size, desc->t_tensor_offset, desc->op_workspace_offset, desc->op_workspace_size); - // 1. 执行 MatMul1: T = x2 * W^T - // Gemm: C(T) = alpha * A(x2) * B(W^T) - CHECK_STATUS(infiniopGemm(desc->matmul1_desc, - op_ws_ptr, op_ws_size, - t_ptr, x2, weight, 1.0f, 0.0f, stream)); - - // 2. 执行 MatMul2: Out = x1 * T^T - // Gemm: C(Out) = alpha * A(x1) * B(T^T) - CHECK_STATUS(infiniopGemm(desc->matmul2_desc, - op_ws_ptr, op_ws_size, - out, x1, t_ptr, 1.0f, 0.0f, stream)); - - printf("Bilinear Op MatMul2 completed, Out tensor computed.\n"); - printf("Bilinear Op Bias Add starting.\n"); - // 3. 执行 Bias Add (可选) - if (desc->has_bias && desc->add_desc) { - CHECK_STATUS(infiniopAdd(desc->add_desc, - op_ws_ptr, op_ws_size, - out, out, bias, stream)); + if (!desc->bias_add_desc && bias != nullptr) { + return INFINI_STATUS_BAD_PARAM; + } + + char *workspace_ptr = reinterpret_cast(workspace); + void *weight_buffer = workspace_ptr + desc->weight_offset; + void *imediate_buffer = workspace_ptr + desc->imediate_offset; + void *op_workspace = workspace_ptr + desc->op_workspace_offset; + printf("workspace_ptr: %p, weight_buffer: %p, imediate_buffer: %p, op_workspace: %p\n", + workspace_ptr, weight_buffer, imediate_buffer, op_workspace); + + printf("desc->weight_rearrange_desc: %p, weight_buffer: %p, weight: %p, stream: %p\n", + desc->weight_rearrange_desc, weight_buffer, weight, stream); + CHECK_STATUS(infiniopRearrange(desc->weight_rearrange_desc, weight_buffer, weight, stream)); + printf("Bilinear: weight rearranged\n"); + + CHECK_STATUS(infiniopGemm(desc->imediate_desc, + op_workspace, desc->op_workspace_size, + imediate_buffer, + x1, + weight_buffer, + 1.0f, + 0.0f, + stream)); + printf("Bilinear: first GEMM done\n"); + + CHECK_STATUS(infiniopGemm(desc->result_desc, + op_workspace, desc->op_workspace_size, + out, + imediate_buffer, + x2, + 1.0f, + 0.0f, + stream)); + printf("Bilinear: second GEMM done\n"); + + if (desc->bias_add_desc) { + CHECK_STATUS(infiniopAdd(desc->bias_add_desc, + op_workspace, desc->op_workspace_size, + out, + out, + bias, + stream)); } + printf("Bilinear: bias added\n"); return INFINI_STATUS_SUCCESS; } __C __export infiniStatus_t infiniopDestroyBilinearDescriptor( infiniopBilinearDescriptor_t desc_) { - - auto desc = (InfiniopBilinearDescriptor *)desc_; - CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->matmul1_desc)); - CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->matmul2_desc)); - if (desc->add_desc) { - CHECK_STATUS(infiniopDestroyAddDescriptor(desc->add_desc)); - if (desc->owns_bias_view_desc && desc->bias_view_desc) { - CHECK_STATUS(infiniopDestroyTensorDescriptor(desc->bias_view_desc)); - } + auto desc = reinterpret_cast(desc_); + + CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->imediate_desc)); + CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->result_desc)); + CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->weight_rearrange_desc)); + if (desc->bias_add_desc) { + CHECK_STATUS(infiniopDestroyAddDescriptor(desc->bias_add_desc)); } + delete desc; return INFINI_STATUS_SUCCESS; -} \ No newline at end of file +} diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 5b2974111..e95764e20 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -938,3 +938,41 @@ def tanh_(lib): lib.infiniopDestroyTanhDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] + + +@OpRegister.operator +def bilinear_(lib): + lib.infiniopCreateBilinearDescriptor.restype = c_int32 + lib.infiniopCreateBilinearDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetBilinearWorkspaceSize.restype = c_int32 + lib.infiniopGetBilinearWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopBilinear.restype = c_int32 + lib.infiniopBilinear.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyBilinearDescriptor.restype = c_int32 + lib.infiniopDestroyBilinearDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] From 7349c80f5af47bbf021d91a285ead217d8bfce84 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Thu, 27 Nov 2025 13:58:00 +0800 Subject: [PATCH 03/32] =?UTF-8?q?=E5=88=A0=E9=99=A4=E8=B0=83=E8=AF=95?= =?UTF-8?q?=E8=AF=AD=E5=8F=A5=EF=BC=8C=E6=B7=BB=E5=8A=A0=E5=A4=96=E9=83=A8?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops/bilinear.hpp | 16 ++++++ python/infinicore/__init__.py | 1 + python/infinicore/ops/bilinear.py | 28 ++++++++++ src/infinicore/ops/bilinear/bilinear.cc | 27 +++++++++ .../ops/bilinear/bilinear_infiniop.cc | 55 +++++++++++++++++++ src/infinicore/pybind11/ops.hpp | 1 + src/infinicore/pybind11/ops/bilinear.hpp | 29 ++++++++++ src/infiniop/ops/bilinear/bilinear.h | 38 +++++++++++++ src/infiniop/ops/bilinear/operator.cc | 8 --- test/infiniop/bilinear.py | 6 +- 10 files changed, 196 insertions(+), 13 deletions(-) create mode 100644 include/infinicore/ops/bilinear.hpp create mode 100644 python/infinicore/ops/bilinear.py create mode 100644 src/infinicore/ops/bilinear/bilinear.cc create mode 100644 src/infinicore/ops/bilinear/bilinear_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/bilinear.hpp create mode 100644 src/infiniop/ops/bilinear/bilinear.h diff --git a/include/infinicore/ops/bilinear.hpp b/include/infinicore/ops/bilinear.hpp new file mode 100644 index 000000000..46c1de894 --- /dev/null +++ b/include/infinicore/ops/bilinear.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Bilinear { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor); + static void execute(Tensor out, Tensor x1, Tensor x2, Tensor weight, Tensor bias); + static common::OpDispatcher &dispatcher(); +}; + +Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, Tensor bias); +void bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, Tensor bias); +} // namespace infinicore::op \ No newline at end of file diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 06294bf3e..efb529a84 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -99,6 +99,7 @@ # Operations. "add", "attention", + "bilinear", "matmul", "mul", "narrow", diff --git a/python/infinicore/ops/bilinear.py b/python/infinicore/ops/bilinear.py new file mode 100644 index 000000000..f9fd489a6 --- /dev/null +++ b/python/infinicore/ops/bilinear.py @@ -0,0 +1,28 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +def bilinear(input1, input2, weight, bias=None, *, out=None): + if out is None: + if bias is None: + return Tensor(_infinicore.bilinear(input1._underlying, + input2._underlying, + weight._underlying)) + else: + return Tensor(_infinicore.bilinear_bias(input1._underlying, + input2._underlying, + weight._underlying, + bias._underlying)) + + if bias is None: + _infinicore.bilinear_(out._underlying, + input1._underlying, + input2._underlying, + weight._underlying) + else: + _infinicore.bilinear_bias_(out._underlying, + input1._underlying, + input2._underlying, + weight._underlying, + bias._underlying) + + return out \ No newline at end of file diff --git a/src/infinicore/ops/bilinear/bilinear.cc b/src/infinicore/ops/bilinear/bilinear.cc new file mode 100644 index 000000000..41a9a14ed --- /dev/null +++ b/src/infinicore/ops/bilinear/bilinear.cc @@ -0,0 +1,27 @@ +#include "infinicore/ops/bilinear.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Bilinear::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Bilinear::execute(Tensor out, Tensor x1, Tensor x2, Tensor weight, Tensor bias) { + dispatcher().lookup(context::getDevice().getType())(out, x1, x2, weight, bias); +} + +Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, Tensor bias) { + size_t batch_size = x1->shape()[0]; + size_t out_features = weight->shape()[0]; + Shape shape = {batch_size, out_features}; + auto out = Tensor::empty(shape, x1->dtype(), x1->device()); + bilinear_(out, x1, x2, weight, bias); + return out; +} + +void bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, Tensor bias) { + Bilinear::execute(out, x1, x2, weight, bias); +} + +} // namespace infinicore::op \ No newline at end of file diff --git a/src/infinicore/ops/bilinear/bilinear_infiniop.cc b/src/infinicore/ops/bilinear/bilinear_infiniop.cc new file mode 100644 index 000000000..1a92e02d4 --- /dev/null +++ b/src/infinicore/ops/bilinear/bilinear_infiniop.cc @@ -0,0 +1,55 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/bilinear.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::bilinear_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopBilinearDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyBilinearDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor out, Tensor x1, Tensor x2, Tensor weight, Tensor bias) { + size_t seed = hash_combine(out, x1, x2, weight, bias); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopBilinearDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateBilinearDescriptor( + context::getInfiniopHandle(out->device()), &desc, + out->desc(), x1->desc(), x2->desc(), weight->desc(), + bias.operator->() ? bias->desc() : nullptr)); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetBilinearWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopBilinear( + desc, workspace->data(), workspace_size, + out->data(), x1->data(), x2->data(), + weight->data(), bias.operator->() ? bias->data() : nullptr, + context::getStream())); +} + +static bool registered = []() { + Bilinear::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::bilinear_impl::infiniop \ No newline at end of file diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 978defa17..c7dd6bf0a 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -4,6 +4,7 @@ #include "ops/add.hpp" #include "ops/attention.hpp" +#include "ops/bilinear.hpp" #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" #include "ops/linear.hpp" diff --git a/src/infinicore/pybind11/ops/bilinear.hpp b/src/infinicore/pybind11/ops/bilinear.hpp new file mode 100644 index 000000000..8987f2e12 --- /dev/null +++ b/src/infinicore/pybind11/ops/bilinear.hpp @@ -0,0 +1,29 @@ +#pragma once + +#include + +#include "infinicore/ops/bilinear.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_bilinear(py::module &m) { + m.def("bilinear", + &op::bilinear, + py::arg("x1"), + py::arg("x2"), + py::arg("weight"), + py::arg("bias"), + R"doc(Bilinear transformation of two input tensors.)doc"); + m.def("bilinear_", + &op::bilinear_, + py::arg("out"), + py::arg("x1"), + py::arg("x2"), + py::arg("weight"), + py::arg("bias"), + R"doc(In-place bilinear transformation of two input tensors.)doc"); +} + +} // namespace infinicore::ops \ No newline at end of file diff --git a/src/infiniop/ops/bilinear/bilinear.h b/src/infiniop/ops/bilinear/bilinear.h new file mode 100644 index 000000000..3e78007e2 --- /dev/null +++ b/src/infiniop/ops/bilinear/bilinear.h @@ -0,0 +1,38 @@ +#ifndef __BILINEAR_H__ +#define __BILINEAR_H__ + +#include "../../operator.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::bilinear::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + size_t workspace_size_, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t out_desc, \ + infiniopTensorDescriptor_t x1_desc, \ + infiniopTensorDescriptor_t x2_desc, \ + infiniopTensorDescriptor_t weight_desc, \ + infiniopTensorDescriptor_t bias_desc); \ + \ + } +#endif // __BILINEAR_H__ diff --git a/src/infiniop/ops/bilinear/operator.cc b/src/infiniop/ops/bilinear/operator.cc index 5ec71b460..00f0741cc 100644 --- a/src/infiniop/ops/bilinear/operator.cc +++ b/src/infiniop/ops/bilinear/operator.cc @@ -204,13 +204,8 @@ __C __export infiniStatus_t infiniopBilinear( void *weight_buffer = workspace_ptr + desc->weight_offset; void *imediate_buffer = workspace_ptr + desc->imediate_offset; void *op_workspace = workspace_ptr + desc->op_workspace_offset; - printf("workspace_ptr: %p, weight_buffer: %p, imediate_buffer: %p, op_workspace: %p\n", - workspace_ptr, weight_buffer, imediate_buffer, op_workspace); - printf("desc->weight_rearrange_desc: %p, weight_buffer: %p, weight: %p, stream: %p\n", - desc->weight_rearrange_desc, weight_buffer, weight, stream); CHECK_STATUS(infiniopRearrange(desc->weight_rearrange_desc, weight_buffer, weight, stream)); - printf("Bilinear: weight rearranged\n"); CHECK_STATUS(infiniopGemm(desc->imediate_desc, op_workspace, desc->op_workspace_size, @@ -220,7 +215,6 @@ __C __export infiniStatus_t infiniopBilinear( 1.0f, 0.0f, stream)); - printf("Bilinear: first GEMM done\n"); CHECK_STATUS(infiniopGemm(desc->result_desc, op_workspace, desc->op_workspace_size, @@ -230,7 +224,6 @@ __C __export infiniStatus_t infiniopBilinear( 1.0f, 0.0f, stream)); - printf("Bilinear: second GEMM done\n"); if (desc->bias_add_desc) { CHECK_STATUS(infiniopAdd(desc->bias_add_desc, @@ -240,7 +233,6 @@ __C __export infiniStatus_t infiniopBilinear( bias, stream)); } - printf("Bilinear: bias added\n"); return INFINI_STATUS_SUCCESS; } diff --git a/test/infiniop/bilinear.py b/test/infiniop/bilinear.py index 436298bc6..47af60812 100644 --- a/test/infiniop/bilinear.py +++ b/test/infiniop/bilinear.py @@ -141,12 +141,8 @@ def lib_bilinear(): None, ) ) - print(f"workspace size: {workspace_size.value} bytes") - print(f"tensor shapes: x1{tuple(x1.shape)} x2{tuple(x2.shape)} weight{tuple(weight.shape)} bias{tuple(bias_tensor.shape) if bias_tensor else None} out{tuple(out_tensor.shape)}") + lib_bilinear() - print(f"new workspace size: {workspace_size.value} bytes") - print(f"new tensor shapes: x1{tuple(x1.shape)} x2{tuple(x2.shape)} weight{tuple(weight.shape)} bias{tuple(bias_tensor.shape) if bias_tensor else None} out{tuple(out_tensor.shape)}") - print(f"--") atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: From 08af55407b5167f802ede4dfa84217da2e21e440 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Thu, 27 Nov 2025 15:19:07 +0800 Subject: [PATCH 04/32] =?UTF-8?q?=E8=AE=BF=E9=97=AE=E8=B6=8A=E7=95=8C?= =?UTF-8?q?=E5=BE=85=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops/bilinear.hpp | 9 +- python/infinicore/ops/bilinear.py | 36 ++- src/infinicore/ops/bilinear/bilinear.cc | 6 +- .../ops/bilinear/bilinear_infiniop.cc | 11 +- src/infinicore/pybind11/ops.hpp | 1 + src/infinicore/pybind11/ops/bilinear.hpp | 57 ++++- src/infiniop/ops/bilinear/operator.cc | 220 ++++++++++++------ test/infinicore/ops/bilinear.py | 14 +- 8 files changed, 234 insertions(+), 120 deletions(-) diff --git a/include/infinicore/ops/bilinear.hpp b/include/infinicore/ops/bilinear.hpp index 46c1de894..9ecc9b622 100644 --- a/include/infinicore/ops/bilinear.hpp +++ b/include/infinicore/ops/bilinear.hpp @@ -2,15 +2,16 @@ #include "../device.hpp" #include "common/op.hpp" +#include namespace infinicore::op { class Bilinear { public: - using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor); - static void execute(Tensor out, Tensor x1, Tensor x2, Tensor weight, Tensor bias); + using schema = void (*)(Tensor, Tensor, Tensor, Tensor, std::optional); + static void execute(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias); static common::OpDispatcher &dispatcher(); }; -Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, Tensor bias); -void bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, Tensor bias); +Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias); +void bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias); } // namespace infinicore::op \ No newline at end of file diff --git a/python/infinicore/ops/bilinear.py b/python/infinicore/ops/bilinear.py index f9fd489a6..887bd6016 100644 --- a/python/infinicore/ops/bilinear.py +++ b/python/infinicore/ops/bilinear.py @@ -3,26 +3,20 @@ def bilinear(input1, input2, weight, bias=None, *, out=None): if out is None: - if bias is None: - return Tensor(_infinicore.bilinear(input1._underlying, - input2._underlying, - weight._underlying)) - else: - return Tensor(_infinicore.bilinear_bias(input1._underlying, - input2._underlying, - weight._underlying, - bias._underlying)) - - if bias is None: - _infinicore.bilinear_(out._underlying, - input1._underlying, - input2._underlying, - weight._underlying) - else: - _infinicore.bilinear_bias_(out._underlying, - input1._underlying, - input2._underlying, - weight._underlying, - bias._underlying) + return Tensor( + _infinicore.bilinear( + input1._underlying, + input2._underlying, + weight._underlying, + bias._underlying if bias is not None else None, + ) + ) + _infinicore.bilinear_( + out._underlying, + input1._underlying, + input2._underlying, + weight._underlying, + bias._underlying if bias is not None else None, + ) return out \ No newline at end of file diff --git a/src/infinicore/ops/bilinear/bilinear.cc b/src/infinicore/ops/bilinear/bilinear.cc index 41a9a14ed..617acd6a1 100644 --- a/src/infinicore/ops/bilinear/bilinear.cc +++ b/src/infinicore/ops/bilinear/bilinear.cc @@ -7,11 +7,11 @@ common::OpDispatcher &Bilinear::dispatcher() { return dispatcher_; }; -void Bilinear::execute(Tensor out, Tensor x1, Tensor x2, Tensor weight, Tensor bias) { +void Bilinear::execute(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias) { dispatcher().lookup(context::getDevice().getType())(out, x1, x2, weight, bias); } -Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, Tensor bias) { +Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) { size_t batch_size = x1->shape()[0]; size_t out_features = weight->shape()[0]; Shape shape = {batch_size, out_features}; @@ -20,7 +20,7 @@ Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, Tensor bias) { return out; } -void bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, Tensor bias) { +void bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias) { Bilinear::execute(out, x1, x2, weight, bias); } diff --git a/src/infinicore/ops/bilinear/bilinear_infiniop.cc b/src/infinicore/ops/bilinear/bilinear_infiniop.cc index 1a92e02d4..44497c80c 100644 --- a/src/infinicore/ops/bilinear/bilinear_infiniop.cc +++ b/src/infinicore/ops/bilinear/bilinear_infiniop.cc @@ -15,8 +15,11 @@ thread_local common::OpCache caches( } }); -void calculate(Tensor out, Tensor x1, Tensor x2, Tensor weight, Tensor bias) { - size_t seed = hash_combine(out, x1, x2, weight, bias); +void calculate(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias) { + size_t seed = hash_combine(out, x1, x2, weight); + if (bias) { + seed = hash_combine(out, x1, x2, weight,*bias); + } auto device_type = context::getDevice().getType(); auto device_index = context::getDevice().getIndex(); @@ -30,7 +33,7 @@ void calculate(Tensor out, Tensor x1, Tensor x2, Tensor weight, Tensor bias) { INFINICORE_CHECK_ERROR(infiniopCreateBilinearDescriptor( context::getInfiniopHandle(out->device()), &desc, out->desc(), x1->desc(), x2->desc(), weight->desc(), - bias.operator->() ? bias->desc() : nullptr)); + bias ? (*bias)->desc() : nullptr)); cache.put(seed, desc); } else { desc = *desc_opt; @@ -43,7 +46,7 @@ void calculate(Tensor out, Tensor x1, Tensor x2, Tensor weight, Tensor bias) { INFINICORE_CHECK_ERROR(infiniopBilinear( desc, workspace->data(), workspace_size, out->data(), x1->data(), x2->data(), - weight->data(), bias.operator->() ? bias->data() : nullptr, + weight->data(), bias ? (*bias)->data() : nullptr, context::getStream())); } diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index c7dd6bf0a..dfbd16434 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -24,6 +24,7 @@ namespace infinicore::ops { inline void bind(py::module &m) { bind_add(m); bind_attention(m); + bind_bilinear(m); bind_causal_softmax(m); bind_random_sample(m); bind_linear(m); diff --git a/src/infinicore/pybind11/ops/bilinear.hpp b/src/infinicore/pybind11/ops/bilinear.hpp index 8987f2e12..08d0447f4 100644 --- a/src/infinicore/pybind11/ops/bilinear.hpp +++ b/src/infinicore/pybind11/ops/bilinear.hpp @@ -8,22 +8,55 @@ namespace py = pybind11; namespace infinicore::ops { +Tensor py_bilinear(Tensor x1, Tensor x2, Tensor weight, pybind11::object bias) { + std::optional bias_tensor = std::nullopt; + if (!bias.is_none()) { + bias_tensor = bias.cast(); + } + return op::bilinear(x1, x2, weight, bias_tensor); +} + +void py_bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, pybind11::object bias) { + std::optional bias_tensor = std::nullopt; + if (!bias.is_none()) { + bias_tensor = bias.cast(); + } + op::bilinear_(out, x1, x2, weight, bias_tensor); +} + inline void bind_bilinear(py::module &m) { m.def("bilinear", - &op::bilinear, - py::arg("x1"), - py::arg("x2"), - py::arg("weight"), - py::arg("bias"), - R"doc(Bilinear transformation of two input tensors.)doc"); - m.def("bilinear_", - &op::bilinear_, - py::arg("out"), + &py_bilinear, py::arg("x1"), py::arg("x2"), py::arg("weight"), py::arg("bias"), - R"doc(In-place bilinear transformation of two input tensors.)doc"); -} + R"doc(Bilinear transformation of two input tensors. +Args: + x1: First input tensor + x2: Second input tensor + weight: Weight tensor + bias: Bias tensor (optional) +Returns: + Output tensor after bilinear transformation +)doc"); + + m.def("bilinear_", + &py_bilinear_, + py::arg("out"), + py::arg("x1"), + py::arg("x2"), + py::arg("weight"), + py::arg("bias"), + R"doc(In-place bilinear transformation of two input tensors. +Args: + out: Output tensor + x1: First input tensor + x2: Second input tensor + weight: Weight tensor + bias: Bias tensor (optional) +)doc"); + +} -} // namespace infinicore::ops \ No newline at end of file +} // namespace infinicore::ops \ No newline at end of file diff --git a/src/infiniop/ops/bilinear/operator.cc b/src/infiniop/ops/bilinear/operator.cc index 00f0741cc..01a5f9490 100644 --- a/src/infiniop/ops/bilinear/operator.cc +++ b/src/infiniop/ops/bilinear/operator.cc @@ -12,15 +12,17 @@ struct InfiniopBilinearDescriptor { InfiniopDescriptor _super; - infiniopGemmDescriptor_t imediate_desc; - infiniopGemmDescriptor_t result_desc; + infiniopGemmDescriptor_t gemm_1_desc; + infiniopGemmDescriptor_t gemm_2_desc; infiniopRearrangeDescriptor_t weight_rearrange_desc; + infiniopRearrangeDescriptor_t x1_rearrange_desc; + infiniopRearrangeDescriptor_t x2_rearrange_desc; infiniopAddDescriptor_t bias_add_desc; size_t workspace_size; size_t weight_offset; - size_t weight_size; size_t imediate_offset; - size_t imediate_size; + size_t x1_cont_offset; + size_t x2_cont_offset; size_t op_workspace_offset; size_t op_workspace_size; }; @@ -71,63 +73,110 @@ __C __export infiniStatus_t infiniopCreateBilinearDescriptor( size_t dtype_size = infiniSizeOf(dtype); - // Prepare packed weight layout to allow reshape into GEMM matrix - size_t weight_shape[3] = {out_features, in1_features, in2_features}; - ptrdiff_t packed_strides[3] = { - static_cast(in2_features), - static_cast(out_features * in2_features), - static_cast(1), - }; + // --- Weight Rearrangement --- + // Target: [in1, out * in2] (flattened from [in1, out, in2]) + // Source: [out, in1, in2] + + // 1. Create descriptor for contiguous target weight [in1, out, in2] + size_t weight_dst_shape[3] = {in1_features, out_features, in2_features}; + infiniopTensorDescriptor_t weight_dst_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&weight_dst_desc, 3, weight_dst_shape, nullptr, dtype)); + + // 2. Create descriptor for source weight viewed as [in1, out, in2] (permuted) + infiniopTensorDescriptor_t weight_src_permuted; + CHECK_STATUS(infiniopCreateTensorDescriptor(&weight_src_permuted, 3, weight_desc->shape().data(), weight_desc->strides().data(), dtype)); + TRANSFORM_TENSOR_DESC(weight_src_permuted, dimPermute({1, 0, 2})); + + // 3. Create Rearrange descriptor + infiniopRearrangeDescriptor_t weight_rearrange_desc; + CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &weight_rearrange_desc, weight_dst_desc, weight_src_permuted)); + CHECK_STATUS(infiniopDestroyTensorDescriptor(weight_src_permuted)); - infiniopTensorDescriptor_t weight_packed_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&weight_packed_desc, 3, weight_shape, packed_strides, dtype)); + size_t weight_bytes = in1_features * out_features * in2_features * dtype_size; - infiniopRearrangeDescriptor_t weight_rearrange_desc; - CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &weight_rearrange_desc, weight_packed_desc, weight_desc)); - CHECK_STATUS(infiniopDestroyTensorDescriptor(weight_packed_desc)); + // --- GEMM 1: x1 @ weight_matrix --- + // x1: [batch, in1] + // weight_matrix: [in1, out * in2] + // result (imediate): [batch, out * in2] + // Prepare weight matrix descriptor (view of weight_dst_desc) infiniopTensorDescriptor_t weight_matrix_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&weight_matrix_desc, 3, weight_shape, packed_strides, dtype)); - TRANSFORM_TENSOR_DESC(weight_matrix_desc, dimPermute({1, 0, 2})); + CHECK_STATUS(infiniopCreateTensorDescriptor(&weight_matrix_desc, 3, weight_dst_shape, nullptr, dtype)); TRANSFORM_TENSOR_DESC(weight_matrix_desc, dimMerge(1, 2)); + CHECK_STATUS(infiniopDestroyTensorDescriptor(weight_dst_desc)); // Done with weight_dst_desc + + // Prepare x1 descriptor (handle non-contiguous) + infiniopTensorDescriptor_t x1_gemm_desc = x1_desc; + infiniopRearrangeDescriptor_t x1_rearrange_desc = nullptr; + size_t x1_cont_bytes = 0; + + if (x1_desc->strides()[x1_desc->ndim() - 1] != 1) { + infiniopTensorDescriptor_t x1_cont_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&x1_cont_desc, 2, x1_desc->shape().data(), nullptr, dtype)); + CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &x1_rearrange_desc, x1_cont_desc, x1_desc)); + x1_gemm_desc = x1_cont_desc; // Use the contiguous descriptor for GEMM creation + x1_cont_bytes = batch_size * in1_features * dtype_size; + } + // Prepare imediate descriptor size_t imediate_shape[2] = {batch_size, out_features * in2_features}; infiniopTensorDescriptor_t imediate_flat_desc; CHECK_STATUS(infiniopCreateTensorDescriptor(&imediate_flat_desc, 2, imediate_shape, nullptr, dtype)); - infiniopGemmDescriptor_t imediate_desc; - CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &imediate_desc, imediate_flat_desc, x1_desc, weight_matrix_desc)); + infiniopGemmDescriptor_t gemm_1_desc; + CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &gemm_1_desc, imediate_flat_desc, x1_gemm_desc, weight_matrix_desc)); CHECK_STATUS(infiniopDestroyTensorDescriptor(weight_matrix_desc)); + if (x1_rearrange_desc) { + CHECK_STATUS(infiniopDestroyTensorDescriptor(x1_gemm_desc)); + } + + size_t imediate_bytes = batch_size * out_features * in2_features * dtype_size; + // --- GEMM 2: imediate @ x2^T --- + // We perform Batch GEMM: + // A (imediate): [batch, out, in2] + // B (x2): [batch, in2, 1] + // C (out): [batch, out, 1] + + // Prepare A: imediate viewed as [batch, out, in2] infiniopTensorDescriptor_t imediate_split_desc; CHECK_STATUS(infiniopCreateTensorDescriptor(&imediate_split_desc, 2, imediate_shape, nullptr, dtype)); TRANSFORM_TENSOR_DESC(imediate_split_desc, dimSplit(1, {out_features, in2_features})); + CHECK_STATUS(infiniopDestroyTensorDescriptor(imediate_flat_desc)); // Done with flat + + // Prepare B: x2 viewed as [batch, in2, 1] (handle non-contiguous) + infiniopTensorDescriptor_t x2_gemm_desc = x2_desc; + infiniopRearrangeDescriptor_t x2_rearrange_desc = nullptr; + size_t x2_cont_bytes = 0; + + if (x2_desc->strides()[x2_desc->ndim() - 1] != 1) { + infiniopTensorDescriptor_t x2_cont_desc; + CHECK_STATUS(infiniopCreateTensorDescriptor(&x2_cont_desc, 2, x2_desc->shape().data(), nullptr, dtype)); + CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &x2_rearrange_desc, x2_cont_desc, x2_desc)); + x2_gemm_desc = x2_cont_desc; + x2_cont_bytes = batch_size * in2_features * dtype_size; + } - auto x2_shape = x2_desc->shape(); - auto x2_strides = x2_desc->strides(); infiniopTensorDescriptor_t x2_col_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&x2_col_desc, 2, x2_shape.data(), x2_strides.data(), dtype)); + CHECK_STATUS(infiniopCreateTensorDescriptor(&x2_col_desc, 2, x2_gemm_desc->shape().data(), x2_gemm_desc->strides().data(), dtype)); TRANSFORM_TENSOR_DESC(x2_col_desc, dimSplit(1, {in2_features, 1})); + if (x2_rearrange_desc) { + CHECK_STATUS(infiniopDestroyTensorDescriptor(x2_gemm_desc)); + } - auto out_shape = out_desc->shape(); - auto out_strides = out_desc->strides(); + // Prepare C: out viewed as [batch, out, 1] infiniopTensorDescriptor_t out_col_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&out_col_desc, 2, out_shape.data(), out_strides.data(), dtype)); + CHECK_STATUS(infiniopCreateTensorDescriptor(&out_col_desc, 2, out_desc->shape().data(), out_desc->strides().data(), dtype)); TRANSFORM_TENSOR_DESC(out_col_desc, dimSplit(1, {out_features, 1})); - infiniopGemmDescriptor_t result_desc; - CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &result_desc, out_col_desc, imediate_split_desc, x2_col_desc)); - - size_t gemm1_workspace_size = 0; - size_t gemm2_workspace_size = 0; - CHECK_STATUS(infiniopGetGemmWorkspaceSize(imediate_desc, &gemm1_workspace_size)); - CHECK_STATUS(infiniopGetGemmWorkspaceSize(result_desc, &gemm2_workspace_size)); - - CHECK_STATUS(infiniopDestroyTensorDescriptor(imediate_flat_desc)); + infiniopGemmDescriptor_t gemm_2_desc; + CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &gemm_2_desc, out_col_desc, imediate_split_desc, x2_col_desc)); + CHECK_STATUS(infiniopDestroyTensorDescriptor(imediate_split_desc)); CHECK_STATUS(infiniopDestroyTensorDescriptor(x2_col_desc)); CHECK_STATUS(infiniopDestroyTensorDescriptor(out_col_desc)); + // --- Bias Add --- infiniopAddDescriptor_t bias_add_desc = nullptr; size_t add_workspace_size = 0; if (bias_desc) { @@ -140,29 +189,46 @@ __C __export infiniStatus_t infiniopCreateBilinearDescriptor( CHECK_STATUS(infiniopDestroyTensorDescriptor(bias_broadcast_desc)); } - size_t weight_size = aligned_size(out_features * in1_features * in2_features * dtype_size); - size_t imediate_size = aligned_size(batch_size * out_features * in2_features * dtype_size); + // --- Workspace Calculation --- + size_t gemm1_workspace_size = 0; + size_t gemm2_workspace_size = 0; + CHECK_STATUS(infiniopGetGemmWorkspaceSize(gemm_1_desc, &gemm1_workspace_size)); + CHECK_STATUS(infiniopGetGemmWorkspaceSize(gemm_2_desc, &gemm2_workspace_size)); - size_t op_workspace_size = std::max(gemm1_workspace_size, gemm2_workspace_size); - op_workspace_size = std::max(op_workspace_size, add_workspace_size); + size_t op_workspace_size = std::max({gemm1_workspace_size, gemm2_workspace_size, add_workspace_size}); op_workspace_size = aligned_size(op_workspace_size); - size_t weight_offset = 0; - size_t imediate_offset = weight_offset + weight_size; - size_t op_workspace_offset = imediate_offset + imediate_size; - size_t workspace_size = op_workspace_offset + op_workspace_size; + size_t workspace_cursor = 0; + auto reserve_buffer = [&](size_t bytes) -> size_t { + if (bytes == 0) { + return 0; + } + workspace_cursor = aligned_size(workspace_cursor); + size_t offset = workspace_cursor; + workspace_cursor += bytes; + return offset; + }; + + size_t weight_offset = reserve_buffer(weight_bytes); + size_t imediate_offset = reserve_buffer(imediate_bytes); + size_t x1_cont_offset = x1_rearrange_desc ? reserve_buffer(x1_cont_bytes) : 0; + size_t x2_cont_offset = x2_rearrange_desc ? reserve_buffer(x2_cont_bytes) : 0; + size_t op_workspace_offset = reserve_buffer(op_workspace_size); + size_t workspace_size = aligned_size(workspace_cursor); *(InfiniopBilinearDescriptor **)desc_ptr = new InfiniopBilinearDescriptor{ {handle->device, handle->device_id}, - imediate_desc, - result_desc, + gemm_1_desc, + gemm_2_desc, weight_rearrange_desc, + x1_rearrange_desc, + x2_rearrange_desc, bias_add_desc, workspace_size, weight_offset, - weight_size, imediate_offset, - imediate_size, + x1_cont_offset, + x2_cont_offset, op_workspace_offset, op_workspace_size}; @@ -205,33 +271,37 @@ __C __export infiniStatus_t infiniopBilinear( void *imediate_buffer = workspace_ptr + desc->imediate_offset; void *op_workspace = workspace_ptr + desc->op_workspace_offset; + // 1. Rearrange Weight CHECK_STATUS(infiniopRearrange(desc->weight_rearrange_desc, weight_buffer, weight, stream)); - CHECK_STATUS(infiniopGemm(desc->imediate_desc, - op_workspace, desc->op_workspace_size, - imediate_buffer, - x1, - weight_buffer, - 1.0f, - 0.0f, - stream)); - - CHECK_STATUS(infiniopGemm(desc->result_desc, - op_workspace, desc->op_workspace_size, - out, - imediate_buffer, - x2, - 1.0f, - 0.0f, - stream)); + // 2. Prepare x1 + void const *x1_ptr = x1; + if (desc->x1_rearrange_desc) { + void *x1_buffer = workspace_ptr + desc->x1_cont_offset; + CHECK_STATUS(infiniopRearrange(desc->x1_rearrange_desc, x1_buffer, x1, stream)); + x1_ptr = x1_buffer; + } + + // 3. GEMM 1: x1 @ weight -> imediate + CHECK_STATUS(infiniopGemm(desc->gemm_1_desc, op_workspace, desc->op_workspace_size, + imediate_buffer, x1_ptr, weight_buffer, 1.0f, 0.0f, stream)); + + // 4. Prepare x2 + void const *x2_ptr = x2; + if (desc->x2_rearrange_desc) { + void *x2_buffer = workspace_ptr + desc->x2_cont_offset; + CHECK_STATUS(infiniopRearrange(desc->x2_rearrange_desc, x2_buffer, x2, stream)); + x2_ptr = x2_buffer; + } + + // 5. GEMM 2: imediate @ x2 -> out + CHECK_STATUS(infiniopGemm(desc->gemm_2_desc, op_workspace, desc->op_workspace_size, + out, imediate_buffer, x2_ptr, 1.0f, 0.0f, stream)); + // 6. Bias Add if (desc->bias_add_desc) { - CHECK_STATUS(infiniopAdd(desc->bias_add_desc, - op_workspace, desc->op_workspace_size, - out, - out, - bias, - stream)); + CHECK_STATUS(infiniopAdd(desc->bias_add_desc, op_workspace, desc->op_workspace_size, + out, out, bias, stream)); } return INFINI_STATUS_SUCCESS; @@ -241,9 +311,15 @@ __C __export infiniStatus_t infiniopDestroyBilinearDescriptor( infiniopBilinearDescriptor_t desc_) { auto desc = reinterpret_cast(desc_); - CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->imediate_desc)); - CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->result_desc)); + CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->gemm_1_desc)); + CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->gemm_2_desc)); CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->weight_rearrange_desc)); + if (desc->x1_rearrange_desc) { + CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->x1_rearrange_desc)); + } + if (desc->x2_rearrange_desc) { + CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->x2_rearrange_desc)); + } if (desc->bias_add_desc) { CHECK_STATUS(infiniopDestroyAddDescriptor(desc->bias_add_desc)); } diff --git a/test/infinicore/ops/bilinear.py b/test/infinicore/ops/bilinear.py index d8eba5111..b4f5ae972 100644 --- a/test/infinicore/ops/bilinear.py +++ b/test/infinicore/ops/bilinear.py @@ -44,12 +44,18 @@ def parse_test_cases(): in1 = TensorSpec.from_tensor(in1_shape, in1_strides, dtype) in2 = TensorSpec.from_tensor(in2_shape, in2_strides, dtype) weight = TensorSpec.from_tensor(weight_shape, weight_strides, dtype) + + inputs = [in1, in2, weight] + if bias_present: + bias_shape = (weight_shape[0],) + bias = TensorSpec.from_tensor(bias_shape, None, dtype) + inputs.append(bias) kwargs = {} test_cases.append( TestCase( - inputs=[in1, in2, weight], + inputs=inputs, kwargs=kwargs, output_spec=None, comparison_target=None, @@ -73,9 +79,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.nn.functional.bilinear(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.nn.functional.bilinear(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + from infinicore.ops.bilinear import bilinear + return bilinear(*args, **kwargs) def main(): From cd5768d1c591648eec52bfba45bd7461f711893e Mon Sep 17 00:00:00 2001 From: Kinorw Date: Thu, 27 Nov 2025 18:35:00 +0800 Subject: [PATCH 05/32] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E6=96=B9=E6=B3=95=EF=BC=8C=E4=BB=8ECore/ops=E7=BA=A7=E5=88=AB?= =?UTF-8?q?=E5=AE=8C=E6=88=90=E5=AE=9E=E7=8E=B0=EF=BC=8C=E4=B8=8D=E5=88=9B?= =?UTF-8?q?=E5=BB=BA=E5=8D=95=E7=8B=AC=E7=9A=84InfiniOp?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops/bilinear.hpp | 7 +- src/infinicore/ops/bilinear/bilinear.cc | 48 ++++++-- .../ops/bilinear/bilinear_infiniop.cc | 116 +++++++++--------- 3 files changed, 95 insertions(+), 76 deletions(-) diff --git a/include/infinicore/ops/bilinear.hpp b/include/infinicore/ops/bilinear.hpp index 9ecc9b622..3f5f44aac 100644 --- a/include/infinicore/ops/bilinear.hpp +++ b/include/infinicore/ops/bilinear.hpp @@ -5,13 +5,8 @@ #include namespace infinicore::op { -class Bilinear { -public: - using schema = void (*)(Tensor, Tensor, Tensor, Tensor, std::optional); - static void execute(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias); - static common::OpDispatcher &dispatcher(); -}; Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias); void bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias); + } // namespace infinicore::op \ No newline at end of file diff --git a/src/infinicore/ops/bilinear/bilinear.cc b/src/infinicore/ops/bilinear/bilinear.cc index 617acd6a1..97ff35ac7 100644 --- a/src/infinicore/ops/bilinear/bilinear.cc +++ b/src/infinicore/ops/bilinear/bilinear.cc @@ -1,27 +1,51 @@ #include "infinicore/ops/bilinear.hpp" +#include "infinicore/ops/gemm.hpp" +#include "infinicore/ops/add.hpp" +#include "infinicore/ops/rearrange.hpp" namespace infinicore::op { -common::OpDispatcher &Bilinear::dispatcher() { - static common::OpDispatcher dispatcher_; - return dispatcher_; -}; - -void Bilinear::execute(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias) { - dispatcher().lookup(context::getDevice().getType())(out, x1, x2, weight, bias); -} Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) { + size_t batch_size = x1->shape()[0]; + size_t in1_features = x1->shape()[1]; + size_t in2_features = x2->shape()[1]; size_t out_features = weight->shape()[0]; - Shape shape = {batch_size, out_features}; - auto out = Tensor::empty(shape, x1->dtype(), x1->device()); - bilinear_(out, x1, x2, weight, bias); + + auto dtype = x1->dtype(); + auto device = x1->device(); + + Tensor x1_cont = x1->is_contiguous() ? x1 : x1->contiguous(); + Tensor x2_cont = x2->is_contiguous() ? x2 : x2->contiguous(); + Tensor weight_cont = weight->is_contiguous() ? weight : weight->contiguous(); + + Tensor weight_permuted = weight_cont->permute({1, 0, 2}); + Tensor weight_permuted_cont = weight_permuted->contiguous(); + Tensor weight_matrix = weight_permuted_cont->view({in1_features, out_features * in2_features}); + + Tensor intermediate = gemm(x1_cont, weight_matrix); + + Tensor intermediate_3d = intermediate->view({batch_size, out_features, in2_features}); + + Tensor x2_col = x2_cont->view({batch_size, in2_features, 1}); + + Tensor out_3d = gemm(intermediate_3d, x2_col); + + Tensor out = out_3d->view({batch_size, out_features}); + + if (bias) { + Tensor bias_broadcast = (*bias)->as_strided({batch_size, out_features}, {0, (*bias)->strides()[0]}); + out = add(out, bias_broadcast); + } + return out; } void bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias) { - Bilinear::execute(out, x1, x2, weight, bias); + Tensor result = bilinear(x1, x2, weight, bias); + // Copy result to out + rearrange_(out, result); } } // namespace infinicore::op \ No newline at end of file diff --git a/src/infinicore/ops/bilinear/bilinear_infiniop.cc b/src/infinicore/ops/bilinear/bilinear_infiniop.cc index 44497c80c..b63f5f164 100644 --- a/src/infinicore/ops/bilinear/bilinear_infiniop.cc +++ b/src/infinicore/ops/bilinear/bilinear_infiniop.cc @@ -1,58 +1,58 @@ -#include "../../utils.hpp" -#include "infinicore/common/hash.hpp" -#include "infinicore/ops/bilinear.hpp" -#include "infinicore/ops/common/cache.hpp" -#include - -namespace infinicore::op::bilinear_impl::infiniop { - -thread_local common::OpCache caches( - 100, // capacity - [](infiniopBilinearDescriptor_t &desc) { - if (desc != nullptr) { - INFINICORE_CHECK_ERROR(infiniopDestroyBilinearDescriptor(desc)); - desc = nullptr; - } - }); - -void calculate(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias) { - size_t seed = hash_combine(out, x1, x2, weight); - if (bias) { - seed = hash_combine(out, x1, x2, weight,*bias); - } - - auto device_type = context::getDevice().getType(); - auto device_index = context::getDevice().getIndex(); - - auto &cache = caches.getCache(device_type, device_index); - - auto desc_opt = cache.get(seed); - infiniopBilinearDescriptor_t desc = nullptr; - - if (!desc_opt) { - INFINICORE_CHECK_ERROR(infiniopCreateBilinearDescriptor( - context::getInfiniopHandle(out->device()), &desc, - out->desc(), x1->desc(), x2->desc(), weight->desc(), - bias ? (*bias)->desc() : nullptr)); - cache.put(seed, desc); - } else { - desc = *desc_opt; - } - - size_t workspace_size = 0; - INFINICORE_CHECK_ERROR(infiniopGetBilinearWorkspaceSize(desc, &workspace_size)); - std::shared_ptr workspace = context::allocateMemory(workspace_size); - - INFINICORE_CHECK_ERROR(infiniopBilinear( - desc, workspace->data(), workspace_size, - out->data(), x1->data(), x2->data(), - weight->data(), bias ? (*bias)->data() : nullptr, - context::getStream())); -} - -static bool registered = []() { - Bilinear::dispatcher().registerAll(&calculate, false); - return true; -}(); - -} // namespace infinicore::op::bilinear_impl::infiniop \ No newline at end of file +// #include "../../utils.hpp" +// #include "infinicore/common/hash.hpp" +// #include "infinicore/ops/bilinear.hpp" +// #include "infinicore/ops/common/cache.hpp" +// #include + +// namespace infinicore::op::bilinear_impl::infiniop { + +// thread_local common::OpCache caches( +// 100, // capacity +// [](infiniopBilinearDescriptor_t &desc) { +// if (desc != nullptr) { +// INFINICORE_CHECK_ERROR(infiniopDestroyBilinearDescriptor(desc)); +// desc = nullptr; +// } +// }); + +// void calculate(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias) { +// size_t seed = hash_combine(out, x1, x2, weight); +// if (bias) { +// seed = hash_combine(out, x1, x2, weight,*bias); +// } + +// auto device_type = context::getDevice().getType(); +// auto device_index = context::getDevice().getIndex(); + +// auto &cache = caches.getCache(device_type, device_index); + +// auto desc_opt = cache.get(seed); +// infiniopBilinearDescriptor_t desc = nullptr; + +// if (!desc_opt) { +// INFINICORE_CHECK_ERROR(infiniopCreateBilinearDescriptor( +// context::getInfiniopHandle(out->device()), &desc, +// out->desc(), x1->desc(), x2->desc(), weight->desc(), +// bias ? (*bias)->desc() : nullptr)); +// cache.put(seed, desc); +// } else { +// desc = *desc_opt; +// } + +// size_t workspace_size = 0; +// INFINICORE_CHECK_ERROR(infiniopGetBilinearWorkspaceSize(desc, &workspace_size)); +// std::shared_ptr workspace = context::allocateMemory(workspace_size); + +// INFINICORE_CHECK_ERROR(infiniopBilinear( +// desc, workspace->data(), workspace_size, +// out->data(), x1->data(), x2->data(), +// weight->data(), bias ? (*bias)->data() : nullptr, +// context::getStream())); +// } + +// static bool registered = []() { +// Bilinear::dispatcher().registerAll(&calculate, false); +// return true; +// }(); + +// } // namespace infinicore::op::bilinear_impl::infiniop \ No newline at end of file From 1285badd4eae78996bb3b2dbb090c5d54854477c Mon Sep 17 00:00:00 2001 From: Kinorw Date: Sat, 29 Nov 2025 19:05:19 +0800 Subject: [PATCH 06/32] =?UTF-8?q?fix:=20rearrange=E5=88=86=E9=85=8D?= =?UTF-8?q?=E6=9C=80=E5=B0=8FUnit=E6=9C=AA=E8=80=83=E8=99=91=E5=B8=A6?= =?UTF-8?q?=E6=9C=89strides=E7=9A=84=E6=83=85=E5=86=B5=E8=83=BD=E5=90=A6?= =?UTF-8?q?=E8=A2=AB=E6=95=B4=E9=99=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infinicore/ops/bilinear/bilinear.cc | 9 +++------ src/utils/rearrange.cc | 22 ++++++++++++++++------ 2 files changed, 19 insertions(+), 12 deletions(-) diff --git a/src/infinicore/ops/bilinear/bilinear.cc b/src/infinicore/ops/bilinear/bilinear.cc index 97ff35ac7..dfe97fa87 100644 --- a/src/infinicore/ops/bilinear/bilinear.cc +++ b/src/infinicore/ops/bilinear/bilinear.cc @@ -1,5 +1,5 @@ #include "infinicore/ops/bilinear.hpp" -#include "infinicore/ops/gemm.hpp" +#include "infinicore/ops/matmul.hpp" #include "infinicore/ops/add.hpp" #include "infinicore/ops/rearrange.hpp" @@ -15,7 +15,6 @@ Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) auto dtype = x1->dtype(); auto device = x1->device(); - Tensor x1_cont = x1->is_contiguous() ? x1 : x1->contiguous(); Tensor x2_cont = x2->is_contiguous() ? x2 : x2->contiguous(); Tensor weight_cont = weight->is_contiguous() ? weight : weight->contiguous(); @@ -24,21 +23,19 @@ Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) Tensor weight_permuted_cont = weight_permuted->contiguous(); Tensor weight_matrix = weight_permuted_cont->view({in1_features, out_features * in2_features}); - Tensor intermediate = gemm(x1_cont, weight_matrix); + Tensor intermediate = matmul(x1_cont, weight_matrix); Tensor intermediate_3d = intermediate->view({batch_size, out_features, in2_features}); Tensor x2_col = x2_cont->view({batch_size, in2_features, 1}); - Tensor out_3d = gemm(intermediate_3d, x2_col); - + Tensor out_3d = matmul(intermediate_3d, x2_col); Tensor out = out_3d->view({batch_size, out_features}); if (bias) { Tensor bias_broadcast = (*bias)->as_strided({batch_size, out_features}, {0, (*bias)->strides()[0]}); out = add(out, bias_broadcast); } - return out; } diff --git a/src/utils/rearrange.cc b/src/utils/rearrange.cc index 7465302d0..48c3c4ed9 100644 --- a/src/utils/rearrange.cc +++ b/src/utils/rearrange.cc @@ -144,13 +144,26 @@ void rearrange( utils::Result RearrangeMeta::distributeUnit(const std::vector &candidates) const { // 获取当前的unit大小 size_t current_unit = _meta[0]; + size_t ndim_value = this->ndim(); - // 寻找满足条件的unit值:当前unit能被其整除 + // 寻找满足条件的unit值:当前unit能被其整除,且所有strides也能被其整除 size_t new_unit = 0; for (size_t candidate : candidates) { if (current_unit % candidate == 0) { - new_unit = candidate; - break; + // 检查所有 strides 是否都能被 candidate 整除(确保内存对齐) + bool strides_aligned = true; + for (size_t i = 0; i < ndim_value; ++i) { + ptrdiff_t dst_stride = std::abs(dst_strides()[i]); + ptrdiff_t src_stride = std::abs(src_strides()[i]); + if (dst_stride % candidate != 0 || src_stride % candidate != 0) { + strides_aligned = false; + break; + } + } + if (strides_aligned) { + new_unit = candidate; + break; + } } } @@ -164,9 +177,6 @@ utils::Result RearrangeMeta::distributeUnit(const std::vector(_meta); } - // 获取当前维度 - size_t ndim_value = this->ndim(); - // 创建新的布局数组 std::vector layout(2 + (ndim_value + 1) * 3, 0); From 7b7b9e382860f9e7d1f57ebc3f4e8317db596909 Mon Sep 17 00:00:00 2001 From: littleotherut Date: Sun, 30 Nov 2025 15:52:47 +0800 Subject: [PATCH 07/32] =?UTF-8?q?=E4=BF=AE=E6=94=B9gemm=E7=BC=96=E8=AF=91?= =?UTF-8?q?=E9=80=89=E9=A1=B9,=E8=A7=A3=E5=86=B3=E7=B2=BE=E5=BA=A6?= =?UTF-8?q?=E9=97=AE=E9=A2=98(=E5=BD=93=E5=89=8D=E5=AE=9E=E7=8E=B0?= =?UTF-8?q?=E4=B8=8BTF32=20rtol=E4=BB=85=E8=83=BD=E6=94=AF=E6=8C=815e-4?= =?UTF-8?q?=E7=BA=A7=E5=88=AB=E7=B2=BE=E5=BA=A6),=E5=88=A0=E9=99=A4?= =?UTF-8?q?=E6=97=A0=E7=94=A8=E5=86=97=E4=BD=99=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sbatch.sh | 3 + .../ops/bilinear/bilinear_infiniop.cc | 58 --- src/infiniop/ops/bilinear/bilinear.h | 38 -- src/infiniop/ops/bilinear/operator.cc | 329 ------------------ src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu | 2 +- 5 files changed, 4 insertions(+), 426 deletions(-) create mode 100755 sbatch.sh delete mode 100644 src/infinicore/ops/bilinear/bilinear_infiniop.cc delete mode 100644 src/infiniop/ops/bilinear/bilinear.h delete mode 100644 src/infiniop/ops/bilinear/operator.cc diff --git a/sbatch.sh b/sbatch.sh new file mode 100755 index 000000000..a1b8f442e --- /dev/null +++ b/sbatch.sh @@ -0,0 +1,3 @@ +srun --partition=nvidia --nodes=1 --gres=gpu:nvidia:2 --ntasks=1 --cpus-per-task=16 --mem=64G --time=00:20:00 \ + --output=output_%j.log \ + python test/infinicore/ops/bilinear.py --nvidia --verbose --bench --debug diff --git a/src/infinicore/ops/bilinear/bilinear_infiniop.cc b/src/infinicore/ops/bilinear/bilinear_infiniop.cc deleted file mode 100644 index b63f5f164..000000000 --- a/src/infinicore/ops/bilinear/bilinear_infiniop.cc +++ /dev/null @@ -1,58 +0,0 @@ -// #include "../../utils.hpp" -// #include "infinicore/common/hash.hpp" -// #include "infinicore/ops/bilinear.hpp" -// #include "infinicore/ops/common/cache.hpp" -// #include - -// namespace infinicore::op::bilinear_impl::infiniop { - -// thread_local common::OpCache caches( -// 100, // capacity -// [](infiniopBilinearDescriptor_t &desc) { -// if (desc != nullptr) { -// INFINICORE_CHECK_ERROR(infiniopDestroyBilinearDescriptor(desc)); -// desc = nullptr; -// } -// }); - -// void calculate(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias) { -// size_t seed = hash_combine(out, x1, x2, weight); -// if (bias) { -// seed = hash_combine(out, x1, x2, weight,*bias); -// } - -// auto device_type = context::getDevice().getType(); -// auto device_index = context::getDevice().getIndex(); - -// auto &cache = caches.getCache(device_type, device_index); - -// auto desc_opt = cache.get(seed); -// infiniopBilinearDescriptor_t desc = nullptr; - -// if (!desc_opt) { -// INFINICORE_CHECK_ERROR(infiniopCreateBilinearDescriptor( -// context::getInfiniopHandle(out->device()), &desc, -// out->desc(), x1->desc(), x2->desc(), weight->desc(), -// bias ? (*bias)->desc() : nullptr)); -// cache.put(seed, desc); -// } else { -// desc = *desc_opt; -// } - -// size_t workspace_size = 0; -// INFINICORE_CHECK_ERROR(infiniopGetBilinearWorkspaceSize(desc, &workspace_size)); -// std::shared_ptr workspace = context::allocateMemory(workspace_size); - -// INFINICORE_CHECK_ERROR(infiniopBilinear( -// desc, workspace->data(), workspace_size, -// out->data(), x1->data(), x2->data(), -// weight->data(), bias ? (*bias)->data() : nullptr, -// context::getStream())); -// } - -// static bool registered = []() { -// Bilinear::dispatcher().registerAll(&calculate, false); -// return true; -// }(); - -// } // namespace infinicore::op::bilinear_impl::infiniop \ No newline at end of file diff --git a/src/infiniop/ops/bilinear/bilinear.h b/src/infiniop/ops/bilinear/bilinear.h deleted file mode 100644 index 3e78007e2..000000000 --- a/src/infiniop/ops/bilinear/bilinear.h +++ /dev/null @@ -1,38 +0,0 @@ -#ifndef __BILINEAR_H__ -#define __BILINEAR_H__ - -#include "../../operator.h" - -#define DESCRIPTOR(NAMESPACE) \ - \ - namespace op::bilinear::NAMESPACE { \ - class Descriptor final : public InfiniopDescriptor { \ - struct Opaque; \ - Opaque *_opaque; \ - size_t _workspace_size; \ - \ - Descriptor( \ - Opaque *opaque, \ - size_t workspace_size_, \ - infiniDevice_t device_type, \ - int device_id) \ - : InfiniopDescriptor{device_type, device_id}, \ - _opaque(opaque), \ - _workspace_size(workspace_size_) {} \ - \ - public: \ - ~Descriptor(); \ - \ - size_t workspaceSize() const { return _workspace_size; } \ - \ - static infiniStatus_t create( \ - infiniopHandle_t handle, \ - Descriptor **desc_ptr, \ - infiniopTensorDescriptor_t out_desc, \ - infiniopTensorDescriptor_t x1_desc, \ - infiniopTensorDescriptor_t x2_desc, \ - infiniopTensorDescriptor_t weight_desc, \ - infiniopTensorDescriptor_t bias_desc); \ - \ - } -#endif // __BILINEAR_H__ diff --git a/src/infiniop/ops/bilinear/operator.cc b/src/infiniop/ops/bilinear/operator.cc deleted file mode 100644 index 01a5f9490..000000000 --- a/src/infiniop/ops/bilinear/operator.cc +++ /dev/null @@ -1,329 +0,0 @@ -#include "../../operator.h" -#include "../../../utils.h" -#include "../../../utils/check.h" -#include "../../handle.h" -#include "../../tensor.h" -#include "infiniop/ops/bilinear.h" -#include "infiniop/ops/gemm.h" -#include "infiniop/ops/add.h" -#include "infiniop/ops/rearrange.h" - -#include - -struct InfiniopBilinearDescriptor { - InfiniopDescriptor _super; - infiniopGemmDescriptor_t gemm_1_desc; - infiniopGemmDescriptor_t gemm_2_desc; - infiniopRearrangeDescriptor_t weight_rearrange_desc; - infiniopRearrangeDescriptor_t x1_rearrange_desc; - infiniopRearrangeDescriptor_t x2_rearrange_desc; - infiniopAddDescriptor_t bias_add_desc; - size_t workspace_size; - size_t weight_offset; - size_t imediate_offset; - size_t x1_cont_offset; - size_t x2_cont_offset; - size_t op_workspace_offset; - size_t op_workspace_size; -}; - -namespace { -constexpr size_t kWorkspaceAlignment = 256; - -size_t aligned_size(size_t value) { - return utils::align(value, kWorkspaceAlignment); -} -} // namespace - -__C __export infiniStatus_t infiniopCreateBilinearDescriptor( - infiniopHandle_t handle, - infiniopBilinearDescriptor_t *desc_ptr, - infiniopTensorDescriptor_t out_desc, - infiniopTensorDescriptor_t x1_desc, - infiniopTensorDescriptor_t x2_desc, - infiniopTensorDescriptor_t weight_desc, - infiniopTensorDescriptor_t bias_desc) { - - if (out_desc->ndim() != 2 || x1_desc->ndim() != 2 || x2_desc->ndim() != 2 || weight_desc->ndim() != 3) { - return INFINI_STATUS_BAD_TENSOR_SHAPE; - } - - size_t batch_size = x1_desc->shape()[0]; - size_t in1_features = x1_desc->shape()[1]; - size_t in2_features = x2_desc->shape()[1]; - size_t out_features = out_desc->shape()[1]; - - if (x2_desc->shape()[0] != batch_size || - weight_desc->shape()[0] != out_features || - weight_desc->shape()[1] != in1_features || - weight_desc->shape()[2] != in2_features || - out_desc->shape()[0] != batch_size) { - return INFINI_STATUS_BAD_TENSOR_SHAPE; - } - - auto dtype = out_desc->dtype(); - CHECK_OR_RETURN(x1_desc->dtype() == dtype && x2_desc->dtype() == dtype && weight_desc->dtype() == dtype, - INFINI_STATUS_BAD_TENSOR_DTYPE); - - if (bias_desc) { - CHECK_OR_RETURN(bias_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); - CHECK_OR_RETURN(bias_desc->ndim() == 1 && bias_desc->dim(0) == out_features, - INFINI_STATUS_BAD_TENSOR_SHAPE); - } - - size_t dtype_size = infiniSizeOf(dtype); - - // --- Weight Rearrangement --- - // Target: [in1, out * in2] (flattened from [in1, out, in2]) - // Source: [out, in1, in2] - - // 1. Create descriptor for contiguous target weight [in1, out, in2] - size_t weight_dst_shape[3] = {in1_features, out_features, in2_features}; - infiniopTensorDescriptor_t weight_dst_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&weight_dst_desc, 3, weight_dst_shape, nullptr, dtype)); - - // 2. Create descriptor for source weight viewed as [in1, out, in2] (permuted) - infiniopTensorDescriptor_t weight_src_permuted; - CHECK_STATUS(infiniopCreateTensorDescriptor(&weight_src_permuted, 3, weight_desc->shape().data(), weight_desc->strides().data(), dtype)); - TRANSFORM_TENSOR_DESC(weight_src_permuted, dimPermute({1, 0, 2})); - - // 3. Create Rearrange descriptor - infiniopRearrangeDescriptor_t weight_rearrange_desc; - CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &weight_rearrange_desc, weight_dst_desc, weight_src_permuted)); - CHECK_STATUS(infiniopDestroyTensorDescriptor(weight_src_permuted)); - - size_t weight_bytes = in1_features * out_features * in2_features * dtype_size; - - // --- GEMM 1: x1 @ weight_matrix --- - // x1: [batch, in1] - // weight_matrix: [in1, out * in2] - // result (imediate): [batch, out * in2] - - // Prepare weight matrix descriptor (view of weight_dst_desc) - infiniopTensorDescriptor_t weight_matrix_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&weight_matrix_desc, 3, weight_dst_shape, nullptr, dtype)); - TRANSFORM_TENSOR_DESC(weight_matrix_desc, dimMerge(1, 2)); - CHECK_STATUS(infiniopDestroyTensorDescriptor(weight_dst_desc)); // Done with weight_dst_desc - - // Prepare x1 descriptor (handle non-contiguous) - infiniopTensorDescriptor_t x1_gemm_desc = x1_desc; - infiniopRearrangeDescriptor_t x1_rearrange_desc = nullptr; - size_t x1_cont_bytes = 0; - - if (x1_desc->strides()[x1_desc->ndim() - 1] != 1) { - infiniopTensorDescriptor_t x1_cont_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&x1_cont_desc, 2, x1_desc->shape().data(), nullptr, dtype)); - CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &x1_rearrange_desc, x1_cont_desc, x1_desc)); - x1_gemm_desc = x1_cont_desc; // Use the contiguous descriptor for GEMM creation - x1_cont_bytes = batch_size * in1_features * dtype_size; - } - - // Prepare imediate descriptor - size_t imediate_shape[2] = {batch_size, out_features * in2_features}; - infiniopTensorDescriptor_t imediate_flat_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&imediate_flat_desc, 2, imediate_shape, nullptr, dtype)); - - infiniopGemmDescriptor_t gemm_1_desc; - CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &gemm_1_desc, imediate_flat_desc, x1_gemm_desc, weight_matrix_desc)); - CHECK_STATUS(infiniopDestroyTensorDescriptor(weight_matrix_desc)); - if (x1_rearrange_desc) { - CHECK_STATUS(infiniopDestroyTensorDescriptor(x1_gemm_desc)); - } - - size_t imediate_bytes = batch_size * out_features * in2_features * dtype_size; - - // --- GEMM 2: imediate @ x2^T --- - // We perform Batch GEMM: - // A (imediate): [batch, out, in2] - // B (x2): [batch, in2, 1] - // C (out): [batch, out, 1] - - // Prepare A: imediate viewed as [batch, out, in2] - infiniopTensorDescriptor_t imediate_split_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&imediate_split_desc, 2, imediate_shape, nullptr, dtype)); - TRANSFORM_TENSOR_DESC(imediate_split_desc, dimSplit(1, {out_features, in2_features})); - CHECK_STATUS(infiniopDestroyTensorDescriptor(imediate_flat_desc)); // Done with flat - - // Prepare B: x2 viewed as [batch, in2, 1] (handle non-contiguous) - infiniopTensorDescriptor_t x2_gemm_desc = x2_desc; - infiniopRearrangeDescriptor_t x2_rearrange_desc = nullptr; - size_t x2_cont_bytes = 0; - - if (x2_desc->strides()[x2_desc->ndim() - 1] != 1) { - infiniopTensorDescriptor_t x2_cont_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&x2_cont_desc, 2, x2_desc->shape().data(), nullptr, dtype)); - CHECK_STATUS(infiniopCreateRearrangeDescriptor(handle, &x2_rearrange_desc, x2_cont_desc, x2_desc)); - x2_gemm_desc = x2_cont_desc; - x2_cont_bytes = batch_size * in2_features * dtype_size; - } - - infiniopTensorDescriptor_t x2_col_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&x2_col_desc, 2, x2_gemm_desc->shape().data(), x2_gemm_desc->strides().data(), dtype)); - TRANSFORM_TENSOR_DESC(x2_col_desc, dimSplit(1, {in2_features, 1})); - if (x2_rearrange_desc) { - CHECK_STATUS(infiniopDestroyTensorDescriptor(x2_gemm_desc)); - } - - // Prepare C: out viewed as [batch, out, 1] - infiniopTensorDescriptor_t out_col_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&out_col_desc, 2, out_desc->shape().data(), out_desc->strides().data(), dtype)); - TRANSFORM_TENSOR_DESC(out_col_desc, dimSplit(1, {out_features, 1})); - - infiniopGemmDescriptor_t gemm_2_desc; - CHECK_STATUS(infiniopCreateGemmDescriptor(handle, &gemm_2_desc, out_col_desc, imediate_split_desc, x2_col_desc)); - - CHECK_STATUS(infiniopDestroyTensorDescriptor(imediate_split_desc)); - CHECK_STATUS(infiniopDestroyTensorDescriptor(x2_col_desc)); - CHECK_STATUS(infiniopDestroyTensorDescriptor(out_col_desc)); - - // --- Bias Add --- - infiniopAddDescriptor_t bias_add_desc = nullptr; - size_t add_workspace_size = 0; - if (bias_desc) { - size_t bias_shape[2] = {batch_size, out_features}; - ptrdiff_t bias_strides[2] = {0, bias_desc->stride(0)}; - infiniopTensorDescriptor_t bias_broadcast_desc; - CHECK_STATUS(infiniopCreateTensorDescriptor(&bias_broadcast_desc, 2, bias_shape, bias_strides, dtype)); - CHECK_STATUS(infiniopCreateAddDescriptor(handle, &bias_add_desc, out_desc, out_desc, bias_broadcast_desc)); - CHECK_STATUS(infiniopGetAddWorkspaceSize(bias_add_desc, &add_workspace_size)); - CHECK_STATUS(infiniopDestroyTensorDescriptor(bias_broadcast_desc)); - } - - // --- Workspace Calculation --- - size_t gemm1_workspace_size = 0; - size_t gemm2_workspace_size = 0; - CHECK_STATUS(infiniopGetGemmWorkspaceSize(gemm_1_desc, &gemm1_workspace_size)); - CHECK_STATUS(infiniopGetGemmWorkspaceSize(gemm_2_desc, &gemm2_workspace_size)); - - size_t op_workspace_size = std::max({gemm1_workspace_size, gemm2_workspace_size, add_workspace_size}); - op_workspace_size = aligned_size(op_workspace_size); - - size_t workspace_cursor = 0; - auto reserve_buffer = [&](size_t bytes) -> size_t { - if (bytes == 0) { - return 0; - } - workspace_cursor = aligned_size(workspace_cursor); - size_t offset = workspace_cursor; - workspace_cursor += bytes; - return offset; - }; - - size_t weight_offset = reserve_buffer(weight_bytes); - size_t imediate_offset = reserve_buffer(imediate_bytes); - size_t x1_cont_offset = x1_rearrange_desc ? reserve_buffer(x1_cont_bytes) : 0; - size_t x2_cont_offset = x2_rearrange_desc ? reserve_buffer(x2_cont_bytes) : 0; - size_t op_workspace_offset = reserve_buffer(op_workspace_size); - size_t workspace_size = aligned_size(workspace_cursor); - - *(InfiniopBilinearDescriptor **)desc_ptr = new InfiniopBilinearDescriptor{ - {handle->device, handle->device_id}, - gemm_1_desc, - gemm_2_desc, - weight_rearrange_desc, - x1_rearrange_desc, - x2_rearrange_desc, - bias_add_desc, - workspace_size, - weight_offset, - imediate_offset, - x1_cont_offset, - x2_cont_offset, - op_workspace_offset, - op_workspace_size}; - - return INFINI_STATUS_SUCCESS; -} - -__C __export infiniStatus_t infiniopGetBilinearWorkspaceSize( - infiniopBilinearDescriptor_t desc, - size_t *size) { - *size = reinterpret_cast(desc)->workspace_size; - return INFINI_STATUS_SUCCESS; -} - -__C __export infiniStatus_t infiniopBilinear( - infiniopBilinearDescriptor_t desc_, - void *workspace, - size_t workspace_size, - void *out, - void const *x1, - void const *x2, - void const *weight, - void const *bias, - void *stream) { - - auto desc = reinterpret_cast(desc_); - if (workspace_size < desc->workspace_size) { - return INFINI_STATUS_INSUFFICIENT_WORKSPACE; - } - - if (desc->bias_add_desc && bias == nullptr) { - return INFINI_STATUS_BAD_PARAM; - } - - if (!desc->bias_add_desc && bias != nullptr) { - return INFINI_STATUS_BAD_PARAM; - } - - char *workspace_ptr = reinterpret_cast(workspace); - void *weight_buffer = workspace_ptr + desc->weight_offset; - void *imediate_buffer = workspace_ptr + desc->imediate_offset; - void *op_workspace = workspace_ptr + desc->op_workspace_offset; - - // 1. Rearrange Weight - CHECK_STATUS(infiniopRearrange(desc->weight_rearrange_desc, weight_buffer, weight, stream)); - - // 2. Prepare x1 - void const *x1_ptr = x1; - if (desc->x1_rearrange_desc) { - void *x1_buffer = workspace_ptr + desc->x1_cont_offset; - CHECK_STATUS(infiniopRearrange(desc->x1_rearrange_desc, x1_buffer, x1, stream)); - x1_ptr = x1_buffer; - } - - // 3. GEMM 1: x1 @ weight -> imediate - CHECK_STATUS(infiniopGemm(desc->gemm_1_desc, op_workspace, desc->op_workspace_size, - imediate_buffer, x1_ptr, weight_buffer, 1.0f, 0.0f, stream)); - - // 4. Prepare x2 - void const *x2_ptr = x2; - if (desc->x2_rearrange_desc) { - void *x2_buffer = workspace_ptr + desc->x2_cont_offset; - CHECK_STATUS(infiniopRearrange(desc->x2_rearrange_desc, x2_buffer, x2, stream)); - x2_ptr = x2_buffer; - } - - // 5. GEMM 2: imediate @ x2 -> out - CHECK_STATUS(infiniopGemm(desc->gemm_2_desc, op_workspace, desc->op_workspace_size, - out, imediate_buffer, x2_ptr, 1.0f, 0.0f, stream)); - - // 6. Bias Add - if (desc->bias_add_desc) { - CHECK_STATUS(infiniopAdd(desc->bias_add_desc, op_workspace, desc->op_workspace_size, - out, out, bias, stream)); - } - - return INFINI_STATUS_SUCCESS; -} - -__C __export infiniStatus_t infiniopDestroyBilinearDescriptor( - infiniopBilinearDescriptor_t desc_) { - auto desc = reinterpret_cast(desc_); - - CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->gemm_1_desc)); - CHECK_STATUS(infiniopDestroyGemmDescriptor(desc->gemm_2_desc)); - CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->weight_rearrange_desc)); - if (desc->x1_rearrange_desc) { - CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->x1_rearrange_desc)); - } - if (desc->x2_rearrange_desc) { - CHECK_STATUS(infiniopDestroyRearrangeDescriptor(desc->x2_rearrange_desc)); - } - if (desc->bias_add_desc) { - CHECK_STATUS(infiniopDestroyAddDescriptor(desc->bias_add_desc)); - } - - delete desc; - return INFINI_STATUS_SUCCESS; -} diff --git a/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu b/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu index 0e0c65f2b..93a263faf 100644 --- a/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu +++ b/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu @@ -71,7 +71,7 @@ infiniStatus_t Descriptor::calculate( #if defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) compute_type = CUDA_R_32F; #else - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = CUBLAS_COMPUTE_32F; #endif break; From a5c5f7e15a8ec71af444c8383bc87ffe33904662 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Sun, 30 Nov 2025 16:04:16 +0800 Subject: [PATCH 08/32] =?UTF-8?q?=E9=81=BF=E5=85=8D=E6=B1=A1=E6=9F=93,?= =?UTF-8?q?=E5=8F=96=E6=B6=88gemm=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu b/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu index 93a263faf..0e0c65f2b 100644 --- a/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu +++ b/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu @@ -71,7 +71,7 @@ infiniStatus_t Descriptor::calculate( #if defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) compute_type = CUDA_R_32F; #else - compute_type = CUBLAS_COMPUTE_32F; + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; #endif break; From 7978405feca989e3625eafc59380a35dd9346baf Mon Sep 17 00:00:00 2001 From: Kinorw Date: Sun, 30 Nov 2025 20:45:01 +0800 Subject: [PATCH 09/32] =?UTF-8?q?feat:=20baddbmm=5Fdemo,=E6=9C=AA=E5=A4=84?= =?UTF-8?q?=E7=90=86beta,alpha?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops.hpp | 2 + include/infinicore/ops/baddbmm.hpp | 15 ++++++ include/infiniop.h | 1 - include/infiniop/ops/bilinear.h | 35 ------------- python/infinicore/__init__.py | 3 ++ python/infinicore/ops/baddbmm.py | 24 +++++++++ sbatch.sh | 2 +- src/infinicore/ops/baddbmm/baddbmm.cc | 53 +++++++++++++++++++ src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/baddbmm.hpp | 67 ++++++++++++++++++++++++ test/infinicore/ops/baddbmm.py | 8 +-- test/infiniop/libinfiniop/op_register.py | 2 +- 12 files changed, 172 insertions(+), 42 deletions(-) create mode 100644 include/infinicore/ops/baddbmm.hpp delete mode 100644 include/infiniop/ops/bilinear.h create mode 100644 python/infinicore/ops/baddbmm.py create mode 100644 src/infinicore/ops/baddbmm/baddbmm.cc create mode 100644 src/infinicore/pybind11/ops/baddbmm.hpp diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 0937a4821..919e3bc05 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -2,6 +2,8 @@ #include "ops/add.hpp" #include "ops/attention.hpp" +#include "ops/baddbmm.hpp" +#include "ops/bilinear.hpp" #include "ops/causal_softmax.hpp" #include "ops/matmul.hpp" #include "ops/ones.hpp" diff --git a/include/infinicore/ops/baddbmm.hpp b/include/infinicore/ops/baddbmm.hpp new file mode 100644 index 000000000..b4c8df05b --- /dev/null +++ b/include/infinicore/ops/baddbmm.hpp @@ -0,0 +1,15 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { + +Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, + std::optional beta = std::nullopt, + std::optional alpha = std::nullopt); +void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, + std::optional beta = std::nullopt, + std::optional alpha = std::nullopt); +} // namespace infinicore::op \ No newline at end of file diff --git a/include/infiniop.h b/include/infiniop.h index 5b6ff7c55..92e6f5963 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -4,7 +4,6 @@ #include "infiniop/handle.h" #include "infiniop/ops/add.h" #include "infiniop/ops/attention.h" -#include "infiniop/ops/bilinear.h" #include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" diff --git a/include/infiniop/ops/bilinear.h b/include/infiniop/ops/bilinear.h deleted file mode 100644 index ea4a0555f..000000000 --- a/include/infiniop/ops/bilinear.h +++ /dev/null @@ -1,35 +0,0 @@ -#ifndef __INFINIOP_BILINEAR_API_H__ -#define __INFINIOP_BILINEAR_API_H__ - -#include "../operator_descriptor.h" - -typedef struct InfiniopDescriptor *infiniopBilinearDescriptor_t; - -__C __export infiniStatus_t infiniopCreateBilinearDescriptor( - infiniopHandle_t handle, - infiniopBilinearDescriptor_t *desc_ptr, - infiniopTensorDescriptor_t out_desc, - infiniopTensorDescriptor_t x1_desc, - infiniopTensorDescriptor_t x2_desc, - infiniopTensorDescriptor_t weight_desc, - infiniopTensorDescriptor_t bias_desc); // bias 可以为 nullptr - -__C __export infiniStatus_t infiniopGetBilinearWorkspaceSize( - infiniopBilinearDescriptor_t desc, - size_t *size); - -__C __export infiniStatus_t infiniopBilinear( - infiniopBilinearDescriptor_t desc, - void *workspace, - size_t workspace_size, - void *out, - void const *x1, - void const *x2, - void const *weight, - void const *bias, - void *stream); - -__C __export infiniStatus_t infiniopDestroyBilinearDescriptor( - infiniopBilinearDescriptor_t desc); - -#endif \ No newline at end of file diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index efb529a84..8f32043e3 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -44,6 +44,8 @@ from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow from infinicore.ops.rearrange import rearrange +from infinicore.ops.baddbmm import baddbmm +from infinicore.ops.bilinear import bilinear from infinicore.tensor import ( Tensor, empty, @@ -99,6 +101,7 @@ # Operations. "add", "attention", + "baddbmm", "bilinear", "matmul", "mul", diff --git a/python/infinicore/ops/baddbmm.py b/python/infinicore/ops/baddbmm.py new file mode 100644 index 000000000..96b2544a9 --- /dev/null +++ b/python/infinicore/ops/baddbmm.py @@ -0,0 +1,24 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +def baddbmm(input, batch1, batch2, *, beta=1.0, alpha=1.0, out=None): + if out is None: + return Tensor( + _infinicore.baddbmm( + input._underlying, + batch1._underlying, + batch2._underlying, + float(beta), + float(alpha), + ) + ) + _infinicore.baddbmm_( + out._underlying, + input._underlying, + batch1._underlying, + batch2._underlying, + float(beta), + float(alpha), + ) + + return out \ No newline at end of file diff --git a/sbatch.sh b/sbatch.sh index a1b8f442e..b1b2b2ddf 100755 --- a/sbatch.sh +++ b/sbatch.sh @@ -1,3 +1,3 @@ srun --partition=nvidia --nodes=1 --gres=gpu:nvidia:2 --ntasks=1 --cpus-per-task=16 --mem=64G --time=00:20:00 \ --output=output_%j.log \ - python test/infinicore/ops/bilinear.py --nvidia --verbose --bench --debug + python test/infinicore/ops/baddbmm.py --nvidia --verbose --bench --debug diff --git a/src/infinicore/ops/baddbmm/baddbmm.cc b/src/infinicore/ops/baddbmm/baddbmm.cc new file mode 100644 index 000000000..37a3096a6 --- /dev/null +++ b/src/infinicore/ops/baddbmm/baddbmm.cc @@ -0,0 +1,53 @@ +#include "infinicore/ops/baddbmm.hpp" +#include "infinicore/ops/matmul.hpp" +#include "infinicore/ops/mul.hpp" +#include "infinicore/ops/add.hpp" +#include "infinicore/ops/rearrange.hpp" +#include + +namespace infinicore::op { + +Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, + std::optional beta, + std::optional alpha) { + + size_t batch_size = batch1->shape()[0]; + size_t m = batch1->shape()[1]; + size_t k = batch1->shape()[2]; + size_t n = batch2->shape()[2]; + + Tensor input_cont = input->is_contiguous() ? input : input->contiguous(); + Tensor batch1_cont = batch1->is_contiguous() ? batch1 : batch1->contiguous(); + Tensor batch2_cont = batch2->is_contiguous() ? batch2 : batch2->contiguous(); + + Tensor result = matmul(batch1_cont, batch2_cont); + + if (alpha.has_value()) { + Tensor alpha_broadcast = (*alpha)->as_strided({batch_size, m ,n}, {0, 0, 0}); + result = mul(alpha_broadcast, result); + } + + Tensor input_part = input_cont; + if (input_part->ndim() == 2) { + input_part = input_part->as_strided({batch_size, m, n}, {0, input_part->strides()[0], input_part->strides()[1]}); + } else if (input_part->ndim() == 3 && input_part->shape()[0] == 1 && batch_size > 1) { + input_part = input_part->as_strided({batch_size, m, n}, {0, input_part->strides()[1], input_part->strides()[2]}); + } + + if (beta.has_value()) { + Tensor beta_broadcast = (*beta)->as_strided({batch_size, m ,n}, {0, 0, 0}); + input_part = mul(beta_broadcast, input_part); + } + + return add(input_part, result); +} + +void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, + std::optional beta, + std::optional alpha) { + Tensor result = baddbmm(input, batch1, batch2, beta, alpha); + // Copy result to out + rearrange_(out, result); +} + +} // namespace infinicore::op \ No newline at end of file diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index dfbd16434..9403c7405 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -4,6 +4,7 @@ #include "ops/add.hpp" #include "ops/attention.hpp" +#include "ops/baddbmm.hpp" #include "ops/bilinear.hpp" #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" @@ -24,6 +25,7 @@ namespace infinicore::ops { inline void bind(py::module &m) { bind_add(m); bind_attention(m); + bind_baddbmm(m); bind_bilinear(m); bind_causal_softmax(m); bind_random_sample(m); diff --git a/src/infinicore/pybind11/ops/baddbmm.hpp b/src/infinicore/pybind11/ops/baddbmm.hpp new file mode 100644 index 000000000..68df6b41a --- /dev/null +++ b/src/infinicore/pybind11/ops/baddbmm.hpp @@ -0,0 +1,67 @@ +#pragma once + +#include + +#include "infinicore/ops/baddbmm.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +Tensor py_baddbmm(Tensor input, Tensor batch1, Tensor batch2, float beta = 1.0f, float alpha = 1.0f) { + if(beta != 1.0f || alpha != 1.0f) { + std::optional beta_tensor = Tensor::from_blob((void*)&beta, {}, DataType::F32, input->device()); + std::optional alpha_tensor = Tensor::from_blob((void*)&alpha, {}, DataType::F32, input->device()); + return op::baddbmm(input, batch1, batch2, beta_tensor, alpha_tensor); + } + return op::baddbmm(input, batch1, batch2); +} + +void py_baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, float beta = 1.0f, float alpha = 1.0f) { + if(beta != 1.0f || alpha != 1.0f) { + std::optional beta_tensor = Tensor::from_blob((void*)&beta, {}, DataType::F32, input->device()); + std::optional alpha_tensor = Tensor::from_blob((void*)&alpha, {}, DataType::F32, input->device()); + op::baddbmm_(out, input, batch1, batch2, beta_tensor, alpha_tensor); + return; + } + op::baddbmm_(out, input, batch1, batch2); +} + +inline void bind_baddbmm(py::module &m) { + m.def("baddbmm", + &py_baddbmm, + py::arg("input"), + py::arg("batch1"), + py::arg("batch2"), + py::arg("beta") = 1.0f, + py::arg("alpha") = 1.0f, + R"doc(Batched matrix-matrix product with addition. +Args: + input: Input tensor + batch1: First batch of matrices + batch2: Second batch of matrices + beta: Scaling factor for input tensor + alpha: Scaling factor for the product of batch1 and batch2 +Returns: + Output tensor after baddbmm operation +)doc"); + m.def("baddbmm_", + &py_baddbmm_, + py::arg("out"), + py::arg("input"), + py::arg("batch1"), + py::arg("batch2"), + py::arg("beta") = 1.0f, + py::arg("alpha") = 1.0f, + R"doc(In-place batched matrix-matrix product with addition. +Args: + out: Output tensor + input: Input tensor + batch1: First batch of matrices + batch2: Second batch of matrices + beta: Scaling factor for input tensor + alpha: Scaling factor for the product of batch1 and batch2 +)doc"); +} + +} // namespace infinicore::ops \ No newline at end of file diff --git a/test/infinicore/ops/baddbmm.py b/test/infinicore/ops/baddbmm.py index 54ba12cae..1284ebbef 100644 --- a/test/infinicore/ops/baddbmm.py +++ b/test/infinicore/ops/baddbmm.py @@ -12,7 +12,7 @@ # Test cases format: (input_shape, batch1_shape, batch2_shape, input_strides_or_None, batch1_strides_or_None, batch2_strides_or_None, beta_or_None, alpha_or_None) _TEST_CASES_DATA = [ - ((3, 5), (2, 3, 4), (2, 4, 5), None, None, None, None, None), + # ((3, 5), (2, 3, 4), (2, 4, 5), None, None, None, None, None), ((8, 8), (4, 8, 8), (4, 8, 8), None, None, None, 0.5, 2.0), ((5, 7), (2, 5, 6), (2, 6, 7), (30, 1), (0, 5, 1), None, None, None), ((16, 16), (2, 16, 16), (2, 16, 16), None, None, (512, 1, 1), 1.0, None), @@ -95,9 +95,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.baddbmm(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.baddbmm(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.baddbmm(*args, **kwargs) def main(): diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index e95764e20..178f57899 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -975,4 +975,4 @@ def bilinear_(lib): lib.infiniopDestroyBilinearDescriptor.restype = c_int32 lib.infiniopDestroyBilinearDescriptor.argtypes = [ infiniopOperatorDescriptor_t, - ] + ] \ No newline at end of file From 02e7268863c9ad2b66762d6506e3a5d13874bcdb Mon Sep 17 00:00:00 2001 From: Kinorw Date: Mon, 1 Dec 2025 18:19:32 +0800 Subject: [PATCH 10/32] =?UTF-8?q?feat:=E5=AE=9E=E7=8E=B0=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E8=BD=AC=E6=8D=A2,baddbmm?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 ++ include/infinicore/dtype.hpp | 2 +- src/infinicore/dtype.cc | 47 +++++++++++++++++++++++++ src/infinicore/pybind11/ops/baddbmm.hpp | 35 +++++++++++++++--- test/infinicore/ops/baddbmm.py | 2 +- 5 files changed, 82 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index d9479360b..7683cae5e 100644 --- a/.gitignore +++ b/.gitignore @@ -32,3 +32,5 @@ cache/ *.gz *.zip *.tar + +*.txt \ No newline at end of file diff --git a/include/infinicore/dtype.hpp b/include/infinicore/dtype.hpp index ea3f49286..0806a2e88 100644 --- a/include/infinicore/dtype.hpp +++ b/include/infinicore/dtype.hpp @@ -29,5 +29,5 @@ enum class DataType { std::string toString(const DataType &dtype); size_t dsize(const DataType &dtype); - +void convertFloat(double value, DataType dtype, void* buffer); } // namespace infinicore diff --git a/src/infinicore/dtype.cc b/src/infinicore/dtype.cc index 5702c4c28..222da2c53 100644 --- a/src/infinicore/dtype.cc +++ b/src/infinicore/dtype.cc @@ -1,5 +1,9 @@ #include +#include +#include +#include + namespace infinicore { std::string toString(const DataType &dtype) { @@ -80,4 +84,47 @@ size_t dsize(const DataType &dtype) { return 0; } +void convertFloat(double value, DataType dtype, void* buffer) { + switch (dtype){ + case DataType::F32: { + float f32_val = static_cast(value); + std::memcpy(buffer, &f32_val, sizeof(float)); + break; + } + case DataType::F64: { + double f64_val = value; + std::memcpy(buffer, &f64_val, sizeof(double)); + break; + } + case DataType::F16: { + float f32_val = static_cast(value); + uint32_t f; + std::memcpy(&f, &f32_val, sizeof(float)); + + uint16_t h; + uint32_t sign = (f >> 16) & 0x8000; + int32_t exp = ((f >> 23) & 0xff) - 127 + 15; + uint32_t mant = f & 0x7fffff; + if (exp <= 0) { + h = sign; + } else if (exp >= 31) { + h = sign | 0x7c00; + } else { + h = sign | (exp << 10) | (mant >> 13); + } + std::memcpy(buffer, &h, sizeof(uint16_t)); + break; + } + case DataType::BF16: { + float f32_val = static_cast(value); + uint32_t f; + std::memcpy(&f, &f32_val, sizeof(float)); + uint16_t bf16 = static_cast(f >> 16); + std::memcpy(buffer, &bf16, sizeof(uint16_t)); + break; + } + default: + throw std::runtime_error("Unsupported dtype for float conversion"); + } +} } // namespace infinicore diff --git a/src/infinicore/pybind11/ops/baddbmm.hpp b/src/infinicore/pybind11/ops/baddbmm.hpp index 68df6b41a..60054c664 100644 --- a/src/infinicore/pybind11/ops/baddbmm.hpp +++ b/src/infinicore/pybind11/ops/baddbmm.hpp @@ -2,16 +2,39 @@ #include +#include "infinicore/dtype.hpp" #include "infinicore/ops/baddbmm.hpp" namespace py = pybind11; namespace infinicore::ops { +namespace { +// Helper function to create a scalar tensor with the correct dtype +Tensor create_scalar_tensor(float value, DataType dtype, Device device) { + // Create buffer for the target dtype + alignas(8) char buffer[8]; + convertFloat(static_cast(value), dtype, buffer); + + // Create CPU tensor with correct dtype + Tensor cpu_tensor = Tensor::from_blob(buffer, {}, dtype, Device(Device::Type::CPU)); + + // Create device tensor and copy + Tensor device_tensor = Tensor::empty({}, dtype, device); + device_tensor->copy_from(cpu_tensor); + + return device_tensor; +} +} + Tensor py_baddbmm(Tensor input, Tensor batch1, Tensor batch2, float beta = 1.0f, float alpha = 1.0f) { if(beta != 1.0f || alpha != 1.0f) { - std::optional beta_tensor = Tensor::from_blob((void*)&beta, {}, DataType::F32, input->device()); - std::optional alpha_tensor = Tensor::from_blob((void*)&alpha, {}, DataType::F32, input->device()); + DataType dtype = batch1->dtype(); + Device device = input->device(); + + Tensor beta_tensor = create_scalar_tensor(beta, dtype, device); + Tensor alpha_tensor = create_scalar_tensor(alpha, dtype, device); + return op::baddbmm(input, batch1, batch2, beta_tensor, alpha_tensor); } return op::baddbmm(input, batch1, batch2); @@ -19,8 +42,12 @@ Tensor py_baddbmm(Tensor input, Tensor batch1, Tensor batch2, float beta = 1.0f, void py_baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, float beta = 1.0f, float alpha = 1.0f) { if(beta != 1.0f || alpha != 1.0f) { - std::optional beta_tensor = Tensor::from_blob((void*)&beta, {}, DataType::F32, input->device()); - std::optional alpha_tensor = Tensor::from_blob((void*)&alpha, {}, DataType::F32, input->device()); + DataType dtype = batch1->dtype(); + Device device = input->device(); + + Tensor beta_tensor = create_scalar_tensor(beta, dtype, device); + Tensor alpha_tensor = create_scalar_tensor(alpha, dtype, device); + op::baddbmm_(out, input, batch1, batch2, beta_tensor, alpha_tensor); return; } diff --git a/test/infinicore/ops/baddbmm.py b/test/infinicore/ops/baddbmm.py index 1284ebbef..10ebbce10 100644 --- a/test/infinicore/ops/baddbmm.py +++ b/test/infinicore/ops/baddbmm.py @@ -12,7 +12,7 @@ # Test cases format: (input_shape, batch1_shape, batch2_shape, input_strides_or_None, batch1_strides_or_None, batch2_strides_or_None, beta_or_None, alpha_or_None) _TEST_CASES_DATA = [ - # ((3, 5), (2, 3, 4), (2, 4, 5), None, None, None, None, None), + ((3, 5), (2, 3, 4), (2, 4, 5), None, None, None, None, None), ((8, 8), (4, 8, 8), (4, 8, 8), None, None, None, 0.5, 2.0), ((5, 7), (2, 5, 6), (2, 6, 7), (30, 1), (0, 5, 1), None, None, None), ((16, 16), (2, 16, 16), (2, 16, 16), None, None, (512, 1, 1), 1.0, None), From ca8f8856e6138e4cde1bc0a191b4de046d042676 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Mon, 1 Dec 2025 18:47:17 +0800 Subject: [PATCH 11/32] =?UTF-8?q?=E6=8D=A2=E7=94=A8=E6=9B=B4=E5=90=88?= =?UTF-8?q?=E9=80=82=E7=9A=84gemm=E8=BF=9B=E8=A1=8C=E9=87=8D=E6=9E=84,?= =?UTF-8?q?=E6=97=A0=E9=9C=80=E7=B1=BB=E5=9E=8B=E8=BD=AC=E6=8D=A2,?= =?UTF-8?q?=E5=8F=96=E6=B6=88=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/dtype.hpp | 2 +- include/infinicore/ops/baddbmm.hpp | 8 ++-- src/infinicore/dtype.cc | 47 ---------------------- src/infinicore/ops/baddbmm/baddbmm.cc | 53 ++++++++++++------------- src/infinicore/pybind11/ops/baddbmm.hpp | 43 ++------------------ 5 files changed, 33 insertions(+), 120 deletions(-) diff --git a/include/infinicore/dtype.hpp b/include/infinicore/dtype.hpp index 0806a2e88..ea3f49286 100644 --- a/include/infinicore/dtype.hpp +++ b/include/infinicore/dtype.hpp @@ -29,5 +29,5 @@ enum class DataType { std::string toString(const DataType &dtype); size_t dsize(const DataType &dtype); -void convertFloat(double value, DataType dtype, void* buffer); + } // namespace infinicore diff --git a/include/infinicore/ops/baddbmm.hpp b/include/infinicore/ops/baddbmm.hpp index b4c8df05b..73e20a633 100644 --- a/include/infinicore/ops/baddbmm.hpp +++ b/include/infinicore/ops/baddbmm.hpp @@ -7,9 +7,9 @@ namespace infinicore::op { Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, - std::optional beta = std::nullopt, - std::optional alpha = std::nullopt); + float beta = 1.0f, + float alpha = 1.0f); void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, - std::optional beta = std::nullopt, - std::optional alpha = std::nullopt); + float beta = 1.0f, + float alpha = 1.0f); } // namespace infinicore::op \ No newline at end of file diff --git a/src/infinicore/dtype.cc b/src/infinicore/dtype.cc index 222da2c53..5702c4c28 100644 --- a/src/infinicore/dtype.cc +++ b/src/infinicore/dtype.cc @@ -1,9 +1,5 @@ #include -#include -#include -#include - namespace infinicore { std::string toString(const DataType &dtype) { @@ -84,47 +80,4 @@ size_t dsize(const DataType &dtype) { return 0; } -void convertFloat(double value, DataType dtype, void* buffer) { - switch (dtype){ - case DataType::F32: { - float f32_val = static_cast(value); - std::memcpy(buffer, &f32_val, sizeof(float)); - break; - } - case DataType::F64: { - double f64_val = value; - std::memcpy(buffer, &f64_val, sizeof(double)); - break; - } - case DataType::F16: { - float f32_val = static_cast(value); - uint32_t f; - std::memcpy(&f, &f32_val, sizeof(float)); - - uint16_t h; - uint32_t sign = (f >> 16) & 0x8000; - int32_t exp = ((f >> 23) & 0xff) - 127 + 15; - uint32_t mant = f & 0x7fffff; - if (exp <= 0) { - h = sign; - } else if (exp >= 31) { - h = sign | 0x7c00; - } else { - h = sign | (exp << 10) | (mant >> 13); - } - std::memcpy(buffer, &h, sizeof(uint16_t)); - break; - } - case DataType::BF16: { - float f32_val = static_cast(value); - uint32_t f; - std::memcpy(&f, &f32_val, sizeof(float)); - uint16_t bf16 = static_cast(f >> 16); - std::memcpy(buffer, &bf16, sizeof(uint16_t)); - break; - } - default: - throw std::runtime_error("Unsupported dtype for float conversion"); - } -} } // namespace infinicore diff --git a/src/infinicore/ops/baddbmm/baddbmm.cc b/src/infinicore/ops/baddbmm/baddbmm.cc index 37a3096a6..5aaae420b 100644 --- a/src/infinicore/ops/baddbmm/baddbmm.cc +++ b/src/infinicore/ops/baddbmm/baddbmm.cc @@ -1,53 +1,50 @@ #include "infinicore/ops/baddbmm.hpp" -#include "infinicore/ops/matmul.hpp" -#include "infinicore/ops/mul.hpp" -#include "infinicore/ops/add.hpp" +#include "infinicore/ops/gemm.hpp" #include "infinicore/ops/rearrange.hpp" #include namespace infinicore::op { Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, - std::optional beta, - std::optional alpha) { + float beta, + float alpha) { size_t batch_size = batch1->shape()[0]; size_t m = batch1->shape()[1]; - size_t k = batch1->shape()[2]; size_t n = batch2->shape()[2]; + - Tensor input_cont = input->is_contiguous() ? input : input->contiguous(); Tensor batch1_cont = batch1->is_contiguous() ? batch1 : batch1->contiguous(); Tensor batch2_cont = batch2->is_contiguous() ? batch2 : batch2->contiguous(); - Tensor result = matmul(batch1_cont, batch2_cont); - - if (alpha.has_value()) { - Tensor alpha_broadcast = (*alpha)->as_strided({batch_size, m ,n}, {0, 0, 0}); - result = mul(alpha_broadcast, result); - } - - Tensor input_part = input_cont; - if (input_part->ndim() == 2) { - input_part = input_part->as_strided({batch_size, m, n}, {0, input_part->strides()[0], input_part->strides()[1]}); - } else if (input_part->ndim() == 3 && input_part->shape()[0] == 1 && batch_size > 1) { - input_part = input_part->as_strided({batch_size, m, n}, {0, input_part->strides()[1], input_part->strides()[2]}); + Tensor result = Tensor::empty({batch_size, m, n}, batch1->dtype(), batch1->device()); + + Tensor input_cont = input->is_contiguous() ? input : input->contiguous(); + if (input->ndim() == 2) { + Tensor input_broadcast = input_cont->as_strided( + {batch_size, m, n}, + {0, input_cont->strides()[0], input_cont->strides()[1]}); + result->copy_from(input_broadcast); + } else if (input->ndim() == 3 && input->shape()[0] == 1 && batch_size > 1) { + Tensor input_broadcast = input_cont->as_strided( + {batch_size, m, n}, + {0, input_cont->strides()[1], input_cont->strides()[2]}); + result->copy_from(input_broadcast); + } else { + result->copy_from(input_cont); } - if (beta.has_value()) { - Tensor beta_broadcast = (*beta)->as_strided({batch_size, m ,n}, {0, 0, 0}); - input_part = mul(beta_broadcast, input_part); - } + // Fused operation: result = alpha * batch1 @ batch2 + beta * result + gemm_(result, batch1_cont, batch2_cont, alpha, beta); - return add(input_part, result); + return result; } void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, - std::optional beta, - std::optional alpha) { + float beta, + float alpha) { Tensor result = baddbmm(input, batch1, batch2, beta, alpha); - // Copy result to out - rearrange_(out, result); + out->copy_from(result); } } // namespace infinicore::op \ No newline at end of file diff --git a/src/infinicore/pybind11/ops/baddbmm.hpp b/src/infinicore/pybind11/ops/baddbmm.hpp index 60054c664..79f7cb412 100644 --- a/src/infinicore/pybind11/ops/baddbmm.hpp +++ b/src/infinicore/pybind11/ops/baddbmm.hpp @@ -2,56 +2,19 @@ #include -#include "infinicore/dtype.hpp" #include "infinicore/ops/baddbmm.hpp" namespace py = pybind11; namespace infinicore::ops { -namespace { -// Helper function to create a scalar tensor with the correct dtype -Tensor create_scalar_tensor(float value, DataType dtype, Device device) { - // Create buffer for the target dtype - alignas(8) char buffer[8]; - convertFloat(static_cast(value), dtype, buffer); - - // Create CPU tensor with correct dtype - Tensor cpu_tensor = Tensor::from_blob(buffer, {}, dtype, Device(Device::Type::CPU)); - - // Create device tensor and copy - Tensor device_tensor = Tensor::empty({}, dtype, device); - device_tensor->copy_from(cpu_tensor); - - return device_tensor; -} -} -Tensor py_baddbmm(Tensor input, Tensor batch1, Tensor batch2, float beta = 1.0f, float alpha = 1.0f) { - if(beta != 1.0f || alpha != 1.0f) { - DataType dtype = batch1->dtype(); - Device device = input->device(); - - Tensor beta_tensor = create_scalar_tensor(beta, dtype, device); - Tensor alpha_tensor = create_scalar_tensor(alpha, dtype, device); - - return op::baddbmm(input, batch1, batch2, beta_tensor, alpha_tensor); - } - return op::baddbmm(input, batch1, batch2); +Tensor py_baddbmm(Tensor input, Tensor batch1, Tensor batch2, float beta = 1.0f, float alpha = 1.0f) { + return op::baddbmm(input, batch1, batch2, beta, alpha); } void py_baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, float beta = 1.0f, float alpha = 1.0f) { - if(beta != 1.0f || alpha != 1.0f) { - DataType dtype = batch1->dtype(); - Device device = input->device(); - - Tensor beta_tensor = create_scalar_tensor(beta, dtype, device); - Tensor alpha_tensor = create_scalar_tensor(alpha, dtype, device); - - op::baddbmm_(out, input, batch1, batch2, beta_tensor, alpha_tensor); - return; - } - op::baddbmm_(out, input, batch1, batch2); + op::baddbmm_(out, input, batch1, batch2, beta, alpha); } inline void bind_baddbmm(py::module &m) { From 802e21743c161241ce0617fcfc74d16770642fc3 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Mon, 1 Dec 2025 20:22:17 +0800 Subject: [PATCH 12/32] =?UTF-8?q?=E4=BC=98=E5=8C=96:=E7=89=B9=E5=88=A4beta?= =?UTF-8?q?=3D0,=E5=8E=BB=E6=8E=89=E4=B8=8D=E5=BF=85=E8=A6=81=E7=9A=84cont?= =?UTF-8?q?iguous=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infinicore/ops/baddbmm/baddbmm.cc | 47 ++++++++++++++++++--------- 1 file changed, 31 insertions(+), 16 deletions(-) diff --git a/src/infinicore/ops/baddbmm/baddbmm.cc b/src/infinicore/ops/baddbmm/baddbmm.cc index 5aaae420b..76bb731c8 100644 --- a/src/infinicore/ops/baddbmm/baddbmm.cc +++ b/src/infinicore/ops/baddbmm/baddbmm.cc @@ -13,28 +13,43 @@ Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, size_t m = batch1->shape()[1]; size_t n = batch2->shape()[2]; - Tensor batch1_cont = batch1->is_contiguous() ? batch1 : batch1->contiguous(); Tensor batch2_cont = batch2->is_contiguous() ? batch2 : batch2->contiguous(); + bool input_is_target_shape = (input->ndim() == 3 && + input->shape()[0] == batch_size && + input->shape()[1] == m && + input->shape()[2] == n); + + if (input_is_target_shape && input->is_contiguous()) { + Tensor result = Tensor::empty({batch_size, m, n}, batch1->dtype(), batch1->device()); + if (beta != 0.0f) { + rearrange_(result, input); + } + gemm_(result, batch1_cont, batch2_cont, alpha, beta); + return result; + } + Tensor result = Tensor::empty({batch_size, m, n}, batch1->dtype(), batch1->device()); - Tensor input_cont = input->is_contiguous() ? input : input->contiguous(); - if (input->ndim() == 2) { - Tensor input_broadcast = input_cont->as_strided( - {batch_size, m, n}, - {0, input_cont->strides()[0], input_cont->strides()[1]}); - result->copy_from(input_broadcast); - } else if (input->ndim() == 3 && input->shape()[0] == 1 && batch_size > 1) { - Tensor input_broadcast = input_cont->as_strided( - {batch_size, m, n}, - {0, input_cont->strides()[1], input_cont->strides()[2]}); - result->copy_from(input_broadcast); - } else { - result->copy_from(input_cont); + if (beta != 0.0f) { + if (input->ndim() == 2) { + auto strides = input->strides(); + Tensor input_broadcast = input->as_strided( + {batch_size, m, n}, + {0, strides[0], strides[1]}); + rearrange_(result, input_broadcast); + } else if (input->ndim() == 3 && input->shape()[0] == 1 && batch_size > 1) { + auto strides = input->strides(); + Tensor input_broadcast = input->as_strided( + {batch_size, m, n}, + {0, strides[1], strides[2]}); + rearrange_(result, input_broadcast); + } else { + rearrange_(result, input); + } } - // Fused operation: result = alpha * batch1 @ batch2 + beta * result gemm_(result, batch1_cont, batch2_cont, alpha, beta); return result; @@ -44,7 +59,7 @@ void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, float beta, float alpha) { Tensor result = baddbmm(input, batch1, batch2, beta, alpha); - out->copy_from(result); + rearrange_(out, result); } } // namespace infinicore::op \ No newline at end of file From e1cf0a16c705d25df628724748d0aeb4a989a4dd Mon Sep 17 00:00:00 2001 From: Kinorw Date: Tue, 2 Dec 2025 17:47:30 +0800 Subject: [PATCH 13/32] =?UTF-8?q?=E4=BC=98=E5=8C=96=E5=AE=9E=E7=8E=B0,?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E9=9D=A2=E5=90=91BLAS=E5=85=BC=E5=AE=B9?= =?UTF-8?q?=E7=9A=84=E6=A3=80=E6=9F=A5=E8=80=8C=E9=9D=9E=E8=BF=9E=E7=BB=AD?= =?UTF-8?q?=E6=A3=80=E6=9F=A5,=E5=87=8F=E5=B0=91baddbmm=E7=9A=84=E6=96=B0?= =?UTF-8?q?=E5=BB=BA=E5=BC=A0=E9=87=8F,device=E8=BE=BE=E5=88=B01.00?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infinicore/ops/baddbmm/baddbmm.cc | 121 ++++++++++++++++---------- 1 file changed, 77 insertions(+), 44 deletions(-) diff --git a/src/infinicore/ops/baddbmm/baddbmm.cc b/src/infinicore/ops/baddbmm/baddbmm.cc index 76bb731c8..acacb58e7 100644 --- a/src/infinicore/ops/baddbmm/baddbmm.cc +++ b/src/infinicore/ops/baddbmm/baddbmm.cc @@ -1,65 +1,98 @@ #include "infinicore/ops/baddbmm.hpp" #include "infinicore/ops/gemm.hpp" #include "infinicore/ops/rearrange.hpp" -#include namespace infinicore::op { +// 内联的 BLAS 兼容性检查,减少函数调用开销 +inline bool is_blas_compatible(const Tensor &t) { + const auto ndim = t->ndim(); + if (ndim == 2) { + const auto rs = t->stride(0); + const auto cs = t->stride(1); + if (rs != 1 && cs != 1) return false; + if (rs == 1 && cs == 1) { + return t->shape()[0] == 1 || t->shape()[1] == 1; + } + return true; + } else if (ndim == 3) { + const auto rs = t->stride(1); + const auto cs = t->stride(2); + if (t->shape()[0] > 1 && t->stride(0) == 0) return false; + if (rs != 1 && cs != 1) return false; + if (rs == 1 && cs == 1) { + return t->shape()[1] == 1 || t->shape()[2] == 1; + } + return true; + } + return false; +} + +inline void prepare_gemm_input(Tensor &output,Tensor & input, const size_t batch_size, const size_t m, const size_t n) { + const auto input_ndim = input->ndim(); + if (input_ndim == 2) { + rearrange_(output, input->as_strided( + {batch_size, m, n}, + {0, input->stride(0), input->stride(1)})); + } else if (input_ndim == 3 && input->shape()[0] == 1 && batch_size > 1) { + rearrange_(output, input->as_strided( + {batch_size, m, n}, + {0, input->stride(1), input->stride(2)})); + } else { + rearrange_(output, input); + } +} + Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, float beta, float alpha) { - - size_t batch_size = batch1->shape()[0]; - size_t m = batch1->shape()[1]; - size_t n = batch2->shape()[2]; + const size_t batch_size = batch1->shape()[0]; + const size_t m = batch1->shape()[1]; + const size_t n = batch2->shape()[2]; - Tensor batch1_cont = batch1->is_contiguous() ? batch1 : batch1->contiguous(); - Tensor batch2_cont = batch2->is_contiguous() ? batch2 : batch2->contiguous(); + const Tensor &a = is_blas_compatible(batch1) ? batch1 : rearrange(batch1); + const Tensor &b = is_blas_compatible(batch2) ? batch2 : rearrange(batch2); - bool input_is_target_shape = (input->ndim() == 3 && - input->shape()[0] == batch_size && - input->shape()[1] == m && - input->shape()[2] == n); - - if (input_is_target_shape && input->is_contiguous()) { - Tensor result = Tensor::empty({batch_size, m, n}, batch1->dtype(), batch1->device()); - if (beta != 0.0f) { - rearrange_(result, input); - } - gemm_(result, batch1_cont, batch2_cont, alpha, beta); - return result; + if (beta == 0.0f) { + return gemm(a, b, alpha, 0.0f); } + + Tensor result = Tensor::empty({batch_size, m, n}, a->dtype(), a->device()); - Tensor result = Tensor::empty({batch_size, m, n}, batch1->dtype(), batch1->device()); + prepare_gemm_input(result, input, batch_size, m, n); - if (beta != 0.0f) { - if (input->ndim() == 2) { - auto strides = input->strides(); - Tensor input_broadcast = input->as_strided( - {batch_size, m, n}, - {0, strides[0], strides[1]}); - rearrange_(result, input_broadcast); - } else if (input->ndim() == 3 && input->shape()[0] == 1 && batch_size > 1) { - auto strides = input->strides(); - Tensor input_broadcast = input->as_strided( - {batch_size, m, n}, - {0, strides[1], strides[2]}); - rearrange_(result, input_broadcast); - } else { - rearrange_(result, input); - } - } - - gemm_(result, batch1_cont, batch2_cont, alpha, beta); - + gemm_(result, a, b, alpha, beta); return result; } void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, float beta, float alpha) { - Tensor result = baddbmm(input, batch1, batch2, beta, alpha); - rearrange_(out, result); + const size_t batch_size = batch1->shape()[0]; + const size_t m = batch1->shape()[1]; + const size_t n = batch2->shape()[2]; + + const Tensor &a = is_blas_compatible(batch1) ? batch1 : rearrange(batch1); + const Tensor &b = is_blas_compatible(batch2) ? batch2 : rearrange(batch2); + + const bool out_is_usable = out->is_contiguous() && + out->ndim() == 3 && + out->shape()[0] == batch_size && + out->shape()[1] == m && + out->shape()[2] == n; + + if (out_is_usable) { + if (beta != 0.0f && input->data() != out->data()) { + prepare_gemm_input(out, input, batch_size, m, n); + } + gemm_(out, a, b, alpha, beta); + } else { + Tensor result = Tensor::empty({batch_size, m, n}, a->dtype(), a->device()); + if (beta != 0.0f) { + prepare_gemm_input(result, input, batch_size, m, n); + } + gemm_(result, a, b, alpha, beta); + rearrange_(out, result); + } } - -} // namespace infinicore::op \ No newline at end of file +} \ No newline at end of file From a765e7ca0961cc022fa666b756584178d0892aff Mon Sep 17 00:00:00 2001 From: Kinorw Date: Wed, 3 Dec 2025 16:23:55 +0800 Subject: [PATCH 14/32] feat: fmod cpu and nvidia --- include/infinicore/ops.hpp | 1 + include/infinicore/ops/fmod.hpp | 16 ++ include/infiniop.h | 1 + include/infiniop/ops/fmod.h | 26 +++ python/infinicore/__init__.py | 2 + python/infinicore/ops/fmod.py | 11 ++ src/infinicore/ops/fmod/fmod.cc | 24 +++ src/infinicore/ops/fmod/fmod_infiniop.cc | 53 ++++++ src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/fmod.hpp | 26 +++ src/infiniop/ops/fmod/cpu/fmod_cpu.cc | 53 ++++++ src/infiniop/ops/fmod/cpu/fmod_cpu.h | 19 +++ src/infiniop/ops/fmod/cuda/kernel.cuh | 48 ++++++ src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu | 59 +++++++ src/infiniop/ops/fmod/nvidia/fmod_nvidia.cuh | 8 + src/infiniop/ops/fmod/operator.cc | 163 +++++++++++++++++++ test/infinicore/ops/fmod.py | 6 +- 17 files changed, 515 insertions(+), 3 deletions(-) create mode 100644 include/infinicore/ops/fmod.hpp create mode 100644 include/infiniop/ops/fmod.h create mode 100644 python/infinicore/ops/fmod.py create mode 100644 src/infinicore/ops/fmod/fmod.cc create mode 100644 src/infinicore/ops/fmod/fmod_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/fmod.hpp create mode 100644 src/infiniop/ops/fmod/cpu/fmod_cpu.cc create mode 100644 src/infiniop/ops/fmod/cpu/fmod_cpu.h create mode 100644 src/infiniop/ops/fmod/cuda/kernel.cuh create mode 100644 src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu create mode 100644 src/infiniop/ops/fmod/nvidia/fmod_nvidia.cuh create mode 100644 src/infiniop/ops/fmod/operator.cc diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 919e3bc05..bb05d022f 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -5,6 +5,7 @@ #include "ops/baddbmm.hpp" #include "ops/bilinear.hpp" #include "ops/causal_softmax.hpp" +#include "ops/fmod.hpp" #include "ops/matmul.hpp" #include "ops/ones.hpp" #include "ops/rearrange.hpp" diff --git a/include/infinicore/ops/fmod.hpp b/include/infinicore/ops/fmod.hpp new file mode 100644 index 000000000..87b90d515 --- /dev/null +++ b/include/infinicore/ops/fmod.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Fmod { +public: + using schema = void (*)(Tensor, Tensor, Tensor); + static void execute(Tensor c, Tensor a, Tensor b); + static common::OpDispatcher &dispatcher(); +}; + +Tensor fmod(Tensor a, Tensor b); +void fmod_(Tensor c, Tensor a, Tensor b); +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index 92e6f5963..3bafe4209 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -8,6 +8,7 @@ #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" #include "infiniop/ops/dequantize_awq.h" +#include "infiniop/ops/fmod.h" #include "infiniop/ops/gelu.h" #include "infiniop/ops/gemm.h" #include "infiniop/ops/layer_norm.h" diff --git a/include/infiniop/ops/fmod.h b/include/infiniop/ops/fmod.h new file mode 100644 index 000000000..b74b6daca --- /dev/null +++ b/include/infiniop/ops/fmod.h @@ -0,0 +1,26 @@ +#ifndef __INFINIOP_FMOD_API_H_ +#define __INFINIOP_FMOD_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopFmodDescriptor_t; + +__C __export infiniStatus_t infiniopCreateFmodDescriptor(infiniopHandle_t handle, + infiniopFmodDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c, + infiniopTensorDescriptor_t a, + infiniopTensorDescriptor_t b); + +__C __export infiniStatus_t infiniopGetFmodWorkspaceSize(infiniopFmodDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopFmod(infiniopFmodDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream); + +__C __export infiniStatus_t infiniopDestroyFmodDescriptor(infiniopFmodDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 8f32043e3..1d81972bf 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -46,6 +46,7 @@ from infinicore.ops.rearrange import rearrange from infinicore.ops.baddbmm import baddbmm from infinicore.ops.bilinear import bilinear +from infinicore.ops.fmod import fmod from infinicore.tensor import ( Tensor, empty, @@ -103,6 +104,7 @@ "attention", "baddbmm", "bilinear", + "fmod", "matmul", "mul", "narrow", diff --git a/python/infinicore/ops/fmod.py b/python/infinicore/ops/fmod.py new file mode 100644 index 000000000..e52be82cb --- /dev/null +++ b/python/infinicore/ops/fmod.py @@ -0,0 +1,11 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def fmod(input, other, *, out=None): + if out is None: + return Tensor(_infinicore.fmod(input._underlying, other._underlying)) + + _infinicore.fmod_(out._underlying, input._underlying, other._underlying) + + return out diff --git a/src/infinicore/ops/fmod/fmod.cc b/src/infinicore/ops/fmod/fmod.cc new file mode 100644 index 000000000..d5aa0d0c7 --- /dev/null +++ b/src/infinicore/ops/fmod/fmod.cc @@ -0,0 +1,24 @@ +#include "infinicore/ops/fmod.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Fmod::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Fmod::execute(Tensor c, Tensor a, Tensor b) { + dispatcher().lookup(context::getDevice().getType())(c, a, b); +} + +Tensor fmod(Tensor a, Tensor b) { + auto c = Tensor::empty(a->shape(), a->dtype(), a->device()); + fmod_(c, a, b); + return c; +} + +void fmod_(Tensor c, Tensor a, Tensor b) { + Fmod::execute(c, a, b); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/fmod/fmod_infiniop.cc b/src/infinicore/ops/fmod/fmod_infiniop.cc new file mode 100644 index 000000000..31434e325 --- /dev/null +++ b/src/infinicore/ops/fmod/fmod_infiniop.cc @@ -0,0 +1,53 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/fmod.hpp" +#include + +namespace infinicore::op::fmod_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopFmodDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyFmodDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor c, Tensor a, Tensor b) { + size_t seed = hash_combine(c, b, a); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopFmodDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateFmodDescriptor( + context::getInfiniopHandle(c->device()), &desc, + c->desc(), a->desc(), b->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetFmodWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopFmod( + desc, workspace->data(), workspace_size, + c->data(), a->data(), b->data(), context::getStream())); + +} + +static bool registered = []() { + Fmod::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::fmod_impl::infiniop \ No newline at end of file diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 9403c7405..0bd903922 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -8,6 +8,7 @@ #include "ops/bilinear.hpp" #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" +#include "ops/fmod.hpp" #include "ops/linear.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" @@ -28,6 +29,7 @@ inline void bind(py::module &m) { bind_baddbmm(m); bind_bilinear(m); bind_causal_softmax(m); + bind_fmod(m); bind_random_sample(m); bind_linear(m); bind_matmul(m); diff --git a/src/infinicore/pybind11/ops/fmod.hpp b/src/infinicore/pybind11/ops/fmod.hpp new file mode 100644 index 000000000..97af57da2 --- /dev/null +++ b/src/infinicore/pybind11/ops/fmod.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "infinicore/ops/fmod.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_fmod(py::module &m) { + m.def("fmod", + &op::fmod, + py::arg("a"), + py::arg("b"), + R"doc(Element-wise floating point remainder of division of two tensors.)doc"); + + m.def("fmod_", + &op::fmod_, + py::arg("c"), + py::arg("a"), + py::arg("b"), + R"doc(In-place element-wise floating point remainder of division of two tensors.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/fmod/cpu/fmod_cpu.cc b/src/infiniop/ops/fmod/cpu/fmod_cpu.cc new file mode 100644 index 000000000..1114470a4 --- /dev/null +++ b/src/infiniop/ops/fmod/cpu/fmod_cpu.cc @@ -0,0 +1,53 @@ +#include "fmod_cpu.h" + +namespace op::fmod::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &out_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(out_shape, a_shape, b_shape); + + // create CPU elementwise descriptor + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate(_info, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + return INFINI_STATUS_SUCCESS; +} +} // namespace op::mul::cpu diff --git a/src/infiniop/ops/fmod/cpu/fmod_cpu.h b/src/infiniop/ops/fmod/cpu/fmod_cpu.h new file mode 100644 index 000000000..54af25540 --- /dev/null +++ b/src/infiniop/ops/fmod/cpu/fmod_cpu.h @@ -0,0 +1,19 @@ +#ifndef _FMOD_CPU_H__ +#define _FMOD_CPU_H__ + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(fmod, cpu) + +namespace op::fmod::cpu { +typedef struct FmodOp { +public: + static constexpr size_t num_inputs = 2; + template + T operator()(const T &a, const T &b) const { + return std::fmod(a, b); + } +} FmodOp; +} // namespace op::fmod::cpu + +#endif // _FMOD_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/fmod/cuda/kernel.cuh b/src/infiniop/ops/fmod/cuda/kernel.cuh new file mode 100644 index 000000000..6e30ed25e --- /dev/null +++ b/src/infiniop/ops/fmod/cuda/kernel.cuh @@ -0,0 +1,48 @@ +#ifndef __FMOD_CUDA_H__ +#define __FMOD_CUDA_H__ + +namespace op::fmod::cuda { +typedef struct FmodOp { + static constexpr size_t num_inputs = 2; + template + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + // fmod(a, b) = a - b * trunc(a / b) + if constexpr (std::is_same_v) { + // 对于 half2,转换为 float 计算后再转回 + float2 af = __half22float2(a); + float2 bf = __half22float2(b); + float2 result; + result.x = ::fmodf(af.x, bf.x); + result.y = ::fmodf(af.y, bf.y); + return __float22half2_rn(result); + } else if constexpr (std::is_same_v) { + // 对于 bfloat162,转换为 float 计算后再转回 + float af_low = __bfloat162float(__low2bfloat16(a)); + float af_high = __bfloat162float(__high2bfloat16(a)); + float bf_low = __bfloat162float(__low2bfloat16(b)); + float bf_high = __bfloat162float(__high2bfloat16(b)); + return __floats2bfloat162_rn(::fmodf(af_low, bf_low), ::fmodf(af_high, bf_high)); + } else if constexpr (std::is_same_v) { + // 对于 half,转换为 float 计算后再转回 + float af = __half2float(a); + float bf = __half2float(b); + return __float2half(::fmodf(af, bf)); + } else if constexpr (std::is_same_v) { + // 对于 bfloat16,转换为 float 计算后再转回 + float af = __bfloat162float(a); + float bf = __bfloat162float(b); + return __float2bfloat16(::fmodf(af, bf)); + } else if constexpr (std::is_same_v) { + return ::fmodf(a, b); + } else if constexpr (std::is_same_v) { + return ::fmod(a, b); + } else { + // 整数类型使用 % 运算符 + return a % b; + } + } +} FmodOp; + +} // namespace op::fmod::cuda + +#endif // __FMOD_CUDA_H__ \ No newline at end of file diff --git a/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu b/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu new file mode 100644 index 000000000..5f834ef95 --- /dev/null +++ b/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu @@ -0,0 +1,59 @@ +#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" +#include "../cuda/kernel.cuh" +#include "fmod_nvidia.cuh" + +namespace op::fmod::nvidia { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size){ + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype){ + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::FmodOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::FmodOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::FmodOp, double>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::FmodOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::fmod::nvidia \ No newline at end of file diff --git a/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cuh b/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cuh new file mode 100644 index 000000000..e40d0088d --- /dev/null +++ b/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __MUL_CUDA_API_H__ +#define __MUL_CUDA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(fmod, nvidia) + +#endif // __MUL_CUDA_API_H__ \ No newline at end of file diff --git a/src/infiniop/ops/fmod/operator.cc b/src/infiniop/ops/fmod/operator.cc new file mode 100644 index 000000000..08fe0ef3c --- /dev/null +++ b/src/infiniop/ops/fmod/operator.cc @@ -0,0 +1,163 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/fmod.h" + +#ifdef ENABLE_CPU_API +#include "cpu/fmod_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#include "nvidia/fmod_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/fmod_metax.h" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/fmod_moore.h" +#endif + +__C infiniStatus_t infiniopCreateFmodDescriptor( + infiniopHandle_t handle, + infiniopFmodDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::fmod::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + c_desc, \ + {a_desc, \ + b_desc}) + + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + CREATE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetFmodWorkspaceSize(infiniopFmodDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + GET(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopFmod( + infiniopFmodDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *c, + const void *a, + const void *b, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, c, {a, b}, stream) + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + CALCULATE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyFmodDescriptor(infiniopFmodDescriptor_t desc) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + GET(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DELETE +} + + \ No newline at end of file diff --git a/test/infinicore/ops/fmod.py b/test/infinicore/ops/fmod.py index 97762bda8..9e86a6377 100644 --- a/test/infinicore/ops/fmod.py +++ b/test/infinicore/ops/fmod.py @@ -99,9 +99,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.fmod(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.fmod(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.fmod(*args, **kwargs) def main(): From eaee08b52c72e29124ec79e661280758518d3984 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Wed, 3 Dec 2025 17:36:16 +0800 Subject: [PATCH 15/32] moore and metax test --- src/infiniop/ops/fmod/metax/fmod_metax.h | 8 +++ src/infiniop/ops/fmod/metax/mul_metax.maca | 61 +++++++++++++++++++++ src/infiniop/ops/fmod/moore/fmod_moore.h | 8 +++ src/infiniop/ops/fmod/moore/fmod_moore.mu | 63 ++++++++++++++++++++++ 4 files changed, 140 insertions(+) create mode 100644 src/infiniop/ops/fmod/metax/fmod_metax.h create mode 100644 src/infiniop/ops/fmod/metax/mul_metax.maca create mode 100644 src/infiniop/ops/fmod/moore/fmod_moore.h create mode 100644 src/infiniop/ops/fmod/moore/fmod_moore.mu diff --git a/src/infiniop/ops/fmod/metax/fmod_metax.h b/src/infiniop/ops/fmod/metax/fmod_metax.h new file mode 100644 index 000000000..ad5769231 --- /dev/null +++ b/src/infiniop/ops/fmod/metax/fmod_metax.h @@ -0,0 +1,8 @@ +#ifndef __FMOD_METAX_API_H__ +#define __FMOD_METAX_API_H__ + +#include "../../../elementwise/metax/elementwise_metax_api.h" + +ELEMENTWISE_DESCRIPTOR(fmod, metax) + +#endif // __FMOD_METAX_API_H__ diff --git a/src/infiniop/ops/fmod/metax/mul_metax.maca b/src/infiniop/ops/fmod/metax/mul_metax.maca new file mode 100644 index 000000000..c9d54ad62 --- /dev/null +++ b/src/infiniop/ops/fmod/metax/mul_metax.maca @@ -0,0 +1,61 @@ +#include "../../../elementwise/metax/elementwise_metax.h" + +#include "../cuda/kernel.cuh" + +#include "fmod_metax.h" + +namespace op::fmod::metax { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + + CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::FmodOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::FmodOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::FmodOp, double>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::FmodOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::fmod::metax diff --git a/src/infiniop/ops/fmod/moore/fmod_moore.h b/src/infiniop/ops/fmod/moore/fmod_moore.h new file mode 100644 index 000000000..b24c337a8 --- /dev/null +++ b/src/infiniop/ops/fmod/moore/fmod_moore.h @@ -0,0 +1,8 @@ +#ifndef __FMOD_MOORE_API_H__ +#define __FMOD_MOORE_API_H__ + +#include "../../../elementwise/moore/elementwise_moore_api.h" + +ELEMENTWISE_DESCRIPTOR(fmod, moore) + +#endif // __FMOD_MOORE_API_H__ \ No newline at end of file diff --git a/src/infiniop/ops/fmod/moore/fmod_moore.mu b/src/infiniop/ops/fmod/moore/fmod_moore.mu new file mode 100644 index 000000000..0d1f9af4b --- /dev/null +++ b/src/infiniop/ops/fmod/moore/fmod_moore.mu @@ -0,0 +1,63 @@ +#include "fmod_moore.h" + +#include "../../../elementwise/moore/elementwise_moore.h" + +#include "../cuda/kernel.cuh" + +namespace op::fmod::moore { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &a_desc = input_desc_vec.at(0); + const auto &b_desc = input_desc_vec.at(1); + const auto &c_shape = out_desc->shape(); + const auto &a_shape = a_desc->shape(); + const auto &b_shape = b_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); + + // create MOORE elementwise descriptor + CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, moore::FmodOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, moore::FmodOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, moore::FmodOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, moore::FmodOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::fmod::moore \ No newline at end of file From 6ea950ef66f86f4c6fe70c56a424c3d22f4bf3fa Mon Sep 17 00:00:00 2001 From: littleotherut Date: Thu, 4 Dec 2025 20:16:58 +0800 Subject: [PATCH 16/32] =?UTF-8?q?fix:=20=E5=9F=BA=E4=BA=8Ematmul=E7=9A=84?= =?UTF-8?q?=E6=9B=B4=E6=96=B0=E8=B0=83=E6=95=B4bilinear?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sbatch.sh | 4 ++-- src/infinicore/ops/bilinear/bilinear.cc | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sbatch.sh b/sbatch.sh index b1b2b2ddf..bd1204adc 100755 --- a/sbatch.sh +++ b/sbatch.sh @@ -1,3 +1,3 @@ -srun --partition=nvidia --nodes=1 --gres=gpu:nvidia:2 --ntasks=1 --cpus-per-task=16 --mem=64G --time=00:20:00 \ +srun --partition=nvidia --nodes=1 --gres=gpu:nvidia:2 --ntasks=1 --cpus-per-task=16 --mem=64G --time=00:02:00 \ --output=output_%j.log \ - python test/infinicore/ops/baddbmm.py --nvidia --verbose --bench --debug + python test/infinicore/ops/fmod.py --nvidia --verbose --bench --debug diff --git a/src/infinicore/ops/bilinear/bilinear.cc b/src/infinicore/ops/bilinear/bilinear.cc index dfe97fa87..1fcea4ff3 100644 --- a/src/infinicore/ops/bilinear/bilinear.cc +++ b/src/infinicore/ops/bilinear/bilinear.cc @@ -23,13 +23,13 @@ Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) Tensor weight_permuted_cont = weight_permuted->contiguous(); Tensor weight_matrix = weight_permuted_cont->view({in1_features, out_features * in2_features}); - Tensor intermediate = matmul(x1_cont, weight_matrix); + Tensor intermediate = matmul(x1_cont, weight_matrix,1.0f); Tensor intermediate_3d = intermediate->view({batch_size, out_features, in2_features}); Tensor x2_col = x2_cont->view({batch_size, in2_features, 1}); - Tensor out_3d = matmul(intermediate_3d, x2_col); + Tensor out_3d = matmul(intermediate_3d, x2_col,1.0f); Tensor out = out_3d->view({batch_size, out_features}); if (bias) { From 06cfc357c768d99f4feea9e427bc991873f3736d Mon Sep 17 00:00:00 2001 From: littleotherut Date: Thu, 4 Dec 2025 20:19:29 +0800 Subject: [PATCH 17/32] moore_test --- sbatch.sh | 4 ++-- src/infiniop/ops/fmod/moore/fmod_moore.mu | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sbatch.sh b/sbatch.sh index bd1204adc..fa3ff1e47 100755 --- a/sbatch.sh +++ b/sbatch.sh @@ -1,3 +1,3 @@ -srun --partition=nvidia --nodes=1 --gres=gpu:nvidia:2 --ntasks=1 --cpus-per-task=16 --mem=64G --time=00:02:00 \ +srun --partition=mt --nodes=1 --gres=gpu:mt:2 --ntasks=1 --cpus-per-task=16 --mem=256G --time=00:01:00 \ --output=output_%j.log \ - python test/infinicore/ops/fmod.py --nvidia --verbose --bench --debug + python test/infinicore/ops/bilinear.py --moore --verbose --bench --debug diff --git a/src/infiniop/ops/fmod/moore/fmod_moore.mu b/src/infiniop/ops/fmod/moore/fmod_moore.mu index 0d1f9af4b..0c37da459 100644 --- a/src/infiniop/ops/fmod/moore/fmod_moore.mu +++ b/src/infiniop/ops/fmod/moore/fmod_moore.mu @@ -46,13 +46,13 @@ infiniStatus_t Descriptor::calculate( switch (_dtype) { case INFINI_DTYPE_F16: - return _device_info->calculate<256, moore::FmodOp, half>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::FmodOp, half>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_BF16: - return _device_info->calculate<256, moore::FmodOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::FmodOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F32: - return _device_info->calculate<256, moore::FmodOp, float>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::FmodOp, float>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F64: - return _device_info->calculate<256, moore::FmodOp, double>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::FmodOp, double>(_info, workspace, output, inputs, stream); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } From c1d7ded3c4a421b673ee8f96c3b699948b9fc848 Mon Sep 17 00:00:00 2001 From: littleotherut Date: Fri, 5 Dec 2025 16:26:20 +0800 Subject: [PATCH 18/32] bilinear --- src/infinicore/ops/bilinear/bilinear.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/infinicore/ops/bilinear/bilinear.cc b/src/infinicore/ops/bilinear/bilinear.cc index 1fcea4ff3..2efeb7763 100644 --- a/src/infinicore/ops/bilinear/bilinear.cc +++ b/src/infinicore/ops/bilinear/bilinear.cc @@ -23,13 +23,15 @@ Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) Tensor weight_permuted_cont = weight_permuted->contiguous(); Tensor weight_matrix = weight_permuted_cont->view({in1_features, out_features * in2_features}); - Tensor intermediate = matmul(x1_cont, weight_matrix,1.0f); + Tensor intermediate = matmul(x1_cont, weight_matrix, 1.0f); Tensor intermediate_3d = intermediate->view({batch_size, out_features, in2_features}); - Tensor x2_col = x2_cont->view({batch_size, in2_features, 1}); + Tensor x2_row = x2_cont->view({batch_size, 1, in2_features}); + Tensor intermediate_perm = intermediate_3d->permute({0, 2, 1}); + // Multiply as (1 x in2) @ (in2 x out) to keep ld >= k on all backends. + Tensor out_3d = matmul(x2_row, intermediate_perm, 1.0f); - Tensor out_3d = matmul(intermediate_3d, x2_col,1.0f); Tensor out = out_3d->view({batch_size, out_features}); if (bias) { From 3b35a33bad0dab2e2d3674ed9e53f5a200fe9839 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Sat, 6 Dec 2025 16:20:14 +0800 Subject: [PATCH 19/32] =?UTF-8?q?=E5=A4=8D=E5=8E=9F,moore=20lda=3D1?= =?UTF-8?q?=E6=9C=AA=E8=A7=A3=E5=86=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infinicore/ops/bilinear/bilinear.cc | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/infinicore/ops/bilinear/bilinear.cc b/src/infinicore/ops/bilinear/bilinear.cc index 2efeb7763..ec51a79c1 100644 --- a/src/infinicore/ops/bilinear/bilinear.cc +++ b/src/infinicore/ops/bilinear/bilinear.cc @@ -5,7 +5,6 @@ namespace infinicore::op { - Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) { size_t batch_size = x1->shape()[0]; @@ -27,11 +26,9 @@ Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) Tensor intermediate_3d = intermediate->view({batch_size, out_features, in2_features}); - Tensor x2_row = x2_cont->view({batch_size, 1, in2_features}); - Tensor intermediate_perm = intermediate_3d->permute({0, 2, 1}); - // Multiply as (1 x in2) @ (in2 x out) to keep ld >= k on all backends. - Tensor out_3d = matmul(x2_row, intermediate_perm, 1.0f); + Tensor x2_col = x2_cont->view({batch_size, in2_features, 1}); + Tensor out_3d = matmul(intermediate_3d, x2_col, 1.0f); Tensor out = out_3d->view({batch_size, out_features}); if (bias) { From 7d6dec6a3f5119702b9d25cdde39556a040e2641 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Sat, 6 Dec 2025 16:52:31 +0800 Subject: [PATCH 20/32] =?UTF-8?q?=E8=B0=83=E6=95=B4=E4=B8=BA=E8=A1=8C?= =?UTF-8?q?=E5=90=91=E9=87=8F(=E6=9B=B4=E7=AC=A6=E5=90=88=E6=95=B0?= =?UTF-8?q?=E5=AD=A6=E7=9B=B4=E8=A7=89=E4=B8=8E=E5=BA=95=E5=B1=82=E8=A6=81?= =?UTF-8?q?=E6=B1=82)=E7=9A=84=E5=AE=9E=E7=8E=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infinicore/ops/bilinear/bilinear.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/infinicore/ops/bilinear/bilinear.cc b/src/infinicore/ops/bilinear/bilinear.cc index ec51a79c1..46020d1c0 100644 --- a/src/infinicore/ops/bilinear/bilinear.cc +++ b/src/infinicore/ops/bilinear/bilinear.cc @@ -26,9 +26,11 @@ Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) Tensor intermediate_3d = intermediate->view({batch_size, out_features, in2_features}); - Tensor x2_col = x2_cont->view({batch_size, in2_features, 1}); + Tensor intermediate_3d_trans = intermediate_3d->permute({0, 2, 1})->contiguous(); - Tensor out_3d = matmul(intermediate_3d, x2_col, 1.0f); + Tensor x2_row = x2_cont->view({batch_size, 1, in2_features}); + + Tensor out_3d = matmul(x2_row, intermediate_3d_trans, 1.0f); Tensor out = out_3d->view({batch_size, out_features}); if (bias) { From e6b921ac1b2af2d5a86131bfa2aa54e5e1784409 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Sat, 6 Dec 2025 17:55:51 +0800 Subject: [PATCH 21/32] =?UTF-8?q?fix:=20=E5=AF=B9=E6=89=80=E6=9C=89?= =?UTF-8?q?=E5=BC=A0=E9=87=8F=E6=B7=BB=E5=8A=A0=E6=A3=80=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infinicore/ops/bilinear/bilinear.cc | 71 +++++++++++++++++++------ 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/src/infinicore/ops/bilinear/bilinear.cc b/src/infinicore/ops/bilinear/bilinear.cc index 46020d1c0..87748240c 100644 --- a/src/infinicore/ops/bilinear/bilinear.cc +++ b/src/infinicore/ops/bilinear/bilinear.cc @@ -5,36 +5,76 @@ namespace infinicore::op { +namespace { + +inline bool is_gemm_compatible_3d(const Tensor &t) { + if (t->ndim() != 3) return false; + + const auto batch = t->shape()[0]; + const auto rows = t->shape()[1]; + const auto cols = t->shape()[2]; + const auto bs = t->stride(0); + const auto rs = t->stride(1); + const auto cs = t->stride(2); + + if (rs != 1 && cs != 1) return false; + + if (cs == 1) { + if (rs < static_cast(cols)) return false; + } else { + if (cs < static_cast(rows)) return false; + } + + if (batch > 1 && bs == 0) return false; + + return true; +} + +inline Tensor ensure_gemm_compatible(const Tensor &t) { + if (t->ndim() == 2) { + return t->is_contiguous() ? t : rearrange(t); + } else if (t->ndim() == 3) { + return is_gemm_compatible_3d(t) ? t : rearrange(t); + } + return t->is_contiguous() ? t : rearrange(t); +} + +} // anonymous namespace + Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) { - size_t batch_size = x1->shape()[0]; - size_t in1_features = x1->shape()[1]; - size_t in2_features = x2->shape()[1]; - size_t out_features = weight->shape()[0]; + const size_t batch_size = x1->shape()[0]; + const size_t in1_features = x1->shape()[1]; + const size_t in2_features = x2->shape()[1]; + const size_t out_features = weight->shape()[0]; - auto dtype = x1->dtype(); - auto device = x1->device(); - Tensor x1_cont = x1->is_contiguous() ? x1 : x1->contiguous(); - Tensor x2_cont = x2->is_contiguous() ? x2 : x2->contiguous(); + Tensor x1_compat = ensure_gemm_compatible(x1); + Tensor x2_compat = ensure_gemm_compatible(x2); Tensor weight_cont = weight->is_contiguous() ? weight : weight->contiguous(); Tensor weight_permuted = weight_cont->permute({1, 0, 2}); - Tensor weight_permuted_cont = weight_permuted->contiguous(); + Tensor weight_permuted_cont = weight_permuted->is_contiguous() + ? weight_permuted + : weight_permuted->contiguous(); Tensor weight_matrix = weight_permuted_cont->view({in1_features, out_features * in2_features}); - Tensor intermediate = matmul(x1_cont, weight_matrix, 1.0f); + Tensor intermediate = matmul(x1_compat, weight_matrix, 1.0f); Tensor intermediate_3d = intermediate->view({batch_size, out_features, in2_features}); + Tensor intermediate_transposed = intermediate_3d->permute({0, 2, 1}); + Tensor intermediate_compat = ensure_gemm_compatible(intermediate_transposed); - Tensor intermediate_3d_trans = intermediate_3d->permute({0, 2, 1})->contiguous(); - - Tensor x2_row = x2_cont->view({batch_size, 1, in2_features}); + Tensor x2_row = x2_compat->view({batch_size, 1, in2_features}); + Tensor x2_row_compat = ensure_gemm_compatible(x2_row); - Tensor out_3d = matmul(x2_row, intermediate_3d_trans, 1.0f); + Tensor out_3d = matmul(x2_row_compat, intermediate_compat, 1.0f); Tensor out = out_3d->view({batch_size, out_features}); if (bias) { - Tensor bias_broadcast = (*bias)->as_strided({batch_size, out_features}, {0, (*bias)->strides()[0]}); + Tensor bias_broadcast = (*bias)->as_strided( + {batch_size, out_features}, + {0, (*bias)->strides()[0]} + ); out = add(out, bias_broadcast); } return out; @@ -42,7 +82,6 @@ Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) void bilinear_(Tensor out, Tensor x1, Tensor x2, Tensor weight, std::optional bias) { Tensor result = bilinear(x1, x2, weight, bias); - // Copy result to out rearrange_(out, result); } From ba9b48c51457169a157312cb819c2a9f45e3b2b2 Mon Sep 17 00:00:00 2001 From: littleotherut Date: Sun, 7 Dec 2025 12:48:28 +0800 Subject: [PATCH 22/32] =?UTF-8?q?gemm=E6=B7=BB=E5=8A=A0TF32=E5=BC=80?= =?UTF-8?q?=E5=85=B3=E6=8E=A5=E5=8F=A3,bilinear=E7=A6=81=E7=94=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infinicore/ops/bilinear/bilinear.cc | 22 +++++++++++++++++++++ src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu | 11 ++++++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/infinicore/ops/bilinear/bilinear.cc b/src/infinicore/ops/bilinear/bilinear.cc index 87748240c..6f2e2be2d 100644 --- a/src/infinicore/ops/bilinear/bilinear.cc +++ b/src/infinicore/ops/bilinear/bilinear.cc @@ -3,9 +3,30 @@ #include "infinicore/ops/add.hpp" #include "infinicore/ops/rearrange.hpp" +#ifdef ENABLE_NVIDIA_API +namespace op::gemm::nvidia { + void set_tf32_enabled(bool); +} +#endif + namespace infinicore::op { namespace { +// RAII 守卫:作用域内禁用 TF32 +struct ScopedTF32Disable { + ScopedTF32Disable() { +#ifdef ENABLE_NVIDIA_API + // 实际项目中建议添加检查,仅在 NVIDIA 设备上调用 + // 使用 ::op 强制从全局命名空间查找,避免被当前的 infinicore::op 遮蔽 + ::op::gemm::nvidia::set_tf32_enabled(false); +#endif + } + ~ScopedTF32Disable() { +#ifdef ENABLE_NVIDIA_API + ::op::gemm::nvidia::set_tf32_enabled(true); +#endif + } +}; inline bool is_gemm_compatible_3d(const Tensor &t) { if (t->ndim() != 3) return false; @@ -42,6 +63,7 @@ inline Tensor ensure_gemm_compatible(const Tensor &t) { } // anonymous namespace Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) { + ScopedTF32Disable tf32_guard; const size_t batch_size = x1->shape()[0]; const size_t in1_features = x1->shape()[1]; diff --git a/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu b/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu index 0e0c65f2b..580cca658 100644 --- a/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu +++ b/src/infiniop/ops/gemm/nvidia/gemm_nvidia.cu @@ -3,6 +3,14 @@ namespace op::gemm::nvidia { +// 添加线程局部控制开关 +thread_local bool g_tf32_enabled = true; + +// 暴露设置函数(非静态,以便外部链接) +void set_tf32_enabled(bool enabled) { + g_tf32_enabled = enabled; +} + struct Descriptor::Opaque { std::shared_ptr internal; }; @@ -71,7 +79,8 @@ infiniStatus_t Descriptor::calculate( #if defined(ENABLE_ILUVATAR_API) || defined(ENABLE_HYGON_API) compute_type = CUDA_R_32F; #else - compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + // compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + compute_type = g_tf32_enabled ? CUBLAS_COMPUTE_32F_FAST_TF32 : CUBLAS_COMPUTE_32F; #endif break; From e917db20c5452baaeec42aa73c7d4fafd75df944 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Sun, 7 Dec 2025 16:04:10 +0800 Subject: [PATCH 23/32] =?UTF-8?q?asinh=5Fdemo=5Fnvidia(=E6=9C=AA=E4=BC=98?= =?UTF-8?q?=E5=8C=96)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops.hpp | 1 + include/infinicore/ops/asinh.hpp | 16 ++ include/infiniop.h | 1 + include/infiniop/ops/asinh.h | 24 +++ python/infinicore/__init__.py | 2 + python/infinicore/ops/asinh.py | 11 ++ src/infinicore/ops/asinh/asinh.cc | 27 ++++ src/infinicore/ops/asinh/asinh_infiniop.cc | 52 +++++++ src/infinicore/pybind11/ops.hpp | 2 + src/infinicore/pybind11/ops/asinh.hpp | 24 +++ src/infiniop/ops/asinh/cpu/asinh_cpu.cc | 50 +++++++ src/infiniop/ops/asinh/cpu/asinh_cpu.h | 22 +++ src/infiniop/ops/asinh/cuda/kernel.cuh | 31 ++++ src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu | 57 +++++++ .../ops/asinh/nvidia/asinh_nvidia.cuh | 8 + src/infiniop/ops/asinh/operator.cc | 141 ++++++++++++++++++ src/infiniop/ops/fmod/operator.cc | 11 +- test/infinicore/ops/asinh.py | 6 +- 18 files changed, 473 insertions(+), 13 deletions(-) create mode 100644 include/infinicore/ops/asinh.hpp create mode 100644 include/infiniop/ops/asinh.h create mode 100644 python/infinicore/ops/asinh.py create mode 100644 src/infinicore/ops/asinh/asinh.cc create mode 100644 src/infinicore/ops/asinh/asinh_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/asinh.hpp create mode 100644 src/infiniop/ops/asinh/cpu/asinh_cpu.cc create mode 100644 src/infiniop/ops/asinh/cpu/asinh_cpu.h create mode 100644 src/infiniop/ops/asinh/cuda/kernel.cuh create mode 100644 src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu create mode 100644 src/infiniop/ops/asinh/nvidia/asinh_nvidia.cuh create mode 100644 src/infiniop/ops/asinh/operator.cc diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index bb05d022f..80b8cad19 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -2,6 +2,7 @@ #include "ops/add.hpp" #include "ops/attention.hpp" +#include "ops/asinh.hpp" #include "ops/baddbmm.hpp" #include "ops/bilinear.hpp" #include "ops/causal_softmax.hpp" diff --git a/include/infinicore/ops/asinh.hpp b/include/infinicore/ops/asinh.hpp new file mode 100644 index 000000000..1d04d6bea --- /dev/null +++ b/include/infinicore/ops/asinh.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class Asinh { +public: + using schema = void (*)(Tensor, Tensor); + static void execute(Tensor y, Tensor x); + static common::OpDispatcher &dispatcher(); +}; + +Tensor asinh(Tensor x); +void asinh_(Tensor y, Tensor x); +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index 3bafe4209..aa93c177d 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -4,6 +4,7 @@ #include "infiniop/handle.h" #include "infiniop/ops/add.h" #include "infiniop/ops/attention.h" +#include "infiniop/ops/asinh.h" #include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" diff --git a/include/infiniop/ops/asinh.h b/include/infiniop/ops/asinh.h new file mode 100644 index 000000000..4849bc422 --- /dev/null +++ b/include/infiniop/ops/asinh.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_ASINH_API_H_ +#define __INFINIOP_ASINH_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopAsinhDescriptor_t; + +__C __export infiniStatus_t infiniopCreateAsinhDescriptor(infiniopHandle_t handle, + infiniopAsinhDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x); + +__C __export infiniStatus_t infiniopGetAsinhWorkspaceSize(infiniopAsinhDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopAsinh(infiniopAsinhDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream); + +__C __export infiniStatus_t infiniopDestroyAsinhDescriptor(infiniopAsinhDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index b1f571e6c..a2f32b62c 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -41,6 +41,7 @@ ) from infinicore.ops.add import add from infinicore.ops.attention import attention +from infinicore.ops.asinh import asinh from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul from infinicore.ops.narrow import narrow @@ -104,6 +105,7 @@ # Operations. "add", "attention", + "asinh", "baddbmm", "bilinear", "fmod", diff --git a/python/infinicore/ops/asinh.py b/python/infinicore/ops/asinh.py new file mode 100644 index 000000000..05ec58779 --- /dev/null +++ b/python/infinicore/ops/asinh.py @@ -0,0 +1,11 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def asinh(input, *, out=None): + if out is None: + return Tensor(_infinicore.asinh(input._underlying)) + + _infinicore.asinh_(out._underlying, input._underlying) + + return out diff --git a/src/infinicore/ops/asinh/asinh.cc b/src/infinicore/ops/asinh/asinh.cc new file mode 100644 index 000000000..fbf131d99 --- /dev/null +++ b/src/infinicore/ops/asinh/asinh.cc @@ -0,0 +1,27 @@ +#include "infinicore/ops/asinh.hpp" +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &Asinh::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void Asinh::execute(Tensor y, Tensor x) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x); + infinicore::context::setDevice(y->device()); + dispatcher().lookup(y->device().getType())(y, x); +} + +Tensor asinh(Tensor x) { + auto y = Tensor::empty(x->shape(), x->dtype(), x->device()); + asinh_(y, x); + return y; +} + +void asinh_(Tensor y, Tensor x) { + Asinh::execute(y, x); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/asinh/asinh_infiniop.cc b/src/infinicore/ops/asinh/asinh_infiniop.cc new file mode 100644 index 000000000..ceed8d5a2 --- /dev/null +++ b/src/infinicore/ops/asinh/asinh_infiniop.cc @@ -0,0 +1,52 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/asinh.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::asinh_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopAsinhDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyAsinhDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor y, Tensor x) { + size_t seed = hash_combine(y, x); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopAsinhDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateAsinhDescriptor( + context::getInfiniopHandle(y->device()), &desc, + y->desc(), x->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetAsinhWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopAsinh( + desc, workspace->data(), workspace_size, + y->data(), x->data(), context::getStream())); +} + +static bool registered = []() { + Asinh::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::asinh_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 0bd903922..966e69eaa 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -4,6 +4,7 @@ #include "ops/add.hpp" #include "ops/attention.hpp" +#include "ops/asinh.hpp" #include "ops/baddbmm.hpp" #include "ops/bilinear.hpp" #include "ops/causal_softmax.hpp" @@ -26,6 +27,7 @@ namespace infinicore::ops { inline void bind(py::module &m) { bind_add(m); bind_attention(m); + bind_asinh(m); bind_baddbmm(m); bind_bilinear(m); bind_causal_softmax(m); diff --git a/src/infinicore/pybind11/ops/asinh.hpp b/src/infinicore/pybind11/ops/asinh.hpp new file mode 100644 index 000000000..bf1fcca23 --- /dev/null +++ b/src/infinicore/pybind11/ops/asinh.hpp @@ -0,0 +1,24 @@ +#pragma once + +#include + +#include "infinicore/ops/asinh.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_asinh(py::module &m) { + m.def("asinh", + &op::asinh, + py::arg("x"), + R"doc(Element-wise inverse hyperbolic sine function.)doc"); + + m.def("asinh_", + &op::asinh_, + py::arg("y"), + py::arg("x"), + R"doc(In-place element-wise inverse hyperbolic sine function.)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/asinh/cpu/asinh_cpu.cc b/src/infiniop/ops/asinh/cpu/asinh_cpu.cc new file mode 100644 index 000000000..860e06032 --- /dev/null +++ b/src/infiniop/ops/asinh/cpu/asinh_cpu.cc @@ -0,0 +1,50 @@ +#include "asinh_cpu.h" + +namespace op::asinh::cpu { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const{ + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate(_info, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate(_info, output, inputs, stream); + default : + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::asinh::cpu \ No newline at end of file diff --git a/src/infiniop/ops/asinh/cpu/asinh_cpu.h b/src/infiniop/ops/asinh/cpu/asinh_cpu.h new file mode 100644 index 000000000..196e208a7 --- /dev/null +++ b/src/infiniop/ops/asinh/cpu/asinh_cpu.h @@ -0,0 +1,22 @@ +#ifndef __ASINH_CPU_H__ +#define __ASINH_CPU_H__ + +#include + +#include "../../../elementwise/cpu/elementwise_cpu.h" + +ELEMENTWISE_DESCRIPTOR(asinh, cpu) + +namespace op::asinh::cpu { +typedef struct AsinhOp { +public: + static constexpr size_t num_inputs = 1; + + template + T operator()(const T &x) const { + return std::asinh(x); + } +} AsinhOp; +} // namespace op::asinh::cpu + +#endif // __ASINH_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asinh/cuda/kernel.cuh b/src/infiniop/ops/asinh/cuda/kernel.cuh new file mode 100644 index 000000000..9b85c830a --- /dev/null +++ b/src/infiniop/ops/asinh/cuda/kernel.cuh @@ -0,0 +1,31 @@ +#ifndef __ASINH_CUDA_KERNEL_H__ +#define __ASINH_CUDA_KERNEL_H__ + +namespace op::asinh::cuda{ + +typedef struct AsinhOp { +public : + static constexpr size_t num_inputs = 1; + template + __device__ __forceinline__ T operator()(const T &x) const { + + if constexpr (std::is_same_v){ + float x_f = __half2float(x); + return __float2half(asinhf(x_f)); + } + else if constexpr (std::is_same_v) { + float x_f = __bfloat162float(x); + return __float2bfloat16(asinhf(x_f)); + } + else if constexpr (std::is_same_v) { + return asinhf(x); + }else { + return ::asinh(x); + } + } + +} AsinOp; + +} // namespace op::Asinh::cuda + +#endif // __ASINH_CUDA_KERNEL_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu new file mode 100644 index 000000000..ae9233bf7 --- /dev/null +++ b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu @@ -0,0 +1,57 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "../../../elementwise/nvidia/elementwise_nvidia.cuh" + +#include "../cuda/kernel.cuh" +#include "asinh_nvidia.cuh" + +namespace op::asinh::nvidia { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::AsinOp,half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::AsinOp,cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::AsinOp,float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::AsinOp,double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +} + +} // namespace op::Asinh::nvidia \ No newline at end of file diff --git a/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cuh b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cuh new file mode 100644 index 000000000..5b75a553c --- /dev/null +++ b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ASINH_NVIDIA_API_H__ +#define __ASINH_NVIDIA_API_H__ + +#include "../../../elementwise/nvidia/elementwise_nvidia_api.cuh" + +ELEMENTWISE_DESCRIPTOR(asinh, nvidia) + +#endif // __ASINH_NVIDIA_API_H \ No newline at end of file diff --git a/src/infiniop/ops/asinh/operator.cc b/src/infiniop/ops/asinh/operator.cc new file mode 100644 index 000000000..3a4eb5d14 --- /dev/null +++ b/src/infiniop/ops/asinh/operator.cc @@ -0,0 +1,141 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/asinh.h" + +#ifdef ENABLE_CPU_API +#include "cpu/asinh_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#include "nvidia/asinh_nvidia.cuh" +#endif +#ifdef ENABLE_METAX_API +#include "metax/asinh_metax.h" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/asinh_moore.h" +#endif + +__C infiniStatus_t infiniopCreateAsinhDescriptor( + infiniopHandle_t handle, + infiniopAsinhDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::asinh::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + {x_desc}) + switch (handle->device) { + +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CREATE +} + +__C infiniStatus_t infiniopGetAsinhWorkspaceSize(infiniopAsinhDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopAsinh(infiniopAsinhDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, y, {x}, stream); + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyAsinhDescriptor(infiniopAsinhDescriptor_t desc) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DELETE +} \ No newline at end of file diff --git a/src/infiniop/ops/fmod/operator.cc b/src/infiniop/ops/fmod/operator.cc index 08fe0ef3c..c94fdfc10 100644 --- a/src/infiniop/ops/fmod/operator.cc +++ b/src/infiniop/ops/fmod/operator.cc @@ -5,7 +5,7 @@ #ifdef ENABLE_CPU_API #include "cpu/fmod_cpu.h" #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) #include "nvidia/fmod_nvidia.cuh" #endif #ifdef ENABLE_METAX_API @@ -41,9 +41,6 @@ __C infiniStatus_t infiniopCreateFmodDescriptor( #ifdef ENABLE_ILUVATAR_API CREATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif -#ifdef ENABLE_QY_API - CREATE(INFINI_DEVICE_QY, nvidia); -#endif #ifdef ENABLE_METAX_API CREATE(INFINI_DEVICE_METAX, metax); #endif @@ -74,9 +71,6 @@ __C infiniStatus_t infiniopGetFmodWorkspaceSize(infiniopFmodDescriptor_t desc, s #ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif -#ifdef ENABLE_QY_API - GET(INFINI_DEVICE_QY, nvidia); -#endif #ifdef ENABLE_METAX_API GET(INFINI_DEVICE_METAX, metax); #endif @@ -113,9 +107,6 @@ __C infiniStatus_t infiniopFmod( #ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif -#ifdef ENABLE_QY_API - CALCULATE(INFINI_DEVICE_QY, nvidia); -#endif #ifdef ENABLE_METAX_API CALCULATE(INFINI_DEVICE_METAX, metax); #endif diff --git a/test/infinicore/ops/asinh.py b/test/infinicore/ops/asinh.py index 97bcd5edb..715cded1b 100644 --- a/test/infinicore/ops/asinh.py +++ b/test/infinicore/ops/asinh.py @@ -97,9 +97,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.asinh(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.asinh(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.asinh(*args, **kwargs) def main(): From c6abfd5ba7bfef3b1a665e5b68bc4b662bc4a07b Mon Sep 17 00:00:00 2001 From: Kinorw Date: Mon, 8 Dec 2025 14:21:31 +0800 Subject: [PATCH 24/32] =?UTF-8?q?metax=20=E5=92=8C=20moore=E7=9A=84?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0(=E5=BE=85=E6=B5=8B)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/ops/asinh/cuda/kernel.cuh | 2 +- src/infiniop/ops/asinh/metax/asinh.maca | 58 ++++++++++++++++++ src/infiniop/ops/asinh/metax/asinh_metax.h | 8 +++ src/infiniop/ops/asinh/moore/asinh_moore.h | 8 +++ src/infiniop/ops/asinh/moore/asinh_moore.mu | 59 +++++++++++++++++++ src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu | 10 ++-- 6 files changed, 139 insertions(+), 6 deletions(-) create mode 100644 src/infiniop/ops/asinh/metax/asinh.maca create mode 100644 src/infiniop/ops/asinh/metax/asinh_metax.h create mode 100644 src/infiniop/ops/asinh/moore/asinh_moore.h create mode 100644 src/infiniop/ops/asinh/moore/asinh_moore.mu diff --git a/src/infiniop/ops/asinh/cuda/kernel.cuh b/src/infiniop/ops/asinh/cuda/kernel.cuh index 9b85c830a..eba99efd0 100644 --- a/src/infiniop/ops/asinh/cuda/kernel.cuh +++ b/src/infiniop/ops/asinh/cuda/kernel.cuh @@ -24,7 +24,7 @@ public : } } -} AsinOp; +} AsinhOp; } // namespace op::Asinh::cuda diff --git a/src/infiniop/ops/asinh/metax/asinh.maca b/src/infiniop/ops/asinh/metax/asinh.maca new file mode 100644 index 000000000..f6f4ac3f9 --- /dev/null +++ b/src/infiniop/ops/asinh/metax/asinh.maca @@ -0,0 +1,58 @@ +#include "asinh_metax.h" +#include "../../../elementwise/metax/elementwise_metax.h" + +#include "../cuda/kernel.cuh" + +namespace op::asinh::metax { + +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + // create CUDA elementwise descriptor + CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::AsinhOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::AsinhOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::AsinhOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::AsinhOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +} + +} // namespace op::asinh::metax diff --git a/src/infiniop/ops/asinh/metax/asinh_metax.h b/src/infiniop/ops/asinh/metax/asinh_metax.h new file mode 100644 index 000000000..dacb77f0d --- /dev/null +++ b/src/infiniop/ops/asinh/metax/asinh_metax.h @@ -0,0 +1,8 @@ +#ifndef __ASINH_METAX_API_H__ +#define __ASINH_METAX_API_H__ + +#include "../../../elementwise/metax/elementwise_metax_api.h" + +ELEMENTWISE_DESCRIPTOR(asinh, metax) + +#endif // __ASINH_METAX_API_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asinh/moore/asinh_moore.h b/src/infiniop/ops/asinh/moore/asinh_moore.h new file mode 100644 index 000000000..36c93d53a --- /dev/null +++ b/src/infiniop/ops/asinh/moore/asinh_moore.h @@ -0,0 +1,8 @@ +#ifndef __ASINH_MOORE_API_H__ +#define __ASINH_MOORE_API_H__ + +#include "../../../elementwise/moore/elementwise_moore_api.h" + +ELEMENTWISE_DESCRIPTOR(asinh, moore) + +#endif // __ASINH_MOORE_API_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asinh/moore/asinh_moore.mu b/src/infiniop/ops/asinh/moore/asinh_moore.mu new file mode 100644 index 000000000..35a8d6475 --- /dev/null +++ b/src/infiniop/ops/asinh/moore/asinh_moore.mu @@ -0,0 +1,59 @@ +#include "asinh_moore.h" + +#include "../../../elementwise/moore/elementwise_moore.h" + +#include "../cuda/kernel.cuh" + +namespace op::asinh::moore { +Descriptor::~Descriptor() = default; + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t out_desc, + std::vector input_desc_vec) { + + auto handle = reinterpret_cast(handle_); + auto dtype = out_desc->dtype(); + + const auto &x_desc = input_desc_vec.at(0); + const auto &y_shape = out_desc->shape(); + const auto &x_shape = x_desc->shape(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64, INFINI_DTYPE_BF16); + + CHECK_SAME_SHAPE(y_shape, x_shape); + + // create MOORE elementwise descriptor + CREATE_ELEMENTWISE_MOORE_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + switch (_dtype) { + case INFINI_DTYPE_F16: + return _device_info->calculate<256, cuda::AsinhOp, half>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_BF16: + return _device_info->calculate<256, cuda::AsinhOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F32: + return _device_info->calculate<256, cuda::AsinhOp, float>(_info, workspace, output, inputs, stream); + case INFINI_DTYPE_F64: + return _device_info->calculate<256, cuda::AsinhOp, double>(_info, workspace, output, inputs, stream); + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::asinh::moore \ No newline at end of file diff --git a/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu index ae9233bf7..788ab4502 100644 --- a/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu +++ b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu @@ -41,17 +41,17 @@ infiniStatus_t Descriptor::calculate( } switch (_dtype) { case INFINI_DTYPE_F16: - return _device_info->calculate<256, cuda::AsinOp,half>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::AsinhOp,half>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_BF16: - return _device_info->calculate<256, cuda::AsinOp,cuda_bfloat16>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::AsinhOp,cuda_bfloat16>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F32: - return _device_info->calculate<256, cuda::AsinOp,float>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::AsinhOp,float>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F64: - return _device_info->calculate<256, cuda::AsinOp,double>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::AsinhOp,double>(_info, workspace, output, inputs, stream); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } } -} // namespace op::Asinh::nvidia \ No newline at end of file +} // namespace op::asinh::nvidia \ No newline at end of file From 1ca4654b565b08fb713fbceb24d214e24c449982 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Tue, 9 Dec 2025 19:39:58 +0800 Subject: [PATCH 25/32] =?UTF-8?q?adaptive=20max=20pool1d=20cpu=20demo(?= =?UTF-8?q?=E5=B7=B2=E6=B5=8B=E8=AF=95)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops.hpp | 1 + .../infinicore/ops/adaptive_max_pool1d.hpp | 16 ++ include/infiniop.h | 1 + include/infiniop/ops/adaptive_max_pool1d.h | 22 +++ python/infinicore/nn/functional/__init__.py | 2 + .../nn/functional/adaptive_max_pool1d.py | 43 +++++ .../adaptive_max_pool1d.cc | 30 ++++ .../adaptive_max_pool1d_infiniop.cc | 52 +++++++ src/infinicore/pybind11/ops.hpp | 2 + .../pybind11/ops/adaptive_max_pool1d.hpp | 39 +++++ .../adaptive_max_pool1d/adaptive_max_pool1d.h | 47 ++++++ .../cpu/adaptive_max_pool1d_cpu.cc | 98 ++++++++++++ .../cpu/adaptive_max_pool1d_cpu.h | 8 + src/infiniop/ops/adaptive_max_pool1d/info.h | 69 ++++++++ .../ops/adaptive_max_pool1d/operator.cc | 147 ++++++++++++++++++ test/infinicore/ops/adaptive_max_pool1d.py | 6 +- 16 files changed, 580 insertions(+), 3 deletions(-) create mode 100644 include/infinicore/ops/adaptive_max_pool1d.hpp create mode 100644 include/infiniop/ops/adaptive_max_pool1d.h create mode 100644 python/infinicore/nn/functional/adaptive_max_pool1d.py create mode 100644 src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc create mode 100644 src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc create mode 100644 src/infinicore/pybind11/ops/adaptive_max_pool1d.hpp create mode 100644 src/infiniop/ops/adaptive_max_pool1d/adaptive_max_pool1d.h create mode 100644 src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc create mode 100644 src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.h create mode 100644 src/infiniop/ops/adaptive_max_pool1d/info.h create mode 100644 src/infiniop/ops/adaptive_max_pool1d/operator.cc diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 80b8cad19..1a6cb51a4 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -1,5 +1,6 @@ #pragma once +#include "ops/adaptive_max_pool1d.hpp" #include "ops/add.hpp" #include "ops/attention.hpp" #include "ops/asinh.hpp" diff --git a/include/infinicore/ops/adaptive_max_pool1d.hpp b/include/infinicore/ops/adaptive_max_pool1d.hpp new file mode 100644 index 000000000..7d1d03494 --- /dev/null +++ b/include/infinicore/ops/adaptive_max_pool1d.hpp @@ -0,0 +1,16 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" + +namespace infinicore::op { +class AdaptiveMaxPool1D { +public: + using schema = void (*)(Tensor, Tensor, size_t); + static void execute(Tensor y, Tensor x, size_t output_size); + static common::OpDispatcher &dispatcher(); +}; + +Tensor adaptive_max_pool1d(Tensor x, size_t output_size); +void adaptive_max_pool1d_(Tensor y, Tensor x, size_t output_size); +} // namespace infinicore::op \ No newline at end of file diff --git a/include/infiniop.h b/include/infiniop.h index aa93c177d..9bb21e0ca 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -2,6 +2,7 @@ #define __INFINIOP_API_H__ #include "infiniop/handle.h" +#include "infiniop/ops/adaptive_max_pool1d.h" #include "infiniop/ops/add.h" #include "infiniop/ops/attention.h" #include "infiniop/ops/asinh.h" diff --git a/include/infiniop/ops/adaptive_max_pool1d.h b/include/infiniop/ops/adaptive_max_pool1d.h new file mode 100644 index 000000000..17a126abb --- /dev/null +++ b/include/infiniop/ops/adaptive_max_pool1d.h @@ -0,0 +1,22 @@ +#ifndef __INFINIOP_ADAPTIVE_MAX_POOL1D_H__ +#define __INFINIOP_ADAPTIVE_MAX_POOL1D_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopAdaptiveMaxPool1dDescriptor_t; + +__C __export infiniStatus_t infiniopCreateAdaptiveMaxPool1dDescriptor( + infiniopHandle_t handle, + infiniopAdaptiveMaxPool1dDescriptor_t *desc, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size); + +__C __export infiniStatus_t infiniopGetAdaptiveMaxPool1dWorkspaceSize(infiniopAdaptiveMaxPool1dDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopAdaptiveMaxPool1d(infiniopAdaptiveMaxPool1dDescriptor_t desc, void *workspace, size_t workspace_size, + void *y, const void *x, void *stream); + +__C __export infiniStatus_t infiniopDestroyAdaptiveMaxPool1dDescriptor(infiniopAdaptiveMaxPool1dDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index 255079790..b9b313b67 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -1,3 +1,4 @@ +from .adaptive_max_pool1d import adaptive_max_pool1d from .causal_softmax import causal_softmax from .embedding import embedding from .linear import linear @@ -8,6 +9,7 @@ from .swiglu import swiglu __all__ = [ + "adaptive_max_pool1d", "causal_softmax", "random_sample", "rms_norm", diff --git a/python/infinicore/nn/functional/adaptive_max_pool1d.py b/python/infinicore/nn/functional/adaptive_max_pool1d.py new file mode 100644 index 000000000..355b7e096 --- /dev/null +++ b/python/infinicore/nn/functional/adaptive_max_pool1d.py @@ -0,0 +1,43 @@ +from typing import List + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def adaptive_max_pool1d( + input: Tensor, + output_size: int, + *, + out=None, +) -> Tensor: + r"""Applies a 1D adaptive max pooling over an input signal composed of + several input planes. + + The output size is H_out. The algorithm used is fairly simple: + + .. math:: + \text{start} = \left\lfloor \frac{i \cdot L_{in}}{L_{out}} \right\rfloor + + \text{end} = \left\lceil \frac{(i + 1) \cdot L_{in}}{L_{out}} \right\rceil + + where :math:`L_{in}` is the size of the input dimension, and :math:`L_{out}` is the size of the output dimension. + + Args: + input (Tensor): Input tensor of shape (N, C, L_in) + output_size (int): The target output size (L_out) + out (Tensor, optional): Output tensor. + + Returns: + Tensor: The result of the adaptive max pooling operation. + """ + + if out is None: + return Tensor( + _infinicore.adaptive_max_pool1d(input._underlying, output_size) + ) + + _infinicore.adaptive_max_pool1d_( + out._underlying, input._underlying, output_size + ) + + return out \ No newline at end of file diff --git a/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc new file mode 100644 index 000000000..de2d24508 --- /dev/null +++ b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc @@ -0,0 +1,30 @@ +#include "infinicore/ops/adaptive_max_pool1d.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op{ + +common::OpDispatcher &AdaptiveMaxPool1D::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +} + +void AdaptiveMaxPool1D::execute(Tensor y, Tensor x, size_t output_size) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x); + infinicore::context::setDevice(y->device()); + dispatcher().lookup(y->device().getType())(y, x, output_size); +} + +Tensor adaptive_max_pool1d(Tensor x, size_t output_size) { + infinicore::Shape y_shape = x->shape(); + y_shape.back() = output_size; + auto y = Tensor::empty(y_shape, x->dtype(), x->device()); + adaptive_max_pool1d_(y, x, output_size); + return y; +} + +void adaptive_max_pool1d_(Tensor y, Tensor x, size_t output_size) { + AdaptiveMaxPool1D::execute(y, x, output_size); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc new file mode 100644 index 000000000..90b8eb77c --- /dev/null +++ b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc @@ -0,0 +1,52 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/common/cache.hpp" +#include "infinicore/ops/adaptive_max_pool1d.hpp" +#include + +namespace infinicore::op::adaptive_max_pool1d_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopAdaptiveMaxPool1dDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyAdaptiveMaxPool1dDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor y, Tensor x, size_t out) { + size_t seed = hash_combine(y, x, out); + + auto device_type = context::getDevice().getType(); + auto device_index = context::getDevice().getIndex(); + + auto &cache = caches.getCache(device_type, device_index); + + auto desc_opt = cache.get(seed); + infiniopAdaptiveMaxPool1dDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateAdaptiveMaxPool1dDescriptor( + context::getInfiniopHandle(y->device()), &desc, + y->desc(), x->desc(), out)); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetAdaptiveMaxPool1dWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopAdaptiveMaxPool1d( + desc, workspace->data(), workspace_size, + y->data(), x->data(), context::getStream())); +} + +static bool registered = []() { + AdaptiveMaxPool1D::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::adaptive_max_pool1d_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 966e69eaa..aded1e685 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -2,6 +2,7 @@ #include +#include "ops/adaptive_max_pool1d.hpp" #include "ops/add.hpp" #include "ops/attention.hpp" #include "ops/asinh.hpp" @@ -25,6 +26,7 @@ namespace py = pybind11; namespace infinicore::ops { inline void bind(py::module &m) { + bind_adaptive_max_pool1d(m); bind_add(m); bind_attention(m); bind_asinh(m); diff --git a/src/infinicore/pybind11/ops/adaptive_max_pool1d.hpp b/src/infinicore/pybind11/ops/adaptive_max_pool1d.hpp new file mode 100644 index 000000000..747d92b9a --- /dev/null +++ b/src/infinicore/pybind11/ops/adaptive_max_pool1d.hpp @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include "infinicore/ops/adaptive_max_pool1d.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_adaptive_max_pool1d(py::module &m) { + m.def("adaptive_max_pool1d", + &op::adaptive_max_pool1d, + py::arg("x"), + py::arg("output_size"), + R"doc(1D Adaptive Max Pooling. + +Args: + x: Input tensor of shape (N, C, L_in) or (N, L_in) + output_size: Target output size L_out +Returns: + Output tensor of shape (N, C, L_out) or (N, L_out) +)doc"); + + m.def("adaptive_max_pool1d_", + &op::adaptive_max_pool1d_, + py::arg("y"), + py::arg("x"), + py::arg("output_size"), + R"doc(In-place 1D Adaptive Max Pooling. + +Args: + y: Output tensor of shape (N, C, L_out) or (N, L_out) + x: Input tensor of shape (N, C, L_in) or (N, L_in) + output_size: Target output size L_out +)doc"); +} + +} // namespace infinicore::ops \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/adaptive_max_pool1d.h b/src/infiniop/ops/adaptive_max_pool1d/adaptive_max_pool1d.h new file mode 100644 index 000000000..9c2fc797a --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/adaptive_max_pool1d.h @@ -0,0 +1,47 @@ +#ifndef ADAPTIVE_MAX_POOL1D_H +#define ADAPTIVE_MAX_POOL1D_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::adaptive_max_pool1d::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + AdaptiveMaxPool1DInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + AdaptiveMaxPool1DInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t x_desc, \ + size_t output_size); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *y, \ + const void *x, \ + void *stream) const; \ + }; \ + } + +#endif // ADAPTIVE_MAX_POOL1D_H \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc b/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc new file mode 100644 index 000000000..3f2e8e450 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc @@ -0,0 +1,98 @@ +#include "adaptive_max_pool1d_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../../reduce/cpu/reduce.h" +#include +#include + +namespace op::adaptive_max_pool1d::cpu { + +Descriptor::~Descriptor() {} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + auto result = AdaptiveMaxPool1DInfo::create(y_desc, x_desc, output_size); + CHECK_RESULT(result); + *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t adaptiveMaxPool1D(const AdaptiveMaxPool1DInfo *info, T *y, const T *x) { + + const size_t ndim = info->ndim(); + const size_t batch_size = info->shape[0]; + const size_t channels = ndim > 2 ? info->shape[1] : 1; + + const size_t input_length = info->input_length(); + const size_t output_length = info->output_length(); + + // 计算总的任务块数 (Batch * Channels) + const ptrdiff_t total_blocks = static_cast(batch_size * channels); + + const ptrdiff_t x_stride_last = info->x_strides.back(); + +#pragma omp parallel for + for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) { + const size_t i = block_idx / channels; // batch index + const size_t j = block_idx % channels; // channel index + + const T *x_ptr_base; + T *y_ptr_base; + + if (ndim > 2) { // (N, C, L) + x_ptr_base = x + i * info->x_strides[0] + j * info->x_strides[1]; + y_ptr_base = y + i * info->y_strides[0] + j * info->y_strides[1]; + } else { // (N, L) + x_ptr_base = x + i * info->x_strides[0]; + y_ptr_base = y + i * info->y_strides[0]; + } + + for (size_t out_idx = 0; out_idx < output_length; ++out_idx) { + // 计算池化窗口范围 [start_index, end_index) + // 公式参考 PyTorch: + // start = floor(out_idx * L_in / L_out) + // end = ceil((out_idx + 1) * L_in / L_out) + int start_index = std::floor((float)out_idx * input_length / output_length); + int end_index = std::ceil((float)(out_idx + 1) * input_length / output_length); + + start_index = std::max(start_index, 0); + end_index = std::min(end_index, (int)input_length); + int window_len = end_index - start_index; + + if (window_len <= 0) { + continue; + } + + const T *window_ptr = x_ptr_base + start_index * x_stride_last; + + auto max_val = op::common_cpu::reduce_op::max(window_ptr, window_len, x_stride_last); + y_ptr_base[out_idx] = utils::cast(max_val); + } + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *x, + void *stream) const { + + if (_info.atype == INFINI_DTYPE_F32) { + return adaptiveMaxPool1D(&_info, (float *)y, (const float *)x); + } else if (_info.atype == INFINI_DTYPE_F16) { + return adaptiveMaxPool1D(&_info, (fp16_t *)y, (const fp16_t *)x); + } else if (_info.atype == INFINI_DTYPE_BF16) { + return adaptiveMaxPool1D(&_info, (bf16_t *)y, (const bf16_t *)x); + } else if (_info.atype == INFINI_DTYPE_F64) { + return adaptiveMaxPool1D(&_info, (double *)y, (const double *)x); + } + + return INFINI_STATUS_BAD_TENSOR_DTYPE; +} + +} // namespace op::adaptive_max_pool1d::cpu \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.h b/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.h new file mode 100644 index 000000000..f3e8ced3c --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.h @@ -0,0 +1,8 @@ +#ifndef __ADAPTIVE_MAX_POOL1D_CPU_H__ +#define __ADAPTIVE_MAX_POOL1D_CPU_H__ + +#include "../adaptive_max_pool1d.h" + +DESCRIPTOR(cpu) + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/info.h b/src/infiniop/ops/adaptive_max_pool1d/info.h new file mode 100644 index 000000000..142297cf4 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/info.h @@ -0,0 +1,69 @@ +#ifndef __ADAPATIVE_MAX_POOL1D_H__ +#define __ADAPATIVE_MAX_POOL1D_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include + +namespace op::adaptive_max_pool1d { + +class AdaptiveMaxPool1DInfo { + AdaptiveMaxPool1DInfo() = default; + +public: + infiniDtype_t atype; + std::vector shape; + std::vector y_strides; + std::vector x_strides; + size_t input_size; + size_t output_size; + size_t ndim() const { return shape.size(); } + size_t input_length() const { return input_size; } + size_t output_length() const { return output_size; } + + static utils::Result create( + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + + auto atype = y_desc->dtype(); + if (x_desc->dtype() != atype){ + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (atype != INFINI_DTYPE_F16 && atype != INFINI_DTYPE_BF16 && + atype != INFINI_DTYPE_F32 && atype != INFINI_DTYPE_F64) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + const size_t y_ndim = y_desc->ndim(); + const size_t x_ndim = x_desc->ndim(); + + if (y_ndim != x_ndim) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + for (size_t i = 0; i < y_ndim - 1; ++i) { + if (x_desc->dim(i) != y_desc->dim(i)){ + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + + if (y_desc->dim(y_ndim - 1) != output_size) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + return utils::Result(AdaptiveMaxPool1DInfo{ + atype, + y_desc->shape(), + y_desc->strides(), + x_desc->strides(), + x_desc->dim(x_ndim - 1), + output_size + }); + + } + +}; +} // namespace op::adaptive_max_pool1d + +#endif // __ADAPATIVE_MAX_POOL1D_H__ \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/operator.cc b/src/infiniop/ops/adaptive_max_pool1d/operator.cc new file mode 100644 index 000000000..ca3573b62 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/operator.cc @@ -0,0 +1,147 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/adaptive_max_pool1d.h" + +#ifdef ENABLE_CPU_API +#include "cpu/adaptive_max_pool1d_cpu.h" +#endif +// #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +// #include "nvidia/adaptive_max_pool1d_cuda.h" +// #endif +#ifdef ENABLE_METAX_API +#include "metax/adaptive_max_pool1d_metax.h" +#endif +#ifdef ENABLE_MOORE_API +#include "moore/adaptive_max_pool1d_moore.h" +#endif + +__C infiniStatus_t infiniopCreateAdaptiveMaxPool1dDescriptor( + infiniopHandle_t handle, + infiniopAdaptiveMaxPool1dDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::adaptive_max_pool1d::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + x_desc, \ + output_size) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +// #ifdef ENABLE_NVIDIA_API +// CREATE(INFINI_DEVICE_NVIDIA, nvidia); +// #endif +// #ifdef ENABLE_ILUVATAR_API +// CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +// #endif +#ifdef ENABLE_METAX_API + CREATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CREATE(INFINI_DEVICE_MOORE, moore); +#endif + } +#undef CREATE + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopGetAdaptiveMaxPool1dWorkspaceSize( + infiniopAdaptiveMaxPool1dDescriptor_t desc, + size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +// #ifdef ENABLE_NVIDIA_API +// GET(INFINI_DEVICE_NVIDIA, nvidia); +// #endif +// #ifdef ENABLE_ILUVATAR_API +// GET(INFINI_DEVICE_ILUVATAR, nvidia); +// #endif +#ifdef ENABLE_METAX_API + GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + GET(INFINI_DEVICE_MOORE, moore); +#endif + } +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopAdaptiveMaxPool1d( + infiniopAdaptiveMaxPool1dDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *x, + void *stream) { +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc)->calculate( \ + workspace, workspace_size, y, x, stream); + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +// #ifdef ENABLE_NVIDIA_API +// CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +// #endif +// #ifdef ENABLE_ILUVATAR_API +// CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +// #endif +#ifdef ENABLE_METAX_API + CALCULATE(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif + } +#undef CALCULATE + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopDestroyAdaptiveMaxPool1dDescriptor( + infiniopAdaptiveMaxPool1dDescriptor_t desc) { +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DESTROY(INFINI_DEVICE_CPU, cpu); +#endif +// #ifdef ENABLE_NVIDIA_API +// DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +// #endif +// #ifdef ENABLE_ILUVATAR_API +// DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); +// #endif +#ifdef ENABLE_METAX_API + DESTROY(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_MOORE_API + DESTROY(INFINI_DEVICE_MOORE, moore); +#endif + } +#undef DESTROY + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} \ No newline at end of file diff --git a/test/infinicore/ops/adaptive_max_pool1d.py b/test/infinicore/ops/adaptive_max_pool1d.py index 0e683b4f1..e3aa89d3b 100644 --- a/test/infinicore/ops/adaptive_max_pool1d.py +++ b/test/infinicore/ops/adaptive_max_pool1d.py @@ -67,9 +67,9 @@ def get_test_cases(self): def torch_operator(self, *args, **kwargs): return torch.nn.functional.adaptive_max_pool1d(*args, **kwargs) - # def infinicore_operator(self, *args, **kwargs): - # """InfiniCore implementation (operator not yet available).""" - # return infinicore.nn.functional.adaptive_max_pool1d(*args, **kwargs) + def infinicore_operator(self, *args, **kwargs): + """InfiniCore implementation (operator not yet available).""" + return infinicore.nn.functional.adaptive_max_pool1d(*args, **kwargs) def main(): From 78f8c769bdced1bbb44fc0fa0ec23a5dd2cf0a6e Mon Sep 17 00:00:00 2001 From: Kinorw Date: Wed, 10 Dec 2025 14:27:56 +0800 Subject: [PATCH 26/32] adaptive_max_pool1d nvidia --- .../ops/adaptive_max_pool1d/cuda/kernel.cuh | 55 +++++++ .../nvidia/adaptive_max_pool1d_nvidia.cu | 144 ++++++++++++++++++ .../nvidia/adaptive_max_pool1d_nvidia.cuh | 8 + .../ops/adaptive_max_pool1d/operator.cc | 56 +++---- 4 files changed, 235 insertions(+), 28 deletions(-) create mode 100644 src/infiniop/ops/adaptive_max_pool1d/cuda/kernel.cuh create mode 100644 src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu create mode 100644 src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cuh diff --git a/src/infiniop/ops/adaptive_max_pool1d/cuda/kernel.cuh b/src/infiniop/ops/adaptive_max_pool1d/cuda/kernel.cuh new file mode 100644 index 000000000..5f21748bc --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/cuda/kernel.cuh @@ -0,0 +1,55 @@ +#ifndef __ADAPTIVE_MAX_POOL1D_CUDA_KERNEL_H__ +#define __ADAPTIVE_MAX_POOL1D_CUDA_KERNEL_H__ + +#include +#include + +template +__device__ void adaptiveMaxPool1dBlock( + Tdata *__restrict__ y, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_channel, + const Tdata *__restrict__ x, + ptrdiff_t stride_x_batch, + ptrdiff_t stride_x_channel, + ptrdiff_t stride_x_length, + size_t channels, + size_t input_length, + size_t output_length, + size_t ndim){ + + size_t block_idx = blockIdx.x; + size_t batch_idx = block_idx / channels; + size_t channel_idx = block_idx % channels; + + const Tdata *x_ptr; + Tdata *y_ptr; + + if(ndim > 2) { + x_ptr = x + batch_idx * stride_x_batch + channel_idx * stride_x_channel; + y_ptr = y + batch_idx * stride_y_batch + channel_idx * stride_y_channel; + } else { + x_ptr = x + batch_idx * stride_x_batch; + y_ptr = y + batch_idx * stride_y_batch; + } + + for (size_t out_idx = threadIdx.x; out_idx < output_length ; out_idx += BLOCK_SIZE) { + int start_index = static_cast(floorf((float)out_idx * input_length / output_length)); + int end_index = static_cast(ceilf((float)(out_idx + 1) * input_length / output_length)); + + if (end_index <= start_index) { + continue; + } + + Tcompute max_val = Tcompute(x_ptr[start_index * stride_x_length]); + for (int i = start_index + 1; i < end_index; ++i) { + Tcompute val = Tcompute(x_ptr[i * stride_x_length]); + max_val = max(max_val, val); + } + + y_ptr[out_idx] = Tdata(max_val); + } + +} + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu b/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu new file mode 100644 index 000000000..6654d8ef2 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu @@ -0,0 +1,144 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "adaptive_max_pool1d_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" + +#include "../cuda/kernel.cuh" + +template +INFINIOP_CUDA_KERNEL adaptiveMaxPool1dKernel( + Tdata *__restrict__ y, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_channel, + const Tdata *__restrict__ x, + ptrdiff_t stride_x_batch, + ptrdiff_t stride_x_channel, + ptrdiff_t stride_x_length, + size_t channels, + size_t input_length, + size_t output_length, + size_t ndim){ + + adaptiveMaxPool1dBlock( + y, stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim); +} + +namespace op::adaptive_max_pool1d::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + auto result = AdaptiveMaxPool1DInfo::create(y_desc, x_desc, output_size); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel( + uint32_t num_blocks, + void *y, infiniDtype_t dtype, + ptrdiff_t stride_y_batch, ptrdiff_t stride_y_channel, + const void *x, + ptrdiff_t stride_x_batch, ptrdiff_t stride_x_channel, ptrdiff_t stride_x_length, + size_t channels, size_t input_length, size_t output_length, size_t ndim, + cudaStream_t cuda_stream) { + +#define LAUNCH_KERNEL(Tdata, Tcompute) \ + adaptiveMaxPool1dKernel<<>>( \ + reinterpret_cast(y), \ + stride_y_batch, stride_y_channel, \ + reinterpret_cast(x), \ + stride_x_batch, stride_x_channel, stride_x_length, \ + channels, input_length, output_length, ndim) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__nv_bfloat16, float); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float); + } else if (dtype == INFINI_DTYPE_F64) { + LAUNCH_KERNEL(double, double); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *x, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + const size_t ndim = _info.ndim(); + const size_t batch_size = _info.shape[0]; + const size_t channels = ndim > 2 ? _info.shape[1] : 1; + const size_t input_length = _info.input_length(); + const size_t output_length = _info.output_length(); + + ptrdiff_t stride_x_batch = _info.x_strides[0]; + ptrdiff_t stride_x_channel = ndim > 2 ? _info.x_strides[1] : 0; + ptrdiff_t stride_x_length = _info.x_strides.back(); + + ptrdiff_t stride_y_batch = _info.y_strides[0]; + ptrdiff_t stride_y_channel = ndim > 2 ? _info.y_strides[1] : 0; + + uint32_t num_blocks = static_cast(batch_size * channels); + auto cuda_stream = reinterpret_cast(stream); + + if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() >= CUDA_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + cuda_stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::adaptive_max_pool1d::nvidia \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cuh b/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cuh new file mode 100644 index 000000000..b980ce269 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ADAPTIVE_MAX_POOL1D_CUDA_H__ +#define __ADAPTIVE_MAX_POOL1D_CUDA_H__ + +#include "../adaptive_max_pool1d.h" + +DESCRIPTOR(nvidia) + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/operator.cc b/src/infiniop/ops/adaptive_max_pool1d/operator.cc index ca3573b62..c301db551 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/operator.cc +++ b/src/infiniop/ops/adaptive_max_pool1d/operator.cc @@ -5,11 +5,11 @@ #ifdef ENABLE_CPU_API #include "cpu/adaptive_max_pool1d_cpu.h" #endif -// #if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) -// #include "nvidia/adaptive_max_pool1d_cuda.h" -// #endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#include "nvidia/adaptive_max_pool1d_nvidia.cuh" +#endif #ifdef ENABLE_METAX_API -#include "metax/adaptive_max_pool1d_metax.h" +#include "metax/adaptive_max_pool1d_metax.cuh" #endif #ifdef ENABLE_MOORE_API #include "moore/adaptive_max_pool1d_moore.h" @@ -35,12 +35,12 @@ __C infiniStatus_t infiniopCreateAdaptiveMaxPool1dDescriptor( #ifdef ENABLE_CPU_API CREATE(INFINI_DEVICE_CPU, cpu); #endif -// #ifdef ENABLE_NVIDIA_API -// CREATE(INFINI_DEVICE_NVIDIA, nvidia); -// #endif -// #ifdef ENABLE_ILUVATAR_API -// CREATE(INFINI_DEVICE_ILUVATAR, nvidia); -// #endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif #ifdef ENABLE_METAX_API CREATE(INFINI_DEVICE_METAX, metax); #endif @@ -65,12 +65,12 @@ __C infiniStatus_t infiniopGetAdaptiveMaxPool1dWorkspaceSize( #ifdef ENABLE_CPU_API GET(INFINI_DEVICE_CPU, cpu); #endif -// #ifdef ENABLE_NVIDIA_API -// GET(INFINI_DEVICE_NVIDIA, nvidia); -// #endif -// #ifdef ENABLE_ILUVATAR_API -// GET(INFINI_DEVICE_ILUVATAR, nvidia); -// #endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif #ifdef ENABLE_METAX_API GET(INFINI_DEVICE_METAX, metax); #endif @@ -99,12 +99,12 @@ __C infiniStatus_t infiniopAdaptiveMaxPool1d( #ifdef ENABLE_CPU_API CALCULATE(INFINI_DEVICE_CPU, cpu); #endif -// #ifdef ENABLE_NVIDIA_API -// CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); -// #endif -// #ifdef ENABLE_ILUVATAR_API -// CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); -// #endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif #ifdef ENABLE_METAX_API CALCULATE(INFINI_DEVICE_METAX, metax); #endif @@ -128,12 +128,12 @@ __C infiniStatus_t infiniopDestroyAdaptiveMaxPool1dDescriptor( #ifdef ENABLE_CPU_API DESTROY(INFINI_DEVICE_CPU, cpu); #endif -// #ifdef ENABLE_NVIDIA_API -// DESTROY(INFINI_DEVICE_NVIDIA, nvidia); -// #endif -// #ifdef ENABLE_ILUVATAR_API -// DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); -// #endif +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); +#endif #ifdef ENABLE_METAX_API DESTROY(INFINI_DEVICE_METAX, metax); #endif From 0bea611396cb7a28688b2549df2b98e6e26303a2 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Wed, 10 Dec 2025 16:53:48 +0800 Subject: [PATCH 27/32] =?UTF-8?q?=E4=BF=AE=E6=AD=A3=E5=91=BD=E5=90=8D,moor?= =?UTF-8?q?e=E5=92=8Cmetax=E5=B9=B3=E5=8F=B0=E5=AE=9E=E7=8E=B0(=E5=BE=85?= =?UTF-8?q?=E6=B5=8B=E8=AF=95)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../infinicore/ops/adaptive_max_pool1d.hpp | 2 +- .../adaptive_max_pool1d.cc | 8 +- .../adaptive_max_pool1d_infiniop.cc | 2 +- .../adaptive_max_pool1d/adaptive_max_pool1d.h | 4 +- .../cpu/adaptive_max_pool1d_cpu.cc | 12 +- src/infiniop/ops/adaptive_max_pool1d/info.h | 8 +- .../metax/adaptive_max_pool1d_metax.cuh | 8 + .../metax/adaptive_max_pool1d_metax.maca | 130 ++++++++++++++++ .../moore/adaptive_max_pool1d_moore.h | 144 ++++++++++++++++++ .../moore/adaptive_max_pool1d_moore.mu | 8 + .../nvidia/adaptive_max_pool1d_nvidia.cu | 2 +- 11 files changed, 309 insertions(+), 19 deletions(-) create mode 100644 src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.cuh create mode 100644 src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.maca create mode 100644 src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.h create mode 100644 src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.mu diff --git a/include/infinicore/ops/adaptive_max_pool1d.hpp b/include/infinicore/ops/adaptive_max_pool1d.hpp index 7d1d03494..05e49b490 100644 --- a/include/infinicore/ops/adaptive_max_pool1d.hpp +++ b/include/infinicore/ops/adaptive_max_pool1d.hpp @@ -4,7 +4,7 @@ #include "common/op.hpp" namespace infinicore::op { -class AdaptiveMaxPool1D { +class AdaptiveMaxPool1d { public: using schema = void (*)(Tensor, Tensor, size_t); static void execute(Tensor y, Tensor x, size_t output_size); diff --git a/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc index de2d24508..e615eb292 100644 --- a/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc +++ b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc @@ -4,12 +4,12 @@ namespace infinicore::op{ -common::OpDispatcher &AdaptiveMaxPool1D::dispatcher() { - static common::OpDispatcher dispatcher_; +common::OpDispatcher &AdaptiveMaxPool1d::dispatcher() { + static common::OpDispatcher dispatcher_; return dispatcher_; } -void AdaptiveMaxPool1D::execute(Tensor y, Tensor x, size_t output_size) { +void AdaptiveMaxPool1d::execute(Tensor y, Tensor x, size_t output_size) { INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, x); infinicore::context::setDevice(y->device()); dispatcher().lookup(y->device().getType())(y, x, output_size); @@ -24,7 +24,7 @@ Tensor adaptive_max_pool1d(Tensor x, size_t output_size) { } void adaptive_max_pool1d_(Tensor y, Tensor x, size_t output_size) { - AdaptiveMaxPool1D::execute(y, x, output_size); + AdaptiveMaxPool1d::execute(y, x, output_size); } } // namespace infinicore::op diff --git a/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc index 90b8eb77c..630d067a8 100644 --- a/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc +++ b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc @@ -45,7 +45,7 @@ void calculate(Tensor y, Tensor x, size_t out) { } static bool registered = []() { - AdaptiveMaxPool1D::dispatcher().registerAll(&calculate, false); + AdaptiveMaxPool1d::dispatcher().registerAll(&calculate, false); return true; }(); diff --git a/src/infiniop/ops/adaptive_max_pool1d/adaptive_max_pool1d.h b/src/infiniop/ops/adaptive_max_pool1d/adaptive_max_pool1d.h index 9c2fc797a..288c2ece4 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/adaptive_max_pool1d.h +++ b/src/infiniop/ops/adaptive_max_pool1d/adaptive_max_pool1d.h @@ -10,12 +10,12 @@ class Descriptor final : public InfiniopDescriptor { \ struct Opaque; \ Opaque *_opaque; \ - AdaptiveMaxPool1DInfo _info; \ + AdaptiveMaxPool1dInfo _info; \ size_t _workspace_size; \ \ Descriptor( \ Opaque *opaque, \ - AdaptiveMaxPool1DInfo info, \ + AdaptiveMaxPool1dInfo info, \ size_t workspace_size, \ infiniDevice_t device_type, \ int device_id) \ diff --git a/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc b/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc index 3f2e8e450..75ade0eb3 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc +++ b/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc @@ -14,14 +14,14 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc, size_t output_size) { - auto result = AdaptiveMaxPool1DInfo::create(y_desc, x_desc, output_size); + auto result = AdaptiveMaxPool1dInfo::create(y_desc, x_desc, output_size); CHECK_RESULT(result); *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id); return INFINI_STATUS_SUCCESS; } template -infiniStatus_t adaptiveMaxPool1D(const AdaptiveMaxPool1DInfo *info, T *y, const T *x) { +infiniStatus_t adaptiveMaxPool1d(const AdaptiveMaxPool1dInfo *info, T *y, const T *x) { const size_t ndim = info->ndim(); const size_t batch_size = info->shape[0]; @@ -83,13 +83,13 @@ infiniStatus_t Descriptor::calculate( void *stream) const { if (_info.atype == INFINI_DTYPE_F32) { - return adaptiveMaxPool1D(&_info, (float *)y, (const float *)x); + return adaptiveMaxPool1d(&_info, (float *)y, (const float *)x); } else if (_info.atype == INFINI_DTYPE_F16) { - return adaptiveMaxPool1D(&_info, (fp16_t *)y, (const fp16_t *)x); + return adaptiveMaxPool1d(&_info, (fp16_t *)y, (const fp16_t *)x); } else if (_info.atype == INFINI_DTYPE_BF16) { - return adaptiveMaxPool1D(&_info, (bf16_t *)y, (const bf16_t *)x); + return adaptiveMaxPool1d(&_info, (bf16_t *)y, (const bf16_t *)x); } else if (_info.atype == INFINI_DTYPE_F64) { - return adaptiveMaxPool1D(&_info, (double *)y, (const double *)x); + return adaptiveMaxPool1d(&_info, (double *)y, (const double *)x); } return INFINI_STATUS_BAD_TENSOR_DTYPE; diff --git a/src/infiniop/ops/adaptive_max_pool1d/info.h b/src/infiniop/ops/adaptive_max_pool1d/info.h index 142297cf4..f346bca9a 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/info.h +++ b/src/infiniop/ops/adaptive_max_pool1d/info.h @@ -7,8 +7,8 @@ namespace op::adaptive_max_pool1d { -class AdaptiveMaxPool1DInfo { - AdaptiveMaxPool1DInfo() = default; +class AdaptiveMaxPool1dInfo { + AdaptiveMaxPool1dInfo() = default; public: infiniDtype_t atype; @@ -21,7 +21,7 @@ class AdaptiveMaxPool1DInfo { size_t input_length() const { return input_size; } size_t output_length() const { return output_size; } - static utils::Result create( + static utils::Result create( infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc, size_t output_size) { @@ -52,7 +52,7 @@ class AdaptiveMaxPool1DInfo { return INFINI_STATUS_BAD_TENSOR_SHAPE; } - return utils::Result(AdaptiveMaxPool1DInfo{ + return utils::Result(AdaptiveMaxPool1dInfo{ atype, y_desc->shape(), y_desc->strides(), diff --git a/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.cuh b/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.cuh new file mode 100644 index 000000000..fcd068b6d --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.cuh @@ -0,0 +1,8 @@ +#ifndef __ADAPTIVE_MAX_POOL1D_METAX_CUH__ +#define __ADAPTIVE_MAX_POOL1D_METAX_CUH__ + +#include "../adaptive_max_pool1d.h" + +DESCRIPTOR(metax) + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.maca b/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.maca new file mode 100644 index 000000000..21b117c9e --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.maca @@ -0,0 +1,130 @@ +#include "../../../devices/metax/metax_common.h" +#include "adaptive_max_pool1d_metax.cuh" + +#include "../../../devices/metax/metax_kernel_common.h" + +#include "../cuda/kernel.cuh" + +template +INFINIOP_METAX_KERNEL adaptiveMaxPool1dKernel( + Tdata *__restrict__ y, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_channel, + const Tdata *__restrict__ x, + ptrdiff_t stride_x_batch, + ptrdiff_t stride_x_channel, + ptrdiff_t stride_x_length, + size_t channels, + size_t input_length, + size_t output_length, + size_t ndim) { + + adaptiveMaxPool1dBlock( + y, stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length,ndim); +} + +namespace op::adaptive_max_pool1d::metax { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor(){ + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + + auto result = AdaptiveMaxPool1dInfo::create(y_desc, x_desc, output_size); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel( + uint32_t numblock, + void *y, infiniDtype_t dtype, + ptrdiff_t stride_x_batch, ptrdiff_t stride_y_channel, + cosnt void *x, + ptrdiff_t stride_x_batch, ptrdiff_t stride_x_channel, ptrdiff_t stride_x_length, + size_t channels, size_t input_length, size_t output_length, size_t ndim, + hcStream_t stream){ + +#define LAUNCH_KERNEL(Tdata, Tcompute) \ + adaptiveMaxPool1dKernel<<>> ( \ + reinterpret_cast(y), \ + stride_y_batch, stride_y_channel, \ + reinterpret_cast(x), \ + stride_x_batch, stride_x_channel, stride_x_length, \ + channels, input_length, output_length, ndim) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__hpcc_bfloat16, float); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float); + } else if (dtype == INFINI_DTYPE_F64) { + LAUNCH_KERNEL(double, double); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +nfiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *x, + void *stream_) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + const size_t ndim = _info.ndim(); + const size_t batch_size = _info.shape[0]; + const size_t channels = ndim > 2 ? _info.shape[1] : 1; + const size_t input_length = _info.input_length(); + const size_t output_length = _info.output_length(); + + ptrdiff_t stride_x_batch = _info.x_strides[0]; + ptrdiff_t stride_x_channel = ndim > 2 ? _info.x_strides[1] : 0; + ptrdiff_t stride_x_length = _info.x_strides.back(); + + ptrdiff_t stride_y_batch = _info.y_strides[0]; + ptrdiff_t stride_y_channel = ndim > 2 ? _info.y_strides[1] : 0; + + uint32_t num_blocks = static_cast(batch_size * channels); + auto stream = reinterpret_cast(stream_); + + if (_opaque->internal->maxThreadsPerBlock() >= METAX_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + cuda_stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::adaptive_max_pool1d::metax \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.h b/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.h new file mode 100644 index 000000000..256392f78 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.h @@ -0,0 +1,144 @@ +#include "../../../devices/moore/moore_common.h" +#include "adaptive_max_pool1d_moore.h" + +#include "../../../devices/moore/moore_kernel_common.h" + +#include "../cuda/kernel.cuh" + +template +INFINIOP_MOORE_KERNEL adaptiveMaxPool1dKernel( + Tdata *__restrict__ y, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_channel, + const Tdata *__restrict__ x, + ptrdiff_t stride_x_batch, + ptrdiff_t stride_x_channel, + ptrdiff_t stride_x_length, + size_t channels, + size_t input_length, + size_t output_length, + size_t ndim){ + + adaptiveMaxPool1dBlock( + y, stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim); +} + +namespace op::adaptive_max_pool1d::moore { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + auto result = AdaptiveMaxPool1dInfo::create(y_desc, x_desc, output_size); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel( + uint32_t num_blocks, + void *y, infiniDtype_t dtype, + ptrdiff_t stride_y_batch, ptrdiff_t stride_y_channel, + const void *x, + ptrdiff_t stride_x_batch, ptrdiff_t stride_x_channel, ptrdiff_t stride_x_length, + size_t channels, size_t input_length, size_t output_length, size_t ndim, + musaStream_t musa_stream) { + +#define LAUNCH_KERNEL(Tdata, Tcompute) \ + adaptiveMaxPool1dKernel<<>>( \ + reinterpret_cast(y), \ + stride_y_batch, stride_y_channel, \ + reinterpret_cast(x), \ + stride_x_batch, stride_x_channel, stride_x_length, \ + channels, input_length, output_length, ndim) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__mt_bfloat16, float); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float); + } else if (dtype == INFINI_DTYPE_F64) { + LAUNCH_KERNEL(double, double); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *x, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + const size_t ndim = _info.ndim(); + const size_t batch_size = _info.shape[0]; + const size_t channels = ndim > 2 ? _info.shape[1] : 1; + const size_t input_length = _info.input_length(); + const size_t output_length = _info.output_length(); + + ptrdiff_t stride_x_batch = _info.x_strides[0]; + ptrdiff_t stride_x_channel = ndim > 2 ? _info.x_strides[1] : 0; + ptrdiff_t stride_x_length = _info.x_strides.back(); + + ptrdiff_t stride_y_batch = _info.y_strides[0]; + ptrdiff_t stride_y_channel = ndim > 2 ? _info.y_strides[1] : 0; + + uint32_t num_blocks = static_cast(batch_size * channels); + auto musa_stream = reinterpret_cast(stream); + + if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + musa_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + musa_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + musa_stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::adaptive_max_pool1d::moore \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.mu b/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.mu new file mode 100644 index 000000000..c56ad6fd4 --- /dev/null +++ b/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.mu @@ -0,0 +1,8 @@ +#ifndef __ADAPTIVE_MAX_POOL1D_MOOORE_H__ +#define __ADAPTIVE_MAX_POOL1D_MOOORE_H__ + +#include "../adaptive_max_pool1d.h" + +DESCRIPTOR(moore) + +#endif diff --git a/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu b/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu index 6654d8ef2..f85f0a135 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu +++ b/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu @@ -41,7 +41,7 @@ infiniStatus_t Descriptor::create( infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc, size_t output_size) { - auto result = AdaptiveMaxPool1DInfo::create(y_desc, x_desc, output_size); + auto result = AdaptiveMaxPool1dInfo::create(y_desc, x_desc, output_size); CHECK_RESULT(result); auto info = result.take(); From 2aa7819090711d10a80e3c516edd1a5130f5b09f Mon Sep 17 00:00:00 2001 From: Kinorw Date: Wed, 10 Dec 2025 19:31:46 +0800 Subject: [PATCH 28/32] =?UTF-8?q?moore=E5=B9=B3=E5=8F=B0=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E5=AE=8C=E6=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../moore/adaptive_max_pool1d_moore.h | 146 +----------------- .../moore/adaptive_max_pool1d_moore.mu | 146 +++++++++++++++++- 2 files changed, 146 insertions(+), 146 deletions(-) diff --git a/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.h b/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.h index 256392f78..c56ad6fd4 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.h +++ b/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.h @@ -1,144 +1,8 @@ -#include "../../../devices/moore/moore_common.h" -#include "adaptive_max_pool1d_moore.h" +#ifndef __ADAPTIVE_MAX_POOL1D_MOOORE_H__ +#define __ADAPTIVE_MAX_POOL1D_MOOORE_H__ -#include "../../../devices/moore/moore_kernel_common.h" +#include "../adaptive_max_pool1d.h" -#include "../cuda/kernel.cuh" +DESCRIPTOR(moore) -template -INFINIOP_MOORE_KERNEL adaptiveMaxPool1dKernel( - Tdata *__restrict__ y, - ptrdiff_t stride_y_batch, - ptrdiff_t stride_y_channel, - const Tdata *__restrict__ x, - ptrdiff_t stride_x_batch, - ptrdiff_t stride_x_channel, - ptrdiff_t stride_x_length, - size_t channels, - size_t input_length, - size_t output_length, - size_t ndim){ - - adaptiveMaxPool1dBlock( - y, stride_y_batch, stride_y_channel, - x, stride_x_batch, stride_x_channel, stride_x_length, - channels, input_length, output_length, ndim); -} - -namespace op::adaptive_max_pool1d::moore { - -struct Descriptor::Opaque { - std::shared_ptr internal; -}; - -Descriptor::~Descriptor() { - delete _opaque; -} - -infiniStatus_t Descriptor::create( - infiniopHandle_t handle, - Descriptor **desc_ptr, - infiniopTensorDescriptor_t y_desc, - infiniopTensorDescriptor_t x_desc, - size_t output_size) { - auto result = AdaptiveMaxPool1dInfo::create(y_desc, x_desc, output_size); - CHECK_RESULT(result); - auto info = result.take(); - - *desc_ptr = new Descriptor( - new Opaque{reinterpret_cast(handle)->internal()}, - std::move(info), - 0, - handle->device, handle->device_id); - return INFINI_STATUS_SUCCESS; -} - -template -infiniStatus_t launchKernel( - uint32_t num_blocks, - void *y, infiniDtype_t dtype, - ptrdiff_t stride_y_batch, ptrdiff_t stride_y_channel, - const void *x, - ptrdiff_t stride_x_batch, ptrdiff_t stride_x_channel, ptrdiff_t stride_x_length, - size_t channels, size_t input_length, size_t output_length, size_t ndim, - musaStream_t musa_stream) { - -#define LAUNCH_KERNEL(Tdata, Tcompute) \ - adaptiveMaxPool1dKernel<<>>( \ - reinterpret_cast(y), \ - stride_y_batch, stride_y_channel, \ - reinterpret_cast(x), \ - stride_x_batch, stride_x_channel, stride_x_length, \ - channels, input_length, output_length, ndim) - - if (dtype == INFINI_DTYPE_F16) { - LAUNCH_KERNEL(half, float); - } else if (dtype == INFINI_DTYPE_BF16) { - LAUNCH_KERNEL(__mt_bfloat16, float); - } else if (dtype == INFINI_DTYPE_F32) { - LAUNCH_KERNEL(float, float); - } else if (dtype == INFINI_DTYPE_F64) { - LAUNCH_KERNEL(double, double); - } else { - return INFINI_STATUS_BAD_TENSOR_DTYPE; - } - -#undef LAUNCH_KERNEL - - return INFINI_STATUS_SUCCESS; -} - -infiniStatus_t Descriptor::calculate( - void *workspace, size_t workspace_size, - void *y, const void *x, - void *stream) const { - - if (workspace_size < _workspace_size) { - return INFINI_STATUS_INSUFFICIENT_WORKSPACE; - } - - const size_t ndim = _info.ndim(); - const size_t batch_size = _info.shape[0]; - const size_t channels = ndim > 2 ? _info.shape[1] : 1; - const size_t input_length = _info.input_length(); - const size_t output_length = _info.output_length(); - - ptrdiff_t stride_x_batch = _info.x_strides[0]; - ptrdiff_t stride_x_channel = ndim > 2 ? _info.x_strides[1] : 0; - ptrdiff_t stride_x_length = _info.x_strides.back(); - - ptrdiff_t stride_y_batch = _info.y_strides[0]; - ptrdiff_t stride_y_channel = ndim > 2 ? _info.y_strides[1] : 0; - - uint32_t num_blocks = static_cast(batch_size * channels); - auto musa_stream = reinterpret_cast(stream); - - if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_1024) { - CHECK_STATUS(launchKernel( - num_blocks, y, _info.atype, - stride_y_batch, stride_y_channel, - x, stride_x_batch, stride_x_channel, stride_x_length, - channels, input_length, output_length, ndim, - musa_stream)); - } else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) { - CHECK_STATUS(launchKernel( - num_blocks, y, _info.atype, - stride_y_batch, stride_y_channel, - x, stride_x_batch, stride_x_channel, stride_x_length, - channels, input_length, output_length, ndim, - musa_stream)); - } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) { - CHECK_STATUS(launchKernel( - num_blocks, y, _info.atype, - stride_y_batch, stride_y_channel, - x, stride_x_batch, stride_x_channel, stride_x_length, - channels, input_length, output_length, ndim, - musa_stream)); - } else { - return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; - } - - return INFINI_STATUS_SUCCESS; -} - -} // namespace op::adaptive_max_pool1d::moore \ No newline at end of file +#endif diff --git a/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.mu b/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.mu index c56ad6fd4..256392f78 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.mu +++ b/src/infiniop/ops/adaptive_max_pool1d/moore/adaptive_max_pool1d_moore.mu @@ -1,8 +1,144 @@ -#ifndef __ADAPTIVE_MAX_POOL1D_MOOORE_H__ -#define __ADAPTIVE_MAX_POOL1D_MOOORE_H__ +#include "../../../devices/moore/moore_common.h" +#include "adaptive_max_pool1d_moore.h" -#include "../adaptive_max_pool1d.h" +#include "../../../devices/moore/moore_kernel_common.h" -DESCRIPTOR(moore) +#include "../cuda/kernel.cuh" -#endif +template +INFINIOP_MOORE_KERNEL adaptiveMaxPool1dKernel( + Tdata *__restrict__ y, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_channel, + const Tdata *__restrict__ x, + ptrdiff_t stride_x_batch, + ptrdiff_t stride_x_channel, + ptrdiff_t stride_x_length, + size_t channels, + size_t input_length, + size_t output_length, + size_t ndim){ + + adaptiveMaxPool1dBlock( + y, stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim); +} + +namespace op::adaptive_max_pool1d::moore { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + size_t output_size) { + auto result = AdaptiveMaxPool1dInfo::create(y_desc, x_desc, output_size); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t launchKernel( + uint32_t num_blocks, + void *y, infiniDtype_t dtype, + ptrdiff_t stride_y_batch, ptrdiff_t stride_y_channel, + const void *x, + ptrdiff_t stride_x_batch, ptrdiff_t stride_x_channel, ptrdiff_t stride_x_length, + size_t channels, size_t input_length, size_t output_length, size_t ndim, + musaStream_t musa_stream) { + +#define LAUNCH_KERNEL(Tdata, Tcompute) \ + adaptiveMaxPool1dKernel<<>>( \ + reinterpret_cast(y), \ + stride_y_batch, stride_y_channel, \ + reinterpret_cast(x), \ + stride_x_batch, stride_x_channel, stride_x_length, \ + channels, input_length, output_length, ndim) + + if (dtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, float); + } else if (dtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__mt_bfloat16, float); + } else if (dtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float); + } else if (dtype == INFINI_DTYPE_F64) { + LAUNCH_KERNEL(double, double); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *x, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + const size_t ndim = _info.ndim(); + const size_t batch_size = _info.shape[0]; + const size_t channels = ndim > 2 ? _info.shape[1] : 1; + const size_t input_length = _info.input_length(); + const size_t output_length = _info.output_length(); + + ptrdiff_t stride_x_batch = _info.x_strides[0]; + ptrdiff_t stride_x_channel = ndim > 2 ? _info.x_strides[1] : 0; + ptrdiff_t stride_x_length = _info.x_strides.back(); + + ptrdiff_t stride_y_batch = _info.y_strides[0]; + ptrdiff_t stride_y_channel = ndim > 2 ? _info.y_strides[1] : 0; + + uint32_t num_blocks = static_cast(batch_size * channels); + auto musa_stream = reinterpret_cast(stream); + + if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + musa_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() >= MOORE_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + musa_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) { + CHECK_STATUS(launchKernel( + num_blocks, y, _info.atype, + stride_y_batch, stride_y_channel, + x, stride_x_batch, stride_x_channel, stride_x_length, + channels, input_length, output_length, ndim, + musa_stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::adaptive_max_pool1d::moore \ No newline at end of file From e2d2e3e4f2b60bec0b7ab4595e8b2e96fa9cb8bb Mon Sep 17 00:00:00 2001 From: Kinorw Date: Wed, 10 Dec 2025 20:35:05 +0800 Subject: [PATCH 29/32] =?UTF-8?q?metax=E5=AE=9E=E7=8E=B0=E5=8B=98=E8=AF=AF?= =?UTF-8?q?(gpu=E6=97=A0=E7=A9=BA=E9=97=B2,=E6=9A=82=E6=9C=AA=E6=B5=8B?= =?UTF-8?q?=E8=AF=95)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../metax/adaptive_max_pool1d_metax.maca | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.maca b/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.maca index 21b117c9e..f72aae852 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.maca +++ b/src/infiniop/ops/adaptive_max_pool1d/metax/adaptive_max_pool1d_metax.maca @@ -58,8 +58,8 @@ template infiniStatus_t launchKernel( uint32_t numblock, void *y, infiniDtype_t dtype, - ptrdiff_t stride_x_batch, ptrdiff_t stride_y_channel, - cosnt void *x, + ptrdiff_t stride_y_batch, ptrdiff_t stride_y_channel, + const void *x, ptrdiff_t stride_x_batch, ptrdiff_t stride_x_channel, ptrdiff_t stride_x_length, size_t channels, size_t input_length, size_t output_length, size_t ndim, hcStream_t stream){ @@ -88,7 +88,7 @@ infiniStatus_t launchKernel( return INFINI_STATUS_SUCCESS; } -nfiniStatus_t Descriptor::calculate( +infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, void *y, const void *x, void *stream_) const { @@ -119,7 +119,7 @@ nfiniStatus_t Descriptor::calculate( stride_y_batch, stride_y_channel, x, stride_x_batch, stride_x_channel, stride_x_length, channels, input_length, output_length, ndim, - cuda_stream)); + stream)); } else { return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; } From 3c3c7f16e999903be4ff91a5320a041b8aa8c049 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Fri, 12 Dec 2025 14:40:48 +0800 Subject: [PATCH 30/32] =?UTF-8?q?=E4=BF=AE=E6=94=B9core=E6=89=A7=E8=A1=8C?= =?UTF-8?q?=E9=83=A8=E5=88=86,=E7=AC=A6=E5=90=88=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E5=90=8E=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infinicore/ops/fmod/fmod.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/infinicore/ops/fmod/fmod.cc b/src/infinicore/ops/fmod/fmod.cc index d5aa0d0c7..30bee17d6 100644 --- a/src/infinicore/ops/fmod/fmod.cc +++ b/src/infinicore/ops/fmod/fmod.cc @@ -1,5 +1,7 @@ #include "infinicore/ops/fmod.hpp" +#include "../../utils.hpp" + namespace infinicore::op { common::OpDispatcher &Fmod::dispatcher() { @@ -8,7 +10,9 @@ common::OpDispatcher &Fmod::dispatcher() { }; void Fmod::execute(Tensor c, Tensor a, Tensor b) { - dispatcher().lookup(context::getDevice().getType())(c, a, b); + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(c, a, b); + infinicore::context::setDevice(c->device()); + dispatcher().lookup(c->device().getType())(c, a, b); } Tensor fmod(Tensor a, Tensor b) { From 108603ed88c5b46b7e73cbad194e75e563a2917e Mon Sep 17 00:00:00 2001 From: Kinorw Date: Sat, 13 Dec 2025 20:35:51 +0800 Subject: [PATCH 31/32] =?UTF-8?q?=E6=B7=BB=E5=8A=A0gemm=20beta=E6=A3=80?= =?UTF-8?q?=E6=9F=A5,=E9=81=BF=E5=85=8Dbeta=E4=B8=BA0=E8=AE=BF=E9=97=AE?= =?UTF-8?q?=E5=88=B0=E8=84=8F=E5=86=85=E5=AD=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops.hpp | 2 +- include/infinicore/ops/asinh.hpp | 2 +- include/infinicore/ops/baddbmm.hpp | 12 +- include/infiniop.h | 2 +- include/infiniop/ops/adaptive_max_pool1d.h | 4 +- .../nn/functional/adaptive_max_pool1d.py | 10 +- python/infinicore/ops/baddbmm.py | 3 +- python/infinicore/ops/bilinear.py | 3 +- .../adaptive_max_pool1d.cc | 2 +- .../adaptive_max_pool1d_infiniop.cc | 2 +- src/infinicore/ops/baddbmm/baddbmm.cc | 54 ++-- src/infinicore/ops/bilinear/bilinear.cc | 43 ++- src/infinicore/ops/fmod/fmod_infiniop.cc | 3 +- src/infinicore/pybind11/ops.hpp | 2 +- src/infinicore/pybind11/ops/baddbmm.hpp | 21 +- src/infinicore/pybind11/ops/bilinear.hpp | 17 +- .../cpu/adaptive_max_pool1d_cpu.cc | 26 +- .../ops/adaptive_max_pool1d/cuda/kernel.cuh | 13 +- src/infiniop/ops/adaptive_max_pool1d/info.h | 16 +- .../nvidia/adaptive_max_pool1d_nvidia.cu | 12 +- .../ops/adaptive_max_pool1d/operator.cc | 38 +-- src/infiniop/ops/asinh/cpu/asinh_cpu.cc | 16 +- src/infiniop/ops/asinh/cpu/asinh_cpu.h | 2 +- src/infiniop/ops/asinh/cuda/kernel.cuh | 20 +- src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu | 13 +- src/infiniop/ops/asinh/operator.cc | 42 +-- src/infiniop/ops/fmod/cpu/fmod_cpu.cc | 2 +- src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu | 14 +- src/infiniop/ops/fmod/operator.cc | 32 +- src/infiniop/ops/gemm/cpu/gemm_cpu.cc | 6 +- test/infinicore/ops/bilinear.py | 3 +- test/infiniop/bilinear.py | 298 +++++++++--------- test/infiniop/libinfiniop/op_register.py | 38 --- 33 files changed, 369 insertions(+), 404 deletions(-) diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 1a6cb51a4..79a0655be 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -2,8 +2,8 @@ #include "ops/adaptive_max_pool1d.hpp" #include "ops/add.hpp" -#include "ops/attention.hpp" #include "ops/asinh.hpp" +#include "ops/attention.hpp" #include "ops/baddbmm.hpp" #include "ops/bilinear.hpp" #include "ops/causal_softmax.hpp" diff --git a/include/infinicore/ops/asinh.hpp b/include/infinicore/ops/asinh.hpp index 1d04d6bea..505eb97d9 100644 --- a/include/infinicore/ops/asinh.hpp +++ b/include/infinicore/ops/asinh.hpp @@ -10,7 +10,7 @@ class Asinh { static void execute(Tensor y, Tensor x); static common::OpDispatcher &dispatcher(); }; - + Tensor asinh(Tensor x); void asinh_(Tensor y, Tensor x); } // namespace infinicore::op diff --git a/include/infinicore/ops/baddbmm.hpp b/include/infinicore/ops/baddbmm.hpp index 73e20a633..3c08b98d9 100644 --- a/include/infinicore/ops/baddbmm.hpp +++ b/include/infinicore/ops/baddbmm.hpp @@ -6,10 +6,10 @@ namespace infinicore::op { -Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, - float beta = 1.0f, - float alpha = 1.0f); -void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, - float beta = 1.0f, - float alpha = 1.0f); +Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, + float beta = 1.0f, + float alpha = 1.0f); +void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, + float beta = 1.0f, + float alpha = 1.0f); } // namespace infinicore::op \ No newline at end of file diff --git a/include/infiniop.h b/include/infiniop.h index 9bb21e0ca..c4c114559 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -4,8 +4,8 @@ #include "infiniop/handle.h" #include "infiniop/ops/adaptive_max_pool1d.h" #include "infiniop/ops/add.h" -#include "infiniop/ops/attention.h" #include "infiniop/ops/asinh.h" +#include "infiniop/ops/attention.h" #include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/clip.h" #include "infiniop/ops/conv.h" diff --git a/include/infiniop/ops/adaptive_max_pool1d.h b/include/infiniop/ops/adaptive_max_pool1d.h index 17a126abb..484413e21 100644 --- a/include/infiniop/ops/adaptive_max_pool1d.h +++ b/include/infiniop/ops/adaptive_max_pool1d.h @@ -15,8 +15,8 @@ __C __export infiniStatus_t infiniopCreateAdaptiveMaxPool1dDescriptor( __C __export infiniStatus_t infiniopGetAdaptiveMaxPool1dWorkspaceSize(infiniopAdaptiveMaxPool1dDescriptor_t desc, size_t *size); __C __export infiniStatus_t infiniopAdaptiveMaxPool1d(infiniopAdaptiveMaxPool1dDescriptor_t desc, void *workspace, size_t workspace_size, - void *y, const void *x, void *stream); + void *y, const void *x, void *stream); __C __export infiniStatus_t infiniopDestroyAdaptiveMaxPool1dDescriptor(infiniopAdaptiveMaxPool1dDescriptor_t desc); -#endif \ No newline at end of file +#endif \ No newline at end of file diff --git a/python/infinicore/nn/functional/adaptive_max_pool1d.py b/python/infinicore/nn/functional/adaptive_max_pool1d.py index 355b7e096..74a8c56e9 100644 --- a/python/infinicore/nn/functional/adaptive_max_pool1d.py +++ b/python/infinicore/nn/functional/adaptive_max_pool1d.py @@ -32,12 +32,8 @@ def adaptive_max_pool1d( """ if out is None: - return Tensor( - _infinicore.adaptive_max_pool1d(input._underlying, output_size) - ) + return Tensor(_infinicore.adaptive_max_pool1d(input._underlying, output_size)) - _infinicore.adaptive_max_pool1d_( - out._underlying, input._underlying, output_size - ) + _infinicore.adaptive_max_pool1d_(out._underlying, input._underlying, output_size) - return out \ No newline at end of file + return out diff --git a/python/infinicore/ops/baddbmm.py b/python/infinicore/ops/baddbmm.py index 96b2544a9..4a34cbb64 100644 --- a/python/infinicore/ops/baddbmm.py +++ b/python/infinicore/ops/baddbmm.py @@ -1,6 +1,7 @@ from infinicore.lib import _infinicore from infinicore.tensor import Tensor + def baddbmm(input, batch1, batch2, *, beta=1.0, alpha=1.0, out=None): if out is None: return Tensor( @@ -21,4 +22,4 @@ def baddbmm(input, batch1, batch2, *, beta=1.0, alpha=1.0, out=None): float(alpha), ) - return out \ No newline at end of file + return out diff --git a/python/infinicore/ops/bilinear.py b/python/infinicore/ops/bilinear.py index 887bd6016..4773dd825 100644 --- a/python/infinicore/ops/bilinear.py +++ b/python/infinicore/ops/bilinear.py @@ -1,6 +1,7 @@ from infinicore.lib import _infinicore from infinicore.tensor import Tensor + def bilinear(input1, input2, weight, bias=None, *, out=None): if out is None: return Tensor( @@ -19,4 +20,4 @@ def bilinear(input1, input2, weight, bias=None, *, out=None): bias._underlying if bias is not None else None, ) - return out \ No newline at end of file + return out diff --git a/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc index e615eb292..bd80b0771 100644 --- a/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc +++ b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d.cc @@ -2,7 +2,7 @@ #include "../../utils.hpp" -namespace infinicore::op{ +namespace infinicore::op { common::OpDispatcher &AdaptiveMaxPool1d::dispatcher() { static common::OpDispatcher dispatcher_; diff --git a/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc index 630d067a8..451489e15 100644 --- a/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc +++ b/src/infinicore/ops/adaptive_max_pool1d/adaptive_max_pool1d_infiniop.cc @@ -1,7 +1,7 @@ #include "../../utils.hpp" #include "infinicore/common/hash.hpp" -#include "infinicore/ops/common/cache.hpp" #include "infinicore/ops/adaptive_max_pool1d.hpp" +#include "infinicore/ops/common/cache.hpp" #include namespace infinicore::op::adaptive_max_pool1d_impl::infiniop { diff --git a/src/infinicore/ops/baddbmm/baddbmm.cc b/src/infinicore/ops/baddbmm/baddbmm.cc index acacb58e7..3a8ee1518 100644 --- a/src/infinicore/ops/baddbmm/baddbmm.cc +++ b/src/infinicore/ops/baddbmm/baddbmm.cc @@ -10,7 +10,9 @@ inline bool is_blas_compatible(const Tensor &t) { if (ndim == 2) { const auto rs = t->stride(0); const auto cs = t->stride(1); - if (rs != 1 && cs != 1) return false; + if (rs != 1 && cs != 1) { + return false; + } if (rs == 1 && cs == 1) { return t->shape()[0] == 1 || t->shape()[1] == 1; } @@ -18,8 +20,12 @@ inline bool is_blas_compatible(const Tensor &t) { } else if (ndim == 3) { const auto rs = t->stride(1); const auto cs = t->stride(2); - if (t->shape()[0] > 1 && t->stride(0) == 0) return false; - if (rs != 1 && cs != 1) return false; + if (t->shape()[0] > 1 && t->stride(0) == 0) { + return false; + } + if (rs != 1 && cs != 1) { + return false; + } if (rs == 1 && cs == 1) { return t->shape()[1] == 1 || t->shape()[2] == 1; } @@ -28,28 +34,28 @@ inline bool is_blas_compatible(const Tensor &t) { return false; } -inline void prepare_gemm_input(Tensor &output,Tensor & input, const size_t batch_size, const size_t m, const size_t n) { +inline void prepare_gemm_input(Tensor &output, Tensor &input, const size_t batch_size, const size_t m, const size_t n) { const auto input_ndim = input->ndim(); if (input_ndim == 2) { rearrange_(output, input->as_strided( - {batch_size, m, n}, - {0, input->stride(0), input->stride(1)})); + {batch_size, m, n}, + {0, input->stride(0), input->stride(1)})); } else if (input_ndim == 3 && input->shape()[0] == 1 && batch_size > 1) { rearrange_(output, input->as_strided( - {batch_size, m, n}, - {0, input->stride(1), input->stride(2)})); + {batch_size, m, n}, + {0, input->stride(1), input->stride(2)})); } else { rearrange_(output, input); } } -Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, - float beta, - float alpha) { +Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, + float beta, + float alpha) { const size_t batch_size = batch1->shape()[0]; const size_t m = batch1->shape()[1]; const size_t n = batch2->shape()[2]; - + const Tensor &a = is_blas_compatible(batch1) ? batch1 : rearrange(batch1); const Tensor &b = is_blas_compatible(batch2) ? batch2 : rearrange(batch2); @@ -58,29 +64,25 @@ Tensor baddbmm(Tensor input, Tensor batch1, Tensor batch2, } Tensor result = Tensor::empty({batch_size, m, n}, a->dtype(), a->device()); - + prepare_gemm_input(result, input, batch_size, m, n); - + gemm_(result, a, b, alpha, beta); return result; } -void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, - float beta, - float alpha) { +void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, + float beta, + float alpha) { const size_t batch_size = batch1->shape()[0]; const size_t m = batch1->shape()[1]; const size_t n = batch2->shape()[2]; - + const Tensor &a = is_blas_compatible(batch1) ? batch1 : rearrange(batch1); const Tensor &b = is_blas_compatible(batch2) ? batch2 : rearrange(batch2); - - const bool out_is_usable = out->is_contiguous() && - out->ndim() == 3 && - out->shape()[0] == batch_size && - out->shape()[1] == m && - out->shape()[2] == n; - + + const bool out_is_usable = out->is_contiguous() && out->ndim() == 3 && out->shape()[0] == batch_size && out->shape()[1] == m && out->shape()[2] == n; + if (out_is_usable) { if (beta != 0.0f && input->data() != out->data()) { prepare_gemm_input(out, input, batch_size, m, n); @@ -95,4 +97,4 @@ void baddbmm_(Tensor out, Tensor input, Tensor batch1, Tensor batch2, rearrange_(out, result); } } -} \ No newline at end of file +} // namespace infinicore::op \ No newline at end of file diff --git a/src/infinicore/ops/bilinear/bilinear.cc b/src/infinicore/ops/bilinear/bilinear.cc index 6f2e2be2d..ab88a28f9 100644 --- a/src/infinicore/ops/bilinear/bilinear.cc +++ b/src/infinicore/ops/bilinear/bilinear.cc @@ -1,11 +1,11 @@ #include "infinicore/ops/bilinear.hpp" -#include "infinicore/ops/matmul.hpp" #include "infinicore/ops/add.hpp" +#include "infinicore/ops/matmul.hpp" #include "infinicore/ops/rearrange.hpp" #ifdef ENABLE_NVIDIA_API namespace op::gemm::nvidia { - void set_tf32_enabled(bool); +void set_tf32_enabled(bool); } #endif @@ -29,8 +29,10 @@ struct ScopedTF32Disable { }; inline bool is_gemm_compatible_3d(const Tensor &t) { - if (t->ndim() != 3) return false; - + if (t->ndim() != 3) { + return false; + } + const auto batch = t->shape()[0]; const auto rows = t->shape()[1]; const auto cols = t->shape()[2]; @@ -38,16 +40,24 @@ inline bool is_gemm_compatible_3d(const Tensor &t) { const auto rs = t->stride(1); const auto cs = t->stride(2); - if (rs != 1 && cs != 1) return false; - + if (rs != 1 && cs != 1) { + return false; + } + if (cs == 1) { - if (rs < static_cast(cols)) return false; + if (rs < static_cast(cols)) { + return false; + } } else { - if (cs < static_cast(rows)) return false; + if (cs < static_cast(rows)) { + return false; + } + } + + if (batch > 1 && bs == 0) { + return false; } - - if (batch > 1 && bs == 0) return false; - + return true; } @@ -75,9 +85,9 @@ Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) Tensor weight_cont = weight->is_contiguous() ? weight : weight->contiguous(); Tensor weight_permuted = weight_cont->permute({1, 0, 2}); - Tensor weight_permuted_cont = weight_permuted->is_contiguous() - ? weight_permuted - : weight_permuted->contiguous(); + Tensor weight_permuted_cont = weight_permuted->is_contiguous() + ? weight_permuted + : weight_permuted->contiguous(); Tensor weight_matrix = weight_permuted_cont->view({in1_features, out_features * in2_features}); Tensor intermediate = matmul(x1_compat, weight_matrix, 1.0f); @@ -94,9 +104,8 @@ Tensor bilinear(Tensor x1, Tensor x2, Tensor weight, std::optional bias) if (bias) { Tensor bias_broadcast = (*bias)->as_strided( - {batch_size, out_features}, - {0, (*bias)->strides()[0]} - ); + {batch_size, out_features}, + {0, (*bias)->strides()[0]}); out = add(out, bias_broadcast); } return out; diff --git a/src/infinicore/ops/fmod/fmod_infiniop.cc b/src/infinicore/ops/fmod/fmod_infiniop.cc index 31434e325..e796090d0 100644 --- a/src/infinicore/ops/fmod/fmod_infiniop.cc +++ b/src/infinicore/ops/fmod/fmod_infiniop.cc @@ -34,7 +34,7 @@ void calculate(Tensor c, Tensor a, Tensor b) { } else { desc = *desc_opt; } - + size_t workspace_size = 0; INFINICORE_CHECK_ERROR(infiniopGetFmodWorkspaceSize(desc, &workspace_size)); std::shared_ptr workspace = context::allocateMemory(workspace_size); @@ -42,7 +42,6 @@ void calculate(Tensor c, Tensor a, Tensor b) { INFINICORE_CHECK_ERROR(infiniopFmod( desc, workspace->data(), workspace_size, c->data(), a->data(), b->data(), context::getStream())); - } static bool registered = []() { diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index aded1e685..5d954c47f 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -4,8 +4,8 @@ #include "ops/adaptive_max_pool1d.hpp" #include "ops/add.hpp" -#include "ops/attention.hpp" #include "ops/asinh.hpp" +#include "ops/attention.hpp" #include "ops/baddbmm.hpp" #include "ops/bilinear.hpp" #include "ops/causal_softmax.hpp" diff --git a/src/infinicore/pybind11/ops/baddbmm.hpp b/src/infinicore/pybind11/ops/baddbmm.hpp index 79f7cb412..3aef0ce20 100644 --- a/src/infinicore/pybind11/ops/baddbmm.hpp +++ b/src/infinicore/pybind11/ops/baddbmm.hpp @@ -8,8 +8,7 @@ namespace py = pybind11; namespace infinicore::ops { - -Tensor py_baddbmm(Tensor input, Tensor batch1, Tensor batch2, float beta = 1.0f, float alpha = 1.0f) { +Tensor py_baddbmm(Tensor input, Tensor batch1, Tensor batch2, float beta = 1.0f, float alpha = 1.0f) { return op::baddbmm(input, batch1, batch2, beta, alpha); } @@ -35,15 +34,15 @@ inline void bind_baddbmm(py::module &m) { Returns: Output tensor after baddbmm operation )doc"); - m.def("baddbmm_", - &py_baddbmm_, - py::arg("out"), - py::arg("input"), - py::arg("batch1"), - py::arg("batch2"), - py::arg("beta") = 1.0f, - py::arg("alpha") = 1.0f, - R"doc(In-place batched matrix-matrix product with addition. + m.def("baddbmm_", + &py_baddbmm_, + py::arg("out"), + py::arg("input"), + py::arg("batch1"), + py::arg("batch2"), + py::arg("beta") = 1.0f, + py::arg("alpha") = 1.0f, + R"doc(In-place batched matrix-matrix product with addition. Args: out: Output tensor input: Input tensor diff --git a/src/infinicore/pybind11/ops/bilinear.hpp b/src/infinicore/pybind11/ops/bilinear.hpp index 08d0447f4..9c8ff80d6 100644 --- a/src/infinicore/pybind11/ops/bilinear.hpp +++ b/src/infinicore/pybind11/ops/bilinear.hpp @@ -41,14 +41,14 @@ inline void bind_bilinear(py::module &m) { Output tensor after bilinear transformation )doc"); - m.def("bilinear_", - &py_bilinear_, - py::arg("out"), - py::arg("x1"), - py::arg("x2"), - py::arg("weight"), - py::arg("bias"), - R"doc(In-place bilinear transformation of two input tensors. + m.def("bilinear_", + &py_bilinear_, + py::arg("out"), + py::arg("x1"), + py::arg("x2"), + py::arg("weight"), + py::arg("bias"), + R"doc(In-place bilinear transformation of two input tensors. Args: out: Output tensor x1: First input tensor @@ -56,7 +56,6 @@ inline void bind_bilinear(py::module &m) { weight: Weight tensor bias: Bias tensor (optional) )doc"); - } } // namespace infinicore::ops \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc b/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc index 75ade0eb3..69edf83bc 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc +++ b/src/infiniop/ops/adaptive_max_pool1d/cpu/adaptive_max_pool1d_cpu.cc @@ -1,8 +1,8 @@ #include "adaptive_max_pool1d_cpu.h" #include "../../../devices/cpu/common_cpu.h" #include "../../../reduce/cpu/reduce.h" -#include #include +#include namespace op::adaptive_max_pool1d::cpu { @@ -26,10 +26,10 @@ infiniStatus_t adaptiveMaxPool1d(const AdaptiveMaxPool1dInfo *info, T *y, const const size_t ndim = info->ndim(); const size_t batch_size = info->shape[0]; const size_t channels = ndim > 2 ? info->shape[1] : 1; - + const size_t input_length = info->input_length(); const size_t output_length = info->output_length(); - + // 计算总的任务块数 (Batch * Channels) const ptrdiff_t total_blocks = static_cast(batch_size * channels); @@ -44,27 +44,27 @@ infiniStatus_t adaptiveMaxPool1d(const AdaptiveMaxPool1dInfo *info, T *y, const T *y_ptr_base; if (ndim > 2) { // (N, C, L) - x_ptr_base = x + i * info->x_strides[0] + j * info->x_strides[1]; - y_ptr_base = y + i * info->y_strides[0] + j * info->y_strides[1]; + x_ptr_base = x + i * info->x_strides[0] + j * info->x_strides[1]; + y_ptr_base = y + i * info->y_strides[0] + j * info->y_strides[1]; } else { // (N, L) - x_ptr_base = x + i * info->x_strides[0]; - y_ptr_base = y + i * info->y_strides[0]; + x_ptr_base = x + i * info->x_strides[0]; + y_ptr_base = y + i * info->y_strides[0]; } for (size_t out_idx = 0; out_idx < output_length; ++out_idx) { // 计算池化窗口范围 [start_index, end_index) - // 公式参考 PyTorch: + // 公式参考 PyTorch: // start = floor(out_idx * L_in / L_out) // end = ceil((out_idx + 1) * L_in / L_out) int start_index = std::floor((float)out_idx * input_length / output_length); int end_index = std::ceil((float)(out_idx + 1) * input_length / output_length); - + start_index = std::max(start_index, 0); end_index = std::min(end_index, (int)input_length); int window_len = end_index - start_index; if (window_len <= 0) { - continue; + continue; } const T *window_ptr = x_ptr_base + start_index * x_stride_last; @@ -79,9 +79,9 @@ infiniStatus_t adaptiveMaxPool1d(const AdaptiveMaxPool1dInfo *info, T *y, const infiniStatus_t Descriptor::calculate( void *workspace, size_t workspace_size, - void *y, const void *x, + void *y, const void *x, void *stream) const { - + if (_info.atype == INFINI_DTYPE_F32) { return adaptiveMaxPool1d(&_info, (float *)y, (const float *)x); } else if (_info.atype == INFINI_DTYPE_F16) { @@ -91,7 +91,7 @@ infiniStatus_t Descriptor::calculate( } else if (_info.atype == INFINI_DTYPE_F64) { return adaptiveMaxPool1d(&_info, (double *)y, (const double *)x); } - + return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/adaptive_max_pool1d/cuda/kernel.cuh b/src/infiniop/ops/adaptive_max_pool1d/cuda/kernel.cuh index 5f21748bc..814688846 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/cuda/kernel.cuh +++ b/src/infiniop/ops/adaptive_max_pool1d/cuda/kernel.cuh @@ -16,16 +16,16 @@ __device__ void adaptiveMaxPool1dBlock( size_t channels, size_t input_length, size_t output_length, - size_t ndim){ - + size_t ndim) { + size_t block_idx = blockIdx.x; size_t batch_idx = block_idx / channels; size_t channel_idx = block_idx % channels; - + const Tdata *x_ptr; Tdata *y_ptr; - if(ndim > 2) { + if (ndim > 2) { x_ptr = x + batch_idx * stride_x_batch + channel_idx * stride_x_channel; y_ptr = y + batch_idx * stride_y_batch + channel_idx * stride_y_channel; } else { @@ -33,11 +33,11 @@ __device__ void adaptiveMaxPool1dBlock( y_ptr = y + batch_idx * stride_y_batch; } - for (size_t out_idx = threadIdx.x; out_idx < output_length ; out_idx += BLOCK_SIZE) { + for (size_t out_idx = threadIdx.x; out_idx < output_length; out_idx += BLOCK_SIZE) { int start_index = static_cast(floorf((float)out_idx * input_length / output_length)); int end_index = static_cast(ceilf((float)(out_idx + 1) * input_length / output_length)); - if (end_index <= start_index) { + if (end_index <= start_index) { continue; } @@ -49,7 +49,6 @@ __device__ void adaptiveMaxPool1dBlock( y_ptr[out_idx] = Tdata(max_val); } - } #endif \ No newline at end of file diff --git a/src/infiniop/ops/adaptive_max_pool1d/info.h b/src/infiniop/ops/adaptive_max_pool1d/info.h index f346bca9a..7194d2d93 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/info.h +++ b/src/infiniop/ops/adaptive_max_pool1d/info.h @@ -24,14 +24,13 @@ class AdaptiveMaxPool1dInfo { static utils::Result create( infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc, - size_t output_size) { + size_t output_size) { auto atype = y_desc->dtype(); - if (x_desc->dtype() != atype){ + if (x_desc->dtype() != atype) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } - if (atype != INFINI_DTYPE_F16 && atype != INFINI_DTYPE_BF16 && - atype != INFINI_DTYPE_F32 && atype != INFINI_DTYPE_F64) { + if (atype != INFINI_DTYPE_F16 && atype != INFINI_DTYPE_BF16 && atype != INFINI_DTYPE_F32 && atype != INFINI_DTYPE_F64) { return INFINI_STATUS_BAD_TENSOR_DTYPE; } @@ -43,11 +42,11 @@ class AdaptiveMaxPool1dInfo { } for (size_t i = 0; i < y_ndim - 1; ++i) { - if (x_desc->dim(i) != y_desc->dim(i)){ + if (x_desc->dim(i) != y_desc->dim(i)) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } } - + if (y_desc->dim(y_ndim - 1) != output_size) { return INFINI_STATUS_BAD_TENSOR_SHAPE; } @@ -58,11 +57,8 @@ class AdaptiveMaxPool1dInfo { y_desc->strides(), x_desc->strides(), x_desc->dim(x_ndim - 1), - output_size - }); - + output_size}); } - }; } // namespace op::adaptive_max_pool1d diff --git a/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu b/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu index f85f0a135..96ffe573f 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu +++ b/src/infiniop/ops/adaptive_max_pool1d/nvidia/adaptive_max_pool1d_nvidia.cu @@ -17,7 +17,7 @@ INFINIOP_CUDA_KERNEL adaptiveMaxPool1dKernel( size_t channels, size_t input_length, size_t output_length, - size_t ndim){ + size_t ndim) { adaptiveMaxPool1dBlock( y, stride_y_batch, stride_y_channel, @@ -63,12 +63,12 @@ infiniStatus_t launchKernel( size_t channels, size_t input_length, size_t output_length, size_t ndim, cudaStream_t cuda_stream) { -#define LAUNCH_KERNEL(Tdata, Tcompute) \ +#define LAUNCH_KERNEL(Tdata, Tcompute) \ adaptiveMaxPool1dKernel<<>>( \ - reinterpret_cast(y), \ - stride_y_batch, stride_y_channel, \ - reinterpret_cast(x), \ - stride_x_batch, stride_x_channel, stride_x_length, \ + reinterpret_cast(y), \ + stride_y_batch, stride_y_channel, \ + reinterpret_cast(x), \ + stride_x_batch, stride_x_channel, stride_x_length, \ channels, input_length, output_length, ndim) if (dtype == INFINI_DTYPE_F16) { diff --git a/src/infiniop/ops/adaptive_max_pool1d/operator.cc b/src/infiniop/ops/adaptive_max_pool1d/operator.cc index c301db551..7048a1033 100644 --- a/src/infiniop/ops/adaptive_max_pool1d/operator.cc +++ b/src/infiniop/ops/adaptive_max_pool1d/operator.cc @@ -22,13 +22,13 @@ __C infiniStatus_t infiniopCreateAdaptiveMaxPool1dDescriptor( infiniopTensorDescriptor_t x_desc, size_t output_size) { -#define CREATE(CASE, NAMESPACE) \ - case CASE: \ +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ return op::adaptive_max_pool1d::NAMESPACE::Descriptor::create( \ - handle, \ + handle, \ reinterpret_cast(desc_ptr), \ - y_desc, \ - x_desc, \ + y_desc, \ + x_desc, \ output_size) switch (handle->device) { @@ -54,13 +54,13 @@ __C infiniStatus_t infiniopCreateAdaptiveMaxPool1dDescriptor( } __C infiniStatus_t infiniopGetAdaptiveMaxPool1dWorkspaceSize( - infiniopAdaptiveMaxPool1dDescriptor_t desc, + infiniopAdaptiveMaxPool1dDescriptor_t desc, size_t *size) { -#define GET(CASE, NAMESPACE) \ - case CASE: \ - *size = reinterpret_cast(desc)->workspaceSize(); \ +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ return INFINI_STATUS_SUCCESS; - + switch (desc->device_type) { #ifdef ENABLE_CPU_API GET(INFINI_DEVICE_CPU, cpu); @@ -84,14 +84,14 @@ __C infiniStatus_t infiniopGetAdaptiveMaxPool1dWorkspaceSize( } __C infiniStatus_t infiniopAdaptiveMaxPool1d( - infiniopAdaptiveMaxPool1dDescriptor_t desc, - void *workspace, + infiniopAdaptiveMaxPool1dDescriptor_t desc, + void *workspace, size_t workspace_size, - void *y, - const void *x, + void *y, + const void *x, void *stream) { -#define CALCULATE(CASE, NAMESPACE) \ - case CASE: \ +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ return reinterpret_cast(desc)->calculate( \ workspace, workspace_size, y, x, stream); @@ -114,13 +114,13 @@ __C infiniStatus_t infiniopAdaptiveMaxPool1d( } #undef CALCULATE - return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } __C infiniStatus_t infiniopDestroyAdaptiveMaxPool1dDescriptor( infiniopAdaptiveMaxPool1dDescriptor_t desc) { -#define DESTROY(CASE, NAMESPACE) \ - case CASE: \ +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ delete reinterpret_cast(desc); \ return INFINI_STATUS_SUCCESS; diff --git a/src/infiniop/ops/asinh/cpu/asinh_cpu.cc b/src/infiniop/ops/asinh/cpu/asinh_cpu.cc index 860e06032..4d7627473 100644 --- a/src/infiniop/ops/asinh/cpu/asinh_cpu.cc +++ b/src/infiniop/ops/asinh/cpu/asinh_cpu.cc @@ -9,7 +9,7 @@ infiniStatus_t Descriptor::create( Descriptor **desc_ptr, infiniopTensorDescriptor_t out_desc, std::vector input_desc_vec) { - + auto handle = reinterpret_cast(handle_); auto dtype = out_desc->dtype(); @@ -27,22 +27,22 @@ infiniStatus_t Descriptor::create( } infiniStatus_t Descriptor::calculate( - void *workspace, - size_t workspace_size, - void *output, - std::vector inputs, - void *stream) const{ + void *workspace, + size_t workspace_size, + void *output, + std::vector inputs, + void *stream) const { switch (_dtype) { case INFINI_DTYPE_F16: return _device_info->calculate(_info, output, inputs, stream); case INFINI_DTYPE_F32: - return _device_info->calculate(_info, output, inputs, stream); + return _device_info->calculate(_info, output, inputs, stream); case INFINI_DTYPE_F64: return _device_info->calculate(_info, output, inputs, stream); case INFINI_DTYPE_BF16: return _device_info->calculate(_info, output, inputs, stream); - default : + default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } } diff --git a/src/infiniop/ops/asinh/cpu/asinh_cpu.h b/src/infiniop/ops/asinh/cpu/asinh_cpu.h index 196e208a7..076fcb30a 100644 --- a/src/infiniop/ops/asinh/cpu/asinh_cpu.h +++ b/src/infiniop/ops/asinh/cpu/asinh_cpu.h @@ -11,7 +11,7 @@ namespace op::asinh::cpu { typedef struct AsinhOp { public: static constexpr size_t num_inputs = 1; - + template T operator()(const T &x) const { return std::asinh(x); diff --git a/src/infiniop/ops/asinh/cuda/kernel.cuh b/src/infiniop/ops/asinh/cuda/kernel.cuh index eba99efd0..2bd6dcbf0 100644 --- a/src/infiniop/ops/asinh/cuda/kernel.cuh +++ b/src/infiniop/ops/asinh/cuda/kernel.cuh @@ -1,31 +1,29 @@ #ifndef __ASINH_CUDA_KERNEL_H__ #define __ASINH_CUDA_KERNEL_H__ -namespace op::asinh::cuda{ - +namespace op::asinh::cuda { + typedef struct AsinhOp { -public : +public: static constexpr size_t num_inputs = 1; template __device__ __forceinline__ T operator()(const T &x) const { - - if constexpr (std::is_same_v){ + + if constexpr (std::is_same_v) { float x_f = __half2float(x); return __float2half(asinhf(x_f)); - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { float x_f = __bfloat162float(x); return __float2bfloat16(asinhf(x_f)); - } - else if constexpr (std::is_same_v) { + } else if constexpr (std::is_same_v) { return asinhf(x); - }else { + } else { return ::asinh(x); } } } AsinhOp; -} // namespace op::Asinh::cuda +} // namespace op::asinh::cuda #endif // __ASINH_CUDA_KERNEL_H__ \ No newline at end of file diff --git a/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu index 788ab4502..77a4652bc 100644 --- a/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu +++ b/src/infiniop/ops/asinh/nvidia/asinh_nvidia.cu @@ -26,7 +26,7 @@ infiniStatus_t Descriptor::create( CHECK_SAME_SHAPE(y_shape, x_shape); CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) - + return INFINI_STATUS_SUCCESS; } @@ -41,17 +41,16 @@ infiniStatus_t Descriptor::calculate( } switch (_dtype) { case INFINI_DTYPE_F16: - return _device_info->calculate<256, cuda::AsinhOp,half>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::AsinhOp, half>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_BF16: - return _device_info->calculate<256, cuda::AsinhOp,cuda_bfloat16>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::AsinhOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F32: - return _device_info->calculate<256, cuda::AsinhOp,float>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::AsinhOp, float>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F64: - return _device_info->calculate<256, cuda::AsinhOp,double>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::AsinhOp, double>(_info, workspace, output, inputs, stream); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } - -} +} } // namespace op::asinh::nvidia \ No newline at end of file diff --git a/src/infiniop/ops/asinh/operator.cc b/src/infiniop/ops/asinh/operator.cc index 3a4eb5d14..e3decacf1 100644 --- a/src/infiniop/ops/asinh/operator.cc +++ b/src/infiniop/ops/asinh/operator.cc @@ -5,7 +5,7 @@ #ifdef ENABLE_CPU_API #include "cpu/asinh_cpu.h" #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) #include "nvidia/asinh_nvidia.cuh" #endif #ifdef ENABLE_METAX_API @@ -20,12 +20,12 @@ __C infiniStatus_t infiniopCreateAsinhDescriptor( infiniopAsinhDescriptor_t *desc_ptr, infiniopTensorDescriptor_t y_desc, infiniopTensorDescriptor_t x_desc) { -#define CREATE(CASE, NAMESPACE) \ - case CASE: \ - return op::asinh::NAMESPACE::Descriptor::create( \ - handle, \ - reinterpret_cast(desc_ptr), \ - y_desc, \ +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::asinh::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ {x_desc}) switch (handle->device) { @@ -52,19 +52,19 @@ __C infiniStatus_t infiniopCreateAsinhDescriptor( __C infiniStatus_t infiniopGetAsinhWorkspaceSize(infiniopAsinhDescriptor_t desc, size_t *size) { -#define GET(CASE, NAMESPACE) \ - case CASE: \ - *size = reinterpret_cast(desc)->workspaceSize(); \ +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ return INFINI_STATUS_SUCCESS switch (desc->device_type) { #ifdef ENABLE_CPU_API GET(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NVIDIA_API GET(INFINI_DEVICE_NVIDIA, nvidia); #endif -#ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif #ifdef ENABLE_METAX_API @@ -86,18 +86,18 @@ __C infiniStatus_t infiniopAsinh(infiniopAsinhDescriptor_t desc, void *y, const void *x, void *stream) { -#define CALCULATE(CASE, NAMESPACE) \ - case CASE: \ +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ return reinterpret_cast(desc) \ ->calculate(workspace, workspace_size, y, {x}, stream); switch (desc->device_type) { #ifdef ENABLE_CPU_API CALCULATE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NVIDIA_API CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); #endif -#ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif #ifdef ENABLE_METAX_API @@ -113,19 +113,19 @@ __C infiniStatus_t infiniopAsinh(infiniopAsinhDescriptor_t desc, } __C infiniStatus_t infiniopDestroyAsinhDescriptor(infiniopAsinhDescriptor_t desc) { -#define GET(CASE, NAMESPACE) \ - case CASE: \ +#define GET(CASE, NAMESPACE) \ + case CASE: \ reinterpret_cast(desc); \ return INFINI_STATUS_SUCCESS; - + switch (desc->device_type) { #ifdef ENABLE_CPU_API GET(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NVIDIA_API GET(INFINI_DEVICE_NVIDIA, nvidia); #endif -#ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_ILUVATAR_API GET(INFINI_DEVICE_ILUVATAR, nvidia); #endif #ifdef ENABLE_METAX_API diff --git a/src/infiniop/ops/fmod/cpu/fmod_cpu.cc b/src/infiniop/ops/fmod/cpu/fmod_cpu.cc index 1114470a4..1f27290de 100644 --- a/src/infiniop/ops/fmod/cpu/fmod_cpu.cc +++ b/src/infiniop/ops/fmod/cpu/fmod_cpu.cc @@ -50,4 +50,4 @@ infiniStatus_t Descriptor::calculate( } return INFINI_STATUS_SUCCESS; } -} // namespace op::mul::cpu +} // namespace op::fmod::cpu diff --git a/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu b/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu index 5f834ef95..a74295264 100644 --- a/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu +++ b/src/infiniop/ops/fmod/nvidia/fmod_nvidia.cu @@ -26,7 +26,7 @@ infiniStatus_t Descriptor::create( CHECK_SAME_SHAPE(c_shape, a_shape, b_shape); CREATE_ELEMENTWISE_CUDA_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec) - + return INFINI_STATUS_SUCCESS; } @@ -37,19 +37,19 @@ infiniStatus_t Descriptor::calculate( std::vector inputs, void *stream) const { - if (workspace_size < _workspace_size){ + if (workspace_size < _workspace_size) { return INFINI_STATUS_INSUFFICIENT_WORKSPACE; } - switch (_dtype){ + switch (_dtype) { case INFINI_DTYPE_F16: - return _device_info->calculate<256, cuda::FmodOp, half>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::FmodOp, half>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F32: - return _device_info->calculate<256, cuda::FmodOp, float>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::FmodOp, float>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_F64: - return _device_info->calculate<256, cuda::FmodOp, double>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::FmodOp, double>(_info, workspace, output, inputs, stream); case INFINI_DTYPE_BF16: - return _device_info->calculate<256, cuda::FmodOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); + return _device_info->calculate<256, cuda::FmodOp, cuda_bfloat16>(_info, workspace, output, inputs, stream); default: return INFINI_STATUS_BAD_TENSOR_DTYPE; } diff --git a/src/infiniop/ops/fmod/operator.cc b/src/infiniop/ops/fmod/operator.cc index c94fdfc10..1fd433c4a 100644 --- a/src/infiniop/ops/fmod/operator.cc +++ b/src/infiniop/ops/fmod/operator.cc @@ -5,7 +5,7 @@ #ifdef ENABLE_CPU_API #include "cpu/fmod_cpu.h" #endif -#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) #include "nvidia/fmod_nvidia.cuh" #endif #ifdef ENABLE_METAX_API @@ -21,14 +21,14 @@ __C infiniStatus_t infiniopCreateFmodDescriptor( infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) { -#define CREATE(CASE, NAMESPACE) \ - case CASE: \ +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ return op::fmod::NAMESPACE::Descriptor::create( \ - handle, \ + handle, \ reinterpret_cast(desc_ptr), \ - c_desc, \ - {a_desc, \ - b_desc}) + c_desc, \ + {a_desc, \ + b_desc}) switch (handle->device) { @@ -92,19 +92,19 @@ __C infiniStatus_t infiniopFmod( const void *a, const void *b, void *stream) { -#define CALCULATE(CASE, NAMESPACE) \ - case CASE: \ +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ return reinterpret_cast(desc) \ - ->calculate(workspace, workspace_size, c, {a, b}, stream) + ->calculate(workspace, workspace_size, c, {a, b}, stream) switch (desc->device_type) { #ifdef ENABLE_CPU_API CALCULATE(INFINI_DEVICE_CPU, cpu); #endif -#ifdef ENABLE_NVIDIA_API +#ifdef ENABLE_NVIDIA_API CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); #endif -#ifdef ENABLE_ILUVATAR_API +#ifdef ENABLE_ILUVATAR_API CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); #endif #ifdef ENABLE_METAX_API @@ -121,8 +121,8 @@ __C infiniStatus_t infiniopFmod( __C infiniStatus_t infiniopDestroyFmodDescriptor(infiniopFmodDescriptor_t desc) { -#define GET(CASE, NAMESPACE) \ - case CASE: \ +#define GET(CASE, NAMESPACE) \ + case CASE: \ delete reinterpret_cast(desc); \ return INFINI_STATUS_SUCCESS; @@ -144,11 +144,9 @@ __C infiniStatus_t infiniopDestroyFmodDescriptor(infiniopFmodDescriptor_t desc) #endif #ifdef ENABLE_MOORE_API GET(INFINI_DEVICE_MOORE, moore); -#endif +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; } #undef DELETE } - - \ No newline at end of file diff --git a/src/infiniop/ops/gemm/cpu/gemm_cpu.cc b/src/infiniop/ops/gemm/cpu/gemm_cpu.cc index d19965614..6f7a2e3e0 100644 --- a/src/infiniop/ops/gemm/cpu/gemm_cpu.cc +++ b/src/infiniop/ops/gemm/cpu/gemm_cpu.cc @@ -64,7 +64,11 @@ void calculate( *c_ = utils::cast(beta * utils::cast(*c_) + alpha * sum); } } else { - *c_ = beta * (*c_) + alpha * sum; + if (beta == 0) { + *c_ = alpha * sum; + } else { + *c_ = beta * (*c_) + alpha * sum; + } } } } diff --git a/test/infinicore/ops/bilinear.py b/test/infinicore/ops/bilinear.py index 765e423cf..2b9970a4c 100644 --- a/test/infinicore/ops/bilinear.py +++ b/test/infinicore/ops/bilinear.py @@ -43,7 +43,7 @@ def parse_test_cases(): in1 = TensorSpec.from_tensor(in1_shape, in1_strides, dtype) in2 = TensorSpec.from_tensor(in2_shape, in2_strides, dtype) weight = TensorSpec.from_tensor(weight_shape, weight_strides, dtype) - + inputs = [in1, in2, weight] if bias_present: bias_shape = (weight_shape[0],) @@ -80,6 +80,7 @@ def torch_operator(self, *args, **kwargs): def infinicore_operator(self, *args, **kwargs): from infinicore.ops.bilinear import bilinear + return bilinear(*args, **kwargs) diff --git a/test/infiniop/bilinear.py b/test/infiniop/bilinear.py index 47af60812..31dac2063 100644 --- a/test/infiniop/bilinear.py +++ b/test/infiniop/bilinear.py @@ -3,39 +3,39 @@ from ctypes import c_uint64 from libinfiniop import ( - LIBINFINIOP, - TestTensor, - get_test_devices, - check_error, - test_operator, - get_args, - debug, - get_tolerance, - profile_operation, - TestWorkspace, - InfiniDtype, - InfiniDtypeNames, - InfiniDeviceNames, - infiniopOperatorDescriptor_t, + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, ) _TEST_CASES = [ - # batch, in1, in2, out, use_bias - (4, 3, 5, 2, True), - (1, 6, 7, 3, True), - (8, 2, 4, 5, False), - (2, 3, 3, 4, True), - (6, 10, 12, 7, False), - (3, 1, 1, 2, True), + # batch, in1, in2, out, use_bias + (4, 3, 5, 2, True), + (1, 6, 7, 3, True), + (8, 2, 4, 5, False), + (2, 3, 3, 4, True), + (6, 10, 12, 7, False), + (3, 1, 1, 2, True), ] _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] _TOLERANCE_MAP = { - InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, - InfiniDtype.BF16: {"atol": 1e-2, "rtol": 5e-2}, - InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-4}, + InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 5e-2}, + InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-4}, } DEBUG = False @@ -45,137 +45,139 @@ def reference_bilinear(x1, x2, weight, bias): - """Compute bilinear output in FP32 for stability and cast back.""" - out = torch.einsum( - "ni,oij,nj->no", - x1.to(torch.float32), - weight.to(torch.float32), - x2.to(torch.float32), - ) - if bias is not None: - out = out + bias.to(torch.float32) - return out.to(x1.dtype) + """Compute bilinear output in FP32 for stability and cast back.""" + out = torch.einsum( + "ni,oij,nj->no", + x1.to(torch.float32), + weight.to(torch.float32), + x2.to(torch.float32), + ) + if bias is not None: + out = out + bias.to(torch.float32) + return out.to(x1.dtype) def test( - handle, - device, - batch, - in1_features, - in2_features, - out_features, - use_bias, - dtype=InfiniDtype.F16, - sync=None, + handle, + device, + batch, + in1_features, + in2_features, + out_features, + use_bias, + dtype=InfiniDtype.F16, + sync=None, ): - print( - f"Testing Bilinear on {InfiniDeviceNames[device]} with N:{batch} in1:{in1_features} in2:{in2_features} " - f"out:{out_features} bias:{use_bias} dtype:{InfiniDtypeNames[dtype]}" - ) - - out_tensor = TestTensor((batch, out_features), None, dtype, device, mode="zeros") - x1 = TestTensor((batch, in1_features), None, dtype, device, scale=0.1, bias=-0.05) - x2 = TestTensor((batch, in2_features), None, dtype, device, scale=0.1, bias=-0.05) - weight = TestTensor( - (out_features, in1_features, in2_features), - None, - dtype, - device, - scale=0.1, - bias=-0.05, - ) - bias_tensor = ( - TestTensor((out_features,), None, dtype, device, scale=0.1, bias=-0.05) - if use_bias - else None - ) - - ref = reference_bilinear( - x1.torch_tensor(), - x2.torch_tensor(), - weight.torch_tensor(), - bias_tensor.torch_tensor() if bias_tensor else None, - ) - - if sync is not None: - sync() - - descriptor = infiniopOperatorDescriptor_t() - check_error( - LIBINFINIOP.infiniopCreateBilinearDescriptor( - handle, - ctypes.byref(descriptor), - out_tensor.descriptor, - x1.descriptor, - x2.descriptor, - weight.descriptor, - bias_tensor.descriptor if bias_tensor else None, - ) - ) - - tensors = [out_tensor, x1, x2, weight] - if bias_tensor: - tensors.append(bias_tensor) - for tensor in tensors: - tensor.destroy_desc() - - workspace_size = c_uint64(0) - check_error( - LIBINFINIOP.infiniopGetBilinearWorkspaceSize( - descriptor, ctypes.byref(workspace_size) - ) - ) - workspace = TestWorkspace(workspace_size.value, device) - - def lib_bilinear(): - check_error( - LIBINFINIOP.infiniopBilinear( - descriptor, - workspace.data(), - workspace_size.value, - out_tensor.data(), - x1.data(), - x2.data(), - weight.data(), - bias_tensor.data() if bias_tensor else None, - None, - ) - ) - - lib_bilinear() - - atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) - if DEBUG: - debug(out_tensor.actual_tensor(), ref, atol=atol, rtol=rtol) - assert torch.allclose(out_tensor.actual_tensor(), ref, atol=atol, rtol=rtol) - - if PROFILE: - profile_operation( - "PyTorch", - lambda: reference_bilinear( - x1.torch_tensor(), - x2.torch_tensor(), - weight.torch_tensor(), - bias_tensor.torch_tensor() if bias_tensor else None, - ), - device, - NUM_PRERUN, - NUM_ITERATIONS, - ) - profile_operation(" lib", lambda: lib_bilinear(), device, NUM_PRERUN, NUM_ITERATIONS) - - check_error(LIBINFINIOP.infiniopDestroyBilinearDescriptor(descriptor)) + print( + f"Testing Bilinear on {InfiniDeviceNames[device]} with N:{batch} in1:{in1_features} in2:{in2_features} " + f"out:{out_features} bias:{use_bias} dtype:{InfiniDtypeNames[dtype]}" + ) + + out_tensor = TestTensor((batch, out_features), None, dtype, device, mode="zeros") + x1 = TestTensor((batch, in1_features), None, dtype, device, scale=0.1, bias=-0.05) + x2 = TestTensor((batch, in2_features), None, dtype, device, scale=0.1, bias=-0.05) + weight = TestTensor( + (out_features, in1_features, in2_features), + None, + dtype, + device, + scale=0.1, + bias=-0.05, + ) + bias_tensor = ( + TestTensor((out_features,), None, dtype, device, scale=0.1, bias=-0.05) + if use_bias + else None + ) + + ref = reference_bilinear( + x1.torch_tensor(), + x2.torch_tensor(), + weight.torch_tensor(), + bias_tensor.torch_tensor() if bias_tensor else None, + ) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateBilinearDescriptor( + handle, + ctypes.byref(descriptor), + out_tensor.descriptor, + x1.descriptor, + x2.descriptor, + weight.descriptor, + bias_tensor.descriptor if bias_tensor else None, + ) + ) + + tensors = [out_tensor, x1, x2, weight] + if bias_tensor: + tensors.append(bias_tensor) + for tensor in tensors: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetBilinearWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + def lib_bilinear(): + check_error( + LIBINFINIOP.infiniopBilinear( + descriptor, + workspace.data(), + workspace_size.value, + out_tensor.data(), + x1.data(), + x2.data(), + weight.data(), + bias_tensor.data() if bias_tensor else None, + None, + ) + ) + + lib_bilinear() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + if DEBUG: + debug(out_tensor.actual_tensor(), ref, atol=atol, rtol=rtol) + assert torch.allclose(out_tensor.actual_tensor(), ref, atol=atol, rtol=rtol) + + if PROFILE: + profile_operation( + "PyTorch", + lambda: reference_bilinear( + x1.torch_tensor(), + x2.torch_tensor(), + weight.torch_tensor(), + bias_tensor.torch_tensor() if bias_tensor else None, + ), + device, + NUM_PRERUN, + NUM_ITERATIONS, + ) + profile_operation( + " lib", lambda: lib_bilinear(), device, NUM_PRERUN, NUM_ITERATIONS + ) + + check_error(LIBINFINIOP.infiniopDestroyBilinearDescriptor(descriptor)) if __name__ == "__main__": - args = get_args() + args = get_args() - DEBUG = args.debug - PROFILE = args.profile - NUM_PRERUN = args.num_prerun - NUM_ITERATIONS = args.num_iterations + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations - for device in get_test_devices(args): - test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) - print("\033[92mTest passed!\033[0m") + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 178f57899..5b2974111 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -938,41 +938,3 @@ def tanh_(lib): lib.infiniopDestroyTanhDescriptor.argtypes = [ infiniopOperatorDescriptor_t, ] - - -@OpRegister.operator -def bilinear_(lib): - lib.infiniopCreateBilinearDescriptor.restype = c_int32 - lib.infiniopCreateBilinearDescriptor.argtypes = [ - infiniopHandle_t, - POINTER(infiniopOperatorDescriptor_t), - infiniopTensorDescriptor_t, - infiniopTensorDescriptor_t, - infiniopTensorDescriptor_t, - infiniopTensorDescriptor_t, - infiniopTensorDescriptor_t, - ] - - lib.infiniopGetBilinearWorkspaceSize.restype = c_int32 - lib.infiniopGetBilinearWorkspaceSize.argtypes = [ - infiniopOperatorDescriptor_t, - POINTER(c_size_t), - ] - - lib.infiniopBilinear.restype = c_int32 - lib.infiniopBilinear.argtypes = [ - infiniopOperatorDescriptor_t, - c_void_p, - c_size_t, - c_void_p, - c_void_p, - c_void_p, - c_void_p, - c_void_p, - c_void_p, - ] - - lib.infiniopDestroyBilinearDescriptor.restype = c_int32 - lib.infiniopDestroyBilinearDescriptor.argtypes = [ - infiniopOperatorDescriptor_t, - ] \ No newline at end of file From 13d2d4a74f38e3e339ccb195faf13ed9f812ed94 Mon Sep 17 00:00:00 2001 From: Kinorw Date: Sun, 14 Dec 2025 05:31:47 +0800 Subject: [PATCH 32/32] =?UTF-8?q?=E6=B8=85=E7=90=86=E6=97=A0=E5=85=B3?= =?UTF-8?q?=E5=86=97=E4=BD=99=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 +- sbatch.sh | 3 - test/infiniop/bilinear.py | 183 -------------------------------------- 3 files changed, 1 insertion(+), 189 deletions(-) delete mode 100755 sbatch.sh delete mode 100644 test/infiniop/bilinear.py diff --git a/.gitignore b/.gitignore index 7683cae5e..4944a5db4 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,4 @@ cache/ # Compressed *.gz *.zip -*.tar - -*.txt \ No newline at end of file +*.tar \ No newline at end of file diff --git a/sbatch.sh b/sbatch.sh deleted file mode 100755 index fa3ff1e47..000000000 --- a/sbatch.sh +++ /dev/null @@ -1,3 +0,0 @@ -srun --partition=mt --nodes=1 --gres=gpu:mt:2 --ntasks=1 --cpus-per-task=16 --mem=256G --time=00:01:00 \ - --output=output_%j.log \ - python test/infinicore/ops/bilinear.py --moore --verbose --bench --debug diff --git a/test/infiniop/bilinear.py b/test/infiniop/bilinear.py deleted file mode 100644 index 31dac2063..000000000 --- a/test/infiniop/bilinear.py +++ /dev/null @@ -1,183 +0,0 @@ -import torch -import ctypes -from ctypes import c_uint64 - -from libinfiniop import ( - LIBINFINIOP, - TestTensor, - get_test_devices, - check_error, - test_operator, - get_args, - debug, - get_tolerance, - profile_operation, - TestWorkspace, - InfiniDtype, - InfiniDtypeNames, - InfiniDeviceNames, - infiniopOperatorDescriptor_t, -) - - -_TEST_CASES = [ - # batch, in1, in2, out, use_bias - (4, 3, 5, 2, True), - (1, 6, 7, 3, True), - (8, 2, 4, 5, False), - (2, 3, 3, 4, True), - (6, 10, 12, 7, False), - (3, 1, 1, 2, True), -] - -_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] - -_TOLERANCE_MAP = { - InfiniDtype.F16: {"atol": 1e-3, "rtol": 1e-2}, - InfiniDtype.BF16: {"atol": 1e-2, "rtol": 5e-2}, - InfiniDtype.F32: {"atol": 1e-5, "rtol": 1e-4}, -} - -DEBUG = False -PROFILE = False -NUM_PRERUN = 10 -NUM_ITERATIONS = 1000 - - -def reference_bilinear(x1, x2, weight, bias): - """Compute bilinear output in FP32 for stability and cast back.""" - out = torch.einsum( - "ni,oij,nj->no", - x1.to(torch.float32), - weight.to(torch.float32), - x2.to(torch.float32), - ) - if bias is not None: - out = out + bias.to(torch.float32) - return out.to(x1.dtype) - - -def test( - handle, - device, - batch, - in1_features, - in2_features, - out_features, - use_bias, - dtype=InfiniDtype.F16, - sync=None, -): - print( - f"Testing Bilinear on {InfiniDeviceNames[device]} with N:{batch} in1:{in1_features} in2:{in2_features} " - f"out:{out_features} bias:{use_bias} dtype:{InfiniDtypeNames[dtype]}" - ) - - out_tensor = TestTensor((batch, out_features), None, dtype, device, mode="zeros") - x1 = TestTensor((batch, in1_features), None, dtype, device, scale=0.1, bias=-0.05) - x2 = TestTensor((batch, in2_features), None, dtype, device, scale=0.1, bias=-0.05) - weight = TestTensor( - (out_features, in1_features, in2_features), - None, - dtype, - device, - scale=0.1, - bias=-0.05, - ) - bias_tensor = ( - TestTensor((out_features,), None, dtype, device, scale=0.1, bias=-0.05) - if use_bias - else None - ) - - ref = reference_bilinear( - x1.torch_tensor(), - x2.torch_tensor(), - weight.torch_tensor(), - bias_tensor.torch_tensor() if bias_tensor else None, - ) - - if sync is not None: - sync() - - descriptor = infiniopOperatorDescriptor_t() - check_error( - LIBINFINIOP.infiniopCreateBilinearDescriptor( - handle, - ctypes.byref(descriptor), - out_tensor.descriptor, - x1.descriptor, - x2.descriptor, - weight.descriptor, - bias_tensor.descriptor if bias_tensor else None, - ) - ) - - tensors = [out_tensor, x1, x2, weight] - if bias_tensor: - tensors.append(bias_tensor) - for tensor in tensors: - tensor.destroy_desc() - - workspace_size = c_uint64(0) - check_error( - LIBINFINIOP.infiniopGetBilinearWorkspaceSize( - descriptor, ctypes.byref(workspace_size) - ) - ) - workspace = TestWorkspace(workspace_size.value, device) - - def lib_bilinear(): - check_error( - LIBINFINIOP.infiniopBilinear( - descriptor, - workspace.data(), - workspace_size.value, - out_tensor.data(), - x1.data(), - x2.data(), - weight.data(), - bias_tensor.data() if bias_tensor else None, - None, - ) - ) - - lib_bilinear() - - atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) - if DEBUG: - debug(out_tensor.actual_tensor(), ref, atol=atol, rtol=rtol) - assert torch.allclose(out_tensor.actual_tensor(), ref, atol=atol, rtol=rtol) - - if PROFILE: - profile_operation( - "PyTorch", - lambda: reference_bilinear( - x1.torch_tensor(), - x2.torch_tensor(), - weight.torch_tensor(), - bias_tensor.torch_tensor() if bias_tensor else None, - ), - device, - NUM_PRERUN, - NUM_ITERATIONS, - ) - profile_operation( - " lib", lambda: lib_bilinear(), device, NUM_PRERUN, NUM_ITERATIONS - ) - - check_error(LIBINFINIOP.infiniopDestroyBilinearDescriptor(descriptor)) - - -if __name__ == "__main__": - args = get_args() - - DEBUG = args.debug - PROFILE = args.profile - NUM_PRERUN = args.num_prerun - NUM_ITERATIONS = args.num_iterations - - for device in get_test_devices(args): - test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) - - print("\033[92mTest passed!\033[0m")