Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions tensormap-backend/app/rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""Rate limiting utilities for API endpoints."""

import time
from collections import defaultdict

from fastapi import HTTPException, Request
from starlette.middleware.base import BaseHTTPMiddleware

from app.shared.logging_config import get_logger

logger = get_logger(__name__)


class RateLimiter:
"""Token bucket rate limiter keyed by client IP."""

def __init__(self, requests_per_minute: int = 60, window_seconds: int = 60):
self.requests_per_minute = requests_per_minute
self.window_seconds = window_seconds
self._clients: dict[str, list[float]] = defaultdict(list)

def _get_client_id(self, request: Request) -> str:
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
client_host = request.client.host if request.client else "unknown"
return client_host

def check(self, request: Request) -> None:
"""Raise HTTPException 429 if the client has exceeded the rate limit."""
client_id = self._get_client_id(request)
now = time.time()
window_start = now - self.window_seconds

timestamps = self._clients[client_id]
timestamps[:] = [t for t in timestamps if t > window_start]

if len(timestamps) >= self.requests_per_minute:
logger.warning("Rate limit exceeded for client %s", client_id)
raise HTTPException(status_code=429, detail="Rate limit exceeded. Please slow down.")

timestamps.append(now)


class RateLimitMiddleware(BaseHTTPMiddleware):
"""Middleware that applies rate limiting to all incoming requests."""

def __init__(self, app, requests_per_minute: int = 60, window_seconds: int = 60):
super().__init__(app)
self.limiter = RateLimiter(
requests_per_minute=requests_per_minute,
window_seconds=window_seconds,
)

async def dispatch(self, request: Request, call_next):
self.limiter.check(request)
response = await call_next(request)
return response
119 changes: 119 additions & 0 deletions tensormap-backend/tests/test_rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Tests for the rate limiter utility."""

import time
from unittest.mock import MagicMock, patch

import pytest
from fastapi import HTTPException
from fastapi.testclient import TestClient
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route

from app.rate_limiter import RateLimiter, RateLimitMiddleware


class TestRateLimiter:
def test_allows_requests_under_limit(self):
"""Requests within the limit should be allowed."""
limiter = RateLimiter(requests_per_minute=5, window_seconds=60)
mock_request = MagicMock()
mock_request.headers = {}
mock_request.client.host = "1.2.3.4"

for _ in range(5):
limiter.check(mock_request)

def test_blocks_requests_over_limit(self):
"""Requests exceeding the limit should raise 429."""
limiter = RateLimiter(requests_per_minute=3, window_seconds=60)
mock_request = MagicMock()
mock_request.headers = {}
mock_request.client.host = "1.2.3.4"

for _ in range(3):
limiter.check(mock_request)

with pytest.raises(HTTPException) as exc:
limiter.check(mock_request)
assert exc.value.status_code == 429

def test_allows_requests_after_window_expires(self):
"""After the window expires, requests should be allowed again."""
limiter = RateLimiter(requests_per_minute=2, window_seconds=60)
mock_request = MagicMock()
mock_request.headers = {}
mock_request.client.host = "1.2.3.4"

limiter.check(mock_request)
limiter.check(mock_request)

with pytest.raises(HTTPException):
limiter.check(mock_request)

future = time.time() + 120
with patch.object(time, "time", return_value=future):
limiter.check(mock_request)

def test_uses_forwarded_for_header(self):
"""If X-Forwarded-For is present, it should be used as the client ID."""
limiter = RateLimiter(requests_per_minute=1, window_seconds=60)
mock_request = MagicMock()
mock_request.headers = {"X-Forwarded-For": "5.6.7.8, 9.10.11.12"}
mock_request.client.host = "1.2.3.4"

limiter.check(mock_request)
with pytest.raises(HTTPException):
limiter.check(mock_request)

def test_different_clients_have_separate_limits(self):
"""Two different clients should not share rate limit state."""
limiter = RateLimiter(requests_per_minute=2, window_seconds=60)

client_a = MagicMock()
client_a.headers = {}
client_a.client.host = "1.1.1.1"

client_b = MagicMock()
client_b.headers = {}
client_b.client.host = "2.2.2.2"

limiter.check(client_a)
limiter.check(client_a)

with pytest.raises(HTTPException):
limiter.check(client_a)

limiter.check(client_b)
limiter.check(client_b)


class TestRateLimitMiddleware:
def test_middleware_allows_normal_requests(self):
"""Normal requests within the limit should succeed."""
app = Starlette(
routes=[
Route("/health", lambda r: JSONResponse({"ok": True})),
],
)
app.add_middleware(RateLimitMiddleware, requests_per_minute=100, window_seconds=60)
client = TestClient(app)

response = client.get("/health")
assert response.status_code == 200

def test_middleware_blocks_excessive_requests(self):
"""Requests exceeding the limit should get 429."""
app = Starlette(
routes=[
Route("/health", lambda r: JSONResponse({"ok": True})),
],
)
app.add_middleware(RateLimitMiddleware, requests_per_minute=2, window_seconds=60)
client = TestClient(app)

client.get("/health")
client.get("/health")

response = client.get("/health")
assert response.status_code == 429
Loading