diff --git a/docs/my-website/docs/proxy/config_settings.md b/docs/my-website/docs/proxy/config_settings.md index 02f5c2be9c7..dce979ab89f 100644 --- a/docs/my-website/docs/proxy/config_settings.md +++ b/docs/my-website/docs/proxy/config_settings.md @@ -804,6 +804,7 @@ router_settings: | LITELLM_OTEL_INTEGRATION_ENABLE_EVENTS | Optionally enable semantic logs for OTEL | LITELLM_OTEL_INTEGRATION_ENABLE_METRICS | Optionally enable emantic metrics for OTEL | LITELLM_ENABLE_PYROSCOPE | If true, enables Pyroscope CPU profiling. Profiles are sent to PYROSCOPE_SERVER_ADDRESS. Off by default. See [Pyroscope profiling](/proxy/pyroscope_profiling). +| LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS | When `true`, if a team's legacy `model_aliases` entry maps a public model name to an internal `model_name__` deployment, pre-call handling can skip that rewrite when team-scoped sibling deployments exist for the public name—so load balancing / `order` apply across siblings. Default is `false` for backwards compatibility. See [Team-scoped models and legacy aliases](./load_balancing#team-scoped-models-and-legacy-model_aliases). When stale aliases are detected and this flag is off, the proxy may log a one-time warning. | PYROSCOPE_APP_NAME | Application name reported to Pyroscope. Required when LITELLM_ENABLE_PYROSCOPE is true. No default. | PYROSCOPE_SERVER_ADDRESS | Pyroscope server URL to send profiles to. Required when LITELLM_ENABLE_PYROSCOPE is true. No default. | PYROSCOPE_SAMPLE_RATE | Optional. Sample rate for Pyroscope profiling (integer). No default; when unset, the pyroscope-io library default is used. diff --git a/docs/my-website/docs/proxy/health.md b/docs/my-website/docs/proxy/health.md index 2764a6f0d4f..530bea3d06b 100644 --- a/docs/my-website/docs/proxy/health.md +++ b/docs/my-website/docs/proxy/health.md @@ -314,6 +314,89 @@ general_settings: health_check_details: False ``` +## Health Check Driven Routing + +By default, background health checks are observability-only — they populate the `/health` endpoint but don't affect routing. Unhealthy deployments still receive traffic until request failures trigger cooldown. + +With `enable_health_check_routing: true`, the router **excludes deployments that failed their last background health check** before selecting a candidate. This gives you proactive failover instead of reactive cooldown. + +### How it works + +1. Background health checks run on their configured interval +2. After each cycle, every deployment is marked healthy or unhealthy +3. On each incoming request, the router filters out unhealthy deployments **before** cooldown filtering and load balancing +4. If all deployments are unhealthy, the filter is bypassed (safety net — never causes a total outage) +5. If health state is stale (older than `health_check_staleness_threshold`), it is ignored + +### Quick start + +```yaml +model_list: + - model_name: gpt-4 + litellm_params: + model: openai/gpt-4 + api_key: os.environ/OPENAI_API_KEY + - model_name: gpt-4 + litellm_params: + model: openai/gpt-4 + api_key: os.environ/OPENAI_API_KEY_SECONDARY + +general_settings: + background_health_checks: true + health_check_interval: 60 + enable_health_check_routing: true +``` + +### Configuration + +| Setting | Where | Default | Description | +|---------|-------|---------|-------------| +| `enable_health_check_routing` | `general_settings` | `false` | Enable/disable health-check-driven routing | +| `health_check_staleness_threshold` | `general_settings` | `health_check_interval * 2` | Seconds before health state is considered stale and ignored | +| `background_health_checks` | `general_settings` | `false` | Must be `true` for health check routing to work | +| `health_check_interval` | `general_settings` | `300` | Seconds between health check cycles | + +### Interaction with cooldown + +Health check filtering and cooldown are **additive**. A deployment can be excluded by either mechanism: + +- **Health check filter** — proactive, runs on the configured interval, excludes deployments that failed the last check +- **Cooldown** — reactive, triggered by request failures, excludes deployments for a short TTL + +This means request failures still provide fast detection between health check intervals. + +### Staleness + +If a health check result is older than `health_check_staleness_threshold`, it is ignored and the deployment is treated as eligible. This prevents stale data from permanently excluding a deployment if the health check loop stops or slows down. + +The default staleness threshold is `health_check_interval * 2`. For a 60s interval, health state expires after 120s. + +### Example: custom staleness + +```yaml +general_settings: + background_health_checks: true + health_check_interval: 30 + enable_health_check_routing: true + health_check_staleness_threshold: 90 # ignore health state older than 90s +``` + +### Debugging + +Run the proxy with `--detailed_debug` and look for: + +``` +health_check_routing_state_updated healthy=3 unhealthy=1 +``` + +This is logged after each health check cycle when routing state is written. + +If the safety net triggers (all deployments unhealthy), you'll see: + +``` +All deployments marked unhealthy by health checks, bypassing health filter +``` + ## Health Check Timeout The health check timeout is set in `litellm/constants.py` and defaults to 60 seconds. diff --git a/docs/my-website/docs/proxy/load_balancing.md b/docs/my-website/docs/proxy/load_balancing.md index 74b3e8a5117..93f3d944340 100644 --- a/docs/my-website/docs/proxy/load_balancing.md +++ b/docs/my-website/docs/proxy/load_balancing.md @@ -324,17 +324,58 @@ model_list: litellm_params: model: azure/gpt-4-fallback api_key: os.environ/AZURE_API_KEY_2 - order: 2 # 👈 Used when order=1 is unavailable + order: 2 # 👈 Used when order=1 fails +``` + +### How order-based fallback works + +When a request to an `order=1` deployment fails (connection error, 404, 429, etc.), the router automatically tries `order=2` deployments, then `order=3`, and so on. Each order level gets its own set of retries before escalating to the next. + +If all order levels are exhausted, the router falls through to any configured [model-level fallbacks](#fallbacks). + +```yaml +model_list: + - model_name: gpt-4 + litellm_params: + model: azure/gpt-4-primary + api_key: os.environ/AZURE_API_KEY + order: 1 + + - model_name: gpt-4 + litellm_params: + model: azure/gpt-4-secondary + api_key: os.environ/AZURE_API_KEY_2 + order: 2 + + - model_name: gpt-4-fallback + litellm_params: + model: openai/gpt-4 + api_key: os.environ/OPENAI_API_KEY router_settings: - enable_pre_call_checks: true # 👈 Required for 'order' to work + fallbacks: + - gpt-4: + - gpt-4-fallback # tried after all order levels fail ``` -:::important -The `order` parameter requires `enable_pre_call_checks: true` in `router_settings`. -::: +The fallback chain for the above config: `order=1` → `order=2` → `gpt-4-fallback`. + +For 429 (rate limit) errors specifically, the failed deployment is immediately placed on cooldown. If all `order=1` deployments are on cooldown, the router picks `order=2` deployments directly during retries without waiting for the fallback path. + +### Team-scoped models and legacy `model_aliases` {#team-scoped-models-and-legacy-model_aliases} + +Team-scoped deployments are identified by `model_info.team_id` and `model_info.team_public_model_name`. Requests should use the **public** model name; the router resolves all sibling deployments (same public name, different `api_base` / `order`, etc.) for routing, failover, and deployment `order`. + +For router internals: when a `team_id` is in scope, optimized lookups key off `(team_id, team_public_model_name)`. If code passes an internal deployment id (e.g. `model_name__`) instead of the public name, routing still works via the usual deployment-name paths, but the team-specific fast path applies only to the public name. + +**Legacy teams:** Older proxy versions could persist `model_aliases` on the team row mapping a public name to a single internal deployment id (`model_name__`). On each request, pre-call logic may still rewrite `model` to that internal name **before** routing, which collapses to one deployment and can make newer sibling deployments unreachable. + +**Migration options:** + +1. **Recommended for upgrades:** Set environment variable `LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS=true` so that when sibling team deployments exist for the public name, the stale alias rewrite is skipped and team-scoped routing (including `order` and failover) applies. See the [Environment variables](./config_settings) table in the proxy settings doc. +2. **Data cleanup:** Remove obsolete `model_aliases` entries for team public names from the team record in the database so only `team_public_model_name` + team model list drive access. -If `order=1` deployment is unavailable (e.g., rate-limited), the router falls back to `order=2` deployments. +If a stale alias is detected and the bypass is **not** enabled, the proxy may emit a **one-time** warning in logs explaining that sibling deployments may be unreachable until the flag is set or aliases are cleaned up. ### When You'll See Load Balancing in Action diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md index 67e7f681147..5aa655ae212 100644 --- a/docs/my-website/docs/routing.md +++ b/docs/my-website/docs/routing.md @@ -842,6 +842,8 @@ Traffic mirroring allows you to "mimic" production traffic to a secondary (silen Set `order` in `litellm_params` to prioritize deployments. Lower values = higher priority. When multiple deployments share the same `order`, the routing strategy picks among them. +When a request to an `order=1` deployment fails (connection error, 404, 429, etc.), the router automatically tries `order=2` deployments, then `order=3`, and so on. Each order level gets its own set of retries before escalating to the next. If all order levels are exhausted, the router falls through to any configured [fallbacks](#fallbacks). + @@ -862,18 +864,14 @@ model_list = [ "litellm_params": { "model": "azure/gpt-4-fallback", "api_key": os.getenv("AZURE_API_KEY_2"), - "order": 2, # 👈 Used when order=1 is unavailable + "order": 2, # 👈 Tried when order=1 fails }, }, ] -router = Router(model_list=model_list, enable_pre_call_checks=True) # 👈 Required for 'order' to work +router = Router(model_list=model_list) ``` -:::important -The `order` parameter requires `enable_pre_call_checks=True` to be set on the Router. -::: - @@ -889,10 +887,7 @@ model_list: litellm_params: model: azure/gpt-4-fallback api_key: os.environ/AZURE_API_KEY_2 - order: 2 # 👈 Used when order=1 is unavailable - -router_settings: - enable_pre_call_checks: true # 👈 Required for 'order' to work + order: 2 # 👈 Tried when order=1 fails ``` diff --git a/litellm/constants.py b/litellm/constants.py index 423f01afac1..252068bd7b0 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -1402,6 +1402,9 @@ DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL = int( os.getenv("DEFAULT_SHARED_HEALTH_CHECK_LOCK_TTL", 60) ) # 1 minute - TTL for health check lock +DEFAULT_HEALTH_CHECK_STALENESS_MULTIPLIER = ( + 2 # health state is stale after interval * this +) PROMETHEUS_FALLBACK_STATS_SEND_TIME_HOURS = int( os.getenv("PROMETHEUS_FALLBACK_STATS_SEND_TIME_HOURS", 9) ) diff --git a/litellm/proxy/health_check.py b/litellm/proxy/health_check.py index a8d0e3e9af2..3e05ee3c484 100644 --- a/litellm/proxy/health_check.py +++ b/litellm/proxy/health_check.py @@ -207,21 +207,65 @@ async def _perform_health_check( for is_healthy, model in zip(results, model_list): litellm_params = model["litellm_params"] + _model_id = (model.get("model_info") or {}).get("id") if isinstance(is_healthy, dict) and "error" not in is_healthy: - healthy_endpoints.append( - _clean_endpoint_data({**litellm_params, **is_healthy}, details) - ) + cleaned = _clean_endpoint_data({**litellm_params, **is_healthy}, details) + if _model_id: + cleaned["model_id"] = _model_id + healthy_endpoints.append(cleaned) elif isinstance(is_healthy, dict): - unhealthy_endpoints.append( - _clean_endpoint_data({**litellm_params, **is_healthy}, details) - ) + cleaned = _clean_endpoint_data({**litellm_params, **is_healthy}, details) + if _model_id: + cleaned["model_id"] = _model_id + unhealthy_endpoints.append(cleaned) else: - unhealthy_endpoints.append(_clean_endpoint_data(litellm_params, details)) + cleaned = _clean_endpoint_data(litellm_params, details) + if _model_id: + cleaned["model_id"] = _model_id + unhealthy_endpoints.append(cleaned) return healthy_endpoints, unhealthy_endpoints +def build_deployment_health_states( + healthy_endpoints: list, + unhealthy_endpoints: list, +) -> dict: + """ + Build a dict mapping deployment_id -> DeploymentHealthStateValue from + health check endpoint results. + + Each endpoint dict includes a 'model_id' field (added by _perform_health_check) + that maps back to the deployment's model_info.id. + + Used by the background health check loop to feed health state into + the router's DeploymentHealthCache for health-check-driven routing. + """ + now = time.time() + states: dict = {} + + for ep in healthy_endpoints: + model_id = ep.get("model_id") + if model_id: + states[model_id] = { + "is_healthy": True, + "timestamp": now, + "reason": "", + } + + for ep in unhealthy_endpoints: + model_id = ep.get("model_id") + if model_id: + states[model_id] = { + "is_healthy": False, + "timestamp": now, + "reason": "background_health_check_failed", + } + + return states + + def _update_litellm_params_for_health_check( model_info: dict, litellm_params: dict ) -> dict: diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 4ca0d876a1c..ba9577f35d7 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1,6 +1,7 @@ import asyncio import copy import time +from collections import OrderedDict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from fastapi import Request @@ -26,6 +27,7 @@ v.value.lower() for v in SpecialHeaders._member_map_.values() ) from litellm.router import Router +from litellm.secret_managers.main import get_secret_bool from litellm.types.llms.anthropic import ANTHROPIC_API_HEADERS from litellm.types.services import ServiceTypes from litellm.types.utils import ( @@ -36,6 +38,11 @@ ) service_logger_obj = ServiceLogging() # used for tracking latency on OTEL +# Bounded dedup for stale-alias warnings (FIFO eviction when over cap). +_MAX_STALE_ALIAS_WARNING_KEYS = 10_000 +_STALE_TEAM_ALIAS_WARNING_KEYS: OrderedDict[str, None] = OrderedDict() +# Cache the stale alias bypass flag at module load to avoid hot-path secret lookups +_ENABLE_TEAM_STALE_ALIAS_BYPASS: Optional[bool] = None if TYPE_CHECKING: @@ -1296,6 +1303,10 @@ def _update_model_if_team_alias_exists( "gpt-4o": "gpt-4o-team-1" } - requested_model = "gpt-4o-team-1" + + Note: model_aliases for team models are deprecated. This function only applies + to legacy non-team-scoped aliases. Team-scoped deployments use team_public_model_name + and are resolved via map_team_model in route_llm_request. """ _model = data.get("model") if ( @@ -1303,7 +1314,54 @@ def _update_model_if_team_alias_exists( and user_api_key_dict.team_model_aliases and _model in user_api_key_dict.team_model_aliases ): - data["model"] = user_api_key_dict.team_model_aliases[_model] + from litellm.proxy.proxy_server import llm_router + + # Skip alias rewrite if this model resolves to team-specific deployments + # (team models use team_public_model_name, not model_aliases) + aliased_target = user_api_key_dict.team_model_aliases[_model] + + # Optional bypass for stale aliases from pre-PR deployments: + # only enabled via feature flag to preserve backwards compatibility. + # Cached at module level to avoid hot-path secret lookups on every request. + global _ENABLE_TEAM_STALE_ALIAS_BYPASS + if _ENABLE_TEAM_STALE_ALIAS_BYPASS is None: + _ENABLE_TEAM_STALE_ALIAS_BYPASS = get_secret_bool( + "LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS", False + ) + enable_stale_alias_bypass = _ENABLE_TEAM_STALE_ALIAS_BYPASS + # Check if the alias points to a team-scoped UUID name + # (format: "model_name_{team_id}_{uuid}") + is_stale_team_alias = aliased_target.startswith( + f"model_name_{user_api_key_dict.team_id}_" + ) + if is_stale_team_alias and llm_router: + # This is a stale alias from pre-PR deployments. + # Check if current team deployments exist for the public name. + key = (user_api_key_dict.team_id, _model) + if key in llm_router.team_model_to_deployment_indices: + if enable_stale_alias_bypass: + # Team deployments exist; skip stale alias + return + warning_key = f"{user_api_key_dict.team_id}:{_model}:{aliased_target}" + if warning_key not in _STALE_TEAM_ALIAS_WARNING_KEYS: + _STALE_TEAM_ALIAS_WARNING_KEYS[warning_key] = None + while ( + len(_STALE_TEAM_ALIAS_WARNING_KEYS) + > _MAX_STALE_ALIAS_WARNING_KEYS + ): + _STALE_TEAM_ALIAS_WARNING_KEYS.popitem(last=False) + verbose_proxy_logger.warning( + "Stale team model alias detected for model='%s', team_id='%s'. " + "New sibling deployments may be unreachable. " + "Set LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS=true to enable " + "team-scoped sibling routing.", + str(_model).replace("\n", "").replace("\r", ""), + str(user_api_key_dict.team_id) + .replace("\n", "") + .replace("\r", ""), + ) + + data["model"] = aliased_target return diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index 44d41097833..2ca8e3daba3 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -13,13 +13,13 @@ import asyncio import datetime import json -from litellm._uuid import uuid from typing import Dict, List, Literal, Optional, Tuple, Union, cast from fastapi import APIRouter, Depends, HTTPException, Request, status from pydantic import BaseModel, ConfigDict, Field from litellm._logging import verbose_proxy_logger +from litellm._uuid import uuid from litellm.constants import LITELLM_PROXY_ADMIN_NAME from litellm.proxy._types import ( CommonProxyErrors, @@ -32,7 +32,7 @@ ProxyErrorTypes, ProxyException, TeamModelAddRequest, - UpdateTeamRequest, + TeamModelDeleteRequest, UserAPIKeyAuth, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth @@ -40,7 +40,10 @@ from litellm.proxy.management_endpoints.common_utils import _is_user_team_admin from litellm.proxy.management_endpoints.team_endpoints import ( team_model_add, - update_team, + team_model_delete, +) +from litellm.proxy.management_endpoints.team_endpoints import ( + update_team as _legacy_update_team, ) from litellm.proxy.management_helpers.audit_logs import create_object_audit_log from litellm.proxy.utils import PrismaClient @@ -58,6 +61,14 @@ router = APIRouter() +async def update_team(*args, **kwargs): + """ + Backward-compatible shim for tests/legacy call sites that patch this symbol. + Team model management now uses team_model_add/team_model_delete directly. + """ + return await _legacy_update_team(*args, **kwargs) + + class UpdatePublicModelGroupsRequest(BaseModel): """Request model for updating public model groups""" @@ -324,17 +335,24 @@ async def _add_team_model_to_db( - generate a unique 'model_name' for the model (e.g. 'model_name_{team_id}_{uuid}) - store the model in the db with the unique 'model_name' - - store a team model alias mapping {"model_name": "model_name_{team_id}_{uuid}"} + - add the public model name to the team's allowed models list """ _team_id = model_params.model_info.team_id if _team_id is None: return None + + # Capture the original public name FIRST, before any mutations original_model_name = model_params.model_name + + # Set team_public_model_name in model_info using the captured original_model_name + # This must happen BEFORE mutating model_params.model_name so _add_model_to_db + # serializes the correct team_public_model_name (not the internal UUID name) if original_model_name: model_params.model_info.team_public_model_name = original_model_name + # Generate and assign unique internal model_name LAST + # (after team_public_model_name is safely stored) unique_model_name = f"model_name_{_team_id}_{uuid.uuid4()}" - model_params.model_name = unique_model_name ## CREATE MODEL IN DB ## @@ -344,25 +362,15 @@ async def _add_team_model_to_db( prisma_client=prisma_client, ) - ## CREATE MODEL ALIAS IN DB ## - await update_team( - data=UpdateTeamRequest( - team_id=_team_id, - model_aliases={original_model_name: unique_model_name}, - ), - user_api_key_dict=user_api_key_dict, - http_request=Request(scope={"type": "http"}), - ) - - # add model to team object - await team_model_add( - data=TeamModelAddRequest( - team_id=_team_id, - models=[original_model_name], - ), - http_request=Request(scope={"type": "http"}), - user_api_key_dict=user_api_key_dict, - ) + if original_model_name: + await team_model_add( + data=TeamModelAddRequest( + team_id=_team_id, + models=[original_model_name], + ), + http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, + ) return model_response @@ -428,6 +436,7 @@ async def _update_team_model_in_db( db_model=db_model, patch_data=patch_data, user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, ) return update_db_model(db_model=db_model, updated_patch=patch_data) @@ -453,19 +462,10 @@ async def _setup_new_team_model_assignment( patch_data: updateDeployment, user_api_key_dict: UserAPIKeyAuth, ) -> None: - """Set up a new team model with unique name, alias, and team membership.""" + """Set up a new team model with unique name and team membership.""" unique_model_name = f"model_name_{team_id}_{uuid.uuid4()}" patch_data.model_name = unique_model_name - await update_team( - data=UpdateTeamRequest( - team_id=team_id, - model_aliases={public_model_name: unique_model_name}, - ), - user_api_key_dict=user_api_key_dict, - http_request=Request(scope={"type": "http"}), - ) - await team_model_add( data=TeamModelAddRequest( team_id=team_id, @@ -476,30 +476,119 @@ async def _setup_new_team_model_assignment( ) +async def _get_team_deployments( + team_id: str, prisma_client: PrismaClient +) -> List[LiteLLM_ProxyModelTable]: + """ + Fetch all deployments for a given team_id from the database. + + Centralizes team deployment queries to ensure consistent filtering and error handling. + This is the established helper pattern for team deployment DB access in this module. + + Note: Direct Prisma call is intentional here as this IS the helper function that + encapsulates the DB access pattern for team deployments. + """ + response = await prisma_client.db.litellm_proxymodeltable.find_many( + where={ + "model_info": { + "path": ["team_id"], + "equals": team_id, + } + } + ) + return response if response else [] + + async def _update_existing_team_model_assignment( team_id: str, public_model_name: str, db_model: Deployment, patch_data: updateDeployment, user_api_key_dict: UserAPIKeyAuth, + prisma_client: Optional[PrismaClient], ) -> None: - """Update an existing team model if the public name changed.""" + """Update an existing team model if the public name changed. + + Note on DB scan: Prisma's JSON filtering does not support compound AND conditions + across multiple JSON paths, so we fetch all deployments for the team and filter + team_public_model_name in Python. For teams with many deployments this scan grows + linearly; if team deployment counts become large this should be revisited. + """ + + def _get_team_public_model_name( + model_info: Optional[Union[dict, str]] + ) -> Optional[str]: + if isinstance(model_info, dict): + value = model_info.get("team_public_model_name") + return value if isinstance(value, str) else None + if isinstance(model_info, str): + try: + parsed = json.loads(model_info) + except (TypeError, ValueError): + return None + if isinstance(parsed, dict): + value = parsed.get("team_public_model_name") + return value if isinstance(value, str) else None + return None + old_public_name = ( db_model.model_info.team_public_model_name if db_model.model_info else None ) - # Update alias only if public name changed if old_public_name and public_model_name != old_public_name: - await update_team( - data=UpdateTeamRequest( + # Clear user-supplied public name from patch before any early return so the + # caller does not overwrite the internal UUID-based model_name in the DB. + patch_data.model_name = None + if prisma_client is None: + verbose_proxy_logger.warning( + "prisma_client not initialized; skipping public name update entirely to avoid orphaned entries" + ) + return + + # Query DB for all team deployments to check for sibling deployments + team_deployments = await _get_team_deployments(team_id, prisma_client) + other_deployments_with_old_name = [ + d + for d in team_deployments + if d.model_name != db_model.model_name + and _get_team_public_model_name(d.model_info) == old_public_name + ] + + # Add new name first, then delete old name to prevent access loss on partial failure + await team_model_add( + data=TeamModelAddRequest( team_id=team_id, - model_aliases={public_model_name: db_model.model_name}, + models=[public_model_name], ), + http_request=Request(scope={"type": "http"}), user_api_key_dict=user_api_key_dict, + ) + + if not other_deployments_with_old_name: + await team_model_delete( + data=TeamModelDeleteRequest( + team_id=team_id, + models=[old_public_name], + ), + http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, + ) + elif not old_public_name and public_model_name: + # First-time assignment of public name on an existing team deployment: + # ensure the team's models list is updated so team routing can resolve it. + await team_model_add( + data=TeamModelAddRequest( + team_id=team_id, + models=[public_model_name], + ), http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, ) + # else: old_public_name == public_model_name (no rename needed) + # No team_model_add/delete calls required; public name is already registered - # Keep existing unique model_name + # Always clear patch_data.model_name to prevent caller from overwriting + # the internal UUID-based model_name in the DB with the user-supplied public name patch_data.model_name = None diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7d3d2ceb533..28e613ef487 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -480,11 +480,11 @@ def generate_feedback_box(): router as search_tool_management_router, ) from litellm.proxy.spend_tracking.cloudzero_endpoints import router as cloudzero_router -from litellm.proxy.spend_tracking.vantage_endpoints import router as vantage_router from litellm.proxy.spend_tracking.spend_management_endpoints import ( router as spend_management_router, ) from litellm.proxy.spend_tracking.spend_tracking_utils import get_logging_payload +from litellm.proxy.spend_tracking.vantage_endpoints import router as vantage_router from litellm.proxy.types_utils.utils import get_instance_fn from litellm.proxy.ui_crud_endpoints.proxy_setting_endpoints import ( router as ui_crud_endpoints_router, @@ -2112,6 +2112,37 @@ def _schedule_background_health_check_db_save( ) +def _write_health_state_to_router_cache( + healthy_endpoints: list, + unhealthy_endpoints: list, +) -> None: + """ + Write deployment health states to the router's health state cache + for health-check-driven routing. No-op if the feature is disabled. + """ + from litellm.proxy.health_check import build_deployment_health_states + + try: + if llm_router is None or not llm_router.enable_health_check_routing: + return + + states = build_deployment_health_states( + healthy_endpoints=healthy_endpoints, + unhealthy_endpoints=unhealthy_endpoints, + ) + if states: + llm_router.health_state_cache.set_deployment_health_states(states) + verbose_proxy_logger.debug( + "health_check_routing_state_updated healthy=%d unhealthy=%d", + sum(1 for s in states.values() if s.get("is_healthy")), + sum(1 for s in states.values() if not s.get("is_healthy")), + ) + except Exception as e: + verbose_proxy_logger.warning( + "Failed to write health state to router cache: %s", str(e) + ) + + async def _run_background_health_check(): """ Periodically run health checks in the background on the endpoints. @@ -2281,6 +2312,9 @@ async def _run_background_health_check(): unhealthy_endpoints, ) + # Write health state to router cache for health-check-driven routing + _write_health_state_to_router_cache(healthy_endpoints, unhealthy_endpoints) + await asyncio.sleep(health_check_interval) @@ -3048,6 +3082,8 @@ async def load_config( # noqa: PLR0915 general_settings = config.get("general_settings", {}) if general_settings is None: general_settings = {} + _enable_hc_routing = False + _hc_staleness = None if general_settings: ### LOAD KEY MANAGEMENT SETTINGS FIRST (needed for custom secret manager) ### key_management_settings = general_settings.get( @@ -3227,13 +3263,21 @@ async def load_config( # noqa: PLR0915 "health_check_concurrency", None ) health_check_details = general_settings.get("health_check_details", True) + # Health-check-driven routing (opt-in, passes through to Router later) + _enable_hc_routing = general_settings.get( + "enable_health_check_routing", False + ) + _hc_staleness = general_settings.get( + "health_check_staleness_threshold", None + ) verbose_proxy_logger.info( - "background_health_check_config enabled=%s shared=%s interval_seconds=%s max_concurrency=%s details=%s", + "background_health_check_config enabled=%s shared=%s interval_seconds=%s max_concurrency=%s details=%s health_check_routing=%s", use_background_health_checks, use_shared_health_check, health_check_interval, health_check_concurrency, health_check_details, + _enable_hc_routing, ) ### RBAC ### @@ -3263,6 +3307,11 @@ async def load_config( # noqa: PLR0915 "cache_responses": litellm.cache is not None, # cache if user passed in cache values } + # Health-check-driven routing params (from general_settings) + if _enable_hc_routing: + router_params["enable_health_check_routing"] = True + if _hc_staleness is not None: + router_params["health_check_staleness_threshold"] = _hc_staleness ## MODEL LIST model_list = config.get("model_list", None) if model_list: diff --git a/litellm/router.py b/litellm/router.py index 25e5c9cb5d9..6cc6bad9def 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -54,7 +54,11 @@ RedisCache, RedisClusterCache, ) -from litellm.constants import DEFAULT_MAX_LRU_CACHE_SIZE +from litellm.constants import ( + DEFAULT_HEALTH_CHECK_INTERVAL, + DEFAULT_HEALTH_CHECK_STALENESS_MULTIPLIER, + DEFAULT_MAX_LRU_CACHE_SIZE, +) from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.asyncify import run_async_function from litellm.litellm_core_utils.core_helpers import ( @@ -113,6 +117,7 @@ async_raise_no_deployment_exception, send_llm_exception_alert, ) +from litellm.router_utils.health_state_cache import DeploymentHealthCache from litellm.router_utils.pre_call_checks.deployment_affinity_check import ( DeploymentAffinityCheck, ) @@ -303,6 +308,8 @@ def __init__( # noqa: PLR0915 deployment_affinity_ttl_seconds: int = 3600, model_group_affinity_config: Optional[Dict[str, List[str]]] = None, ignore_invalid_deployments: bool = False, + enable_health_check_routing: bool = False, + health_check_staleness_threshold: Optional[int] = None, ) -> None: """ Initialize the Router class with the given parameters for caching, reliability, and routing strategy. @@ -467,6 +474,8 @@ def __init__( # noqa: PLR0915 # Initialize model name to deployment indices mapping for O(1) lookups # Maps model_name -> list of indices in model_list self.model_name_to_deployment_indices: Dict[str, List[int]] = {} + # Maps (team_id, team_public_model_name) -> list of indices in model_list + self.team_model_to_deployment_indices: Dict[Tuple[str, str], List[int]] = {} if model_list is not None: # set_model_list will build indices automatically @@ -491,6 +500,13 @@ def __init__( # noqa: PLR0915 cache=self.cache, default_cooldown_time=self.cooldown_time ) self.disable_cooldowns = disable_cooldowns + self.enable_health_check_routing = enable_health_check_routing + _staleness = health_check_staleness_threshold or ( + DEFAULT_HEALTH_CHECK_INTERVAL * DEFAULT_HEALTH_CHECK_STALENESS_MULTIPLIER + ) + self.health_state_cache = DeploymentHealthCache( + cache=self.cache, staleness_threshold=float(_staleness) + ) self.failed_calls = ( InMemoryCache() ) # cache to track failed call per deployment, if num failed calls within 1 minute > allowed fails, then add it to cooldown @@ -5288,6 +5304,64 @@ async def async_function_with_fallbacks_common_utils( # noqa: PLR0915 if "fallback_depth" not in input_kwargs: input_kwargs["fallback_depth"] = 0 + # ORDER-BASED FALLBACKS: prepend higher order levels to the fallback list + # Skip for error types that have their own dedicated fallback handlers + _skip_order_fallback = isinstance( + e, + (litellm.ContextWindowExceededError, litellm.ContentPolicyViolationError), + ) + all_deployments = self._get_all_deployments(model_name=original_model_group) + _order_set: set = { + d.get("litellm_params", {}).get("order") + for d in all_deployments + if d.get("litellm_params", {}).get("order") is not None + } + order_values: list = sorted(_order_set) + if len(order_values) > 1 and not _skip_order_fallback: + # Determine which order levels have already been tried + current_target = kwargs.get("_target_order") + skip_up_to = ( + current_target if current_target is not None else order_values[0] + ) + # Build order-based fallback entries (skip already-tried levels) + order_fallback_entries: List = [ + {"model": original_model_group, "_target_order": o} + for o in order_values + if o > skip_up_to + ] + # Get external fallbacks — handle both standard and non-standard formats + external_fallback_group: Optional[List] = None + if fallbacks is not None and model_group is not None: + if _check_non_standard_fallback_format(fallbacks=fallbacks): + # Non-standard formats (e.g. ["claude-3-haiku"] or + # [{"model": "...", "messages": [...]}]) are passed through directly + external_fallback_group = fallbacks + else: + external_fallback_group, generic_idx = get_fallback_model_group( + fallbacks=fallbacks, + model_group=cast(str, model_group), + ) + if external_fallback_group is None and generic_idx is not None: + external_fallback_group = fallbacks[generic_idx]["*"] + + # Combined list: order fallbacks first, then external + combined_fallbacks = order_fallback_entries + ( + external_fallback_group or [] + ) + + if combined_fallbacks: + input_kwargs.update( + { + "fallback_model_group": combined_fallbacks, + "original_model_group": original_model_group, + } + ) + response = await run_async_fallback( + *args, + **input_kwargs, + ) + return response + try: verbose_router_logger.info("Trying to fallback b/w models") @@ -6835,6 +6909,7 @@ def set_model_list(self, model_list: list): self.model_list = [] self.model_id_to_deployment_index_map = {} # Reset the index self.model_name_to_deployment_indices = {} # Reset the model_name index + self.team_model_to_deployment_indices = {} # Reset the team_model index self._invalidate_model_group_info_cache() self._invalidate_access_groups_cache() # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works @@ -7132,16 +7207,17 @@ def _update_deployment_indices_after_removal( # Update model_name_to_deployment_indices for model_name, indices in list(self.model_name_to_deployment_indices.items()): - # Remove the deleted index - if removal_idx in indices: - indices.remove(removal_idx) - - # Decrement all indices greater than removal_idx + # Build new list without mutating the original updated_indices = [] for idx in indices: - if idx > removal_idx: + if idx == removal_idx: + # Skip the removed index + continue + elif idx > removal_idx: + # Decrement indices after removal updated_indices.append(idx - 1) else: + # Keep indices before removal unchanged updated_indices.append(idx) # Update or remove the entry @@ -7150,6 +7226,46 @@ def _update_deployment_indices_after_removal( else: del self.model_name_to_deployment_indices[model_name] + # Update team_model_to_deployment_indices + for key, indices in list(self.team_model_to_deployment_indices.items()): + # Build new list without mutating the original + updated_indices = [] + for idx in indices: + if idx == removal_idx: + # Skip the removed index + continue + elif idx > removal_idx: + # Decrement indices after removal + updated_indices.append(idx - 1) + else: + # Keep indices before removal unchanged + updated_indices.append(idx) + + # Update or remove the entry + if len(updated_indices) > 0: + self.team_model_to_deployment_indices[key] = updated_indices + else: + del self.team_model_to_deployment_indices[key] + + def _update_team_model_index(self, model: dict, idx: int) -> None: + """ + Helper to update team_model_to_deployment_indices for a single deployment. + + Parameters: + - model: dict - the deployment to index + - idx: int - the index in model_list + """ + team_id = (model.get("model_info") or {}).get("team_id") + team_public_model_name = (model.get("model_info") or {}).get( + "team_public_model_name" + ) + if team_id and team_public_model_name: + key = (team_id, team_public_model_name) + if key not in self.team_model_to_deployment_indices: + self.team_model_to_deployment_indices[key] = [] + if idx not in self.team_model_to_deployment_indices[key]: + self.team_model_to_deployment_indices[key].append(idx) + def _add_model_to_list_and_index_map( self, model: dict, model_id: Optional[str] = None ) -> None: @@ -7178,6 +7294,9 @@ def _add_model_to_list_and_index_map( self.model_name_to_deployment_indices[model_name] = [] self.model_name_to_deployment_indices[model_name].append(idx) + # Update team_model index for O(1) team-scoped lookup + self._update_team_model_index(model, idx) + def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]: """ Add or update deployment @@ -7196,7 +7315,10 @@ def upsert_deployment(self, deployment: Deployment) -> Optional[Deployment]: ) if _deployment_on_router is not None: # deployment with this model_id exists on the router - if deployment.litellm_params == _deployment_on_router.litellm_params: + if ( + deployment.litellm_params == _deployment_on_router.litellm_params + and deployment.model_info == _deployment_on_router.model_info + ): # No need to update return None @@ -8008,6 +8130,7 @@ def _build_model_name_index(self, model_list: list) -> None: instead of O(n) linear scan through the entire model_list. """ self.model_name_to_deployment_indices.clear() + self.team_model_to_deployment_indices.clear() for idx, model in enumerate(model_list): model_name = model.get("model_name") @@ -8016,6 +8139,8 @@ def _build_model_name_index(self, model_list: list) -> None: self.model_name_to_deployment_indices[model_name] = [] self.model_name_to_deployment_indices[model_name].append(idx) + self._update_team_model_index(model, idx) + def _build_model_id_to_deployment_index_map(self, model_list: list): """ Build model index from model list to enable O(1) lookups immediately. @@ -8148,20 +8273,25 @@ def resolve_model_name_from_model_id( def map_team_model(self, team_model_name: str, team_id: str) -> Optional[str]: """ - Map a team model name to a team-specific model name. + Check if team_model_name resolves to team-specific deployments. + + Returns the public model name (unchanged) so the router can find all + sibling deployments via team_id filtering, instead of collapsing to a + single internal model_name. Returns: - - deployment id: str - the deployment id of the team-specific model - - None: if no team-specific model name is found + - str: the team_model_name if team deployments exist for this team + - None: if no team-specific model is found """ models = self.get_model_list(model_name=team_model_name, team_id=team_id) if not models: return None for model in models: if model.get("model_info", {}).get("team_id") == team_id: - return model.get("model_name") + return team_model_name - ## wildcard models + # No team-scoped deployment found; wildcard/pattern routes are + # handled downstream by the pattern_router in _common_checks_available_deployment. return None def should_include_deployment( @@ -8172,12 +8302,22 @@ def should_include_deployment( """ if ( team_id is not None - and model["model_info"].get("team_id") == team_id - and model_name == model["model_info"].get("team_public_model_name") + and (model.get("model_info") or {}).get("team_id") == team_id + and model_name + == (model.get("model_info") or {}).get("team_public_model_name") ): return True elif model_name is not None and model["model_name"] == model_name: - return True + # Fallback: check by internal model_name for non-team deployments + # or deployments that haven't been migrated to team_public_model_name yet + model_team_id = (model.get("model_info") or {}).get("team_id") + if ( + team_id is None # requester has no team constraint + or model_team_id is None # global deployment - accessible to all teams + or model_team_id == team_id # deployment belongs to requester's team + ): + return True + # No match: deployment is for a different team or doesn't match the requested model return False def _get_all_deployments( @@ -8194,9 +8334,36 @@ def _get_all_deployments( if team_id specified, only return team-specific models Optimized with O(1) index lookup instead of O(n) linear scan. + + Note: when team_id is provided, O(1) lookup in + `team_model_to_deployment_indices` only applies when `model_name` is the + team public model name. If a caller passes an internal deployment model + name (for example, `model_name__`), this method falls back + to the standard model-name index / scan path. """ returned_models: List[DeploymentTypedDict] = [] + # O(1) lookup in team_model index when team_id is provided + if team_id is not None: + key = (team_id, model_name) + if key in self.team_model_to_deployment_indices: + indices = self.team_model_to_deployment_indices[key] + # O(k) where k = team deployments for this model_name (typically 1-10) + for idx in indices: + model = self.model_list[idx] + if not self.should_include_deployment( + model_name=model_name, model=model, team_id=team_id + ): + continue + if model_alias is not None: + alias_model = model.copy() + alias_model["model_name"] = model_alias + returned_models.append(alias_model) + else: + returned_models.append(model) + if returned_models: + return returned_models + # O(1) lookup in model_name index if model_name in self.model_name_to_deployment_indices: indices = self.model_name_to_deployment_indices[model_name] @@ -8791,12 +8958,6 @@ def _pre_call_checks( # noqa: PLR0915 if i not in invalid_model_indices ] - ## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2) - if len(_returned_deployments) > 0: - _returned_deployments = litellm.utils._get_order_filtered_deployments( - _returned_deployments - ) - return _returned_deployments def _get_model_from_alias(self, model: str) -> Optional[str]: @@ -8867,6 +9028,16 @@ def _common_checks_available_deployment( model = _model_from_alias if model not in self.model_names: + # Check for team-specific deployments by team_public_model_name. + # This intentionally takes priority over team pattern routers below, + # so that named team deployments shadow wildcard/pattern routes. + if request_team_id is not None: + team_deployments = self._get_all_deployments( + model_name=model, team_id=request_team_id + ) + if team_deployments: + return model, team_deployments + # check if provider/ specific wildcard routing use pattern matching pattern_deployments = self.pattern_router.get_deployments_by_pattern( model=model, @@ -8997,6 +9168,14 @@ async def async_get_healthy_deployments( if isinstance(healthy_deployments, dict): return healthy_deployments + # Health-check-based filtering (before cooldown) + healthy_deployments = ( + await self._async_filter_health_check_unhealthy_deployments( + healthy_deployments=healthy_deployments, + parent_otel_span=parent_otel_span, + ) + ) + cooldown_deployments = await _async_get_cooldown_deployments( litellm_router_instance=self, parent_otel_span=parent_otel_span ) @@ -9035,6 +9214,12 @@ async def async_get_healthy_deployments( ), ) + ## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2) + _target_order = (request_kwargs or {}).pop("_target_order", None) + healthy_deployments = litellm.utils._get_order_filtered_deployments( + cast(List[Dict], healthy_deployments), target_order=_target_order + ) + if len(healthy_deployments) == 0: exception = await async_raise_no_deployment_exception( litellm_router_instance=self, @@ -9422,6 +9607,13 @@ def get_available_deployment( parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( request_kwargs ) + + # Health-check-based filtering (before cooldown) + healthy_deployments = self._filter_health_check_unhealthy_deployments( + healthy_deployments=healthy_deployments, + parent_otel_span=parent_otel_span, + ) + cooldown_deployments = _get_cooldown_deployments( litellm_router_instance=self, parent_otel_span=parent_otel_span ) @@ -9439,6 +9631,12 @@ def get_available_deployment( request_kwargs=request_kwargs, ) + ## ORDER FILTERING ## -> if user set 'order' in deployments, return deployments with lowest order (e.g. order=1 > order=2) + _target_order = (request_kwargs or {}).pop("_target_order", None) + healthy_deployments = litellm.utils._get_order_filtered_deployments( + healthy_deployments, target_order=_target_order + ) + if len(healthy_deployments) == 0: model_ids = self.get_model_ids(model_name=model) _cooldown_time = self.cooldown_cache.get_min_cooldown( @@ -9581,10 +9779,14 @@ def get_available_deployment_for_pass_through( llm_provider="", ) - # 4. Apply cooldown filtering + # 4. Apply health-check and cooldown filtering parent_otel_span: Optional[Span] = _get_parent_otel_span_from_kwargs( request_kwargs ) + pass_through_deployments = self._filter_health_check_unhealthy_deployments( + healthy_deployments=pass_through_deployments, + parent_otel_span=parent_otel_span, + ) cooldown_deployments = _get_cooldown_deployments( litellm_router_instance=self, parent_otel_span=parent_otel_span ) @@ -9706,6 +9908,67 @@ def _filter_cooldown_deployments( if deployment["model_info"]["id"] not in cooldown_set ] + async def _async_filter_health_check_unhealthy_deployments( + self, + healthy_deployments: List[Dict], + parent_otel_span: Optional[Span] = None, + ) -> List[Dict]: + """ + Filter out deployments marked unhealthy by background health checks. + No-op when enable_health_check_routing is False. + Returns all deployments if health state is unavailable, stale, or would + exclude every candidate (safety net). + """ + if not self.enable_health_check_routing: + return healthy_deployments + + unhealthy_ids = ( + await self.health_state_cache.async_get_unhealthy_deployment_ids( + parent_otel_span=parent_otel_span + ) + ) + if not unhealthy_ids: + return healthy_deployments + + filtered = [ + d for d in healthy_deployments if d["model_info"]["id"] not in unhealthy_ids + ] + + if not filtered: + verbose_router_logger.warning( + "All deployments marked unhealthy by health checks, bypassing health filter" + ) + return healthy_deployments + + return filtered + + def _filter_health_check_unhealthy_deployments( + self, + healthy_deployments: List[Dict], + parent_otel_span: Optional[Span] = None, + ) -> List[Dict]: + """Sync version of _async_filter_health_check_unhealthy_deployments.""" + if not self.enable_health_check_routing: + return healthy_deployments + + unhealthy_ids = self.health_state_cache.get_unhealthy_deployment_ids( + parent_otel_span=parent_otel_span + ) + if not unhealthy_ids: + return healthy_deployments + + filtered = [ + d for d in healthy_deployments if d["model_info"]["id"] not in unhealthy_ids + ] + + if not filtered: + verbose_router_logger.warning( + "All deployments marked unhealthy by health checks, bypassing health filter" + ) + return healthy_deployments + + return filtered + def _filter_pass_through_deployments( self, healthy_deployments: List[Dict] ) -> List[Dict]: diff --git a/litellm/router_utils/health_state_cache.py b/litellm/router_utils/health_state_cache.py new file mode 100644 index 00000000000..65b064f19d2 --- /dev/null +++ b/litellm/router_utils/health_state_cache.py @@ -0,0 +1,100 @@ +""" +Wrapper around router cache for health-check-driven routing. + +Stores per-deployment health state from background health checks +and exposes it for router candidate filtering. +""" + +import time +from typing import TYPE_CHECKING, Any, Dict, Optional, Set, Union + +from typing_extensions import TypedDict + +from litellm import verbose_logger +from litellm.caching.caching import DualCache + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = Union[_Span, Any] +else: + Span = Any + + +class DeploymentHealthStateValue(TypedDict): + is_healthy: bool + timestamp: float + reason: str + + +class DeploymentHealthCache: + """ + Cache for deployment health states produced by background health checks. + + Stores a single dict mapping deployment_id -> DeploymentHealthStateValue. + Staleness is enforced at read time: entries older than staleness_threshold + are treated as healthy (unknown). + """ + + CACHE_KEY = "litellm:health_check:deployment_health_state" + + def __init__(self, cache: DualCache, staleness_threshold: float): + self.cache = cache + self.staleness_threshold = staleness_threshold + + def set_deployment_health_states( + self, states: Dict[str, DeploymentHealthStateValue] + ) -> None: + """Bulk-write all deployment health states as a single cache entry.""" + try: + self.cache.set_cache( + key=self.CACHE_KEY, + value=states, + ttl=int(self.staleness_threshold * 1.5), + ) + except Exception as e: + verbose_logger.error( + "DeploymentHealthCache::set_deployment_health_states - Exception: %s", + str(e), + ) + + def _extract_unhealthy_ids(self, raw: Any) -> Set[str]: + """Given raw cache value, return set of non-stale unhealthy deployment IDs.""" + if not raw or not isinstance(raw, dict): + return set() + now = time.time() + return { + model_id + for model_id, state in raw.items() + if isinstance(state, dict) + and not state.get("is_healthy", True) + and (now - state.get("timestamp", 0)) < self.staleness_threshold + } + + async def async_get_unhealthy_deployment_ids( + self, parent_otel_span: Optional[Span] = None + ) -> Set[str]: + """Return set of deployment IDs currently marked unhealthy and not stale.""" + try: + raw = await self.cache.async_get_cache(key=self.CACHE_KEY) + return self._extract_unhealthy_ids(raw) + except Exception as e: + verbose_logger.debug( + "DeploymentHealthCache::async_get_unhealthy_deployment_ids - Exception: %s", + str(e), + ) + return set() + + def get_unhealthy_deployment_ids( + self, parent_otel_span: Optional[Span] = None + ) -> Set[str]: + """Sync version: return set of deployment IDs currently marked unhealthy and not stale.""" + try: + raw = self.cache.get_cache(key=self.CACHE_KEY) + return self._extract_unhealthy_ids(raw) + except Exception as e: + verbose_logger.debug( + "DeploymentHealthCache::get_unhealthy_deployment_ids - Exception: %s", + str(e), + ) + return set() diff --git a/litellm/utils.py b/litellm/utils.py index 088ee07d630..0e3792773aa 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4866,7 +4866,21 @@ def calculate_max_parallel_requests( return None -def _get_order_filtered_deployments(healthy_deployments: List[Dict]) -> List: +def _get_order_filtered_deployments( + healthy_deployments: List[Dict], target_order: Optional[int] = None +) -> List: + if target_order is not None: + filtered = [ + d + for d in healthy_deployments + if d["litellm_params"].get("order") == target_order + ] + if filtered: + return filtered + # target_order doesn't match any deployment (e.g., external fallback model) — return all + return healthy_deployments + + # Default: pick min order group min_order = min( ( deployment["litellm_params"]["order"] diff --git a/tests/proxy_unit_tests/test_proxy_utils.py b/tests/proxy_unit_tests/test_proxy_utils.py index 00d4cd24e4b..6c6ec7bcd60 100644 --- a/tests/proxy_unit_tests/test_proxy_utils.py +++ b/tests/proxy_unit_tests/test_proxy_utils.py @@ -2044,6 +2044,58 @@ def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_mod assert test_data.get("model") == expected_model +def test_team_alias_stale_bypass_disabled_by_default(monkeypatch): + monkeypatch.delenv("LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS", raising=False) + import litellm.proxy.litellm_pre_call_utils as pre_call_utils + from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists + + # Reset module-level cache to ensure test isolation + pre_call_utils._ENABLE_TEAM_STALE_ALIAS_BYPASS = None + + class _MockRouter: + team_model_to_deployment_indices = {("team-1", "gpt-4o"): [0]} + + test_data = {"model": "gpt-4o"} + user_api_key_dict = UserAPIKeyAuth( + api_key="test_key", + team_id="team-1", + team_model_aliases={"gpt-4o": "model_name_team-1_legacy-uuid"}, + ) + + with patch("litellm.proxy.proxy_server.llm_router", _MockRouter()): + _update_model_if_team_alias_exists( + data=test_data, user_api_key_dict=user_api_key_dict + ) + + assert test_data.get("model") == "model_name_team-1_legacy-uuid" + + +def test_team_alias_stale_bypass_enabled_by_flag(monkeypatch): + import litellm.proxy.litellm_pre_call_utils as pre_call_utils + from litellm.proxy.litellm_pre_call_utils import _update_model_if_team_alias_exists + + # Reset module-level cache to ensure test isolation + pre_call_utils._ENABLE_TEAM_STALE_ALIAS_BYPASS = None + + class _MockRouter: + team_model_to_deployment_indices = {("team-1", "gpt-4o"): [0]} + + test_data = {"model": "gpt-4o"} + user_api_key_dict = UserAPIKeyAuth( + api_key="test_key", + team_id="team-1", + team_model_aliases={"gpt-4o": "model_name_team-1_legacy-uuid"}, + ) + monkeypatch.setenv("LITELLM_ENABLE_TEAM_STALE_ALIAS_BYPASS", "true") + + with patch("litellm.proxy.proxy_server.llm_router", _MockRouter()): + _update_model_if_team_alias_exists( + data=test_data, user_api_key_dict=user_api_key_dict + ) + + assert test_data.get("model") == "gpt-4o" + + @pytest.fixture def mock_prisma_client(): client = MagicMock() diff --git a/tests/router_unit_tests/test_get_model_list_alias_optimization.py b/tests/router_unit_tests/test_get_model_list_alias_optimization.py index 31d992b6646..2c2df3be945 100644 --- a/tests/router_unit_tests/test_get_model_list_alias_optimization.py +++ b/tests/router_unit_tests/test_get_model_list_alias_optimization.py @@ -44,7 +44,7 @@ def test_map_team_model_should_not_iterate_aliases_for_non_alias_team_model_name {f"alias-{idx}": "gpt-4" for idx in range(200)} ) - assert ( - router.map_team_model(team_model_name="team-model", team_id="team-1") - == "gpt-3.5-turbo" - ) + # map_team_model should return the public name unchanged (not the internal UUID name) + # so the router can find all sibling deployments via team_id filtering + result = router.map_team_model(team_model_name="team-model", team_id="team-1") + assert result == "team-model", f"Expected public name 'team-model', got {result}" diff --git a/tests/router_unit_tests/test_router_index_management.py b/tests/router_unit_tests/test_router_index_management.py index 90d98b8ab0a..2694c62827c 100644 --- a/tests/router_unit_tests/test_router_index_management.py +++ b/tests/router_unit_tests/test_router_index_management.py @@ -118,6 +118,28 @@ def test_add_model_to_list_and_index_map_multiple_models(self, router): assert router.model_id_to_deployment_index_map["id-2"] == 1 assert router.model_id_to_deployment_index_map["id-3"] == 2 + def test_update_team_model_index(self, router): + """Test _update_team_model_index updates team_model_to_deployment_indices.""" + model = { + "model_name": "team-alias", + "model_info": { + "id": "dep-1", + "team_id": "team-abc", + "team_public_model_name": "gpt-4o", + }, + } + router._update_team_model_index(model, 0) + assert router.team_model_to_deployment_indices[("team-abc", "gpt-4o")] == [0] + router._update_team_model_index(model, 2) + assert router.team_model_to_deployment_indices[("team-abc", "gpt-4o")] == [0, 2] + + router._update_team_model_index( + {"model_name": "x", "model_info": {"id": "dep-2"}}, 5 + ) + assert router.team_model_to_deployment_indices == { + ("team-abc", "gpt-4o"): [0, 2], + } + def test_has_model_id(self, router): """Test has_model_id function for O(1) membership check""" # Setup: Add models to router diff --git a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py index f3c89003105..2e566ab6222 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_model_management_endpoints.py @@ -1,13 +1,14 @@ import json import os import sys -from litellm._uuid import uuid from typing import Dict, Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastapi.testclient import TestClient +from litellm._uuid import uuid + sys.path.insert( 0, os.path.abspath("../../../..") ) # Adds the parent directory to the system path @@ -27,9 +28,15 @@ class MockPrismaClient: - def __init__(self, team_exists: bool = True, user_admin: bool = True): + def __init__( + self, + team_exists: bool = True, + user_admin: bool = True, + sibling_deployments: list = None, + ): self.team_exists = team_exists self.user_admin = user_admin + self.sibling_deployments = sibling_deployments or [] self.db = self async def find_unique(self, where): @@ -45,10 +52,53 @@ async def find_unique(self, where): ) return None + async def find_many(self, where): + # Filter sibling deployments by team_id if where clause specifies it + if not self.sibling_deployments: + return [] + + # Extract team_id from where clause if present + team_id_filter = None + if where and "model_info" in where: + model_info_filter = where["model_info"] + if isinstance(model_info_filter, dict) and "path" in model_info_filter: + if ( + model_info_filter["path"] == ["team_id"] + and "equals" in model_info_filter + ): + team_id_filter = model_info_filter["equals"] + + # Filter deployments by team_id if specified + if team_id_filter: + + def _get_team_id(model_info): + if isinstance(model_info, dict): + return model_info.get("team_id") + if isinstance(model_info, str): + try: + parsed = json.loads(model_info) + except (TypeError, ValueError): + return None + if isinstance(parsed, dict): + return parsed.get("team_id") + return None + + return [ + d + for d in self.sibling_deployments + if _get_team_id(d.model_info) == team_id_filter + ] + + return self.sibling_deployments + @property def litellm_teamtable(self): return self + @property + def litellm_proxymodeltable(self): + return self + class MockLLMRouter: def __init__(self): @@ -399,7 +449,9 @@ async def test_clear_cache_preserve_config_models(self): """ Test that clear_cache clears DB models and preserves config models. """ - from litellm.proxy.management_endpoints.model_management_endpoints import clear_cache + from litellm.proxy.management_endpoints.model_management_endpoints import ( + clear_cache, + ) # Create mock router with mixed DB and config models mock_router = MagicMock() @@ -407,18 +459,18 @@ async def test_clear_cache_preserve_config_models(self): { "model_name": "gpt-4", "model_info": {"id": "db-model-1", "db_model": True}, - "litellm_params": {"model": "gpt-4"} + "litellm_params": {"model": "gpt-4"}, }, { - "model_name": "gpt-3.5-turbo", + "model_name": "gpt-3.5-turbo", "model_info": {"id": "config-model-1", "db_model": False}, - "litellm_params": {"model": "gpt-3.5-turbo"} + "litellm_params": {"model": "gpt-3.5-turbo"}, }, { "model_name": "claude-3", "model_info": {"id": "db-model-2", "db_model": True}, - "litellm_params": {"model": "claude-3"} - } + "litellm_params": {"model": "claude-3"}, + }, ] mock_router.delete_deployment = MagicMock(return_value=True) mock_router.auto_routers = MagicMock() @@ -466,8 +518,8 @@ async def test_public_model_groups_set_after_get_config(self): """ import litellm from litellm.proxy.management_endpoints.model_management_endpoints import ( - update_public_model_groups, UpdatePublicModelGroupsRequest, + update_public_model_groups, ) old_db_models = ["db-model-1", "db-model-2"] @@ -525,7 +577,10 @@ async def test_useful_links_set_after_get_config(self): ) old_links = {"Old Doc": "https://old.example.com"} - new_links = {"New Doc": "https://new.example.com", "API Ref": "https://api.example.com"} + new_links = { + "New Doc": "https://new.example.com", + "API Ref": "https://api.example.com", + } async def mock_get_config(*args, **kwargs): litellm.public_model_groups_links = old_links @@ -558,6 +613,161 @@ async def mock_get_config(*args, **kwargs): litellm.public_model_groups_links = original_value +class TestTeamModelSiblingRouting: + """ + Verify that sibling team deployments (same public model name, different + api_base) are all reachable through routing — no alias overwrite, no + collapse to a single deployment. + """ + + @pytest.mark.asyncio + async def test_no_model_aliases_written_for_team_models(self): + """ + _add_team_model_to_db must NOT write model_aliases (which caused + the second sibling to overwrite the first). It should only call + team_model_add to register the public name on the team's models list. + """ + from litellm.proxy.management_endpoints.model_management_endpoints import ( + _add_team_model_to_db, + ) + from litellm.types.router import ModelInfo + + team_id = "team_no_alias" + public_name = "gpt-4.1-mini" + + async def mock_add_model_to_db(model_params, user_api_key_dict, prisma_client): + return MagicMock(model_id=str(uuid.uuid4())) + + mock_team_model_add = AsyncMock() + + user = UserAPIKeyAuth(user_id="admin", user_role=LitellmUserRoles.PROXY_ADMIN) + prisma_client = MockPrismaClient(team_exists=True) + + for api_base in ["https://eastus.example.com", "https://westus.example.com"]: + dep = Deployment( + model_name=public_name, + litellm_params=LiteLLM_Params( + model="azure/gpt-4o-mini", + api_key="key", + api_base=api_base, + ), + model_info=ModelInfo(team_id=team_id), + ) + with patch( + "litellm.proxy.management_endpoints.model_management_endpoints._add_model_to_db", + side_effect=mock_add_model_to_db, + ), patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add", + mock_team_model_add, + ): + await _add_team_model_to_db( + model_params=dep, + user_api_key_dict=user, + prisma_client=prisma_client, + ) + + assert mock_team_model_add.call_count == 2 + + @pytest.mark.asyncio + async def test_router_finds_all_sibling_team_deployments(self): + """ + When two team deployments share team_public_model_name="gpt-4.1-mini", + the router's _common_checks_available_deployment must return BOTH as + healthy_deployments (not collapse to one). + """ + import litellm + + team_id = "teamA" + public_name = "gpt-4.1-mini" + + router = litellm.Router( + model_list=[ + { + "model_name": f"model_name_{team_id}_uuid1", + "litellm_params": { + "model": "azure/gpt-4o-mini", + "api_key": "key-1", + "api_base": "https://eastus.openai.azure.com", + }, + "model_info": { + "team_id": team_id, + "team_public_model_name": public_name, + }, + }, + { + "model_name": f"model_name_{team_id}_uuid2", + "litellm_params": { + "model": "azure/gpt-4o-mini", + "api_key": "key-2", + "api_base": "https://westus.openai.azure.com", + }, + "model_info": { + "team_id": team_id, + "team_public_model_name": public_name, + }, + }, + { + "model_name": "global-gpt-4o", + "litellm_params": { + "model": "azure/gpt-4o", + "api_key": "global-key", + "api_base": "https://global.openai.azure.com", + }, + "model_info": {}, # No team_id - global deployment + }, + ], + ) + + # map_team_model should return the public name (not an internal UUID) + result = router.map_team_model(public_name, team_id) + assert result == public_name + + # _common_checks_available_deployment should return both deployments + model, healthy = router._common_checks_available_deployment( + model=public_name, + request_kwargs={"metadata": {"user_api_key_team_id": team_id}}, + ) + assert isinstance(healthy, list) + assert len(healthy) == 2 + api_bases = {d["litellm_params"]["api_base"] for d in healthy} + assert api_bases == { + "https://eastus.openai.azure.com", + "https://westus.openai.azure.com", + } + + def test_global_deployments_accessible_to_teams(self): + """Test that global deployments (no team_id) are accessible to all teams""" + import litellm + + router = litellm.Router( + model_list=[ + { + "model_name": "global-gpt-4o", + "litellm_params": { + "model": "azure/gpt-4o", + "api_key": "global-key", + "api_base": "https://global.openai.azure.com", + }, + "model_info": {}, # No team_id - global deployment + }, + ], + ) + + # Global deployment should be accessible when team_id is provided + deployments = router._get_all_deployments( + model_name="global-gpt-4o", team_id="teamA" + ) + assert len(deployments) == 1 + assert deployments[0]["model_name"] == "global-gpt-4o" + + # should_include_deployment should return True for global deployments + assert router.should_include_deployment( + model_name="global-gpt-4o", + model={"model_name": "global-gpt-4o", "model_info": {}}, + team_id="teamA", + ) + + class TestTeamModelUpdate: """Test team model update handles team_id consistently with model creation""" @@ -591,10 +801,10 @@ async def test_patch_model_with_team_id_creates_proper_setup(self): "litellm.proxy.proxy_server.premium_user", True, ), patch( - "litellm.proxy.management_endpoints.model_management_endpoints.update_team" - ) as mock_update_team, patch( "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add" - ) as mock_team_model_add: + ) as mock_team_model_add, patch( + "litellm.proxy.management_endpoints.model_management_endpoints.update_team" + ) as mock_update_team: result = await _update_team_model_in_db( db_model=db_model, patch_data=patch_data, @@ -604,8 +814,201 @@ async def test_patch_model_with_team_id_creates_proper_setup(self): assert result.get("model_name", "").startswith("model_name_test_team_123_") assert "team_public_model_name" in str(result.get("model_info", "")) - mock_update_team.assert_called_once() + # team_model_add must be called to add public name to team's models list mock_team_model_add.assert_called_once() + # update_team (model_aliases write) must NOT be called in the new implementation + mock_update_team.assert_not_called() + + @pytest.mark.asyncio + async def test_rename_preserves_old_name_when_siblings_exist(self): + """Test that renaming a deployment preserves old public name when sibling deployments still use it""" + from unittest.mock import MagicMock + + from litellm.proxy.management_endpoints.model_management_endpoints import ( + _update_existing_team_model_assignment, + ) + from litellm.types.router import ModelInfo + + # Create a deployment being renamed + db_model = Deployment( + model_name="model_name_team_123_uuid1", + litellm_params=LiteLLM_Params(model="azure/gpt-4o-mini"), + model_info=ModelInfo( + team_id="team_123", team_public_model_name="old-public-name" + ), + ) + + # Create a sibling deployment that still uses the old public name + sibling_deployment = MagicMock() + sibling_deployment.model_name = "model_name_team_123_uuid2" + sibling_deployment.model_info = { + "team_id": "team_123", + "team_public_model_name": "old-public-name", + } + + prisma_client = MockPrismaClient( + team_exists=True, sibling_deployments=[sibling_deployment] + ) + + patch_data = updateDeployment( + model_name="new-public-name", + model_info=ModelInfo(team_id="team_123"), + ) + + user_api_key_dict = UserAPIKeyAuth( + user_id="test_user", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + with patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_delete" + ) as mock_delete, patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add" + ) as mock_add: + await _update_existing_team_model_assignment( + team_id="team_123", + public_model_name="new-public-name", + db_model=db_model, + patch_data=patch_data, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, # type: ignore + ) + + # team_model_delete should NOT be called because sibling exists + mock_delete.assert_not_called() + # team_model_add should be called to add new public name + mock_add.assert_called_once() + + @pytest.mark.asyncio + async def test_first_time_public_name_assignment_adds_team_model(self): + """If existing team deployment had no public name, first assignment must call team_model_add.""" + from litellm.proxy.management_endpoints.model_management_endpoints import ( + _update_existing_team_model_assignment, + ) + from litellm.types.router import ModelInfo + + db_model = Deployment( + model_name="model_name_team_123_uuid1", + litellm_params=LiteLLM_Params(model="azure/gpt-4o-mini"), + model_info=ModelInfo(team_id="team_123"), + ) + + patch_data = updateDeployment( + model_name="new-public-name", + model_info=ModelInfo(team_id="team_123"), + ) + + user_api_key_dict = UserAPIKeyAuth( + user_id="test_user", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + with patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_delete" + ) as mock_delete, patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add" + ) as mock_add: + await _update_existing_team_model_assignment( + team_id="team_123", + public_model_name="new-public-name", + db_model=db_model, + patch_data=patch_data, + user_api_key_dict=user_api_key_dict, + prisma_client=None, + ) + + mock_add.assert_called_once() + mock_delete.assert_not_called() + + @pytest.mark.asyncio + async def test_rename_with_prisma_none_clears_patch_model_name(self): + """Rename path must clear patch_data.model_name even when prisma is unavailable (P1).""" + from litellm.proxy.management_endpoints.model_management_endpoints import ( + _update_existing_team_model_assignment, + ) + from litellm.types.router import ModelInfo + + db_model = Deployment( + model_name="model_name_team_123_uuid1", + litellm_params=LiteLLM_Params(model="azure/gpt-4o-mini"), + model_info=ModelInfo( + team_id="team_123", team_public_model_name="old-public-name" + ), + ) + patch_data = updateDeployment( + model_name="new-public-name", + model_info=ModelInfo(team_id="team_123"), + ) + user_api_key_dict = UserAPIKeyAuth( + user_id="test_user", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + await _update_existing_team_model_assignment( + team_id="team_123", + public_model_name="new-public-name", + db_model=db_model, + patch_data=patch_data, + user_api_key_dict=user_api_key_dict, + prisma_client=None, + ) + + assert patch_data.model_name is None + + @pytest.mark.asyncio + async def test_rename_handles_legacy_string_model_info(self): + """Test rename path handles legacy string-encoded model_info rows without crashing.""" + from unittest.mock import MagicMock + + from litellm.proxy.management_endpoints.model_management_endpoints import ( + _update_existing_team_model_assignment, + ) + from litellm.types.router import ModelInfo + + db_model = Deployment( + model_name="model_name_team_123_uuid1", + litellm_params=LiteLLM_Params(model="azure/gpt-4o-mini"), + model_info=ModelInfo( + team_id="team_123", team_public_model_name="old-public-name" + ), + ) + + sibling_deployment = MagicMock() + sibling_deployment.model_name = "model_name_team_123_uuid2" + sibling_deployment.model_info = ( + '{"team_id":"team_123","team_public_model_name":"old-public-name"}' + ) + + prisma_client = MockPrismaClient( + team_exists=True, sibling_deployments=[sibling_deployment] + ) + + patch_data = updateDeployment( + model_name="new-public-name", + model_info=ModelInfo(team_id="team_123"), + ) + + user_api_key_dict = UserAPIKeyAuth( + user_id="test_user", + user_role=LitellmUserRoles.PROXY_ADMIN, + ) + + with patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_delete" + ) as mock_delete, patch( + "litellm.proxy.management_endpoints.model_management_endpoints.team_model_add" + ) as mock_add: + await _update_existing_team_model_assignment( + team_id="team_123", + public_model_name="new-public-name", + db_model=db_model, + patch_data=patch_data, + user_api_key_dict=user_api_key_dict, + prisma_client=prisma_client, # type: ignore + ) + + mock_delete.assert_not_called() + mock_add.assert_called_once() @pytest.mark.asyncio async def test_patch_model_with_team_id_validates_permissions(self): @@ -657,27 +1060,37 @@ async def test_model_info_accessible_model_success(self): user_id="test_user", api_key="test_key", models=["gpt-4", "claude-3"], - team_models=["gpt-3.5-turbo"] + team_models=["gpt-3.5-turbo"], ) - with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \ - patch("litellm.proxy.proxy_server.get_key_models") as mock_get_key_models, \ - patch("litellm.proxy.proxy_server.get_team_models") as mock_get_team_models, \ - patch("litellm.proxy.proxy_server.get_complete_model_list") as mock_get_complete_models, \ - patch("litellm.get_llm_provider") as mock_get_provider: - + with patch("litellm.proxy.proxy_server.llm_router") as mock_router, patch( + "litellm.proxy.proxy_server.get_key_models" + ) as mock_get_key_models, patch( + "litellm.proxy.proxy_server.get_team_models" + ) as mock_get_team_models, patch( + "litellm.proxy.proxy_server.get_complete_model_list" + ) as mock_get_complete_models, patch( + "litellm.get_llm_provider" + ) as mock_get_provider: # Setup mocks - mock_router.get_model_names.return_value = ["gpt-4", "claude-3", "gpt-3.5-turbo"] + mock_router.get_model_names.return_value = [ + "gpt-4", + "claude-3", + "gpt-3.5-turbo", + ] mock_router.get_model_access_groups.return_value = {} mock_get_key_models.return_value = ["gpt-4", "claude-3"] mock_get_team_models.return_value = ["gpt-3.5-turbo"] - mock_get_complete_models.return_value = ["gpt-4", "claude-3", "gpt-3.5-turbo"] + mock_get_complete_models.return_value = [ + "gpt-4", + "claude-3", + "gpt-3.5-turbo", + ] mock_get_provider.return_value = (None, "openai", None, None) # Test accessible model result = await model_info( - model_id="gpt-4", - user_api_key_dict=user_api_key_dict + model_id="gpt-4", user_api_key_dict=user_api_key_dict ) assert result["id"] == "gpt-4" @@ -688,22 +1101,25 @@ async def test_model_info_accessible_model_success(self): @pytest.mark.asyncio async def test_model_info_inaccessible_model_returns_404(self): """Test model_info returns 404 for inaccessible models""" - from litellm.proxy.proxy_server import model_info from fastapi import HTTPException + from litellm.proxy.proxy_server import model_info + # Mock user with limited access user_api_key_dict = UserAPIKeyAuth( user_id="test_user", api_key="test_key", models=["gpt-4"], # Only has access to gpt-4 - team_models=[] + team_models=[], ) - with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \ - patch("litellm.proxy.proxy_server.get_key_models") as mock_get_key_models, \ - patch("litellm.proxy.proxy_server.get_team_models") as mock_get_team_models, \ - patch("litellm.proxy.proxy_server.get_complete_model_list") as mock_get_complete_models: - + with patch("litellm.proxy.proxy_server.llm_router") as mock_router, patch( + "litellm.proxy.proxy_server.get_key_models" + ) as mock_get_key_models, patch( + "litellm.proxy.proxy_server.get_team_models" + ) as mock_get_team_models, patch( + "litellm.proxy.proxy_server.get_complete_model_list" + ) as mock_get_complete_models: # Setup mocks - user only has access to gpt-4 mock_router.get_model_names.return_value = ["gpt-4", "claude-3"] mock_router.get_model_access_groups.return_value = {} @@ -715,32 +1131,35 @@ async def test_model_info_inaccessible_model_returns_404(self): with pytest.raises(HTTPException) as exc_info: await model_info( model_id="claude-3", # Not in user's accessible models - user_api_key_dict=user_api_key_dict + user_api_key_dict=user_api_key_dict, ) - + assert exc_info.value.status_code == 404 assert "does not exist or is not accessible" in exc_info.value.detail - @pytest.mark.asyncio + @pytest.mark.asyncio async def test_model_info_team_model_access(self): """Test model_info works with team model access""" from litellm.proxy.proxy_server import model_info - + # Mock user with team access user_api_key_dict = UserAPIKeyAuth( user_id="test_user", - api_key="test_key", + api_key="test_key", team_id="test_team", models=[], # No direct key models - team_models=["team-model-1"] + team_models=["team-model-1"], ) - with patch("litellm.proxy.proxy_server.llm_router") as mock_router, \ - patch("litellm.proxy.proxy_server.get_key_models") as mock_get_key_models, \ - patch("litellm.proxy.proxy_server.get_team_models") as mock_get_team_models, \ - patch("litellm.proxy.proxy_server.get_complete_model_list") as mock_get_complete_models, \ - patch("litellm.get_llm_provider") as mock_get_provider: - + with patch("litellm.proxy.proxy_server.llm_router") as mock_router, patch( + "litellm.proxy.proxy_server.get_key_models" + ) as mock_get_key_models, patch( + "litellm.proxy.proxy_server.get_team_models" + ) as mock_get_team_models, patch( + "litellm.proxy.proxy_server.get_complete_model_list" + ) as mock_get_complete_models, patch( + "litellm.get_llm_provider" + ) as mock_get_provider: # Setup mocks mock_router.get_model_names.return_value = ["team-model-1"] mock_router.get_model_access_groups.return_value = {} @@ -751,10 +1170,9 @@ async def test_model_info_team_model_access(self): # Test team model access result = await model_info( - model_id="team-model-1", - user_api_key_dict=user_api_key_dict + model_id="team-model-1", user_api_key_dict=user_api_key_dict ) assert result["id"] == "team-model-1" - assert result["object"] == "model" + assert result["object"] == "model" assert result["owned_by"] == "custom" diff --git a/tests/test_litellm/router_utils/test_health_check_routing.py b/tests/test_litellm/router_utils/test_health_check_routing.py new file mode 100644 index 00000000000..f40144b44c9 --- /dev/null +++ b/tests/test_litellm/router_utils/test_health_check_routing.py @@ -0,0 +1,197 @@ +""" +Tests for health-check-driven routing filter in the Router. +""" + +import time + +import pytest + +from litellm.caching.caching import DualCache +from litellm.router_utils.health_state_cache import DeploymentHealthCache + + +def _make_deployment(model_id: str, model_name: str = "gpt-4") -> dict: + """Helper to create a deployment dict for testing.""" + return { + "model_name": model_name, + "litellm_params": {"model": model_name, "api_key": "fake"}, + "model_info": {"id": model_id}, + } + + +def _make_health_cache( + unhealthy_ids: set = None, staleness_threshold: float = 60.0 +) -> DeploymentHealthCache: + """Create a health cache pre-populated with unhealthy deployment IDs.""" + cache = DualCache() + health_cache = DeploymentHealthCache( + cache=cache, staleness_threshold=staleness_threshold + ) + if unhealthy_ids: + now = time.time() + states = {} + for uid in unhealthy_ids: + states[uid] = { + "is_healthy": False, + "timestamp": now, + "reason": "test_unhealthy", + } + health_cache.set_deployment_health_states(states) + return health_cache + + +class TestFilterHealthCheckUnhealthyDeployments: + """Test the sync filter method.""" + + def _make_router_like(self, enable: bool, health_cache: DeploymentHealthCache): + """Create a minimal object that behaves like Router for filter testing.""" + + class FakeRouter: + def __init__(self): + self.enable_health_check_routing = enable + self.health_state_cache = health_cache + + # Import the actual method and bind it + from litellm.router import Router + + fake = FakeRouter() + # Use the unbound method + fake._filter_health_check_unhealthy_deployments = ( + Router._filter_health_check_unhealthy_deployments.__get__(fake, FakeRouter) + ) + return fake + + def test_filter_removes_unhealthy_deployments(self): + """Unhealthy deployments should be removed from candidates.""" + health_cache = _make_health_cache(unhealthy_ids={"deploy-2"}) + router = self._make_router_like(enable=True, health_cache=health_cache) + + deployments = [ + _make_deployment("deploy-1"), + _make_deployment("deploy-2"), + _make_deployment("deploy-3"), + ] + result = router._filter_health_check_unhealthy_deployments(deployments) + assert len(result) == 2 + assert all(d["model_info"]["id"] != "deploy-2" for d in result) + + def test_filter_noop_when_disabled(self): + """When enable_health_check_routing=False, filter should be a no-op.""" + health_cache = _make_health_cache(unhealthy_ids={"deploy-1"}) + router = self._make_router_like(enable=False, health_cache=health_cache) + + deployments = [ + _make_deployment("deploy-1"), + _make_deployment("deploy-2"), + ] + result = router._filter_health_check_unhealthy_deployments(deployments) + assert len(result) == 2 # no filtering + + def test_filter_returns_all_when_all_unhealthy(self): + """Safety net: if ALL deployments are unhealthy, return all (don't cause outage).""" + health_cache = _make_health_cache( + unhealthy_ids={"deploy-1", "deploy-2", "deploy-3"} + ) + router = self._make_router_like(enable=True, health_cache=health_cache) + + deployments = [ + _make_deployment("deploy-1"), + _make_deployment("deploy-2"), + _make_deployment("deploy-3"), + ] + result = router._filter_health_check_unhealthy_deployments(deployments) + assert len(result) == 3 # all returned, safety net + + def test_filter_returns_all_when_cache_empty(self): + """When cache is empty, all deployments should pass through.""" + health_cache = _make_health_cache() # empty + router = self._make_router_like(enable=True, health_cache=health_cache) + + deployments = [ + _make_deployment("deploy-1"), + _make_deployment("deploy-2"), + ] + result = router._filter_health_check_unhealthy_deployments(deployments) + assert len(result) == 2 + + +class TestAsyncFilterHealthCheckUnhealthyDeployments: + """Test the async filter method.""" + + def _make_router_like(self, enable: bool, health_cache: DeploymentHealthCache): + from litellm.router import Router + + class FakeRouter: + def __init__(self): + self.enable_health_check_routing = enable + self.health_state_cache = health_cache + + fake = FakeRouter() + fake._async_filter_health_check_unhealthy_deployments = ( + Router._async_filter_health_check_unhealthy_deployments.__get__( + fake, FakeRouter + ) + ) + return fake + + @pytest.mark.asyncio + async def test_async_filter_removes_unhealthy(self): + """Async version: unhealthy deployments removed.""" + health_cache = _make_health_cache(unhealthy_ids={"deploy-2"}) + router = self._make_router_like(enable=True, health_cache=health_cache) + + deployments = [ + _make_deployment("deploy-1"), + _make_deployment("deploy-2"), + _make_deployment("deploy-3"), + ] + result = await router._async_filter_health_check_unhealthy_deployments( + healthy_deployments=deployments + ) + assert len(result) == 2 + assert all(d["model_info"]["id"] != "deploy-2" for d in result) + + @pytest.mark.asyncio + async def test_async_filter_safety_net(self): + """Async version: safety net when all unhealthy.""" + health_cache = _make_health_cache(unhealthy_ids={"deploy-1", "deploy-2"}) + router = self._make_router_like(enable=True, health_cache=health_cache) + + deployments = [ + _make_deployment("deploy-1"), + _make_deployment("deploy-2"), + ] + result = await router._async_filter_health_check_unhealthy_deployments( + healthy_deployments=deployments + ) + assert len(result) == 2 # safety net + + +class TestBuildDeploymentHealthStates: + """Test the build_deployment_health_states function.""" + + def test_builds_states_from_endpoints(self): + from litellm.proxy.health_check import build_deployment_health_states + + healthy = [{"model": "gpt-4", "model_id": "deploy-1"}] + unhealthy = [{"model": "gpt-4", "model_id": "deploy-2", "error": "timeout"}] + + states = build_deployment_health_states(healthy, unhealthy) + assert states["deploy-1"]["is_healthy"] is True + assert states["deploy-2"]["is_healthy"] is False + + def test_no_model_id_skipped(self): + from litellm.proxy.health_check import build_deployment_health_states + + healthy = [{"model": "gpt-4"}] # no model_id + unhealthy = [{"model": "gpt-4", "model_id": "deploy-2"}] + + states = build_deployment_health_states(healthy, unhealthy) + assert "deploy-1" not in states + assert states["deploy-2"]["is_healthy"] is False + + def test_empty_endpoints(self): + from litellm.proxy.health_check import build_deployment_health_states + + states = build_deployment_health_states([], []) + assert states == {} diff --git a/tests/test_litellm/router_utils/test_health_state_cache.py b/tests/test_litellm/router_utils/test_health_state_cache.py new file mode 100644 index 00000000000..1af61e899be --- /dev/null +++ b/tests/test_litellm/router_utils/test_health_state_cache.py @@ -0,0 +1,113 @@ +""" +Tests for DeploymentHealthCache - the cache layer for health-check-driven routing. +""" + +import time + +import pytest + +from litellm.caching.caching import DualCache +from litellm.router_utils.health_state_cache import DeploymentHealthCache + + +@pytest.fixture +def cache(): + return DualCache() + + +@pytest.fixture +def health_cache(cache): + return DeploymentHealthCache(cache=cache, staleness_threshold=60.0) + + +def test_set_and_get_unhealthy_ids(health_cache): + """Write states, verify unhealthy set is returned correctly.""" + now = time.time() + states = { + "deploy-1": {"is_healthy": True, "timestamp": now, "reason": ""}, + "deploy-2": {"is_healthy": False, "timestamp": now, "reason": "check_failed"}, + "deploy-3": {"is_healthy": False, "timestamp": now, "reason": "timeout"}, + } + health_cache.set_deployment_health_states(states) + result = health_cache.get_unhealthy_deployment_ids() + assert result == {"deploy-2", "deploy-3"} + + +@pytest.mark.asyncio +async def test_async_get_unhealthy_ids(health_cache): + """Async version of set and get.""" + now = time.time() + states = { + "deploy-1": {"is_healthy": True, "timestamp": now, "reason": ""}, + "deploy-2": {"is_healthy": False, "timestamp": now, "reason": "check_failed"}, + } + health_cache.set_deployment_health_states(states) + result = await health_cache.async_get_unhealthy_deployment_ids() + assert result == {"deploy-2"} + + +def test_staleness_filtering(health_cache): + """Entries older than staleness_threshold should be ignored.""" + old_time = time.time() - 120 # 2 minutes ago, threshold is 60s + states = { + "deploy-1": { + "is_healthy": False, + "timestamp": old_time, + "reason": "check_failed", + }, + } + health_cache.set_deployment_health_states(states) + result = health_cache.get_unhealthy_deployment_ids() + assert result == set() # stale entry should be ignored + + +def test_empty_cache_returns_empty_set(health_cache): + """No data in cache should return empty set.""" + result = health_cache.get_unhealthy_deployment_ids() + assert result == set() + + +def test_all_healthy_returns_empty_set(health_cache): + """All healthy deployments should return empty set.""" + now = time.time() + states = { + "deploy-1": {"is_healthy": True, "timestamp": now, "reason": ""}, + "deploy-2": {"is_healthy": True, "timestamp": now, "reason": ""}, + } + health_cache.set_deployment_health_states(states) + result = health_cache.get_unhealthy_deployment_ids() + assert result == set() + + +def test_mixed_stale_and_fresh(health_cache): + """Only fresh unhealthy entries should be returned.""" + now = time.time() + old_time = now - 120 # stale + states = { + "deploy-1": { + "is_healthy": False, + "timestamp": old_time, + "reason": "stale", + }, + "deploy-2": { + "is_healthy": False, + "timestamp": now, + "reason": "fresh", + }, + } + health_cache.set_deployment_health_states(states) + result = health_cache.get_unhealthy_deployment_ids() + assert result == {"deploy-2"} + + +def test_malformed_state_entries_are_skipped(health_cache): + """Non-dict entries in the cache should be skipped safely.""" + now = time.time() + states = { + "deploy-1": {"is_healthy": False, "timestamp": now, "reason": "bad"}, + "deploy-2": "not_a_dict", # malformed + "deploy-3": None, # malformed + } + health_cache.set_deployment_health_states(states) + result = health_cache.get_unhealthy_deployment_ids() + assert result == {"deploy-1"} diff --git a/tests/test_litellm/test_constants.py b/tests/test_litellm/test_constants.py index 8fff3ec40d4..b3c13c6e26e 100644 --- a/tests/test_litellm/test_constants.py +++ b/tests/test_litellm/test_constants.py @@ -1,3 +1,4 @@ +import ast import inspect import json import os @@ -17,57 +18,56 @@ from litellm import constants -def test_all_numeric_constants_can_be_overridden(): +def _build_constant_env_var_map() -> dict[str, str]: """ - Test that all integer and float constants in constants.py can be overridden with environment variables. - This ensures that any new constants added in the future will be configurable via environment variables. - """ - # Get all attributes from the constants module - constants_attributes = inspect.getmembers(constants) - - # Filter for uppercase constants (by convention) that are integers or floats - # Exclude booleans since bool is a subclass of int in Python - numeric_constants = [ - (name, value) - for name, value in constants_attributes - if name.isupper() and isinstance(value, (int, float)) and not isinstance(value, bool) - ] - - # Ensure we found some constants to test - assert len(numeric_constants) > 0, "No numeric constants found to test" + Build a mapping of CONSTANT_NAME -> ENV_VAR_NAME by parsing constants.py. - print("all numeric constants", json.dumps(numeric_constants, indent=4)) - - # Constants that use a different env var name than the constant name - constant_to_env_var = { - "MAX_CALLBACKS": "LITELLM_MAX_CALLBACKS", - "MCP_CLIENT_TIMEOUT": "LITELLM_MCP_CLIENT_TIMEOUT", - "MCP_TOOL_LISTING_TIMEOUT": "LITELLM_MCP_TOOL_LISTING_TIMEOUT", - "MCP_METADATA_TIMEOUT": "LITELLM_MCP_METADATA_TIMEOUT", - "MCP_HEALTH_CHECK_TIMEOUT": "LITELLM_MCP_HEALTH_CHECK_TIMEOUT", - } + This keeps the test resilient when a constant name and env var name differ + (e.g., aliases like LITELLM_* env vars). + """ + env_var_map: dict[str, str] = {} + constants_source = inspect.getsource(constants) + parsed = ast.parse(constants_source) - # Verify all numeric constants have environment variable support - for name, value in numeric_constants: - # Skip constants that are not meant to be overridden (if any) - if name.startswith("_"): + for node in parsed.body: + if not isinstance(node, ast.Assign): continue - # Create a test value that's different from the default - test_value = value + 1 if isinstance(value, int) else value + 0.1 - - # Use the env var name that the constants module actually reads - env_var_name = constant_to_env_var.get(name, name) - - # Set the environment variable - with mock.patch.dict(os.environ, {env_var_name: str(test_value)}): - print("overriding", name, "with", test_value) - importlib.reload(constants) - - # Get the new value after reload - new_value = getattr(constants, name) + if len(node.targets) != 1 or not isinstance(node.targets[0], ast.Name): + continue - # Verify the value was overridden - assert ( - new_value == test_value - ), f"Failed to override {name} with environment variable. Expected {test_value}, got {new_value}" + constant_name = node.targets[0].id + env_var_name = None + + for child in ast.walk(node.value): + if not isinstance(child, ast.Call): + continue + + # os.getenv("ENV_NAME", default) + if ( + isinstance(child.func, ast.Attribute) + and isinstance(child.func.value, ast.Name) + and child.func.value.id == "os" + and child.func.attr == "getenv" + and len(child.args) >= 1 + and isinstance(child.args[0], ast.Constant) + and isinstance(child.args[0].value, str) + ): + env_var_name = child.args[0].value + break + + # get_env_int("ENV_NAME", default) + if ( + isinstance(child.func, ast.Name) + and child.func.id == "get_env_int" + and len(child.args) >= 1 + and isinstance(child.args[0], ast.Constant) + and isinstance(child.args[0].value, str) + ): + env_var_name = child.args[0].value + break + + if env_var_name: + env_var_map[constant_name] = env_var_name + + return env_var_map diff --git a/tests/test_litellm/test_router_order_fallback.py b/tests/test_litellm/test_router_order_fallback.py new file mode 100644 index 00000000000..760766a7461 --- /dev/null +++ b/tests/test_litellm/test_router_order_fallback.py @@ -0,0 +1,331 @@ +""" +Tests for order-based fallback routing. + +When deployments have `order` set in litellm_params, lower order deployments +should be tried first, and higher order deployments should be used as fallbacks +when lower order deployments fail. +""" + +from typing import Optional + +import pytest + +from litellm import Router +from litellm.utils import _get_order_filtered_deployments + +# --------------------------------------------------------------------------- +# Unit tests for _get_order_filtered_deployments +# --------------------------------------------------------------------------- + + +class TestGetOrderFilteredDeployments: + def _make_deployment(self, order: Optional[int], dep_id: str) -> dict: + params: dict = {"model": "gpt-4o", "api_key": "key"} + if order is not None: + params["order"] = order + return { + "model_name": "test-model", + "litellm_params": params, + "model_info": {"id": dep_id}, + } + + def test_returns_min_order_group(self): + deps = [ + self._make_deployment(1, "a"), + self._make_deployment(2, "b"), + self._make_deployment(1, "c"), + ] + result = _get_order_filtered_deployments(deps) + assert len(result) == 2 + assert all(d["model_info"]["id"] in ("a", "c") for d in result) + + def test_target_order_filters_to_exact_level(self): + deps = [ + self._make_deployment(1, "a"), + self._make_deployment(2, "b"), + self._make_deployment(3, "c"), + ] + result = _get_order_filtered_deployments(deps, target_order=2) + assert len(result) == 1 + assert result[0]["model_info"]["id"] == "b" + + def test_target_order_no_match_returns_all(self): + deps = [ + self._make_deployment(1, "a"), + self._make_deployment(2, "b"), + ] + result = _get_order_filtered_deployments(deps, target_order=99) + assert len(result) == 2 + + def test_no_order_set_returns_all(self): + deps = [ + self._make_deployment(None, "a"), + self._make_deployment(None, "b"), + ] + result = _get_order_filtered_deployments(deps) + assert len(result) == 2 + + def test_empty_list(self): + result = _get_order_filtered_deployments([]) + assert result == [] + + def test_single_order_returns_all_with_that_order(self): + deps = [ + self._make_deployment(1, "a"), + self._make_deployment(1, "b"), + ] + result = _get_order_filtered_deployments(deps) + assert len(result) == 2 + + +# --------------------------------------------------------------------------- +# Integration tests for order-based fallback in Router +# --------------------------------------------------------------------------- + + +def test_router_order_without_pre_call_checks(): + """Order filtering should work even when enable_pre_call_checks=False (default).""" + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "key", + "mock_response": "from order 1", + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "key", + "mock_response": "from order 2", + "order": 2, + }, + "model_info": {"id": "2"}, + }, + ], + num_retries=0, + enable_pre_call_checks=False, + ) + + for _ in range(20): + response = router.completion( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + ) + assert response._hidden_params["model_id"] == "1" + + +def test_router_order_no_fallback_when_healthy(): + """When order=1 is healthy, order=2 should never be used.""" + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "key", + "mock_response": "from order 1", + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "key", + "mock_response": "from order 2", + "order": 2, + }, + "model_info": {"id": "2"}, + }, + ], + num_retries=0, + ) + + for _ in range(50): + response = router.completion( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + ) + assert response._hidden_params["model_id"] == "1" + + +@pytest.mark.asyncio +async def test_router_order_fallback_on_failure(): + """When order=1 fails, order=2 should be tried as fallback.""" + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad-key", + "mock_response": Exception("connection error"), + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "good-key", + "mock_response": "success from order 2", + "order": 2, + }, + "model_info": {"id": "2"}, + }, + ], + num_retries=0, + ) + + response = await router.acompletion( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + ) + assert response._hidden_params["model_id"] == "2" + + +@pytest.mark.asyncio +async def test_router_order_fallback_three_levels(): + """When order=1 and order=2 both fail, order=3 should be tried.""" + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad", + "mock_response": Exception("fail 1"), + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad", + "mock_response": Exception("fail 2"), + "order": 2, + }, + "model_info": {"id": "2"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "good", + "mock_response": "success from order 3", + "order": 3, + }, + "model_info": {"id": "3"}, + }, + ], + num_retries=0, + ) + + response = await router.acompletion( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + ) + assert response._hidden_params["model_id"] == "3" + + +@pytest.mark.asyncio +async def test_router_order_fallback_then_external_fallback(): + """When all order levels fail, external fallbacks should be tried.""" + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad", + "mock_response": Exception("fail order 1"), + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad", + "mock_response": Exception("fail order 2"), + "order": 2, + }, + "model_info": {"id": "2"}, + }, + { + "model_name": "fallback-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "good", + "mock_response": "success from external fallback", + }, + "model_info": {"id": "fallback"}, + }, + ], + fallbacks=[{"test-model": ["fallback-model"]}], + num_retries=0, + ) + + response = await router.acompletion( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + ) + assert response._hidden_params["model_id"] == "fallback" + + +@pytest.mark.asyncio +async def test_router_order_fallback_with_non_standard_fallbacks(): + """Non-standard fallback formats (e.g. fallbacks=["model-name"]) passed + per-request should still be tried after all order levels are exhausted.""" + router = Router( + model_list=[ + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad", + "mock_response": Exception("fail order 1"), + "order": 1, + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "test-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "bad", + "mock_response": Exception("fail order 2"), + "order": 2, + }, + "model_info": {"id": "2"}, + }, + { + "model_name": "fallback-model", + "litellm_params": { + "model": "gpt-4o", + "api_key": "good", + "mock_response": "success from non-standard fallback", + }, + "model_info": {"id": "fallback"}, + }, + ], + num_retries=0, + ) + + response = await router.acompletion( + model="test-model", + messages=[{"role": "user", "content": "hi"}], + fallbacks=["fallback-model"], # non-standard format, passed per-request + ) + assert response._hidden_params["model_id"] == "fallback"