Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions litellm/proxy/proxy_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,37 @@ def _get_default_unvicorn_init_args(
)
return uvicorn_args

@staticmethod
def _apply_uvicorn_max_requests_jitter(
uvicorn_args: dict,
max_requests_before_restart: Optional[int],
jitter: int,
) -> None:
"""
Stagger uvicorn worker restarts via limit_max_requests_jitter (uvicorn>=0.41.0).
"""
import inspect

import uvicorn

if max_requests_before_restart is None:
print(
"\033[1;33mLiteLLM Proxy: --max_requests_before_restart_jitter "
"has no effect without --max_requests_before_restart\033[0m\n"
)
return
if (
"limit_max_requests_jitter"
in inspect.signature(uvicorn.Config.__init__).parameters
):
uvicorn_args["limit_max_requests_jitter"] = jitter
Comment thread
greptile-apps[bot] marked this conversation as resolved.
else:
print(
f"\033[1;33mLiteLLM Proxy: --max_requests_before_restart_jitter "
f"requires uvicorn>=0.41.0, but installed uvicorn=={uvicorn.__version__}. "
f"Ignoring the flag.\033[0m"
)

@staticmethod
def _get_reload_options(config_path: Optional[str]) -> dict:
"""Build uvicorn reload kwargs so --reload also reacts to .env and YAML edits."""
Expand Down Expand Up @@ -387,6 +418,7 @@ def _run_gunicorn_server(
ssl_certfile_path: str,
ssl_keyfile_path: str,
max_requests_before_restart: Optional[int] = None,
max_requests_before_restart_jitter: Optional[int] = None,
):
"""
Run litellm with `gunicorn`
Expand Down Expand Up @@ -467,6 +499,16 @@ def load(self):
# Optional: recycle workers after N requests to mitigate memory growth
if max_requests_before_restart is not None:
gunicorn_options["max_requests"] = max_requests_before_restart
if max_requests_before_restart_jitter is not None:
if max_requests_before_restart is None:
print(
"\033[1;33mLiteLLM Proxy: --max_requests_before_restart_jitter "
"has no effect without --max_requests_before_restart\033[0m\n"
)
else:
gunicorn_options["max_requests_jitter"] = (
max_requests_before_restart_jitter
)

# Clean up prometheus .db files when a worker exits (prevents ghost gauge values)
if os.environ.get("PROMETHEUS_MULTIPROC_DIR"):
Expand Down Expand Up @@ -791,6 +833,18 @@ def _maybe_setup_prometheus_multiproc_dir(
help="Restart worker after this many requests (uvicorn: limit_max_requests, gunicorn: max_requests)",
envvar="MAX_REQUESTS_BEFORE_RESTART",
)
@click.option(
"--max_requests_before_restart_jitter",
default=None,
type=int,
help=(
"Stagger worker restarts by adding a random amount in [0, jitter] to "
"--max_requests_before_restart so workers do not recycle at the same time "
"(uvicorn: limit_max_requests_jitter, requires uvicorn>=0.41.0; gunicorn: max_requests_jitter). "
"Has no effect without --max_requests_before_restart."
),
envvar="MAX_REQUESTS_BEFORE_RESTART_JITTER",
)
@click.option(
"--enforce_prisma_migration_check",
is_flag=True,
Expand Down Expand Up @@ -858,6 +912,7 @@ def run_server(
keepalive_timeout,
timeout_worker_healthcheck,
max_requests_before_restart,
max_requests_before_restart_jitter: Optional[int],
enforce_prisma_migration_check: bool,
use_v2_migration_resolver: bool,
reload: bool,
Expand Down Expand Up @@ -1260,6 +1315,12 @@ def run_server(
if max_requests_before_restart is not None:
uvicorn_args["limit_max_requests"] = max_requests_before_restart
if run_gunicorn is False and run_hypercorn is False and run_granian is False:
if max_requests_before_restart_jitter is not None:
ProxyInitializationHelpers._apply_uvicorn_max_requests_jitter(
uvicorn_args=uvicorn_args,
max_requests_before_restart=max_requests_before_restart,
jitter=max_requests_before_restart_jitter,
)
if ssl_certfile_path is not None and ssl_keyfile_path is not None:
print(
f"\033[1;32mLiteLLM Proxy: Using SSL with certfile: {ssl_certfile_path} and keyfile: {ssl_keyfile_path}\033[0m\n"
Expand Down Expand Up @@ -1287,6 +1348,7 @@ def run_server(
ssl_certfile_path=ssl_certfile_path,
ssl_keyfile_path=ssl_keyfile_path,
max_requests_before_restart=max_requests_before_restart,
max_requests_before_restart_jitter=max_requests_before_restart_jitter,
)
elif run_hypercorn is True:
ProxyInitializationHelpers._init_hypercorn_server(
Expand Down
225 changes: 225 additions & 0 deletions tests/test_litellm/proxy/test_proxy_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1178,6 +1178,231 @@ def test_max_requests_before_restart_flag(
call_args = mock_uvicorn_run.call_args
assert call_args[1]["limit_max_requests"] == 123

@patch("uvicorn.run")
@patch("builtins.print")
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
def test_max_requests_before_restart_jitter_flag(
self, mock_setup_db, mock_print, mock_uvicorn_run
):
"""--max_requests_before_restart_jitter maps to uvicorn limit_max_requests_jitter"""
from click.testing import CliRunner

from litellm.proxy.proxy_cli import run_server

class _NewUvicornConfig:
def __init__(self, limit_max_requests=None, limit_max_requests_jitter=0):
pass

runner = CliRunner()
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
with (
patch.dict(os.environ, clean_env, clear=True),
patch("uvicorn.Config", _NewUvicornConfig),
patch.dict(
"sys.modules",
{
"proxy_server": MagicMock(
app=MagicMock(),
ProxyConfig=MagicMock(),
KeyManagementSettings=MagicMock(),
save_worker_config=MagicMock(),
)
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}

result = runner.invoke(
run_server,
[
"--local",
"--max_requests_before_restart",
"1000",
"--max_requests_before_restart_jitter",
"50",
],
)

assert (
result.exit_code == 0
), f"exit_code={result.exit_code}, output={result.output}"
mock_uvicorn_run.assert_called_once()
call_args = mock_uvicorn_run.call_args
assert call_args[1]["limit_max_requests"] == 1000
assert call_args[1]["limit_max_requests_jitter"] == 50

@patch("litellm.proxy.proxy_cli.ProxyInitializationHelpers._run_gunicorn_server")
@patch("uvicorn.run")
@patch("builtins.print")
@patch("litellm.proxy.db.prisma_client.PrismaManager.setup_database")
def test_run_gunicorn_passes_max_requests_jitter(
self, mock_setup_db, mock_print, mock_uvicorn_run, mock_run_gunicorn
):
"""--run_gunicorn threads jitter into _run_gunicorn_server, not uvicorn.run"""
from click.testing import CliRunner

from litellm.proxy.proxy_cli import run_server

runner = CliRunner()
clean_env = {
k: v
for k, v in os.environ.items()
if k not in ("DATABASE_URL", "DIRECT_URL")
}
with (
patch.dict(os.environ, clean_env, clear=True),
patch.dict(
"sys.modules",
{
"proxy_server": MagicMock(
app=MagicMock(),
ProxyConfig=MagicMock(),
KeyManagementSettings=MagicMock(),
save_worker_config=MagicMock(),
)
},
),
patch(
"litellm.proxy.proxy_cli.ProxyInitializationHelpers._get_default_unvicorn_init_args"
) as mock_get_args,
):
mock_get_args.return_value = {
"app": "litellm.proxy.proxy_server:app",
"host": "localhost",
"port": 8000,
}

result = runner.invoke(
run_server,
[
"--local",
"--run_gunicorn",
"--max_requests_before_restart",
"900",
"--max_requests_before_restart_jitter",
"75",
],
)

assert (
result.exit_code == 0
), f"exit_code={result.exit_code}, output={result.output}"
mock_uvicorn_run.assert_not_called()
mock_run_gunicorn.assert_called_once()
g_kwargs = mock_run_gunicorn.call_args[1]
assert g_kwargs["max_requests_before_restart"] == 900
assert g_kwargs["max_requests_before_restart_jitter"] == 75

@pytest.mark.skipif(os.name == "nt", reason="gunicorn server path skips Windows")
def test_gunicorn_options_include_max_requests_jitter(self):
"""_run_gunicorn_server puts max_requests_jitter into the gunicorn options"""
pytest.importorskip("gunicorn")

captured: dict = {}

def capture_run(self):
captured["options"] = dict(self.options)

with patch("gunicorn.app.base.BaseApplication.run", capture_run):
ProxyInitializationHelpers._run_gunicorn_server(
host="127.0.0.1",
port=4010,
app=MagicMock(),
num_workers=2,
ssl_certfile_path=None,
ssl_keyfile_path=None,
max_requests_before_restart=1000,
max_requests_before_restart_jitter=50,
)

assert captured["options"]["max_requests"] == 1000
assert captured["options"]["max_requests_jitter"] == 50

@pytest.mark.skipif(os.name == "nt", reason="gunicorn server path skips Windows")
def test_gunicorn_jitter_without_base_warns(self):
"""gunicorn path warns when jitter is set without --max_requests_before_restart"""
pytest.importorskip("gunicorn")

captured: dict = {}

def capture_run(self):
captured["options"] = dict(self.options)

with (
patch("gunicorn.app.base.BaseApplication.run", capture_run),
patch("builtins.print") as mock_print,
):
ProxyInitializationHelpers._run_gunicorn_server(
host="127.0.0.1",
port=4011,
app=MagicMock(),
num_workers=2,
ssl_certfile_path=None,
ssl_keyfile_path=None,
max_requests_before_restart=None,
max_requests_before_restart_jitter=50,
)

assert "max_requests" not in captured["options"]
assert "max_requests_jitter" not in captured["options"]
assert any("has no effect" in str(c) for c in mock_print.call_args_list)

def test_apply_uvicorn_jitter_sets_arg_when_supported(self):
class _NewUvicornConfig:
def __init__(self, limit_max_requests=None, limit_max_requests_jitter=0):
pass

uvicorn_args: dict = {}
with patch("uvicorn.Config", _NewUvicornConfig):
ProxyInitializationHelpers._apply_uvicorn_max_requests_jitter(
uvicorn_args=uvicorn_args,
max_requests_before_restart=1000,
jitter=50,
)
assert uvicorn_args["limit_max_requests_jitter"] == 50

def test_apply_uvicorn_jitter_skipped_on_old_uvicorn(self):
class _FakeUvicornConfig:
def __init__(self, limit_max_requests=None):
pass

uvicorn_args: dict = {}
with (
patch("uvicorn.Config", _FakeUvicornConfig),
patch("builtins.print") as mock_print,
):
ProxyInitializationHelpers._apply_uvicorn_max_requests_jitter(
uvicorn_args=uvicorn_args,
max_requests_before_restart=1000,
jitter=50,
)

assert "limit_max_requests_jitter" not in uvicorn_args
assert any("0.41.0" in str(c) for c in mock_print.call_args_list)

def test_apply_uvicorn_jitter_without_base_warns(self):
uvicorn_args: dict = {}
with patch("builtins.print") as mock_print:
ProxyInitializationHelpers._apply_uvicorn_max_requests_jitter(
uvicorn_args=uvicorn_args,
max_requests_before_restart=None,
jitter=50,
)

assert "limit_max_requests_jitter" not in uvicorn_args
assert any("has no effect" in str(c) for c in mock_print.call_args_list)

@patch.dict(os.environ, {}, clear=True)
def test_construct_database_url_from_env_vars(self):
"""Test the construct_database_url_from_env_vars function with various scenarios"""
Expand Down
Loading