From f1d00ef7b4b6065562dbdd71b5af25893b005e07 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 6 Apr 2023 06:55:57 -0700 Subject: [PATCH 1/5] Add endpoint error handler flag to Python async listener constructor --- python/ucxx/_lib/libucxx.pyx | 4 +++- python/ucxx/_lib_async/application_context.py | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/python/ucxx/_lib/libucxx.pyx b/python/ucxx/_lib/libucxx.pyx index 98ae657ad..8e2853d0c 100644 --- a/python/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/_lib/libucxx.pyx @@ -1080,7 +1080,7 @@ cdef void _listener_callback(ucp_conn_request_h conn_request, void *args) with g cb_data['cb_func']( ( cb_data['listener'].create_endpoint_from_conn_request( - int(conn_request), True + int(conn_request), cb_data['endpoint_error_handling'] ) if 'listener' in cb_data else int(conn_request) ), @@ -1112,6 +1112,7 @@ cdef class UCXListener(): cls, UCXWorker worker, uint16_t port, + bint endpoint_error_handling, cb_func, tuple cb_args=None, dict cb_kwargs=None, @@ -1130,6 +1131,7 @@ cdef class UCXListener(): "cb_func": cb_func, "cb_args": cb_args, "cb_kwargs": cb_kwargs, + "endpoint_error_handling": endpoint_error_handling, } with nogil: diff --git a/python/ucxx/_lib_async/application_context.py b/python/ucxx/_lib_async/application_context.py index 2091d531c..be66cce5b 100644 --- a/python/ucxx/_lib_async/application_context.py +++ b/python/ucxx/_lib_async/application_context.py @@ -205,6 +205,7 @@ def create_listener( ucx_api.UCXListener.create( worker=self.worker, port=port, + endpoint_error_handling=endpoint_error_handling, cb_func=_listener_handler, cb_args=( loop, From 8a9c91e33660ef99209d4c3682040145b56b57fd Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Thu, 6 Apr 2023 09:41:00 -0700 Subject: [PATCH 2/5] Set listener's endpoint error handler in Python tests/benchmarks --- python/ucxx/_lib/tests/test_cancel.py | 2 +- python/ucxx/_lib/tests/test_endpoint.py | 2 +- python/ucxx/_lib/tests/test_listener.py | 2 +- python/ucxx/_lib/tests/test_probe.py | 2 +- python/ucxx/_lib/tests/test_server_client.py | 2 +- python/ucxx/benchmarks/backends/ucxx_core.py | 5 ++++- 6 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/ucxx/_lib/tests/test_cancel.py b/python/ucxx/_lib/tests/test_cancel.py index 865aaf691..196259abe 100644 --- a/python/ucxx/_lib/tests/test_cancel.py +++ b/python/ucxx/_lib/tests/test_cancel.py @@ -27,7 +27,7 @@ def _listener_handler(conn_request): ep[0] = listener.create_endpoint_from_conn_request(conn_request, True) listener = ucx_api.UCXListener.create( - worker=worker, port=0, cb_func=_listener_handler + worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler ) queue.put(listener.port) diff --git a/python/ucxx/_lib/tests/test_endpoint.py b/python/ucxx/_lib/tests/test_endpoint.py index 6717c1bab..997023a67 100644 --- a/python/ucxx/_lib/tests/test_endpoint.py +++ b/python/ucxx/_lib/tests/test_endpoint.py @@ -42,7 +42,7 @@ def _listener_handler(conn_request): listener_finished[0] = True listener = ucx_api.UCXListener.create( - worker=worker, port=0, cb_func=_listener_handler + worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler ) queue.put(listener.port) diff --git a/python/ucxx/_lib/tests/test_listener.py b/python/ucxx/_lib/tests/test_listener.py index d814372ba..d4fe1c86e 100644 --- a/python/ucxx/_lib/tests/test_listener.py +++ b/python/ucxx/_lib/tests/test_listener.py @@ -12,7 +12,7 @@ def _listener_handler(conn_request): pass listener = ucx_api.UCXListener.create( - worker=worker, port=0, cb_func=_listener_handler + worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler ) assert isinstance(listener.ip, str) and listener.ip diff --git a/python/ucxx/_lib/tests/test_probe.py b/python/ucxx/_lib/tests/test_probe.py index 5cb742236..1652993c1 100644 --- a/python/ucxx/_lib/tests/test_probe.py +++ b/python/ucxx/_lib/tests/test_probe.py @@ -32,7 +32,7 @@ def _listener_handler(conn_request): ) listener = ucx_api.UCXListener.create( - worker=worker, port=0, cb_func=_listener_handler + worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler ) queue.put(listener.port) diff --git a/python/ucxx/_lib/tests/test_server_client.py b/python/ucxx/_lib/tests/test_server_client.py index d7cbcf52a..4426e3938 100644 --- a/python/ucxx/_lib/tests/test_server_client.py +++ b/python/ucxx/_lib/tests/test_server_client.py @@ -60,7 +60,7 @@ def _listener_handler(conn_request): ep[0] = listener.create_endpoint_from_conn_request(conn_request, True) listener = ucx_api.UCXListener.create( - worker=worker, port=0, cb_func=_listener_handler + worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler ) put_queue.put(listener.port) diff --git a/python/ucxx/benchmarks/backends/ucxx_core.py b/python/ucxx/benchmarks/backends/ucxx_core.py index 9be1becd3..33cd2245a 100644 --- a/python/ucxx/benchmarks/backends/ucxx_core.py +++ b/python/ucxx/benchmarks/backends/ucxx_core.py @@ -145,7 +145,10 @@ def _listener_handler(conn_request): ep = listener.create_endpoint_from_conn_request(conn_request, True) listener = ucx_api.UCXListener.create( - worker=worker, port=self.args.port or 0, cb_func=_listener_handler + worker=worker, + port=self.args.port or 0, + endpoint_error_handling=True, + cb_func=_listener_handler, ) self.queue.put(listener.port) From 897f4c7f48d1d8bffa78dee372b01b5c0e9b0035 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Wed, 19 Apr 2023 13:04:21 -0700 Subject: [PATCH 3/5] Fix flags for disabled endpoint error handling --- cpp/src/endpoint.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index 92eda07ef..47499d832 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -66,8 +66,9 @@ std::shared_ptr createEndpointFromHostname(std::shared_ptr wor struct hostent* hostname = gethostbyname(ipAddress.c_str()); if (hostname == nullptr) throw ucxx::Error(std::string("Invalid IP address or hostname")); - params->field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_SOCK_ADDR | - UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; + params->field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_SOCK_ADDR; + if (endpointErrorHandling) + params->field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; params->flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER; if (ucxx::utils::sockaddr_set(¶ms->sockaddr, hostname->h_name, port)) throw std::bad_alloc(); @@ -82,8 +83,9 @@ std::shared_ptr createEndpointFromConnRequest(std::shared_ptr(new ucp_ep_params_t); - params->field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_CONN_REQUEST | - UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; + params->field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_CONN_REQUEST; + if (endpointErrorHandling) + params->field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; params->flags = UCP_EP_PARAMS_FLAGS_NO_LOOPBACK; params->conn_request = connRequest; @@ -101,8 +103,9 @@ std::shared_ptr createEndpointFromWorkerAddress(std::shared_ptr(new ucp_ep_params_t); - params->field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS | UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | - UCP_EP_PARAM_FIELD_ERR_HANDLER; + params->field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; + if (endpointErrorHandling) + params->field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; params->address = address->getHandle(); return std::shared_ptr(new Endpoint(worker, std::move(params), endpointErrorHandling)); @@ -122,7 +125,7 @@ void Endpoint::close() ucxx_debug("Endpoint %p canceled %lu requests", _handle, canceled); // Close the endpoint - unsigned closeMode = UCP_EP_CLOSE_MODE_FORCE; + unsigned closeMode = UCP_EP_CLOSE_MODE_FLUSH; if (_endpointErrorHandling && _callbackData->status != UCS_OK) { // We force close endpoint if endpoint error handling is enabled and // the endpoint status is not UCS_OK From 531fb2b41ebbc85df795d95dd4d023f2534d7474 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 11 Aug 2023 02:05:23 -0700 Subject: [PATCH 4/5] Fix construction of `ucp_ep_params_t` --- cpp/src/endpoint.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/cpp/src/endpoint.cpp b/cpp/src/endpoint.cpp index d8dfb4688..a3de98bd4 100644 --- a/cpp/src/endpoint.cpp +++ b/cpp/src/endpoint.cpp @@ -86,11 +86,9 @@ std::shared_ptr createEndpointFromHostname(std::shared_ptr wor auto info = ucxx::utils::get_addrinfo(ipAddress.c_str(), port); ucp_ep_params_t params = {.field_mask = UCP_EP_PARAM_FIELD_FLAGS | UCP_EP_PARAM_FIELD_SOCK_ADDR, .flags = UCP_EP_PARAMS_FLAGS_CLIENT_SERVER, - .sockaddr = {.addrlen = info->ai_addrlen, .addr = info->ai_addr}}; + .sockaddr = {.addr = info->ai_addr, .addrlen = info->ai_addrlen}}; if (endpointErrorHandling) - params->field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; - - if (ucxx::utils::sockaddr_set(¶ms->sockaddr, hostname->h_name, port)) throw std::bad_alloc(); + params.field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; return std::shared_ptr(new Endpoint(worker, ¶ms, endpointErrorHandling)); } @@ -107,7 +105,7 @@ std::shared_ptr createEndpointFromConnRequest(std::shared_ptrfield_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; + params.field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; return std::shared_ptr(new Endpoint(listener, ¶ms, endpointErrorHandling)); } @@ -124,7 +122,7 @@ std::shared_ptr createEndpointFromWorkerAddress(std::shared_ptrgetHandle()}; if (endpointErrorHandling) - params->field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; + params.field_mask |= UCP_EP_PARAM_FIELD_ERR_HANDLING_MODE | UCP_EP_PARAM_FIELD_ERR_HANDLER; return std::shared_ptr(new Endpoint(worker, ¶ms, endpointErrorHandling)); } From 9b554b5b47e4ba940eede5136c53058fa8eaf113 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Fri, 25 Apr 2025 07:30:48 -0700 Subject: [PATCH 5/5] Remove duplicate test file --- python/ucxx/_lib/tests/test_probe.py | 93 ---------------------------- 1 file changed, 93 deletions(-) delete mode 100644 python/ucxx/_lib/tests/test_probe.py diff --git a/python/ucxx/_lib/tests/test_probe.py b/python/ucxx/_lib/tests/test_probe.py deleted file mode 100644 index 94c7801e2..000000000 --- a/python/ucxx/_lib/tests/test_probe.py +++ /dev/null @@ -1,93 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. -# SPDX-License-Identifier: BSD-3-Clause - -import multiprocessing as mp - -from ucxx._lib import libucxx as ucx_api -from ucxx._lib.arr import Array -from ucxx.testing import terminate_process, wait_requests - -mp = mp.get_context("spawn") - -WireupMessage = bytearray(b"wireup") -DataMessage = bytearray(b"0" * 10) - - -def _server_probe(queue): - """Server that probes and receives message after client disconnected. - - Note that since it is illegal to call progress() in callback functions, - we keep a reference to the endpoint after the listener callback has - terminated, this way we can progress even after Python blocking calls. - """ - ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) - worker = ucx_api.UCXWorker(ctx) - - # Keep endpoint to be used from outside the listener callback - ep = [None] - - def _listener_handler(conn_request): - ep[0] = listener.create_endpoint_from_conn_request( - conn_request, endpoint_error_handling=True - ) - - listener = ucx_api.UCXListener.create( - worker=worker, port=0, endpoint_error_handling=True, cb_func=_listener_handler - ) - queue.put(listener.port) - - while ep[0] is None: - worker.progress() - - ep = ep[0] - - # Ensure wireup and inform client before it can disconnect - wireup = bytearray(len(WireupMessage)) - wait_requests(worker, "blocking", ep.tag_recv(Array(wireup), tag=0)) - queue.put("wireup completed") - - # Ensure client has disconnected -- endpoint is not alive anymore - while ep.is_alive() is True: - worker.progress() - - # Probe/receive message even after the remote endpoint has disconnected - while worker.tag_probe(0) is False: - worker.progress() - received = bytearray(len(DataMessage)) - wait_requests(worker, "blocking", ep.tag_recv(Array(received), tag=0)) - - assert wireup == WireupMessage - assert received == DataMessage - - -def _client_probe(queue): - ctx = ucx_api.UCXContext(feature_flags=(ucx_api.Feature.TAG,)) - worker = ucx_api.UCXWorker(ctx) - port = queue.get() - ep = ucx_api.UCXEndpoint.create( - worker, - "127.0.0.1", - port, - endpoint_error_handling=True, - ) - - requests = [ - ep.tag_send(Array(WireupMessage), tag=0), - ep.tag_send(Array(DataMessage), tag=0), - ] - wait_requests(worker, "blocking", requests) - - # Wait for wireup before disconnecting - assert queue.get() == "wireup completed" - - -def test_message_probe(): - queue = mp.Queue() - server = mp.Process(target=_server_probe, args=(queue,)) - server.start() - client = mp.Process(target=_client_probe, args=(queue,)) - client.start() - client.join(timeout=10) - server.join(timeout=10) - terminate_process(client) - terminate_process(server)