diff --git a/tests/test_sensor.py b/tests/test_sensor.py index 6e96304df..2b0009a78 100644 --- a/tests/test_sensor.py +++ b/tests/test_sensor.py @@ -2190,3 +2190,109 @@ async def test_ubisys_polled_em_keeps_polling_when_disabled( # Polling task must still be running (no duplicate created) assert entity._polling_task is not None assert not entity._polling_task.done() + + +async def test_pollable_sensor_enable_non_idempotent_disable_leaves_orphan_poll_task( + zha_gateway: Gateway, +) -> None: + """Test PollableSensor enable/disable lifecycle is idempotent.""" + zigpy_device = elec_measurement_zigpy_device_mock(zha_gateway) + zha_device = await join_zigpy_device(zha_gateway, zigpy_device) + entity = get_entity( + zha_device, + platform=Platform.SENSOR, + exact_entity_type=sensor.PolledElectricalMeasurement, + ) + + assert entity._polling_task is not None + + first_task: asyncio.Task | None = None + second_task: asyncio.Task | None = None + + # Issue being validated: + # PollableSensor.enable() always calls maybe_start_polling() and does not guard + # against an existing polling task, so repeated enable() calls create extra tasks. + # PollableSensor.disable() then only cancels/removes self._polling_task (latest). + # + # Why this is a problem: + # A stale poll task can outlive disable(), continue background work, and leak task + # ownership because only the newest handle is tracked for cancellation. + try: + # Reset baseline from on_add() auto-started polling. + entity.disable() + await asyncio.sleep(0) + + entity.enable() + first_task = entity._polling_task + assert first_task is not None + assert not first_task.done() + + entity.enable() + second_task = entity._polling_task + assert second_task is not None + assert second_task is first_task + assert not second_task.done() + + entity.disable() + await asyncio.sleep(0) + + assert first_task.cancelled() + assert second_task.cancelled() + assert first_task not in entity._tracked_tasks + assert second_task not in entity._tracked_tasks + finally: + for task in (first_task, second_task): + if task is None: + continue + if task in entity._tracked_tasks: + entity._tracked_tasks.remove(task) + if not task.done(): + task.cancel() + await asyncio.gather( + *(task for task in (first_task, second_task) if task is not None), + return_exceptions=True, + ) + + +async def test_pollable_sensor_replaces_completed_polling_task( + zha_gateway: Gateway, +) -> None: + """Test completed poll task handles are replaced cleanly.""" + zigpy_device = elec_measurement_zigpy_device_mock(zha_gateway) + zha_device = await join_zigpy_device(zha_gateway, zigpy_device) + entity = get_entity( + zha_device, + platform=Platform.SENSOR, + exact_entity_type=sensor.PolledElectricalMeasurement, + ) + + # Reset baseline from on_add() auto-started polling. + entity.disable() + await asyncio.sleep(0) + + completed_task = asyncio.create_task(asyncio.sleep(0)) + await completed_task + entity._polling_task = completed_task + entity._tracked_tasks.append(completed_task) + + # Issue being validated: + # maybe_start_polling() must remove a completed poll task from tracking before + # creating the next task handle. + # + # Why this is a problem: + # Completed task handles left in _tracked_tasks can accumulate stale lifecycle + # state and interfere with deterministic cleanup in disable/remove paths. + replacement_task: asyncio.Task | None = None + try: + entity.maybe_start_polling() + replacement_task = entity._polling_task + + assert replacement_task is not None + assert replacement_task is not completed_task + assert completed_task not in entity._tracked_tasks + assert replacement_task in entity._tracked_tasks + finally: + entity.disable() + await asyncio.sleep(0) + if replacement_task and replacement_task in entity._tracked_tasks: + entity._tracked_tasks.remove(replacement_task) diff --git a/zha/application/platforms/sensor/__init__.py b/zha/application/platforms/sensor/__init__.py index f89bd7457..42ddca130 100644 --- a/zha/application/platforms/sensor/__init__.py +++ b/zha/application/platforms/sensor/__init__.py @@ -418,18 +418,27 @@ def should_poll(self) -> bool: def maybe_start_polling(self) -> None: """Start polling if necessary.""" - if self.should_poll: - self._polling_task = self.device.gateway.async_create_background_task( - self._refresh(), - name=f"sensor_state_poller_{self.unique_id}_{self.__class__.__name__}", - eager_start=True, - untracked=True, - ) - self._tracked_tasks.append(self._polling_task) - self.debug( - "started polling with refresh interval of %s", - getattr(self, "__polling_interval"), - ) + if not self.should_poll: + return + + if self._polling_task and not self._polling_task.done(): + return + + if self._polling_task and self._polling_task.done(): + with contextlib.suppress(ValueError): + self._tracked_tasks.remove(self._polling_task) + + self._polling_task = self.device.gateway.async_create_background_task( + self._refresh(), + name=f"sensor_state_poller_{self.unique_id}_{self.__class__.__name__}", + eager_start=True, + untracked=True, + ) + self._tracked_tasks.append(self._polling_task) + self.debug( + "started polling with refresh interval of %s", + getattr(self, "__polling_interval"), + ) def enable(self) -> None: """Enable the entity.""" @@ -440,7 +449,8 @@ def disable(self) -> None: """Disable the entity.""" super().disable() if self._polling_task: - self._tracked_tasks.remove(self._polling_task) + with contextlib.suppress(ValueError): + self._tracked_tasks.remove(self._polling_task) self._polling_task.cancel() self._polling_task = None