diff --git a/cpp/include/ucxx/endpoint.h b/cpp/include/ucxx/endpoint.h index 7a3cd8604..94ab5daa6 100644 --- a/cpp/include/ucxx/endpoint.h +++ b/cpp/include/ucxx/endpoint.h @@ -60,9 +60,15 @@ class Endpoint : public Component { ///< that may run asynchronously on another thread. ucs_status_t _status{UCS_INPROGRESS}; ///< Endpoint status std::atomic _closing{false}; ///< Prevent calling close multiple concurrent times. + std::atomic _stopping{false}; ///< Signal whether endpoint is stopping. EndpointCloseCallbackUserFunction _closeCallback{nullptr}; ///< Close callback to call EndpointCloseCallbackUserData _closeCallbackArg{ nullptr}; ///< Argument to be passed to close callback + VoidCallbackUserFunction _cancelInflightCallback{ + nullptr}; ///< The wrapper to the callback registered via `cancelInflightRequests()` that will + ///< deregister once the callback is called. + VoidCallbackUserFunction _cancelInflightCallbackOriginal{ + nullptr}; ///< The original user callback registered via `cancelInflightRequests()` /** * @brief Private constructor of `ucxx::Endpoint`. @@ -271,9 +277,21 @@ class Endpoint : public Component { * progress the worker and check the result of `getCancelingSize()`, all requests are only * canceled when `getCancelingSize()` returns `0`. * + * Supports an optional callback function to be called exclusively if there are no + * more requests inflight or canceling. Be advised that before the callback is called the + * mutex that controls inflight requests is released to prevent deadlocks in case the + * callback happens to register a new inflight request, therefore there's no guarantee + * that another inflight request won't be registered between the time in which the mutex + * is released and the callback is executed, the user is thus responsible to prevent such + * situations and the use of `stop()` before `cancelInflightRequests()` is highly + * advisable. + * + * @param[in] callbackFunction function to be called upon termination and only if no + * further requests inflight or canceling remain. + * * @returns Number of requests that were scheduled for cancelation. */ - size_t cancelInflightRequests(); + size_t cancelInflightRequests(VoidCallbackUserFunction callbackFunction = nullptr); /** * @brief Check the number of inflight requests being canceled. @@ -286,6 +304,16 @@ class Endpoint : public Component { */ [[nodiscard]] size_t getCancelingSize() const; + /** + * @brief Check the number of inflight requests waiting for completion. + * + * Check the number of inflight requests that were posted but have not yet completed nor + * have been scheduled for cancelation. + * + * @returns Number of inflight requests that are waiting for completion. + */ + size_t getInflightSize() const; + /** * @brief Cancel inflight requests. * @@ -347,6 +375,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @note If a `callbackFunction` is specified, the lifetime of `callbackData` and of any * other objects used in the scope of `callbackFunction` must be guaranteed by the caller * until it executes or `isCompleted()` becomes true. The `callbackFunction` executes in @@ -467,6 +497,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @note If a `callbackFunction` is specified, the lifetime of `callbackData` and of any * other objects used in the scope of `callbackFunction` must be guaranteed by the caller * until it executes or `isCompleted()` becomes true. The `callbackFunction` executes in @@ -507,6 +539,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @note If a `callbackFunction` is specified, the lifetime of `callbackData` and of any * other objects used in the scope of `callbackFunction` must be guaranteed by the caller * until it executes or `isCompleted()` becomes true. The `callbackFunction` executes in @@ -549,6 +583,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @note If a `callbackFunction` is specified, the lifetime of `callbackData` and of any * other objects used in the scope of `callbackFunction` must be guaranteed by the caller * until it executes or `isCompleted()` becomes true. The `callbackFunction` executes in @@ -589,6 +625,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @note If a `callbackFunction` is specified, the lifetime of `callbackData` and of any * other objects used in the scope of `callbackFunction` must be guaranteed by the caller * until it executes or `isCompleted()` becomes true. The `callbackFunction` executes in @@ -631,6 +669,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @param[in] buffer a raw pointer to the data to be sent. * @param[in] length the size in bytes of the tag message to be sent. * @param[in] enablePythonFuture whether a python future should be created and @@ -640,7 +680,7 @@ class Endpoint : public Component { */ [[nodiscard]] std::shared_ptr streamSend(const void* const buffer, size_t length, - const bool enablePythonFuture); + const bool enablePythonFuture = false); /** * @brief Enqueue a stream receive operation. @@ -654,6 +694,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @param[in] buffer a raw pointer to pre-allocated memory where resulting * data will be stored. * @param[in] length the size in bytes of the tag message to be received. @@ -664,7 +706,7 @@ class Endpoint : public Component { */ [[nodiscard]] std::shared_ptr streamRecv(void* buffer, size_t length, - const bool enablePythonFuture); + const bool enablePythonFuture = false); /** * @brief Enqueue a tag send operation. @@ -678,6 +720,8 @@ class Endpoint : public Component { * Python future is requested, the Python application must then await on this future to * ensure the transfer has completed. Requires UCXX Python support. * + * @throws ucxx::RejectedError if `stop()` was already called. + * * @note If a `callbackFunction` is specified, the lifetime of `callbackData` and of any * other objects used in the scope of `callbackFunction` must be guaranteed by the caller * until it executes or `isCompleted()` becomes true. The `callbackFunction` executes in @@ -763,6 +807,7 @@ class Endpoint : public Component { * ensure the transfer has completed. Requires UCXX Python support. * * @throws std::runtime_error if sizes of `buffer`, `size` and `isCUDA` do not match. + * @throws ucxx::RejectedError if `stop()` was already called. * * @param[in] buffer a vector of raw pointers to the data frames to be sent. * @param[in] size a vector of size in bytes of each frame to be sent. @@ -780,7 +825,7 @@ class Endpoint : public Component { const std::vector& size, const std::vector& isCUDA, const Tag tag, - const bool enablePythonFuture); + const bool enablePythonFuture = false); /** * @brief Enqueue a multi-buffer tag receive operation. @@ -806,7 +851,7 @@ class Endpoint : public Component { */ [[nodiscard]] std::shared_ptr tagMultiRecv(const Tag tag, const TagMask tagMask, - const bool enablePythonFuture); + const bool enablePythonFuture = false); /** * @brief Enqueue a flush operation. @@ -857,7 +902,23 @@ class Endpoint : public Component { * * Enqueue a non-blocking endpoint close operation, which will close the endpoint without * requiring to destroy the object. This may be useful when other - * `std::shared_ptr` objects are still alive, such as inflight transfers. + * `std::shared_ptr` objects are still alive, such as inflight transfers, + * or the user wants to have more control over cancelation and closing order. + * + * @warning Unlike its `closeBlocking()` counterpart, this method does not cancel any + * inflight requests prior to submitting the UCP close request. Before scheduling the + * endpoint close request, the caller is advised to first call `stop()` to prevent new + * requests that require an active endpoint from being registered and once `stop()` is + * called, the user may call `cancelInflightRequests()` specifying a callback that can + * be used to submit a `close()` request, or may check for the number of inflight and + * canceling requests via `getInflightSize()` and `getCancelingSize()` methods, + * respectively, and issue the non-blocking `close()` of the worker once both return + * `0` or after a certain period of time has elapsed and the application cannot wait + * anymore for their completion. Note that `cancelInflightRequests()` callback is not + * guaranteed to be called, nor are `getCancelingSize()` and `getInflightSize()` + * guaranteed to go to `0` depending on the requests being handled, and thus the user is + * advised to provide a forceful termination mechanism in case the requests can never + * complete. * * This method returns a `std::shared` that can be later awaited and * checked for errors. This is a non-blocking operation, and the status of closing the @@ -885,11 +946,6 @@ class Endpoint : public Component { * in which case the callback will also execute immediately within the calling thread and * before the method returns. * - * @warning Unlike its `closeBlocking()` counterpart, this method does not cancel any - * inflight requests prior to submitting the UCP close request. Before scheduling the - * endpoint close request, the caller must first call `cancelInflightRequests()` and - * progress the worker until `getCancelingSize()` returns `0`. - * * @param[in] enablePythonFuture whether a python future should be created and * subsequently notified. * @param[in] callbackFunction user-defined callback function to call upon completion. @@ -928,6 +984,59 @@ class Endpoint : public Component { * if worker is running a progress thread and `period > 0`. */ void closeBlocking(uint64_t period = 0, uint64_t maxAttempts = 1); + + /** + * @brief Signal wish to close the endpoint, but does not close it. + * + * Signal the wish to close the endpoint without closing it such that no new requests can + * be issued that require the endpoint to complete. This method is useful when the user + * needs control of the non-blocking closing process with `close()`, and thus allow + * requests to complete before issuing the close request. Once this is called, the user + * may call `cancelInflightRequests()` specifying a callback that can be used to submit + * a `close()` request, or may check for the number of inflight and canceling requests + * via `getInflightSize()` and `getCancelingSize()` methods, respectively, and issue the + * non-blocking `close()` of the worker once both return `0` or after a certain period of + * time has elapsed and the application cannot wait anymore for their completion. Note + * that `cancelInflightRequests()` callback is not guaranteed to be called, nor are + * `getCancelingSize()` and `getInflightSize()` guaranteed to go to `0` depending on the + * requests being handled, and thus the user is advised to provide a forceful termination + * mechanism in case the requests can never complete. + * + * After this is called certain requests are not anymore accepted because they require a + * valid endpoint to complete. The requests that are not accepted are: + * + * - `amSend()` + * - `memGet()` + * - `memPut()` + * - `streamSend()` + * - `streamRecv()` + * - `tagSend()` + * - `tagMultiSend()` + * + * If the user attempts to call any of the methods above after this method a + * `ucxx::RejectError` exception is thrown. The user is also able to check whether this + * method was already called by checking the result of `isStopping()`. + * + * The following requests are handled by the underlying `ucxx::Worker` and only matched + * to the `ucxx::Endpoint` object and thus can still be called after the present method + * is called to allow draining them: + * + * - `amRecv()` + * - `tagRecv()` + * - `tagMultiRecv()` + */ + void stop(); + + /** + * @brief Check whether the endpoint was signaled the wish to close. + * + * Check whether the endpoint was signaled the wish to close after calling `stop()`. This + * result is useful to identify the stage where the endpoint finds itself, if this method + * returns `True`, the process to ultimately close the endpoint has begun. + * + * @return Whether the endpoint was signaled the wish to close. + */ + bool isStopping(); }; } // namespace ucxx diff --git a/cpp/include/ucxx/inflight_requests.h b/cpp/include/ucxx/inflight_requests.h index 0f9e13a03..9db9f7cf9 100644 --- a/cpp/include/ucxx/inflight_requests.h +++ b/cpp/include/ucxx/inflight_requests.h @@ -4,11 +4,14 @@ */ #pragma once +#include #include #include #include #include +#include + namespace ucxx { class Request; @@ -31,8 +34,8 @@ struct TrackedRequests { * Handle tracked requests, providing functionality so that its owner can modify those * requests, performing operations such as insertion, removal and cancelation. * - * Uses `std::unordered_map>` for O(1) amortized - * insert/remove that scales to thousands of concurrent inflight requests. + * Uses `std::unordered_set>` for O(1) amortized insert/remove that + * scales to thousands of concurrent inflight requests. */ class InflightRequests { private: @@ -40,6 +43,7 @@ class InflightRequests { std::unordered_set> _canceling{}; std::mutex _mutex{}; + std::atomic _cancelAllInProgress{false}; public: /** @@ -91,9 +95,19 @@ class InflightRequests { * be called when a request has completed and the `InflightRequests` owner does not need * to keep track of it anymore. * + * Supports an optional callback function to be called exclusively if there are no + * more requests inflight or canceling. Be advised that before the callback is called the + * mutex that controls inflight requests is released to prevent deadlocks in case the + * callback happens to register a new inflight request, therefore there's no guarantee + * that another inflight request won't be registered between the time in which the mutex + * is released and the callback is executed. + * * @param[in] request shared pointer to the request + * @param[in] callbackFunction function to be called upon termination and only if no + * further requests inflight or canceling remain. */ - void remove(const std::shared_ptr& request); + void remove(const std::shared_ptr& request, + VoidCallbackUserFunction callbackFunction = nullptr); /** * @brief Issue cancelation of all inflight requests and clear the internal container. @@ -101,9 +115,19 @@ class InflightRequests { * Issue cancelation of all inflight requests known to this object and clear the * internal container. The total number of canceled requests is returned. * + * Supports an optional callback function to be called exclusively if there are no + * more requests inflight or canceling. Be advised that before the callback is called the + * mutex that controls inflight requests is released to prevent deadlocks in case the + * callback happens to register a new inflight request, therefore there's no guarantee + * that another inflight request won't be registered between the time in which the mutex + * is released and the callback is executed. + * + * @param[in] callbackFunction function to be called upon termination and only if no + * further requests inflight or canceling remain. + * * @returns The total number of canceled requests. */ - size_t cancelAll(); + size_t cancelAll(VoidCallbackUserFunction callbackFunction = nullptr); /** * @brief Releases the internally-tracked containers. @@ -127,6 +151,16 @@ class InflightRequests { * @returns The count of requests that are in process of cancelation. */ [[nodiscard]] size_t getCancelingSize(); + + /** + * @brief Get count of inflight requests. + * + * Get the count of inflight requests that have not yet completed nor have been scheduled + * for cancelation. + * + * @returns The count of inflight requests. + */ + [[nodiscard]] size_t getInflightSize(); }; } // namespace ucxx diff --git a/cpp/include/ucxx/request.h b/cpp/include/ucxx/request.h index 6ad7607bd..6b8d8ed3d 100644 --- a/cpp/include/ucxx/request.h +++ b/cpp/include/ucxx/request.h @@ -16,6 +16,7 @@ #include #include #include +#include #define ucxx_trace_req_f(_owner, _req, _handle, _name, _message, ...) \ ucxx_trace_req("ucxx::Request: %p on %s, UCP handle: %p, op: %s, " _message, \ @@ -54,6 +55,10 @@ class Request : public Component { bool _enablePythonFuture{true}; ///< Whether Python future is enabled for this request RequestCallbackUserFunction _callback{nullptr}; ///< Completion callback RequestCallbackUserData _callbackData{nullptr}; ///< Completion callback data + DelayedSubmissionCallbackType + _cancelCallback{}; ///< Callback function to use when canceling in thread progress mode. + utils::CallbackNotifier _cancelCallbackNotifier{}; ///< Callback notifier to use when canceling + ///< in thread progress mode. /** * @brief Protected constructor of an abstract `ucxx::Request`. @@ -104,6 +109,27 @@ class Request : public Component { */ void setStatus(ucs_status_t status); + private: + /** + * @brief Remove reference to request from endpoint and worker. + * + * Remove the reference to the request from the endpoint and worker. This should be called + * when a request has completed and the parent `ucxx::Endpoint` or `ucxx::Worker` does not + * need to keep track of it anymore. This is called during `setStatus()` and also during + * `cancel()` in case the request was scheduled for cancelation while it completed, which + * may have duplicated the canceling request tracking. + */ + void removeInflightRequest(); + + /** + * @brief Implementation of the request cancelation. + * + * Cancel the request, called by `cancel()` which will submit it for execution from the + * worker progress thread when active as it is unsafe to do so from the application thread + * given it `ucp_request_cancel` requires the UCS spinlock. When no worker progress thread + */ + void cancelImpl(); + public: Request() = delete; Request(const Request&) = delete; diff --git a/cpp/include/ucxx/typedefs.h b/cpp/include/ucxx/typedefs.h index 63685fb3c..b642360e8 100644 --- a/cpp/include/ucxx/typedefs.h +++ b/cpp/include/ucxx/typedefs.h @@ -239,4 +239,14 @@ struct AmSendParams { */ typedef const std::string SerializedRemoteKey; +/** + * @brief A user-defined callback with no parameters or return value. + * + * Used where UCXX needs a simple notification hook, for example optional callbacks on + * `ucxx::InflightRequests::remove()` and `cancelAll()` when no inflight or canceling + * requests remain, and for endpoint-driven cancelation completion in + * `ucxx::Endpoint::cancelInflightRequests()`. + */ +typedef std::function VoidCallbackUserFunction; + } // namespace ucxx diff --git a/cpp/include/ucxx/utils/callback_notifier.h b/cpp/include/ucxx/utils/callback_notifier.h index 8256b4023..dccdec99b 100644 --- a/cpp/include/ucxx/utils/callback_notifier.h +++ b/cpp/include/ucxx/utils/callback_notifier.h @@ -1,7 +1,9 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ +#pragma once + #include #include #include diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 5d880c359..2b524dfde 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -402,9 +402,11 @@ std::shared_ptr Endpoint::registerInflightRequest(std::shared_ptr(request) == nullptr) { auto worker = ::ucxx::getWorker(_parent); worker->scheduleRequestCancel(_inflightRequests->release()); } @@ -414,10 +416,30 @@ std::shared_ptr Endpoint::registerInflightRequest(std::shared_ptr request) { - _inflightRequests->remove(request); + _inflightRequests->remove(request, _cancelInflightCallback); } -size_t Endpoint::cancelInflightRequests() { return _inflightRequests->cancelAll(); } +size_t Endpoint::cancelInflightRequests(VoidCallbackUserFunction callback) +{ + _cancelInflightCallbackOriginal = callback; + + /** + * Wrapper responsible for deregistering before callback is called. The member + * `_cancelInflightCallback` is moved into a local (`selfCopy`) before invoking the + * user callback so that any new inflight requests created during the callback (e.g., the + * close request) will NOT re-trigger this callback when they complete and are removed. + * `selfCopy` keeps the wrapper lambda alive for the duration of this invocation. + */ + _cancelInflightCallback = [this]() { + auto selfCopy = std::move(_cancelInflightCallback); + auto callback = std::exchange(_cancelInflightCallbackOriginal, nullptr); + if (callback) callback(); + }; + + auto canceled = _inflightRequests->cancelAll(_cancelInflightCallback); + + return canceled; +} size_t Endpoint::cancelInflightRequestsBlocking(uint64_t period, uint64_t maxAttempts) { @@ -458,6 +480,8 @@ size_t Endpoint::cancelInflightRequestsBlocking(uint64_t period, uint64_t maxAtt size_t Endpoint::getCancelingSize() const { return _inflightRequests->getCancelingSize(); } +size_t Endpoint::getInflightSize() const { return _inflightRequests->getInflightSize(); } + std::shared_ptr Endpoint::amSend( const void* const buffer, const size_t length, @@ -467,6 +491,8 @@ std::shared_ptr Endpoint::amSend( RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_stopping.load()) throw RejectedError("Endpoint is stopping."); + auto params = AmSendParams{}; params.memoryType = memoryType; params.receiverCallbackInfo = receiverCallbackInfo; @@ -481,6 +507,8 @@ std::shared_ptr Endpoint::amSend(const void* const buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestAm(endpoint, data::AmSend(buffer, length, params), @@ -495,6 +523,8 @@ std::shared_ptr Endpoint::amSend(std::vector iov, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestAm(endpoint, data::AmSend(std::move(iov), params), @@ -520,6 +550,8 @@ std::shared_ptr Endpoint::memGet(void* buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestMem(endpoint, data::MemGet(buffer, length, remoteAddr, rkey), @@ -536,6 +568,8 @@ std::shared_ptr Endpoint::memGet(void* buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestMem( endpoint, @@ -554,6 +588,8 @@ std::shared_ptr Endpoint::memPut(const void* const buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestMem(endpoint, data::MemPut(buffer, length, remoteAddr, rkey), @@ -570,6 +606,8 @@ std::shared_ptr Endpoint::memPut(const void* const buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestMem( endpoint, @@ -584,6 +622,8 @@ std::shared_ptr Endpoint::streamSend(const void* const buffer, size_t length, const bool enablePythonFuture) { + if (_stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest( createRequestStream(endpoint, data::StreamSend(buffer, length), enablePythonFuture)); @@ -593,6 +633,8 @@ std::shared_ptr Endpoint::streamRecv(void* buffer, size_t length, const bool enablePythonFuture) { + if (_stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest( createRequestStream(endpoint, data::StreamReceive(buffer, length), enablePythonFuture)); @@ -605,6 +647,8 @@ std::shared_ptr Endpoint::tagSend(const void* const buffer, RequestCallbackUserFunction callbackFunction, RequestCallbackUserData callbackData) { + if (_stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestTag(endpoint, data::TagSend(buffer, length, tag), @@ -635,6 +679,8 @@ std::shared_ptr Endpoint::tagMultiSend(const std::vector& const Tag tag, const bool enablePythonFuture) { + if (_stopping.load()) throw RejectedError("Endpoint is stopping."); + auto endpoint = std::dynamic_pointer_cast(shared_from_this()); return registerInflightRequest(createRequestTagMulti( endpoint, data::TagMultiSend(buffer, size, isCUDA, tag), enablePythonFuture)); @@ -660,4 +706,8 @@ std::shared_ptr Endpoint::flush(const bool enablePythonFuture, std::shared_ptr Endpoint::getWorker() { return ::ucxx::getWorker(_parent); } +void Endpoint::stop() { _stopping = true; } + +bool Endpoint::isStopping() { return _stopping.load(); } + } // namespace ucxx diff --git a/cpp/src/inflight_requests.cpp b/cpp/src/inflight_requests.cpp index 4c9b396e1..d5d4f2929 100644 --- a/cpp/src/inflight_requests.cpp +++ b/cpp/src/inflight_requests.cpp @@ -26,10 +26,24 @@ void InflightRequests::insert(const std::shared_ptr& request) _inflight.insert(request); } -void InflightRequests::remove(const std::shared_ptr& request) +void InflightRequests::remove(const std::shared_ptr& request, + VoidCallbackUserFunction callbackFunction) { - std::lock_guard lock(_mutex); - _inflight.erase(request); + bool shouldCallback = false; + { + std::lock_guard lock(_mutex); + _inflight.erase(request); + _canceling.erase(request); + if (!_cancelAllInProgress && callbackFunction && _inflight.empty() && _canceling.empty()) { + shouldCallback = true; + } + } + + if (shouldCallback) { + ucxx_trace( + "ucxx::InflightRequests::%s: %p, calling user cancel inflight callback", __func__, this); + callbackFunction(); + } } void InflightRequests::merge(TrackedRequests&& trackedRequests) @@ -41,7 +55,7 @@ void InflightRequests::merge(TrackedRequests&& trackedRequests) if (r) _canceling.insert(std::move(r)); } -size_t InflightRequests::cancelAll() +size_t InflightRequests::cancelAll(VoidCallbackUserFunction callbackFunction) { decltype(_inflight) toCancel; { @@ -50,19 +64,42 @@ size_t InflightRequests::cancelAll() } size_t total = toCancel.size(); - if (total == 0) return 0; + if (total == 0) { + bool shouldCallback = false; + { + std::lock_guard lock(_mutex); + if (callbackFunction && _inflight.empty() && _canceling.empty()) { shouldCallback = true; } + } + if (shouldCallback) { + ucxx_trace( + "ucxx::InflightRequests::%s: %p, calling user cancel inflight callback", __func__, this); + callbackFunction(); + } + return 0; + } ucxx_debug("ucxx::InflightRequests::%s, canceling %lu requests", __func__, total); + _cancelAllInProgress = true; for (auto& r : toCancel) { if (r) r->cancel(); } + _cancelAllInProgress = false; + bool shouldCallback = false; { std::lock_guard lock(_mutex); for (auto& r : toCancel) { - if (r && r->getStatus() == UCS_INPROGRESS) _canceling.insert(r); + if (r && r->getStatus() == UCS_INPROGRESS) + _canceling.insert(std::move(const_cast&>(r))); } + if (callbackFunction && _inflight.empty() && _canceling.empty()) { shouldCallback = true; } + } + + if (shouldCallback) { + ucxx_trace( + "ucxx::InflightRequests::%s: %p, calling user cancel inflight callback", __func__, this); + callbackFunction(); } return total; @@ -86,6 +123,12 @@ TrackedRequests InflightRequests::release() return result; } +size_t InflightRequests::getInflightSize() +{ + std::lock_guard lock(_mutex); + return _inflight.size(); +} + size_t InflightRequests::getCancelingSize() { std::lock_guard lock(_mutex); diff --git a/cpp/src/request.cpp b/cpp/src/request.cpp index 1551e4f23..297605b64 100644 --- a/cpp/src/request.cpp +++ b/cpp/src/request.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include @@ -68,10 +69,25 @@ Request::Request(std::shared_ptr endpointOrWorker, Request::~Request() { + if (_cancelCallback != nullptr) { + std::ignore = _cancelCallbackNotifier.wait(1000000000 /* 1s */); + _cancelCallback = nullptr; + } + if (UCS_PTR_IS_PTR(_request)) { + ucxx_warn("ucxx::Request (%s) freeing: %p", _operationName.c_str(), _request); + ucp_request_free(_request); + } ucxx_trace("ucxx::Request destroyed (%s): %p", _operationName.c_str(), this); } -void Request::cancel() +void Request::removeInflightRequest() +{ + auto requestPtr = std::dynamic_pointer_cast(shared_from_this()); + if (_endpoint != nullptr) _endpoint->removeInflightRequest(requestPtr); + _worker->removeInflightRequest(requestPtr); +} + +void Request::cancelImpl() { std::lock_guard lock(_mutex); if (_status == UCS_INPROGRESS) { @@ -85,8 +101,19 @@ void Request::cancel() status, ucs_status_string(status)); } else { - ucxx_trace_req_f(_ownerString.c_str(), this, _request, _operationName.c_str(), "canceling"); - if (_request != nullptr) ucp_request_cancel(_worker->getHandle(), _request); + if (_request != nullptr) { + ucxx_trace_req_f(_ownerString.c_str(), this, _request, _operationName.c_str(), "canceling"); + ucp_request_cancel(_worker->getHandle(), _request); + + /** + * Tag send requests cannot be canceled: https://github.com/openucx/ucx/issues/1162 + * This can be problematic for unmatched rendezvous tag send messages as it would + * otherwise not complete cancelation, so we forcefully "cancel" the requests which + * ultimately leads it reaching `ucp_request_free`. This currently causes the UCX + * warnings "was not returned to mpool ucp_requests", which is likely a UCX bug. + */ + if (_operationName == "tagSend") { setStatus(UCS_ERR_CANCELED); } + } } } else { ucxx_trace_req_f(_ownerString.c_str(), @@ -96,6 +123,35 @@ void Request::cancel() "already completed with status: %d (%s)", _status, ucs_status_string(_status)); + + /** + * Ensure the request is removed from the parent in case it got re-registered while it + * was completing. + */ + removeInflightRequest(); + } + + _cancelCallback = nullptr; +} + +void Request::cancel() +{ + if (_worker->isProgressThreadRunning()) { + _cancelCallback = [this]() { + /** + * FIXME: Check the callback hasn't run and object hasn't been destroyed. Long-term + * fix is to allow deregistering generic callbacks with the worker. + */ + if (_cancelCallback == nullptr) return; + + cancelImpl(); + _cancelCallbackNotifier.set(); + }; + // The Cancel callback is store in an attributed, thus we do not need to + // cancel it if it fails to run immediately. + std::ignore = _worker->registerGenericPre(_cancelCallback); + } else { + cancelImpl(); } } @@ -148,7 +204,10 @@ void Request::callback(void* request, ucs_status_t status) status, ucs_status_string(status)); - if (UCS_PTR_IS_PTR(_request)) ucp_request_free(request); + if (UCS_PTR_IS_PTR(_request)) { + ucp_request_free(request); + _request = nullptr; + } ucxx_trace_req_f(_ownerString.c_str(), this, _request, _operationName.c_str(), "completed"); setStatus(status); @@ -202,10 +261,6 @@ void Request::setStatus(ucs_status_t status) { std::lock_guard lock(_mutex); - auto requestPtr = std::dynamic_pointer_cast(shared_from_this()); - if (_endpoint != nullptr) _endpoint->removeInflightRequest(requestPtr); - _worker->removeInflightRequest(requestPtr); - ucxx_trace_req_f(_ownerString.c_str(), this, _request, @@ -214,15 +269,20 @@ void Request::setStatus(ucs_status_t status) status, ucs_status_string(status)); - if (_status != UCS_INPROGRESS) - ucxx_error( + if (_status != UCS_INPROGRESS) { + ucxx_warn( "ucxx::Request: %p, setStatus called with status: %d (%s) but status: %d (%s) was " - "already set", + "already set, ignoring", this, status, ucs_status_string(status), _status, ucs_status_string(_status)); + return; + } + + removeInflightRequest(); + _status = status; if (_enablePythonFuture) { diff --git a/cpp/src/request_am.cpp b/cpp/src/request_am.cpp index 3d4ad2dfb..dbea8ea4b 100644 --- a/cpp/src/request_am.cpp +++ b/cpp/src/request_am.cpp @@ -158,7 +158,9 @@ std::shared_ptr createRequestAm( callbackFunction, callbackData)); }; - return worker->getAmRecv(endpoint->getHandle(), createRequest); + auto req = worker->getAmRecv(endpoint->getHandle(), createRequest); + req->_endpoint = endpoint; + return req; }, }, requestData); @@ -363,14 +365,10 @@ ucs_status_t RequestAm::recvCallback(void* arg, buf->data(), length); - if (req->isCompleted()) { - // The request completed/errored immediately - ucs_status_t s = UCS_PTR_STATUS(status); - recvAmMessage->callback(nullptr, s); - - return s; - } else { - // The request will be handled by the callback + if (UCS_PTR_IS_PTR(status)) { + // The request is in progress, register for the completion callback. The + // recvAmMessage must be stored in the map to prevent use-after-free when + // the UCX callback fires. recvAmMessage->setUcpRequest(status); amData->_registerInflightRequest(req); @@ -380,6 +378,12 @@ ucs_status_t RequestAm::recvCallback(void* arg, } return UCS_INPROGRESS; + } else { + // The request completed/errored immediately + ucs_status_t s = UCS_PTR_RAW_STATUS(status); + recvAmMessage->callback(nullptr, s); + + return s; } } else { buf = amData->_allocators.at(UCS_MEMORY_TYPE_HOST)(length); diff --git a/cpp/src/request_endpoint_close.cpp b/cpp/src/request_endpoint_close.cpp index 3e1175824..629fe29ed 100644 --- a/cpp/src/request_endpoint_close.cpp +++ b/cpp/src/request_endpoint_close.cpp @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #include "ucxx/request_data.h" @@ -71,7 +71,10 @@ void RequestEndpointClose::request() ucp_request_param_t param = { .op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA, .user_data = this}; - if (std::get(_requestData)._force) param.flags = UCP_EP_CLOSE_FLAG_FORCE; + if (std::get(_requestData)._force) { + param.op_attr_mask |= UCP_OP_ATTR_FIELD_FLAGS; + param.flags = UCP_EP_CLOSE_FLAG_FORCE; + } param.cb.send = endpointCloseCallback; if (_endpoint != nullptr) request = ucp_ep_close_nbx(_endpoint->getHandle(), ¶m); diff --git a/cpp/src/worker.cpp b/cpp/src/worker.cpp index 29701172c..a0d0da5d2 100644 --- a/cpp/src/worker.cpp +++ b/cpp/src/worker.cpp @@ -538,7 +538,7 @@ size_t Worker::cancelInflightRequests(uint64_t period, uint64_t maxAttempts) if (inflightRequestsToCancel->getCancelingSize() > 0) { std::lock_guard lock(_inflightRequestsMutex); - _inflightRequestsToCancel->merge(inflightRequestsToCancel->release()); + _inflightRequestsToCancel->merge(std::move(inflightRequestsToCancel->release())); } return canceled; diff --git a/cpp/tests/endpoint.cpp b/cpp/tests/endpoint.cpp index 16fa43708..08b76543f 100644 --- a/cpp/tests/endpoint.cpp +++ b/cpp/tests/endpoint.cpp @@ -1,16 +1,32 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ +#include #include +#include +#include +#include #include +#include #include #include +#include +#include + +#include "include/utils.h" +#include "ucxx/exception.h" +#include "ucxx/typedefs.h" + namespace { +using ::testing::Combine; +using ::testing::ContainerEq; +using ::testing::Values; + class EndpointTest : public ::testing::Test { protected: std::shared_ptr _context{ @@ -27,12 +43,89 @@ class EndpointTest : public ::testing::Test { } }; +enum class TransferType { Am, Tag, Stream }; +typedef std::vector> RequestContainer; + +class EndpointCancelTest + : public ::testing::TestWithParam> { + protected: + std::shared_ptr _context{nullptr}; + std::shared_ptr _remoteContext{nullptr}; + std::shared_ptr _worker{nullptr}; + std::shared_ptr _remoteWorker{nullptr}; + ProgressMode _progressMode{}; + TransferType _transferType{}; + size_t _messageSize{}; + RequestContainer _requests{}; + std::vector _send{}, _recv{}; + bool _rndv{false}; + + virtual void SetUp() + { + std::tie(_progressMode, _transferType, _messageSize, _rndv) = GetParam(); + + _send.resize(_messageSize); + _recv.resize(_messageSize, 0); + std::iota(_send.begin(), _send.end(), 0); + + _context = ucxx::createContext({}, ucxx::Context::defaultFeatureFlags); + _remoteContext = ucxx::createContext({}, ucxx::Context::defaultFeatureFlags); + _worker = _context->createWorker(); + _remoteWorker = _remoteContext->createWorker(); + } + + RequestContainer buildPair(std::shared_ptr sendEp, + std::shared_ptr recvEp) + { + if (_transferType == TransferType::Tag) { + return RequestContainer{ + sendEp->tagSend(_send.data(), _send.size() * sizeof(int), ucxx::Tag{0}), + recvEp->tagRecv(_recv.data(), _recv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)}; + } else if (_transferType == TransferType::Am) { + return RequestContainer{ + sendEp->amSend(_send.data(), _send.size() * sizeof(int), UCS_MEMORY_TYPE_HOST), + recvEp->amRecv()}; + } else if (_transferType == TransferType::Stream) { + return RequestContainer{sendEp->streamSend(_send.data(), _send.size() * sizeof(int)), + recvEp->streamRecv(_recv.data(), _recv.size() * sizeof(int))}; + } + return RequestContainer{}; + } +}; + +static void wireup(std::shared_ptr ep1, + std::shared_ptr ep2, + std::function progressWorker) +{ + // wireup + std::vector wireupSend(1, 99); + std::vector wireupRecv(wireupSend.size(), 0); + + std::vector> wireupRequests; + wireupRequests.push_back( + ep1->tagSend(wireupSend.data(), wireupSend.size() * sizeof(int), ucxx::Tag{0})); + wireupRequests.push_back(ep2->tagRecv( + wireupRecv.data(), wireupRecv.size() * sizeof(int), ucxx::Tag{0}, ucxx::TagMaskFull)); + + while (!wireupRequests[0]->isCompleted() || !wireupRequests[1]->isCompleted()) + progressWorker(); + + ASSERT_EQ(wireupRequests[0]->getStatus(), UCS_OK); + ASSERT_EQ(wireupRequests[1]->getStatus(), UCS_OK); + ASSERT_THAT(wireupRecv, ContainerEq(wireupSend)); +} + +static size_t countIncomplete(const std::vector>& requests) +{ + return std::count_if(requests.begin(), requests.end(), [](auto r) { return !r->isCompleted(); }); +} + TEST_F(EndpointTest, HandleIsValid) { auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); _worker->progress(); - ASSERT_TRUE(ep->getHandle() != nullptr); + ASSERT_NE(ep->getHandle(), nullptr); } TEST_F(EndpointTest, IsAlive) @@ -56,4 +149,190 @@ TEST_F(EndpointTest, IsAlive) ASSERT_FALSE(ep->isAlive()); } +TEST_F(EndpointTest, StoppingRejectRequests) +{ + auto progressWorker = getProgressFunction(_worker, ProgressMode::Blocking); + + auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + while (!ep->isAlive()) + progressWorker(); + + wireup(ep, ep, progressWorker); + + ep->stop(); + + ASSERT_TRUE(ep->isStopping()); + + std::vector tmp(10, 42); + + EXPECT_THROW( + static_cast(ep->amSend(tmp.data(), tmp.size() * sizeof(int), UCS_MEMORY_TYPE_HOST)), + ucxx::RejectedError); + EXPECT_THROW(static_cast(ep->tagSend(tmp.data(), tmp.size() * sizeof(int), ucxx::Tag{0})), + ucxx::RejectedError); + EXPECT_THROW(static_cast(ep->streamRecv(tmp.data(), tmp.size() * sizeof(int))), + ucxx::RejectedError); + EXPECT_THROW(static_cast(ep->streamSend(tmp.data(), tmp.size() * sizeof(int))), + ucxx::RejectedError); + + { + auto memoryHandle = _context->createMemoryHandle(tmp.size() * sizeof(int), nullptr); + + auto localRemoteKey = memoryHandle->createRemoteKey(); + auto serializedRemoteKey = localRemoteKey->serialize(); + auto remoteKey = ucxx::createRemoteKeyFromSerialized(ep, serializedRemoteKey); + + std::vector> requests; + EXPECT_THROW(static_cast(ep->memPut(tmp.data(), tmp.size() * sizeof(int), remoteKey)), + ucxx::RejectedError); + EXPECT_THROW( + static_cast(ep->memPut( + tmp.data(), tmp.size() * sizeof(int), remoteKey->getBaseAddress(), remoteKey->getHandle())), + ucxx::RejectedError); + EXPECT_THROW(static_cast(ep->memGet(tmp.data(), tmp.size() * sizeof(int), remoteKey)), + ucxx::RejectedError); + EXPECT_THROW( + static_cast(ep->memGet( + tmp.data(), tmp.size() * sizeof(int), remoteKey->getBaseAddress(), remoteKey->getHandle())), + ucxx::RejectedError); + } + + { + std::vector buffers{tmp.data()}; + std::vector sizes{tmp.size()}; + std::vector isCUDA{false}; + EXPECT_THROW(static_cast(ep->tagMultiSend(buffers, sizes, isCUDA, ucxx::Tag{0})), + ucxx::RejectedError); + } +} + +TEST_P(EndpointCancelTest, StoppingWaitCompletionThenCancel) +{ + if (_transferType == TransferType::Stream && _messageSize == 0) + GTEST_SKIP() << "Stream messages of size 0 are not supported."; + + // Get appropriate progress worker function depending on selected mode + auto progressWorker = getProgressFunction(_worker, _progressMode); + if (_progressMode == ProgressMode::ThreadPolling) + _worker->startProgressThread(true); + else if (_progressMode == ProgressMode::ThreadBlocking) + _worker->startProgressThread(false); + + // Create endpoint + auto ep = _worker->createEndpointFromWorkerAddress(_worker->getAddress()); + while (!ep->isAlive()) + progressWorker(); + + // Perform endpoint wireup + wireup(ep, ep, progressWorker); + + // Submit transfer requests + auto requests = buildPair(ep, ep); + + auto checkAfterSubmit = [this, &requests, &ep]() { + // Check requests completion statuses + if (_rndv) { + std::for_each( + requests.begin(), requests.end(), [](auto r) { ASSERT_FALSE(r->isCompleted()); }); + } else { + ASSERT_TRUE(requests[0]->isCompleted()); + // In thread progress mode it's not possible to determine completion time + if (_progressMode != ProgressMode::ThreadBlocking && + _progressMode != ProgressMode::ThreadPolling) + ASSERT_FALSE(requests[1]->isCompleted()); + } + + // Check no requests are being canceled + ASSERT_EQ(ep->getCancelingSize(), 0); + + // Check which requests are inflight or completed + if (_rndv) { + ASSERT_EQ(ep->getInflightSize(), requests.size()); + } else { + // Eager send request completes immediately, receive request still inflight, except + // for thread progress mode where it's not possible to determine completion time + if (_progressMode != ProgressMode::ThreadBlocking && + _progressMode != ProgressMode::ThreadPolling) + ASSERT_EQ(ep->getInflightSize(), 1); + } + }; + + // Check request statuses before stopping endpoint + checkAfterSubmit(); + + // Stop accepting new requests, except those handled by the worker + ep->stop(); + ASSERT_TRUE(ep->isStopping()); + + // Check that requests statuses haven't changed + checkAfterSubmit(); + + bool cancelationComplete = false; + auto cancelInflightCallback = [&ep, &progressWorker, &cancelationComplete]() { + // No more inflight or canceling requests, closing the endpoint is safe + auto close = ep->close(); + ASSERT_NE(close, nullptr); + while (!close->isCompleted()) + progressWorker(); + ASSERT_EQ(close->getStatus(), UCS_OK); + + cancelationComplete = true; + }; + + // Cancel inflight requests + ep->cancelInflightRequests(cancelInflightCallback); + ASSERT_EQ(ep->getCancelingSize(), countIncomplete(requests)); + ASSERT_EQ(ep->getInflightSize(), 0); + + // Wait for canceling requests to complete and `cancelInflightCallback` to run + while (!cancelationComplete) + progressWorker(); + + // `cancelInflightCallback` executed an the endpoint should be closed now. + ASSERT_FALSE(ep->isAlive()); + + // Check all requests have been canceled or completed + std::for_each(requests.begin(), requests.end(), [](auto r) { + auto status = r->getStatus(); + ASSERT_TRUE(status == UCS_ERR_CANCELED || status == UCS_OK); + }); + + // Check received message, if it wasn't canceled + if (requests[1]->getStatus() == UCS_OK) { + // Copy AM results back into a `std::vector` which can be checked with `ASSERT_THAT` + if (_transferType == TransferType::Am) { + auto recvBuffer = requests[1]->getRecvBuffer(); + std::copy(reinterpret_cast(recvBuffer->data()), + reinterpret_cast(recvBuffer->data()) + _recv.size(), + _recv.begin()); + } + + ASSERT_THAT(_recv, ContainerEq(_send)); + } + + // Check no more tracked requests exist + ASSERT_EQ(ep->getCancelingSize(), 0); + ASSERT_EQ(ep->getInflightSize(), 0); +} + +INSTANTIATE_TEST_SUITE_P(Eager, + EndpointCancelTest, + Combine(Values(ProgressMode::Polling, + ProgressMode::Blocking, + ProgressMode::ThreadPolling, + ProgressMode::ThreadBlocking), + Values(TransferType::Tag, TransferType::Am, TransferType::Stream), + Values(0, 1, 10), + Values(false))); + +INSTANTIATE_TEST_SUITE_P(Rndv, + EndpointCancelTest, + Combine(Values(ProgressMode::Polling, + ProgressMode::Blocking, + ProgressMode::ThreadPolling, + ProgressMode::ThreadBlocking), + Values(TransferType::Tag, TransferType::Am, TransferType::Stream), + Values(10485760, 104857600), + Values(true))); + } // namespace diff --git a/cpp/tests/include/utils.h b/cpp/tests/include/utils.h index b98d59731..34cddc087 100644 --- a/cpp/tests/include/utils.h +++ b/cpp/tests/include/utils.h @@ -1,5 +1,5 @@ /** - * SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. + * SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. * SPDX-License-Identifier: BSD-3-Clause */ #pragma once @@ -40,6 +40,15 @@ inline void waitRequests(std::shared_ptr worker, } } +template +inline void waitSingleRequest(const std::shared_ptr& request, + const std::function& progressWorker) +{ + while (!request->isCompleted()) + if (progressWorker) progressWorker(); + request->checkError(); +} + std::function getProgressFunction(std::shared_ptr worker, ProgressMode progressMode);