diff --git a/docs/ucxx/source/api.rst b/docs/ucxx/source/api.rst index e5c6e1e87..eb9a76d85 100644 --- a/docs/ucxx/source/api.rst +++ b/docs/ucxx/source/api.rst @@ -26,6 +26,10 @@ API _lib_async.Endpoint.close_after_n_recv _lib_async.Endpoint.get_ucp_endpoint _lib_async.Endpoint.get_ucp_worker + _lib_async.Endpoint.am_recv + _lib_async.Endpoint.am_recv_with_header + _lib_async.Endpoint.am_send + _lib_async.Endpoint.am_send_iov _lib_async.Endpoint.recv _lib_async.Endpoint.send _lib_async.Endpoint.uid diff --git a/python/ucxx/ucxx/_lib/libucxx.pyx b/python/ucxx/ucxx/_lib/libucxx.pyx index 32cb5ece5..c8c26e2fe 100644 --- a/python/ucxx/ucxx/_lib/libucxx.pyx +++ b/python/ucxx/ucxx/_lib/libucxx.pyx @@ -25,7 +25,6 @@ from libcpp.memory cimport ( shared_ptr, static_pointer_cast, ) -from libcpp.optional cimport nullopt from libcpp.string cimport string from libcpp.string_view cimport string_view from libcpp.utility cimport move @@ -340,6 +339,11 @@ class Feature(enum.Enum): AM = UCP_FEATURE_AM +class PythonAmSendMemoryTypePolicy(enum.Enum): + FallbackToHost = AmSendMemoryTypePolicy.FallbackToHost + ErrorOnUnsupported = AmSendMemoryTypePolicy.ErrorOnUnsupported + + class PythonRequestNotifierWaitState(enum.Enum): Ready = RequestNotifierWaitState.Ready Timeout = RequestNotifierWaitState.Timeout @@ -994,6 +998,20 @@ cdef class UCXRequest(): elif bufType == BufferType.Host: return _get_host_buffer(buf.get()) + @property + def recv_header(self) -> bytes: + """Get the user-defined header from an AM receive request. + + Returns the opaque header bytes sent by the peer. Returns empty bytes + if no user header was sent or for non-AM requests. + """ + cdef string header + + with nogil: + header = self._request.get().getRecvHeader() + + return header + def is_completed(self) -> bool: warnings.warn( "UCXRequest.is_completed() is deprecated and will soon be removed, " @@ -1451,21 +1469,99 @@ cdef class UCXEndpoint(): return ep_matched - def am_send(self, Array arr) -> UCXRequest: + def am_send( + self, Array arr, memory_type_policy=None, user_header=None + ) -> UCXRequest: cdef void* buf = arr.ptr cdef size_t nbytes = arr.nbytes cdef bint cuda_array = arr.cuda cdef shared_ptr[Request] req + cdef AmSendParams params + cdef bytes user_header_bytes + cdef const char* user_header_ptr if not self._context_feature_flags & Feature.AM.value: raise ValueError("UCXContext must be created with `Feature.AM`") + params.memoryType = ( + UCS_MEMORY_TYPE_CUDA if cuda_array else UCS_MEMORY_TYPE_HOST + ) + if memory_type_policy is not None: + params.memoryTypePolicy = ( + memory_type_policy.value + ) + if user_header is not None: + if not isinstance(user_header, bytes): + raise TypeError("user_header must be bytes") + user_header_bytes = user_header + user_header_ptr = user_header_bytes + params.setUserHeader(user_header_ptr, len(user_header_bytes)) + with nogil: req = self._endpoint.get().amSend( buf, nbytes, - UCS_MEMORY_TYPE_CUDA if cuda_array else UCS_MEMORY_TYPE_HOST, - nullopt, + params, + self._enable_python_future + ) + + return UCXRequest(&req, self._enable_python_future) + + def am_send_iov( + self, list arrays, memory_type_policy=None, user_header=None + ) -> UCXRequest: + cdef vector[ucp_dt_iov_t] iov_vec + cdef ucp_dt_iov_t entry + cdef shared_ptr[Request] req + cdef AmSendParams params + cdef bytes user_header_bytes + cdef const char* user_header_ptr + + if not self._context_feature_flags & Feature.AM.value: + raise ValueError("UCXContext must be created with `Feature.AM`") + + if len(arrays) == 0: + raise ValueError("IOV segment list must not be empty") + + cdef list wrapped = [] + for buf in arrays: + if not isinstance(buf, Array): + buf = Array(buf) + wrapped.append(buf) + + # Validate all segments have the same memory type + cdef bint first_cuda = (wrapped[0]).cuda + for arr_obj in wrapped: + if (arr_obj).cuda != first_cuda: + raise ValueError( + "All IOV segments must have the same memory type " + "(all host or all CUDA)" + ) + + for arr_obj in wrapped: + entry.buffer = (arr_obj).ptr + entry.length = (arr_obj).nbytes + iov_vec.push_back(entry) + + params.datatype = UCP_DATATYPE_IOV + params.memoryType = ( + UCS_MEMORY_TYPE_CUDA if first_cuda else UCS_MEMORY_TYPE_HOST + ) + if memory_type_policy is not None: + params.memoryTypePolicy = ( + memory_type_policy.value + ) + if user_header is not None: + if not isinstance(user_header, bytes): + raise TypeError("user_header must be bytes") + user_header_bytes = user_header + user_header_ptr = user_header_bytes + params.setUserHeader(user_header_ptr, len(user_header_bytes)) + + with nogil: + req = self._endpoint.get().amSend( + move(iov_vec), + params, self._enable_python_future ) diff --git a/python/ucxx/ucxx/_lib/tests/test_server_client.py b/python/ucxx/ucxx/_lib/tests/test_server_client.py index 414889185..0f3c8da28 100644 --- a/python/ucxx/ucxx/_lib/tests/test_server_client.py +++ b/python/ucxx/ucxx/_lib/tests/test_server_client.py @@ -17,9 +17,9 @@ WireupMessageSize = 10 -def _send(ep, api, message): +def _send(ep, api, message, memory_type_policy=None): if api == "am": - return ep.am_send(message) + return ep.am_send(message, memory_type_policy=memory_type_policy) elif api == "stream": return ep.stream_send(message) else: @@ -177,6 +177,535 @@ def _echo_client(transfer_api, msg_size, progress_mode, port): worker.stop_progress_thread() +def _echo_server_am_params( + get_queue, put_queue, msg_size, progress_mode, memory_type_policy +): + """Server that echoes AM messages using the AmSendParams code path.""" + feature_flags = (ucx_api.Feature.WAKEUP, ucx_api.Feature.AM) + ctx = ucx_api.UCXContext(feature_flags=feature_flags) + worker = ucx_api.UCXWorker(ctx) + + if progress_mode == "blocking": + worker.init_blocking_progress_mode() + else: + worker.start_progress_thread() + + ep = [None] + + def _listener_handler(conn_request): + ep[0] = listener.create_endpoint_from_conn_request(conn_request, True) + + listener = ucx_api.UCXListener.create( + worker=worker, port=0, cb_func=_listener_handler + ) + put_queue.put(listener.port) + + while ep[0] is None: + if progress_mode == "blocking": + worker.progress() + + msg = Array(bytearray(msg_size)) + requests = [ep[0].am_recv()] + wait_requests(worker, progress_mode, requests) + msg = Array(requests[0].recv_buffer) + requests = [ep[0].am_send(msg, memory_type_policy=memory_type_policy)] + wait_requests(worker, progress_mode, requests) + + while True: + try: + get_queue.get(block=True, timeout=0.1) + except QueueIsEmpty: + continue + else: + break + + if progress_mode == "thread": + worker.stop_progress_thread() + + +def _echo_client_am_params(msg_size, progress_mode, memory_type_policy, port): + """Client that sends and receives AM messages using AmSendParams.""" + feature_flags = (ucx_api.Feature.WAKEUP, ucx_api.Feature.AM) + ctx = ucx_api.UCXContext(feature_flags=feature_flags) + worker = ucx_api.UCXWorker(ctx) + + if progress_mode == "blocking": + worker.init_blocking_progress_mode() + else: + worker.start_progress_thread() + + ep = ucx_api.UCXEndpoint.create( + worker, + "127.0.0.1", + port, + endpoint_error_handling=True, + ) + + if progress_mode == "blocking": + worker.progress() + + send_msg = bytes(os.urandom(msg_size)) + + requests = [ + ep.am_send(Array(send_msg), memory_type_policy=memory_type_policy), + ep.am_recv(), + ] + wait_requests(worker, progress_mode, requests) + + recv_msg = requests[1].recv_buffer + assert bytes(recv_msg) == send_msg + + if progress_mode == "thread": + worker.stop_progress_thread() + + +@pytest.mark.parametrize( + "memory_type_policy", + [None, ucx_api.PythonAmSendMemoryTypePolicy.FallbackToHost], +) +@pytest.mark.parametrize("msg_size", [10, 2**24]) +@pytest.mark.parametrize("progress_mode", ["blocking", "thread"]) +def test_server_client_am_params(msg_size, progress_mode, memory_type_policy): + put_queue, get_queue = mp.Queue(), mp.Queue() + server_error_q = mp.Queue() + client_error_q = mp.Queue() + server = mp.Process( + target=run_in_subprocess, + args=( + _echo_server_am_params, + server_error_q, + put_queue, + get_queue, + msg_size, + progress_mode, + memory_type_policy, + ), + ) + server.start() + port = get_queue.get() + client = mp.Process( + target=run_in_subprocess, + args=( + _echo_client_am_params, + client_error_q, + msg_size, + progress_mode, + memory_type_policy, + port, + ), + ) + client.start() + client.join(timeout=60) + terminate_process(client, error_queue=client_error_q) + put_queue.put("Finished") + server.join(timeout=10) + terminate_process(server, error_queue=server_error_q) + + +def _echo_server_am_iov(get_queue, put_queue, msg_size, progress_mode): + """Server that receives an IOV AM message and echoes it back.""" + feature_flags = (ucx_api.Feature.WAKEUP, ucx_api.Feature.AM) + ctx = ucx_api.UCXContext(feature_flags=feature_flags) + worker = ucx_api.UCXWorker(ctx) + + if progress_mode == "blocking": + worker.init_blocking_progress_mode() + else: + worker.start_progress_thread() + + ep = [None] + + def _listener_handler(conn_request): + ep[0] = listener.create_endpoint_from_conn_request(conn_request, True) + + listener = ucx_api.UCXListener.create( + worker=worker, port=0, cb_func=_listener_handler + ) + put_queue.put(listener.port) + + while ep[0] is None: + if progress_mode == "blocking": + worker.progress() + + # Receive the IOV message (arrives as a single contiguous buffer) + requests = [ep[0].am_recv()] + wait_requests(worker, progress_mode, requests) + msg = Array(requests[0].recv_buffer) + # Echo back as a regular contiguous send + requests = [ep[0].am_send(msg)] + wait_requests(worker, progress_mode, requests) + + while True: + try: + get_queue.get(block=True, timeout=0.1) + except QueueIsEmpty: + continue + else: + break + + if progress_mode == "thread": + worker.stop_progress_thread() + + +def _echo_client_am_iov(msg_size, progress_mode, port): + """Client that sends an IOV AM message and receives the echo.""" + feature_flags = (ucx_api.Feature.WAKEUP, ucx_api.Feature.AM) + ctx = ucx_api.UCXContext(feature_flags=feature_flags) + worker = ucx_api.UCXWorker(ctx) + + if progress_mode == "blocking": + worker.init_blocking_progress_mode() + else: + worker.start_progress_thread() + + ep = ucx_api.UCXEndpoint.create( + worker, + "127.0.0.1", + port, + endpoint_error_handling=True, + ) + + if progress_mode == "blocking": + worker.progress() + + send_msg = bytes(os.urandom(msg_size)) + + # Split the message into two segments for IOV send + mid = msg_size // 2 + seg1 = Array(send_msg[:mid]) + seg2 = Array(send_msg[mid:]) + + requests = [ + ep.am_send_iov([seg1, seg2]), + ep.am_recv(), + ] + wait_requests(worker, progress_mode, requests) + + recv_msg = requests[1].recv_buffer + assert bytes(recv_msg) == send_msg + + if progress_mode == "thread": + worker.stop_progress_thread() + + +@pytest.mark.parametrize("msg_size", [10, 2**24]) +@pytest.mark.parametrize("progress_mode", ["blocking", "thread"]) +def test_server_client_am_iov(msg_size, progress_mode): + put_queue, get_queue = mp.Queue(), mp.Queue() + server_error_q = mp.Queue() + client_error_q = mp.Queue() + server = mp.Process( + target=run_in_subprocess, + args=( + _echo_server_am_iov, + server_error_q, + put_queue, + get_queue, + msg_size, + progress_mode, + ), + ) + server.start() + port = get_queue.get() + client = mp.Process( + target=run_in_subprocess, + args=( + _echo_client_am_iov, + client_error_q, + msg_size, + progress_mode, + port, + ), + ) + client.start() + client.join(timeout=60) + terminate_process(client, error_queue=client_error_q) + put_queue.put("Finished") + server.join(timeout=10) + terminate_process(server, error_queue=server_error_q) + + +def _user_header_server(get_queue, put_queue, msg_size, progress_mode): + """Server that receives an AM message with user header and echoes both back.""" + feature_flags = (ucx_api.Feature.WAKEUP, ucx_api.Feature.AM) + ctx = ucx_api.UCXContext(feature_flags=feature_flags) + worker = ucx_api.UCXWorker(ctx) + + if progress_mode == "blocking": + worker.init_blocking_progress_mode() + else: + worker.start_progress_thread() + + ep = [None] + + def _listener_handler(conn_request): + ep[0] = listener.create_endpoint_from_conn_request(conn_request, True) + + listener = ucx_api.UCXListener.create( + worker=worker, port=0, cb_func=_listener_handler + ) + put_queue.put(listener.port) + + while ep[0] is None: + if progress_mode == "blocking": + worker.progress() + + # Receive the message and its user header + requests = [ep[0].am_recv()] + wait_requests(worker, progress_mode, requests) + recv_header = requests[0].recv_header + msg = Array(requests[0].recv_buffer) + + # Echo back with the same user header + requests = [ep[0].am_send(msg, user_header=recv_header)] + wait_requests(worker, progress_mode, requests) + + while True: + try: + get_queue.get(block=True, timeout=0.1) + except QueueIsEmpty: + continue + else: + break + + if progress_mode == "thread": + worker.stop_progress_thread() + + +def _user_header_client(msg_size, progress_mode, port): + """Client that sends AM with user header and validates the echo.""" + feature_flags = (ucx_api.Feature.WAKEUP, ucx_api.Feature.AM) + ctx = ucx_api.UCXContext(feature_flags=feature_flags) + worker = ucx_api.UCXWorker(ctx) + + if progress_mode == "blocking": + worker.init_blocking_progress_mode() + else: + worker.start_progress_thread() + + ep = ucx_api.UCXEndpoint.create( + worker, + "127.0.0.1", + port, + endpoint_error_handling=True, + ) + + if progress_mode == "blocking": + worker.progress() + + send_msg = bytes(os.urandom(msg_size)) + user_header = b"test-header-\x00\x01\xff" + + requests = [ + ep.am_send(Array(send_msg), user_header=user_header), + ep.am_recv(), + ] + wait_requests(worker, progress_mode, requests) + + recv_msg = requests[1].recv_buffer + recv_header = requests[1].recv_header + assert bytes(recv_msg) == send_msg + assert recv_header == user_header + + if progress_mode == "thread": + worker.stop_progress_thread() + + +def _user_header_iov_client(msg_size, progress_mode, port): + """Client that sends IOV AM with user header and validates the echo.""" + feature_flags = (ucx_api.Feature.WAKEUP, ucx_api.Feature.AM) + ctx = ucx_api.UCXContext(feature_flags=feature_flags) + worker = ucx_api.UCXWorker(ctx) + + if progress_mode == "blocking": + worker.init_blocking_progress_mode() + else: + worker.start_progress_thread() + + ep = ucx_api.UCXEndpoint.create( + worker, + "127.0.0.1", + port, + endpoint_error_handling=True, + ) + + if progress_mode == "blocking": + worker.progress() + + send_msg = bytes(os.urandom(msg_size)) + user_header = b"iov-header-data" + + mid = msg_size // 2 + seg1 = Array(send_msg[:mid]) + seg2 = Array(send_msg[mid:]) + + requests = [ + ep.am_send_iov([seg1, seg2], user_header=user_header), + ep.am_recv(), + ] + wait_requests(worker, progress_mode, requests) + + recv_msg = requests[1].recv_buffer + recv_header = requests[1].recv_header + assert bytes(recv_msg) == send_msg + assert recv_header == user_header + + if progress_mode == "thread": + worker.stop_progress_thread() + + +def _empty_user_header_client(msg_size, progress_mode, port): + """Client that sends AM without user header and validates empty recv_header.""" + feature_flags = (ucx_api.Feature.WAKEUP, ucx_api.Feature.AM) + ctx = ucx_api.UCXContext(feature_flags=feature_flags) + worker = ucx_api.UCXWorker(ctx) + + if progress_mode == "blocking": + worker.init_blocking_progress_mode() + else: + worker.start_progress_thread() + + ep = ucx_api.UCXEndpoint.create( + worker, + "127.0.0.1", + port, + endpoint_error_handling=True, + ) + + if progress_mode == "blocking": + worker.progress() + + send_msg = bytes(os.urandom(msg_size)) + + requests = [ + ep.am_send(Array(send_msg)), + ep.am_recv(), + ] + wait_requests(worker, progress_mode, requests) + + recv_msg = requests[1].recv_buffer + recv_header = requests[1].recv_header + assert bytes(recv_msg) == send_msg + assert recv_header == b"" + + if progress_mode == "thread": + worker.stop_progress_thread() + + +@pytest.mark.parametrize("msg_size", [10, 2**24]) +@pytest.mark.parametrize("progress_mode", ["blocking", "thread"]) +def test_server_client_am_user_header(msg_size, progress_mode): + put_queue, get_queue = mp.Queue(), mp.Queue() + server_error_q = mp.Queue() + client_error_q = mp.Queue() + server = mp.Process( + target=run_in_subprocess, + args=( + _user_header_server, + server_error_q, + put_queue, + get_queue, + msg_size, + progress_mode, + ), + ) + server.start() + port = get_queue.get() + client = mp.Process( + target=run_in_subprocess, + args=( + _user_header_client, + client_error_q, + msg_size, + progress_mode, + port, + ), + ) + client.start() + client.join(timeout=60) + terminate_process(client, error_queue=client_error_q) + put_queue.put("Finished") + server.join(timeout=10) + terminate_process(server, error_queue=server_error_q) + + +@pytest.mark.parametrize("msg_size", [10, 2**24]) +@pytest.mark.parametrize("progress_mode", ["blocking", "thread"]) +def test_server_client_am_iov_user_header(msg_size, progress_mode): + put_queue, get_queue = mp.Queue(), mp.Queue() + server_error_q = mp.Queue() + client_error_q = mp.Queue() + server = mp.Process( + target=run_in_subprocess, + args=( + _user_header_server, + server_error_q, + put_queue, + get_queue, + msg_size, + progress_mode, + ), + ) + server.start() + port = get_queue.get() + client = mp.Process( + target=run_in_subprocess, + args=( + _user_header_iov_client, + client_error_q, + msg_size, + progress_mode, + port, + ), + ) + client.start() + client.join(timeout=60) + terminate_process(client, error_queue=client_error_q) + put_queue.put("Finished") + server.join(timeout=10) + terminate_process(server, error_queue=server_error_q) + + +@pytest.mark.parametrize("msg_size", [10, 2**24]) +@pytest.mark.parametrize("progress_mode", ["blocking", "thread"]) +def test_server_client_am_empty_user_header(msg_size, progress_mode): + """Test that recv_header is empty bytes when no user header is sent.""" + put_queue, get_queue = mp.Queue(), mp.Queue() + server_error_q = mp.Queue() + client_error_q = mp.Queue() + # Reuse the echo server that doesn't set user_header + server = mp.Process( + target=run_in_subprocess, + args=( + _echo_server_am_params, + server_error_q, + put_queue, + get_queue, + msg_size, + progress_mode, + None, + ), + ) + server.start() + port = get_queue.get() + client = mp.Process( + target=run_in_subprocess, + args=( + _empty_user_header_client, + client_error_q, + msg_size, + progress_mode, + port, + ), + ) + client.start() + client.join(timeout=60) + terminate_process(client, error_queue=client_error_q) + put_queue.put("Finished") + server.join(timeout=10) + terminate_process(server, error_queue=server_error_q) + + @pytest.mark.parametrize("transfer_api", ["am", "stream", "tag"]) @pytest.mark.parametrize("msg_size", [0, 10, 2**24]) @pytest.mark.parametrize("progress_mode", ["blocking", "thread"]) diff --git a/python/ucxx/ucxx/_lib/ucxx_api.pxd b/python/ucxx/ucxx/_lib/ucxx_api.pxd index f1ff85c65..eabc54f37 100644 --- a/python/ucxx/ucxx/_lib/ucxx_api.pxd +++ b/python/ucxx/ucxx/_lib/ucxx_api.pxd @@ -4,7 +4,7 @@ from posix cimport fcntl -from libc.stdint cimport int64_t, uint16_t, uint64_t +from libc.stdint cimport int64_t, uint16_t, uint32_t, uint64_t from libcpp cimport bool as cpp_bool from libcpp.functional cimport function from libcpp.memory cimport shared_ptr, unique_ptr @@ -56,6 +56,14 @@ cdef extern from "ucp/api/ucp.h" nogil: ctypedef uint64_t ucp_tag_t + ctypedef uint64_t ucp_datatype_t + + ctypedef struct ucp_dt_iov_t: + void* buffer + size_t length + + ucp_datatype_t UCP_DATATYPE_IOV + ctypedef struct ucp_tag_recv_info_t: pass @@ -197,6 +205,17 @@ cdef extern from "" namespace "ucxx" nogil: cdef cppclass AmReceiverCallbackInfo: pass + cdef enum class AmSendMemoryTypePolicy: + FallbackToHost + ErrorOnUnsupported + + cdef cppclass AmSendParams: + uint32_t flags + ucp_datatype_t datatype + ucs_memory_type_t memoryType + AmSendMemoryTypePolicy memoryTypePolicy + void setUserHeader(const void* data, size_t size) + # Using function[Buffer] here doesn't seem possible due to Cython bugs/limitations. # The workaround is to use a raw C function pointer and let it be parsed by the # compiler. @@ -310,6 +329,17 @@ cdef extern from "" namespace "ucxx" nogil: nullopt_t receiver_callback_info, bint enable_python_future ) except +raise_py_error + shared_ptr[Request] amSend( + const void* const buffer, + size_t length, + AmSendParams params, + bint enable_python_future + ) except +raise_py_error + shared_ptr[Request] amSend( + vector[ucp_dt_iov_t] iov, + AmSendParams params, + bint enable_python_future + ) except +raise_py_error shared_ptr[Request] amRecv( bint enable_python_future ) except +raise_py_error @@ -366,6 +396,7 @@ cdef extern from "" namespace "ucxx" nogil: void checkError() except +raise_py_error void* getFuture() except +raise_py_error shared_ptr[Buffer] getRecvBuffer() except +raise_py_error + string getRecvHeader() except +raise_py_error void cancel() diff --git a/python/ucxx/ucxx/_lib_async/endpoint.py b/python/ucxx/ucxx/_lib_async/endpoint.py index c5c61e6b3..839312b52 100644 --- a/python/ucxx/ucxx/_lib_async/endpoint.py +++ b/python/ucxx/ucxx/_lib_async/endpoint.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause @@ -167,7 +167,7 @@ async def close(self, period=10**10, max_attempts=1): await asyncio.sleep(0) self.abort(period=period, max_attempts=max_attempts) - async def am_send(self, buffer): + async def am_send(self, buffer, memory_type_policy=None, user_header=None): """Send `buffer` to connected peer via active messages. Parameters @@ -175,6 +175,12 @@ async def am_send(self, buffer): buffer: exposing the buffer protocol or array/cuda interface The buffer to send. Raise ValueError if buffer is smaller than nbytes. + memory_type_policy: PythonAmSendMemoryTypePolicy, optional + Policy controlling receiver-side allocation when no allocator is + registered for the sender's memory type. Default ``None`` uses + ``FallbackToHost``. + user_header: bytes, optional + Opaque user-defined header bytes to send alongside the message. """ self._ep.raise_on_error() if self.closed: @@ -196,7 +202,11 @@ async def am_send(self, buffer): self._send_count += 1 try: - request = self._ep.am_send(buffer) + request = self._ep.am_send( + buffer, + memory_type_policy=memory_type_policy, + user_header=user_header, + ) return await request.wait() except UCXCanceled as e: # If self._ep has already been closed and destroyed, we reraise the @@ -204,6 +214,51 @@ async def am_send(self, buffer): if self._ep is None: raise e + async def am_send_iov(self, buffers, memory_type_policy=None, user_header=None): + """Send multiple buffers as a single IOV active message. + + Parameters + ---------- + buffers: list + List of buffers exposing the buffer protocol or array/cuda + interface. All buffers must have the same memory type (all host + or all CUDA). + memory_type_policy: PythonAmSendMemoryTypePolicy, optional + Policy controlling receiver-side allocation when no allocator is + registered for the sender's memory type. Default ``None`` uses + ``FallbackToHost``. + user_header: bytes, optional + Opaque user-defined header bytes to send alongside the message. + """ + self._ep.raise_on_error() + if self.closed: + raise UCXCloseError("Endpoint closed") + if not (isinstance(buffers, list) or isinstance(buffers, tuple)): + raise ValueError("The `buffers` argument must be a `list` or `tuple`") + arrays = [Array(b) if not isinstance(b, Array) else b for b in buffers] + + if logger.isEnabledFor(logging.DEBUG): + log = "[AM Send IOV #%03d] ep: 0x%x, segments: %d, nbytes: %s" % ( + self._send_count, + self.uid, + len(arrays), + tuple([b.nbytes for b in arrays]), + ) + logger.debug(log) + + self._send_count += 1 + + try: + request = self._ep.am_send_iov( + arrays, + memory_type_policy=memory_type_policy, + user_header=user_header, + ) + return await request.wait() + except UCXCanceled as e: + if self._ep is None: + raise e + # @ucx_api.nvtx_annotate("UCXPY_SEND", color="green", domain="ucxpy") async def send(self, buffer, tag=None, force_tag=False): """Send `buffer` to connected peer. @@ -335,8 +390,8 @@ async def send_obj(self, obj, tag=None): await self.send(nbytes, tag=tag) await self.send(obj, tag=tag) - async def am_recv(self): - """Receive from connected peer via active messages.""" + async def _am_recv_request(self): + """Internal helper: receive AM request, return (buffer, request).""" if not self._ep.am_probe(): self._ep.raise_on_error() if self.closed: @@ -371,8 +426,37 @@ async def am_recv(self): and self._finished_recv_count >= self._close_after_n_recv ): self.abort() + return buffer, req + + async def am_recv(self): + """Receive from connected peer via active messages. + + Returns + ------- + buffer + The received data buffer. + + See Also + -------- + am_recv_with_header : Also returns the user-defined header. + """ + buffer, _ = await self._am_recv_request() return buffer + async def am_recv_with_header(self): + """Receive from connected peer via active messages, including + the user-defined header. + + Returns + ------- + tuple of (buffer, header) + - buffer: The received data buffer. + - header: bytes, the user-defined header sent by the peer. + Empty bytes if no user header was sent. + """ + buffer, req = await self._am_recv_request() + return buffer, req.recv_header + def tag_probe(self, tag=None, force_tag=False, remove=False): """Probe for tag messages without receiving them. diff --git a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_am.py b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_am.py index 1032d89ce..c9743201e 100644 --- a/python/ucxx/ucxx/_lib_async/tests/test_send_recv_am.py +++ b/python/ucxx/ucxx/_lib_async/tests/test_send_recv_am.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-2025, NVIDIA CORPORATION. +# SPDX-FileCopyrightText: Copyright (c) 2023-2026, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause import asyncio @@ -8,9 +8,11 @@ import pytest import ucxx +from ucxx._lib.libucxx import PythonAmSendMemoryTypePolicy from ucxx._lib_async.utils_test import wait_listener_client_handlers msg_sizes = [0] + [2**i for i in range(0, 25, 4)] +iov_msg_sizes = [10] + [2**i for i in range(4, 25, 4)] def _bytearray_assert_equal(a, b): @@ -54,10 +56,10 @@ def get_data(): return ret -def simple_server(size, recv): +def simple_server(size, recv, memory_type_policy=None): async def server(ep): recv = await ep.am_recv() - await ep.am_send(recv) + await ep.am_send(recv, memory_type_policy=memory_type_policy) await ep.close() return server @@ -67,14 +69,20 @@ async def server(ep): @pytest.mark.parametrize("size", msg_sizes) @pytest.mark.parametrize("recv_wait", [True, False]) @pytest.mark.parametrize("data", get_data()) -async def test_send_recv_am(size, recv_wait, data): +@pytest.mark.parametrize( + "memory_type_policy", + [None, PythonAmSendMemoryTypePolicy.FallbackToHost], +) +async def test_send_recv_am(size, recv_wait, data, memory_type_policy): rndv_thresh = 8192 ucxx.init(options={"RNDV_THRESH": str(rndv_thresh)}) msg = data["generator"](size) recv = [] - listener = ucxx.create_listener(simple_server(size, recv)) + listener = ucxx.create_listener( + simple_server(size, recv, memory_type_policy=memory_type_policy) + ) num_clients = 1 clients = [ await ucxx.create_endpoint(ucxx.get_address(), listener.port) @@ -85,7 +93,9 @@ async def test_send_recv_am(size, recv_wait, data): # ep.am_recv call will have to wait, rather than return # immediately as receive data is already available. await asyncio.sleep(1) - await asyncio.gather(*(c.am_send(msg) for c in clients)) + await asyncio.gather( + *(c.am_send(msg, memory_type_policy=memory_type_policy) for c in clients) + ) recv_msgs = await asyncio.gather(*(c.am_recv() for c in clients)) for recv_msg in recv_msgs: @@ -98,3 +108,132 @@ async def test_send_recv_am(size, recv_wait, data): await asyncio.gather(*(c.close() for c in clients)) await wait_listener_client_handlers(listener) + + +def simple_user_header_server(user_header_echo): + async def server(ep): + recv_buffer, recv_header = await ep.am_recv_with_header() + # Append before am_send so the client cannot observe the list before this + # (asyncio scheduling). + user_header_echo.append(recv_header) + await ep.am_send(recv_buffer, user_header=recv_header) + await ep.close() + + return server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("size", msg_sizes) +async def test_send_recv_am_user_header(size): + ucxx.init() + + msg = bytearray(b"m" * size) + user_header = b"test-header-\x00\x01\xff" + + server_received_headers = [] + listener = ucxx.create_listener(simple_user_header_server(server_received_headers)) + ep = await ucxx.create_endpoint(ucxx.get_address(), listener.port) + + await ep.am_send(msg, user_header=user_header) + recv_msg, recv_header = await ep.am_recv_with_header() + + assert bytes(recv_msg) == bytes(msg) + assert recv_header == user_header + + # Verify the server also received the header correctly + assert len(server_received_headers) == 1 + assert server_received_headers[0] == user_header + + await ep.close() + await wait_listener_client_handlers(listener) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("size", msg_sizes) +async def test_send_recv_am_empty_user_header(size): + """Test that recv_header is empty bytes when no user header is sent.""" + ucxx.init() + + msg = bytearray(b"m" * size) + + listener = ucxx.create_listener(simple_server(size, [], memory_type_policy=None)) + ep = await ucxx.create_endpoint(ucxx.get_address(), listener.port) + + await ep.am_send(msg) + recv_msg, recv_header = await ep.am_recv_with_header() + + assert bytes(recv_msg) == bytes(msg) + assert recv_header == b"" + + await ep.close() + await wait_listener_client_handlers(listener) + + +def simple_iov_server(): + async def server(ep): + recv = await ep.am_recv() + await ep.am_send(recv) + await ep.close() + + return server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("size", iov_msg_sizes) +async def test_send_recv_am_iov(size): + ucxx.init() + + msg = bytearray(b"m" * size) + mid = size // 2 + seg1 = msg[:mid] + seg2 = msg[mid:] + + listener = ucxx.create_listener(simple_iov_server()) + ep = await ucxx.create_endpoint(ucxx.get_address(), listener.port) + + await ep.am_send_iov([seg1, seg2]) + recv_msg = await ep.am_recv() + + assert bytes(recv_msg) == bytes(msg) + + await ep.close() + await wait_listener_client_handlers(listener) + + +def simple_iov_user_header_server(server_received_headers): + async def server(ep): + recv_buffer, recv_header = await ep.am_recv_with_header() + server_received_headers.append(recv_header) + await ep.am_send(recv_buffer, user_header=recv_header) + await ep.close() + + return server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("size", iov_msg_sizes) +async def test_send_recv_am_iov_user_header(size): + ucxx.init() + + msg = bytearray(b"m" * size) + mid = size // 2 + seg1 = msg[:mid] + seg2 = msg[mid:] + user_header = b"iov-header-data" + + server_received_headers = [] + listener = ucxx.create_listener( + simple_iov_user_header_server(server_received_headers) + ) + ep = await ucxx.create_endpoint(ucxx.get_address(), listener.port) + + await ep.am_send_iov([seg1, seg2], user_header=user_header) + recv_msg, recv_header = await ep.am_recv_with_header() + + assert bytes(recv_msg) == bytes(msg) + assert recv_header == user_header + assert len(server_received_headers) == 1 + assert server_received_headers[0] == user_header + + await ep.close() + await wait_listener_client_handlers(listener)