From 6629e96837abf8817c78a739bad42dae158b6246 Mon Sep 17 00:00:00 2001 From: Peter Andreas Entschev Date: Mon, 11 Nov 2024 13:35:18 -0800 Subject: [PATCH] Switch to AM in `exchange_peer_info()` --- .../ucxx/_lib_async/application_context.py | 2 +- .../ucxx/_lib_async/exchange_peer_info.py | 22 +++++++++---------- python/ucxx/ucxx/_lib_async/listener.py | 2 +- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python/ucxx/ucxx/_lib_async/application_context.py b/python/ucxx/ucxx/_lib_async/application_context.py index 4a4883094..27e22c74e 100644 --- a/python/ucxx/ucxx/_lib_async/application_context.py +++ b/python/ucxx/ucxx/_lib_async/application_context.py @@ -365,7 +365,7 @@ async def create_endpoint( msg_tag=msg_tag, ctrl_tag=ctrl_tag, listener=False, - stream_timeout=exchange_peer_info_timeout, + timeout=exchange_peer_info_timeout, ) except UCXMessageTruncatedError as e: # A truncated message occurs if the remote endpoint closed before diff --git a/python/ucxx/ucxx/_lib_async/exchange_peer_info.py b/python/ucxx/ucxx/_lib_async/exchange_peer_info.py index 5995605eb..0492554db 100644 --- a/python/ucxx/ucxx/_lib_async/exchange_peer_info.py +++ b/python/ucxx/ucxx/_lib_async/exchange_peer_info.py @@ -13,28 +13,28 @@ logger = logging.getLogger("ucx") -async def exchange_peer_info(endpoint, msg_tag, ctrl_tag, listener, stream_timeout=5.0): +async def exchange_peer_info(endpoint, msg_tag, ctrl_tag, listener, timeout=5.0): """Help function that exchange endpoint information""" # Pack peer information incl. a checksum fmt = "QQQ" my_info = struct.pack(fmt, msg_tag, ctrl_tag, hash64bits(msg_tag, ctrl_tag)) - peer_info = bytearray(len(my_info)) my_info_arr = Array(my_info) - peer_info_arr = Array(peer_info) # Send/recv peer information. Notice, we force an `await` between the two # streaming calls (see ) if listener is True: - req = endpoint.stream_send(my_info_arr) - await asyncio.wait_for(req.wait(), timeout=stream_timeout) - req = endpoint.stream_recv(peer_info_arr) - await asyncio.wait_for(req.wait(), timeout=stream_timeout) + req = endpoint.am_send(my_info_arr) + await asyncio.wait_for(req.wait(), timeout=timeout) + req = endpoint.am_recv() + await asyncio.wait_for(req.wait(), timeout=timeout) + peer_info = req.recv_buffer else: - req = endpoint.stream_recv(peer_info_arr) - await asyncio.wait_for(req.wait(), timeout=stream_timeout) - req = endpoint.stream_send(my_info_arr) - await asyncio.wait_for(req.wait(), timeout=stream_timeout) + req = endpoint.am_recv() + await asyncio.wait_for(req.wait(), timeout=timeout) + peer_info = req.recv_buffer + req = endpoint.am_send(my_info_arr) + await asyncio.wait_for(req.wait(), timeout=timeout) # Unpacking and sanity check of the peer information ret = {} diff --git a/python/ucxx/ucxx/_lib_async/listener.py b/python/ucxx/ucxx/_lib_async/listener.py index 9531d86c3..50dd318b0 100644 --- a/python/ucxx/ucxx/_lib_async/listener.py +++ b/python/ucxx/ucxx/_lib_async/listener.py @@ -146,7 +146,7 @@ async def _listener_handler_coroutine( msg_tag=msg_tag, ctrl_tag=ctrl_tag, listener=True, - stream_timeout=exchange_peer_info_timeout, + timeout=exchange_peer_info_timeout, ) except UCXMessageTruncatedError: # A truncated message occurs if the remote endpoint closed before