diff --git a/.env.example b/.env.example index b415d11..d0fd76f 100644 --- a/.env.example +++ b/.env.example @@ -29,5 +29,17 @@ OAUTH_SIGNING_KEY= # uses a per-connector callback path under https://chatgpt.com/connector/oauth/). OAUTH_ALLOWED_REDIRECT_URIS=https://claude.ai/api/mcp/auth_callback,https://cowork.com/api/mcp/auth_callback,https://chatgpt.com/connector/oauth/* +# Rate limiting (optional; defaults shown). Only failed auth attempts count. +# Set a *_FAILURES value to 0 to disable that surface's limiter. +# TRUST_FORWARDED_FOR=true is correct behind CapRover's nginx; set false when +# the app is exposed directly (the header would be attacker-controlled). +TRUST_FORWARDED_FOR=true +RATE_LIMIT_AUTH_FAILURES=10 +RATE_LIMIT_AUTH_WINDOW_SECONDS=60 +RATE_LIMIT_CONSENT_FAILURES=5 +RATE_LIMIT_CONSENT_WINDOW_SECONDS=300 +RATE_LIMIT_TOKEN_FAILURES=10 +RATE_LIMIT_TOKEN_WINDOW_SECONDS=60 + # Misc LOG_LEVEL=INFO diff --git a/app/config.py b/app/config.py index 719e4b8..2de8cd1 100644 --- a/app/config.py +++ b/app/config.py @@ -42,6 +42,18 @@ class Settings(BaseSettings): "https://chatgpt.com/connector/oauth/*" ) + # Rate limiting (brute-force protection on auth surfaces). Only *failed* + # attempts count; an IP over the limit is rejected with 429 until the + # window expires. Limits are per uvicorn worker. Set a *_failures value + # below 1 to disable limiting on that surface. + trust_forwarded_for: bool = True + rate_limit_auth_failures: int = 10 + rate_limit_auth_window_seconds: float = 60.0 + rate_limit_consent_failures: int = 5 + rate_limit_consent_window_seconds: float = 300.0 + rate_limit_token_failures: int = 10 + rate_limit_token_window_seconds: float = 60.0 + # Misc log_level: str = "INFO" diff --git a/app/main.py b/app/main.py index c72ceb8..6062fbe 100644 --- a/app/main.py +++ b/app/main.py @@ -13,6 +13,7 @@ from app.logging_setup import configure_logging from app.mcp_server import build_mcp from app.metrics import observe_request +from app.ratelimit import rate_limit_middleware from app.rest import check_qdrant from app.rest import router as rest_router @@ -50,6 +51,11 @@ async def lifespan(app: FastAPI): app = FastAPI(title="mem0 Memory Server", version="1.0.0", lifespan=lifespan) +# Registered before log_requests so logging wraps it (later-registered +# middleware is outermost): rate-limited 429s still get a request log line and +# a latency observation. +app.middleware("http")(rate_limit_middleware) + @app.middleware("http") async def log_requests(request: Request, call_next): diff --git a/app/metrics.py b/app/metrics.py index 939e643..82c458f 100644 --- a/app/metrics.py +++ b/app/metrics.py @@ -14,6 +14,28 @@ ) +# Brute-force signal: failed auth attempts and rate-limited rejections, by +# auth surface ("rest", "mcp", "oauth_consent", "oauth_token"). +AUTH_FAILURES = Counter( + "auth_failures_total", + "Failed authentication attempts.", + ["surface"], +) +RATE_LIMITED = Counter( + "rate_limited_requests_total", + "Requests rejected by the auth rate limiter.", + ["surface"], +) + + def observe_request(method: str, path: str, status: int, duration_s: float) -> None: REQUEST_COUNT.labels(method=method, path=path, status=str(status)).inc() REQUEST_LATENCY.labels(method=method, path=path).observe(duration_s) + + +def observe_auth_failure(surface: str) -> None: + AUTH_FAILURES.labels(surface=surface).inc() + + +def observe_rate_limited(surface: str) -> None: + RATE_LIMITED.labels(surface=surface).inc() diff --git a/app/ratelimit.py b/app/ratelimit.py new file mode 100644 index 0000000..30a756e --- /dev/null +++ b/app/ratelimit.py @@ -0,0 +1,180 @@ +"""Per-IP rate limiting of failed authentication attempts. + +Every surface that accepts a secret is an online-guessing oracle for +MEM0_API_KEY (or an OAuth code/refresh token): the REST bearer check, the MCP +token verifier, the OAuth consent form, and the OAuth token endpoint. This +module slows brute force to a crawl with a small in-process fixed-window +counter keyed by client IP. Only *failures* count toward the limit; an IP that +hits the limit is locked out of that surface (even with valid credentials) +until the window expires. + +Limits are per uvicorn worker (the Dockerfile runs --workers 2), so the +effective limit is roughly workers x the configured value. That is fine for a +single-user service: the point is to turn millions of guesses per day into +dozens, not to enforce an exact quota. +""" + +import math +import threading +import time +from functools import lru_cache + +import structlog +from fastapi import Request +from fastapi.responses import JSONResponse + +from app.config import get_settings +from app.metrics import observe_auth_failure, observe_rate_limited + +_log = structlog.get_logger() + +# Cap on tracked IPs per limiter; when reached, expired windows are pruned on +# the next recorded failure so an attacker rotating IPs can't grow the dict +# without bound. +_MAX_TRACKED_KEYS = 4096 + + +class RateLimiter: + """Fixed-window failure counter keyed by an arbitrary string (client IP). + + A key with `max_failures` failures inside `window_seconds` is limited until + the window that contains its first failure expires. `max_failures < 1` + disables limiting entirely (operator opt-out). + """ + + def __init__(self, max_failures: int, window_seconds: float): + self.max_failures = max_failures + self.window_seconds = window_seconds + self._lock = threading.Lock() + # key -> (window_start_monotonic, failure_count) + self._windows: dict[str, tuple[float, int]] = {} + + def retry_after(self, key: str, now: float | None = None) -> float | None: + """Seconds until `key` may retry, or None if it is not limited.""" + if self.max_failures < 1: + return None + now = time.monotonic() if now is None else now + with self._lock: + entry = self._windows.get(key) + if entry is None: + return None + start, count = entry + if now - start >= self.window_seconds: + del self._windows[key] + return None + if count >= self.max_failures: + return self.window_seconds - (now - start) + return None + + def record_failure(self, key: str, now: float | None = None) -> None: + if self.max_failures < 1: + return + now = time.monotonic() if now is None else now + with self._lock: + if len(self._windows) >= _MAX_TRACKED_KEYS: + self._prune(now) + start, count = self._windows.get(key, (now, 0)) + if now - start >= self.window_seconds: + start, count = now, 0 + self._windows[key] = (start, count + 1) + + def _prune(self, now: float) -> None: + expired = [ + key + for key, (start, _) in self._windows.items() + if now - start >= self.window_seconds + ] + for key in expired: + del self._windows[key] + + def reset(self) -> None: + with self._lock: + self._windows.clear() + + +def client_ip(request: Request) -> str: + """Best-effort client IP for rate-limit keying. + + Behind CapRover's nginx the peer address is the proxy, so the original + client is in X-Forwarded-For (first hop). TRUST_FORWARDED_FOR=false turns + that off for deployments exposed directly (e.g. docker-compose without a + proxy), where the header would be attacker-controlled. + """ + if get_settings().trust_forwarded_for: + forwarded = request.headers.get("x-forwarded-for", "") + first = forwarded.split(",")[0].strip() + if first: + return first + return request.client.host if request.client else "unknown" + + +# Response statuses that count as an authentication failure, per surface. +# REST/MCP only ever 401 on a bad bearer token; the consent form returns 401 on +# a wrong API key; the token endpoint returns 400 for guessed/expired codes and +# refresh tokens (per RFC 6749), so 400 counts there too. +_FAILURE_STATUSES = { + "rest": {401}, + "mcp": {401}, + "oauth_consent": {401}, + "oauth_token": {400, 401}, +} + + +def _surface(path: str, method: str) -> str | None: + p = path.rstrip("/") or "/" + if p == "/api/v1" or p.startswith("/api/v1/"): + return "rest" + if p == "/mcp": + return "mcp" + if p == "/oauth/authorize" and method == "POST": + # The GET form render checks no secret; only the POST submit does. + return "oauth_consent" + if p == "/oauth/token": + return "oauth_token" + return None + + +@lru_cache +def _limiter(surface: str) -> RateLimiter: + s = get_settings() + if surface == "oauth_consent": + return RateLimiter( + s.rate_limit_consent_failures, s.rate_limit_consent_window_seconds + ) + if surface == "oauth_token": + return RateLimiter( + s.rate_limit_token_failures, s.rate_limit_token_window_seconds + ) + return RateLimiter(s.rate_limit_auth_failures, s.rate_limit_auth_window_seconds) + + +def reset_all() -> None: + """Drop all limiter state (and re-read settings). Test hook.""" + _limiter.cache_clear() + + +async def rate_limit_middleware(request: Request, call_next): + surface = _surface(request.url.path, request.method) + if surface is None: + return await call_next(request) + ip = client_ip(request) + limiter = _limiter(surface) + retry_after = limiter.retry_after(ip) + if retry_after is not None: + observe_rate_limited(surface) + _log.warning("rate_limited", surface=surface, ip=ip) + return JSONResponse( + status_code=429, + content={ + "detail": "Too many failed authentication attempts; try again later." + }, + headers={"Retry-After": str(max(1, math.ceil(retry_after)))}, + ) + response = await call_next(request) + if response.status_code in _FAILURE_STATUSES[surface]: + limiter.record_failure(ip) + observe_auth_failure(surface) + _log.warning( + "auth_failure", surface=surface, ip=ip, status=response.status_code + ) + return response diff --git a/docs/DEVELOPER_GUIDE.md b/docs/DEVELOPER_GUIDE.md index ddb96b0..9848ee1 100644 --- a/docs/DEVELOPER_GUIDE.md +++ b/docs/DEVELOPER_GUIDE.md @@ -79,6 +79,10 @@ app/ wiring, build_verifier() selecting Phase 1 vs Phase 2. oauth.py Phase 2 OAuth 2.1 + PKCE + DCR endpoints, JWT issuance, JWKS, AS/PR metadata. oauth_store.py SQLite store for OAuth clients, auth codes, refresh tokens (/app/data/oauth.db). + ratelimit.py Per-IP fixed-window rate limiting of *failed* auth attempts, applied as the + rate_limit_middleware over four surfaces: REST (/api/v1), MCP (/mcp), OAuth + consent (POST /oauth/authorize) and token (/oauth/token). In-process state, + per worker; client_ip() honors X-Forwarded-For when TRUST_FORWARDED_FOR=true. metrics.py Prometheus Counter + Histogram and observe_request(). logging_setup.py structlog configuration. main.py Wiring: FastAPI app, request-logging middleware, router include, conditional @@ -251,6 +255,15 @@ label cardinality bounded; unmatched 404s bucket under `__unmatched__`. Under mu `/metrics` aggregates across workers when `PROMETHEUS_MULTIPROC_DIR` is set. The middleware never reads the `Authorization` header, so tokens are never logged. +`rate_limit_middleware` (`app/ratelimit.py`) is registered *before* `log_requests` so logging wraps +it and 429s still get a log line. It counts failed-auth responses (401s; also 400s on +`/oauth/token`, which is what RFC 6749 returns for guessed codes) per client IP per surface, and +rejects an over-limit IP with 429 + `Retry-After` before the request reaches auth. Two metrics +track it: `auth_failures_total{surface}` and `rate_limited_requests_total{surface}` — a spike in +either is a brute-force signal. State is in-process and per worker by design (single-user service; +no shared Redis). Tests must not leak limiter state: `tests/conftest.py` has an autouse fixture +calling `ratelimit.reset_all()`. + ## CI and deployment - **CI** (`.github/workflows/ci.yml`) runs on pushes to `main` and on all PRs: installs deps, then diff --git a/docs/USER_GUIDE.md b/docs/USER_GUIDE.md index 2d37127..9ba6efe 100644 --- a/docs/USER_GUIDE.md +++ b/docs/USER_GUIDE.md @@ -154,8 +154,26 @@ runs, or set these in the CapRover app's **App Configs** panel for production. | `PUBLIC_BASE_URL` | yes | — | Public URL, e.g. `https://mem0.your-domain.com`. Used in OAuth metadata. | | `OAUTH_SIGNING_KEY` | no | empty | PEM RSA private key. **Setting this enables Phase 2 OAuth.** Leave blank for Phase 1. | | `OAUTH_ALLOWED_REDIRECT_URIS` | no | claude.ai + cowork + chatgpt callbacks | Comma-separated allowlist for OAuth redirect URIs. An entry ending in `*` is a **path-prefix** match locked to an exact scheme + host — it must be a full `scheme://host/path/` prefix (e.g. `https://chatgpt.com/connector/oauth/*`). Host-only or bare wildcards (`https://chatgpt.com*`, `https://*`, `*`) are **ignored**, so a misconfigured entry can't match lookalike hosts like `chatgpt.com.evil.com`. | +| `TRUST_FORWARDED_FOR` | no | `true` | Use the first `X-Forwarded-For` hop as the client IP for rate limiting. Correct behind CapRover's nginx; set to `false` if the app is exposed directly (no reverse proxy), where the header would be attacker-controlled. | +| `RATE_LIMIT_AUTH_FAILURES` | no | `10` | Failed bearer-token attempts (REST + MCP, per surface) allowed per IP per window before 429s. `0` disables. | +| `RATE_LIMIT_AUTH_WINDOW_SECONDS` | no | `60` | Window for the above. | +| `RATE_LIMIT_CONSENT_FAILURES` | no | `5` | Failed OAuth consent (wrong API key) attempts per IP per window. `0` disables. | +| `RATE_LIMIT_CONSENT_WINDOW_SECONDS` | no | `300` | Window for the above. | +| `RATE_LIMIT_TOKEN_FAILURES` | no | `10` | Failed `/oauth/token` exchanges per IP per window. `0` disables. | +| `RATE_LIMIT_TOKEN_WINDOW_SECONDS` | no | `60` | Window for the above. | | `LOG_LEVEL` | no | `INFO` | Log level. | +#### Rate limiting + +Failed authentication attempts are rate-limited per client IP to slow down brute-force guessing of +`MEM0_API_KEY` (and OAuth codes). Only **failures** count — normal authenticated traffic is never +throttled — but once an IP crosses the limit, *all* its requests to that surface (even with the +correct token) get **HTTP 429** with a `Retry-After` header until the window expires. The four +surfaces (REST `/api/v1/...`, MCP `/mcp`, OAuth consent, OAuth token) are limited independently. +`/healthz` and `/metrics` are never limited. Limits are per uvicorn worker (the default image runs +2 workers), so the effective ceiling is about twice the configured value. If you lock yourself out +during testing, wait out the window or restart the app. + ### Phases - **Phase 1 (MVP)** — static bearer token only. Leave `OAUTH_SIGNING_KEY` blank. Works with Claude diff --git a/tests/conftest.py b/tests/conftest.py index e5aa6d2..93b7dfb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,17 @@ memory_mod.get_memory = lambda: FAKE_MEMORY +@pytest.fixture(autouse=True) +def _reset_rate_limiters(): + # Failed-auth counts are keyed by client IP and the TestClient always + # connects as "testclient", so limiter state must not leak across tests. + from app import ratelimit + + ratelimit.reset_all() + yield + ratelimit.reset_all() + + @pytest.fixture def mem(): FAKE_MEMORY.reset_mock() diff --git a/tests/test_ratelimit.py b/tests/test_ratelimit.py new file mode 100644 index 0000000..d9358ff --- /dev/null +++ b/tests/test_ratelimit.py @@ -0,0 +1,274 @@ +from unittest.mock import MagicMock + +import pytest +from fastapi import FastAPI, HTTPException +from fastapi.testclient import TestClient + +from app import ratelimit +from app.ratelimit import RateLimiter, client_ip, rate_limit_middleware + +# --------------------------------------------------------------------------- +# RateLimiter unit tests +# --------------------------------------------------------------------------- + + +def test_under_limit_not_limited(): + limiter = RateLimiter(max_failures=3, window_seconds=60) + limiter.record_failure("ip", now=0.0) + limiter.record_failure("ip", now=1.0) + assert limiter.retry_after("ip", now=2.0) is None + + +def test_limited_at_threshold_with_retry_after(): + limiter = RateLimiter(max_failures=3, window_seconds=60) + for t in (0.0, 1.0, 2.0): + limiter.record_failure("ip", now=t) + retry = limiter.retry_after("ip", now=10.0) + assert retry == pytest.approx(50.0) + + +def test_window_expiry_unblocks(): + limiter = RateLimiter(max_failures=2, window_seconds=60) + limiter.record_failure("ip", now=0.0) + limiter.record_failure("ip", now=1.0) + assert limiter.retry_after("ip", now=30.0) is not None + assert limiter.retry_after("ip", now=60.0) is None + + +def test_failure_after_expired_window_starts_fresh(): + limiter = RateLimiter(max_failures=2, window_seconds=60) + limiter.record_failure("ip", now=0.0) + limiter.record_failure("ip", now=1.0) + # New failure in a fresh window: count restarts at 1, not 3. + limiter.record_failure("ip", now=120.0) + assert limiter.retry_after("ip", now=121.0) is None + + +def test_keys_are_independent(): + limiter = RateLimiter(max_failures=1, window_seconds=60) + limiter.record_failure("a", now=0.0) + assert limiter.retry_after("a", now=1.0) is not None + assert limiter.retry_after("b", now=1.0) is None + + +def test_disabled_when_max_failures_below_one(): + limiter = RateLimiter(max_failures=0, window_seconds=60) + for t in range(100): + limiter.record_failure("ip", now=float(t)) + assert limiter.retry_after("ip", now=1.0) is None + + +def test_prune_caps_tracked_keys(): + limiter = RateLimiter(max_failures=5, window_seconds=60) + for i in range(ratelimit._MAX_TRACKED_KEYS): + limiter.record_failure(f"ip-{i}", now=0.0) + # All previous windows are expired by now=100; the next failure prunes them. + limiter.record_failure("fresh", now=100.0) + assert len(limiter._windows) == 1 + + +def test_reset_clears_state(): + limiter = RateLimiter(max_failures=1, window_seconds=60) + limiter.record_failure("ip", now=0.0) + limiter.reset() + assert limiter.retry_after("ip", now=1.0) is None + + +# --------------------------------------------------------------------------- +# client_ip extraction +# --------------------------------------------------------------------------- + + +def _fake_request(xff: str | None, peer: str | None = "10.0.0.1"): + request = MagicMock() + request.headers = {"x-forwarded-for": xff} if xff is not None else {} + request.client = MagicMock(host=peer) if peer else None + return request + + +def test_client_ip_uses_first_forwarded_hop(): + assert client_ip(_fake_request("1.2.3.4, 5.6.7.8")) == "1.2.3.4" + + +def test_client_ip_falls_back_to_peer(): + assert client_ip(_fake_request(None)) == "10.0.0.1" + assert client_ip(_fake_request("")) == "10.0.0.1" + + +def test_client_ip_ignores_header_when_untrusted(monkeypatch): + class _S: + trust_forwarded_for = False + + monkeypatch.setattr(ratelimit, "get_settings", lambda: _S()) + assert client_ip(_fake_request("1.2.3.4")) == "10.0.0.1" + + +def test_client_ip_no_peer(): + assert client_ip(_fake_request(None, peer=None)) == "unknown" + + +# --------------------------------------------------------------------------- +# Surface classification +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("path", "method", "expected"), + [ + ("/api/v1/memories", "POST", "rest"), + ("/api/v1/memories/abc", "GET", "rest"), + ("/mcp", "POST", "mcp"), + ("/mcp/", "POST", "mcp"), + ("/oauth/authorize", "POST", "oauth_consent"), + ("/oauth/authorize", "GET", None), # form render checks no secret + ("/oauth/token", "POST", "oauth_token"), + ("/healthz", "GET", None), + ("/metrics", "GET", None), + ("/api/v1x", "GET", None), # prefix must be a path segment + ], +) +def test_surface_classification(path, method, expected): + assert ratelimit._surface(path, method) == expected + + +# --------------------------------------------------------------------------- +# Middleware behavior (stub app exercising every surface) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def stub_client(): + app = FastAPI() + app.middleware("http")(rate_limit_middleware) + + @app.get("/api/v1/memories") + def rest_endpoint(fail: bool = False): + if fail: + raise HTTPException(status_code=401) + return {"ok": True} + + @app.post("/oauth/token") + def token_endpoint(fail: bool = False): + if fail: + raise HTTPException(status_code=400) + return {"ok": True} + + @app.post("/oauth/authorize") + def consent_endpoint(fail: bool = False): + if fail: + raise HTTPException(status_code=401) + return {"ok": True} + + @app.get("/healthz") + def health_endpoint(): + raise HTTPException(status_code=401) + + return TestClient(app) + + +def _hammer(client, n, path, ip, method="GET"): + for _ in range(n): + resp = client.request( + method, path, params={"fail": "true"}, headers={"X-Forwarded-For": ip} + ) + return resp + + +def test_rest_failures_trigger_429(stub_client): + resp = _hammer(stub_client, 10, "/api/v1/memories", "9.9.9.1") + assert resp.status_code == 401 + resp = stub_client.get( + "/api/v1/memories", headers={"X-Forwarded-For": "9.9.9.1"} + ) + assert resp.status_code == 429 + assert int(resp.headers["Retry-After"]) >= 1 + + +def test_successes_never_count(stub_client): + for _ in range(30): + resp = stub_client.get( + "/api/v1/memories", headers={"X-Forwarded-For": "9.9.9.2"} + ) + assert resp.status_code == 200 + + +def test_other_ips_unaffected(stub_client): + _hammer(stub_client, 10, "/api/v1/memories", "9.9.9.3") + resp = stub_client.get( + "/api/v1/memories", headers={"X-Forwarded-For": "9.9.9.4"} + ) + assert resp.status_code == 200 + + +def test_unclassified_paths_never_limited(stub_client): + for _ in range(30): + resp = stub_client.get("/healthz", headers={"X-Forwarded-For": "9.9.9.5"}) + assert resp.status_code == 401 # the stub 401s, but never 429 + + +def test_oauth_token_400_counts_as_failure(stub_client): + resp = _hammer(stub_client, 10, "/oauth/token", "9.9.9.6", method="POST") + assert resp.status_code == 400 + resp = stub_client.post("/oauth/token", headers={"X-Forwarded-For": "9.9.9.6"}) + assert resp.status_code == 429 + + +def test_oauth_consent_stricter_limit(stub_client): + resp = _hammer(stub_client, 5, "/oauth/authorize", "9.9.9.7", method="POST") + assert resp.status_code == 401 + resp = stub_client.post( + "/oauth/authorize", headers={"X-Forwarded-For": "9.9.9.7"} + ) + assert resp.status_code == 429 + + +def test_surfaces_limited_independently(stub_client): + # Locking out the REST surface must not lock out the token endpoint. + _hammer(stub_client, 10, "/api/v1/memories", "9.9.9.8") + resp = stub_client.post("/oauth/token", headers={"X-Forwarded-For": "9.9.9.8"}) + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# Wired into the real app +# --------------------------------------------------------------------------- + + +def test_main_app_limits_bad_bearer_tokens(app_instance, mem): + client = TestClient(app_instance) + headers = { + "Authorization": "Bearer wrong-token", + "X-Forwarded-For": "8.8.8.1", + } + for _ in range(10): + resp = client.get("/api/v1/memories", headers=headers) + assert resp.status_code == 401 + resp = client.get("/api/v1/memories", headers=headers) + assert resp.status_code == 429 + # Even a valid token is locked out from that IP until the window expires. + resp = client.get( + "/api/v1/memories", + headers={ + "Authorization": "Bearer test-bearer-token", + "X-Forwarded-For": "8.8.8.1", + }, + ) + assert resp.status_code == 429 + + +def test_main_app_healthz_unaffected_by_lockout(app_instance, mem, monkeypatch): + import app.main as main_mod + + async def _ok(): + return True + + monkeypatch.setattr(main_mod, "check_qdrant", _ok) + client = TestClient(app_instance) + headers = { + "Authorization": "Bearer wrong-token", + "X-Forwarded-For": "8.8.8.2", + } + for _ in range(11): + client.get("/api/v1/memories", headers=headers) + resp = client.get("/healthz", headers={"X-Forwarded-For": "8.8.8.2"}) + assert resp.status_code == 200