diff --git a/tests/test_async_.py b/tests/test_async_.py index d049cec34..fa77e1849 100644 --- a/tests/test_async_.py +++ b/tests/test_async_.py @@ -10,6 +10,7 @@ from zha import async_ as zha_async from zha.application.gateway import Gateway from zha.async_ import AsyncUtilMixin, ZHAJob, ZHAJobType, create_eager_task +from zha.decorators import periodic @pytest.mark.parametrize("eager_start", [True, False]) @@ -452,6 +453,50 @@ async def _increment_runs_if_in_time(): assert results == [2, 2, -1, -1] +async def test_gather_with_limited_concurrency_zero_limit_rejected() -> None: + """Test gather_with_limited_concurrency rejects a zero concurrency limit.""" + task = asyncio.create_task(asyncio.sleep(0, result=1)) + + # Issue being validated: + # a limit of 0 builds Semaphore(0), so no wrapped task can ever acquire it. + # The call deadlocks instead of failing fast with an argument error. + # + # Why this is a problem: + # invalid configuration creates indefinite hangs and stuck shutdown paths instead + # of a clear ValueError that callers can handle immediately. + with pytest.raises(ValueError): + await asyncio.wait_for( + zha_async.gather_with_limited_concurrency(0, task), + timeout=0.05, + ) + + +async def test_periodic_cancellation_propagates() -> None: + """Test periodic-decorated tasks propagate cancellation.""" + started = asyncio.Event() + release = asyncio.Event() + + class Poller: + @periodic((1, 1), run_immediately=True) + async def poll(self) -> None: + started.set() + await release.wait() + + # Issue being validated: + # periodic() catches CancelledError raised from inside the wrapped function and + # breaks the loop instead of re-raising. + # + # Why this is a problem: + # callers cannot observe real cancellation semantics (task is "finished" instead + # of "cancelled"), which breaks cancellation-aware orchestration and tests. + task = asyncio.create_task(Poller().poll()) + await started.wait() + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + async def test_create_eager_task_312(zha_gateway: Gateway) -> None: # pylint: disable=unused-argument """Test create_eager_task schedules a task eagerly in the event loop. diff --git a/tests/test_event.py b/tests/test_event.py index 7c27543a7..19aa5d67e 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -188,6 +188,59 @@ async def test_event_emit_with_context(): async_callback.assert_awaited_once_with("test", "data") +async def test_event_base_once_async_multiple_emits_same_tick() -> None: + """Test that async once listeners only run once with back-to-back emits.""" + event = EventGenerator() + callback = AsyncMock() + + # Issue being validated: + # async "once" listeners are unsubscribed inside an async task, not synchronously. + # When two emits happen back-to-back in the same tick, both see the listener as + # still subscribed and both schedule callback execution. + # + # Why this is a problem: + # call sites rely on "once" semantics for idempotency (pairing, lifecycle hooks, + # cleanup paths). Running twice can cause duplicate writes, duplicate state + # transitions, and hard-to-reproduce race conditions. + event.once("test", callback) + event.emit("test", "first") + event.emit("test", "second") + + await asyncio.gather(*event._event_tasks) + + assert callback.await_count == 1 + assert callback.await_args_list == [call("first")] + + +async def test_event_base_emit_async_callable_object() -> None: + """Test that async callable objects are awaited when emitted.""" + + class AsyncCallable: + def __init__(self) -> None: + self.calls: list[str] = [] + + async def __call__(self, data: str) -> None: + self.calls.append(data) + + event = EventGenerator() + callback = AsyncCallable() + + # Issue being validated: + # emit() only checks inspect.iscoroutinefunction(listener.callback), which is False + # for callable objects whose __call__ is async. The returned coroutine is never + # awaited/scheduled. + # + # Why this is a problem: + # listeners silently do not run, and Python emits "coroutine was never awaited" + # warnings. That means missed automations/events plus noisy runtime warnings. + event.on_event("test", callback) + + event.emit("test", "payload") + await asyncio.sleep(0) + + assert callback.calls == ["payload"] + + def test_handle_event_protocol(): """Test event base class.""" diff --git a/zha/async_.py b/zha/async_.py index 99a72bc1a..f6ad19576 100644 --- a/zha/async_.py +++ b/zha/async_.py @@ -57,6 +57,9 @@ async def gather_with_limited_concurrency( From: https://stackoverflow.com/a/61478547/9127614 """ + if limit <= 0: + raise ValueError("limit must be > 0") + semaphore = Semaphore(limit) async def sem_task(task: Awaitable[Any]) -> Any: diff --git a/zha/decorators.py b/zha/decorators.py index bc5214393..0f632a79b 100644 --- a/zha/decorators.py +++ b/zha/decorators.py @@ -87,7 +87,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> None: asyncio.current_task(), method_info, ) - break + raise except Exception as ex: # pylint: disable=broad-except _LOGGER.warning( "[%s] Failed to poll using method %s", diff --git a/zha/event.py b/zha/event.py index 6a31f775b..55bb3a848 100644 --- a/zha/event.py +++ b/zha/event.py @@ -64,22 +64,14 @@ def once( self, event_name: str, callback: Callable, with_context: bool = False ) -> Callable: """Listen for an event exactly once.""" - if inspect.iscoroutinefunction(callback): - - async def async_event_listener(*args, **kwargs) -> None: - unsub() - task = asyncio.create_task(callback(*args, **kwargs)) - self._event_tasks.append(task) - task.add_done_callback(self._event_tasks.remove) - - unsub = self.on_event( - event_name, async_event_listener, with_context=with_context - ) - return unsub def event_listener(*args, **kwargs) -> None: unsub() - callback(*args, **kwargs) + call = callback(*args, **kwargs) + if inspect.isawaitable(call): + task: asyncio.Task[Any] = asyncio.create_task(call) + self._event_tasks.append(task) + task.add_done_callback(self._event_tasks.remove) unsub = self.on_event(event_name, event_listener, with_context=with_context) return unsub @@ -100,8 +92,8 @@ def emit(self, event_name: str, data=None) -> None: else: call = listener.callback(data) - if inspect.iscoroutinefunction(listener.callback): - task = asyncio.create_task(call) + if inspect.isawaitable(call): + task: asyncio.Task[Any] = asyncio.create_task(call) self._event_tasks.append(task) task.add_done_callback(self._event_tasks.remove)