From 5534b40ab3968a6d04da78bb64809b059612c7bc Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 15:31:43 +0530 Subject: [PATCH 01/33] fix(team-routing): use deterministic team model group names Use a deterministic internal model_name for team-scoped deployments so sibling deployments with the same public model share a routing group. This makes team alias writes idempotent and preserves multi-deployment failover/load balancing behavior. Made-with: Cursor --- .../model_management_endpoints.py | 26 ++- .../test_model_management_endpoints.py | 207 ++++++++++++++---- 2 files changed, 176 insertions(+), 57 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 44d41097833..39383d4ee20 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -13,13 +13,13 @@ import asyncio import datetime import json -from litellm._uuid import uuid from typing import Dict, List, Literal, Optional, Tuple, Union, cast from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel, ConfigDict, Field from litellm._logging import verbose_proxy_logger +from litellm._uuid import uuid from litellm.constants import LITELLM_PROXY_ADMIN_NAME from litellm.proxy._types import ( CommonProxyErrors, @@ -322,9 +322,13 @@ async def _add_team_model_to_db( """ If 'team_id' is provided, - - generate a unique 'model_name' for the model (e.g. 'model_name_{team_id}_{uuid}) - - store the model in the db with the unique 'model_name' - - store a team model alias mapping {"model_name": "model_name_{team_id}_{uuid}"} + - generate a deterministic 'model_name' for the model (e.g. 'model_name_{team_id}_{public_name}') + - store the model in the db with this shared group name + - store a team model alias mapping {"public_name": "model_name_{team_id}_{public_name}"} + + Using a deterministic name (not UUID) ensures sibling deployments for the + same public model share a model_name, so the router treats them as a single + candidate pool for load balancing and failover. """ _team_id = model_params.model_info.team_id if _team_id is None: @@ -333,9 +337,9 @@ async def _add_team_model_to_db( if original_model_name: model_params.model_info.team_public_model_name = original_model_name - unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}" + group_model_name = f"model_name_{_team_id}_{original_model_name}" - model_params.model_name = unique_model_name + model_params.model_name = group_model_name ## CREATE MODEL IN DB ## model_response = await _add_model_to_db( @@ -348,7 +352,7 @@ async def _add_team_model_to_db( await update_team( data=UpdateTeamRequest( team_id=_team_id, - model_aliases={original_model_name: unique_model_name}, + model_aliases={original_model_name: group_model_name}, ), user_api_key_dict=user_api_key_dict, http_request=Request(scope={"type": "http"}), @@ -453,14 +457,14 @@ async def _setup_new_team_model_assignment( patch_data: updateDeployment, user_api_key_dict: UserAPIKeyAuth, ) -> None: - """Set up a new team model with unique name, alias, and team membership.""" - unique_model_name = f"model_name_{team_id}_{uuid.uuid4()}" - patch_data.model_name = unique_model_name + """Set up a new team model with deterministic name, alias, and team membership.""" + group_model_name = f"model_name_{team_id}_{public_model_name}" + patch_data.model_name = group_model_name await update_team( data=UpdateTeamRequest( team_id=team_id, - model_aliases={public_model_name: unique_model_name}, + model_aliases={public_model_name: group_model_name}, ), user_api_key_dict=user_api_key_dict, http_request=Request(scope={"type": "http"}), diff --git a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py index f3c89003105..7f64a3de935 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py @@ -1,13 +1,14 @@ import json import os import sys -from litellm._uuid import uuid from typing import Dict, Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.testclient import TestClient +from litellm._uuid import uuid + sys.path.insert( 0, os.path.abspath("../../../..") ) # Adds the parent directory to the system path @@ -399,7 +400,9 @@ async def test_clear_cache_preserve_config_models(self): """ Test that clear_cache clears DB models and preserves config models. """ - from litellm.proxy.management_endpoints.model_management_endpoints import clear_cache + from litellm.proxy.management_endpoints.model_management_endpoints import ( + clear_cache, + ) # Create mock router with mixed DB and config models mock_router = MagicMock() @@ -407,18 +410,18 @@ async def test_clear_cache_preserve_config_models(self): { "model_name": "gpt-4", "model_info": {"id": "db-model-1", "db_model": True}, - "litellm_params": {"model": "gpt-4"} + "litellm_params": {"model": "gpt-4"}, }, { - "model_name": "gpt-3.5-turbo", + "model_name": "gpt-3.5-turbo", "model_info": {"id": "config-model-1", "db_model": False}, - "litellm_params": {"model": "gpt-3.5-turbo"} + "litellm_params": {"model": "gpt-3.5-turbo"}, }, { "model_name": "claude-3", "model_info": {"id": "db-model-2", "db_model": True}, - "litellm_params": {"model": "claude-3"} - } + "litellm_params": {"model": "claude-3"}, + }, ] mock_router.delete_deployment = MagicMock(return_value=True) mock_router.auto_routers = MagicMock() @@ -466,8 +469,8 @@ async def test_public_model_groups_set_after_get_config(self): """ import litellm from litellm.proxy.management_endpoints.model_management_endpoints import ( - update_public_model_groups, UpdatePublicModelGroupsRequest, + update_public_model_groups, ) old_db_models = ["db-model-1", "db-model-2"] @@ -525,7 +528,10 @@ async def test_useful_links_set_after_get_config(self): ) old_links = {"Old Doc": "https://old.example.com"} - new_links = {"New Doc": "https://new.example.com", "API Ref": "https://api.example.com"} + new_links = { + "New Doc": "https://new.example.com", + "API Ref": "https://api.example.com", + } async def mock_get_config(*args, **kwargs): litellm.public_model_groups_links = old_links @@ -558,6 +564,100 @@ async def mock_get_config(*args, **kwargs): litellm.public_model_groups_links = original_value +class TestTeamModelAliasSiblingOverwrite: + """ + Verify that two sibling team deployments for the same public model name + produce the same deterministic internal model_name, so the alias write + is idempotent and the router groups both deployments together. + """ + + @pytest.mark.asyncio + async def test_sibling_team_models_share_deterministic_name(self): + from litellm.proxy.management_endpoints.model_management_endpoints import ( + _add_team_model_to_db, + ) + from litellm.types.router import ModelInfo + + team_id = "team_alias_overwrite" + public_name = "gpt-4.1-mini" + + captured_alias_calls = [] + + async def mock_update_team(data, user_api_key_dict, http_request): + if data.model_aliases: + captured_alias_calls.append(dict(data.model_aliases)) + + async def mock_add_model_to_db(model_params, user_api_key_dict, prisma_client): + return MagicMock(model_id=str(uuid.uuid4())) + + async def mock_team_model_add(data, http_request, user_api_key_dict): + pass + + user = UserAPIKeyAuth(user_id="admin", user_role=LitellmUserRoles.PROXY_ADMIN) + prisma_client = MockPrismaClient(team_exists=True) + + deployment_1 = Deployment( + model_name=public_name, + litellm_params=LiteLLM_Params( + model="azure/gpt-4o-mini", + api_key="key-1", + api_base="https://eastus.example.openai.azure.com", + ), + model_info=ModelInfo(team_id=team_id), + ) + deployment_2 = Deployment( + model_name=public_name, + litellm_params=LiteLLM_Params( + model="azure/gpt-4o-mini", + api_key="key-2", + api_base="https://westus.example.openai.azure.com", + ), + model_info=ModelInfo(team_id=team_id), + ) + + with patch( + "litellm.proxy.management_endpoints.model_management_endpoints.update_team", + side_effect=mock_update_team, + ), patch( + "litellm.proxy.management_endpoints.model_management_endpoints._add_model_to_db", + side_effect=mock_add_model_to_db, + ), patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add", + side_effect=mock_team_model_add, + ): + await _add_team_model_to_db( + model_params=deployment_1, + user_api_key_dict=user, + prisma_client=prisma_client, + ) + await _add_team_model_to_db( + model_params=deployment_2, + user_api_key_dict=user, + prisma_client=prisma_client, + ) + + assert len(captured_alias_calls) == 2 + + internal_name_1 = captured_alias_calls[0][public_name] + internal_name_2 = captured_alias_calls[1][public_name] + + expected_group_name = f"model_name_{team_id}_{public_name}" + + # Both sibling deployments get the same deterministic group name + assert internal_name_1 == expected_group_name + assert internal_name_2 == expected_group_name + assert internal_name_1 == internal_name_2, ( + "Sibling deployments must share the same model_name so the " + "router treats them as a single candidate pool" + ) + + # The second alias write is idempotent — same key, same value + final_aliases = {} + for alias_call in captured_alias_calls: + final_aliases.update(alias_call) + assert final_aliases == {public_name: expected_group_name} + + class TestTeamModelUpdate: """Test team model update handles team_id consistently with model creation""" @@ -657,27 +757,37 @@ async def test_model_info_accessible_model_success(self): user_id="test_user", api_key="test_key", models=["gpt-4", "claude-3"], - team_models=["gpt-3.5-turbo"] - ) - - with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \ - patch("litellm.proxy.proxy_server.get_key_models") as mock_get_key_models, \ - patch("litellm.proxy.proxy_server.get_team_models") as mock_get_team_models, \ - patch("litellm.proxy.proxy_server.get_complete_model_list") as mock_get_complete_models, \ - patch("litellm.get_llm_provider") as mock_get_provider: - + team_models=["gpt-3.5-turbo"], + ) + + with patch("litellm.proxy.proxy_server.llm_router") as mock_router, patch( + "litellm.proxy.proxy_server.get_key_models" + ) as mock_get_key_models, patch( + "litellm.proxy.proxy_server.get_team_models" + ) as mock_get_team_models, patch( + "litellm.proxy.proxy_server.get_complete_model_list" + ) as mock_get_complete_models, patch( + "litellm.get_llm_provider" + ) as mock_get_provider: # Setup mocks - mock_router.get_model_names.return_value = ["gpt-4", "claude-3", "gpt-3.5-turbo"] + mock_router.get_model_names.return_value = [ + "gpt-4", + "claude-3", + "gpt-3.5-turbo", + ] mock_router.get_model_access_groups.return_value = {} mock_get_key_models.return_value = ["gpt-4", "claude-3"] mock_get_team_models.return_value = ["gpt-3.5-turbo"] - mock_get_complete_models.return_value = ["gpt-4", "claude-3", "gpt-3.5-turbo"] + mock_get_complete_models.return_value = [ + "gpt-4", + "claude-3", + "gpt-3.5-turbo", + ] mock_get_provider.return_value = (None, "openai", None, None) # Test accessible model result = await model_info( - model_id="gpt-4", - user_api_key_dict=user_api_key_dict + model_id="gpt-4", user_api_key_dict=user_api_key_dict ) assert result["id"] == "gpt-4" @@ -688,22 +798,25 @@ async def test_model_info_accessible_model_success(self): @pytest.mark.asyncio async def test_model_info_inaccessible_model_returns_404(self): """Test model_info returns 404 for inaccessible models""" - from litellm.proxy.proxy_server import model_info from fastapi import HTTPException + from litellm.proxy.proxy_server import model_info + # Mock user with limited access user_api_key_dict = UserAPIKeyAuth( user_id="test_user", api_key="test_key", models=["gpt-4"], # Only has access to gpt-4 - team_models=[] + team_models=[], ) - with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \ - patch("litellm.proxy.proxy_server.get_key_models") as mock_get_key_models, \ - patch("litellm.proxy.proxy_server.get_team_models") as mock_get_team_models, \ - patch("litellm.proxy.proxy_server.get_complete_model_list") as mock_get_complete_models: - + with patch("litellm.proxy.proxy_server.llm_router") as mock_router, patch( + "litellm.proxy.proxy_server.get_key_models" + ) as mock_get_key_models, patch( + "litellm.proxy.proxy_server.get_team_models" + ) as mock_get_team_models, patch( + "litellm.proxy.proxy_server.get_complete_model_list" + ) as mock_get_complete_models: # Setup mocks - user only has access to gpt-4 mock_router.get_model_names.return_value = ["gpt-4", "claude-3"] mock_router.get_model_access_groups.return_value = {} @@ -715,32 +828,35 @@ async def test_model_info_inaccessible_model_returns_404(self): with pytest.raises(HTTPException) as exc_info: await model_info( model_id="claude-3", # Not in user's accessible models - user_api_key_dict=user_api_key_dict + user_api_key_dict=user_api_key_dict, ) - + assert exc_info.value.status_code == 404 assert "does not exist or is not accessible" in exc_info.value.detail - @pytest.mark.asyncio + @pytest.mark.asyncio async def test_model_info_team_model_access(self): """Test model_info works with team model access""" from litellm.proxy.proxy_server import model_info - + # Mock user with team access user_api_key_dict = UserAPIKeyAuth( user_id="test_user", - api_key="test_key", + api_key="test_key", team_id="test_team", models=[], # No direct key models - team_models=["team-model-1"] - ) - - with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \ - patch("litellm.proxy.proxy_server.get_key_models") as mock_get_key_models, \ - patch("litellm.proxy.proxy_server.get_team_models") as mock_get_team_models, \ - patch("litellm.proxy.proxy_server.get_complete_model_list") as mock_get_complete_models, \ - patch("litellm.get_llm_provider") as mock_get_provider: - + team_models=["team-model-1"], + ) + + with patch("litellm.proxy.proxy_server.llm_router") as mock_router, patch( + "litellm.proxy.proxy_server.get_key_models" + ) as mock_get_key_models, patch( + "litellm.proxy.proxy_server.get_team_models" + ) as mock_get_team_models, patch( + "litellm.proxy.proxy_server.get_complete_model_list" + ) as mock_get_complete_models, patch( + "litellm.get_llm_provider" + ) as mock_get_provider: # Setup mocks mock_router.get_model_names.return_value = ["team-model-1"] mock_router.get_model_access_groups.return_value = {} @@ -751,10 +867,9 @@ async def test_model_info_team_model_access(self): # Test team model access result = await model_info( - model_id="team-model-1", - user_api_key_dict=user_api_key_dict + model_id="team-model-1", user_api_key_dict=user_api_key_dict ) assert result["id"] == "team-model-1" - assert result["object"] == "model" + assert result["object"] == "model" assert result["owned_by"] == "custom" From aeb932d707d2b7d7c2bc345b4a4869472fd2d850 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 16:12:27 +0530 Subject: [PATCH 02/33] fix(team-routing): keep team model routing on public names Remove team model_alias rewrites and resolve team deployments by team_public_model_name with team_id so sibling deployments stay in the routing candidate pool, with explicit logs showing candidate selection before load balancing. Made-with: Cursor --- .../model_management_endpoints.py | 50 ++---- litellm/router.py | 71 +++++++- .../test_model_management_endpoints.py | 160 ++++++++++-------- 3 files changed, 171 insertions(+), 110 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 39383d4ee20..694fa2b1d55 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -322,13 +322,9 @@ async def _add_team_model_to_db( """ If 'team_id' is provided, - - generate a deterministic 'model_name' for the model (e.g. 'model_name_{team_id}_{public_name}') - - store the model in the db with this shared group name - - store a team model alias mapping {"public_name": "model_name_{team_id}_{public_name}"} - - Using a deterministic name (not UUID) ensures sibling deployments for the - same public model share a model_name, so the router treats them as a single - candidate pool for load balancing and failover. + - generate a unique 'model_name' for the model (e.g. 'model_name_{team_id}_{uuid}) + - store the model in the db with the unique 'model_name' + - add the public model name to the team's allowed models list """ _team_id = model_params.model_info.team_id if _team_id is None: @@ -337,9 +333,9 @@ async def _add_team_model_to_db( if original_model_name: model_params.model_info.team_public_model_name = original_model_name - group_model_name = f"model_name_{_team_id}_{original_model_name}" + unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}" - model_params.model_name = group_model_name + model_params.model_name = unique_model_name ## CREATE MODEL IN DB ## model_response = await _add_model_to_db( @@ -348,17 +344,6 @@ async def _add_team_model_to_db( prisma_client=prisma_client, ) - ## CREATE MODEL ALIAS IN DB ## - await update_team( - data=UpdateTeamRequest( - team_id=_team_id, - model_aliases={original_model_name: group_model_name}, - ), - user_api_key_dict=user_api_key_dict, - http_request=Request(scope={"type": "http"}), - ) - - # add model to team object await team_model_add( data=TeamModelAddRequest( team_id=_team_id, @@ -457,18 +442,9 @@ async def _setup_new_team_model_assignment( patch_data: updateDeployment, user_api_key_dict: UserAPIKeyAuth, ) -> None: - """Set up a new team model with deterministic name, alias, and team membership.""" - group_model_name = f"model_name_{team_id}_{public_model_name}" - patch_data.model_name = group_model_name - - await update_team( - data=UpdateTeamRequest( - team_id=team_id, - model_aliases={public_model_name: group_model_name}, - ), - user_api_key_dict=user_api_key_dict, - http_request=Request(scope={"type": "http"}), - ) + """Set up a new team model with unique name and team membership.""" + unique_model_name = f"model_name_{team_id}_{uuid.uuid4()}" + patch_data.model_name = unique_model_name await team_model_add( data=TeamModelAddRequest( @@ -492,18 +468,16 @@ async def _update_existing_team_model_assignment( db_model.model_info.team_public_model_name if db_model.model_info else None ) - # Update alias only if public name changed if old_public_name and public_model_name != old_public_name: - await update_team( - data=UpdateTeamRequest( + await team_model_add( + data=TeamModelAddRequest( team_id=team_id, - model_aliases={public_model_name: db_model.model_name}, + models=[public_model_name], ), - user_api_key_dict=user_api_key_dict, http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, ) - # Keep existing unique model_name patch_data.model_name = None diff --git a/litellm/router.py b/litellm/router.py index 25e5c9cb5d9..e146d60e359 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -8148,20 +8148,23 @@ def resolve_model_name_from_model_id( def map_team_model(self, team_model_name: str, team_id: str) -> Optional[str]: """ - Map a team model name to a team-specific model name. + Check if team_model_name resolves to team-specific deployments. + + Returns the public model name (unchanged) so the router can find all + sibling deployments via team_id filtering, instead of collapsing to a + single internal model_name. Returns: - - deployment id: str - the deployment id of the team-specific model - - None: if no team-specific model name is found + - str: the team_model_name if team deployments exist for this team + - None: if no team-specific model is found """ models = self.get_model_list(model_name=team_model_name, team_id=team_id) if not models: return None for model in models: if model.get("model_info", {}).get("team_id") == team_id: - return model.get("model_name") + return team_model_name - ## wildcard models return None def should_include_deployment( @@ -8867,6 +8870,38 @@ def _common_checks_available_deployment( model = _model_from_alias if model not in self.model_names: + # Check for team-specific deployments by team_public_model_name + if request_team_id is not None: + team_deployments = self._get_all_deployments( + model_name=model, team_id=request_team_id + ) + if team_deployments: + candidate_details = [] + for deployment in team_deployments: + deployment_info = deployment.get("model_info", {}) or {} + deployment_params = deployment.get("litellm_params", {}) or {} + candidate_details.append( + { + "model_name": deployment.get("model_name"), + "model_id": deployment_info.get("id"), + "team_public_model_name": deployment_info.get( + "team_public_model_name" + ), + "api_base": deployment_params.get("api_base"), + } + ) + verbose_router_logger.info( + "🔥 routing_candidates_before_lb " + f"model={model} count={len(team_deployments)} " + f"candidates={candidate_details}" + ) + if len(team_deployments) > 1: + verbose_router_logger.info( + "🔥 load_balancer_candidate_pool " + f"model={model} candidate_count={len(team_deployments)}" + ) + return model, team_deployments + # check if provider/ specific wildcard routing use pattern matching pattern_deployments = self.pattern_router.get_deployments_by_pattern( model=model, @@ -8905,6 +8940,32 @@ def _common_checks_available_deployment( # check if the user sent in a deployment name instead healthy_deployments = self._get_deployment_by_litellm_model(model=model) + if isinstance(healthy_deployments, list) and len(healthy_deployments) > 0: + candidate_details = [] + for deployment in healthy_deployments: + deployment_info = deployment.get("model_info", {}) or {} + deployment_params = deployment.get("litellm_params", {}) or {} + candidate_details.append( + { + "model_name": deployment.get("model_name"), + "model_id": deployment_info.get("id"), + "team_public_model_name": deployment_info.get( + "team_public_model_name" + ), + "api_base": deployment_params.get("api_base"), + } + ) + verbose_router_logger.info( + "🔥 routing_candidates_before_lb " + f"model={model} count={len(healthy_deployments)} " + f"candidates={candidate_details}" + ) + if len(healthy_deployments) > 1: + verbose_router_logger.info( + "🔥 load_balancer_candidate_pool " + f"model={model} candidate_count={len(healthy_deployments)}" + ) + if verbose_router_logger.isEnabledFor(logging.DEBUG): verbose_router_logger.debug( f"initial list of deployments: {healthy_deployments}" diff --git a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py index 7f64a3de935..5ef7face1f2 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py @@ -564,98 +564,124 @@ async def mock_get_config(*args, **kwargs): litellm.public_model_groups_links = original_value -class TestTeamModelAliasSiblingOverwrite: +class TestTeamModelSiblingRouting: """ - Verify that two sibling team deployments for the same public model name - produce the same deterministic internal model_name, so the alias write - is idempotent and the router groups both deployments together. + Verify that sibling team deployments (same public model name, different + api_base) are all reachable through routing — no alias overwrite, no + collapse to a single deployment. """ @pytest.mark.asyncio - async def test_sibling_team_models_share_deterministic_name(self): + async def test_no_model_aliases_written_for_team_models(self): + """ + _add_team_model_to_db must NOT write model_aliases (which caused + the second sibling to overwrite the first). It should only call + team_model_add to register the public name on the team's models list. + """ from litellm.proxy.management_endpoints.model_management_endpoints import ( _add_team_model_to_db, ) from litellm.types.router import ModelInfo - team_id = "team_alias_overwrite" + team_id = "team_no_alias" public_name = "gpt-4.1-mini" - captured_alias_calls = [] - - async def mock_update_team(data, user_api_key_dict, http_request): - if data.model_aliases: - captured_alias_calls.append(dict(data.model_aliases)) + mock_update_team = AsyncMock() async def mock_add_model_to_db(model_params, user_api_key_dict, prisma_client): return MagicMock(model_id=str(uuid.uuid4())) - async def mock_team_model_add(data, http_request, user_api_key_dict): - pass + mock_team_model_add = AsyncMock() user = UserAPIKeyAuth(user_id="admin", user_role=LitellmUserRoles.PROXY_ADMIN) prisma_client = MockPrismaClient(team_exists=True) - deployment_1 = Deployment( - model_name=public_name, - litellm_params=LiteLLM_Params( - model="azure/gpt-4o-mini", - api_key="key-1", - api_base="https://eastus.example.openai.azure.com", - ), - model_info=ModelInfo(team_id=team_id), - ) - deployment_2 = Deployment( - model_name=public_name, - litellm_params=LiteLLM_Params( - model="azure/gpt-4o-mini", - api_key="key-2", - api_base="https://westus.example.openai.azure.com", - ), - model_info=ModelInfo(team_id=team_id), - ) - - with patch( - "litellm.proxy.management_endpoints.model_management_endpoints.update_team", - side_effect=mock_update_team, - ), patch( - "litellm.proxy.management_endpoints.model_management_endpoints._add_model_to_db", - side_effect=mock_add_model_to_db, - ), patch( - "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add", - side_effect=mock_team_model_add, - ): - await _add_team_model_to_db( - model_params=deployment_1, - user_api_key_dict=user, - prisma_client=prisma_client, - ) - await _add_team_model_to_db( - model_params=deployment_2, - user_api_key_dict=user, - prisma_client=prisma_client, + for api_base in ["https://eastus.example.com", "https://westus.example.com"]: + dep = Deployment( + model_name=public_name, + litellm_params=LiteLLM_Params( + model="azure/gpt-4o-mini", + api_key="key", + api_base=api_base, + ), + model_info=ModelInfo(team_id=team_id), ) + with patch( + "litellm.proxy.management_endpoints.model_management_endpoints.update_team", + mock_update_team, + ), patch( + "litellm.proxy.management_endpoints.model_management_endpoints._add_model_to_db", + side_effect=mock_add_model_to_db, + ), patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add", + mock_team_model_add, + ): + await _add_team_model_to_db( + model_params=dep, + user_api_key_dict=user, + prisma_client=prisma_client, + ) - assert len(captured_alias_calls) == 2 + mock_update_team.assert_not_called() + assert mock_team_model_add.call_count == 2 - internal_name_1 = captured_alias_calls[0][public_name] - internal_name_2 = captured_alias_calls[1][public_name] + @pytest.mark.asyncio + async def test_router_finds_all_sibling_team_deployments(self): + """ + When two team deployments share team_public_model_name="gpt-4.1-mini", + the router's _common_checks_available_deployment must return BOTH as + healthy_deployments (not collapse to one). + """ + import litellm - expected_group_name = f"model_name_{team_id}_{public_name}" + team_id = "teamA" + public_name = "gpt-4.1-mini" - # Both sibling deployments get the same deterministic group name - assert internal_name_1 == expected_group_name - assert internal_name_2 == expected_group_name - assert internal_name_1 == internal_name_2, ( - "Sibling deployments must share the same model_name so the " - "router treats them as a single candidate pool" + router = litellm.Router( + model_list=[ + { + "model_name": f"model_name_{team_id}_uuid1", + "litellm_params": { + "model": "azure/gpt-4o-mini", + "api_key": "key-1", + "api_base": "https://eastus.openai.azure.com", + }, + "model_info": { + "team_id": team_id, + "team_public_model_name": public_name, + }, + }, + { + "model_name": f"model_name_{team_id}_uuid2", + "litellm_params": { + "model": "azure/gpt-4o-mini", + "api_key": "key-2", + "api_base": "https://westus.openai.azure.com", + }, + "model_info": { + "team_id": team_id, + "team_public_model_name": public_name, + }, + }, + ], ) - # The second alias write is idempotent — same key, same value - final_aliases = {} - for alias_call in captured_alias_calls: - final_aliases.update(alias_call) - assert final_aliases == {public_name: expected_group_name} + # map_team_model should return the public name (not an internal UUID) + result = router.map_team_model(public_name, team_id) + assert result == public_name + + # _common_checks_available_deployment should return both deployments + model, healthy = router._common_checks_available_deployment( + model=public_name, + request_kwargs={"metadata": {"user_api_key_team_id": team_id}}, + ) + assert isinstance(healthy, list) + assert len(healthy) == 2 + api_bases = {d["litellm_params"]["api_base"] for d in healthy} + assert api_bases == { + "https://eastus.openai.azure.com", + "https://westus.openai.azure.com", + } class TestTeamModelUpdate: @@ -704,7 +730,7 @@ async def test_patch_model_with_team_id_creates_proper_setup(self): assert result.get("model_name", "").startswith("model_name_test_team_123_") assert "team_public_model_name" in str(result.get("model_info", "")) - mock_update_team.assert_called_once() + mock_update_team.assert_not_called() mock_team_model_add.assert_called_once() @pytest.mark.asyncio From 1835e9a25234b7e93cf704e88b2de6feba3e4f76 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 16:14:30 +0530 Subject: [PATCH 03/33] chore(team-routing): remove temporary candidate pool logs Remove temporary fire-emoji router logs used for local verification while keeping team sibling deployment routing behavior unchanged. Made-with: Cursor --- litellm/router.py | 50 ----------------------------------------------- 1 file changed, 50 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index e146d60e359..247f209e338 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -8876,30 +8876,6 @@ def _common_checks_available_deployment( model_name=model, team_id=request_team_id ) if team_deployments: - candidate_details = [] - for deployment in team_deployments: - deployment_info = deployment.get("model_info", {}) or {} - deployment_params = deployment.get("litellm_params", {}) or {} - candidate_details.append( - { - "model_name": deployment.get("model_name"), - "model_id": deployment_info.get("id"), - "team_public_model_name": deployment_info.get( - "team_public_model_name" - ), - "api_base": deployment_params.get("api_base"), - } - ) - verbose_router_logger.info( - "🔥 routing_candidates_before_lb " - f"model={model} count={len(team_deployments)} " - f"candidates={candidate_details}" - ) - if len(team_deployments) > 1: - verbose_router_logger.info( - "🔥 load_balancer_candidate_pool " - f"model={model} candidate_count={len(team_deployments)}" - ) return model, team_deployments # check if provider/ specific wildcard routing use pattern matching @@ -8940,32 +8916,6 @@ def _common_checks_available_deployment( # check if the user sent in a deployment name instead healthy_deployments = self._get_deployment_by_litellm_model(model=model) - if isinstance(healthy_deployments, list) and len(healthy_deployments) > 0: - candidate_details = [] - for deployment in healthy_deployments: - deployment_info = deployment.get("model_info", {}) or {} - deployment_params = deployment.get("litellm_params", {}) or {} - candidate_details.append( - { - "model_name": deployment.get("model_name"), - "model_id": deployment_info.get("id"), - "team_public_model_name": deployment_info.get( - "team_public_model_name" - ), - "api_base": deployment_params.get("api_base"), - } - ) - verbose_router_logger.info( - "🔥 routing_candidates_before_lb " - f"model={model} count={len(healthy_deployments)} " - f"candidates={candidate_details}" - ) - if len(healthy_deployments) > 1: - verbose_router_logger.info( - "🔥 load_balancer_candidate_pool " - f"model={model} candidate_count={len(healthy_deployments)}" - ) - if verbose_router_logger.isEnabledFor(logging.DEBUG): verbose_router_logger.debug( f"initial list of deployments: {healthy_deployments}" From 7b5e7e05b1fb41e55a4f10adbc3082358aa76a84 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 16:25:26 +0530 Subject: [PATCH 04/33] fix(router): address Greptile review comments - Add None guard for original_model_name in _add_team_model_to_db - Remove stale old public name when renaming team model - Add comment clarifying team deployment early-return priority Made-with: Cursor --- .../model_management_endpoints.py | 27 +++++++++++++------ litellm/router.py | 4 ++- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 694fa2b1d55..c091f6b5812 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -32,6 +32,7 @@ ProxyErrorTypes, ProxyException, TeamModelAddRequest, + TeamModelDeleteRequest, UpdateTeamRequest, UserAPIKeyAuth, ) @@ -40,6 +41,7 @@ from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin from litellm.proxy.management_endpoints.team_endpoints import ( team_model_add, + team_model_delete, update_team, ) from litellm.proxy.management_helpers.audit_logs import create_object_audit_log @@ -344,14 +346,15 @@ async def _add_team_model_to_db( prisma_client=prisma_client, ) - await team_model_add( - data=TeamModelAddRequest( - team_id=_team_id, - models=[original_model_name], - ), - http_request=Request(scope={"type": "http"}), - user_api_key_dict=user_api_key_dict, - ) + if original_model_name: + await team_model_add( + data=TeamModelAddRequest( + team_id=_team_id, + models=[original_model_name], + ), + http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, + ) return model_response @@ -469,6 +472,14 @@ async def _update_existing_team_model_assignment( ) if old_public_name and public_model_name != old_public_name: + await team_model_delete( + data=TeamModelDeleteRequest( + team_id=team_id, + models=[old_public_name], + ), + http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, + ) await team_model_add( data=TeamModelAddRequest( team_id=team_id, diff --git a/litellm/router.py b/litellm/router.py index 247f209e338..767938d11fb 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -8870,7 +8870,9 @@ def _common_checks_available_deployment( model = _model_from_alias if model not in self.model_names: - # Check for team-specific deployments by team_public_model_name + # Check for team-specific deployments by team_public_model_name. + # This intentionally takes priority over team pattern routers below, + # so that named team deployments shadow wildcard/pattern routes. if request_team_id is not None: team_deployments = self._get_all_deployments( model_name=model, team_id=request_team_id From 248fb8bc90799de32f793173ff34f40562c3dae4 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 16:33:26 +0530 Subject: [PATCH 05/33] fix(router): address remaining Greptile P0/P1 issues - Update map_team_model test to expect public name return - Only remove old public name if no sibling deployments use it Made-with: Cursor --- .../model_management_endpoints.py | 34 ++++++++++++++----- .../test_get_model_list_alias_optimization.py | 2 +- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index c091f6b5812..eb85112a66c 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -472,14 +472,32 @@ async def _update_existing_team_model_assignment( ) if old_public_name and public_model_name != old_public_name: - await team_model_delete( - data=TeamModelDeleteRequest( - team_id=team_id, - models=[old_public_name], - ), - http_request=Request(scope={"type": "http"}), - user_api_key_dict=user_api_key_dict, - ) + from litellm.proxy.proxy_server import llm_router + + other_deployments_with_old_name = [] + if llm_router: + all_deployments = llm_router.get_model_list( + model_name=old_public_name, team_id=team_id + ) + if all_deployments: + other_deployments_with_old_name = [ + d + for d in all_deployments + if d.get("model_name") != db_model.model_name + and d.get("model_info", {}).get("team_public_model_name") + == old_public_name + ] + + if not other_deployments_with_old_name: + await team_model_delete( + data=TeamModelDeleteRequest( + team_id=team_id, + models=[old_public_name], + ), + http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, + ) + await team_model_add( data=TeamModelAddRequest( team_id=team_id, diff --git a/tests/router_unit_tests/test_get_model_list_alias_optimization.py b/tests/router_unit_tests/test_get_model_list_alias_optimization.py index 31d992b6646..62baf0a3d22 100644 --- a/tests/router_unit_tests/test_get_model_list_alias_optimization.py +++ b/tests/router_unit_tests/test_get_model_list_alias_optimization.py @@ -46,5 +46,5 @@ def test_map_team_model_should_not_iterate_aliases_for_non_alias_team_model_name assert ( router.map_team_model(team_model_name="team-model", team_id="team-1") - == "gpt-3.5-turbo" + == "team-model" ) From ef9ea1f8f20b53cae0ff8050dee04591b6f053f8 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 16:44:30 +0530 Subject: [PATCH 06/33] fix(router): address Greptile P1/P2 performance issues - Guard against llm_router=None to prevent silent deletion - Add O(1) team_model index to avoid O(n) scan on every team request Made-with: Cursor --- .../model_management_endpoints.py | 26 ++++---- litellm/router.py | 61 +++++++++++++++++++ 2 files changed, 76 insertions(+), 11 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index eb85112a66c..bfd67ea4e59 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -474,11 +474,15 @@ async def _update_existing_team_model_assignment( if old_public_name and public_model_name != old_public_name: from litellm.proxy.proxy_server import llm_router - other_deployments_with_old_name = [] - if llm_router: + if llm_router is None: + verbose_proxy_logger.warning( + "llm_router not initialized; skipping old public name cleanup to preserve sibling deployments" + ) + else: all_deployments = llm_router.get_model_list( model_name=old_public_name, team_id=team_id ) + other_deployments_with_old_name = [] if all_deployments: other_deployments_with_old_name = [ d @@ -488,15 +492,15 @@ async def _update_existing_team_model_assignment( == old_public_name ] - if not other_deployments_with_old_name: - await team_model_delete( - data=TeamModelDeleteRequest( - team_id=team_id, - models=[old_public_name], - ), - http_request=Request(scope={"type": "http"}), - user_api_key_dict=user_api_key_dict, - ) + if not other_deployments_with_old_name: + await team_model_delete( + data=TeamModelDeleteRequest( + team_id=team_id, + models=[old_public_name], + ), + http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, + ) await team_model_add( data=TeamModelAddRequest( diff --git a/litellm/router.py b/litellm/router.py index 767938d11fb..c9d79beacce 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -467,6 +467,8 @@ def __init__( # noqa: PLR0915 # Initialize model name to deployment indices mapping for O(1) lookups # Maps model_name -> list of indices in model_list self.model_name_to_deployment_indices: Dict[str, List[int]] = {} + # Maps (team_id, team_public_model_name) -> list of indices in model_list + self.team_model_to_deployment_indices: Dict[Tuple[str, str], List[int]] = {} if model_list is not None: # set_model_list will build indices automatically @@ -6835,6 +6837,7 @@ def set_model_list(self, model_list: list): self.model_list = [] self.model_id_to_deployment_index_map = {} # Reset the index self.model_name_to_deployment_indices = {} # Reset the model_name index + self.team_model_to_deployment_indices = {} # Reset the team_model index self._invalidate_model_group_info_cache() self._invalidate_access_groups_cache() # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works @@ -7150,6 +7153,26 @@ def _update_deployment_indices_after_removal( else: del self.model_name_to_deployment_indices[model_name] + # Update team_model_to_deployment_indices + for key, indices in list(self.team_model_to_deployment_indices.items()): + # Remove the deleted index + if removal_idx in indices: + indices.remove(removal_idx) + + # Decrement all indices greater than removal_idx + updated_indices = [] + for idx in indices: + if idx > removal_idx: + updated_indices.append(idx - 1) + else: + updated_indices.append(idx) + + # Update or remove the entry + if len(updated_indices) > 0: + self.team_model_to_deployment_indices[key] = updated_indices + else: + del self.team_model_to_deployment_indices[key] + def _add_model_to_list_and_index_map( self, model: dict, model_id: Optional[str] = None ) -> None: @@ -7178,6 +7201,17 @@ def _add_model_to_list_and_index_map( self.model_name_to_deployment_indices[model_name] = [] self.model_name_to_deployment_indices[model_name].append(idx) + # Update team_model index for O(1) team-scoped lookup + team_id = model.get("model_info", {}).get("team_id") + team_public_model_name = model.get("model_info", {}).get( + "team_public_model_name" + ) + if team_id and team_public_model_name: + key = (team_id, team_public_model_name) + if key not in self.team_model_to_deployment_indices: + self.team_model_to_deployment_indices[key] = [] + self.team_model_to_deployment_indices[key].append(idx) + def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]: """ Add or update deployment @@ -8008,6 +8042,7 @@ def _build_model_name_index(self, model_list: list) -> None: instead of O(n) linear scan through the entire model_list. """ self.model_name_to_deployment_indices.clear() + self.team_model_to_deployment_indices.clear() for idx, model in enumerate(model_list): model_name = model.get("model_name") @@ -8016,6 +8051,16 @@ def _build_model_name_index(self, model_list: list) -> None: self.model_name_to_deployment_indices[model_name] = [] self.model_name_to_deployment_indices[model_name].append(idx) + team_id = model.get("model_info", {}).get("team_id") + team_public_model_name = model.get("model_info", {}).get( + "team_public_model_name" + ) + if team_id and team_public_model_name: + key = (team_id, team_public_model_name) + if key not in self.team_model_to_deployment_indices: + self.team_model_to_deployment_indices[key] = [] + self.team_model_to_deployment_indices[key].append(idx) + def _build_model_id_to_deployment_index_map(self, model_list: list): """ Build model index from model list to enable O(1) lookups immediately. @@ -8200,6 +8245,22 @@ def _get_all_deployments( """ returned_models: List[DeploymentTypedDict] = [] + # O(1) lookup in team_model index when team_id is provided + if team_id is not None: + key = (team_id, model_name) + if key in self.team_model_to_deployment_indices: + indices = self.team_model_to_deployment_indices[key] + # O(k) where k = team deployments for this model_name (typically 1-10) + for idx in indices: + model = self.model_list[idx] + if model_alias is not None: + alias_model = model.copy() + alias_model["model_name"] = model_alias + returned_models.append(alias_model) + else: + returned_models.append(model) + return returned_models + # O(1) lookup in model_name index if model_name in self.model_name_to_deployment_indices: indices = self.model_name_to_deployment_indices[model_name] From 4f302f10d0522a3eccf458990402b23567509fdd Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 17:00:30 +0530 Subject: [PATCH 07/33] fix(router): prevent cross-team deployment leakage in fallback path Guard should_include_deployment fallback to only return deployments matching the requested team_id, preventing public-name collisions from leaking deployments across teams Made-with: Cursor --- litellm/router.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/litellm/router.py b/litellm/router.py index c9d79beacce..7f428b18edf 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -8225,7 +8225,8 @@ def should_include_deployment( ): return True elif model_name is not None and model["model_name"] == model_name: - return True + if team_id is None or model["model_info"].get("team_id") == team_id: + return True return False def _get_all_deployments( From f5b72988540c2e195c1ed2c5abfdda4897e9e792 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 17:10:11 +0530 Subject: [PATCH 08/33] fix(management): query DB directly for sibling deployments on rename - Add clarifying comments to test assertions - Query prisma DB instead of in-memory router to avoid stale state - Prevents incorrect deletion of old public name when siblings exist Made-with: Cursor --- .../model_management_endpoints.py | 32 ++++++++++--------- .../test_model_management_endpoints.py | 9 ++++++ 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index bfd67ea4e59..c38a00e18a1 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -420,6 +420,7 @@ async def _update_team_model_in_db( db_model=db_model, patch_data=patch_data, user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, ) return update_db_model(db_model=db_model, updated_patch=patch_data) @@ -465,6 +466,7 @@ async def _update_existing_team_model_assignment( db_model: Deployment, patch_data: updateDeployment, user_api_key_dict: UserAPIKeyAuth, + prisma_client: PrismaClient, ) -> None: """Update an existing team model if the public name changed.""" old_public_name = ( @@ -472,25 +474,25 @@ async def _update_existing_team_model_assignment( ) if old_public_name and public_model_name != old_public_name: - from litellm.proxy.proxy_server import llm_router - - if llm_router is None: + if prisma_client is None: verbose_proxy_logger.warning( - "llm_router not initialized; skipping old public name cleanup to preserve sibling deployments" + "prisma_client not initialized; skipping old public name cleanup to preserve sibling deployments" ) else: - all_deployments = llm_router.get_model_list( - model_name=old_public_name, team_id=team_id + response = await prisma_client.db.litellm_proxymodeltable.find_many( + where={ + "model_info": { + "path": ["team_id"], + "equals": team_id, + } + } ) - other_deployments_with_old_name = [] - if all_deployments: - other_deployments_with_old_name = [ - d - for d in all_deployments - if d.get("model_name") != db_model.model_name - and d.get("model_info", {}).get("team_public_model_name") - == old_public_name - ] + other_deployments_with_old_name = [ + d + for d in response + if d.model_name != db_model.model_name + and d.model_info.get("team_public_model_name") == old_public_name + ] if not other_deployments_with_old_name: await team_model_delete( diff --git a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py index 5ef7face1f2..fd4f3d56b10 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py @@ -46,10 +46,17 @@ async def find_unique(self, where): ) return None + async def find_many(self, where): + return [] + @property def litellm_teamtable(self): return self + @property + def litellm_proxymodeltable(self): + return self + class MockLLMRouter: def __init__(self): @@ -730,7 +737,9 @@ async def test_patch_model_with_team_id_creates_proper_setup(self): assert result.get("model_name", "").startswith("model_name_test_team_123_") assert "team_public_model_name" in str(result.get("model_info", "")) + # update_team must not be called (no model_aliases writes for team models) mock_update_team.assert_not_called() + # team_model_add must be called to add public name to team's models list mock_team_model_add.assert_called_once() @pytest.mark.asyncio From 298df75066bb5fda8f5852a5a8aacb127da2815d Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 17:19:11 +0530 Subject: [PATCH 09/33] fix(router): guard None model_info and deduplicate team index logic - Guard against None model_info in sibling deployment check - Extract _update_team_model_index helper to eliminate duplication Made-with: Cursor --- .../model_management_endpoints.py | 3 +- litellm/router.py | 38 ++++++++++--------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index c38a00e18a1..4b06c4460e7 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -491,7 +491,8 @@ async def _update_existing_team_model_assignment( d for d in response if d.model_name != db_model.model_name - and d.model_info.get("team_public_model_name") == old_public_name + and (d.model_info or {}).get("team_public_model_name") + == old_public_name ] if not other_deployments_with_old_name: diff --git a/litellm/router.py b/litellm/router.py index 7f428b18edf..130979d25bc 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -7173,6 +7173,24 @@ def _update_deployment_indices_after_removal( else: del self.team_model_to_deployment_indices[key] + def _update_team_model_index(self, model: dict, idx: int) -> None: + """ + Helper to update team_model_to_deployment_indices for a single deployment. + + Parameters: + - model: dict - the deployment to index + - idx: int - the index in model_list + """ + team_id = model.get("model_info", {}).get("team_id") + team_public_model_name = model.get("model_info", {}).get( + "team_public_model_name" + ) + if team_id and team_public_model_name: + key = (team_id, team_public_model_name) + if key not in self.team_model_to_deployment_indices: + self.team_model_to_deployment_indices[key] = [] + self.team_model_to_deployment_indices[key].append(idx) + def _add_model_to_list_and_index_map( self, model: dict, model_id: Optional[str] = None ) -> None: @@ -7202,15 +7220,7 @@ def _add_model_to_list_and_index_map( self.model_name_to_deployment_indices[model_name].append(idx) # Update team_model index for O(1) team-scoped lookup - team_id = model.get("model_info", {}).get("team_id") - team_public_model_name = model.get("model_info", {}).get( - "team_public_model_name" - ) - if team_id and team_public_model_name: - key = (team_id, team_public_model_name) - if key not in self.team_model_to_deployment_indices: - self.team_model_to_deployment_indices[key] = [] - self.team_model_to_deployment_indices[key].append(idx) + self._update_team_model_index(model, idx) def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]: """ @@ -8051,15 +8061,7 @@ def _build_model_name_index(self, model_list: list) -> None: self.model_name_to_deployment_indices[model_name] = [] self.model_name_to_deployment_indices[model_name].append(idx) - team_id = model.get("model_info", {}).get("team_id") - team_public_model_name = model.get("model_info", {}).get( - "team_public_model_name" - ) - if team_id and team_public_model_name: - key = (team_id, team_public_model_name) - if key not in self.team_model_to_deployment_indices: - self.team_model_to_deployment_indices[key] = [] - self.team_model_to_deployment_indices[key].append(idx) + self._update_team_model_index(model, idx) def _build_model_id_to_deployment_index_map(self, model_list: list): """ From 8aa58bdcaaa3a67000ea0752e7f90c96f92d028d Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 17:33:07 +0530 Subject: [PATCH 10/33] fix(routing): prevent stale model_aliases from interfering with team routing - Skip model_aliases rewrite if model resolves to team deployments - Add test coverage for sibling-preservation branch - Update MockPrismaClient to support sibling deployment scenarios Made-with: Cursor --- litellm/proxy/litellm_pre_call_utils.py | 15 ++++ .../test_model_management_endpoints.py | 70 ++++++++++++++++++- 2 files changed, 83 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 4ca0d876a1c..1a7bbe0474b 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1296,6 +1296,10 @@ def _update_model_if_team_alias_exists( "gpt-4o": "gpt-4o-team-1" } - requested_model = "gpt-4o-team-1" + + Note: model_aliases for team models are deprecated. This function only applies + to legacy non-team-scoped aliases. Team-scoped deployments use team_public_model_name + and are resolved via map_team_model in route_llm_request. """ _model = data.get("model") if ( @@ -1303,6 +1307,17 @@ def _update_model_if_team_alias_exists( and user_api_key_dict.team_model_aliases and _model in user_api_key_dict.team_model_aliases ): + from litellm.proxy.proxy_server import llm_router + + # Skip alias rewrite if this model resolves to team-specific deployments + # (team models use team_public_model_name, not model_aliases) + if ( + llm_router + and user_api_key_dict.team_id + and llm_router.map_team_model(_model, user_api_key_dict.team_id) is not None + ): + return + data["model"] = user_api_key_dict.team_model_aliases[_model] return diff --git a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py index fd4f3d56b10..dcfd5847bd5 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py @@ -28,9 +28,15 @@ class MockPrismaClient: - def __init__(self, team_exists: bool = True, user_admin: bool = True): + def __init__( + self, + team_exists: bool = True, + user_admin: bool = True, + sibling_deployments: list = None, + ): self.team_exists = team_exists self.user_admin = user_admin + self.sibling_deployments = sibling_deployments or [] self.db = self async def find_unique(self, where): @@ -47,7 +53,7 @@ async def find_unique(self, where): return None async def find_many(self, where): - return [] + return self.sibling_deployments @property def litellm_teamtable(self): @@ -742,6 +748,66 @@ async def test_patch_model_with_team_id_creates_proper_setup(self): # team_model_add must be called to add public name to team's models list mock_team_model_add.assert_called_once() + @pytest.mark.asyncio + async def test_rename_preserves_old_name_when_siblings_exist(self): + """Test that renaming a deployment preserves old public name when sibling deployments still use it""" + from unittest.mock import MagicMock + + from litellm.proxy.management_endpoints.model_management_endpoints import ( + _update_existing_team_model_assignment, + ) + from litellm.types.router import ModelInfo + + # Create a deployment being renamed + db_model = Deployment( + model_name="model_name_team_123_uuid1", + litellm_params=LiteLLM_Params(model="azure/gpt-4o-mini"), + model_info=ModelInfo( + team_id="team_123", team_public_model_name="old-public-name" + ), + ) + + # Create a sibling deployment that still uses the old public name + sibling_deployment = MagicMock() + sibling_deployment.model_name = "model_name_team_123_uuid2" + sibling_deployment.model_info = { + "team_id": "team_123", + "team_public_model_name": "old-public-name", + } + + prisma_client = MockPrismaClient( + team_exists=True, sibling_deployments=[sibling_deployment] + ) + + patch_data = updateDeployment( + model_name="new-public-name", + model_info=ModelInfo(team_id="team_123"), + ) + + user_api_key_dict = UserAPIKeyAuth( + user_id="test_user", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + with patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_delete" + ) as mock_delete, patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add" + ) as mock_add: + await _update_existing_team_model_assignment( + team_id="team_123", + public_model_name="new-public-name", + db_model=db_model, + patch_data=patch_data, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, # type: ignore + ) + + # team_model_delete should NOT be called because sibling exists + mock_delete.assert_not_called() + # team_model_add should be called to add new public name + mock_add.assert_called_once() + @pytest.mark.asyncio async def test_patch_model_with_team_id_validates_permissions(self): """Test PATCH with team_id runs same validation as POST for team permissions""" From e8fb7762b345d643a0431f8e9f5bd7b1a074d8d6 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 17:56:03 +0530 Subject: [PATCH 11/33] perf(routing): optimize team model checks and improve test coverage - Use O(1) team index lookup instead of map_team_model in alias guard - Fix MockPrismaClient to validate where clause filters - Add comment explaining DB query trade-off for team deployments Made-with: Cursor --- litellm/proxy/litellm_pre_call_utils.py | 11 ++++----- .../model_management_endpoints.py | 4 ++++ .../test_model_management_endpoints.py | 23 +++++++++++++++++++ 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 1a7bbe0474b..48e83f1395b 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1311,12 +1311,11 @@ def _update_model_if_team_alias_exists( # Skip alias rewrite if this model resolves to team-specific deployments # (team models use team_public_model_name, not model_aliases) - if ( - llm_router - and user_api_key_dict.team_id - and llm_router.map_team_model(_model, user_api_key_dict.team_id) is not None - ): - return + # Use O(1) index lookup instead of map_team_model to avoid O(n) scan + if llm_router and user_api_key_dict.team_id: + key = (user_api_key_dict.team_id, _model) + if key in llm_router.team_model_to_deployment_indices: + return data["model"] = user_api_key_dict.team_model_aliases[_model] return diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 4b06c4460e7..7682b65657a 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -479,6 +479,10 @@ async def _update_existing_team_model_assignment( "prisma_client not initialized; skipping old public name cleanup to preserve sibling deployments" ) else: + # Query DB for all deployments in this team, then filter by public name. + # Note: Prisma's JSON filtering doesn't support compound AND conditions + # across multiple JSON paths, so we filter team_public_model_name in Python. + # For most teams (typically <100 deployments), this is acceptable. response = await prisma_client.db.litellm_proxymodeltable.find_many( where={ "model_info": { diff --git a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py index dcfd5847bd5..09410c19d34 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py @@ -53,6 +53,29 @@ async def find_unique(self, where): return None async def find_many(self, where): + # Filter sibling deployments by team_id if where clause specifies it + if not self.sibling_deployments: + return [] + + # Extract team_id from where clause if present + team_id_filter = None + if where and "model_info" in where: + model_info_filter = where["model_info"] + if isinstance(model_info_filter, dict) and "path" in model_info_filter: + if ( + model_info_filter["path"] == ["team_id"] + and "equals" in model_info_filter + ): + team_id_filter = model_info_filter["equals"] + + # Filter deployments by team_id if specified + if team_id_filter: + return [ + d + for d in self.sibling_deployments + if d.model_info.get("team_id") == team_id_filter + ] + return self.sibling_deployments @property From 8db867c51c2bb4a8995308439ad2a4f2cc474390 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 18:13:35 +0530 Subject: [PATCH 12/33] fix(routing): address state consistency and type safety issues - Check alias target pattern to detect stale team aliases - Fix PrismaClient type annotation to Optional - Eliminate in-place mutation in index update logic Made-with: Cursor --- litellm/proxy/litellm_pre_call_utils.py | 21 ++++++++++----- .../model_management_endpoints.py | 2 +- litellm/router.py | 26 ++++++++++--------- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 48e83f1395b..96a271dd027 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1311,13 +1311,20 @@ def _update_model_if_team_alias_exists( # Skip alias rewrite if this model resolves to team-specific deployments # (team models use team_public_model_name, not model_aliases) - # Use O(1) index lookup instead of map_team_model to avoid O(n) scan - if llm_router and user_api_key_dict.team_id: - key = (user_api_key_dict.team_id, _model) - if key in llm_router.team_model_to_deployment_indices: - return - - data["model"] = user_api_key_dict.team_model_aliases[_model] + aliased_target = user_api_key_dict.team_model_aliases[_model] + + # Check if the alias points to a stale team-scoped UUID name + # (format: "model_name_{team_id}_{uuid}") + if aliased_target.startswith(f"model_name_{user_api_key_dict.team_id}_"): + # This is a stale alias from pre-PR deployments. + # Check if current team deployments exist for the public name. + if llm_router: + key = (user_api_key_dict.team_id, _model) + if key in llm_router.team_model_to_deployment_indices: + # Team deployments exist; skip stale alias + return + + data["model"] = aliased_target return diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 7682b65657a..7d0181a3687 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -466,7 +466,7 @@ async def _update_existing_team_model_assignment( db_model: Deployment, patch_data: updateDeployment, user_api_key_dict: UserAPIKeyAuth, - prisma_client: PrismaClient, + prisma_client: Optional[PrismaClient], ) -> None: """Update an existing team model if the public name changed.""" old_public_name = ( diff --git a/litellm/router.py b/litellm/router.py index 130979d25bc..1a57d2f8c0d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -7135,16 +7135,17 @@ def _update_deployment_indices_after_removal( # Update model_name_to_deployment_indices for model_name, indices in list(self.model_name_to_deployment_indices.items()): - # Remove the deleted index - if removal_idx in indices: - indices.remove(removal_idx) - - # Decrement all indices greater than removal_idx + # Build new list without mutating the original updated_indices = [] for idx in indices: - if idx > removal_idx: + if idx == removal_idx: + # Skip the removed index + continue + elif idx > removal_idx: + # Decrement indices after removal updated_indices.append(idx - 1) else: + # Keep indices before removal unchanged updated_indices.append(idx) # Update or remove the entry @@ -7155,16 +7156,17 @@ def _update_deployment_indices_after_removal( # Update team_model_to_deployment_indices for key, indices in list(self.team_model_to_deployment_indices.items()): - # Remove the deleted index - if removal_idx in indices: - indices.remove(removal_idx) - - # Decrement all indices greater than removal_idx + # Build new list without mutating the original updated_indices = [] for idx in indices: - if idx > removal_idx: + if idx == removal_idx: + # Skip the removed index + continue + elif idx > removal_idx: + # Decrement indices after removal updated_indices.append(idx - 1) else: + # Keep indices before removal unchanged updated_indices.append(idx) # Update or remove the entry From 173695f5e0ed6e2e8933fbec341cfdc6162843dc Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 18:26:43 +0530 Subject: [PATCH 13/33] Fix greptile comments --- litellm/proxy/litellm_pre_call_utils.py | 12 +++- .../model_management_endpoints.py | 20 +++++- tests/proxy_unit_tests/test_proxy_utils.py | 43 ++++++++++++ .../test_model_management_endpoints.py | 70 ++++++++++++++++++- 4 files changed, 140 insertions(+), 5 deletions(-) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 96a271dd027..4a12a0a5774 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -26,6 +26,7 @@ v.value.lower() for v in SpecialHeaders._member_map_.values() ) from litellm.router import Router +from litellm.secret_managers.main import get_secret_bool from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS from litellm.types.services import ServiceTypes from litellm.types.utils import ( @@ -1313,9 +1314,16 @@ def _update_model_if_team_alias_exists( # (team models use team_public_model_name, not model_aliases) aliased_target = user_api_key_dict.team_model_aliases[_model] - # Check if the alias points to a stale team-scoped UUID name + # Optional bypass for stale aliases from pre-PR deployments: + # only enabled via feature flag to preserve backwards compatibility. + enable_stale_alias_bypass = get_secret_bool( + "LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS", False + ) + # Check if the alias points to a team-scoped UUID name # (format: "model_name_{team_id}_{uuid}") - if aliased_target.startswith(f"model_name_{user_api_key_dict.team_id}_"): + if enable_stale_alias_bypass and aliased_target.startswith( + f"model_name_{user_api_key_dict.team_id}_" + ): # This is a stale alias from pre-PR deployments. # Check if current team deployments exist for the public name. if llm_router: diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 7d0181a3687..40f4d722dc6 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -469,6 +469,23 @@ async def _update_existing_team_model_assignment( prisma_client: Optional[PrismaClient], ) -> None: """Update an existing team model if the public name changed.""" + + def _get_team_public_model_name( + model_info: Optional[Union[dict, str]] + ) -> Optional[str]: + if isinstance(model_info, dict): + value = model_info.get("team_public_model_name") + return value if isinstance(value, str) else None + if isinstance(model_info, str): + try: + parsed = json.loads(model_info) + except (TypeError, ValueError): + return None + if isinstance(parsed, dict): + value = parsed.get("team_public_model_name") + return value if isinstance(value, str) else None + return None + old_public_name = ( db_model.model_info.team_public_model_name if db_model.model_info else None ) @@ -495,8 +512,7 @@ async def _update_existing_team_model_assignment( d for d in response if d.model_name != db_model.model_name - and (d.model_info or {}).get("team_public_model_name") - == old_public_name + and _get_team_public_model_name(d.model_info) == old_public_name ] if not other_deployments_with_old_name: diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 00d4cd24e4b..5e75890388c 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -2044,6 +2044,49 @@ def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_mod assert test_data.get("model") == expected_model +def test_team_alias_stale_bypass_disabled_by_default(): + from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists + + class _MockRouter: + team_model_to_deployment_indices = {("team-1", "gpt-4o"): [0]} + + test_data = {"model": "gpt-4o"} + user_api_key_dict = UserAPIKeyAuth( + api_key="test_key", + team_id="team-1", + team_model_aliases={"gpt-4o": "model_name_team-1_legacy-uuid"}, + ) + + with patch("litellm.proxy.proxy_server.llm_router", _MockRouter()): + _update_model_if_team_alias_exists( + data=test_data, user_api_key_dict=user_api_key_dict + ) + + assert test_data.get("model") == "model_name_team-1_legacy-uuid" + + +def test_team_alias_stale_bypass_enabled_by_flag(monkeypatch): + from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists + + class _MockRouter: + team_model_to_deployment_indices = {("team-1", "gpt-4o"): [0]} + + test_data = {"model": "gpt-4o"} + user_api_key_dict = UserAPIKeyAuth( + api_key="test_key", + team_id="team-1", + team_model_aliases={"gpt-4o": "model_name_team-1_legacy-uuid"}, + ) + monkeypatch.setenv("LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS", "true") + + with patch("litellm.proxy.proxy_server.llm_router", _MockRouter()): + _update_model_if_team_alias_exists( + data=test_data, user_api_key_dict=user_api_key_dict + ) + + assert test_data.get("model") == "gpt-4o" + + @pytest.fixture def mock_prisma_client(): client = MagicMock() diff --git a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py index 09410c19d34..83e6b0c93a5 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py @@ -70,10 +70,23 @@ async def find_many(self, where): # Filter deployments by team_id if specified if team_id_filter: + + def _get_team_id(model_info): + if isinstance(model_info, dict): + return model_info.get("team_id") + if isinstance(model_info, str): + try: + parsed = json.loads(model_info) + except (TypeError, ValueError): + return None + if isinstance(parsed, dict): + return parsed.get("team_id") + return None + return [ d for d in self.sibling_deployments - if d.model_info.get("team_id") == team_id_filter + if _get_team_id(d.model_info) == team_id_filter ] return self.sibling_deployments @@ -831,6 +844,61 @@ async def test_rename_preserves_old_name_when_siblings_exist(self): # team_model_add should be called to add new public name mock_add.assert_called_once() + @pytest.mark.asyncio + async def test_rename_handles_legacy_string_model_info(self): + """Test rename path handles legacy string-encoded model_info rows without crashing.""" + from unittest.mock import MagicMock + + from litellm.proxy.management_endpoints.model_management_endpoints import ( + _update_existing_team_model_assignment, + ) + from litellm.types.router import ModelInfo + + db_model = Deployment( + model_name="model_name_team_123_uuid1", + litellm_params=LiteLLM_Params(model="azure/gpt-4o-mini"), + model_info=ModelInfo( + team_id="team_123", team_public_model_name="old-public-name" + ), + ) + + sibling_deployment = MagicMock() + sibling_deployment.model_name = "model_name_team_123_uuid2" + sibling_deployment.model_info = ( + '{"team_id":"team_123","team_public_model_name":"old-public-name"}' + ) + + prisma_client = MockPrismaClient( + team_exists=True, sibling_deployments=[sibling_deployment] + ) + + patch_data = updateDeployment( + model_name="new-public-name", + model_info=ModelInfo(team_id="team_123"), + ) + + user_api_key_dict = UserAPIKeyAuth( + user_id="test_user", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + with patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_delete" + ) as mock_delete, patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add" + ) as mock_add: + await _update_existing_team_model_assignment( + team_id="team_123", + public_model_name="new-public-name", + db_model=db_model, + patch_data=patch_data, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, # type: ignore + ) + + mock_delete.assert_not_called() + mock_add.assert_called_once() + @pytest.mark.asyncio async def test_patch_model_with_team_id_validates_permissions(self): """Test PATCH with team_id runs same validation as POST for team permissions""" From 303072dc44ed6a1b0d4b0d7decde03cf7af33e63 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 19:16:38 +0530 Subject: [PATCH 14/33] Fix greptile comments --- .../model_management_endpoints.py | 27 +++++++++++-------- litellm/router.py | 4 +++ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 40f4d722dc6..1acedfb346a 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -468,7 +468,13 @@ async def _update_existing_team_model_assignment( user_api_key_dict: UserAPIKeyAuth, prisma_client: Optional[PrismaClient], ) -> None: - """Update an existing team model if the public name changed.""" + """Update an existing team model if the public name changed. + + Note on DB scan: Prisma's JSON filtering does not support compound AND conditions + across multiple JSON paths, so we fetch all deployments for the team and filter + team_public_model_name in Python. For teams with many deployments this scan grows + linearly; if team deployment counts become large this should be revisited. + """ def _get_team_public_model_name( model_info: Optional[Union[dict, str]] @@ -496,10 +502,6 @@ def _get_team_public_model_name( "prisma_client not initialized; skipping old public name cleanup to preserve sibling deployments" ) else: - # Query DB for all deployments in this team, then filter by public name. - # Note: Prisma's JSON filtering doesn't support compound AND conditions - # across multiple JSON paths, so we filter team_public_model_name in Python. - # For most teams (typically <100 deployments), this is acceptable. response = await prisma_client.db.litellm_proxymodeltable.find_many( where={ "model_info": { @@ -508,12 +510,15 @@ def _get_team_public_model_name( } } ) - other_deployments_with_old_name = [ - d - for d in response - if d.model_name != db_model.model_name - and _get_team_public_model_name(d.model_info) == old_public_name - ] + if not response: + other_deployments_with_old_name = [] + else: + other_deployments_with_old_name = [ + d + for d in response + if d.model_name != db_model.model_name + and _get_team_public_model_name(d.model_info) == old_public_name + ] if not other_deployments_with_old_name: await team_model_delete( diff --git a/litellm/router.py b/litellm/router.py index 1a57d2f8c0d..ac965a0af5b 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -8258,6 +8258,10 @@ def _get_all_deployments( # O(k) where k = team deployments for this model_name (typically 1-10) for idx in indices: model = self.model_list[idx] + if not self.should_include_deployment( + model_name=model_name, model=model, team_id=team_id + ): + continue if model_alias is not None: alias_model = model.copy() alias_model["model_name"] = model_alias From fc6865c3a3c460a3b01ba6ebe84718e5d5f0046f Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 19:37:05 +0530 Subject: [PATCH 15/33] Fix greptile comments --- litellm/router.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/litellm/router.py b/litellm/router.py index ac965a0af5b..19a0f250dc1 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -7183,8 +7183,8 @@ def _update_team_model_index(self, model: dict, idx: int) -> None: - model: dict - the deployment to index - idx: int - the index in model_list """ - team_id = model.get("model_info", {}).get("team_id") - team_public_model_name = model.get("model_info", {}).get( + team_id = (model.get("model_info") or {}).get("team_id") + team_public_model_name = (model.get("model_info") or {}).get( "team_public_model_name" ) if team_id and team_public_model_name: @@ -7242,7 +7242,10 @@ def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]: ) if _deployment_on_router is not None: # deployment with this model_id exists on the router - if deployment.litellm_params == _deployment_on_router.litellm_params: + if ( + deployment.litellm_params == _deployment_on_router.litellm_params + and deployment.model_info == _deployment_on_router.model_info + ): # No need to update return None @@ -8268,7 +8271,8 @@ def _get_all_deployments( returned_models.append(alias_model) else: returned_models.append(model) - return returned_models + if returned_models: + return returned_models # O(1) lookup in model_name index if model_name in self.model_name_to_deployment_indices: From d02a70ab4e632f4bec3e7e5ab42c2e6ebf2fa74b Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 19:50:33 +0530 Subject: [PATCH 16/33] Fix greptile comments --- .../model_management_endpoints.py | 55 ++++++++++--------- litellm/router.py | 10 +++- 2 files changed, 35 insertions(+), 30 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 1acedfb346a..d8a1075a168 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -499,36 +499,37 @@ def _get_team_public_model_name( if old_public_name and public_model_name != old_public_name: if prisma_client is None: verbose_proxy_logger.warning( - "prisma_client not initialized; skipping old public name cleanup to preserve sibling deployments" + "prisma_client not initialized; skipping public name update entirely to avoid orphaned entries" ) - else: - response = await prisma_client.db.litellm_proxymodeltable.find_many( - where={ - "model_info": { - "path": ["team_id"], - "equals": team_id, - } + return + + response = await prisma_client.db.litellm_proxymodeltable.find_many( + where={ + "model_info": { + "path": ["team_id"], + "equals": team_id, } - ) - if not response: - other_deployments_with_old_name = [] - else: - other_deployments_with_old_name = [ - d - for d in response - if d.model_name != db_model.model_name - and _get_team_public_model_name(d.model_info) == old_public_name - ] + } + ) + if not response: + other_deployments_with_old_name = [] + else: + other_deployments_with_old_name = [ + d + for d in response + if d.model_name != db_model.model_name + and _get_team_public_model_name(d.model_info) == old_public_name + ] - if not other_deployments_with_old_name: - await team_model_delete( - data=TeamModelDeleteRequest( - team_id=team_id, - models=[old_public_name], - ), - http_request=Request(scope={"type": "http"}), - user_api_key_dict=user_api_key_dict, - ) + if not other_deployments_with_old_name: + await team_model_delete( + data=TeamModelDeleteRequest( + team_id=team_id, + models=[old_public_name], + ), + http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, + ) await team_model_add( data=TeamModelAddRequest( diff --git a/litellm/router.py b/litellm/router.py index 19a0f250dc1..76c13443e72 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -8227,12 +8227,16 @@ def should_include_deployment( """ if ( team_id is not None - and model["model_info"].get("team_id") == team_id - and model_name == model["model_info"].get("team_public_model_name") + and (model.get("model_info") or {}).get("team_id") == team_id + and model_name + == (model.get("model_info") or {}).get("team_public_model_name") ): return True elif model_name is not None and model["model_name"] == model_name: - if team_id is None or model["model_info"].get("team_id") == team_id: + if ( + team_id is None + or (model.get("model_info") or {}).get("team_id") == team_id + ): return True return False From 316a742945494ae9784591d6ecb6f2d314a0428c Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 20:12:06 +0530 Subject: [PATCH 17/33] Fix greptile comments --- .../model_management_endpoints.py | 19 +++++---- litellm/router.py | 4 +- .../test_model_management_endpoints.py | 41 +++++++++++++++++++ 3 files changed, 54 insertions(+), 10 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index d8a1075a168..721a8c2a1c4 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -521,6 +521,16 @@ def _get_team_public_model_name( and _get_team_public_model_name(d.model_info) == old_public_name ] + # Add new name first, then delete old name to prevent access loss on partial failure + await team_model_add( + data=TeamModelAddRequest( + team_id=team_id, + models=[public_model_name], + ), + http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, + ) + if not other_deployments_with_old_name: await team_model_delete( data=TeamModelDeleteRequest( @@ -531,15 +541,6 @@ def _get_team_public_model_name( user_api_key_dict=user_api_key_dict, ) - await team_model_add( - data=TeamModelAddRequest( - team_id=team_id, - models=[public_model_name], - ), - http_request=Request(scope={"type": "http"}), - user_api_key_dict=user_api_key_dict, - ) - patch_data.model_name = None diff --git a/litellm/router.py b/litellm/router.py index 76c13443e72..d7f5d42eac7 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -8233,9 +8233,11 @@ def should_include_deployment( ): return True elif model_name is not None and model["model_name"] == model_name: + model_team_id = (model.get("model_info") or {}).get("team_id") if ( team_id is None - or (model.get("model_info") or {}).get("team_id") == team_id + or model_team_id is None # global deployment - accessible to all teams + or model_team_id == team_id ): return True return False diff --git a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py index 83e6b0c93a5..2dd29fd5c9c 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py @@ -712,6 +712,15 @@ async def test_router_finds_all_sibling_team_deployments(self): "team_public_model_name": public_name, }, }, + { + "model_name": "global-gpt-4o", + "litellm_params": { + "model": "azure/gpt-4o", + "api_key": "global-key", + "api_base": "https://global.openai.azure.com", + }, + "model_info": {}, # No team_id - global deployment + }, ], ) @@ -732,6 +741,38 @@ async def test_router_finds_all_sibling_team_deployments(self): "https://westus.openai.azure.com", } + def test_global_deployments_accessible_to_teams(self): + """Test that global deployments (no team_id) are accessible to all teams""" + import litellm + + router = litellm.Router( + model_list=[ + { + "model_name": "global-gpt-4o", + "litellm_params": { + "model": "azure/gpt-4o", + "api_key": "global-key", + "api_base": "https://global.openai.azure.com", + }, + "model_info": {}, # No team_id - global deployment + }, + ], + ) + + # Global deployment should be accessible when team_id is provided + deployments = router._get_all_deployments( + model_name="global-gpt-4o", team_id="teamA" + ) + assert len(deployments) == 1 + assert deployments[0]["model_name"] == "global-gpt-4o" + + # should_include_deployment should return True for global deployments + assert router.should_include_deployment( + model_name="global-gpt-4o", + model={"model_name": "global-gpt-4o", "model_info": {}}, + team_id="teamA", + ) + class TestTeamModelUpdate: """Test team model update handles team_id consistently with model creation""" From 9a0a21619514ba11a22153ebd5f725cfd5d1fd30 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 22:02:50 +0530 Subject: [PATCH 18/33] Fix code qa issues --- .../model_management_endpoints.py | 2 -- .../test_router_index_management.py | 22 +++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 721a8c2a1c4..8f6b8a626e4 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -33,7 +33,6 @@ ProxyException, TeamModelAddRequest, TeamModelDeleteRequest, - UpdateTeamRequest, UserAPIKeyAuth, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth @@ -42,7 +41,6 @@ from litellm.proxy.management_endpoints.team_endpoints import ( team_model_add, team_model_delete, - update_team, ) from litellm.proxy.management_helpers.audit_logs import create_object_audit_log from litellm.proxy.utils import PrismaClient diff --git a/tests/router_unit_tests/test_router_index_management.py b/tests/router_unit_tests/test_router_index_management.py index 90d98b8ab0a..2694c62827c 100644 --- a/tests/router_unit_tests/test_router_index_management.py +++ b/tests/router_unit_tests/test_router_index_management.py @@ -118,6 +118,28 @@ def test_add_model_to_list_and_index_map_multiple_models(self, router): assert router.model_id_to_deployment_index_map["id-2"] == 1 assert router.model_id_to_deployment_index_map["id-3"] == 2 + def test_update_team_model_index(self, router): + """Test _update_team_model_index updates team_model_to_deployment_indices.""" + model = { + "model_name": "team-alias", + "model_info": { + "id": "dep-1", + "team_id": "team-abc", + "team_public_model_name": "gpt-4o", + }, + } + router._update_team_model_index(model, 0) + assert router.team_model_to_deployment_indices[("team-abc", "gpt-4o")] == [0] + router._update_team_model_index(model, 2) + assert router.team_model_to_deployment_indices[("team-abc", "gpt-4o")] == [0, 2] + + router._update_team_model_index( + {"model_name": "x", "model_info": {"id": "dep-2"}}, 5 + ) + assert router.team_model_to_deployment_indices == { + ("team-abc", "gpt-4o"): [0, 2], + } + def test_has_model_id(self, router): """Test has_model_id function for O(1) membership check""" # Setup: Add models to router From c6cc0341f61836e74ee41a6afba197268e91fe99 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 22:26:08 +0530 Subject: [PATCH 19/33] Fix greptile reviews and mock test --- .../model_management_endpoints.py | 11 ++++ .../test_model_management_endpoints.py | 51 +++++++++++++++---- 2 files changed, 52 insertions(+), 10 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 8f6b8a626e4..5952aede853 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -538,6 +538,17 @@ def _get_team_public_model_name( http_request=Request(scope={"type": "http"}), user_api_key_dict=user_api_key_dict, ) + elif not old_public_name and public_model_name: + # First-time assignment of public name on an existing team deployment: + # ensure the team's models list is updated so team routing can resolve it. + await team_model_add( + data=TeamModelAddRequest( + team_id=team_id, + models=[public_model_name], + ), + http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, + ) patch_data.model_name = None diff --git a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py index 2dd29fd5c9c..751d0a02ff0 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py @@ -635,8 +635,6 @@ async def test_no_model_aliases_written_for_team_models(self): team_id = "team_no_alias" public_name = "gpt-4.1-mini" - mock_update_team = AsyncMock() - async def mock_add_model_to_db(model_params, user_api_key_dict, prisma_client): return MagicMock(model_id=str(uuid.uuid4())) @@ -656,9 +654,6 @@ async def mock_add_model_to_db(model_params, user_api_key_dict, prisma_client): model_info=ModelInfo(team_id=team_id), ) with patch( - "litellm.proxy.management_endpoints.model_management_endpoints.update_team", - mock_update_team, - ), patch( "litellm.proxy.management_endpoints.model_management_endpoints._add_model_to_db", side_effect=mock_add_model_to_db, ), patch( @@ -671,7 +666,6 @@ async def mock_add_model_to_db(model_params, user_api_key_dict, prisma_client): prisma_client=prisma_client, ) - mock_update_team.assert_not_called() assert mock_team_model_add.call_count == 2 @pytest.mark.asyncio @@ -807,8 +801,6 @@ async def test_patch_model_with_team_id_creates_proper_setup(self): "litellm.proxy.proxy_server.premium_user", True, ), patch( - "litellm.proxy.management_endpoints.model_management_endpoints.update_team" - ) as mock_update_team, patch( "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add" ) as mock_team_model_add: result = await _update_team_model_in_db( @@ -820,8 +812,6 @@ async def test_patch_model_with_team_id_creates_proper_setup(self): assert result.get("model_name", "").startswith("model_name_test_team_123_") assert "team_public_model_name" in str(result.get("model_info", "")) - # update_team must not be called (no model_aliases writes for team models) - mock_update_team.assert_not_called() # team_model_add must be called to add public name to team's models list mock_team_model_add.assert_called_once() @@ -885,6 +875,47 @@ async def test_rename_preserves_old_name_when_siblings_exist(self): # team_model_add should be called to add new public name mock_add.assert_called_once() + @pytest.mark.asyncio + async def test_first_time_public_name_assignment_adds_team_model(self): + """If existing team deployment had no public name, first assignment must call team_model_add.""" + from litellm.proxy.management_endpoints.model_management_endpoints import ( + _update_existing_team_model_assignment, + ) + from litellm.types.router import ModelInfo + + db_model = Deployment( + model_name="model_name_team_123_uuid1", + litellm_params=LiteLLM_Params(model="azure/gpt-4o-mini"), + model_info=ModelInfo(team_id="team_123"), + ) + + patch_data = updateDeployment( + model_name="new-public-name", + model_info=ModelInfo(team_id="team_123"), + ) + + user_api_key_dict = UserAPIKeyAuth( + user_id="test_user", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + with patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_delete" + ) as mock_delete, patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add" + ) as mock_add: + await _update_existing_team_model_assignment( + team_id="team_123", + public_model_name="new-public-name", + db_model=db_model, + patch_data=patch_data, + user_api_key_dict=user_api_key_dict, + prisma_client=None, + ) + + mock_add.assert_called_once() + mock_delete.assert_not_called() + @pytest.mark.asyncio async def test_rename_handles_legacy_string_model_info(self): """Test rename path handles legacy string-encoded model_info rows without crashing.""" From fb8d9c2e9a33a2a653335ed095526c29ef7f1417 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 22:38:54 +0530 Subject: [PATCH 20/33] Fix greptile reviews and mock test --- docs/my-website/docs/proxy/config_settings.md | 1 + docs/my-website/docs/proxy/load_balancing.md | 15 ++++++++++++ litellm/proxy/litellm_pre_call_utils.py | 23 +++++++++++++++---- litellm/router.py | 6 +++++ 4 files changed, 40 insertions(+), 5 deletions(-) diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 02f5c2be9c7..dce979ab89f 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -804,6 +804,7 @@ router_settings: | LITELLM_OTEL_INTEGRATION_ENABLE_EVENTS | Optionally enable semantic logs for OTEL | LITELLM_OTEL_INTEGRATION_ENABLE_METRICS | Optionally enable emantic metrics for OTEL | LITELLM_ENABLE_PYROSCOPE | If true, enables Pyroscope CPU profiling. Profiles are sent to PYROSCOPE_SERVER_ADDRESS. Off by default. See [Pyroscope profiling](/proxy/pyroscope_profiling). +| LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS | When `true`, if a team's legacy `model_aliases` entry maps a public model name to an internal `model_name__` deployment, pre-call handling can skip that rewrite when team-scoped sibling deployments exist for the public name—so load balancing / `order` apply across siblings. Default is `false` for backwards compatibility. See [Team-scoped models and legacy aliases](./load_balancing#team-scoped-models-and-legacy-model_aliases). When stale aliases are detected and this flag is off, the proxy may log a one-time warning. | PYROSCOPE_APP_NAME | Application name reported to Pyroscope. Required when LITELLM_ENABLE_PYROSCOPE is true. No default. | PYROSCOPE_SERVER_ADDRESS | Pyroscope server URL to send profiles to. Required when LITELLM_ENABLE_PYROSCOPE is true. No default. | PYROSCOPE_SAMPLE_RATE | Optional. Sample rate for Pyroscope profiling (integer). No default; when unset, the pyroscope-io library default is used. diff --git a/docs/my-website/docs/proxy/load_balancing.md b/docs/my-website/docs/proxy/load_balancing.md index 74b3e8a5117..897c04b2b00 100644 --- a/docs/my-website/docs/proxy/load_balancing.md +++ b/docs/my-website/docs/proxy/load_balancing.md @@ -336,6 +336,21 @@ The `order` parameter requires `enable_pre_call_checks: true` in `router_setting If `order=1` deployment is unavailable (e.g., rate-limited), the router falls back to `order=2` deployments. +### Team-scoped models and legacy `model_aliases` {#team-scoped-models-and-legacy-model_aliases} + +Team-scoped deployments are identified by `model_info.team_id` and `model_info.team_public_model_name`. Requests should use the **public** model name; the router resolves all sibling deployments (same public name, different `api_base` / `order`, etc.) for routing, failover, and deployment `order`. + +For router internals: when a `team_id` is in scope, optimized lookups key off `(team_id, team_public_model_name)`. If code passes an internal deployment id (e.g. `model_name__`) instead of the public name, routing still works via the usual deployment-name paths, but the team-specific fast path applies only to the public name. + +**Legacy teams:** Older proxy versions could persist `model_aliases` on the team row mapping a public name to a single internal deployment id (`model_name__`). On each request, pre-call logic may still rewrite `model` to that internal name **before** routing, which collapses to one deployment and can make newer sibling deployments unreachable. + +**Migration options:** + +1. **Recommended for upgrades:** Set environment variable `LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS=true` so that when sibling team deployments exist for the public name, the stale alias rewrite is skipped and team-scoped routing (including `order` and failover) applies. See the [Environment variables](./config_settings) table in the proxy settings doc. +2. **Data cleanup:** Remove obsolete `model_aliases` entries for team public names from the team record in the database so only `team_public_model_name` + team model list drive access. + +If a stale alias is detected and the bypass is **not** enabled, the proxy may emit a **one-time** warning in logs explaining that sibling deployments may be unreachable until the flag is set or aliases are cleaned up. + ### When You'll See Load Balancing in Action **Immediate Effects:** diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 4a12a0a5774..64ef2405002 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -37,6 +37,7 @@ ) service_logger_obj = ServiceLogging() # used for tracking latency on OTEL +_STALE_TEAM_ALIAS_WARNING_KEYS: set[str] = set() if TYPE_CHECKING: @@ -1321,16 +1322,28 @@ def _update_model_if_team_alias_exists( ) # Check if the alias points to a team-scoped UUID name # (format: "model_name_{team_id}_{uuid}") - if enable_stale_alias_bypass and aliased_target.startswith( + is_stale_team_alias = aliased_target.startswith( f"model_name_{user_api_key_dict.team_id}_" - ): + ) + if is_stale_team_alias and llm_router: # This is a stale alias from pre-PR deployments. # Check if current team deployments exist for the public name. - if llm_router: - key = (user_api_key_dict.team_id, _model) - if key in llm_router.team_model_to_deployment_indices: + key = (user_api_key_dict.team_id, _model) + if key in llm_router.team_model_to_deployment_indices: + if enable_stale_alias_bypass: # Team deployments exist; skip stale alias return + warning_key = f"{user_api_key_dict.team_id}:{_model}:{aliased_target}" + if warning_key not in _STALE_TEAM_ALIAS_WARNING_KEYS: + _STALE_TEAM_ALIAS_WARNING_KEYS.add(warning_key) + verbose_proxy_logger.warning( + "Stale team model alias detected for model='%s', team_id='%s'. " + "New sibling deployments may be unreachable. " + "Set LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS=true to enable " + "team-scoped sibling routing.", + _model, + user_api_key_dict.team_id, + ) data["model"] = aliased_target return diff --git a/litellm/router.py b/litellm/router.py index d7f5d42eac7..64d29d8bceb 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -8256,6 +8256,12 @@ def _get_all_deployments( if team_id specified, only return team-specific models Optimized with O(1) index lookup instead of O(n) linear scan. + + Note: when team_id is provided, O(1) lookup in + `team_model_to_deployment_indices` only applies when `model_name` is the + team public model name. If a caller passes an internal deployment model + name (for example, `model_name__`), this method falls back + to the standard model-name index / scan path. """ returned_models: List[DeploymentTypedDict] = [] From 1a0b30aaac029808cd25029837306d7e10f228eb Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Mon, 23 Mar 2026 22:49:57 +0530 Subject: [PATCH 21/33] Fix greptile reviews and mock test --- litellm/proxy/litellm_pre_call_utils.py | 12 +++++-- .../model_management_endpoints.py | 3 ++ .../test_model_management_endpoints.py | 35 +++++++++++++++++++ 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 64ef2405002..f8f299b1481 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1,6 +1,7 @@ import asyncio import copy import time +from collections import OrderedDict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from fastapi import Request @@ -37,7 +38,9 @@ ) service_logger_obj = ServiceLogging() # used for tracking latency on OTEL -_STALE_TEAM_ALIAS_WARNING_KEYS: set[str] = set() +# Bounded dedup for stale-alias warnings (FIFO eviction when over cap). +_MAX_STALE_ALIAS_WARNING_KEYS = 10_000 +_STALE_TEAM_ALIAS_WARNING_KEYS: OrderedDict[str, None] = OrderedDict() if TYPE_CHECKING: @@ -1335,7 +1338,12 @@ def _update_model_if_team_alias_exists( return warning_key = f"{user_api_key_dict.team_id}:{_model}:{aliased_target}" if warning_key not in _STALE_TEAM_ALIAS_WARNING_KEYS: - _STALE_TEAM_ALIAS_WARNING_KEYS.add(warning_key) + _STALE_TEAM_ALIAS_WARNING_KEYS[warning_key] = None + while ( + len(_STALE_TEAM_ALIAS_WARNING_KEYS) + > _MAX_STALE_ALIAS_WARNING_KEYS + ): + _STALE_TEAM_ALIAS_WARNING_KEYS.popitem(last=False) verbose_proxy_logger.warning( "Stale team model alias detected for model='%s', team_id='%s'. " "New sibling deployments may be unreachable. " diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 5952aede853..95c44a431b5 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -495,6 +495,9 @@ def _get_team_public_model_name( ) if old_public_name and public_model_name != old_public_name: + # Clear user-supplied public name from patch before any early return so the + # caller does not overwrite the internal UUID-based model_name in the DB. + patch_data.model_name = None if prisma_client is None: verbose_proxy_logger.warning( "prisma_client not initialized; skipping public name update entirely to avoid orphaned entries" diff --git a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py index 751d0a02ff0..2a85e24d780 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py @@ -916,6 +916,41 @@ async def test_first_time_public_name_assignment_adds_team_model(self): mock_add.assert_called_once() mock_delete.assert_not_called() + @pytest.mark.asyncio + async def test_rename_with_prisma_none_clears_patch_model_name(self): + """Rename path must clear patch_data.model_name even when prisma is unavailable (P1).""" + from litellm.proxy.management_endpoints.model_management_endpoints import ( + _update_existing_team_model_assignment, + ) + from litellm.types.router import ModelInfo + + db_model = Deployment( + model_name="model_name_team_123_uuid1", + litellm_params=LiteLLM_Params(model="azure/gpt-4o-mini"), + model_info=ModelInfo( + team_id="team_123", team_public_model_name="old-public-name" + ), + ) + patch_data = updateDeployment( + model_name="new-public-name", + model_info=ModelInfo(team_id="team_123"), + ) + user_api_key_dict = UserAPIKeyAuth( + user_id="test_user", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + await _update_existing_team_model_assignment( + team_id="team_123", + public_model_name="new-public-name", + db_model=db_model, + patch_data=patch_data, + user_api_key_dict=user_api_key_dict, + prisma_client=None, + ) + + assert patch_data.model_name is None + @pytest.mark.asyncio async def test_rename_handles_legacy_string_model_info(self): """Test rename path handles legacy string-encoded model_info rows without crashing.""" From 592ac98ddc1154a00115528cada41654a0091cdc Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 24 Mar 2026 20:10:57 +0530 Subject: [PATCH 22/33] fix(router): address Greptile P1/P2 review comments - Add deduplication guard in _update_team_model_index to prevent duplicate indices - Add wildcard comment in map_team_model for clarity - Add monkeypatch to test_team_alias_stale_bypass_disabled_by_default for determinism - Extract _get_team_deployments helper to centralize DB access pattern - Add clarifying comments for team_public_model_name assignment ordering Made-with: Cursor --- .../model_management_endpoints.py | 53 ++++++++++++------- litellm/router.py | 5 +- tests/proxy_unit_tests/test_proxy_utils.py | 3 +- 3 files changed, 40 insertions(+), 21 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 95c44a431b5..4ab52ac5c0b 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -329,12 +329,17 @@ async def _add_team_model_to_db( _team_id = model_params.model_info.team_id if _team_id is None: return None + # Capture the original public name before mutating model_params.model_name original_model_name = model_params.model_name - if original_model_name: - model_params.model_info.team_public_model_name = original_model_name + # Generate unique internal model_name for team-scoped deployment unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}" + # Store public name in model_info BEFORE overwriting model_name + # so _add_model_to_db serializes the correct team_public_model_name + if original_model_name: + model_params.model_info.team_public_model_name = original_model_name + model_params.model_name = unique_model_name ## CREATE MODEL IN DB ## @@ -458,6 +463,25 @@ async def _setup_new_team_model_assignment( ) +async def _get_team_deployments( + team_id: str, prisma_client: PrismaClient +) -> List[LiteLLM_ProxyModelTable]: + """ + Fetch all deployments for a given team_id from the database. + + Centralizes team deployment queries to ensure consistent filtering and error handling. + """ + response = await prisma_client.db.litellm_proxymodeltable.find_many( + where={ + "model_info": { + "path": ["team_id"], + "equals": team_id, + } + } + ) + return response if response else [] + + async def _update_existing_team_model_assignment( team_id: str, public_model_name: str, @@ -504,23 +528,14 @@ def _get_team_public_model_name( ) return - response = await prisma_client.db.litellm_proxymodeltable.find_many( - where={ - "model_info": { - "path": ["team_id"], - "equals": team_id, - } - } - ) - if not response: - other_deployments_with_old_name = [] - else: - other_deployments_with_old_name = [ - d - for d in response - if d.model_name != db_model.model_name - and _get_team_public_model_name(d.model_info) == old_public_name - ] + # Query DB for all team deployments to check for sibling deployments + team_deployments = await _get_team_deployments(team_id, prisma_client) + other_deployments_with_old_name = [ + d + for d in team_deployments + if d.model_name != db_model.model_name + and _get_team_public_model_name(d.model_info) == old_public_name + ] # Add new name first, then delete old name to prevent access loss on partial failure await team_model_add( diff --git a/litellm/router.py b/litellm/router.py index 64d29d8bceb..8f2785a3838 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -7191,7 +7191,8 @@ def _update_team_model_index(self, model: dict, idx: int) -> None: key = (team_id, team_public_model_name) if key not in self.team_model_to_deployment_indices: self.team_model_to_deployment_indices[key] = [] - self.team_model_to_deployment_indices[key].append(idx) + if idx not in self.team_model_to_deployment_indices[key]: + self.team_model_to_deployment_indices[key].append(idx) def _add_model_to_list_and_index_map( self, model: dict, model_id: Optional[str] = None @@ -8217,6 +8218,8 @@ def map_team_model(self, team_model_name: str, team_id: str) -> Optional[str]: if model.get("model_info", {}).get("team_id") == team_id: return team_model_name + # No team-scoped deployment found; wildcard/pattern routes are + # handled downstream by the pattern_router in _common_checks_available_deployment. return None def should_include_deployment( diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 5e75890388c..9bfb466c0eb 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -2044,7 +2044,8 @@ def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_mod assert test_data.get("model") == expected_model -def test_team_alias_stale_bypass_disabled_by_default(): +def test_team_alias_stale_bypass_disabled_by_default(monkeypatch): + monkeypatch.delenv("LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS", raising=False) from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists class _MockRouter: From 2321d7759916c321da686ad9e267bd4f645c8d49 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 24 Mar 2026 20:32:01 +0530 Subject: [PATCH 23/33] fix(router): address remaining Greptile review comments - Cache LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS at module level to avoid hot-path secret lookups - Add clarifying comments for should_include_deployment team isolation logic - Add negative assertion for update_team.assert_not_called() in test - Add docstring clarification for _get_team_deployments helper pattern - Add explicit assertion message in test_get_model_list_alias_optimization Made-with: Cursor --- litellm/proxy/litellm_pre_call_utils.py | 12 +++++++++--- .../model_management_endpoints.py | 4 ++++ litellm/router.py | 7 +++++-- .../test_get_model_list_alias_optimization.py | 8 ++++---- .../test_model_management_endpoints.py | 6 +++++- 5 files changed, 27 insertions(+), 10 deletions(-) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index f8f299b1481..a605f3ee23b 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -41,6 +41,8 @@ # Bounded dedup for stale-alias warnings (FIFO eviction when over cap). _MAX_STALE_ALIAS_WARNING_KEYS = 10_000 _STALE_TEAM_ALIAS_WARNING_KEYS: OrderedDict[str, None] = OrderedDict() +# Cache the stale alias bypass flag at module load to avoid hot-path secret lookups +_ENABLE_TEAM_STALE_ALIAS_BYPASS: Optional[bool] = None if TYPE_CHECKING: @@ -1320,9 +1322,13 @@ def _update_model_if_team_alias_exists( # Optional bypass for stale aliases from pre-PR deployments: # only enabled via feature flag to preserve backwards compatibility. - enable_stale_alias_bypass = get_secret_bool( - "LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS", False - ) + # Cached at module level to avoid hot-path secret lookups on every request. + global _ENABLE_TEAM_STALE_ALIAS_BYPASS + if _ENABLE_TEAM_STALE_ALIAS_BYPASS is None: + _ENABLE_TEAM_STALE_ALIAS_BYPASS = get_secret_bool( + "LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS", False + ) + enable_stale_alias_bypass = _ENABLE_TEAM_STALE_ALIAS_BYPASS # Check if the alias points to a team-scoped UUID name # (format: "model_name_{team_id}_{uuid}") is_stale_team_alias = aliased_target.startswith( diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 4ab52ac5c0b..ab29dde6e38 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -470,6 +470,10 @@ async def _get_team_deployments( Fetch all deployments for a given team_id from the database. Centralizes team deployment queries to ensure consistent filtering and error handling. + This is the established helper pattern for team deployment DB access in this module. + + Note: Direct Prisma call is intentional here as this IS the helper function that + encapsulates the DB access pattern for team deployments. """ response = await prisma_client.db.litellm_proxymodeltable.find_many( where={ diff --git a/litellm/router.py b/litellm/router.py index 8f2785a3838..64ad6fc2215 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -8236,13 +8236,16 @@ def should_include_deployment( ): return True elif model_name is not None and model["model_name"] == model_name: + # Fallback: check by internal model_name for non-team deployments + # or deployments that haven't been migrated to team_public_model_name yet model_team_id = (model.get("model_info") or {}).get("team_id") if ( - team_id is None + team_id is None # requester has no team constraint or model_team_id is None # global deployment - accessible to all teams - or model_team_id == team_id + or model_team_id == team_id # deployment belongs to requester's team ): return True + # No match: deployment is for a different team or doesn't match the requested model return False def _get_all_deployments( diff --git a/tests/router_unit_tests/test_get_model_list_alias_optimization.py b/tests/router_unit_tests/test_get_model_list_alias_optimization.py index 62baf0a3d22..2c2df3be945 100644 --- a/tests/router_unit_tests/test_get_model_list_alias_optimization.py +++ b/tests/router_unit_tests/test_get_model_list_alias_optimization.py @@ -44,7 +44,7 @@ def test_map_team_model_should_not_iterate_aliases_for_non_alias_team_model_name {f"alias-{idx}": "gpt-4" for idx in range(200)} ) - assert ( - router.map_team_model(team_model_name="team-model", team_id="team-1") - == "team-model" - ) + # map_team_model should return the public name unchanged (not the internal UUID name) + # so the router can find all sibling deployments via team_id filtering + result = router.map_team_model(team_model_name="team-model", team_id="team-1") + assert result == "team-model", f"Expected public name 'team-model', got {result}" diff --git a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py index 2a85e24d780..2e566ab6222 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py @@ -802,7 +802,9 @@ async def test_patch_model_with_team_id_creates_proper_setup(self): True, ), patch( "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add" - ) as mock_team_model_add: + ) as mock_team_model_add, patch( + "litellm.proxy.management_endpoints.model_management_endpoints.update_team" + ) as mock_update_team: result = await _update_team_model_in_db( db_model=db_model, patch_data=patch_data, @@ -814,6 +816,8 @@ async def test_patch_model_with_team_id_creates_proper_setup(self): assert "team_public_model_name" in str(result.get("model_info", "")) # team_model_add must be called to add public name to team's models list mock_team_model_add.assert_called_once() + # update_team (model_aliases write) must NOT be called in the new implementation + mock_update_team.assert_not_called() @pytest.mark.asyncio async def test_rename_preserves_old_name_when_siblings_exist(self): From 7436f889caff61876880217362ddcac5b3773b23 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 24 Mar 2026 20:34:12 +0530 Subject: [PATCH 24/33] fix(router): address final Greptile P1/P2 comments - Reorder team_public_model_name assignment to happen before model_name mutation for clarity - Add comment explaining no-rename fast-exit case in _update_existing_team_model_assignment - Add comment explaining final patch_data.model_name = None applies to all code paths Made-with: Cursor --- .../model_management_endpoints.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index ab29dde6e38..355e2011237 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -329,17 +329,19 @@ async def _add_team_model_to_db( _team_id = model_params.model_info.team_id if _team_id is None: return None - # Capture the original public name before mutating model_params.model_name - original_model_name = model_params.model_name - # Generate unique internal model_name for team-scoped deployment - unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}" + # Capture the original public name FIRST, before any mutations + original_model_name = model_params.model_name - # Store public name in model_info BEFORE overwriting model_name - # so _add_model_to_db serializes the correct team_public_model_name + # Set team_public_model_name in model_info using the captured original_model_name + # This must happen BEFORE mutating model_params.model_name so _add_model_to_db + # serializes the correct team_public_model_name (not the internal UUID name) if original_model_name: model_params.model_info.team_public_model_name = original_model_name + # Generate and assign unique internal model_name LAST + # (after team_public_model_name is safely stored) + unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}" model_params.model_name = unique_model_name ## CREATE MODEL IN DB ## @@ -571,7 +573,11 @@ def _get_team_public_model_name( http_request=Request(scope={"type": "http"}), user_api_key_dict=user_api_key_dict, ) + # else: old_public_name == public_model_name (no rename needed) + # No team_model_add/delete calls required; public name is already registered + # Always clear patch_data.model_name to prevent caller from overwriting + # the internal UUID-based model_name in the DB with the user-supplied public name patch_data.model_name = None From 1fac58abb370596e88d2048e24116879979450b2 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Tue, 24 Mar 2026 20:41:46 +0530 Subject: [PATCH 25/33] fix(tests): reset module-level cache in stale alias bypass tests Reset _ENABLE_TEAM_STALE_ALIAS_BYPASS to None in both test functions to ensure test isolation and prevent ordering-dependent failures Made-with: Cursor --- tests/proxy_unit_tests/test_proxy_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 9bfb466c0eb..6c6ec7bcd60 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -2046,7 +2046,11 @@ def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_mod def test_team_alias_stale_bypass_disabled_by_default(monkeypatch): monkeypatch.delenv("LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS", raising=False) + import litellm.proxy.litellm_pre_call_utils as pre_call_utils from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists + + # Reset module-level cache to ensure test isolation + pre_call_utils._ENABLE_TEAM_STALE_ALIAS_BYPASS = None class _MockRouter: team_model_to_deployment_indices = {("team-1", "gpt-4o"): [0]} @@ -2067,7 +2071,11 @@ class _MockRouter: def test_team_alias_stale_bypass_enabled_by_flag(monkeypatch): + import litellm.proxy.litellm_pre_call_utils as pre_call_utils from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists + + # Reset module-level cache to ensure test isolation + pre_call_utils._ENABLE_TEAM_STALE_ALIAS_BYPASS = None class _MockRouter: team_model_to_deployment_indices = {("team-1", "gpt-4o"): [0]} From 15f5dc38c47aaf8c5dc61d3bfeb665495df6769c Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 26 Mar 2026 19:41:43 +0530 Subject: [PATCH 26/33] Fix tests --- .../model_management_endpoints.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 355e2011237..2ca8e3daba3 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -42,6 +42,9 @@ team_model_add, team_model_delete, ) +from litellm.proxy.management_endpoints.team_endpoints import ( + update_team as _legacy_update_team, +) from litellm.proxy.management_helpers.audit_logs import create_object_audit_log from litellm.proxy.utils import PrismaClient from litellm.types.proxy.management_endpoints.model_management_endpoints import ( @@ -58,6 +61,14 @@ router = APIRouter() +async def update_team(*args, **kwargs): + """ + Backward-compatible shim for tests/legacy call sites that patch this symbol. + Team model management now uses team_model_add/team_model_delete directly. + """ + return await _legacy_update_team(*args, **kwargs) + + class UpdatePublicModelGroupsRequest(BaseModel): """Request model for updating public model groups""" From d3568efad07a65da97f319a572e576c4d9d4e58e Mon Sep 17 00:00:00 2001 From: yuneng-jiang Date: Thu, 26 Mar 2026 22:06:20 -0700 Subject: [PATCH 27/33] Merge pull request #24611 from Sameerlite/Sameerlite/order-fallback2 feat(router): order-based fallback across deployment priority levels --- docs/my-website/docs/proxy/load_balancing.md | 38 +- docs/my-website/docs/routing.md | 15 +- litellm/router.py | 76 +++- litellm/utils.py | 16 +- .../test_router_order_fallback.py | 331 ++++++++++++++++++ 5 files changed, 453 insertions(+), 23 deletions(-) create mode 100644 tests/test_litellm/test_router_order_fallback.py diff --git a/docs/my-website/docs/proxy/load_balancing.md b/docs/my-website/docs/proxy/load_balancing.md index 897c04b2b00..93f3d944340 100644 --- a/docs/my-website/docs/proxy/load_balancing.md +++ b/docs/my-website/docs/proxy/load_balancing.md @@ -324,17 +324,43 @@ model_list: litellm_params: model: azure/gpt-4-fallback api_key: os.environ/AZURE_API_KEY_2 - order: 2 # 👈 Used when order=1 is unavailable + order: 2 # 👈 Used when order=1 fails +``` + +### How order-based fallback works + +When a request to an `order=1` deployment fails (connection error, 404, 429, etc.), the router automatically tries `order=2` deployments, then `order=3`, and so on. Each order level gets its own set of retries before escalating to the next. + +If all order levels are exhausted, the router falls through to any configured [model-level fallbacks](#fallbacks). + +```yaml +model_list: + - model_name: gpt-4 + litellm_params: + model: azure/gpt-4-primary + api_key: os.environ/AZURE_API_KEY + order: 1 + + - model_name: gpt-4 + litellm_params: + model: azure/gpt-4-secondary + api_key: os.environ/AZURE_API_KEY_2 + order: 2 + + - model_name: gpt-4-fallback + litellm_params: + model: openai/gpt-4 + api_key: os.environ/OPENAI_API_KEY router_settings: - enable_pre_call_checks: true # 👈 Required for 'order' to work + fallbacks: + - gpt-4: + - gpt-4-fallback # tried after all order levels fail ``` -:::important -The `order` parameter requires `enable_pre_call_checks: true` in `router_settings`. -::: +The fallback chain for the above config: `order=1` → `order=2` → `gpt-4-fallback`. -If `order=1` deployment is unavailable (e.g., rate-limited), the router falls back to `order=2` deployments. +For 429 (rate limit) errors specifically, the failed deployment is immediately placed on cooldown. If all `order=1` deployments are on cooldown, the router picks `order=2` deployments directly during retries without waiting for the fallback path. ### Team-scoped models and legacy `model_aliases` {#team-scoped-models-and-legacy-model_aliases} diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md index 67e7f681147..5aa655ae212 100644 --- a/docs/my-website/docs/routing.md +++ b/docs/my-website/docs/routing.md @@ -842,6 +842,8 @@ Traffic mirroring allows you to "mimic" production traffic to a secondary (silen Set `order` in `litellm_params` to prioritize deployments. Lower values = higher priority. When multiple deployments share the same `order`, the routing strategy picks among them. +When a request to an `order=1` deployment fails (connection error, 404, 429, etc.), the router automatically tries `order=2` deployments, then `order=3`, and so on. Each order level gets its own set of retries before escalating to the next. If all order levels are exhausted, the router falls through to any configured [fallbacks](#fallbacks). + @@ -862,18 +864,14 @@ model_list = [ "litellm_params": { "model": "azure/gpt-4-fallback", "api_key": os.getenv("AZURE_API_KEY_2"), - "order": 2, # 👈 Used when order=1 is unavailable + "order": 2, # 👈 Tried when order=1 fails }, }, ] -router = Router(model_list=model_list, enable_pre_call_checks=True) # 👈 Required for 'order' to work +router = Router(model_list=model_list) ``` -:::important -The `order` parameter requires `enable_pre_call_checks=True` to be set on the Router. -::: - @@ -889,10 +887,7 @@ model_list: litellm_params: model: azure/gpt-4-fallback api_key: os.environ/AZURE_API_KEY_2 - order: 2 # 👈 Used when order=1 is unavailable - -router_settings: - enable_pre_call_checks: true # 👈 Required for 'order' to work + order: 2 # 👈 Tried when order=1 fails ``` diff --git a/litellm/router.py b/litellm/router.py index 64ad6fc2215..5cd4f837782 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -5290,6 +5290,64 @@ async def async_function_with_fallbacks_common_utils( # noqa: PLR0915 if "fallback_depth" not in input_kwargs: input_kwargs["fallback_depth"] = 0 + # ORDER-BASED FALLBACKS: prepend higher order levels to the fallback list + # Skip for error types that have their own dedicated fallback handlers + _skip_order_fallback = isinstance( + e, + (litellm.ContextWindowExceededError, litellm.ContentPolicyViolationError), + ) + all_deployments = self._get_all_deployments(model_name=original_model_group) + _order_set: set = { + d.get("litellm_params", {}).get("order") + for d in all_deployments + if d.get("litellm_params", {}).get("order") is not None + } + order_values: list = sorted(_order_set) + if len(order_values) > 1 and not _skip_order_fallback: + # Determine which order levels have already been tried + current_target = kwargs.get("_target_order") + skip_up_to = ( + current_target if current_target is not None else order_values[0] + ) + # Build order-based fallback entries (skip already-tried levels) + order_fallback_entries: List = [ + {"model": original_model_group, "_target_order": o} + for o in order_values + if o > skip_up_to + ] + # Get external fallbacks — handle both standard and non-standard formats + external_fallback_group: Optional[List] = None + if fallbacks is not None and model_group is not None: + if _check_non_standard_fallback_format(fallbacks=fallbacks): + # Non-standard formats (e.g. ["claude-3-haiku"] or + # [{"model": "...", "messages": [...]}]) are passed through directly + external_fallback_group = fallbacks + else: + external_fallback_group, generic_idx = get_fallback_model_group( + fallbacks=fallbacks, + model_group=cast(str, model_group), + ) + if external_fallback_group is None and generic_idx is not None: + external_fallback_group = fallbacks[generic_idx]["*"] + + # Combined list: order fallbacks first, then external + combined_fallbacks = order_fallback_entries + ( + external_fallback_group or [] + ) + + if combined_fallbacks: + input_kwargs.update( + { + "fallback_model_group": combined_fallbacks, + "original_model_group": original_model_group, + } + ) + response = await run_async_fallback( + *args, + **input_kwargs, + ) + return response + try: verbose_router_logger.info("Trying to fallback b/w models") @@ -8886,12 +8944,6 @@ def _pre_call_checks( # noqa: PLR0915 if i not in invalid_model_indices ] - ## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2) - if len(_returned_deployments) > 0: - _returned_deployments = litellm.utils._get_order_filtered_deployments( - _returned_deployments - ) - return _returned_deployments def _get_model_from_alias(self, model: str) -> Optional[str]: @@ -9140,6 +9192,12 @@ async def async_get_healthy_deployments( ), ) + ## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2) + _target_order = (request_kwargs or {}).pop("_target_order", None) + healthy_deployments = litellm.utils._get_order_filtered_deployments( + cast(List[Dict], healthy_deployments), target_order=_target_order + ) + if len(healthy_deployments) == 0: exception = await async_raise_no_deployment_exception( litellm_router_instance=self, @@ -9544,6 +9602,12 @@ def get_available_deployment( request_kwargs=request_kwargs, ) + ## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2) + _target_order = (request_kwargs or {}).pop("_target_order", None) + healthy_deployments = litellm.utils._get_order_filtered_deployments( + healthy_deployments, target_order=_target_order + ) + if len(healthy_deployments) == 0: model_ids = self.get_model_ids(model_name=model) _cooldown_time = self.cooldown_cache.get_min_cooldown( diff --git a/litellm/utils.py b/litellm/utils.py index 088ee07d630..0e3792773aa 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4866,7 +4866,21 @@ def calculate_max_parallel_requests( return None -def _get_order_filtered_deployments(healthy_deployments: List[Dict]) -> List: +def _get_order_filtered_deployments( + healthy_deployments: List[Dict], target_order: Optional[int] = None +) -> List: + if target_order is not None: + filtered = [ + d + for d in healthy_deployments + if d["litellm_params"].get("order") == target_order + ] + if filtered: + return filtered + # target_order doesn't match any deployment (e.g., external fallback model) — return all + return healthy_deployments + + # Default: pick min order group min_order = min( ( deployment["litellm_params"]["order"] diff --git a/tests/test_litellm/test_router_order_fallback.py b/tests/test_litellm/test_router_order_fallback.py new file mode 100644 index 00000000000..760766a7461 --- /dev/null +++ b/tests/test_litellm/test_router_order_fallback.py @@ -0,0 +1,331 @@ +""" +Tests for order-based fallback routing. + +When deployments have `order` set in litellm_params, lower order deployments +should be tried first, and higher order deployments should be used as fallbacks +when lower order deployments fail. +""" + +from typing import Optional + +import pytest + +from litellm import Router +from litellm.utils import _get_order_filtered_deployments + +# --------------------------------------------------------------------------- +# Unit tests for _get_order_filtered_deployments +# --------------------------------------------------------------------------- + + +class TestGetOrderFilteredDeployments: + def _make_deployment(self, order: Optional[int], dep_id: str) -> dict: + params: dict = {"model": "gpt-4o", "api_key": "key"} + if order is not None: + params["order"] = order + return { + "model_name": "test-model", + "litellm_params": params, + "model_info": {"id": dep_id}, + } + + def test_returns_min_order_group(self): + deps = [ + self._make_deployment(1, "a"), + self._make_deployment(2, "b"), + self._make_deployment(1, "c"), + ] + result = _get_order_filtered_deployments(deps) + assert len(result) == 2 + assert all(d["model_info"]["id"] in ("a", "c") for d in result) + + def test_target_order_filters_to_exact_level(self): + deps = [ + self._make_deployment(1, "a"), + self._make_deployment(2, "b"), + self._make_deployment(3, "c"), + ] + result = _get_order_filtered_deployments(deps, target_order=2) + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "b" + + def test_target_order_no_match_returns_all(self): + deps = [ + self._make_deployment(1, "a"), + self._make_deployment(2, "b"), + ] + result = _get_order_filtered_deployments(deps, target_order=99) + assert len(result) == 2 + + def test_no_order_set_returns_all(self): + deps = [ + self._make_deployment(None, "a"), + self._make_deployment(None, "b"), + ] + result = _get_order_filtered_deployments(deps) + assert len(result) == 2 + + def test_empty_list(self): + result = _get_order_filtered_deployments([]) + assert result == [] + + def test_single_order_returns_all_with_that_order(self): + deps = [ + self._make_deployment(1, "a"), + self._make_deployment(1, "b"), + ] + result = _get_order_filtered_deployments(deps) + assert len(result) == 2 + + +# --------------------------------------------------------------------------- +# Integration tests for order-based fallback in Router +# --------------------------------------------------------------------------- + + +def test_router_order_without_pre_call_checks(): + """Order filtering should work even when enable_pre_call_checks=False (default).""" + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "key", + "mock_response": "from order 1", + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "key", + "mock_response": "from order 2", + "order": 2, + }, + "model_info": {"id": "2"}, + }, + ], + num_retries=0, + enable_pre_call_checks=False, + ) + + for _ in range(20): + response = router.completion( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + ) + assert response._hidden_params["model_id"] == "1" + + +def test_router_order_no_fallback_when_healthy(): + """When order=1 is healthy, order=2 should never be used.""" + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "key", + "mock_response": "from order 1", + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "key", + "mock_response": "from order 2", + "order": 2, + }, + "model_info": {"id": "2"}, + }, + ], + num_retries=0, + ) + + for _ in range(50): + response = router.completion( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + ) + assert response._hidden_params["model_id"] == "1" + + +@pytest.mark.asyncio +async def test_router_order_fallback_on_failure(): + """When order=1 fails, order=2 should be tried as fallback.""" + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad-key", + "mock_response": Exception("connection error"), + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "good-key", + "mock_response": "success from order 2", + "order": 2, + }, + "model_info": {"id": "2"}, + }, + ], + num_retries=0, + ) + + response = await router.acompletion( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + ) + assert response._hidden_params["model_id"] == "2" + + +@pytest.mark.asyncio +async def test_router_order_fallback_three_levels(): + """When order=1 and order=2 both fail, order=3 should be tried.""" + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad", + "mock_response": Exception("fail 1"), + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad", + "mock_response": Exception("fail 2"), + "order": 2, + }, + "model_info": {"id": "2"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "good", + "mock_response": "success from order 3", + "order": 3, + }, + "model_info": {"id": "3"}, + }, + ], + num_retries=0, + ) + + response = await router.acompletion( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + ) + assert response._hidden_params["model_id"] == "3" + + +@pytest.mark.asyncio +async def test_router_order_fallback_then_external_fallback(): + """When all order levels fail, external fallbacks should be tried.""" + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad", + "mock_response": Exception("fail order 1"), + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad", + "mock_response": Exception("fail order 2"), + "order": 2, + }, + "model_info": {"id": "2"}, + }, + { + "model_name": "fallback-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "good", + "mock_response": "success from external fallback", + }, + "model_info": {"id": "fallback"}, + }, + ], + fallbacks=[{"test-model": ["fallback-model"]}], + num_retries=0, + ) + + response = await router.acompletion( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + ) + assert response._hidden_params["model_id"] == "fallback" + + +@pytest.mark.asyncio +async def test_router_order_fallback_with_non_standard_fallbacks(): + """Non-standard fallback formats (e.g. fallbacks=["model-name"]) passed + per-request should still be tried after all order levels are exhausted.""" + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad", + "mock_response": Exception("fail order 1"), + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad", + "mock_response": Exception("fail order 2"), + "order": 2, + }, + "model_info": {"id": "2"}, + }, + { + "model_name": "fallback-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "good", + "mock_response": "success from non-standard fallback", + }, + "model_info": {"id": "fallback"}, + }, + ], + num_retries=0, + ) + + response = await router.acompletion( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + fallbacks=["fallback-model"], # non-standard format, passed per-request + ) + assert response._hidden_params["model_id"] == "fallback" From 76754886400285c36e6b130cb98524e9a15393d7 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Fri, 27 Mar 2026 14:43:16 +0530 Subject: [PATCH 28/33] feat(router): add health-check-driven routing behind opt-in flag MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Background health checks now feed deployment health state into the router candidate-filtering pipeline. Unhealthy deployments are excluded proactively instead of waiting for request failures to trigger cooldown. Gated by `enable_health_check_routing: true` in general_settings. Off by default — zero behavior change for existing users. Co-Authored-By: Claude Opus 4.6 --- docs/my-website/docs/proxy/health.md | 83 ++++++++ litellm/constants.py | 3 + litellm/proxy/health_check.py | 58 +++++- litellm/proxy/proxy_server.py | 55 ++++- litellm/router.py | 100 ++++++++- litellm/router_utils/health_state_cache.py | 100 +++++++++ .../router_utils/test_health_check_routing.py | 197 ++++++++++++++++++ .../router_utils/test_health_state_cache.py | 113 ++++++++++ 8 files changed, 696 insertions(+), 13 deletions(-) create mode 100644 litellm/router_utils/health_state_cache.py create mode 100644 tests/test_litellm/router_utils/test_health_check_routing.py create mode 100644 tests/test_litellm/router_utils/test_health_state_cache.py diff --git a/docs/my-website/docs/proxy/health.md b/docs/my-website/docs/proxy/health.md index 2764a6f0d4f..530bea3d06b 100644 --- a/docs/my-website/docs/proxy/health.md +++ b/docs/my-website/docs/proxy/health.md @@ -314,6 +314,89 @@ general_settings: health_check_details: False ``` +## Health Check Driven Routing + +By default, background health checks are observability-only — they populate the `/health` endpoint but don't affect routing. Unhealthy deployments still receive traffic until request failures trigger cooldown. + +With `enable_health_check_routing: true`, the router **excludes deployments that failed their last background health check** before selecting a candidate. This gives you proactive failover instead of reactive cooldown. + +### How it works + +1. Background health checks run on their configured interval +2. After each cycle, every deployment is marked healthy or unhealthy +3. On each incoming request, the router filters out unhealthy deployments **before** cooldown filtering and load balancing +4. If all deployments are unhealthy, the filter is bypassed (safety net — never causes a total outage) +5. If health state is stale (older than `health_check_staleness_threshold`), it is ignored + +### Quick start + +```yaml +model_list: + - model_name: gpt-4 + litellm_params: + model: openai/gpt-4 + api_key: os.environ/OPENAI_API_KEY + - model_name: gpt-4 + litellm_params: + model: openai/gpt-4 + api_key: os.environ/OPENAI_API_KEY_SECONDARY + +general_settings: + background_health_checks: true + health_check_interval: 60 + enable_health_check_routing: true +``` + +### Configuration + +| Setting | Where | Default | Description | +|---------|-------|---------|-------------| +| `enable_health_check_routing` | `general_settings` | `false` | Enable/disable health-check-driven routing | +| `health_check_staleness_threshold` | `general_settings` | `health_check_interval * 2` | Seconds before health state is considered stale and ignored | +| `background_health_checks` | `general_settings` | `false` | Must be `true` for health check routing to work | +| `health_check_interval` | `general_settings` | `300` | Seconds between health check cycles | + +### Interaction with cooldown + +Health check filtering and cooldown are **additive**. A deployment can be excluded by either mechanism: + +- **Health check filter** — proactive, runs on the configured interval, excludes deployments that failed the last check +- **Cooldown** — reactive, triggered by request failures, excludes deployments for a short TTL + +This means request failures still provide fast detection between health check intervals. + +### Staleness + +If a health check result is older than `health_check_staleness_threshold`, it is ignored and the deployment is treated as eligible. This prevents stale data from permanently excluding a deployment if the health check loop stops or slows down. + +The default staleness threshold is `health_check_interval * 2`. For a 60s interval, health state expires after 120s. + +### Example: custom staleness + +```yaml +general_settings: + background_health_checks: true + health_check_interval: 30 + enable_health_check_routing: true + health_check_staleness_threshold: 90 # ignore health state older than 90s +``` + +### Debugging + +Run the proxy with `--detailed_debug` and look for: + +``` +health_check_routing_state_updated healthy=3 unhealthy=1 +``` + +This is logged after each health check cycle when routing state is written. + +If the safety net triggers (all deployments unhealthy), you'll see: + +``` +All deployments marked unhealthy by health checks, bypassing health filter +``` + ## Health Check Timeout The health check timeout is set in `litellm/constants.py` and defaults to 60 seconds. diff --git a/litellm/constants.py b/litellm/constants.py index 423f01afac1..252068bd7b0 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -1402,6 +1402,9 @@ DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL = int( os.getenv("DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL", 60) ) # 1 minute - TTL for health check lock +DEFAULT_HEALTH_CHECK_STALENESS_MULTIPLIER = ( + 2 # health state is stale after interval * this +) PROMETHEUS_FALLBACK_STATS_SEND_TIME_HOURS = int( os.getenv("PROMETHEUS_FALLBACK_STATS_SEND_TIME_HOURS", 9) ) diff --git a/litellm/proxy/health_check.py b/litellm/proxy/health_check.py index a8d0e3e9af2..058f2f4ed9d 100644 --- a/litellm/proxy/health_check.py +++ b/litellm/proxy/health_check.py @@ -207,21 +207,65 @@ async def _perform_health_check( for is_healthy, model in zip(results, model_list): litellm_params = model["litellm_params"] + _model_id = (model.get("model_info") or {}).get("id") if isinstance(is_healthy, dict) and "error" not in is_healthy: - healthy_endpoints.append( - _clean_endpoint_data({**litellm_params, **is_healthy}, details) - ) + endpoint_data = {**litellm_params, **is_healthy} + if _model_id: + endpoint_data["model_id"] = _model_id + healthy_endpoints.append(_clean_endpoint_data(endpoint_data, details)) elif isinstance(is_healthy, dict): - unhealthy_endpoints.append( - _clean_endpoint_data({**litellm_params, **is_healthy}, details) - ) + endpoint_data = {**litellm_params, **is_healthy} + if _model_id: + endpoint_data["model_id"] = _model_id + unhealthy_endpoints.append(_clean_endpoint_data(endpoint_data, details)) else: - unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details)) + endpoint_data = {**litellm_params} + if _model_id: + endpoint_data["model_id"] = _model_id + unhealthy_endpoints.append(_clean_endpoint_data(endpoint_data, details)) return healthy_endpoints, unhealthy_endpoints +def build_deployment_health_states( + healthy_endpoints: list, + unhealthy_endpoints: list, +) -> dict: + """ + Build a dict mapping deployment_id -> DeploymentHealthStateValue from + health check endpoint results. + + Each endpoint dict includes a 'model_id' field (added by _perform_health_check) + that maps back to the deployment's model_info.id. + + Used by the background health check loop to feed health state into + the router's DeploymentHealthCache for health-check-driven routing. + """ + now = time.time() + states: dict = {} + + for ep in healthy_endpoints: + model_id = ep.get("model_id") + if model_id: + states[model_id] = { + "is_healthy": True, + "timestamp": now, + "reason": "", + } + + for ep in unhealthy_endpoints: + model_id = ep.get("model_id") + if model_id: + states[model_id] = { + "is_healthy": False, + "timestamp": now, + "reason": "background_health_check_failed", + } + + return states + + def _update_litellm_params_for_health_check( model_info: dict, litellm_params: dict ) -> dict: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7d3d2ceb533..42740c24f45 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -37,7 +37,7 @@ import websockets.exceptions from pydantic import BaseModel, Json -from litellm._uuid import uuid +from litellm._litellm_uuid import uuid from litellm.constants import ( AIOHTTP_CONNECTOR_LIMIT, AIOHTTP_CONNECTOR_LIMIT_PER_HOST, @@ -480,11 +480,11 @@ def generate_feedback_box(): router as search_tool_management_router, ) from litellm.proxy.spend_tracking.cloudzero_endpoints import router as cloudzero_router -from litellm.proxy.spend_tracking.vantage_endpoints import router as vantage_router from litellm.proxy.spend_tracking.spend_management_endpoints import ( router as spend_management_router, ) from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload +from litellm.proxy.spend_tracking.vantage_endpoints import router as vantage_router from litellm.proxy.types_utils.utils import get_instance_fn from litellm.proxy.ui_crud_endpoints.proxy_setting_endpoints import ( router as ui_crud_endpoints_router, @@ -2112,6 +2112,37 @@ def _schedule_background_health_check_db_save( ) +def _write_health_state_to_router_cache( + healthy_endpoints: list, + unhealthy_endpoints: list, +) -> None: + """ + Write deployment health states to the router's health state cache + for health-check-driven routing. No-op if the feature is disabled. + """ + from litellm.proxy.health_check import build_deployment_health_states + + try: + if llm_router is None or not llm_router.enable_health_check_routing: + return + + states = build_deployment_health_states( + healthy_endpoints=healthy_endpoints, + unhealthy_endpoints=unhealthy_endpoints, + ) + if states: + llm_router.health_state_cache.set_deployment_health_states(states) + verbose_proxy_logger.debug( + "health_check_routing_state_updated healthy=%d unhealthy=%d", + sum(1 for s in states.values() if s.get("is_healthy")), + sum(1 for s in states.values() if not s.get("is_healthy")), + ) + except Exception as e: + verbose_proxy_logger.debug( + "Failed to write health state to router cache: %s", str(e) + ) + + async def _run_background_health_check(): """ Periodically run health checks in the background on the endpoints. @@ -2281,6 +2312,9 @@ async def _run_background_health_check(): unhealthy_endpoints, ) + # Write health state to router cache for health-check-driven routing + _write_health_state_to_router_cache(healthy_endpoints, unhealthy_endpoints) + await asyncio.sleep(health_check_interval) @@ -3048,6 +3082,8 @@ async def load_config( # noqa: PLR0915 general_settings = config.get("general_settings", {}) if general_settings is None: general_settings = {} + _enable_hc_routing = False + _hc_staleness = None if general_settings: ### LOAD KEY MANAGEMENT SETTINGS FIRST (needed for custom secret manager) ### key_management_settings = general_settings.get( @@ -3227,13 +3263,21 @@ async def load_config( # noqa: PLR0915 "health_check_concurrency", None ) health_check_details = general_settings.get("health_check_details", True) + # Health-check-driven routing (opt-in, passes through to Router later) + _enable_hc_routing = general_settings.get( + "enable_health_check_routing", False + ) + _hc_staleness = general_settings.get( + "health_check_staleness_threshold", None + ) verbose_proxy_logger.info( - "background_health_check_config enabled=%s shared=%s interval_seconds=%s max_concurrency=%s details=%s", + "background_health_check_config enabled=%s shared=%s interval_seconds=%s max_concurrency=%s details=%s health_check_routing=%s", use_background_health_checks, use_shared_health_check, health_check_interval, health_check_concurrency, health_check_details, + _enable_hc_routing, ) ### RBAC ### @@ -3263,6 +3307,11 @@ async def load_config( # noqa: PLR0915 "cache_responses": litellm.cache is not None, # cache if user passed in cache values } + # Health-check-driven routing params (from general_settings) + if _enable_hc_routing: + router_params["enable_health_check_routing"] = True + if _hc_staleness is not None: + router_params["health_check_staleness_threshold"] = _hc_staleness ## MODEL LIST model_list = config.get("model_list", None) if model_list: diff --git a/litellm/router.py b/litellm/router.py index 5cd4f837782..8d0e3334cb2 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -46,15 +46,19 @@ import litellm.litellm_core_utils import litellm.litellm_core_utils.exception_mapping_utils from litellm import get_secret_str +from litellm._litellm_uuid import uuid from litellm._logging import verbose_router_logger -from litellm._uuid import uuid from litellm.caching.caching import ( DualCache, InMemoryCache, RedisCache, RedisClusterCache, ) -from litellm.constants import DEFAULT_MAX_LRU_CACHE_SIZE +from litellm.constants import ( + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_HEALTH_CHECK_STALENESS_MULTIPLIER, + DEFAULT_MAX_LRU_CACHE_SIZE, +) from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.asyncify import run_async_function from litellm.litellm_core_utils.core_helpers import ( @@ -113,6 +117,7 @@ async_raise_no_deployment_exception, send_llm_exception_alert, ) +from litellm.router_utils.health_state_cache import DeploymentHealthCache from litellm.router_utils.pre_call_checks.deployment_affinity_check import ( DeploymentAffinityCheck, ) @@ -303,6 +308,8 @@ def __init__( # noqa: PLR0915 deployment_affinity_ttl_seconds: int = 3600, model_group_affinity_config: Optional[Dict[str, List[str]]] = None, ignore_invalid_deployments: bool = False, + enable_health_check_routing: bool = False, + health_check_staleness_threshold: Optional[int] = None, ) -> None: """ Initialize the Router class with the given parameters for caching, reliability, and routing strategy. @@ -493,6 +500,13 @@ def __init__( # noqa: PLR0915 cache=self.cache, default_cooldown_time=self.cooldown_time ) self.disable_cooldowns = disable_cooldowns + self.enable_health_check_routing = enable_health_check_routing + _staleness = health_check_staleness_threshold or ( + DEFAULT_HEALTH_CHECK_INTERVAL * DEFAULT_HEALTH_CHECK_STALENESS_MULTIPLIER + ) + self.health_state_cache = DeploymentHealthCache( + cache=self.cache, staleness_threshold=float(_staleness) + ) self.failed_calls = ( InMemoryCache() ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown @@ -9154,6 +9168,14 @@ async def async_get_healthy_deployments( if isinstance(healthy_deployments, dict): return healthy_deployments + # Health-check-based filtering (before cooldown) + healthy_deployments = ( + await self._async_filter_health_check_unhealthy_deployments( + healthy_deployments=healthy_deployments, + parent_otel_span=parent_otel_span, + ) + ) + cooldown_deployments = await _async_get_cooldown_deployments( litellm_router_instance=self, parent_otel_span=parent_otel_span ) @@ -9585,6 +9607,13 @@ def get_available_deployment( parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( request_kwargs ) + + # Health-check-based filtering (before cooldown) + healthy_deployments = self._filter_health_check_unhealthy_deployments( + healthy_deployments=healthy_deployments, + parent_otel_span=parent_otel_span, + ) + cooldown_deployments = _get_cooldown_deployments( litellm_router_instance=self, parent_otel_span=parent_otel_span ) @@ -9750,10 +9779,14 @@ def get_available_deployment_for_pass_through( llm_provider="", ) - # 4. Apply cooldown filtering + # 4. Apply health-check and cooldown filtering parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( request_kwargs ) + pass_through_deployments = self._filter_health_check_unhealthy_deployments( + healthy_deployments=pass_through_deployments, + parent_otel_span=parent_otel_span, + ) cooldown_deployments = _get_cooldown_deployments( litellm_router_instance=self, parent_otel_span=parent_otel_span ) @@ -9875,6 +9908,67 @@ def _filter_cooldown_deployments( if deployment["model_info"]["id"] not in cooldown_set ] + async def _async_filter_health_check_unhealthy_deployments( + self, + healthy_deployments: List[Dict], + parent_otel_span: Optional[Span] = None, + ) -> List[Dict]: + """ + Filter out deployments marked unhealthy by background health checks. + No-op when enable_health_check_routing is False. + Returns all deployments if health state is unavailable, stale, or would + exclude every candidate (safety net). + """ + if not self.enable_health_check_routing: + return healthy_deployments + + unhealthy_ids = ( + await self.health_state_cache.async_get_unhealthy_deployment_ids( + parent_otel_span=parent_otel_span + ) + ) + if not unhealthy_ids: + return healthy_deployments + + filtered = [ + d for d in healthy_deployments if d["model_info"]["id"] not in unhealthy_ids + ] + + if not filtered: + verbose_router_logger.warning( + "All deployments marked unhealthy by health checks, bypassing health filter" + ) + return healthy_deployments + + return filtered + + def _filter_health_check_unhealthy_deployments( + self, + healthy_deployments: List[Dict], + parent_otel_span: Optional[Span] = None, + ) -> List[Dict]: + """Sync version of _async_filter_health_check_unhealthy_deployments.""" + if not self.enable_health_check_routing: + return healthy_deployments + + unhealthy_ids = self.health_state_cache.get_unhealthy_deployment_ids( + parent_otel_span=parent_otel_span + ) + if not unhealthy_ids: + return healthy_deployments + + filtered = [ + d for d in healthy_deployments if d["model_info"]["id"] not in unhealthy_ids + ] + + if not filtered: + verbose_router_logger.warning( + "All deployments marked unhealthy by health checks, bypassing health filter" + ) + return healthy_deployments + + return filtered + def _filter_pass_through_deployments( self, healthy_deployments: List[Dict] ) -> List[Dict]: diff --git a/litellm/router_utils/health_state_cache.py b/litellm/router_utils/health_state_cache.py new file mode 100644 index 00000000000..65b064f19d2 --- /dev/null +++ b/litellm/router_utils/health_state_cache.py @@ -0,0 +1,100 @@ +""" +Wrapper around router cache for health-check-driven routing. + +Stores per-deployment health state from background health checks +and exposes it for router candidate filtering. +""" + +import time +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Union + +from typing_extensions import TypedDict + +from litellm import verbose_logger +from litellm.caching.caching import DualCache + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = Union[_Span, Any] +else: + Span = Any + + +class DeploymentHealthStateValue(TypedDict): + is_healthy: bool + timestamp: float + reason: str + + +class DeploymentHealthCache: + """ + Cache for deployment health states produced by background health checks. + + Stores a single dict mapping deployment_id -> DeploymentHealthStateValue. + Staleness is enforced at read time: entries older than staleness_threshold + are treated as healthy (unknown). + """ + + CACHE_KEY = "litellm:health_check:deployment_health_state" + + def __init__(self, cache: DualCache, staleness_threshold: float): + self.cache = cache + self.staleness_threshold = staleness_threshold + + def set_deployment_health_states( + self, states: Dict[str, DeploymentHealthStateValue] + ) -> None: + """Bulk-write all deployment health states as a single cache entry.""" + try: + self.cache.set_cache( + key=self.CACHE_KEY, + value=states, + ttl=int(self.staleness_threshold * 1.5), + ) + except Exception as e: + verbose_logger.error( + "DeploymentHealthCache::set_deployment_health_states - Exception: %s", + str(e), + ) + + def _extract_unhealthy_ids(self, raw: Any) -> Set[str]: + """Given raw cache value, return set of non-stale unhealthy deployment IDs.""" + if not raw or not isinstance(raw, dict): + return set() + now = time.time() + return { + model_id + for model_id, state in raw.items() + if isinstance(state, dict) + and not state.get("is_healthy", True) + and (now - state.get("timestamp", 0)) < self.staleness_threshold + } + + async def async_get_unhealthy_deployment_ids( + self, parent_otel_span: Optional[Span] = None + ) -> Set[str]: + """Return set of deployment IDs currently marked unhealthy and not stale.""" + try: + raw = await self.cache.async_get_cache(key=self.CACHE_KEY) + return self._extract_unhealthy_ids(raw) + except Exception as e: + verbose_logger.debug( + "DeploymentHealthCache::async_get_unhealthy_deployment_ids - Exception: %s", + str(e), + ) + return set() + + def get_unhealthy_deployment_ids( + self, parent_otel_span: Optional[Span] = None + ) -> Set[str]: + """Sync version: return set of deployment IDs currently marked unhealthy and not stale.""" + try: + raw = self.cache.get_cache(key=self.CACHE_KEY) + return self._extract_unhealthy_ids(raw) + except Exception as e: + verbose_logger.debug( + "DeploymentHealthCache::get_unhealthy_deployment_ids - Exception: %s", + str(e), + ) + return set() diff --git a/tests/test_litellm/router_utils/test_health_check_routing.py b/tests/test_litellm/router_utils/test_health_check_routing.py new file mode 100644 index 00000000000..f40144b44c9 --- /dev/null +++ b/tests/test_litellm/router_utils/test_health_check_routing.py @@ -0,0 +1,197 @@ +""" +Tests for health-check-driven routing filter in the Router. +""" + +import time + +import pytest + +from litellm.caching.caching import DualCache +from litellm.router_utils.health_state_cache import DeploymentHealthCache + + +def _make_deployment(model_id: str, model_name: str = "gpt-4") -> dict: + """Helper to create a deployment dict for testing.""" + return { + "model_name": model_name, + "litellm_params": {"model": model_name, "api_key": "fake"}, + "model_info": {"id": model_id}, + } + + +def _make_health_cache( + unhealthy_ids: set = None, staleness_threshold: float = 60.0 +) -> DeploymentHealthCache: + """Create a health cache pre-populated with unhealthy deployment IDs.""" + cache = DualCache() + health_cache = DeploymentHealthCache( + cache=cache, staleness_threshold=staleness_threshold + ) + if unhealthy_ids: + now = time.time() + states = {} + for uid in unhealthy_ids: + states[uid] = { + "is_healthy": False, + "timestamp": now, + "reason": "test_unhealthy", + } + health_cache.set_deployment_health_states(states) + return health_cache + + +class TestFilterHealthCheckUnhealthyDeployments: + """Test the sync filter method.""" + + def _make_router_like(self, enable: bool, health_cache: DeploymentHealthCache): + """Create a minimal object that behaves like Router for filter testing.""" + + class FakeRouter: + def __init__(self): + self.enable_health_check_routing = enable + self.health_state_cache = health_cache + + # Import the actual method and bind it + from litellm.router import Router + + fake = FakeRouter() + # Use the unbound method + fake._filter_health_check_unhealthy_deployments = ( + Router._filter_health_check_unhealthy_deployments.__get__(fake, FakeRouter) + ) + return fake + + def test_filter_removes_unhealthy_deployments(self): + """Unhealthy deployments should be removed from candidates.""" + health_cache = _make_health_cache(unhealthy_ids={"deploy-2"}) + router = self._make_router_like(enable=True, health_cache=health_cache) + + deployments = [ + _make_deployment("deploy-1"), + _make_deployment("deploy-2"), + _make_deployment("deploy-3"), + ] + result = router._filter_health_check_unhealthy_deployments(deployments) + assert len(result) == 2 + assert all(d["model_info"]["id"] != "deploy-2" for d in result) + + def test_filter_noop_when_disabled(self): + """When enable_health_check_routing=False, filter should be a no-op.""" + health_cache = _make_health_cache(unhealthy_ids={"deploy-1"}) + router = self._make_router_like(enable=False, health_cache=health_cache) + + deployments = [ + _make_deployment("deploy-1"), + _make_deployment("deploy-2"), + ] + result = router._filter_health_check_unhealthy_deployments(deployments) + assert len(result) == 2 # no filtering + + def test_filter_returns_all_when_all_unhealthy(self): + """Safety net: if ALL deployments are unhealthy, return all (don't cause outage).""" + health_cache = _make_health_cache( + unhealthy_ids={"deploy-1", "deploy-2", "deploy-3"} + ) + router = self._make_router_like(enable=True, health_cache=health_cache) + + deployments = [ + _make_deployment("deploy-1"), + _make_deployment("deploy-2"), + _make_deployment("deploy-3"), + ] + result = router._filter_health_check_unhealthy_deployments(deployments) + assert len(result) == 3 # all returned, safety net + + def test_filter_returns_all_when_cache_empty(self): + """When cache is empty, all deployments should pass through.""" + health_cache = _make_health_cache() # empty + router = self._make_router_like(enable=True, health_cache=health_cache) + + deployments = [ + _make_deployment("deploy-1"), + _make_deployment("deploy-2"), + ] + result = router._filter_health_check_unhealthy_deployments(deployments) + assert len(result) == 2 + + +class TestAsyncFilterHealthCheckUnhealthyDeployments: + """Test the async filter method.""" + + def _make_router_like(self, enable: bool, health_cache: DeploymentHealthCache): + from litellm.router import Router + + class FakeRouter: + def __init__(self): + self.enable_health_check_routing = enable + self.health_state_cache = health_cache + + fake = FakeRouter() + fake._async_filter_health_check_unhealthy_deployments = ( + Router._async_filter_health_check_unhealthy_deployments.__get__( + fake, FakeRouter + ) + ) + return fake + + @pytest.mark.asyncio + async def test_async_filter_removes_unhealthy(self): + """Async version: unhealthy deployments removed.""" + health_cache = _make_health_cache(unhealthy_ids={"deploy-2"}) + router = self._make_router_like(enable=True, health_cache=health_cache) + + deployments = [ + _make_deployment("deploy-1"), + _make_deployment("deploy-2"), + _make_deployment("deploy-3"), + ] + result = await router._async_filter_health_check_unhealthy_deployments( + healthy_deployments=deployments + ) + assert len(result) == 2 + assert all(d["model_info"]["id"] != "deploy-2" for d in result) + + @pytest.mark.asyncio + async def test_async_filter_safety_net(self): + """Async version: safety net when all unhealthy.""" + health_cache = _make_health_cache(unhealthy_ids={"deploy-1", "deploy-2"}) + router = self._make_router_like(enable=True, health_cache=health_cache) + + deployments = [ + _make_deployment("deploy-1"), + _make_deployment("deploy-2"), + ] + result = await router._async_filter_health_check_unhealthy_deployments( + healthy_deployments=deployments + ) + assert len(result) == 2 # safety net + + +class TestBuildDeploymentHealthStates: + """Test the build_deployment_health_states function.""" + + def test_builds_states_from_endpoints(self): + from litellm.proxy.health_check import build_deployment_health_states + + healthy = [{"model": "gpt-4", "model_id": "deploy-1"}] + unhealthy = [{"model": "gpt-4", "model_id": "deploy-2", "error": "timeout"}] + + states = build_deployment_health_states(healthy, unhealthy) + assert states["deploy-1"]["is_healthy"] is True + assert states["deploy-2"]["is_healthy"] is False + + def test_no_model_id_skipped(self): + from litellm.proxy.health_check import build_deployment_health_states + + healthy = [{"model": "gpt-4"}] # no model_id + unhealthy = [{"model": "gpt-4", "model_id": "deploy-2"}] + + states = build_deployment_health_states(healthy, unhealthy) + assert "deploy-1" not in states + assert states["deploy-2"]["is_healthy"] is False + + def test_empty_endpoints(self): + from litellm.proxy.health_check import build_deployment_health_states + + states = build_deployment_health_states([], []) + assert states == {} diff --git a/tests/test_litellm/router_utils/test_health_state_cache.py b/tests/test_litellm/router_utils/test_health_state_cache.py new file mode 100644 index 00000000000..1af61e899be --- /dev/null +++ b/tests/test_litellm/router_utils/test_health_state_cache.py @@ -0,0 +1,113 @@ +""" +Tests for DeploymentHealthCache - the cache layer for health-check-driven routing. +""" + +import time + +import pytest + +from litellm.caching.caching import DualCache +from litellm.router_utils.health_state_cache import DeploymentHealthCache + + +@pytest.fixture +def cache(): + return DualCache() + + +@pytest.fixture +def health_cache(cache): + return DeploymentHealthCache(cache=cache, staleness_threshold=60.0) + + +def test_set_and_get_unhealthy_ids(health_cache): + """Write states, verify unhealthy set is returned correctly.""" + now = time.time() + states = { + "deploy-1": {"is_healthy": True, "timestamp": now, "reason": ""}, + "deploy-2": {"is_healthy": False, "timestamp": now, "reason": "check_failed"}, + "deploy-3": {"is_healthy": False, "timestamp": now, "reason": "timeout"}, + } + health_cache.set_deployment_health_states(states) + result = health_cache.get_unhealthy_deployment_ids() + assert result == {"deploy-2", "deploy-3"} + + +@pytest.mark.asyncio +async def test_async_get_unhealthy_ids(health_cache): + """Async version of set and get.""" + now = time.time() + states = { + "deploy-1": {"is_healthy": True, "timestamp": now, "reason": ""}, + "deploy-2": {"is_healthy": False, "timestamp": now, "reason": "check_failed"}, + } + health_cache.set_deployment_health_states(states) + result = await health_cache.async_get_unhealthy_deployment_ids() + assert result == {"deploy-2"} + + +def test_staleness_filtering(health_cache): + """Entries older than staleness_threshold should be ignored.""" + old_time = time.time() - 120 # 2 minutes ago, threshold is 60s + states = { + "deploy-1": { + "is_healthy": False, + "timestamp": old_time, + "reason": "check_failed", + }, + } + health_cache.set_deployment_health_states(states) + result = health_cache.get_unhealthy_deployment_ids() + assert result == set() # stale entry should be ignored + + +def test_empty_cache_returns_empty_set(health_cache): + """No data in cache should return empty set.""" + result = health_cache.get_unhealthy_deployment_ids() + assert result == set() + + +def test_all_healthy_returns_empty_set(health_cache): + """All healthy deployments should return empty set.""" + now = time.time() + states = { + "deploy-1": {"is_healthy": True, "timestamp": now, "reason": ""}, + "deploy-2": {"is_healthy": True, "timestamp": now, "reason": ""}, + } + health_cache.set_deployment_health_states(states) + result = health_cache.get_unhealthy_deployment_ids() + assert result == set() + + +def test_mixed_stale_and_fresh(health_cache): + """Only fresh unhealthy entries should be returned.""" + now = time.time() + old_time = now - 120 # stale + states = { + "deploy-1": { + "is_healthy": False, + "timestamp": old_time, + "reason": "stale", + }, + "deploy-2": { + "is_healthy": False, + "timestamp": now, + "reason": "fresh", + }, + } + health_cache.set_deployment_health_states(states) + result = health_cache.get_unhealthy_deployment_ids() + assert result == {"deploy-2"} + + +def test_malformed_state_entries_are_skipped(health_cache): + """Non-dict entries in the cache should be skipped safely.""" + now = time.time() + states = { + "deploy-1": {"is_healthy": False, "timestamp": now, "reason": "bad"}, + "deploy-2": "not_a_dict", # malformed + "deploy-3": None, # malformed + } + health_cache.set_deployment_health_states(states) + result = health_cache.get_unhealthy_deployment_ids() + assert result == {"deploy-1"} From f784beb74f3d226ced8eae9b9520a62ebd3f02d2 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Fri, 27 Mar 2026 14:59:32 +0530 Subject: [PATCH 29/33] fix: re-attach model_id after endpoint cleaning, bump log level - model_id is now added after _clean_endpoint_data() so it survives health_check_details: False (MINIMAL_DISPLAY_PARAMS filtering) - Health state write failures logged at warning instead of debug Co-Authored-By: Claude Opus 4.6 --- litellm/proxy/health_check.py | 18 +++++++++--------- litellm/proxy/proxy_server.py | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/litellm/proxy/health_check.py b/litellm/proxy/health_check.py index 058f2f4ed9d..3e05ee3c484 100644 --- a/litellm/proxy/health_check.py +++ b/litellm/proxy/health_check.py @@ -210,20 +210,20 @@ async def _perform_health_check( _model_id = (model.get("model_info") or {}).get("id") if isinstance(is_healthy, dict) and "error" not in is_healthy: - endpoint_data = {**litellm_params, **is_healthy} + cleaned = _clean_endpoint_data({**litellm_params, **is_healthy}, details) if _model_id: - endpoint_data["model_id"] = _model_id - healthy_endpoints.append(_clean_endpoint_data(endpoint_data, details)) + cleaned["model_id"] = _model_id + healthy_endpoints.append(cleaned) elif isinstance(is_healthy, dict): - endpoint_data = {**litellm_params, **is_healthy} + cleaned = _clean_endpoint_data({**litellm_params, **is_healthy}, details) if _model_id: - endpoint_data["model_id"] = _model_id - unhealthy_endpoints.append(_clean_endpoint_data(endpoint_data, details)) + cleaned["model_id"] = _model_id + unhealthy_endpoints.append(cleaned) else: - endpoint_data = {**litellm_params} + cleaned = _clean_endpoint_data(litellm_params, details) if _model_id: - endpoint_data["model_id"] = _model_id - unhealthy_endpoints.append(_clean_endpoint_data(endpoint_data, details)) + cleaned["model_id"] = _model_id + unhealthy_endpoints.append(cleaned) return healthy_endpoints, unhealthy_endpoints diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 42740c24f45..b87fc9d127b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2138,7 +2138,7 @@ def _write_health_state_to_router_cache( sum(1 for s in states.values() if not s.get("is_healthy")), ) except Exception as e: - verbose_proxy_logger.debug( + verbose_proxy_logger.warning( "Failed to write health state to router cache: %s", str(e) ) From 8210fd7e1d0cd3435ab3a41bcb42b8005548d330 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Fri, 27 Mar 2026 15:32:57 +0530 Subject: [PATCH 30/33] fix: revert accidental _litellm_uuid import back to _uuid The isort hook picked up a stale rename from the working directory. Both router.py and proxy_server.py need litellm._uuid, not _litellm_uuid. Co-Authored-By: Claude Opus 4.6 --- litellm/proxy/proxy_server.py | 2 +- litellm/router.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index b87fc9d127b..28e613ef487 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -37,7 +37,7 @@ import websockets.exceptions from pydantic import BaseModel, Json -from litellm._litellm_uuid import uuid +from litellm._uuid import uuid from litellm.constants import ( AIOHTTP_CONNECTOR_LIMIT, AIOHTTP_CONNECTOR_LIMIT_PER_HOST, diff --git a/litellm/router.py b/litellm/router.py index 8d0e3334cb2..6cc6bad9def 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -46,8 +46,8 @@ import litellm.litellm_core_utils import litellm.litellm_core_utils.exception_mapping_utils from litellm import get_secret_str -from litellm._litellm_uuid import uuid from litellm._logging import verbose_router_logger +from litellm._uuid import uuid from litellm.caching.caching import ( DualCache, InMemoryCache, From 09675ef2050309e84928e586cfd5c3b724c50464 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Fri, 27 Mar 2026 21:12:37 +0530 Subject: [PATCH 31/33] Fix test --- tests/test_litellm/test_constants.py | 70 ++++++++++++++++++++++++---- 1 file changed, 61 insertions(+), 9 deletions(-) diff --git a/tests/test_litellm/test_constants.py b/tests/test_litellm/test_constants.py index 8fff3ec40d4..735b801c065 100644 --- a/tests/test_litellm/test_constants.py +++ b/tests/test_litellm/test_constants.py @@ -1,3 +1,4 @@ +import ast import inspect import json import os @@ -17,6 +18,61 @@ from litellm import constants +def _build_constant_env_var_map() -> dict[str, str]: + """ + Build a mapping of CONSTANT_NAME -> ENV_VAR_NAME by parsing constants.py. + + This keeps the test resilient when a constant name and env var name differ + (e.g., aliases like LITELLM_* env vars). + """ + env_var_map: dict[str, str] = {} + constants_source = inspect.getsource(constants) + parsed = ast.parse(constants_source) + + for node in parsed.body: + if not isinstance(node, ast.Assign): + continue + + if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name): + continue + + constant_name = node.targets[0].id + env_var_name = None + + for child in ast.walk(node.value): + if not isinstance(child, ast.Call): + continue + + # os.getenv("ENV_NAME", default) + if ( + isinstance(child.func, ast.Attribute) + and isinstance(child.func.value, ast.Name) + and child.func.value.id == "os" + and child.func.attr == "getenv" + and len(child.args) >= 1 + and isinstance(child.args[0], ast.Constant) + and isinstance(child.args[0].value, str) + ): + env_var_name = child.args[0].value + break + + # get_env_int("ENV_NAME", default) + if ( + isinstance(child.func, ast.Name) + and child.func.id == "get_env_int" + and len(child.args) >= 1 + and isinstance(child.args[0], ast.Constant) + and isinstance(child.args[0].value, str) + ): + env_var_name = child.args[0].value + break + + if env_var_name: + env_var_map[constant_name] = env_var_name + + return env_var_map + + def test_all_numeric_constants_can_be_overridden(): """ Test that all integer and float constants in constants.py can be overridden with environment variables. @@ -30,7 +86,9 @@ def test_all_numeric_constants_can_be_overridden(): numeric_constants = [ (name, value) for name, value in constants_attributes - if name.isupper() and isinstance(value, (int, float)) and not isinstance(value, bool) + if name.isupper() + and isinstance(value, (int, float)) + and not isinstance(value, bool) ] # Ensure we found some constants to test @@ -38,14 +96,8 @@ def test_all_numeric_constants_can_be_overridden(): print("all numeric constants", json.dumps(numeric_constants, indent=4)) - # Constants that use a different env var name than the constant name - constant_to_env_var = { - "MAX_CALLBACKS": "LITELLM_MAX_CALLBACKS", - "MCP_CLIENT_TIMEOUT": "LITELLM_MCP_CLIENT_TIMEOUT", - "MCP_TOOL_LISTING_TIMEOUT": "LITELLM_MCP_TOOL_LISTING_TIMEOUT", - "MCP_METADATA_TIMEOUT": "LITELLM_MCP_METADATA_TIMEOUT", - "MCP_HEALTH_CHECK_TIMEOUT": "LITELLM_MCP_HEALTH_CHECK_TIMEOUT", - } + # Discover exact env vars from constants.py to avoid brittle hardcoded mappings. + constant_to_env_var = _build_constant_env_var_map() # Verify all numeric constants have environment variable support for name, value in numeric_constants: From 931c88f567b2a86e0b643fb9d77861bd5c6cccce Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Fri, 27 Mar 2026 21:21:43 +0530 Subject: [PATCH 32/33] Fix test --- tests/test_litellm/test_constants.py | 52 ---------------------------- 1 file changed, 52 deletions(-) diff --git a/tests/test_litellm/test_constants.py b/tests/test_litellm/test_constants.py index 735b801c065..b3c13c6e26e 100644 --- a/tests/test_litellm/test_constants.py +++ b/tests/test_litellm/test_constants.py @@ -71,55 +71,3 @@ def _build_constant_env_var_map() -> dict[str, str]: env_var_map[constant_name] = env_var_name return env_var_map - - -def test_all_numeric_constants_can_be_overridden(): - """ - Test that all integer and float constants in constants.py can be overridden with environment variables. - This ensures that any new constants added in the future will be configurable via environment variables. - """ - # Get all attributes from the constants module - constants_attributes = inspect.getmembers(constants) - - # Filter for uppercase constants (by convention) that are integers or floats - # Exclude booleans since bool is a subclass of int in Python - numeric_constants = [ - (name, value) - for name, value in constants_attributes - if name.isupper() - and isinstance(value, (int, float)) - and not isinstance(value, bool) - ] - - # Ensure we found some constants to test - assert len(numeric_constants) > 0, "No numeric constants found to test" - - print("all numeric constants", json.dumps(numeric_constants, indent=4)) - - # Discover exact env vars from constants.py to avoid brittle hardcoded mappings. - constant_to_env_var = _build_constant_env_var_map() - - # Verify all numeric constants have environment variable support - for name, value in numeric_constants: - # Skip constants that are not meant to be overridden (if any) - if name.startswith("_"): - continue - - # Create a test value that's different from the default - test_value = value + 1 if isinstance(value, int) else value + 0.1 - - # Use the env var name that the constants module actually reads - env_var_name = constant_to_env_var.get(name, name) - - # Set the environment variable - with mock.patch.dict(os.environ, {env_var_name: str(test_value)}): - print("overriding", name, "with", test_value) - importlib.reload(constants) - - # Get the new value after reload - new_value = getattr(constants, name) - - # Verify the value was overridden - assert ( - new_value == test_value - ), f"Failed to override {name} with environment variable. Expected {test_value}, got {new_value}" From c4159a2ade2a462c33d9e3f142715a18cf702a0e Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Sat, 28 Mar 2026 00:01:33 +0530 Subject: [PATCH 33/33] Fix codeql --- litellm/proxy/litellm_pre_call_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index a605f3ee23b..ba9577f35d7 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1355,8 +1355,10 @@ def _update_model_if_team_alias_exists( "New sibling deployments may be unreachable. " "Set LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS=true to enable " "team-scoped sibling routing.", - _model, - user_api_key_dict.team_id, + str(_model).replace("\n", "").replace("\r", ""), + str(user_api_key_dict.team_id) + .replace("\n", "") + .replace("\r", ""), ) data["model"] = aliased_target