diff --git a/dashscope/__init__.py b/dashscope/__init__.py index 92cd259..672f24c 100644 --- a/dashscope/__init__.py +++ b/dashscope/__init__.py @@ -21,6 +21,7 @@ HttpSpeechSynthesizer, ) from dashscope.audio.tts.speech_synthesizer import SpeechSynthesizer +from dashscope.api_entities.aio_session import close_shared_aio_session from dashscope.common.api_key import save_api_key from dashscope.common.env import ( api_key, @@ -75,6 +76,7 @@ "api_key", "api_key_file_path", "save_api_key", + "close_shared_aio_session", "AioGeneration", "Conversation", "Generation", diff --git a/dashscope/api_entities/aiohttp_request.py b/dashscope/api_entities/aiohttp_request.py index 598be25..e5487b5 100644 --- a/dashscope/api_entities/aiohttp_request.py +++ b/dashscope/api_entities/aiohttp_request.py @@ -3,6 +3,7 @@ import json from http import HTTPStatus +from typing import Optional import aiohttp @@ -31,6 +32,7 @@ def __init__( timeout: int = DEFAULT_REQUEST_TIMEOUT_SECONDS, task_id: str = None, user_agent: str = "", + session: Optional[aiohttp.ClientSession] = None, ) -> None: """HttpSSERequest, processing http server sent event stream. @@ -45,16 +47,20 @@ def __init__( Defaults to DEFAULT_REQUEST_TIMEOUT_SECONDS. user_agent (str, optional): Additional user agent string to append. Defaults to ''. + session (aiohttp.ClientSession, optional): External aiohttp + session to use instead of the shared session. The caller is + responsible for closing it. Defaults to None. """ super().__init__(user_agent=user_agent) self.url = url self.async_request = async_request + self._external_aio_session = session self.headers = { "Accept": "application/json", "Authorization": f"Bearer {api_key}", "Cache-Control": "no-cache", - **self.headers, + **self.headers, # type: ignore[has-type] } self.query = query if self.async_request and self.query is False: @@ -247,52 +253,67 @@ async def _handle_response( # pylint: disable=too-many-branches message=msg.decode("utf-8"), ) + # pylint: disable=too-many-branches async def _handle_request(self): - session = await get_shared_aio_session() - if self.stream: - request_timeout = aiohttp.ClientTimeout( - total=None, - sock_read=self.timeout, - ) - else: - request_timeout = aiohttp.ClientTimeout(total=self.timeout) + try: + if self._external_aio_session is not None: + session = self._external_aio_session + should_close = False + else: + session = await get_shared_aio_session() + should_close = False - logger.debug("Starting request: %s", self.url) - if self.method == HTTPMethod.POST: - is_form, obj = False, {} - if hasattr(self, "data") and self.data is not None: - is_form, obj = self.data.get_aiohttp_payload() - if is_form: - headers = {**self.headers, **obj.headers} - response = await session.post( - url=self.url, - data=obj, - headers=headers, - timeout=request_timeout, + if self.stream: + request_timeout = aiohttp.ClientTimeout( + total=None, + sock_read=self.timeout, ) else: - response = await session.request( - "POST", - url=self.url, - json=obj, - headers=self.headers, - timeout=request_timeout, - ) - elif self.method == HTTPMethod.GET: - params = {} - if hasattr(self, "data") and self.data is not None: - params = getattr(self.data, "parameters", {}) - response = await session.get( - url=self.url, - params=params, - headers=self.headers, - timeout=request_timeout, - ) - else: - raise UnsupportedHTTPMethod( - f"Unsupported http method: {self.method}", - ) - logger.debug("Response returned: %s", self.url) - async with response: - async for rsp in self._handle_response(response): - yield rsp + request_timeout = aiohttp.ClientTimeout(total=self.timeout) + + try: + logger.debug("Starting request: %s", self.url) + if self.method == HTTPMethod.POST: + is_form, obj = False, {} + if hasattr(self, "data") and self.data is not None: + is_form, obj = self.data.get_aiohttp_payload() + if is_form: + headers = {**self.headers, **obj.headers} + response = await session.post( + url=self.url, + data=obj, + headers=headers, + timeout=request_timeout, + ) + else: + response = await session.request( + "POST", + url=self.url, + json=obj, + headers=self.headers, + timeout=request_timeout, + ) + elif self.method == HTTPMethod.GET: + params = {} + if hasattr(self, "data") and self.data is not None: + params = getattr(self.data, "parameters", {}) + response = await session.get( + url=self.url, + params=params, + headers=self.headers, + timeout=request_timeout, + ) + else: + raise UnsupportedHTTPMethod( + f"Unsupported http method: {self.method}", + ) + logger.debug("Response returned: %s", self.url) + async with response: + async for rsp in self._handle_response(response): + yield rsp + finally: + if should_close: + await session.close() + except Exception as e: + logger.debug(e) + raise e diff --git a/dashscope/api_entities/http_request.py b/dashscope/api_entities/http_request.py index 3142328..1208bbc 100644 --- a/dashscope/api_entities/http_request.py +++ b/dashscope/api_entities/http_request.py @@ -161,62 +161,72 @@ async def aio_call(self): return result async def _handle_aio_request(self): # pylint: disable=too-many-branches - # Use external aio_session if provided, - # otherwise use shared session with connection pooling - if self._external_aio_session is not None: - session = self._external_aio_session - else: - session = await get_shared_aio_session() + try: + # Use external aio_session if provided, + # otherwise use shared session with connection pooling + if self._external_aio_session is not None: + session = self._external_aio_session + should_close = False + else: + session = await get_shared_aio_session() + should_close = False - if self.stream: - request_timeout = aiohttp.ClientTimeout( - total=None, - sock_read=self.timeout, - ) - else: - request_timeout = aiohttp.ClientTimeout(total=self.timeout) + try: + if self.stream: + request_timeout = aiohttp.ClientTimeout( + total=None, + sock_read=self.timeout, + ) + else: + request_timeout = aiohttp.ClientTimeout(total=self.timeout) - logger.debug("Starting request: %s", self.url) - if self.method == HTTPMethod.POST: - is_form, obj = False, {} - if hasattr(self, "data") and self.data is not None: - is_form, obj = self.data.get_aiohttp_payload() - if is_form: - headers = {**self.headers, **obj.headers} - response = await session.post( - url=self.url, - data=obj, - headers=headers, - timeout=request_timeout, - ) - else: - response = await session.request( - "POST", - url=self.url, - json=obj, - headers=self.headers, - timeout=request_timeout, - ) - elif self.method == HTTPMethod.GET: - params = {} - if hasattr(self, "data") and self.data is not None: - params = getattr(self.data, "parameters", {}) - if params: - params = self.__handle_parameters(params) - response = await session.get( - url=self.url, - params=params, - headers=self.headers, - timeout=request_timeout, - ) - else: - raise UnsupportedHTTPMethod( - f"Unsupported http method: {self.method}", - ) - logger.debug("Response returned: %s", self.url) - async with response: - async for rsp in self._handle_aio_response(response): - yield rsp + logger.debug("Starting request: %s", self.url) + if self.method == HTTPMethod.POST: + is_form, obj = False, {} + if hasattr(self, "data") and self.data is not None: + is_form, obj = self.data.get_aiohttp_payload() + if is_form: + headers = {**self.headers, **obj.headers} + response = await session.post( + url=self.url, + data=obj, + headers=headers, + timeout=request_timeout, + ) + else: + response = await session.request( + "POST", + url=self.url, + json=obj, + headers=self.headers, + timeout=request_timeout, + ) + elif self.method == HTTPMethod.GET: + params = {} + if hasattr(self, "data") and self.data is not None: + params = getattr(self.data, "parameters", {}) + if params: + params = self.__handle_parameters(params) + response = await session.get( + url=self.url, + params=params, + headers=self.headers, + timeout=request_timeout, + ) + else: + raise UnsupportedHTTPMethod( + f"Unsupported http method: {self.method}", + ) + logger.debug("Response returned: %s", self.url) + async with response: + async for rsp in self._handle_aio_response(response): + yield rsp + finally: + if should_close: + await session.close() + except Exception as e: + logger.debug(e) + raise e @staticmethod def __handle_parameters(params: dict) -> dict: @@ -445,57 +455,61 @@ def _handle_response( # pylint: disable=too-many-branches else: yield _handle_http_failed_response(response) - def _handle_request(self): - # Use external session if provided, - # otherwise create temporary session - if self._external_session is not None: - session = self._external_session - should_close = False - else: - session = requests.Session() - should_close = True - + def _handle_request(self): # pylint: disable=too-many-branches try: - if self.method == HTTPMethod.POST: - is_form, form, obj = False, None, {} - if hasattr(self, "data") and self.data is not None: - is_form, form, obj = self.data.get_http_payload() - if is_form: - headers = {**self.headers} - headers.pop("Content-Type") - response = session.post( + # Use external session if provided, + # otherwise create temporary session + if self._external_session is not None: + session = self._external_session + should_close = False + else: + session = requests.Session() + should_close = True + + try: + if self.method == HTTPMethod.POST: + is_form, form, obj = False, None, {} + if hasattr(self, "data") and self.data is not None: + is_form, form, obj = self.data.get_http_payload() + if is_form: + headers = {**self.headers} + headers.pop("Content-Type") + response = session.post( + url=self.url, + data=obj, + files=form, + headers=headers, + timeout=self.timeout, + ) + else: + logger.debug("Request body: %s", obj) + response = session.post( + url=self.url, + stream=self.stream, + json=obj, + headers={**self.headers}, + timeout=self.timeout, + ) + elif self.method == HTTPMethod.GET: + params = {} + if hasattr(self, "data") and self.data is not None: + params = getattr(self.data, "parameters", {}) + response = session.get( url=self.url, - data=obj, - files=form, - headers=headers, + params=params, + headers=self.headers, timeout=self.timeout, ) else: - logger.debug("Request body: %s", obj) - response = session.post( - url=self.url, - stream=self.stream, - json=obj, - headers={**self.headers}, - timeout=self.timeout, + raise UnsupportedHTTPMethod( + f"Unsupported http method: {self.method}", ) - elif self.method == HTTPMethod.GET: - params = {} - if hasattr(self, "data") and self.data is not None: - params = getattr(self.data, "parameters", {}) - response = session.get( - url=self.url, - params=params, - headers=self.headers, - timeout=self.timeout, - ) - else: - raise UnsupportedHTTPMethod( - f"Unsupported http method: {self.method}", - ) - for rsp in self._handle_response(response): - yield rsp - finally: - # Only close if we created the session - if should_close: - session.close() + for rsp in self._handle_response(response): + yield rsp + finally: + # Only close if we created the session + if should_close: + session.close() + except Exception as e: + logger.debug(e) + raise e