From 0aaa1df9d5ff02e957b53d5f23e41da5fa016c08 Mon Sep 17 00:00:00 2001 From: Yassin Kortam Date: Wed, 17 Jun 2026 09:02:18 -0700 Subject: [PATCH] feat(proxy): add --max_requests_before_restart_jitter to stagger worker restarts Setting --max_requests_before_restart alone recycles every worker at almost the same time once they have served a similar number of requests, which under sustained load can drop a whole pod's capacity at once roughly every 7-10 days. This exposes a jitter knob that adds a random amount in [0, jitter] to the restart threshold per worker so restarts are staggered. It maps to uvicorn's limit_max_requests_jitter and gunicorn's max_requests_jitter. uvicorn only gained limit_max_requests_jitter in 0.41.0 while litellm still allows uvicorn>=0.33.0, so the uvicorn path feature-detects the parameter via the Config signature and warns instead of crashing on older versions. The flag has no effect without --max_requests_before_restart, so the kwarg is not forwarded in that case and a warning is printed on both the uvicorn and gunicorn paths. Resolves LIT-3774 --- litellm/proxy/proxy_cli.py | 62 ++++++ tests/test_litellm/proxy/test_proxy_cli.py | 225 +++++++++++++++++++++ 2 files changed, 287 insertions(+) diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index e1fb65074cd..9c4d7b1bb5d 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -205,6 +205,37 @@ def _get_default_unvicorn_init_args( ) return uvicorn_args + @staticmethod + def _apply_uvicorn_max_requests_jitter( + uvicorn_args: dict, + max_requests_before_restart: Optional[int], + jitter: int, + ) -> None: + """ + Stagger uvicorn worker restarts via limit_max_requests_jitter (uvicorn>=0.41.0). + """ + import inspect + + import uvicorn + + if max_requests_before_restart is None: + print( + "\033[1;33mLiteLLM Proxy: --max_requests_before_restart_jitter " + "has no effect without --max_requests_before_restart\033[0m\n" + ) + return + if ( + "limit_max_requests_jitter" + in inspect.signature(uvicorn.Config.__init__).parameters + ): + uvicorn_args["limit_max_requests_jitter"] = jitter + else: + print( + f"\033[1;33mLiteLLM Proxy: --max_requests_before_restart_jitter " + f"requires uvicorn>=0.41.0, but installed uvicorn=={uvicorn.__version__}. " + f"Ignoring the flag.\033[0m" + ) + @staticmethod def _get_reload_options(config_path: Optional[str]) -> dict: """Build uvicorn reload kwargs so --reload also reacts to .env and YAML edits.""" @@ -387,6 +418,7 @@ def _run_gunicorn_server( ssl_certfile_path: str, ssl_keyfile_path: str, max_requests_before_restart: Optional[int] = None, + max_requests_before_restart_jitter: Optional[int] = None, ): """ Run litellm with `gunicorn` @@ -467,6 +499,16 @@ def load(self): # Optional: recycle workers after N requests to mitigate memory growth if max_requests_before_restart is not None: gunicorn_options["max_requests"] = max_requests_before_restart + if max_requests_before_restart_jitter is not None: + if max_requests_before_restart is None: + print( + "\033[1;33mLiteLLM Proxy: --max_requests_before_restart_jitter " + "has no effect without --max_requests_before_restart\033[0m\n" + ) + else: + gunicorn_options["max_requests_jitter"] = ( + max_requests_before_restart_jitter + ) # Clean up prometheus .db files when a worker exits (prevents ghost gauge values) if os.environ.get("PROMETHEUS_MULTIPROC_DIR"): @@ -791,6 +833,18 @@ def _maybe_setup_prometheus_multiproc_dir( help="Restart worker after this many requests (uvicorn: limit_max_requests, gunicorn: max_requests)", envvar="MAX_REQUESTS_BEFORE_RESTART", ) +@click.option( + "--max_requests_before_restart_jitter", + default=None, + type=int, + help=( + "Stagger worker restarts by adding a random amount in [0, jitter] to " + "--max_requests_before_restart so workers do not recycle at the same time " + "(uvicorn: limit_max_requests_jitter, requires uvicorn>=0.41.0; gunicorn: max_requests_jitter). " + "Has no effect without --max_requests_before_restart." + ), + envvar="MAX_REQUESTS_BEFORE_RESTART_JITTER", +) @click.option( "--enforce_prisma_migration_check", is_flag=True, @@ -858,6 +912,7 @@ def run_server( keepalive_timeout, timeout_worker_healthcheck, max_requests_before_restart, + max_requests_before_restart_jitter: Optional[int], enforce_prisma_migration_check: bool, use_v2_migration_resolver: bool, reload: bool, @@ -1260,6 +1315,12 @@ def run_server( if max_requests_before_restart is not None: uvicorn_args["limit_max_requests"] = max_requests_before_restart if run_gunicorn is False and run_hypercorn is False and run_granian is False: + if max_requests_before_restart_jitter is not None: + ProxyInitializationHelpers._apply_uvicorn_max_requests_jitter( + uvicorn_args=uvicorn_args, + max_requests_before_restart=max_requests_before_restart, + jitter=max_requests_before_restart_jitter, + ) if ssl_certfile_path is not None and ssl_keyfile_path is not None: print( f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n" @@ -1287,6 +1348,7 @@ def run_server( ssl_certfile_path=ssl_certfile_path, ssl_keyfile_path=ssl_keyfile_path, max_requests_before_restart=max_requests_before_restart, + max_requests_before_restart_jitter=max_requests_before_restart_jitter, ) elif run_hypercorn is True: ProxyInitializationHelpers._init_hypercorn_server( diff --git a/tests/test_litellm/proxy/test_proxy_cli.py b/tests/test_litellm/proxy/test_proxy_cli.py index 34c88e2fd33..56627c5be88 100644 --- a/tests/test_litellm/proxy/test_proxy_cli.py +++ b/tests/test_litellm/proxy/test_proxy_cli.py @@ -1178,6 +1178,231 @@ def test_max_requests_before_restart_flag( call_args = mock_uvicorn_run.call_args assert call_args[1]["limit_max_requests"] == 123 + @patch("uvicorn.run") + @patch("builtins.print") + @patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database") + def test_max_requests_before_restart_jitter_flag( + self, mock_setup_db, mock_print, mock_uvicorn_run + ): + """--max_requests_before_restart_jitter maps to uvicorn limit_max_requests_jitter""" + from click.testing import CliRunner + + from litellm.proxy.proxy_cli import run_server + + class _NewUvicornConfig: + def __init__(self, limit_max_requests=None, limit_max_requests_jitter=0): + pass + + runner = CliRunner() + clean_env = { + k: v + for k, v in os.environ.items() + if k not in ("DATABASE_URL", "DIRECT_URL") + } + with ( + patch.dict(os.environ, clean_env, clear=True), + patch("uvicorn.Config", _NewUvicornConfig), + patch.dict( + "sys.modules", + { + "proxy_server": MagicMock( + app=MagicMock(), + ProxyConfig=MagicMock(), + KeyManagementSettings=MagicMock(), + save_worker_config=MagicMock(), + ) + }, + ), + patch( + "litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args" + ) as mock_get_args, + ): + mock_get_args.return_value = { + "app": "litellm.proxy.proxy_server:app", + "host": "localhost", + "port": 8000, + } + + result = runner.invoke( + run_server, + [ + "--local", + "--max_requests_before_restart", + "1000", + "--max_requests_before_restart_jitter", + "50", + ], + ) + + assert ( + result.exit_code == 0 + ), f"exit_code={result.exit_code}, output={result.output}" + mock_uvicorn_run.assert_called_once() + call_args = mock_uvicorn_run.call_args + assert call_args[1]["limit_max_requests"] == 1000 + assert call_args[1]["limit_max_requests_jitter"] == 50 + + @patch("litellm.proxy.proxy_cli.ProxyInitializationHelpers._run_gunicorn_server") + @patch("uvicorn.run") + @patch("builtins.print") + @patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database") + def test_run_gunicorn_passes_max_requests_jitter( + self, mock_setup_db, mock_print, mock_uvicorn_run, mock_run_gunicorn + ): + """--run_gunicorn threads jitter into _run_gunicorn_server, not uvicorn.run""" + from click.testing import CliRunner + + from litellm.proxy.proxy_cli import run_server + + runner = CliRunner() + clean_env = { + k: v + for k, v in os.environ.items() + if k not in ("DATABASE_URL", "DIRECT_URL") + } + with ( + patch.dict(os.environ, clean_env, clear=True), + patch.dict( + "sys.modules", + { + "proxy_server": MagicMock( + app=MagicMock(), + ProxyConfig=MagicMock(), + KeyManagementSettings=MagicMock(), + save_worker_config=MagicMock(), + ) + }, + ), + patch( + "litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args" + ) as mock_get_args, + ): + mock_get_args.return_value = { + "app": "litellm.proxy.proxy_server:app", + "host": "localhost", + "port": 8000, + } + + result = runner.invoke( + run_server, + [ + "--local", + "--run_gunicorn", + "--max_requests_before_restart", + "900", + "--max_requests_before_restart_jitter", + "75", + ], + ) + + assert ( + result.exit_code == 0 + ), f"exit_code={result.exit_code}, output={result.output}" + mock_uvicorn_run.assert_not_called() + mock_run_gunicorn.assert_called_once() + g_kwargs = mock_run_gunicorn.call_args[1] + assert g_kwargs["max_requests_before_restart"] == 900 + assert g_kwargs["max_requests_before_restart_jitter"] == 75 + + @pytest.mark.skipif(os.name == "nt", reason="gunicorn server path skips Windows") + def test_gunicorn_options_include_max_requests_jitter(self): + """_run_gunicorn_server puts max_requests_jitter into the gunicorn options""" + pytest.importorskip("gunicorn") + + captured: dict = {} + + def capture_run(self): + captured["options"] = dict(self.options) + + with patch("gunicorn.app.base.BaseApplication.run", capture_run): + ProxyInitializationHelpers._run_gunicorn_server( + host="127.0.0.1", + port=4010, + app=MagicMock(), + num_workers=2, + ssl_certfile_path=None, + ssl_keyfile_path=None, + max_requests_before_restart=1000, + max_requests_before_restart_jitter=50, + ) + + assert captured["options"]["max_requests"] == 1000 + assert captured["options"]["max_requests_jitter"] == 50 + + @pytest.mark.skipif(os.name == "nt", reason="gunicorn server path skips Windows") + def test_gunicorn_jitter_without_base_warns(self): + """gunicorn path warns when jitter is set without --max_requests_before_restart""" + pytest.importorskip("gunicorn") + + captured: dict = {} + + def capture_run(self): + captured["options"] = dict(self.options) + + with ( + patch("gunicorn.app.base.BaseApplication.run", capture_run), + patch("builtins.print") as mock_print, + ): + ProxyInitializationHelpers._run_gunicorn_server( + host="127.0.0.1", + port=4011, + app=MagicMock(), + num_workers=2, + ssl_certfile_path=None, + ssl_keyfile_path=None, + max_requests_before_restart=None, + max_requests_before_restart_jitter=50, + ) + + assert "max_requests" not in captured["options"] + assert "max_requests_jitter" not in captured["options"] + assert any("has no effect" in str(c) for c in mock_print.call_args_list) + + def test_apply_uvicorn_jitter_sets_arg_when_supported(self): + class _NewUvicornConfig: + def __init__(self, limit_max_requests=None, limit_max_requests_jitter=0): + pass + + uvicorn_args: dict = {} + with patch("uvicorn.Config", _NewUvicornConfig): + ProxyInitializationHelpers._apply_uvicorn_max_requests_jitter( + uvicorn_args=uvicorn_args, + max_requests_before_restart=1000, + jitter=50, + ) + assert uvicorn_args["limit_max_requests_jitter"] == 50 + + def test_apply_uvicorn_jitter_skipped_on_old_uvicorn(self): + class _FakeUvicornConfig: + def __init__(self, limit_max_requests=None): + pass + + uvicorn_args: dict = {} + with ( + patch("uvicorn.Config", _FakeUvicornConfig), + patch("builtins.print") as mock_print, + ): + ProxyInitializationHelpers._apply_uvicorn_max_requests_jitter( + uvicorn_args=uvicorn_args, + max_requests_before_restart=1000, + jitter=50, + ) + + assert "limit_max_requests_jitter" not in uvicorn_args + assert any("0.41.0" in str(c) for c in mock_print.call_args_list) + + def test_apply_uvicorn_jitter_without_base_warns(self): + uvicorn_args: dict = {} + with patch("builtins.print") as mock_print: + ProxyInitializationHelpers._apply_uvicorn_max_requests_jitter( + uvicorn_args=uvicorn_args, + max_requests_before_restart=None, + jitter=50, + ) + + assert "limit_max_requests_jitter" not in uvicorn_args + assert any("has no effect" in str(c) for c in mock_print.call_args_list) + @patch.dict(os.environ, {}, clear=True) def test_construct_database_url_from_env_vars(self): """Test the construct_database_url_from_env_vars function with various scenarios"""