-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi.py
More file actions
160 lines (123 loc) · 4.79 KB
/
api.py
File metadata and controls
160 lines (123 loc) · 4.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""
api.py – FastAPI server that exposes Baymax RAG chat for a VS Code extension.
Start with:
uvicorn api:app --host 127.0.0.1 --port 8888 --reload
Endpoints
---------
POST /chat – send a message and get an answer
GET /history – retrieve conversation history for a session
DELETE /history – clear conversation history for a session
GET /health – liveness check
"""
import os
import uuid
from collections import defaultdict
from typing import Optional
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from dotenv import load_dotenv
from data_manager import finalize_chroma_swap
from chat import Chat
# ---------------------------------------------------------------------------
# Boot-time setup
# ---------------------------------------------------------------------------
load_dotenv(override=True)
finalize_chroma_swap() # same guard as streamlit.py
app = FastAPI(
title="Baymax RAG API",
description="REST interface for the Baymax RAG chat system. Designed for consumption by the VS Code extension.",
version="1.0.0",
)
# Allow VS Code webview / localhost callers
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ---------------------------------------------------------------------------
# Singleton Chat instance (shared across all sessions, thread-safe for reads)
# ---------------------------------------------------------------------------
_chat: Optional[Chat] = None
def get_chat() -> Chat:
global _chat
if _chat is None:
_chat = Chat("api")
return _chat
# ---------------------------------------------------------------------------
# In-memory session history { session_id -> [ {role, content}, … ] }
# ---------------------------------------------------------------------------
_history: dict[str, list[dict]] = defaultdict(list)
# ---------------------------------------------------------------------------
# Request / Response schemas
# ---------------------------------------------------------------------------
class ChatRequest(BaseModel):
message: str
session_id: Optional[str] = None # omit → auto-generated fresh session
class MessageRecord(BaseModel):
role: str # "user" | "assistant"
content: str
class ChatResponse(BaseModel):
session_id: str
message: str
history: list[MessageRecord]
class HistoryResponse(BaseModel):
session_id: str
history: list[MessageRecord]
class HealthResponse(BaseModel):
status: str
chat_model: str
use_graph: bool
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.get("/health", response_model=HealthResponse, tags=["Utility"])
def health():
"""Quick liveness / config check."""
chat = get_chat()
return HealthResponse(
status="ok",
chat_model=os.getenv("CHAT_MODEL", "default"),
use_graph=chat.use_graph,
)
@app.post("/chat", response_model=ChatResponse, tags=["Chat"])
def chat_endpoint(req: ChatRequest):
"""
Send a message to Baymax and receive an answer.
- **message**: the user prompt
- **session_id**: optional; supply the value returned by a previous call
to maintain conversation context in the history log.
If omitted a new session UUID is created.
"""
if not req.message.strip():
raise HTTPException(status_code=400, detail="message cannot be empty")
session_id = req.session_id or str(uuid.uuid4())
chat = get_chat()
# Persist user message
_history[session_id].append({"role": "user", "content": req.message})
try:
answer = chat.query(req.message)
except Exception as exc:
raise HTTPException(status_code=500, detail=f"Chat error: {exc}") from exc
# Persist assistant message
_history[session_id].append({"role": "assistant", "content": answer})
return ChatResponse(
session_id=session_id,
message=answer,
history=[MessageRecord(**m) for m in _history[session_id]],
)
@app.get("/history/{session_id}", response_model=HistoryResponse, tags=["Chat"])
def get_history(session_id: str):
"""Return the full conversation history for a session."""
if session_id not in _history:
raise HTTPException(status_code=404, detail="Session not found")
return HistoryResponse(
session_id=session_id,
history=[MessageRecord(**m) for m in _history[session_id]],
)
@app.delete("/history/{session_id}", tags=["Chat"])
def clear_history(session_id: str):
"""Clear conversation history for a session."""
_history.pop(session_id, None)
return {"detail": f"History for session {session_id!r} cleared."}