From ed1a21b42bfbf03f37a66ca14278f1a316d85b7b Mon Sep 17 00:00:00 2001 From: zhangyue Date: Mon, 22 Dec 2025 17:20:43 +0800 Subject: [PATCH 1/4] issue/826: kunlun layernorm --- src/infiniop/ops/layer_norm/kunlun/kernel.h | 77 ++++ .../ops/layer_norm/kunlun/layer_norm_kunlun.h | 7 + .../layer_norm/kunlun/layer_norm_kunlun.xpu | 428 ++++++++++++++++++ src/infiniop/ops/layer_norm/operator.cc | 15 + src/infiniop/reduce/kunlun/reduce_kunlun.h | 22 + test/infiniop/layer_norm.py | 67 ++- 6 files changed, 580 insertions(+), 36 deletions(-) create mode 100644 src/infiniop/ops/layer_norm/kunlun/kernel.h create mode 100644 src/infiniop/ops/layer_norm/kunlun/layer_norm_kunlun.h create mode 100644 src/infiniop/ops/layer_norm/kunlun/layer_norm_kunlun.xpu diff --git a/src/infiniop/ops/layer_norm/kunlun/kernel.h b/src/infiniop/ops/layer_norm/kunlun/kernel.h new file mode 100644 index 000000000..c58f3fbdb --- /dev/null +++ b/src/infiniop/ops/layer_norm/kunlun/kernel.h @@ -0,0 +1,77 @@ +#ifndef __LAYER_NORM_KUNLUN_KERNEL_H__ +#define __LAYER_NORM_KUNLUN_KERNEL_H__ + +#include "../../../devices/kunlun/kunlun_kernel_common.h" +#include "../../../reduce/kunlun/reduce_kunlun.h" + +using namespace device::kunlun::kernel; + +// Calculate norm in BLOCK_SIZE cores in one cluster. Useful for long normalized_size +template +__device__ void layerNormCluster( + __shared_ptr__ Tdata *y, + __shared_ptr__ Tdata *output_standardized, + __shared_ptr__ Tdata *output_std_deviation, + __shared_ptr__ const Tdata *input, + __shared_ptr__ const Tdata *weight, + __shared_ptr__ const Tdata *bias, + float eps, + int32_t normalized_size, + bool bias_exist) { + + // Block reduce sum of x^2 + Tcompute mean = op::common_kunlun::reduce_op:: + sum(input, normalized_size) + / normalized_size; + Tcompute sum_squared = op::common_kunlun::reduce_op:: + sumSquared(input, normalized_size); + Tcompute var = sum_squared / normalized_size - mean * mean; + // Compute rsqrt variance + epsilon + Tcompute rstd = Tcompute(1.0f) / sqrt(var + Tcompute(eps)); + + // Write to output_std_deviation + if (core_id() == 0) { + *output_std_deviation = static_cast(rstd); + } + sync_cluster(); + + for (int32_t i = core_id(); i < normalized_size; i += BLOCK_SIZE) { + Tcompute x_standard = (Tcompute(input[i]) - mean) * rstd; + output_standardized[i] = static_cast(x_standard); + y[i] = static_cast(x_standard * Tcompute(weight[i]) + (bias_exist ? Tcompute(bias[i]) : Tcompute(0))); + } + sync_cluster(); +} + +// Calculate norm in single core. Useful for short normalized_size +template +__device__ void layerNormBlock( + __local__ Tdata *output, + __local__ Tdata *output_standardization, + __local__ Tdata *output_rstd_deviation, + __local__ const Tdata *input, + __shared_ptr__ const Tdata *weight, + __shared_ptr__ const Tdata *bias, + float eps, + int32_t normalized_size, + bool bias_exist) { + + // Block reduce sum of x^2 + Tcompute mean = op::common_kunlun::reduce_op::blockSum(input, normalized_size) + / normalized_size; + Tcompute sum_squared = op::common_kunlun::reduce_op::blockSumSquared(input, normalized_size); + Tcompute var = sum_squared / normalized_size - mean * mean; + // Compute rsqrt variance + epsilon + Tcompute rstd = Tcompute(1.0f) / sqrt(var + Tcompute(eps)); + // Write to output_rstd_deviation + *output_rstd_deviation = static_cast(rstd); + + for (int32_t i = 0; i < normalized_size; i += 1) { + Tcompute x_standard = (Tcompute(input[i]) - mean) * rstd; + output_standardization[i] = static_cast(x_standard); + output[i] = static_cast(x_standard * Tcompute(weight[i]) + (bias_exist ? Tcompute(bias[i]) : Tcompute(0))); + } + mfence(); +} + +#endif diff --git a/src/infiniop/ops/layer_norm/kunlun/layer_norm_kunlun.h b/src/infiniop/ops/layer_norm/kunlun/layer_norm_kunlun.h new file mode 100644 index 000000000..5bccad498 --- /dev/null +++ b/src/infiniop/ops/layer_norm/kunlun/layer_norm_kunlun.h @@ -0,0 +1,7 @@ +#ifndef __LAYER_NORM_KUNLUN_API_H__ +#define __LAYER_NORM_KUNLUN_API_H__ +#include "../layer_norm.h" + +DESCRIPTOR(kunlun) + +#endif // __LAYER_NORM_KUNLUN_API_H__ diff --git a/src/infiniop/ops/layer_norm/kunlun/layer_norm_kunlun.xpu b/src/infiniop/ops/layer_norm/kunlun/layer_norm_kunlun.xpu new file mode 100644 index 000000000..2f67d1d3d --- /dev/null +++ b/src/infiniop/ops/layer_norm/kunlun/layer_norm_kunlun.xpu @@ -0,0 +1,428 @@ +#include "../../../devices/kunlun/kunlun_common.h" +#include "../../../devices/kunlun/kunlun_handle.h" +#include "../../../devices/kunlun/kunlun_kernel_common.h" +#include "kernel.h" +#include "layer_norm_kunlun.h" + +template +__global__ void layerNormKernel( + int32_t loop_idx, + Tdata *output, // [b, seq, dim] + Tdata *output_standardization, // [b, seq, dim] + Tdata *output_rstd_deviation, // [b, seq] + const Tdata *input, // [b, seq, dim] + const Tdata *weight, // [dim] + const Tdata *bias, // [None | dim] + float eps, + int32_t normalized_size, + int32_t othersize, + const void *shape, // [ndim] + const void *output_strides, // [ndim] + const void *output_standardization_strides, // [ndim] + const void *output_rstd_deviation_strides, // [ndim - 1] + const void *input_strides, // [ndim] + int32_t weight_strides, + int32_t bias_strides, + int32_t ndim, + bool bias_exist) { + + // Shape and strides + __local__ _size_t shape_local[ndim]; + __local__ _ptrdiff_t output_strides_local[ndim]; + __local__ _ptrdiff_t output_standardization_strides_local[ndim]; + __local__ _ptrdiff_t output_rstd_deviation_strides_local[ndim - 1]; + __local__ _ptrdiff_t input_strides_local[ndim]; + + // Load shape and strides from global memory to local memory + GM2LM_ASYNC(shape, shape_local, ndim * sizeof(_size_t)); + GM2LM_ASYNC(output_strides, output_strides_local, ndim * sizeof(_ptrdiff_t)); + GM2LM_ASYNC(output_standardization_strides, output_standardization_strides_local, ndim * sizeof(_ptrdiff_t)); + GM2LM_ASYNC(output_rstd_deviation_strides, output_rstd_deviation_strides_local, (ndim - 1) * sizeof(_ptrdiff_t)); + GM2LM_ASYNC(input_strides, input_strides_local, ndim * sizeof(_ptrdiff_t)); + mfence(); + + // Calculate block tile coordinates + int32_t block_idx = cluster_id() + loop_idx * cluster_num(); + if (block_idx >= othersize) { + return; + } + + int32_t t_coords[ndim - 1]; + int32_t temp_block_idx = block_idx; + for (int i = ndim - 2; i >= 0; i--) { + int32_t dim_i = shape_local[i].value; + t_coords[i] = temp_block_idx % dim_i; + temp_block_idx /= dim_i; + } + + // Calculate offsets + int32_t offset_output = 0; + int32_t offset_output_standardization = 0; + int32_t offset_output_rstd_deviation = 0; + int32_t offset_input = 0; + for (int i = 0; i < ndim - 1; i++) { + int32_t dim_i = shape_local[i].value; + offset_output += t_coords[i] * output_strides_local[i].value; + offset_output_standardization += t_coords[i] * output_standardization_strides_local[i].value; + offset_output_rstd_deviation += t_coords[i] * output_rstd_deviation_strides_local[i].value; + offset_input += t_coords[i] * input_strides_local[i].value; + } + + // Shared memory allocation + __shared__ Tdata input_sm[SM_SIZE / sizeof(Tdata)]; + __shared__ Tdata weight_sm[SM_SIZE / sizeof(Tdata)]; + __shared__ Tdata bias_sm[SM_SIZE / sizeof(Tdata)]; + __shared__ Tdata output_sm[SM_SIZE / sizeof(Tdata)]; + __shared__ Tdata output_standardization_sm[SM_SIZE / sizeof(Tdata)]; + __shared__ Tdata output_rstd_sm[1]; + + // Copy data to shared memory + if (weight_strides == 1 && bias_strides == 1 && core_id() == 0) { + GM2SM_ASYNC(input + offset_input, input_sm, normalized_size * sizeof(Tdata)); + GM2SM_ASYNC(weight, weight_sm, normalized_size * sizeof(Tdata)); + if (bias_exist) { + GM2SM_ASYNC(bias, bias_sm, normalized_size * sizeof(Tdata)); + } + } else { + for (int32_t i = core_id(); i < normalized_size; i += BLOCK_SIZE) { + // Load input + GM2SM_ASYNC(input + offset_input + i * input_strides_local[ndim - 1].value, input_sm + i, sizeof(Tdata)); + // Load weight + GM2SM_ASYNC(weight + i * weight_strides, weight_sm + i, sizeof(Tdata)); + // Load bias + if (bias_exist) { + GM2SM_ASYNC(bias + i * bias_strides, bias_sm + i, sizeof(Tdata)); + } + } + } + sync_cluster(); + + // Compute layer norm in each cluster + layerNormCluster( + output_sm, + output_standardization_sm, + output_rstd_sm, + input_sm, + weight_sm, + bias_sm, + eps, + normalized_size, + bias_exist); + sync_cluster(); + + // Copy results back to global memory + if (core_id() == 0) { + SM2GM_ASYNC(output_sm, output + offset_output, normalized_size * sizeof(Tdata)); + SM2GM_ASYNC(output_standardization_sm, output_standardization + offset_output_standardization, normalized_size * sizeof(Tdata)); + SM2GM_ASYNC(output_rstd_sm, output_rstd_deviation + offset_output_rstd_deviation, sizeof(Tdata)); + } + sync_cluster(); +} + +template +__global__ void layerNormKernelV2( + Tdata *output, + Tdata *output_standardization, + Tdata *output_rstd_deviation, + const Tdata *input, + const Tdata *weight, // [dim] + const Tdata *bias, // [None | dim] + float eps, + int32_t normalized_size, + int32_t othersize, + const void *shape, // [ndim] + const void *output_strides, // [ndim] + const void *output_standardization_strides, // [ndim] + const void *output_rstd_deviation_strides, // [ndim - 1] + const void *input_strides, // [ndim] + int32_t weight_strides, + int32_t bias_strides, + int32_t ndim, + bool bias_exist) { + + int32_t cid = core_id(); + int32_t ncores = core_num(); + int32_t thread_idx = ncores * cluster_id() + cid; + int32_t nthreads = ncores * cluster_num(); + + if (thread_idx >= othersize) { + return; + } + + // Shape and strides + __local__ _size_t shape_local[ndim]; + __local__ _ptrdiff_t output_strides_local[ndim]; + __local__ _ptrdiff_t output_standardization_strides_local[ndim]; + __local__ _ptrdiff_t output_rstd_deviation_strides_local[ndim - 1]; + __local__ _ptrdiff_t input_strides_local[ndim]; + + // Load shape and strides from global memory to local memory + GM2LM_ASYNC(shape, shape_local, ndim * sizeof(_size_t)); + GM2LM_ASYNC(output_strides, output_strides_local, ndim * sizeof(_ptrdiff_t)); + GM2LM_ASYNC(output_standardization_strides, output_standardization_strides_local, ndim * sizeof(_ptrdiff_t)); + GM2LM_ASYNC(output_rstd_deviation_strides, output_rstd_deviation_strides_local, (ndim - 1) * sizeof(_ptrdiff_t)); + GM2LM_ASYNC(input_strides, input_strides_local, ndim * sizeof(_ptrdiff_t)); + mfence(); + + // Allocate local memory for i/o + __local__ Tdata input_local[normalized_size]; + __local__ Tdata output_local[normalized_size]; + __local__ Tdata output_standardization_local[normalized_size]; + __local__ Tdata output_rstd[1]; + + // Allocal Shared memory for weight and bias + __shared__ Tdata weight_sm[SM_SIZE / sizeof(Tdata)]; + __shared__ Tdata bias_sm[SM_SIZE / sizeof(Tdata)]; + + // Every normalization vector as a tile to be processed + for (int32_t tid = thread_idx; tid < othersize; tid += nthreads) { + // Calculate tile coordinates + int32_t t_coords[ndim - 1]; + int32_t temp_tid = tid; + for (int i = 0; i < ndim - 1; i++) { + int32_t dim_i = shape_local[i].value; + t_coords[i] = temp_tid % dim_i; + temp_tid /= dim_i; + } + + // Calculate offsets + int32_t offset_output = 0; + int32_t offset_output_standardization = 0; + int32_t offset_output_rstd_deviation = 0; + int32_t offset_input = 0; + for (int i = 0; i < ndim - 1; i++) { + int32_t dim_i = shape_local[i].value; + offset_output += t_coords[i] * output_strides_local[i].value; + offset_output_standardization += t_coords[i] * output_standardization_strides_local[i].value; + offset_output_rstd_deviation += t_coords[i] * output_rstd_deviation_strides_local[i].value; + offset_input += t_coords[i] * input_strides_local[i].value; + } + + // Load input into local memory + GM2LM_ASYNC(input + offset_input, input_local, normalized_size * sizeof(Tdata)); + + // Load weight and bias to shared memory + if (core_id() == 0) { + if (weight_strides == 1 && bias_strides == 1) { + GM2SM_ASYNC(weight, weight_sm, normalized_size * sizeof(Tdata)); + if (bias_exist) { + GM2SM_ASYNC(bias, bias_sm, normalized_size * sizeof(Tdata)); + } + } else { + for (int32_t i = 0; i < normalized_size; i++) { + // Load weight + GM2SM_ASYNC(weight + i * weight_strides, weight_sm + i, sizeof(Tdata)); + // Load bias + if (bias_exist) { + GM2SM_ASYNC(bias + i * bias_strides, bias_sm + i, sizeof(Tdata)); + } + } + } + } + sync_cluster(); + + // Compute layer norm in each core + layerNormBlock( + output_local, + output_standardization_local, + output_rstd, + input_local, + weight_sm, + bias_sm, + eps, + normalized_size, + bias_exist); + mfence(); + + // Copy result into global memory + LM2GM_ASYNC(output_local, output + offset_output, normalized_size * sizeof(Tdata)); + LM2GM_ASYNC(output_standardization_local, output_standardization + offset_output_standardization, normalized_size * sizeof(Tdata)); + LM2GM_ASYNC(output_rstd, output_rstd_deviation + offset_output_rstd_deviation, sizeof(Tdata)); + mfence(); + } +} + +namespace op::layer_norm::kunlun { + +template +infiniStatus_t launchLayerNormKernel( + const LayerNormInfo &info, + Tdata *output, + Tdata *output_standardization, + Tdata *output_rstd_deviation, + const Tdata *input, + const Tdata *weight, + const Tdata *bias, + kunlunStream_t stream, + void *workspace) { + + size_t ndim = info.ndim; + char *workspace_ptr = reinterpret_cast(workspace); + + // Prepare strides and shape pointer in kunlun device memory + ptrdiff_t *input_strides_kunlun = reinterpret_cast(workspace_ptr); + ptrdiff_t *output_strides_kunlun = input_strides_kunlun + ndim; + ptrdiff_t *output_standardization_strides_kunlun = output_strides_kunlun + ndim; + ptrdiff_t *output_rstd_deviation_strides_kunlun = output_standardization_strides_kunlun + ndim; + + size_t ptrdiff_array_size = (4 * ndim - 1) * sizeof(ptrdiff_t); + size_t *shape_kunlun = reinterpret_cast(workspace_ptr + ptrdiff_array_size); + + // Copy strides and shape to kunlun device memory + CHECK_KUNLUN(xpu_memcpy_async(input_strides_kunlun, info.input_strides.data(), + ndim * sizeof(ptrdiff_t), XPU_HOST_TO_DEVICE, stream)); + CHECK_KUNLUN(xpu_memcpy_async(output_strides_kunlun, info.output_strides.data(), + ndim * sizeof(ptrdiff_t), XPU_HOST_TO_DEVICE, stream)); + CHECK_KUNLUN(xpu_memcpy_async(output_standardization_strides_kunlun, info.input_standardization_strides.data(), + ndim * sizeof(ptrdiff_t), XPU_HOST_TO_DEVICE, stream)); + CHECK_KUNLUN(xpu_memcpy_async(output_rstd_deviation_strides_kunlun, info.input_std_deviation_strides.data(), + (ndim - 1) * sizeof(ptrdiff_t), XPU_HOST_TO_DEVICE, stream)); + + CHECK_KUNLUN(xpu_memcpy_async(shape_kunlun, info.input_shape.data(), + ndim * sizeof(size_t), XPU_HOST_TO_DEVICE, stream)); + + int32_t normalized_size = static_cast(info.normalized_size); + // Launch kernel + if (normalized_size <= 512) { + layerNormKernelV2 + <<<12, BLOCK_SIZE, stream>>>( + output, + output_standardization, + output_rstd_deviation, + input, + weight, + bias, + info.eps, + normalized_size, + static_cast(info.othersize), + reinterpret_cast<__global_ptr__ const void *>(shape_kunlun), + reinterpret_cast<__global_ptr__ const void *>(output_strides_kunlun), + reinterpret_cast<__global_ptr__ const void *>(output_standardization_strides_kunlun), + reinterpret_cast<__global_ptr__ const void *>(output_rstd_deviation_strides_kunlun), + reinterpret_cast<__global_ptr__ const void *>(input_strides_kunlun), + static_cast(info.weight_strides[0]), + static_cast(info.bias_exist ? info.bias_strides[0] : 0), + static_cast(ndim), + info.bias_exist); + } else { + int32_t num_blocks = static_cast(info.othersize) > 255 + ? 255 + : static_cast(info.othersize); + int32_t num_loops = (static_cast(info.othersize) + num_blocks - 1) / num_blocks; + for (int32_t i = 0; i < num_loops; i++) { + layerNormKernel + <<>>( + i, + output, + output_standardization, + output_rstd_deviation, + input, + weight, + bias, + info.eps, + normalized_size, + static_cast(info.othersize), + reinterpret_cast<__global_ptr__ const void *>(shape_kunlun), + reinterpret_cast<__global_ptr__ const void *>(output_strides_kunlun), + reinterpret_cast<__global_ptr__ const void *>(output_standardization_strides_kunlun), + reinterpret_cast<__global_ptr__ const void *>(output_rstd_deviation_strides_kunlun), + reinterpret_cast<__global_ptr__ const void *>(input_strides_kunlun), + static_cast(info.weight_strides[0]), + static_cast(info.bias_exist ? info.bias_strides[0] : 0), + static_cast(ndim), + info.bias_exist); + } + } + return INFINI_STATUS_SUCCESS; +} + +typedef device::kunlun::Handle::Internal HandleInternal; + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t output_desc, + infiniopTensorDescriptor_t input_standardization_desc, + infiniopTensorDescriptor_t input_std_deviation_desc, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t weight_desc, + infiniopTensorDescriptor_t bias_desc, + float eps) { + + auto handle = reinterpret_cast(handle_); + auto dtype = input_desc->dtype(); + auto ndim = input_desc->ndim(); + CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16); + size_t workspace_size = (ndim * 4 - 1) * sizeof(ptrdiff_t) + + ndim * sizeof(size_t); + + auto result = LayerNormInfo::createLayerNormInfo( + output_desc, + input_standardization_desc, + input_std_deviation_desc, + input_desc, + weight_desc, + bias_desc, + eps); + CHECK_RESULT(result); + + *desc_ptr = new Descriptor( + dtype, std::move(result.take()), workspace_size, + new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + void *workspace, + size_t workspace_size, + void *output, + void *input_standardization, + void *input_std_deviation, + const void *input, + const void *weight, + const void *bias, + void *stream_) const { + if (workspace_size < workspaceSize()) { + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; + } + + kunlunStream_t stream = (kunlunStream_t)stream_; + +#define DISPATCH_KERNEL(Tdata) \ + return kunlun::launchLayerNormKernel( \ + _info, (Tdata *)output, \ + (Tdata *)input_standardization, \ + (Tdata *)input_std_deviation, \ + (const Tdata *)input, \ + (const Tdata *)weight, \ + (const Tdata *)bias, \ + stream, \ + workspace); + + const LayerNormInfo &info = this->_info; + const unsigned int BLOCK_SIZE = 64; + + if (_info.dtype == INFINI_DTYPE_F16) { + DISPATCH_KERNEL(half); + } else if (_info.dtype == INFINI_DTYPE_BF16) { + DISPATCH_KERNEL(bfloat16_t); + } else if (_info.dtype == INFINI_DTYPE_F32) { + DISPATCH_KERNEL(float); + } else { + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } +#undef DISPATCH_KERNEL + + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::layer_norm::kunlun diff --git a/src/infiniop/ops/layer_norm/operator.cc b/src/infiniop/ops/layer_norm/operator.cc index 3dbbdcb21..430d3200e 100644 --- a/src/infiniop/ops/layer_norm/operator.cc +++ b/src/infiniop/ops/layer_norm/operator.cc @@ -11,6 +11,9 @@ #ifdef ENABLE_METAX_API #include "metax/layer_norm_metax.h" #endif +#ifdef ENABLE_KUNLUN_API +#include "kunlun/layer_norm_kunlun.h" +#endif __C infiniStatus_t infiniopCreateLayerNormDescriptor( infiniopHandle_t handle, @@ -52,6 +55,9 @@ __C infiniStatus_t infiniopCreateLayerNormDescriptor( #ifdef ENABLE_METAX_API CREATE(INFINI_DEVICE_METAX, metax); #endif +#ifdef ENABLE_KUNLUN_API + CREATE(INFINI_DEVICE_KUNLUN, kunlun); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -81,6 +87,9 @@ __C infiniStatus_t infiniopGetLayerNormWorkspaceSize(infiniopLayerNormDescriptor #endif #ifdef ENABLE_METAX_API GET(INFINI_DEVICE_METAX, metax); +#endif +#ifdef ENABLE_KUNLUN_API + GET(INFINI_DEVICE_KUNLUN, kunlun); #endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -132,6 +141,9 @@ __C infiniStatus_t infiniopLayerNorm( #ifdef ENABLE_METAX_API CALCULATE(INFINI_DEVICE_METAX, metax); #endif +#ifdef ENABLE_KUNLUN_API + CALCULATE(INFINI_DEVICE_KUNLUN, kunlun); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; @@ -162,6 +174,9 @@ infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc) { #ifdef ENABLE_METAX_API DELETE(INFINI_DEVICE_METAX, metax); #endif +#ifdef ENABLE_KUNLUN_API + DELETE(INFINI_DEVICE_KUNLUN, kunlun); +#endif default: return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; diff --git a/src/infiniop/reduce/kunlun/reduce_kunlun.h b/src/infiniop/reduce/kunlun/reduce_kunlun.h index 34f23829b..3a3b78198 100644 --- a/src/infiniop/reduce/kunlun/reduce_kunlun.h +++ b/src/infiniop/reduce/kunlun/reduce_kunlun.h @@ -73,6 +73,28 @@ __device__ inline Tdata max(__shared_ptr__ const Tdata *data_ptr, size_t count) return temp_storage; } +// Sum(x) in single block +template +__device__ inline Tcompute blockSum(__local__ const Tdata *data_ptr, size_t count) { + Tcompute ss = 0; + for (size_t i = 0; i < count; i += 1) { + Tdata xi = data_ptr[i]; + ss += Tcompute(xi); + } + return ss; +} + +// Sum(x^2) in single block +template +__device__ inline Tcompute blockSumSquared(__local__ const Tdata *data_ptr, size_t count) { + Tcompute ss = 0; + for (size_t i = 0; i < count; i += 1) { + Tdata xi = data_ptr[i]; + ss += Tcompute(xi) * Tcompute(xi); + } + return ss; +} + } // namespace op::common_kunlun::reduce_op #endif diff --git a/test/infiniop/layer_norm.py b/test/infiniop/layer_norm.py index 85a360248..d1ae6ce20 100644 --- a/test/infiniop/layer_norm.py +++ b/test/infiniop/layer_norm.py @@ -71,48 +71,41 @@ class Inplace(Enum): def torch_layer_norm( output: torch.Tensor, - input_standardization: torch.Tensor, - input_std_deviation: torch.Tensor, + output_standardization: torch.Tensor, + output_rstd: torch.Tensor, input: torch.Tensor, weight, bias, eps, bias_exist: bool, ): - normalized_shape = input.shape[-1:] - ln = torch.nn.LayerNorm( - normalized_shape=normalized_shape, - eps=eps, - dtype=torch.float, - bias=bias_exist, - device=input.device, - ) - ln.weight.data = weight.type(torch.float) - if bias_exist: - ln.bias.data = bias.type(torch.float) - input = input.type(torch.float) - mean = input.mean(dim=-1, keepdim=True) - var = input.var(dim=-1, correction=0) - std = torch.sqrt(var + eps) - input_standardization.copy_( - ((input - mean) / std.unsqueeze(2)).type(input_standardization.dtype) - ) - input_std_deviation.copy_(std.type(input_standardization.dtype)) - output.copy_(ln(input).detach().type(output.dtype)) + original_dtype = input.dtype + input_f32 = input.to(torch.float32) + weight_f32 = weight.to(torch.float32) + bias_f32 = bias.to(torch.float32) if bias_exist and bias is not None else None -def layer_norm( - output: torch.Tensor, input: torch.Tensor, weight, bias, eps, bias_exist: bool -): - normalized_shape = input.shape[-1:] - ln = torch.nn.LayerNorm( - normalized_shape=normalized_shape, eps=eps, bias=bias_exist, device=input.device - ) + mean = input_f32.mean(dim=-1, keepdim=True) # [..., 1] + # print("mean in torch", mean) + var = input_f32.var(dim=-1, keepdim=True, correction=0) # [..., 1] + # print("var in torch", var) - ln.weight.data = weight - if bias_exist: - ln.bias.data = bias - output.copy_(ln.forward(input).detach().type(output.dtype)) + rstd = torch.rsqrt(var + eps) # [..., 1] + # print("rstd in torch", rstd) + centered_input = input_f32 - mean # [..., D] + normalized = centered_input * rstd # [..., D] + + # y = normalized * weight + bias + output_f32 = normalized * weight_f32 + if bias_exist and bias_f32 is not None: + output_f32 = output_f32 + bias_f32 + + # 写回最终输出(转回原始 dtype) + output.copy_(output_f32.to(original_dtype).detach()) + # 写入中间输出:centered input (x - μ) + output_standardization.copy_(normalized.to(output_standardization.dtype).detach()) + # 写入中间输出:rstd,注意去掉最后的 keepdim 维度(从 [..., 1] → [...]) + output_rstd.copy_(rstd.squeeze(-1).to(output_rstd.dtype).detach()) def test( @@ -148,7 +141,7 @@ def test( device, ) - input = TestTensor(input_shape, input_strides, dtype, device, mode="zeros") + input = TestTensor(input_shape, input_strides, dtype, device) if inplace == Inplace.INPLACE: if output_strides != input_strides: return @@ -179,8 +172,10 @@ def test( else None ) - layer_norm( + torch_layer_norm( output.torch_tensor(), + input_standardization.torch_tensor(), + input_std_deviation.torch_tensor(), input.torch_tensor(), weight.torch_tensor(), bias.torch_tensor() if bias_exist else None, @@ -295,4 +290,4 @@ def lib_layer_norm(): for device in get_test_devices(args): test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) - print("\033[92mTest my layer_norm passed!\033[0m") + print("\033[92mTest layer_norm passed!\033[0m") From 37cd80217df851599eeb9b035616a6447aa55a31 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 23 Dec 2025 15:53:05 +0800 Subject: [PATCH 2/4] issue/826: modify according comments --- .../layer_norm/kunlun/layer_norm_kunlun.xpu | 235 +++++++++--------- 1 file changed, 112 insertions(+), 123 deletions(-) diff --git a/src/infiniop/ops/layer_norm/kunlun/layer_norm_kunlun.xpu b/src/infiniop/ops/layer_norm/kunlun/layer_norm_kunlun.xpu index 2f67d1d3d..826e55e69 100644 --- a/src/infiniop/ops/layer_norm/kunlun/layer_norm_kunlun.xpu +++ b/src/infiniop/ops/layer_norm/kunlun/layer_norm_kunlun.xpu @@ -6,21 +6,21 @@ template __global__ void layerNormKernel( - int32_t loop_idx, - Tdata *output, // [b, seq, dim] - Tdata *output_standardization, // [b, seq, dim] - Tdata *output_rstd_deviation, // [b, seq] - const Tdata *input, // [b, seq, dim] - const Tdata *weight, // [dim] - const Tdata *bias, // [None | dim] + // int32_t loop_idx, + Tdata *output, + Tdata *output_standardization, + Tdata *output_rstd_deviation, + const Tdata *input, + const Tdata *weight, + const Tdata *bias, float eps, int32_t normalized_size, int32_t othersize, - const void *shape, // [ndim] - const void *output_strides, // [ndim] - const void *output_standardization_strides, // [ndim] - const void *output_rstd_deviation_strides, // [ndim - 1] - const void *input_strides, // [ndim] + const void *shape, + const void *output_strides, + const void *output_standardization_strides, + const void *output_rstd_deviation_strides, + const void *input_strides, int32_t weight_strides, int32_t bias_strides, int32_t ndim, @@ -41,82 +41,82 @@ __global__ void layerNormKernel( GM2LM_ASYNC(input_strides, input_strides_local, ndim * sizeof(_ptrdiff_t)); mfence(); - // Calculate block tile coordinates - int32_t block_idx = cluster_id() + loop_idx * cluster_num(); - if (block_idx >= othersize) { - return; - } - - int32_t t_coords[ndim - 1]; - int32_t temp_block_idx = block_idx; - for (int i = ndim - 2; i >= 0; i--) { - int32_t dim_i = shape_local[i].value; - t_coords[i] = temp_block_idx % dim_i; - temp_block_idx /= dim_i; - } - - // Calculate offsets - int32_t offset_output = 0; - int32_t offset_output_standardization = 0; - int32_t offset_output_rstd_deviation = 0; - int32_t offset_input = 0; - for (int i = 0; i < ndim - 1; i++) { - int32_t dim_i = shape_local[i].value; - offset_output += t_coords[i] * output_strides_local[i].value; - offset_output_standardization += t_coords[i] * output_standardization_strides_local[i].value; - offset_output_rstd_deviation += t_coords[i] * output_rstd_deviation_strides_local[i].value; - offset_input += t_coords[i] * input_strides_local[i].value; - } + int32_t num_clusters = cluster_num(); + int32_t num_loops = (othersize + num_clusters - 1) / num_clusters; + for (int32_t loop_idx = 0; loop_idx < num_loops; ++loop_idx) { + // Calculate block tile coordinates + int32_t block_idx = cluster_id() + loop_idx * cluster_num(); + if (block_idx >= othersize) { + return; + } + + int32_t temp_block_idx = block_idx; - // Shared memory allocation - __shared__ Tdata input_sm[SM_SIZE / sizeof(Tdata)]; - __shared__ Tdata weight_sm[SM_SIZE / sizeof(Tdata)]; - __shared__ Tdata bias_sm[SM_SIZE / sizeof(Tdata)]; - __shared__ Tdata output_sm[SM_SIZE / sizeof(Tdata)]; - __shared__ Tdata output_standardization_sm[SM_SIZE / sizeof(Tdata)]; - __shared__ Tdata output_rstd_sm[1]; - - // Copy data to shared memory - if (weight_strides == 1 && bias_strides == 1 && core_id() == 0) { - GM2SM_ASYNC(input + offset_input, input_sm, normalized_size * sizeof(Tdata)); - GM2SM_ASYNC(weight, weight_sm, normalized_size * sizeof(Tdata)); - if (bias_exist) { - GM2SM_ASYNC(bias, bias_sm, normalized_size * sizeof(Tdata)); + // Decode multi-dimensional coordinates and accumulate offsets in one pass (reverse order) + int32_t offset_output = 0; + int32_t offset_output_standardization = 0; + int32_t offset_output_rstd_deviation = 0; + int32_t offset_input = 0; + for (int i = ndim - 2; i >= 0; --i) { + int32_t dim_i = shape_local[i].value; + int32_t coord = temp_block_idx % dim_i; + temp_block_idx /= dim_i; + offset_output += coord * output_strides_local[i].value; + offset_output_standardization += coord * output_standardization_strides_local[i].value; + offset_output_rstd_deviation += coord * output_rstd_deviation_strides_local[i].value; + offset_input += coord * input_strides_local[i].value; } - } else { - for (int32_t i = core_id(); i < normalized_size; i += BLOCK_SIZE) { - // Load input - GM2SM_ASYNC(input + offset_input + i * input_strides_local[ndim - 1].value, input_sm + i, sizeof(Tdata)); - // Load weight - GM2SM_ASYNC(weight + i * weight_strides, weight_sm + i, sizeof(Tdata)); - // Load bias + + // Shared memory allocation + __shared__ Tdata input_sm[SM_SIZE / sizeof(Tdata)]; + __shared__ Tdata weight_sm[SM_SIZE / sizeof(Tdata)]; + __shared__ Tdata bias_sm[SM_SIZE / sizeof(Tdata)]; + __shared__ Tdata output_sm[SM_SIZE / sizeof(Tdata)]; + __shared__ Tdata output_standardization_sm[SM_SIZE / sizeof(Tdata)]; + __shared__ Tdata output_rstd_sm[1]; + + // Copy data to shared memory + if (weight_strides == 1 && bias_strides == 1 && core_id() == 0) { + GM2SM_ASYNC(input + offset_input, input_sm, normalized_size * sizeof(Tdata)); + GM2SM_ASYNC(weight, weight_sm, normalized_size * sizeof(Tdata)); if (bias_exist) { - GM2SM_ASYNC(bias + i * bias_strides, bias_sm + i, sizeof(Tdata)); + GM2SM_ASYNC(bias, bias_sm, normalized_size * sizeof(Tdata)); + } + } else { + for (int32_t i = core_id(); i < normalized_size; i += BLOCK_SIZE) { + // Load input + GM2SM_ASYNC(input + offset_input + i * input_strides_local[ndim - 1].value, input_sm + i, sizeof(Tdata)); + // Load weight + GM2SM_ASYNC(weight + i * weight_strides, weight_sm + i, sizeof(Tdata)); + // Load bias + if (bias_exist) { + GM2SM_ASYNC(bias + i * bias_strides, bias_sm + i, sizeof(Tdata)); + } } } + sync_cluster(); + + // Compute layer norm in each cluster + layerNormCluster( + output_sm, + output_standardization_sm, + output_rstd_sm, + input_sm, + weight_sm, + bias_sm, + eps, + normalized_size, + bias_exist); + sync_cluster(); + + // Copy results back to global memory + if (core_id() == 0) { + SM2GM_ASYNC(output_sm, output + offset_output, normalized_size * sizeof(Tdata)); + SM2GM_ASYNC(output_standardization_sm, output_standardization + offset_output_standardization, normalized_size * sizeof(Tdata)); + SM2GM_ASYNC(output_rstd_sm, output_rstd_deviation + offset_output_rstd_deviation, sizeof(Tdata)); + } + sync_cluster(); } - sync_cluster(); - - // Compute layer norm in each cluster - layerNormCluster( - output_sm, - output_standardization_sm, - output_rstd_sm, - input_sm, - weight_sm, - bias_sm, - eps, - normalized_size, - bias_exist); - sync_cluster(); - - // Copy results back to global memory - if (core_id() == 0) { - SM2GM_ASYNC(output_sm, output + offset_output, normalized_size * sizeof(Tdata)); - SM2GM_ASYNC(output_standardization_sm, output_standardization + offset_output_standardization, normalized_size * sizeof(Tdata)); - SM2GM_ASYNC(output_rstd_sm, output_rstd_deviation + offset_output_rstd_deviation, sizeof(Tdata)); - } - sync_cluster(); } template @@ -176,26 +176,19 @@ __global__ void layerNormKernelV2( // Every normalization vector as a tile to be processed for (int32_t tid = thread_idx; tid < othersize; tid += nthreads) { - // Calculate tile coordinates - int32_t t_coords[ndim - 1]; int32_t temp_tid = tid; - for (int i = 0; i < ndim - 1; i++) { - int32_t dim_i = shape_local[i].value; - t_coords[i] = temp_tid % dim_i; - temp_tid /= dim_i; - } - - // Calculate offsets int32_t offset_output = 0; int32_t offset_output_standardization = 0; int32_t offset_output_rstd_deviation = 0; int32_t offset_input = 0; - for (int i = 0; i < ndim - 1; i++) { + for (int i = 0; i < ndim - 1; ++i) { int32_t dim_i = shape_local[i].value; - offset_output += t_coords[i] * output_strides_local[i].value; - offset_output_standardization += t_coords[i] * output_standardization_strides_local[i].value; - offset_output_rstd_deviation += t_coords[i] * output_rstd_deviation_strides_local[i].value; - offset_input += t_coords[i] * input_strides_local[i].value; + int32_t coord = temp_tid % dim_i; + temp_tid /= dim_i; + offset_output += coord * output_strides_local[i].value; + offset_output_standardization += coord * output_standardization_strides_local[i].value; + offset_output_rstd_deviation += coord * output_rstd_deviation_strides_local[i].value; + offset_input += coord * input_strides_local[i].value; } // Load input into local memory @@ -305,33 +298,29 @@ infiniStatus_t launchLayerNormKernel( static_cast(ndim), info.bias_exist); } else { - int32_t num_blocks = static_cast(info.othersize) > 255 - ? 255 - : static_cast(info.othersize); - int32_t num_loops = (static_cast(info.othersize) + num_blocks - 1) / num_blocks; - for (int32_t i = 0; i < num_loops; i++) { - layerNormKernel - <<>>( - i, - output, - output_standardization, - output_rstd_deviation, - input, - weight, - bias, - info.eps, - normalized_size, - static_cast(info.othersize), - reinterpret_cast<__global_ptr__ const void *>(shape_kunlun), - reinterpret_cast<__global_ptr__ const void *>(output_strides_kunlun), - reinterpret_cast<__global_ptr__ const void *>(output_standardization_strides_kunlun), - reinterpret_cast<__global_ptr__ const void *>(output_rstd_deviation_strides_kunlun), - reinterpret_cast<__global_ptr__ const void *>(input_strides_kunlun), - static_cast(info.weight_strides[0]), - static_cast(info.bias_exist ? info.bias_strides[0] : 0), - static_cast(ndim), - info.bias_exist); - } + int32_t NUM_CLUSTERS = static_cast(info.othersize) > MAX_CLUSTERS + ? MAX_CLUSTERS + : static_cast(info.othersize); + layerNormKernel + <<>>( + output, + output_standardization, + output_rstd_deviation, + input, + weight, + bias, + info.eps, + normalized_size, + static_cast(info.othersize), + reinterpret_cast<__global_ptr__ const void *>(shape_kunlun), + reinterpret_cast<__global_ptr__ const void *>(output_strides_kunlun), + reinterpret_cast<__global_ptr__ const void *>(output_standardization_strides_kunlun), + reinterpret_cast<__global_ptr__ const void *>(output_rstd_deviation_strides_kunlun), + reinterpret_cast<__global_ptr__ const void *>(input_strides_kunlun), + static_cast(info.weight_strides[0]), + static_cast(info.bias_exist ? info.bias_strides[0] : 0), + static_cast(ndim), + info.bias_exist); } return INFINI_STATUS_SUCCESS; } From f62b11fcd6b8af2cfe22029133ed6b88a90abedf Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 23 Dec 2025 16:50:27 +0800 Subject: [PATCH 3/4] issue/826: delete debug print --- test/infiniop/layer_norm.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/infiniop/layer_norm.py b/test/infiniop/layer_norm.py index d1ae6ce20..8fe0a470e 100644 --- a/test/infiniop/layer_norm.py +++ b/test/infiniop/layer_norm.py @@ -86,12 +86,9 @@ def torch_layer_norm( bias_f32 = bias.to(torch.float32) if bias_exist and bias is not None else None mean = input_f32.mean(dim=-1, keepdim=True) # [..., 1] - # print("mean in torch", mean) var = input_f32.var(dim=-1, keepdim=True, correction=0) # [..., 1] - # print("var in torch", var) rstd = torch.rsqrt(var + eps) # [..., 1] - # print("rstd in torch", rstd) centered_input = input_f32 - mean # [..., D] normalized = centered_input * rstd # [..., D] From c5f3c8a0b3c42dc4653662d3804e32dde460f3a7 Mon Sep 17 00:00:00 2001 From: zhangyue Date: Tue, 23 Dec 2025 16:55:58 +0800 Subject: [PATCH 4/4] issue/826: add print for test script --- src/infiniop/devices/kunlun/kunlun_kernel_common.h | 2 ++ test/infiniop/layer_norm.py | 5 +++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/infiniop/devices/kunlun/kunlun_kernel_common.h b/src/infiniop/devices/kunlun/kunlun_kernel_common.h index 45758e9d9..ce7e19e0e 100644 --- a/src/infiniop/devices/kunlun/kunlun_kernel_common.h +++ b/src/infiniop/devices/kunlun/kunlun_kernel_common.h @@ -13,6 +13,8 @@ namespace device::kunlun::kernel { #define SM_SIZE 40960 +#define MAX_CLUSTERS 255 // P800 +#define MAX_BLOCK_SIZE 64 /** * @brief Define ptrdiff_t and size_t for kunlun xpu diff --git a/test/infiniop/layer_norm.py b/test/infiniop/layer_norm.py index 8fe0a470e..54839eccf 100644 --- a/test/infiniop/layer_norm.py +++ b/test/infiniop/layer_norm.py @@ -119,8 +119,9 @@ def test( sync=None, ): print( - f"Testing layer_norm on {InfiniDeviceNames[device]} with input_shape:{input_shape}," - f"bias:{bias_exist},eps:{eps}," + f"Testing layer_norm on {InfiniDeviceNames[device]} with input_shape:{input_shape}, " + f"bias:{bias_exist}, eps:{eps}, input_strides:{input_strides}, output_strides:{output_strides}, " + f"weight_strides:{weight_strides}, inplace:{inplace}, " f"dtype:{InfiniDtypeNames[dtype]}" )