From 56c3ae705ad67de92885a13788cc66f412321f96 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 6 Dec 2023 11:05:31 +0000 Subject: [PATCH 1/3] Use a weakref finalizer to stop notifier thread When creating our own application context with the python notifier thread enabled we previously had to arrange to stop it before dropping the reference to the context we created. The consequence of not doing so is that our application hangs in exit waiting for a thread to join that never will. Applying the principle of least surprise, an object should clean up all of its resources when dropping out of scope. To ensure this change Application.stop_notifier_thread to a staticmethod and attach it as a weakref finalizer when starting the notifier thread. This way, we don't have to handle the case where there is no notifier thread, nor do we have to remember to call the stop function manually. This additionally allows us to clean up the top-level reset function since there is no longer a need to stop the notifier "manually" during reset. --- python/ucxx/_lib_async/application_context.py | 46 +++++++++++-------- python/ucxx/core.py | 12 +---- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/python/ucxx/_lib_async/application_context.py b/python/ucxx/_lib_async/application_context.py index 698439911..956a3a9d5 100644 --- a/python/ucxx/_lib_async/application_context.py +++ b/python/ucxx/_lib_async/application_context.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. # SPDX-License-Identifier: BSD-3-Clause +import functools import logging import os import threading @@ -144,12 +145,23 @@ def start_notifier_thread(self): name="UCX-Py Async Notifier Thread", ) self.notifier_thread.start() + weakref.finalize( + self, + functools.partial( + self.stop_notifier_thread, + self.notifier_thread_q, + self.notifier_thread, + ), + ) else: logger.debug( "UCXX not compiled with UCXX_ENABLE_PYTHON, disabling notifier thread" ) - def stop_notifier_thread(self): + # Must be a staticmethod so that it can be used in a weakref + # finalizer on the application context + @staticmethod + def stop_notifier_thread(queue, thread): """ Stop Python future notifier thread @@ -157,24 +169,22 @@ def stop_notifier_thread(self): notification enabled via `UCXPY_ENABLE_PYTHON_FUTURE=1` or `ucxx.init(..., enable_python_future=True)`. - .. warning:: When the notifier thread is enabled it may be necessary to - explicitly call this method before shutting down the process or - or application, otherwise it may block indefinitely waiting for - the thread to terminate. Executing `ucxx.reset()` will also run - this method, so it's not necessary to have both. + The application context arranges to call this function + automatically in a weakref finalizer when it goes out of + scope. If using the global application context, `ucxx.reset()` + will drop the reference and cause the notifier thread to be + stopped. For a user-maintained context, one must just ensure + that the reference is dropped. """ - if self.notifier_thread_q and self.notifier_thread: - self.notifier_thread_q.put("shutdown") - while True: - # Having a timeout is required. During the notifier thread shutdown - # it may require the GIL, which will cause a deadlock with the `join()` - # call otherwise. - self.notifier_thread.join(timeout=0.01) - if not self.notifier_thread.is_alive(): - break - logger.debug("Notifier thread stopped") - else: - logger.debug("Notifier thread not running") + queue.put("shutdown") + while True: + # Having a timeout is required. During the notifier thread shutdown + # it may require the GIL, which will cause a deadlock with the `join()` + # call otherwise. + thread.join(timeout=0.01) + if not thread.is_alive(): + break + logger.debug("Notifier thread stopped") def create_listener( self, diff --git a/python/ucxx/core.py b/python/ucxx/core.py index e290e002e..e29b61e29 100644 --- a/python/ucxx/core.py +++ b/python/ucxx/core.py @@ -94,7 +94,6 @@ def reset(): The library is initiated at next API call. """ - stop_notifier_thread() global _ctx if _ctx is not None: weakref_ctx = weakref.ref(_ctx) @@ -104,21 +103,13 @@ def reset(): msg = ( "Trying to reset UCX but not all Endpoints and/or Listeners " "are closed(). The following objects are still referencing " - "ApplicationContext: " + f"the global ApplicationContext {weakref_ctx()}: " ) for o in gc.get_referrers(weakref_ctx()): msg += "\n %s" % str(o) raise UCXError(msg) -def stop_notifier_thread(): - global _ctx - if _ctx: - _ctx.stop_notifier_thread() - else: - logger.debug("UCX is not initialized.") - - def get_ucx_version(): """Return the version of the underlying UCX installation @@ -248,4 +239,3 @@ async def recv(buffer, tag): create_endpoint.__doc__ = ApplicationContext.create_endpoint.__doc__ continuous_ucx_progress.__doc__ = ApplicationContext.continuous_ucx_progress.__doc__ get_ucp_worker.__doc__ = ApplicationContext.get_ucp_worker.__doc__ -stop_notifier_thread.__doc__ = ApplicationContext.stop_notifier_thread.__doc__ From 3ea304352bb4076edc52c3600c473e5c8c9e052e Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 6 Dec 2023 12:14:58 +0000 Subject: [PATCH 2/3] Reinstate top-level stop_notifier_thread --- python/ucxx/core.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/ucxx/core.py b/python/ucxx/core.py index e29b61e29..b85a2433a 100644 --- a/python/ucxx/core.py +++ b/python/ucxx/core.py @@ -110,6 +110,14 @@ def reset(): raise UCXError(msg) +def stop_notifier_thread(): + global _ctx + if _ctx and _ctx.notifier_thread is not None: + _ctx.stop_notifier_thread(_ctx.notifier_thread_q, _ctx.notifier_thread) + else: + logger.debug("UCX is not initialized.") + + def get_ucx_version(): """Return the version of the underlying UCX installation @@ -239,3 +247,4 @@ async def recv(buffer, tag): create_endpoint.__doc__ = ApplicationContext.create_endpoint.__doc__ continuous_ucx_progress.__doc__ = ApplicationContext.continuous_ucx_progress.__doc__ get_ucp_worker.__doc__ = ApplicationContext.get_ucp_worker.__doc__ +stop_notifier_thread.__doc__ = ApplicationContext.stop_notifier_thread.__doc__ From e5170e7c6da615a44c1e81679bc11bb314ee5d72 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Wed, 6 Dec 2023 15:01:44 +0000 Subject: [PATCH 3/3] Squash, hopefully, a few more refcycles for cleanup --- .../distributed-ucxx/distributed_ucxx/ucxx.py | 23 +++++++++++-------- .../_lib_async/tests/test_custom_send_recv.py | 11 +++++---- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/python/distributed-ucxx/distributed_ucxx/ucxx.py b/python/distributed-ucxx/distributed_ucxx/ucxx.py index 62ff11bab..e63f93e71 100644 --- a/python/distributed-ucxx/distributed_ucxx/ucxx.py +++ b/python/distributed-ucxx/distributed_ucxx/ucxx.py @@ -566,24 +566,27 @@ def address(self): return f"{self.prefix}{self.ip}:{self.port}" async def start(self): - async def serve_forever(client_ep): - ucx = self.comm_class( + async def serve_forever(client_ep, *, selfref): + ucx = selfref().comm_class( client_ep, - local_addr=self.address, - peer_addr=self.address, - deserialize=self.deserialize, + local_addr=selfref().address, + peer_addr=selfref().address, + deserialize=selfref().deserialize, ) - ucx.allow_offload = self.allow_offload + ucx.allow_offload = selfref().allow_offload try: - await self.on_connection(ucx) + await selfref().on_connection(ucx) except CommClosedError: logger.debug("Connection closed before handshake completed") return - if self.comm_handler: - await self.comm_handler(ucx) + if selfref().comm_handler: + await selfref().comm_handler(ucx) init_once() - self.ucxx_server = ucxx.create_listener(serve_forever, port=self._input_port) + self.ucxx_server = ucxx.create_listener( + functools.partial(serve_forever, selfref=weakref.ref(self)), + port=self._input_port, + ) def stop(self): self.ucxx_server = None diff --git a/python/ucxx/_lib_async/tests/test_custom_send_recv.py b/python/ucxx/_lib_async/tests/test_custom_send_recv.py index 56018cab9..e247c731a 100644 --- a/python/ucxx/_lib_async/tests/test_custom_send_recv.py +++ b/python/ucxx/_lib_async/tests/test_custom_send_recv.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause import asyncio +import functools import pickle +import weakref import numpy as np import pytest @@ -98,11 +100,12 @@ def __init__(self): self.comm = None def start(self): - async def serve_forever(ep): - ucx = UCX(ep) - self.comm = ucx + async def serve_forever(ep, *, selfref): + selfref().comm = UCX(ep) - self.ucxx_server = ucxx.create_listener(serve_forever) + self.ucxx_server = ucxx.create_listener( + functools.partial(serve_forever, selfref=weakref.ref(self)) + ) uu = UCXListener() uu.start()