-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRAGRetriever.py
More file actions
213 lines (171 loc) · 7.31 KB
/
RAGRetriever.py
File metadata and controls
213 lines (171 loc) · 7.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import httpx
import json
from typing import List, Dict, Optional
from Configurator import AppConfig
class RAGRetriever:
"""
Retrieval component for RAG pipeline.
Handles query embedding and similarity search in Qdrant.
"""
def __init__(self, config: AppConfig, embedder):
"""
Initialize the RAG retriever.
Args:
config (AppConfig): Application configuration
embedder: Embedder instance for query embedding
"""
self.config = config
self.embedder = embedder
self.qdrant_url = config.qdrant_url
self.qdrant_api_key = config.qdrant_api_key
self.collection_name = config.qdrant_collection
# Setup headers for Qdrant API
self.headers = {
"Content-Type": "application/json",
"api-key": self.qdrant_api_key
}
async def search_similar_chunks(self, query: str, top_k: Optional[int] = None) -> List[Dict]:
"""
Search for similar chunks in Qdrant using query embedding.
Args:
query (str): User's natural language query
top_k (int, optional): Number of top results to retrieve
Returns:
List[Dict]: List of retrieved chunks with metadata and scores
"""
if top_k is None:
top_k = self.config.rag.top_k
try:
# Step 1: Embed the query
query_embedding = await self._embed_query(query)
# Step 2: Search in Qdrant
search_results = await self._search_qdrant(query_embedding, top_k)
# Step 3: Format results
formatted_results = self._format_search_results(search_results)
return formatted_results
except Exception as e:
raise Exception(f"RAG retrieval failed: {str(e)}")
async def _embed_query(self, query: str) -> List[float]:
"""
Embed the user query using the configured embedding model.
Args:
query (str): User's natural language query
Returns:
List[float]: Query embedding vector
"""
try:
# Use the embedder to get query embedding
embeddings = self.embedder.embed_texts([query])
return embeddings[0]
except Exception as e:
raise Exception(f"Query embedding failed: {str(e)}")
async def _search_qdrant(self, query_vector: List[float], top_k: int) -> Dict:
"""
Perform similarity search in Qdrant.
Args:
query_vector (List[float]): Query embedding vector
top_k (int): Number of results to retrieve
Returns:
Dict: Raw search results from Qdrant
"""
search_url = f"{self.qdrant_url}/collections/{self.collection_name}/points/search"
search_payload = {
"vector": query_vector,
"limit": top_k,
"with_payload": True,
"with_vector": False, # We don't need the vectors back
"score_threshold": 0.0 # No minimum score threshold
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(
search_url,
headers=self.headers,
json=search_payload,
timeout=30.0
)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"Qdrant search failed: HTTP {response.status_code} - {response.text}")
except httpx.RequestError as e:
raise Exception(f"Qdrant connection error: {str(e)}")
except Exception as e:
raise Exception(f"Qdrant search error: {str(e)}")
def _format_search_results(self, search_results: Dict) -> List[Dict]:
"""
Format Qdrant search results into standardized format.
Args:
search_results (Dict): Raw results from Qdrant
Returns:
List[Dict]: Formatted results with chunk data and metadata
"""
formatted_results = []
if "result" not in search_results:
return formatted_results
for result in search_results["result"]:
point_id = result.get("id")
score = result.get("score", 0.0)
payload = result.get("payload", {})
# Extract chunk information from payload
chunk_data = {
"id": point_id,
"score": score,
"text": payload.get("text", ""),
"source": payload.get("source", ""),
"index": payload.get("index", 0),
"meta": payload.get("meta", {})
}
formatted_results.append(chunk_data)
return formatted_results
async def get_collection_info(self) -> Dict:
"""
Get information about the Qdrant collection.
Returns:
Dict: Collection information
"""
info_url = f"{self.qdrant_url}/collections/{self.collection_name}"
try:
async with httpx.AsyncClient() as client:
response = await client.get(
info_url,
headers=self.headers,
timeout=10.0
)
if response.status_code == 200:
return response.json()
else:
raise Exception(f"Failed to get collection info: HTTP {response.status_code}")
except Exception as e:
raise Exception(f"Collection info retrieval failed: {str(e)}")
async def validate_connection(self) -> Dict:
"""
Validate connection to Qdrant and collection.
Returns:
Dict: Validation results
"""
validation_results = {
"valid": True,
"errors": [],
"warnings": [],
"collection_info": None
}
try:
# Test connection to Qdrant
collection_info = await self.get_collection_info()
validation_results["collection_info"] = collection_info
# Check if collection has points
points_count = collection_info.get("points_count", 0)
if points_count == 0:
validation_results["warnings"].append("Collection is empty - no documents to search")
except Exception as e:
error_msg = str(e)
if "HTTP 404" in error_msg:
# Collection doesn't exist yet - this is expected for new setups
validation_results["warnings"].append("Qdrant collection does not exist yet - needs to be created and populated")
validation_results["collection_info"] = {"status": "not_found", "points_count": 0}
else:
# Other connection errors
validation_results["valid"] = False
validation_results["errors"].append(f"Connection validation failed: {error_msg}")
return validation_results