Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements_test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pre-commit
pytest-cov
pytest-sugar
pytest-timeout
pytest-asyncio<1.0
pytest-asyncio>=1.3.0
pytest-xdist
pytest
python-slugify
Expand Down
22 changes: 14 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test configuration for the ZHA component."""

import asyncio
from collections.abc import Callable, Generator
from collections.abc import AsyncGenerator, Callable, Generator
from contextlib import contextmanager
import logging
import os
Expand All @@ -12,6 +12,7 @@

import looptime
import pytest
import pytest_asyncio
import zigpy
from zigpy.application import ControllerApplication
import zigpy.config
Expand Down Expand Up @@ -151,18 +152,18 @@ def expected_lingering_timers() -> bool:
return False


@pytest.fixture(autouse=True)
def verify_cleanup(
event_loop: asyncio.AbstractEventLoop,
@pytest_asyncio.fixture(autouse=True)
async def verify_cleanup(
expected_lingering_tasks: bool, # pylint: disable=redefined-outer-name
expected_lingering_timers: bool, # pylint: disable=redefined-outer-name
) -> Generator[None, None, None]:
) -> AsyncGenerator[None, None]:
"""Verify that the test has cleaned up resources correctly."""
event_loop = asyncio.get_running_loop()
threads_before = frozenset(threading.enumerate())
tasks_before = asyncio.all_tasks(event_loop)
yield

event_loop.run_until_complete(event_loop.shutdown_default_executor())
await event_loop.shutdown_default_executor()

if len(INSTANCES) >= 2:
count = len(INSTANCES)
Expand All @@ -172,15 +173,20 @@ def verify_cleanup(

# Warn and clean-up lingering tasks and timers
# before moving on to the next test.
tasks = asyncio.all_tasks(event_loop) - tasks_before
current_task = asyncio.current_task()
tasks = {
task
for task in asyncio.all_tasks(event_loop) - tasks_before
if task is not current_task
}
for task in tasks:
if expected_lingering_tasks:
_LOGGER.warning("Lingering task after test %r", task)
else:
pytest.fail(f"Lingering task after test {task!r}")
task.cancel()
if tasks:
event_loop.run_until_complete(asyncio.wait(tasks))
await asyncio.wait(tasks)

for handle in event_loop._scheduled:
if not handle.cancelled():
Expand Down
3 changes: 2 additions & 1 deletion zha/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import enum
import functools
from functools import cached_property
import inspect
import logging
import time
from typing import (
Expand Down Expand Up @@ -133,7 +134,7 @@ def get_zhajob_callable_job_type(target: Callable[..., Any]) -> ZHAJobType:
while isinstance(target, functools.partial):
target = target.func

if asyncio.iscoroutinefunction(target):
if inspect.iscoroutinefunction(target):
return ZHAJobType.Coroutinefunction
else:
return ZHAJobType.Callback
Expand Down
Loading