diff --git a/ci/run_cpp.sh b/ci/run_cpp.sh index 0d0a4ab1..bf725e8f 100755 --- a/ci/run_cpp.sh +++ b/ci/run_cpp.sh @@ -45,7 +45,7 @@ else fi run_cpp_tests() { - CMD_LINE="python ${TIMEOUT_TOOL_PATH} $((10*60)) ${GTESTS_PATH}/UCXX_TEST" + CMD_LINE="python ${TIMEOUT_TOOL_PATH} $((20*60)) ${GTESTS_PATH}/UCXX_TEST" log_command "${CMD_LINE}" UCX_TCP_CM_REUSEADDR=y ${CMD_LINE} diff --git a/cpp/include/ucxx/experimental/worker_builder.h b/cpp/include/ucxx/experimental/worker_builder.h index b8f9115b..1a9596cd 100644 --- a/cpp/include/ucxx/experimental/worker_builder.h +++ b/cpp/include/ucxx/experimental/worker_builder.h @@ -90,6 +90,19 @@ class WorkerBuilder final { */ WorkerBuilder& pythonFuture(bool enable = true); + /** + * @brief Configure request attributes querying. + * + * When enabled, each `ucxx::Request` created from the worker will have its UCP + * attributes (such as the debug string) queried immediately after submission, making + * them available via `ucxx::Request::getRequestAttributes()`. This may have + * non-negligible runtime cost and is therefore disabled by default. + * + * @param[in] enable whether request attributes querying is enabled (default: true). + * @return Reference to this builder for method chaining. + */ + WorkerBuilder& requestAttributes(bool enable = true); + /** * @brief Configure the preferred buffer type for CUDA allocations. * diff --git a/cpp/include/ucxx/internal/request_am.h b/cpp/include/ucxx/internal/request_am.h index fe3e8ac4..37d7dd97 100644 --- a/cpp/include/ucxx/internal/request_am.h +++ b/cpp/include/ucxx/internal/request_am.h @@ -67,15 +67,6 @@ class RecvAmMessage { AmReceiverCallbackType receiverCallback = AmReceiverCallbackType(), std::vector userHeader = {}); - /** - * @brief Set the UCP request. - * - * Set the underlying UCP request (`_request` attribute) of the `RequestAm`. - * - * @param[in] request the UCP request associated to the active message receive operation. - */ - void setUcpRequest(void* request); - /** * @brief Execute the `ucxx::Request::callback()`. * diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index 6ad7607b..c718ebca 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -8,6 +8,7 @@ #include #include #include +#include #include @@ -37,6 +38,14 @@ namespace ucxx { */ class Request : public Component { protected: + /** + * @brief Request attributes reported by `ucp_request_query`. + */ + struct Attributes { + ucs_memory_type memoryType{UCS_MEMORY_TYPE_UNKNOWN}; ///< Memory type of the request + std::string debugString{}; ///< Stored debug string + }; + ucs_status_t _status{UCS_INPROGRESS}; ///< Requests status std::string _status_msg{}; ///< Human-readable status message void* _request{nullptr}; ///< Pointer to UCP request @@ -54,6 +63,9 @@ class Request : public Component { bool _enablePythonFuture{true}; ///< Whether Python future is enabled for this request RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data + Attributes _requestAttr{}; ///< Request attributes queried when request is posted; the + ///< default `memoryType == UCS_MEMORY_TYPE_UNKNOWN` doubles + ///< as the "not populated yet" sentinel /** * @brief Protected constructor of an abstract `ucxx::Request`. @@ -235,6 +247,48 @@ class Request : public Component { * @return The received user header (if applicable) or an empty string. */ [[nodiscard]] virtual std::string getRecvHeader(); + + /** + * @brief Get the requests's attributes. + * + * Returns the request attributes as a struct. The owning `ucxx::Worker` must have been + * created with request attributes querying enabled (see + * `ucxx::experimental::WorkerBuilder::requestAttributes()`); otherwise the attributes + * are never populated and this method throws. Querying the underlying UCP request is + * an implementation detail performed eagerly when the request is submitted. All + * non-status fields exposed by UCP are queried, use `getStatus()` to obtain the status. + * + * @throw ucxx::UnsupportedError if the owning worker was not built with request + * attributes querying enabled. Requires `Worker` + * created with + * `ucxx::experimental::WorkerBuilder::requestAttributes(true)`. + * @throw ucxx::NoElemError if attributes are unavailable for this specific + * request: either because UCX took an inline-completion + * path that produced no UCP request to query, or because + * the request has not completed yet. Callers can + * distinguish the latter from the former by checking + * `isCompleted()`. + * + * @return An `Attributes` containing the request attributes. + */ + [[nodiscard]] Attributes queryAttributes(); + + protected: + /** + * @brief Publish the UCP request handle and capture its attributes. + * + * Single critical section that stores the UCP request pointer in `_request` and, when + * the owning worker has request attributes querying enabled, immediately queries those + * attributes. The completion path frees the UCP request inside `setStatus` under the + * same `_mutex`, so this helper guarantees the query and the free are mutually + * exclusive and that there are no use-after-free in threaded progress modes. + * + * Every submit site calls this after obtaining the request handle from the corresponding + * `ucp_*_nbx` function. + * + * @param[in] request the UCP request pointer returned by a non-blocking submit. + */ + void publishRequest(void* request); }; } // namespace ucxx diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index 661e4eb4..20e2705f 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -79,6 +79,8 @@ class Worker : public Component { protected: bool _enableFuture{ false}; ///< Boolean identifying whether the worker was created with future capability + bool _enableRequestAttributes{ + false}; ///< Whether request attributes (e.g. UCP debug info) are queried for each request std::mutex _futuresPoolMutex{}; ///< Mutex to access the futures pool std::queue> _futuresPool{}; ///< Futures pool to prevent running out of fresh futures @@ -507,6 +509,19 @@ class Worker : public Component { */ [[nodiscard]] bool isFutureEnabled() const; + /** + * @brief Inquire if worker has been created with request attributes querying enabled. + * + * Check whether the worker has been created with request attributes querying enabled. + * When enabled, each `ucxx::Request` will have its UCP attributes (such as the debug + * string) queried immediately after submission, making them available via + * `ucxx::Request::queryAttributes()`. Querying request attributes has a + * non-negligible runtime cost and is therefore disabled by default. + * + * @returns `true` if request attributes querying is enabled, `false` otherwise. + */ + [[nodiscard]] bool isRequestAttributesEnabled() const noexcept; + /** * @brief Get the preferred buffer type for CUDA allocations. * @@ -1000,7 +1015,7 @@ class Worker : public Component { * * Using a Python future may be requested by specifying `enablePythonFuture`. If a * Python future is requested, the Python application must then await on this future to - * ensure the transfer has completed. Requires UCXX Python support. + * ensure the transfer has completed. * * @note If a `callbackFunction` is specified, the lifetime of `callbackData` and of any * other objects used in the scope of `callbackFunction` must be guaranteed by the caller @@ -1020,6 +1035,32 @@ class Worker : public Component { const bool enablePythonFuture = false, RequestCallbackUserFunction callbackFunction = nullptr, RequestCallbackUserData callbackData = nullptr); + + /** + * @brief Worker attributes reported by `ucp_worker_query`. + */ + struct Attributes { + /// Thread safety level the worker was created with. + ucs_thread_mode_t threadMode{UCS_THREAD_MODE_MULTI}; + /// Maximum allowed header size for `ucp_am_send_nbx`. + size_t maxAmHeader{0}; + /// Worker name used by tracing and analysis tools. + std::string name{}; + /// Maximum debug-string buffer size accepted by `ucp_request_query`. + size_t maxDebugString{0}; + }; + + /** + * @brief Get the worker's attributes. + * + * Returns the worker attributes as a struct, querying UCP via `ucp_worker_query` under + * the hood. All non-address fields exposed by UCP are queried, use `getAddress()` to + * obtain the address. + * + * @returns An `Attributes` filled with all queried fields. + * @throws ucxx::Error if an error occurred while querying worker attributes. + */ + [[nodiscard]] Attributes queryAttributes() const; }; /** diff --git a/cpp/src/experimental/worker_builder.cpp b/cpp/src/experimental/worker_builder.cpp index a8c780aa..1efd0994 100644 --- a/cpp/src/experimental/worker_builder.cpp +++ b/cpp/src/experimental/worker_builder.cpp @@ -17,6 +17,7 @@ struct WorkerBuilder::Impl { std::shared_ptr context; bool enableDelayedSubmission{false}; bool enableFuture{false}; + bool enableRequestAttributes{false}; BufferType cudaBufferType{BufferType::Invalid}; explicit Impl(std::shared_ptr ctx) : context(std::move(ctx)) {} @@ -41,6 +42,12 @@ WorkerBuilder& WorkerBuilder::pythonFuture(bool enable) return *this; } +WorkerBuilder& WorkerBuilder::requestAttributes(bool enable) +{ + _impl->enableRequestAttributes = enable; + return *this; +} + WorkerBuilder& WorkerBuilder::cudaBufferType(BufferType bufferType) { _impl->cudaBufferType = bufferType; @@ -51,6 +58,7 @@ std::shared_ptr WorkerBuilder::build() const { auto worker = ucxx::createWorker(_impl->context, _impl->enableDelayedSubmission, _impl->enableFuture); + worker->_enableRequestAttributes = _impl->enableRequestAttributes; if (_impl->cudaBufferType != BufferType::Invalid) worker->setCudaBufferType(_impl->cudaBufferType); return worker; diff --git a/cpp/src/internal/request_am.cpp b/cpp/src/internal/request_am.cpp index e2eb65c2..6fbe7d92 100644 --- a/cpp/src/internal/request_am.cpp +++ b/cpp/src/internal/request_am.cpp @@ -40,8 +40,6 @@ RecvAmMessage::RecvAmMessage(internal::AmData* amData, } } -void RecvAmMessage::setUcpRequest(void* request) { _request->_request = request; } - void RecvAmMessage::callback(void* request, ucs_status_t status) { std::visit(data::dispatch{ diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 1551e4f2..05eb48f5 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -3,6 +3,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ #include +#include #include #include #include @@ -140,7 +141,7 @@ void Request::callback(void* request, ucs_status_t status) if (_status != UCS_INPROGRESS) ucxx_trace_req_f(_ownerString.c_str(), this, - _request, + request, _operationName.c_str(), "has status already set to %d (%s), callback setting %d (%s)", _status, @@ -148,12 +149,10 @@ void Request::callback(void* request, ucs_status_t status) status, ucs_status_string(status)); - if (UCS_PTR_IS_PTR(_request)) ucp_request_free(request); - - ucxx_trace_req_f(_ownerString.c_str(), this, _request, _operationName.c_str(), "completed"); + ucxx_trace_req_f(_ownerString.c_str(), this, request, _operationName.c_str(), "completed"); setStatus(status); ucxx_trace_req_f( - _ownerString.c_str(), this, _request, _operationName.c_str(), "isCompleted: %d", isCompleted()); + _ownerString.c_str(), this, request, _operationName.c_str(), "isCompleted: %d", isCompleted()); } void Request::process() @@ -235,11 +234,69 @@ void Request::setStatus(ucs_status_t status) _ownerString.c_str(), this, _request, _operationName.c_str(), "invoking user callback"); _callback(status, _callbackData); } + + // Free the UCP request inside the lock so it is mutually exclusive with + // `publishRequest()`/`queryRequestAttributes()` on the submit thread. + if (UCS_PTR_IS_PTR(_request)) { + ucp_request_free(_request); + _request = nullptr; + } } } const std::string& Request::getOwnerString() const { return _ownerString; } +void Request::publishRequest(void* request) +{ + if (!_worker->isRequestAttributesEnabled()) { + std::lock_guard lock(_mutex); + _request = request; + return; + } + + std::lock_guard lock(_mutex); + _request = request; + + if (_requestAttr.memoryType != UCS_MEMORY_TYPE_UNKNOWN) return; + + ucp_request_attr_t result; + + auto worker_attr = _worker->queryAttributes(); + + std::string debugString(worker_attr.maxDebugString, '\0'); + + result.field_mask = UCP_REQUEST_ATTR_FIELD_MEM_TYPE | UCP_REQUEST_ATTR_FIELD_INFO_STRING | + UCP_REQUEST_ATTR_FIELD_INFO_STRING_SIZE; + + result.debug_string = debugString.data(); + result.debug_string_size = debugString.size(); + + if (UCS_PTR_IS_PTR(_request)) { + auto queryStatus = ucp_request_query(_request, &result); + if (queryStatus == UCS_OK && result.debug_string != nullptr) { + debugString.resize(std::strlen(debugString.c_str())); + _requestAttr.debugString = std::move(debugString); + _requestAttr.memoryType = result.mem_type; + } + } +} + +Request::Attributes Request::queryAttributes() +{ + if (!_worker->isRequestAttributesEnabled()) + throw ucxx::UnsupportedError( + "Request attributes querying is disabled on the owning worker; build the worker " + "with `ucxx::experimental::WorkerBuilder::requestAttributes(true)` to enable it"); + + std::lock_guard lock(_mutex); + + if (_requestAttr.memoryType != UCS_MEMORY_TYPE_UNKNOWN) return _requestAttr; + + throw ucxx::NoElemError( + "Request attributes are not available for this request: UCX took an inline-completion " + "path with no queryable UCP request, or the request has not completed yet"); +} + std::shared_ptr Request::getRecvBuffer() { return nullptr; } std::string Request::getRecvHeader() { return {}; } diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index d77ac2bb..4ccdae33 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -368,7 +368,7 @@ ucs_status_t RequestAm::recvCallback(void* arg, return s; } else { // The request will be handled by the callback - recvAmMessage->setUcpRequest(status); + req->publishRequest(status); amData->_registerInflightRequest(req); { @@ -470,8 +470,7 @@ void RequestAm::request() amSend._count, ¶m); - std::lock_guard lock(_mutex); - _request = request; + publishRequest(request); }, [](auto) { throw ucxx::UnsupportedError("Only send active messages can call request()"); }, }, diff --git a/cpp/src/request_endpoint_close.cpp b/cpp/src/request_endpoint_close.cpp index 3e117582..05f83845 100644 --- a/cpp/src/request_endpoint_close.cpp +++ b/cpp/src/request_endpoint_close.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include "ucxx/request_data.h" @@ -78,8 +78,7 @@ void RequestEndpointClose::request() else throw ucxx::Error("A valid endpoint or worker is required for a close operation."); - std::lock_guard lock(_mutex); - _request = request; + publishRequest(request); } void RequestEndpointClose::populateDelayedSubmission() diff --git a/cpp/src/request_flush.cpp b/cpp/src/request_flush.cpp index 1dcb3a93..bb20491b 100644 --- a/cpp/src/request_flush.cpp +++ b/cpp/src/request_flush.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -77,8 +77,7 @@ void RequestFlush::request() else throw ucxx::Error("A valid endpoint or worker is required for a flush operation."); - std::lock_guard lock(_mutex); - _request = request; + publishRequest(request); } void RequestFlush::populateDelayedSubmission() diff --git a/cpp/src/request_mem.cpp b/cpp/src/request_mem.cpp index bff6caee..16f157e5 100644 --- a/cpp/src/request_mem.cpp +++ b/cpp/src/request_mem.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -127,8 +127,7 @@ void RequestMem::request() }, _requestData); - std::lock_guard lock(_mutex); - _request = request; + publishRequest(request); } void RequestMem::populateDelayedSubmission() diff --git a/cpp/src/request_stream.cpp b/cpp/src/request_stream.cpp index 4327cc40..30b42f60 100644 --- a/cpp/src/request_stream.cpp +++ b/cpp/src/request_stream.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -91,8 +91,7 @@ void RequestStream::request() }, _requestData); - std::lock_guard lock(_mutex); - _request = request; + publishRequest(request); } void RequestStream::populateDelayedSubmission() diff --git a/cpp/src/request_tag.cpp b/cpp/src/request_tag.cpp index a4424a4f..edc0f862 100644 --- a/cpp/src/request_tag.cpp +++ b/cpp/src/request_tag.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include @@ -161,8 +161,7 @@ void RequestTag::request() }, _requestData); - std::lock_guard lock(_mutex); - _request = request; + publishRequest(request); } void RequestTag::populateDelayedSubmission() diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index 9e6b556a..b9b9063d 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -3,6 +3,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ #include +#include #include #include #include @@ -200,6 +201,22 @@ std::string Worker::getInfo() return utils::decodeTextFileDescriptor(TextFileDescriptor); } +Worker::Attributes Worker::queryAttributes() const +{ + ucp_worker_attr_t attr = { + .field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE | UCP_WORKER_ATTR_FIELD_MAX_AM_HEADER | + UCP_WORKER_ATTR_FIELD_NAME | UCP_WORKER_ATTR_FIELD_MAX_INFO_STRING}; + + utils::ucsErrorThrow(ucp_worker_query(_handle, &attr)); + + return Attributes{ + .threadMode = attr.thread_mode, + .maxAmHeader = attr.max_am_header, + .name = std::string(attr.name, ::strnlen(attr.name, sizeof(attr.name))), + .maxDebugString = attr.max_debug_string, + }; +} + bool Worker::isDelayedRequestSubmissionEnabled() const { return _delayedSubmissionCollection->isDelayedRequestSubmissionEnabled(); @@ -207,6 +224,8 @@ bool Worker::isDelayedRequestSubmissionEnabled() const bool Worker::isFutureEnabled() const { return _enableFuture; } +bool Worker::isRequestAttributesEnabled() const noexcept { return _enableRequestAttributes; } + BufferType Worker::getCudaBufferType() const { return _cudaBufferType; } void Worker::setCudaBufferType(BufferType bufferType) diff --git a/cpp/tests/request.cpp b/cpp/tests/request.cpp index 6997e6c4..e28584df 100644 --- a/cpp/tests/request.cpp +++ b/cpp/tests/request.cpp @@ -62,6 +62,40 @@ class RequestTest : public ::testing::TestWithParam< std::vector _sendPtr{nullptr}; std::vector _recvPtr{nullptr}; + void buildWorker(bool enableRequestAttributes) + { + auto builder = ucxx::experimental::createWorker(_context) + .delayedSubmission(_enableDelayedSubmission) + .requestAttributes(enableRequestAttributes); + + if (_bufferType == ucxx::BufferType::RMM || _bufferType == ucxx::BufferType::CCCL) + builder.cudaBufferType(_bufferType); + + _worker = builder.build(); + + if (_progressMode == ProgressMode::Blocking) { + _worker->initBlockingProgressMode(); + } else if (_progressMode == ProgressMode::ThreadPolling || + _progressMode == ProgressMode::ThreadBlocking) { + _worker->setProgressThreadStartCallback(::createCudaContextCallback, nullptr); + + if (_progressMode == ProgressMode::ThreadPolling) _worker->startProgressThread(true); + if (_progressMode == ProgressMode::ThreadBlocking) _worker->startProgressThread(false); + } + + _progressWorker = getProgressFunction(_worker, _progressMode); + + _ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + } + + void rebuildWorker(bool enableRequestAttributes) + { + if (_worker && _worker->isProgressThreadRunning()) _worker->stopProgressThread(); + _ep.reset(); + _worker.reset(); + buildWorker(enableRequestAttributes); + } + void SetUp() { std::tie(_bufferType, @@ -88,25 +122,7 @@ class RequestTest : public ::testing::TestWithParam< _context = ucxx::createContext({{"RNDV_THRESH", std::to_string(_rndvThresh)}}, ucxx::Context::defaultFeatureFlags); - auto builder = - ucxx::experimental::createWorker(_context).delayedSubmission(_enableDelayedSubmission); - if (_bufferType == ucxx::BufferType::RMM || _bufferType == ucxx::BufferType::CCCL) - builder.cudaBufferType(_bufferType); - _worker = builder.build(); - - if (_progressMode == ProgressMode::Blocking) { - _worker->initBlockingProgressMode(); - } else if (_progressMode == ProgressMode::ThreadPolling || - _progressMode == ProgressMode::ThreadBlocking) { - _worker->setProgressThreadStartCallback(::createCudaContextCallback, nullptr); - - if (_progressMode == ProgressMode::ThreadPolling) _worker->startProgressThread(true); - if (_progressMode == ProgressMode::ThreadBlocking) _worker->startProgressThread(false); - } - - _progressWorker = getProgressFunction(_worker, _progressMode); - - _ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + buildWorker(false); } void TearDown() @@ -222,8 +238,8 @@ TEST_P(RequestTest, ProgressAm) auto recvReq = requests[1]; _recvPtr[0] = recvReq->getRecvBuffer()->data(); - // Messages larger than `_rndvThresh` are rendezvous and will use custom allocator, - // smaller messages are eager and will always be host-allocated. + // Messages of size `_rndvThresh` or larger are rendezvous and will use the custom + // allocator, smaller messages are eager and will always be host-allocated. ASSERT_THAT(recvReq->getRecvBuffer()->getType(), (_registerCustomAmAllocator && _messageSize >= _rndvThresh) ? _bufferType : ucxx::BufferType::Host); @@ -533,6 +549,261 @@ TEST_P(RequestTest, ProgressTag) ASSERT_THAT(_recv[0], ContainerEq(_send[0])); } +TEST_P(RequestTest, ProgressTagRequestAttributes) +{ + if (_messageSize < _rndvThresh) + GTEST_SKIP() << "Eager messages do not create a ucp_request and thus no debug info"; + + rebuildWorker(true); + + allocate(); + + std::vector> requests; + requests.push_back(_ep->tagSend(_sendPtr[0], _messageSize, ucxx::Tag{0})); + requests.push_back(_ep->tagRecv(_recvPtr[0], _messageSize, ucxx::Tag{0}, ucxx::TagMaskFull)); + waitRequests(_worker, requests, _progressWorker); + + for (const auto& request : requests) { + auto debugString = request->queryAttributes().debugString; + ASSERT_THAT(debugString, ::testing::HasSubstr("length " + std::to_string(_messageSize))); + ASSERT_THAT(debugString, + ::testing::HasSubstr(_memoryType == UCS_MEMORY_TYPE_HOST ? "host" : "cuda")); + } + + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +class RequestAttributesDisabledTest : public ::testing::Test { + protected: + static constexpr size_t kMessageLength = 1024; + static constexpr size_t kMessageSize = kMessageLength * sizeof(int); + + std::shared_ptr _context; + std::shared_ptr _worker; + std::shared_ptr _ep; + std::function _progressWorker; + std::vector _sendBuf; + std::vector _recvBuf; + + void SetUp() override + { + _context = ucxx::createContext({}, ucxx::Context::defaultFeatureFlags); + _worker = ucxx::experimental::createWorker(_context).build(); + ASSERT_FALSE(_worker->isRequestAttributesEnabled()); + + _ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + _progressWorker = getProgressFunction(_worker, ProgressMode::Polling); + + _sendBuf.resize(kMessageLength); + _recvBuf.resize(kMessageLength); + std::iota(_sendBuf.begin(), _sendBuf.end(), 0); + } + + void expectAllThrow(const std::vector>& requests) const + { + for (const auto& request : requests) { + EXPECT_THROW(std::ignore = request->queryAttributes(), ucxx::UnsupportedError); + } + } +}; + +TEST_F(RequestAttributesDisabledTest, Tag) +{ + std::vector> requests; + requests.push_back(_ep->tagSend(_sendBuf.data(), kMessageSize, ucxx::Tag{0})); + requests.push_back(_ep->tagRecv(_recvBuf.data(), kMessageSize, ucxx::Tag{0}, ucxx::TagMaskFull)); + waitRequests(_worker, requests, _progressWorker); + + expectAllThrow(requests); + ASSERT_THAT(_recvBuf, ::testing::ContainerEq(_sendBuf)); +} + +TEST_F(RequestAttributesDisabledTest, Stream) +{ + std::vector> requests; + requests.push_back(_ep->streamSend(_sendBuf.data(), kMessageSize, 0)); + requests.push_back(_ep->streamRecv(_recvBuf.data(), kMessageSize, 0)); + waitRequests(_worker, requests, _progressWorker); + + expectAllThrow(requests); + ASSERT_THAT(_recvBuf, ::testing::ContainerEq(_sendBuf)); +} + +TEST_F(RequestAttributesDisabledTest, Am) +{ + std::vector> requests; + requests.push_back(_ep->amSend(_sendBuf.data(), kMessageSize, UCS_MEMORY_TYPE_HOST)); + requests.push_back(_ep->amRecv()); + waitRequests(_worker, requests, _progressWorker); + + expectAllThrow(requests); + + auto recvBuffer = requests[1]->getRecvBuffer(); + ASSERT_EQ(recvBuffer->getSize(), kMessageSize); + std::vector received(reinterpret_cast(recvBuffer->data()), + reinterpret_cast(recvBuffer->data()) + kMessageLength); + ASSERT_THAT(received, ::testing::ContainerEq(_sendBuf)); +} + +TEST_F(RequestAttributesDisabledTest, MemoryGet) +{ + auto memoryHandle = _context->createMemoryHandle(kMessageSize, nullptr, UCS_MEMORY_TYPE_HOST); + std::memcpy( + reinterpret_cast(memoryHandle->getBaseAddress()), _sendBuf.data(), kMessageSize); + + auto serializedRemoteKey = memoryHandle->createRemoteKey()->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + auto request = _ep->memGet(_recvBuf.data(), kMessageSize, remoteKey); + std::vector> requests{request, _ep->flush()}; + waitRequests(_worker, requests, _progressWorker); + + expectAllThrow({request}); + ASSERT_THAT(_recvBuf, ::testing::ContainerEq(_sendBuf)); +} + +TEST_F(RequestAttributesDisabledTest, MemoryPut) +{ + auto memoryHandle = _context->createMemoryHandle(kMessageSize, nullptr, UCS_MEMORY_TYPE_HOST); + + auto serializedRemoteKey = memoryHandle->createRemoteKey()->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + auto request = _ep->memPut(_sendBuf.data(), kMessageSize, remoteKey); + std::vector> requests{request, _ep->flush()}; + waitRequests(_worker, requests, _progressWorker); + + expectAllThrow({request}); + + std::memcpy( + _recvBuf.data(), reinterpret_cast(memoryHandle->getBaseAddress()), kMessageSize); + ASSERT_THAT(_recvBuf, ::testing::ContainerEq(_sendBuf)); +} + +TEST_P(RequestTest, ProgressStreamRequestAttributes) +{ + if (_messageSize == 0) GTEST_SKIP() << "Stream rejects zero-length transfers"; + + rebuildWorker(true); + + allocate(); + + auto sendRequest = _ep->streamSend(_sendPtr[0], _messageSize, 0); + auto recvRequest = _ep->streamRecv(_recvPtr[0], _messageSize, 0); + std::vector> requests{sendRequest, recvRequest}; + waitRequests(_worker, requests, _progressWorker); + + try { + auto sendDebug = sendRequest->queryAttributes().debugString; + EXPECT_FALSE(sendDebug.empty()); + EXPECT_THAT(sendDebug, ::testing::HasSubstr("length " + std::to_string(_messageSize))); + } catch (const ucxx::NoElemError&) { + // Send completed inline; no UCP request handle to query. + } + + try { + auto recvDebug = recvRequest->queryAttributes().debugString; + EXPECT_THAT(recvDebug, ::testing::HasSubstr("no debug info")); + } catch (const ucxx::NoElemError&) { + // Recv completed inline; no UCP request handle to query. + } + + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, ProgressAmRequestAttributes) +{ + if (_messageSize < _rndvThresh) + GTEST_SKIP() << "Eager messages complete inline without a UCP request to query"; + if (_progressMode == ProgressMode::Wait) + GTEST_SKIP() << "Interrupting UCP worker progress operation in wait mode is not possible"; + + rebuildWorker(true); + + allocate(1, false); + + std::vector> requests; + requests.push_back(_ep->amSend(_sendPtr[0], _messageSize, _memoryType)); + requests.push_back(_ep->amRecv()); + waitRequests(_worker, requests, _progressWorker); + + for (const auto& request : requests) { + auto debugString = request->queryAttributes().debugString; + ASSERT_THAT(debugString, ::testing::HasSubstr("length " + std::to_string(_messageSize))); + } + + auto recvReq = requests[1]; + _recvPtr[0] = recvReq->getRecvBuffer()->data(); + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, MemoryGetRequestAttributes) +{ + if (_messageSize == 0) GTEST_SKIP() << "Zero-length memGet completes without a UCP request"; + + rebuildWorker(true); + + allocate(); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr, _memoryType); + copyMemoryTypeAware( + reinterpret_cast(memoryHandle->getBaseAddress()), _sendPtr[0], _messageSize); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + auto request = _ep->memGet(_recvPtr[0], _messageSize, remoteKey); + std::vector> requests; + requests.push_back(request); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + auto debugString = request->queryAttributes().debugString; + ASSERT_THAT(debugString, ::testing::HasSubstr("length " + std::to_string(_messageSize))); + + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + +TEST_P(RequestTest, MemoryPutRequestAttributes) +{ + if (_messageSize == 0) GTEST_SKIP() << "Zero-length memPut completes without a UCP request"; + + rebuildWorker(true); + + allocate(); + + auto memoryHandle = _context->createMemoryHandle(_messageSize, nullptr, _memoryType); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(_ep, serializedRemoteKey); + + auto request = _ep->memPut(_sendPtr[0], _messageSize, remoteKey); + std::vector> requests; + requests.push_back(request); + requests.push_back(_ep->flush()); + waitRequests(_worker, requests, _progressWorker); + + try { + auto debugString = request->queryAttributes().debugString; + EXPECT_FALSE(debugString.empty()); + EXPECT_THAT(debugString, ::testing::HasSubstr("length " + std::to_string(_messageSize))); + } catch (const ucxx::NoElemError&) { + // Request completed inline; no UCP request handle to query. + } + + copyMemoryTypeAware( + _recvPtr[0], reinterpret_cast(memoryHandle->getBaseAddress()), _messageSize); + + copyResults(); + ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +} + TEST_P(RequestTest, ProgressTagMulti) { if (_progressMode == ProgressMode::Wait) { diff --git a/cpp/tests/worker.cpp b/cpp/tests/worker.cpp index b7350a93..1749c8c7 100644 --- a/cpp/tests/worker.cpp +++ b/cpp/tests/worker.cpp @@ -108,6 +108,21 @@ class WorkerGenericCallbackSingleTest : public WorkerProgressTest {}; TEST_F(WorkerTest, HandleIsValid) { ASSERT_TRUE(_worker->getHandle() != nullptr); } +TEST_F(WorkerTest, QueryAttributes) +{ + auto attrs = _worker->queryAttributes(); + + // The worker was created with UCS_THREAD_MODE_MULTI in the constructor. + EXPECT_EQ(attrs.threadMode, UCS_THREAD_MODE_MULTI); + + // The remaining fields are determined by UCX configuration, so the strongest + // portable assertion is that they were populated with non-zero / non-empty + // values. + EXPECT_GT(attrs.maxAmHeader, 0u); + EXPECT_FALSE(attrs.name.empty()); + EXPECT_GT(attrs.maxDebugString, 0u); +} + TEST_P(WorkerCapabilityTest, CheckCapability) { ASSERT_EQ(_worker->isDelayedRequestSubmissionEnabled(), _enableDelayedSubmission); @@ -876,6 +891,35 @@ TEST(WorkerBuilderTest, BuilderBackwardCompatibility) ASSERT_TRUE(worker2->isFutureEnabled()); } +TEST(WorkerBuilderTest, RequestAttributesDefaultDisabled) +{ + auto context = ucxx::experimental::createContext(ucxx::Context::defaultFeatureFlags).build(); + auto worker = ucxx::experimental::createWorker(context).build(); + + ASSERT_TRUE(worker != nullptr); + ASSERT_FALSE(worker->isRequestAttributesEnabled()); +} + +TEST(WorkerBuilderTest, RequestAttributesEnabled) +{ + auto context = ucxx::experimental::createContext(ucxx::Context::defaultFeatureFlags).build(); + auto worker = ucxx::experimental::createWorker(context).requestAttributes(true).build(); + + ASSERT_TRUE(worker != nullptr); + ASSERT_TRUE(worker->isRequestAttributesEnabled()); + ASSERT_FALSE(worker->isDelayedRequestSubmissionEnabled()); + ASSERT_FALSE(worker->isFutureEnabled()); +} + +TEST(WorkerBuilderTest, RequestAttributesExplicitDisable) +{ + auto context = ucxx::experimental::createContext(ucxx::Context::defaultFeatureFlags).build(); + auto worker = ucxx::experimental::createWorker(context).requestAttributes(false).build(); + + ASSERT_TRUE(worker != nullptr); + ASSERT_FALSE(worker->isRequestAttributesEnabled()); +} + TEST(AmReceiverCallbackOwnerTypeTest, DefaultConstructsEmpty) { ucxx::AmReceiverCallbackOwnerType owner;