diff --git a/include/infinicore/ops.hpp b/include/infinicore/ops.hpp index 0937a4821..aa92fe58b 100644 --- a/include/infinicore/ops.hpp +++ b/include/infinicore/ops.hpp @@ -1,6 +1,7 @@ #pragma once #include "ops/add.hpp" +#include "ops/add_rms_norm.hpp" #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" #include "ops/matmul.hpp" diff --git a/include/infinicore/ops/add_rms_norm.hpp b/include/infinicore/ops/add_rms_norm.hpp new file mode 100644 index 000000000..e8a955a3c --- /dev/null +++ b/include/infinicore/ops/add_rms_norm.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include "../device.hpp" +#include "common/op.hpp" +#include + +namespace infinicore::op { +class AddRMSNorm { +public: + using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float); + static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); + static common::OpDispatcher &dispatcher(); +}; + +// Fused Add and RMS Normalization +// Returns: (normalized_result, add_result) +// The add_result can be used as residual for subsequent layers +std::pair add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); +void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); +} // namespace infinicore::op diff --git a/include/infiniop.h b/include/infiniop.h index 92e6f5963..16cb23d36 100644 --- a/include/infiniop.h +++ b/include/infiniop.h @@ -3,6 +3,7 @@ #include "infiniop/handle.h" #include "infiniop/ops/add.h" +#include "infiniop/ops/add_rms_norm.h" #include "infiniop/ops/attention.h" #include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/clip.h" diff --git a/include/infiniop/ops/add_rms_norm.h b/include/infiniop/ops/add_rms_norm.h new file mode 100644 index 000000000..7742c1343 --- /dev/null +++ b/include/infiniop/ops/add_rms_norm.h @@ -0,0 +1,32 @@ +#ifndef __INFINIOP_ADD_RMS_NORM_API_H__ +#define __INFINIOP_ADD_RMS_NORM_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopAddRMSNormDescriptor_t; + +__C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor( + infiniopHandle_t handle, + infiniopAddRMSNormDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon, + infiniopTensorDescriptor_t residual_out_desc); + +__C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *a, + const void *b, + const void *weight, + void *residual_out, + void *stream); + +__C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc); + +#endif diff --git a/python/infinicore/__init__.py b/python/infinicore/__init__.py index 7ca962449..94d21d88e 100644 --- a/python/infinicore/__init__.py +++ b/python/infinicore/__init__.py @@ -40,6 +40,7 @@ uint8, ) from infinicore.ops.add import add +from infinicore.ops.add_rms_norm import add_rms_norm, add_rms_norm_ from infinicore.ops.attention import attention from infinicore.ops.matmul import matmul from infinicore.ops.mul import mul @@ -102,6 +103,8 @@ "uint8", # Operations. "add", + "add_rms_norm", + "add_rms_norm_", "attention", "matmul", "mul", diff --git a/python/infinicore/ops/add_rms_norm.py b/python/infinicore/ops/add_rms_norm.py new file mode 100644 index 000000000..4ad347812 --- /dev/null +++ b/python/infinicore/ops/add_rms_norm.py @@ -0,0 +1,47 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + + +def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None): + """ + Fused Add and RMS Normalization. + + Args: + a: First input tensor + b: Second input tensor + weight: Scale weights + epsilon: Small constant for numerical stability, default is 1e-5 + out: Optional output tuple (y, residual_out) for in-place operation + + Returns: + Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b) + The add_result can be used as residual for subsequent layers. + """ + if out is None: + result = _infinicore.add_rms_norm( + a._underlying, b._underlying, weight._underlying, epsilon + ) + return (Tensor(result[0]), Tensor(result[1])) + + y, residual_out = out + _infinicore.add_rms_norm_( + y._underlying, + residual_out._underlying, + a._underlying, + b._underlying, + weight._underlying, + epsilon, + ) + return (y, residual_out) + + +def add_rms_norm_(y, residual_out, a, b, weight, epsilon=1e-5): + """In-place Fused Add and RMS Normalization.""" + _infinicore.add_rms_norm_( + y._underlying, + residual_out._underlying, + a._underlying, + b._underlying, + weight._underlying, + epsilon, + ) diff --git a/src/infinicore/ops/add_rms_norm/add_rms_norm.cc b/src/infinicore/ops/add_rms_norm/add_rms_norm.cc new file mode 100644 index 000000000..650ce87e6 --- /dev/null +++ b/src/infinicore/ops/add_rms_norm/add_rms_norm.cc @@ -0,0 +1,29 @@ +#include "infinicore/ops/add_rms_norm.hpp" + +#include "../../utils.hpp" + +namespace infinicore::op { + +common::OpDispatcher &AddRMSNorm::dispatcher() { + static common::OpDispatcher dispatcher_; + return dispatcher_; +}; + +void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { + INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight); + infinicore::context::setDevice(y->device()); + dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon); +} + +std::pair add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) { + auto y = Tensor::empty(a->shape(), a->dtype(), a->device()); + auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device()); + add_rms_norm_(y, residual_out, a, b, weight, epsilon); + return std::make_pair(y, residual_out); +} + +void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { + AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon); +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc b/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc new file mode 100644 index 000000000..d6540a039 --- /dev/null +++ b/src/infinicore/ops/add_rms_norm/add_rms_norm_infiniop.cc @@ -0,0 +1,50 @@ +#include "../../utils.hpp" +#include "infinicore/common/hash.hpp" +#include "infinicore/ops/add_rms_norm.hpp" +#include "infinicore/ops/common/cache.hpp" +#include + +namespace infinicore::op::add_rms_norm_impl::infiniop { + +thread_local common::OpCache caches( + 100, // capacity + [](infiniopAddRMSNormDescriptor_t &desc) { + if (desc != nullptr) { + INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc)); + desc = nullptr; + } + }); + +void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { + size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon); + + auto device = context::getDevice(); + auto &cache = caches.getCache(device); + + auto desc_opt = cache.get(seed); + infiniopAddRMSNormDescriptor_t desc = nullptr; + + if (!desc_opt) { + INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor( + context::getInfiniopHandle(device), &desc, + y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc())); + cache.put(seed, desc); + } else { + desc = *desc_opt; + } + + size_t workspace_size = 0; + INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size)); + std::shared_ptr workspace = context::allocateMemory(workspace_size); + + INFINICORE_CHECK_ERROR(infiniopAddRMSNorm( + desc, workspace->data(), workspace_size, + y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream())); +} + +static bool registered = []() { + AddRMSNorm::dispatcher().registerAll(&calculate, false); + return true; +}(); + +} // namespace infinicore::op::add_rms_norm_impl::infiniop diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 978defa17..05bf7bb4e 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -3,6 +3,7 @@ #include #include "ops/add.hpp" +#include "ops/add_rms_norm.hpp" #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" #include "ops/embedding.hpp" @@ -22,6 +23,7 @@ namespace infinicore::ops { inline void bind(py::module &m) { bind_add(m); + bind_add_rms_norm(m); bind_attention(m); bind_causal_softmax(m); bind_random_sample(m); diff --git a/src/infinicore/pybind11/ops/add_rms_norm.hpp b/src/infinicore/pybind11/ops/add_rms_norm.hpp new file mode 100644 index 000000000..5f9b243e5 --- /dev/null +++ b/src/infinicore/pybind11/ops/add_rms_norm.hpp @@ -0,0 +1,51 @@ +#pragma once + +#include + +#include "infinicore/ops/add_rms_norm.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_add_rms_norm(py::module &m) { + m.def("add_rms_norm", + &op::add_rms_norm, + py::arg("a"), + py::arg("b"), + py::arg("weight"), + py::arg("epsilon") = 1e-5f, + R"doc(Fused Add and RMS Normalization. + +Args: + a: First input tensor + b: Second input tensor + weight: Scale weights + epsilon: Small constant for numerical stability, default is 1e-5 + +Returns: + Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b) + The add_result can be used as residual for subsequent layers. +)doc"); + + m.def("add_rms_norm_", + &op::add_rms_norm_, + py::arg("y"), + py::arg("residual_out"), + py::arg("a"), + py::arg("b"), + py::arg("weight"), + py::arg("epsilon") = 1e-5f, + R"doc(In-place Fused Add and RMS Normalization. + +Args: + y: Output tensor for normalized result + residual_out: Output tensor for add result (a + b) before normalization + a: First input tensor + b: Second input tensor + weight: Scale weights + epsilon: Small constant for numerical stability, default is 1e-5 +)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infiniop/ops/add_rms_norm/add_rms_norm.h b/src/infiniop/ops/add_rms_norm/add_rms_norm.h new file mode 100644 index 000000000..c5d63333d --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/add_rms_norm.h @@ -0,0 +1,53 @@ +#ifndef ADD_RMS_NORM_H +#define ADD_RMS_NORM_H + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::add_rms_norm::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + AddRMSNormInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor( \ + Opaque *opaque, \ + AddRMSNormInfo info, \ + size_t workspace_size, \ + infiniDevice_t device_type, \ + int device_id) \ + : InfiniopDescriptor{device_type, device_id}, \ + _opaque(opaque), \ + _info(info), \ + _workspace_size(workspace_size) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create( \ + infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t y_desc, \ + infiniopTensorDescriptor_t a_desc, \ + infiniopTensorDescriptor_t b_desc, \ + infiniopTensorDescriptor_t weight_desc, \ + float epsilon, \ + infiniopTensorDescriptor_t residual_out_desc); \ + \ + infiniStatus_t calculate( \ + void *workspace, size_t workspace_size, \ + void *y, \ + const void *a, \ + const void *b, \ + const void *weight, \ + void *residual_out, \ + void *stream) const; \ + }; \ + } + +#endif // ADD_RMS_NORM_H diff --git a/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc b/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc new file mode 100644 index 000000000..5e7954b71 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.cc @@ -0,0 +1,147 @@ +#include "add_rms_norm_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../../reduce/cpu/reduce.h" + +namespace op::add_rms_norm::cpu { + +Descriptor::~Descriptor() {} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon, + infiniopTensorDescriptor_t residual_out_desc) { + auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + CHECK_RESULT(result); + *desc_ptr = new Descriptor(nullptr, result.take(), 0, handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t add_rmsnorm(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const T *w, T *residual_out) { + const size_t batch_size = info->shape[0]; + const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1; + const size_t dim = info->dim(); + const ptrdiff_t total_blocks = static_cast(batch_size * nhead); + +#pragma omp parallel for + for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) { + const size_t i = block_idx / nhead; // batch index + const size_t j = block_idx % nhead; // head index + + const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1]; + const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1]; + T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1]; + T *residual_out_ptr = residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1]; + + // Compute add(a, b) once and store it + T sum_squared = (T)0; + for (size_t k = 0; k < dim; k++) { + T sum_val = a_ptr[k] + b_ptr[k]; + residual_out_ptr[k] = sum_val; // Store add result + sum_squared += sum_val * sum_val; + } + + // Compute RMS: 1 / (sqrt(mean(sum^2) + eps)) + // Note: mean = sum_squared / dim + T rms = (T)1 / std::sqrt(sum_squared / (T)(dim) + (T)(info->epsilon)); + + // Apply normalization: y = (a + b) * w * rms + // Reuse stored values from residual_out + for (size_t k = 0; k < dim; k++) { + y_ptr[k] = residual_out_ptr[k] * w[k] * rms; + } + } + + return INFINI_STATUS_SUCCESS; +} + +template +infiniStatus_t add_rmsnormHalfPrecision(const AddRMSNormInfo *info, T *y, const T *a, const T *b, const Tw *w, T *residual_out) { + static_assert(std::is_same::value || std::is_same::value, + "T must be fp16_t or bf16_t"); + + const size_t batch_size = info->shape[0]; + const size_t nhead = info->ndim() > 2 ? info->shape[1] : 1; + const size_t dim = info->dim(); + const ptrdiff_t total_blocks = static_cast(batch_size * nhead); + +#pragma omp parallel for + for (ptrdiff_t block_idx = 0; block_idx < total_blocks; ++block_idx) { + const size_t i = block_idx / nhead; // batch index + const size_t j = block_idx % nhead; // head index + + const T *a_ptr = a + i * info->a_strides[0] + j * info->a_strides[1]; + const T *b_ptr = b + i * info->b_strides[0] + j * info->b_strides[1]; + T *y_ptr = y + i * info->y_strides[0] + j * info->y_strides[1]; + T *residual_out_ptr = residual_out + i * info->residual_out_strides[0] + j * info->residual_out_strides[1]; + + // Compute sum of squares for RMS normalization and store add result + float sum_squared = 0.0f; + for (size_t k = 0; k < dim; k++) { + float sum_val = utils::cast(a_ptr[k]) + utils::cast(b_ptr[k]); + residual_out_ptr[k] = utils::cast(sum_val); // Store add result + sum_squared += sum_val * sum_val; + } + + // Compute RMS: 1 / (sqrt(sum/dim + eps)) + float rms = 1.f / std::sqrt(sum_squared / (float)(dim) + info->epsilon); + + // Apply normalization: y = (a + b) * w * rms + // Reuse stored values from residual_out + for (size_t k = 0; k < dim; k++) { + float sum_val = utils::cast(residual_out_ptr[k]); + float val; + if constexpr (std::is_same::value) { + val = sum_val * w[k] * rms; + } else if constexpr (std::is_same::value || std::is_same_v || std::is_same_v) { + val = sum_val * utils::cast(w[k]) * rms; + } else { + std::abort(); + } + y_ptr[k] = utils::cast(val); + } + } + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *a, const void *b, const void *weight, + void *residual_out, void *stream) const { + if (_info.atype == INFINI_DTYPE_F16) { + if (_info.wtype == INFINI_DTYPE_F16) { + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const fp16_t *)weight, (fp16_t *)residual_out)); + } else if (_info.wtype == INFINI_DTYPE_F32) { + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const float *)weight, (fp16_t *)residual_out)); + } else if (_info.wtype == INFINI_DTYPE_BF16) { + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (fp16_t *)y, (const fp16_t *)a, (const fp16_t *)b, (const bf16_t *)weight, (fp16_t *)residual_out)); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else if (_info.atype == INFINI_DTYPE_BF16) { + if (_info.wtype == INFINI_DTYPE_BF16) { + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const bf16_t *)weight, (bf16_t *)residual_out)); + } else if (_info.wtype == INFINI_DTYPE_F32) { + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const float *)weight, (bf16_t *)residual_out)); + } else if (_info.wtype == INFINI_DTYPE_F16) { + CHECK_STATUS(add_rmsnormHalfPrecision(&_info, (bf16_t *)y, (const bf16_t *)a, (const bf16_t *)b, (const fp16_t *)weight, (bf16_t *)residual_out)); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else if (_info.atype == INFINI_DTYPE_F32) { + CHECK_STATUS(add_rmsnorm(&_info, (float *)y, (const float *)a, (const float *)b, (const float *)weight, (float *)residual_out)); + } else if (_info.atype == INFINI_DTYPE_F64) { + CHECK_STATUS(add_rmsnorm(&_info, (double *)y, (const double *)a, (const double *)b, (const double *)weight, (double *)residual_out)); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + return INFINI_STATUS_SUCCESS; +} +} // namespace op::add_rms_norm::cpu diff --git a/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.h b/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.h new file mode 100644 index 000000000..f4173adbf --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/cpu/add_rms_norm_cpu.h @@ -0,0 +1,7 @@ +#ifndef __ADD_RMS_NORM_CPU_H__ +#define __ADD_RMS_NORM_CPU_H__ +#include "../add_rms_norm.h" + +DESCRIPTOR(cpu) + +#endif diff --git a/src/infiniop/ops/add_rms_norm/cuda/kernel.cuh b/src/infiniop/ops/add_rms_norm/cuda/kernel.cuh new file mode 100644 index 000000000..97f254a09 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/cuda/kernel.cuh @@ -0,0 +1,63 @@ +#ifndef __ADD_RMS_NORM_CUDA_KERNEL_H__ +#define __ADD_RMS_NORM_CUDA_KERNEL_H__ + +#include + +template +__device__ void add_rmsnormBlock( + Tdata *__restrict__ y, + Tdata *__restrict__ residual_out, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_nhead, + ptrdiff_t stride_residual_out_batch, + ptrdiff_t stride_residual_out_nhead, + const Tdata *__restrict__ a, + ptrdiff_t stride_a_batch, + ptrdiff_t stride_a_nhead, + const Tdata *__restrict__ b, + ptrdiff_t stride_b_batch, + ptrdiff_t stride_b_nhead, + const Tweight *__restrict__ w, + size_t nhead, + size_t dim, + float epsilon) { + // Each block takes care of one head in one batch + // Each thread deals with every block_size element in the row + size_t batch_idx = blockIdx.x / nhead; + size_t head_idx = blockIdx.x % nhead; + + auto y_ptr = y + batch_idx * stride_y_batch + head_idx * stride_y_nhead; + auto a_ptr = a + batch_idx * stride_a_batch + head_idx * stride_a_nhead; + auto b_ptr = b + batch_idx * stride_b_batch + head_idx * stride_b_nhead; + auto w_ptr = w; + Tdata *residual_out_ptr = residual_out + batch_idx * stride_residual_out_batch + head_idx * stride_residual_out_nhead; + + // Compute add(a, b) and sum of squares in one pass + Tcompute sum_squared = 0; + for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) { + Tcompute sum_val = Tcompute(a_ptr[i]) + Tcompute(b_ptr[i]); + residual_out_ptr[i] = Tdata(sum_val); // Store add result + sum_squared += sum_val * sum_val; + } + + // Block-reduce sum of squares + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + sum_squared = BlockReduce(temp_storage).Sum(sum_squared); + + // Thread_0 computes RMS=1/sqrt(ss/dim+epsilon) and stores in shared memory + __shared__ Tcompute rms; + if (threadIdx.x == 0) { + rms = Tcompute(rsqrtf(sum_squared / Tcompute(dim) + epsilon)); + } + __syncthreads(); + + // Apply normalization: y = (a + b) * w * rms + // Reuse stored values from residual_out + for (size_t i = threadIdx.x; i < dim; i += BLOCK_SIZE) { + Tcompute sum_val = Tcompute(residual_out_ptr[i]); // Reuse stored value + y_ptr[i] = Tdata(sum_val * Tcompute(w_ptr[i]) * rms); + } +} + +#endif diff --git a/src/infiniop/ops/add_rms_norm/info.h b/src/infiniop/ops/add_rms_norm/info.h new file mode 100644 index 000000000..abe1b5059 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/info.h @@ -0,0 +1,132 @@ +#ifndef __ADD_RMS_NORM_INFO_H__ +#define __ADD_RMS_NORM_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include + +namespace op::add_rms_norm { + +class AddRMSNormInfo { + AddRMSNormInfo() = default; + +public: + infiniDtype_t wtype; + infiniDtype_t atype; + float epsilon; + std::vector shape; + std::vector y_strides; + std::vector a_strides; + std::vector b_strides; + std::vector residual_out_strides; + bool has_residual_out; + + size_t ndim() const { return shape.size(); } + size_t dim() const { return shape[ndim() - 1]; } + + static utils::Result create( + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon, + infiniopTensorDescriptor_t residual_out_desc) { + + auto atype = y_desc->dtype(); + auto wtype = weight_desc->dtype(); + + // Check that all input tensors have the same dtype + if (a_desc->dtype() != atype || b_desc->dtype() != atype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + if (atype == INFINI_DTYPE_F16 || atype == INFINI_DTYPE_BF16) { + // For half-precision types (FP16/BF16), weights can be the same half-precision type or FP32 + if (wtype != atype && wtype != INFINI_DTYPE_F32 && wtype != INFINI_DTYPE_BF16 && wtype != INFINI_DTYPE_F16) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else if (atype == INFINI_DTYPE_F32 || atype == INFINI_DTYPE_F64) { + // For FP32/FP64, activations and weights must be of the same type + if (atype != wtype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + + const size_t y_ndim = y_desc->ndim(); + const size_t a_ndim = a_desc->ndim(); + const size_t b_ndim = b_desc->ndim(); + const size_t w_ndim = weight_desc->ndim(); + + if (y_ndim != a_ndim || y_ndim != b_ndim || w_ndim != 1) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + size_t batch = 1; + size_t nhead = 1; + size_t dim = 0; + + if (y_ndim == 2) { + batch = y_desc->dim(0); + dim = y_desc->dim(1); + + if (a_desc->dim(0) != batch || a_desc->dim(1) != dim || b_desc->dim(0) != batch || b_desc->dim(1) != dim || weight_desc->dim(0) != dim) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else if (y_ndim == 3) { + batch = y_desc->dim(0); + nhead = y_desc->dim(1); + dim = y_desc->dim(2); + + if (a_desc->dim(0) != batch || a_desc->dim(1) != nhead || a_desc->dim(2) != dim || b_desc->dim(0) != batch || b_desc->dim(1) != nhead || b_desc->dim(2) != dim || weight_desc->dim(0) != dim) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } else { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + // Check contiguity of the last dimension + if (y_desc->stride(y_ndim - 1) != 1 || a_desc->stride(a_ndim - 1) != 1 || b_desc->stride(b_ndim - 1) != 1 || weight_desc->stride(w_ndim - 1) != 1) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + // residual_out_desc is required (always needed for fused operator) + if (residual_out_desc == nullptr) { + return INFINI_STATUS_BAD_PARAM; + } + + const size_t residual_out_ndim = residual_out_desc->ndim(); + if (residual_out_ndim != y_ndim) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + if (residual_out_desc->dtype() != atype) { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + // Check shape matches + for (size_t i = 0; i < y_ndim; i++) { + if (residual_out_desc->dim(i) != y_desc->dim(i)) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + } + if (residual_out_desc->stride(residual_out_ndim - 1) != 1) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + AddRMSNormInfo info; + info.wtype = wtype; + info.atype = atype; + info.epsilon = epsilon; + info.shape = y_desc->shape(); + info.y_strides = y_desc->strides(); + info.a_strides = a_desc->strides(); + info.b_strides = b_desc->strides(); + info.has_residual_out = true; // Always true now + info.residual_out_strides = residual_out_desc->strides(); + return utils::Result(info); + } +}; + +} // namespace op::add_rms_norm + +#endif // __ADD_RMS_NORM_INFO_H__ diff --git a/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu new file mode 100644 index 000000000..03601205f --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cu @@ -0,0 +1,175 @@ +#include "../../../devices/nvidia/nvidia_common.cuh" +#include "add_rms_norm_nvidia.cuh" + +#include "../../../devices/nvidia/nvidia_kernel_common.cuh" +#include + +#include "../../../reduce/cuda/reduce.cuh" + +#include "../cuda/kernel.cuh" + +template +INFINIOP_CUDA_KERNEL add_rmsnormKernel( + Tdata *__restrict__ y, + Tdata *__restrict__ residual_out, + ptrdiff_t stride_y_batch, + ptrdiff_t stride_y_nhead, + ptrdiff_t stride_residual_out_batch, + ptrdiff_t stride_residual_out_nhead, + const Tdata *__restrict__ a, + ptrdiff_t stride_a_batch, + ptrdiff_t stride_a_nhead, + const Tdata *__restrict__ b, + ptrdiff_t stride_b_batch, + ptrdiff_t stride_b_nhead, + const Tweight *__restrict__ w, + size_t nhead, + size_t dim, + float epsilon) { + add_rmsnormBlock( + y, residual_out, + stride_y_batch, stride_y_nhead, + stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + w, nhead, dim, epsilon); +} + +namespace op::add_rms_norm::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon, + infiniopTensorDescriptor_t residual_out_desc) { + auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); + CHECK_RESULT(result); + auto info = result.take(); + + *desc_ptr = new Descriptor( + new Opaque{reinterpret_cast(handle)->internal()}, + std::move(info), + 0, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +// launch kernel with different data types +template +infiniStatus_t launchKernel( + uint32_t batch_size, size_t nhead, size_t dim, + void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead, + void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead, + const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead, + const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead, + const void *w, infiniDtype_t wtype, + float epsilon, + cudaStream_t cuda_stream) { + +#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ + add_rmsnormKernel<<>>( \ + reinterpret_cast(y), \ + reinterpret_cast(residual_out), \ + stride_y_batch, \ + stride_y_nhead, \ + stride_residual_out_batch, \ + stride_residual_out_nhead, \ + reinterpret_cast(a), \ + stride_a_batch, \ + stride_a_nhead, \ + reinterpret_cast(b), \ + stride_b_batch, \ + stride_b_nhead, \ + reinterpret_cast(w), \ + nhead, \ + dim, \ + epsilon) + + if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(half, half, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(half, __nv_bfloat16, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(half, float, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(__nv_bfloat16, __nv_bfloat16, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(__nv_bfloat16, half, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(__nv_bfloat16, float, float); + } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { + LAUNCH_KERNEL(float, float, float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + +#undef LAUNCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, size_t workspace_size, + void *y, const void *a, const void *b, const void *weight, + void *residual_out, void *stream) const { + + if (workspace_size < _workspace_size) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + auto stride_a_batch = _info.a_strides[0]; + auto stride_a_nhead = _info.a_strides[1]; + auto stride_b_batch = _info.b_strides[0]; + auto stride_b_nhead = _info.b_strides[1]; + auto stride_y_batch = _info.y_strides[0]; + auto stride_y_nhead = _info.y_strides[1]; + auto stride_residual_out_batch = _info.residual_out_strides[0]; + auto stride_residual_out_nhead = _info.residual_out_strides[1]; + auto dim = _info.dim(); + uint32_t batch_size = static_cast(_info.shape[0]); + size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; + auto cuda_stream = reinterpret_cast(stream); + + // launch kernel with different block sizes + if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_512) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, cuda_stream)); + } else if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_4096) { + CHECK_STATUS(launchKernel( + batch_size, nhead, dim, + y, _info.atype, stride_y_batch, stride_y_nhead, + residual_out, stride_residual_out_batch, stride_residual_out_nhead, + a, stride_a_batch, stride_a_nhead, + b, stride_b_batch, stride_b_nhead, + weight, _info.wtype, _info.epsilon, cuda_stream)); + } else { + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; + } + return INFINI_STATUS_SUCCESS; +} +} // namespace op::add_rms_norm::nvidia diff --git a/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cuh b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cuh new file mode 100644 index 000000000..548be83e1 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/nvidia/add_rms_norm_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __ADD_RMS_NORM_NVIDIA_CUDA_H__ +#define __ADD_RMS_NORM_NVIDIA_CUDA_H__ + +#include "../add_rms_norm.h" + +DESCRIPTOR(nvidia) + +#endif diff --git a/src/infiniop/ops/add_rms_norm/operator.cc b/src/infiniop/ops/add_rms_norm/operator.cc new file mode 100644 index 000000000..a856e5447 --- /dev/null +++ b/src/infiniop/ops/add_rms_norm/operator.cc @@ -0,0 +1,189 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/add_rms_norm.h" + +#ifdef ENABLE_CPU_API +#include "cpu/add_rms_norm_cpu.h" +#endif +#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API) +#include "nvidia/add_rms_norm_nvidia.cuh" +#endif +#ifdef ENABLE_ASCEND_API +// TODO: Add Ascend implementation +// #include "ascend/add_rms_norm_aclnn.h" +#endif +#ifdef ENABLE_CAMBRICON_API +// TODO: Add Cambricon implementation +// #include "bang/add_rms_norm_bang.h" +#endif +#ifdef ENABLE_METAX_API +// TODO: Add Metax implementation +// #include "metax/add_rms_norm_metax.cuh" +#endif +#ifdef ENABLE_MOORE_API +// TODO: Add Moore implementation +// #include "moore/add_rms_norm_moore.h" +#endif +#ifdef ENABLE_KUNLUN_API +// TODO: Add Kunlun implementation +// #include "kunlun/add_rms_norm_kunlun.h" +#endif + +__C infiniStatus_t infiniopCreateAddRMSNormDescriptor( + infiniopHandle_t handle, + infiniopAddRMSNormDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t weight_desc, + float epsilon, + infiniopTensorDescriptor_t residual_out_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::add_rms_norm::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + y_desc, \ + a_desc, \ + b_desc, \ + weight_desc, \ + epsilon, \ + residual_out_desc) + + switch (handle->device) { +#ifdef ENABLE_CPU_API + CREATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CREATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + CREATE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_HYGON_API + CREATE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_KUNLUN_API + // CREATE(INFINI_DEVICE_KUNLUN, kunlun); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size) { + +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc)->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + GET(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + GET(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + GET(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_HYGON_API + GET(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_KUNLUN_API + // GET(INFINI_DEVICE_KUNLUN, kunlun); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET + + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; +} + +__C infiniStatus_t infiniopAddRMSNorm( + infiniopAddRMSNormDescriptor_t desc, + void *workspace, + size_t workspace_size, + void *y, + const void *a, + const void *b, + const void *weight, + void *residual_out, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, y, a, b, weight, residual_out, stream) + + switch (desc->device_type) { + +#ifdef ENABLE_CPU_API + CALCULATE(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + CALCULATE(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + CALCULATE(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_HYGON_API + CALCULATE(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_KUNLUN_API + // CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef CALCULATE +} + +__C infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc) { + if (desc == nullptr) { + return INFINI_STATUS_SUCCESS; + } + +#define DESTROY(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_CPU_API + DESTROY(INFINI_DEVICE_CPU, cpu); +#endif +#ifdef ENABLE_NVIDIA_API + DESTROY(INFINI_DEVICE_NVIDIA, nvidia); +#endif +#ifdef ENABLE_ILUVATAR_API + DESTROY(INFINI_DEVICE_ILUVATAR, nvidia); +#endif +#ifdef ENABLE_QY_API + DESTROY(INFINI_DEVICE_QY, nvidia); +#endif +#ifdef ENABLE_HYGON_API + DESTROY(INFINI_DEVICE_HYGON, nvidia); +#endif +#ifdef ENABLE_KUNLUN_API + // DESTROY(INFINI_DEVICE_KUNLUN, kunlun); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DESTROY +} diff --git a/test/infinicore/ops/add_rms_norm.py b/test/infinicore/ops/add_rms_norm.py new file mode 100644 index 000000000..429d9df25 --- /dev/null +++ b/test/infinicore/ops/add_rms_norm.py @@ -0,0 +1,171 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework import ( + BaseOperatorTest, + TensorSpec, + TestCase, + GenericTestRunner, + is_broadcast, +) + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== + +# Test cases format: (y_shape, a_shape, b_shape, w_shape, y_strides, a_strides, b_strides) +_TEST_CASES_DATA = [ + # Basic cases + ((1, 4), (1, 4), (1, 4), (4,), None, None, None), + ((2, 4), (2, 4), (2, 4), (4,), None, None, None), + ((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), None, None, None), + # Strided cases + ((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1), (12, 8, 1)), + # Large tensors + ((16, 2048), (16, 2048), (16, 2048), (2048,), None, None, None), + ((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)), + ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), + ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None), + ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)), + ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)), +] + +# Tolerance configuration +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 2e-3, "rtol": 2e-3}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-5, "rtol": 1e-4}, +} + +# Data types for individual tensors +_INPUT_DTYPES = [infinicore.float16, infinicore.bfloat16] +_WEIGHT_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + +# EPSILON constant for AddRMSNorm +_EPSILON = 1e-5 + + +def parse_test_cases(): + """ + Parse AddRMSNorm test case data and return list of TestCase objects. + Format: (y_shape, a_shape, b_shape, w_shape, y_strides, a_strides, b_strides) + """ + test_cases = [] + + for data in _TEST_CASES_DATA: + y_shape = data[0] # Output shape + a_shape = data[1] # First input shape + b_shape = data[2] # Second input shape + w_shape = data[3] # Weight shape (1D) + y_strides = data[4] if len(data) > 4 else None + a_strides = data[5] if len(data) > 5 else None + b_strides = data[6] if len(data) > 6 else None + + # Check if tensors support in-place operations + a_supports_inplace = not is_broadcast(a_strides) + b_supports_inplace = not is_broadcast(b_strides) + y_supports_inplace = not is_broadcast(y_strides) + + # Generate test cases for all dtype combinations + for input_dtype in _INPUT_DTYPES: + for weight_dtype in _WEIGHT_DTYPES: + # Use input dtype tolerance for output + tolerance = _TOLERANCE_MAP.get( + input_dtype, {"atol": 1e-5, "rtol": 1e-4} + ) + + # Create typed tensor specs + a_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype) + b_spec = TensorSpec.from_tensor(b_shape, b_strides, input_dtype) + w_spec = TensorSpec.from_tensor( + w_shape, None, weight_dtype + ) # Weight is always contiguous + y_spec = TensorSpec.from_tensor(y_shape, y_strides, input_dtype) + + # Test Case 1: Out-of-place (return value) - returns (normalized_result, add_result) + residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype) + test_cases.append( + TestCase( + inputs=[a_spec, b_spec, w_spec], + kwargs={"epsilon": _EPSILON}, + output_specs=[y_spec, residual_out_spec], # Two outputs + comparison_target=None, + tolerance=tolerance, + output_count=2, # Two outputs: normalized_result and add_result + description=f"AddRMSNorm - OUT_OF_PLACE", + ) + ) + + # Test Case 2: In-place with explicit output tensors (add_rms_norm_(y, residual_out, a, b, w)) + if y_supports_inplace: + residual_out_spec = TensorSpec.from_tensor(a_shape, a_strides, input_dtype) + test_cases.append( + TestCase( + inputs=[a_spec, b_spec, w_spec], + kwargs={"epsilon": _EPSILON, "out": (y_spec, residual_out_spec)}, + output_specs=[y_spec, residual_out_spec], # Two outputs + comparison_target="out", + tolerance=tolerance, + output_count=2, + description=f"AddRMSNorm - INPLACE(out)", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """AddRMSNorm operator test with simplified implementation""" + + def __init__(self): + super().__init__("AddRMSNorm") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): + """PyTorch AddRMSNorm implementation - returns (normalized_result, add_result)""" + input_dtype = a.dtype + + # Compute add(a, b) + sum_tensor = a.to(torch.float32) + b.to(torch.float32) + weight_fp32 = weight.to(torch.float32) + + # Calculate RMSNorm: (a + b) * weight / sqrt(mean((a+b)^2) + epsilon) + variance = sum_tensor.pow(2).mean(-1, keepdim=True) + normalized_result = sum_tensor * torch.rsqrt(variance + epsilon) * weight_fp32 + + # Convert back to original dtype + normalized_result = normalized_result.to(input_dtype) + add_result = sum_tensor.to(input_dtype) + + if out is not None: + # For in-place operations, we need to handle the output tuple + if isinstance(out, (tuple, list)) and len(out) == 2: + out[0].copy_(normalized_result) + out[1].copy_(add_result) + return tuple(out) + else: + # Single output - just return normalized result for backward compatibility + out.copy_(normalized_result) + return out + + return (normalized_result, add_result) + + def infinicore_operator(self, a, b, weight, epsilon=_EPSILON, out=None, **kwargs): + """InfiniCore AddRMSNorm implementation - returns (normalized_result, add_result)""" + return infinicore.add_rms_norm(a, b, weight, epsilon, out=out) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infiniop/add_rms_norm.py b/test/infiniop/add_rms_norm.py new file mode 100644 index 000000000..930314761 --- /dev/null +++ b/test/infiniop/add_rms_norm.py @@ -0,0 +1,185 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES_ = [ + # y_shape, a_shape, b_shape, w_shape, y_stride, a_stride, b_stride + ((1, 4), (1, 4), (1, 4), (4,), None, None, None), + ((2, 4), (2, 4), (2, 4), (4,), None, None, None), + ((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), None, None, None), + ((2, 2, 4), (2, 2, 4), (2, 2, 4), (4,), (12, 8, 1), (12, 8, 1), (12, 8, 1)), + ((16, 2048), (16, 2048), (16, 2048), (2048,), None, None, None), + ((16, 2048), (16, 2048), (16, 2048), (2048,), (4096, 1), (4096, 1), (4096, 1)), + ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), + ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), None, None, None), + ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (2048, 8192, 1), (2048, 8192, 1), (2048, 8192, 1)), + ((4, 4, 2048), (4, 4, 2048), (4, 4, 2048), (2048,), (16384, 4096, 1), (16384, 4096, 1), (16384, 4096, 1)), + ((15, 3584), (15, 3584), (15, 3584), (3584,), None, None, None), + ((15, 8192), (15, 8192), (15, 8192), (8192,), None, None, None), +] + +# w (weight) types +# Note: 'None' means the same as input dtype +_WEIGHT_DTYPES = [None, InfiniDtype.F32, InfiniDtype.F16, InfiniDtype.BF16] +# a, b types used for testing +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16] + +# Form the test cases by appending each element of _WEIGHT_DTYPES to each tuple in _TEST_CASES_ +_TEST_CASES = [ + test_case + (w_dtype,) for test_case in _TEST_CASES_ for w_dtype in _WEIGHT_DTYPES +] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 2e-3, "rtol": 2e-3}, + InfiniDtype.BF16: {"atol": 1e-2, "rtol": 1e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def add_rms_norm(ans, a, b, w, eps): + input_dtype = a.dtype + # Compute add(a, b) + sum_tensor = a.to(torch.float32) + b.to(torch.float32) + # Compute RMS normalization + scale = sum_tensor.pow(2).mean(-1, keepdim=True).add_(eps).rsqrt_() + ans.set_((sum_tensor.mul_(scale).mul_(w.to(torch.float32))).to(input_dtype)) + + +def test( + handle, + device, + y_shape, + a_shape, + b_shape, + w_shape, + y_stride, + a_stride, + b_stride, + w_dtype=InfiniDtype.F32, + dtype=InfiniDtype.F16, + sync=None, +): + w_dtype = w_dtype if w_dtype else dtype + print( + f"Testing AddRMSNorm on {InfiniDeviceNames[device]} with y_shape:{y_shape} a_shape:{a_shape} b_shape:{b_shape} w_shape:{w_shape}" + f" y_stride:{y_stride} a_stride:{a_stride} b_stride:{b_stride} w_dtype:{InfiniDtypeNames[w_dtype]} dtype:{InfiniDtypeNames[dtype]}" + ) + + y = TestTensor(y_shape, y_stride, dtype, device, mode="ones") + residual_out = TestTensor(a_shape, a_stride, dtype, device, mode="ones") + a = TestTensor(a_shape, a_stride, dtype, device, scale=0.01) + b = TestTensor(b_shape, b_stride, dtype, device, scale=0.01) + w = TestTensor(w_shape, None, w_dtype, device) + + eps = 1e-6 + add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps) + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + + check_error( + LIBINFINIOP.infiniopCreateAddRMSNormDescriptor( + handle, + ctypes.byref(descriptor), + y.descriptor, + a.descriptor, + b.descriptor, + w.descriptor, + eps, + residual_out.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [a, b, y, w, residual_out]: + tensor.destroy_desc() + + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetAddRMSNormWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, y.device) + + def lib_add_rms_norm(): + check_error( + LIBINFINIOP.infiniopAddRMSNorm( + descriptor, + workspace.data(), + workspace_size.value, + y.data(), + a.data(), + b.data(), + w.data(), + residual_out.data(), + None, + ) + ) + + lib_add_rms_norm() + + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + + # Verify normalized result (y) + if DEBUG: + debug(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose(y.actual_tensor(), y.torch_tensor(), atol=atol, rtol=rtol) + + # Verify add result (residual_out) - should be a + b + expected_residual = a.torch_tensor().to(torch.float32) + b.torch_tensor().to(torch.float32) + expected_residual = expected_residual.to(a.torch_tensor().dtype) + if DEBUG: + debug(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol) + assert torch.allclose(residual_out.actual_tensor(), expected_residual, atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: add_rms_norm(y.torch_tensor(), a.torch_tensor(), b.torch_tensor(), w.torch_tensor(), eps), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_add_rms_norm(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(LIBINFINIOP.infiniopDestroyAddRMSNormDescriptor(descriptor)) + + +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # Execute tests + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index 5b2974111..a61d99189 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -383,6 +383,43 @@ def rms_norm_(lib): ] +@OpRegister.operator +def add_rms_norm_(lib): + lib.infiniopCreateAddRMSNormDescriptor.restype = c_int32 + lib.infiniopCreateAddRMSNormDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_float, + ] + + lib.infiniopGetAddRMSNormWorkspaceSize.restype = c_int32 + lib.infiniopGetAddRMSNormWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopAddRMSNorm.restype = c_int32 + lib.infiniopAddRMSNorm.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + + lib.infiniopDestroyAddRMSNormDescriptor.restype = c_int32 + lib.infiniopDestroyAddRMSNormDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def rope_(lib): lib.infiniopCreateRoPEDescriptor.restype = c_int32