-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdatabase.py
More file actions
229 lines (178 loc) · 7.04 KB
/
database.py
File metadata and controls
229 lines (178 loc) · 7.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
import json
from typing import List
import uuid
import os
from fastapi import HTTPException, status
import requests
from dotenv import load_dotenv
import chromadb
from chromadb.errors import NotFoundError
from langchain_classic.embeddings import CacheBackedEmbeddings
from langchain_classic.embeddings.base import Embeddings as LangChainEmbeddings
from langchain_classic.storage import LocalFileStore
from tqdm import tqdm
from utils.hash import make_hash
load_dotenv()
store = LocalFileStore("./tmp/cache/")
class SiliconFlowEmbeddings(LangChainEmbeddings):
"""自定义 SiliconFlow embedding 类,兼容 LangChain 接口"""
def __init__(self):
self.model = "BAAI/bge-m3"
self.base_url = "https://api.siliconflow.cn/v1"
self.api_key = os.getenv("siliconflow_token")
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""嵌入多个文档"""
url = f"{self.base_url}/embeddings"
payload = {
"model": self.model,
"input": texts,
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
response = requests.post(url, json=payload, headers=headers)
if not response.ok:
if response.status_code == 401:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid API key for Embedding service",
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Embedding API request failed with status code {response.status_code}",
)
data = response.json()
message = data.get("message")
if message:
assert False, f"Embedding API error: {message}"
return [item["embedding"] for item in data["data"]] if data["data"] else []
def embed_query(self, text: str) -> List[float]:
"""嵌入单个查询"""
return self.embed_documents([text])[0]
class CustomDocument:
def __init__(self, content: str, doc_name: str, page_number):
self.content = content
self.doc_name = doc_name
self.page_number = page_number
def to_dict(self):
return {
"content": self.content,
"doc_name": self.doc_name,
"page_number": self.page_number,
}
# 创建带缓存的 embedding 实例
underlying_embeddings = SiliconFlowEmbeddings()
cached_embeddings = CacheBackedEmbeddings.from_bytes_store(
underlying_embeddings,
store,
key_encoder=make_hash,
)
chroma_client = chromadb.PersistentClient(path="./chroma_db")
try:
collection = chroma_client.get_collection("my_collection")
if not collection:
collection = chroma_client.create_collection(name="my_collection")
except NotFoundError:
collection = chroma_client.create_collection(name="my_collection")
class DocumentRecord:
metadata: dict
def __init__(self, content: str, metadata: str | dict):
self.content = content
if isinstance(metadata, str):
try:
self.metadata = json.loads(metadata)
except json.JSONDecodeError:
self.metadata = {"text": metadata}
elif isinstance(metadata, dict):
self.metadata = metadata["metadata"] if "metadata" in metadata else metadata
if isinstance(self.metadata, str):
try:
self.metadata = json.loads(self.metadata)
except json.JSONDecodeError:
self.metadata = {"text": self.metadata}
else:
raise ValueError("Metadata must be either a string or a dictionary.")
def __dict__(self):
return {
"content": self.content,
"metadata": self.metadata,
}
def insert_record(doc: DocumentRecord) -> str:
id = str(uuid.uuid4())
doc_embedding = cached_embeddings.embed_query(doc.content)
collection.add(
ids=[id],
documents=[doc.content],
embeddings=[doc_embedding],
metadatas=[doc.metadata],
)
return id
def insert_records_batch(docs: List[DocumentRecord], batch_size: int = 32) -> List[str]:
"""批量插入文档记录"""
all_ids = []
for i in range(0, len(docs), batch_size):
batch = docs[i : i + batch_size]
# 批量生成ID
batch_ids = [str(uuid.uuid4()) for _ in batch]
# 批量生成embedding
batch_contents = [doc.content for doc in batch]
batch_embeddings = cached_embeddings.embed_documents(batch_contents)
# 批量插入到collection
collection.add(
ids=batch_ids,
documents=batch_contents,
embeddings=batch_embeddings, # type: ignore
metadatas=[doc.metadata for doc in batch],
)
all_ids.extend(batch_ids)
return all_ids
def similarity_search(query: str, limit: int = 5) -> List[DocumentRecord]:
# 使用缓存的 embeddings 生成查询向量
query_embedding = cached_embeddings.embed_query(query)
query_results = collection.query(
query_embeddings=[query_embedding],
n_results=limit,
)
if (not query_results["documents"]) or (not query_results["metadatas"]):
return []
documents_list = query_results["documents"][0]
metadatas_list = query_results["metadatas"][0]
if "distances" in query_results and query_results["distances"]:
distances = query_results["distances"][0]
# 创建 (distance, content, metadata) 的列表
items = list(zip(distances, documents_list, metadatas_list))
# 按距离升序排序(距离越小,越相似)
items.sort(key=lambda x: x[0], reverse=True)
else:
# 如果没有距离信息,直接使用原始顺序
items = list(zip([0] * len(documents_list), documents_list, metadatas_list))
# 创建 DocumentRecord 列表
documents = [
DocumentRecord(
content=content,
metadata=dict(metadata), # 转换为 dict
)
for _, content, metadata in items
]
return documents
def extract_docs_has_single_term(term: str) -> List[DocumentRecord]:
"""Extract sentences containing the term from the text."""
documents = similarity_search(f"`{term}`", limit=20)
results: List[DocumentRecord] = []
for doc in documents:
if term in doc.content:
results.append(doc)
print("Retrieved Documents:")
print(f"Found {len(results)} documents containing the term.")
return results
def extract_docs_has_both_term(term_pair: tuple) -> List[DocumentRecord]:
"""Extract sentences containing both terms from the text."""
documents = similarity_search(f"`{term_pair[0]}`和`{term_pair[1]}`", limit=20)
results: List[DocumentRecord] = []
for doc in documents:
if term_pair[0] in doc.content and term_pair[1] in doc.content:
results.append(doc)
print("Retrieved Documents:")
print(f"Found {len(results)} documents containing both terms.")
return results