-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathRAGProcessor.py
More file actions
270 lines (209 loc) · 9.01 KB
/
RAGProcessor.py
File metadata and controls
270 lines (209 loc) · 9.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
from typing import List, Dict
from Configurator import AppConfig, RerankMethod
class RAGProcessor:
"""
Post-processing component for RAG pipeline.
Handles re-ranking, context assembly, and prompt construction.
"""
def __init__(self, config: AppConfig):
"""
Initialize the RAG processor.
Args:
config (AppConfig): Application configuration
"""
self.config = config
def rerank_chunks(self, chunks: List[Dict], query: str) -> List[Dict]:
"""
Re-rank retrieved chunks based on the configured method.
Args:
chunks (List[Dict]): Retrieved chunks with scores
query (str): Original user query
Returns:
List[Dict]: Re-ranked chunks
"""
if not chunks:
return chunks
rerank_method = self.config.rag.rerank_method
if rerank_method == RerankMethod.MMR:
return self._rerank_mmr(chunks, query)
elif rerank_method == RerankMethod.CROSS_ENCODER:
return self._rerank_cross_encoder(chunks, query)
elif rerank_method == RerankMethod.LLM_RERANK:
return self._rerank_llm(chunks, query)
else:
# Default: return top chunks by score
return sorted(chunks, key=lambda x: x.get("score", 0), reverse=True)[:self.config.rag.final_chunks]
def _rerank_mmr(self, chunks: List[Dict], query: str) -> List[Dict]:
"""
Re-rank using Maximal Marginal Relevance (MMR).
Balances relevance and diversity.
Args:
chunks (List[Dict]): Retrieved chunks
query (str): Original query
Returns:
List[Dict]: MMR re-ranked chunks
"""
if len(chunks) <= self.config.rag.final_chunks:
return chunks
# Sort by score first
chunks = sorted(chunks, key=lambda x: x.get("score", 0), reverse=True)
selected_chunks = []
remaining_chunks = chunks.copy()
# Select first chunk (highest score)
if remaining_chunks:
selected_chunks.append(remaining_chunks.pop(0))
# MMR selection for remaining chunks
while len(selected_chunks) < self.config.rag.final_chunks and remaining_chunks:
best_chunk = None
best_mmr_score = -1
for chunk in remaining_chunks:
# Calculate MMR score: λ * relevance - (1-λ) * max_similarity
relevance = chunk.get("score", 0)
max_similarity = self._calculate_max_similarity(chunk, selected_chunks)
mmr_score = self.config.rag.mmr_diversity_threshold * relevance - (1 - self.config.rag.mmr_diversity_threshold) * max_similarity
if mmr_score > best_mmr_score:
best_mmr_score = mmr_score
best_chunk = chunk
if best_chunk:
selected_chunks.append(best_chunk)
remaining_chunks.remove(best_chunk)
else:
break
return selected_chunks
def _rerank_cross_encoder(self, chunks: List[Dict], query: str) -> List[Dict]:
"""
Re-rank using cross-encoder (placeholder implementation).
TODO: Implement actual cross-encoder model.
Args:
chunks (List[Dict]): Retrieved chunks
query (str): Original query
Returns:
List[Dict]: Cross-encoder re-ranked chunks
"""
# For now, fall back to MMR
return self._rerank_mmr(chunks, query)
def _rerank_llm(self, chunks: List[Dict], query: str) -> List[Dict]:
"""
Re-rank using LLM (placeholder implementation).
TODO: Implement LLM-based re-ranking.
Args:
chunks (List[Dict]): Retrieved chunks
query (str): Original query
Returns:
List[Dict]: LLM re-ranked chunks
"""
# For now, fall back to MMR
return self._rerank_mmr(chunks, query)
def _calculate_max_similarity(self, chunk: Dict, selected_chunks: List[Dict]) -> float:
"""
Calculate maximum similarity between a chunk and selected chunks.
Simple text-based similarity for now.
Args:
chunk (Dict): Chunk to compare
selected_chunks (List[Dict]): Already selected chunks
Returns:
float: Maximum similarity score
"""
if not selected_chunks:
return 0.0
chunk_text = chunk.get("text", "").lower()
max_similarity = 0.0
for selected_chunk in selected_chunks:
selected_text = selected_chunk.get("text", "").lower()
similarity = self._calculate_text_similarity(chunk_text, selected_text)
max_similarity = max(max_similarity, similarity)
return max_similarity
def _calculate_text_similarity(self, text1: str, text2: str) -> float:
"""
Calculate simple text similarity based on word overlap.
Args:
text1 (str): First text
text2 (str): Second text
Returns:
float: Similarity score between 0 and 1
"""
words1 = set(text1.split())
words2 = set(text2.split())
if not words1 or not words2:
return 0.0
intersection = words1.intersection(words2)
union = words1.union(words2)
return len(intersection) / len(union) if union else 0.0
def assemble_context(self, chunks: List[Dict]) -> str:
"""
Assemble selected chunks into a single context block.
Args:
chunks (List[Dict]): Selected chunks after re-ranking
Returns:
str: Assembled context with source citations
"""
if not chunks:
return ""
context_parts = []
current_tokens = 0
max_tokens = self.config.rag.context_max_tokens
for i, chunk in enumerate(chunks):
chunk_text = chunk.get("text", "")
source = chunk.get("source", "unknown")
index = chunk.get("index", 0)
# Estimate tokens (rough approximation: words * 1.3)
chunk_tokens = int(len(chunk_text.split()) * 1.3)
# Check if adding this chunk would exceed token limit
if current_tokens + chunk_tokens > max_tokens and context_parts:
break
# Add chunk with source citation
citation = f"[Source: {source}, Chunk: {index}]"
context_part = f"{chunk_text}\n{citation}"
context_parts.append(context_part)
current_tokens += chunk_tokens
return "\n\n".join(context_parts)
def construct_prompt(self, query: str, context: str, conversation_context: str = "") -> List[Dict]:
"""
Construct the prompt for the LLM with system, context, and user messages.
Args:
query (str): User's original query
context (str): Assembled context from retrieved chunks
conversation_context (str): Previous conversation context (optional)
Returns:
List[Dict]: Messages for the LLM
"""
# Build system message with conversation context if available
system_content = self.config.rag.system_prompt
if conversation_context:
system_content = f"{system_content}\n\n{conversation_context}"
system_message = {
"role": "system",
"content": system_content
}
context_message = {
"role": "user",
"content": f"Context:\n{context}\n\nQuery: {query}"
}
return [system_message, context_message]
def extract_sources(self, chunks: List[Dict]) -> List[Dict]:
"""
Extract source information from chunks for citation.
Args:
chunks (List[Dict]): Selected chunks
Returns:
List[Dict]: Source information for citations
"""
sources = []
for chunk in chunks:
source_info = {
"source": chunk.get("source", "unknown"),
"chunk_index": chunk.get("index", 0),
"snippet": chunk.get("text", "")[:200] + "..." if len(chunk.get("text", "")) > 200 else chunk.get("text", ""),
"score": chunk.get("score", 0.0)
}
sources.append(source_info)
return sources
def estimate_tokens(self, text: str) -> int:
"""
Estimate token count for text.
Args:
text (str): Text to estimate
Returns:
int: Estimated token count
"""
return int(len(text.split()) * 1.3)