diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index 62ff11bab..e63f93e71 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -566,24 +566,27 @@ def address(self): return f"{self.prefix}{self.ip}:{self.port}" async def start(self): - async def serve_forever(client_ep): - ucx = self.comm_class( + async def serve_forever(client_ep, *, selfref): + ucx = selfref().comm_class( client_ep, - local_addr=self.address, - peer_addr=self.address, - deserialize=self.deserialize, + local_addr=selfref().address, + peer_addr=selfref().address, + deserialize=selfref().deserialize, ) - ucx.allow_offload = self.allow_offload + ucx.allow_offload = selfref().allow_offload try: - await self.on_connection(ucx) + await selfref().on_connection(ucx) except CommClosedError: logger.debug("Connection closed before handshake completed") return - if self.comm_handler: - await self.comm_handler(ucx) + if selfref().comm_handler: + await selfref().comm_handler(ucx) init_once() - self.ucxx_server = ucxx.create_listener(serve_forever, port=self._input_port) + self.ucxx_server = ucxx.create_listener( + functools.partial(serve_forever, selfref=weakref.ref(self)), + port=self._input_port, + ) def stop(self): self.ucxx_server = None diff --git a/python/ucxx/_lib_async/application_context.py b/python/ucxx/_lib_async/application_context.py index 698439911..956a3a9d5 100644 --- a/python/ucxx/_lib_async/application_context.py +++ b/python/ucxx/_lib_async/application_context.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause +import functools import logging import os import threading @@ -144,12 +145,23 @@ def start_notifier_thread(self): name="UCX-Py Async Notifier Thread", ) self.notifier_thread.start() + weakref.finalize( + self, + functools.partial( + self.stop_notifier_thread, + self.notifier_thread_q, + self.notifier_thread, + ), + ) else: logger.debug( "UCXX not compiled with UCXX_ENABLE_PYTHON, disabling notifier thread" ) - def stop_notifier_thread(self): + # Must be a staticmethod so that it can be used in a weakref + # finalizer on the application context + @staticmethod + def stop_notifier_thread(queue, thread): """ Stop Python future notifier thread @@ -157,24 +169,22 @@ def stop_notifier_thread(self): notification enabled via `UCXPY_ENABLE_PYTHON_FUTURE=1` or `ucxx.init(..., enable_python_future=True)`. - .. warning:: When the notifier thread is enabled it may be necessary to - explicitly call this method before shutting down the process or - or application, otherwise it may block indefinitely waiting for - the thread to terminate. Executing `ucxx.reset()` will also run - this method, so it's not necessary to have both. + The application context arranges to call this function + automatically in a weakref finalizer when it goes out of + scope. If using the global application context, `ucxx.reset()` + will drop the reference and cause the notifier thread to be + stopped. For a user-maintained context, one must just ensure + that the reference is dropped. """ - if self.notifier_thread_q and self.notifier_thread: - self.notifier_thread_q.put("shutdown") - while True: - # Having a timeout is required. During the notifier thread shutdown - # it may require the GIL, which will cause a deadlock with the `join()` - # call otherwise. - self.notifier_thread.join(timeout=0.01) - if not self.notifier_thread.is_alive(): - break - logger.debug("Notifier thread stopped") - else: - logger.debug("Notifier thread not running") + queue.put("shutdown") + while True: + # Having a timeout is required. During the notifier thread shutdown + # it may require the GIL, which will cause a deadlock with the `join()` + # call otherwise. + thread.join(timeout=0.01) + if not thread.is_alive(): + break + logger.debug("Notifier thread stopped") def create_listener( self, diff --git a/python/ucxx/_lib_async/tests/test_custom_send_recv.py b/python/ucxx/_lib_async/tests/test_custom_send_recv.py index 56018cab9..e247c731a 100644 --- a/python/ucxx/_lib_async/tests/test_custom_send_recv.py +++ b/python/ucxx/_lib_async/tests/test_custom_send_recv.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause import asyncio +import functools import pickle +import weakref import numpy as np import pytest @@ -98,11 +100,12 @@ def __init__(self): self.comm = None def start(self): - async def serve_forever(ep): - ucx = UCX(ep) - self.comm = ucx + async def serve_forever(ep, *, selfref): + selfref().comm = UCX(ep) - self.ucxx_server = ucxx.create_listener(serve_forever) + self.ucxx_server = ucxx.create_listener( + functools.partial(serve_forever, selfref=weakref.ref(self)) + ) uu = UCXListener() uu.start() diff --git a/python/ucxx/core.py b/python/ucxx/core.py index e290e002e..b85a2433a 100644 --- a/python/ucxx/core.py +++ b/python/ucxx/core.py @@ -94,7 +94,6 @@ def reset(): The library is initiated at next API call. """ - stop_notifier_thread() global _ctx if _ctx is not None: weakref_ctx = weakref.ref(_ctx) @@ -104,7 +103,7 @@ def reset(): msg = ( "Trying to reset UCX but not all Endpoints and/or Listeners " "are closed(). The following objects are still referencing " - "ApplicationContext: " + f"the global ApplicationContext {weakref_ctx()}: " ) for o in gc.get_referrers(weakref_ctx()): msg += "\n %s" % str(o) @@ -113,8 +112,8 @@ def reset(): def stop_notifier_thread(): global _ctx - if _ctx: - _ctx.stop_notifier_thread() + if _ctx and _ctx.notifier_thread is not None: + _ctx.stop_notifier_thread(_ctx.notifier_thread_q, _ctx.notifier_thread) else: logger.debug("UCX is not initialized.")