Skip to content

Commit b0d694b

Browse files
feat(security): add CORS and TrustedHost middlewares
- Added `CORSMiddleware` and `TrustedHostMiddleware` to `kafka_app.py`. - Configured `ALLOWED_ORIGINS` and `ALLOWED_HOSTS` from environment variables. - Implemented logic to disable `allow_credentials` when `ALLOWED_ORIGINS` contains `*`. - Added test `tests/controller/test_middleware.py` to verify middleware configuration. Co-authored-by: lgcorzo <46710567+lgcorzo@users.noreply.github.com>
1 parent ab9e880 commit b0d694b

2 files changed

Lines changed: 50 additions & 0 deletions

File tree

src/regression_model_template/controller/kafka_app.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
import uvicorn
1313
import pandas as pd
1414
from fastapi import FastAPI, HTTPException
15+
from fastapi.middleware.cors import CORSMiddleware
16+
from fastapi.middleware.trustedhost import TrustedHostMiddleware
1517
from pydantic import BaseModel
1618

1719
from confluent_kafka import Producer, Consumer, KafkaError, Message
@@ -29,6 +31,8 @@
2931
DEFAULT_OUTPUT_TOPIC = os.getenv("DEFAULT_OUTPUT_TOPIC", "output_topic")
3032
DEFAULT_FASTAPI_HOST = os.getenv("DEFAULT_FASTAPI_HOST", "127.0.0.1")
3133
DEFAULT_FASTAPI_PORT = int(os.getenv("DEFAULT_FASTAPI_PORT", 8100))
34+
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "*").split(",")
35+
ALLOWED_HOSTS = os.getenv("ALLOWED_HOSTS", "*").split(",")
3236
LOGGING_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
3337

3438

@@ -43,6 +47,23 @@
4347
version="1.0.0",
4448
)
4549

50+
# Security Middleware Configuration
51+
# If ALLOWED_ORIGINS contains *, allow_credentials must be False to prevent browser errors
52+
allow_credentials = "*" not in ALLOWED_ORIGINS
53+
54+
app.add_middleware(
55+
CORSMiddleware,
56+
allow_origins=ALLOWED_ORIGINS,
57+
allow_credentials=allow_credentials,
58+
allow_methods=["*"],
59+
allow_headers=["*"],
60+
)
61+
62+
app.add_middleware(
63+
TrustedHostMiddleware,
64+
allowed_hosts=ALLOWED_HOSTS,
65+
)
66+
4667

4768
# Data Models
4869
class PredictionRequest(BaseModel):
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from starlette.middleware.cors import CORSMiddleware
2+
from starlette.middleware.trustedhost import TrustedHostMiddleware
3+
from regression_model_template.controller.kafka_app import app
4+
5+
6+
def test_middleware_configuration():
7+
"""Test that CORS and TrustedHost middlewares are configured."""
8+
middlewares = [m.cls for m in app.user_middleware]
9+
10+
# Check for CORSMiddleware
11+
assert CORSMiddleware in middlewares, "CORSMiddleware is missing"
12+
13+
# Check for TrustedHostMiddleware
14+
assert TrustedHostMiddleware in middlewares, "TrustedHostMiddleware is missing"
15+
16+
17+
def test_cors_middleware_options():
18+
"""Test CORSMiddleware configuration details."""
19+
cors_middleware = next((m for m in app.user_middleware if m.cls == CORSMiddleware), None)
20+
assert cors_middleware is not None
21+
22+
# We can inspect the options passed to the middleware
23+
# Note: In Starlette/FastAPI, middleware options are stored in .kwargs or .options
24+
# For CORSMiddleware, we expect allow_origins, allow_credentials, etc.
25+
26+
options = cors_middleware.kwargs
27+
assert "allow_origins" in options
28+
assert "allow_methods" in options
29+
assert "allow_headers" in options

0 commit comments

Comments
 (0)