From 296f94ed6190d6f1a84d9b40b5e4c7aa3fd0c22e Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Sun, 5 May 2024 04:31:49 -0700 Subject: [PATCH 01/39] Use `weak_ptr` in `ErrorCallbackData` A race could cause the `Endpoint::errorCallback` to dereference an already closed and destroyed `ucxx::Endpoint`, leading to segmentation faults and sometimes deadlocks. With the use of a `weak_ptr` that can be prevented by returning before dereferencing the pointer which should be safe since the owner has already been destroyed anyway. --- cpp/include/ucxx/endpoint.h | 7 ++--- cpp/src/endpoint.cpp | 61 ++++++++++++++++++++++++++++--------- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 54332bba8..433dac8ab 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -48,8 +48,8 @@ struct EpParamsDeleter { * callback to modify the `ucxx::Endpoint` with information relevant to the error occurred. */ struct ErrorCallbackData { - Endpoint* endpoint{ - nullptr}; ///< Pointer to the `ucxx::Endpoint` that owns this object, used only for logging. + std::weak_ptr + endpoint{}; ///< Pointer to the `ucxx::Endpoint` that owns this object, used only for logging. std::mutex mutex{std::mutex()}; ///< Mutex used to prevent race conditions with ///< `ucxx::Endpoint::setCloseCallback()`. ucs_status_t status{UCS_INPROGRESS}; ///< Endpoint status @@ -60,8 +60,7 @@ struct ErrorCallbackData { nullptr}; ///< Argument to be passed to close callback std::shared_ptr worker{nullptr}; ///< Worker the endpoint has been created from - ErrorCallbackData(Endpoint* endpoint, - std::shared_ptr inflightRequests, + ErrorCallbackData(std::shared_ptr inflightRequests, std::shared_ptr worker); ErrorCallbackData() = delete; diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 01237ba28..5c274407e 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -34,10 +34,9 @@ namespace ucxx { -ErrorCallbackData::ErrorCallbackData(Endpoint* endpoint, - std::shared_ptr inflightRequests, +ErrorCallbackData::ErrorCallbackData(std::shared_ptr inflightRequests, std::shared_ptr worker) - : endpoint(endpoint), inflightRequests(inflightRequests), worker(worker) + : inflightRequests(inflightRequests), worker(worker) { } @@ -67,12 +66,17 @@ Endpoint::Endpoint(std::shared_ptr workerOrListener, setParent(workerOrListener); - _callbackData = std::make_unique(this, _inflightRequests, worker); + _callbackData = std::make_unique(_inflightRequests, worker); - params->err_mode = - (endpointErrorHandling ? UCP_ERR_HANDLING_MODE_PEER : UCP_ERR_HANDLING_MODE_NONE); - params->err_handler.cb = Endpoint::errorCallback; - params->err_handler.arg = _callbackData.get(); + if (endpointErrorHandling) { + params->err_mode = UCP_ERR_HANDLING_MODE_PEER; + params->err_handler.cb = Endpoint::errorCallback; + params->err_handler.arg = _callbackData.get(); + } else { + params->err_mode = UCP_ERR_HANDLING_MODE_NONE; + params->err_handler.cb = nullptr; + params->err_handler.arg = nullptr; + } if (worker->isProgressThreadRunning()) { ucs_status_t status = UCS_INPROGRESS; @@ -120,7 +124,9 @@ std::shared_ptr createEndpointFromHostname(std::shared_ptr wor params.sockaddr.addrlen = info->ai_addrlen; params.sockaddr.addr = info->ai_addr; - return std::shared_ptr(new Endpoint(worker, ¶ms, endpointErrorHandling)); + auto ep = std::shared_ptr(new Endpoint(worker, ¶ms, endpointErrorHandling)); + ep->_callbackData->endpoint = ep; + return ep; } std::shared_ptr createEndpointFromConnRequest(std::shared_ptr listener, @@ -136,7 +142,9 @@ std::shared_ptr createEndpointFromConnRequest(std::shared_ptr(new Endpoint(listener, ¶ms, endpointErrorHandling)); + auto ep = std::shared_ptr(new Endpoint(listener, ¶ms, endpointErrorHandling)); + ep->_callbackData->endpoint = ep; + return ep; } std::shared_ptr createEndpointFromWorkerAddress(std::shared_ptr worker, @@ -153,7 +161,9 @@ std::shared_ptr createEndpointFromWorkerAddress(std::shared_ptrgetHandle()}; - return std::shared_ptr(new Endpoint(worker, ¶ms, endpointErrorHandling)); + auto ep = std::shared_ptr(new Endpoint(worker, ¶ms, endpointErrorHandling)); + ep->_callbackData->endpoint = ep; + return ep; } Endpoint::~Endpoint() @@ -566,7 +576,28 @@ std::shared_ptr Endpoint::getWorker() { return ::ucxx::getWorker(_parent void Endpoint::errorCallback(void* arg, ucp_ep_h ep, ucs_status_t status) { ErrorCallbackData* data = reinterpret_cast(arg); + + // Unable to cast to `ErrorCallbackData*`: invalid `arg`. + if (data == nullptr) { + ucxx_error("ucxx::Endpoint::%s, UCP handle: %p, error callback called with status %d: %s", + __func__, + ep, + status, + ucs_status_string(status)); + return; + } + + std::shared_ptr endpoint{nullptr}; + try { + endpoint = data->endpoint.lock(); + } catch (std::bad_weak_ptr& exception) { + // Unable to acquire `std::shared_ptr`: owner was already destroyed. + return; + } + + // Endpoint is already closing. if (data->closing.exchange(true)) return; + data->status = status; data->worker->scheduleRequestCancel(data->inflightRequests->release()); { @@ -574,7 +605,7 @@ void Endpoint::errorCallback(void* arg, ucp_ep_h ep, ucs_status_t status) if (data->closeCallback) { ucxx_debug("ucxx::Endpoint::%s: %p, UCP handle: %p, calling user close callback", __func__, - data->endpoint, + endpoint.get(), ep); data->closeCallback(status, data->closeCallbackArg); data->closeCallback = nullptr; @@ -585,14 +616,16 @@ void Endpoint::errorCallback(void* arg, ucp_ep_h ep, ucs_status_t status) // Connection reset and timeout often represent just a normal remote // endpoint disconnect, log only in debug mode. if (status == UCS_ERR_CONNECTION_RESET || status == UCS_ERR_ENDPOINT_TIMEOUT) - ucxx_debug("ucxx::Endpoint::%s, UCP handle: %p, error callback called with status %d: %s", + ucxx_debug("ucxx::Endpoint::%s: %p, UCP handle: %p, error callback called with status %d: %s", __func__, + endpoint.get(), ep, status, ucs_status_string(status)); else - ucxx_error("ucxx::Endpoint::%s, UCP handle: %p, error callback called with status %d: %s", + ucxx_error("ucxx::Endpoint::%s: %p, UCP handle: %p, error callback called with status %d: %s", __func__, + endpoint.get(), ep, status, ucs_status_string(status)); From f4b7dc4b9be68d455916034c91087339d22d11f6 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 7 May 2024 03:01:46 -0700 Subject: [PATCH 02/39] EP cancel/non-blocking close --- cpp/src/endpoint.cpp | 10 +++ cpp/tests/endpoint.cpp | 148 +++++++++++++++++++++++++++++++++++++- cpp/tests/include/utils.h | 9 +++ 3 files changed, 166 insertions(+), 1 deletion(-) diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 5c274407e..48140e8a1 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -407,6 +407,8 @@ std::shared_ptr Endpoint::amSend( RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_callbackData->closing.load()) return nullptr; + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest( createRequestAm(endpoint, @@ -497,6 +499,8 @@ std::shared_ptr Endpoint::streamSend(void* buffer, size_t length, const bool enablePythonFuture) { + if (_callbackData->closing.load()) return nullptr; + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest( createRequestStream(endpoint, data::StreamSend(buffer, length), enablePythonFuture)); @@ -506,6 +510,8 @@ std::shared_ptr Endpoint::streamRecv(void* buffer, size_t length, const bool enablePythonFuture) { + if (_callbackData->closing.load()) return nullptr; + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest( createRequestStream(endpoint, data::StreamReceive(buffer, length), enablePythonFuture)); @@ -518,6 +524,8 @@ std::shared_ptr Endpoint::tagSend(void* buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_callbackData->closing.load()) return nullptr; + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestTag(endpoint, data::TagSend(buffer, length, tag), @@ -548,6 +556,8 @@ std::shared_ptr Endpoint::tagMultiSend(const std::vector& buffer const Tag tag, const bool enablePythonFuture) { + if (_callbackData->closing.load()) return nullptr; + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestTagMulti( endpoint, data::TagMultiSend(buffer, size, isCUDA, tag), enablePythonFuture)); diff --git a/cpp/tests/endpoint.cpp b/cpp/tests/endpoint.cpp index 16fa43708..76deacc88 100644 --- a/cpp/tests/endpoint.cpp +++ b/cpp/tests/endpoint.cpp @@ -5,12 +5,20 @@ #include #include +#include #include #include +#include +#include + +#include "include/utils.h" + namespace { +using ::testing::ContainerEq; + class EndpointTest : public ::testing::Test { protected: std::shared_ptr _context{ @@ -32,7 +40,7 @@ TEST_F(EndpointTest, HandleIsValid) auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); _worker->progress(); - ASSERT_TRUE(ep->getHandle() != nullptr); + ASSERT_NE(ep->getHandle(), nullptr); } TEST_F(EndpointTest, IsAlive) @@ -56,4 +64,142 @@ TEST_F(EndpointTest, IsAlive) ASSERT_FALSE(ep->isAlive()); } +TEST_F(EndpointTest, CancelCloseNonBlocking) +{ + auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + _worker->progress(); + + ASSERT_TRUE(ep->isAlive()); + + std::vector send(100 * 1024 * 1024, 42); + std::vector recv(send.size(), 0); + + // Submit and wait for transfers to complete + std::vector> requests; + requests.push_back(ep->tagSend(send.data(), send.size() * sizeof(int), ucxx::Tag{0})); + requests.push_back( + ep->tagRecv(recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); + // waitRequests(_worker, requests, _progressWorker); + + // copyResults(); + + // Assert data correctness + // ASSERT_THAT(_recv[0], ContainerEq(_send[0])); + + size_t inflightCount = 0; + for (const auto& req : requests) { + inflightCount += !req->isCompleted(); + } + std::cout << "inflightCount: " << inflightCount << std::endl; + + size_t tmpCount = 0; + do { + tmpCount = 0; + for (const auto& req : requests) { + tmpCount += req->isCompleted(); + req->cancel(); + } + _worker->progress(); + std::cout << "tmpCount: " << tmpCount << std::endl; + + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + } while (tmpCount != 2); + + auto canceling = ep->cancelInflightRequests(); + // ASSERT_EQ(canceling, inflightCount); + std::cout << canceling << std::endl; + + while (ep->getCancelingSize() > 0) { + std::cout << ep->getCancelingSize() << std::endl; + _worker->progress(); + + std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + for (const auto& req : requests) + std::cout << req->isCompleted() << " "; + std::cout << std::endl; + } + + auto close = ep->close(); + while (!close->isCompleted()) + _worker->progress(); + ASSERT_EQ(close->getStatus(), UCS_OK); +} + +TEST_F(EndpointTest, CloseNonBlockingCancel) +{ + auto progressWorker = getProgressFunction(_worker, ProgressMode::Blocking); + auto progressRemoteWorker = getProgressFunction(_remoteWorker, ProgressMode::Blocking); + auto progressAllWorkers = [&progressWorker, progressRemoteWorker]() { + progressWorker(); + progressRemoteWorker(); + }; + + auto ep = _worker->createEndpointFromWorkerAddress(_remoteWorker->getAddress()); + while (!ep->isAlive()) + progressWorker(); + + auto remoteEp = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + while (!remoteEp->isAlive()) + progressRemoteWorker(); + + std::vector send(100 * 1024 * 1024, 42); + std::vector recv(send.size(), 0); + + // Submit and wait for transfers to complete + std::vector> requests, remoteRequests; + requests.push_back(ep->tagSend(send.data(), send.size() * sizeof(int), ucxx::Tag{0})); + // requests.push_back(remoteEp->tagRecv(recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, + // ucxx::TagMaskFull)); + requests.push_back(_remoteWorker->tagRecv( + recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); + + // copyResults(); + + // Assert data correctness + // ASSERT_THAT(_recv[0], ContainerEq(_send[0])); + + // Endpoints that are _receiving_ may close before the request completes + waitRequests(_remoteWorker, requests, progressAllWorkers); + auto closeRemote = remoteEp->close(); + // waitRequests(_remoteWorker, requests, progressRemoteWorker); + waitSingleRequest(closeRemote, progressRemoteWorker); + ASSERT_FALSE(remoteEp->isAlive()); + + // Endpoints that are _sending_ may _not_ close before the request completes + auto close = ep->close(); + waitRequests(_worker, requests, progressWorker); + waitSingleRequest(close, progressRemoteWorker); + ASSERT_FALSE(ep->isAlive()); + + // requests.clear(); + // requests.push_back(ep->tagRecv(recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, + // ucxx::TagMaskFull)); waitRequests(_worker, requests, progressWorker); + + ASSERT_THAT(recv, ContainerEq(send)); + + // size_t inflightCount = 0; + // for (const auto& req : requests) { + // inflightCount += !req->isCompleted(); + // } + + // // ASSERT_EQ(canceling, inflightCount); + // std::cout << canceling << std::endl; + + // while (ep->getCancelingSize() > 0) { + // std::cout << ep->getCancelingSize() << std::endl; + // _worker->progress(); + + // std::this_thread::sleep_for(std::chrono::milliseconds(1000)); + + // for (const auto& req : requests) + // std::cout << req->isCompleted() << " "; + // std::cout << std::endl; + // } + + // ASSERT_EQ(close->getStatus(), UCS_OK); + + // auto canceling = ep->cancelInflightRequests(); +} + } // namespace diff --git a/cpp/tests/include/utils.h b/cpp/tests/include/utils.h index b98d59731..8d357e780 100644 --- a/cpp/tests/include/utils.h +++ b/cpp/tests/include/utils.h @@ -40,6 +40,15 @@ inline void waitRequests(std::shared_ptr worker, } } +template +inline void waitSingleRequest(const std::shared_ptr& request, + const std::function& progressWorker) +{ + while (!request->isCompleted()) + if (progressWorker) progressWorker(); + request->checkError(); +} + std::function getProgressFunction(std::shared_ptr worker, ProgressMode progressMode); From 230b81be242c1b9114556606f4466cc87ad078d5 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 14 May 2024 09:43:26 -0700 Subject: [PATCH 03/39] Allow querying inflight request count in `ucxx::InflightRequests` --- cpp/include/ucxx/inflight_requests.h | 10 ++++++++++ cpp/src/inflight_requests.cpp | 11 +++++++++++ 2 files changed, 21 insertions(+) diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index de27584cd..cb4af2088 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -167,6 +167,16 @@ class InflightRequests { * @returns The count of requests that are in process of cancelation. */ size_t getCancelingSize(); + + /** + * @brief Get count of inflight requests. + * + * Get the count of inflight requests that have not yet completed nor have been scheduled + * for cancelation. + * + * @returns The count of inflight requests. + */ + size_t getInflightSize(); }; } // namespace ucxx diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 9862c45e0..25a1675c2 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -105,6 +105,17 @@ size_t InflightRequests::getCancelingSize() return cancelingSize; } +size_t InflightRequests::getInflightSize() +{ + size_t inflightSize = 0; + { + std::scoped_lock lock{_mutex}; + inflightSize = _trackedRequests->_inflight->size(); + } + + return inflightSize; +} + size_t InflightRequests::cancelAll() { decltype(_trackedRequests->_inflight) toCancel; From 8641ec3315e54db52919b55a2014d6639f6b9cad Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 14 May 2024 09:47:05 -0700 Subject: [PATCH 04/39] Allow querying inflight request count in `ucxx::Endpoint` --- cpp/include/ucxx/endpoint.h | 10 ++++++++++ cpp/src/endpoint.cpp | 2 ++ 2 files changed, 12 insertions(+) diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 433dac8ab..73cf77131 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -289,6 +289,16 @@ class Endpoint : public Component { */ size_t getCancelingSize() const; + /** + * @brief Check the number of inflight requests waiting for completion. + * + * Check the number of inflight requests that were posted but have not yet completed nor + * have been scheduled for cancelation. + * + * @returns Number of inflight requests that are waiting for completion. + */ + size_t getInflightSize() const; + /** * @brief Cancel inflight requests. * diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 48140e8a1..2ed09c72f 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -398,6 +398,8 @@ size_t Endpoint::cancelInflightRequestsBlocking(uint64_t period, uint64_t maxAtt size_t Endpoint::getCancelingSize() const { return _inflightRequests->getCancelingSize(); } +size_t Endpoint::getInflightSize() const { return _inflightRequests->getInflightSize(); } + std::shared_ptr Endpoint::amSend( void* buffer, const size_t length, From 8f5d0b2d77a02afa8319858b151da02034df57bc Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 14 May 2024 10:13:05 -0700 Subject: [PATCH 05/39] Add methods to initiate `ucxx::Endpoint` stop process --- cpp/include/ucxx/endpoint.h | 66 +++++++++++++++++++++++++++++++++++++ cpp/src/endpoint.cpp | 22 ++++++++++--- 2 files changed, 83 insertions(+), 5 deletions(-) diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 73cf77131..1118be4bd 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -54,6 +54,7 @@ struct ErrorCallbackData { ///< `ucxx::Endpoint::setCloseCallback()`. ucs_status_t status{UCS_INPROGRESS}; ///< Endpoint status std::atomic closing{false}; ///< Prevent calling close multiple concurrent times. + std::atomic stopping{false}; ///< Signal whether endpoint is stopping. std::shared_ptr inflightRequests{nullptr}; ///< Endpoint inflight requests EndpointCloseCallbackUserFunction closeCallback{nullptr}; ///< Close callback to call EndpointCloseCallbackUserData closeCallbackArg{ @@ -360,6 +361,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @param[in] buffer a raw pointer to the data to be sent. * @param[in] length the size in bytes of the tag message to be sent. * @param[in] memoryType the memory type of the buffer. @@ -419,6 +422,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @param[in] buffer a raw pointer to the data to be sent. * @param[in] length the size in bytes of the tag message to be sent. * @param[in] remoteAddr the destination remote memory address to write to. @@ -449,6 +454,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @param[in] buffer a raw pointer to the data to be sent. * @param[in] length the size in bytes of the tag message to be sent. * @param[in] remoteKey the remote memory key associated with the remote memory @@ -481,6 +488,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @param[in] buffer a raw pointer to the data to be sent. * @param[in] length the size in bytes of the tag message to be sent. * @param[in] remoteAddr the source remote memory address to read from. @@ -511,6 +520,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @param[in] buffer a raw pointer to the data to be sent. * @param[in] length the size in bytes of the tag message to be sent. * @param[in] remoteKey the remote memory key associated with the remote memory @@ -543,6 +554,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @param[in] buffer a raw pointer to the data to be sent. * @param[in] length the size in bytes of the tag message to be sent. * @param[in] enablePythonFuture whether a python future should be created and @@ -564,6 +577,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @param[in] buffer a raw pointer to pre-allocated memory where resulting * data will be stored. * @param[in] length the size in bytes of the tag message to be received. @@ -586,6 +601,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @param[in] buffer a raw pointer to the data to be sent. * @param[in] length the size in bytes of the tag message to be sent. * @param[in] tag the tag to match. @@ -655,6 +672,7 @@ class Endpoint : public Component { * ensure the transfer has completed. Requires UCXX Python support. * * @throws std::runtime_error if sizes of `buffer`, `size` and `isCUDA` do not match. + * @throws ucxx::RejectedError if `stop()` was already called. * * @param[in] buffer a vector of raw pointers to the data frames to be sent. * @param[in] size a vector of size in bytes of each frame to be sent. @@ -817,6 +835,54 @@ class Endpoint : public Component { * if worker is running a progress thread and `period > 0`. */ void closeBlocking(uint64_t period = 0, uint64_t maxAttempts = 1); + + /** + * @brief Signal wish to close the endpoint, but does not close it. + * + * Signal the wish to close the endpoint without closing it such that no new requests can + * be issued that require the endpoint to complete. This method is useful when the user + * needs control of the non-blocking closing process with `close()`, and thus allow + * requests to complete before issuing the close request. Once this is called, the user + * may check for the number of inflight and canceling requests via `getInflightSize()` + * and `getCancelingSize()` methods, respectively, and issue the non-blocking close of + * the worker once both return `0` or after a certain period of time has elapsed and + * the application cannot wait anymore for their completion. + * + * After this is called certain requests are not anymore accepted because they require a + * valid endpoint to complete. The requests that are not accepted are: + * + * - `amSend()` + * - `memGet()` + * - `memPut()` + * - `streamSend()` + * - `streamRecv()` + * - `tagSend()` + * - `tagMultiSend()` + * + * If the user attempts to call any of the methods above after this method a + * `ucxx::RejectError` exception is thrown. The user is also able to check whether this + * method was already called by checking the result of `isStopping()`. + * + * The following requests are handled by the underlying `ucxx::Worker` and only matched + * to the `ucxx::Endpoint` object and thus can still be called after the present method + * is called to allow draining them: + * + * - `amRecv()` + * - `tagRecv()` + * - `tagMultiRecv()` + */ + void stop(); + + /** + * @brief Check whether the endpoint was signaled the wish to close. + * + * Check whether the endpoint was signaled the wish to close after calling `stop()`. This + * result is useful to identify the stage where the endpoint finds itself, if this method + * returns `True`, the process to ultimately close the endpoint has begun. + * + * @return Whether the endpoint was signaled the wish to close. + */ + bool isStopping(); }; } // namespace ucxx diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 2ed09c72f..7f9f0db26 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -409,7 +409,7 @@ std::shared_ptr Endpoint::amSend( RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { - if (_callbackData->closing.load()) return nullptr; + if (_callbackData->stopping.load()) throw RejectedError("Endpoint is stopping."); auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest( @@ -437,6 +437,8 @@ std::shared_ptr Endpoint::memGet(void* buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_callbackData->stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestMem(endpoint, data::MemGet(buffer, length, remoteAddr, rkey), @@ -453,6 +455,8 @@ std::shared_ptr Endpoint::memGet(void* buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_callbackData->stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestMem( endpoint, @@ -471,6 +475,8 @@ std::shared_ptr Endpoint::memPut(void* buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_callbackData->stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestMem(endpoint, data::MemPut(buffer, length, remoteAddr, rkey), @@ -487,6 +493,8 @@ std::shared_ptr Endpoint::memPut(void* buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_callbackData->stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestMem( endpoint, @@ -501,7 +509,7 @@ std::shared_ptr Endpoint::streamSend(void* buffer, size_t length, const bool enablePythonFuture) { - if (_callbackData->closing.load()) return nullptr; + if (_callbackData->stopping.load()) throw RejectedError("Endpoint is stopping."); auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest( @@ -512,7 +520,7 @@ std::shared_ptr Endpoint::streamRecv(void* buffer, size_t length, const bool enablePythonFuture) { - if (_callbackData->closing.load()) return nullptr; + if (_callbackData->stopping.load()) throw RejectedError("Endpoint is stopping."); auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest( @@ -526,7 +534,7 @@ std::shared_ptr Endpoint::tagSend(void* buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { - if (_callbackData->closing.load()) return nullptr; + if (_callbackData->stopping.load()) throw RejectedError("Endpoint is stopping."); auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestTag(endpoint, @@ -558,7 +566,7 @@ std::shared_ptr Endpoint::tagMultiSend(const std::vector& buffer const Tag tag, const bool enablePythonFuture) { - if (_callbackData->closing.load()) return nullptr; + if (_callbackData->stopping.load()) throw RejectedError("Endpoint is stopping."); auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestTagMulti( @@ -643,4 +651,8 @@ void Endpoint::errorCallback(void* arg, ucp_ep_h ep, ucs_status_t status) ucs_status_string(status)); } +void Endpoint::stop() { _callbackData->stopping = true; } + +bool Endpoint::isStopping() { return _callbackData->stopping.load(); } + } // namespace ucxx From 3c983c65a2cbb027917c82d28827cb3988f594f9 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 14 May 2024 10:18:29 -0700 Subject: [PATCH 06/39] Add missing `enablePythonFuture` defaults --- cpp/include/ucxx/endpoint.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 1118be4bd..0de560e23 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -563,7 +563,9 @@ class Endpoint : public Component { * * @returns Request to be subsequently checked for the completion and its state. */ - std::shared_ptr streamSend(void* buffer, size_t length, const bool enablePythonFuture); + std::shared_ptr streamSend(void* buffer, + size_t length, + const bool enablePythonFuture = false); /** * @brief Enqueue a stream receive operation. @@ -587,7 +589,9 @@ class Endpoint : public Component { * * @returns Request to be subsequently checked for the completion and its state. */ - std::shared_ptr streamRecv(void* buffer, size_t length, const bool enablePythonFuture); + std::shared_ptr streamRecv(void* buffer, + size_t length, + const bool enablePythonFuture = false); /** * @brief Enqueue a tag send operation. From e87baf5df5d0076ebb1d6d989b118bc912774007 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 14 May 2024 10:21:06 -0700 Subject: [PATCH 07/39] Fix missing status set in `ucxx::Endpoint::closeBlocking()` --- cpp/src/endpoint.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 7f9f0db26..746fa522c 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -270,7 +270,6 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) ucs_status_t s; while ((s = ucp_request_check_status(status)) == UCS_INPROGRESS) worker->progress(); - _callbackData->status = s; } else if (UCS_PTR_STATUS(status) != UCS_OK) { ucxx_error( "ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, Error while closing endpoint: %s", @@ -279,6 +278,7 @@ void Endpoint::closeBlocking(uint64_t period, uint64_t maxAttempts) _handle, ucs_status_string(UCS_PTR_STATUS(status))); } + _callbackData->status = UCS_PTR_STATUS(status); } ucxx_trace("ucxx::Endpoint::%s, Endpoint: %p, UCP handle: %p, closed", __func__, this, _handle); From 3b22f5458263b1f65f1ebbe7551b0f170c03fab0 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 14 May 2024 14:27:22 -0700 Subject: [PATCH 08/39] Request fixes --- cpp/src/request.cpp | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 981984ffe..95e90361c 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -63,6 +63,11 @@ Request::Request(std::shared_ptr endpointOrWorker, Request::~Request() { + if (UCS_PTR_IS_PTR(_request)) { + // ucxx_warn("ucxx::Request (%s) freeing: %p", _operationName.c_str(), _request); + ucp_request_free(_request); + } + // ucxx_warn("ucxx::Request destroyed (%s): %p", _operationName.c_str(), this); ucxx_trace("ucxx::Request destroyed (%s): %p", _operationName.c_str(), this); } @@ -80,8 +85,16 @@ void Request::cancel() status, ucs_status_string(status)); } else { - ucxx_trace_req_f(_ownerString.c_str(), this, _request, _operationName.c_str(), "canceling"); - if (_request != nullptr) ucp_request_cancel(_worker->getHandle(), _request); + if (_request != nullptr) { + ucs_status_t status = UCS_PTR_STATUS(_request); + // ucxx_warn("status1: %d (%s)", status, ucs_status_string(status)); + ucxx_trace_req_f(_ownerString.c_str(), this, _request, _operationName.c_str(), "canceling"); + ucp_request_cancel(_worker->getHandle(), _request); + status = UCS_PTR_STATUS(_request); + // ucxx_warn("status2: %d (%s)", status, ucs_status_string(status)); + // TODO: Cancel all send operations? amSend/streamSend probably do not need it. + if (_operationName == "tagSend") { setStatus(UCS_ERR_CANCELED); } + } } } else { ucxx_trace_req_f(_ownerString.c_str(), @@ -143,7 +156,11 @@ void Request::callback(void* request, ucs_status_t status) status, ucs_status_string(status)); - if (UCS_PTR_IS_PTR(_request)) ucp_request_free(request); + if (UCS_PTR_IS_PTR(_request)) { + // ucxx_warn("ucxx::Request (%s) callback: %p %p", _operationName.c_str(), _request, request); + ucp_request_free(request); + _request = nullptr; + } ucxx_trace_req_f(_ownerString.c_str(), this, _request, _operationName.c_str(), "completed"); setStatus(status); From 2d54ebd79d5d884c8c721e0e3e34d35cb1287fe5 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 15 May 2024 05:26:01 -0700 Subject: [PATCH 09/39] Remove request from `ucxx::InflightRequests` canceling map Additionally comment out use of `dropCanceled()` which is probably now redundant but requires confirmation. --- cpp/src/inflight_requests.cpp | 38 ++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 25a1675c2..371ca554c 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -33,6 +33,24 @@ void InflightRequests::merge(TrackedRequestsPtr trackedRequests) } } +static void findAndRemove(InflightRequestsMap* requestsMap, const Request* const request) +{ + auto search = requestsMap->find(request); + decltype(search->second) tmpRequest; + if (search != requestsMap->end()) { + /** + * If this is the last request to hold `std::shared_ptr` erasing it + * may cause the `ucxx::Endpoint`s destructor and subsequently the `closeBlocking()` + * method to be called which will in turn call `cancelAll()` and attempt to take the + * mutexes. For this reason we should make a temporary copy of the request being + * erased from `_trackedRequests->_inflight` to allow unlocking the mutexes and only then + * destroy the object upon this method's return. + */ + tmpRequest = search->second; + requestsMap->erase(search); + } +} + void InflightRequests::remove(const Request* const request) { do { @@ -51,20 +69,8 @@ void InflightRequests::remove(const Request* const request) if (result == 0) { return; } else if (result == -1) { - auto search = _trackedRequests->_inflight->find(request); - decltype(search->second) tmpRequest; - if (search != _trackedRequests->_inflight->end()) { - /** - * If this is the last request to hold `std::shared_ptr` erasing it - * may cause the `ucxx::Endpoint`s destructor and subsequently the `closeBlocking()` - * method to be called which will in turn call `cancelAll()` and attempt to take the - * mutexes. For this reason we should make a temporary copy of the request being - * erased from `_trackedRequests->_inflight` to allow unlocking the mutexes and only then - * destroy the object upon this method's return. - */ - tmpRequest = search->second; - _trackedRequests->_inflight->erase(search); - } + findAndRemove(_trackedRequests->_inflight.get(), request); + findAndRemove(_trackedRequests->_canceling.get(), request); _cancelMutex.unlock(); _mutex.unlock(); return; @@ -95,7 +101,7 @@ size_t InflightRequests::dropCanceled() size_t InflightRequests::getCancelingSize() { - dropCanceled(); + // dropCanceled(); size_t cancelingSize = 0; { std::scoped_lock lock{_cancelMutex}; @@ -142,7 +148,7 @@ size_t InflightRequests::cancelAll() std::scoped_lock lock{_cancelMutex, _mutex}; _trackedRequests->_canceling->merge(*toCancel); } - dropCanceled(); + // dropCanceled(); return total; } From 27ac23fa201b5d751f409d5371048b4d3ca679d9 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 15 May 2024 08:01:39 -0700 Subject: [PATCH 10/39] Do not track requests that completed immediately in `cancelAll` --- cpp/src/inflight_requests.cpp | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 371ca554c..456e36024 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -124,30 +124,22 @@ size_t InflightRequests::getInflightSize() size_t InflightRequests::cancelAll() { - decltype(_trackedRequests->_inflight) toCancel; - size_t total; - { - std::scoped_lock lock{_cancelMutex, _mutex}; - total = _trackedRequests->_inflight->size(); - - // Fast path when no requests have been registered or the map has been - // previously released. - if (total == 0) return 0; - - toCancel = std::exchange(_trackedRequests->_inflight, std::make_unique()); - } + std::scoped_lock lock{_cancelMutex, _mutex}; + auto total = _trackedRequests->_inflight->size(); - ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests", __func__, total); + // Fast path when no requests have been registered or the map has been + // previously released. + if (total == 0) return 0; - for (auto& r : *toCancel) { + for (auto& r : *_trackedRequests->_inflight) { auto request = r.second; - if (request != nullptr) { request->cancel(); } + if (request != nullptr) { + request->cancel(); + if (!request->isCompleted()) _trackedRequests->_canceling->insert({request.get(), request}); + } } + _trackedRequests->_inflight->clear(); - { - std::scoped_lock lock{_cancelMutex, _mutex}; - _trackedRequests->_canceling->merge(*toCancel); - } // dropCanceled(); return total; From a68798846c3bfdbdcc779a30ae3d53cc5c33c570 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 15 May 2024 08:55:42 -0700 Subject: [PATCH 11/39] Fix `RequestEndpointClose` force mode --- cpp/src/request_endpoint_close.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cpp/src/request_endpoint_close.cpp b/cpp/src/request_endpoint_close.cpp index 537f4f2fd..a4dca9ad0 100644 --- a/cpp/src/request_endpoint_close.cpp +++ b/cpp/src/request_endpoint_close.cpp @@ -64,7 +64,10 @@ void RequestEndpointClose::request() ucp_request_param_t param = { .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, .user_data = this}; - if (std::get(_requestData)._force) param.flags = UCP_EP_CLOSE_FLAG_FORCE; + if (std::get(_requestData)._force) { + param.op_attr_mask |= UCP_OP_ATTR_FIELD_FLAGS; + param.flags = UCP_EP_CLOSE_FLAG_FORCE; + } param.cb.send = endpointCloseCallback; if (_endpoint != nullptr) request = ucp_ep_close_nbx(_endpoint->getHandle(), ¶m); From dfe975e458ae3114968bae7d366b1217bbc209e5 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 14 May 2024 14:26:43 -0700 Subject: [PATCH 12/39] Endpoint cancelation tests --- cpp/tests/endpoint.cpp | 409 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 360 insertions(+), 49 deletions(-) diff --git a/cpp/tests/endpoint.cpp b/cpp/tests/endpoint.cpp index 76deacc88..b7f751bfb 100644 --- a/cpp/tests/endpoint.cpp +++ b/cpp/tests/endpoint.cpp @@ -3,6 +3,9 @@ * SPDX-License-Identifier: BSD-3-Clause */ #include +#include +#include +#include #include #include @@ -17,7 +20,9 @@ namespace { +using ::testing::Combine; using ::testing::ContainerEq; +using ::testing::Values; class EndpointTest : public ::testing::Test { protected: @@ -35,6 +40,81 @@ class EndpointTest : public ::testing::Test { } }; +enum class TransferType { Am, Tag, Stream }; +typedef std::vector> RequestContainer; + +class EndpointCancelTest : public ::testing::TestWithParam> { + protected: + std::shared_ptr _context{nullptr}; + std::shared_ptr _remoteContext{nullptr}; + std::shared_ptr _worker{nullptr}; + std::shared_ptr _remoteWorker{nullptr}; + TransferType _transferType{}; + size_t _messageSize{}; + RequestContainer _requests{}; + std::vector _send{}, _recv{}; + bool _rndv{false}; + + virtual void SetUp() + { + std::tie(_transferType, _messageSize, _rndv) = GetParam(); + + _send.resize(_messageSize); + _recv.resize(_messageSize, 0); + std::iota(_send.begin(), _send.end(), 0); + + _context = ucxx::createContext({}, ucxx::Context::defaultFeatureFlags); + _remoteContext = ucxx::createContext({}, ucxx::Context::defaultFeatureFlags); + _worker = _context->createWorker(); + _remoteWorker = _remoteContext->createWorker(); + } + + RequestContainer buildPair(std::shared_ptr sendEp, + std::shared_ptr recvEp) + { + if (_transferType == TransferType::Tag) { + return RequestContainer{ + sendEp->tagSend(_send.data(), _send.size() * sizeof(int), ucxx::Tag{0}), + recvEp->tagRecv(_recv.data(), _recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)}; + } else if (_transferType == TransferType::Am) { + return RequestContainer{ + sendEp->amSend(_send.data(), _send.size() * sizeof(int), UCS_MEMORY_TYPE_HOST), + recvEp->amRecv()}; + } else if (_transferType == TransferType::Stream) { + return RequestContainer{sendEp->streamSend(_send.data(), _send.size() * sizeof(int)), + recvEp->streamRecv(_recv.data(), _recv.size() * sizeof(int))}; + } + return RequestContainer{}; + } +}; + +static void wireup(std::shared_ptr ep1, + std::shared_ptr ep2, + std::function progressWorker) +{ + // wireup + std::vector wireupSend(1, 99); + std::vector wireupRecv(wireupSend.size(), 0); + + std::vector> wireupRequests; + wireupRequests.push_back( + ep1->tagSend(wireupSend.data(), wireupSend.size() * sizeof(int), ucxx::Tag{0})); + wireupRequests.push_back(ep2->tagRecv( + wireupRecv.data(), wireupRecv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); + + while (!wireupRequests[0]->isCompleted() || !wireupRequests[1]->isCompleted()) + progressWorker(); + + ASSERT_EQ(wireupRequests[0]->getStatus(), UCS_OK); + ASSERT_EQ(wireupRequests[1]->getStatus(), UCS_OK); + ASSERT_THAT(wireupRecv, ContainerEq(wireupSend)); +} + +static size_t countIncomplete(const std::vector>& requests) +{ + return std::count_if(requests.begin(), requests.end(), [](auto r) { return !r->isCompleted(); }); +} + TEST_F(EndpointTest, HandleIsValid) { auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); @@ -64,67 +144,67 @@ TEST_F(EndpointTest, IsAlive) ASSERT_FALSE(ep->isAlive()); } -TEST_F(EndpointTest, CancelCloseNonBlocking) -{ - auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); - _worker->progress(); +// TEST_F(EndpointTest, CancelCloseNonBlocking) +// { +// auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); +// _worker->progress(); - ASSERT_TRUE(ep->isAlive()); +// ASSERT_TRUE(ep->isAlive()); - std::vector send(100 * 1024 * 1024, 42); - std::vector recv(send.size(), 0); +// std::vector send(100 * 1024 * 1024, 42); +// std::vector recv(send.size(), 0); - // Submit and wait for transfers to complete - std::vector> requests; - requests.push_back(ep->tagSend(send.data(), send.size() * sizeof(int), ucxx::Tag{0})); - requests.push_back( - ep->tagRecv(recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); - // waitRequests(_worker, requests, _progressWorker); +// // Submit and wait for transfers to complete +// std::vector> requests; +// requests.push_back(ep->tagSend(send.data(), send.size() * sizeof(int), ucxx::Tag{0})); +// requests.push_back( +// ep->tagRecv(recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); +// // waitRequests(_worker, requests, _progressWorker); - // copyResults(); +// // copyResults(); - // Assert data correctness - // ASSERT_THAT(_recv[0], ContainerEq(_send[0])); +// // Assert data correctness +// // ASSERT_THAT(_recv[0], ContainerEq(_send[0])); - size_t inflightCount = 0; - for (const auto& req : requests) { - inflightCount += !req->isCompleted(); - } - std::cout << "inflightCount: " << inflightCount << std::endl; - - size_t tmpCount = 0; - do { - tmpCount = 0; - for (const auto& req : requests) { - tmpCount += req->isCompleted(); - req->cancel(); - } - _worker->progress(); - std::cout << "tmpCount: " << tmpCount << std::endl; +// size_t inflightCount = 0; +// for (const auto& req : requests) { +// inflightCount += !req->isCompleted(); +// } +// std::cout << "inflightCount: " << inflightCount << std::endl; - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - } while (tmpCount != 2); +// size_t tmpCount = 0; +// do { +// tmpCount = 0; +// for (const auto& req : requests) { +// tmpCount += req->isCompleted(); +// req->cancel(); +// } +// _worker->progress(); +// std::cout << "tmpCount: " << tmpCount << std::endl; - auto canceling = ep->cancelInflightRequests(); - // ASSERT_EQ(canceling, inflightCount); - std::cout << canceling << std::endl; +// std::this_thread::sleep_for(std::chrono::milliseconds(1000)); +// } while (tmpCount != 2); - while (ep->getCancelingSize() > 0) { - std::cout << ep->getCancelingSize() << std::endl; - _worker->progress(); +// auto canceling = ep->cancelInflightRequests(); +// // ASSERT_EQ(canceling, inflightCount); +// std::cout << canceling << std::endl; - std::this_thread::sleep_for(std::chrono::milliseconds(1000)); +// while (ep->getCancelingSize() > 0) { +// std::cout << ep->getCancelingSize() << std::endl; +// _worker->progress(); - for (const auto& req : requests) - std::cout << req->isCompleted() << " "; - std::cout << std::endl; - } +// std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - auto close = ep->close(); - while (!close->isCompleted()) - _worker->progress(); - ASSERT_EQ(close->getStatus(), UCS_OK); -} +// for (const auto& req : requests) +// std::cout << req->isCompleted() << " "; +// std::cout << std::endl; +// } + +// auto close = ep->close(); +// while (!close->isCompleted()) +// _worker->progress(); +// ASSERT_EQ(close->getStatus(), UCS_OK); +// } TEST_F(EndpointTest, CloseNonBlockingCancel) { @@ -202,4 +282,235 @@ TEST_F(EndpointTest, CloseNonBlockingCancel) // auto canceling = ep->cancelInflightRequests(); } +// TEST_F(EndpointTest, CancelTmp) +// { +// auto progressWorker = getProgressFunction(_worker, ProgressMode::Blocking); +// auto progressRemoteWorker = getProgressFunction(_remoteWorker, ProgressMode::Blocking); +// auto progressAllWorkers = [&progressWorker, progressRemoteWorker]() { +// progressWorker(); +// progressRemoteWorker(); +// }; + +// auto ep = _worker->createEndpointFromWorkerAddress(_remoteWorker->getAddress()); +// while (!ep->isAlive()) +// progressWorker(); + +// auto remoteEp = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); +// while (!remoteEp->isAlive()) +// progressRemoteWorker(); + +// std::vector send(100 * 1024 * 1024, 42); +// std::vector recv(send.size(), 0); + +// // Submit and wait for transfers to complete +// std::vector> requests, remoteRequests; +// requests.push_back(ep->tagSend(send.data(), send.size() * sizeof(int), ucxx::Tag{0})); +// remoteRequests.push_back(remoteEp->tagRecv(recv.data(), recv.size() * sizeof(int), +// ucxx::Tag{0}, ucxx::TagMaskFull)); +// // requests.push_back(_remoteWorker->tagRecv( +// // recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); + +// // copyResults(); + +// // Assert data correctness +// // ASSERT_THAT(_recv[0], ContainerEq(_send[0])); + +// // Endpoints that are _receiving_ may close before the request completes +// ASSERT_FALSE(requests[0]->isCompleted()); +// ASSERT_FALSE(remoteRequests[0]->isCompleted()); + +// // ep->cancelInflightRequests(); +// // remoteEp->cancelInflightRequests(); + +// // ASSERT_FALSE(requests[0]->isCompleted()); +// // ASSERT_FALSE(remoteRequests[0]->isCompleted()); + +// while (!requests[0]->isCompleted() || ! remoteRequests[0]->isCompleted()) { +// std::cout << "progress: " << requests[0]->isCompleted() << " " << +// remoteRequests[0]->isCompleted() << std::endl; progressWorker(); progressRemoteWorker(); +// std::this_thread::sleep_for(std::chrono::seconds(1)); +// } + +// ASSERT_TRUE(requests[0]->isCompleted()); +// ASSERT_TRUE(remoteRequests[0]->isCompleted()); +// ASSERT_EQ(requests[0]->getStatus(), UCS_OK); +// ASSERT_EQ(remoteRequests[0]->isCompleted(), UCS_OK); + +// ASSERT_THAT(recv, ContainerEq(send)); +// } + +TEST_F(EndpointTest, CancelSingleTmp) +{ + // auto progressWorker = getProgressFunction(_worker, ProgressMode::Blocking); + auto progressWorker = getProgressFunction(_worker, ProgressMode::Polling); + + auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + while (!ep->isAlive()) + progressWorker(); + + wireup(ep, ep, progressWorker); + + std::vector send(100 * 1024 * 1024, 42); + std::vector recv(send.size(), 0); + + // Submit and wait for transfers to complete + std::vector> requests; + requests.push_back(ep->tagSend(send.data(), send.size() * sizeof(int), ucxx::Tag{0})); + requests.push_back( + ep->tagRecv(recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); + + ASSERT_FALSE(requests[0]->isCompleted()); + ASSERT_FALSE(requests[1]->isCompleted()); + + progressWorker(); + ep->cancelInflightRequests(); + + // ASSERT_FALSE(requests[0]->isCompleted()); + // ASSERT_FALSE(requests[1]->isCompleted()); + + while (!requests[0]->isCompleted() || !requests[1]->isCompleted()) { + std::cout << "progress1: " << requests[0]->isCompleted() << " " << requests[1]->isCompleted() + << std::endl; + std::cout << "progress2: " << requests[0]->getStatus() << " " << requests[1]->getStatus() + << std::endl; + progressWorker(); + std::this_thread::sleep_for(std::chrono::seconds(1)); + } + std::cout << "progress3: " << requests[0]->isCompleted() << " " << requests[1]->isCompleted() + << std::endl; + std::cout << "progress4: " << requests[0]->getStatus() << " " << requests[1]->getStatus() + << std::endl; + + ASSERT_TRUE(requests[0]->isCompleted()); + ASSERT_TRUE(requests[1]->isCompleted()); + ASSERT_EQ(requests[0]->getStatus(), UCS_OK); + ASSERT_EQ(requests[1]->getStatus(), UCS_OK); + + ASSERT_THAT(recv, ContainerEq(send)); +} + +TEST_F(EndpointTest, StoppingRejection) +{ + auto progressWorker = getProgressFunction(_worker, ProgressMode::Blocking); + + auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + while (!ep->isAlive()) + progressWorker(); + + wireup(ep, ep, progressWorker); + + ep->stop(); + + ASSERT_TRUE(ep->isStopping()); + + std::vector tmp(10, 42); + + EXPECT_THROW(ep->amSend(tmp.data(), tmp.size() * sizeof(int), UCS_MEMORY_TYPE_HOST), + ucxx::RejectedError); + EXPECT_THROW(ep->tagSend(tmp.data(), tmp.size() * sizeof(int), ucxx::Tag{0}), + ucxx::RejectedError); + EXPECT_THROW(ep->streamRecv(tmp.data(), tmp.size() * sizeof(int)), ucxx::RejectedError); + EXPECT_THROW(ep->streamSend(tmp.data(), tmp.size() * sizeof(int)), ucxx::RejectedError); + + // TODO: tagMultiSend, memPut, memGet +} + +TEST_P(EndpointCancelTest, StoppingWaitCompletionThenCancel) +{ + if (_transferType == TransferType::Stream && _messageSize == 0) + GTEST_SKIP() << "Stream messages of size 0 are not supported."; + + auto progressWorker = getProgressFunction(_worker, ProgressMode::Blocking); + + auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + while (!ep->isAlive()) + progressWorker(); + + wireup(ep, ep, progressWorker); + + // Submit and wait for transfers to complete + auto requests = buildPair(ep, ep); + + auto checkAfterSubmit = [this, &requests, &ep]() { + // Check requests completion statuses + if (_rndv) { + std::for_each( + requests.begin(), requests.end(), [](auto r) { ASSERT_FALSE(r->isCompleted()); }); + } else { + ASSERT_TRUE(requests[0]->isCompleted()); + ASSERT_FALSE(requests[1]->isCompleted()); + } + + // Check no requests are being canceled + ASSERT_EQ(ep->getCancelingSize(), 0); + + // Check which requests are inflight or completed + if (_rndv) + ASSERT_EQ(ep->getInflightSize(), requests.size()); + else + // Eager send request completes immediately, receive request still inflight + ASSERT_EQ(ep->getInflightSize(), 1); + }; + + // Check request statuses before stopping endpoint + checkAfterSubmit(); + + // Stop accepting new requests, except those handled by the worker + ep->stop(); + ASSERT_TRUE(ep->isStopping()); + + // Check that requests statuses haven't changed + checkAfterSubmit(); + + // Cancel inflight requests + ep->cancelInflightRequests(); + ASSERT_EQ(ep->getCancelingSize(), countIncomplete(requests)); + ASSERT_EQ(ep->getInflightSize(), 0); + + // Wait for canceling requests to complete + while (countIncomplete(requests) > 0) + progressWorker(); + + // Check all requests have been canceled + std::for_each(requests.begin(), requests.end(), [](auto r) { + auto status = r->getStatus(); + ASSERT_TRUE(status == UCS_ERR_CANCELED || status == UCS_OK); + }); + + // Check received message, if it wasn't canceled + if (requests[1]->getStatus() == UCS_OK) { + // Copy AM results back into a `std::vector` which can be checked with `ASSERT_THAT` + if (_transferType == TransferType::Am) { + auto recvBuffer = requests[1]->getRecvBuffer(); + std::copy(reinterpret_cast(recvBuffer->data()), + reinterpret_cast(recvBuffer->data()) + _recv.size(), + _recv.begin()); + } + + ASSERT_THAT(_recv, ContainerEq(_send)); + } + + // Check no more tracked requests exist + ASSERT_EQ(ep->getCancelingSize(), 0); + ASSERT_EQ(ep->getInflightSize(), 0); + + auto close = ep->close(); + ASSERT_NE(close, nullptr); + while (!close->isCompleted()) + progressWorker(); + ASSERT_EQ(close->getStatus(), UCS_OK); +} + +INSTANTIATE_TEST_SUITE_P(Eager, + EndpointCancelTest, + Combine(Values(TransferType::Tag, TransferType::Am, TransferType::Stream), + Values(0, 1, 10), + Values(false))); + +INSTANTIATE_TEST_SUITE_P(Rndv, + EndpointCancelTest, + Combine(Values(TransferType::Tag, TransferType::Am, TransferType::Stream), + Values(10485760, 104857600), + Values(true))); + } // namespace From 1fee6b46ca933d64916bae3e5c1cba3e32bf4680 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 15 May 2024 14:27:27 -0700 Subject: [PATCH 13/39] Notify callback when all inflight/canceling requests complete --- cpp/include/ucxx/inflight_requests.h | 25 ++++++- cpp/include/ucxx/typedefs.h | 2 + cpp/src/inflight_requests.cpp | 97 ++++++++++++++++++++++------ 3 files changed, 102 insertions(+), 22 deletions(-) diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index cb4af2088..0a938f6d8 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -9,6 +9,8 @@ #include #include +#include + namespace ucxx { class Request; @@ -131,9 +133,18 @@ class InflightRequests { * the raw pointer address is used as key to the requests reference, and this is called * called from the object's destructor. * + * Supports an optional callback function to be called exclusively if there are no + * more requests inflight or canceling. Be advised that before the callback is called the + * mutex that controls inflight requests is released to prevent deadlocks in case the + * callback happens to register a new inflight request, therefore there's no guarantee + * that another inflight request won't be registered between the time in which the mutex + * is released and the callback is executed. + * * @param[in] request raw pointer to the request + * @param[in] callbackFunction function to be called upon termination and only if no + * further requests inflight or canceling remain. */ - void remove(const Request* const request); + void remove(const Request* const request, GenericCallbackUserFunction callbackFunction = nullptr); /** * @brief Issue cancelation of all inflight requests and clear the internal container. @@ -141,9 +152,19 @@ class InflightRequests { * Issue cancelation of all inflight requests known to this object and clear the * internal container. The total number of canceled requests is returned. * + * Supports an optional callback function to be called exclusively if there are no + * more requests inflight or canceling. Be advised that before the callback is called the + * mutex that controls inflight requests is released to prevent deadlocks in case the + * callback happens to register a new inflight request, therefore there's no guarantee + * that another inflight request won't be registered between the time in which the mutex + * is released and the callback is executed. + * + * @param[in] callbackFunction function to be called upon termination and only if no + * further requests inflight or canceling remain. + * * @returns The total number of canceled requests. */ - size_t cancelAll(); + size_t cancelAll(GenericCallbackUserFunction callbackFunction = nullptr); /** * @brief Releases the internally-tracked containers. diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index 2cbf7960a..aa81940bc 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -162,4 +162,6 @@ class AmReceiverCallbackInfo { typedef const std::string SerializedRemoteKey; +typedef std::function GenericCallbackUserFunction; + } // namespace ucxx diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 456e36024..a120c1cb4 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -51,9 +51,10 @@ static void findAndRemove(InflightRequestsMap* requestsMap, const Request* const } } -void InflightRequests::remove(const Request* const request) +void InflightRequests::remove(const Request* const request, + GenericCallbackUserFunction cancelInflightCallback) { - do { + while (true) { int result = std::try_lock(_cancelMutex, _mutex); /** @@ -71,11 +72,31 @@ void InflightRequests::remove(const Request* const request) } else if (result == -1) { findAndRemove(_trackedRequests->_inflight.get(), request); findAndRemove(_trackedRequests->_canceling.get(), request); - _cancelMutex.unlock(); + + size_t trackedRequestsCount = + _trackedRequests->_inflight->size() + _trackedRequests->_canceling->size(); + + /** + * Unlock `_mutex` before calling the user callback to prevent deadlocks in case the + * user callback happens to register another inflight request. + */ _mutex.unlock(); - return; + try { + if (cancelInflightCallback && trackedRequestsCount == 0) { + ucxx_debug("ucxx::InflightRequests::%s: %p, calling user cancel inflight callback", + __func__, + this); + cancelInflightCallback(); + } + _cancelMutex.unlock(); + return; + } catch (const std::exception& e) { + ucxx_warn("Exception in callback"); + _cancelMutex.unlock(); + throw(e); + } } - } while (true); + } } size_t InflightRequests::dropCanceled() @@ -122,27 +143,63 @@ size_t InflightRequests::getInflightSize() return inflightSize; } -size_t InflightRequests::cancelAll() +size_t InflightRequests::cancelAll(GenericCallbackUserFunction cancelInflightCallback) { - std::scoped_lock lock{_cancelMutex, _mutex}; - auto total = _trackedRequests->_inflight->size(); + size_t total = 0; + + while (true) { + /** + * -1: both mutexes were locked. + */ + if (std::try_lock(_cancelMutex, _mutex) == -1) { + auto total = _trackedRequests->_inflight->size(); + + // Fast path when no requests have been registered or the map has been + // previously released. + if (total == 0) break; + + // ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests", __func__, total); + + for (auto& r : *_trackedRequests->_inflight) { + auto request = r.second; + if (request != nullptr) { + request->cancel(); + if (!request->isCompleted()) { + _trackedRequests->_canceling->insert({request.get(), request}); + } else { + auto status = request->getStatus(); + } + } + } + _trackedRequests->_inflight->clear(); - // Fast path when no requests have been registered or the map has been - // previously released. - if (total == 0) return 0; + // dropCanceled(); - for (auto& r : *_trackedRequests->_inflight) { - auto request = r.second; - if (request != nullptr) { - request->cancel(); - if (!request->isCompleted()) _trackedRequests->_canceling->insert({request.get(), request}); + break; } } - _trackedRequests->_inflight->clear(); - // dropCanceled(); - - return total; + size_t trackedRequestsCount = + _trackedRequests->_inflight->size() + _trackedRequests->_canceling->size(); + + /** + * Unlock `_mutex` before calling the user callback to prevent deadlocks in case the + * user callback happens to register another inflight request. + */ + _mutex.unlock(); + try { + if (cancelInflightCallback && trackedRequestsCount == 0) { + ucxx_debug( + "ucxx::InflightRequests::%s: %p, calling user cancel inflight callback", __func__, this); + cancelInflightCallback(); + } + _cancelMutex.unlock(); + return total; + } catch (const std::exception& e) { + ucxx_warn("Exception in callback"); + _cancelMutex.unlock(); + throw(e); + } } TrackedRequestsPtr InflightRequests::release() From 4dc81c51c762925f37c69273d08f0b3f0b5ba3c4 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 15 May 2024 14:50:41 -0700 Subject: [PATCH 14/39] Add user callback to `ucxx::Endpoint::cancelInflightRequests()` --- cpp/include/ucxx/endpoint.h | 19 ++++++++++++++++++- cpp/src/endpoint.cpp | 22 +++++++++++++++++++--- 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 0de560e23..46fcb6ad6 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -87,6 +87,11 @@ class Endpoint : public Component { nullptr}; ///< Data struct to pass to endpoint error handling callback std::shared_ptr _inflightRequests{ std::make_shared()}; ///< The inflight requests + GenericCallbackUserFunction _cancelInflightCallback{ + nullptr}; ///< The wrapper to the callback registered via `cancelInflightRequests()` that will + ///< deregister once the callback is called. + GenericCallbackUserFunction _cancelInflightCallbackOriginal{ + nullptr}; ///< The original user callback registered via `cancelInflightRequests()` /** * @brief Private constructor of `ucxx::Endpoint`. @@ -275,9 +280,21 @@ class Endpoint : public Component { * progress the worker and check the result of `getCancelingSize()`, all requests are only * canceled when `getCancelingSize()` returns `0`. * + * Supports an optional callback function to be called exclusively if there are no + * more requests inflight or canceling. Be advised that before the callback is called the + * mutex that controls inflight requests is released to prevent deadlocks in case the + * callback happens to register a new inflight request, therefore there's no guarantee + * that another inflight request won't be registered between the time in which the mutex + * is released and the callback is executed, the user is thus responsible to prevent such + * situations and the use of `stop()` before `cancelInflightRequests()` is highly + * advisable. + * + * @param[in] callbackFunction function to be called upon termination and only if no + * further requests inflight or canceling remain. + * * @returns Number of requests that were scheduled for cancelation. */ - size_t cancelInflightRequests(); + size_t cancelInflightRequests(GenericCallbackUserFunction callbackFunction = nullptr); /** * @brief Check the number of inflight requests being canceled. diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 746fa522c..79d16c951 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -341,9 +341,11 @@ std::shared_ptr Endpoint::registerInflightRequest(std::shared_ptrstatus != UCS_INPROGRESS) + if (_callbackData->status != UCS_INPROGRESS && + std::dynamic_pointer_cast(request) == nullptr) _callbackData->worker->scheduleRequestCancel(_inflightRequests->release()); return request; @@ -354,7 +356,21 @@ void Endpoint::removeInflightRequest(const Request* const request) _inflightRequests->remove(request); } -size_t Endpoint::cancelInflightRequests() { return _inflightRequests->cancelAll(); } +size_t Endpoint::cancelInflightRequests(GenericCallbackUserFunction callback) +{ + _cancelInflightCallbackOriginal = callback; + + // Wrapper responsible for deregistering after callback is called. + _cancelInflightCallback = [this]() { + if (_cancelInflightCallbackOriginal != nullptr) _cancelInflightCallbackOriginal(); + _cancelInflightCallback = nullptr; + _cancelInflightCallbackOriginal = nullptr; + }; + + auto canceled = _inflightRequests->cancelAll(_cancelInflightCallback); + + return canceled; +} size_t Endpoint::cancelInflightRequestsBlocking(uint64_t period, uint64_t maxAttempts) { From 0c2b72d1f01ef88fa49edcc1c8fba09e01df8f55 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 15 May 2024 14:54:39 -0700 Subject: [PATCH 15/39] Endpoint test cancel inflight requests callback --- cpp/tests/endpoint.cpp | 71 ++++++++++++++++++++++++++++++------------ 1 file changed, 51 insertions(+), 20 deletions(-) diff --git a/cpp/tests/endpoint.cpp b/cpp/tests/endpoint.cpp index b7f751bfb..1991be8b4 100644 --- a/cpp/tests/endpoint.cpp +++ b/cpp/tests/endpoint.cpp @@ -17,6 +17,7 @@ #include #include "include/utils.h" +#include "ucxx/typedefs.h" namespace { @@ -43,12 +44,14 @@ class EndpointTest : public ::testing::Test { enum class TransferType { Am, Tag, Stream }; typedef std::vector> RequestContainer; -class EndpointCancelTest : public ::testing::TestWithParam> { +class EndpointCancelTest + : public ::testing::TestWithParam> { protected: std::shared_ptr _context{nullptr}; std::shared_ptr _remoteContext{nullptr}; std::shared_ptr _worker{nullptr}; std::shared_ptr _remoteWorker{nullptr}; + ProgressMode _progressMode{}; TransferType _transferType{}; size_t _messageSize{}; RequestContainer _requests{}; @@ -57,7 +60,7 @@ class EndpointCancelTest : public ::testing::TestWithParamstartProgressThread(true); + else if (_progressMode == ProgressMode::ThreadBlocking) + _worker->startProgressThread(false); auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); while (!ep->isAlive()) @@ -438,18 +445,25 @@ TEST_P(EndpointCancelTest, StoppingWaitCompletionThenCancel) requests.begin(), requests.end(), [](auto r) { ASSERT_FALSE(r->isCompleted()); }); } else { ASSERT_TRUE(requests[0]->isCompleted()); - ASSERT_FALSE(requests[1]->isCompleted()); + // In thread progress mode it's not possible to determine completion time + if (_progressMode != ProgressMode::ThreadBlocking && + _progressMode != ProgressMode::ThreadPolling) + ASSERT_FALSE(requests[1]->isCompleted()); } // Check no requests are being canceled ASSERT_EQ(ep->getCancelingSize(), 0); // Check which requests are inflight or completed - if (_rndv) + if (_rndv) { ASSERT_EQ(ep->getInflightSize(), requests.size()); - else - // Eager send request completes immediately, receive request still inflight - ASSERT_EQ(ep->getInflightSize(), 1); + } else { + // Eager send request completes immediately, receive request still inflight, except + // for thread progress mode where it's not possible to determine completion time + if (_progressMode != ProgressMode::ThreadBlocking && + _progressMode != ProgressMode::ThreadPolling) + ASSERT_EQ(ep->getInflightSize(), 1); + } }; // Check request statuses before stopping endpoint @@ -462,16 +476,31 @@ TEST_P(EndpointCancelTest, StoppingWaitCompletionThenCancel) // Check that requests statuses haven't changed checkAfterSubmit(); + bool cancelationComplete = false; + auto cancelInflightCallback = [&ep, &progressWorker, &cancelationComplete]() { + // Now that cancelation is complete, closing the endpoint is safe + auto close = ep->close(); + ASSERT_NE(close, nullptr); + while (!close->isCompleted()) + progressWorker(); + ASSERT_EQ(close->getStatus(), UCS_OK); + + cancelationComplete = true; + }; + // Cancel inflight requests - ep->cancelInflightRequests(); + ep->cancelInflightRequests(cancelInflightCallback); ASSERT_EQ(ep->getCancelingSize(), countIncomplete(requests)); ASSERT_EQ(ep->getInflightSize(), 0); - // Wait for canceling requests to complete - while (countIncomplete(requests) > 0) + // Wait for canceling requests to complete and `cancelInflightCallback` to run + while (!cancelationComplete) progressWorker(); - // Check all requests have been canceled + // `cancelInflightCallback` executed an the endpoint should be closed now. + ASSERT_FALSE(ep->isAlive()); + + // Check all requests have been canceled or completed std::for_each(requests.begin(), requests.end(), [](auto r) { auto status = r->getStatus(); ASSERT_TRUE(status == UCS_ERR_CANCELED || status == UCS_OK); @@ -493,23 +522,25 @@ TEST_P(EndpointCancelTest, StoppingWaitCompletionThenCancel) // Check no more tracked requests exist ASSERT_EQ(ep->getCancelingSize(), 0); ASSERT_EQ(ep->getInflightSize(), 0); - - auto close = ep->close(); - ASSERT_NE(close, nullptr); - while (!close->isCompleted()) - progressWorker(); - ASSERT_EQ(close->getStatus(), UCS_OK); } INSTANTIATE_TEST_SUITE_P(Eager, EndpointCancelTest, - Combine(Values(TransferType::Tag, TransferType::Am, TransferType::Stream), + Combine(Values(ProgressMode::Polling, + ProgressMode::Blocking, + ProgressMode::ThreadPolling, + ProgressMode::ThreadBlocking), + Values(TransferType::Tag, TransferType::Am, TransferType::Stream), Values(0, 1, 10), Values(false))); INSTANTIATE_TEST_SUITE_P(Rndv, EndpointCancelTest, - Combine(Values(TransferType::Tag, TransferType::Am, TransferType::Stream), + Combine(Values(ProgressMode::Polling, + ProgressMode::Blocking, + ProgressMode::ThreadPolling, + ProgressMode::ThreadBlocking), + Values(TransferType::Tag, TransferType::Am, TransferType::Stream), Values(10485760, 104857600), Values(true))); From 82e6f5fbb82ce8516865303fe2a240d4d9354fe8 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 16 May 2024 01:58:46 -0700 Subject: [PATCH 16/39] Use `scoped_lock`s in `ucxx::InflightRequests` --- cpp/src/inflight_requests.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index a120c1cb4..e28271642 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -16,7 +16,7 @@ size_t InflightRequests::size() { return _trackedRequests->_inflight->size(); } void InflightRequests::insert(std::shared_ptr request) { - std::lock_guard lock(_mutex); + std::scoped_lock lock(_mutex); _trackedRequests->_inflight->insert({request.get(), request}); } @@ -24,11 +24,11 @@ void InflightRequests::insert(std::shared_ptr request) void InflightRequests::merge(TrackedRequestsPtr trackedRequests) { { - std::lock_guard lock(_mutex); + std::scoped_lock lock(_mutex); _trackedRequests->_inflight->merge(*(trackedRequests->_inflight)); } { - std::lock_guard lock(_cancelMutex); + std::scoped_lock lock(_cancelMutex); _trackedRequests->_canceling->merge(*(trackedRequests->_canceling)); } } From 15dbb91eeef5b835a6e8f978154c01e2d8266490 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 16 May 2024 02:28:29 -0700 Subject: [PATCH 17/39] Print exception in callback warning message --- cpp/src/inflight_requests.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index e28271642..0d8b205ae 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -91,7 +91,7 @@ void InflightRequests::remove(const Request* const request, _cancelMutex.unlock(); return; } catch (const std::exception& e) { - ucxx_warn("Exception in callback"); + ucxx_warn("Exception in callback: %s", e.what()); _cancelMutex.unlock(); throw(e); } @@ -196,7 +196,7 @@ size_t InflightRequests::cancelAll(GenericCallbackUserFunction cancelInflightCal _cancelMutex.unlock(); return total; } catch (const std::exception& e) { - ucxx_warn("Exception in callback"); + ucxx_warn("Exception in callback: %s", e.what()); _cancelMutex.unlock(); throw(e); } From 621abc46b15a565d9940008f33fced0d9b381e36 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 16 May 2024 02:29:36 -0700 Subject: [PATCH 18/39] Remove `ucxx::InflightRequests::dropCanceled()` --- cpp/include/ucxx/inflight_requests.h | 9 --------- cpp/src/inflight_requests.cpp | 24 ------------------------ 2 files changed, 33 deletions(-) diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index 0a938f6d8..84c02eec1 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -71,15 +71,6 @@ class InflightRequests { std::mutex _cancelMutex{}; ///< Mutex to allow cancelation and prevent removing requests simultaneously - /** - * @brief Drop references to requests that completed cancelation. - * - * Drops references to requests that completed cancelation and stop tracking them. - * - * @returns The number of requests that have completed cancelation since last call. - */ - size_t dropCanceled(); - public: /** * @brief Default constructor. diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 0d8b205ae..cf60548cb 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -99,30 +99,8 @@ void InflightRequests::remove(const Request* const request, } } -size_t InflightRequests::dropCanceled() -{ - size_t removed = 0; - - { - std::scoped_lock lock{_cancelMutex}; - for (auto it = _trackedRequests->_canceling->begin(); - it != _trackedRequests->_canceling->end();) { - auto request = it->second; - if (request != nullptr && request->getStatus() != UCS_INPROGRESS) { - it = _trackedRequests->_canceling->erase(it); - ++removed; - } else { - ++it; - } - } - } - - return removed; -} - size_t InflightRequests::getCancelingSize() { - // dropCanceled(); size_t cancelingSize = 0; { std::scoped_lock lock{_cancelMutex}; @@ -173,8 +151,6 @@ size_t InflightRequests::cancelAll(GenericCallbackUserFunction cancelInflightCal } _trackedRequests->_inflight->clear(); - // dropCanceled(); - break; } } From 981469ae127a9bb7032616c22c29b850bd756ca4 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 16 May 2024 02:35:29 -0700 Subject: [PATCH 19/39] `ucxx::InflightRequests` cleanup --- cpp/src/inflight_requests.cpp | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index cf60548cb..1a4a64564 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -83,7 +83,7 @@ void InflightRequests::remove(const Request* const request, _mutex.unlock(); try { if (cancelInflightCallback && trackedRequestsCount == 0) { - ucxx_debug("ucxx::InflightRequests::%s: %p, calling user cancel inflight callback", + ucxx_trace("ucxx::InflightRequests::%s: %p, calling user cancel inflight callback", __func__, this); cancelInflightCallback(); @@ -126,9 +126,7 @@ size_t InflightRequests::cancelAll(GenericCallbackUserFunction cancelInflightCal size_t total = 0; while (true) { - /** - * -1: both mutexes were locked. - */ + // -1: both mutexes were locked. if (std::try_lock(_cancelMutex, _mutex) == -1) { auto total = _trackedRequests->_inflight->size(); @@ -136,8 +134,6 @@ size_t InflightRequests::cancelAll(GenericCallbackUserFunction cancelInflightCal // previously released. if (total == 0) break; - // ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests", __func__, total); - for (auto& r : *_trackedRequests->_inflight) { auto request = r.second; if (request != nullptr) { @@ -165,7 +161,7 @@ size_t InflightRequests::cancelAll(GenericCallbackUserFunction cancelInflightCal _mutex.unlock(); try { if (cancelInflightCallback && trackedRequestsCount == 0) { - ucxx_debug( + ucxx_trace( "ucxx::InflightRequests::%s: %p, calling user cancel inflight callback", __func__, this); cancelInflightCallback(); } From dfdcb90adc16b5b37743d084d6a826e887b42437 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 16 May 2024 03:54:00 -0700 Subject: [PATCH 20/39] Add missing rejection tests and improve comments --- cpp/tests/endpoint.cpp | 282 +++++------------------------------------ 1 file changed, 33 insertions(+), 249 deletions(-) diff --git a/cpp/tests/endpoint.cpp b/cpp/tests/endpoint.cpp index 1991be8b4..a4d3de6fc 100644 --- a/cpp/tests/endpoint.cpp +++ b/cpp/tests/endpoint.cpp @@ -17,6 +17,7 @@ #include #include "include/utils.h" +#include "ucxx/exception.h" #include "ucxx/typedefs.h" namespace { @@ -147,252 +148,7 @@ TEST_F(EndpointTest, IsAlive) ASSERT_FALSE(ep->isAlive()); } -// TEST_F(EndpointTest, CancelCloseNonBlocking) -// { -// auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); -// _worker->progress(); - -// ASSERT_TRUE(ep->isAlive()); - -// std::vector send(100 * 1024 * 1024, 42); -// std::vector recv(send.size(), 0); - -// // Submit and wait for transfers to complete -// std::vector> requests; -// requests.push_back(ep->tagSend(send.data(), send.size() * sizeof(int), ucxx::Tag{0})); -// requests.push_back( -// ep->tagRecv(recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); -// // waitRequests(_worker, requests, _progressWorker); - -// // copyResults(); - -// // Assert data correctness -// // ASSERT_THAT(_recv[0], ContainerEq(_send[0])); - -// size_t inflightCount = 0; -// for (const auto& req : requests) { -// inflightCount += !req->isCompleted(); -// } -// std::cout << "inflightCount: " << inflightCount << std::endl; - -// size_t tmpCount = 0; -// do { -// tmpCount = 0; -// for (const auto& req : requests) { -// tmpCount += req->isCompleted(); -// req->cancel(); -// } -// _worker->progress(); -// std::cout << "tmpCount: " << tmpCount << std::endl; - -// std::this_thread::sleep_for(std::chrono::milliseconds(1000)); -// } while (tmpCount != 2); - -// auto canceling = ep->cancelInflightRequests(); -// // ASSERT_EQ(canceling, inflightCount); -// std::cout << canceling << std::endl; - -// while (ep->getCancelingSize() > 0) { -// std::cout << ep->getCancelingSize() << std::endl; -// _worker->progress(); - -// std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - -// for (const auto& req : requests) -// std::cout << req->isCompleted() << " "; -// std::cout << std::endl; -// } - -// auto close = ep->close(); -// while (!close->isCompleted()) -// _worker->progress(); -// ASSERT_EQ(close->getStatus(), UCS_OK); -// } - -TEST_F(EndpointTest, CloseNonBlockingCancel) -{ - auto progressWorker = getProgressFunction(_worker, ProgressMode::Blocking); - auto progressRemoteWorker = getProgressFunction(_remoteWorker, ProgressMode::Blocking); - auto progressAllWorkers = [&progressWorker, progressRemoteWorker]() { - progressWorker(); - progressRemoteWorker(); - }; - - auto ep = _worker->createEndpointFromWorkerAddress(_remoteWorker->getAddress()); - while (!ep->isAlive()) - progressWorker(); - - auto remoteEp = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); - while (!remoteEp->isAlive()) - progressRemoteWorker(); - - std::vector send(100 * 1024 * 1024, 42); - std::vector recv(send.size(), 0); - - // Submit and wait for transfers to complete - std::vector> requests, remoteRequests; - requests.push_back(ep->tagSend(send.data(), send.size() * sizeof(int), ucxx::Tag{0})); - // requests.push_back(remoteEp->tagRecv(recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, - // ucxx::TagMaskFull)); - requests.push_back(_remoteWorker->tagRecv( - recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); - - // copyResults(); - - // Assert data correctness - // ASSERT_THAT(_recv[0], ContainerEq(_send[0])); - - // Endpoints that are _receiving_ may close before the request completes - waitRequests(_remoteWorker, requests, progressAllWorkers); - auto closeRemote = remoteEp->close(); - // waitRequests(_remoteWorker, requests, progressRemoteWorker); - waitSingleRequest(closeRemote, progressRemoteWorker); - ASSERT_FALSE(remoteEp->isAlive()); - - // Endpoints that are _sending_ may _not_ close before the request completes - auto close = ep->close(); - waitRequests(_worker, requests, progressWorker); - waitSingleRequest(close, progressRemoteWorker); - ASSERT_FALSE(ep->isAlive()); - - // requests.clear(); - // requests.push_back(ep->tagRecv(recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, - // ucxx::TagMaskFull)); waitRequests(_worker, requests, progressWorker); - - ASSERT_THAT(recv, ContainerEq(send)); - - // size_t inflightCount = 0; - // for (const auto& req : requests) { - // inflightCount += !req->isCompleted(); - // } - - // // ASSERT_EQ(canceling, inflightCount); - // std::cout << canceling << std::endl; - - // while (ep->getCancelingSize() > 0) { - // std::cout << ep->getCancelingSize() << std::endl; - // _worker->progress(); - - // std::this_thread::sleep_for(std::chrono::milliseconds(1000)); - - // for (const auto& req : requests) - // std::cout << req->isCompleted() << " "; - // std::cout << std::endl; - // } - - // ASSERT_EQ(close->getStatus(), UCS_OK); - - // auto canceling = ep->cancelInflightRequests(); -} - -// TEST_F(EndpointTest, CancelTmp) -// { -// auto progressWorker = getProgressFunction(_worker, ProgressMode::Blocking); -// auto progressRemoteWorker = getProgressFunction(_remoteWorker, ProgressMode::Blocking); -// auto progressAllWorkers = [&progressWorker, progressRemoteWorker]() { -// progressWorker(); -// progressRemoteWorker(); -// }; - -// auto ep = _worker->createEndpointFromWorkerAddress(_remoteWorker->getAddress()); -// while (!ep->isAlive()) -// progressWorker(); - -// auto remoteEp = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); -// while (!remoteEp->isAlive()) -// progressRemoteWorker(); - -// std::vector send(100 * 1024 * 1024, 42); -// std::vector recv(send.size(), 0); - -// // Submit and wait for transfers to complete -// std::vector> requests, remoteRequests; -// requests.push_back(ep->tagSend(send.data(), send.size() * sizeof(int), ucxx::Tag{0})); -// remoteRequests.push_back(remoteEp->tagRecv(recv.data(), recv.size() * sizeof(int), -// ucxx::Tag{0}, ucxx::TagMaskFull)); -// // requests.push_back(_remoteWorker->tagRecv( -// // recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); - -// // copyResults(); - -// // Assert data correctness -// // ASSERT_THAT(_recv[0], ContainerEq(_send[0])); - -// // Endpoints that are _receiving_ may close before the request completes -// ASSERT_FALSE(requests[0]->isCompleted()); -// ASSERT_FALSE(remoteRequests[0]->isCompleted()); - -// // ep->cancelInflightRequests(); -// // remoteEp->cancelInflightRequests(); - -// // ASSERT_FALSE(requests[0]->isCompleted()); -// // ASSERT_FALSE(remoteRequests[0]->isCompleted()); - -// while (!requests[0]->isCompleted() || ! remoteRequests[0]->isCompleted()) { -// std::cout << "progress: " << requests[0]->isCompleted() << " " << -// remoteRequests[0]->isCompleted() << std::endl; progressWorker(); progressRemoteWorker(); -// std::this_thread::sleep_for(std::chrono::seconds(1)); -// } - -// ASSERT_TRUE(requests[0]->isCompleted()); -// ASSERT_TRUE(remoteRequests[0]->isCompleted()); -// ASSERT_EQ(requests[0]->getStatus(), UCS_OK); -// ASSERT_EQ(remoteRequests[0]->isCompleted(), UCS_OK); - -// ASSERT_THAT(recv, ContainerEq(send)); -// } - -TEST_F(EndpointTest, CancelSingleTmp) -{ - // auto progressWorker = getProgressFunction(_worker, ProgressMode::Blocking); - auto progressWorker = getProgressFunction(_worker, ProgressMode::Polling); - - auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); - while (!ep->isAlive()) - progressWorker(); - - wireup(ep, ep, progressWorker); - - std::vector send(100 * 1024 * 1024, 42); - std::vector recv(send.size(), 0); - - // Submit and wait for transfers to complete - std::vector> requests; - requests.push_back(ep->tagSend(send.data(), send.size() * sizeof(int), ucxx::Tag{0})); - requests.push_back( - ep->tagRecv(recv.data(), recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); - - ASSERT_FALSE(requests[0]->isCompleted()); - ASSERT_FALSE(requests[1]->isCompleted()); - - progressWorker(); - ep->cancelInflightRequests(); - - // ASSERT_FALSE(requests[0]->isCompleted()); - // ASSERT_FALSE(requests[1]->isCompleted()); - - while (!requests[0]->isCompleted() || !requests[1]->isCompleted()) { - std::cout << "progress1: " << requests[0]->isCompleted() << " " << requests[1]->isCompleted() - << std::endl; - std::cout << "progress2: " << requests[0]->getStatus() << " " << requests[1]->getStatus() - << std::endl; - progressWorker(); - std::this_thread::sleep_for(std::chrono::seconds(1)); - } - std::cout << "progress3: " << requests[0]->isCompleted() << " " << requests[1]->isCompleted() - << std::endl; - std::cout << "progress4: " << requests[0]->getStatus() << " " << requests[1]->getStatus() - << std::endl; - - ASSERT_TRUE(requests[0]->isCompleted()); - ASSERT_TRUE(requests[1]->isCompleted()); - ASSERT_EQ(requests[0]->getStatus(), UCS_OK); - ASSERT_EQ(requests[1]->getStatus(), UCS_OK); - - ASSERT_THAT(recv, ContainerEq(send)); -} - -TEST_F(EndpointTest, StoppingRejection) +TEST_F(EndpointTest, StoppingRejectRequests) { auto progressWorker = getProgressFunction(_worker, ProgressMode::Blocking); @@ -415,7 +171,32 @@ TEST_F(EndpointTest, StoppingRejection) EXPECT_THROW(ep->streamRecv(tmp.data(), tmp.size() * sizeof(int)), ucxx::RejectedError); EXPECT_THROW(ep->streamSend(tmp.data(), tmp.size() * sizeof(int)), ucxx::RejectedError); - // TODO: tagMultiSend, memPut, memGet + { + auto memoryHandle = _context->createMemoryHandle(tmp.size() * sizeof(int), nullptr); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(ep, serializedRemoteKey); + + std::vector> requests; + EXPECT_THROW(ep->memPut(tmp.data(), tmp.size() * sizeof(int), remoteKey), ucxx::RejectedError); + EXPECT_THROW( + ep->memPut( + tmp.data(), tmp.size() * sizeof(int), remoteKey->getBaseAddress(), remoteKey->getHandle()), + ucxx::RejectedError); + EXPECT_THROW(ep->memGet(tmp.data(), tmp.size() * sizeof(int), remoteKey), ucxx::RejectedError); + EXPECT_THROW( + ep->memGet( + tmp.data(), tmp.size() * sizeof(int), remoteKey->getBaseAddress(), remoteKey->getHandle()), + ucxx::RejectedError); + } + + { + std::vector buffers{tmp.data()}; + std::vector sizes{tmp.size()}; + std::vector isCUDA{false}; + EXPECT_THROW(ep->tagMultiSend(buffers, sizes, isCUDA, ucxx::Tag{0}), ucxx::RejectedError); + } } TEST_P(EndpointCancelTest, StoppingWaitCompletionThenCancel) @@ -423,19 +204,22 @@ TEST_P(EndpointCancelTest, StoppingWaitCompletionThenCancel) if (_transferType == TransferType::Stream && _messageSize == 0) GTEST_SKIP() << "Stream messages of size 0 are not supported."; + // Get appropriate progress worker function depending on selected mode auto progressWorker = getProgressFunction(_worker, _progressMode); if (_progressMode == ProgressMode::ThreadPolling) _worker->startProgressThread(true); else if (_progressMode == ProgressMode::ThreadBlocking) _worker->startProgressThread(false); + // Create endpoint auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); while (!ep->isAlive()) progressWorker(); + // Perform endpoint wireup wireup(ep, ep, progressWorker); - // Submit and wait for transfers to complete + // Submit transfer requests auto requests = buildPair(ep, ep); auto checkAfterSubmit = [this, &requests, &ep]() { @@ -478,7 +262,7 @@ TEST_P(EndpointCancelTest, StoppingWaitCompletionThenCancel) bool cancelationComplete = false; auto cancelInflightCallback = [&ep, &progressWorker, &cancelationComplete]() { - // Now that cancelation is complete, closing the endpoint is safe + // No more inflight or canceling requests, closing the endpoint is safe auto close = ep->close(); ASSERT_NE(close, nullptr); while (!close->isCompleted()) From cef478b43667851af3d5b814740d4e58ad8bb47c Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 16 May 2024 04:16:01 -0700 Subject: [PATCH 21/39] Update `ucxx::Endpoint::close()` docstring --- cpp/include/ucxx/endpoint.h | 40 ++++++++++++++++++++++++++----------- 1 file changed, 28 insertions(+), 12 deletions(-) diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 46fcb6ad6..e9ca9158a 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -711,7 +711,7 @@ class Endpoint : public Component { const std::vector& size, const std::vector& isCUDA, const Tag tag, - const bool enablePythonFuture); + const bool enablePythonFuture = false); /** * @brief Enqueue a multi-buffer tag receive operation. @@ -737,7 +737,7 @@ class Endpoint : public Component { */ std::shared_ptr tagMultiRecv(const Tag tag, const TagMask tagMask, - const bool enablePythonFuture); + const bool enablePythonFuture = false); /** * @brief Enqueue a flush operation. @@ -793,7 +793,23 @@ class Endpoint : public Component { * * Enqueue a non-blocking endpoint close operation, which will close the endpoint without * requiring to destroy the object. This may be useful when other - * `std::shared_ptr` objects are still alive, such as inflight transfers. + * `std::shared_ptr` objects are still alive, such as inflight transfers, + * or the user wants to have more control over cancelation and closing order. + * + * @warning Unlike its `closeBlocking()` counterpart, this method does not cancel any + * inflight requests prior to submitting the UCP close request. Before scheduling the + * endpoint close request, the caller is advised to first call `stop()` to prevent new + * requests that require an active endpoint from being registered and once `stop()` is + * called, the user may call `cancelInflightRequests()` specifying a callback that can + * be used to submit a `close()` request, or may check for the number of inflight and + * canceling requests via `getInflightSize()` and `getCancelingSize()` methods, + * respectively, and issue the non-blocking `close()` of the worker once both return + * `0` or after a certain period of time has elapsed and the application cannot wait + * anymore for their completion. Note that `cancelInflightRequests()` callback is not + * guaranteed to be called, nor are `getCancelingSize()` and `getInflightSize()` + * guaranteed to go to `0` depending on the requests being handled, and thus the user is + * advised to provide a forceful termination mechanism in case the requests can never + * complete. * * This method returns a `std::shared` that can be later awaited and * checked for errors. This is a non-blocking operation, and the status of closing the @@ -814,11 +830,6 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the close operation has completed. Requires UCXX Python support. * - * @warning Unlike its `closeBlocking()` counterpart, this method does not cancel any - * inflight requests prior to submitting the UCP close request. Before scheduling the - * endpoint close request, the caller must first call `cancelInflightRequests()` and - * progress the worker until `getCancelingSize()` returns `0`. - * * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. * @param[in] callbackFunction user-defined callback function to call upon completion. @@ -864,10 +875,15 @@ class Endpoint : public Component { * be issued that require the endpoint to complete. This method is useful when the user * needs control of the non-blocking closing process with `close()`, and thus allow * requests to complete before issuing the close request. Once this is called, the user - * may check for the number of inflight and canceling requests via `getInflightSize()` - * and `getCancelingSize()` methods, respectively, and issue the non-blocking close of - * the worker once both return `0` or after a certain period of time has elapsed and - * the application cannot wait anymore for their completion. + * may call `cancelInflightRequests()` specifying a callback that can be used to submit + * a `close()` request, or may check for the number of inflight and canceling requests + * via `getInflightSize()` and `getCancelingSize()` methods, respectively, and issue the + * non-blocking `close()` of the worker once both return `0` or after a certain period of + * time has elapsed and the application cannot wait anymore for their completion. Note + * that `cancelInflightRequests()` callback is not guaranteed to be called, nor are + * `getCancelingSize()` and `getInflightSize()` guaranteed to go to `0` depending on the + * requests being handled, and thus the user is advised to provide a forceful termination + * mechanism in case the requests can never complete. * * After this is called certain requests are not anymore accepted because they require a * valid endpoint to complete. The requests that are not accepted are: From 5a72fe1854c0ee937bc7fd7c0dd040730bab21ee Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 16 May 2024 04:52:04 -0700 Subject: [PATCH 22/39] Update `ucxx::Requests` comments and logging --- cpp/src/request.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 95e90361c..8de2ab5ad 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -64,10 +64,9 @@ Request::Request(std::shared_ptr endpointOrWorker, Request::~Request() { if (UCS_PTR_IS_PTR(_request)) { - // ucxx_warn("ucxx::Request (%s) freeing: %p", _operationName.c_str(), _request); + ucxx_warn("ucxx::Request (%s) freeing: %p", _operationName.c_str(), _request); ucp_request_free(_request); } - // ucxx_warn("ucxx::Request destroyed (%s): %p", _operationName.c_str(), this); ucxx_trace("ucxx::Request destroyed (%s): %p", _operationName.c_str(), this); } @@ -87,12 +86,17 @@ void Request::cancel() } else { if (_request != nullptr) { ucs_status_t status = UCS_PTR_STATUS(_request); - // ucxx_warn("status1: %d (%s)", status, ucs_status_string(status)); ucxx_trace_req_f(_ownerString.c_str(), this, _request, _operationName.c_str(), "canceling"); ucp_request_cancel(_worker->getHandle(), _request); status = UCS_PTR_STATUS(_request); - // ucxx_warn("status2: %d (%s)", status, ucs_status_string(status)); - // TODO: Cancel all send operations? amSend/streamSend probably do not need it. + + /** + * Tag send requests cannot be canceled: https://github.com/openucx/ucx/issues/1162 + * This can be problematic for unmatched rendezvous tag send messages as it would + * otherwise not complete cancelation, so we forcefully "cancel" the requests which + * ultimately leads it reaching `ucp_request_free`. This currently causes the UCX + * warnings "was not returned to mpool ucp_requests", which is likely a UCX bug. + */ if (_operationName == "tagSend") { setStatus(UCS_ERR_CANCELED); } } } @@ -157,7 +161,6 @@ void Request::callback(void* request, ucs_status_t status) ucs_status_string(status)); if (UCS_PTR_IS_PTR(_request)) { - // ucxx_warn("ucxx::Request (%s) callback: %p %p", _operationName.c_str(), _request, request); ucp_request_free(request); _request = nullptr; } From e0206f0e71eec935c9c2f185a9a5b3c9ddf7b404 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 16 May 2024 08:35:37 -0700 Subject: [PATCH 23/39] Add missing callback in `ucxx::Endpoint::removeInflightRequest` --- cpp/src/endpoint.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 656edffb3..a86118c22 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -353,7 +353,7 @@ std::shared_ptr Endpoint::registerInflightRequest(std::shared_ptrremove(request); + _inflightRequests->remove(request, _cancelInflightCallback); } size_t Endpoint::cancelInflightRequests(GenericCallbackUserFunction callback) From 65671d282ce1bdae422d375568b1a5ea52de3a23 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 17 May 2024 00:04:07 -0700 Subject: [PATCH 24/39] Add `ucxx::Endpoint` reference back to `ucxx::EndpointErrorCallback` --- cpp/include/ucxx/endpoint.h | 3 ++- cpp/src/endpoint.cpp | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 11ab77912..b992af0e1 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -61,7 +61,8 @@ struct ErrorCallbackData { nullptr}; ///< Argument to be passed to close callback std::shared_ptr worker{nullptr}; ///< Worker the endpoint has been created from - ErrorCallbackData(std::shared_ptr inflightRequests, + ErrorCallbackData(std::shared_ptr endpoint, + std::shared_ptr inflightRequests, std::shared_ptr worker); ErrorCallbackData() = delete; diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index a86118c22..65d1f36ad 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -34,9 +34,10 @@ namespace ucxx { -ErrorCallbackData::ErrorCallbackData(std::shared_ptr inflightRequests, +ErrorCallbackData::ErrorCallbackData(std::shared_ptr endpoint, + std::shared_ptr inflightRequests, std::shared_ptr worker) - : inflightRequests(inflightRequests), worker(worker) + : endpoint(endpoint), inflightRequests(inflightRequests), worker(worker) { } @@ -68,7 +69,8 @@ Endpoint::Endpoint(std::shared_ptr workerOrListener, bool endpointErr void Endpoint::create(ucp_ep_params_t* params) { auto worker = ::ucxx::getWorker(_parent); - _callbackData = std::make_unique(_inflightRequests, worker); + _callbackData = std::make_unique( + std::dynamic_pointer_cast(shared_from_this()), _inflightRequests, worker); if (_endpointErrorHandling) { params->err_mode = UCP_ERR_HANDLING_MODE_PEER; From 71a8f1274765feb8939cbf567327962d122535ad Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 17 May 2024 03:25:24 -0700 Subject: [PATCH 25/39] Issue `ucxx::Request::cancel()` in worker thread when active --- cpp/include/ucxx/request.h | 15 +++++++++++++++ cpp/include/ucxx/utils/callback_notifier.h | 2 ++ cpp/src/request.cpp | 18 +++++++++++++++++- 3 files changed, 34 insertions(+), 1 deletion(-) diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index c164ecf67..476e44c63 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -16,6 +16,7 @@ #include #include #include +#include #define ucxx_trace_req_f(_owner, _req, _handle, _name, _message, ...) \ ucxx_trace_req("ucxx::Request: %p on %s, UCP handle: %p, op: %s, " _message, \ @@ -54,6 +55,10 @@ class Request : public Component { "request_undefined"}; ///< Human-readable operation name, mostly used for log messages std::recursive_mutex _mutex{}; ///< Mutex to prevent checking status while it's being set bool _enablePythonFuture{true}; ///< Whether Python future is enabled for this request + DelayedSubmissionCallbackType + _cancelCallback{}; ///< Callback function to use when canceling in thread progress mode. + utils::CallbackNotifier _cancelCallbackNotifier{}; ///< Callback notifier to use when canceling + ///< in thread progress mode. /** * @brief Protected constructor of an abstract `ucxx::Request`. @@ -100,6 +105,16 @@ class Request : public Component { */ void setStatus(ucs_status_t status); + private: + /** + * @brief Implementation of the request cancelatron. + * + * Cancel the request, called by `cancel()` which will submit it for execution from the + * worker progress thread when active as it is unsafe to do so from the application thread + * given it `ucp_request_cancel` requires the UCS spinlock. When no worker progress thread + */ + void cancelImpl(); + public: Request() = delete; Request(const Request&) = delete; diff --git a/cpp/include/ucxx/utils/callback_notifier.h b/cpp/include/ucxx/utils/callback_notifier.h index 354210828..7d9188d47 100644 --- a/cpp/include/ucxx/utils/callback_notifier.h +++ b/cpp/include/ucxx/utils/callback_notifier.h @@ -2,6 +2,8 @@ * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ +#pragma once + #include #include #include diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 8de2ab5ad..39c2471a6 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -63,6 +63,7 @@ Request::Request(std::shared_ptr endpointOrWorker, Request::~Request() { + if (_cancelCallback != nullptr) { _cancelCallbackNotifier.wait(1000000000 /* 1s */); } if (UCS_PTR_IS_PTR(_request)) { ucxx_warn("ucxx::Request (%s) freeing: %p", _operationName.c_str(), _request); ucp_request_free(_request); @@ -70,7 +71,7 @@ Request::~Request() ucxx_trace("ucxx::Request destroyed (%s): %p", _operationName.c_str(), this); } -void Request::cancel() +void Request::cancelImpl() { std::lock_guard lock(_mutex); if (_status == UCS_INPROGRESS) { @@ -109,6 +110,21 @@ void Request::cancel() _status, ucs_status_string(_status)); } + + _cancelCallback = nullptr; +} + +void Request::cancel() +{ + if (_worker->isProgressThreadRunning()) { + _cancelCallback = [this]() { + cancelImpl(); + _cancelCallbackNotifier.set(); + }; + _worker->registerGenericPre(_cancelCallback); + } else { + cancelImpl(); + } } ucs_status_t Request::getStatus() From 21dd0a98ef05a915615cf31679b60594802cc058 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 17 May 2024 10:42:52 -0700 Subject: [PATCH 26/39] Remove inflight request if its status has been already set --- cpp/include/ucxx/request.h | 13 ++++++++++++- cpp/src/request.cpp | 15 +++++++++++++-- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index 476e44c63..05e4c2a47 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -107,7 +107,18 @@ class Request : public Component { private: /** - * @brief Implementation of the request cancelatron. + * @brief Remove reference to request from endpoint and worker. + * + * Remove the reference to the request from the endpoint and worker. This should be called + * when a request has completed and the parent `ucxx::Endpoint` or `ucxx::Worker` does not + * need to keep track of it anymore. This is called during `setStatus()` and also during + * `cancel()` in case the request was scheduled for cancelation while it completed, which + * may have duplicated the canceling request tracking. + */ + void removeInflightRequest(); + + /** + * @brief Implementation of the request cancelation. * * Cancel the request, called by `cancel()` which will submit it for execution from the * worker progress thread when active as it is unsafe to do so from the application thread diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 39c2471a6..de7182514 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -71,6 +71,12 @@ Request::~Request() ucxx_trace("ucxx::Request destroyed (%s): %p", _operationName.c_str(), this); } +void Request::removeInflightRequest() +{ + if (_endpoint != nullptr) _endpoint->removeInflightRequest(this); + _worker->removeInflightRequest(this); +} + void Request::cancelImpl() { std::lock_guard lock(_mutex); @@ -109,6 +115,12 @@ void Request::cancelImpl() "already completed with status: %d (%s)", _status, ucs_status_string(_status)); + + /** + * Ensure the request is removed from the parent in case it got re-registered while it + * was completing. + */ + removeInflightRequest(); } _cancelCallback = nullptr; @@ -233,8 +245,7 @@ void Request::setStatus(ucs_status_t status) { std::lock_guard lock(_mutex); - if (_endpoint != nullptr) _endpoint->removeInflightRequest(this); - _worker->removeInflightRequest(this); + removeInflightRequest(); ucxx_trace_req_f(_ownerString.c_str(), this, From 6d113a3d45cfd2caa1a800b3a5037485d35723ca Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 17 May 2024 13:49:56 -0700 Subject: [PATCH 27/39] Keep track of `ucxx::Request` upon `findAndRemove()` return --- cpp/src/inflight_requests.cpp | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 1a4a64564..af1770dbc 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -33,22 +33,26 @@ void InflightRequests::merge(TrackedRequestsPtr trackedRequests) } } -static void findAndRemove(InflightRequestsMap* requestsMap, const Request* const request) +static InflightRequestsMapPtr findAndRemove(InflightRequestsMap* requestsMap, + const Request* const request) { - auto search = requestsMap->find(request); - decltype(search->second) tmpRequest; + auto removed = std::make_unique(); + auto search = requestsMap->find(request); if (search != requestsMap->end()) { /** * If this is the last request to hold `std::shared_ptr` erasing it * may cause the `ucxx::Endpoint`s destructor and subsequently the `closeBlocking()` * method to be called which will in turn call `cancelAll()` and attempt to take the * mutexes. For this reason we should make a temporary copy of the request being - * erased from `_trackedRequests->_inflight` to allow unlocking the mutexes and only then - * destroy the object upon this method's return. + * erased from `_trackedRequests->_inflight` in `removed` to allow the caller to unlock + * the mutexes and only then destroy the object. */ - tmpRequest = search->second; + removed->insert({request, search->second}); + requestsMap->erase(search); } + + return removed; } void InflightRequests::remove(const Request* const request, @@ -70,8 +74,13 @@ void InflightRequests::remove(const Request* const request, if (result == 0) { return; } else if (result == -1) { - findAndRemove(_trackedRequests->_inflight.get(), request); - findAndRemove(_trackedRequests->_canceling.get(), request); + /** + * Retain references to removed pointers to prevent their refcounts from going to + * while locks are held, which may trigger a chain effect and cause `this` itself + * from destroying and thus call `cancelAll()` which will then cause a deadlock. + */ + auto removedInflight = findAndRemove(_trackedRequests->_inflight.get(), request); + auto removedCanceling = findAndRemove(_trackedRequests->_canceling.get(), request); size_t trackedRequestsCount = _trackedRequests->_inflight->size() + _trackedRequests->_canceling->size(); From 110c8d19fdb562a64b0d754aca394470b67f6090 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 20 May 2024 03:13:07 -0700 Subject: [PATCH 28/39] Register `ucxx::Endpoint` during `createRequestAm()` --- cpp/src/request_am.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index cae3cea49..99c22c4c7 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -123,7 +123,9 @@ std::shared_ptr createRequestAm( return std::shared_ptr(new RequestAm( endpoint, amReceive, "amReceive", enablePythonFuture, callbackFunction, callbackData)); }; - return worker->getAmRecv(endpoint->getHandle(), createRequest); + auto req = worker->getAmRecv(endpoint->getHandle(), createRequest); + req->_endpoint = endpoint; + return req; }, }, requestData); From 50072c1a8a4586db7e10a5e8e444a4fa5bc14e63 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 20 May 2024 12:23:38 -0700 Subject: [PATCH 29/39] Check and reset `ucxx::Request::_cancelCallback --- cpp/src/request.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index de7182514..bb927aec1 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -63,7 +63,10 @@ Request::Request(std::shared_ptr endpointOrWorker, Request::~Request() { - if (_cancelCallback != nullptr) { _cancelCallbackNotifier.wait(1000000000 /* 1s */); } + if (_cancelCallback != nullptr) { + auto completed = _cancelCallbackNotifier.wait(1000000000 /* 1s */); + _cancelCallback = nullptr; + } if (UCS_PTR_IS_PTR(_request)) { ucxx_warn("ucxx::Request (%s) freeing: %p", _operationName.c_str(), _request); ucp_request_free(_request); @@ -130,6 +133,12 @@ void Request::cancel() { if (_worker->isProgressThreadRunning()) { _cancelCallback = [this]() { + /** + * FIXME: Check the callback hasn't run and object hasn't been destroyed. Long-term + * fix is to allow deregistering generic callbacks with the worker. + */ + if (_cancelCallback == nullptr) return; + cancelImpl(); _cancelCallbackNotifier.set(); }; From cae6d04ed72bf31781b181110b2262f8e9679017 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 31 Mar 2026 10:06:32 -0700 Subject: [PATCH 30/39] Refactor InflightRequests --- cpp/include/ucxx/inflight_requests.h | 99 +++++----- cpp/include/ucxx/worker.h | 2 +- cpp/src/inflight_requests.cpp | 261 ++++++++++++++++----------- cpp/src/worker.cpp | 6 +- 4 files changed, 205 insertions(+), 163 deletions(-) diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index 10b2b9d50..78cc0eb4a 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -4,75 +4,72 @@ */ #pragma once -#include +#include #include #include -#include +#include +#include namespace ucxx { class Request; /** - * @brief An inflight request map. + * @brief Container for transferring tracked requests between InflightRequests instances. * - * A map of inflight requests, where keys are a shared pointer to the request and - * value is the reference-counted `ucxx::Request`, using owner-based comparison. + * Used by `InflightRequests::release()` and `InflightRequests::merge()` to move + * request ownership between instances (e.g., from an endpoint to the worker during + * endpoint close). */ -typedef std:: - map, std::shared_ptr, std::owner_less>> - InflightRequestsMap; - -/** - * @brief A container for the different types of tracked requests. - * - * A container encapsulating the different types of handled tracked requests, currently - * those still valid (inflight), and those scheduled for cancelation (canceling). - */ -typedef struct TrackedRequests { - InflightRequestsMap _inflight{}; ///< Valid requests awaiting completion. - InflightRequestsMap _canceling{}; ///< Requests scheduled for cancelation. - std::mutex _mutex{}; ///< Mutex to control access to inflight requests container - std::mutex - _cancelMutex{}; ///< Mutex to allow cancelation and prevent removing requests simultaneously -} TrackedRequests; - -/** - * @brief Pre-defined type for a pointer to a container of tracked requests. - * - * A pre-defined type for a pointer to a container of tracked requests, used as a - * convenience type. - */ -typedef std::unique_ptr TrackedRequestsPtr; +struct TrackedRequests { + std::vector> inflight{}; ///< Valid requests awaiting completion. + std::vector> canceling{}; ///< Requests scheduled for cancelation. +}; /** * @brief Handle tracked requests. * * Handle tracked requests, providing functionality so that its owner can modify those * requests, performing operations such as insertion, removal and cancelation. + * + * Two container backends are available, selected at construction time via the + * `UCXX_INFLIGHT_REQUESTS_BACKEND` environment variable: + * - `vector` (default): best latency for small request counts, cache-friendly, + * no per-insert heap allocation, O(n) removal. + * - `map`: O(1) amortized insert/remove, scales to thousands of concurrent + * inflight requests. */ class InflightRequests { private: - TrackedRequestsPtr _trackedRequests{ - std::make_unique()}; ///< Container storing pointers to all inflight - ///< and in cancelation process requests known to - ///< the owner of this object - std::recursive_mutex _mutex{}; ///< Mutex to control access to class resources + bool _useMap{false}; - /** - * @brief Drop references to requests that completed cancelation. - * - * Drops references to requests that completed cancelation and stop tracking them. - * - * @returns The number of requests that have completed cancelation since last call. - */ - size_t dropCanceled(); + std::vector> _inflightVec{}; + std::vector> _cancelingVec{}; + + std::unordered_map> _inflightMap{}; + std::unordered_map> _cancelingMap{}; + + std::mutex _mutex{}; + + void _doInsert(const std::shared_ptr& request); + void _doRemove(const std::shared_ptr& request); + size_t _doInflightSize() const; + std::vector> _doTakeInflight(); + void _doPutCanceling(std::vector>* requests); + size_t _doDropCanceled(); + size_t _doCancelingSize() const; + void _doMergeInflight(std::vector>* requests); + void _doMergeCanceling(std::vector>* requests); + std::vector> _doTakeCanceling(); public: /** - * @brief Default constructor. + * @brief Construct with backend selected from UCXX_INFLIGHT_REQUESTS_BACKEND env var. + * + * Reads the `UCXX_INFLIGHT_REQUESTS_BACKEND` environment variable to select the + * container backend. Valid values are `vector` (default) and `map`. */ - InflightRequests() = default; + InflightRequests(); InflightRequests(const InflightRequests&) = delete; InflightRequests& operator=(InflightRequests const&) = delete; @@ -94,11 +91,11 @@ class InflightRequests { [[nodiscard]] size_t size(); /** - * @brief Insert an inflight requests to the container. + * @brief Insert an inflight request into the container. * * @param[in] request a `std::shared_ptr` with the inflight request. */ - void insert(std::shared_ptr request); + void insert(const std::shared_ptr& request); /** * @brief Merge containers of inflight requests with the internal containers. @@ -109,7 +106,7 @@ class InflightRequests { * @param[in] trackedRequests containers of tracked inflight requests to merge with the * internal tracked inflight requests. */ - void merge(TrackedRequestsPtr trackedRequests); + void merge(TrackedRequests&& trackedRequests); /** * @brief Remove an inflight request from the internal container. @@ -120,7 +117,7 @@ class InflightRequests { * * @param[in] request shared pointer to the request */ - void remove(std::shared_ptr request); + void remove(const std::shared_ptr& request); /** * @brief Issue cancelation of all inflight requests and clear the internal container. @@ -139,9 +136,9 @@ class InflightRequests { * `InflightRequests` object with `InflightRequests::merge()`. Effectively leaves the * internal state as a clean, new object. * - * @returns The internally-tracked containers. + * @returns The tracked requests as vectors for transfer. */ - [[nodiscard]] TrackedRequestsPtr release(); + [[nodiscard]] TrackedRequests release(); /** * @brief Get count of requests in process of cancelation. diff --git a/cpp/include/ucxx/worker.h b/cpp/include/ucxx/worker.h index 231c0dea4..7a1a3f4c3 100644 --- a/cpp/include/ucxx/worker.h +++ b/cpp/include/ucxx/worker.h @@ -656,7 +656,7 @@ class Worker : public Component { * @param[in] trackedRequests the requests tracked by a child of this class to be * scheduled for cancelation. */ - void scheduleRequestCancel(TrackedRequestsPtr trackedRequests); + void scheduleRequestCancel(TrackedRequests trackedRequests); /** * @brief Remove reference to request from internal container. diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 07d8ab50d..fecd7d445 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -2,7 +2,13 @@ * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ +#include +#include +#include #include +#include +#include +#include #include #include @@ -10,160 +16,199 @@ namespace ucxx { -InflightRequests::~InflightRequests() { cancelAll(); } +// ---- Private backend helpers ------------------------------------------------ -size_t InflightRequests::size() +void InflightRequests::_doInsert(const std::shared_ptr& request) { - std::scoped_lock localLock{_mutex}; - std::lock_guard lock(_trackedRequests->_mutex); - return _trackedRequests->_inflight.size(); + if (_useMap) + _inflightMap.emplace(request.get(), request); + else + _inflightVec.push_back(request); } -void InflightRequests::insert(std::shared_ptr request) +void InflightRequests::_doRemove(const std::shared_ptr& request) { - std::scoped_lock localLock{_mutex}; - std::lock_guard lock(_trackedRequests->_mutex); - - _trackedRequests->_inflight.insert({request, request}); + if (_useMap) { + _inflightMap.erase(request.get()); + } else { + auto it = std::find(_inflightVec.begin(), _inflightVec.end(), request); + if (it != _inflightVec.end()) { + *it = std::move(_inflightVec.back()); + _inflightVec.pop_back(); + } + } } -void InflightRequests::merge(TrackedRequestsPtr trackedRequests) +size_t InflightRequests::_doInflightSize() const { - { - if (trackedRequests == nullptr) return; - - std::scoped_lock localLock{_mutex}; - std::scoped_lock lock{_trackedRequests->_cancelMutex, - _trackedRequests->_mutex, - trackedRequests->_cancelMutex, - trackedRequests->_mutex}; + return _useMap ? _inflightMap.size() : _inflightVec.size(); +} - _trackedRequests->_inflight.merge(trackedRequests->_inflight); - _trackedRequests->_canceling.merge(trackedRequests->_canceling); +std::vector> InflightRequests::_doTakeInflight() +{ + std::vector> result; + if (_useMap) { + result.reserve(_inflightMap.size()); + for (auto& kv : _inflightMap) + result.push_back(std::move(kv.second)); + _inflightMap.clear(); + } else { + result = std::exchange(_inflightVec, {}); } + return result; } -void InflightRequests::remove(std::shared_ptr request) -{ - do { - std::scoped_lock localLock{_mutex}; - int result = std::try_lock(_trackedRequests->_cancelMutex, _trackedRequests->_mutex); - - /** - * `result` can be have one of three values: - * -1 (both arguments were locked): Remove request and return. - * 0 (failed to lock argument 0): Failed acquiring `_cancelMutex`, cancel in - * progress, nothing to do but return. The method was - * called during execution of `cancelAll()` and the - * `Request*` callback was invoked. - * 1 (failed to lock argument 1): Only `_cancelMutex` was acquired, another - * operation in progress, retry. - */ - if (result == 0) { - return; - } else if (result == -1) { - auto search = _trackedRequests->_inflight.find(request); - decltype(search->second) tmpRequest; - if (search != _trackedRequests->_inflight.end()) { - /** - * If this is the last request to hold `std::shared_ptr` erasing it - * may cause the `ucxx::Endpoint`s destructor and subsequently the `closeBlocking()` - * method to be called which will in turn call `cancelAll()` and attempt to take the - * mutexes. For this reason we should make a temporary copy of the request being - * erased from `_trackedRequests->_inflight` to allow unlocking the mutexes and only then - * destroy the object upon this method's return. - */ - tmpRequest = search->second; - _trackedRequests->_inflight.erase(search); - } - _trackedRequests->_cancelMutex.unlock(); - _trackedRequests->_mutex.unlock(); - return; - } - } while (true); +void InflightRequests::_doPutCanceling(std::vector>* requests) +{ + if (_useMap) { + for (auto& r : *requests) + if (r) _cancelingMap.emplace(r.get(), std::move(r)); + } else { + _cancelingVec.insert(_cancelingVec.end(), + std::make_move_iterator(requests->begin()), + std::make_move_iterator(requests->end())); + } } -size_t InflightRequests::dropCanceled() +size_t InflightRequests::_doDropCanceled() { size_t removed = 0; - - { - std::scoped_lock localLock{_mutex}; - std::scoped_lock lock{_trackedRequests->_cancelMutex}; - for (auto it = _trackedRequests->_canceling.begin(); - it != _trackedRequests->_canceling.end();) { - auto request = it->second; - if (request != nullptr && request->getStatus() != UCS_INPROGRESS) { - it = _trackedRequests->_canceling.erase(it); + if (_useMap) { + for (auto it = _cancelingMap.begin(); it != _cancelingMap.end();) { + if (it->second && it->second->getStatus() != UCS_INPROGRESS) { + it = _cancelingMap.erase(it); ++removed; } else { ++it; } } + } else { + auto newEnd = std::remove_if( + _cancelingVec.begin(), _cancelingVec.end(), [](const std::shared_ptr& r) { + return r && r->getStatus() != UCS_INPROGRESS; + }); + removed = static_cast(std::distance(newEnd, _cancelingVec.end())); + _cancelingVec.erase(newEnd, _cancelingVec.end()); } - return removed; } -size_t InflightRequests::getCancelingSize() +size_t InflightRequests::_doCancelingSize() const { - dropCanceled(); - size_t cancelingSize = 0; - { - std::scoped_lock localLock{_mutex}; - std::scoped_lock lock{_trackedRequests->_cancelMutex}; - cancelingSize = _trackedRequests->_canceling.size(); + return _useMap ? _cancelingMap.size() : _cancelingVec.size(); +} + +void InflightRequests::_doMergeInflight(std::vector>* requests) +{ + if (_useMap) { + for (auto& r : *requests) + if (r) _inflightMap.emplace(r.get(), std::move(r)); + } else { + _inflightVec.insert(_inflightVec.end(), + std::make_move_iterator(requests->begin()), + std::make_move_iterator(requests->end())); } +} - return cancelingSize; +void InflightRequests::_doMergeCanceling(std::vector>* requests) +{ + _doPutCanceling(requests); +} + +std::vector> InflightRequests::_doTakeCanceling() +{ + std::vector> result; + if (_useMap) { + result.reserve(_cancelingMap.size()); + for (auto& kv : _cancelingMap) + result.push_back(std::move(kv.second)); + _cancelingMap.clear(); + } else { + result = std::exchange(_cancelingVec, {}); + } + return result; +} + +// ---- Public API ------------------------------------------------------------- + +InflightRequests::InflightRequests() +{ + const char* env = std::getenv("UCXX_INFLIGHT_REQUESTS_BACKEND"); + if (env != nullptr && std::strcmp(env, "map") == 0) _useMap = true; +} + +InflightRequests::~InflightRequests() { cancelAll(); } + +size_t InflightRequests::size() +{ + std::lock_guard lock(_mutex); + return _doInflightSize(); +} + +void InflightRequests::insert(const std::shared_ptr& request) +{ + std::lock_guard lock(_mutex); + _doInsert(request); +} + +void InflightRequests::remove(const std::shared_ptr& request) +{ + std::lock_guard lock(_mutex); + _doRemove(request); +} + +void InflightRequests::merge(TrackedRequests&& trackedRequests) +{ + std::lock_guard lock(_mutex); + _doMergeInflight(&trackedRequests.inflight); + _doMergeCanceling(&trackedRequests.canceling); } size_t InflightRequests::cancelAll() { - decltype(_trackedRequests->_inflight) toCancel; - size_t total; + std::vector> toCancel; { - std::scoped_lock localLock{_mutex}; - std::scoped_lock lock{_trackedRequests->_cancelMutex, _trackedRequests->_mutex}; - total = _trackedRequests->_inflight.size(); - - // Fast path when no requests have been registered or the map has been - // previously released. - if (total == 0) return 0; + std::lock_guard lock(_mutex); + toCancel = _doTakeInflight(); + } - toCancel = std::exchange(_trackedRequests->_inflight, InflightRequestsMap()); + size_t total = toCancel.size(); + if (total == 0) return 0; - ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests", __func__, total); + ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests", __func__, total); - for (auto& r : toCancel) { - auto request = r.second; - if (request != nullptr) { request->cancel(); } - } + for (auto& r : toCancel) { + if (r) r->cancel(); + } - _trackedRequests->_canceling.merge(toCancel); + { + std::lock_guard lock(_mutex); - // Drop canceled requests. Do not call `dropCanceled()` to prevent locking mutexes - // again. - for (auto it = _trackedRequests->_canceling.begin(); - it != _trackedRequests->_canceling.end();) { - auto request = it->second; - if (request != nullptr && request->getStatus() != UCS_INPROGRESS) { - it = _trackedRequests->_canceling.erase(it); - } else { - ++it; - } + // Keep requests that are still in progress; drop completed ones. + std::vector> stillCanceling; + for (auto& r : toCancel) { + if (r && r->getStatus() == UCS_INPROGRESS) stillCanceling.push_back(std::move(r)); } + _doPutCanceling(&stillCanceling); } return total; } -TrackedRequestsPtr InflightRequests::release() +TrackedRequests InflightRequests::release() { - std::scoped_lock localLock{_mutex}; - std::scoped_lock lock{_trackedRequests->_cancelMutex, _trackedRequests->_mutex}; + std::lock_guard lock(_mutex); + TrackedRequests result; + result.inflight = _doTakeInflight(); + result.canceling = _doTakeCanceling(); + return result; +} - return std::exchange(_trackedRequests, std::make_unique()); +size_t InflightRequests::getCancelingSize() +{ + std::lock_guard lock(_mutex); + _doDropCanceled(); + return _doCancelingSize(); } } // namespace ucxx diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index f4134b080..a0d0da5d2 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -538,13 +538,13 @@ size_t Worker::cancelInflightRequests(uint64_t period, uint64_t maxAttempts) if (inflightRequestsToCancel->getCancelingSize() > 0) { std::lock_guard lock(_inflightRequestsMutex); - _inflightRequestsToCancel->merge(inflightRequestsToCancel->release()); + _inflightRequestsToCancel->merge(std::move(inflightRequestsToCancel->release())); } return canceled; } -void Worker::scheduleRequestCancel(TrackedRequestsPtr trackedRequests) +void Worker::scheduleRequestCancel(TrackedRequests trackedRequests) { { std::lock_guard lock(_inflightRequestsMutex); @@ -554,7 +554,7 @@ void Worker::scheduleRequestCancel(TrackedRequestsPtr trackedRequests) __func__, this, _handle, - trackedRequests->_inflight.size() + trackedRequests->_canceling.size()); + trackedRequests.inflight.size() + trackedRequests.canceling.size()); _inflightRequestsToCancel->merge(std::move(trackedRequests)); } } From 062d60ee8c8a00bfc6aa59ec0f8ffc39f7f3df81 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 31 Mar 2026 10:56:22 -0700 Subject: [PATCH 31/39] Remove std::vector-based option --- cpp/include/ucxx/inflight_requests.h | 36 +----- cpp/src/inflight_requests.cpp | 172 +++++---------------------- 2 files changed, 38 insertions(+), 170 deletions(-) diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index 78cc0eb4a..3a9aa93f0 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -4,7 +4,6 @@ */ #pragma once -#include #include #include #include @@ -32,44 +31,21 @@ struct TrackedRequests { * Handle tracked requests, providing functionality so that its owner can modify those * requests, performing operations such as insertion, removal and cancelation. * - * Two container backends are available, selected at construction time via the - * `UCXX_INFLIGHT_REQUESTS_BACKEND` environment variable: - * - `vector` (default): best latency for small request counts, cache-friendly, - * no per-insert heap allocation, O(n) removal. - * - `map`: O(1) amortized insert/remove, scales to thousands of concurrent - * inflight requests. + * Uses `std::unordered_map>` for O(1) amortized + * insert/remove that scales to thousands of concurrent inflight requests. */ class InflightRequests { private: - bool _useMap{false}; - - std::vector> _inflightVec{}; - std::vector> _cancelingVec{}; - - std::unordered_map> _inflightMap{}; - std::unordered_map> _cancelingMap{}; + std::unordered_map> _inflight{}; + std::unordered_map> _canceling{}; std::mutex _mutex{}; - void _doInsert(const std::shared_ptr& request); - void _doRemove(const std::shared_ptr& request); - size_t _doInflightSize() const; - std::vector> _doTakeInflight(); - void _doPutCanceling(std::vector>* requests); - size_t _doDropCanceled(); - size_t _doCancelingSize() const; - void _doMergeInflight(std::vector>* requests); - void _doMergeCanceling(std::vector>* requests); - std::vector> _doTakeCanceling(); - public: /** - * @brief Construct with backend selected from UCXX_INFLIGHT_REQUESTS_BACKEND env var. - * - * Reads the `UCXX_INFLIGHT_REQUESTS_BACKEND` environment variable to select the - * container backend. Valid values are `vector` (default) and `map`. + * @brief Default constructor. */ - InflightRequests(); + InflightRequests() = default; InflightRequests(const InflightRequests&) = delete; InflightRequests& operator=(InflightRequests const&) = delete; diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index fecd7d445..80d88aec9 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -2,11 +2,7 @@ * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ -#include -#include -#include #include -#include #include #include @@ -16,152 +12,33 @@ namespace ucxx { -// ---- Private backend helpers ------------------------------------------------ - -void InflightRequests::_doInsert(const std::shared_ptr& request) -{ - if (_useMap) - _inflightMap.emplace(request.get(), request); - else - _inflightVec.push_back(request); -} - -void InflightRequests::_doRemove(const std::shared_ptr& request) -{ - if (_useMap) { - _inflightMap.erase(request.get()); - } else { - auto it = std::find(_inflightVec.begin(), _inflightVec.end(), request); - if (it != _inflightVec.end()) { - *it = std::move(_inflightVec.back()); - _inflightVec.pop_back(); - } - } -} - -size_t InflightRequests::_doInflightSize() const -{ - return _useMap ? _inflightMap.size() : _inflightVec.size(); -} - -std::vector> InflightRequests::_doTakeInflight() -{ - std::vector> result; - if (_useMap) { - result.reserve(_inflightMap.size()); - for (auto& kv : _inflightMap) - result.push_back(std::move(kv.second)); - _inflightMap.clear(); - } else { - result = std::exchange(_inflightVec, {}); - } - return result; -} - -void InflightRequests::_doPutCanceling(std::vector>* requests) -{ - if (_useMap) { - for (auto& r : *requests) - if (r) _cancelingMap.emplace(r.get(), std::move(r)); - } else { - _cancelingVec.insert(_cancelingVec.end(), - std::make_move_iterator(requests->begin()), - std::make_move_iterator(requests->end())); - } -} - -size_t InflightRequests::_doDropCanceled() -{ - size_t removed = 0; - if (_useMap) { - for (auto it = _cancelingMap.begin(); it != _cancelingMap.end();) { - if (it->second && it->second->getStatus() != UCS_INPROGRESS) { - it = _cancelingMap.erase(it); - ++removed; - } else { - ++it; - } - } - } else { - auto newEnd = std::remove_if( - _cancelingVec.begin(), _cancelingVec.end(), [](const std::shared_ptr& r) { - return r && r->getStatus() != UCS_INPROGRESS; - }); - removed = static_cast(std::distance(newEnd, _cancelingVec.end())); - _cancelingVec.erase(newEnd, _cancelingVec.end()); - } - return removed; -} - -size_t InflightRequests::_doCancelingSize() const -{ - return _useMap ? _cancelingMap.size() : _cancelingVec.size(); -} - -void InflightRequests::_doMergeInflight(std::vector>* requests) -{ - if (_useMap) { - for (auto& r : *requests) - if (r) _inflightMap.emplace(r.get(), std::move(r)); - } else { - _inflightVec.insert(_inflightVec.end(), - std::make_move_iterator(requests->begin()), - std::make_move_iterator(requests->end())); - } -} - -void InflightRequests::_doMergeCanceling(std::vector>* requests) -{ - _doPutCanceling(requests); -} - -std::vector> InflightRequests::_doTakeCanceling() -{ - std::vector> result; - if (_useMap) { - result.reserve(_cancelingMap.size()); - for (auto& kv : _cancelingMap) - result.push_back(std::move(kv.second)); - _cancelingMap.clear(); - } else { - result = std::exchange(_cancelingVec, {}); - } - return result; -} - -// ---- Public API ------------------------------------------------------------- - -InflightRequests::InflightRequests() -{ - const char* env = std::getenv("UCXX_INFLIGHT_REQUESTS_BACKEND"); - if (env != nullptr && std::strcmp(env, "map") == 0) _useMap = true; -} - InflightRequests::~InflightRequests() { cancelAll(); } size_t InflightRequests::size() { std::lock_guard lock(_mutex); - return _doInflightSize(); + return _inflight.size(); } void InflightRequests::insert(const std::shared_ptr& request) { std::lock_guard lock(_mutex); - _doInsert(request); + _inflight.emplace(request.get(), request); } void InflightRequests::remove(const std::shared_ptr& request) { std::lock_guard lock(_mutex); - _doRemove(request); + _inflight.erase(request.get()); } void InflightRequests::merge(TrackedRequests&& trackedRequests) { std::lock_guard lock(_mutex); - _doMergeInflight(&trackedRequests.inflight); - _doMergeCanceling(&trackedRequests.canceling); + for (auto& r : trackedRequests.inflight) + if (r) _inflight.emplace(r.get(), std::move(r)); + for (auto& r : trackedRequests.canceling) + if (r) _canceling.emplace(r.get(), std::move(r)); } size_t InflightRequests::cancelAll() @@ -169,7 +46,10 @@ size_t InflightRequests::cancelAll() std::vector> toCancel; { std::lock_guard lock(_mutex); - toCancel = _doTakeInflight(); + toCancel.reserve(_inflight.size()); + for (auto& kv : _inflight) + toCancel.push_back(std::move(kv.second)); + _inflight.clear(); } size_t total = toCancel.size(); @@ -183,13 +63,9 @@ size_t InflightRequests::cancelAll() { std::lock_guard lock(_mutex); - - // Keep requests that are still in progress; drop completed ones. - std::vector> stillCanceling; for (auto& r : toCancel) { - if (r && r->getStatus() == UCS_INPROGRESS) stillCanceling.push_back(std::move(r)); + if (r && r->getStatus() == UCS_INPROGRESS) _canceling.emplace(r.get(), std::move(r)); } - _doPutCanceling(&stillCanceling); } return total; @@ -199,16 +75,32 @@ TrackedRequests InflightRequests::release() { std::lock_guard lock(_mutex); TrackedRequests result; - result.inflight = _doTakeInflight(); - result.canceling = _doTakeCanceling(); + + result.inflight.reserve(_inflight.size()); + for (auto& kv : _inflight) + result.inflight.push_back(std::move(kv.second)); + _inflight.clear(); + + result.canceling.reserve(_canceling.size()); + for (auto& kv : _canceling) + result.canceling.push_back(std::move(kv.second)); + _canceling.clear(); + return result; } size_t InflightRequests::getCancelingSize() { std::lock_guard lock(_mutex); - _doDropCanceled(); - return _doCancelingSize(); + + for (auto it = _canceling.begin(); it != _canceling.end();) { + if (it->second && it->second->getStatus() != UCS_INPROGRESS) + it = _canceling.erase(it); + else + ++it; + } + + return _canceling.size(); } } // namespace ucxx From 30b0d8defb7dcef97cbf39147eb6c1295f71ca11 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Tue, 31 Mar 2026 12:09:55 -0700 Subject: [PATCH 32/39] Simplify to use `std::unordered_set` --- cpp/include/ucxx/inflight_requests.h | 6 +++--- cpp/src/inflight_requests.cpp | 28 +++++++++++++--------------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index 3a9aa93f0..0f9e13a03 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -6,7 +6,7 @@ #include #include -#include +#include #include namespace ucxx { @@ -36,8 +36,8 @@ struct TrackedRequests { */ class InflightRequests { private: - std::unordered_map> _inflight{}; - std::unordered_map> _canceling{}; + std::unordered_set> _inflight{}; + std::unordered_set> _canceling{}; std::mutex _mutex{}; diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 80d88aec9..4a411eb7b 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -23,33 +23,30 @@ size_t InflightRequests::size() void InflightRequests::insert(const std::shared_ptr& request) { std::lock_guard lock(_mutex); - _inflight.emplace(request.get(), request); + _inflight.insert(request); } void InflightRequests::remove(const std::shared_ptr& request) { std::lock_guard lock(_mutex); - _inflight.erase(request.get()); + _inflight.erase(request); } void InflightRequests::merge(TrackedRequests&& trackedRequests) { std::lock_guard lock(_mutex); for (auto& r : trackedRequests.inflight) - if (r) _inflight.emplace(r.get(), std::move(r)); + if (r) _inflight.insert(std::move(r)); for (auto& r : trackedRequests.canceling) - if (r) _canceling.emplace(r.get(), std::move(r)); + if (r) _canceling.insert(std::move(r)); } size_t InflightRequests::cancelAll() { - std::vector> toCancel; + decltype(_inflight) toCancel; { std::lock_guard lock(_mutex); - toCancel.reserve(_inflight.size()); - for (auto& kv : _inflight) - toCancel.push_back(std::move(kv.second)); - _inflight.clear(); + toCancel = std::exchange(_inflight, {}); } size_t total = toCancel.size(); @@ -64,7 +61,8 @@ size_t InflightRequests::cancelAll() { std::lock_guard lock(_mutex); for (auto& r : toCancel) { - if (r && r->getStatus() == UCS_INPROGRESS) _canceling.emplace(r.get(), std::move(r)); + if (r && r->getStatus() == UCS_INPROGRESS) + _canceling.insert(std::move(const_cast&>(r))); } } @@ -77,13 +75,13 @@ TrackedRequests InflightRequests::release() TrackedRequests result; result.inflight.reserve(_inflight.size()); - for (auto& kv : _inflight) - result.inflight.push_back(std::move(kv.second)); + for (auto& r : _inflight) + result.inflight.push_back(r); _inflight.clear(); result.canceling.reserve(_canceling.size()); - for (auto& kv : _canceling) - result.canceling.push_back(std::move(kv.second)); + for (auto& r : _canceling) + result.canceling.push_back(r); _canceling.clear(); return result; @@ -94,7 +92,7 @@ size_t InflightRequests::getCancelingSize() std::lock_guard lock(_mutex); for (auto it = _canceling.begin(); it != _canceling.end();) { - if (it->second && it->second->getStatus() != UCS_INPROGRESS) + if (*it && (*it)->getStatus() != UCS_INPROGRESS) it = _canceling.erase(it); else ++it; From de5e5c58d001e1265a783d2716bcff302f8e0e56 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 2 Apr 2026 02:31:47 -0700 Subject: [PATCH 33/39] Fix build errors --- cpp/tests/endpoint.cpp | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/cpp/tests/endpoint.cpp b/cpp/tests/endpoint.cpp index a4d3de6fc..08b76543f 100644 --- a/cpp/tests/endpoint.cpp +++ b/cpp/tests/endpoint.cpp @@ -1,7 +1,8 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ +#include #include #include #include @@ -164,12 +165,15 @@ TEST_F(EndpointTest, StoppingRejectRequests) std::vector tmp(10, 42); - EXPECT_THROW(ep->amSend(tmp.data(), tmp.size() * sizeof(int), UCS_MEMORY_TYPE_HOST), + EXPECT_THROW( + static_cast(ep->amSend(tmp.data(), tmp.size() * sizeof(int), UCS_MEMORY_TYPE_HOST)), + ucxx::RejectedError); + EXPECT_THROW(static_cast(ep->tagSend(tmp.data(), tmp.size() * sizeof(int), ucxx::Tag{0})), ucxx::RejectedError); - EXPECT_THROW(ep->tagSend(tmp.data(), tmp.size() * sizeof(int), ucxx::Tag{0}), + EXPECT_THROW(static_cast(ep->streamRecv(tmp.data(), tmp.size() * sizeof(int))), + ucxx::RejectedError); + EXPECT_THROW(static_cast(ep->streamSend(tmp.data(), tmp.size() * sizeof(int))), ucxx::RejectedError); - EXPECT_THROW(ep->streamRecv(tmp.data(), tmp.size() * sizeof(int)), ucxx::RejectedError); - EXPECT_THROW(ep->streamSend(tmp.data(), tmp.size() * sizeof(int)), ucxx::RejectedError); { auto memoryHandle = _context->createMemoryHandle(tmp.size() * sizeof(int), nullptr); @@ -179,23 +183,26 @@ TEST_F(EndpointTest, StoppingRejectRequests) auto remoteKey = ucxx::createRemoteKeyFromSerialized(ep, serializedRemoteKey); std::vector> requests; - EXPECT_THROW(ep->memPut(tmp.data(), tmp.size() * sizeof(int), remoteKey), ucxx::RejectedError); + EXPECT_THROW(static_cast(ep->memPut(tmp.data(), tmp.size() * sizeof(int), remoteKey)), + ucxx::RejectedError); EXPECT_THROW( - ep->memPut( - tmp.data(), tmp.size() * sizeof(int), remoteKey->getBaseAddress(), remoteKey->getHandle()), + static_cast(ep->memPut( + tmp.data(), tmp.size() * sizeof(int), remoteKey->getBaseAddress(), remoteKey->getHandle())), ucxx::RejectedError); - EXPECT_THROW(ep->memGet(tmp.data(), tmp.size() * sizeof(int), remoteKey), ucxx::RejectedError); + EXPECT_THROW(static_cast(ep->memGet(tmp.data(), tmp.size() * sizeof(int), remoteKey)), + ucxx::RejectedError); EXPECT_THROW( - ep->memGet( - tmp.data(), tmp.size() * sizeof(int), remoteKey->getBaseAddress(), remoteKey->getHandle()), + static_cast(ep->memGet( + tmp.data(), tmp.size() * sizeof(int), remoteKey->getBaseAddress(), remoteKey->getHandle())), ucxx::RejectedError); } { - std::vector buffers{tmp.data()}; + std::vector buffers{tmp.data()}; std::vector sizes{tmp.size()}; std::vector isCUDA{false}; - EXPECT_THROW(ep->tagMultiSend(buffers, sizes, isCUDA, ucxx::Tag{0}), ucxx::RejectedError); + EXPECT_THROW(static_cast(ep->tagMultiSend(buffers, sizes, isCUDA, ucxx::Tag{0})), + ucxx::RejectedError); } } From 638b91e889cc28eb55803228d3f5634e83f37137 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 2 Apr 2026 02:37:39 -0700 Subject: [PATCH 34/39] Fix linting --- cpp/include/ucxx/inflight_requests.h | 3 ++- cpp/include/ucxx/typedefs.h | 8 ++++++++ cpp/include/ucxx/utils/callback_notifier.h | 2 +- cpp/src/endpoint.cpp | 10 +++++----- cpp/src/request.cpp | 2 +- cpp/src/request_endpoint_close.cpp | 2 +- cpp/tests/include/utils.h | 2 +- 7 files changed, 19 insertions(+), 10 deletions(-) diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index 4a00338fc..a8d4ea2dd 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -131,7 +131,8 @@ class InflightRequests { * @param[in] callbackFunction function to be called upon termination and only if no * further requests inflight or canceling remain. */ - void remove(std::shared_ptr request, GenericCallbackUserFunction callbackFunction = nullptr); + void remove(std::shared_ptr request, + GenericCallbackUserFunction callbackFunction = nullptr); /** * @brief Issue cancelation of all inflight requests and clear the internal container. diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index b34993e74..4772c92fc 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -239,6 +239,14 @@ struct AmSendParams { */ typedef const std::string SerializedRemoteKey; +/** + * @brief A user-defined callback with no parameters or return value. + * + * Used where UCXX needs a simple notification hook, for example optional callbacks on + * `ucxx::InflightRequests::remove()` and `cancelAll()` when no inflight or canceling + * requests remain, and for endpoint-driven cancelation completion in + * `ucxx::Endpoint::cancelInflightRequests()`. + */ typedef std::function GenericCallbackUserFunction; } // namespace ucxx diff --git a/cpp/include/ucxx/utils/callback_notifier.h b/cpp/include/ucxx/utils/callback_notifier.h index 9a4e39049..dccdec99b 100644 --- a/cpp/include/ucxx/utils/callback_notifier.h +++ b/cpp/include/ucxx/utils/callback_notifier.h @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 23674ba77..cdd89ce31 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -495,11 +495,11 @@ std::shared_ptr Endpoint::amSend( } std::shared_ptr Endpoint::amSend(const void* const buffer, - const size_t length, - const AmSendParams& params, - const bool enablePythonFuture, - RequestCallbackUserFunction callbackFunction, - RequestCallbackUserData callbackData) + const size_t length, + const AmSendParams& params, + const bool enablePythonFuture, + RequestCallbackUserFunction callbackFunction, + RequestCallbackUserData callbackData) { if (_stopping.load()) throw RejectedError("Endpoint is stopping."); diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 780a448a9..e21cf6446 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -70,7 +70,7 @@ Request::Request(std::shared_ptr endpointOrWorker, Request::~Request() { if (_cancelCallback != nullptr) { - std::ignore = _cancelCallbackNotifier.wait(1000000000 /* 1s */); + std::ignore = _cancelCallbackNotifier.wait(1000000000 /* 1s */); _cancelCallback = nullptr; } if (UCS_PTR_IS_PTR(_request)) { diff --git a/cpp/src/request_endpoint_close.cpp b/cpp/src/request_endpoint_close.cpp index 1caa5dd31..629fe29ed 100644 --- a/cpp/src/request_endpoint_close.cpp +++ b/cpp/src/request_endpoint_close.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include "ucxx/request_data.h" diff --git a/cpp/tests/include/utils.h b/cpp/tests/include/utils.h index 8d357e780..34cddc087 100644 --- a/cpp/tests/include/utils.h +++ b/cpp/tests/include/utils.h @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once From ec1a3e12d8b174b1dc0d59e71f78250e42cf55aa Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 2 Apr 2026 03:09:27 -0700 Subject: [PATCH 35/39] Rename to VoidCallbackUserFunction --- cpp/include/ucxx/endpoint.h | 6 +++--- cpp/include/ucxx/inflight_requests.h | 4 ++-- cpp/include/ucxx/typedefs.h | 2 +- cpp/src/endpoint.cpp | 2 +- cpp/src/inflight_requests.cpp | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 9d742e947..94ab5daa6 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -64,10 +64,10 @@ class Endpoint : public Component { EndpointCloseCallbackUserFunction _closeCallback{nullptr}; ///< Close callback to call EndpointCloseCallbackUserData _closeCallbackArg{ nullptr}; ///< Argument to be passed to close callback - GenericCallbackUserFunction _cancelInflightCallback{ + VoidCallbackUserFunction _cancelInflightCallback{ nullptr}; ///< The wrapper to the callback registered via `cancelInflightRequests()` that will ///< deregister once the callback is called. - GenericCallbackUserFunction _cancelInflightCallbackOriginal{ + VoidCallbackUserFunction _cancelInflightCallbackOriginal{ nullptr}; ///< The original user callback registered via `cancelInflightRequests()` /** @@ -291,7 +291,7 @@ class Endpoint : public Component { * * @returns Number of requests that were scheduled for cancelation. */ - size_t cancelInflightRequests(GenericCallbackUserFunction callbackFunction = nullptr); + size_t cancelInflightRequests(VoidCallbackUserFunction callbackFunction = nullptr); /** * @brief Check the number of inflight requests being canceled. diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index a8d4ea2dd..cf4957146 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -132,7 +132,7 @@ class InflightRequests { * further requests inflight or canceling remain. */ void remove(std::shared_ptr request, - GenericCallbackUserFunction callbackFunction = nullptr); + VoidCallbackUserFunction callbackFunction = nullptr); /** * @brief Issue cancelation of all inflight requests and clear the internal container. @@ -152,7 +152,7 @@ class InflightRequests { * * @returns The total number of canceled requests. */ - size_t cancelAll(GenericCallbackUserFunction callbackFunction = nullptr); + size_t cancelAll(VoidCallbackUserFunction callbackFunction = nullptr); /** * @brief Releases the internally-tracked containers. diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index 4772c92fc..b642360e8 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -247,6 +247,6 @@ typedef const std::string SerializedRemoteKey; * requests remain, and for endpoint-driven cancelation completion in * `ucxx::Endpoint::cancelInflightRequests()`. */ -typedef std::function GenericCallbackUserFunction; +typedef std::function VoidCallbackUserFunction; } // namespace ucxx diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index cdd89ce31..6bab9992f 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -419,7 +419,7 @@ void Endpoint::removeInflightRequest(std::shared_ptr request) _inflightRequests->remove(request, _cancelInflightCallback); } -size_t Endpoint::cancelInflightRequests(GenericCallbackUserFunction callback) +size_t Endpoint::cancelInflightRequests(VoidCallbackUserFunction callback) { _cancelInflightCallbackOriginal = callback; diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 674f333d3..042e7c74f 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -66,7 +66,7 @@ static std::unique_ptr findAndRemove(InflightRequestsMap* r } void InflightRequests::remove(std::shared_ptr request, - GenericCallbackUserFunction cancelInflightCallback) + VoidCallbackUserFunction cancelInflightCallback) { do { std::scoped_lock localLock{_mutex}; @@ -164,7 +164,7 @@ size_t InflightRequests::getInflightSize() return inflightSize; } -size_t InflightRequests::cancelAll(GenericCallbackUserFunction cancelInflightCallback) +size_t InflightRequests::cancelAll(VoidCallbackUserFunction cancelInflightCallback) { size_t total = 0; From 7bbf55d3f9d05726240ea647db559df0f7471014 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 2 Apr 2026 04:18:02 -0700 Subject: [PATCH 36/39] Fix deadlock during InflightRequests::remove --- cpp/src/inflight_requests.cpp | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 042e7c74f..b217c7b76 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -3,6 +3,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ #include +#include #include #include @@ -69,7 +70,7 @@ void InflightRequests::remove(std::shared_ptr request, VoidCallbackUserFunction cancelInflightCallback) { do { - std::scoped_lock localLock{_mutex}; + std::unique_lock localLock{_mutex}; int result = std::try_lock(_trackedRequests->_cancelMutex, _trackedRequests->_mutex); /** @@ -97,10 +98,14 @@ void InflightRequests::remove(std::shared_ptr request, _trackedRequests->_inflight.size() + _trackedRequests->_canceling.size(); /** - * Unlock `_mutex` before calling the user callback to prevent deadlocks in case the - * user callback happens to register another inflight request. + * Unlock the outer mutex before calling the user callback to prevent deadlocks in + * case the user callback happens to register another inflight request. Use + * `unique_lock` so we do not double-unlock when this scope ends. + * + * `std::try_lock` locked `_cancelMutex` then `_trackedRequests->_mutex`; unlock both + * in reverse order before returning. */ - _mutex.unlock(); + localLock.unlock(); try { if (cancelInflightCallback && trackedRequestsCount == 0) { ucxx_trace("ucxx::InflightRequests::%s: %p, calling user cancel inflight callback", @@ -108,12 +113,18 @@ void InflightRequests::remove(std::shared_ptr request, this); cancelInflightCallback(); } + _trackedRequests->_mutex.unlock(); _trackedRequests->_cancelMutex.unlock(); return; } catch (const std::exception& e) { ucxx_warn("Exception in callback: %s", e.what()); + _trackedRequests->_mutex.unlock(); + _trackedRequests->_cancelMutex.unlock(); + throw; + } catch (...) { + _trackedRequests->_mutex.unlock(); _trackedRequests->_cancelMutex.unlock(); - throw(e); + throw; } } } while (true); From 0172a2da2fea95bb05e5b1464a5d0dbe2a187e3f Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 2 Apr 2026 13:50:12 -0700 Subject: [PATCH 37/39] Fix that fixes first part --- cpp/include/ucxx/inflight_requests.h | 2 ++ cpp/src/endpoint.cpp | 14 ++++++++++---- cpp/src/inflight_requests.cpp | 6 +++++- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index 392d5eb1c..9db9f7cf9 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -4,6 +4,7 @@ */ #pragma once +#include #include #include #include @@ -42,6 +43,7 @@ class InflightRequests { std::unordered_set> _canceling{}; std::mutex _mutex{}; + std::atomic _cancelAllInProgress{false}; public: /** diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 6bab9992f..2b524dfde 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -423,11 +423,17 @@ size_t Endpoint::cancelInflightRequests(VoidCallbackUserFunction callback) { _cancelInflightCallbackOriginal = callback; - // Wrapper responsible for deregistering after callback is called. + /** + * Wrapper responsible for deregistering before callback is called. The member + * `_cancelInflightCallback` is moved into a local (`selfCopy`) before invoking the + * user callback so that any new inflight requests created during the callback (e.g., the + * close request) will NOT re-trigger this callback when they complete and are removed. + * `selfCopy` keeps the wrapper lambda alive for the duration of this invocation. + */ _cancelInflightCallback = [this]() { - if (_cancelInflightCallbackOriginal != nullptr) _cancelInflightCallbackOriginal(); - _cancelInflightCallback = nullptr; - _cancelInflightCallbackOriginal = nullptr; + auto selfCopy = std::move(_cancelInflightCallback); + auto callback = std::exchange(_cancelInflightCallbackOriginal, nullptr); + if (callback) callback(); }; auto canceled = _inflightRequests->cancelAll(_cancelInflightCallback); diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 1b94648c4..d5d4f2929 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -34,7 +34,9 @@ void InflightRequests::remove(const std::shared_ptr& request, std::lock_guard lock(_mutex); _inflight.erase(request); _canceling.erase(request); - if (callbackFunction && _inflight.empty() && _canceling.empty()) { shouldCallback = true; } + if (!_cancelAllInProgress && callbackFunction && _inflight.empty() && _canceling.empty()) { + shouldCallback = true; + } } if (shouldCallback) { @@ -78,9 +80,11 @@ size_t InflightRequests::cancelAll(VoidCallbackUserFunction callbackFunction) ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests", __func__, total); + _cancelAllInProgress = true; for (auto& r : toCancel) { if (r) r->cancel(); } + _cancelAllInProgress = false; bool shouldCallback = false; { From 29c991f1140f72c8afe8f6c3b0601592e16e333b Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 3 Apr 2026 10:46:01 -0700 Subject: [PATCH 38/39] Supposedly fix std::bad_variant exception --- cpp/src/request.cpp | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index e21cf6446..297605b64 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -261,8 +261,6 @@ void Request::setStatus(ucs_status_t status) { std::lock_guard lock(_mutex); - removeInflightRequest(); - ucxx_trace_req_f(_ownerString.c_str(), this, _request, @@ -271,15 +269,20 @@ void Request::setStatus(ucs_status_t status) status, ucs_status_string(status)); - if (_status != UCS_INPROGRESS) - ucxx_error( + if (_status != UCS_INPROGRESS) { + ucxx_warn( "ucxx::Request: %p, setStatus called with status: %d (%s) but status: %d (%s) was " - "already set", + "already set, ignoring", this, status, ucs_status_string(status), _status, ucs_status_string(_status)); + return; + } + + removeInflightRequest(); + _status = status; if (_enablePythonFuture) { From 4fe4ba82a964af087b1a1548b94b253da110c846 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Sun, 5 Apr 2026 08:53:07 -0700 Subject: [PATCH 39/39] Fix check of in progress AM requests --- cpp/src/request_am.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index bc22142f7..dbea8ea4b 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -365,14 +365,10 @@ ucs_status_t RequestAm::recvCallback(void* arg, buf->data(), length); - if (req->isCompleted()) { - // The request completed/errored immediately - ucs_status_t s = UCS_PTR_STATUS(status); - recvAmMessage->callback(nullptr, s); - - return s; - } else { - // The request will be handled by the callback + if (UCS_PTR_IS_PTR(status)) { + // The request is in progress, register for the completion callback. The + // recvAmMessage must be stored in the map to prevent use-after-free when + // the UCX callback fires. recvAmMessage->setUcpRequest(status); amData->_registerInflightRequest(req); @@ -382,6 +378,12 @@ ucs_status_t RequestAm::recvCallback(void* arg, } return UCS_INPROGRESS; + } else { + // The request completed/errored immediately + ucs_status_t s = UCS_PTR_RAW_STATUS(status); + recvAmMessage->callback(nullptr, s); + + return s; } } else { buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length);