From ca382b6a03a12f809cdbe2a36cbbbe0ec6ca7221 Mon Sep 17 00:00:00 2001 From: MaYuhang <2902139028@qq.com> Date: Wed, 17 Dec 2025 22:07:39 +0800 Subject: [PATCH] =?UTF-8?q?issue/804:=20=E5=AE=8C=E5=96=84OpenCL=E5=90=8E?= =?UTF-8?q?=E7=AB=AF=E6=94=AF=E6=8C=81=EF=BC=9A=E6=B7=BB=E5=8A=A0device=20?= =?UTF-8?q?handle=E3=80=81program=E7=BC=93=E5=AD=98=E3=80=81opencl?= =?UTF-8?q?=E7=AE=97=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/infiniop/devices/handle.cc | 9 + src/infiniop/devices/opencl/opencl_common.h | 39 + src/infiniop/devices/opencl/opencl_handle.cc | 59 ++ src/infiniop/devices/opencl/opencl_handle.h | 28 + .../devices/opencl/opencl_kernel_common.h | 225 ++++++ .../devices/opencl/opencl_program_cache.cc | 85 +++ .../devices/opencl/opencl_program_cache.h | 47 ++ .../opencl/causal_softmax_opencl.cc | 282 +++++++ .../opencl/causal_softmax_opencl.h | 8 + src/infiniop/ops/causal_softmax/operator.cc | 15 + src/infiniop/ops/gemm/opencl/gemm_opencl.cc | 244 ++++++ src/infiniop/ops/gemm/opencl/gemm_opencl.h | 8 + src/infiniop/ops/gemm/operator.cc | 15 + .../opencl/random_sample_opencl.cc | 314 ++++++++ .../opencl/random_sample_opencl.h | 8 + src/infiniop/ops/random_sample/operator.cc | 15 + .../ops/rearrange/opencl/rearrange_opencl.cc | 720 ++++++++++++++++++ .../ops/rearrange/opencl/rearrange_opencl.h | 8 + src/infiniop/ops/rearrange/operator.cc | 12 + .../ops/rms_norm/opencl/rms_norm_opencl.cc | 234 ++++++ .../ops/rms_norm/opencl/rms_norm_opencl.h | 8 + src/infiniop/ops/rms_norm/operator.cc | 15 + src/infiniop/ops/rope/opencl/rope_opencl.cc | 284 +++++++ src/infiniop/ops/rope/opencl/rope_opencl.h | 8 + src/infiniop/ops/rope/operator.cc | 15 + .../ops/swiglu/opencl/swiglu_opencl.cc | 219 ++++++ .../ops/swiglu/opencl/swiglu_opencl.h | 89 +++ src/infiniop/ops/swiglu/operator.cc | 15 + test/infiniop/libinfiniop/devices.py | 3 + test/infiniop/libinfiniop/utils.py | 14 + xmake.lua | 3 + xmake/opencl.lua | 20 + 32 files changed, 3068 insertions(+) create mode 100644 src/infiniop/devices/opencl/opencl_common.h create mode 100644 src/infiniop/devices/opencl/opencl_handle.cc create mode 100644 src/infiniop/devices/opencl/opencl_handle.h create mode 100644 src/infiniop/devices/opencl/opencl_kernel_common.h create mode 100644 src/infiniop/devices/opencl/opencl_program_cache.cc create mode 100644 src/infiniop/devices/opencl/opencl_program_cache.h create mode 100644 src/infiniop/ops/causal_softmax/opencl/causal_softmax_opencl.cc create mode 100644 src/infiniop/ops/causal_softmax/opencl/causal_softmax_opencl.h create mode 100644 src/infiniop/ops/gemm/opencl/gemm_opencl.cc create mode 100644 src/infiniop/ops/gemm/opencl/gemm_opencl.h create mode 100644 src/infiniop/ops/random_sample/opencl/random_sample_opencl.cc create mode 100644 src/infiniop/ops/random_sample/opencl/random_sample_opencl.h create mode 100644 src/infiniop/ops/rearrange/opencl/rearrange_opencl.cc create mode 100644 src/infiniop/ops/rearrange/opencl/rearrange_opencl.h create mode 100644 src/infiniop/ops/rms_norm/opencl/rms_norm_opencl.cc create mode 100644 src/infiniop/ops/rms_norm/opencl/rms_norm_opencl.h create mode 100644 src/infiniop/ops/rope/opencl/rope_opencl.cc create mode 100644 src/infiniop/ops/rope/opencl/rope_opencl.h create mode 100644 src/infiniop/ops/swiglu/opencl/swiglu_opencl.cc create mode 100644 src/infiniop/ops/swiglu/opencl/swiglu_opencl.h diff --git a/src/infiniop/devices/handle.cc b/src/infiniop/devices/handle.cc index 6b036e553..c5198b9cb 100644 --- a/src/infiniop/devices/handle.cc +++ b/src/infiniop/devices/handle.cc @@ -23,6 +23,9 @@ #ifdef ENABLE_METAX_API #include "metax/metax_handle.h" #endif +#ifdef ENABLE_OPENCL_API +#include "opencl/opencl_handle.h" +#endif __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) { if (handle_ptr == nullptr) { @@ -68,6 +71,9 @@ __C infiniStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr) { #ifdef ENABLE_HYGON_API CREATE(INFINI_DEVICE_HYGON, hygon); #endif +#ifdef ENABLE_OPENCL_API + CREATE(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -113,6 +119,9 @@ __C infiniStatus_t infiniopDestroyHandle(infiniopHandle_t handle) { #endif #ifdef ENABLE_HYGON_API DELETE(INFINI_DEVICE_HYGON, hygon); +#endif +#ifdef ENABLE_OPENCL_API + DELETE(INFINI_DEVICE_OPENCL, opencl); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/devices/opencl/opencl_common.h b/src/infiniop/devices/opencl/opencl_common.h new file mode 100644 index 000000000..7e3c23ea9 --- /dev/null +++ b/src/infiniop/devices/opencl/opencl_common.h @@ -0,0 +1,39 @@ +#ifndef __INFINIOP_OPENCL_COMMON_H__ +#define __INFINIOP_OPENCL_COMMON_H__ + +#include "../../../utils.h" +#include "../pool.h" +#include "opencl_handle.h" +#include "opencl_kernel_common.h" +#include "opencl_program_cache.h" +#include +#include + +namespace device::opencl { + +class Handle::Internal { + + int _warp_size, + _max_threads_per_block, + _block_size[3]; + + template + using Fn = std::function; + +public: + Internal(int); + + int warpSize() const; + int maxThreadsPerBlock() const; + int blockSizeX() const; + int blockSizeY() const; + int blockSizeZ() const; + ProgramCache *programCache() const; + +private: + std::unique_ptr program_cache_; +}; + +} // namespace device::opencl + +#endif // __INFINIOP_OPENCL_COMMON_H__ diff --git a/src/infiniop/devices/opencl/opencl_handle.cc b/src/infiniop/devices/opencl/opencl_handle.cc new file mode 100644 index 000000000..a27f9bfd8 --- /dev/null +++ b/src/infiniop/devices/opencl/opencl_handle.cc @@ -0,0 +1,59 @@ +#include "../../../infinirt/opencl/infinirt_opencl.h" +#include "opencl_common.h" + +namespace device::opencl { +Handle::Handle(infiniDevice_t device, int device_id) + : InfiniopHandle{device, device_id}, + _internal(std::make_shared(device_id)) {} + +Handle::Handle(int device_id) : Handle(INFINI_DEVICE_OPENCL, device_id) {} + +auto Handle::internal() const -> const std::shared_ptr & { + return _internal; +} + +Handle::Internal::Internal(int device_id) { + infinirtInit(); + cl_device_id cl_device; + infinirtOpenclDevice_t device; + infinirtGetOpenclDevice(&device); + cl_device = static_cast(device); + +#if defined(INTEL) + _warp_size = 32; +#elif defined(ADRENO) + _warp_size = 128; +#else + _warp_size = 32; +#endif + + size_t device_max_wg = 0; + clGetDeviceInfo(cl_device, CL_DEVICE_MAX_WORK_GROUP_SIZE, sizeof(device_max_wg), &device_max_wg, nullptr); + _max_threads_per_block = static_cast(device_max_wg); + + size_t max_item_sizes[3]; + clGetDeviceInfo(cl_device, CL_DEVICE_MAX_WORK_ITEM_SIZES, sizeof(max_item_sizes), max_item_sizes, nullptr); + _block_size[0] = max_item_sizes[0]; + _block_size[1] = max_item_sizes[1]; + _block_size[2] = max_item_sizes[2]; + program_cache_ = std::make_unique(); +} + +int Handle::Internal::warpSize() const { return _warp_size; } + +int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; } + +int Handle::Internal::blockSizeX() const { return _block_size[0]; } + +int Handle::Internal::blockSizeY() const { return _block_size[1]; } + +int Handle::Internal::blockSizeZ() const { return _block_size[2]; } + +ProgramCache *Handle::Internal::programCache() const { return program_cache_.get(); } + +infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) { + *handle_ptr = new Handle(INFINI_DEVICE_OPENCL, device_id); + return INFINI_STATUS_SUCCESS; +} + +} // namespace device::opencl diff --git a/src/infiniop/devices/opencl/opencl_handle.h b/src/infiniop/devices/opencl/opencl_handle.h new file mode 100644 index 000000000..ef2a02c58 --- /dev/null +++ b/src/infiniop/devices/opencl/opencl_handle.h @@ -0,0 +1,28 @@ +#ifndef __INFINIOP_OPENCL_HANDLE_H__ +#define __INFINIOP_OPENCL_HANDLE_H__ + +#include "../../handle.h" +#include + +namespace device { +namespace opencl { + +struct Handle : public InfiniopHandle { + Handle(int device_id); + class Internal; + auto internal() const -> const std::shared_ptr &; + +protected: + Handle(infiniDevice_t device, int device_id); + +public: + static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id); + +private: + std::shared_ptr _internal; +}; + +} // namespace opencl +} // namespace device + +#endif // __INFINIOP_OPENCL_HANDLE_H__ diff --git a/src/infiniop/devices/opencl/opencl_kernel_common.h b/src/infiniop/devices/opencl/opencl_kernel_common.h new file mode 100644 index 000000000..d0e81b92e --- /dev/null +++ b/src/infiniop/devices/opencl/opencl_kernel_common.h @@ -0,0 +1,225 @@ +#ifndef __INFINIOP_OPENCL_KERNEL_COMMON_H__ +#define __INFINIOP_OPENCL_KERNEL_COMMON_H__ + +#include "infinicore.h" +#include +#include + +#ifndef CL_TARGET_OPENCL_VERSION +#define CL_TARGET_OPENCL_VERSION 300 +#endif +#include + +namespace device::opencl::kernel { + +inline size_t dtypeSize(infiniDtype_t dtype) { + switch (dtype) { + case INFINI_DTYPE_BYTE: + return 1; + case INFINI_DTYPE_BOOL: + return 1; + case INFINI_DTYPE_I8: + return 1; + case INFINI_DTYPE_U8: + return 1; + case INFINI_DTYPE_I16: + return 2; + case INFINI_DTYPE_U16: + return 2; + case INFINI_DTYPE_F16: + return 2; + case INFINI_DTYPE_I32: + return 4; + case INFINI_DTYPE_U32: + return 4; + case INFINI_DTYPE_F32: + return 4; + case INFINI_DTYPE_I64: + return 8; + case INFINI_DTYPE_U64: + return 8; + case INFINI_DTYPE_F64: + return 8; + default: + return 0; + } +} + +inline size_t indexToOffset( + size_t flat_index, + size_t ndim, + const size_t *shape, + const ptrdiff_t *strides) { + size_t res = 0; + for (size_t i = ndim; i-- > 0;) { + res += (flat_index % shape[i]) * strides[i]; + flat_index /= shape[i]; + } + return res; +} + +inline const char *clErrorString(cl_int err) { + switch (err) { + case CL_SUCCESS: + return "CL_SUCCESS"; + case CL_DEVICE_NOT_FOUND: + return "CL_DEVICE_NOT_FOUND"; + case CL_DEVICE_NOT_AVAILABLE: + return "CL_DEVICE_NOT_AVAILABLE"; + case CL_COMPILER_NOT_AVAILABLE: + return "CL_COMPILER_NOT_AVAILABLE"; + case CL_MEM_OBJECT_ALLOCATION_FAILURE: + return "CL_MEM_OBJECT_ALLOCATION_FAILURE"; + case CL_OUT_OF_RESOURCES: + return "CL_OUT_OF_RESOURCES"; + case CL_OUT_OF_HOST_MEMORY: + return "CL_OUT_OF_HOST_MEMORY"; + case CL_PROFILING_INFO_NOT_AVAILABLE: + return "CL_PROFILING_INFO_NOT_AVAILABLE"; + case CL_MEM_COPY_OVERLAP: + return "CL_MEM_COPY_OVERLAP"; + case CL_IMAGE_FORMAT_MISMATCH: + return "CL_IMAGE_FORMAT_MISMATCH"; + case CL_IMAGE_FORMAT_NOT_SUPPORTED: + return "CL_IMAGE_FORMAT_NOT_SUPPORTED"; + case CL_BUILD_PROGRAM_FAILURE: + return "CL_BUILD_PROGRAM_FAILURE"; + case CL_MAP_FAILURE: + return "CL_MAP_FAILURE"; + case CL_INVALID_VALUE: + return "CL_INVALID_VALUE"; + case CL_INVALID_DEVICE_TYPE: + return "CL_INVALID_DEVICE_TYPE"; + case CL_INVALID_PLATFORM: + return "CL_INVALID_PLATFORM"; + case CL_INVALID_DEVICE: + return "CL_INVALID_DEVICE"; + case CL_INVALID_CONTEXT: + return "CL_INVALID_CONTEXT"; + case CL_INVALID_QUEUE_PROPERTIES: + return "CL_INVALID_QUEUE_PROPERTIES"; + case CL_INVALID_COMMAND_QUEUE: + return "CL_INVALID_COMMAND_QUEUE"; + case CL_INVALID_HOST_PTR: + return "CL_INVALID_HOST_PTR"; + case CL_INVALID_MEM_OBJECT: + return "CL_INVALID_MEM_OBJECT"; + case CL_INVALID_IMAGE_FORMAT_DESCRIPTOR: + return "CL_INVALID_IMAGE_FORMAT_DESCRIPTOR"; + case CL_INVALID_IMAGE_SIZE: + return "CL_INVALID_IMAGE_SIZE"; + case CL_INVALID_SAMPLER: + return "CL_INVALID_SAMPLER"; + case CL_INVALID_BINARY: + return "CL_INVALID_BINARY"; + case CL_INVALID_BUILD_OPTIONS: + return "CL_INVALID_BUILD_OPTIONS"; + case CL_INVALID_PROGRAM: + return "CL_INVALID_PROGRAM"; + case CL_INVALID_PROGRAM_EXECUTABLE: + return "CL_INVALID_PROGRAM_EXECUTABLE"; + case CL_INVALID_KERNEL_NAME: + return "CL_INVALID_KERNEL_NAME"; + case CL_INVALID_KERNEL_DEFINITION: + return "CL_INVALID_KERNEL_DEFINITION"; + case CL_INVALID_KERNEL: + return "CL_INVALID_KERNEL"; + case CL_INVALID_ARG_INDEX: + return "CL_INVALID_ARG_INDEX"; + case CL_INVALID_ARG_VALUE: + return "CL_INVALID_ARG_VALUE"; + case CL_INVALID_ARG_SIZE: + return "CL_INVALID_ARG_SIZE"; + case CL_INVALID_KERNEL_ARGS: + return "CL_INVALID_KERNEL_ARGS"; + case CL_INVALID_WORK_DIMENSION: + return "CL_INVALID_WORK_DIMENSION"; + case CL_INVALID_WORK_GROUP_SIZE: + return "CL_INVALID_WORK_GROUP_SIZE"; + case CL_INVALID_WORK_ITEM_SIZE: + return "CL_INVALID_WORK_ITEM_SIZE"; + case CL_INVALID_GLOBAL_OFFSET: + return "CL_INVALID_GLOBAL_OFFSET"; + case CL_INVALID_EVENT_WAIT_LIST: + return "CL_INVALID_EVENT_WAIT_LIST"; + case CL_INVALID_EVENT: + return "CL_INVALID_EVENT"; + case CL_INVALID_OPERATION: + return "CL_INVALID_OPERATION"; + case CL_INVALID_GL_OBJECT: + return "CL_INVALID_GL_OBJECT"; + case CL_INVALID_BUFFER_SIZE: + return "CL_INVALID_BUFFER_SIZE"; + case CL_INVALID_MIP_LEVEL: + return "CL_INVALID_MIP_LEVEL"; + case CL_INVALID_GLOBAL_WORK_SIZE: + return "CL_INVALID_GLOBAL_WORK_SIZE"; + default: + return "UNKNOWN_CL_ERROR"; + } +} + +inline bool dtypeToClType(infiniDtype_t dt, std::string &out) noexcept { + switch (dt) { + case INFINI_DTYPE_INVALID: + return false; + case INFINI_DTYPE_BYTE: + return false; + case INFINI_DTYPE_BOOL: + out = "bool"; + return true; + case INFINI_DTYPE_I8: + out = "char"; + return true; + case INFINI_DTYPE_I16: + out = "short"; + return true; + case INFINI_DTYPE_I32: + out = "int"; + return true; + case INFINI_DTYPE_I64: + out = "long"; + return true; + case INFINI_DTYPE_U8: + out = "uchar"; + return true; + case INFINI_DTYPE_U16: + out = "ushort"; + return true; + case INFINI_DTYPE_U32: + out = "uint"; + return true; + case INFINI_DTYPE_U64: + out = "ulong"; + return true; + case INFINI_DTYPE_F8: + return false; + case INFINI_DTYPE_F16: + // half 需要 cl_khr_fp16 支持 + out = "half"; + return true; + case INFINI_DTYPE_F32: + out = "float"; + return true; + case INFINI_DTYPE_F64: + // double 需要 cl_khr_fp64 支持 + out = "double"; + return true; + case INFINI_DTYPE_C16: + return false; + case INFINI_DTYPE_C32: + return false; + case INFINI_DTYPE_C64: + return false; + case INFINI_DTYPE_C128: + return false; + case INFINI_DTYPE_BF16: + return false; + default: + return false; + } +} + +} // namespace device::opencl::kernel + +#endif diff --git a/src/infiniop/devices/opencl/opencl_program_cache.cc b/src/infiniop/devices/opencl/opencl_program_cache.cc new file mode 100644 index 000000000..9d488d4cb --- /dev/null +++ b/src/infiniop/devices/opencl/opencl_program_cache.cc @@ -0,0 +1,85 @@ +#include "opencl_program_cache.h" +#include +#include + +namespace device::opencl { + +std::shared_ptr ProgramCache::getOrBuildWithSource( + const std::string &op_name, + const std::string &source, + const std::string &build_opts, + cl_context context, + cl_device_id device) { + + std::ostringstream oss; + oss << build_opts << "#dev:" << (uintptr_t)device << "#ctx:" << (uintptr_t)context + << "#op_name:" << op_name; + std::string key = oss.str(); + + std::unique_lock lk(mtx_); + auto &entry_ptr = map_[key]; + if (!entry_ptr) { + entry_ptr.reset(new Entry()); + } + Entry *entry = entry_ptr.get(); + + if (entry->program) { + return entry->program; + } + if (entry->failed) { + return nullptr; + } + if (entry->building) { + entry->cv.wait(lk, [&] { return !entry->building; }); + return entry->program; + } + + entry->building = true; + lk.unlock(); + + cl_program raw_prog = nullptr; + cl_int clerr = CL_SUCCESS; + const char *src_ptr = source.c_str(); + size_t src_len = source.size(); + + raw_prog = clCreateProgramWithSource(context, 1, &src_ptr, &src_len, &clerr); + if (raw_prog == nullptr || clerr != CL_SUCCESS) { + raw_prog = nullptr; + } else { + clerr = clBuildProgram(raw_prog, 1, &device, build_opts.c_str(), nullptr, nullptr); + if (clerr != CL_SUCCESS) { + // print build log + size_t log_size = 0; + clGetProgramBuildInfo(raw_prog, device, CL_PROGRAM_BUILD_LOG, 0, nullptr, &log_size); + if (log_size > 0) { + std::vector log(log_size + 1); + clGetProgramBuildInfo(raw_prog, device, CL_PROGRAM_BUILD_LOG, log_size, log.data(), nullptr); + log[log_size] = '\0'; + std::cerr << "[OpenCL] build log (" << op_name << "): " << log.data() << std::endl; + } else { + std::cerr << "[OpenCL] build failed for op " << op_name << ", clBuildProgram returned " << clerr << std::endl; + } + clReleaseProgram(raw_prog); + raw_prog = nullptr; + } else { + } + } + + lk.lock(); + if (raw_prog == nullptr) { + entry->building = false; + entry->failed = true; + entry->cv.notify_all(); + return nullptr; + } else { + auto deleter = [](void *p) { if (p){ clReleaseProgram(reinterpret_cast(p)); +} }; + std::shared_ptr prog_shared(reinterpret_cast(raw_prog), deleter); + entry->program = prog_shared; + entry->building = false; + entry->cv.notify_all(); + return entry->program; + } +} + +} // namespace device::opencl diff --git a/src/infiniop/devices/opencl/opencl_program_cache.h b/src/infiniop/devices/opencl/opencl_program_cache.h new file mode 100644 index 000000000..b1693db14 --- /dev/null +++ b/src/infiniop/devices/opencl/opencl_program_cache.h @@ -0,0 +1,47 @@ +#ifndef __INFINIOP_OPENCL_PROGRAM_CACHE_H__ +#define __INFINIOP_OPENCL_PROGRAM_CACHE_H__ + +#include +#include +#include +#include +#include +#include +#include +#include +#ifndef CL_TARGET_OPENCL_VERSION +#define CL_TARGET_OPENCL_VERSION 300 +#endif +#include + +namespace device::opencl { + +class ProgramCache { +public: + ProgramCache() = default; + ~ProgramCache() = default; + + std::shared_ptr getOrBuildWithSource( + const std::string &op_name, + const std::string &source, + const std::string &build_opts, + cl_context context, + cl_device_id device); + + ProgramCache(const ProgramCache &) = delete; + ProgramCache &operator=(const ProgramCache &) = delete; + +private: + struct Entry { + std::shared_ptr program; + bool building = false; + bool failed = false; + std::condition_variable cv; + }; + mutable std::mutex mtx_; + std::unordered_map> map_; +}; + +} // namespace device::opencl + +#endif // __INFINIOP_OPENCL_PROGRAM_CACHE_H__ diff --git a/src/infiniop/ops/causal_softmax/opencl/causal_softmax_opencl.cc b/src/infiniop/ops/causal_softmax/opencl/causal_softmax_opencl.cc new file mode 100644 index 000000000..94e28c0fa --- /dev/null +++ b/src/infiniop/ops/causal_softmax/opencl/causal_softmax_opencl.cc @@ -0,0 +1,282 @@ +#include "causal_softmax_opencl.h" +#include "../../../../infinirt/opencl/infinirt_opencl.h" +#include "../../../devices/opencl/opencl_common.h" +#include +#include +#include + +const size_t ITEMS_THREAD = 8; + +static const char *CausalSoftmaxKernelSource = R"CLC( +#define CL_TARGET_OPENCL_VERSION 200 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifndef Tval +#define Tval float +#endif + +#ifndef ITEMS_THREAD +#define ITEMS_THREAD 8 +#endif + +#ifndef MASK +#define MASK causal_mask +#endif + +typedef unsigned int Tidx; + +bool causal_mask(Tidx tok_id, Tidx seq_len, + Tidx pos_id, Tidx att_len) { + // tok_id ↓ |<---att_len--->| + // 0 | * * ... * | + // 1 | * * ... * * | + // 2 | * * ... * * * | + // seq_len: 3 |---------------| + return att_len + tok_id >= pos_id + seq_len; +} + +kernel void softmax_register( + global Tval *att_y, + global Tval *att_x, + Tidx const seq_len, + Tidx const att_len, + int const head_stride_y, + int const tok_stride_y, + int const head_stride_x, + int const tok_stride_x) { + + Tidx const + head_idx = get_group_id(1), + tok_id = get_group_id(0), + l_idx = get_local_id(0), + l_len = get_local_size(0); + + global Tval *y = att_y + head_idx * head_stride_y + tok_id * tok_stride_y; + global Tval *x = att_x + head_idx * head_stride_x + tok_id * tok_stride_x; + + float + data[ITEMS_THREAD], + max_ = -FLT_MAX, + sum_ = 0; + + for (Tidx i = 0, idx = l_idx; idx < att_len; ++i, idx += l_len) { + data[i] = causal_mask(tok_id, seq_len, idx, att_len) ? x[idx] : -FLT_MAX; + max_ = fmax(max_, data[i]); + } + + max_ = work_group_reduce_max(max_); + + for (Tidx i = 0, idx = l_idx; idx < att_len; ++i, idx += l_len) { + data[i] = exp(data[i] - max_); + sum_ += data[i]; + } + + barrier(CLK_LOCAL_MEM_FENCE); + float const k = 1 / work_group_reduce_add(sum_); + + for (Tidx i = 0, idx = l_idx; idx < att_len; ++i, idx += l_len) + y[idx] = data[i] * k; +} + +kernel void softmax_global( + global Tval *att_y, + global Tval *att_x, + Tidx const seq_len, + Tidx const att_len, + int const head_stride_y, + int const tok_stride_y, + int const head_stride_x, + int const tok_stride_x) { + + Tidx const + head_idx = get_group_id(1), + tok_id = get_group_id(0), + l_idx = get_local_id(0), + l_len = get_local_size(0); + + global Tval *y = att_y + head_idx * head_stride_y + tok_id * tok_stride_y; + global Tval *x = att_x + head_idx * head_stride_x + tok_id * tok_stride_x; + + float + max_ = -FLT_MAX, + sum_ = 0; + + for (Tidx i = 0, idx = l_idx; idx < att_len; ++i, idx += l_len) { + float const data = causal_mask(tok_id, seq_len, idx, att_len) ? x[idx] : -FLT_MAX; + max_ = fmax(max_, data); + } + + max_ = work_group_reduce_max(max_); + + for (Tidx i = 0, idx = l_idx; idx < att_len; ++i, idx += l_len) { + float const data = exp(x[idx] - max_); + y[idx] = data; + sum_ += data; + } + + barrier(CLK_LOCAL_MEM_FENCE); + float const k = 1 / work_group_reduce_add(sum_); + + for (Tidx i = 0, idx = l_idx; idx < att_len; ++i, idx += l_len) + y[idx] *= k; +} + +)CLC"; + +inline int last_power_of_two(int n) { + int p = 1; + while (p * 2 <= n) { + p *= 2; + } + return p; +} + +namespace op::causal_softmax::opencl { + +using namespace device::opencl::kernel; + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { + auto info = CausalSoftmaxInfo::create(y_desc, x_desc); + CHECK_RESULT(info); + auto opaque = new Descriptor::Opaque{ + reinterpret_cast(handle)->internal()}; + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + // opaque, + info.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t launchKernel(void *y, const void *x, infiniDtype_t dtype, + size_t batch_size, size_t seq_len, size_t total_seq_len, + ptrdiff_t y_stride_b, ptrdiff_t y_stride_i, + ptrdiff_t x_stride_b, ptrdiff_t x_stride_i, + size_t block_size, cl_context context, + cl_device_id device, cl_command_queue cl_queue, + cl_program program) { + cl_int clerr; + + int group_size = last_power_of_two(std::min(total_seq_len, total_seq_len)); + int items_thread = (total_seq_len + group_size - 1) / group_size; + cl_kernel kernel; + if (items_thread <= ITEMS_THREAD) { + kernel = clCreateKernel(program, "softmax_register", &clerr); + if (clerr != CL_SUCCESS || kernel == nullptr) { + clReleaseProgram(program); + return INFINI_STATUS_INTERNAL_ERROR; + } + } else { + kernel = clCreateKernel(program, "softmax_global", &clerr); + if (clerr != CL_SUCCESS || kernel == nullptr) { + clReleaseProgram(program); + return INFINI_STATUS_INTERNAL_ERROR; + } + } + + int arg_idx = 0; + void *y_svm = NULL; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, y); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&y_svm, ((batch_size - 1) * y_stride_b + (seq_len - 1) * y_stride_i + total_seq_len) * dtypeSize(dtype)); + infinirtMemcpy(y_svm, y, ((batch_size - 1) * y_stride_b + (seq_len - 1) * y_stride_i + total_seq_len) * dtypeSize(dtype), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, y_svm); + } + clerr |= clSetKernelArgSVMPointer(kernel, arg_idx++, x); + if (clerr != CL_SUCCESS) { // for python test + void *x_svm = NULL; + infinirtMalloc(&x_svm, ((batch_size - 1) * x_stride_b + (seq_len - 1) * x_stride_i + total_seq_len) * dtypeSize(dtype)); + infinirtMemcpy(x_svm, x, ((batch_size - 1) * x_stride_b + (seq_len - 1) * x_stride_i + total_seq_len) * dtypeSize(dtype), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, x_svm); + } + + cl_int s_len = static_cast(seq_len); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_len); + cl_int att_len = static_cast(total_seq_len); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &att_len); + cl_int y_s_b = static_cast(y_stride_b); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &y_s_b); + cl_int y_s_i = static_cast(y_stride_i); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &y_s_i); + cl_int x_s_b = static_cast(x_stride_b); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &x_s_b); + cl_int x_s_i = static_cast(x_stride_i); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &x_s_i); + + size_t global_size[2] = {group_size * seq_len, batch_size}; + size_t local_size[2] = {size_t(group_size), size_t(1)}; + + clerr = clEnqueueNDRangeKernel(cl_queue, kernel, 2, nullptr, global_size, local_size, 0, nullptr, nullptr); + + if (clerr != CL_SUCCESS) { + fprintf(stderr, "[OpenCL] clEnqueueNDRangeKernel failed: %s (%d)\n", clErrorString(clerr), clerr); + fprintf(stderr, " global_size: %zu, local_size: %zu\n", global_size[0], local_size[0]); + clReleaseKernel(kernel); + return INFINI_STATUS_INTERNAL_ERROR; + } + if (y_svm) { // for python test + infinirtMemcpy(y, y_svm, ((batch_size - 1) * y_stride_b + (seq_len - 1) * y_stride_i + total_seq_len) * dtypeSize(dtype), INFINIRT_MEMCPY_D2H); + } + + clReleaseKernel(kernel); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size, + void *y, + const void *x, + void *stream_) const { + size_t block_size = _opaque->internal->maxThreadsPerBlock(); + void *device; + void *context; + CHECK_STATUS(infinirtGetOpenclDevice(&device)); + CHECK_STATUS(infinirtGetOpenclContext(&context)); + cl_context clcontext = static_cast(context); + cl_device_id cldevice = static_cast(device); + if (!stream_) { + CHECK_STATUS(infinirtGetOpenclStream(&stream_)); + } + cl_command_queue clqueue = static_cast(stream_); + auto dtype = _info.dtype; + std::string dt_val; + if (dtype == INFINI_DTYPE_F16) { + dt_val = "half"; + } else if (dtype == INFINI_DTYPE_F32) { + dt_val = "float"; + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + // build options + std::string build_opts; + build_opts += "-D Tval=" + dt_val + " "; + build_opts += "-D ITEMS_THREAD=" + std::to_string(ITEMS_THREAD) + " "; + build_opts += "-cl-std=CL2.0 "; + + auto prog_shared = this->_opaque->internal->programCache()->getOrBuildWithSource("causal_softmax", CausalSoftmaxKernelSource, build_opts, clcontext, cldevice); + if (!prog_shared) { + return INFINI_STATUS_INTERNAL_ERROR; + } + cl_program clprogram = reinterpret_cast(prog_shared.get()); + + CHECK_STATUS(launchKernel(y, x, dtype, _info.batch_size, _info.seq_len, _info.total_seq_len, + _info.y_stride_b, _info.y_stride_i, _info.x_stride_b, _info.x_stride_i, block_size, clcontext, cldevice, clqueue, clprogram)); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::causal_softmax::opencl diff --git a/src/infiniop/ops/causal_softmax/opencl/causal_softmax_opencl.h b/src/infiniop/ops/causal_softmax/opencl/causal_softmax_opencl.h new file mode 100644 index 000000000..d09ab742d --- /dev/null +++ b/src/infiniop/ops/causal_softmax/opencl/causal_softmax_opencl.h @@ -0,0 +1,8 @@ +#ifndef __CAUSAL_SOFTMAX_OPENCL_H__ +#define __CAUSAL_SOFTMAX_OPENCL_H__ + +#include "../causal_softmax.h" + +DESCRIPTOR(opencl) + +#endif \ No newline at end of file diff --git a/src/infiniop/ops/causal_softmax/operator.cc b/src/infiniop/ops/causal_softmax/operator.cc index b1be4c075..67476cfac 100644 --- a/src/infiniop/ops/causal_softmax/operator.cc +++ b/src/infiniop/ops/causal_softmax/operator.cc @@ -23,6 +23,9 @@ #ifdef ENABLE_MOORE_API #include "moore/causal_softmax_moore.h" #endif +#ifdef ENABLE_OPENCL_API +#include "opencl/causal_softmax_opencl.h" +#endif __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor( infiniopHandle_t handle, @@ -68,6 +71,9 @@ __C infiniStatus_t infiniopCreateCausalSoftmaxDescriptor( #endif #ifdef ENABLE_MOORE_API CREATE(INFINI_DEVICE_MOORE, moore) +#endif +#ifdef ENABLE_OPENCL_API + CREATE(INFINI_DEVICE_OPENCL, opencl) #endif } return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -110,6 +116,9 @@ __C infiniStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmaxDe #endif #ifdef ENABLE_MOORE_API GET(INFINI_DEVICE_MOORE, moore) +#endif +#ifdef ENABLE_OPENCL_API + GET(INFINI_DEVICE_OPENCL, opencl) #endif } return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -157,6 +166,9 @@ __C infiniStatus_t infiniopCausalSoftmax( #endif #ifdef ENABLE_MOORE_API CALCULATE(INFINI_DEVICE_MOORE, moore) +#endif +#ifdef ENABLE_OPENCL_API + CALCULATE(INFINI_DEVICE_OPENCL, opencl) #endif } return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -199,6 +211,9 @@ __C infiniStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftmaxD #endif #ifdef ENABLE_MOORE_API DESTROY(INFINI_DEVICE_MOORE, moore) +#endif +#ifdef ENABLE_OPENCL_API + DESTROY(INFINI_DEVICE_OPENCL, opencl) #endif } return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/gemm/opencl/gemm_opencl.cc b/src/infiniop/ops/gemm/opencl/gemm_opencl.cc new file mode 100644 index 000000000..b920d23fa --- /dev/null +++ b/src/infiniop/ops/gemm/opencl/gemm_opencl.cc @@ -0,0 +1,244 @@ +#include "gemm_opencl.h" +#include "../../../../infinirt/opencl/infinirt_opencl.h" +#include "../../../devices/opencl/opencl_common.h" +#include +#include +#include +#include + +static const char *GemmKernelSource = R"CLC( +#define CL_TARGET_OPENCL_VERSION 200 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifndef Tval +#define Tval float +#endif + +#ifdef USE_HALF +#define MUL(valueA, valueB) (float)(valueA * valueB) +#define SCAL(beta, p, alpha, value) (half)(beta * (float)(*p) + alpha * value) +#define SCAL1(alpha, value) (half)(alpha * value) +#else +#define MUL(valueA, valueB) valueA *valueB +#define SCAL(beta, p, alpha, value) beta *(*p) + alpha *value +#define SCAL1(alpha, value) (half)(alpha * value) +#endif + +__kernel void general_gemm(__global Tval *A, __global Tval *B, __global Tval *C, + int as, int ars, int acs, int bs, int brs, int bcs, + int cs, int crs, int ccs, int batch, + int M, int N, int K, float alpha, float beta) { + int g_idx = get_global_id(1); + int g_idy = get_global_id(0); + int row_id = g_idy / N; + int col_id = g_idy % N; + + Tval valueA = 0.0f; + Tval valueB = 0.0f; + float value = 0.0f; + + for (int i = 0; i < K; i++) { + valueA = *(A + g_idx * as + row_id * ars + i * acs); + valueB = *(B + g_idx * bs + i * brs + col_id * bcs); + value += MUL(valueA, valueB); + } + + __global Tval *p = C + g_idx * cs + row_id * crs + col_id * ccs; + if (beta != 0) + *p = SCAL(beta, p, alpha, value); + else + *p = SCAL1(alpha, value); +} + +)CLC"; + +namespace op::gemm::opencl { + +using namespace device::opencl::kernel; + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc) { + auto handle = reinterpret_cast(handle_); + auto dtype = c_desc->dtype(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + + auto result = MatmulInfo::create(c_desc, a_desc, b_desc, MatrixLayout::COL_MAJOR); + CHECK_RESULT(result); + + *desc_ptr = new Descriptor( + dtype, result.take(), 0, + new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t launchKernel( + void *c, const void *a, const void *b, + infiniDtype_t dtype, size_t batch, size_t m, size_t n, size_t k, float alpha, float beta, + ptrdiff_t c_stride, ptrdiff_t c_row_stride, ptrdiff_t c_col_stride, + ptrdiff_t a_stride, ptrdiff_t a_row_stride, ptrdiff_t a_col_stride, + ptrdiff_t b_stride, ptrdiff_t b_row_stride, ptrdiff_t b_col_stride, + size_t block_size, + cl_context context, + cl_device_id device, + cl_command_queue cl_queue, + cl_program program) { + + cl_int clerr; + cl_kernel kernel = clCreateKernel(program, "general_gemm", &clerr); + if (clerr != CL_SUCCESS || kernel == nullptr) { + std::cout << clErrorString(clerr) << std::endl; + return INFINI_STATUS_INTERNAL_ERROR; + } + + int arg_idx = 0; + void *a_svm = NULL; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, a); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&a_svm, ((batch - 1) * a_stride + (m - 1) * a_row_stride + (k - 1) * a_col_stride + 1) * dtypeSize(dtype)); + infinirtMemcpy(a_svm, a, ((batch - 1) * a_stride + (m - 1) * a_row_stride + (k - 1) * a_col_stride + 1) * dtypeSize(dtype), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, a_svm); + } + void *b_svm = NULL; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, b); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&b_svm, ((batch - 1) * b_stride + (k - 1) * b_row_stride + (n - 1) * b_col_stride + 1) * dtypeSize(dtype)); + infinirtMemcpy(b_svm, b, ((batch - 1) * b_stride + (k - 1) * b_row_stride + (n - 1) * b_col_stride + 1) * dtypeSize(dtype), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, b_svm); + } + void *c_svm = NULL; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, c); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&c_svm, ((batch - 1) * c_stride + (m - 1) * c_row_stride + (n - 1) * c_col_stride + 1) * dtypeSize(dtype)); + infinirtMemcpy(c_svm, c, ((batch - 1) * c_stride + (m - 1) * c_row_stride + (n - 1) * c_col_stride + 1) * dtypeSize(dtype), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, c_svm); + } + cl_int a_s = static_cast(a_stride); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &a_s); + cl_int a_r_s = static_cast(a_row_stride); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &a_r_s); + cl_int a_c_s = static_cast(a_col_stride); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &a_c_s); + cl_int b_s = static_cast(b_stride); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &b_s); + cl_int b_r_s = static_cast(b_row_stride); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &b_r_s); + cl_int b_c_s = static_cast(b_col_stride); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &b_c_s); + cl_int c_s = static_cast(c_stride); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &c_s); + cl_int c_r_s = static_cast(c_row_stride); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &c_r_s); + cl_int c_c_s = static_cast(c_col_stride); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &c_c_s); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &batch); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &m); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &n); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &k); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(float), &alpha); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(float), &beta); + + size_t global_size[2] = {m * n, batch}; + size_t local_size[2] = {block_size, 1}; + + clerr = clEnqueueNDRangeKernel(cl_queue, kernel, 2, nullptr, global_size, local_size, 0, nullptr, nullptr); + if (clerr != CL_SUCCESS) { + fprintf(stderr, "[OpenCL] clEnqueueNDRangeKernel failed: %s (%d)\n", clErrorString(clerr), clerr); + fprintf(stderr, " global_size: %zu, local_size: %zu\n", global_size[0], local_size[0]); + clReleaseKernel(kernel); + return INFINI_STATUS_INTERNAL_ERROR; + } + + if (c_svm) { // for python test + infinirtMemcpy(c, c_svm, ((batch - 1) * c_stride + (m - 1) * c_row_stride + (n - 1) * c_col_stride + 1) * dtypeSize(dtype), INFINIRT_MEMCPY_D2H); + } + + clReleaseKernel(kernel); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *c, + float beta, + const void *a, + const void *b, + float alpha, + void *stream) const { + + if (_info.is_transed) { + std::swap(a, b); + } + // _dtype + auto batch = _info.batch; + auto m = _info.m; + auto n = _info.n; + auto k = _info.k; + + auto c_stride = _info.c_matrix.stride; + auto c_row_stride = _info.c_matrix.row_stride; + auto c_col_stride = _info.c_matrix.col_stride; + auto a_stride = _info.a_matrix.stride; + auto a_row_stride = _info.a_matrix.row_stride; + auto a_col_stride = _info.a_matrix.col_stride; + auto b_stride = _info.b_matrix.stride; + auto b_row_stride = _info.b_matrix.row_stride; + auto b_col_stride = _info.b_matrix.col_stride; + + size_t block_size = _opaque->internal->maxThreadsPerBlock(); + void *device; + void *context; + CHECK_STATUS(infinirtGetOpenclDevice(&device)); + CHECK_STATUS(infinirtGetOpenclContext(&context)); + cl_context clcontext = static_cast(context); + cl_device_id cldevice = static_cast(device); + if (!stream) { + CHECK_STATUS(infinirtGetOpenclStream(&stream)); + } + cl_command_queue clqueue = static_cast(stream); + + std::string dt_val; + if (_dtype == INFINI_DTYPE_F16) { + dt_val = "half"; + } else if (_dtype == INFINI_DTYPE_F32) { + dt_val = "float"; + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + // build options + std::string build_opts; + build_opts += "-D Tval=" + dt_val + " "; + if (_dtype == INFINI_DTYPE_F16) { + build_opts += "-D USE_HALF=1 "; + } + build_opts += "-cl-std=CL2.0 "; + auto prog_shared = this->_opaque->internal->programCache()->getOrBuildWithSource("gemm", GemmKernelSource, build_opts, clcontext, cldevice); + if (!prog_shared) { + return INFINI_STATUS_INTERNAL_ERROR; + } + cl_program clprogram = reinterpret_cast(prog_shared.get()); + + CHECK_STATUS(launchKernel(c, a, b, _dtype, batch, m, n, k, alpha, beta, + c_stride, c_row_stride, c_col_stride, a_stride, a_row_stride, a_col_stride, b_stride, b_row_stride, b_col_stride, block_size, clcontext, cldevice, clqueue, clprogram)); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::gemm::opencl diff --git a/src/infiniop/ops/gemm/opencl/gemm_opencl.h b/src/infiniop/ops/gemm/opencl/gemm_opencl.h new file mode 100644 index 000000000..02636db1d --- /dev/null +++ b/src/infiniop/ops/gemm/opencl/gemm_opencl.h @@ -0,0 +1,8 @@ +#ifndef __GEMM_OPENCL_CUH__ +#define __GEMM_OPENCL_CUH__ + +#include "../gemm.h" + +DESCRIPTOR(opencl) + +#endif // __GEMM_OPENCL_CUH__ diff --git a/src/infiniop/ops/gemm/operator.cc b/src/infiniop/ops/gemm/operator.cc index 0a0995e8e..aaac1b7d0 100644 --- a/src/infiniop/ops/gemm/operator.cc +++ b/src/infiniop/ops/gemm/operator.cc @@ -23,6 +23,9 @@ #ifdef ENABLE_KUNLUN_API #include "kunlun/gemm_kunlun.h" #endif +#ifdef ENABLE_OPENCL_API +#include "opencl/gemm_opencl.h" +#endif __C infiniStatus_t infiniopCreateGemmDescriptor( infiniopHandle_t handle, @@ -73,6 +76,9 @@ __C infiniStatus_t infiniopCreateGemmDescriptor( #ifdef ENABLE_KUNLUN_API CREATE(INFINI_DEVICE_KUNLUN, kunlun); #endif +#ifdef ENABLE_OPENCL_API + CREATE(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -123,6 +129,9 @@ infiniopGetGemmWorkspaceSize( #ifdef ENABLE_KUNLUN_API GET(INFINI_DEVICE_KUNLUN, kunlun); #endif +#ifdef ENABLE_OPENCL_API + GET(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -181,6 +190,9 @@ __C infiniStatus_t infiniopGemm( #ifdef ENABLE_KUNLUN_API CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); #endif +#ifdef ENABLE_OPENCL_API + CALCULATE(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -229,6 +241,9 @@ infiniopDestroyGemmDescriptor(infiniopGemmDescriptor_t desc) { #ifdef ENABLE_KUNLUN_API DELETE(INFINI_DEVICE_KUNLUN, kunlun); #endif +#ifdef ENABLE_OPENCL_API + DELETE(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/random_sample/opencl/random_sample_opencl.cc b/src/infiniop/ops/random_sample/opencl/random_sample_opencl.cc new file mode 100644 index 000000000..4679fb58f --- /dev/null +++ b/src/infiniop/ops/random_sample/opencl/random_sample_opencl.cc @@ -0,0 +1,314 @@ +#include "random_sample_opencl.h" +#include "../../../../infinirt/opencl/infinirt_opencl.h" +#include "../../../devices/opencl/opencl_common.h" +#include "../info.h" +#include +#include +#include +#include + +static const char *RandomSampleKernelSource = R"CLC( +#define CL_TARGET_OPENCL_VERSION 200 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifndef Tval +#define Tval float +#endif +#ifndef Tidx +#define Tidx unsigned int +#endif +#ifndef GROUP_SIZE +#define GROUP_SIZE 1024 +#endif + +typedef unsigned int T_idx; + +typedef struct { + Tidx idx; + Tval val; +} KVPair; + +KVPair group_argmax(local KVPair *data, KVPair reg) { + T_idx const idx = get_local_id(0), + len = get_local_size(0); + + data[idx] = reg; + barrier(CLK_LOCAL_MEM_FENCE); + + for (T_idx stride = len >> 1; stride; stride >>= 1) { + if (idx < stride) { + local KVPair + *a = data + idx, + *b = data + idx + stride; + if (b->val > a->val) *a = *b; + else if (b->val == a->val) { + if(b->idx < a->idx) *a = *b; + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + return data[0]; +} + +kernel void argmax_build_pairs( + global Tval const *input, + global KVPair *output, + T_idx const n, + float init) { + + T_idx const + g_idx = get_global_id(0), + g_len = get_global_size(0), + l_idx = get_local_id(0); + + KVPair reg = {-1, (Tval) init}; + for (T_idx i = g_idx; i < n; i += g_len) { + Tval const val = input[i]; + if (val > reg.val) reg = (KVPair) {i, val}; + } + + local KVPair kv_pairs[GROUP_SIZE]; + reg = group_argmax(kv_pairs, reg); + + if (l_idx == 0) output[g_idx / GROUP_SIZE] = reg; +} + +kernel void argmax_reduce( + global KVPair const *pairs, + global Tidx *output, + T_idx const n, + float init) { + + T_idx const + g_idx = get_global_id(0), + g_len = get_global_size(0), + l_idx = get_local_id(0); + + KVPair reg = {-1, (Tval) init}; + for (T_idx i = g_idx; i < n; i += g_len) { + KVPair const pair = pairs[i]; + if (pair.val > reg.val) reg = pair; + else if (pair.val == reg.val) { + if (pair.idx < reg.idx) reg = pair; + } + } + + local KVPair kv_pairs[GROUP_SIZE]; + reg = group_argmax(kv_pairs, reg); + + // 最终结果写回 global + if (l_idx == 0) *output = reg.idx; +} + + +)CLC"; + +static size_t alignTo(size_t x, size_t a) { return (x + a - 1) / a * a; } + +namespace op::random_sample::opencl { + +using namespace device::opencl::kernel; + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t result_desc, + infiniopTensorDescriptor_t probs_desc) { + auto handle = reinterpret_cast(handle_); + auto internal = handle->internal(); + + auto result = RandomSampleInfo::create(result_desc, probs_desc); + CHECK_RESULT(result); + + auto info = result.take(); + size_t workspace_size; + size_t wg = internal->maxThreadsPerBlock(); + size_t num_partials = (info.n + wg - 1) / wg; + size_t argmax_tmp = num_partials * (dtypeSize(info.dt_i) + dtypeSize(info.dt_p)); + workspace_size = alignTo(argmax_tmp + 256, 256); + + *desc_ptr = new Descriptor( + info, + workspace_size, + new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +size_t Descriptor::minWorkspaceSize() const { + return _min_workspace_size; +} + +infiniStatus_t launchKernel( + const void *probs, size_t n, void *result, void *workspace, size_t workspace_size, + infiniDtype_t dt_i, infiniDtype_t dt_p, float random_val, float topp, int topk, + float temperature, size_t block_size, cl_context context, cl_device_id device, + cl_command_queue cl_queue, cl_program program) { + // todo: add random + // argmax + if (topk != 1) { + std::cout << " only argmax" << std::endl; + return INFINI_STATUS_INTERNAL_ERROR; + } + + cl_int clerr; + size_t n_pairs = (n + block_size - 1) / block_size / 2; + size_t reduce_size = std::min(n_pairs, block_size); + + cl_kernel kernel_0 = clCreateKernel(program, "argmax_build_pairs", &clerr); + if (clerr != CL_SUCCESS || kernel_0 == nullptr) { + return INFINI_STATUS_INTERNAL_ERROR; + } + cl_kernel kernel_1 = clCreateKernel(program, "argmax_reduce", &clerr); + if (clerr != CL_SUCCESS || kernel_1 == nullptr) { + return INFINI_STATUS_INTERNAL_ERROR; + } + int arg_idx = 0; + void *probs_svm = NULL; + clerr = clSetKernelArgSVMPointer(kernel_0, arg_idx++, probs); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&probs_svm, n * dtypeSize(dt_p)); + infinirtMemcpy(probs_svm, probs, n * dtypeSize(dt_p), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel_0, arg_idx++, probs_svm); + } + + void *w_svm = NULL; + clerr |= clSetKernelArgSVMPointer(kernel_0, arg_idx++, workspace); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&w_svm, workspace_size); + infinirtMemcpy(w_svm, workspace, workspace_size, INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel_0, arg_idx++, w_svm); + } + cl_int len = static_cast(n); + clerr |= clSetKernelArg(kernel_0, arg_idx++, sizeof(cl_int), &len); + + float neg_inf = -INFINITY; + clerr |= clSetKernelArg(kernel_0, arg_idx++, sizeof(float), &neg_inf); + + size_t global_size[1] = {n_pairs * block_size}; + size_t local_size[1] = {block_size}; + + clerr = clEnqueueNDRangeKernel(cl_queue, kernel_0, 1, nullptr, global_size, local_size, 0, nullptr, nullptr); + if (clerr != CL_SUCCESS) { + fprintf(stderr, "[OpenCL] clEnqueueNDRangeKernel failed: %s (%d)\n", clErrorString(clerr), clerr); + fprintf(stderr, " global_size: %zu, local_size: %zu\n", global_size[0], local_size[0]); + clReleaseKernel(kernel_0); + clReleaseKernel(kernel_1); + clReleaseProgram(program); + return INFINI_STATUS_INTERNAL_ERROR; + } + + arg_idx = 0; + if (!w_svm) { + clerr |= clSetKernelArgSVMPointer(kernel_1, arg_idx++, workspace); + } else { + clerr = clSetKernelArgSVMPointer(kernel_1, arg_idx++, w_svm); + } + void *result_svm = NULL; + clerr |= clSetKernelArgSVMPointer(kernel_1, arg_idx++, result); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&result_svm, sizeof(dt_i)); + infinirtMemcpy(result_svm, result, sizeof(dt_i), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel_1, arg_idx++, result_svm); + } + + len = static_cast(n_pairs); + clerr |= clSetKernelArg(kernel_1, arg_idx++, sizeof(cl_int), &len); + clerr |= clSetKernelArg(kernel_1, arg_idx++, sizeof(float), &neg_inf); + + global_size[0] = {reduce_size}; + local_size[0] = {reduce_size}; + + clerr = clEnqueueNDRangeKernel(cl_queue, kernel_1, 1, nullptr, global_size, local_size, 0, nullptr, nullptr); + if (clerr != CL_SUCCESS) { + fprintf(stderr, "[OpenCL] clEnqueueNDRangeKernel failed: %s (%d)\n", clErrorString(clerr), clerr); + fprintf(stderr, " global_size: %zu, local_size: %zu\n", global_size[0], local_size[0]); + clReleaseKernel(kernel_0); + clReleaseKernel(kernel_1); + clReleaseProgram(program); + return INFINI_STATUS_INTERNAL_ERROR; + } + + if (result_svm) { // for python test + infinirtMemcpy(result, result_svm, dtypeSize(dt_i), INFINIRT_MEMCPY_D2H); + } + + // cleanup kernel + clReleaseKernel(kernel_0); + clReleaseKernel(kernel_1); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *result, + const void *probs, + float random_val, + float topp, + int topk, + float temperature, + void *stream) const { + + if (workspace_size < _min_workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + size_t block_size = _opaque->internal->maxThreadsPerBlock(); + auto dt_i = _info.dt_i; + auto dt_p = _info.dt_p; + void *device; + void *context; + CHECK_STATUS(infinirtGetOpenclDevice(&device)); + CHECK_STATUS(infinirtGetOpenclContext(&context)); + cl_context clcontext = static_cast(context); + cl_device_id cldevice = static_cast(device); + if (!stream) { + CHECK_STATUS(infinirtGetOpenclStream(&stream)); + } + cl_command_queue clqueue = static_cast(stream); + + // Build program with cache + std::string dt_probs, dt_idx; + if (!dtypeToClType(dt_p, dt_probs)) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (!dtypeToClType(dt_i, dt_idx)) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + // build options + std::string build_opts; + build_opts += "-D Tval=" + dt_probs + " "; + build_opts += "-D Tidx=" + dt_idx + " "; + build_opts += "-D GROUP_SIZE=" + std::to_string(block_size) + " "; + if (dt_p == INFINI_DTYPE_F16) { + build_opts += "-D USE_HALF=1 "; + } + build_opts += "-cl-std=CL2.0 "; + + auto prog_shared = _opaque->internal->programCache()->getOrBuildWithSource("random_sample", RandomSampleKernelSource, build_opts, clcontext, cldevice); + if (!prog_shared) { + return INFINI_STATUS_INTERNAL_ERROR; + } + cl_program clprogram = reinterpret_cast(prog_shared.get()); + + CHECK_STATUS(launchKernel(probs, _info.n, result, workspace, workspace_size, dt_i, dt_p, random_val, topp, topk, + temperature, block_size, clcontext, cldevice, clqueue, clprogram)); + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::random_sample::opencl diff --git a/src/infiniop/ops/random_sample/opencl/random_sample_opencl.h b/src/infiniop/ops/random_sample/opencl/random_sample_opencl.h new file mode 100644 index 000000000..d58595c95 --- /dev/null +++ b/src/infiniop/ops/random_sample/opencl/random_sample_opencl.h @@ -0,0 +1,8 @@ +#ifndef __RANDOM_SAMPLE_OPENCL_CUH__ +#define __RANDOM_SAMPLE_OPENCL_CUH__ + +#include "../random_sample.h" + +DESCRIPTOR(opencl) + +#endif // __RANDOM_SAMPLE_OPENCL_CUH__ diff --git a/src/infiniop/ops/random_sample/operator.cc b/src/infiniop/ops/random_sample/operator.cc index 8239d97c5..07fa49bed 100644 --- a/src/infiniop/ops/random_sample/operator.cc +++ b/src/infiniop/ops/random_sample/operator.cc @@ -23,6 +23,9 @@ #ifdef ENABLE_KUNLUN_API #include "kunlun/random_sample_kunlun.h" #endif +#ifdef ENABLE_OPENCL_API +#include "opencl/random_sample_opencl.h" +#endif __C infiniStatus_t infiniopCreateRandomSampleDescriptor( @@ -71,6 +74,9 @@ infiniopCreateRandomSampleDescriptor( #ifdef ENABLE_KUNLUN_API CREATE(INFINI_DEVICE_KUNLUN, kunlun); #endif +#ifdef ENABLE_OPENCL_API + CREATE(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -122,6 +128,9 @@ __C infiniStatus_t infiniopGetRandomSampleWorkspaceSize( #ifdef ENABLE_KUNLUN_API GET(INFINI_DEVICE_KUNLUN, kunlun); #endif +#ifdef ENABLE_OPENCL_API + GET(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -183,6 +192,9 @@ __C infiniStatus_t infiniopRandomSample( #ifdef ENABLE_KUNLUN_API CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); #endif +#ifdef ENABLE_OPENCL_API + CALCULATE(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -231,6 +243,9 @@ __C infiniStatus_t infiniopDestroyRandomSampleDescriptor( #ifdef ENABLE_KUNLUN_API DELETE(INFINI_DEVICE_KUNLUN, kunlun); #endif +#ifdef ENABLE_OPENCL_API + DELETE(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/rearrange/opencl/rearrange_opencl.cc b/src/infiniop/ops/rearrange/opencl/rearrange_opencl.cc new file mode 100644 index 000000000..ede320c77 --- /dev/null +++ b/src/infiniop/ops/rearrange/opencl/rearrange_opencl.cc @@ -0,0 +1,720 @@ +#include "rearrange_opencl.h" +#include "../../../../infinirt/opencl/infinirt_opencl.h" +#include "../../../devices/opencl/opencl_common.h" +#include "../../../tensor.h" +#include +#include +#include +#include +#include +#include +#include + +namespace op::rearrange::opencl { + +using ARRAY_TYPE_STRIDE = ptrdiff_t; +using ARRAY_TYPE_SIZE = size_t; + +template +struct Constraint { + ElementType grid_idx; + ElementType block_idx; + ElementType grid_div_block; + ElementType total_len; +}; + +// RearrangeParams +struct RearrangeParams { + std::vector block_len; + std::vector src_block_stride; + std::vector dst_block_stride; + std::vector grid_len; + std::vector src_grid_stride; + std::vector dst_grid_stride; + size_t block_dim; + size_t block_len_total; + std::vector> constraints; + size_t unit_size; // bytes per unit +}; + +inline std::string unitSizeToOpenCLType(size_t unit_size) { + switch (unit_size) { + case 1: + return "uchar"; + case 2: + return "uchar2"; + case 4: + return "float"; + case 8: + return "float2"; + case 16: + return "float4"; + case 32: + return "double4"; // 需要 cl_khr_fp64 + default: + return ""; // unsupported + } +} + +inline std::string kernelName(const RearrangeParams &p, int constraint_num) { + std::ostringstream oss; + oss << "rearrange_unit_" << unitSizeToOpenCLType(p.unit_size) + << "_block_" << p.block_len.size() + << "_grid_" << p.grid_len.size() + << "_constrain_" << constraint_num; + return oss.str(); +} + +inline std::string generateOpenCLKernel(const RearrangeParams &p) { + auto grid_num = p.grid_len.size(); + auto block_num = p.block_len.size(); + auto constraint_num = p.constraints.size(); + CHECK_OR_RETURN(grid_num <= 5 && grid_num != 0, NULL); + CHECK_OR_RETURN(block_num <= 5 && block_num != 0, NULL); // grid和block的维数都不超过5 + CHECK_OR_RETURN(constraint_num <= 2, NULL); + auto unit_type = unitSizeToOpenCLType(p.unit_size); + if (unit_type.empty()) { + throw std::runtime_error("unsupported unit_size"); + } + + const size_t B = p.block_len.size(); + const size_t G = p.grid_len.size(); + + std::ostringstream s; + + s << "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"; // for double4 type + s << "typedef int ARRAY_TYPE_STRIDE;\n"; + s << "typedef unsigned int ARRAY_TYPE_SIZE;\n\n"; + + std::string kn = kernelName(p, constraint_num); + s << "// Generated kernel: " << kn << "\n"; + + s << "__kernel void " << kn << "(\n"; + s << " __global uchar* dst,\n"; + s << " __global const uchar* src,\n"; + s << " const ulong block_len_total"; + + for (size_t i = 0; i < B; ++i) { + s << ",\n const ARRAY_TYPE_SIZE block_len_" << i; + s << ",\n const ARRAY_TYPE_STRIDE src_block_stride_" << i; + s << ",\n const ARRAY_TYPE_STRIDE dst_block_stride_" << i; + } + + for (size_t i = 0; i < G; ++i) { + s << ",\n const ARRAY_TYPE_SIZE grid_len_" << i; + s << ",\n const ARRAY_TYPE_STRIDE src_grid_stride_" << i; + s << ",\n const ARRAY_TYPE_STRIDE dst_grid_stride_" << i; + } + + for (int c = 0; c < constraint_num; ++c) { + s << ",\n const ARRAY_TYPE_SIZE c" << c << "_grid_idx"; + s << ",\n const ARRAY_TYPE_SIZE c" << c << "_block_idx"; + s << ",\n const ARRAY_TYPE_SIZE c" << c << "_grid_div_block"; + s << ",\n const ARRAY_TYPE_SIZE c" << c << "_total_len"; + } + + s << " ) {\n"; + s << " const size_t local_tid = get_local_id(0);\n"; + s << " const size_t group_id = get_group_id(0);\n\n"; + + s << " if (local_tid >= (size_t)block_len_total) return;\n\n"; + + s << " __local ARRAY_TYPE_STRIDE shared_src_offset_arr[1];\n"; + s << " __local ARRAY_TYPE_STRIDE shared_dst_offset_arr[1];\n"; + if (constraint_num > 0) { + s << " __local ARRAY_TYPE_SIZE shared_constraints_grid_idx_multiple[" << constraint_num << "];\n"; + } + s << "\n"; + + s << " if (local_tid == 0) {\n"; + s << " ARRAY_TYPE_STRIDE src_offset = 0;\n"; + s << " ARRAY_TYPE_STRIDE dst_offset = 0;\n"; + if (constraint_num > 0) { + s << " ARRAY_TYPE_SIZE constraints_grid_idx_multiple[" << constraint_num << "];\n"; + } + s << " size_t rem = group_id;\n\n"; + + for (int i = (int)G - 1; i >= 0; --i) { + s << " {\n"; + s << " size_t idx = rem % (size_t)grid_len_" << i << ";\n"; + s << " rem = rem / (size_t)grid_len_" << i << ";\n"; + s << " src_offset += (ARRAY_TYPE_STRIDE)idx * src_grid_stride_" << i << ";\n"; + s << " dst_offset += (ARRAY_TYPE_STRIDE)idx * dst_grid_stride_" << i << ";\n"; + if (constraint_num > 0) { + for (int j = 0; j < constraint_num; ++j) { + s << " if (" << i << " == (int)c" << j << "_grid_idx) constraints_grid_idx_multiple[" << j << "] = idx * (ARRAY_TYPE_SIZE)c" << j << "_grid_div_block;\n"; + } + } + s << " }\n"; + } + + s << " shared_src_offset_arr[0] = src_offset;\n"; + s << " shared_dst_offset_arr[0] = dst_offset;\n"; + if (constraint_num > 0) { + s << " for (int j=0; j<" << constraint_num << "; ++j) shared_constraints_grid_idx_multiple[j] = constraints_grid_idx_multiple[j];\n"; + } + s << " }\n\n"; + + // barrier and load + s << " barrier(CLK_LOCAL_MEM_FENCE);\n\n"; + s << " ARRAY_TYPE_STRIDE src_offset = shared_src_offset_arr[0];\n"; + s << " ARRAY_TYPE_STRIDE dst_offset = shared_dst_offset_arr[0];\n"; + if (constraint_num > 0) { + s << " ARRAY_TYPE_SIZE constraints_grid_idx_multiple[" << constraint_num << "];\n"; + s << " for (int j=0;j<" << constraint_num << "; ++j) constraints_grid_idx_multiple[j] = shared_constraints_grid_idx_multiple[j];\n"; + } + s << "\n"; + s << " size_t rem_local = local_tid;\n\n"; + for (int i = (int)B - 1; i >= 1; --i) { + s << " {\n"; + s << " size_t idx = rem_local % (size_t)block_len_" << i << ";\n"; + s << " rem_local = rem_local / (size_t)block_len_" << i << ";\n"; + s << " src_offset += (ARRAY_TYPE_STRIDE)idx * src_block_stride_" << i << ";\n"; + s << " dst_offset += (ARRAY_TYPE_STRIDE)idx * dst_block_stride_" << i << ";\n"; + if (constraint_num > 0) { + for (int j = 0; j < constraint_num; ++j) { + s << " if (" << i << " == (int)c" << j << "_block_idx) { if (constraints_grid_idx_multiple[" << j << "] + idx >= c" << j << "_total_len) return; }\n"; + } + } + s << " }\n"; + } + + s << " {\n"; + s << " size_t idx = rem_local;\n"; + s << " src_offset += (ARRAY_TYPE_STRIDE)idx * src_block_stride_0;\n"; + s << " dst_offset += (ARRAY_TYPE_STRIDE)idx * dst_block_stride_0;\n"; + if (constraint_num > 0) { + for (int j = 0; j < constraint_num; ++j) { + s << " if (0 == (int)c" << j << "_block_idx) { if (constraints_grid_idx_multiple[" << j << "] + idx >= c" << j << "_total_len) return; }\n"; + } + } + s << " }\n\n"; + + s << " // unit_size = " << p.unit_size << " (bytes)\n"; + if (p.unit_size == 1) { + s << " dst[dst_offset] = src[src_offset];\n"; + } else { + s << " for (size_t b = 0; b < " << p.unit_size << "; ++b) {\n"; + s << " dst[dst_offset + b] = src[src_offset + b];\n"; + s << " }\n"; + } + + s << "}\n"; + + return s.str(); +} + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc) { + + auto dtype = y_desc->dtype(); + auto ndim = y_desc->ndim(); + + CHECK_OR_RETURN(x_desc->dtype() == dtype, INFINI_STATUS_BAD_TENSOR_DTYPE); + CHECK_OR_RETURN(x_desc->ndim() == ndim, INFINI_STATUS_BAD_TENSOR_SHAPE); + auto x_shape = x_desc->shape(); + auto y_shape = y_desc->shape(); + auto y_strides = y_desc->strides(); + auto x_strides = x_desc->strides(); + + CHECK_SAME_SHAPE(x_shape, y_shape); + + auto meta = utils::RearrangeMeta::create( + y_shape.data(), + y_strides.data(), + x_strides.data(), + ndim, + infiniSizeOf(dtype)); + + CHECK_RESULT(meta); + + *desc_ptr = new Descriptor( + std::move(*meta), + new Opaque{reinterpret_cast(handle)->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +struct Dim { + size_t len; + ARRAY_TYPE_STRIDE src_stride; + ARRAY_TYPE_STRIDE dst_stride; +}; + +struct SplitDim { + size_t choose_idx; + size_t num_per_block; + size_t num_per_grid; + int array_struct_idx_block; + int array_struct_idx_grid; + size_t dim_len; +}; + +utils::Result prepareRearrangeParams(const utils::RearrangeMeta &original_meta, int max_threads) { + RearrangeParams params; + + auto meta_result = original_meta.distributeUnit({32, 16, 8, 4, 2, 1}); + + CHECK_RESULT(meta_result); + + const utils::RearrangeMeta &meta = meta_result.take(); + + const size_t ndim = meta.ndim(); + const size_t unit = meta.unit(); + + if (ndim == 0) { + params.block_dim = 0; + params.block_len_total = 1; + params.block_len = {static_cast(1)}; + params.src_block_stride = {static_cast(0)}; + params.dst_block_stride = {static_cast(0)}; + params.grid_len = {static_cast(1)}; + params.src_grid_stride = {static_cast(0)}; + params.dst_grid_stride = {static_cast(0)}; + params.unit_size = unit; + return utils::Result(params); + } + + const ptrdiff_t *idx_strides = meta.idx_strides(); + const ptrdiff_t *dst_strides = meta.dst_strides(); + const ptrdiff_t *src_strides = meta.src_strides(); + + std::vector dims; + std::vector shape; + dims.reserve(ndim); + shape.reserve(ndim); + + auto prev_idx_stride = meta.count(); + for (size_t i = 0; i < ndim; ++i) { + size_t len = prev_idx_stride / idx_strides[i]; + shape.push_back(len); + dims.push_back({len, src_strides[i], dst_strides[i]}); + prev_idx_stride = idx_strides[i]; + } + + std::vector src_strides_desc_idx(ndim); + for (size_t i = 0; i < ndim; ++i) { + src_strides_desc_idx[i] = i; + } + std::sort(src_strides_desc_idx.begin(), src_strides_desc_idx.end(), + [&dims](size_t a, size_t b) { + return std::abs(dims[a].src_stride) > std::abs(dims[b].src_stride); + }); + + const size_t block_size = max_threads; + std::vector block_dim_choose(ndim, false); + + size_t block_elements = 1; + size_t block_src_elements = 1; + size_t block_dst_elements = 1; + size_t src_choose_idx = ndim; + size_t dst_choose_idx = ndim; + + std::vector split_dims; + + while (src_choose_idx > 0 && dst_choose_idx > 0) { + size_t src_idx = src_strides_desc_idx[src_choose_idx - 1]; + size_t dst_idx = dst_choose_idx - 1; + + if (src_idx == dst_idx) { + size_t idx = src_idx; + size_t len = shape[idx]; + + if (block_elements * len <= block_size) { + block_dim_choose[idx] = true; + block_elements *= len; + block_src_elements *= len; + block_dst_elements *= len; + src_choose_idx--; + dst_choose_idx--; + } else { + + size_t num_per_block = block_size / block_elements; + + if (num_per_block > 0 && len >= num_per_block && num_per_block > 1) { + size_t num_per_grid = (len + num_per_block - 1) / num_per_block; + + SplitDim split_dim = { + idx, + num_per_block, + num_per_grid, + 0, + 0, + len}; + split_dims.push_back(split_dim); + } + break; + } + } else { + double src_div_dst = static_cast(block_src_elements) / block_dst_elements; + double src_num_per_block = std::sqrt(block_size / (double)block_elements / src_div_dst); + double dst_num_per_block = src_num_per_block * src_div_dst; + + size_t src_current_dim_len = shape[src_idx]; + size_t dst_current_dim_len = shape[dst_idx]; + + if (static_cast(src_current_dim_len) < src_num_per_block) { + block_dim_choose[src_idx] = true; + block_elements *= src_current_dim_len; + block_src_elements *= src_current_dim_len; + src_choose_idx--; + } else if (static_cast(dst_current_dim_len) < dst_num_per_block) { + block_dim_choose[dst_idx] = true; + block_elements *= dst_current_dim_len; + block_dst_elements *= dst_current_dim_len; + dst_choose_idx--; + } else { + size_t src_num_per_block_int = static_cast(std::floor(src_num_per_block)); + size_t dst_num_per_block_int = static_cast(std::floor(dst_num_per_block)); + + size_t src_num_per_grid = (src_current_dim_len + src_num_per_block_int - 1) / src_num_per_block_int; + size_t dst_num_per_grid = (dst_current_dim_len + dst_num_per_block_int - 1) / dst_num_per_block_int; + + if (src_num_per_block_int > 1) { + if (src_num_per_grid == 1) { + + block_dim_choose[src_idx] = true; + block_elements *= src_current_dim_len; + block_src_elements *= src_current_dim_len; + src_choose_idx--; + } else { + SplitDim split_dim = { + src_idx, + src_num_per_block_int, + src_num_per_grid, + 0, + 0, + src_current_dim_len}; + split_dims.push_back(split_dim); + } + } + + if (dst_num_per_block_int > 1) { + if (dst_num_per_grid == 1) { + + block_dim_choose[dst_idx] = true; + block_elements *= dst_current_dim_len; + block_dst_elements *= dst_current_dim_len; + dst_choose_idx--; + } else { + + SplitDim split_dim = { + dst_idx, + dst_num_per_block_int, + dst_num_per_grid, + 0, + 0, + dst_current_dim_len}; + split_dims.push_back(split_dim); + } + } + + break; + } + } + } + + size_t block_dim = 0; + size_t block_len_total = 1; + + std::vector block_len; + std::vector src_block_stride; + std::vector dst_block_stride; + + std::vector grid_len; + std::vector src_grid_stride; + std::vector dst_grid_stride; + + for (size_t i = 0; i < ndim; ++i) { + if (block_dim_choose[i]) { + block_len.push_back(shape[i]); + src_block_stride.push_back(dims[i].src_stride); + dst_block_stride.push_back(dims[i].dst_stride); + block_dim += 1; + block_len_total *= shape[i]; + } + + for (size_t j = 0; j < split_dims.size(); ++j) { + if (i == split_dims[j].choose_idx) { + block_len.push_back(split_dims[j].num_per_block); + src_block_stride.push_back(dims[i].src_stride); + dst_block_stride.push_back(dims[i].dst_stride); + split_dims[j].array_struct_idx_block = static_cast(block_dim); + block_dim += 1; + block_len_total *= split_dims[j].num_per_block; + } + } + } + + for (size_t i = 0; i < ndim; ++i) { + if (!block_dim_choose[i]) { + bool is_split = false; + + for (size_t j = 0; j < split_dims.size(); ++j) { + if (i == split_dims[j].choose_idx) { + is_split = true; + grid_len.push_back(split_dims[j].num_per_grid); + src_grid_stride.push_back(dims[i].src_stride * split_dims[j].num_per_block); + dst_grid_stride.push_back(dims[i].dst_stride * split_dims[j].num_per_block); + split_dims[j].array_struct_idx_grid = static_cast(grid_len.size() - 1); + } + } + + if (!is_split) { + grid_len.push_back(shape[i]); + src_grid_stride.push_back(dims[i].src_stride); + dst_grid_stride.push_back(dims[i].dst_stride); + } + } + } + + if (grid_len.empty()) { + grid_len.push_back(1); + src_grid_stride.push_back(0); + dst_grid_stride.push_back(0); + } + + std::vector> constraints; + + for (size_t i = 0; i < split_dims.size(); ++i) { + if (split_dims[i].dim_len % split_dims[i].num_per_block == 0) { + continue; + } + Constraint constraint; + constraint.grid_idx = split_dims[i].array_struct_idx_grid; + constraint.block_idx = split_dims[i].array_struct_idx_block; + constraint.grid_div_block = split_dims[i].num_per_block; + constraint.total_len = split_dims[i].dim_len; + constraints.push_back(constraint); + } + + params.block_dim = block_dim; + params.block_len_total = block_len_total; + params.block_len = block_len; + params.src_block_stride = src_block_stride; + params.dst_block_stride = dst_block_stride; + params.grid_len = grid_len; + params.src_grid_stride = src_grid_stride; + params.dst_grid_stride = dst_grid_stride; + params.constraints = constraints; + params.unit_size = unit; + + return utils::Result(params); +} + +infiniStatus_t launchKernel( + void *y, + const void *x, + size_t y_size, size_t x_size, + const RearrangeParams ¶ms, + size_t grid_size, size_t block_size, size_t unit_size, + cl_context context, cl_device_id device, + cl_command_queue cl_queue, cl_program program) { + + cl_int clerr; + cl_kernel kernel = clCreateKernel(program, kernelName(params, params.constraints.size()).c_str(), &clerr); + if (clerr != CL_SUCCESS || kernel == nullptr) { + return INFINI_STATUS_INTERNAL_ERROR; + } + + int arg_idx = 0; + void *y_svm = NULL; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, y); + + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&y_svm, y_size); + clerr = infinirtMemcpy(y_svm, y, y_size, INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, y_svm); + } + void *x_svm = NULL; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, x); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&x_svm, x_size); + infinirtMemcpy(x_svm, x, x_size, INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, x_svm); + } + clerr = clSetKernelArg(kernel, arg_idx++, sizeof(unsigned int), &block_size); + if (clerr != CL_SUCCESS) { + std::cerr << "[OpenCL] clSetKernelArg(block_size) failed\n"; + } + + for (size_t i = 0; i < params.block_len.size(); ++i) { + ARRAY_TYPE_SIZE v_block_len = static_cast(params.block_len[i]); + clerr = clSetKernelArg(kernel, arg_idx++, sizeof(ARRAY_TYPE_SIZE), &v_block_len); + if (clerr != CL_SUCCESS) { + std::cerr << "[OpenCL] clSetKernelArg(block_len_" << i << ") failed\n"; + } + + ARRAY_TYPE_STRIDE v_src_bs = static_cast(params.src_block_stride[i]); + clerr = clSetKernelArg(kernel, arg_idx++, sizeof(ARRAY_TYPE_STRIDE), &v_src_bs); + if (clerr != CL_SUCCESS) { + std::cerr << "[OpenCL] clSetKernelArg(src_block_stride_" << i << ") failed\n"; + } + + ARRAY_TYPE_STRIDE v_dst_bs = static_cast(params.dst_block_stride[i]); + clerr = clSetKernelArg(kernel, arg_idx++, sizeof(ARRAY_TYPE_STRIDE), &v_dst_bs); + if (clerr != CL_SUCCESS) { + std::cerr << "[OpenCL] clSetKernelArg(dst_block_stride_" << i << ") failed\n"; + } + } + + for (size_t i = 0; i < params.grid_len.size(); ++i) { + ARRAY_TYPE_SIZE v_grid_len = static_cast(params.grid_len[i]); + clerr = clSetKernelArg(kernel, arg_idx++, sizeof(ARRAY_TYPE_SIZE), &v_grid_len); + if (clerr != CL_SUCCESS) { + std::cerr << "[OpenCL] clSetKernelArg(grid_len_" << i << ") failed\n"; + } + + ARRAY_TYPE_STRIDE v_src_gs = static_cast(params.src_grid_stride[i]); + clerr = clSetKernelArg(kernel, arg_idx++, sizeof(ARRAY_TYPE_STRIDE), &v_src_gs); + if (clerr != CL_SUCCESS) { + std::cerr << "[OpenCL] clSetKernelArg(src_grid_stride_" << i << ") failed\n"; + } + + ARRAY_TYPE_STRIDE v_dst_gs = static_cast(params.dst_grid_stride[i]); + clerr = clSetKernelArg(kernel, arg_idx++, sizeof(ARRAY_TYPE_STRIDE), &v_dst_gs); + if (clerr != CL_SUCCESS) { + std::cerr << "[OpenCL] clSetKernelArg(dst_grid_stride_" << i << ") failed\n"; + } + } + + for (size_t c = 0; c < params.constraints.size(); ++c) { + const auto &cc = params.constraints[c]; + ARRAY_TYPE_SIZE v0 = cc.grid_idx; + ARRAY_TYPE_SIZE v1 = cc.block_idx; + ARRAY_TYPE_SIZE v2 = cc.grid_div_block; + ARRAY_TYPE_SIZE v3 = cc.total_len; + clerr = clSetKernelArg(kernel, arg_idx++, sizeof(ARRAY_TYPE_SIZE), &v0); + if (clerr != CL_SUCCESS) { + std::cerr << "[OpenCL] clSetKernelArg(c" << c << "_grid_idx) failed\n"; + } + clerr = clSetKernelArg(kernel, arg_idx++, sizeof(ARRAY_TYPE_SIZE), &v1); + if (clerr != CL_SUCCESS) { + std::cerr << "[OpenCL] clSetKernelArg(c" << c << "_block_idx) failed\n"; + } + clerr = clSetKernelArg(kernel, arg_idx++, sizeof(ARRAY_TYPE_SIZE), &v2); + if (clerr != CL_SUCCESS) { + std::cerr << "[OpenCL] clSetKernelArg(c" << c << "_grid_div_block) failed\n"; + } + clerr = clSetKernelArg(kernel, arg_idx++, sizeof(ARRAY_TYPE_SIZE), &v3); + if (clerr != CL_SUCCESS) { + std::cerr << "[OpenCL] clSetKernelArg(c" << c << "_total_len) failed\n"; + } + } + + size_t global_size[1] = {block_size * grid_size}; + size_t local_size[1] = {block_size}; + + clerr = clEnqueueNDRangeKernel(cl_queue, kernel, 1, nullptr, global_size, local_size, 0, nullptr, nullptr); + if (clerr != CL_SUCCESS) { + fprintf(stderr, " global_size: %zu, local_size: %zu\n", global_size[0], local_size[0]); + clReleaseKernel(kernel); + return INFINI_STATUS_INTERNAL_ERROR; + } + + if (y_svm) { // for python test + infinirtMemcpy(y, y_svm, y_size, INFINIRT_MEMCPY_D2H); + } + clReleaseKernel(kernel); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *y, + const void *x, + void *stream) const { + + void *device; + void *context; + CHECK_STATUS(infinirtGetOpenclDevice(&device)); + CHECK_STATUS(infinirtGetOpenclContext(&context)); + cl_context clcontext = static_cast(context); + cl_device_id cldevice = static_cast(device); + if (!stream) { + CHECK_STATUS(infinirtGetOpenclStream(&stream)); + } + cl_command_queue clqueue = static_cast(stream); + + if (_meta.ndim() == 0) { + auto clerr = infinirtMemcpyAsync(y, x, _meta.unit(), INFINIRT_MEMCPY_D2D, clqueue); + + if (clerr != CL_SUCCESS) { + return INFINI_STATUS_INTERNAL_ERROR; + } + return INFINI_STATUS_SUCCESS; + } + + const size_t ndim = _meta.ndim(); + const size_t unit = _meta.unit(); + + const ptrdiff_t *idx_strides = _meta.idx_strides(); + const ptrdiff_t *dst_strides = _meta.dst_strides(); + const ptrdiff_t *src_strides = _meta.src_strides(); + std::vector shape; + shape.reserve(ndim); + + auto prev_idx_stride = _meta.count(); + for (size_t i = 0; i < ndim; ++i) { + size_t len = prev_idx_stride / idx_strides[i]; + shape.push_back(len); + prev_idx_stride = idx_strides[i]; + } + + size_t y_size = unit; + size_t x_size = unit; + for (size_t i = 0; i < ndim; ++i) { + y_size += dst_strides[i] * (shape[i] - 1); + x_size += src_strides[i] * (shape[i] - 1); + } + + int max_threads = _opaque->internal->maxThreadsPerBlock(); + + auto params_result = prepareRearrangeParams(_meta, max_threads); + CHECK_RESULT(params_result); + auto params = params_result.take(); + + size_t grid_size = 1; + for (size_t i = 0; i < params.grid_len.size(); ++i) { + grid_size *= params.grid_len[i]; + } + + if (grid_size == 0) { + return INFINI_STATUS_BAD_PARAM; + } + + infiniStatus_t status = INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + + size_t block_size = params.block_len_total; + + RearrangeParams params_copy = params; + std::string RearrangeKernelSource = generateOpenCLKernel(params_copy); + + std::string build_opts; + build_opts += "-cl-std=CL2.0 "; + + auto prog_shared = this->_opaque->internal->programCache()->getOrBuildWithSource(kernelName(params, params.constraints.size()).c_str(), RearrangeKernelSource, build_opts, clcontext, cldevice); + if (!prog_shared) { + return INFINI_STATUS_INTERNAL_ERROR; + } + cl_program clprogram = reinterpret_cast(prog_shared.get()); + + status = launchKernel(y, x, y_size, x_size, params, grid_size, block_size, _meta.unit(), clcontext, cldevice, clqueue, clprogram); + return status; + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::rearrange::opencl diff --git a/src/infiniop/ops/rearrange/opencl/rearrange_opencl.h b/src/infiniop/ops/rearrange/opencl/rearrange_opencl.h new file mode 100644 index 000000000..2bd295ed2 --- /dev/null +++ b/src/infiniop/ops/rearrange/opencl/rearrange_opencl.h @@ -0,0 +1,8 @@ +#ifndef __REARRANGE_OPENCL_H__ +#define __REARRANGE_OPENCL_H__ + +#include "../rearrange.h" + +DESCRIPTOR(opencl) + +#endif // __REARRANGE_OPENCL_H__ diff --git a/src/infiniop/ops/rearrange/operator.cc b/src/infiniop/ops/rearrange/operator.cc index c7a309033..1659d1ac0 100644 --- a/src/infiniop/ops/rearrange/operator.cc +++ b/src/infiniop/ops/rearrange/operator.cc @@ -23,6 +23,9 @@ #ifdef ENABLE_KUNLUN_API #include "kunlun/rearrange_kunlun.h" #endif +#ifdef ENABLE_OPENCL_API +#include "opencl/rearrange_opencl.h" +#endif __C infiniStatus_t infiniopCreateRearrangeDescriptor( infiniopHandle_t handle, @@ -69,6 +72,9 @@ __C infiniStatus_t infiniopCreateRearrangeDescriptor( #endif #ifdef ENABLE_KUNLUN_API CREATE(INFINI_DEVICE_KUNLUN, kunlun); +#endif +#ifdef ENABLE_OPENCL_API + CREATE(INFINI_DEVICE_OPENCL, opencl); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -120,6 +126,9 @@ __C infiniStatus_t infiniopRearrange( #ifdef ENABLE_KUNLUN_API CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); #endif +#ifdef ENABLE_OPENCL_API + CALCULATE(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -168,6 +177,9 @@ __C infiniStatus_t infiniopDestroyRearrangeDescriptor( #ifdef ENABLE_KUNLUN_API DELETE(INFINI_DEVICE_KUNLUN, kunlun); #endif +#ifdef ENABLE_OPENCL_API + DELETE(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/ops/rms_norm/opencl/rms_norm_opencl.cc b/src/infiniop/ops/rms_norm/opencl/rms_norm_opencl.cc new file mode 100644 index 000000000..c593eaab8 --- /dev/null +++ b/src/infiniop/ops/rms_norm/opencl/rms_norm_opencl.cc @@ -0,0 +1,234 @@ +#include "rms_norm_opencl.h" +#include "../../../../infinirt/opencl/infinirt_opencl.h" +#include "../../../devices/opencl/opencl_common.h" +#include "../../../devices/opencl/opencl_kernel_common.h" +#include +#include +#include + +static const char *RmsNormKernelSource = R"CLC( +#define CL_TARGET_OPENCL_VERSION 200 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifndef Ta +#define Ta float +#endif + +#ifndef Tw +#define Tw float +#endif + +#ifndef Tcompute +#define Tcompute float +#endif + +#ifndef ITEMS_THREAD +#define ITEMS_THREAD 1 +#endif + +typedef unsigned int Tidx; + +kernel void rms_norm( + global Ta *y_, + int const s_y_batch, + int const s_y_nhead, + global Ta const *x_, + int const s_x_batch, + int const s_x_nhead, + global Tw const *w, + float const epsilon, + Tidx const nhead, + Tidx const d) { + + Tidx g_idx = get_group_id(0), + l_idx = get_local_id(0), + l_len = get_local_size(0); + Tidx batch_id = g_idx / nhead, + nhead_id = g_idx % nhead; + global Ta + *y = y_ + batch_id * s_y_batch + nhead_id * s_y_nhead; + global Ta const + *x = x_ + batch_id * s_x_batch + nhead_id * s_x_nhead; + + Tcompute val_x[ITEMS_THREAD]; + Tcompute val_w[ITEMS_THREAD]; + Tcompute squared = 0; + for (Tidx i = 0, idx = l_idx; idx < d; ++i, idx += l_len) { + val_x[i] = (Tcompute)x[idx]; + val_w[i] = (Tcompute)w[idx]; + squared += val_x[i] * val_x[i]; + } + // TODO:测试加载相邻元素处理; + Tcompute mean_sq = work_group_reduce_add(squared) / (Tcompute)d; + Tcompute rms = native_rsqrt(mean_sq + (Tcompute)epsilon); + + for (Tidx i = 0, idx = l_idx; idx < d; ++i, idx += l_len) + y[idx] = (Ta)(rms * val_x[i] * val_w[i]); +} +)CLC"; + +namespace op::rms_norm::opencl { + +using namespace device::opencl::kernel; + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + float epsilon) { + auto result = RMSNormInfo::create(y_desc, x_desc, w_desc, epsilon); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +// launch kernel +infiniStatus_t launchKernel( + uint32_t batch_size, size_t nhead, size_t dim, + void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead, + const void *x, ptrdiff_t stride_x_batch, ptrdiff_t stride_x_nhead, + const void *w, infiniDtype_t wtype, + float epsilon, + size_t block_size, + cl_context context, + cl_device_id device, + cl_command_queue cl_queue, + cl_program program) { + + cl_int clerr; + + cl_kernel kernel = clCreateKernel(program, "rms_norm", &clerr); + if (clerr != CL_SUCCESS || kernel == nullptr) { + return INFINI_STATUS_INTERNAL_ERROR; + } + + int arg_idx = 0; + void *y_svm = NULL; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, y); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&y_svm, ((batch_size - 1) * stride_y_batch + (nhead - 1) * stride_y_nhead + dim) * infiniSizeOf(atype)); + infinirtMemcpy(y_svm, y, ((batch_size - 1) * stride_y_batch + (nhead - 1) * stride_y_nhead + dim) * infiniSizeOf(atype), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, y_svm); + } + cl_int s_y_batch = static_cast(stride_y_batch); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_y_batch); + cl_int s_y_nhead = static_cast(stride_y_nhead); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_y_nhead); + clerr |= clSetKernelArgSVMPointer(kernel, arg_idx++, x); + if (clerr != CL_SUCCESS) { // for python test + void *x_svm = NULL; + infinirtMalloc(&x_svm, ((batch_size - 1) * stride_x_batch + (nhead - 1) * stride_x_nhead + dim) * infiniSizeOf(atype)); + infinirtMemcpy(x_svm, x, ((batch_size - 1) * stride_x_batch + (nhead - 1) * stride_x_nhead + dim) * infiniSizeOf(atype), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, x_svm); + } + cl_int s_x_batch = static_cast(stride_x_batch); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_x_batch); + cl_int s_x_nhead = static_cast(stride_x_nhead); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_x_nhead); + clerr |= clSetKernelArgSVMPointer(kernel, arg_idx++, w); + if (clerr != CL_SUCCESS) { // for python test + void *w_svm = NULL; + infinirtMalloc(&w_svm, dim * infiniSizeOf(wtype)); + infinirtMemcpy(w_svm, w, dim * infiniSizeOf(wtype), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, w_svm); + } + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(float), &epsilon); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(int), &nhead); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(int), &dim); + + size_t global_size = batch_size * nhead * block_size; + + clerr = clEnqueueNDRangeKernel(cl_queue, kernel, 1, nullptr, &global_size, &block_size, 0, nullptr, nullptr); + if (clerr != CL_SUCCESS) { + fprintf(stderr, "[OpenCL] clEnqueueNDRangeKernel failed: %s (%d)\n", clErrorString(clerr), clerr); + fprintf(stderr, " global_size: %zu, local_size: %zu\n", global_size, block_size); + clReleaseKernel(kernel); + return INFINI_STATUS_INTERNAL_ERROR; + } + if (y_svm) { // for python test + infinirtMemcpy(y, y_svm, ((batch_size - 1) * stride_y_batch + (nhead - 1) * stride_y_nhead + dim) * infiniSizeOf(atype), INFINIRT_MEMCPY_D2H); + } + + clReleaseKernel(kernel); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *x, const void *w, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + auto stride_x_batch = _info.x_strides[0]; + auto stride_x_nhead = _info.x_strides[1]; + auto stride_y_batch = _info.y_strides[0]; + auto stride_y_nhead = _info.y_strides[1]; + auto dim = _info.dim(); + uint32_t batch_size = static_cast(_info.shape[0]); + size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; + size_t block_size = _opaque->internal->maxThreadsPerBlock(); + + void *device; + void *context; + + CHECK_STATUS(infinirtGetOpenclDevice(&device)); + CHECK_STATUS(infinirtGetOpenclContext(&context)); + cl_context clcontext = static_cast(context); + cl_device_id cldevice = static_cast(device); + if (!stream) { + CHECK_STATUS(infinirtGetOpenclStream(&stream)); + } + cl_command_queue clqueue = static_cast(stream); + + std::string dt_a, dt_w, dt_compute; + dt_compute = "float"; + if (!dtypeToClType(_info.atype, dt_a)) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (!dtypeToClType(_info.wtype, dt_w)) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + size_t items_perthread = (dim + block_size - 1) / block_size; + + // build options + std::string build_opts; + build_opts += "-D Ta=" + dt_a + " "; + build_opts += "-D Tw=" + dt_w + " "; + build_opts += "-D Tc=" + dt_compute + " "; + build_opts += "-D ITEMS_THREAD=" + std::to_string(items_perthread) + " "; + build_opts += "-cl-std=CL2.0 "; + + auto prog_shared = this->_opaque->internal->programCache()->getOrBuildWithSource("rms_norm", RmsNormKernelSource, build_opts, clcontext, cldevice); + if (!prog_shared) { + return INFINI_STATUS_INTERNAL_ERROR; + } + cl_program clprogram = reinterpret_cast(prog_shared.get()); + + CHECK_STATUS(launchKernel(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, block_size, clcontext, cldevice, clqueue, clprogram)); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::rms_norm::opencl diff --git a/src/infiniop/ops/rms_norm/opencl/rms_norm_opencl.h b/src/infiniop/ops/rms_norm/opencl/rms_norm_opencl.h new file mode 100644 index 000000000..bac27e825 --- /dev/null +++ b/src/infiniop/ops/rms_norm/opencl/rms_norm_opencl.h @@ -0,0 +1,8 @@ +#ifndef __RMS_NORM_OPENCL_H__ +#define __RMS_NORM_OPENCL_H__ + +#include "../rms_norm.h" + +DESCRIPTOR(opencl) + +#endif diff --git a/src/infiniop/ops/rms_norm/operator.cc b/src/infiniop/ops/rms_norm/operator.cc index 4311f516a..f75d267ca 100644 --- a/src/infiniop/ops/rms_norm/operator.cc +++ b/src/infiniop/ops/rms_norm/operator.cc @@ -23,6 +23,9 @@ #ifdef ENABLE_KUNLUN_API #include "kunlun/rms_norm_kunlun.h" #endif +#ifdef ENABLE_OPENCL_API +#include "opencl/rms_norm_opencl.h" +#endif __C infiniStatus_t infiniopCreateRMSNormDescriptor( infiniopHandle_t handle, @@ -72,6 +75,9 @@ __C infiniStatus_t infiniopCreateRMSNormDescriptor( #endif #ifdef ENABLE_MOORE_API CREATE(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_OPENCL_API + CREATE(INFINI_DEVICE_OPENCL, opencl); #endif } @@ -117,6 +123,9 @@ __C infiniStatus_t infiniopGetRMSNormWorkspaceSize(infiniopRMSNormDescriptor_t d #endif #ifdef ENABLE_MOORE_API GET(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_OPENCL_API + GET(INFINI_DEVICE_OPENCL, opencl); #endif } @@ -163,6 +172,9 @@ __C infiniStatus_t infiniopRMSNorm(infiniopRMSNormDescriptor_t desc, void *works #endif #ifdef ENABLE_MOORE_API CALCULATE(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_OPENCL_API + CALCULATE(INFINI_DEVICE_OPENCL, opencl); #endif } @@ -208,6 +220,9 @@ __C infiniStatus_t infiniopDestroyRMSNormDescriptor(infiniopRMSNormDescriptor_t #endif #ifdef ENABLE_MOORE_API DESTROY(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_OPENCL_API + DESTROY(INFINI_DEVICE_OPENCL, opencl); #endif } diff --git a/src/infiniop/ops/rope/opencl/rope_opencl.cc b/src/infiniop/ops/rope/opencl/rope_opencl.cc new file mode 100644 index 000000000..7433530fc --- /dev/null +++ b/src/infiniop/ops/rope/opencl/rope_opencl.cc @@ -0,0 +1,284 @@ +#include "rope_opencl.h" +#include "../../../../infinirt/opencl/infinirt_opencl.h" +#include "../../../devices/opencl/opencl_common.h" +#include +#include +#include + +static const char *RopeKernelSource = R"CLC( +#define CL_TARGET_OPENCL_VERSION 200 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifndef Tval +#define Tval float2 +#endif + +#ifndef Tpos +#define Tpos unsigned int +#endif + +#ifndef Ttable +#define Ttable float +#endif + +#ifdef USE_HALF +#define LOAD_DATA(ptr) vload_half2(0, (__global half *) ptr) +#define STORE_DATA(ptr, val) vstore_half2(val, 0, (__global half *) ptr) +#else +#define LOAD_DATA(ptr) (*ptr) +#define STORE_DATA(ptr, val) (*ptr = val) +#endif + +typedef unsigned int Tidx; + +__kernel void rope( + __global Tval *t, + int const ystride_token, + int const ystride_head, + __global Tval *x, + int const xstride_token, + int const xstride_head, + __global Tpos const *pos, + __global Ttable const *sin_table, + __global Ttable const *cos_table, + float const theta) { + + Tidx nh_l = get_local_size(0), + dh = get_local_size(1), + it = get_group_id(0), + ih_h = get_group_id(1), + ih_l = get_local_id(0), + ih = ih_h * nh_l + ih_l, + i = get_local_id(1); + + __global Tval *t2 = t + it * ystride_token + ih * ystride_head + i; + __global Tval *x2 = x + it * xstride_token + ih * xstride_head + i; + + float2 data = LOAD_DATA(x2); + + int index = pos[it] * dh + i; // 防越界 + float sin_val = sin_table[index]; + float cos_val = cos_table[index]; + + float2 result; + result.x = data.x * cos_val - data.y * sin_val; + result.y = data.x * sin_val + data.y * cos_val; + STORE_DATA(t2, result); +} + +)CLC"; + +namespace op::rope::opencl { + +using namespace device::opencl::kernel; + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t pos_desc, + infiniopTensorDescriptor_t sin_desc, + infiniopTensorDescriptor_t cos_desc, + infiniopRoPEAlgo_t algo) { + + auto handle = reinterpret_cast(handle_); + + auto info = RoPEInfo::createRoPEInfo(y_desc, x_desc, pos_desc, sin_desc, cos_desc, algo); + CHECK_RESULT(info); + + *desc_ptr = new Descriptor( + info.take(), + 0, + new Opaque{reinterpret_cast(handle)->internal()}, + handle->device, + handle->device_id); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t launchKernel( + size_t dimx, size_t dimy, size_t table_dim, + void *y, infiniDtype_t tdata, ptrdiff_t y_stride_seqlen, ptrdiff_t y_stride_nhead, + const void *x, ptrdiff_t x_stride_seqlen, ptrdiff_t x_stride_nhead, + infiniDtype_t tpos, const void *pos_ids, const void *sin_table, const void *cos_table, + size_t block_size, + cl_context context, + cl_device_id device, + cl_command_queue cl_queue, + cl_program program) { + + cl_int clerr; + cl_kernel kernel = clCreateKernel(program, "rope", &clerr); + if (clerr != CL_SUCCESS || kernel == nullptr) { + return INFINI_STATUS_INTERNAL_ERROR; + } + + int arg_idx = 0; + void *y_svm = NULL; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, y); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&y_svm, ((dimx - 1) * y_stride_seqlen + (dimy - 1) * y_stride_nhead + table_dim) * dtypeSize(tdata)); + infinirtMemcpy(y_svm, y, ((dimx - 1) * y_stride_seqlen + (dimy - 1) * y_stride_nhead + table_dim) * dtypeSize(tdata), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, y_svm); + } + cl_int s_y_batch = static_cast(y_stride_seqlen) / 2; + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_y_batch); + cl_int s_y_nhead = static_cast(y_stride_nhead) / 2; + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_y_nhead); + clerr |= clSetKernelArgSVMPointer(kernel, arg_idx++, x); + if (clerr != CL_SUCCESS) { // for python test + void *x_svm = NULL; + infinirtMalloc(&x_svm, ((dimx - 1) * x_stride_seqlen + (dimy - 1) * x_stride_nhead + table_dim * 2) * dtypeSize(tdata)); + infinirtMemcpy(x_svm, x, ((dimx - 1) * x_stride_seqlen + (dimy - 1) * x_stride_nhead + table_dim * 2) * dtypeSize(tdata), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, x_svm); + } + cl_int s_x_batch = static_cast(x_stride_seqlen) / 2; + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_x_batch); + cl_int s_x_nhead = static_cast(x_stride_nhead) / 2; + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_x_nhead); + + clerr |= clSetKernelArgSVMPointer(kernel, arg_idx++, pos_ids); + if (clerr != CL_SUCCESS) { // for python test + void *pos_svm = NULL; + infinirtMalloc(&pos_svm, dimx * dtypeSize(tpos)); + infinirtMemcpy(pos_svm, pos_ids, dimx * dtypeSize(tpos), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, pos_svm); + } + clerr |= clSetKernelArgSVMPointer(kernel, arg_idx++, sin_table); + if (clerr != CL_SUCCESS) { // for python test + void *sin_svm = NULL; + infinirtMalloc(&sin_svm, dimx * table_dim * dtypeSize(tdata)); + infinirtMemcpy(sin_svm, sin_table, dimx * table_dim * dtypeSize(tdata), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, sin_svm); + } + clerr |= clSetKernelArgSVMPointer(kernel, arg_idx++, cos_table); + if (clerr != CL_SUCCESS) { // for python test + void *cos_svm = NULL; + infinirtMalloc(&cos_svm, dimx * table_dim * dtypeSize(tdata)); + infinirtMemcpy(cos_svm, cos_table, dimx * table_dim * dtypeSize(tdata), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, cos_svm); + } + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(int), &table_dim); + + if (block_size % table_dim != 0) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + int max_nh_l = std::min(block_size / table_dim, dimy); + int nh_l = 1; + for (int candidate = max_nh_l; candidate >= 1; --candidate) { + if (dimy % candidate == 0) { + nh_l = candidate; + break; + } + } + int nh_h = dimy / nh_l; + size_t global_size[2] = {dimx * nh_l, nh_h * table_dim}; + size_t local_size[2] = {nh_l, table_dim}; + + clerr = clEnqueueNDRangeKernel(cl_queue, kernel, 2, nullptr, global_size, local_size, 0, nullptr, nullptr); + if (clerr != CL_SUCCESS) { + fprintf(stderr, "[OpenCL] clEnqueueNDRangeKernel failed: %s (%d)\n", clErrorString(clerr), clerr); + fprintf(stderr, " global_size: %zu, local_size: %zu\n", global_size[0], local_size[0]); + clReleaseKernel(kernel); + + return INFINI_STATUS_INTERNAL_ERROR; + } + if (y_svm) { // for python test + infinirtMemcpy(y, y_svm, ((dimx - 1) * y_stride_seqlen + (dimy - 1) * y_stride_nhead + table_dim * 2) * dtypeSize(tdata), INFINIRT_MEMCPY_D2H); + } + + clReleaseKernel(kernel); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *y, + const void *x, + const void *pos_ids, + const void *sin_table, + const void *cos_table, + void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + auto y_stride_seqlen = _info.y_stride_seqlen; + auto y_stride_nhead = _info.y_stride_nhead; + auto x_stride_seqlen = _info.x_stride_seqlen; + auto x_stride_nhead = _info.x_stride_nhead; + auto table_dim = _info.table_dim; + + auto dimx = uint32_t(_info.seqlen), + dimy = uint32_t(_info.nhead); + size_t block_size = _opaque->internal->maxThreadsPerBlock(); + auto tdata = _info.data_type; + auto tpos = _info.pos_type; + + void *device; + void *context; + CHECK_STATUS(infinirtGetOpenclDevice(&device)); + CHECK_STATUS(infinirtGetOpenclContext(&context)); + cl_context clcontext = static_cast(context); + cl_device_id cldevice = static_cast(device); + if (!stream) { + CHECK_STATUS(infinirtGetOpenclStream(&stream)); + } + cl_command_queue clqueue = static_cast(stream); + + std::string dt_val, dt_pos, dt_table; + if (tdata == INFINI_DTYPE_F16) { + dt_val = "half2"; + dt_table = "half"; + } else if (tdata == INFINI_DTYPE_F32) { + dt_val = "float2"; + dt_table = "float"; + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + if (tpos == INFINI_DTYPE_I32) { + dt_pos = "int"; + } else if (tpos == INFINI_DTYPE_U32) { + dt_pos = "uint"; + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + // build options + std::string build_opts; + build_opts += "-D Tval=" + dt_val + " "; + build_opts += "-D Tpos=" + dt_pos + " "; + build_opts += "-D Ttable=" + dt_table + " "; + if (tdata == INFINI_DTYPE_F16) { + build_opts += "-D USE_HALF=1 "; + } + build_opts += "-cl-std=CL2.0 "; + + auto prog_shared = this->_opaque->internal->programCache()->getOrBuildWithSource("rope", RopeKernelSource, build_opts, clcontext, cldevice); + if (!prog_shared) { + return INFINI_STATUS_INTERNAL_ERROR; + } + cl_program clprogram = reinterpret_cast(prog_shared.get()); + + CHECK_STATUS(launchKernel(dimx, dimy, table_dim, y, tdata, y_stride_seqlen, y_stride_nhead, + x, x_stride_seqlen, x_stride_nhead, tpos, pos_ids, sin_table, cos_table, block_size, clcontext, cldevice, clqueue, clprogram)); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::rope::opencl diff --git a/src/infiniop/ops/rope/opencl/rope_opencl.h b/src/infiniop/ops/rope/opencl/rope_opencl.h new file mode 100644 index 000000000..f92ae1d65 --- /dev/null +++ b/src/infiniop/ops/rope/opencl/rope_opencl.h @@ -0,0 +1,8 @@ +#ifndef __ROPE_OPENCL_H__ +#define __ROPE_OPENCL_H__ + +#include "../rope.h" + +DESCRIPTOR(opencl) + +#endif diff --git a/src/infiniop/ops/rope/operator.cc b/src/infiniop/ops/rope/operator.cc index d24ec4090..99971723b 100644 --- a/src/infiniop/ops/rope/operator.cc +++ b/src/infiniop/ops/rope/operator.cc @@ -23,6 +23,9 @@ #ifdef ENABLE_MOORE_API #include "moore/rope_moore.h" #endif +#ifdef ENABLE_OPENCL_API +#include "opencl/rope_opencl.h" +#endif __C infiniStatus_t infiniopCreateRoPEDescriptor( infiniopHandle_t handle, @@ -76,6 +79,9 @@ __C infiniStatus_t infiniopCreateRoPEDescriptor( #endif #ifdef ENABLE_CAMBRICON_API CREATE(INFINI_DEVICE_CAMBRICON, bang); +#endif +#ifdef ENABLE_OPENCL_API + CREATE(INFINI_DEVICE_OPENCL, opencl); #endif } @@ -121,6 +127,9 @@ __C infiniStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc, #endif #ifdef ENABLE_ASCEND_API GET(INFINI_DEVICE_ASCEND, ascend); +#endif +#ifdef ENABLE_OPENCL_API + GET(INFINI_DEVICE_OPENCL, opencl); #endif } @@ -175,6 +184,9 @@ __C infiniStatus_t infiniopRoPE( #endif #ifdef ENABLE_ASCEND_API CALCULATE(INFINI_DEVICE_ASCEND, ascend); +#endif +#ifdef ENABLE_OPENCL_API + CALCULATE(INFINI_DEVICE_OPENCL, opencl); #endif } @@ -221,6 +233,9 @@ infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc) { #endif #ifdef ENABLE_ASCEND_API DELETE(INFINI_DEVICE_ASCEND, ascend); +#endif +#ifdef ENABLE_OPENCL_API + DELETE(INFINI_DEVICE_OPENCL, opencl); #endif } diff --git a/src/infiniop/ops/swiglu/opencl/swiglu_opencl.cc b/src/infiniop/ops/swiglu/opencl/swiglu_opencl.cc new file mode 100644 index 000000000..4f87fc606 --- /dev/null +++ b/src/infiniop/ops/swiglu/opencl/swiglu_opencl.cc @@ -0,0 +1,219 @@ +#include "swiglu_opencl.h" +#include "../../../../infinirt/opencl/infinirt_opencl.h" +#include "../../../devices/opencl/opencl_common.h" +#include +#include +#include +#include + +static const char *SwigluKernelSource = R"CLC( +#define CL_TARGET_OPENCL_VERSION 200 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifndef Tval +#define Tval float +#endif + +typedef unsigned int Tidx; + +__kernel void swiglu( + __global Tval *c, + int const stride_c_b, + int const stride_c_s, + __global Tval *gate, + int const stride_gate_b, + int const stride_gate_s, + __global Tval *up, + int const stride_up_b, + int const stride_up_s, + int const seq) { + + Tidx g_idx = get_global_id(0); + Tidx g_idy = get_global_id(1); + Tidx g_idx_b = get_global_id(0) / seq; + Tidx g_idx_s = get_global_id(0) % seq; + + Tidx k = g_idx_b * stride_c_b + g_idx_s * stride_c_s + g_idy; + Tidx i = g_idx_b * stride_gate_b + g_idx_s * stride_gate_s + g_idy; + Tidx j = g_idx_b * stride_up_b + g_idx_s * stride_up_s + g_idy; + + Tval x = gate[i]; + Tval y = up[j]; + + Tval sig = 1.0f / (1.0f + exp(-x)); + c[k] = x * sig * y; +} + +)CLC"; + +namespace op::swiglu::opencl { + +using namespace device::opencl::kernel; + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create(infiniopHandle_t handle_, Descriptor **desc_ptr, + infiniopTensorDescriptor_t c_desc, + std::vector input_descs) { + auto handle = reinterpret_cast(handle_); + + auto dtype = c_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32); + + const auto &a_desc = input_descs[0]; + const auto &b_desc = input_descs[1]; + + auto info = SwigluInfo::create(c_desc, a_desc, b_desc); + CHECK_RESULT(info); + + size_t workspace_size = 0; + + *desc_ptr = new Descriptor( + info.take(), + workspace_size, + new Opaque{reinterpret_cast(handle)->internal()}, + handle->device, + handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t launchKernel( + void *c, void *a, void *b, + infiniDtype_t dtype, size_t batch, size_t seq, size_t hd, + ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b, + ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, + size_t block_size, + cl_context context, + cl_device_id device, + cl_command_queue cl_queue, + cl_program program) { + + cl_int clerr; + cl_kernel kernel = clCreateKernel(program, "swiglu", &clerr); + + if (clerr != CL_SUCCESS || kernel == nullptr) { + + return INFINI_STATUS_INTERNAL_ERROR; + } + + int arg_idx = 0; + void *c_svm = NULL; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, c); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&c_svm, ((batch - 1) * stride_batch_c + (seq - 1) * stride_seq_c + hd) * dtypeSize(dtype)); + infinirtMemcpy(c_svm, c, ((batch - 1) * stride_batch_c + (seq - 1) * stride_seq_c + hd) * dtypeSize(dtype), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, c_svm); + } + cl_int s_c_batch = static_cast(stride_batch_c); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_c_batch); + cl_int s_c_seq = static_cast(stride_seq_c); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_c_seq); + + void *b_svm = NULL; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, b); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&b_svm, ((batch - 1) * stride_batch_b + (seq - 1) * stride_seq_b + hd) * dtypeSize(dtype)); + infinirtMemcpy(b_svm, b, ((batch - 1) * stride_batch_b + (seq - 1) * stride_seq_b + hd) * dtypeSize(dtype), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, b_svm); + } + cl_int s_b_batch = static_cast(stride_batch_b); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_b_batch); + cl_int s_b_seq = static_cast(stride_seq_b); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_b_seq); + + void *a_svm = NULL; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, a); + if (clerr != CL_SUCCESS) { // for python test + infinirtMalloc(&a_svm, ((batch - 1) * stride_batch_a + (seq - 1) * stride_seq_a + hd) * dtypeSize(dtype)); + infinirtMemcpy(a_svm, a, ((batch - 1) * stride_batch_a + (seq - 1) * stride_seq_a + hd) * dtypeSize(dtype), INFINIRT_MEMCPY_H2D); + arg_idx -= 1; + clerr = clSetKernelArgSVMPointer(kernel, arg_idx++, a_svm); + } + cl_int s_a_batch = static_cast(stride_batch_a); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_a_batch); + cl_int s_a_seq = static_cast(stride_seq_a); + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &s_a_seq); + + clerr |= clSetKernelArg(kernel, arg_idx++, sizeof(cl_int), &seq); + + size_t global_size[2] = {batch * seq, hd}; + size_t local_size[2] = {1, block_size}; + + clerr = clEnqueueNDRangeKernel(cl_queue, kernel, 2, nullptr, global_size, local_size, 0, nullptr, nullptr); + if (clerr != CL_SUCCESS) { + fprintf(stderr, "[OpenCL] clEnqueueNDRangeKernel failed: %s (%d)\n", clErrorString(clerr), clerr); + fprintf(stderr, " global_size: %zu, local_size: %zu\n", global_size[0], local_size[0]); + clReleaseKernel(kernel); + return INFINI_STATUS_INTERNAL_ERROR; + } + if (c_svm) { // for python test + infinirtMemcpy(c, c_svm, ((batch - 1) * stride_batch_c + (seq - 1) * stride_seq_c + hd) * dtypeSize(dtype), INFINIRT_MEMCPY_D2H); + } + + clReleaseKernel(kernel); + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate(void *workspace, + size_t workspace_size, + void *c, + std::vector inputs, + void *stream) const { + auto batch = _info.ndim == 2 ? 1 : _info.shape[0]; + auto seq_len = _info.ndim == 2 ? _info.shape[0] : _info.shape[1]; + auto hidden_size = _info.shape[_info.ndim - 1]; + auto stride_batch_c = _info.ndim == 2 ? 1 : _info.c_strides[0]; + auto stride_batch_a = _info.ndim == 2 ? 1 : _info.a_strides[0]; + auto stride_batch_b = _info.ndim == 2 ? 1 : _info.b_strides[0]; + auto stride_seq_c = _info.ndim == 2 ? _info.c_strides[0] : _info.c_strides[1]; + auto stride_seq_a = _info.ndim == 2 ? _info.a_strides[0] : _info.a_strides[1]; + auto stride_seq_b = _info.ndim == 2 ? _info.b_strides[0] : _info.b_strides[1]; + size_t block_size = _opaque->internal->maxThreadsPerBlock(); + void *device; + void *context; + CHECK_STATUS(infinirtGetOpenclDevice(&device)); + CHECK_STATUS(infinirtGetOpenclContext(&context)); + cl_context clcontext = static_cast(context); + cl_device_id cldevice = static_cast(device); + if (!stream) { + CHECK_STATUS(infinirtGetOpenclStream(&stream)); + } + cl_command_queue clqueue = static_cast(stream); + + std::string dt_val; + if (_info.dtype == INFINI_DTYPE_F16) { + dt_val = "half"; + } else if (_info.dtype == INFINI_DTYPE_F32) { + dt_val = "float"; + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + // build options + std::string build_opts; + build_opts += "-D Tval=" + dt_val + " "; + if (_info.dtype == INFINI_DTYPE_F16) { + build_opts += "-D USE_HALF=1 "; + } + build_opts += "-cl-std=CL2.0 "; + + auto prog_shared = this->_opaque->internal->programCache()->getOrBuildWithSource("swiglu", SwigluKernelSource, build_opts, clcontext, cldevice); + if (!prog_shared) { + return INFINI_STATUS_INTERNAL_ERROR; + } + cl_program clprogram = reinterpret_cast(prog_shared.get()); + + CHECK_STATUS(launchKernel(c, (void *)inputs[0], (void *)inputs[1], _info.dtype, batch, seq_len, hidden_size, stride_batch_c, stride_batch_a, stride_batch_b, stride_seq_c, stride_seq_a, stride_seq_b, block_size, clcontext, cldevice, clqueue, clprogram)); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::swiglu::opencl diff --git a/src/infiniop/ops/swiglu/opencl/swiglu_opencl.h b/src/infiniop/ops/swiglu/opencl/swiglu_opencl.h new file mode 100644 index 000000000..7ea9e13b8 --- /dev/null +++ b/src/infiniop/ops/swiglu/opencl/swiglu_opencl.h @@ -0,0 +1,89 @@ +#ifndef __SWIGLU_OPENCL_H__ +#define __SWIGLU_OPENCL_H__ + +#include "../../../../utils.h" +#include "../../../operator.h" +#include "../../../tensor.h" + +namespace op::swiglu::opencl { +class SwigluInfo { + +private: + SwigluInfo() = default; + +public: + infiniDtype_t dtype; + std::vector shape; + int32_t ndim; + std::vector c_strides; + std::vector a_strides; + std::vector b_strides; + + static utils::Result create(infiniopTensorDescriptor_t c_desc, infiniopTensorDescriptor_t a_desc, infiniopTensorDescriptor_t b_desc) { + CHECK_OR_RETURN(c_desc && a_desc && b_desc, INFINI_STATUS_BAD_PARAM); + CHECK_OR_RETURN(!c_desc->hasBroadcastDim(), INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(c_desc->ndim() == a_desc->ndim() + && c_desc->ndim() == b_desc->ndim() + && (c_desc->ndim() == 2 || c_desc->ndim() == 3), + INFINI_STATUS_BAD_TENSOR_SHAPE); + CHECK_SAME_SHAPE(c_desc->shape(), a_desc->shape(), b_desc->shape()); + int32_t ndim = c_desc->ndim(); + CHECK_OR_RETURN(c_desc->stride(ndim - 1) == 1 + && a_desc->stride(ndim - 1) == 1 + && b_desc->stride(ndim - 1) == 1, + INFINI_STATUS_BAD_TENSOR_STRIDES); + CHECK_OR_RETURN(c_desc->dtype() == a_desc->dtype() + && c_desc->dtype() == b_desc->dtype(), + INFINI_STATUS_BAD_TENSOR_DTYPE); + + return utils::Result(SwigluInfo{ + c_desc->dtype(), + c_desc->shape(), + ndim, + c_desc->strides(), + a_desc->strides(), + b_desc->strides(), + }); + } +}; + +class Descriptor final : public InfiniopDescriptor { + struct Opaque; + Opaque *_opaque; + SwigluInfo _info; + size_t _workspace_size; + + Descriptor( + SwigluInfo info, + size_t workspace_size, + Opaque *opaque, + infiniDevice_t device_type, + int device_id) + : InfiniopDescriptor{device_type, device_id}, + _opaque(opaque), + _info(info), + _workspace_size(workspace_size) {} + +public: + ~Descriptor(); + static infiniStatus_t create(infiniopHandle_t handle, Descriptor **desc_ptr, + infiniopTensorDescriptor_t c_desc, + std::vector input_descs); + size_t workspaceSize() const { return _workspace_size; } + + infiniStatus_t calculate( + void *workspace, + size_t workspace_size, + void *c, + std::vector inputs, + void *stream) const; +}; + +// extern "C" infiniStatus_t swiglu_kernel_launch( +// void *c, void *a, void *b, +// infiniDtype_t dtype, size_t batch, size_t seq, size_t hd, +// ptrdiff_t stride_batch_c, ptrdiff_t stride_batch_a, ptrdiff_t stride_batch_b, +// ptrdiff_t stride_seq_c, ptrdiff_t stride_seq_a, ptrdiff_t stride_seq_b, void *stream); + +} // namespace op::swiglu::opencl +#endif // __SWIGLU_OPENCL_H__ diff --git a/src/infiniop/ops/swiglu/operator.cc b/src/infiniop/ops/swiglu/operator.cc index 9d8e6406a..f1b168865 100644 --- a/src/infiniop/ops/swiglu/operator.cc +++ b/src/infiniop/ops/swiglu/operator.cc @@ -23,6 +23,9 @@ #ifdef ENABLE_MOORE_API #include "moore/swiglu_moore.h" #endif +#ifdef ENABLE_OPENCL_API +#include "opencl/swiglu_opencl.h" +#endif __C infiniStatus_t infiniopCreateSwiGLUDescriptor( infiniopHandle_t handle, @@ -72,6 +75,9 @@ __C infiniStatus_t infiniopCreateSwiGLUDescriptor( #ifdef ENABLE_MOORE_API CREATE(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_OPENCL_API + CREATE(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -117,6 +123,9 @@ __C infiniStatus_t infiniopGetSwiGLUWorkspaceSize(infiniopSwiGLUDescriptor_t des #endif #ifdef ENABLE_MOORE_API GET(INFINI_DEVICE_MOORE, moore); +#endif +#ifdef ENABLE_OPENCL_API + GET(INFINI_DEVICE_OPENCL, opencl); #endif } @@ -171,6 +180,9 @@ __C infiniStatus_t infiniopSwiGLU( #ifdef ENABLE_MOORE_API CALCULATE(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_OPENCL_API + CALCULATE(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -219,6 +231,9 @@ infiniopDestroySwiGLUDescriptor(infiniopSwiGLUDescriptor_t desc) { #ifdef ENABLE_MOORE_API DELETE(INFINI_DEVICE_MOORE, moore); #endif +#ifdef ENABLE_OPENCL_API + DELETE(INFINI_DEVICE_OPENCL, opencl); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/test/infiniop/libinfiniop/devices.py b/test/infiniop/libinfiniop/devices.py index db2e8ae4d..e6c65a9fa 100644 --- a/test/infiniop/libinfiniop/devices.py +++ b/test/infiniop/libinfiniop/devices.py @@ -9,6 +9,7 @@ class InfiniDeviceEnum: KUNLUN = 7 HYGON = 8 QY = 9 + OPENCL = 10 InfiniDeviceNames = { @@ -22,6 +23,7 @@ class InfiniDeviceEnum: InfiniDeviceEnum.KUNLUN: "Kunlun", InfiniDeviceEnum.HYGON: "Hygon", InfiniDeviceEnum.QY: "QY", + InfiniDeviceEnum.OPENCL: "Opencl", } # Mapping that maps InfiniDeviceEnum to torch device string @@ -36,4 +38,5 @@ class InfiniDeviceEnum: InfiniDeviceEnum.KUNLUN: "cuda", InfiniDeviceEnum.HYGON: "cuda", InfiniDeviceEnum.QY: "cuda", + InfiniDeviceEnum.OPENCL: "cpu", } diff --git a/test/infiniop/libinfiniop/utils.py b/test/infiniop/libinfiniop/utils.py index 9b2e9798b..946a218b1 100644 --- a/test/infiniop/libinfiniop/utils.py +++ b/test/infiniop/libinfiniop/utils.py @@ -433,6 +433,11 @@ def get_args(): action="store_true", help="Run HYGON DCU test", ) + parser.add_argument( + "--opencl", + action="store_true", + help="Run OPENCL test", + ) return parser.parse_args() @@ -693,6 +698,7 @@ def test_operator(device, test_func, test_cases, tensor_dtypes): to be passed to `test_func`. - tensor_dtypes (list): A list of tensor data types (e.g., `torch.float32`) to test. """ + LIBINFINIOP.infinirtInit() LIBINFINIOP.infinirtSetDevice(device, ctypes.c_int(0)) handle = create_handle() tensor_dtypes = filter_tensor_dtypes_by_device(device, tensor_dtypes) @@ -730,6 +736,8 @@ def get_test_devices(args): devices_to_test.append(InfiniDeviceEnum.ILUVATAR) if args.qy: devices_to_test.append(InfiniDeviceEnum.QY) + if args.opencl: + devices_to_test.append(InfiniDeviceEnum.OPENCL) if args.cambricon: import torch_mlu @@ -768,6 +776,12 @@ def get_sync_func(device): if device == InfiniDeviceEnum.CPU or device == InfiniDeviceEnum.CAMBRICON: sync = None + elif device == InfiniDeviceEnum.OPENCL: + + def opencl_sync(): + LIBINFINIOP.infinirtDeviceSynchronize() + + return opencl_sync else: sync = getattr(torch, torch_device_map[device]).synchronize diff --git a/xmake.lua b/xmake.lua index 32ce49426..7e5d10172 100644 --- a/xmake.lua +++ b/xmake.lua @@ -308,6 +308,9 @@ target("infiniop") if has_config("hygon-dcu") then add_deps("infiniop-hygon") end + if has_config("opencl") then + add_deps("infiniop-opencl") + end set_languages("cxx17") add_files("src/infiniop/devices/handle.cc") add_files("src/infiniop/ops/*/operator.cc") diff --git a/xmake/opencl.lua b/xmake/opencl.lua index 255881aa3..dcf1cc104 100644 --- a/xmake/opencl.lua +++ b/xmake/opencl.lua @@ -5,6 +5,26 @@ if not (OPENCL_HEADERS and OPENCL_LIB) then raise("Please set OPENCL_HEADERS and OPENCL_LIB environment variables") end +target("infiniop-opencl") + set_kind("static") + add_deps("infini-utils") + on_install(function (target) end) + set_languages("cxx17") + + on_load(function (target) + target:add("includedirs", OPENCL_HEADERS) + target:add("linkdirs", OPENCL_LIB) + target:add("links", "OpenCL") + end) + + if not is_plat("windows") then + add_cxflags("-fPIC") + end + + add_files("../src/infiniop/devices/opencl/*.cc", + "../src/infiniop/ops/*/opencl/*.cc") +target_end() + target("infinirt-opencl") set_kind("static") add_deps("infini-utils")