Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions cpp/include/ucxx/internal/request_am.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
#include <functional>
Expand Down Expand Up @@ -38,6 +38,8 @@ class RecvAmMessage {
std::shared_ptr<RequestAm> _request{
nullptr}; ///< Request which will later be notified/delivered to user
std::shared_ptr<Buffer> _buffer{nullptr}; ///< Buffer containing the received data
std::optional<AmReceiverCallbackInfo> _receiverCallbackInfo{
std::nullopt}; ///< Callback info with user header

RecvAmMessage() = delete;
RecvAmMessage(const RecvAmMessage&) = delete;
Expand All @@ -50,18 +52,21 @@ class RecvAmMessage {
*
* Construct the object, setting attributes that are later needed by the callback.
*
* @param[in] amData active messages worker data.
* @param[in] ep handle containing address of the reply endpoint (i.e.,
endpoint where user is requesting to receive).
* @param[in] request request to be later notified/delivered to user.
* @param[in] buffer buffer containing the received data
* @param[in] receiverCallback receiver callback to execute when request completes.
* @param[in] amData active messages worker data.
* @param[in] ep handle containing address of the reply endpoint (i.e.,
* endpoint where user is requesting to receive).
* @param[in] request request to be later notified/delivered to user.
* @param[in] buffer buffer containing the received data.
* @param[in] receiverCallback receiver callback to execute when request completes.
* @param[in] receiverCallbackInfo receiver callback info to execute when request completes,
* including user header.
Comment on lines +61 to +62
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this docstring addition. AFAICT the receiverCallbackInfo is not a function, so how can it be executed?

*/
RecvAmMessage(internal::AmData* amData,
ucp_ep_h ep,
std::shared_ptr<RequestAm> request,
std::shared_ptr<Buffer> buffer,
AmReceiverCallbackType receiverCallback = AmReceiverCallbackType());
AmReceiverCallbackType receiverCallback = AmReceiverCallbackType(),
std::optional<AmReceiverCallbackInfo> receiverCallbackInfo = std::nullopt);

/**
* @brief Set the UCP request.
Expand Down
75 changes: 74 additions & 1 deletion cpp/include/ucxx/request_am.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
#pragma once
Expand Down Expand Up @@ -34,6 +34,20 @@ class RequestAm : public Request {
std::string _header{}; ///< Retain copy of header for send requests as workaround for
///< https://github.com/openucx/ucx/issues/10424

/**
* @brief Common implementation for receiveData operations.
*
* @param[in] amData AM data containing pointer and length.
* @param[in] bufferPtr Pointer to the buffer to receive data into.
* @param[in] managedBuffer Optional managed buffer for lifetime management (nullptr for raw
* pointers).
*
* @returns A shared_ptr to a Request object for the delayed receive operation.
*/
std::shared_ptr<Request> receiveDataImpl(const AmData& amData,
void* bufferPtr,
Buffer* managedBuffer);

/**
* @brief Private constructor of `ucxx::RequestAm`.
*
Expand Down Expand Up @@ -154,6 +168,65 @@ class RequestAm : public Request {
const ucp_am_recv_param_t* param);

[[nodiscard]] std::shared_ptr<Buffer> getRecvBuffer() override;

/**
* @brief Receive delayed Active Message data into a user-provided buffer.
*
* This method is used for delayed receive operations where `delayReceive`` was enabled.
* It takes a user-provided buffer and internally calls `ucp_am_recv_data_nbx`` to receive
* the AM data that was stored when the message first arrived. Returns a `Request`` object
Comment on lines +175 to +177
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* This method is used for delayed receive operations where `delayReceive`` was enabled.
* It takes a user-provided buffer and internally calls `ucp_am_recv_data_nbx`` to receive
* the AM data that was stored when the message first arrived. Returns a `Request`` object
* This method is used for delayed receive operations where `delayReceive` was enabled.
* It takes a user-provided buffer and internally calls `ucp_am_recv_data_nbx` to receive
* the AM data that was stored when the message first arrived. Returns a `Request` object

Start in markdown, finish in rst :)

* that the user can wait on for completion.
*
* @param[in] buffer The buffer to receive the AM data into. Must be large enough to hold
* the AM data (length available via the original AM callback).
*
* @returns A shared_ptr to a Request object that can be waited on for completion.
* Returns nullptr if this was not a delayed receive operation or AM data
* is not available.
*/
[[nodiscard]] std::shared_ptr<Request> receiveData(std::shared_ptr<Buffer> buffer);

/**
* @brief Receive delayed Active Message data into a user-provided buffer pointer.
*
* This method is used for delayed receive operations where `delayReceive`` was enabled.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* This method is used for delayed receive operations where `delayReceive`` was enabled.
* This method is used for delayed receive operations where `delayReceive` was enabled.

* It takes a user-provided pointer to buffer and internally calls `ucp_am_recv_data_nbx`
* to receive the AM data that was stored when the message first arrived. Returns a
* `Request` object that the user can wait on for completion.
*
* @param[in] buffer The buffer pointer to receive the AM data into. Must be large enough
* to hold the AM data (length available via the original AM callback).
*
* @returns A shared_ptr to a Request object that can be waited on for completion.
* Returns nullptr if this was not a delayed receive operation or AM data
* is not available.
*/
[[nodiscard]] std::shared_ptr<Request> receiveData(void* buffer);

/**
* @brief Get the length of delayed Active Message data.
*
* This method returns the length of the AM data that was stored when delayReceive
* was enabled and the message was received via the receive callback. This allows
* users to allocate appropriately sized buffers before calling receiveData().
*
* @returns The length of the AM data in bytes, or 0 if this was not a delayed
* receive operation or AM data is not available.
*/
[[nodiscard]] size_t getAmDataLength();

/**
* @brief Check if this request uses delayed receive.
*
* This method returns true if this request is a delayed receive operation where
* delayReceive was enabled and AM data is stored for later retrieval. If true,
* users should use getAmDataLength() and receiveData() to retrieve the data.
* If false, users should use getRecvBuffer() to access immediately received data.
*
* @returns True if this is a delayed receive operation with stored AM data,
* false for immediate/eager receives or if no AM data is available.
*/
[[nodiscard]] bool isDelayedReceive();
};

} // namespace ucxx
6 changes: 5 additions & 1 deletion cpp/include/ucxx/request_data.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES.
* SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES.
* SPDX-License-Identifier: BSD-3-Clause
*/
#pragma once
Expand Down Expand Up @@ -61,6 +61,10 @@ class AmSend {
class AmReceive {
public:
std::shared_ptr<::ucxx::Buffer> _buffer{nullptr}; ///< The AM received message buffer
std::optional<::ucxx::AmData> _amData{
std::nullopt}; ///< The AM data pointer and length for delayed receiving
void* _rawBuffer{nullptr}; ///< Raw buffer pointer for delayed receiving
size_t _rawBufferSize{0}; ///< Size of the raw buffer

/**
* @brief Constructor for Active Message-specific receive data.
Expand Down
192 changes: 180 additions & 12 deletions cpp/include/ucxx/typedefs.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
#include <functional>
#include <limits>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include <ucp/api/ucp.h>

Expand Down Expand Up @@ -140,15 +144,6 @@ typedef RequestCallbackUserData EndpointCloseCallbackUserData;
*/
typedef std::function<std::shared_ptr<Buffer>(size_t)> AmAllocatorType;

/**
* @brief Active Message receiver callback.
*
* Type for a custom Active Message receiver callback, executed by the remote worker upon
* Active Message request completion. The first parameter is the request that completed,
* the second is the handle of the UCX endpoint of the sender.
*/
typedef std::function<void(std::shared_ptr<Request>, ucp_ep_h)> AmReceiverCallbackType;

/**
* @brief Active Message receiver callback owner name.
*
Expand All @@ -173,6 +168,161 @@ typedef uint64_t AmReceiverCallbackIdType;
*/
typedef const std::string AmReceiverCallbackInfoSerialized;

/**
* @brief Active Message data information for delayed receiving.
*
* Structure containing the Active Message data pointer and length that will be used
* when the user chooses to delay receiving and handle ucp_am_recv_data_nbx manually.
*/
struct AmData {
void* data; ///< The Active Message data pointer from the receive callback
size_t length; ///< The length of the Active Message data

AmData() : data(nullptr), length(0) {}

/**
* @brief Construct an AmData object.
*
* @param[in] data The Active Message data pointer from the receive callback.
* @param[in] length The length of the Active Message data.
*/
AmData(void* data, size_t length) : data(data), length(length) {}
};

/**
* @brief Container for arbitrary user header data that can be attached to Active Messages.
*
* This class provides a type-safe interface for storing arbitrary user header data that will be
* serialized and transmitted with Active Messages. It supports common data types while
* also allowing direct access to the underlying byte storage for custom serialization.
*/
class AmUserHeader {
private:
std::vector<uint8_t> _data;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: use std::byte (these aren't characters...).

Also, do you need this to be a resizeable thing, or is it sufficient for this object to manage a std::byte *?


public:
AmUserHeader() = delete;

/**
* @brief Construct AmUserHeader from a byte array.
*
* @param[in] data Pointer to the data to copy.
* @param[in] size Size of the data in bytes.
*
* @throws std::invalid_argument if data is null or if size is 0.
*/
AmUserHeader(const void* data, size_t size)
{
if (size == 0) { throw std::invalid_argument("AmUserHeader size must be greater than zero"); }
if (data == nullptr) {
throw std::invalid_argument("AmUserHeader data pointer cannot be null");
}
_data.assign(static_cast<const uint8_t*>(data), static_cast<const uint8_t*>(data) + size);
}

/**
* @brief Construct AmUserHeader from a string.
*
* @param[in] str The string to store.
*
* @throws std::invalid_argument if the string is empty.
*/
explicit AmUserHeader(const std::string& str) : _data(str.begin(), str.end())
{
if (str.empty()) { throw std::invalid_argument("AmUserHeader string cannot be empty"); }
}

/**
* @brief Construct AmUserHeader from a vector of bytes.
*
* @param[in] data The byte vector to copy.
*
* @throws std::invalid_argument if the vector is empty.
*/
explicit AmUserHeader(const std::vector<uint8_t>& data) : _data(data)
{
if (data.empty()) { throw std::invalid_argument("AmUserHeader vector cannot be empty"); }
}

/**
* @brief Construct AmUserHeader from a vector of bytes (move constructor).
*
* @param[in] data The byte vector to move.
*
* @throws std::invalid_argument if the vector is empty.
*/
explicit AmUserHeader(std::vector<uint8_t>&& data) : _data(std::move(data))
{
if (_data.empty()) { throw std::invalid_argument("AmUserHeader vector cannot be empty"); }
}

/**
* @brief Template constructor for POD types.
*
* @param[in] value The POD value to store.
*/
template <typename T>
explicit AmUserHeader(const T& value)
: _data(reinterpret_cast<const uint8_t*>(&value),
reinterpret_cast<const uint8_t*>(&value) + sizeof(T))
Comment on lines +306 to +307
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is UB due to aliasing I think, you need to std::memcpy. The compiler will turn it in to the obvious thing.

{
static_assert(std::is_trivially_copyable_v<T>, "Type must be trivially copyable");
static_assert(sizeof(T) > 0, "Type size must be greater than zero");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: No C++ type has zero size (using Empty = T[0] for some T is UB)

}

/**
* @brief Get the underlying data as a byte array.
*
* @returns Pointer to the underlying data.
*/
[[nodiscard]] const uint8_t* data() const { return _data.data(); }

/**
* @brief Get the size of the data in bytes.
*
* @returns Size of the data in bytes.
*/
[[nodiscard]] size_t size() const { return _data.size(); }

/**
* @brief Check if the user data is empty.
*
* @returns True if no data is stored, false otherwise.
*/
[[nodiscard]] bool empty() const { return _data.empty(); }

/**
* @brief Get the data as a string.
*
* @returns String representation of the data.
*/
[[nodiscard]] std::string asString() const { return std::string(_data.begin(), _data.end()); }
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This copies the values, is that what you want or would a std::string_view do?


/**
* @brief Get the data as a specific POD type.
*
* @returns Reference to the data interpreted as type T.
* @throws std::runtime_error if the size doesn't match.
*/
template <typename T>
[[nodiscard]] const T& as() const
{
static_assert(std::is_trivially_copyable_v<T>, "Type must be trivially copyable");
if (_data.size() != sizeof(T)) {
throw std::runtime_error("AmUserHeader size mismatch: expected " + std::to_string(sizeof(T)) +
" bytes, got " + std::to_string(_data.size()));
}
return *reinterpret_cast<const T*>(_data.data());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be std::memcpy as well, I think, due to aliasing issues.

}

/**
* @brief Get a copy of the underlying byte vector.
*
* @returns Copy of the underlying data vector.
*/
[[nodiscard]] std::vector<uint8_t> getBytes() const { return _data; }
};

/**
* @brief Information of an Active Message receiver callback.
*
Expand All @@ -182,18 +332,36 @@ class AmReceiverCallbackInfo {
public:
const AmReceiverCallbackOwnerType owner; ///< The owner name of the callback
const AmReceiverCallbackIdType id; ///< The unique identifier of the callback
const bool delayReceive; ///< Whether to delay receiving data (user-controlled)
const std::optional<AmUserHeader> userHeader; ///< Optional arbitrary user header data

AmReceiverCallbackInfo() = delete;

/**
* @brief Construct an AmReceiverCallbackInfo object.
*
* @param[in] owner The owner name of the callback.
* @param[in] id The unique identifier of the callback.
* @param[in] owner The owner name of the callback.
* @param[in] id The unique identifier of the callback.
* @param[in] delayReceive Whether to delay receiving data, allowing user to control when
* ucp_am_recv_data_nbx is called.
* @param[in] userHeader Optional arbitrary user header data to be transmitted with the AM.
*/
AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, AmReceiverCallbackIdType id);
AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner,
AmReceiverCallbackIdType id,
bool delayReceive = false,
std::optional<AmUserHeader> userHeader = std::nullopt);
};

/**
* @brief Active Message receiver callback.
*
* Type for a custom Active Message receiver callback, executed by the remote worker upon
* Active Message request completion. The first parameter is the request that completed,
* the second is the handle of the UCX endpoint of the sender.
*/
typedef std::function<void(std::shared_ptr<Request>, ucp_ep_h, AmReceiverCallbackInfo&)>
AmReceiverCallbackType;

/**
* @brief Serialized form of a remote key.
*
Expand Down
Loading