diff --git a/pyproject.toml b/pyproject.toml index 77842d9..25534de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ classifiers = [ ] requires-python = ">= 3.9" dependencies = [ - "anyio >=4.8.0,<5.0.0", + "anyio >=4.10.0,<5.0.0", "anyioutils >=0.7.1,<0.8.0", "pyzmq >=26.0.0,<28.0.0", ] @@ -36,7 +36,7 @@ dependencies = [ test = [ "pytest >=8,<9", "pytest-timeout", - "trio >=0.27.0,<0.28", + "anyio[trio]", "mypy", "ruff", "coverage[toml] >=7,<8", diff --git a/src/zmq_anyio/_socket.py b/src/zmq_anyio/_socket.py index 164fb99..e56c15c 100644 --- a/src/zmq_anyio/_socket.py +++ b/src/zmq_anyio/_socket.py @@ -19,9 +19,11 @@ get_cancelled_exc_class, sleep, wait_readable, + ClosedResourceError, + notify_closing, ) from anyio.abc import TaskGroup, TaskStatus -from anyioutils import FIRST_COMPLETED, Future, create_task, wait +from anyioutils import Future, create_task import zmq from zmq import EVENTS, POLLIN, POLLOUT @@ -890,36 +892,36 @@ async def _start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED): task_status.started() self.started.set() self._thread = get_ident() + + async def wait_or_cancel() -> None: + assert self.stopped is not None + await self.stopped.wait() + tg.cancel_scope.cancel() + + def fileno() -> int: + if self.closed: + return -1 + try: + return self._shadow_sock.fileno() + except zmq.ZMQError: + return -1 + try: - while True: - wait_stopped_task = create_task( - self.stopped.wait(), - self._task_group, - exception_handler=ignore_exceptions, - ) - tasks = [ - create_task( - wait_readable(self._shadow_sock), # type: ignore[arg-type] - self._task_group, - exception_handler=ignore_exceptions, - ), - wait_stopped_task, - ] - done, pending = await wait( - tasks, self._task_group, return_when=FIRST_COMPLETED - ) - for task in pending: - task.cancel() - if wait_stopped_task in done: + while (fd := fileno()) > 0: + async with create_task_group() as tg: + tg.start_soon(wait_or_cancel) + try: + await wait_readable(fd) + except ClosedResourceError: + break + finally: + tg.cancel_scope.cancel() + if self.stopped.is_set(): break await self._handle_events() - except BaseException: - pass finally: self._exited.set() - - assert self.stopped is not None - self.stopped.set() + self.stopped.set() async def stop(self): assert self._exited is not None @@ -933,11 +935,13 @@ async def stop(self): self.close() def close(self, linger: int | None = None) -> None: - try: - if not self.closed and self._fd is not None: + fd = self._fd + if not self.closed and fd is not None: + notify_closing(fd) + try: super().close(linger=linger) - except BaseException: - pass + except BaseException: + pass assert self.stopped is not None self.stopped.set() diff --git a/tests/conftest.py b/tests/conftest.py index 8f68ac6..99510ab 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,7 +66,7 @@ def context(contexts): @pytest.fixture -def sockets(contexts): +async def sockets(contexts): sockets = [] yield sockets # ensure any tracked sockets get their contexts cleaned up