-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchatbot_fast.py
More file actions
250 lines (201 loc) · 7.31 KB
/
chatbot_fast.py
File metadata and controls
250 lines (201 loc) · 7.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
import os
import uuid
from fastapi import FastAPI, UploadFile, File, Form, Depends, Request
from fastapi.responses import StreamingResponse, JSONResponse
from pydantic import BaseModel
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from sqlalchemy.orm import Session
from starlette.middleware.base import BaseHTTPMiddleware
from langchain_core.prompts import ChatPromptTemplate
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains import create_history_aware_retriever
from langchain_core.messages import HumanMessage, AIMessage
from config import (
llm, get_vector_store, UPLOAD_DIR, ensure_collection,
qdrant_client, QDRANT_COLLECTION, REDIS_URL, DATABASE_URL,
)
from database import create_tables, get_db, engine
from models import IngestionJob, JobStatus
from redis_client import (
get_cached_response,
set_cached_response,
get_chat_history,
append_chat_history,
get_job_queue,
get_redis,
)
from logger import get_logger, request_id_var, new_request_id
log = get_logger("api")
app = FastAPI(title="Dealer RAG API")
app.mount("/static", StaticFiles(directory="static"), name="static")
os.makedirs(UPLOAD_DIR, exist_ok=True)
class RequestIDMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
rid = request.headers.get("X-Request-ID", new_request_id())
request_id_var.set(rid)
log.info(f"{request.method} {request.url.path}")
response = await call_next(request)
response.headers["X-Request-ID"] = rid
log.info(f"{request.method} {request.url.path} -> {response.status_code}")
return response
app.add_middleware(RequestIDMiddleware)
@app.on_event("startup")
def startup():
create_tables()
ensure_collection()
log.info("API started, tables and collection ready")
@app.get("/")
def serve_frontend():
return FileResponse("static/index.html")
@app.get("/health")
def health():
"""Liveness probe - confirms the process is running."""
return {"status": "ok"}
@app.get("/ready")
def readiness():
"""Readiness probe - checks Redis, Postgres, and Qdrant connectivity."""
checks = {}
try:
r = get_redis()
r.ping()
checks["redis"] = "ok"
except Exception as e:
checks["redis"] = f"error: {e}"
try:
conn = engine.connect()
conn.close()
checks["postgres"] = "ok"
except Exception as e:
checks["postgres"] = f"error: {e}"
try:
qdrant_client.get_collections()
checks["qdrant"] = "ok"
except Exception as e:
checks["qdrant"] = f"error: {e}"
all_ok = all(v == "ok" for v in checks.values())
return JSONResponse(
content={"status": "ready" if all_ok else "degraded", "checks": checks},
status_code=200 if all_ok else 503,
)
class QuestionRequest(BaseModel):
question: str
session_id: str = "default"
@app.post("/upload")
async def upload_pdf(
file: UploadFile = File(...),
ocr: str = Form("false"),
db: Session = Depends(get_db),
):
enable_ocr = ocr.lower() == "true"
job_id = str(uuid.uuid4())
log.info(f"Upload received: {file.filename}, job_id={job_id}, ocr={enable_ocr}")
file_ext = os.path.splitext(file.filename)[1]
safe_filename = f"{job_id}{file_ext}"
file_path = os.path.join(UPLOAD_DIR, safe_filename)
with open(file_path, "wb") as f:
content = await file.read()
f.write(content)
job = IngestionJob(
id=job_id,
filename=file.filename,
status=JobStatus.PENDING,
)
db.add(job)
db.commit()
log.info(f"Job {job_id} created as PENDING")
queue = get_job_queue()
queue.enqueue(
"worker.process_ingestion",
job_id,
file_path,
file.filename,
enable_ocr,
job_timeout="30m",
)
log.info(f"Job {job_id} enqueued to Redis")
return JSONResponse(
content={"job_id": job_id, "status": "PENDING"},
status_code=202,
)
@app.get("/job/{job_id}")
async def get_job_status(job_id: str, db: Session = Depends(get_db)):
job = db.query(IngestionJob).filter(IngestionJob.id == job_id).first()
if not job:
return JSONResponse(
content={"error": "Job not found"}, status_code=404
)
return JSONResponse(content=job.to_dict())
@app.post("/ask")
async def ask_question(request: QuestionRequest):
session_id = request.session_id
log.info(f"Question from session={session_id}: {request.question[:80]}")
cached = get_cached_response(request.question)
if cached:
log.info("Cache hit, returning cached response")
return StreamingResponse(
iter([cached]), media_type="text/plain"
)
history_data = get_chat_history(session_id)
chat_history = []
for entry in history_data:
if entry["role"] == "human":
chat_history.append(HumanMessage(content=entry["content"]))
else:
chat_history.append(AIMessage(content=entry["content"]))
vector_store = get_vector_store()
retriever = vector_store.as_retriever(search_kwargs={"k": 8})
contextualize_prompt = ChatPromptTemplate.from_messages([
(
"system",
"Rephrase the user question into standalone form if needed.",
),
("placeholder", "{chat_history}"),
("human", "{input}"),
])
history_aware_retriever = create_history_aware_retriever(
llm, retriever, contextualize_prompt
)
retrieved_docs = history_aware_retriever.invoke({
"input": request.question,
"chat_history": chat_history,
})
log.info(f"Retrieved {len(retrieved_docs)} docs for context")
system_prompt = (
"Rules:\n"
"1. Answer ONLY using the provided document context.\n"
"2. Do NOT use prior knowledge.\n"
"3. If the answer is not explicitly present, reply EXACTLY:\n"
" Not mentioned in the document.\n"
"4. Do NOT include phrases like 'based on the document' or 'it seems'.\n"
"5. Be precise, factual, and direct.\n"
"Formatting Rules:\n"
"6. If the answer contains multiple facts, return them as bullet points.\n"
"7. If numerical values are present, include them exactly as written.\n"
"8. Prefer structured outputs over paragraphs.\n"
"Behavior:\n"
"9. Do NOT explain reasoning.\n"
"10. Do NOT add extra commentary.\n"
"11. Extract, don't generate.\n\n"
"12. If possible, include the source section or page reference.\n"
"Document Context:\n{context}"
)
qa_prompt = ChatPromptTemplate.from_messages([
("system", system_prompt),
("placeholder", "{chat_history}"),
("human", "{input}"),
])
question_answer_chain = create_stuff_documents_chain(llm, qa_prompt)
def stream_response():
full_response = ""
for chunk in question_answer_chain.stream({
"input": request.question,
"chat_history": chat_history,
"context": retrieved_docs,
}):
full_response += chunk
yield chunk
set_cached_response(request.question, full_response)
append_chat_history(session_id, request.question, full_response)
log.info(f"Response streamed, length={len(full_response)}")
return StreamingResponse(stream_response(), media_type="text/plain")