diff --git a/multi_llm_chatbot_backend/app/api/routes/chat.py b/multi_llm_chatbot_backend/app/api/routes/chat.py index 5e27b9b..f575b01 100644 --- a/multi_llm_chatbot_backend/app/api/routes/chat.py +++ b/multi_llm_chatbot_backend/app/api/routes/chat.py @@ -25,6 +25,7 @@ # Enhanced data models class UserInput(BaseModel): user_input: str + chat_session_id: Optional[str] = None class ChatMessage(BaseModel): user_input: str @@ -62,6 +63,42 @@ def to_ndjson(self) -> str: return json.dumps(self.model_dump(mode="json"), ensure_ascii=False) + "\n" +def build_user_persist_message( + content: str, + **extra, +) -> Dict[str, Any]: + """Build the dict persisted to MongoDB for a user message.""" + msg = { + "id": str(ObjectId()), + "type": "user", + "content": content, + } + msg.update(extra) + return msg + + +def build_advisor_persist_message( + persona_id: str, + persona_name: str, + content: str, + used_documents: bool = False, + document_chunks_used: int = 0, + **extra, +) -> Dict[str, Any]: + """Build the dict persisted to MongoDB for an advisor/orchestrator response.""" + msg = { + "id": str(ObjectId()), + "type": "advisor", + "persona_id": persona_id, + "advisorName": persona_name, + "content": content, + "used_documents": used_documents, + "document_chunks_used": document_chunks_used, + } + msg.update(extra) + return msg + + @router.post("/chat-stream") async def chat_stream( message: ChatMessage, @@ -95,11 +132,10 @@ async def _event_generator(): # Append user message to in-memory session and persist to MongoDB session.append_message("user", message.user_input) if message.chat_session_id: - await persist_message(message.chat_session_id, { - "id": str(ObjectId()), - "type": "user", - "content": message.user_input, - }) + await persist_message( + message.chat_session_id, + build_user_persist_message(content=message.user_input), + ) if await chat_orchestrator.needs_clarification_improved(session, message.user_input): clar = await chat_orchestrator.generate_contextual_clarification(message.user_input) @@ -120,7 +156,17 @@ async def _event_generator(): # directly and skip persona generation. tool_result = await chat_orchestrator.get_tool_response(message.user_input) if tool_result.used_tool: + # Append user message to in-memory session and persist to MongoDB session.append_message("orchestrator", tool_result.text) + if message.chat_session_id: + await persist_message( + message.chat_session_id, + build_advisor_persist_message( + persona_id="orchestrator", + persona_name="Orchestrator", + content=tool_result.text, + ), + ) yield ChatStreamLine( type="advisor", data={ @@ -167,6 +213,17 @@ async def _run(pid: str) -> None: for _ in range(len(tasks)): result = await done_queue.get() + if message.chat_session_id: + await persist_message( + message.chat_session_id, + build_advisor_persist_message( + persona_id=result["persona_id"], + persona_name=result["persona_name"], + content=result["response"], + used_documents=result.get("used_documents", False), + document_chunks_used=result.get("document_chunks_used", 0), + ), + ) line = ChatStreamLine( type="advisor", data={ @@ -323,7 +380,16 @@ async def chat_with_specific_advisor(persona_id: str, input: UserInput, request: # Use async session management session_id = await get_or_create_session_for_request_async(request) - + + if input.chat_session_id: + await persist_message( + input.chat_session_id, + build_user_persist_message( + content=input.user_input, + isExpandRequest=True, + ), + ) + result = await chat_orchestrator.chat_with_persona( user_input=input.user_input, persona_id=persona_id, @@ -333,12 +399,32 @@ async def chat_with_specific_advisor(persona_id: str, input: UserInput, request: # Handle response structure if result.get("type") == "single_persona_response" and "persona" in result: persona_data = result["persona"] + if input.chat_session_id: + await persist_message( + input.chat_session_id, + build_advisor_persist_message( + persona_id=persona_data["persona_id"], + persona_name=persona_data["persona_name"], + content=persona_data["response"], + isExpansion=True, + ), + ) return { "persona": persona_data["persona_name"], "persona_id": persona_data["persona_id"], "response": persona_data["response"] } elif "persona_id" in result and "response" in result: + if input.chat_session_id: + await persist_message( + input.chat_session_id, + build_advisor_persist_message( + persona_id=result["persona_id"], + persona_name=result["persona_name"], + content=result["response"], + isExpansion=True, + ), + ) return { "persona": result["persona_name"], "persona_id": result["persona_id"], @@ -373,7 +459,19 @@ async def reply_to_advisor(reply: ReplyToAdvisor, request: Request): session_id = await get_or_create_session_for_request_async(request) session = session_manager.get_session(session_id) - + + if reply.chat_session_id: + await persist_message( + reply.chat_session_id, + build_user_persist_message( + content=reply.user_input, + replyTo={ + "advisorId": reply.advisor_id, + "messageId": reply.original_message_id, + }, + ), + ) + # Find the original message being replied to for context original_message = None if reply.original_message_id: @@ -396,6 +494,16 @@ async def reply_to_advisor(reply: ReplyToAdvisor, request: Request): # Handle response structure if result.get("type") == "single_persona_response" and "persona" in result: persona_data = result["persona"] + if reply.chat_session_id: + await persist_message( + reply.chat_session_id, + build_advisor_persist_message( + persona_id=persona_data["persona_id"], + persona_name=persona_data["persona_name"], + content=persona_data["response"], + isReply=True, + ), + ) return { "type": "advisor_reply", "persona": persona_data["persona_name"], @@ -404,6 +512,16 @@ async def reply_to_advisor(reply: ReplyToAdvisor, request: Request): "original_message_id": reply.original_message_id } elif "persona_id" in result and "response" in result: + if reply.chat_session_id: + await persist_message( + reply.chat_session_id, + build_advisor_persist_message( + persona_id=result["persona_id"], + persona_name=result["persona_name"], + content=result["response"], + isReply=True, + ), + ) return { "type": "advisor_reply", "persona": result["persona_name"], diff --git a/multi_llm_chatbot_backend/app/api/routes/documents.py b/multi_llm_chatbot_backend/app/api/routes/documents.py index 4756e27..51d7fae 100644 --- a/multi_llm_chatbot_backend/app/api/routes/documents.py +++ b/multi_llm_chatbot_backend/app/api/routes/documents.py @@ -9,6 +9,7 @@ from app.utils.file_export import prepare_export_response, generate_pdf_file_from_blocks from app.core.session_manager import get_session_manager from app.core.bootstrap import chat_orchestrator +from app.api.routes.chat_sessions import persist_message from app.core.auth import get_current_active_user from app.core.database import get_database from app.models.user import User @@ -217,6 +218,13 @@ async def upload_document( f"Document uploaded: '{doc_title}' ({file.filename}) - {rag_result['chunks_created']} sections processed, ~{rag_result['total_tokens']} tokens analyzed. You can now ask questions about this document by referencing it by name." ) + if chat_session_id: + await persist_message(chat_session_id, { + "id": str(ObjectId()), + "type": "document_upload", + "content": f"Document uploaded: {file.filename} ({rag_result['chunks_created']} sections processed)", + }) + # Return session info for frontend tracking return { "message": f"Document '{file.filename}' uploaded and processed successfully.", diff --git a/multi_llm_chatbot_backend/app/tests/unit/test_chat_stream_persistence.py b/multi_llm_chatbot_backend/app/tests/unit/test_chat_stream_persistence.py new file mode 100644 index 0000000..6340191 --- /dev/null +++ b/multi_llm_chatbot_backend/app/tests/unit/test_chat_stream_persistence.py @@ -0,0 +1,173 @@ +import unittest + +from bson import ObjectId + +# The conftest stubs ``app.api.routes.chat`` with a MagicMock so that other +# test modules can import ``app.api.routes`` without booting the LLM stack. +# Pop the stub so we can import the real helpers, then leave the real module +# in sys.modules so it doesn't interfere with other test files. +import sys + +sys.modules.pop("app.api.routes.chat", None) + +from app.api.routes.chat import ( # noqa: E402 + build_advisor_persist_message, + build_user_persist_message, +) + + +REQUIRED_FIELDS = {"id", "type", "persona_id", "advisorName", "content", + "used_documents", "document_chunks_used"} + + +# ------------------------------------------------------------------ +# build_advisor_persist_message +# ------------------------------------------------------------------ + + +class TestBuildAdvisorPersistMessage(unittest.TestCase): + + def test_includes_all_required_fields(self): + msg = build_advisor_persist_message( + persona_id="advisor_a", + persona_name="Advisor A", + content="Some advice.", + ) + self.assertTrue(REQUIRED_FIELDS.issubset(msg.keys()), + f"Missing fields: {REQUIRED_FIELDS - msg.keys()}") + + def test_type_is_advisor(self): + msg = build_advisor_persist_message( + persona_id="x", persona_name="X", content="c", + ) + self.assertEqual(msg["type"], "advisor") + + def test_maps_persona_name_to_advisorName(self): + msg = build_advisor_persist_message( + persona_id="methodologist", + persona_name="Dr. Method", + content="content", + ) + self.assertEqual(msg["advisorName"], "Dr. Method") + self.assertEqual(msg["persona_id"], "methodologist") + + def test_defaults_for_document_fields(self): + msg = build_advisor_persist_message( + persona_id="x", persona_name="X", content="c", + ) + self.assertFalse(msg["used_documents"]) + self.assertEqual(msg["document_chunks_used"], 0) + + def test_explicit_document_fields(self): + msg = build_advisor_persist_message( + persona_id="x", + persona_name="X", + content="c", + used_documents=True, + document_chunks_used=5, + ) + self.assertTrue(msg["used_documents"]) + self.assertEqual(msg["document_chunks_used"], 5) + + def test_id_is_valid_objectid_string(self): + msg = build_advisor_persist_message( + persona_id="x", persona_name="X", content="c", + ) + ObjectId(msg["id"]) # raises if invalid + + def test_each_call_generates_unique_id(self): + ids = { + build_advisor_persist_message( + persona_id="x", persona_name="X", content="c", + )["id"] + for _ in range(10) + } + self.assertEqual(len(ids), 10) + + def test_extra_kwargs_included(self): + msg = build_advisor_persist_message( + persona_id="x", + persona_name="X", + content="c", + isReply=True, + ) + self.assertTrue(msg["isReply"]) + + def test_orchestrator_message_shape(self): + msg = build_advisor_persist_message( + persona_id="orchestrator", + persona_name="Orchestrator", + content="Tool output here", + ) + self.assertEqual(msg["persona_id"], "orchestrator") + self.assertEqual(msg["advisorName"], "Orchestrator") + self.assertEqual(msg["content"], "Tool output here") + self.assertEqual(msg["type"], "advisor") + + def test_expansion_flag_included(self): + msg = build_advisor_persist_message( + persona_id="theorist", + persona_name="Dr. Theory", + content="Here is a deeper explanation...", + isExpansion=True, + ) + self.assertEqual(msg["type"], "advisor") + self.assertTrue(msg["isExpansion"]) + self.assertEqual(msg["persona_id"], "theorist") + + +# ------------------------------------------------------------------ +# build_user_persist_message +# ------------------------------------------------------------------ + +USER_REQUIRED_FIELDS = {"id", "type", "content"} + + +class TestBuildUserPersistMessage(unittest.TestCase): + + def test_includes_required_fields(self): + msg = build_user_persist_message(content="hello") + self.assertTrue(USER_REQUIRED_FIELDS.issubset(msg.keys()), + f"Missing fields: {USER_REQUIRED_FIELDS - msg.keys()}") + + def test_type_is_user(self): + msg = build_user_persist_message(content="hello") + self.assertEqual(msg["type"], "user") + + def test_content_preserved(self): + msg = build_user_persist_message(content="Tell me more") + self.assertEqual(msg["content"], "Tell me more") + + def test_id_is_valid_objectid_string(self): + msg = build_user_persist_message(content="hello") + ObjectId(msg["id"]) + + def test_each_call_generates_unique_id(self): + ids = { + build_user_persist_message(content="hello")["id"] + for _ in range(10) + } + self.assertEqual(len(ids), 10) + + def test_reply_to_metadata_included(self): + msg = build_user_persist_message( + content="I disagree", + replyTo={ + "advisorId": "methodologist", + "messageId": "msg_123", + }, + ) + self.assertEqual(msg["replyTo"]["advisorId"], "methodologist") + self.assertEqual(msg["replyTo"]["messageId"], "msg_123") + + def test_plain_message_has_no_replyTo(self): + msg = build_user_persist_message(content="hello") + self.assertNotIn("replyTo", msg) + + def test_expand_request_shape(self): + msg = build_user_persist_message( + content="Please expand on your previous response...", + isExpandRequest=True, + ) + self.assertEqual(msg["type"], "user") + self.assertTrue(msg["isExpandRequest"])