diff --git a/cpp/include/ucxx/internal/request_am.h b/cpp/include/ucxx/internal/request_am.h index c259d4886..b23f088bf 100644 --- a/cpp/include/ucxx/internal/request_am.h +++ b/cpp/include/ucxx/internal/request_am.h @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -38,6 +38,8 @@ class RecvAmMessage { std::shared_ptr _request{ nullptr}; ///< Request which will later be notified/delivered to user std::shared_ptr _buffer{nullptr}; ///< Buffer containing the received data + std::optional _receiverCallbackInfo{ + std::nullopt}; ///< Callback info with user header RecvAmMessage() = delete; RecvAmMessage(const RecvAmMessage&) = delete; @@ -50,18 +52,21 @@ class RecvAmMessage { * * Construct the object, setting attributes that are later needed by the callback. * - * @param[in] amData active messages worker data. - * @param[in] ep handle containing address of the reply endpoint (i.e., - endpoint where user is requesting to receive). - * @param[in] request request to be later notified/delivered to user. - * @param[in] buffer buffer containing the received data - * @param[in] receiverCallback receiver callback to execute when request completes. + * @param[in] amData active messages worker data. + * @param[in] ep handle containing address of the reply endpoint (i.e., + * endpoint where user is requesting to receive). + * @param[in] request request to be later notified/delivered to user. + * @param[in] buffer buffer containing the received data. + * @param[in] receiverCallback receiver callback to execute when request completes. + * @param[in] receiverCallbackInfo receiver callback info to execute when request completes, + * including user header. */ RecvAmMessage(internal::AmData* amData, ucp_ep_h ep, std::shared_ptr request, std::shared_ptr buffer, - AmReceiverCallbackType receiverCallback = AmReceiverCallbackType()); + AmReceiverCallbackType receiverCallback = AmReceiverCallbackType(), + std::optional receiverCallbackInfo = std::nullopt); /** * @brief Set the UCP request. diff --git a/cpp/include/ucxx/request_am.h b/cpp/include/ucxx/request_am.h index a1aae9655..bc0a3a111 100644 --- a/cpp/include/ucxx/request_am.h +++ b/cpp/include/ucxx/request_am.h @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once @@ -34,6 +34,20 @@ class RequestAm : public Request { std::string _header{}; ///< Retain copy of header for send requests as workaround for ///< https://github.com/openucx/ucx/issues/10424 + /** + * @brief Common implementation for receiveData operations. + * + * @param[in] amData AM data containing pointer and length. + * @param[in] bufferPtr Pointer to the buffer to receive data into. + * @param[in] managedBuffer Optional managed buffer for lifetime management (nullptr for raw + * pointers). + * + * @returns A shared_ptr to a Request object for the delayed receive operation. + */ + std::shared_ptr receiveDataImpl(const AmData& amData, + void* bufferPtr, + Buffer* managedBuffer); + /** * @brief Private constructor of `ucxx::RequestAm`. * @@ -154,6 +168,65 @@ class RequestAm : public Request { const ucp_am_recv_param_t* param); [[nodiscard]] std::shared_ptr getRecvBuffer() override; + + /** + * @brief Receive delayed Active Message data into a user-provided buffer. + * + * This method is used for delayed receive operations where `delayReceive`` was enabled. + * It takes a user-provided buffer and internally calls `ucp_am_recv_data_nbx`` to receive + * the AM data that was stored when the message first arrived. Returns a `Request`` object + * that the user can wait on for completion. + * + * @param[in] buffer The buffer to receive the AM data into. Must be large enough to hold + * the AM data (length available via the original AM callback). + * + * @returns A shared_ptr to a Request object that can be waited on for completion. + * Returns nullptr if this was not a delayed receive operation or AM data + * is not available. + */ + [[nodiscard]] std::shared_ptr receiveData(std::shared_ptr buffer); + + /** + * @brief Receive delayed Active Message data into a user-provided buffer pointer. + * + * This method is used for delayed receive operations where `delayReceive`` was enabled. + * It takes a user-provided pointer to buffer and internally calls `ucp_am_recv_data_nbx` + * to receive the AM data that was stored when the message first arrived. Returns a + * `Request` object that the user can wait on for completion. + * + * @param[in] buffer The buffer pointer to receive the AM data into. Must be large enough + * to hold the AM data (length available via the original AM callback). + * + * @returns A shared_ptr to a Request object that can be waited on for completion. + * Returns nullptr if this was not a delayed receive operation or AM data + * is not available. + */ + [[nodiscard]] std::shared_ptr receiveData(void* buffer); + + /** + * @brief Get the length of delayed Active Message data. + * + * This method returns the length of the AM data that was stored when delayReceive + * was enabled and the message was received via the receive callback. This allows + * users to allocate appropriately sized buffers before calling receiveData(). + * + * @returns The length of the AM data in bytes, or 0 if this was not a delayed + * receive operation or AM data is not available. + */ + [[nodiscard]] size_t getAmDataLength(); + + /** + * @brief Check if this request uses delayed receive. + * + * This method returns true if this request is a delayed receive operation where + * delayReceive was enabled and AM data is stored for later retrieval. If true, + * users should use getAmDataLength() and receiveData() to retrieve the data. + * If false, users should use getRecvBuffer() to access immediately received data. + * + * @returns True if this is a delayed receive operation with stored AM data, + * false for immediate/eager receives or if no AM data is available. + */ + [[nodiscard]] bool isDelayedReceive(); }; } // namespace ucxx diff --git a/cpp/include/ucxx/request_data.h b/cpp/include/ucxx/request_data.h index ca5b68bbc..d700dfcee 100644 --- a/cpp/include/ucxx/request_data.h +++ b/cpp/include/ucxx/request_data.h @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once @@ -61,6 +61,10 @@ class AmSend { class AmReceive { public: std::shared_ptr<::ucxx::Buffer> _buffer{nullptr}; ///< The AM received message buffer + std::optional<::ucxx::AmData> _amData{ + std::nullopt}; ///< The AM data pointer and length for delayed receiving + void* _rawBuffer{nullptr}; ///< Raw buffer pointer for delayed receiving + size_t _rawBufferSize{0}; ///< Size of the raw buffer /** * @brief Constructor for Active Message-specific receive data. diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index b4c76ed25..234fa9bdd 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -7,8 +7,12 @@ #include #include #include +#include +#include #include #include +#include +#include #include @@ -140,15 +144,6 @@ typedef RequestCallbackUserData EndpointCloseCallbackUserData; */ typedef std::function(size_t)> AmAllocatorType; -/** - * @brief Active Message receiver callback. - * - * Type for a custom Active Message receiver callback, executed by the remote worker upon - * Active Message request completion. The first parameter is the request that completed, - * the second is the handle of the UCX endpoint of the sender. - */ -typedef std::function, ucp_ep_h)> AmReceiverCallbackType; - /** * @brief Active Message receiver callback owner name. * @@ -173,6 +168,161 @@ typedef uint64_t AmReceiverCallbackIdType; */ typedef const std::string AmReceiverCallbackInfoSerialized; +/** + * @brief Active Message data information for delayed receiving. + * + * Structure containing the Active Message data pointer and length that will be used + * when the user chooses to delay receiving and handle ucp_am_recv_data_nbx manually. + */ +struct AmData { + void* data; ///< The Active Message data pointer from the receive callback + size_t length; ///< The length of the Active Message data + + AmData() : data(nullptr), length(0) {} + + /** + * @brief Construct an AmData object. + * + * @param[in] data The Active Message data pointer from the receive callback. + * @param[in] length The length of the Active Message data. + */ + AmData(void* data, size_t length) : data(data), length(length) {} +}; + +/** + * @brief Container for arbitrary user header data that can be attached to Active Messages. + * + * This class provides a type-safe interface for storing arbitrary user header data that will be + * serialized and transmitted with Active Messages. It supports common data types while + * also allowing direct access to the underlying byte storage for custom serialization. + */ +class AmUserHeader { + private: + std::vector _data; + + public: + AmUserHeader() = delete; + + /** + * @brief Construct AmUserHeader from a byte array. + * + * @param[in] data Pointer to the data to copy. + * @param[in] size Size of the data in bytes. + * + * @throws std::invalid_argument if data is null or if size is 0. + */ + AmUserHeader(const void* data, size_t size) + { + if (size == 0) { throw std::invalid_argument("AmUserHeader size must be greater than zero"); } + if (data == nullptr) { + throw std::invalid_argument("AmUserHeader data pointer cannot be null"); + } + _data.assign(static_cast(data), static_cast(data) + size); + } + + /** + * @brief Construct AmUserHeader from a string. + * + * @param[in] str The string to store. + * + * @throws std::invalid_argument if the string is empty. + */ + explicit AmUserHeader(const std::string& str) : _data(str.begin(), str.end()) + { + if (str.empty()) { throw std::invalid_argument("AmUserHeader string cannot be empty"); } + } + + /** + * @brief Construct AmUserHeader from a vector of bytes. + * + * @param[in] data The byte vector to copy. + * + * @throws std::invalid_argument if the vector is empty. + */ + explicit AmUserHeader(const std::vector& data) : _data(data) + { + if (data.empty()) { throw std::invalid_argument("AmUserHeader vector cannot be empty"); } + } + + /** + * @brief Construct AmUserHeader from a vector of bytes (move constructor). + * + * @param[in] data The byte vector to move. + * + * @throws std::invalid_argument if the vector is empty. + */ + explicit AmUserHeader(std::vector&& data) : _data(std::move(data)) + { + if (_data.empty()) { throw std::invalid_argument("AmUserHeader vector cannot be empty"); } + } + + /** + * @brief Template constructor for POD types. + * + * @param[in] value The POD value to store. + */ + template + explicit AmUserHeader(const T& value) + : _data(reinterpret_cast(&value), + reinterpret_cast(&value) + sizeof(T)) + { + static_assert(std::is_trivially_copyable_v, "Type must be trivially copyable"); + static_assert(sizeof(T) > 0, "Type size must be greater than zero"); + } + + /** + * @brief Get the underlying data as a byte array. + * + * @returns Pointer to the underlying data. + */ + [[nodiscard]] const uint8_t* data() const { return _data.data(); } + + /** + * @brief Get the size of the data in bytes. + * + * @returns Size of the data in bytes. + */ + [[nodiscard]] size_t size() const { return _data.size(); } + + /** + * @brief Check if the user data is empty. + * + * @returns True if no data is stored, false otherwise. + */ + [[nodiscard]] bool empty() const { return _data.empty(); } + + /** + * @brief Get the data as a string. + * + * @returns String representation of the data. + */ + [[nodiscard]] std::string asString() const { return std::string(_data.begin(), _data.end()); } + + /** + * @brief Get the data as a specific POD type. + * + * @returns Reference to the data interpreted as type T. + * @throws std::runtime_error if the size doesn't match. + */ + template + [[nodiscard]] const T& as() const + { + static_assert(std::is_trivially_copyable_v, "Type must be trivially copyable"); + if (_data.size() != sizeof(T)) { + throw std::runtime_error("AmUserHeader size mismatch: expected " + std::to_string(sizeof(T)) + + " bytes, got " + std::to_string(_data.size())); + } + return *reinterpret_cast(_data.data()); + } + + /** + * @brief Get a copy of the underlying byte vector. + * + * @returns Copy of the underlying data vector. + */ + [[nodiscard]] std::vector getBytes() const { return _data; } +}; + /** * @brief Information of an Active Message receiver callback. * @@ -182,18 +332,36 @@ class AmReceiverCallbackInfo { public: const AmReceiverCallbackOwnerType owner; ///< The owner name of the callback const AmReceiverCallbackIdType id; ///< The unique identifier of the callback + const bool delayReceive; ///< Whether to delay receiving data (user-controlled) + const std::optional userHeader; ///< Optional arbitrary user header data AmReceiverCallbackInfo() = delete; /** * @brief Construct an AmReceiverCallbackInfo object. * - * @param[in] owner The owner name of the callback. - * @param[in] id The unique identifier of the callback. + * @param[in] owner The owner name of the callback. + * @param[in] id The unique identifier of the callback. + * @param[in] delayReceive Whether to delay receiving data, allowing user to control when + * ucp_am_recv_data_nbx is called. + * @param[in] userHeader Optional arbitrary user header data to be transmitted with the AM. */ - AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, AmReceiverCallbackIdType id); + AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, + AmReceiverCallbackIdType id, + bool delayReceive = false, + std::optional userHeader = std::nullopt); }; +/** + * @brief Active Message receiver callback. + * + * Type for a custom Active Message receiver callback, executed by the remote worker upon + * Active Message request completion. The first parameter is the request that completed, + * the second is the handle of the UCX endpoint of the sender. + */ +typedef std::function, ucp_ep_h, AmReceiverCallbackInfo&)> + AmReceiverCallbackType; + /** * @brief Serialized form of a remote key. * diff --git a/cpp/src/internal/request_am.cpp b/cpp/src/internal/request_am.cpp index d45dee659..5b8ccafaa 100644 --- a/cpp/src/internal/request_am.cpp +++ b/cpp/src/internal/request_am.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -18,8 +18,9 @@ RecvAmMessage::RecvAmMessage(internal::AmData* amData, ucp_ep_h ep, std::shared_ptr request, std::shared_ptr buffer, - AmReceiverCallbackType receiverCallback) - : _amData(amData), _ep(ep), _request(request) + AmReceiverCallbackType receiverCallback, + std::optional receiverCallbackInfo) + : _amData(amData), _ep(ep), _request(request), _receiverCallbackInfo(receiverCallbackInfo) { std::visit(data::dispatch{ [this, buffer](data::AmReceive& amReceive) { amReceive._buffer = buffer; }, @@ -29,7 +30,7 @@ RecvAmMessage::RecvAmMessage(internal::AmData* amData, if (receiverCallback) { _request->_callback = [this, receiverCallback](ucs_status_t, std::shared_ptr) { - receiverCallback(_request, _ep); + if (_receiverCallbackInfo) { receiverCallback(_request, _ep, *_receiverCallbackInfo); } }; } } diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index df3e3bbd5..d89d25a9c 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include @@ -22,8 +23,10 @@ namespace ucxx { AmReceiverCallbackInfo::AmReceiverCallbackInfo(const AmReceiverCallbackOwnerType owner, - AmReceiverCallbackIdType id) - : owner(owner), id(id) + AmReceiverCallbackIdType id, + bool delayReceive, + std::optional userHeader) + : owner(owner), id(id), delayReceive(delayReceive), userHeader(userHeader) { } @@ -58,8 +61,26 @@ struct AmHeader { AmReceiverCallbackIdType id{}; decode(&id, sizeof(id)); - return AmHeader{.memoryType = memoryType, - .receiverCallbackInfo = AmReceiverCallbackInfo(owner, id)}; + bool delayReceive{false}; + decode(&delayReceive, sizeof(delayReceive)); + + // Check if user data is present + bool hasUserHeader{false}; + decode(&hasUserHeader, sizeof(hasUserHeader)); + + std::optional userHeader = std::nullopt; + if (hasUserHeader) { + size_t userHeaderSize{0}; + decode(&userHeaderSize, sizeof(userHeaderSize)); + + std::vector userHeaderBytes(userHeaderSize); + decode(userHeaderBytes.data(), userHeaderSize); + userHeader = AmUserHeader(std::move(userHeaderBytes)); + } + + return AmHeader{ + .memoryType = memoryType, + .receiverCallbackInfo = AmReceiverCallbackInfo(owner, id, delayReceive, userHeader)}; } return AmHeader{.memoryType = memoryType, .receiverCallbackInfo = std::nullopt}; @@ -69,9 +90,17 @@ struct AmHeader { { size_t offset{0}; bool hasReceiverCallback{static_cast(receiverCallbackInfo)}; - const size_t ownerSize = (receiverCallbackInfo) ? receiverCallbackInfo->owner.size() : 0; + const size_t ownerSize = (receiverCallbackInfo) ? receiverCallbackInfo->owner.size() : 0; + const size_t userHeaderSize = (receiverCallbackInfo && receiverCallbackInfo->userHeader) + ? receiverCallbackInfo->userHeader->size() + : 0; + const bool hasUserHeader = (receiverCallbackInfo && receiverCallbackInfo->userHeader); const size_t amReceiverCallbackInfoSize = - (receiverCallbackInfo) ? sizeof(ownerSize) + ownerSize + sizeof(receiverCallbackInfo->id) : 0; + (receiverCallbackInfo) + ? sizeof(ownerSize) + ownerSize + sizeof(receiverCallbackInfo->id) + + sizeof(receiverCallbackInfo->delayReceive) + sizeof(hasUserHeader) + + (hasUserHeader ? sizeof(userHeaderSize) + userHeaderSize : 0) + : 0; const size_t totalSize = sizeof(memoryType) + sizeof(hasReceiverCallback) + amReceiverCallbackInfoSize; std::string serialized(totalSize, 0); @@ -87,6 +116,12 @@ struct AmHeader { encode(&ownerSize, sizeof(ownerSize)); encode(receiverCallbackInfo->owner.c_str(), ownerSize); encode(&receiverCallbackInfo->id, sizeof(receiverCallbackInfo->id)); + encode(&receiverCallbackInfo->delayReceive, sizeof(receiverCallbackInfo->delayReceive)); + encode(&hasUserHeader, sizeof(hasUserHeader)); + if (hasUserHeader) { + encode(&userHeaderSize, sizeof(userHeaderSize)); + encode(receiverCallbackInfo->userHeader->data(), userHeaderSize); + } } return serialized; @@ -263,85 +298,125 @@ ucs_status_t RequestAm::recvCallback(void* arg, if (req->getStatus() != UCS_INPROGRESS) return req->getStatus(); if (is_rndv) { - if (amData->_allocators.find(amHeader.memoryType) == amData->_allocators.end()) { - // TODO: Is a hard failure better? - // ucxx_debug("Unsupported memory type %d", amHeader.memoryType); - // internal::RecvAmMessage recvAmMessage(amData, ep, req, nullptr); - // recvAmMessage.callback(nullptr, UCS_ERR_UNSUPPORTED); - // return UCS_ERR_UNSUPPORTED; - - ucxx_trace_req("No allocator registered for memory type %u, falling back to host memory.", - amHeader.memoryType); - amHeader.memoryType = UCS_MEMORY_TYPE_HOST; - } - - try { - buf = amData->_allocators.at(amHeader.memoryType)(length); - } catch (const std::exception& e) { - ucxx_debug("Exception calling allocator: %s", e.what()); - } - - auto recvAmMessage = - std::make_shared(amData, ep, req, buf, receiverCallback); + // Check if delayed receive is requested for rendezvous messages + bool shouldDelayReceive = + amHeader.receiverCallbackInfo && amHeader.receiverCallbackInfo->delayReceive; + + if (shouldDelayReceive) { + // For delayed receive: don't allocate buffer, don't call ucp_am_recv_data_nbx, + // store AM data pointer, return UCS_INPROGRESS + auto& amReceiveData = std::get(req->_requestData); + amReceiveData._amData = AmData(data, length); + + if (req->_enablePythonFuture) + ucxx_trace_req_f(ownerString.c_str(), + req.get(), + nullptr, + "amRecv rndv delayed", + "recvCallback, ep: %p, data: %p, size: %lu, future: %p, future handle: %p", + ep, + data, + length, + req->_future.get(), + req->_future->getHandle()); + else + ucxx_trace_req_f(ownerString.c_str(), + req.get(), + nullptr, + "amRecv rndv delayed", + "recvCallback, ep: %p, data: %p, size: %lu", + ep, + data, + length); + + // Execute receiver callback if present + if (receiverCallback) { receiverCallback(req, ep, *amHeader.receiverCallbackInfo); } - ucp_request_param_t requestParam = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | - UCP_OP_ATTR_FIELD_USER_DATA | - UCP_OP_ATTR_FLAG_NO_IMM_CMPL, - .cb = {.recv_am = _recvCompletedCallback}, - .user_data = recvAmMessage.get()}; - - if (buf == nullptr) { - ucxx_debug("Failed to allocate %lu bytes of memory", length); - recvAmMessage->_request->setStatus(UCS_ERR_NO_MEMORY); - return UCS_ERR_NO_MEMORY; - } - - ucs_status_ptr_t status = - ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &requestParam); + return UCS_INPROGRESS; + } else { + // Normal rendezvous receive path + if (amData->_allocators.find(amHeader.memoryType) == amData->_allocators.end()) { + // TODO: Is a hard failure better? + // ucxx_debug("Unsupported memory type %d", amHeader.memoryType); + // internal::RecvAmMessage recvAmMessage(amData, ep, req, nullptr); + // recvAmMessage.callback(nullptr, UCS_ERR_UNSUPPORTED); + // return UCS_ERR_UNSUPPORTED; + + ucxx_trace_req("No allocator registered for memory type %u, falling back to host memory.", + amHeader.memoryType); + amHeader.memoryType = UCS_MEMORY_TYPE_HOST; + } - if (req->_enablePythonFuture) - ucxx_trace_req_f(ownerString.c_str(), - req.get(), - status, - "amRecv rndv", - "recvCallback, ep: %p, buffer: %p, size: %lu, future: %p, future handle: %p", - ep, - buf->data(), - length, - req->_future.get(), - req->_future->getHandle()); - else - ucxx_trace_req_f(ownerString.c_str(), - req.get(), - status, - "amRecv rndv", - "recvCallback, ep: %p, buffer: %p, size: %lu", - ep, - buf->data(), - length); + try { + buf = amData->_allocators.at(amHeader.memoryType)(length); + } catch (const std::exception& e) { + ucxx_debug("Exception calling allocator: %s", e.what()); + } - if (req->isCompleted()) { - // The request completed/errored immediately - ucs_status_t s = UCS_PTR_STATUS(status); - recvAmMessage->callback(nullptr, s); + auto recvAmMessage = std::make_shared( + amData, ep, req, buf, receiverCallback, amHeader.receiverCallbackInfo); - return s; - } else { - // The request will be handled by the callback - recvAmMessage->setUcpRequest(status); - amData->_registerInflightRequest(req); + ucp_request_param_t requestParam = {.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | + UCP_OP_ATTR_FIELD_USER_DATA | + UCP_OP_ATTR_FLAG_NO_IMM_CMPL, + .cb = {.recv_am = _recvCompletedCallback}, + .user_data = recvAmMessage.get()}; - { - std::lock_guard lock(amData->_mutex); - amData->_recvAmMessageMap.emplace(req.get(), recvAmMessage); + if (buf == nullptr) { + ucxx_debug("Failed to allocate %lu bytes of memory", length); + recvAmMessage->_request->setStatus(UCS_ERR_NO_MEMORY); + return UCS_ERR_NO_MEMORY; } - return UCS_INPROGRESS; + ucs_status_ptr_t status = + ucp_am_recv_data_nbx(worker->getHandle(), data, buf->data(), length, &requestParam); + + if (req->_enablePythonFuture) + ucxx_trace_req_f( + ownerString.c_str(), + req.get(), + status, + "amRecv rndv", + "recvCallback, ep: %p, buffer: %p, size: %lu, future: %p, future handle: %p", + ep, + buf->data(), + length, + req->_future.get(), + req->_future->getHandle()); + else + ucxx_trace_req_f(ownerString.c_str(), + req.get(), + status, + "amRecv rndv", + "recvCallback, ep: %p, buffer: %p, size: %lu", + ep, + buf->data(), + length); + + if (req->isCompleted()) { + // The request completed/errored immediately + ucs_status_t s = UCS_PTR_STATUS(status); + recvAmMessage->callback(nullptr, s); + + return s; + } else { + // The request will be handled by the callback + recvAmMessage->setUcpRequest(status); + amData->_registerInflightRequest(req); + + { + std::lock_guard lock(amData->_mutex); + amData->_recvAmMessageMap.emplace(req.get(), recvAmMessage); + } + + return UCS_INPROGRESS; + } } } else { buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length); - internal::RecvAmMessage recvAmMessage(amData, ep, req, buf, receiverCallback); + internal::RecvAmMessage recvAmMessage( + amData, ep, req, buf, receiverCallback, amHeader.receiverCallbackInfo); if (buf == nullptr) { ucxx_debug("Failed to allocate %lu bytes of memory", length); recvAmMessage._request->setStatus(UCS_ERR_NO_MEMORY); @@ -386,6 +461,130 @@ std::shared_ptr RequestAm::getRecvBuffer() _requestData); } +std::shared_ptr RequestAm::receiveData(std::shared_ptr buffer) +{ + return std::visit( + data::dispatch{ + [this, &buffer](data::AmReceive amReceive) -> std::shared_ptr { + if (!amReceive._amData.has_value()) { + // No AM data available - not a delayed receive operation + return nullptr; + } + + auto& amData = amReceive._amData.value(); + + // Validate buffer size + if (buffer->getSize() < amData.length) { + throw std::runtime_error("Buffer too small for AM data"); + } + + return receiveDataImpl(amData, buffer->data(), buffer.get()); + }, + [](auto) -> std::shared_ptr { throw std::runtime_error("Unreachable"); }, + }, + _requestData); +} + +std::shared_ptr RequestAm::receiveData(void* data) +{ + return std::visit( + data::dispatch{ + [this, data](data::AmReceive amReceive) -> std::shared_ptr { + if (!amReceive._amData.has_value()) { + ucxx_trace_req_f(_ownerString.c_str(), + this, + _request, + _operationName.c_str(), + "receiveData, data %p", + data); + ucxx_warn("No AM data available, not a delayed receive operation"); + return nullptr; + } + + if (data == nullptr) { throw std::runtime_error("Buffer pointer cannot be null"); } + + auto& amData = amReceive._amData.value(); + return receiveDataImpl(amData, data, nullptr); + }, + [](auto) -> std::shared_ptr { throw std::runtime_error("Unreachable"); }, + }, + _requestData); +} + +std::shared_ptr RequestAm::receiveDataImpl(const AmData& amData, + void* bufferPtr, + Buffer* managedBuffer) +{ + // Create a new RequestAm for the delayed receive operation + auto request = std::shared_ptr(new RequestAm( + _worker, data::AmReceive(), "amDelayedReceive", _enablePythonFuture, nullptr, nullptr)); + + // Store the buffer information in the request + auto& requestAmReceiveData = std::get(request->_requestData); + if (managedBuffer) { + // For shared_ptr case, store the managed buffer + requestAmReceiveData._buffer = std::shared_ptr(managedBuffer, [](Buffer*) { + // Custom deleter that does nothing - the original shared_ptr manages the lifetime + }); + } else { + // For raw pointer case, store the raw buffer info + requestAmReceiveData._rawBuffer = bufferPtr; + requestAmReceiveData._rawBufferSize = amData.length; + } + + // Simple completion callback that directly calls request completion + auto callback = [](void* ucpRequest, ucs_status_t status, size_t /*length*/, void* userData) { + auto* req = static_cast(userData); + req->callback(ucpRequest, status); + }; + + // Set up UCP request parameters + ucp_request_param_t requestParam = { + .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, + .cb = {.recv_am = callback}, + .user_data = request.get()}; + + // Call ucp_am_recv_data_nbx + ucs_status_ptr_t status = ucp_am_recv_data_nbx( + _worker->getHandle(), amData.data, bufferPtr, amData.length, &requestParam); + + if (UCS_PTR_IS_ERR(status)) { + request->callback(nullptr, UCS_PTR_STATUS(status)); + } else if (status == nullptr) { + // Completed immediately + request->callback(nullptr, UCS_OK); + } else { + // Will be completed by callback - store UCP request + request->_request = status; + } + + return request; +} + +size_t RequestAm::getAmDataLength() +{ + return std::visit(data::dispatch{ + [](data::AmReceive amReceive) -> size_t { + if (!amReceive._amData.has_value()) { + return 0; // No AM data available + } + return amReceive._amData.value().length; + }, + [](auto) -> size_t { throw std::runtime_error("Unreachable"); }, + }, + _requestData); +} + +bool RequestAm::isDelayedReceive() +{ + return std::visit( + data::dispatch{ + [](data::AmReceive amReceive) -> bool { return amReceive._amData.has_value(); }, + [](auto) -> bool { return false; }, + }, + _requestData); +} + void RequestAm::request() { std::visit( diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index acbffc323..6a4a92b9a 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -5,15 +5,18 @@ #include #include #include +#include #include #include #include +#include #include #include #include #include +#include #include "include/utils.h" #include "ucxx/buffer.h" @@ -32,8 +35,8 @@ using ::testing::Values; typedef std::vector DataContainerType; -class RequestTest : public ::testing::TestWithParam< - std::tuple> { +class RequestTestBase : public ::testing::TestWithParam< + std::tuple> { protected: std::shared_ptr _context{nullptr}; std::shared_ptr _worker{nullptr}; @@ -163,7 +166,14 @@ class RequestTest : public ::testing::TestWithParam< } }; -TEST_P(RequestTest, ProgressAm) +class RequestTestAmAllocator : public RequestTestBase {}; + +class RequestTestNoAmAllocator : public RequestTestBase {}; + +// Limited test suite specifically for user header functionality to reduce test runtime +class RequestTestAmUserHeader : public RequestTestBase {}; + +TEST_P(RequestTestAmAllocator, ProgressAm) { if (_progressMode == ProgressMode::Wait) { GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; @@ -202,7 +212,7 @@ TEST_P(RequestTest, ProgressAm) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, ProgressAmReceiverCallback) +TEST_P(RequestTestAmAllocator, ProgressAmReceiverCallback) { if (_progressMode == ProgressMode::Wait) { GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; @@ -228,7 +238,8 @@ TEST_P(RequestTest, ProgressAmReceiverCallback) // Define AM receiver callback and register with worker std::vector> receivedRequests; auto callback = ucxx::AmReceiverCallbackType( - [this, &receivedRequests, &mutex](std::shared_ptr req, ucp_ep_h) { + [this, &receivedRequests, &mutex]( + std::shared_ptr req, ucp_ep_h, const ucxx::AmReceiverCallbackInfo&) { { std::lock_guard lock(mutex); receivedRequests.push_back(req); @@ -264,7 +275,368 @@ TEST_P(RequestTest, ProgressAmReceiverCallback) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, ProgressStream) +TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceive) +{ + if (_progressMode == ProgressMode::Wait) { + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + } + + if (_memoryType == UCS_MEMORY_TYPE_CUDA) { +#if !UCXX_ENABLE_RMM + GTEST_SKIP() << "UCXX was not built with RMM support"; +#else + _worker->registerAmAllocator(UCS_MEMORY_TYPE_CUDA, [](size_t length) { + return std::make_shared(length); + }); +#endif + } + + // Define AM receiver callback's owner and id for callback with delayReceive enabled + ucxx::AmReceiverCallbackInfo receiverCallbackInfo("TestApp", 0, true); // delayReceive = true + + // Mutex required for blocking progress mode + std::mutex mutex; + + // Storage for the received request and receive operation + std::shared_ptr receivedRequest{nullptr}; + std::shared_ptr manualRecvBuffer{nullptr}; + std::unique_ptr rawBuffer{nullptr}; + std::shared_ptr receiveDataRequest{nullptr}; + + // Define AM receiver callback and register with worker + auto callback = ucxx::AmReceiverCallbackType( + [this, &receivedRequest, &manualRecvBuffer, &rawBuffer, &receiveDataRequest, &mutex]( + std::shared_ptr req, ucp_ep_h, const ucxx::AmReceiverCallbackInfo&) { + { + std::lock_guard lock(mutex); + receivedRequest = req; + + // Cast to RequestAm to access receiveData() method + auto requestAm = std::dynamic_pointer_cast(req); + ASSERT_NE(requestAm, nullptr) << "Request should be a RequestAm for AM operations"; + + // Check if this is a delayed receive operation + if (requestAm->isDelayedReceive()) { + // Delayed receive: use getAmDataLength() and receiveData() API + size_t messageLength = requestAm->getAmDataLength(); + ASSERT_GT(messageLength, 0) + << "AM data length should be greater than 0 for delayed receive"; + ASSERT_EQ(messageLength, _messageSize) + << "AM data length should match the sent message size"; + + // Allocate buffer based on the actual message length + if (_memoryType == UCS_MEMORY_TYPE_HOST) { + manualRecvBuffer = std::make_shared(messageLength); +#if UCXX_ENABLE_RMM + } else if (_memoryType == UCS_MEMORY_TYPE_CUDA) { + manualRecvBuffer = std::make_shared(messageLength); +#endif + } else { + FAIL() << "Unsupported memory type for test"; + } + + // Test both receiveData APIs: managed buffer and raw pointer + // First test null pointer validation for raw pointer API + EXPECT_THROW(std::ignore = requestAm->receiveData(nullptr), std::runtime_error); + + // Use the managed buffer receiveData() API + receiveDataRequest = requestAm->receiveData(manualRecvBuffer); + ASSERT_NE(receiveDataRequest, nullptr) + << "receiveData with managed buffer should return a valid request for delayed receive"; + } else { + // Immediate/eager receive: data is already available via getRecvBuffer() + manualRecvBuffer = requestAm->getRecvBuffer(); + ASSERT_NE(manualRecvBuffer, nullptr) + << "getRecvBuffer should return valid buffer for immediate receive"; + ASSERT_EQ(manualRecvBuffer->getSize(), _messageSize) + << "Received buffer size should match sent message size"; + + // For immediate receives, there's no additional request to wait on + receiveDataRequest = nullptr; + } + } + }); + _worker->registerAmReceiverCallback(receiverCallbackInfo, callback); + + allocate(1, false); + + // Submit and wait for transfers to complete + std::vector> requests; + requests.push_back(_ep->amSend(_sendPtr[0], _messageSize, _memoryType, receiverCallbackInfo)); + waitRequests(_worker, requests, _progressWorker); + + // Wait for the AM receiver callback to be called + while (receivedRequest == nullptr) + _progressWorker(); + + // Cast to check the receive type + auto requestAm = std::dynamic_pointer_cast(receivedRequest); + ASSERT_NE(requestAm, nullptr); + + if (requestAm->isDelayedReceive()) { + // For delayed receive: wait for receiveData request to be created and completed + while (receiveDataRequest == nullptr) + _progressWorker(); + + // Wait for the receive data request to complete + while (!receiveDataRequest->isCompleted()) + _progressWorker(); + + // Verify receive data request completed successfully + ASSERT_TRUE(receiveDataRequest->isCompleted()) << "Receive data request should be completed"; + ASSERT_EQ(receiveDataRequest->getStatus(), UCS_OK) + << "Receive data request should complete without error"; + } + // For immediate receives, no additional waiting needed - data is already available + + { + std::lock_guard lock(mutex); + + // Cast to RequestAm to verify it's the correct type + auto requestAm = std::dynamic_pointer_cast(receivedRequest); + ASSERT_NE(requestAm, nullptr) << "Request should be a RequestAm for AM operations"; + + if (requestAm->isDelayedReceive()) { + // Verify that the original delayed receive request has no regular receive buffer + ASSERT_EQ(requestAm->getRecvBuffer(), nullptr) + << "Delayed receive request should not have a receive buffer"; + + // Verify we have the manually received data + ASSERT_NE(manualRecvBuffer, nullptr) << "Manual receive buffer should be allocated"; + + // Verify buffer type matches expectation for delayed receive + ASSERT_THAT(manualRecvBuffer->getType(), + (_memoryType == UCS_MEMORY_TYPE_CUDA) ? _bufferType : ucxx::BufferType::Host); + } else { + // For immediate receives, the buffer should be available from getRecvBuffer() + ASSERT_NE(manualRecvBuffer, nullptr) << "Immediate receive buffer should be available"; + ASSERT_EQ(manualRecvBuffer, requestAm->getRecvBuffer()) + << "Buffer should match the one from getRecvBuffer()"; + + // Verify buffer type matches expectation for immediate receive + ASSERT_THAT(manualRecvBuffer->getType(), + (_messageSize >= _rndvThresh && _memoryType == UCS_MEMORY_TYPE_CUDA) + ? _bufferType + : ucxx::BufferType::Host); + } + + // Set up the received data for verification + _recvPtr[0] = manualRecvBuffer->data(); + } + + copyResults(); + + // Assert data correctness + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTestNoAmAllocator, ProgressAmReceiverCallbackDelayedReceiveRawPointer) +{ + if (_progressMode == ProgressMode::Wait) { + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + } + + if (_memoryType == UCS_MEMORY_TYPE_CUDA) { +#if !UCXX_ENABLE_RMM + GTEST_SKIP() << "UCXX was not built with RMM support"; +#else + _worker->registerAmAllocator(UCS_MEMORY_TYPE_CUDA, [](size_t length) { + return std::make_shared(length); + }); +#endif + } + + // Only test with larger messages to ensure rendezvous/delayed receive + if (_messageSize < _rndvThresh) { + GTEST_SKIP() << "Test only runs with rendezvous messages for delayed receive"; + } + + // Define AM receiver callback with delayReceive enabled + ucxx::AmReceiverCallbackInfo receiverCallbackInfo("TestAppRaw", 0, true); + + std::mutex mutex; + std::shared_ptr receivedRequest{nullptr}; + std::unique_ptr rawBuffer{nullptr}; + std::shared_ptr receiveDataRequest{nullptr}; + + auto callback = ucxx::AmReceiverCallbackType( + [this, &receivedRequest, &rawBuffer, &receiveDataRequest, &mutex]( + std::shared_ptr req, ucp_ep_h, const ucxx::AmReceiverCallbackInfo&) { + { + std::lock_guard lock(mutex); + receivedRequest = req; + + auto requestAm = std::dynamic_pointer_cast(req); + ASSERT_NE(requestAm, nullptr); + + if (requestAm->isDelayedReceive()) { + size_t messageLength = requestAm->getAmDataLength(); + ASSERT_EQ(messageLength, _messageSize); + + // Allocate raw buffer and test the raw pointer receiveData API + rawBuffer = std::make_unique(messageLength); + receiveDataRequest = requestAm->receiveData(rawBuffer.get()); + ASSERT_NE(receiveDataRequest, nullptr); + } else { + FAIL() << "Expected delayed receive operation"; + } + } + }); + + _worker->registerAmReceiverCallback(receiverCallbackInfo, callback); + allocate(1, false); + + std::vector> requests; + requests.push_back(_ep->amSend(_sendPtr[0], _messageSize, _memoryType, receiverCallbackInfo)); + waitRequests(_worker, requests, _progressWorker); + + while (receivedRequest == nullptr) + _progressWorker(); + while (receiveDataRequest == nullptr) + _progressWorker(); + while (!receiveDataRequest->isCompleted()) + _progressWorker(); + + { + std::lock_guard lock(mutex); + ASSERT_TRUE(receiveDataRequest->isCompleted()); + ASSERT_EQ(receiveDataRequest->getStatus(), UCS_OK); + _recvPtr[0] = rawBuffer.get(); + } + + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTestAmUserHeader, ProgressAmReceiverCallbackWithUserHeader) +{ + if (_progressMode == ProgressMode::Wait) { + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + } + + if (_registerCustomAmAllocator && _memoryType == UCS_MEMORY_TYPE_CUDA) { +#if !UCXX_ENABLE_RMM + GTEST_SKIP() << "UCXX was not built with RMM support"; +#else + _worker->registerAmAllocator(UCS_MEMORY_TYPE_CUDA, [](size_t length) { + return std::make_shared(length); + }); +#endif + } + + // Create user header data - test different data types + std::string userHeaderString = "Hello from user header!"; + ucxx::AmUserHeader userHeader(userHeaderString); + std::optional receivedUserHeader = std::nullopt; + + // Define AM receiver callback's owner and id with user header + ucxx::AmReceiverCallbackInfo receiverCallbackInfo("TestAppWithHeader", 42, false, userHeader); + + // Mutex required for blocking progress mode, otherwise `receivedRequests` may be + // accessed before `push_back()` completed. + std::mutex mutex; + + // Define AM receiver callback and register with worker + std::vector> receivedRequests; + auto callback = ucxx::AmReceiverCallbackType( + [this, &receivedRequests, &mutex, &receivedUserHeader]( + std::shared_ptr req, ucp_ep_h, ucxx::AmReceiverCallbackInfo& info) { + { + std::lock_guard lock(mutex); + receivedRequests.push_back(req); + // auto userHeader = std::move(info.userHeader); + receivedUserHeader = std::move(info.userHeader); + } + }); + + _worker->registerAmReceiverCallback(receiverCallbackInfo, callback); + + allocate(1, false); + + // Submit and wait for transfers to complete + std::vector> requests; + requests.push_back(_ep->amSend(_sendPtr[0], _messageSize, _memoryType, receiverCallbackInfo)); + waitRequests(_worker, requests, _progressWorker); + + while (receivedRequests.size() < 1) + _progressWorker(); + + { + std::lock_guard lock(mutex); + _recvPtr[0] = receivedRequests[0]->getRecvBuffer()->data(); + + // Messages larger than `_rndvThresh` are rendezvous and will use custom allocator, + // smaller messages are eager and will always be host-allocated. + ASSERT_THAT(receivedRequests[0]->getRecvBuffer()->getType(), + (_registerCustomAmAllocator && _messageSize >= _rndvThresh) + ? _bufferType + : ucxx::BufferType::Host); + } + + copyResults(); + + // Assert header and data correctness + ASSERT_EQ(receivedUserHeader->asString(), userHeaderString); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST(AmUserHeaderTest, Validation) +{ + // Test that AmUserHeader constructors properly validate input + + // Test valid constructions (should not throw) + ASSERT_NO_THROW([]() { ucxx::AmUserHeader h("test"); }()); + ASSERT_NO_THROW([]() { + std::string s = "test"; + ucxx::AmUserHeader h(s); + }()); + ASSERT_NO_THROW([]() { + std::vector data(3, 0); + ucxx::AmUserHeader h(data); + }()); + ASSERT_NO_THROW([]() { + std::vector data(3, 0); + ucxx::AmUserHeader h(std::move(data)); + }()); + ASSERT_NO_THROW([]() { + int value = 42; + ucxx::AmUserHeader h(value); + }()); + ASSERT_NO_THROW([]() { ucxx::AmUserHeader h("test", 4); }()); + + // Test invalid constructions (should throw) + EXPECT_THROW([]() { ucxx::AmUserHeader h(std::string("")); }(), std::invalid_argument); + EXPECT_THROW( + []() { + std::string empty = ""; + ucxx::AmUserHeader h(empty); + }(), + std::invalid_argument); + EXPECT_THROW( + []() { + std::vector emptyVec; + ucxx::AmUserHeader h(emptyVec); + }(), + std::invalid_argument); + EXPECT_THROW( + []() { + std::vector emptyVec; + ucxx::AmUserHeader h(std::move(emptyVec)); + }(), + std::invalid_argument); + EXPECT_THROW([]() { ucxx::AmUserHeader h(nullptr, 5); }(), std::invalid_argument); + EXPECT_THROW([]() { ucxx::AmUserHeader h("test", 0); }(), std::invalid_argument); + EXPECT_THROW([]() { ucxx::AmUserHeader h("", 0); }(), std::invalid_argument); + + // Test data access works correctly + std::string testStr = "Hello"; + ucxx::AmUserHeader header(testStr); + ASSERT_EQ(header.size(), 5); + ASSERT_EQ(header.asString(), "Hello"); + ASSERT_FALSE(header.empty()); +} + +TEST_P(RequestTestNoAmAllocator, ProgressStream) { allocate(); @@ -285,7 +657,7 @@ TEST_P(RequestTest, ProgressStream) } } -TEST_P(RequestTest, ProgressTag) +TEST_P(RequestTestNoAmAllocator, ProgressTag) { allocate(); @@ -301,7 +673,7 @@ TEST_P(RequestTest, ProgressTag) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, ProgressTagMulti) +TEST_P(RequestTestNoAmAllocator, ProgressTagMulti) { if (_progressMode == ProgressMode::Wait) { GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; @@ -348,7 +720,7 @@ TEST_P(RequestTest, ProgressTagMulti) ASSERT_THAT(_recv[i], ContainerEq(_send[i])); } -TEST_P(RequestTest, TagUserCallback) +TEST_P(RequestTestNoAmAllocator, TagUserCallback) { allocate(); @@ -382,7 +754,7 @@ TEST_P(RequestTest, TagUserCallback) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, TagUserCallbackDiscardReturn) +TEST_P(RequestTestNoAmAllocator, TagUserCallbackDiscardReturn) { allocate(); @@ -423,7 +795,7 @@ TEST_P(RequestTest, TagUserCallbackDiscardReturn) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, MemoryGet) +TEST_P(RequestTestNoAmAllocator, MemoryGet) { allocate(); @@ -450,7 +822,7 @@ TEST_P(RequestTest, MemoryGet) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, MemoryGetPreallocated) +TEST_P(RequestTestNoAmAllocator, MemoryGetPreallocated) { allocate(); @@ -473,7 +845,7 @@ TEST_P(RequestTest, MemoryGetPreallocated) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, MemoryGetWithOffset) +TEST_P(RequestTestNoAmAllocator, MemoryGetWithOffset) { if (_messageLength < 2) GTEST_SKIP() << "Message too small to perform operations with offsets"; allocate(); @@ -509,7 +881,7 @@ TEST_P(RequestTest, MemoryGetWithOffset) ASSERT_THAT(recvOffset, sendOffset); } -TEST_P(RequestTest, MemoryPut) +TEST_P(RequestTestNoAmAllocator, MemoryPut) { allocate(); @@ -536,7 +908,7 @@ TEST_P(RequestTest, MemoryPut) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, MemoryPutPreallocated) +TEST_P(RequestTestNoAmAllocator, MemoryPutPreallocated) { allocate(); @@ -559,7 +931,7 @@ TEST_P(RequestTest, MemoryPutPreallocated) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } -TEST_P(RequestTest, MemoryPutWithOffset) +TEST_P(RequestTestNoAmAllocator, MemoryPutWithOffset) { if (_messageLength < 2) GTEST_SKIP() << "Message too small to perform operations with offsets"; allocate(); @@ -595,8 +967,56 @@ TEST_P(RequestTest, MemoryPutWithOffset) ASSERT_THAT(recvOffset, sendOffset); } -INSTANTIATE_TEST_SUITE_P(ProgressModes, - RequestTest, +// Custom naming function for parameterized tests +std::string generateTestName( + const ::testing::TestParamInfo>& + info) +{ + auto [bufferType, + registerCustomAmAllocator, + enableDelayedSubmission, + progressMode, + messageLength] = info.param; + + std::string name; + + // Buffer type + name += (bufferType == ucxx::BufferType::Host) ? "Host" : "RMM"; + + // Custom AM allocator + if (registerCustomAmAllocator) { name += "_CustomAmAlloc"; } + + // Delayed submission + if (enableDelayedSubmission) { name += "_DelayedSubmission"; } + + // Progress mode + name += "_"; + switch (progressMode) { + case ProgressMode::Polling: name += "Polling"; break; + case ProgressMode::Blocking: name += "Blocking"; break; + case ProgressMode::Wait: name += "Wait"; break; + case ProgressMode::ThreadPolling: name += "ThreadPolling"; break; + case ProgressMode::ThreadBlocking: name += "ThreadBlocking"; break; + } + + // Message length + name += "_Msg"; + if (messageLength == 0) { + name += "Empty"; + } else if (messageLength >= 1048576) { + name += "1MB"; + } else if (messageLength >= 1024) { + name += std::to_string(messageLength / 1024) + "KB"; + } else { + name += std::to_string(messageLength) + "B"; + } + + return name; +} + +// Tests that support custom AM allocator +INSTANTIATE_TEST_SUITE_P(HostProgressModes, + RequestTestAmAllocator, Combine(Values(ucxx::BufferType::Host), Values(false), Values(false), @@ -605,19 +1025,21 @@ INSTANTIATE_TEST_SUITE_P(ProgressModes, // ProgressMode::Wait, // Hangs on Stream ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(0, 1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); -INSTANTIATE_TEST_SUITE_P(DelayedSubmission, - RequestTest, +INSTANTIATE_TEST_SUITE_P(HostDelayedSubmission, + RequestTestAmAllocator, Combine(Values(ucxx::BufferType::Host), Values(false), Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(0, 1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); #if UCXX_ENABLE_RMM INSTANTIATE_TEST_SUITE_P(RMMProgressModes, - RequestTest, + RequestTestAmAllocator, Combine(Values(ucxx::BufferType::RMM), Values(false, true), Values(false), @@ -626,15 +1048,85 @@ INSTANTIATE_TEST_SUITE_P(RMMProgressModes, // ProgressMode::Wait, // Hangs on Stream ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(0, 1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, - RequestTest, + RequestTestAmAllocator, Combine(Values(ucxx::BufferType::RMM), Values(false, true), Values(true), Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), - Values(0, 1, 1024, 2048, 1048576))); + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); +#endif + +// Tests that do NOT support custom AM allocator (always false for _registerCustomAmAllocator) +INSTANTIATE_TEST_SUITE_P(HostProgressModes, + RequestTestNoAmAllocator, + Combine(Values(ucxx::BufferType::Host), + Values(false), // Never use custom AM allocator for these tests + Values(false), + Values(ProgressMode::Polling, + ProgressMode::Blocking, + // ProgressMode::Wait, // Hangs on Stream + ProgressMode::ThreadPolling, + ProgressMode::ThreadBlocking), + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); + +INSTANTIATE_TEST_SUITE_P(HostDelayedSubmission, + RequestTestNoAmAllocator, + Combine(Values(ucxx::BufferType::Host), + Values(false), // Never use custom AM allocator for these tests + Values(true), + Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); + +#if UCXX_ENABLE_RMM +INSTANTIATE_TEST_SUITE_P(RMMProgressModes, + RequestTestNoAmAllocator, + Combine(Values(ucxx::BufferType::RMM), + Values(false), // Never use custom AM allocator for these tests + Values(false), + Values(ProgressMode::Polling, + ProgressMode::Blocking, + // ProgressMode::Wait, // Hangs on Stream + ProgressMode::ThreadPolling, + ProgressMode::ThreadBlocking), + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); + +INSTANTIATE_TEST_SUITE_P(RMMDelayedSubmission, + RequestTestNoAmAllocator, + Combine(Values(ucxx::BufferType::RMM), + Values(false), // Never use custom AM allocator for these tests + Values(true), + Values(ProgressMode::ThreadPolling, ProgressMode::ThreadBlocking), + Values(0, 1, 1024, 2048, 1048576)), + generateTestName); +#endif + +// Limited parameter set for user header tests - only test single message size +INSTANTIATE_TEST_SUITE_P(UserHeaderLimited, + RequestTestAmUserHeader, + Combine(Values(ucxx::BufferType::Host), + Values(false), + Values(false), + Values(ProgressMode::Polling, ProgressMode::Blocking), + Values(1024)), // Only test with 1024 byte messages + generateTestName); + +#if UCXX_ENABLE_RMM +INSTANTIATE_TEST_SUITE_P(UserHeaderLimitedRMM, + RequestTestAmUserHeader, + Combine(Values(ucxx::BufferType::RMM), + Values(false), + Values(false), + Values(ProgressMode::Polling, ProgressMode::Blocking), + Values(1024)), // Only test with 1024 byte message5 + generateTestName); #endif } // namespace diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index 9bc6bfbbf..823dc3898 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -194,7 +194,8 @@ TEST_P(WorkerProgressTest, ProgressAmReceiverCallback) // Define AM receiver callback and register with worker std::vector> receivedRequests; auto callback = ucxx::AmReceiverCallbackType( - [this, &receivedRequests, &mutex](std::shared_ptr req, ucp_ep_h) { + [this, &receivedRequests, &mutex]( + std::shared_ptr req, ucp_ep_h, const ucxx::AmReceiverCallbackInfo&) { { std::lock_guard lock(mutex); receivedRequests.push_back(req);