Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,17 @@ OAUTH_SIGNING_KEY=
# uses a per-connector callback path under https://chatgpt.com/connector/oauth/).
OAUTH_ALLOWED_REDIRECT_URIS=https://claude.ai/api/mcp/auth_callback,https://cowork.com/api/mcp/auth_callback,https://chatgpt.com/connector/oauth/*

# Rate limiting (optional; defaults shown). Only failed auth attempts count.
# Set a *_FAILURES value to 0 to disable that surface's limiter.
# TRUST_FORWARDED_FOR=true is correct behind CapRover's nginx; set false when
# the app is exposed directly (the header would be attacker-controlled).
TRUST_FORWARDED_FOR=true
RATE_LIMIT_AUTH_FAILURES=10
RATE_LIMIT_AUTH_WINDOW_SECONDS=60
RATE_LIMIT_CONSENT_FAILURES=5
RATE_LIMIT_CONSENT_WINDOW_SECONDS=300
RATE_LIMIT_TOKEN_FAILURES=10
RATE_LIMIT_TOKEN_WINDOW_SECONDS=60

# Misc
LOG_LEVEL=INFO
12 changes: 12 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ class Settings(BaseSettings):
"https://chatgpt.com/connector/oauth/*"
)

# Rate limiting (brute-force protection on auth surfaces). Only *failed*
# attempts count; an IP over the limit is rejected with 429 until the
# window expires. Limits are per uvicorn worker. Set a *_failures value
# below 1 to disable limiting on that surface.
trust_forwarded_for: bool = True
rate_limit_auth_failures: int = 10
rate_limit_auth_window_seconds: float = 60.0
rate_limit_consent_failures: int = 5
rate_limit_consent_window_seconds: float = 300.0
rate_limit_token_failures: int = 10
rate_limit_token_window_seconds: float = 60.0

# Misc
log_level: str = "INFO"

Expand Down
6 changes: 6 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from app.logging_setup import configure_logging
from app.mcp_server import build_mcp
from app.metrics import observe_request
from app.ratelimit import rate_limit_middleware
from app.rest import check_qdrant
from app.rest import router as rest_router

Expand Down Expand Up @@ -50,6 +51,11 @@ async def lifespan(app: FastAPI):

app = FastAPI(title="mem0 Memory Server", version="1.0.0", lifespan=lifespan)

# Registered before log_requests so logging wraps it (later-registered
# middleware is outermost): rate-limited 429s still get a request log line and
# a latency observation.
app.middleware("http")(rate_limit_middleware)


@app.middleware("http")
async def log_requests(request: Request, call_next):
Expand Down
22 changes: 22 additions & 0 deletions app/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,28 @@
)


# Brute-force signal: failed auth attempts and rate-limited rejections, by
# auth surface ("rest", "mcp", "oauth_consent", "oauth_token").
AUTH_FAILURES = Counter(
"auth_failures_total",
"Failed authentication attempts.",
["surface"],
)
RATE_LIMITED = Counter(
"rate_limited_requests_total",
"Requests rejected by the auth rate limiter.",
["surface"],
)


def observe_request(method: str, path: str, status: int, duration_s: float) -> None:
REQUEST_COUNT.labels(method=method, path=path, status=str(status)).inc()
REQUEST_LATENCY.labels(method=method, path=path).observe(duration_s)


def observe_auth_failure(surface: str) -> None:
AUTH_FAILURES.labels(surface=surface).inc()


def observe_rate_limited(surface: str) -> None:
RATE_LIMITED.labels(surface=surface).inc()
180 changes: 180 additions & 0 deletions app/ratelimit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
"""Per-IP rate limiting of failed authentication attempts.

Every surface that accepts a secret is an online-guessing oracle for
MEM0_API_KEY (or an OAuth code/refresh token): the REST bearer check, the MCP
token verifier, the OAuth consent form, and the OAuth token endpoint. This
module slows brute force to a crawl with a small in-process fixed-window
counter keyed by client IP. Only *failures* count toward the limit; an IP that
hits the limit is locked out of that surface (even with valid credentials)
until the window expires.

Limits are per uvicorn worker (the Dockerfile runs --workers 2), so the
effective limit is roughly workers x the configured value. That is fine for a
single-user service: the point is to turn millions of guesses per day into
dozens, not to enforce an exact quota.
"""

import math
import threading
import time
from functools import lru_cache

import structlog
from fastapi import Request
from fastapi.responses import JSONResponse

from app.config import get_settings
from app.metrics import observe_auth_failure, observe_rate_limited

_log = structlog.get_logger()

# Cap on tracked IPs per limiter; when reached, expired windows are pruned on
# the next recorded failure so an attacker rotating IPs can't grow the dict
# without bound.
_MAX_TRACKED_KEYS = 4096


class RateLimiter:
"""Fixed-window failure counter keyed by an arbitrary string (client IP).

A key with `max_failures` failures inside `window_seconds` is limited until
the window that contains its first failure expires. `max_failures < 1`
disables limiting entirely (operator opt-out).
"""

