Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions dashscope/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
HttpSpeechSynthesizer,
)
from dashscope.audio.tts.speech_synthesizer import SpeechSynthesizer
from dashscope.api_entities.aio_session import close_shared_aio_session
from dashscope.common.api_key import save_api_key
from dashscope.common.env import (
api_key,
Expand Down Expand Up @@ -75,6 +76,7 @@
"api_key",
"api_key_file_path",
"save_api_key",
"close_shared_aio_session",
"AioGeneration",
"Conversation",
"Generation",
Expand Down
113 changes: 67 additions & 46 deletions dashscope/api_entities/aiohttp_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import json
from http import HTTPStatus
from typing import Optional

import aiohttp

Expand Down Expand Up @@ -31,6 +32,7 @@ def __init__(
timeout: int = DEFAULT_REQUEST_TIMEOUT_SECONDS,
task_id: str = None,
user_agent: str = "",
session: Optional[aiohttp.ClientSession] = None,
) -> None:
"""HttpSSERequest, processing http server sent event stream.

Expand All @@ -45,16 +47,20 @@ def __init__(
Defaults to DEFAULT_REQUEST_TIMEOUT_SECONDS.
user_agent (str, optional): Additional user agent string to
append. Defaults to ''.
session (aiohttp.ClientSession, optional): External aiohttp
session to use instead of the shared session. The caller is
responsible for closing it. Defaults to None.
"""

super().__init__(user_agent=user_agent)
self.url = url
self.async_request = async_request
self._external_aio_session = session
self.headers = {
"Accept": "application/json",
"Authorization": f"Bearer {api_key}",
"Cache-Control": "no-cache",
**self.headers,
**self.headers, # type: ignore[has-type]
}
self.query = query
if self.async_request and self.query is False:
Expand Down Expand Up @@ -247,52 +253,67 @@ async def _handle_response( # pylint: disable=too-many-branches
message=msg.decode("utf-8"),
)

# pylint: disable=too-many-branches
async def _handle_request(self):
session = await get_shared_aio_session()
if self.stream:
request_timeout = aiohttp.ClientTimeout(
total=None,
sock_read=self.timeout,
)
else:
request_timeout = aiohttp.ClientTimeout(total=self.timeout)
try:
if self._external_aio_session is not None:
session = self._external_aio_session
should_close = False
else:
session = await get_shared_aio_session()
should_close = False

logger.debug("Starting request: %s", self.url)
if self.method == HTTPMethod.POST:
is_form, obj = False, {}
if hasattr(self, "data") and self.data is not None:
is_form, obj = self.data.get_aiohttp_payload()
if is_form:
headers = {**self.headers, **obj.headers}
response = await session.post(
url=self.url,
data=obj,
headers=headers,
timeout=request_timeout,
if self.stream:
request_timeout = aiohttp.ClientTimeout(
total=None,
sock_read=self.timeout,
)
else:
response = await session.request(
"POST",
url=self.url,
json=obj,
headers=self.headers,
timeout=request_timeout,
)
elif self.method == HTTPMethod.GET:
params = {}
if hasattr(self, "data") and self.data is not None:
params = getattr(self.data, "parameters", {})
response = await session.get(
url=self.url,
params=params,
headers=self.headers,
timeout=request_timeout,
)
else:
raise UnsupportedHTTPMethod(
f"Unsupported http method: {self.method}",
)
logger.debug("Response returned: %s", self.url)
async with response:
async for rsp in self._handle_response(response):
yield rsp
request_timeout = aiohttp.ClientTimeout(total=self.timeout)

try:
logger.debug("Starting request: %s", self.url)
if self.method == HTTPMethod.POST:
is_form, obj = False, {}
if hasattr(self, "data") and self.data is not None:
is_form, obj = self.data.get_aiohttp_payload()
if is_form:
headers = {**self.headers, **obj.headers}
response = await session.post(
url=self.url,
data=obj,
headers=headers,
timeout=request_timeout,
)
else:
response = await session.request(
"POST",
url=self.url,
json=obj,
headers=self.headers,
timeout=request_timeout,
)
elif self.method == HTTPMethod.GET:
params = {}
if hasattr(self, "data") and self.data is not None:
params = getattr(self.data, "parameters", {})
response = await session.get(
url=self.url,
params=params,
headers=self.headers,
timeout=request_timeout,
)
else:
raise UnsupportedHTTPMethod(
f"Unsupported http method: {self.method}",
)
logger.debug("Response returned: %s", self.url)
async with response:
async for rsp in self._handle_response(response):
yield rsp
finally:
if should_close:
await session.close()
except Exception as e:
logger.debug(e)
raise e
Loading
Loading