Skip to content

Commit 3cb9075

Browse files
authored
Merge pull request #74 from MuRainBot/dev
合并Dev: 优化download_file_to_cache函数,为事件添加一些方便的属性方法,修复一些问题
2 parents 5ee5d70 + 2ad9cf5 commit 3cb9075

File tree

8 files changed

+327
-95
lines changed

8 files changed

+327
-95
lines changed

murainbot/common.py

Lines changed: 72 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33
"""
44
import inspect
55
import os.path
6+
import urllib.parse
67
import shutil
78
import sys
89
import threading
910
import time
1011
import uuid
1112
from collections import OrderedDict
1213
from io import BytesIO
14+
from pathlib import Path
1315
from typing import Callable
1416

1517
import requests
@@ -54,66 +56,102 @@ def restart() -> None:
5456
sys.exit()
5557

5658

57-
def download_file_to_cache(url: str, headers=None, file_name: str = "",
58-
download_path: str = None, stream=False, fake_headers: bool = True) -> str | None:
59+
def download_file_to_cache(url: str,
60+
headers=None,
61+
file_name: str = None,
62+
max_size: int = None,
63+
timeout: int = 30,
64+
download_path: str = paths.CACHE_PATH,
65+
stream=True,
66+
fake_headers: bool = True) -> str | None:
5967
"""
6068
下载文件到缓存
69+
**请自行保证下载链接的安全性**
6170
Args:
6271
url: 下载的url
6372
headers: 下载请求的请求头
6473
file_name: 文件名
74+
max_size: 最大大小,单位字节,None则为不限制
75+
timeout: 请求超时时间
6576
download_path: 下载路径
6677
stream: 是否使用流式传输
6778
fake_headers: 是否使用自动生成的假请求头
6879
Returns:
69-
文件路径
80+
文件路径,如果请求失败则返回None
7081
"""
7182
if headers is None:
7283
headers = {}
7384

7485
if fake_headers:
75-
headers["User-Agent"] = ("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) "
76-
"Chrome/113.0.0.0 Safari/537.36 Edg/113.0.1774.42")
77-
headers["Accept-Language"] = "zh-CN,zh;q=0.9,en;q=0.8,da;q=0.7,ko;q=0.6"
78-
headers["Accept-Encoding"] = "gzip, deflate, br"
79-
headers["Accept"] = ("text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,"
80-
"application/signed-exchange;v=b3;q=0.7")
81-
headers["Connection"] = "keep-alive"
82-
headers["Upgrade-Insecure-Requests"] = "1"
83-
headers["Cache-Control"] = "max-age=0"
84-
headers["Sec-Fetch-Dest"] = "document"
85-
headers["Sec-Fetch-Mode"] = "navigate"
86-
headers["Sec-Fetch-Site"] = "none"
87-
headers["Sec-Fetch-User"] = "?1"
88-
headers["Sec-Ch-Ua"] = "\"Chromium\";v=\"113\", \"Not-A.Brand\";v=\"24\", \"Microsoft Edge\";v=\"113\""
89-
headers["Sec-Ch-Ua-Mobile"] = "?0"
90-
headers["Sec-Ch-Ua-Platform"] = "\"Windows\""
91-
headers["Host"] = url.split("/")[2]
86+
headers.update({
87+
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) "
88+
"Chrome/113.0.0.0 Safari/537.36 Edg/113.0.1774.42",
89+
"Accept-Language": "zh-CN,zh;q=0.9,en;q=0.8,da;q=0.7,ko;q=0.6",
90+
"Accept-Encoding": "gzip, deflate, br",
91+
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8,"
92+
"application/signed-exchange;v=b3;q=0.7",
93+
"Connection": "keep-alive",
94+
"Host": urllib.parse.urlparse(url).hostname
95+
})
9296

9397
# 路径拼接
94-
if file_name == "":
98+
if file_name is None:
9599
file_name = uuid.uuid4().hex + ".cache"
96100

97-
if download_path is None:
98-
file_path = os.path.join(paths.CACHE_PATH, file_name)
99-
else:
100-
file_path = os.path.join(download_path, file_name)
101-
102-
# 路径不存在特判
103-
if not os.path.exists(paths.CACHE_PATH):
104-
os.makedirs(paths.CACHE_PATH)
101+
file_path = Path(download_path) / file_name
102+
if paths.CACHE_PATH in file_path.parents:
103+
try:
104+
if not file_path.resolve().is_relative_to(paths.CACHE_PATH.resolve()):
105+
logger.warning("下载文件失败: 文件路径解析后超出缓存目录")
106+
return None
107+
except FileNotFoundError:
108+
pass
109+
file_path = str(file_path)
105110

106111
try:
107112
# 下载
108113
if stream:
109-
with open(file_path, "wb") as f, requests.get(url, stream=True, headers=headers) as res:
110-
for chunk in res.iter_content(chunk_size=64 * 1024):
111-
if not chunk:
112-
break
113-
f.write(chunk)
114+
# 使用流式下载
115+
with requests.get(url, stream=True, timeout=timeout, headers=headers) as res:
116+
res.raise_for_status() # 请求失败则抛出异常
117+
118+
# 优先从Content-Length判断
119+
content_length = res.headers.get('Content-Length')
120+
if max_size and content_length and int(content_length) > max_size:
121+
logger.warning(f"下载中止: 文件大小 ({content_length} B) 超出限制 ({max_size} B)")
122+
return None
123+
124+
downloaded_size = 0
125+
with open(file_path, "wb") as f:
126+
for chunk in res.iter_content(chunk_size=8192):
127+
downloaded_size += len(chunk)
128+
if max_size and downloaded_size > max_size:
129+
logger.warning(f"下载中止: 文件在传输过程中超出大小限制 ({max_size} B)")
130+
f.close()
131+
os.remove(file_path)
132+
return None
133+
f.write(chunk)
114134
else:
115135
# 不使用流式传输
136+
if max_size is not None:
137+
# 获取响应头
138+
res = requests.head(url, timeout=timeout, headers=headers)
139+
if "Content-Length" in res.headers:
140+
# 获取响应头中的Content-Length
141+
content_length = int(res.headers["Content-Length"])
142+
if content_length > max_size:
143+
logger.warning(f"下载中止: 文件大小 ({content_length} B) 超出限制 ({max_size} B)")
144+
return None
145+
else:
146+
logger.warning(f"下载文件失败: HEAD请求未获取到文件大小,建议使用流式传输")
147+
return None
148+
116149
res = requests.get(url, headers=headers)
150+
res.raise_for_status()
151+
152+
if len(res.content) > max_size:
153+
logger.warning(f"下载中止: 文件在传输过程中超出大小限制 ({max_size} B)")
154+
return None
117155

118156
with open(file_path, "wb") as f:
119157
f.write(res.content)

murainbot/core/PluginManager.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,51 @@
77
import inspect
88
import os
99
import sys
10+
from types import ModuleType
11+
from typing import TypedDict
1012

1113
from murainbot.common import save_exc_dump
1214
from murainbot.paths import paths
1315
from murainbot.core import ConfigManager
14-
from murainbot.core.EventManager import event_listener
15-
from murainbot.core.ListenerServer import EscalationEvent
1616
from murainbot.utils.Logger import get_logger
1717

1818
logger = get_logger()
1919

20-
plugins: list[dict] = []
21-
found_plugins: list[dict] = []
20+
21+
@dataclasses.dataclass
22+
class PluginInfo:
23+
"""
24+
插件信息
25+
"""
26+
NAME: str # 插件名称
27+
AUTHOR: str # 插件作者
28+
VERSION: str # 插件版本
29+
DESCRIPTION: str # 插件描述
30+
HELP_MSG: str # 插件帮助
31+
ENABLED: bool = True # 插件是否启用
32+
IS_HIDDEN: bool = False # 插件是否隐藏(在/help命令中)
33+
extra: dict | None = None # 一个字典,可以用于存储任意信息。其他插件可以通过约定 extra 字典的键名来达成收集某些特殊信息的目的。
34+
35+
def __post_init__(self):
36+
if not self.ENABLED:
37+
raise NotEnabledPluginException
38+
if self.extra is None:
39+
self.extra = {}
40+
41+
42+
class PluginDict(TypedDict):
43+
"""
44+
插件信息字典
45+
"""
46+
name: str
47+
plugin: ModuleType | None
48+
info: PluginInfo | None
49+
file_path: str
50+
path: str
51+
52+
53+
plugins: list[PluginDict] = []
54+
found_plugins: list[PluginDict] = []
2255

2356

2457
class NotEnabledPluginException(Exception):
@@ -137,27 +170,6 @@ def load_plugins():
137170
logger.debug(f"插件 {name}({full_path}) 加载成功!")
138171

139172

140-
@dataclasses.dataclass
141-
class PluginInfo:
142-
"""
143-
插件信息
144-
"""
145-
NAME: str # 插件名称
146-
AUTHOR: str # 插件作者
147-
VERSION: str # 插件版本
148-
DESCRIPTION: str # 插件描述
149-
HELP_MSG: str # 插件帮助
150-
ENABLED: bool = True # 插件是否启用
151-
IS_HIDDEN: bool = False # 插件是否隐藏(在/help命令中)
152-
extra: dict | None = None # 一个字典,可以用于存储任意信息。其他插件可以通过约定 extra 字典的键名来达成收集某些特殊信息的目的。
153-
154-
def __post_init__(self):
155-
if self.ENABLED is not True:
156-
raise NotEnabledPluginException
157-
if self.extra is None:
158-
self.extra = {}
159-
160-
161173
def requirement_plugin(plugin_name: str):
162174
"""
163175
插件依赖

