-
Notifications
You must be signed in to change notification settings - Fork 93
Issue/791 增加add_rms_norm融合算子 #842
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| #pragma once | ||
|
|
||
| #include "../device.hpp" | ||
| #include "common/op.hpp" | ||
| #include <utility> | ||
|
|
||
| namespace infinicore::op { | ||
| class AddRMSNorm { | ||
| public: | ||
| using schema = void (*)(Tensor, Tensor, Tensor, Tensor, Tensor, float); | ||
| static void execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); | ||
| static common::OpDispatcher<schema> &dispatcher(); | ||
| }; | ||
|
|
||
| // Fused Add and RMS Normalization | ||
| // Returns: (normalized_result, add_result) | ||
| // The add_result can be used as residual for subsequent layers | ||
| std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); | ||
| void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon = 1e-5f); | ||
| } // namespace infinicore::op | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| #ifndef __INFINIOP_ADD_RMS_NORM_API_H__ | ||
| #define __INFINIOP_ADD_RMS_NORM_API_H__ | ||
|
|
||
| #include "../operator_descriptor.h" | ||
|
|
||
| typedef struct InfiniopDescriptor *infiniopAddRMSNormDescriptor_t; | ||
|
|
||
| __C __export infiniStatus_t infiniopCreateAddRMSNormDescriptor( | ||
| infiniopHandle_t handle, | ||
| infiniopAddRMSNormDescriptor_t *desc_ptr, | ||
| infiniopTensorDescriptor_t y_desc, | ||
| infiniopTensorDescriptor_t a_desc, | ||
| infiniopTensorDescriptor_t b_desc, | ||
| infiniopTensorDescriptor_t weight_desc, | ||
| float epsilon, | ||
| infiniopTensorDescriptor_t residual_out_desc); | ||
|
|
||
| __C __export infiniStatus_t infiniopGetAddRMSNormWorkspaceSize(infiniopAddRMSNormDescriptor_t desc, size_t *size); | ||
|
|
||
| __C __export infiniStatus_t infiniopAddRMSNorm(infiniopAddRMSNormDescriptor_t desc, | ||
| void *workspace, | ||
| size_t workspace_size, | ||
| void *y, | ||
| const void *a, | ||
| const void *b, | ||
| const void *weight, | ||
| void *residual_out, | ||
| void *stream); | ||
|
|
||
| __C __export infiniStatus_t infiniopDestroyAddRMSNormDescriptor(infiniopAddRMSNormDescriptor_t desc); | ||
|
|
||
| #endif |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,47 @@ | ||||||||||||||||||||||||||||||||
| from infinicore.lib import _infinicore | ||||||||||||||||||||||||||||||||
| from infinicore.tensor import Tensor | ||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||
| def add_rms_norm(a, b, weight, epsilon=1e-5, *, out=None): | ||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||
| Fused Add and RMS Normalization. | ||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||
| a: First input tensor | ||||||||||||||||||||||||||||||||
| b: Second input tensor | ||||||||||||||||||||||||||||||||
| weight: Scale weights | ||||||||||||||||||||||||||||||||
| epsilon: Small constant for numerical stability, default is 1e-5 | ||||||||||||||||||||||||||||||||
| out: Optional output tuple (y, residual_out) for in-place operation | ||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||
| Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b) | ||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||
| Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b) | |
| Tuple (normalized_result, add_result), where: | |
| add_result = a + b | |
| normalized_result = (a + b) * weight / sqrt(mean((a + b) ** 2) + epsilon) |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for the in-place function add_rms_norm_ lacks parameter documentation. While the out-of-place version has detailed documentation for all parameters and return values, the in-place version only has a one-line description. Add parameter documentation (y, residual_out, a, b, weight, epsilon) for consistency and to help users understand the API.
| """In-place Fused Add and RMS Normalization.""" | |
| """ | |
| In-place Fused Add and RMS Normalization. | |
| Args: | |
| y: Output tensor for the normalized result (RMSNorm(a + b) * weight). | |
| residual_out: Output tensor for the residual result (a + b). | |
| a: First input tensor. | |
| b: Second input tensor. | |
| weight: Scale weights applied after RMS normalization. | |
| epsilon: Small constant for numerical stability, default is 1e-5. | |
| Returns: | |
| None. The results are written in-place into ``y`` and ``residual_out``. | |
| """ |
Copilot
AI
Jan 7, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The in-place function add_rms_norm_ doesn't return anything (implicit None), but the out-of-place function add_rms_norm with out parameter returns (y, residual_out). For API consistency, the in-place function should also return the tuple (y, residual_out) to allow for method chaining and to match the pattern used in the out-of-place variant when out is provided.
| ) | |
| ) | |
| return (y, residual_out) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,29 @@ | ||
| #include "infinicore/ops/add_rms_norm.hpp" | ||
|
|
||
| #include "../../utils.hpp" | ||
|
|
||
| namespace infinicore::op { | ||
|
|
||
| common::OpDispatcher<AddRMSNorm::schema> &AddRMSNorm::dispatcher() { | ||
| static common::OpDispatcher<AddRMSNorm::schema> dispatcher_; | ||
| return dispatcher_; | ||
| }; | ||
|
|
||
| void AddRMSNorm::execute(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { | ||
| INFINICORE_ASSERT_TENSORS_SAME_DEVICE(y, residual_out, a, b, weight); | ||
| infinicore::context::setDevice(y->device()); | ||
| dispatcher().lookup(y->device().getType())(y, residual_out, a, b, weight, epsilon); | ||
| } | ||
|
|
||
| std::pair<Tensor, Tensor> add_rms_norm(Tensor a, Tensor b, Tensor weight, float epsilon) { | ||
| auto y = Tensor::empty(a->shape(), a->dtype(), a->device()); | ||
| auto residual_out = Tensor::empty(a->shape(), a->dtype(), a->device()); | ||
| add_rms_norm_(y, residual_out, a, b, weight, epsilon); | ||
| return std::make_pair(y, residual_out); | ||
| } | ||
|
|
||
| void add_rms_norm_(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { | ||
| AddRMSNorm::execute(y, residual_out, a, b, weight, epsilon); | ||
| } | ||
|
|
||
| } // namespace infinicore::op |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,50 @@ | ||
| #include "../../utils.hpp" | ||
| #include "infinicore/common/hash.hpp" | ||
| #include "infinicore/ops/add_rms_norm.hpp" | ||
| #include "infinicore/ops/common/cache.hpp" | ||
| #include <infiniop.h> | ||
|
|
||
| namespace infinicore::op::add_rms_norm_impl::infiniop { | ||
|
|
||
| thread_local common::OpCache<size_t, infiniopAddRMSNormDescriptor_t> caches( | ||
| 100, // capacity | ||
| [](infiniopAddRMSNormDescriptor_t &desc) { | ||
| if (desc != nullptr) { | ||
| INFINICORE_CHECK_ERROR(infiniopDestroyAddRMSNormDescriptor(desc)); | ||
| desc = nullptr; | ||
| } | ||
| }); | ||
|
|
||
| void calculate(Tensor y, Tensor residual_out, Tensor a, Tensor b, Tensor weight, float epsilon) { | ||
| size_t seed = hash_combine(y, residual_out, a, b, weight, epsilon); | ||
|
|
||
| auto device = context::getDevice(); | ||
| auto &cache = caches.getCache(device); | ||
|
|
||
| auto desc_opt = cache.get(seed); | ||
| infiniopAddRMSNormDescriptor_t desc = nullptr; | ||
|
|
||
| if (!desc_opt) { | ||
| INFINICORE_CHECK_ERROR(infiniopCreateAddRMSNormDescriptor( | ||
| context::getInfiniopHandle(device), &desc, | ||
| y->desc(), a->desc(), b->desc(), weight->desc(), epsilon, residual_out->desc())); | ||
| cache.put(seed, desc); | ||
| } else { | ||
| desc = *desc_opt; | ||
| } | ||
|
|
||
| size_t workspace_size = 0; | ||
| INFINICORE_CHECK_ERROR(infiniopGetAddRMSNormWorkspaceSize(desc, &workspace_size)); | ||
| std::shared_ptr<Memory> workspace = context::allocateMemory(workspace_size); | ||
|
|
||
| INFINICORE_CHECK_ERROR(infiniopAddRMSNorm( | ||
| desc, workspace->data(), workspace_size, | ||
| y->data(), a->data(), b->data(), weight->data(), residual_out->data(), context::getStream())); | ||
| } | ||
|
|
||
| static bool registered = []() { | ||
| AddRMSNorm::dispatcher().registerAll(&calculate, false); | ||
| return true; | ||
| }(); | ||
|
|
||
| } // namespace infinicore::op::add_rms_norm_impl::infiniop |
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,51 @@ | ||||||||||
| #pragma once | ||||||||||
|
|
||||||||||
| #include <pybind11/pybind11.h> | ||||||||||
|
|
||||||||||
| #include "infinicore/ops/add_rms_norm.hpp" | ||||||||||
|
|
||||||||||
| namespace py = pybind11; | ||||||||||
|
|
||||||||||
| namespace infinicore::ops { | ||||||||||
|
|
||||||||||
| inline void bind_add_rms_norm(py::module &m) { | ||||||||||
| m.def("add_rms_norm", | ||||||||||
| &op::add_rms_norm, | ||||||||||
| py::arg("a"), | ||||||||||
| py::arg("b"), | ||||||||||
| py::arg("weight"), | ||||||||||
| py::arg("epsilon") = 1e-5f, | ||||||||||
| R"doc(Fused Add and RMS Normalization. | ||||||||||
|
|
||||||||||
| Args: | ||||||||||
| a: First input tensor | ||||||||||
| b: Second input tensor | ||||||||||
| weight: Scale weights | ||||||||||
| epsilon: Small constant for numerical stability, default is 1e-5 | ||||||||||
|
|
||||||||||
| Returns: | ||||||||||
| Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b) | ||||||||||
|
||||||||||
| Tuple of (normalized_result, add_result): (RMSNorm(a + b) * weight, a + b) | |
| Tuple of (normalized_result, add_result): | |
| normalized_result = (a + b) * weight / sqrt(mean((a + b)^2) + epsilon) | |
| add_result = a + b |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| #ifndef ADD_RMS_NORM_H | ||
| #define ADD_RMS_NORM_H | ||
|
|
||
| #include "../../operator.h" | ||
| #include "info.h" | ||
|
|
||
| #define DESCRIPTOR(NAMESPACE) \ | ||
| \ | ||
| namespace op::add_rms_norm::NAMESPACE { \ | ||
| class Descriptor final : public InfiniopDescriptor { \ | ||
| struct Opaque; \ | ||
| Opaque *_opaque; \ | ||
| AddRMSNormInfo _info; \ | ||
| size_t _workspace_size; \ | ||
| \ | ||
| Descriptor( \ | ||
| Opaque *opaque, \ | ||
| AddRMSNormInfo info, \ | ||
| size_t workspace_size, \ | ||
| infiniDevice_t device_type, \ | ||
| int device_id) \ | ||
| : InfiniopDescriptor{device_type, device_id}, \ | ||
| _opaque(opaque), \ | ||
| _info(info), \ | ||
| _workspace_size(workspace_size) {} \ | ||
| \ | ||
| public: \ | ||
| ~Descriptor(); \ | ||
| \ | ||
| size_t workspaceSize() const { return _workspace_size; } \ | ||
| \ | ||
| static infiniStatus_t create( \ | ||
| infiniopHandle_t handle, \ | ||
| Descriptor **desc_ptr, \ | ||
| infiniopTensorDescriptor_t y_desc, \ | ||
| infiniopTensorDescriptor_t a_desc, \ | ||
| infiniopTensorDescriptor_t b_desc, \ | ||
| infiniopTensorDescriptor_t weight_desc, \ | ||
| float epsilon, \ | ||
| infiniopTensorDescriptor_t residual_out_desc); \ | ||
| \ | ||
| infiniStatus_t calculate( \ | ||
| void *workspace, size_t workspace_size, \ | ||
| void *y, \ | ||
| const void *a, \ | ||
| const void *b, \ | ||
| const void *weight, \ | ||
| void *residual_out, \ | ||
| void *stream) const; \ | ||
| }; \ | ||
| } | ||
|
|
||
| #endif // ADD_RMS_NORM_H |
Uh oh!
There was an error while loading. Please reload this page.