diff --git a/tests/test_send_recv_am.py b/tests/test_send_recv_am.py index 252d817ca..eb9d82cc8 100644 --- a/tests/test_send_recv_am.py +++ b/tests/test_send_recv_am.py @@ -114,3 +114,50 @@ async def test_send_recv_bytes(size, blocking_progress_mode, recv_wait, data): assert recv[0] == bytearray(msg.get()) else: data["validator"](recv[0], msg) + + +@pytest.mark.skipif( + not ucp._libs.ucx_api.is_am_supported(), reason="AM only supported in UCX >= 1.11" +) +@pytest.mark.asyncio +@pytest.mark.parametrize("size", msg_sizes) +@pytest.mark.parametrize("blocking_progress_mode", [True, False]) +@pytest.mark.parametrize("recv_wait", [True, False]) +@pytest.mark.parametrize("data", get_data()) +async def test_send_recv_bytes_callback(size, blocking_progress_mode, recv_wait, data): + rndv_thresh = 8192 + ucp.init( + options={"RNDV_THRESH": str(rndv_thresh)}, + blocking_progress_mode=blocking_progress_mode, + ) + + recv = [] + + async def _cb(recv_obj, exception, ep): + recv.append(recv_obj) + + ucp.register_am_allocator(data["allocator"], data["memory_type"]) + ucp.register_am_recv_callback(_cb) + msg = data["generator"](size) + + num_clients = 1 + clients = [ + await ucp.create_endpoint_from_worker_address(ucp.get_worker_address()) + for i in range(num_clients) + ] + for c in clients: + if recv_wait: + # By sleeping here we ensure that the listener's + # ep.am_recv call will have to wait, rather than return + # immediately as receive data is already available. + await asyncio.sleep(1) + await c.am_send(msg) + for c in clients: + await c.close() + + if data["memory_type"] == "cuda" and msg.nbytes < rndv_thresh: + # Eager messages are always received on the host, if no host + # allocator is registered UCX-Py defaults to `bytearray`. + assert recv[0] == bytearray(msg.get()) + else: + data["validator"](recv[0], msg) diff --git a/ucp/_libs/tests/test_server_client_am.py b/ucp/_libs/tests/test_server_client_am.py index b2b336fe4..e0a9f1b13 100644 --- a/ucp/_libs/tests/test_server_client_am.py +++ b/ucp/_libs/tests/test_server_client_am.py @@ -1,5 +1,6 @@ import multiprocessing as mp import os +import pickle from functools import partial from queue import Empty as QueueIsEmpty @@ -49,7 +50,9 @@ def get_data(): return ret -def _echo_server(get_queue, put_queue, msg_size, datatype, endpoint_error_handling): +def _echo_server( + get_queue, put_queue, msg_size, datatype, user_callback, endpoint_error_handling +): """Server that send received message back to the client Notice, since it is illegal to call progress() in call-back functions, @@ -64,10 +67,6 @@ def _echo_server(get_queue, put_queue, msg_size, datatype, endpoint_error_handli worker = ucx_api.UCXWorker(ctx) worker.register_am_allocator(data["allocator"], data["memory_type"]) - # A reference to listener's endpoint is stored to prevent it from going - # out of scope too early. - ep = None - def _send_handle(request, exception, msg): # Notice, we pass `msg` to the handler in order to make sure # it doesn't go out of scope prematurely. @@ -78,20 +77,28 @@ def _recv_handle(recv_obj, exception, ep): msg = Array(recv_obj) ucx_api.am_send_nbx(ep, msg, msg.nbytes, cb_func=_send_handle, cb_args=(msg,)) - def _listener_handler(conn_request): - global ep - ep = ucx_api.UCXEndpoint.create_from_conn_request( - worker, conn_request, endpoint_error_handling=endpoint_error_handling, - ) + if user_callback is True: + worker.register_am_recv_callback(_recv_handle) + put_queue.put(pickle.dumps(worker.get_address())) + else: + # A reference to listener's endpoint is stored to prevent it from going + # out of scope too early. + ep = None - # Wireup - ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,)) + def _listener_handler(conn_request): + global ep + ep = ucx_api.UCXEndpoint.create_from_conn_request( + worker, conn_request, endpoint_error_handling=endpoint_error_handling, + ) - # Data - ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,)) + # Wireup + ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,)) - listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) - put_queue.put(listener.port) + # Data + ucx_api.am_recv_nb(ep, cb_func=_recv_handle, cb_args=(ep,)) + + listener = ucx_api.UCXListener(worker=worker, port=0, cb_func=_listener_handler) + put_queue.put(listener.port) while True: worker.progress() @@ -103,7 +110,9 @@ def _listener_handler(conn_request): break -def _echo_client(msg_size, datatype, port, endpoint_error_handling): +def _echo_client( + msg_size, datatype, user_callback, server_info, endpoint_error_handling +): data = get_data()[datatype] ctx = ucx_api.UCXContext( @@ -113,9 +122,16 @@ def _echo_client(msg_size, datatype, port, endpoint_error_handling): worker = ucx_api.UCXWorker(ctx) worker.register_am_allocator(data["allocator"], data["memory_type"]) - ep = ucx_api.UCXEndpoint.create( - worker, "localhost", port, endpoint_error_handling=endpoint_error_handling, - ) + if user_callback is True: + server_worker_addr = pickle.loads(server_info) + ep = ucx_api.UCXEndpoint.create_from_worker_address( + worker, server_worker_addr, endpoint_error_handling=endpoint_error_handling, + ) + else: + port = server_info + ep = ucx_api.UCXEndpoint.create( + worker, "localhost", port, endpoint_error_handling=endpoint_error_handling, + ) # The wireup message is sent to ensure endpoints are connected, otherwise # UCX may not perform any rendezvous transfers. @@ -134,10 +150,14 @@ def _echo_client(msg_size, datatype, port, endpoint_error_handling): recv_wireup = bytearray(recv_wireup) assert bytearray(recv_wireup) == send_wireup - if data["memory_type"] == "cuda" and send_data.nbytes < RNDV_THRESH: + if ( + data["memory_type"] == ucx_api.AllocatorType.CUDA + and send_data.nbytes < RNDV_THRESH + ): # Eager messages are always received on the host, if no host # allocator is registered UCX-Py defaults to `bytearray`. assert recv_data == bytearray(send_data.get()) + else: data["validator"](recv_data, send_data) @@ -146,18 +166,27 @@ def _echo_client(msg_size, datatype, port, endpoint_error_handling): ) @pytest.mark.parametrize("msg_size", [10, 2 ** 24]) @pytest.mark.parametrize("datatype", get_data().keys()) -def test_server_client(msg_size, datatype): +@pytest.mark.parametrize("user_callback", [False, True]) +def test_server_client(msg_size, datatype, user_callback): endpoint_error_handling = ucx_api.get_ucx_version() >= (1, 10, 0) put_queue, get_queue = mp.Queue(), mp.Queue() server = mp.Process( target=_echo_server, - args=(put_queue, get_queue, msg_size, datatype, endpoint_error_handling), + args=( + put_queue, + get_queue, + msg_size, + datatype, + user_callback, + endpoint_error_handling, + ), ) server.start() - port = get_queue.get() + server_info = get_queue.get() client = mp.Process( - target=_echo_client, args=(msg_size, datatype, port, endpoint_error_handling) + target=_echo_client, + args=(msg_size, datatype, user_callback, server_info, endpoint_error_handling), ) client.start() client.join(timeout=10) diff --git a/ucp/_libs/transfer_am.pyx b/ucp/_libs/transfer_am.pyx index 1beb4ffe0..85d7ad8cc 100644 --- a/ucp/_libs/transfer_am.pyx +++ b/ucp/_libs/transfer_am.pyx @@ -210,6 +210,10 @@ def am_recv_nb( IF CY_UCP_AM_SUPPORTED: worker = ep.worker + if worker.is_am_recv_callback_registered(): + raise RuntimeError("`am_recv_nb` cannot be used when a callback was " + "registered to worker with `register_am_recv_callback`") + if cb_args is None: cb_args = () if cb_kwargs is None: diff --git a/ucp/_libs/ucx_endpoint.pyx b/ucp/_libs/ucx_endpoint.pyx index 373002226..aa90b8373 100644 --- a/ucp/_libs/ucx_endpoint.pyx +++ b/ucp/_libs/ucx_endpoint.pyx @@ -135,45 +135,65 @@ cdef class UCXEndpoint(UCXObject): def __init__( self, UCXWorker worker, - uintptr_t params_as_int, - bint endpoint_error_handling + uintptr_t params_as_int=0, + endpoint_error_handling=None, + uintptr_t ep_as_int=0, ): """The Constructor""" + cdef ucp_err_handler_cb_t err_cb + cdef uintptr_t ep_status + cdef ucp_ep_params_t *params = NULL + cdef ucp_ep_h ucp_ep + cdef ucs_status_t status + assert worker.initialized self.worker = worker self._inflight_msgs = set() + self._endpoint_error_handling = endpoint_error_handling - cdef ucp_err_handler_cb_t err_cb - cdef uintptr_t ep_status - err_cb, ep_status = ( - _get_error_callback(worker._context._config["TLS"], endpoint_error_handling) - ) + if params_as_int == 0 and ep_as_int == 0: + raise ValueError("At least one of `params_as_int` or `ep` must be set.") + elif params_as_int != 0 and ep_as_int != 0: + raise ValueError("`params_as_int` and `ep` are mutually exclusive.") + elif params_as_int != 0: + err_cb, ep_status = ( + _get_error_callback( + worker._context._config["TLS"], + endpoint_error_handling + ) + ) - cdef ucp_ep_params_t *params = params_as_int - if err_cb == NULL: - params.err_mode = UCP_ERR_HANDLING_MODE_NONE + params = params_as_int + if err_cb == NULL: + params.err_mode = UCP_ERR_HANDLING_MODE_NONE + else: + params.err_mode = UCP_ERR_HANDLING_MODE_PEER + params.err_handler.cb = err_cb + params.err_handler.arg = self + + status = ucp_ep_create(worker._handle, params, &ucp_ep) + assert_ucs_status(status) + + self._handle = ucp_ep + self._status = ep_status + self.add_handle_finalizer( + _ucx_endpoint_finalizer, + int(ucp_ep), + int(ep_status), + endpoint_error_handling, + worker, + self._inflight_msgs, + ) + worker.add_child(self) else: - params.err_mode = UCP_ERR_HANDLING_MODE_PEER - params.err_handler.cb = err_cb - params.err_handler.arg = self - - cdef ucp_ep_h ucp_ep - cdef ucs_status_t status = ucp_ep_create(worker._handle, params, &ucp_ep) - assert_ucs_status(status) - - self._handle = ucp_ep - self._status = ep_status - self._endpoint_error_handling = endpoint_error_handling - self.add_handle_finalizer( - _ucx_endpoint_finalizer, - int(ucp_ep), - int(ep_status), - endpoint_error_handling, - worker, - self._inflight_msgs, - ) - worker.add_child(self) + self._handle = ep_as_int + self._status = UCS_OK + self.add_handle_finalizer( + lambda handle: None, + int(ucp_ep), + ) + worker.add_child(self) @classmethod def create( @@ -200,7 +220,11 @@ cdef class UCXEndpoint(UCXObject): raise MemoryError("Failed allocation of sockaddr") try: - return cls(worker, params, endpoint_error_handling) + return cls( + worker, + params_as_int=params, + endpoint_error_handling=endpoint_error_handling + ) finally: c_util_sockaddr_free(¶ms.sockaddr) free(params) @@ -221,7 +245,11 @@ cdef class UCXEndpoint(UCXObject): params.address = address._address try: - return cls(worker, params, endpoint_error_handling) + return cls( + worker, + params_as_int=params, + endpoint_error_handling=endpoint_error_handling + ) finally: free(params) @@ -243,10 +271,23 @@ cdef class UCXEndpoint(UCXObject): params.conn_request = conn_request try: - return cls(worker, params, endpoint_error_handling) + return cls( + worker, + params_as_int=params, + endpoint_error_handling=endpoint_error_handling + ) finally: free(params) + @classmethod + def create_from_handle( + cls, + UCXWorker worker, + uintptr_t ep, + ): + assert worker.initialized + return cls(worker, ep_as_int=ep) + def info(self): assert self.initialized diff --git a/ucp/_libs/ucx_worker.pyx b/ucp/_libs/ucx_worker.pyx index 786ce657d..1df194b8c 100644 --- a/ucp/_libs/ucx_worker.pyx +++ b/ucp/_libs/ucx_worker.pyx @@ -99,6 +99,7 @@ cdef class UCXWorker(UCXObject): IF CY_UCP_AM_SUPPORTED: dict _am_recv_pool dict _am_recv_wait + object _am_user_cb object _am_host_allocator object _am_cuda_allocator @@ -125,6 +126,7 @@ cdef class UCXWorker(UCXObject): if Feature.AM in context._feature_flags: self._am_recv_pool = dict() self._am_recv_wait = dict() + self._am_user_cb = None self._am_host_allocator = bytearray self._am_cuda_allocator = None am_handler_param.field_mask = ( @@ -179,7 +181,33 @@ cdef class UCXWorker(UCXObject): raise UCXError("Allocator type not supported") else: raise RuntimeError("UCX-Py needs to be built against and running with " - "UCX >= 1.11 to support am_send_nbx.") + "UCX >= 1.11 to support active messages.") + + def register_am_recv_callback(self, object callback): + """Register callback to be executed when an Active Message is received. + + When a callback is registered with this function, each time a message + arrives via Active Message API the callback function is executed. + Note that when this is used, `am_recv_nb` cannot be used. + + Parameters + ---------- + callback: callable + A callback function accepting exactly three arguments: the received + message, an exception, and a `UCXEndpoint`. The received message is a + buffer allocated with host or CUDA allocator, the exception contains + the error that occurred while receiving or `None` if it was successful, + and the `UCXEndpoint` object that can be used to send a reply back to + remote worker. + """ + if is_am_supported(): + self._am_user_cb = callback + else: + raise RuntimeError("UCX-Py needs to be built against and running with " + "UCX >= 1.11 to support active messages.") + + def is_am_recv_callback_registered(self): + return self._am_user_cb is not None def init_blocking_progress_mode(self): assert self.initialized diff --git a/ucp/_libs/ucx_worker_cb.pyx b/ucp/_libs/ucx_worker_cb.pyx index 42b02e558..ea70a3375 100644 --- a/ucp/_libs/ucx_worker_cb.pyx +++ b/ucp/_libs/ucx_worker_cb.pyx @@ -92,6 +92,7 @@ IF CY_UCP_AM_SUPPORTED: cdef UCXWorker worker = arg cdef dict am_recv_pool = worker._am_recv_pool cdef dict am_recv_wait = worker._am_recv_wait + cdef object am_user_cb = worker._am_user_cb cdef set inflight_msgs = worker._inflight_msgs assert worker.initialized assert param.recv_attr & UCP_AM_RECV_ATTR_FIELD_REPLY_EP @@ -111,7 +112,15 @@ IF CY_UCP_AM_SUPPORTED: cdef int allocator_type = (header)[0] def _push_result(buf, exception, recv_type): - if ( + if am_user_cb is not None: + ep_obj = UCXEndpoint.create_from_handle(worker, ep_as_int) + logger.debug("running am callback on %s message in ep %s" % ( + recv_type, + hex(int(ep_obj.handle)), + )) + + am_user_cb(buf, exception, ep_obj) + elif ( ep_as_int in am_recv_wait and len(am_recv_wait[ep_as_int]) > 0 ): diff --git a/ucp/core.py b/ucp/core.py index f820ed377..8662ed3e2 100644 --- a/ucp/core.py +++ b/ucp/core.py @@ -185,6 +185,16 @@ def _listener_handler(conn_request, callback_func, ctx, endpoint_error_handling) ) +def _am_user_callback_handler(callback_func, recv_obj, exception, ep): + async def _handler(cb_func, recv_obj, exception, ep): + if asyncio.iscoroutinefunction(cb_func): + await cb_func(recv_obj, exception, ep) + else: + cb_func(recv_obj, exception, ep) + + asyncio.ensure_future(_handler(callback_func, recv_obj, exception, ep,)) + + def _epoll_fd_finalizer(epoll_fd, progress_tasks): assert epoll_fd >= 0 # Notice, progress_tasks must be cleared before we close @@ -470,6 +480,10 @@ def register_am_allocator(self, allocator, allocator_type): allocator_type = ucx_api.AllocatorType.UNSUPPORTED self.worker.register_am_allocator(allocator, allocator_type) + def register_am_recv_callback(self, callback): + callback = partial(_am_user_callback_handler, callback) + self.worker.register_am_recv_callback(callback) + @nvtx_annotate("UCXPY_WORKER_RECV", color="red", domain="ucxpy") async def recv(self, buffer, tag): """Receive directly on worker without a local Endpoint into `buffer`. @@ -579,25 +593,34 @@ async def close(self): return self._shutting_down_peer = True - # Send a shutdown message to the peer - msg = CtrlMsg.serialize(opcode=1, close_after_n_recv=self._send_count) - msg_arr = Array(msg) - log = "[Send shutdown] ep: %s, tag: %s, close_after_n_recv: %d" % ( - hex(self.uid), - hex(self._tags["ctrl_send"]), - self._send_count, - ) - logger.debug(log) - try: - await comm.tag_send( - self._ep, msg_arr, msg_arr.nbytes, self._tags["ctrl_send"], name=log - ) - # The peer might already be shutting down thus we can ignore any send errors - except UCXError as e: - logging.warning( - "UCX failed closing worker %s (probably already closed): %s" - % (hex(self.uid), repr(e)) + # A shutdown message can only be sent to the remote peer if peer + # (i.e., tag) information was exchanged. This doesn't apply to + # AM-exclusive message exchange. + if self._tags is not None: + # Send a shutdown message to the peer + msg = CtrlMsg.serialize(opcode=1, close_after_n_recv=self._send_count) + msg_arr = Array(msg) + log = "[Send shutdown] ep: %s, tag: %s, close_after_n_recv: %d" % ( + hex(self.uid), + hex(self._tags["ctrl_send"]), + self._send_count, ) + logger.debug(log) + try: + await comm.tag_send( + self._ep, + msg_arr, + msg_arr.nbytes, + self._tags["ctrl_send"], + name=log, + ) + # The peer might already be shutting down thus we can ignore any send + # errors + except UCXError as e: + logging.warning( + "UCX failed closing worker %s (probably already closed): %s" + % (hex(self.uid), repr(e)) + ) finally: if not self.closed(): # Give all current outstanding send() calls a chance to return @@ -668,10 +691,9 @@ async def am_send(self, buffer): if not isinstance(buffer, Array): buffer = Array(buffer) nbytes = buffer.nbytes - log = "[AM Send #%03d] ep: %s, tag: %s, nbytes: %d, type: %s" % ( + log = "[AM Send #%03d] ep: %s, nbytes: %d, type: %s" % ( self._send_count, hex(self.uid), - hex(self._tags["msg_send"]), nbytes, type(buffer.obj), ) @@ -965,6 +987,10 @@ def register_am_allocator(allocator, allocator_type): return _get_ctx().register_am_allocator(allocator, allocator_type) +def register_am_recv_callback(callback): + return _get_ctx().register_am_recv_callback(callback) + + def create_listener(callback_func, port=None, endpoint_error_handling=None): return _get_ctx().create_listener( callback_func, port, endpoint_error_handling=endpoint_error_handling,