def __init__(self, max_failures: int, window_seconds: float):
self.max_failures = max_failures
self.window_seconds = window_seconds
self._lock = threading.Lock()
# key -> (window_start_monotonic, failure_count)
self._windows: dict[str, tuple[float, int]] = {}

def retry_after(self, key: str, now: float | None = None) -> float | None:
"""Seconds until `key` may retry, or None if it is not limited."""
if self.max_failures < 1:
return None
now = time.monotonic() if now is None else now
with self._lock:
entry = self._windows.get(key)
if entry is None:
return None
start, count = entry
if now - start >= self.window_seconds:
del self._windows[key]
return None
if count >= self.max_failures:
return self.window_seconds - (now - start)
return None

def record_failure(self, key: str, now: float | None = None) -> None:
if self.max_failures < 1:
return
now = time.monotonic() if now is None else now
with self._lock:
if len(self._windows) >= _MAX_TRACKED_KEYS:
self._prune(now)
start, count = self._windows.get(key, (now, 0))
if now - start >= self.window_seconds:
start, count = now, 0
self._windows[key] = (start, count + 1)

def _prune(self, now: float) -> None:
expired = [
key
for key, (start, _) in self._windows.items()
if now - start >= self.window_seconds
]
for key in expired:
del self._windows[key]

def reset(self) -> None:
with self._lock:
self._windows.clear()


def client_ip(request: Request) -> str:
"""Best-effort client IP for rate-limit keying.

Behind CapRover's nginx the peer address is the proxy, so the original
client is in X-Forwarded-For (first hop). TRUST_FORWARDED_FOR=false turns
that off for deployments exposed directly (e.g. docker-compose without a
proxy), where the header would be attacker-controlled.
"""
if get_settings().trust_forwarded_for:
forwarded = request.headers.get("x-forwarded-for", "")
first = forwarded.split(",")[0].strip()
if first:
return first
return request.client.host if request.client else "unknown"


# Response statuses that count as an authentication failure, per surface.
# REST/MCP only ever 401 on a bad bearer token; the consent form returns 401 on
# a wrong API key; the token endpoint returns 400 for guessed/expired codes and
# refresh tokens (per RFC 6749), so 400 counts there too.
_FAILURE_STATUSES = {
"rest": {401},
"mcp": {401},
"oauth_consent": {401},
"oauth_token": {400, 401},
}


def _surface(path: str, method: str) -> str | None:
p = path.rstrip("/") or "/"
if p == "/api/v1" or p.startswith("/api/v1/"):
return "rest"
if p == "/mcp":
return "mcp"
if p == "/oauth/authorize" and method == "POST":
# The GET form render checks no secret; only the POST submit does.
return "oauth_consent"
if p == "/oauth/token":
return "oauth_token"
return None


@lru_cache
def _limiter(surface: str) -> RateLimiter:
s = get_settings()
if surface == "oauth_consent":
return RateLimiter(
s.rate_limit_consent_failures, s.rate_limit_consent_window_seconds
)
if surface == "oauth_token":
return RateLimiter(
s.rate_limit_token_failures, s.rate_limit_token_window_seconds
)
return RateLimiter(s.rate_limit_auth_failures, s.rate_limit_auth_window_seconds)


def reset_all() -> None:
"""Drop all limiter state (and re-read settings). Test hook."""
_limiter.cache_clear()


async def rate_limit_middleware(request: Request, call_next):
surface = _surface(request.url.path, request.method)
if surface is None:
return await call_next(request)
ip = client_ip(request)
limiter = _limiter(surface)
retry_after = limiter.retry_after(ip)
if retry_after is not None:
observe_rate_limited(surface)
_log.warning("rate_limited", surface=surface, ip=ip)
return JSONResponse(
status_code=429,
content={
"detail": "Too many failed authentication attempts; try again later."
},
headers={"Retry-After": str(max(1, math.ceil(retry_after)))},
)
response = await call_next(request)
if response.status_code in _FAILURE_STATUSES[surface]:
limiter.record_failure(ip)
observe_auth_failure(surface)
_log.warning(
"auth_failure", surface=surface, ip=ip, status=response.status_code
)
return response
13 changes: 13 additions & 0 deletions docs/DEVELOPER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ app/
wiring, build_verifier() selecting Phase 1 vs Phase 2.
oauth.py Phase 2 OAuth 2.1 + PKCE + DCR endpoints, JWT issuance, JWKS, AS/PR metadata.
oauth_store.py SQLite store for OAuth clients, auth codes, refresh tokens (/app/data/oauth.db).
ratelimit.py Per-IP fixed-window rate limiting of *failed* auth attempts, applied as the
rate_limit_middleware over four surfaces: REST (/api/v1), MCP (/mcp), OAuth
consent (POST /oauth/authorize) and token (/oauth/token). In-process state,
per worker; client_ip() honors X-Forwarded-For when TRUST_FORWARDED_FOR=true.
metrics.py Prometheus Counter + Histogram and observe_request().
logging_setup.py structlog configuration.
main.py Wiring: FastAPI app, request-logging middleware, router include, conditional
Expand Down Expand Up @@ -251,6 +255,15 @@ label cardinality bounded; unmatched 404s bucket under `__unmatched__`. Under mu
`/metrics` aggregates across workers when `PROMETHEUS_MULTIPROC_DIR` is set. The middleware never
reads the `Authorization` header, so tokens are never logged.

