99from urllib .parse import urlparse
1010
1111from curl_cffi .requests import AsyncSession
12- from fastapi import APIRouter , Depends , HTTPException , Query , WebSocket , WebSocketDisconnect
12+ from fastapi import APIRouter , Depends , HTTPException , Query , Request , WebSocket , WebSocketDisconnect
1313from fastapi .responses import JSONResponse , StreamingResponse
1414
15- from ..core .auth import verify_api_key_flexible
15+ from ..core .auth import AuthManager , verify_api_key_flexible
1616from ..core .logger import debug_logger
1717from ..core .model_resolver import get_base_model_aliases , resolve_model_name
1818from ..core .models import (
2323)
2424from ..services .generation_handler import MODEL_CONFIG , GenerationHandler
2525from ..services .browser_captcha_extension import ExtensionCaptchaService
26- from fastapi import WebSocket , WebSocketDisconnect
2726
2827router = APIRouter ()
2928
@@ -486,6 +485,7 @@ async def _collect_non_stream_result(
486485 model : str ,
487486 prompt : str ,
488487 images : List [bytes ],
488+ base_url_override : Optional [str ] = None ,
489489 video_media_id : Optional [str ] = None ,
490490) -> str :
491491 handler = _ensure_generation_handler ()
@@ -495,6 +495,7 @@ async def _collect_non_stream_result(
495495 prompt = prompt ,
496496 images = images if images else None ,
497497 stream = False ,
498+ base_url_override = base_url_override ,
498499 video_media_id = video_media_id ,
499500 ):
500501 result = chunk
@@ -723,6 +724,7 @@ async def _iterate_openai_stream(
723724 prompt = normalized .prompt ,
724725 images = normalized .images if normalized .images else None ,
725726 stream = True ,
727+ base_url_override = base_url_override ,
726728 video_media_id = normalized .video_media_id ,
727729 ):
728730 if chunk .startswith ("data: " ):
@@ -746,6 +748,7 @@ async def _iterate_gemini_stream(
746748 prompt = normalized .prompt ,
747749 images = normalized .images if normalized .images else None ,
748750 stream = True ,
751+ base_url_override = base_url_override ,
749752 video_media_id = normalized .video_media_id ,
750753 ):
751754 if chunk .startswith ("data: " ):
@@ -874,6 +877,7 @@ async def create_chat_completion(
874877 normalized .model ,
875878 normalized .prompt ,
876879 normalized .images ,
880+ base_url_override = request_base_url ,
877881 video_media_id = normalized .video_media_id ,
878882 )
879883 )
@@ -907,7 +911,8 @@ async def generate_content(
907911 normalized .model ,
908912 normalized .prompt ,
909913 normalized .images ,
910- request_base_url ,
914+ base_url_override = request_base_url ,
915+ video_media_id = normalized .video_media_id ,
911916 )
912917 )
913918 )
@@ -970,6 +975,20 @@ async def stream_generate_content(
970975@router .websocket ("/captcha_ws" )
971976async def captcha_websocket_endpoint (websocket : WebSocket ):
972977 from ..core .logger import debug_logger
978+ api_key = (
979+ websocket .query_params .get ("key" )
980+ or websocket .query_params .get ("api_key" )
981+ or websocket .headers .get ("x-goog-api-key" )
982+ or ""
983+ ).strip ()
984+ authorization = (websocket .headers .get ("authorization" ) or "" ).strip ()
985+ if authorization .lower ().startswith ("bearer " ):
986+ api_key = authorization [7 :].strip ()
987+
988+ if not api_key or not AuthManager .verify_api_key (api_key ):
989+ await websocket .close (code = 1008 )
990+ return
991+
973992 service = await ExtensionCaptchaService .get_instance ()
974993 await service .connect (websocket )
975994 try :
0 commit comments