diff --git a/README.md b/README.md index 6190b4e..38946fb 100644 --- a/README.md +++ b/README.md @@ -241,18 +241,35 @@ response = TextReRank.call( ### Image Generation +Image and video generation APIs use server-side asynchronous tasks: + +- `async_call()` submits a task and returns task information immediately. It is not a Python `async` coroutine. +- `call()` submits a task and blocks by polling task status until the task finishes. +- Use `fetch()` to query task status manually, or `wait()` to block until completion. +- Use `wait_timeout_seconds` with blocking calls to limit the maximum wait time. + ```python from dashscope import ImageSynthesis -# Async task pattern +# Submit a server-side async task response = ImageSynthesis.async_call( model="wanx-v1", prompt="A serene mountain landscape at sunset", ) -# Wait for result +# Query task status manually +status = ImageSynthesis.fetch(response) + +# Or wait for result result = ImageSynthesis.wait(response) +# Blocking call with timeout +result = ImageSynthesis.call( + model="wanx-v1", + prompt="A serene mountain landscape at sunset", + wait_timeout_seconds=60, +) + # Sync call (for wan2.2-t2i-flash/plus) result = ImageSynthesis.sync_call( model="wan2.2-t2i-flash", @@ -265,13 +282,21 @@ result = ImageSynthesis.sync_call( ```python from dashscope import VideoSynthesis -# Text-to-video +# Submit a server-side async task response = VideoSynthesis.async_call( model="wan2.7-t2v", prompt="A cat playing with a ball of yarn", ) +# Wait for result result = VideoSynthesis.wait(response) + +# Blocking call with timeout +result = VideoSynthesis.call( + model="wan2.7-t2v", + prompt="A cat playing with a ball of yarn", + wait_timeout_seconds=60, +) ``` ### Speech Synthesis (TTS) diff --git a/dashscope/__init__.py b/dashscope/__init__.py index 637d5de..19a32ca 100644 --- a/dashscope/__init__.py +++ b/dashscope/__init__.py @@ -28,6 +28,7 @@ base_compatible_api_url, base_http_api_url, base_websocket_api_url, + trust_env, ) from dashscope.finetune.deployments import Deployments from dashscope.finetune.finetunes import FineTunes @@ -46,7 +47,7 @@ from dashscope.files import Files from dashscope.models import Models from dashscope.nlp.understanding import Understanding -from dashscope.rerank.text_rerank import TextReRank +from dashscope.rerank import AioTextReRank, TextReRank from dashscope.threads import ( MessageFile, Messages, @@ -74,6 +75,7 @@ "base_websocket_api_url", "api_key", "api_key_file_path", + "trust_env", "save_api_key", "AioGeneration", "Conversation", @@ -106,6 +108,7 @@ "list_tokenizers", "Application", "TextReRank", + "AioTextReRank", "Assistants", "Threads", "Messages", diff --git a/dashscope/aigc/image_synthesis.py b/dashscope/aigc/image_synthesis.py index 41bd060..0491d7e 100644 --- a/dashscope/aigc/image_synthesis.py +++ b/dashscope/aigc/image_synthesis.py @@ -387,6 +387,7 @@ def wait( # type: ignore[override] task: Union[str, ImageSynthesisResponse], api_key: str = None, workspace: str = None, + **kwargs, ) -> ImageSynthesisResponse: """Wait for image(s) synthesis task to complete, and return the result. @@ -399,7 +400,12 @@ def wait( # type: ignore[override] Returns: ImageSynthesisResponse: The task result. """ - response = super().wait(task, api_key, workspace=workspace) + response = super().wait( + task, + api_key, + workspace=workspace, + **kwargs, + ) return ImageSynthesisResponse.from_api_response(response) @classmethod diff --git a/dashscope/aigc/video_synthesis.py b/dashscope/aigc/video_synthesis.py index 0bc5032..b976799 100644 --- a/dashscope/aigc/video_synthesis.py +++ b/dashscope/aigc/video_synthesis.py @@ -509,6 +509,7 @@ def wait( # type: ignore[override] task: Union[str, VideoSynthesisResponse], api_key: str = None, workspace: str = None, + **kwargs, ) -> VideoSynthesisResponse: """Wait for video synthesis task to complete, and return the result. @@ -521,7 +522,12 @@ def wait( # type: ignore[override] Returns: VideoSynthesisResponse: The task result. """ - response = super().wait(task, api_key, workspace=workspace) + response = super().wait( + task, + api_key, + workspace=workspace, + **kwargs, + ) return VideoSynthesisResponse.from_api_response(response) @classmethod diff --git a/dashscope/api_entities/aiohttp_request.py b/dashscope/api_entities/aiohttp_request.py index 75b3965..008bfa5 100644 --- a/dashscope/api_entities/aiohttp_request.py +++ b/dashscope/api_entities/aiohttp_request.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. +import asyncio import json from http import HTTPStatus @@ -13,6 +14,7 @@ SSE_CONTENT_TYPE, HTTPMethod, ) +from dashscope.common.env import get_trust_env from dashscope.common.error import UnsupportedHTTPMethod from dashscope.common.logging import logger from dashscope.common.utils import async_to_sync @@ -107,25 +109,38 @@ async def aio_call(self): return result async def _handle_stream(self, response): - # TODO define done message. is_error = False status_code = HTTPStatus.BAD_REQUEST - async for line in response.content: - if line: - line = line.decode("utf8") - line = line.rstrip("\n").rstrip("\r") - if line.startswith("event:error"): - is_error = True - elif line.startswith("status:"): - status_code = line[len("status:") :] - status_code = int(status_code.strip()) - elif line.startswith("data:"): - line = line[len("data:") :] - yield (is_error, status_code, line) - if is_error: - break - else: - continue # ignore heartbeat... + event_type = None + try: + async for line in response.content: + if line: + line = line.decode("utf8") + line = line.rstrip("\n").rstrip("\r") + if line.startswith("event:"): + event_type = line[len("event:") :].strip() + if event_type == "error": + is_error = True + elif line.startswith("status:"): + status_code = line[len("status:") :] + status_code = int(status_code.strip()) + elif line.startswith("data:"): + line = line[len("data:") :] + if event_type == "done": + continue + yield (is_error, status_code, line) + if is_error: + break + else: + continue # ignore heartbeat... + except (aiohttp.ClientError, asyncio.TimeoutError): + logger.exception( + "Stream response interrupted while reading aiohttp SSE " + "response, status_code=%s, request_id=%s", + response.status, + response.headers.get("X-Request-Id"), + ) + raise # pylint: disable=too-many-statements async def _handle_response( # pylint: disable=too-many-branches @@ -249,6 +264,7 @@ async def _handle_request(self): async with aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=self.timeout), headers=self.headers, + trust_env=get_trust_env(), ) as session: logger.debug("Starting request: %s", self.url) if self.method == HTTPMethod.POST: @@ -281,9 +297,23 @@ async def _handle_request(self): async with response: async for rsp in self._handle_response(response): yield rsp - except aiohttp.ClientConnectorError as e: - logger.error(e) - raise e - except Exception as e: - logger.error(e) - raise e + except (aiohttp.ClientError, asyncio.TimeoutError): + logger.exception( + "Aio HTTP request failed, url=%s, method=%s, stream=%s, " + "timeout=%s", + self.url, + self.method, + self.stream, + self.timeout, + ) + raise + except Exception: + logger.exception( + "Unexpected aio HTTP request error, url=%s, method=%s, " + "stream=%s, timeout=%s", + self.url, + self.method, + self.stream, + self.timeout, + ) + raise diff --git a/dashscope/api_entities/dashscope_response.py b/dashscope/api_entities/dashscope_response.py index 71f16b8..58ade6b 100644 --- a/dashscope/api_entities/dashscope_response.py +++ b/dashscope/api_entities/dashscope_response.py @@ -59,7 +59,12 @@ def setattr(self, attr, value): return super().__setitem__(attr, value) def __getattr__(self, attr): - return self[attr] + try: + return self[attr] + except KeyError: + raise AttributeError( + f"{type(self).__name__!r} object has no attribute {attr!r}", + ) from None def __setattr__(self, attr, value): self[attr] = value diff --git a/dashscope/api_entities/http_request.py b/dashscope/api_entities/http_request.py index 85ee959..203092c 100644 --- a/dashscope/api_entities/http_request.py +++ b/dashscope/api_entities/http_request.py @@ -18,6 +18,7 @@ HTTPMethod, ) from dashscope.common.error import UnsupportedHTTPMethod +from dashscope.common.env import get_trust_env from dashscope.common.logging import logger from dashscope.common.utils import ( _handle_aio_stream, @@ -176,6 +177,7 @@ async def _handle_aio_request(self): # pylint: disable=too-many-branches connector=connector, timeout=aiohttp.ClientTimeout(total=self.timeout), headers=self.headers, + trust_env=get_trust_env(), ) should_close = True @@ -223,12 +225,26 @@ async def _handle_aio_request(self): # pylint: disable=too-many-branches # Only close if we created the session if should_close: await session.close() - except aiohttp.ClientConnectorError as e: - logger.error(e) - raise e - except BaseException as e: - logger.error(e) - raise e + except aiohttp.ClientError: + logger.exception( + "Aio HTTP request failed, url=%s, method=%s, stream=%s, " + "timeout=%s", + self.url, + self.method, + self.stream, + self.timeout, + ) + raise + except Exception: + logger.exception( + "Unexpected aio HTTP request error, url=%s, method=%s, " + "stream=%s, timeout=%s", + self.url, + self.method, + self.stream, + self.timeout, + ) + raise @staticmethod def __handle_parameters(params: dict) -> dict: @@ -507,6 +523,23 @@ def _handle_request(self): # Only close if we created the session if should_close: session.close() - except BaseException as e: - logger.error(e) - raise e + except requests.exceptions.RequestException: + logger.exception( + "HTTP request failed, url=%s, method=%s, stream=%s, " + "timeout=%s", + self.url, + self.method, + self.stream, + self.timeout, + ) + raise + except Exception: + logger.exception( + "Unexpected HTTP request error, url=%s, method=%s, " + "stream=%s, timeout=%s", + self.url, + self.method, + self.stream, + self.timeout, + ) + raise diff --git a/dashscope/audio/qwen_tts_realtime/qwen_tts_realtime.py b/dashscope/audio/qwen_tts_realtime/qwen_tts_realtime.py index 7142a66..87fa290 100644 --- a/dashscope/audio/qwen_tts_realtime/qwen_tts_realtime.py +++ b/dashscope/audio/qwen_tts_realtime/qwen_tts_realtime.py @@ -97,13 +97,19 @@ def __init__( self.user_workspace = workspace self.model = model self.config = {} - self.callback = callback + self.callback = callback or QwenTtsRealtimeCallback() self.ws = None + self.thread = None self.session_id = None self.last_message = None self.last_response_id = None self.last_first_text_time = None self.last_first_audio_delay = None + self.last_error = None + self.close_status_code = None + self.close_msg = None + self.session_created_event = threading.Event() + self.websocket_closed_event = threading.Event() self.metrics = [] def _generate_event_id(self): @@ -131,6 +137,11 @@ def connect(self) -> None: """ connect to server, create session and return default session configuration # noqa: E501 """ + self.last_error = None + self.close_status_code = None + self.close_msg = None + self.session_created_event.clear() + self.websocket_closed_event.clear() self.ws = websocket.WebSocketApp( self.url, header=self._get_websocket_header(), @@ -141,23 +152,45 @@ def connect(self) -> None: self.thread = threading.Thread(target=self.ws.run_forever) self.thread.daemon = True self.thread.start() - timeout = 5 # 最长等待时间(秒) + timeout = 5 start_time = time.time() while ( not (self.ws.sock and self.ws.sock.connected) + and not self.websocket_closed_event.is_set() and (time.time() - start_time) < timeout ): - time.sleep(0.1) # 短暂休眠,避免密集轮询 - if not (self.ws.sock and self.ws.sock.connected): + time.sleep(0.1) + if not self._is_websocket_connected(): raise TimeoutError( "websocket connection could not established within 5s. " - "Please check your network connection, firewall settings, or server status.", # noqa: E501 # pylint: disable=line-too-long + f"{self._build_connection_state_message()}", + ) + if not self.session_created_event.wait(timeout): + raise TimeoutError( + "websocket session could not be created within 5s. " + f"{self._build_connection_state_message()}", ) self.callback.on_open() + def _is_websocket_connected(self): + return bool(self.ws and self.ws.sock and self.ws.sock.connected) + + def _build_connection_state_message(self): + return ( + f"close_status_code: {self.close_status_code}, " + f"close_msg: {self.close_msg}, " + f"last_error: {self.last_error}, " + f"last_message: {self.last_message}" + ) + def __send_str(self, data: str, enable_log: bool = True): if enable_log: logger.debug("[qwen tts realtime] send string: %s", data) + if not self._is_websocket_connected(): + raise ConnectionError( + "qwen tts realtime websocket connection is closed. " + f"{self._build_connection_state_message()}", + ) self.ws.send(data) def update_session( @@ -351,6 +384,7 @@ def on_message( # pylint: disable=unused-argument if "type" in message: if "session.created" == json_data["type"]: self.session_id = json_data["session"]["id"] + self.session_created_event.set() if "response.created" == json_data["type"]: self.last_response_id = json_data["response"]["id"] elif "response.audio.delta" == json_data["type"]: @@ -387,8 +421,11 @@ def on_close( # pylint: disable=unused-argument close_status_code, close_msg, ): + self.close_status_code = close_status_code + self.close_msg = close_msg + self.websocket_closed_event.set() logger.debug( - "[omni realtime] connection closed with code %s and message %s", # noqa: E501 + "[qwen tts realtime] connection closed with code %s and message %s", # noqa: E501 close_status_code, close_msg, ) @@ -396,9 +433,8 @@ def on_close( # pylint: disable=unused-argument # WebSocket发生错误的回调函数 def on_error(self, ws, error): # pylint: disable=unused-argument - print(f"websocket closed due to {error}") - # pylint: disable=broad-exception-raised - raise Exception(f"websocket closed due to {error}") + self.last_error = error + logger.error("[qwen tts realtime] websocket closed due to %s", error) # 获取上一个任务的taskId def get_session_id(self): diff --git a/dashscope/client/base_api.py b/dashscope/client/base_api.py index cc52fa2..4ad3192 100644 --- a/dashscope/client/base_api.py +++ b/dashscope/client/base_api.py @@ -20,7 +20,12 @@ TaskStatus, HTTPMethod, ) -from dashscope.common.error import InvalidParameter, InvalidTask, ModelRequired +from dashscope.common.error import ( + InvalidParameter, + InvalidTask, + ModelRequired, + TimeoutException, +) from dashscope.common.logging import logger from dashscope.common.utils import ( _handle_http_failed_response, @@ -148,6 +153,7 @@ async def call( workspace: str = None, **kwargs, ) -> DashScopeAPIResponse: + wait_timeout_seconds = kwargs.pop("wait_timeout_seconds", None) # call request service. response = await BaseAsyncAioApi.async_call( model, @@ -159,11 +165,14 @@ async def call( workspace, **kwargs, ) + wait_kwargs = kwargs.copy() + if wait_timeout_seconds is not None: + wait_kwargs["wait_timeout_seconds"] = wait_timeout_seconds response = await BaseAsyncAioApi.wait( response, api_key=api_key, workspace=workspace, - **kwargs, + **wait_kwargs, ) return response @@ -202,6 +211,10 @@ async def wait( Returns: DashScopeAPIResponse: The async task information. """ + wait_timeout_seconds = kwargs.pop("wait_timeout_seconds", None) + if wait_timeout_seconds is not None: + wait_timeout_seconds = float(wait_timeout_seconds) + start_time = time.monotonic() task_id = cls._get_task_id(task) wait_seconds = 1 max_wait_seconds = 5 @@ -236,6 +249,12 @@ async def wait( return rsp else: logger.info("The task %s is %s", task_id, task_status) + if ( + wait_timeout_seconds is not None + and time.monotonic() - start_time + >= wait_timeout_seconds + ): + raise TimeoutException(f"Wait task {task_id} timeout.") await asyncio.sleep(wait_seconds) # 异步等待 elif rsp.status_code in REPEATABLE_STATUS: logger.warning( @@ -246,6 +265,11 @@ async def wait( rsp.code, rsp.message, ) + if ( + wait_timeout_seconds is not None + and time.monotonic() - start_time >= wait_timeout_seconds + ): + raise TimeoutException(f"Wait task {task_id} timeout.") await asyncio.sleep(wait_seconds) # 异步等待 else: return rsp @@ -432,7 +456,7 @@ async def call( function (str, optional): The function of the task. Defaults to None. api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. api_protocol (str, optional): Api protocol websocket or http. Defaults to None. ws_stream_mode (str, optional): websocket stream mode, @@ -498,7 +522,7 @@ def call( function (str, optional): The function of the task. Defaults to None. api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. api_protocol (str, optional): Api protocol websocket or http. Defaults to None. ws_stream_mode (str, optional): websocket stream mode, @@ -599,16 +623,21 @@ def call( **kwargs, ) -> DashScopeAPIResponse: """Call service and get result.""" + wait_timeout_seconds = kwargs.pop("wait_timeout_seconds", None) task_response = cls.async_call( # type: ignore[misc] *args, api_key=api_key, workspace=workspace, **kwargs, ) + wait_kwargs = {} + if wait_timeout_seconds is not None: + wait_kwargs["wait_timeout_seconds"] = wait_timeout_seconds response = cls.wait( task_response, api_key=api_key, workspace=workspace, + **wait_kwargs, ) return response @@ -778,6 +807,10 @@ def wait( Returns: DashScopeAPIResponse: The async task information. """ + wait_timeout_seconds = kwargs.pop("wait_timeout_seconds", None) + if wait_timeout_seconds is not None: + wait_timeout_seconds = float(wait_timeout_seconds) + start_time = time.monotonic() task_id = cls._get_task_id(task) wait_seconds = 1 max_wait_seconds = 5 @@ -789,8 +822,8 @@ def wait( # the query interval after every 3(increment_steps) # intervals, until we hit the max waiting interval # of 5(seconds) - # TODO: investigate if we can use long-poll - # (server side return immediately when ready) + # Polling is used here because the task status API returns the + # current state for each request. if wait_seconds < max_wait_seconds and step % increment_steps == 0: wait_seconds = min(wait_seconds * 2, max_wait_seconds) rsp = cls._get(task_id, api_key, workspace=workspace, **kwargs) @@ -808,6 +841,12 @@ def wait( return rsp else: logger.info("The task %s is %s", task_id, task_status) + if ( + wait_timeout_seconds is not None + and time.monotonic() - start_time + >= wait_timeout_seconds + ): + raise TimeoutException(f"Wait task {task_id} timeout.") time.sleep(wait_seconds) elif rsp.status_code in REPEATABLE_STATUS: logger.warning( @@ -818,6 +857,11 @@ def wait( rsp.code, rsp.message, ) + if ( + wait_timeout_seconds is not None + and time.monotonic() - start_time >= wait_timeout_seconds + ): + raise TimeoutException(f"Wait task {task_id} timeout.") time.sleep(wait_seconds) else: return rsp @@ -844,7 +888,7 @@ def async_call( function (str, optional): The function of the task. Defaults to None. api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. Returns: DashScopeAPIResponse: The async task information, @@ -987,7 +1031,7 @@ def list( Args: api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. path (str, optional): The path of the api, if not default. page_no (int, optional): Page number. Defaults to 1. page_size (int, optional): Items per page. Defaults to 10. @@ -1063,7 +1107,7 @@ def get( Args: target (str): The target to get, such as model_id. api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. Returns: DashScopeAPIResponse: The object information in output. @@ -1104,7 +1148,7 @@ def get( Args: target (str): The target to get, such as model_id. api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. Returns: DashScopeAPIResponse: The object information in output. @@ -1144,7 +1188,7 @@ def delete( Args: target (str): The object to delete, . api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. Returns: DashScopeAPIResponse: The delete result. @@ -1193,7 +1237,7 @@ def call( Args: data (object): The create request json body. api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. Returns: DashScopeAPIResponse: The created object in output. @@ -1252,7 +1296,7 @@ def update( target (str): The target to update. json (object): The create request json body. api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. Returns: DashScopeAPIResponse: The updated object information in output. @@ -1316,7 +1360,7 @@ def put( target (str): The target to update. json (object): The create request json body. api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. Returns: DashScopeAPIResponse: The updated object information in output. @@ -1368,7 +1412,7 @@ def upload( # pylint: disable=unused-argument descriptions (list[str]): The file description messages. params (dict): The parameters api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. Returns: DashScopeAPIResponse: The uploaded file information in the output. @@ -1418,7 +1462,7 @@ def cancel( Args: target (str): The request params, key/value map. api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. Returns: DashScopeAPIResponse: The cancel result. @@ -1455,25 +1499,38 @@ def cancel( class StreamEventMixin: @classmethod def _handle_stream(cls, response: requests.Response): - # TODO define done message. is_error = False status_code = HTTPStatus.INTERNAL_SERVER_ERROR - for line in response.iter_lines(): - if line: - line = line.decode("utf8") - line = line.rstrip("\n").rstrip("\r") - if line.startswith("event:error"): - is_error = True - elif line.startswith("status:"): - status_code = line[len("status:") :] - status_code = int(status_code.strip()) - elif line.startswith("data:"): - line = line[len("data:") :] - yield (is_error, status_code, line) - if is_error: - break - else: - continue # ignore heartbeat... + event_type = None + try: + for line in response.iter_lines(): + if line: + line = line.decode("utf8") + line = line.rstrip("\n").rstrip("\r") + if line.startswith("event:"): + event_type = line[len("event:") :].strip() + if event_type == "error": + is_error = True + elif line.startswith("status:"): + status_code = line[len("status:") :] + status_code = int(status_code.strip()) + elif line.startswith("data:"): + line = line[len("data:") :] + if event_type == "done": + continue + yield (is_error, status_code, line) + if is_error: + break + else: + continue # ignore heartbeat... + except requests.exceptions.RequestException: + logger.exception( + "Stream response interrupted while reading SSE response, " + "status_code=%s, request_id=%s", + response.status_code, + response.headers.get("X-Request-Id"), + ) + raise @classmethod def _handle_response(cls, response: requests.Response): @@ -1529,7 +1586,7 @@ def stream_events( Args: target (str): The target to get, such as model_id. api_key (str, optional): The api api_key, if not present, - will get by default rule(TODO: api key doc). Defaults to None. + will use the default API key resolution rule. Defaults to None. Returns: DashScopeAPIResponse: The target outputs. diff --git a/dashscope/common/env.py b/dashscope/common/env.py index bf9a0f9..8498cc6 100644 --- a/dashscope/common/env.py +++ b/dashscope/common/env.py @@ -2,6 +2,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. import os +import sys from dashscope.common.constants import ( DASHSCOPE_API_KEY_ENV, @@ -15,6 +16,12 @@ # read the api key from env api_key = os.environ.get(DASHSCOPE_API_KEY_ENV) api_key_file_path = os.environ.get(DASHSCOPE_API_KEY_FILE_PATH_ENV) +trust_env = os.environ.get("DASHSCOPE_TRUST_ENV", "true").lower() in ( + "true", + "1", + "yes", +) + # define api base url, ensure end / base_http_api_url = os.environ.get( @@ -29,3 +36,8 @@ "DASHSCOPE_COMPATIBLE_BASE_URL", f"https://dashscope.aliyuncs.com/compatible-mode/{api_version}", ) + + +def get_trust_env() -> bool: + dashscope_module = sys.modules.get("dashscope") + return bool(getattr(dashscope_module, "trust_env", trust_env)) diff --git a/dashscope/common/utils.py b/dashscope/common/utils.py index d568446..f740c9b 100644 --- a/dashscope/common/utils.py +++ b/dashscope/common/utils.py @@ -229,36 +229,45 @@ def __init__( # pylint: disable=redefined-builtin def _handle_stream(response: requests.Response): - # TODO define done message. is_error = False status_code = HTTPStatus.BAD_REQUEST event = SSEEvent(None, None, None) # type: ignore[arg-type] eventType = None - for line in response.iter_lines(): - if line: - line = line.decode("utf8") - line = line.rstrip("\n").rstrip("\r") - if line.startswith("id:"): - id = line[len("id:") :] # pylint: disable=redefined-builtin - event.id = id.strip() - elif line.startswith("event:"): - eventType = line[len("event:") :] - event.eventType = eventType.strip() - if eventType == "error": - is_error = True - elif line.startswith("status:"): - status_code = line[len("status:") :] - status_code = int(status_code.strip()) - elif line.startswith("data:"): - line = line[len("data:") :] - event.data = line.strip() - if eventType is not None and eventType == "done": - continue - yield (is_error, status_code, event) - if is_error: - break - else: - continue # ignore heartbeat... + + try: + for line in response.iter_lines(): + if line: + line = line.decode("utf8") + line = line.rstrip("\n").rstrip("\r") + if line.startswith("id:"): + event_id = line[len("id:") :] + event.id = event_id.strip() + elif line.startswith("event:"): + eventType = line[len("event:") :].strip() + event.eventType = eventType + if eventType == "error": + is_error = True + elif line.startswith("status:"): + status_code = line[len("status:") :] + status_code = int(status_code.strip()) + elif line.startswith("data:"): + line = line[len("data:") :] + event.data = line.strip() + if eventType is not None and eventType == "done": + continue + yield (is_error, status_code, event) + if is_error: + break + else: + continue # ignore heartbeat... + except requests.exceptions.RequestException: + logger.exception( + "Stream response interrupted while reading SSE response, " + "status_code=%s, request_id=%s", + response.status_code, + response.headers.get("X-Request-Id"), + ) + raise def _handle_error_message(error, status_code, flattened_output, headers): @@ -331,25 +340,38 @@ def _handle_http_failed_response( async def _handle_aio_stream(response): - # TODO define done message. is_error = False status_code = HTTPStatus.BAD_REQUEST - async for line in response.content: - if line: - line = line.decode("utf8") - line = line.rstrip("\n").rstrip("\r") - if line.startswith("event:error"): - is_error = True - elif line.startswith("status:"): - status_code = line[len("status:") :] - status_code = int(status_code.strip()) - elif line.startswith("data:"): - line = line[len("data:") :] - yield (is_error, status_code, line) - if is_error: - break - else: - continue # ignore heartbeat... + event_type = None + try: + async for line in response.content: + if line: + line = line.decode("utf8") + line = line.rstrip("\n").rstrip("\r") + if line.startswith("event:"): + event_type = line[len("event:") :].strip() + if event_type == "error": + is_error = True + elif line.startswith("status:"): + status_code = line[len("status:") :] + status_code = int(status_code.strip()) + elif line.startswith("data:"): + line = line[len("data:") :] + if event_type == "done": + continue + yield (is_error, status_code, line) + if is_error: + break + else: + continue # ignore heartbeat... + except (aiohttp.ClientError, asyncio.TimeoutError): + logger.exception( + "Stream response interrupted while reading aiohttp SSE " + "response, status_code=%s, request_id=%s", + response.status, + response.headers.get("X-Request-Id"), + ) + raise async def _handle_aiohttp_failed_response( diff --git a/dashscope/embeddings/batch_text_embedding.py b/dashscope/embeddings/batch_text_embedding.py index d3af2d1..c2c9b9c 100644 --- a/dashscope/embeddings/batch_text_embedding.py +++ b/dashscope/embeddings/batch_text_embedding.py @@ -143,6 +143,7 @@ def wait( # type: ignore[override] task: Union[str, BatchTextEmbeddingResponse], api_key: str = None, workspace: str = None, + **kwargs, ) -> BatchTextEmbeddingResponse: """Wait for async text embedding task to complete, and return the result. # noqa: E501 @@ -155,7 +156,12 @@ def wait( # type: ignore[override] Returns: AsyncTextEmbeddingResponse: The task result. """ - response = super().wait(task, api_key, workspace=workspace) + response = super().wait( + task, + api_key, + workspace=workspace, + **kwargs, + ) return BatchTextEmbeddingResponse.from_api_response(response) @classmethod diff --git a/dashscope/rerank/__init__.py b/dashscope/rerank/__init__.py index e69de29..2f0829c 100644 --- a/dashscope/rerank/__init__.py +++ b/dashscope/rerank/__init__.py @@ -0,0 +1,6 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +from dashscope.rerank.text_rerank import AioTextReRank, TextReRank + +__all__ = ["AioTextReRank", "TextReRank"] diff --git a/dashscope/rerank/text_rerank.py b/dashscope/rerank/text_rerank.py index b152ae6..3a4b31e 100644 --- a/dashscope/rerank/text_rerank.py +++ b/dashscope/rerank/text_rerank.py @@ -1,13 +1,46 @@ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import List +from typing import Any, Dict, List, Tuple from dashscope.api_entities.dashscope_response import ReRankResponse -from dashscope.client.base_api import BaseApi +from dashscope.client.base_api import BaseAioApi, BaseApi from dashscope.common.error import InputRequired, ModelRequired from dashscope.common.utils import _get_task_group_and_task +__all__ = ["TextReRank", "AioTextReRank"] + + +def _build_rerank_request( + model: str, + query: str, + documents: List[str], + return_documents: bool = None, + top_n: int = None, + instruct: str = None, + **kwargs, +) -> Tuple[str, str, Dict[str, Any], Dict[str, Any]]: + if query is None or documents is None or not documents: + raise InputRequired("query and documents are required!") + if model is None or not model: + raise ModelRequired("Model is required!") + + task_group, function = _get_task_group_and_task(__name__) + rerank_input = { + "query": query, + "documents": documents, + } + parameters = {} + if return_documents is not None: + parameters["return_documents"] = return_documents + if top_n is not None: + parameters["top_n"] = top_n + if instruct is not None: + parameters["instruct"] = instruct + parameters = {**parameters, **kwargs} + + return task_group, function, rerank_input, parameters + class TextReRank(BaseApi): task = "text-rerank" @@ -41,8 +74,8 @@ def call( # type: ignore[override] # pylint: disable=arguments-renamed documents (List[str]): The documents to rank. return_documents(bool, `optional`): enable return origin documents, system default is false. - top_n(int, `optional`): how many documents to return, default return # noqa: E501 - all the documents. + top_n(int, `optional`): how many documents to return, + default return all the documents. api_key (str, optional): The DashScope api key. Defaults to None. instruct (str, optional): Custom task instruction to guide ranking strategy. English recommended. @@ -55,23 +88,15 @@ def call( # type: ignore[override] # pylint: disable=arguments-renamed RerankResponse: The rerank result. """ - if query is None or documents is None or not documents: - raise InputRequired("query and documents are required!") - if model is None or not model: - raise ModelRequired("Model is required!") - task_group, function = _get_task_group_and_task(__name__) - input = { # pylint: disable=redefined-builtin - "query": query, - "documents": documents, - } - parameters = {} - if return_documents is not None: - parameters["return_documents"] = return_documents - if top_n is not None: - parameters["top_n"] = top_n - if instruct is not None: - parameters["instruct"] = instruct - parameters = {**parameters, **kwargs} + task_group, function, rerank_input, parameters = _build_rerank_request( + model=model, + query=query, + documents=documents, + return_documents=return_documents, + top_n=top_n, + instruct=instruct, + **kwargs, + ) response = super().call( model=model, @@ -79,7 +104,73 @@ def call( # type: ignore[override] # pylint: disable=arguments-renamed task=TextReRank.task, function=function, api_key=api_key, - input=input, + input=rerank_input, + **parameters, # type: ignore[arg-type] + ) + + return ReRankResponse.from_api_response(response) + + +class AioTextReRank(BaseAioApi): + task = "text-rerank" + """Async API for rerank models.""" + + Models = TextReRank.Models + + @classmethod + # pylint: disable=arguments-renamed + async def call( # type: ignore[override] + cls, + model: str, + query: str, + documents: List[str], + return_documents: bool = None, + top_n: int = None, + api_key: str = None, + workspace: str = None, + instruct: str = None, + **kwargs, + ) -> ReRankResponse: + """Calling rerank service asynchronously. + + Args: + model (str): The model to use. + query (str): The query string. + documents (List[str]): The documents to rank. + return_documents(bool, `optional`): enable return origin documents, + system default is false. + top_n(int, `optional`): how many documents to return, + default return all the documents. + api_key (str, optional): The DashScope api key. Defaults to None. + workspace (str, optional): The DashScope workspace id. + instruct (str, optional): Custom task instruction to guide + ranking strategy. English recommended. + + Raises: + InputRequired: The query and documents are required. + ModelRequired: The model is required. + + Returns: + RerankResponse: The rerank result. + """ + task_group, function, rerank_input, parameters = _build_rerank_request( + model=model, + query=query, + documents=documents, + return_documents=return_documents, + top_n=top_n, + instruct=instruct, + **kwargs, + ) + + response = await super().call( + model=model, + task_group=task_group, + task=AioTextReRank.task, + function=function, + api_key=api_key, + workspace=workspace, + input=rerank_input, **parameters, # type: ignore[arg-type] ) diff --git a/dashscope/utils/oss_utils.py b/dashscope/utils/oss_utils.py index 216333e..27d2683 100644 --- a/dashscope/utils/oss_utils.py +++ b/dashscope/utils/oss_utils.py @@ -129,6 +129,27 @@ def get_upload_certificate( return super().get(None, api_key, params=params, **kwargs) # type: ignore[return-value] # pylint: disable=line-too-long # noqa: E501 +def _resolve_file_uri_path(file_uri: str): + parse_result = urlparse(file_uri) + netloc = parse_result.netloc + if netloc.lower() in ("localhost", "127.0.0.1"): + netloc = "" + + if netloc: + file_path = netloc + unquote_plus(parse_result.path) + else: + file_path = unquote_plus(parse_result.path) + + if ( + file_path.startswith("/") + and len(file_path) > 2 + and file_path[2] == ":" + ): + file_path = file_path[1:] + + return os.path.expanduser(file_path) + + def upload_file( model: str, upload_path: str, @@ -136,11 +157,7 @@ def upload_file( upload_certificate: dict = None, ): if upload_path.startswith(FILE_PATH_SCHEMA): - parse_result = urlparse(upload_path) - if parse_result.netloc: - file_path = parse_result.netloc + unquote_plus(parse_result.path) - else: - file_path = unquote_plus(parse_result.path) + file_path = _resolve_file_uri_path(upload_path) if os.path.exists(file_path): file_url, _ = OssUtils.upload( model=model, @@ -154,7 +171,7 @@ def upload_file( ) return file_url else: - raise InvalidInput(f"The file: {file_path} is not exists!") + raise InvalidInput(f"The file: {file_path} does not exist!") return None @@ -184,11 +201,7 @@ def check_and_upload_local( is the certificate (newly obtained or passed in) """ if content.startswith(FILE_PATH_SCHEMA): - parse_result = urlparse(content) - if parse_result.netloc: - file_path = parse_result.netloc + unquote_plus(parse_result.path) - else: - file_path = unquote_plus(parse_result.path) + file_path = _resolve_file_uri_path(content) if os.path.isfile(file_path): file_url, cert = OssUtils.upload( model=model, @@ -201,9 +214,10 @@ def check_and_upload_local( f"Uploading file: {content} failed", ) return True, file_url, cert - elif content.startswith("oss://"): + raise InvalidInput(f"The file: {file_path} does not exist!") + if content.startswith("oss://"): return True, content, upload_certificate - elif not content.startswith("http"): + if not content.startswith("http"): content = os.path.expanduser(content) if os.path.isfile(content): file_url, cert = OssUtils.upload( diff --git a/tests/unit/test_async_custom_session.py b/tests/unit/test_async_custom_session.py index 180b478..012b77e 100644 --- a/tests/unit/test_async_custom_session.py +++ b/tests/unit/test_async_custom_session.py @@ -24,6 +24,7 @@ import certifi import pytest +import dashscope from dashscope.api_entities.http_request import HttpRequest from dashscope.api_entities.api_request_data import ApiRequestData from dashscope.common.constants import ApiProtocol, HTTPMethod @@ -577,6 +578,57 @@ async def mock_handle_response(_response): # 验证临时 aio_session 被关闭(原有行为) mock_session.close.assert_called_once() + @pytest.mark.asyncio + async def test_temporary_aio_session_uses_global_trust_env(self): + """测试临时 aio_session 会使用全局 trust_env 配置""" + mock_session = AsyncMock() + mock_response = AsyncMock() + mock_response.status = 200 + mock_response.headers = {"content-type": "application/json"} + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + mock_session.request.return_value = mock_response + + http_request = HttpRequest( + url="http://example.com/api", + api_key="fake-api-key", + http_method=HTTPMethod.POST, + stream=False, + ) + http_request.data = ApiRequestData( + model="test-model", + task_group="test", + task="test", + function="test", + input_data={"test": "data"}, + form=None, + is_binary_input=False, + api_protocol=ApiProtocol.HTTPS, + ) + + original_trust_env = dashscope.trust_env + dashscope.trust_env = False + try: + + async def mock_handle_response(_response): + yield mock_response + + with patch( + "aiohttp.ClientSession", + return_value=mock_session, + ) as session_class: + with patch.object( + http_request, + "_handle_aio_response", + side_effect=mock_handle_response, + ): + _ = await http_request.aio_call() + + session_class.assert_called_once() + assert session_class.call_args.kwargs["trust_env"] is False + finally: + dashscope.trust_env = original_trust_env + class TestAsyncSessionLifecycle: """测试异步 Session 生命周期""" diff --git a/tests/unit/test_async_task_wait_timeout.py b/tests/unit/test_async_task_wait_timeout.py new file mode 100644 index 0000000..5421905 --- /dev/null +++ b/tests/unit/test_async_task_wait_timeout.py @@ -0,0 +1,269 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +from http import HTTPStatus +from unittest.mock import AsyncMock, patch + +import pytest + +from dashscope.aigc.image_synthesis import ImageSynthesis +from dashscope.aigc.video_synthesis import VideoSynthesis +from dashscope.api_entities.dashscope_response import DashScopeAPIResponse +from dashscope.client.base_api import BaseAsyncAioApi, BaseAsyncApi +from dashscope.embeddings.batch_text_embedding import BatchTextEmbedding +from dashscope.common.constants import TaskStatus +from dashscope.common.error import TimeoutException + + +class TimeoutWaitTestAsyncApi(BaseAsyncApi): + pass + + +class TimeoutCallTestAsyncApi(BaseAsyncApi): + captured_async_call_kwargs = {} + captured_wait_kwargs = {} + + @classmethod + def async_call(cls, *_args, **kwargs): + cls.captured_async_call_kwargs = kwargs + return DashScopeAPIResponse( + request_id="request-id", + status_code=HTTPStatus.OK, + code=None, + output={"task_id": "task-id"}, + usage=None, + message="", + ) + + @classmethod + def wait(cls, task, api_key=None, workspace=None, **kwargs): + cls.captured_wait_kwargs = kwargs + return task + + +class LegacyWaitSignatureTestAsyncApi(BaseAsyncApi): + @classmethod + def async_call(cls, *_args, **_kwargs): + return DashScopeAPIResponse( + request_id="request-id", + status_code=HTTPStatus.OK, + code=None, + output={"task_id": "task-id"}, + usage=None, + message="", + ) + + @classmethod + def wait(cls, task, api_key=None, workspace=None): + return task + + +class TimeoutTestAsyncAioApi(BaseAsyncAioApi): + pass + + +@pytest.fixture(autouse=True) +def reset_timeout_test_api(): + TimeoutCallTestAsyncApi.captured_async_call_kwargs = {} + TimeoutCallTestAsyncApi.captured_wait_kwargs = {} + + +class TestAsyncTaskWaitTimeout: + def test_base_async_api_wait_raises_timeout(self): + response = DashScopeAPIResponse( + request_id="request-id", + status_code=HTTPStatus.OK, + code=None, + output={"task_status": TaskStatus.RUNNING}, + usage=None, + message="", + ) + + with patch.object( + TimeoutWaitTestAsyncApi, + "_get", + return_value=response, + ): + with pytest.raises(TimeoutException): + TimeoutWaitTestAsyncApi.wait("task-id", wait_timeout_seconds=0) + + def test_base_async_api_wait_accepts_string_timeout(self): + response = DashScopeAPIResponse( + request_id="request-id", + status_code=HTTPStatus.OK, + code=None, + output={"task_status": TaskStatus.RUNNING}, + usage=None, + message="", + ) + + with patch.object( + TimeoutWaitTestAsyncApi, + "_get", + return_value=response, + ): + with pytest.raises(TimeoutException): + TimeoutWaitTestAsyncApi.wait( + "task-id", + wait_timeout_seconds="0", + ) + + @pytest.mark.asyncio + async def test_base_async_aio_api_wait_raises_timeout(self): + response = DashScopeAPIResponse( + request_id="request-id", + status_code=HTTPStatus.OK, + code=None, + output={"task_status": TaskStatus.RUNNING}, + usage=None, + message="", + ) + + with patch.object( + TimeoutTestAsyncAioApi, + "_get", + AsyncMock(return_value=response), + ): + with pytest.raises(TimeoutException): + await TimeoutTestAsyncAioApi.wait( + "task-id", + wait_timeout_seconds=0, + ) + + @pytest.mark.asyncio + async def test_base_async_aio_api_wait_accepts_string_timeout(self): + response = DashScopeAPIResponse( + request_id="request-id", + status_code=HTTPStatus.OK, + code=None, + output={"task_status": TaskStatus.RUNNING}, + usage=None, + message="", + ) + + with patch.object( + TimeoutTestAsyncAioApi, + "_get", + AsyncMock(return_value=response), + ): + with pytest.raises(TimeoutException): + await TimeoutTestAsyncAioApi.wait( + "task-id", + wait_timeout_seconds="0", + ) + + def test_base_async_call_does_not_pass_default_wait_timeout( + self, + ): + response = LegacyWaitSignatureTestAsyncApi.call( + "model", + "input", + api_key="api-key", + ) + + assert response.output["task_id"] == "task-id" + + def test_base_async_call_excludes_wait_timeout_from_request( + self, + ): + response = TimeoutCallTestAsyncApi.call( + "model", + "input", + api_key="api-key", + wait_timeout_seconds=10, + custom_param="custom-value", + ) + + assert response.output["task_id"] == "task-id" + assert ( + "wait_timeout_seconds" + not in TimeoutCallTestAsyncApi.captured_async_call_kwargs + ) + assert ( + TimeoutCallTestAsyncApi.captured_async_call_kwargs["custom_param"] + == "custom-value" + ) + assert ( + TimeoutCallTestAsyncApi.captured_wait_kwargs[ + "wait_timeout_seconds" + ] + == 10 + ) + + @pytest.mark.asyncio + async def test_base_async_aio_call_excludes_wait_timeout_from_request( + self, + ): + async_call_response = DashScopeAPIResponse( + request_id="request-id", + status_code=HTTPStatus.OK, + code=None, + output={"task_id": "task-id"}, + usage=None, + message="", + ) + wait_response = DashScopeAPIResponse( + request_id="request-id", + status_code=HTTPStatus.OK, + code=None, + output={"task_status": TaskStatus.SUCCEEDED}, + usage=None, + message="", + ) + + with patch.object( + BaseAsyncAioApi, + "async_call", + AsyncMock(return_value=async_call_response), + ) as async_call_mock: + with patch.object( + BaseAsyncAioApi, + "wait", + AsyncMock(return_value=wait_response), + ) as wait_mock: + response = await BaseAsyncAioApi.call( + "model", + "input", + "task-group", + api_key="api-key", + wait_timeout_seconds=10, + custom_param="custom-value", + ) + + assert response is wait_response + assert "wait_timeout_seconds" not in async_call_mock.call_args.kwargs + assert ( + async_call_mock.call_args.kwargs["custom_param"] == "custom-value" + ) + assert wait_mock.call_args.kwargs["wait_timeout_seconds"] == 10 + + @pytest.mark.parametrize( + "api_class", + [ImageSynthesis, VideoSynthesis, BatchTextEmbedding], + ) + def test_overridden_wait_accepts_wait_timeout( + self, + api_class, + ): + wait_response = DashScopeAPIResponse( + request_id="request-id", + status_code=HTTPStatus.BAD_REQUEST, + code="InvalidParameter", + output=None, + usage=None, + message="invalid parameter", + ) + + with patch.object( + BaseAsyncApi, + "wait", + return_value=wait_response, + ) as wait_mock: + response = api_class.wait( + "task-id", + api_key="api-key", + wait_timeout_seconds=10, + ) + + assert response.status_code == HTTPStatus.BAD_REQUEST + assert wait_mock.call_args.kwargs["wait_timeout_seconds"] == 10 diff --git a/tests/unit/test_dashscope_response.py b/tests/unit/test_dashscope_response.py new file mode 100644 index 0000000..ed7a135 --- /dev/null +++ b/tests/unit/test_dashscope_response.py @@ -0,0 +1,21 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +from dashscope.api_entities.dashscope_response import DictMixin + + +class TestDictMixin: + def test_getattr_missing_key_raises_attribute_error(self): + response = DictMixin(existing="value") + + try: + response.missing + except AttributeError: + return + + raise AssertionError("Missing attribute should raise AttributeError") + + def test_getattr_existing_key_returns_value(self): + response = DictMixin(existing="value") + + assert response.existing == "value" diff --git a/tests/unit/test_oss_utils.py b/tests/unit/test_oss_utils.py new file mode 100644 index 0000000..2d740ce --- /dev/null +++ b/tests/unit/test_oss_utils.py @@ -0,0 +1,189 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Alibaba, Inc. and its affiliates. + +from http import HTTPStatus + +import pytest + +from dashscope.common.error import InvalidInput +from dashscope.utils import oss_utils +from dashscope.utils.oss_utils import OssUtils + + +class FakeUploadResponse: + status_code = HTTPStatus.OK + headers = {} + + +class FakeSession: + captured_file = None + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + return False + + def post(self, url, files, data, headers, timeout): + assert url == "https://oss.example.com" + assert data["key"] == "test-dir/dogs.jpg" + assert headers["Accept"] == "application/json" + assert timeout == 3600 + + FakeSession.captured_file = files["file"] + assert not FakeSession.captured_file.closed + return FakeUploadResponse() + + +class TestOssUtils: + def test_upload_closes_opened_file(self, monkeypatch): + upload_certificate = { + "oss_access_key_id": "access-key-id", + "signature": "signature", + "policy": "policy", + "upload_dir": "test-dir", + "x_oss_object_acl": "private", + "x_oss_forbid_overwrite": "true", + "upload_host": "https://oss.example.com", + } + FakeSession.captured_file = None + monkeypatch.setattr(oss_utils.requests, "Session", FakeSession) + + file_url, returned_certificate = OssUtils.upload( + model="test-model", + file_path="tests/data/dogs.jpg", + api_key="test-api-key", + upload_certificate=upload_certificate, + ) + + assert file_url == "oss://test-dir/dogs.jpg" + assert returned_certificate is upload_certificate + assert FakeSession.captured_file is not None + assert FakeSession.captured_file.closed + + def test_check_and_upload_local_uploads_relative_file_uri( + self, + monkeypatch, + ): + captured_file_path = {} + + def fake_isfile(file_path): + captured_file_path["value"] = file_path + return True + + def fake_upload(model, file_path, api_key, upload_certificate): + assert model == "test-model" + assert api_key == "test-api-key" + assert upload_certificate == {"cert": "value"} + assert file_path == "test_video_frames/frame_0000.jpg" + return "oss://test-dir/frame_0000.jpg", {"cert": "value"} + + monkeypatch.setattr(oss_utils.os.path, "isfile", fake_isfile) + monkeypatch.setattr(OssUtils, "upload", fake_upload) + + is_upload, file_url, certificate = oss_utils.check_and_upload_local( + model="test-model", + content="file://test_video_frames/frame_0000.jpg", + api_key="test-api-key", + upload_certificate={"cert": "value"}, + ) + + assert is_upload + assert file_url == "oss://test-dir/frame_0000.jpg" + assert certificate == {"cert": "value"} + assert ( + captured_file_path["value"] == "test_video_frames/frame_0000.jpg" + ) + + def test_check_and_upload_local_supports_windows_absolute_file_uri( + self, + monkeypatch, + ): + captured_file_path = {} + + def fake_isfile(file_path): + captured_file_path["value"] = file_path + return True + + def fake_upload( + model, + file_path, + api_key, + upload_certificate, + ): + assert model == "test-model" + assert file_path == "C:/Users/test/frame_0000.jpg" + assert api_key == "test-api-key" + return "oss://test-dir/frame_0000.jpg", upload_certificate + + monkeypatch.setattr(oss_utils.os.path, "isfile", fake_isfile) + monkeypatch.setattr(OssUtils, "upload", fake_upload) + + is_upload, file_url, _ = oss_utils.check_and_upload_local( + model="test-model", + content="file:///C:/Users/test/frame_0000.jpg", + api_key="test-api-key", + ) + + assert is_upload + assert file_url == "oss://test-dir/frame_0000.jpg" + assert captured_file_path["value"] == "C:/Users/test/frame_0000.jpg" + + @pytest.mark.parametrize( + "file_uri", + [ + "file://localhost/home/user/frame_0000.jpg", + "file://127.0.0.1/home/user/frame_0000.jpg", + ], + ) + def test_check_and_upload_local_treats_loopback_host_as_local_path( + self, + monkeypatch, + file_uri, + ): + captured_file_path = {} + + def fake_isfile(file_path): + captured_file_path["value"] = file_path + return True + + def fake_upload( + model, + file_path, + api_key, + upload_certificate, + ): + assert model == "test-model" + assert file_path == "/home/user/frame_0000.jpg" + assert api_key == "test-api-key" + return "oss://test-dir/frame_0000.jpg", upload_certificate + + monkeypatch.setattr(oss_utils.os.path, "isfile", fake_isfile) + monkeypatch.setattr(OssUtils, "upload", fake_upload) + + is_upload, file_url, _ = oss_utils.check_and_upload_local( + model="test-model", + content=file_uri, + api_key="test-api-key", + ) + + assert is_upload + assert file_url == "oss://test-dir/frame_0000.jpg" + assert captured_file_path["value"] == "/home/user/frame_0000.jpg" + + def test_check_and_upload_local_raises_when_file_uri_not_found( + self, + monkeypatch, + ): + monkeypatch.setattr( + oss_utils.os.path, + "isfile", + lambda file_path: False, + ) + + with pytest.raises(InvalidInput): + oss_utils.check_and_upload_local( + model="test-model", + content="file://missing/frame_0000.jpg", + api_key="test-api-key", + ) diff --git a/tests/unit/test_rerank.py b/tests/unit/test_rerank.py index d2afd34..b742794 100644 --- a/tests/unit/test_rerank.py +++ b/tests/unit/test_rerank.py @@ -1,10 +1,11 @@ # -*- coding: utf-8 -*- # Copyright (c) Alibaba, Inc. and its affiliates. +import asyncio import json import uuid -from dashscope import TextReRank +from dashscope import AioTextReRank, TextReRank from tests.unit.mock_request_base import MockServerBase from tests.unit.mock_server import MockServer @@ -62,3 +63,61 @@ def test_call(self, mock_server: MockServer): assert len(response.output["results"]) == 2 assert response.output["results"][0]["index"] == 1 assert response.output["results"][1]["document"]["text"] == "黑龙江离俄罗斯很近" + + def test_aio_call(self, mock_server: MockServer): + response_body = { + "output": { + "results": [ + { + "index": 1, + "relevance_score": 0.987654, + "document": { + "text": "哈尔滨是中国黑龙江省的省会,位于中国东北", + }, + }, + { + "index": 0, + "relevance_score": 0.876543, + "document": { + "text": "黑龙江离俄罗斯很近", + }, + }, + ], + }, + "usage": { + "input_tokens": 1279, + }, + "request_id": "b042e72d-7994-97dd-b3d2-7ee7e0140525", + } + mock_server.responses.put(json.dumps(response_body)) + model = str(uuid.uuid4()) + query = str(uuid.uuid4()) + documents = [ + str(uuid.uuid4()), + str(uuid.uuid4()), + str(uuid.uuid4()), + str(uuid.uuid4()), + ] + + response = asyncio.run( + AioTextReRank.call( + model=model, + query=query, + documents=documents, + return_documents=True, + top_n=2, + instruct="Rank the documents by relevance.", + ), + ) + + req = mock_server.requests.get(block=True) + assert req["path"] == "/api/v1/services/rerank/text-rerank/text-rerank" + assert req["body"]["parameters"] == { + "return_documents": True, + "top_n": 2, + "instruct": "Rank the documents by relevance.", + } + assert req["body"]["input"] == {"query": query, "documents": documents} + assert response.usage["input_tokens"] == 1279 + assert len(response.output["results"]) == 2 + assert response.output["results"][0]["index"] == 1