`rate_limit_middleware` (`app/ratelimit.py`) is registered *before* `log_requests` so logging wraps
it and 429s still get a log line. It counts failed-auth responses (401s; also 400s on
`/oauth/token`, which is what RFC 6749 returns for guessed codes) per client IP per surface, and
rejects an over-limit IP with 429 + `Retry-After` before the request reaches auth. Two metrics
track it: `auth_failures_total{surface}` and `rate_limited_requests_total{surface}` — a spike in
either is a brute-force signal. State is in-process and per worker by design (single-user service;
no shared Redis). Tests must not leak limiter state: `tests/conftest.py` has an autouse fixture
calling `ratelimit.reset_all()`.

## CI and deployment

- **CI** (`.github/workflows/ci.yml`) runs on pushes to `main` and on all PRs: installs deps, then
Expand Down
18 changes: 18 additions & 0 deletions docs/USER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,26 @@ runs, or set these in the CapRover app's **App Configs** panel for production.
| `PUBLIC_BASE_URL` | yes | — | Public URL, e.g. `https://mem0.your-domain.com`. Used in OAuth metadata. |
| `OAUTH_SIGNING_KEY` | no | empty | PEM RSA private key. **Setting this enables Phase 2 OAuth.** Leave blank for Phase 1. |
| `OAUTH_ALLOWED_REDIRECT_URIS` | no | claude.ai + cowork + chatgpt callbacks | Comma-separated allowlist for OAuth redirect URIs. An entry ending in `*` is a **path-prefix** match locked to an exact scheme + host — it must be a full `scheme://host/path/` prefix (e.g. `https://chatgpt.com/connector/oauth/*`). Host-only or bare wildcards (`https://chatgpt.com*`, `https://*`, `*`) are **ignored**, so a misconfigured entry can't match lookalike hosts like `chatgpt.com.evil.com`. |
| `TRUST_FORWARDED_FOR` | no | `true` | Use the first `X-Forwarded-For` hop as the client IP for rate limiting. Correct behind CapRover's nginx; set to `false` if the app is exposed directly (no reverse proxy), where the header would be attacker-controlled. |
| `RATE_LIMIT_AUTH_FAILURES` | no | `10` | Failed bearer-token attempts (REST + MCP, per surface) allowed per IP per window before 429s. `0` disables. |
| `RATE_LIMIT_AUTH_WINDOW_SECONDS` | no | `60` | Window for the above. |
| `RATE_LIMIT_CONSENT_FAILURES` | no | `5` | Failed OAuth consent (wrong API key) attempts per IP per window. `0` disables. |
| `RATE_LIMIT_CONSENT_WINDOW_SECONDS` | no | `300` | Window for the above. |
| `RATE_LIMIT_TOKEN_FAILURES` | no | `10` | Failed `/oauth/token` exchanges per IP per window. `0` disables. |
| `RATE_LIMIT_TOKEN_WINDOW_SECONDS` | no | `60` | Window for the above. |
| `LOG_LEVEL` | no | `INFO` | Log level. |

#### Rate limiting

Failed authentication attempts are rate-limited per client IP to slow down brute-force guessing of
`MEM0_API_KEY` (and OAuth codes). Only **failures** count — normal authenticated traffic is never
throttled — but once an IP crosses the limit, *all* its requests to that surface (even with the
correct token) get **HTTP 429** with a `Retry-After` header until the window expires. The four
surfaces (REST `/api/v1/...`, MCP `/mcp`, OAuth consent, OAuth token) are limited independently.
`/healthz` and `/metrics` are never limited. Limits are per uvicorn worker (the default image runs
2 workers), so the effective ceiling is about twice the configured value. If you lock yourself out
during testing, wait out the window or restart the app.

### Phases

- **Phase 1 (MVP)** — static bearer token only. Leave `OAUTH_SIGNING_KEY` blank. Works with Claude
Expand Down
11 changes: 11 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,17 @@
memory_mod.get_memory = lambda: FAKE_MEMORY


@pytest.fixture(autouse=True)
def _reset_rate_limiters():
# Failed-auth counts are keyed by client IP and the TestClient always
# connects as "testclient", so limiter state must not leak across tests.
from app import ratelimit

ratelimit.reset_all()
yield
ratelimit.reset_all()


@pytest.fixture
def mem():
FAKE_MEMORY.reset_mock()
Expand Down
Loading
Loading