diff --git a/tensormap-backend/app/rate_limiter.py b/tensormap-backend/app/rate_limiter.py new file mode 100644 index 00000000..b44090b3 --- /dev/null +++ b/tensormap-backend/app/rate_limiter.py @@ -0,0 +1,58 @@ +"""Rate limiting utilities for API endpoints.""" + +import time +from collections import defaultdict + +from fastapi import HTTPException, Request +from starlette.middleware.base import BaseHTTPMiddleware + +from app.shared.logging_config import get_logger + +logger = get_logger(__name__) + + +class RateLimiter: + """Token bucket rate limiter keyed by client IP.""" + + def __init__(self, requests_per_minute: int = 60, window_seconds: int = 60): + self.requests_per_minute = requests_per_minute + self.window_seconds = window_seconds + self._clients: dict[str, list[float]] = defaultdict(list) + + def _get_client_id(self, request: Request) -> str: + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + client_host = request.client.host if request.client else "unknown" + return client_host + + def check(self, request: Request) -> None: + """Raise HTTPException 429 if the client has exceeded the rate limit.""" + client_id = self._get_client_id(request) + now = time.time() + window_start = now - self.window_seconds + + timestamps = self._clients[client_id] + timestamps[:] = [t for t in timestamps if t > window_start] + + if len(timestamps) >= self.requests_per_minute: + logger.warning("Rate limit exceeded for client %s", client_id) + raise HTTPException(status_code=429, detail="Rate limit exceeded. Please slow down.") + + timestamps.append(now) + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """Middleware that applies rate limiting to all incoming requests.""" + + def __init__(self, app, requests_per_minute: int = 60, window_seconds: int = 60): + super().__init__(app) + self.limiter = RateLimiter( + requests_per_minute=requests_per_minute, + window_seconds=window_seconds, + ) + + async def dispatch(self, request: Request, call_next): + self.limiter.check(request) + response = await call_next(request) + return response diff --git a/tensormap-backend/tests/test_rate_limiter.py b/tensormap-backend/tests/test_rate_limiter.py new file mode 100644 index 00000000..46bdf200 --- /dev/null +++ b/tensormap-backend/tests/test_rate_limiter.py @@ -0,0 +1,119 @@ +"""Tests for the rate limiter utility.""" + +import time +from unittest.mock import MagicMock, patch + +import pytest +from fastapi import HTTPException +from fastapi.testclient import TestClient +from starlette.applications import Starlette +from starlette.responses import JSONResponse +from starlette.routing import Route + +from app.rate_limiter import RateLimiter, RateLimitMiddleware + + +class TestRateLimiter: + def test_allows_requests_under_limit(self): + """Requests within the limit should be allowed.""" + limiter = RateLimiter(requests_per_minute=5, window_seconds=60) + mock_request = MagicMock() + mock_request.headers = {} + mock_request.client.host = "1.2.3.4" + + for _ in range(5): + limiter.check(mock_request) + + def test_blocks_requests_over_limit(self): + """Requests exceeding the limit should raise 429.""" + limiter = RateLimiter(requests_per_minute=3, window_seconds=60) + mock_request = MagicMock() + mock_request.headers = {} + mock_request.client.host = "1.2.3.4" + + for _ in range(3): + limiter.check(mock_request) + + with pytest.raises(HTTPException) as exc: + limiter.check(mock_request) + assert exc.value.status_code == 429 + + def test_allows_requests_after_window_expires(self): + """After the window expires, requests should be allowed again.""" + limiter = RateLimiter(requests_per_minute=2, window_seconds=60) + mock_request = MagicMock() + mock_request.headers = {} + mock_request.client.host = "1.2.3.4" + + limiter.check(mock_request) + limiter.check(mock_request) + + with pytest.raises(HTTPException): + limiter.check(mock_request) + + future = time.time() + 120 + with patch.object(time, "time", return_value=future): + limiter.check(mock_request) + + def test_uses_forwarded_for_header(self): + """If X-Forwarded-For is present, it should be used as the client ID.""" + limiter = RateLimiter(requests_per_minute=1, window_seconds=60) + mock_request = MagicMock() + mock_request.headers = {"X-Forwarded-For": "5.6.7.8, 9.10.11.12"} + mock_request.client.host = "1.2.3.4" + + limiter.check(mock_request) + with pytest.raises(HTTPException): + limiter.check(mock_request) + + def test_different_clients_have_separate_limits(self): + """Two different clients should not share rate limit state.""" + limiter = RateLimiter(requests_per_minute=2, window_seconds=60) + + client_a = MagicMock() + client_a.headers = {} + client_a.client.host = "1.1.1.1" + + client_b = MagicMock() + client_b.headers = {} + client_b.client.host = "2.2.2.2" + + limiter.check(client_a) + limiter.check(client_a) + + with pytest.raises(HTTPException): + limiter.check(client_a) + + limiter.check(client_b) + limiter.check(client_b) + + +class TestRateLimitMiddleware: + def test_middleware_allows_normal_requests(self): + """Normal requests within the limit should succeed.""" + app = Starlette( + routes=[ + Route("/health", lambda r: JSONResponse({"ok": True})), + ], + ) + app.add_middleware(RateLimitMiddleware, requests_per_minute=100, window_seconds=60) + client = TestClient(app) + + response = client.get("/health") + assert response.status_code == 200 + + def test_middleware_blocks_excessive_requests(self): + """Requests exceeding the limit should get 429.""" + app = Starlette( + routes=[ + Route("/health", lambda r: JSONResponse({"ok": True})), + ], + ) + app.add_middleware(RateLimitMiddleware, requests_per_minute=2, window_seconds=60) + client = TestClient(app) + + client.get("/health") + client.get("/health") + + response = client.get("/health") + assert response.status_code == 429