diff --git a/app/errors.py b/app/errors.py new file mode 100644 index 0000000..5791ea3 --- /dev/null +++ b/app/errors.py @@ -0,0 +1,61 @@ +"""Map backend/provider failures to stable, sanitized HTTP error responses. + +mem0 calls Qdrant, an LLM provider, and an embedder synchronously inside +request handlers; without this module any of those failing surfaces as an +opaque 500 whose body may leak backend details (hosts, model names, key +prefixes). classify_exception() sorts the concrete SDK exception types into a +small taxonomy with fixed, content-free messages: + +- vector store / network failures -> 503 backend_unavailable +- LLM / embedder provider failures -> 502 upstream_provider_error +- anything else -> 500 internal_error + +The full exception is always logged server-side (with the request_id) and +never echoed to the client. New mem0 call sites need no wrapping — the +classifier runs from the app-level exception handler in app/main.py. +""" + +import anthropic +import httpx +import openai +from qdrant_client.http.exceptions import ( + ResponseHandlingException, + UnexpectedResponse, +) + +# Provider SDK errors are checked first: openai/anthropic connection errors +# wrap httpx exceptions, so the network-level check below would otherwise +# misfile them as vector-store trouble. +_PROVIDER_ERRORS = (openai.OpenAIError, anthropic.AnthropicError) + +# Qdrant client failures and raw transport errors. ConnectionError/TimeoutError +# cover the stdlib variants some client paths raise. +_BACKEND_ERRORS = ( + ResponseHandlingException, + UnexpectedResponse, + httpx.TransportError, + httpx.TimeoutException, + # Raised if a client path surfaces a raw non-2xx instead of wrapping it in + # the qdrant exceptions above. + httpx.HTTPStatusError, + ConnectionError, + TimeoutError, +) + + +def classify_exception(exc: BaseException) -> tuple[int, str, str]: + """Return (status_code, error_code, client-safe detail) for an exception.""" + if isinstance(exc, _PROVIDER_ERRORS): + return ( + 502, + "upstream_provider_error", + "An upstream model provider (LLM or embedder) failed; " + "check provider keys and status.", + ) + if isinstance(exc, _BACKEND_ERRORS): + return ( + 503, + "backend_unavailable", + "The vector store is unreachable or returned an error; try again later.", + ) + return (500, "internal_error", "Internal server error.") diff --git a/app/main.py b/app/main.py index 6062fbe..9766f42 100644 --- a/app/main.py +++ b/app/main.py @@ -10,6 +10,7 @@ from starlette.routing import Route from app.config import get_settings +from app.errors import classify_exception from app.logging_setup import configure_logging from app.mcp_server import build_mcp from app.metrics import observe_request @@ -62,6 +63,9 @@ async def log_requests(request: Request, call_next): # The Authorization header is never read here, so tokens are never logged. request_id = request.headers.get("x-request-id") or uuid.uuid4().hex[:12] structlog.contextvars.bind_contextvars(request_id=request_id) + # Also stashed on request.state for the exception handler: by the time it + # runs, this middleware's finally block has already cleared the contextvars. + request.state.request_id = request_id start = time.perf_counter() status = 500 # if call_next raises, the request is logged as a 500 try: @@ -94,6 +98,29 @@ async def log_requests(request: Request, call_next): structlog.contextvars.clear_contextvars() +@app.exception_handler(Exception) +async def unhandled_exception_handler(request: Request, exc: Exception) -> JSONResponse: + """Translate unhandled errors into stable, sanitized JSON. + + Backend (Qdrant/network) failures become 503, model-provider failures 502, + everything else a generic 500. The response never includes exception text — + it carries the request_id instead, which correlates with the server-side + log line holding the full traceback. + """ + status, code, detail = classify_exception(exc) + request_id = getattr(request.state, "request_id", None) + _log.error( + "unhandled_exception", + request_id=request_id, + error_code=code, + exc_info=exc, + ) + return JSONResponse( + status_code=status, + content={"detail": detail, "error": code, "request_id": request_id}, + ) + + app.include_router(rest_router, prefix="/api/v1") if settings.oauth_enabled: diff --git a/app/mcp_server.py b/app/mcp_server.py index c6b8802..396a9fa 100644 --- a/app/mcp_server.py +++ b/app/mcp_server.py @@ -76,19 +76,40 @@ def list_memories(limit: int = 50, offset: int = 0) -> dict: filters={"user_id": default_user}, limit=limit, offset=offset ) + def _not_found(memory_id: str) -> dict: + # MCP tools return a structured error instead of raising, so the model + # sees a usable signal rather than an opaque tool exception. + return {"error": "not_found", "memory_id": memory_id} + @mcp.tool def get_memory(memory_id: str) -> dict: - """Fetch a single memory by ID.""" - return memory.get(memory_id=memory_id) + """Fetch a single memory by ID. + + Returns {"error": "not_found", ...} if no memory has that ID. + """ + result = memory.get(memory_id=memory_id) + # Explicit None check: a falsy-but-present result (e.g. {}) is a found + # memory, not a miss. + return _not_found(memory_id) if result is None else result @mcp.tool def update_memory(memory_id: str, content: str) -> dict: - """Replace the content of an existing memory.""" + """Replace the content of an existing memory. + + Returns {"error": "not_found", ...} if no memory has that ID. + """ + if not memory.get(memory_id=memory_id): + return _not_found(memory_id) return memory.update(memory_id=memory_id, data=content) @mcp.tool def delete_memory(memory_id: str) -> dict: - """Permanently delete a memory.""" + """Permanently delete a memory. + + Returns {"error": "not_found", ...} if no memory has that ID. + """ + if not memory.get(memory_id=memory_id): + return _not_found(memory_id) memory.delete(memory_id=memory_id) return {"deleted": True, "memory_id": memory_id} diff --git a/app/rest.py b/app/rest.py index c7aad95..981bd36 100644 --- a/app/rest.py +++ b/app/rest.py @@ -2,7 +2,7 @@ import httpx from fastapi import APIRouter, Depends, HTTPException, Query -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from app import memory as memory_mod from app.auth import require_bearer @@ -55,6 +55,59 @@ class UpdateMemoryRequest(BaseModel): content: str +# --- Response models --------------------------------------------------------- +# These document and validate the *stable* parts of mem0's payloads without +# freezing them: extra="allow" passes unexpected mem0 fields through untouched, +# and every route sets response_model_exclude_unset=True so fields mem0 didn't +# send aren't fabricated as nulls — the wire format stays exactly what mem0 +# returned, but /docs and client generators get a real schema. + + +class MemoryItem(BaseModel): + model_config = ConfigDict(extra="allow") + + id: str | None = None + memory: str | None = None + hash: str | None = None + created_at: str | None = None + updated_at: str | None = None + user_id: str | None = None + agent_id: str | None = None + run_id: str | None = None + score: float | None = None + metadata: dict | None = None + + +class MemoryResults(BaseModel): + model_config = ConfigDict(extra="allow") + + results: list[MemoryItem] = Field(default_factory=list) + + +class AddMemoryResponse(MemoryResults): + # Set when the add was skipped because identical content already exists. + deduplicated: bool | None = None + memory_id: str | None = None + + +class UpdateResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + # mem0's update() returns a success message, not the updated item. + message: str | None = None + + +class DeleteResponse(BaseModel): + deleted: bool + memory_id: str + + +class HistoryResponse(BaseModel): + model_config = ConfigDict(extra="allow") + + history: list = Field(default_factory=list) + + def _provenance_filters( source: str | None, confidence: str | None, review_status: str | None ) -> dict: @@ -75,7 +128,9 @@ def _scope_kwargs( return kwargs -@router.post("/memories") +@router.post( + "/memories", response_model=AddMemoryResponse, response_model_exclude_unset=True +) def add_memory(req: AddMemoryRequest) -> dict: if not req.content and not req.messages: raise HTTPException(status_code=422, detail="Provide either 'content' or 'messages'") @@ -86,7 +141,9 @@ def add_memory(req: AddMemoryRequest) -> dict: return memory_mod.add_memory(payload, dedup=req.dedup, **kwargs) -@router.post("/memories/search") +@router.post( + "/memories/search", response_model=MemoryResults, response_model_exclude_unset=True +) def search_memories(req: SearchRequest) -> dict: prov = _provenance_filters(req.source, req.confidence, req.review_status) if req.mode == "keyword": @@ -104,7 +161,9 @@ def search_memories(req: SearchRequest) -> dict: return memory_mod.drop_expired(results) if req.exclude_expired else results -@router.get("/memories") +@router.get( + "/memories", response_model=MemoryResults, response_model_exclude_unset=True +) def list_memories( user_id: str | None = None, agent_id: str | None = None, @@ -126,29 +185,48 @@ def list_memories( return memory_mod.drop_expired(results) if exclude_expired else results -@router.get("/memories/{memory_id}") -def get_memory_by_id(memory_id: str) -> dict: - memory = memory_mod.get_memory() +def _get_or_404(memory, memory_id: str) -> dict: result = memory.get(memory_id=memory_id) if not result: raise HTTPException(status_code=404, detail="Memory not found") return result -@router.put("/memories/{memory_id}") +@router.get( + "/memories/{memory_id}", + response_model=MemoryItem, + response_model_exclude_unset=True, +) +def get_memory_by_id(memory_id: str) -> dict: + return _get_or_404(memory_mod.get_memory(), memory_id) + + +@router.put( + "/memories/{memory_id}", + response_model=UpdateResponse, + response_model_exclude_unset=True, +) def update_memory(memory_id: str, req: UpdateMemoryRequest) -> dict: memory = memory_mod.get_memory() + # Depending on the mem0 version, update() on a missing id either raises or + # silently no-ops; pre-checking makes it a 404 like GET. + _get_or_404(memory, memory_id) return memory.update(memory_id=memory_id, data=req.content) -@router.delete("/memories/{memory_id}") +@router.delete("/memories/{memory_id}", response_model=DeleteResponse) def delete_memory(memory_id: str) -> dict: memory = memory_mod.get_memory() + _get_or_404(memory, memory_id) memory.delete(memory_id=memory_id) return {"deleted": True, "memory_id": memory_id} -@router.get("/memories/{memory_id}/history") +@router.get( + "/memories/{memory_id}/history", + response_model=HistoryResponse, + response_model_exclude_unset=True, +) def memory_history(memory_id: str) -> dict: memory = memory_mod.get_memory() return {"history": memory.history(memory_id=memory_id)} diff --git a/docs/DEVELOPER_GUIDE.md b/docs/DEVELOPER_GUIDE.md index 61cf3c3..8e30277 100644 --- a/docs/DEVELOPER_GUIDE.md +++ b/docs/DEVELOPER_GUIDE.md @@ -83,6 +83,10 @@ app/ wiring, build_verifier() selecting Phase 1 vs Phase 2. oauth.py Phase 2 OAuth 2.1 + PKCE + DCR endpoints, JWT issuance, JWKS, AS/PR metadata. oauth_store.py SQLite store for OAuth clients, auth codes, refresh tokens (/app/data/oauth.db). + errors.py classify_exception(): maps concrete SDK exceptions (qdrant/httpx -> 503, + openai/anthropic -> 502, else 500) to sanitized JSON via the app-level + exception handler in main.py. New mem0 call sites need no wrapping; if a new + backend dependency is added, add its exception types to the tuples here. ratelimit.py Per-IP fixed-window rate limiting of *failed* auth attempts, applied as the rate_limit_middleware over four surfaces: REST (/api/v1), MCP (/mcp), OAuth consent (POST /oauth/authorize) and token (/oauth/token). In-process state, diff --git a/docs/USER_GUIDE.md b/docs/USER_GUIDE.md index 6400357..de463de 100644 --- a/docs/USER_GUIDE.md +++ b/docs/USER_GUIDE.md @@ -545,7 +545,25 @@ and drive the same six tools. ## REST API reference All endpoints live under `/api/v1` and require `Authorization: Bearer `. Request and -response bodies are JSON. `user_id` defaults to `MEM0_DEFAULT_USER_ID` if omitted. +response bodies are JSON. `user_id` defaults to `MEM0_DEFAULT_USER_ID` if omitted. Response +schemas are published in the interactive docs at `/docs` (OpenAPI). + +### Error responses + +Failures return a stable JSON shape: + +```json +{"detail": "human-readable summary", "error": "machine_code", "request_id": "abc123def456"} +``` + +| Status | `error` | Meaning | +|---|---|---| +| `401` | — | Missing/invalid bearer token (plain `detail` only). | +| `404` | — | Memory ID does not exist (`GET`/`PUT`/`DELETE` by ID). | +| `422` | — | Request validation failed (FastAPI's standard shape). | +| `502` | `upstream_provider_error` | The LLM or embedding provider failed — check provider keys/status. | +| `503` | `backend_unavailable` | Qdrant is unreachable or erroring — same condition `/healthz` reports. | +| `500` | `internal_error` | Unexpected failure. The body never contains internals; quote the `request_id` (also settable via an `X-Request-Id` request header) when digging through server logs. | ### Add a memory — `POST /api/v1/memories` @@ -650,11 +668,11 @@ Returns 404 if the memory does not exist. ### Update — `PUT /api/v1/memories/{memory_id}` -Body: `{"content": "new text"}`. +Body: `{"content": "new text"}`. Returns 404 if the memory does not exist. ### Delete — `DELETE /api/v1/memories/{memory_id}` -Returns `{"deleted": true, "memory_id": "…"}`. +Returns `{"deleted": true, "memory_id": "…"}`, or 404 if the memory does not exist. ### History — `GET /api/v1/memories/{memory_id}/history` diff --git a/tests/conftest.py b/tests/conftest.py index 93b7dfb..518da0b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -44,7 +44,10 @@ def _reset_rate_limiters(): @pytest.fixture def mem(): - FAKE_MEMORY.reset_mock() + # Full reset: plain reset_mock() keeps return_value/side_effect, which + # would leak one test's stubbing (e.g. get -> None, search -> raise) into + # every later test. + FAKE_MEMORY.reset_mock(return_value=True, side_effect=True) # Default: no existing fingerprint, so add_memory()'s dedup check is a no-op # and proceeds to call .add(). Tests exercising dedup override this. FAKE_MEMORY.vector_store.list.return_value = ([], None) diff --git a/tests/test_errors.py b/tests/test_errors.py new file mode 100644 index 0000000..3c0ad43 --- /dev/null +++ b/tests/test_errors.py @@ -0,0 +1,102 @@ +import anthropic +import httpx +import openai +from fastapi.testclient import TestClient +from qdrant_client.http.exceptions import ResponseHandlingException + +from app.errors import classify_exception + +# --------------------------------------------------------------------------- +# classify_exception unit tests +# --------------------------------------------------------------------------- + + +def test_classifies_provider_errors_as_502(): + for exc in (openai.OpenAIError("boom"), anthropic.AnthropicError("boom")): + status, code, _ = classify_exception(exc) + assert (status, code) == (502, "upstream_provider_error") + + +def test_classifies_backend_errors_as_503(): + for exc in ( + httpx.ConnectError("refused"), + httpx.ReadTimeout("slow"), + ResponseHandlingException("qdrant glitch"), + httpx.HTTPStatusError( + "503", + request=httpx.Request("GET", "http://q"), + response=httpx.Response(503), + ), + ConnectionError("reset"), + TimeoutError(), + ): + status, code, _ = classify_exception(exc) + assert (status, code) == (503, "backend_unavailable") + + +def test_provider_wins_over_wrapped_network_error(): + # openai.APIConnectionError is an OpenAIError wrapping an httpx error; it + # must classify as a provider failure, not vector-store trouble. + exc = openai.APIConnectionError(request=httpx.Request("GET", "http://x")) + status, code, _ = classify_exception(exc) + assert (status, code) == (502, "upstream_provider_error") + + +def test_unknown_errors_are_500(): + status, code, detail = classify_exception(RuntimeError("secret-host-name")) + assert (status, code) == (500, "internal_error") + assert "secret-host-name" not in detail + + +# --------------------------------------------------------------------------- +# Handler wired into the real app +# --------------------------------------------------------------------------- + + +def _failing_search(app_instance, mem, auth_header, exc): + mem.search.side_effect = exc + client = TestClient(app_instance, raise_server_exceptions=False) + return client.post( + "/api/v1/memories/search", json={"query": "x"}, headers=auth_header + ) + + +def test_qdrant_down_yields_503_json(app_instance, mem, auth_header): + resp = _failing_search( + app_instance, mem, auth_header, httpx.ConnectError("connection refused") + ) + assert resp.status_code == 503 + body = resp.json() + assert body["error"] == "backend_unavailable" + assert body["request_id"] + + +def test_provider_failure_yields_502(app_instance, mem, auth_header): + resp = _failing_search( + app_instance, mem, auth_header, anthropic.AnthropicError("bad key") + ) + assert resp.status_code == 502 + assert resp.json()["error"] == "upstream_provider_error" + + +def test_unexpected_error_yields_sanitized_500(app_instance, mem, auth_header): + resp = _failing_search( + app_instance, mem, auth_header, RuntimeError("qdrant.internal:6333 exploded") + ) + assert resp.status_code == 500 + body = resp.json() + assert body["error"] == "internal_error" + # No exception text leaks; the request_id is the correlation handle instead. + assert "qdrant.internal" not in resp.text + assert body["request_id"] + + +def test_request_id_header_round_trips_into_error_body(app_instance, mem, auth_header): + mem.search.side_effect = RuntimeError("boom") + client = TestClient(app_instance, raise_server_exceptions=False) + resp = client.post( + "/api/v1/memories/search", + json={"query": "x"}, + headers={**auth_header, "X-Request-Id": "trace-me-123"}, + ) + assert resp.json()["request_id"] == "trace-me-123" diff --git a/tests/test_mcp.py b/tests/test_mcp.py index ef0e740..51cfbaa 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -137,3 +137,28 @@ async def test_delete_memory_tool(mcp, mem): async with Client(mcp) as client: await client.call_tool("delete_memory", {"memory_id": "xyz"}) mem.delete.assert_called_once_with(memory_id="xyz") + + +async def test_get_memory_tool_not_found(mcp, mem): + mem.get.return_value = None + async with Client(mcp) as client: + result = await client.call_tool("get_memory", {"memory_id": "ghost"}) + assert result.data == {"error": "not_found", "memory_id": "ghost"} + + +async def test_update_memory_tool_not_found(mcp, mem): + mem.get.return_value = None + async with Client(mcp) as client: + result = await client.call_tool( + "update_memory", {"memory_id": "ghost", "content": "x"} + ) + assert result.data == {"error": "not_found", "memory_id": "ghost"} + mem.update.assert_not_called() + + +async def test_delete_memory_tool_not_found(mcp, mem): + mem.get.return_value = None + async with Client(mcp) as client: + result = await client.call_tool("delete_memory", {"memory_id": "ghost"}) + assert result.data == {"error": "not_found", "memory_id": "ghost"} + mem.delete.assert_not_called() diff --git a/tests/test_rest.py b/tests/test_rest.py index 6314213..ba387c6 100644 --- a/tests/test_rest.py +++ b/tests/test_rest.py @@ -367,3 +367,31 @@ def test_healthz_unreachable(app_instance): resp = c.get("/healthz") assert resp.status_code == 503 assert resp.json()["ok"] is False + + +def test_update_missing_memory_404(app_instance, mem, auth_header): + mem.get.return_value = None + c = _client(app_instance) + resp = c.put( + "/api/v1/memories/missing", json={"content": "x"}, headers=auth_header + ) + assert resp.status_code == 404 + mem.update.assert_not_called() + + +def test_delete_missing_memory_404(app_instance, mem, auth_header): + mem.get.return_value = None + c = _client(app_instance) + resp = c.delete("/api/v1/memories/missing", headers=auth_header) + assert resp.status_code == 404 + mem.delete.assert_not_called() + + +def test_response_passes_through_unknown_mem0_fields(app_instance, mem, auth_header): + # extra="allow" + exclude_unset: unexpected mem0 fields survive, and fields + # mem0 didn't send are not fabricated as nulls. + mem.get.return_value = {"id": "abc", "memory": "hi", "brand_new_field": 7} + c = _client(app_instance) + body = c.get("/api/v1/memories/abc", headers=auth_header).json() + assert body == {"id": "abc", "memory": "hi", "brand_new_field": 7} + assert "agent_id" not in body