Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 213 additions & 1 deletion tests/test_discover.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,14 @@
zigpy_device_from_json,
)
from zha.application import Platform
from zha.application.discovery import discover_device_entities
from zha.application.discovery import discover_device_entities, discover_group_entities
from zha.application.gateway import Gateway
from zha.application.helpers import DeviceOverridesConfiguration
from zha.application.platforms import PlatformEntity, binary_sensor, sensor
from zha.application.platforms.light import HueLight
from zha.application.platforms.number import BaseNumber, NumberMode
from zha.zigbee.cluster_handlers.const import PHILLIPS_REMOTE_CLUSTER
from zha.zigbee.group import GroupMemberReference


def _get_identify_cluster(zigpy_device):
Expand All @@ -69,6 +70,50 @@ def _get_identify_cluster(zigpy_device):
return endpoint.identify


def test_discover_device_entities_continues_after_endpoint_exception() -> None:
"""Test endpoint discovery exception does not stop later endpoints."""
zha_device = mock.MagicMock()
zha_device.ieee = "00:0d:6f:00:0a:90:69:e7"
zha_device.name = "FakeManufacturer FakeModel"
zha_device.is_active_coordinator = False

endpoint_1 = mock.MagicMock()
endpoint_1.id = 1
endpoint_1.device.ieee = zha_device.ieee

endpoint_2 = mock.MagicMock()
endpoint_2.id = 2
endpoint_2.device.ieee = zha_device.ieee

zha_device.endpoints = {
endpoint_1.id: endpoint_1,
endpoint_2.id: endpoint_2,
}

discovered_entity = mock.sentinel.discovered_entity

def _discover(endpoint):
if endpoint.id == 1:
raise RuntimeError("endpoint discovery failed")
return iter([discovered_entity])

with mock.patch(
"zha.application.discovery.discover_entities_for_endpoint",
side_effect=_discover,
):
# Issue being validated:
# discover_device_entities() wraps iteration-level exceptions, but an endpoint
# discovery exception currently terminates the underlying generator, so later
# endpoints are never processed.
#
# Why this is a problem:
# a single bad endpoint prevents discovery for all remaining endpoints on the
# device, causing partial entity loss.
entities = list(discover_device_entities(zha_device))

assert entities == [discovered_entity]


