From 0ed9dabd58c25059f9ddb77f3fada3c6dd4a4d92 Mon Sep 17 00:00:00 2001 From: tranphuc8a Date: Thu, 6 Nov 2025 01:22:29 +0700 Subject: [PATCH] update 251106 --- backend/fastapi/requirements.txt | 2 + .../src/adapter/factory/service_factory.py | 5 + .../controllers/conversation_controller.py | 45 +++- .../input/controllers/gemini_controller.py | 29 ++- .../input/controllers/health_controller.py | 8 +- .../input/controllers/messages_controller.py | 85 ++++++++ .../input/controllers/response_utils.py | 23 +++ .../output/gemini/helper/gemini_client.py | 193 +++++++++++++----- .../output/gemini/service/gemini_service.py | 16 +- .../output/mysql/entities/message_entity.py | 14 +- .../repositories/conversation_repository.py | 19 ++ .../mysql/repositories/message_repository.py | 8 +- .../src/application/exceptions/exceptions.py | 47 +++++ .../ports/input/conversation_input_port.py | 5 + .../ports/output/gemini_output_port.py | 4 +- .../ports/output/message_output_port.py | 3 +- .../usecases/conversation_usecase.py | 56 ++++- .../application/usecases/gemini_usecase.py | 36 ++-- backend/fastapi/src/domain/enums/enums.py | 13 +- .../src/domain/models/conversation_domain.py | 14 +- .../src/domain/models/message_domain.py | 12 +- .../src/domain/vo/conversation_response.py | 22 +- .../fastapi/src/domain/vo/list_response.py | 5 +- .../fastapi/src/domain/vo/message_request.py | 9 +- .../fastapi/src/domain/vo/message_response.py | 5 +- .../src/domain/vo/message_update_request.py | 20 ++ backend/fastapi/src/main.py | 34 ++- 27 files changed, 584 insertions(+), 148 deletions(-) create mode 100644 backend/fastapi/src/adapter/input/controllers/messages_controller.py create mode 100644 backend/fastapi/src/adapter/input/controllers/response_utils.py create mode 100644 backend/fastapi/src/application/exceptions/exceptions.py create mode 100644 backend/fastapi/src/domain/vo/message_update_request.py diff --git a/backend/fastapi/requirements.txt b/backend/fastapi/requirements.txt index 97e8fb5..d7a32b3 100644 --- a/backend/fastapi/requirements.txt +++ b/backend/fastapi/requirements.txt @@ -28,6 +28,8 @@ typing-extensions>=4.0 # HTTP client and retry helpers httpx>=0.24 +# Enable HTTP/2 support for httpx when available +h2>=4.1 tenacity>=8.2 # Testing diff --git a/backend/fastapi/src/adapter/factory/service_factory.py b/backend/fastapi/src/adapter/factory/service_factory.py index db4f714..3018e04 100644 --- a/backend/fastapi/src/adapter/factory/service_factory.py +++ b/backend/fastapi/src/adapter/factory/service_factory.py @@ -51,6 +51,11 @@ def get_conversation_output_port() -> ConversationOutputPort: conv_repo, _, _, _ = _make_repos_and_ports() return conv_repo + @staticmethod + def get_message_output_port() -> MessageOutputPort: + _, msg_repo, _, _ = _make_repos_and_ports() + return msg_repo + @staticmethod def get_health_input_port() -> HealthInputPort: _, _, _, health_input = _make_repos_and_ports() diff --git a/backend/fastapi/src/adapter/input/controllers/conversation_controller.py b/backend/fastapi/src/adapter/input/controllers/conversation_controller.py index 828cf28..f8883e4 100644 --- a/backend/fastapi/src/adapter/input/controllers/conversation_controller.py +++ b/backend/fastapi/src/adapter/input/controllers/conversation_controller.py @@ -1,10 +1,12 @@ from fastapi import APIRouter, Query, Depends -from typing import Optional +from typing import Optional, List +from src.domain.vo.message_response import MessageResponse from src.application.ports.input.conversation_input_port import ConversationInputPort from src.domain.vo.conversation_response import ConversationResponse from src.domain.vo.conversation_update_request import ConversationUpdateRequest from src.domain.vo.list_response import ListResponse from src.adapter.factory.service_factory import ServiceFactory +from src.adapter.input.controllers.response_utils import success_response # Router organized as a grouped resource; prefix is applied when included in the app @@ -19,7 +21,8 @@ async def list_conversations( conversation_service: ConversationInputPort = Depends(ServiceFactory.get_conversation_input_port) ): """Cursor pagination: `after` is the id of the anchor element; return `limit` items after that anchor, skipping `offset` items, ordered by `created_at`.""" - return await conversation_service.get_conversation_list(limit=limit, after=after, order=order) + data = await conversation_service.get_conversation_list(limit=limit, after=after, order=order) + return success_response(data=data, message="ok", status_code=200) @router.get("/{conversation_id}", response_model=ConversationResponse) @@ -27,14 +30,16 @@ async def get_conversation( conversation_id: str, conversation_service: ConversationInputPort = Depends(ServiceFactory.get_conversation_input_port) ): - return await conversation_service.get_conversation_detail(conversation_id) + data = await conversation_service.get_conversation_detail(conversation_id) + return success_response(data=data, message="ok", status_code=200) @router.post("/", response_model=ConversationResponse) async def create_conversation( conversation_service: ConversationInputPort = Depends(ServiceFactory.get_conversation_input_port) ): - return await conversation_service.create_conversation() + data = await conversation_service.create_conversation() + return success_response(data=data, message="created", status_code=201) @router.put("/", response_model=ConversationResponse) @@ -42,7 +47,8 @@ async def update_conversation( request: ConversationUpdateRequest, conversation_service: ConversationInputPort = Depends(ServiceFactory.get_conversation_input_port) ): - return await conversation_service.update_conversation(request) + data = await conversation_service.update_conversation(request) + return success_response(data=data, message="updated", status_code=200) @router.delete("/{conversation_id}", response_model=ConversationResponse) @@ -50,7 +56,8 @@ async def delete_conversation( conversation_id: str, conversation_service: ConversationInputPort = Depends(ServiceFactory.get_conversation_input_port) ): - return await conversation_service.delete_conversation(conversation_id) + result = await conversation_service.delete_conversation(conversation_id) + return success_response(data={"deleted": result}, message="deleted", status_code=200) @router.post("/{conversation_id}/messages") @@ -59,4 +66,28 @@ async def post_message( content: str, conversation_service: ConversationInputPort = Depends(ServiceFactory.get_conversation_input_port) ): - return await conversation_service.post_message(conversation_id, content) + data = await conversation_service.post_message(conversation_id, content) + return success_response(data=data, message="created", status_code=201) + + +@router.get("/{conversation_id}/messages", response_model=ListResponse[MessageResponse]) +async def get_conversation_messages( + conversation_id: str, + after: Optional[str] = Query(None), + limit: int = Query(10, gt=0), + order: str = Query("desc", regex="^(asc|desc)$"), + conversation_service: ConversationInputPort = Depends(ServiceFactory.get_conversation_input_port) +): + """Get messages for a conversation (cursor pagination).""" + data = await conversation_service.get_conversation_messages(conversation_id=conversation_id, after=after, limit=limit, order=order) + return success_response(data=data, message="ok", status_code=200) + + +@router.get("/{conversation_id}/messages/recent", response_model=List[MessageResponse]) +async def get_recent_messages( + conversation_id: str, + k: int = Query(5, gt=0, description="Number of latest messages to return (default 5)"), + conversation_service: ConversationInputPort = Depends(ServiceFactory.get_conversation_input_port) +): + data = await conversation_service.get_recent_messages(conversation_id=conversation_id, k=k) + return success_response(data=data, message="ok", status_code=200) diff --git a/backend/fastapi/src/adapter/input/controllers/gemini_controller.py b/backend/fastapi/src/adapter/input/controllers/gemini_controller.py index 94bdd59..56c68b9 100644 --- a/backend/fastapi/src/adapter/input/controllers/gemini_controller.py +++ b/backend/fastapi/src/adapter/input/controllers/gemini_controller.py @@ -4,6 +4,8 @@ from src.application.ports.input.gemini_input_port import GeminiInputPort from src.domain.vo.message_request import MessageRequest from src.adapter.factory.service_factory import ServiceFactory +from src.adapter.input.controllers.response_utils import success_response +import json router = APIRouter(prefix="/gemini", tags=["gemini"]) @@ -15,7 +17,8 @@ async def query( gemini_service: GeminiInputPort = Depends(ServiceFactory.get_gemini_input_port), ): """Synchronous (non-streaming) Gemini query returning the full assistant text.""" - return await gemini_service.query(message_request) + resp = await gemini_service.query(message_request) + return success_response(data=resp, message="ok", status_code=200) @router.post("/stream") @@ -30,8 +33,24 @@ async def query_stream( """ async def generator() -> AsyncIterator[bytes]: + # Yield Server-Sent Events (SSE) style 'data:' frames so clients such as + # EventSource or curl can process parts immediately. Each event is + # terminated by a blank line. This also tends to reduce buffering in + # intermediate proxies. async for chunk in gemini_service.query_stream(message_request): - # encode each string chunk as utf-8 bytes - yield chunk.encode("utf-8") - - return StreamingResponse(generator(), media_type="text/plain; charset=utf-8") + if chunk is None: + continue + # Ensure chunk is a str and strip accidental newlines + text = str(chunk) + # SSE data frame + data = f"data: {json.dumps(text, ensure_ascii=False)}\n\n" + yield data.encode("utf-8") + + headers = { + # Prevent proxies from buffering the response + "Cache-Control": "no-cache, no-transform", + # For nginx / proxy buffering bypass + "X-Accel-Buffering": "no", + } + + return StreamingResponse(generator(), media_type="text/event-stream", headers=headers) diff --git a/backend/fastapi/src/adapter/input/controllers/health_controller.py b/backend/fastapi/src/adapter/input/controllers/health_controller.py index 6155331..94d82f0 100644 --- a/backend/fastapi/src/adapter/input/controllers/health_controller.py +++ b/backend/fastapi/src/adapter/input/controllers/health_controller.py @@ -1,14 +1,18 @@ from fastapi import APIRouter, Depends from src.application.ports.input.health_input_port import HealthInputPort from src.adapter.factory.service_factory import ServiceFactory +from src.adapter.input.controllers.response_utils import success_response router = APIRouter(prefix="/health", tags=["health"]) + @router.get("/", summary="Liveness probe") async def health(health_service: HealthInputPort = Depends(ServiceFactory.get_health_input_port)): - return await health_service.check_health() + data = await health_service.check_health() + return success_response(data=data, message="ok", status_code=200) @router.get("/ready", summary="Readiness probe") async def ready(health_service: HealthInputPort = Depends(ServiceFactory.get_health_input_port)): - return await health_service.check_readiness() + data = await health_service.check_readiness() + return success_response(data=data, message="ready", status_code=200) diff --git a/backend/fastapi/src/adapter/input/controllers/messages_controller.py b/backend/fastapi/src/adapter/input/controllers/messages_controller.py new file mode 100644 index 0000000..42fa7c0 --- /dev/null +++ b/backend/fastapi/src/adapter/input/controllers/messages_controller.py @@ -0,0 +1,85 @@ +from fastapi import APIRouter, Query, Depends +from typing import Optional, List +from src.adapter.factory.service_factory import ServiceFactory +from src.application.ports.output.message_output_port import MessageOutputPort +from src.application.ports.input.conversation_input_port import ConversationInputPort +from src.domain.vo.message_response import MessageResponse +from src.domain.vo.list_response import ListResponse +from src.domain.vo.message_update_request import MessageUpdateRequest +from src.domain.models.message_domain import MessageDomain +from src.adapter.input.controllers.response_utils import success_response + + +router = APIRouter(prefix="/messages", tags=["messages"]) + + +@router.get("/", response_model=ListResponse[MessageResponse]) +async def list_messages( + conversation_id: Optional[str] = Query(None, description="Filter by conversation id"), + after: Optional[str] = Query(None), + limit: int = Query(20, gt=0), + order: str = Query("desc", regex="^(asc|desc)$"), + message_repo: MessageOutputPort = Depends(ServiceFactory.get_message_output_port), +): + """List messages. If `conversation_id` is provided, returns messages for that conversation (cursor pagination).""" + if conversation_id: + messages, has_more = await message_repo.get_list_by_conversation(conversation_id, after, limit, order) + data: List[MessageResponse] = [MessageResponse.from_domain(m) for m in messages] + first_id = data[0].id if data else None + last_id = data[-1].id if data else None + payload = ListResponse[MessageResponse](data=data, first_id=first_id, last_id=last_id, has_more=has_more) + return success_response(data=payload, message="ok", status_code=200) + # If no conversation_id provided, return empty list + payload = ListResponse[MessageResponse](data=[], first_id=None, last_id=None, has_more=False) + return success_response(data=payload, message="ok", status_code=200) + + +@router.get("/{message_id}", response_model=MessageResponse) +async def get_message( + message_id: str, + message_repo: MessageOutputPort = Depends(ServiceFactory.get_message_output_port), +): + msg = await message_repo.get_by_id(message_id) + return success_response(data=MessageResponse.from_domain(msg), message="ok", status_code=200) + + +@router.put("/", response_model=MessageResponse) +async def update_message( + request: MessageUpdateRequest = Depends(MessageUpdateRequest.as_body), + message_repo: MessageOutputPort = Depends(ServiceFactory.get_message_output_port), +): + # retrieve existing message + existing = await message_repo.get_by_id(request.id) + # apply updates + updated = MessageDomain( + id=existing.id, + conversation_id=existing.conversation_id, + role=request.role if request.role is not None else existing.role, + content=request.content, + created_at=existing.created_at, + ) + saved = await message_repo.update(updated) + return success_response(data=MessageResponse.from_domain(saved), message="ok", status_code=200) + + +@router.delete("/{message_id}") +async def delete_message( + message_id: str, + message_repo: MessageOutputPort = Depends(ServiceFactory.get_message_output_port), +): + existing = await message_repo.get_by_id(message_id) + result = await message_repo.delete(existing) + return success_response(data={"deleted": result}, message="deleted", status_code=200) + + +# Helper endpoint: messages by conversation (convenience) +@router.get("/by-conversation/{conversation_id}", response_model=ListResponse[MessageResponse]) +async def get_messages_by_conversation( + conversation_id: str, + after: Optional[str] = Query(None), + limit: int = Query(20, gt=0), + order: str = Query("desc", regex="^(asc|desc)$"), + conversation_service: ConversationInputPort = Depends(ServiceFactory.get_conversation_input_port), +): + data = await conversation_service.get_conversation_messages(conversation_id=conversation_id, after=after, limit=limit, order=order) + return success_response(data=data, message="ok", status_code=200) diff --git a/backend/fastapi/src/adapter/input/controllers/response_utils.py b/backend/fastapi/src/adapter/input/controllers/response_utils.py new file mode 100644 index 0000000..02d1ffb --- /dev/null +++ b/backend/fastapi/src/adapter/input/controllers/response_utils.py @@ -0,0 +1,23 @@ +from fastapi.responses import JSONResponse +from typing import Any + +from pydantic import BaseModel + +def to_serializable(obj: Any): + if isinstance(obj, BaseModel): + return obj.model_dump() + if isinstance(obj, list): + return [to_serializable(o) for o in obj] + if isinstance(obj, dict): + return {k: to_serializable(v) for k, v in obj.items()} + return obj + + +def success_response(data: Any = None, message: str = "OK", status_code: int = 200) -> JSONResponse: + payload = {"status_code": status_code, "message": message, "data": to_serializable(data)} + return JSONResponse(content=payload, status_code=status_code) + + +def error_response(message: str = "Error", status_code: int = 500, data: Any = None) -> JSONResponse: + payload = {"status_code": status_code, "message": message, "data": to_serializable(data)} + return JSONResponse(content=payload, status_code=status_code) diff --git a/backend/fastapi/src/adapter/output/gemini/helper/gemini_client.py b/backend/fastapi/src/adapter/output/gemini/helper/gemini_client.py index 7099459..4345989 100644 --- a/backend/fastapi/src/adapter/output/gemini/helper/gemini_client.py +++ b/backend/fastapi/src/adapter/output/gemini/helper/gemini_client.py @@ -1,9 +1,13 @@ from typing import Any, Dict, Optional, AsyncIterator -from httpx import AsyncClient, HTTPStatusError, RequestError +from httpx import AsyncClient, HTTPStatusError, RequestError, Timeout, Limits +import logging +from urllib.parse import urlparse, urlunparse from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type from src.application.config.config import settings +from src.domain.enums.enums import ERole import asyncio import json +import re class GeminiClientError(RuntimeError): @@ -24,15 +28,35 @@ def __init__( self.url = url or settings.GEMINI_URL self.api_key = api_key or settings.GEMINI_API_KEY or "" self.timeout = timeout or settings.GEMINI_TIMEOUT_SECONDS - self.client: AsyncClient = AsyncClient(timeout=self.timeout) + # Enable HTTP/2 when available and set modest connection limits for better throughput. + # Keep a single shared client to reuse connections. + self.client: AsyncClient = AsyncClient( + timeout=self.timeout, + http2=True, + limits=Limits(max_connections=100, max_keepalive_connections=20, keepalive_expiry=60.0), + ) self.headers: Dict[str, str] = { "Content-Type": "application/json", - "x-goog-api-key": self.api_key + # Prefer JSON streaming as returned by Google for streamGenerateContent + "Accept": "application/json", + "x-goog-api-key": self.api_key, } async def stop(self) -> None: if self.client is not None: await self.client.aclose() + + async def health_check(self) -> bool: + """Perform a simple health check by sending a request to the base URL.""" + if not self.url: + raise GeminiClientError("GEMINI_URL is not configured") + try: + resp = await self.client.get(self.url, headers=self.headers) + resp.raise_for_status() + return True + except Exception as exc: + logging.error(f"GeminiClient health check failed: {exc}") + return False @retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=1, max=10), retry=retry_if_exception_type(RequestError)) @@ -45,7 +69,11 @@ def _get_payload(self, prompt: Any, model: Optional[str] = None, extra: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: - content: Any = prompt if isinstance(prompt, list) else [{"role": "user", "parts": [{"text": prompt}]}] + content: Any = ( + prompt + if isinstance(prompt, list) + else [{"role": ERole.USER.value, "parts": [{"text": prompt}]}] + ) payload: Dict[str, Any] = {"contents": content} if model: payload["model"] = model @@ -53,6 +81,44 @@ def _get_payload(self, payload.update(extra) return payload + def _apply_model_to_url(self, base_url: str, model: Optional[str]) -> str: + """Replace the model segment in `.../models/:` with the provided model. + + If model is falsy or the URL doesn't match the expected format, return the original URL. + """ + if not model: + return base_url + try: + parsed = urlparse(base_url) + path = parsed.path or "" + marker = "/models/" + idx = path.find(marker) + if idx == -1: + return base_url + start = idx + len(marker) + colon = path.find(":", start) + if colon == -1: + return base_url + # substitute model between start and colon + new_path = path[:start] + str(model) + path[colon:] + return urlunparse((parsed.scheme, parsed.netloc, new_path, parsed.params, parsed.query, parsed.fragment)) + except Exception: + return base_url + + def _to_stream_url(self, base_url: str) -> str: + """Convert a non-stream generate URL to its streaming variant when applicable. + + Example: + .../models/gemini-2.5-flash:generateContent?key=... -> + .../models/gemini-2.5-flash:streamGenerateContent?key=... + """ + if ":streamGenerateContent" in base_url: + return base_url + if ":generateContent" in base_url: + return base_url.replace(":generateContent", ":streamGenerateContent") + # If method not present, assume caller already set a streaming-capable URL + return base_url + async def generate(self, prompt: Any, model: Optional[str] = None, @@ -60,9 +126,10 @@ async def generate(self, if not self.url: raise GeminiClientError("GEMINI_URL is not configured") payload: Dict[str, Any] = self._get_payload(prompt, model, extra) + url_to_use = self._apply_model_to_url(self.url, model) try: - return await self._post(self.url, payload, self.headers) + return await self._post(url_to_use, payload, self.headers) except RequestError as exc: raise GeminiClientError(f"Request error while calling Gemini API: {exc}") from exc except HTTPStatusError as exc: @@ -87,68 +154,82 @@ async def stream_generate(self, payload: Dict[str, Any] = self._get_payload(prompt, model, extra) headers = self.headers.copy() - headers.setdefault("Accept", "text/event-stream") try: - async with self.client.stream("POST", self.url, json=payload, headers=headers, timeout=self.timeout) as resp: + # For long-lived streams, disable read timeout to avoid premature disconnects. + stream_timeout = Timeout(connect=self.timeout, read=None, write=self.timeout, pool=self.timeout) + # Apply model override into URL first, then convert to streaming method + stream_url = self._to_stream_url(self._apply_model_to_url(self.url, model)) + async with self.client.stream("POST", stream_url, json=payload, headers=headers, timeout=stream_timeout) as resp: resp.raise_for_status() - event_lines: list[str] = [] + # Google streamGenerateContent may return NDJSON or SSE. Some providers split a single + # JSON object/array across multiple SSE events. We therefore accumulate SSE payloads + # and greedily extract any completed "text": "..." fields without echoing raw JSON. + sse_mode = False + json_buf: str = "" + processed_idx: int = 0 + text_pattern = re.compile(r'"text"\s*:\s*"((?:\\.|[^"\\])*)"') + async for line in resp.aiter_lines(): - if not line: - # empty line = end of an SSE event; flush if we have buffered data - if not event_lines: - continue - data = "\n".join(event_lines).strip() - event_lines = [] - else: - line = line.strip() - if line.startswith("data:"): - event_lines.append(line[len("data:"):].lstrip()) - continue - # not an SSE "data:" line — treat the line itself as a unit - data = line + if line is None: + continue + raw = line.rstrip("\n") - if not data: + # Detect SSE markers + if raw.startswith(":"): + # SSE comment/keepalive + continue + if raw.startswith("data:"): + sse_mode = True + sse_payload = raw[len("data:"):].lstrip() + if sse_payload == "[DONE]": + return + # Accumulate payload fragments; many providers split JSON across events + json_buf += (sse_payload + "\n") + + # Greedily extract any complete text fields from the rolling buffer + for m in text_pattern.finditer(json_buf, processed_idx): + raw_text = m.group(1) + try: + # Use json.loads on a quoted string to unescape sequences + decoded = json.loads(f'"{raw_text}"') + except Exception: + decoded = raw_text + if decoded: + yield decoded + processed_idx = m.end() + # Optionally trim buffer to keep memory bounded + if processed_idx > 0 and processed_idx > len(json_buf) // 2: + json_buf = json_buf[processed_idx:] + processed_idx = 0 continue - # termination sentinel used by many streaming APIs - if data == "[DONE]": - return + # Blank line between SSE events — ignore; we accumulate across events + if raw.strip() == "" and sse_mode: + continue - # Try to parse JSON payloads, then extract text parts if present - try: - obj = json.loads(data) - except Exception: - # not JSON — yield raw chunk - yield data + # NDJSON or multi-line JSON fallback: + # Accumulate raw lines into the rolling buffer and extract text fields with regex + data = raw.strip() + if not data: continue + if data == "[DONE]": + return - # try common Gemini-like shapes: top-level "candidates" -> candidate.content.parts[].text - candidates = obj.get("candidates") or [] - emitted = False - for cand in candidates: - content = cand.get("content", {}) if isinstance(cand, dict) else {} - parts = content.get("parts", []) if isinstance(content, dict) else [] - for part in parts: - if not isinstance(part, dict): - continue - text = part.get("text") - if text: - emitted = True - yield text - - # fallback: sometimes text may be at top-level fields like "text" or "message" - if not emitted: - fallback_text = None - if isinstance(obj, dict): - for k in ("text", "message", "content"): - v = obj.get(k) - if isinstance(v, str) and v: - fallback_text = v - break - if fallback_text: - yield fallback_text + json_buf += (data + "\n") + for m in text_pattern.finditer(json_buf, processed_idx): + raw_text = m.group(1) + try: + decoded = json.loads(f'"{raw_text}"') + except Exception: + decoded = raw_text + if decoded: + yield decoded + processed_idx = m.end() + if processed_idx > 0 and processed_idx > len(json_buf) // 2: + json_buf = json_buf[processed_idx:] + processed_idx = 0 except RequestError as exc: raise GeminiClientError(f"Request error while streaming from Gemini API: {exc}") from exc diff --git a/backend/fastapi/src/adapter/output/gemini/service/gemini_service.py b/backend/fastapi/src/adapter/output/gemini/service/gemini_service.py index d20ea1c..3754e41 100644 --- a/backend/fastapi/src/adapter/output/gemini/service/gemini_service.py +++ b/backend/fastapi/src/adapter/output/gemini/service/gemini_service.py @@ -12,13 +12,14 @@ class GeminiService(GeminiOutputPort): def __init__(self, gemini_client: GeminiClient): self.gemini_client = gemini_client - async def generate(self, model: str, message: MessageDomain, history: List[MessageDomain]) -> str: + async def generate(self, model: str, history: List[MessageDomain]) -> str: """Call the Gemini client and return the assistant text as a single string.""" - # Prepare prompt: convert history + message into contents list expected by the client + # Prepare prompt: convert history into contents list expected by the client contents = [] + history.sort(key=lambda x: x.created_at) for m in history: - contents.append({"role": m.role, "parts": [{"text": m.content}]}) - contents.append({"role": message.role, "parts": [{"text": message.content}]}) + role = m.role.value if hasattr(m.role, "value") else str(m.role) + contents.append({"role": role, "parts": [{"text": m.content}]}) try: raw = await self.gemini_client.generate(contents, model=model) @@ -47,11 +48,12 @@ async def generate(self, model: str, message: MessageDomain, history: List[Messa return "".join(texts) return str(raw) - async def stream_generate(self, model: str, message: MessageDomain, history: List[MessageDomain]) -> AsyncIterator[str]: + async def stream_generate(self, model: str, history: List[MessageDomain]) -> AsyncIterator[str]: contents = [] + history.sort(key=lambda x: x.created_at) for m in history: - contents.append({"role": m.role, "parts": [{"text": m.content}]}) - contents.append({"role": message.role, "parts": [{"text": message.content}]}) + role = m.role.value if hasattr(m.role, "value") else str(m.role) + contents.append({"role": role, "parts": [{"text": m.content}]}) async for part in self.gemini_client.stream_generate(contents, model=model): yield part diff --git a/backend/fastapi/src/adapter/output/mysql/entities/message_entity.py b/backend/fastapi/src/adapter/output/mysql/entities/message_entity.py index 848c755..655f503 100644 --- a/backend/fastapi/src/adapter/output/mysql/entities/message_entity.py +++ b/backend/fastapi/src/adapter/output/mysql/entities/message_entity.py @@ -13,7 +13,13 @@ class MessageEntity(Base, AbstractEntity[MessageDomain]): __tablename__ = "messages" id = Column(String(64), primary_key=True, default=lambda: str(uuid.uuid4())) conversation_id = Column(String(64), ForeignKey("conversations.id"), nullable=False) - role = Column(Enum(ERole), nullable=False) + # Store the enum by its value (e.g. 'user', 'model') and allow SQLAlchemy + # to map DB strings back to the Python Enum using the value, not the name. + # values_callable tells SQLAlchemy which values to expect in the DB. + role = Column( + Enum(ERole, values_callable=lambda enum: [e.value for e in enum], name="erole"), + nullable=False, + ) content = Column(Text, nullable=False) created_at = Column(Integer, nullable=False) @@ -31,8 +37,8 @@ def from_domain(cls, domain_obj: MessageDomain) -> "MessageEntity": return ent def to_domain(self) -> MessageDomain: - # normalize role to a plain string and created_at to an int timestamp - role_val = self.role.value if hasattr(self.role, "value") else str(self.role) + # normalize role to enum and created_at to an int timestamp + role_enum = self.role if isinstance(self.role, ERole) else ERole(str(self.role)) created = self.created_at if isinstance(created, datetime): created = int(created.timestamp()) @@ -40,7 +46,7 @@ def to_domain(self) -> MessageDomain: return MessageDomain( id=cast(str, self.id), conversation_id=cast(str, self.conversation_id), - role=cast(str, role_val), + role=role_enum, content=cast(str, self.content), created_at=cast(int, created), ) diff --git a/backend/fastapi/src/adapter/output/mysql/repositories/conversation_repository.py b/backend/fastapi/src/adapter/output/mysql/repositories/conversation_repository.py index f154ef2..9990f22 100644 --- a/backend/fastapi/src/adapter/output/mysql/repositories/conversation_repository.py +++ b/backend/fastapi/src/adapter/output/mysql/repositories/conversation_repository.py @@ -3,6 +3,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from src.application.ports.output.health_check_output_port import HealthCheckOutputPort from src.adapter.output.mysql.entities import ConversationEntity +from src.adapter.output.mysql.entities.message_entity import MessageEntity from src.domain.models.conversation_domain import ConversationDomain from src.application.ports.output.conversation_output_port import ConversationOutputPort @@ -58,6 +59,24 @@ async def get_all(self, limit: int, after: Optional[str], order: str) -> Tuple[L items = rows[:limit] # map to domain domains = [r.to_domain() for r in items] + + # Fetch message counts for all returned conversations in one query to + # avoid N+1 queries. If there are no conversations, skip the query. + if domains: + convo_ids = [d.id for d in domains] + count_stmt = ( + select(MessageEntity.conversation_id, func.count()) + .where(MessageEntity.conversation_id.in_(convo_ids)) + .group_by(MessageEntity.conversation_id) + ) + res = await self.db.execute(count_stmt) + rows = res.all() + # rows are tuples (conversation_id, count) + counts = {r[0]: int(r[1]) for r in rows} + # attach counts to domains (use 0 as default) + for d in domains: + d.messages_count = counts.get(d.id, 0) + return domains, has_more async def save(self, conversation: ConversationDomain) -> ConversationDomain: diff --git a/backend/fastapi/src/adapter/output/mysql/repositories/message_repository.py b/backend/fastapi/src/adapter/output/mysql/repositories/message_repository.py index 7fe1cb2..2ffe436 100644 --- a/backend/fastapi/src/adapter/output/mysql/repositories/message_repository.py +++ b/backend/fastapi/src/adapter/output/mysql/repositories/message_repository.py @@ -74,7 +74,7 @@ async def save(self, message: MessageDomain) -> MessageDomain: ent = None try: # message.id may be numeric-like or generated; try numeric primary-key lookup first - ent = await self.db.get(MessageEntity, int(message.id)) + ent = await self.db.get(MessageEntity, message.id) except Exception: ent = None @@ -96,7 +96,7 @@ async def delete(self, message: MessageDomain) -> bool: # find by PK ent = None try: - ent = await self.db.get(MessageEntity, int(message.id)) + ent = await self.db.get(MessageEntity, message.id) except Exception: # fallback: query by id equality stmt = select(MessageEntity).where(MessageEntity.id == message.id) @@ -126,7 +126,7 @@ async def get_latest_by_conversation(self, conversation_id: str, count: int) -> rows = res.scalars().all() return [r.to_domain() for r in rows] - async def get_by_id(self, message_id: int) -> MessageDomain: + async def get_by_id(self, message_id: str) -> MessageDomain: ent = await self.db.get(MessageEntity, message_id) if ent is None: raise ValueError(f"Message not found: {message_id}") @@ -136,7 +136,7 @@ async def update(self, message: MessageDomain) -> MessageDomain: # update by PK ent = None try: - ent = await self.db.get(MessageEntity, int(message.id)) + ent = await self.db.get(MessageEntity, message.id) except Exception: # fallback to query stmt = select(MessageEntity).where(MessageEntity.id == message.id) diff --git a/backend/fastapi/src/application/exceptions/exceptions.py b/backend/fastapi/src/application/exceptions/exceptions.py new file mode 100644 index 0000000..f8385b9 --- /dev/null +++ b/backend/fastapi/src/application/exceptions/exceptions.py @@ -0,0 +1,47 @@ +from typing import Any + + +class AppException(Exception): + """Base application exception with a status code and optional payload.""" + + def __init__(self, message: str = "Application error", status_code: int = 500, code: str | None = None, payload: Any = None): + super().__init__(message) + self.message = message + self.status_code = status_code + self.code = code + self.payload = payload + + +class NotFoundError(AppException): + def __init__(self, message: str = "Resource not found", payload: Any = None): + super().__init__(message=message, status_code=404, code="not_found", payload=payload) + + +class BadRequestError(AppException): + def __init__(self, message: str = "Bad request", payload: Any = None): + super().__init__(message=message, status_code=400, code="bad_request", payload=payload) + + +class UnauthorizedError(AppException): + def __init__(self, message: str = "Unauthorized", payload: Any = None): + super().__init__(message=message, status_code=401, code="unauthorized", payload=payload) + + +class ConflictError(AppException): + def __init__(self, message: str = "Conflict", payload: Any = None): + super().__init__(message=message, status_code=409, code="conflict", payload=payload) + + +class InternalServerError(AppException): + def __init__(self, message: str = "Internal server error", payload: Any = None): + super().__init__(message=message, status_code=500, code="internal_error", payload=payload) + + +class BadGatewayError(AppException): + def __init__(self, message: str = "Bad gateway", payload: Any = None): + super().__init__(message=message, status_code=502, code="bad_gateway", payload=payload) + + +class GatewayTimeoutError(AppException): + def __init__(self, message: str = "Gateway timeout", payload: Any = None): + super().__init__(message=message, status_code=504, code="gateway_timeout", payload=payload) diff --git a/backend/fastapi/src/application/ports/input/conversation_input_port.py b/backend/fastapi/src/application/ports/input/conversation_input_port.py index e5b32a3..094de6c 100644 --- a/backend/fastapi/src/application/ports/input/conversation_input_port.py +++ b/backend/fastapi/src/application/ports/input/conversation_input_port.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Optional +from typing import List from src.domain.vo.conversation_update_request import ConversationUpdateRequest from src.domain.vo.conversation_response import ConversationResponse from src.domain.vo.message_response import MessageResponse @@ -18,6 +19,10 @@ async def get_conversation_list(self, after: Optional[str] = None, limit: int = async def get_conversation_messages(self, conversation_id: str, after: Optional[str] = None, limit: int = 10, order: Optional[str] = "desc") -> ListResponse[MessageResponse]: pass + @abstractmethod + async def get_recent_messages(self, conversation_id: str, k: int = 5) -> List[MessageResponse]: + pass + @abstractmethod async def create_conversation(self) -> ConversationResponse: pass diff --git a/backend/fastapi/src/application/ports/output/gemini_output_port.py b/backend/fastapi/src/application/ports/output/gemini_output_port.py index 8493b03..bc5123c 100644 --- a/backend/fastapi/src/application/ports/output/gemini_output_port.py +++ b/backend/fastapi/src/application/ports/output/gemini_output_port.py @@ -11,7 +11,7 @@ class GeminiOutputPort(ABC): @abstractmethod async def generate( - self, model: str, message: MessageDomain, history: List[MessageDomain] + self, model: str, history: List[MessageDomain] ) -> str: """ Sinh một phản hồi hoàn chỉnh từ Gemini. @@ -23,7 +23,7 @@ async def generate( @abstractmethod async def stream_generate( - self, model: str, message: MessageDomain, history: List[MessageDomain] + self, model: str, history: List[MessageDomain] ) -> AsyncIterator[str]: """ Sinh phản hồi theo luồng (streaming). Trả về iterator async yield các phần text dần dần. diff --git a/backend/fastapi/src/application/ports/output/message_output_port.py b/backend/fastapi/src/application/ports/output/message_output_port.py index 0ee11b6..e89f047 100644 --- a/backend/fastapi/src/application/ports/output/message_output_port.py +++ b/backend/fastapi/src/application/ports/output/message_output_port.py @@ -34,7 +34,8 @@ async def get_latest_by_conversation(self, conversation_id: str, count: int) -> pass @abstractmethod - async def get_by_id(self, message_id: int) -> MessageDomain: + async def get_by_id(self, message_id: str) -> MessageDomain: + """Fetch a message by id. Implementations may accept numeric or string PKs; pass the raw id as string.""" pass @abstractmethod diff --git a/backend/fastapi/src/application/usecases/conversation_usecase.py b/backend/fastapi/src/application/usecases/conversation_usecase.py index 193d981..fe9fd86 100644 --- a/backend/fastapi/src/application/usecases/conversation_usecase.py +++ b/backend/fastapi/src/application/usecases/conversation_usecase.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List from src.domain.enums.enums import ERole from src.domain.models.message_domain import MessageDomain from src.application.ports.input.conversation_input_port import ConversationInputPort @@ -10,7 +10,7 @@ from src.domain.vo.conversation_response import ConversationResponse from src.domain.vo.list_response import ListResponse from src.domain.models.conversation_domain import ConversationDomain -from fastapi import HTTPException +from src.application.exceptions.exceptions import NotFoundError class ConversationUseCase(ConversationInputPort): @@ -20,16 +20,42 @@ class ConversationUseCase(ConversationInputPort): def __init__(self, conversation_repo: ConversationOutputPort, message_repo: MessageOutputPort): self.conversation_repo = conversation_repo self.message_repo = message_repo + + # produce truncated summary for each message to keep conversation detail lightweight + def _truncate(self, text: str, limit: int = 120) -> str: + if text is None: + return "" + if len(text) <= limit: + return text + return text[:limit].rstrip() + "..." async def get_conversation_detail(self, conversation_id: str) -> ConversationResponse: conv = await self.conversation_repo.get_by_id(conversation_id) if conv is None: - raise HTTPException(status_code=404, detail="Conversation not found") - - # fetch latest messages explicitly to avoid async lazy-loading issues + raise NotFoundError("Conversation not found") + messages, _ = await self.message_repo.get_list_by_conversation(conversation_id, limit=self.LATEST_MESSAGE_COUNT, after=None, order="desc") - conv.messages = messages - return ConversationResponse.from_domain(conv) + total_count = await self.message_repo.count_by_conversation(conversation_id) + + truncated_msgs = [ + MessageResponse( + id=m.id, + conversation_id=m.conversation_id, + role=m.role, + content=self._truncate(m.content, limit=120), + created_at=m.created_at, + ) + for m in messages + ] + + return ConversationResponse( + id=conv.id, + name=conv.name, + created_at=conv.created_at, + updated_at=conv.updated_at, + messages=truncated_msgs, + messages_count=total_count, + ) async def get_conversation_list(self, after: Optional[str] = None, @@ -56,7 +82,7 @@ async def get_conversation_messages(self, order: Optional[str] = "desc") -> ListResponse[MessageResponse]: conv = await self.conversation_repo.get_by_id(conversation_id) if conv is None: - raise HTTPException(status_code=404, detail="Conversation not found") + raise NotFoundError("Conversation not found") messageList, has_more = await self.message_repo\ .get_list_by_conversation(conversation_id, after=after, limit=limit, order=order if order else "desc") data = [MessageResponse.from_domain(m) for m in messageList] @@ -69,6 +95,14 @@ async def get_conversation_messages(self, has_more=has_more ) + async def get_recent_messages(self, conversation_id: str, k: int = 5) -> List[MessageResponse]: + conv = await self.conversation_repo.get_by_id(conversation_id) + if conv is None: + raise NotFoundError("Conversation not found") + + msgs = await self.message_repo.get_latest_by_conversation(conversation_id, k) + return [MessageResponse.from_domain(m) for m in msgs] + async def create_conversation(self) -> ConversationResponse: conv = ConversationDomain( id=generate_unique_id("conv"), @@ -83,7 +117,7 @@ async def create_conversation(self) -> ConversationResponse: async def update_conversation(self, request: ConversationUpdateRequest) -> ConversationResponse: conv = await self.conversation_repo.get_by_id(request.id) if conv is None: - raise HTTPException(status_code=404, detail="Conversation not found") + raise NotFoundError("Conversation not found") conv.name = request.name conv.updated_at = get_current_timestamp() updated = await self.conversation_repo.save(conv) @@ -92,14 +126,14 @@ async def update_conversation(self, request: ConversationUpdateRequest) -> Conve async def delete_conversation(self, conversation_id: str) -> bool: conv = await self.conversation_repo.get_by_id(conversation_id) if conv is None: - raise HTTPException(status_code=404, detail="Conversation not found") + raise NotFoundError("Conversation not found") await self.conversation_repo.delete(conv) return True async def post_message(self, conversation_id: str, content: str) -> MessageResponse: conv = await self.conversation_repo.get_by_id(conversation_id) if conv is None: - raise HTTPException(status_code=404, detail="Conversation not found") + raise NotFoundError("Conversation not found") message = MessageDomain( id=generate_unique_id("msg"), diff --git a/backend/fastapi/src/application/usecases/gemini_usecase.py b/backend/fastapi/src/application/usecases/gemini_usecase.py index 9b41cac..84b4e65 100644 --- a/backend/fastapi/src/application/usecases/gemini_usecase.py +++ b/backend/fastapi/src/application/usecases/gemini_usecase.py @@ -3,7 +3,9 @@ import asyncio import logging from typing import AsyncIterator, List, Optional +from io import StringIO from fastapi import HTTPException +from src.application.exceptions.exceptions import BadGatewayError, GatewayTimeoutError from src.application.ports.output.conversation_output_port import ConversationOutputPort from src.application.ports.output.gemini_output_port import GeminiOutputPort @@ -14,7 +16,7 @@ from src.domain.utils.utils import generate_unique_id, get_current_timestamp from src.application.config.config import settings from src.domain.models.message_domain import MessageDomain -from src.domain.enums.enums import EModel +from src.domain.enums.enums import EModel, ERole logger = logging.getLogger(__name__) @@ -51,7 +53,7 @@ async def _persist_assistant_message(self, conversation_id: Optional[str], text: msg = MessageDomain( id=generate_unique_id("msg"), conversation_id=conversation_id, - role="model", + role=ERole.MODEL, content=text, created_at=get_current_timestamp(), ) @@ -73,10 +75,10 @@ async def query(self, message_request: MessageRequest) -> str: # determine model try: - model = validate_model_name(model_hint or settings.GEMINI_URL or "gpt-4") + model = validate_model_name(model_hint or settings.GEMINI_URL or EModel.GEMINI_2_5_PRO) model_name = model.value if isinstance(model, EModel) else str(model) except Exception: - model_name = str(model_hint or "gpt-4") + model_name = str(model_hint or EModel.GEMINI_2_5_PRO) # build short history history: List[MessageDomain] = [] @@ -90,15 +92,15 @@ async def query(self, message_request: MessageRequest) -> str: try: timeout = getattr(settings, "GEMINI_TIMEOUT_SECONDS", 30) resp = await asyncio.wait_for( - self.gemini_output_port.generate(model_name, user_msg, history), timeout=timeout + self.gemini_output_port.generate(model_name, history), timeout=timeout ) except asyncio.TimeoutError: logger.exception("Gemini generate timed out") - raise HTTPException(status_code=504, detail="Gemini request timed out") + raise GatewayTimeoutError("Gemini request timed out") except Exception as exc: logger.exception("Gemini generate failed: %s", exc) # Map adapter failures to a 502 Bad Gateway so callers know it's an upstream problem - raise HTTPException(status_code=502, detail=f"Gemini service error: {exc}") + raise BadGatewayError(f"Gemini service error: {exc}") # persist assistant message and update conversation await self._persist_assistant_message(user_msg.conversation_id, resp) @@ -110,10 +112,10 @@ async def query_stream(self, message_request: MessageRequest) -> AsyncIterator[s await self._persist_user_message(user_msg) try: - model = validate_model_name(model_hint or settings.GEMINI_URL or "gpt-4") + model = validate_model_name(model_hint or settings.GEMINI_URL or EModel.GEMINI_2_5_PRO) model_name = model.value if isinstance(model, EModel) else str(model) except Exception: - model_name = str(model_hint or "gpt-4") + model_name = str(model_hint or EModel.GEMINI_2_5_PRO) history: List[MessageDomain] = [] if user_msg.conversation_id: @@ -125,15 +127,18 @@ async def query_stream(self, message_request: MessageRequest) -> AsyncIterator[s # get stream iterator try: # stream_generate returns an async iterator (async generator); do not await it - stream_iter = self.gemini_output_port.stream_generate(model_name, user_msg, history) + stream_iter = self.gemini_output_port.stream_generate(model_name, history) except Exception as exc: logger.exception("Gemini stream_generate not available or failed to start: %s", exc) - raise HTTPException(status_code=502, detail=f"Gemini stream error: {exc}") + raise BadGatewayError(f"Gemini stream error: {exc}") - parts: List[str] = [] + # Accumulate in a StringIO to reduce list growth overhead + buffer = StringIO() try: async for part in stream_iter: - parts.append(part) + if part is None: + continue + buffer.write(part) yield part except asyncio.CancelledError: logger.info("Streaming to client cancelled") @@ -143,6 +148,7 @@ async def query_stream(self, message_request: MessageRequest) -> AsyncIterator[s # Stop iteration; downstream StreamingResponse will close the connection. return - full = "".join(parts) - await self._persist_assistant_message(user_msg.conversation_id, full) + # Stream finished successfully, persist the assistant message once + full_text = buffer.getvalue() + await self._persist_assistant_message(user_msg.conversation_id, full_text) \ No newline at end of file diff --git a/backend/fastapi/src/domain/enums/enums.py b/backend/fastapi/src/domain/enums/enums.py index 4da8491..e65be12 100644 --- a/backend/fastapi/src/domain/enums/enums.py +++ b/backend/fastapi/src/domain/enums/enums.py @@ -46,13 +46,12 @@ def from_str(cls, value: Any) -> "ERole": class EModel(str, Enum): - # A small set of example model identifiers. Add more as needed. - GPT_4 = "gpt-4" - GPT_4O = "gpt-4o" - GPT_4O_MINI = "gpt-4o-mini" - GPT_3_5_TURBO = "gpt-3.5-turbo" - GEMINI_1_0 = "gemini-1.0" - GEMINI_1_5 = "gemini-1.5" + GEMINI_2_5_PRO = "gemini-2.5-pro" + GEMINI_2_5_FLASH = "gemini-2.5-flash" + GEMINI_2_5_FLASH_LITE = "gemini-2.5-flash-lite" + GEMINI_2_0_FLASH = "gemini-2.0-flash" + GEMINI_2_0_FLASH_LITE = "gemini-2.0-flash-lite" + GEMINI_FLASH_LATEST = "gemini-flash-latest" @classmethod def from_str(cls, value: Any) -> "EModel": diff --git a/backend/fastapi/src/domain/models/conversation_domain.py b/backend/fastapi/src/domain/models/conversation_domain.py index 0c744e6..f95958e 100644 --- a/backend/fastapi/src/domain/models/conversation_domain.py +++ b/backend/fastapi/src/domain/models/conversation_domain.py @@ -1,11 +1,17 @@ -from dataclasses import dataclass, field from typing import List, Optional +from pydantic import BaseModel, Field from .message_domain import MessageDomain -@dataclass -class ConversationDomain: + +class ConversationDomain(BaseModel): id: str name: str created_at: int updated_at: Optional[int] = None - messages: List[MessageDomain] = field(default_factory=list) + # avoid mutable default list + messages: List[MessageDomain] = Field(default_factory=list) + # optional DB-backed messages count (populated by repository) + messages_count: Optional[int] = None + + class Config: + from_attributes = True diff --git a/backend/fastapi/src/domain/models/message_domain.py b/backend/fastapi/src/domain/models/message_domain.py index e8184fe..fe51f45 100644 --- a/backend/fastapi/src/domain/models/message_domain.py +++ b/backend/fastapi/src/domain/models/message_domain.py @@ -1,9 +1,13 @@ -from dataclasses import dataclass +from pydantic import BaseModel +from src.domain.enums.enums import ERole -@dataclass -class MessageDomain: + +class MessageDomain(BaseModel): id: str conversation_id: str - role: str + role: ERole content: str created_at: int + + class Config: + from_attributes = True diff --git a/backend/fastapi/src/domain/vo/conversation_response.py b/backend/fastapi/src/domain/vo/conversation_response.py index b176877..308200e 100644 --- a/backend/fastapi/src/domain/vo/conversation_response.py +++ b/backend/fastapi/src/domain/vo/conversation_response.py @@ -9,18 +9,26 @@ class ConversationResponse(BaseModel): name: Optional[str] = Field(default="New Conversation") created_at: Optional[int] updated_at: Optional[int] - messages: List[MessageResponse] = [] + messages: List[MessageResponse] = Field(default_factory=list) + messages_count: int = 0 class Config: - orm_mode = True + from_attributes = True @classmethod def from_domain(cls, domain_obj: ConversationDomain): messages = [MessageResponse.from_domain(m) for m in getattr(domain_obj, "messages", [])] + # prefer repository-populated messages_count when available to avoid + # relying on the in-memory messages list (which may be truncated in list views) + messages_count = getattr(domain_obj, "messages_count", None) + if messages_count is None: + messages_count = len(messages) + return cls( - id=domain_obj.id, - name=domain_obj.name, - created_at=domain_obj.created_at, - updated_at=domain_obj.updated_at, - messages=messages + id=domain_obj.id, + name=domain_obj.name, + created_at=domain_obj.created_at, + updated_at=domain_obj.updated_at, + messages=messages, + messages_count=messages_count, ) diff --git a/backend/fastapi/src/domain/vo/list_response.py b/backend/fastapi/src/domain/vo/list_response.py index 228687a..4143fd4 100644 --- a/backend/fastapi/src/domain/vo/list_response.py +++ b/backend/fastapi/src/domain/vo/list_response.py @@ -1,10 +1,9 @@ from typing import Generic, List, Optional, TypeVar -from pydantic.generics import GenericModel +from pydantic import BaseModel T = TypeVar("T") - -class ListResponse(GenericModel, Generic[T]): +class ListResponse(BaseModel, Generic[T]): data: List[T] first_id: Optional[str] last_id: Optional[str] diff --git a/backend/fastapi/src/domain/vo/message_request.py b/backend/fastapi/src/domain/vo/message_request.py index 9b2b383..e016510 100644 --- a/backend/fastapi/src/domain/vo/message_request.py +++ b/backend/fastapi/src/domain/vo/message_request.py @@ -1,10 +1,9 @@ - -from fastapi.utils import generate_unique_id from pydantic import BaseModel from fastapi import Body from src.domain.models.message_domain import MessageDomain from src.domain.utils.utils import get_current_timestamp, generate_unique_id +from src.domain.enums.enums import ERole class MessageRequest(BaseModel): @@ -13,14 +12,14 @@ class MessageRequest(BaseModel): model: str class Config: - orm_mode = True + from_attributes = True def to_domain(self) -> tuple[MessageDomain, str]: return ( MessageDomain( id=generate_unique_id("msg"), conversation_id=self.conversation_id, - role="user", + role=ERole.USER, content=self.content, created_at=get_current_timestamp() ), @@ -31,7 +30,7 @@ def to_domain(self) -> tuple[MessageDomain, str]: def as_body( conversation_id: str = Body(..., description="ID cuộc chuyện (không rỗng)", min_length=1), content: str = Body(..., description="Nội dung tin nhắn (1-2000 ký tự)", min_length=1, max_length=2000), - model: str = Body(..., description="Tên mô hình (ví dụ: 'gpt-4')", min_length=1), + model: str = Body(..., description="Tên mô hình (ví dụ: 'gemini-2.5-pro')", min_length=1), ) -> "MessageRequest": return MessageRequest( conversation_id=conversation_id, diff --git a/backend/fastapi/src/domain/vo/message_response.py b/backend/fastapi/src/domain/vo/message_response.py index b4d5e89..c9aff3a 100644 --- a/backend/fastapi/src/domain/vo/message_response.py +++ b/backend/fastapi/src/domain/vo/message_response.py @@ -1,17 +1,18 @@ from pydantic import BaseModel from src.domain.models.message_domain import MessageDomain +from src.domain.enums.enums import ERole class MessageResponse(BaseModel): id: str conversation_id: str - role: str + role: ERole content: str created_at: int class Config: - orm_mode = True + from_attributes = True @classmethod def from_domain(cls, domain_obj: MessageDomain): diff --git a/backend/fastapi/src/domain/vo/message_update_request.py b/backend/fastapi/src/domain/vo/message_update_request.py new file mode 100644 index 0000000..bfe33c4 --- /dev/null +++ b/backend/fastapi/src/domain/vo/message_update_request.py @@ -0,0 +1,20 @@ +from fastapi import Body +from pydantic import BaseModel + + +class MessageUpdateRequest(BaseModel): + id: str + content: str + role: str | None = None + + @staticmethod + def as_body( + id: str = Body(..., min_length=1, description="Message ID (non-empty)"), + content: str = Body(..., min_length=0, description="Updated message content"), + role: str | None = Body(None, description="Optional role (user/model)") + ) -> "MessageUpdateRequest": + return MessageUpdateRequest( + id=id, + content=content, + role=role, + ) diff --git a/backend/fastapi/src/main.py b/backend/fastapi/src/main.py index 33765bf..9510679 100644 --- a/backend/fastapi/src/main.py +++ b/backend/fastapi/src/main.py @@ -1,5 +1,10 @@ from fastapi import FastAPI -from src.adapter.input.controllers import conversation_controller, health_controller, gemini_controller +from src.adapter.input.controllers import conversation_controller, health_controller, gemini_controller, messages_controller +from fastapi import Request +from fastapi.responses import JSONResponse +from src.application.exceptions.exceptions import AppException +from src.adapter.input.controllers.response_utils import error_response +from fastapi import HTTPException from src.adapter.output.mysql.db.base import init_db from src.application.config.config import settings @@ -12,9 +17,12 @@ ) # include routers under a API prefix +app.include_router(gemini_controller.router, prefix=settings.API_PREFIX) app.include_router(conversation_controller.router, prefix=settings.API_PREFIX) +app.include_router(messages_controller.router, prefix=settings.API_PREFIX) app.include_router(health_controller.router, prefix=settings.API_PREFIX) -app.include_router(gemini_controller.router, prefix=settings.API_PREFIX) + + # also expose health at root for backward compatibility (/health and /health/ready) app.include_router(health_controller.router) @@ -25,3 +33,25 @@ def startup(): init_db() except Exception: pass + + +@app.exception_handler(AppException) +async def app_exception_handler(request: Request, exc: AppException): + # Custom application exceptions use our unified envelope + return error_response(message=exc.message or "Error", status_code=exc.status_code, data=exc.payload) + + +@app.exception_handler(Exception) +async def generic_exception_handler(request: Request, exc: Exception): + # Map some common errors to nicer responses + # ValueError -> 404 (conservative mapping for repository lookups that raise ValueError) + if isinstance(exc, ValueError): + return error_response(message=str(exc), status_code=404, data=None) + # For FastAPI's HTTPException the framework normally handles it; but as a fallback wrap here + return error_response(message=str(exc), status_code=500, data=None) + + +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + # Convert FastAPI HTTPException into our unified envelope + return error_response(message=exc.detail if hasattr(exc, "detail") else str(exc), status_code=exc.status_code, data=None)