murainbot/utils/CommandManager.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def send(self, message: QQRichText.QQRichText | str):
609609
return Actions.SendMsg(
610610
message=message,
611611
**{"group_id": self["group_id"]}
612-
if self.message_type == "group" else
612+
if self.is_group else
613613
{"user_id": self.user_id}
614614
).call()
615615

@@ -630,17 +630,10 @@ def reply(self, message: QQRichText.QQRichText | str):
630630
message
631631
),
632632
**{"group_id": self["group_id"]}
633-
if self.message_type == "group" else
633+
if self.is_group else
634634
{"user_id": self.user_id}
635635
).call()
636636

637-
@property
638-
def is_group(self) -> bool:
639-
"""
640-
判断是否为群消息
641-
"""
642-
return self.message_type == "group"
643-
644637

645638
class CommandManager:
646639
"""
@@ -851,9 +844,9 @@ def match(self, event_data: CommandEvent, rules_kwargs: dict):
851844

852845
# 检测依赖注入
853846
if isinstance(event_data, EventClassifier.MessageEvent):
854-
if event_data.message_type == "private":
847+
if event_data.is_private:
855848
state_id = f"u{event_data.user_id}"
856-
elif event_data.message_type == "group":
849+
elif event_data.is_group:
857850
state_id = f"g{event_data["group_id"]}_u{event_data.user_id}"
858851
else:
859852
state_id = None

0 commit comments

Comments
 (0)