@pytest.mark.parametrize("override_platform", [Platform.SWITCH, Platform.LIGHT])
async def test_device_override(
zha_gateway: Gateway, override_platform: Platform
Expand Down Expand Up @@ -673,6 +718,173 @@ async def test_quirks_v2_fallback_name(zha_gateway: Gateway) -> None:
assert entity.fallback_name == "Fallback name"


async def test_discover_group_entities_member_drop_runs_group_entity_on_remove(
zha_gateway: Gateway,
) -> None:
"""Test group member drop cleanup runs GroupEntity.on_remove."""
switch_endpoints = {
1: {
SIG_EP_INPUT: [
zigpy.zcl.clusters.general.Basic.cluster_id,
zigpy.zcl.clusters.general.OnOff.cluster_id,
zigpy.zcl.clusters.general.Groups.cluster_id,
],
SIG_EP_OUTPUT: [],
SIG_EP_TYPE: zigpy.profiles.zha.DeviceType.ON_OFF_SWITCH,
SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID,
}
}

zigpy_device_1 = create_mock_zigpy_device(
zha_gateway,
switch_endpoints,
ieee="01:2d:6f:00:0a:90:69:e8",
)
zigpy_device_2 = create_mock_zigpy_device(
zha_gateway,
switch_endpoints,
ieee="02:2d:6f:00:0a:90:69:e9",
)

zha_device_1 = await join_zigpy_device(zha_gateway, zigpy_device_1)
zha_device_2 = await join_zigpy_device(zha_gateway, zigpy_device_2)

members = [
GroupMemberReference(ieee=zha_device_1.ieee, endpoint_id=1),
GroupMemberReference(ieee=zha_device_2.ieee, endpoint_id=1),
]
zha_group = await zha_gateway.async_create_zigpy_group("Test Group", members)
zha_group.clear_caches()

for entity in discover_group_entities(zha_group):
entity.on_add()

assert zha_group.group_entities
(group_entity,) = tuple(zha_group.group_entities.values())

try:
with mock.patch.object(
group_entity,
"on_remove",
new=AsyncMock(wraps=group_entity.on_remove),
) as mocked_on_remove:
for member in list(zha_group.zigpy_group.members)[1:]:
zha_group.zigpy_group.members.pop(member)
zha_group.clear_caches()

# Issue being validated:
# when group members drop below 2, discover_group_entities() currently
# clears group.group_entities directly.
#
# Why this is a problem:
# bypassing GroupEntity.on_remove() skips entity-level cleanup (task/handle
# cancellation and unregister logic), leaking group entity lifecycle state.
list(discover_group_entities(zha_group))

assert mocked_on_remove.await_count == 1
finally:
await group_entity.on_remove()


async def test_discover_group_entities_platform_quorum_drop_prunes_stale_platform_entity(
zha_gateway: Gateway,
) -> None:
"""Test platform quorum drop cleanup removes stale group platform entities."""
light_endpoints = {
1: {
SIG_EP_INPUT: [
zigpy.zcl.clusters.general.Basic.cluster_id,
zigpy.zcl.clusters.general.OnOff.cluster_id,
zigpy.zcl.clusters.general.LevelControl.cluster_id,
zigpy.zcl.clusters.general.Groups.cluster_id,
],
SIG_EP_OUTPUT: [],
SIG_EP_TYPE: zigpy.profiles.zha.DeviceType.DIMMABLE_LIGHT,
SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID,
}
}
switch_endpoints = {
1: {
SIG_EP_INPUT: [
zigpy.zcl.clusters.general.Basic.cluster_id,
zigpy.zcl.clusters.general.OnOff.cluster_id,
zigpy.zcl.clusters.general.Groups.cluster_id,
],
SIG_EP_OUTPUT: [],
SIG_EP_TYPE: zigpy.profiles.zha.DeviceType.ON_OFF_SWITCH,
SIG_EP_PROFILE: zigpy.profiles.zha.PROFILE_ID,
}
}

zigpy_light_1 = create_mock_zigpy_device(
zha_gateway,
light_endpoints,
ieee="11:2d:6f:00:0a:90:69:e1",
)
zigpy_light_2 = create_mock_zigpy_device(
zha_gateway,
light_endpoints,
ieee="12:2d:6f:00:0a:90:69:e2",
)
zigpy_switch = create_mock_zigpy_device(
zha_gateway,
switch_endpoints,
ieee="13:2d:6f:00:0a:90:69:e3",
)

zha_light_1 = await join_zigpy_device(zha_gateway, zigpy_light_1)
zha_light_2 = await join_zigpy_device(zha_gateway, zigpy_light_2)
zha_switch = await join_zigpy_device(zha_gateway, zigpy_switch)

members = [
GroupMemberReference(ieee=zha_light_1.ieee, endpoint_id=1),
GroupMemberReference(ieee=zha_light_2.ieee, endpoint_id=1),
GroupMemberReference(ieee=zha_switch.ieee, endpoint_id=1),
]
zha_group = await zha_gateway.async_create_zigpy_group("Test Group", members)
stale_light_group_entity = None
try:
zha_group.clear_caches()

for entity in discover_group_entities(zha_group):
entity.on_add()

light_group_entities = [
entity
for entity in zha_group.group_entities.values()
if entity.PLATFORM == Platform.LIGHT
]
assert len(light_group_entities) == 1
stale_light_group_entity = light_group_entities[0]

removed_light_member = False
for member in list(zha_group.zigpy_group.members):
if member[0] == zha_light_2.ieee:
zha_group.zigpy_group.members.pop(member)
removed_light_member = True
break
assert removed_light_member
zha_group.clear_caches()

# Issue being validated:
# when a single platform's member count drops below 2 (for example, LIGHT),
# discover_group_entities() yields no replacement but does not prune stale
# group entities for that platform if total members remain >= 2.
#
# Why this is a problem:
# stale platform group entities remain exposed and continue receiving updates
# even though the group no longer has enough members for that platform.
list(discover_group_entities(zha_group))

assert stale_light_group_entity.unique_id not in zha_group.group_entities
finally:
if stale_light_group_entity is not None:
await stale_light_group_entity.on_remove()
for entity in tuple(zha_group.group_entities.values()):
await entity.on_remove()
zha_group.clear_caches()


def pytest_generate_tests(metafunc):
"""Generate tests for all device files."""
if "file_path" in metafunc.fixturenames:
Expand Down
30 changes: 28 additions & 2 deletions zha/application/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,14 @@ def discover_device_entities(device: Device) -> Iterator[BaseEntity]:
endpoint.id,
)

yield from discover_entities_for_endpoint(endpoint)
try:
yield from discover_entities_for_endpoint(endpoint)
except Exception: # pylint: disable=broad-exception-caught
_LOGGER.exception(
"Failed to discover entities for endpoint %s on device %s",
endpoint.id,
str(device.ieee),
)

yield from discover_quirks_v2_entities(device)

Expand Down Expand Up @@ -194,14 +201,24 @@ def discover_coordinator_device_entities(
@ignore_exceptions_during_iteration
def discover_group_entities(group: Group) -> Iterator[GroupEntity]:
"""Process a group and create any entities that are needed."""

def _schedule_group_entity_cleanup(group_entity: GroupEntity) -> None:
"""Schedule lifecycle cleanup for an existing group entity."""
group.gateway.async_create_task(
group_entity.on_remove(),
name=f"zha.discovery-remove-group-entity-{group_entity.unique_id}",
eager_start=True,
)

# only create a group entity if there are 2 or more members in a group
if len(group.members) < 2:
_LOGGER.debug(
"Group: %s:0x%04x has less than 2 members - skipping entity discovery",
group.name,
group.group_id,
)
group.group_entities.clear()
for group_entity in tuple(group.group_entities.values()):
_schedule_group_entity_cleanup(group_entity)
return

# We only create groups with two or more devices
Expand All @@ -214,6 +231,15 @@ def discover_group_entities(group: Group) -> Iterator[GroupEntity]:
for entity in member.associated_entities:
platform_counts[entity.PLATFORM] += 1

eligible_platforms = {
platform for platform, count in platform_counts.items() if count >= 2
}

for group_entity in tuple(group.group_entities.values()):
if group_entity.PLATFORM in eligible_platforms:
continue
_schedule_group_entity_cleanup(group_entity)

for platform, count in platform_counts.items():
if count < 2:
continue
Expand Down
Loading