-
Notifications
You must be signed in to change notification settings - Fork 53
Allow AM receive callback to control data allocation/copy #479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: branch-0.46
Are you sure you want to change the base?
Changes from all commits
b706657
0490c1d
90bc524
ca30f43
8365739
002fe76
2924c6e
da80a5b
d20f401
8bc98ae
5f4d24d
d9a04ee
04eded7
7567b11
0710668
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,5 +1,5 @@ | ||||||||||||||
| /** | ||||||||||||||
| * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. | ||||||||||||||
| * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. | ||||||||||||||
| * SPDX-License-Identifier: BSD-3-Clause | ||||||||||||||
| */ | ||||||||||||||
| #pragma once | ||||||||||||||
|
|
@@ -34,6 +34,20 @@ class RequestAm : public Request { | |||||||||||||
| std::string _header{}; ///< Retain copy of header for send requests as workaround for | ||||||||||||||
| ///< https://github.com/openucx/ucx/issues/10424 | ||||||||||||||
|
|
||||||||||||||
| /** | ||||||||||||||
| * @brief Common implementation for receiveData operations. | ||||||||||||||
| * | ||||||||||||||
| * @param[in] amData AM data containing pointer and length. | ||||||||||||||
| * @param[in] bufferPtr Pointer to the buffer to receive data into. | ||||||||||||||
| * @param[in] managedBuffer Optional managed buffer for lifetime management (nullptr for raw | ||||||||||||||
| * pointers). | ||||||||||||||
| * | ||||||||||||||
| * @returns A shared_ptr to a Request object for the delayed receive operation. | ||||||||||||||
| */ | ||||||||||||||
| std::shared_ptr<Request> receiveDataImpl(const AmData& amData, | ||||||||||||||
| void* bufferPtr, | ||||||||||||||
| Buffer* managedBuffer); | ||||||||||||||
|
|
||||||||||||||
| /** | ||||||||||||||
| * @brief Private constructor of `ucxx::RequestAm`. | ||||||||||||||
| * | ||||||||||||||
|
|
@@ -154,6 +168,65 @@ class RequestAm : public Request { | |||||||||||||
| const ucp_am_recv_param_t* param); | ||||||||||||||
|
|
||||||||||||||
| [[nodiscard]] std::shared_ptr<Buffer> getRecvBuffer() override; | ||||||||||||||
|
|
||||||||||||||
| /** | ||||||||||||||
| * @brief Receive delayed Active Message data into a user-provided buffer. | ||||||||||||||
| * | ||||||||||||||
| * This method is used for delayed receive operations where `delayReceive`` was enabled. | ||||||||||||||
| * It takes a user-provided buffer and internally calls `ucp_am_recv_data_nbx`` to receive | ||||||||||||||
| * the AM data that was stored when the message first arrived. Returns a `Request`` object | ||||||||||||||
|
Comment on lines
+175
to
+177
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Start in markdown, finish in rst :) |
||||||||||||||
| * that the user can wait on for completion. | ||||||||||||||
| * | ||||||||||||||
| * @param[in] buffer The buffer to receive the AM data into. Must be large enough to hold | ||||||||||||||
| * the AM data (length available via the original AM callback). | ||||||||||||||
| * | ||||||||||||||
| * @returns A shared_ptr to a Request object that can be waited on for completion. | ||||||||||||||
| * Returns nullptr if this was not a delayed receive operation or AM data | ||||||||||||||
| * is not available. | ||||||||||||||
| */ | ||||||||||||||
| [[nodiscard]] std::shared_ptr<Request> receiveData(std::shared_ptr<Buffer> buffer); | ||||||||||||||
|
|
||||||||||||||
| /** | ||||||||||||||
| * @brief Receive delayed Active Message data into a user-provided buffer pointer. | ||||||||||||||
| * | ||||||||||||||
| * This method is used for delayed receive operations where `delayReceive`` was enabled. | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||
| * It takes a user-provided pointer to buffer and internally calls `ucp_am_recv_data_nbx` | ||||||||||||||
| * to receive the AM data that was stored when the message first arrived. Returns a | ||||||||||||||
| * `Request` object that the user can wait on for completion. | ||||||||||||||
| * | ||||||||||||||
| * @param[in] buffer The buffer pointer to receive the AM data into. Must be large enough | ||||||||||||||
| * to hold the AM data (length available via the original AM callback). | ||||||||||||||
| * | ||||||||||||||
| * @returns A shared_ptr to a Request object that can be waited on for completion. | ||||||||||||||
| * Returns nullptr if this was not a delayed receive operation or AM data | ||||||||||||||
| * is not available. | ||||||||||||||
| */ | ||||||||||||||
| [[nodiscard]] std::shared_ptr<Request> receiveData(void* buffer); | ||||||||||||||
|
|
||||||||||||||
| /** | ||||||||||||||
| * @brief Get the length of delayed Active Message data. | ||||||||||||||
| * | ||||||||||||||
| * This method returns the length of the AM data that was stored when delayReceive | ||||||||||||||
| * was enabled and the message was received via the receive callback. This allows | ||||||||||||||
| * users to allocate appropriately sized buffers before calling receiveData(). | ||||||||||||||
| * | ||||||||||||||
| * @returns The length of the AM data in bytes, or 0 if this was not a delayed | ||||||||||||||
| * receive operation or AM data is not available. | ||||||||||||||
| */ | ||||||||||||||
| [[nodiscard]] size_t getAmDataLength(); | ||||||||||||||
|
|
||||||||||||||
| /** | ||||||||||||||
| * @brief Check if this request uses delayed receive. | ||||||||||||||
| * | ||||||||||||||
| * This method returns true if this request is a delayed receive operation where | ||||||||||||||
| * delayReceive was enabled and AM data is stored for later retrieval. If true, | ||||||||||||||
| * users should use getAmDataLength() and receiveData() to retrieve the data. | ||||||||||||||
| * If false, users should use getRecvBuffer() to access immediately received data. | ||||||||||||||
| * | ||||||||||||||
| * @returns True if this is a delayed receive operation with stored AM data, | ||||||||||||||
| * false for immediate/eager receives or if no AM data is available. | ||||||||||||||
| */ | ||||||||||||||
| [[nodiscard]] bool isDelayedReceive(); | ||||||||||||||
| }; | ||||||||||||||
|
|
||||||||||||||
| } // namespace ucxx | ||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,8 +7,12 @@ | |
| #include <functional> | ||
| #include <limits> | ||
| #include <memory> | ||
| #include <optional> | ||
| #include <stdexcept> | ||
| #include <string> | ||
| #include <unordered_map> | ||
| #include <utility> | ||
| #include <vector> | ||
|
|
||
| #include <ucp/api/ucp.h> | ||
|
|
||
|
|
@@ -140,15 +144,6 @@ typedef RequestCallbackUserData EndpointCloseCallbackUserData; | |
| */ | ||
| typedef std::function<std::shared_ptr<Buffer>(size_t)> AmAllocatorType; | ||
|
|
||
| /** | ||
| * @brief Active Message receiver callback. | ||
| * | ||
| * Type for a custom Active Message receiver callback, executed by the remote worker upon | ||
| * Active Message request completion. The first parameter is the request that completed, | ||
| * the second is the handle of the UCX endpoint of the sender. | ||
| */ | ||
| typedef std::function<void(std::shared_ptr<Request>, ucp_ep_h)> AmReceiverCallbackType; | ||
|
|
||
| /** | ||
| * @brief Active Message receiver callback owner name. | ||
| * | ||
|
|
@@ -173,6 +168,161 @@ typedef uint64_t AmReceiverCallbackIdType; | |
| */ | ||
| typedef const std::string AmReceiverCallbackInfoSerialized; | ||
|
|
||
| /** | ||
| * @brief Active Message data information for delayed receiving. | ||
| * | ||
| * Structure containing the Active Message data pointer and length that will be used | ||
| * when the user chooses to delay receiving and handle ucp_am_recv_data_nbx manually. | ||
| */ | ||
| struct AmData { | ||
| void* data; ///< The Active Message data pointer from the receive callback | ||
| size_t length; ///< The length of the Active Message data | ||
|
|
||
| AmData() : data(nullptr), length(0) {} | ||
|
|
||
| /** | ||
| * @brief Construct an AmData object. | ||
| * | ||
| * @param[in] data The Active Message data pointer from the receive callback. | ||
| * @param[in] length The length of the Active Message data. | ||
| */ | ||
| AmData(void* data, size_t length) : data(data), length(length) {} | ||
| }; | ||
|
|
||
| /** | ||
| * @brief Container for arbitrary user header data that can be attached to Active Messages. | ||
| * | ||
| * This class provides a type-safe interface for storing arbitrary user header data that will be | ||
| * serialized and transmitted with Active Messages. It supports common data types while | ||
| * also allowing direct access to the underlying byte storage for custom serialization. | ||
| */ | ||
| class AmUserHeader { | ||
| private: | ||
| std::vector<uint8_t> _data; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: use Also, do you need this to be a resizeable thing, or is it sufficient for this object to manage a |
||
|
|
||
| public: | ||
| AmUserHeader() = delete; | ||
|
|
||
| /** | ||
| * @brief Construct AmUserHeader from a byte array. | ||
| * | ||
| * @param[in] data Pointer to the data to copy. | ||
| * @param[in] size Size of the data in bytes. | ||
| * | ||
| * @throws std::invalid_argument if data is null or if size is 0. | ||
| */ | ||
| AmUserHeader(const void* data, size_t size) | ||
| { | ||
| if (size == 0) { throw std::invalid_argument("AmUserHeader size must be greater than zero"); } | ||
| if (data == nullptr) { | ||
| throw std::invalid_argument("AmUserHeader data pointer cannot be null"); | ||
| } | ||
| _data.assign(static_cast<const uint8_t*>(data), static_cast<const uint8_t*>(data) + size); | ||
| } | ||
|
|
||
| /** | ||
| * @brief Construct AmUserHeader from a string. | ||
| * | ||
| * @param[in] str The string to store. | ||
| * | ||
| * @throws std::invalid_argument if the string is empty. | ||
| */ | ||
| explicit AmUserHeader(const std::string& str) : _data(str.begin(), str.end()) | ||
| { | ||
| if (str.empty()) { throw std::invalid_argument("AmUserHeader string cannot be empty"); } | ||
| } | ||
|
|
||
| /** | ||
| * @brief Construct AmUserHeader from a vector of bytes. | ||
| * | ||
| * @param[in] data The byte vector to copy. | ||
| * | ||
| * @throws std::invalid_argument if the vector is empty. | ||
| */ | ||
| explicit AmUserHeader(const std::vector<uint8_t>& data) : _data(data) | ||
| { | ||
| if (data.empty()) { throw std::invalid_argument("AmUserHeader vector cannot be empty"); } | ||
| } | ||
|
|
||
| /** | ||
| * @brief Construct AmUserHeader from a vector of bytes (move constructor). | ||
| * | ||
| * @param[in] data The byte vector to move. | ||
| * | ||
| * @throws std::invalid_argument if the vector is empty. | ||
| */ | ||
| explicit AmUserHeader(std::vector<uint8_t>&& data) : _data(std::move(data)) | ||
| { | ||
| if (_data.empty()) { throw std::invalid_argument("AmUserHeader vector cannot be empty"); } | ||
| } | ||
|
|
||
| /** | ||
| * @brief Template constructor for POD types. | ||
| * | ||
| * @param[in] value The POD value to store. | ||
| */ | ||
| template <typename T> | ||
| explicit AmUserHeader(const T& value) | ||
| : _data(reinterpret_cast<const uint8_t*>(&value), | ||
| reinterpret_cast<const uint8_t*>(&value) + sizeof(T)) | ||
|
Comment on lines
+306
to
+307
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is UB due to aliasing I think, you need to |
||
| { | ||
| static_assert(std::is_trivially_copyable_v<T>, "Type must be trivially copyable"); | ||
| static_assert(sizeof(T) > 0, "Type size must be greater than zero"); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: No C++ type has zero size ( |
||
| } | ||
|
|
||
| /** | ||
| * @brief Get the underlying data as a byte array. | ||
| * | ||
| * @returns Pointer to the underlying data. | ||
| */ | ||
| [[nodiscard]] const uint8_t* data() const { return _data.data(); } | ||
|
|
||
| /** | ||
| * @brief Get the size of the data in bytes. | ||
| * | ||
| * @returns Size of the data in bytes. | ||
| */ | ||
| [[nodiscard]] size_t size() const { return _data.size(); } | ||
|
|
||
| /** | ||
| * @brief Check if the user data is empty. | ||
| * | ||
| * @returns True if no data is stored, false otherwise. | ||
| */ | ||
| [[nodiscard]] bool empty() const { return _data.empty(); } | ||
|
|
||
| /** | ||
| * @brief Get the data as a string. | ||
| * | ||
| * @returns String representation of the data. | ||
| */ | ||
| [[nodiscard]] std::string asString() const { return std::string(_data.begin(), _data.end()); } | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This copies the values, is that what you want or would a |
||
|
|
||
| /** | ||
| * @brief Get the data as a specific POD type. | ||
| * | ||
| * @returns Reference to the data interpreted as type T. | ||
| * @throws std::runtime_error if the size doesn't match. | ||
| */ | ||
| template <typename T> | ||
| [[nodiscard]] const T& as() const | ||
| { | ||
| static_assert(std::is_trivially_copyable_v<T>, "Type must be trivially copyable"); | ||
| if (_data.size() != sizeof(T)) { | ||
| throw std::runtime_error("AmUserHeader size mismatch: expected " + std::to_string(sizeof(T)) + | ||
| " bytes, got " + std::to_string(_data.size())); | ||
| } | ||
| return *reinterpret_cast<const T*>(_data.data()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This needs to be |
||
| } | ||
|
|
||
| /** | ||
| * @brief Get a copy of the underlying byte vector. | ||
| * | ||
| * @returns Copy of the underlying data vector. | ||
| */ | ||
| [[nodiscard]] std::vector<uint8_t> getBytes() const { return _data; } | ||
| }; | ||
|
|
||
| /** | ||
| * @brief Information of an Active Message receiver callback. | ||
| * | ||
|
|
@@ -182,18 +332,36 @@ class AmReceiverCallbackInfo { | |
| public: | ||
| const AmReceiverCallbackOwnerType owner; ///< The owner name of the callback | ||
| const AmReceiverCallbackIdType id; ///< The unique identifier of the callback | ||
| const bool delayReceive; ///< Whether to delay receiving data (user-controlled) | ||
| const std::optional<AmUserHeader> userHeader; ///< Optional arbitrary user header data | ||
|
|
||
| AmReceiverCallbackInfo() = delete; | ||
|
|
||
| /** | ||
| * @brief Construct an AmReceiverCallbackInfo object. | ||
| * | ||
| * @param[in] owner The owner name of the callback. | ||
| * @param[in] id The unique identifier of the callback. | ||
| * @param[in] owner The owner name of the callback. | ||
| * @param[in] id The unique identifier of the callback. | ||
| * @param[in] delayReceive Whether to delay receiving data, allowing user to control when | ||
| * ucp_am_recv_data_nbx is called. | ||
| * @param[in] userHeader Optional arbitrary user header data to be transmitted with the AM. | ||
| */ | ||
| AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, AmReceiverCallbackIdType id); | ||
| AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, | ||
| AmReceiverCallbackIdType id, | ||
| bool delayReceive = false, | ||
| std::optional<AmUserHeader> userHeader = std::nullopt); | ||
| }; | ||
|
|
||
| /** | ||
| * @brief Active Message receiver callback. | ||
| * | ||
| * Type for a custom Active Message receiver callback, executed by the remote worker upon | ||
| * Active Message request completion. The first parameter is the request that completed, | ||
| * the second is the handle of the UCX endpoint of the sender. | ||
| */ | ||
| typedef std::function<void(std::shared_ptr<Request>, ucp_ep_h, AmReceiverCallbackInfo&)> | ||
| AmReceiverCallbackType; | ||
|
|
||
| /** | ||
| * @brief Serialized form of a remote key. | ||
| * | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this docstring addition. AFAICT the
receiverCallbackInfois not a function, so how can it be executed?