|
1 | 1 | from dataclasses import dataclass |
2 | 2 | from typing import Dict, List, Optional, Any, Union |
| 3 | +import asyncio |
| 4 | +import time |
3 | 5 |
|
4 | 6 | import telegramify_markdown |
5 | 7 | from aiogram import types |
6 | | -from aiohttp import ClientSession, ClientTimeout, ClientResponse |
| 8 | +from aiohttp import ClientSession, ClientTimeout, ClientResponse, ClientError, ServerTimeoutError, ClientConnectorError |
| 9 | +from aiohttp.client_exceptions import ClientResponseError, ClientPayloadError, ServerDisconnectedError |
7 | 10 | from orjson import dumps, loads |
8 | 11 |
|
9 | 12 | from manager import manager |
@@ -86,56 +89,192 @@ class ModelDescription: |
86 | 89 |
|
87 | 90 |
|
88 | 91 | async def _api_request(url: str, data: Dict[str, Any], proxy_token: str) -> Dict[str, Any]: |
89 | | - """Make API request to LLM provider and handle common error cases""" |
| 92 | + """ |
| 93 | + 发送API请求到LLM提供商并处理常见错误情况 |
| 94 | + 包含重试机制和详细的错误处理 |
| 95 | + """ |
90 | 96 | session = await manager.bot.session.create_session() # type: ignore |
91 | | - |
| 97 | + |
92 | 98 | # 根据模型类型设置不同的超时时间 |
93 | 99 | model_name = data.get("model", DEFAULT_MODEL) |
94 | | - |
| 100 | + |
95 | 101 | # 默认超时设置 |
96 | 102 | timeout_config = ClientTimeout( |
97 | 103 | total=180, # 3分钟总超时 |
98 | 104 | connect=15, |
99 | 105 | sock_read=170, # 2分50秒读取超时 |
100 | 106 | sock_connect=20, |
101 | 107 | ) |
102 | | - |
103 | | - try: |
104 | | - async with session.post( |
105 | | - url, |
106 | | - json=data, |
107 | | - headers={"Authorization": f"Bearer {proxy_token}"}, |
108 | | - timeout=timeout_config, |
109 | | - ) as response: |
110 | | - if response.status != 200: |
111 | | - error_message = await response.text() |
112 | | - error_code = response.status |
113 | | - logger.error(f"API request error: {error_code} {error_message}") |
114 | | - raise ValueError(f"System error: {error_code} {error_message}") |
115 | | - |
116 | | - response_data = await response.json() |
117 | | - |
118 | | - # check error |
119 | | - if "error" in response_data: |
120 | | - code = response_data["error"].get("code", "unknown") |
121 | | - message = response_data["error"].get("message", "Unknown error") |
122 | | - logger.error(f"API response error: {code} {message}") |
123 | | - raise ValueError(message) |
124 | | - |
125 | | - return response_data |
126 | | - except Exception as e: |
127 | | - if not isinstance(e, ValueError): |
128 | | - # 针对不同类型的异常提供更具体的错误信息 |
129 | | - if "SocketTimeoutError" in str(type(e)) or "TimeoutError" in str(type(e)): |
130 | | - logger.error(f"Request timeout for model {model_name}: {str(e)}") |
131 | | - raise ValueError(f"请求超时,模型 {model_name} 响应时间过长,请稍后重试") |
132 | | - elif "ClientConnectorError" in str(type(e)): |
133 | | - logger.error(f"Connection error: {str(e)}") |
134 | | - raise ValueError("无法连接到AI服务,请检查网络连接") |
135 | | - else: |
136 | | - logger.exception("Unexpected error during API request") |
137 | | - raise ValueError(f"请求失败: {str(e)}") |
138 | | - raise |
| 108 | + |
| 109 | + # 重试配置 |
| 110 | + max_retries = 3 |
| 111 | + retry_delay = 1 # 初始重试延迟(秒) |
| 112 | + |
| 113 | + # 可重试的HTTP状态码 |
| 114 | + retryable_status_codes = {502, 503, 504, 429} |
| 115 | + |
| 116 | + for attempt in range(max_retries): |
| 117 | + try: |
| 118 | + start_time = time.time() |
| 119 | + |
| 120 | + async with session.post( |
| 121 | + url, |
| 122 | + json=data, |
| 123 | + headers={"Authorization": f"Bearer {proxy_token}"}, |
| 124 | + timeout=timeout_config, |
| 125 | + ) as response: |
| 126 | + |
| 127 | + request_time = time.time() - start_time |
| 128 | + |
| 129 | + # 记录请求信息 |
| 130 | + logger.info(f"API请求 - 模型: {model_name}, 状态码: {response.status}, " |
| 131 | + f"用时: {request_time:.2f}s, 尝试: {attempt + 1}/{max_retries}") |
| 132 | + |
| 133 | + # 处理不同的HTTP状态码 |
| 134 | + if response.status == 200: |
| 135 | + try: |
| 136 | + response_data = await response.json() |
| 137 | + |
| 138 | + # 检查响应中的错误 |
| 139 | + if "error" in response_data: |
| 140 | + error_info = response_data["error"] |
| 141 | + error_code = error_info.get("code", "unknown") |
| 142 | + error_message = error_info.get("message", "Unknown error") |
| 143 | + |
| 144 | + logger.error(f"API响应错误 - 模型: {model_name}, 错误代码: {error_code}, " |
| 145 | + f"错误信息: {error_message}") |
| 146 | + |
| 147 | + # 根据错误代码决定是否重试 |
| 148 | + if error_code in ["rate_limit_exceeded", "server_error"] and attempt < max_retries - 1: |
| 149 | + await asyncio.sleep(retry_delay * (2 ** attempt)) |
| 150 | + continue |
| 151 | + |
| 152 | + raise ValueError(f"AI服务返回错误: {error_message}") |
| 153 | + |
| 154 | + return response_data |
| 155 | + |
| 156 | + except Exception as json_error: |
| 157 | + logger.error(f"解析响应JSON失败 - 模型: {model_name}, 错误: {str(json_error)}") |
| 158 | + if attempt < max_retries - 1: |
| 159 | + await asyncio.sleep(retry_delay * (2 ** attempt)) |
| 160 | + continue |
| 161 | + raise ValueError("响应格式错误,请稍后重试") |
| 162 | + |
| 163 | + # 处理特定的HTTP错误状态码 |
| 164 | + elif response.status == 400: |
| 165 | + error_text = await response.text() |
| 166 | + logger.error(f"请求参数错误 - 模型: {model_name}, 状态码: 400, 详情: {error_text}") |
| 167 | + raise ValueError("请求参数错误,请检查输入内容") |
| 168 | + |
| 169 | + elif response.status == 401: |
| 170 | + logger.error(f"认证失败 - 模型: {model_name}, 状态码: 401") |
| 171 | + raise ValueError("AI服务认证失败,请检查配置") |
| 172 | + |
| 173 | + elif response.status == 403: |
| 174 | + logger.error(f"权限不足 - 模型: {model_name}, 状态码: 403") |
| 175 | + raise ValueError("无权访问AI服务,请检查权限配置") |
| 176 | + |
| 177 | + elif response.status == 429: |
| 178 | + error_text = await response.text() |
| 179 | + logger.warning(f"请求频率限制 - 模型: {model_name}, 状态码: 429, 尝试: {attempt + 1}") |
| 180 | + |
| 181 | + if attempt < max_retries - 1: |
| 182 | + # 对于429错误,使用更长的重试延迟 |
| 183 | + retry_wait = retry_delay * (3 ** attempt) |
| 184 | + logger.info(f"等待 {retry_wait}s 后重试...") |
| 185 | + await asyncio.sleep(retry_wait) |
| 186 | + continue |
| 187 | + |
| 188 | + raise ValueError("请求过于频繁,请稍后重试") |
| 189 | + |
| 190 | + elif response.status in retryable_status_codes: |
| 191 | + error_text = await response.text() |
| 192 | + logger.warning(f"服务器临时错误 - 模型: {model_name}, 状态码: {response.status}, " |
| 193 | + f"尝试: {attempt + 1}, 详情: {error_text}") |
| 194 | + |
| 195 | + if attempt < max_retries - 1: |
| 196 | + await asyncio.sleep(retry_delay * (2 ** attempt)) |
| 197 | + continue |
| 198 | + |
| 199 | + raise ValueError(f"AI服务暂时不可用 ({response.status}),请稍后重试") |
| 200 | + |
| 201 | + else: |
| 202 | + # 其他HTTP错误 |
| 203 | + error_text = await response.text() |
| 204 | + logger.error(f"HTTP错误 - 模型: {model_name}, 状态码: {response.status}, " |
| 205 | + f"详情: {error_text}") |
| 206 | + raise ValueError(f"服务器错误 ({response.status}),请稍后重试") |
| 207 | + |
| 208 | + # 处理网络异常 |
| 209 | + except ClientConnectorError as e: |
| 210 | + logger.error(f"连接错误 - 模型: {model_name}, 尝试: {attempt + 1}, 错误: {str(e)}") |
| 211 | + |
| 212 | + if attempt < max_retries - 1: |
| 213 | + await asyncio.sleep(retry_delay * (2 ** attempt)) |
| 214 | + continue |
| 215 | + |
| 216 | + raise ValueError("无法连接到AI服务,请检查网络连接或服务地址") |
| 217 | + |
| 218 | + except ServerTimeoutError as e: |
| 219 | + logger.error(f"服务器超时 - 模型: {model_name}, 尝试: {attempt + 1}, 错误: {str(e)}") |
| 220 | + |
| 221 | + if attempt < max_retries - 1: |
| 222 | + await asyncio.sleep(retry_delay * (2 ** attempt)) |
| 223 | + continue |
| 224 | + |
| 225 | + raise ValueError("服务器响应超时,请稍后重试") |
| 226 | + |
| 227 | + except asyncio.TimeoutError as e: |
| 228 | + logger.error(f"请求超时 - 模型: {model_name}, 尝试: {attempt + 1}, 错误: {str(e)}") |
| 229 | + |
| 230 | + if attempt < max_retries - 1: |
| 231 | + await asyncio.sleep(retry_delay * (2 ** attempt)) |
| 232 | + continue |
| 233 | + |
| 234 | + raise ValueError("请求超时,请稍后重试") |
| 235 | + |
| 236 | + except ClientPayloadError as e: |
| 237 | + logger.error(f"请求负载错误 - 模型: {model_name}, 尝试: {attempt + 1}, 错误: {str(e)}") |
| 238 | + |
| 239 | + if attempt < max_retries - 1: |
| 240 | + await asyncio.sleep(retry_delay * (2 ** attempt)) |
| 241 | + continue |
| 242 | + |
| 243 | + raise ValueError("请求数据格式错误,请重试") |
| 244 | + |
| 245 | + except ServerDisconnectedError as e: |
| 246 | + logger.error(f"服务器断开连接 - 模型: {model_name}, 尝试: {attempt + 1}, 错误: {str(e)}") |
| 247 | + |
| 248 | + if attempt < max_retries - 1: |
| 249 | + await asyncio.sleep(retry_delay * (2 ** attempt)) |
| 250 | + continue |
| 251 | + |
| 252 | + raise ValueError("服务器连接中断,请稍后重试") |
| 253 | + |
| 254 | + except ClientError as e: |
| 255 | + logger.error(f"客户端错误 - 模型: {model_name}, 尝试: {attempt + 1}, 错误: {str(e)}") |
| 256 | + |
| 257 | + if attempt < max_retries - 1: |
| 258 | + await asyncio.sleep(retry_delay * (2 ** attempt)) |
| 259 | + continue |
| 260 | + |
| 261 | + raise ValueError("网络请求失败,请稍后重试") |
| 262 | + |
| 263 | + except ValueError: |
| 264 | + # ValueError是我们自定义的错误,不需要重试 |
| 265 | + raise |
| 266 | + |
| 267 | + except Exception as e: |
| 268 | + logger.exception(f"未预期的错误 - 模型: {model_name}, 尝试: {attempt + 1}, 错误: {str(e)}") |
| 269 | + |
| 270 | + if attempt < max_retries - 1: |
| 271 | + await asyncio.sleep(retry_delay * (2 ** attempt)) |
| 272 | + continue |
| 273 | + |
| 274 | + raise ValueError(f"请求失败: {str(e)}") |
| 275 | + |
| 276 | + # 如果所有重试都失败了,抛出最终错误 |
| 277 | + raise ValueError("所有重试都失败了,请稍后重试") |
139 | 278 |
|
140 | 279 |
|
141 | 280 | async def tg_generate_text(chat: types.Chat, member: types.User, prompt: str) -> Optional[str]: |
@@ -257,9 +396,7 @@ async def chat_completions(messages: List[Dict[str, Any]], model_name: Optional[ |
257 | 396 | **kwargs, |
258 | 397 | } |
259 | 398 |
|
260 | | - try: |
261 | | - response_data = await _api_request(url, data, proxy_token) |
262 | | - logger.info(f"generate txt use model {model_name}") |
263 | | - return response_data["choices"][0]["message"]["content"] |
264 | | - except ValueError as e: |
265 | | - return str(e) |
| 399 | + response_data = await _api_request(url, data, proxy_token) |
| 400 | + logger.info(f"generate txt use model {model_name}") |
| 401 | + return response_data["choices"][0]["message"]["content"] |
| 402 | + |
0 commit comments