diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 0e72f47e224..95aa85cfdb3 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -2925,6 +2925,10 @@ def __init__( self._reconnect_escalation_threshold: int = max( 1, int(os.getenv("PRISMA_RECONNECT_ESCALATION_THRESHOLD", "3")) ) + self._watchdog_failures_before_reconnect: int = max( + 1, int(os.getenv("PRISMA_WATCHDOG_FAILURES_BEFORE_RECONNECT", "1")) + ) + self._consecutive_probe_failures: int = 0 self._engine_pidfd: int = -1 self._engine_pid: int = 0 self._watching_engine: bool = False @@ -4746,11 +4750,12 @@ async def start_db_health_watchdog_task(self) -> None: self._db_health_watchdog_loop() ) verbose_proxy_logger.info( - "Started Prisma DB health watchdog (interval=%ss, reconnect_cooldown=%ss, probe_timeout=%ss, reconnect_timeout=%ss)", + "Started Prisma DB health watchdog (interval=%ss, reconnect_cooldown=%ss, probe_timeout=%ss, reconnect_timeout=%ss, failures_before_reconnect=%s)", self._db_health_watchdog_interval_seconds, self._db_reconnect_cooldown_seconds, self._db_health_watchdog_probe_timeout_seconds, self._db_watchdog_reconnect_timeout_seconds, + self._watchdog_failures_before_reconnect, ) await self._start_engine_watcher() @@ -4775,16 +4780,29 @@ async def _db_health_watchdog_loop(self) -> None: self.db.query_raw("SELECT 1"), timeout=self._db_health_watchdog_probe_timeout_seconds, ) + self._consecutive_probe_failures = 0 except asyncio.CancelledError: break except Exception as e: if isinstance( e, asyncio.TimeoutError ) or PrismaDBExceptionHandler.is_database_connection_error(e): - await self.attempt_db_reconnect( - reason="db_health_watchdog_connection_error", - timeout_seconds=self._db_watchdog_reconnect_timeout_seconds, - ) + self._consecutive_probe_failures += 1 + if ( + self._consecutive_probe_failures + >= self._watchdog_failures_before_reconnect + ): + self._consecutive_probe_failures = 0 + await self.attempt_db_reconnect( + reason="db_health_watchdog_connection_error", + timeout_seconds=self._db_watchdog_reconnect_timeout_seconds, + ) + else: + verbose_proxy_logger.debug( + "Prisma DB watchdog probe failure %d/%d; deferring reconnect.", + self._consecutive_probe_failures, + self._watchdog_failures_before_reconnect, + ) else: verbose_proxy_logger.debug( "Prisma DB health watchdog observed non-DB error: %s", e diff --git a/tests/test_litellm/proxy/db/test_prisma_self_heal.py b/tests/test_litellm/proxy/db/test_prisma_self_heal.py index 3f9ba6af3af..e7191023ece 100644 --- a/tests/test_litellm/proxy/db/test_prisma_self_heal.py +++ b/tests/test_litellm/proxy/db/test_prisma_self_heal.py @@ -486,3 +486,90 @@ async def test_engine_confirmed_dead_persists_across_failed_heavy_reconnect( # The flag must STILL be True so the next attempt re-enters the heavy # branch instead of silently demoting to the lightweight path. assert client._engine_confirmed_dead is True + + +# --------------------------------------------------------------------------- +# Consecutive-failures gate (PR 3) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_db_health_watchdog_defers_reconnect_below_threshold(mock_proxy_logging): + """With threshold=3, two consecutive probe failures must NOT trigger reconnect.""" + client = PrismaClient( + database_url="mock://test", proxy_logging_obj=mock_proxy_logging + ) + client.db.query_raw = AsyncMock(side_effect=asyncio.TimeoutError()) + client.attempt_db_reconnect = AsyncMock(return_value=True) + client._db_health_watchdog_interval_seconds = 1 + client._db_health_watchdog_probe_timeout_seconds = 0.2 + client._watchdog_failures_before_reconnect = 3 + + with patch( + "litellm.proxy.utils.asyncio.sleep", + AsyncMock(side_effect=[None, None, asyncio.CancelledError()]), + ): + await client._db_health_watchdog_loop() + + client.attempt_db_reconnect.assert_not_awaited() + assert client._consecutive_probe_failures == 2 + + +@pytest.mark.asyncio +async def test_db_health_watchdog_triggers_reconnect_at_threshold(mock_proxy_logging): + """With threshold=3, exactly three consecutive failures must trigger one reconnect.""" + client = PrismaClient( + database_url="mock://test", proxy_logging_obj=mock_proxy_logging + ) + client.db.query_raw = AsyncMock(side_effect=asyncio.TimeoutError()) + client.attempt_db_reconnect = AsyncMock(return_value=True) + client._db_health_watchdog_interval_seconds = 1 + client._db_health_watchdog_probe_timeout_seconds = 0.2 + client._db_watchdog_reconnect_timeout_seconds = 7.0 + client._watchdog_failures_before_reconnect = 3 + + with patch( + "litellm.proxy.utils.asyncio.sleep", + AsyncMock(side_effect=[None, None, None, asyncio.CancelledError()]), + ): + await client._db_health_watchdog_loop() + + client.attempt_db_reconnect.assert_awaited_once_with( + reason="db_health_watchdog_connection_error", + timeout_seconds=7.0, + ) + + +@pytest.mark.asyncio +async def test_db_health_watchdog_resets_failure_counter_on_successful_probe( + mock_proxy_logging, +): + """A successful probe resets the failure counter so two failures before and two + after do not add up to a threshold of three.""" + client = PrismaClient( + database_url="mock://test", proxy_logging_obj=mock_proxy_logging + ) + client.db.query_raw = AsyncMock( + side_effect=[ + asyncio.TimeoutError(), + asyncio.TimeoutError(), + [{"1": 1}], + asyncio.TimeoutError(), + asyncio.TimeoutError(), + ] + ) + client.attempt_db_reconnect = AsyncMock(return_value=True) + client._db_health_watchdog_interval_seconds = 1 + client._db_health_watchdog_probe_timeout_seconds = 0.2 + client._watchdog_failures_before_reconnect = 3 + + with patch( + "litellm.proxy.utils.asyncio.sleep", + AsyncMock( + side_effect=[None, None, None, None, None, asyncio.CancelledError()] + ), + ): + await client._db_health_watchdog_loop() + + client.attempt_db_reconnect.assert_not_awaited() + assert client._consecutive_probe_failures == 2