Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions litellm/proxy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2925,6 +2925,10 @@ def __init__(
self._reconnect_escalation_threshold: int = max(
1, int(os.getenv("PRISMA_RECONNECT_ESCALATION_THRESHOLD", "3"))
)
self._watchdog_failures_before_reconnect: int = max(
1, int(os.getenv("PRISMA_WATCHDOG_FAILURES_BEFORE_RECONNECT", "1"))
)
self._consecutive_probe_failures: int = 0
self._engine_pidfd: int = -1
self._engine_pid: int = 0
self._watching_engine: bool = False
Expand Down Expand Up @@ -4746,11 +4750,12 @@ async def start_db_health_watchdog_task(self) -> None:
self._db_health_watchdog_loop()
)
verbose_proxy_logger.info(
"Started Prisma DB health watchdog (interval=%ss, reconnect_cooldown=%ss, probe_timeout=%ss, reconnect_timeout=%ss)",
"Started Prisma DB health watchdog (interval=%ss, reconnect_cooldown=%ss, probe_timeout=%ss, reconnect_timeout=%ss, failures_before_reconnect=%s)",
self._db_health_watchdog_interval_seconds,
self._db_reconnect_cooldown_seconds,
self._db_health_watchdog_probe_timeout_seconds,
self._db_watchdog_reconnect_timeout_seconds,
self._watchdog_failures_before_reconnect,
)
await self._start_engine_watcher()

Expand All @@ -4775,16 +4780,29 @@ async def _db_health_watchdog_loop(self) -> None:
self.db.query_raw("SELECT 1"),
timeout=self._db_health_watchdog_probe_timeout_seconds,
)
self._consecutive_probe_failures = 0
except asyncio.CancelledError:
break
except Exception as e:
if isinstance(
e, asyncio.TimeoutError
) or PrismaDBExceptionHandler.is_database_connection_error(e):
await self.attempt_db_reconnect(
reason="db_health_watchdog_connection_error",
timeout_seconds=self._db_watchdog_reconnect_timeout_seconds,
)
self._consecutive_probe_failures += 1
if (
self._consecutive_probe_failures
>= self._watchdog_failures_before_reconnect
):
self._consecutive_probe_failures = 0
await self.attempt_db_reconnect(
reason="db_health_watchdog_connection_error",
timeout_seconds=self._db_watchdog_reconnect_timeout_seconds,
)
else:
verbose_proxy_logger.debug(
"Prisma DB watchdog probe failure %d/%d; deferring reconnect.",
self._consecutive_probe_failures,
self._watchdog_failures_before_reconnect,
)
else:
verbose_proxy_logger.debug(
"Prisma DB health watchdog observed non-DB error: %s", e
Expand Down
87 changes: 87 additions & 0 deletions tests/test_litellm/proxy/db/test_prisma_self_heal.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,90 @@ async def test_engine_confirmed_dead_persists_across_failed_heavy_reconnect(
# The flag must STILL be True so the next attempt re-enters the heavy
# branch instead of silently demoting to the lightweight path.
assert client._engine_confirmed_dead is True


# ---------------------------------------------------------------------------
# Consecutive-failures gate (PR 3)
# ---------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_db_health_watchdog_defers_reconnect_below_threshold(mock_proxy_logging):
"""With threshold=3, two consecutive probe failures must NOT trigger reconnect."""
client = PrismaClient(
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
)
client.db.query_raw = AsyncMock(side_effect=asyncio.TimeoutError())
client.attempt_db_reconnect = AsyncMock(return_value=True)
client._db_health_watchdog_interval_seconds = 1
client._db_health_watchdog_probe_timeout_seconds = 0.2
client._watchdog_failures_before_reconnect = 3

with patch(
"litellm.proxy.utils.asyncio.sleep",
AsyncMock(side_effect=[None, None, asyncio.CancelledError()]),
):
await client._db_health_watchdog_loop()

client.attempt_db_reconnect.assert_not_awaited()
assert client._consecutive_probe_failures == 2


@pytest.mark.asyncio
async def test_db_health_watchdog_triggers_reconnect_at_threshold(mock_proxy_logging):
"""With threshold=3, exactly three consecutive failures must trigger one reconnect."""
client = PrismaClient(
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
)
client.db.query_raw = AsyncMock(side_effect=asyncio.TimeoutError())
client.attempt_db_reconnect = AsyncMock(return_value=True)
client._db_health_watchdog_interval_seconds = 1
client._db_health_watchdog_probe_timeout_seconds = 0.2
client._db_watchdog_reconnect_timeout_seconds = 7.0
client._watchdog_failures_before_reconnect = 3

with patch(
"litellm.proxy.utils.asyncio.sleep",
AsyncMock(side_effect=[None, None, None, asyncio.CancelledError()]),
):
await client._db_health_watchdog_loop()

client.attempt_db_reconnect.assert_awaited_once_with(
reason="db_health_watchdog_connection_error",
timeout_seconds=7.0,
)


@pytest.mark.asyncio
async def test_db_health_watchdog_resets_failure_counter_on_successful_probe(
mock_proxy_logging,
):
"""A successful probe resets the failure counter so two failures before and two
after do not add up to a threshold of three."""
client = PrismaClient(
database_url="mock://test", proxy_logging_obj=mock_proxy_logging
)
client.db.query_raw = AsyncMock(
side_effect=[
asyncio.TimeoutError(),
asyncio.TimeoutError(),
[{"1": 1}],
asyncio.TimeoutError(),
asyncio.TimeoutError(),
]
)
client.attempt_db_reconnect = AsyncMock(return_value=True)
client._db_health_watchdog_interval_seconds = 1
client._db_health_watchdog_probe_timeout_seconds = 0.2
client._watchdog_failures_before_reconnect = 3

with patch(
"litellm.proxy.utils.asyncio.sleep",
AsyncMock(
side_effect=[None, None, None, None, None, asyncio.CancelledError()]
),
):
await client._db_health_watchdog_loop()

client.attempt_db_reconnect.assert_not_awaited()
assert client._consecutive_probe_failures == 2
Loading