-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
139 lines (118 loc) · 4.49 KB
/
main.py
File metadata and controls
139 lines (118 loc) · 4.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# main.py
import os
import json
import time
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from app.crawler import crawl_site
from app.indexer import index_pages, load_index_and_meta
from app.qa import answer_question
app = FastAPI(title="RAG Service")
# --------------------------
# Request Models
# --------------------------
class CrawlRequest(BaseModel):
start_url: str
max_pages: int = 50
max_depth: int = 3
crawl_delay_ms: int = 500
class IndexRequest(BaseModel):
source: str # path to crawl folder
chunk_size: int = 800
chunk_overlap: int = 100
embedding_model: str = "all-MiniLM-L6-v2"
class AskRequest(BaseModel):
question: str
top_k: int = 5
# --------------------------
# /crawl endpoint
# --------------------------
@app.post("/crawl")
async def crawl(req: CrawlRequest):
"""Crawl pages starting from a URL."""
try:
# Create folder for this crawl
crawl_dir = os.path.join("data", "crawls")
os.makedirs(crawl_dir, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
crawl_path = os.path.join(crawl_dir, f"crawl_{timestamp}")
os.makedirs(crawl_path, exist_ok=True)
pages, skipped = await crawl_site(
start_url=req.start_url,
max_pages=req.max_pages,
max_depth=req.max_depth,
delay_ms=req.crawl_delay_ms
)
if pages is None:
raise HTTPException(status_code=500, detail="Crawl returned None")
# Save pages to crawl_path
for i, p in enumerate(pages):
fname = os.path.join(crawl_path, f"page_{i+1}.json")
with open(fname, "w", encoding="utf-8") as f:
json.dump(p, f, ensure_ascii=False)
return {
"page_count": len(pages),
"skipped_count": skipped,
"urls": [p["url"] for p in pages],
"crawl_path": crawl_path
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
# --------------------------
# /index endpoint
# --------------------------
@app.post("/index")
def index(req: IndexRequest):
"""Index crawled pages into FAISS."""
try:
if not os.path.exists(req.source):
raise HTTPException(status_code=400, detail=f"Source folder '{req.source}' does not exist")
# Read all JSON pages from crawl folder
pages = []
for fname in os.listdir(req.source):
full_path = os.path.join(req.source, fname)
if os.path.isfile(full_path) and fname.endswith(".json"):
try:
with open(full_path, "r", encoding="utf-8") as f:
pages.append(json.load(f))
except Exception as e:
print(f"Warning: Failed to read {full_path}: {e}")
if not pages:
raise HTTPException(status_code=400, detail="No JSON pages found in the source folder")
# Index pages
vector_count, errors = index_pages(
pages,
chunk_size=req.chunk_size,
chunk_overlap=req.chunk_overlap,
embedding_model=req.embedding_model
)
print(f"Indexed {vector_count} vectors with {len(errors)} errors")
return {
"vector_count": vector_count,
"errors": errors
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")
# --------------------------
# /ask endpoint
# --------------------------
@app.post("/ask")
def ask(req: AskRequest):
"""Answer a question using the indexed vectors."""
try:
if not req.question or not req.question.strip():
raise HTTPException(status_code=400, detail="Question cannot be empty")
response = answer_question(req.question, top_k=req.top_k)
if response is None:
raise HTTPException(status_code=500, detail="Failed to generate an answer")
# Log timings for observability
timings = response.get("timings", {})
print(f"Ask: retrieval={timings.get('retrieval_ms')}ms, "
f"generation={timings.get('generation_ms')}ms, total={timings.get('total_ms')}ms")
return response
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Unexpected error: {e}")