Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions tests/test_async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from zha import async_ as zha_async
from zha.application.gateway import Gateway
from zha.async_ import AsyncUtilMixin, ZHAJob, ZHAJobType, create_eager_task
from zha.decorators import periodic


@pytest.mark.parametrize("eager_start", [True, False])
Expand Down Expand Up @@ -452,6 +453,50 @@ async def _increment_runs_if_in_time():
assert results == [2, 2, -1, -1]


async def test_gather_with_limited_concurrency_zero_limit_rejected() -> None:
"""Test gather_with_limited_concurrency rejects a zero concurrency limit."""
task = asyncio.create_task(asyncio.sleep(0, result=1))

# Issue being validated:
# a limit of 0 builds Semaphore(0), so no wrapped task can ever acquire it.
# The call deadlocks instead of failing fast with an argument error.
#
# Why this is a problem:
# invalid configuration creates indefinite hangs and stuck shutdown paths instead
# of a clear ValueError that callers can handle immediately.
with pytest.raises(ValueError):
await asyncio.wait_for(
zha_async.gather_with_limited_concurrency(0, task),
timeout=0.05,
)


async def test_periodic_cancellation_propagates() -> None:
"""Test periodic-decorated tasks propagate cancellation."""
started = asyncio.Event()
release = asyncio.Event()

class Poller:
@periodic((1, 1), run_immediately=True)
async def poll(self) -> None:
started.set()
await release.wait()

# Issue being validated:
# periodic() catches CancelledError raised from inside the wrapped function and
# breaks the loop instead of re-raising.
#
# Why this is a problem:
# callers cannot observe real cancellation semantics (task is "finished" instead
# of "cancelled"), which breaks cancellation-aware orchestration and tests.
task = asyncio.create_task(Poller().poll())
await started.wait()
task.cancel()

with pytest.raises(asyncio.CancelledError):
await task


async def test_create_eager_task_312(zha_gateway: Gateway) -> None: # pylint: disable=unused-argument
"""Test create_eager_task schedules a task eagerly in the event loop.

Expand Down
53 changes: 53 additions & 0 deletions tests/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,59 @@ async def test_event_emit_with_context():
async_callback.assert_awaited_once_with("test", "data")


async def test_event_base_once_async_multiple_emits_same_tick() -> None:
"""Test that async once listeners only run once with back-to-back emits."""
event = EventGenerator()
callback = AsyncMock()

# Issue being validated:
# async "once" listeners are unsubscribed inside an async task, not synchronously.
# When two emits happen back-to-back in the same tick, both see the listener as
# still subscribed and both schedule callback execution.
#
# Why this is a problem:
# call sites rely on "once" semantics for idempotency (pairing, lifecycle hooks,
# cleanup paths). Running twice can cause duplicate writes, duplicate state
# transitions, and hard-to-reproduce race conditions.
event.once("test", callback)
event.emit("test", "first")
event.emit("test", "second")

await asyncio.gather(*event._event_tasks)

assert callback.await_count == 1
assert callback.await_args_list == [call("first")]


async def test_event_base_emit_async_callable_object() -> None:
"""Test that async callable objects are awaited when emitted."""

class AsyncCallable:
def __init__(self) -> None:
self.calls: list[str] = []

async def __call__(self, data: str) -> None:
self.calls.append(data)

event = EventGenerator()
callback = AsyncCallable()

# Issue being validated:
# emit() only checks inspect.iscoroutinefunction(listener.callback), which is False
# for callable objects whose __call__ is async. The returned coroutine is never
# awaited/scheduled.
#
# Why this is a problem:
# listeners silently do not run, and Python emits "coroutine was never awaited"
# warnings. That means missed automations/events plus noisy runtime warnings.
event.on_event("test", callback)

event.emit("test", "payload")
await asyncio.sleep(0)

assert callback.calls == ["payload"]


def test_handle_event_protocol():
"""Test event base class."""

Expand Down
3 changes: 3 additions & 0 deletions zha/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ async def gather_with_limited_concurrency(

From: https://stackoverflow.com/a/61478547/9127614
"""
if limit <= 0:
raise ValueError("limit must be > 0")

semaphore = Semaphore(limit)

async def sem_task(task: Awaitable[Any]) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion zha/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> None:
asyncio.current_task(),
method_info,
)
break
raise
except Exception as ex: # pylint: disable=broad-except
_LOGGER.warning(
"[%s] Failed to poll using method %s",
Expand Down
22 changes: 7 additions & 15 deletions zha/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,14 @@ def once(
self, event_name: str, callback: Callable, with_context: bool = False
) -> Callable:
"""Listen for an event exactly once."""
if inspect.iscoroutinefunction(callback):

async def async_event_listener(*args, **kwargs) -> None:
unsub()
task = asyncio.create_task(callback(*args, **kwargs))
self._event_tasks.append(task)
task.add_done_callback(self._event_tasks.remove)

unsub = self.on_event(
event_name, async_event_listener, with_context=with_context
)
return unsub

def event_listener(*args, **kwargs) -> None:
unsub()
callback(*args, **kwargs)
call = callback(*args, **kwargs)
if inspect.isawaitable(call):
task: asyncio.Task[Any] = asyncio.create_task(call)
self._event_tasks.append(task)
task.add_done_callback(self._event_tasks.remove)

unsub = self.on_event(event_name, event_listener, with_context=with_context)
return unsub
Expand All @@ -100,8 +92,8 @@ def emit(self, event_name: str, data=None) -> None:
else:
call = listener.callback(data)

if inspect.iscoroutinefunction(listener.callback):
task = asyncio.create_task(call)
if inspect.isawaitable(call):
task: asyncio.Task[Any] = asyncio.create_task(call)
self._event_tasks.append(task)
task.add_done_callback(self._event_tasks.remove)

Expand Down
Loading