diff --git a/src/regression_model_template/controller/kafka_app.py b/src/regression_model_template/controller/kafka_app.py index b04c716..2cd5aed 100644 --- a/src/regression_model_template/controller/kafka_app.py +++ b/src/regression_model_template/controller/kafka_app.py @@ -12,6 +12,8 @@ import uvicorn import pandas as pd from fastapi import FastAPI, HTTPException +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.trustedhost import TrustedHostMiddleware from pydantic import BaseModel from confluent_kafka import Producer, Consumer, KafkaError, Message @@ -29,6 +31,10 @@ DEFAULT_OUTPUT_TOPIC = os.getenv("DEFAULT_OUTPUT_TOPIC", "output_topic") DEFAULT_FASTAPI_HOST = os.getenv("DEFAULT_FASTAPI_HOST", "127.0.0.1") DEFAULT_FASTAPI_PORT = int(os.getenv("DEFAULT_FASTAPI_PORT", 8100)) +# Security Configuration +ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "*").split(",") +ALLOWED_HOSTS = os.getenv("ALLOWED_HOSTS", "*").split(",") + LOGGING_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" @@ -43,6 +49,20 @@ version="1.0.0", ) +# Add Security Middleware +app.add_middleware( + TrustedHostMiddleware, + allowed_hosts=ALLOWED_HOSTS, +) + +app.add_middleware( + CORSMiddleware, + allow_origins=ALLOWED_ORIGINS, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + # Data Models class PredictionRequest(BaseModel): diff --git a/tests/controller/test_kafka_app_middleware.py b/tests/controller/test_kafka_app_middleware.py new file mode 100644 index 0000000..0d8c48f --- /dev/null +++ b/tests/controller/test_kafka_app_middleware.py @@ -0,0 +1,61 @@ +from regression_model_template.controller import kafka_app +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.trustedhost import TrustedHostMiddleware +import os +from unittest.mock import patch +from importlib import reload + + +def test_middleware_present(): + """Test that security middlewares are correctly added to the FastAPI app with defaults.""" + # Ensure we start with a clean state (defaults) + reload(kafka_app) + app = kafka_app.app + + # Create a map of middleware class to the middleware object + middlewares = {m.cls: m for m in app.user_middleware} + + # Assert CORSMiddleware is present + assert CORSMiddleware in middlewares, "CORSMiddleware should be added to the app" + cors_middleware = middlewares[CORSMiddleware] + + # Verify CORS configuration (checking kwargs as per memory instruction) + assert cors_middleware.kwargs["allow_origins"] == ["*"] + assert cors_middleware.kwargs["allow_credentials"] is True + assert cors_middleware.kwargs["allow_methods"] == ["*"] + assert cors_middleware.kwargs["allow_headers"] == ["*"] + + # Assert TrustedHostMiddleware is present + assert TrustedHostMiddleware in middlewares, "TrustedHostMiddleware should be added to the app" + trusted_host_middleware = middlewares[TrustedHostMiddleware] + + # Verify TrustedHost configuration + assert trusted_host_middleware.kwargs["allowed_hosts"] == ["*"] + + +def test_middleware_configuration_from_env(): + """Test that middleware configuration respects environment variables.""" + # Set environment variables + with patch.dict( + os.environ, + { + "ALLOWED_ORIGINS": "https://example.com,https://api.example.com", + "ALLOWED_HOSTS": "example.com,api.example.com", + }, + ): + # Reload the module to pick up new env vars + reload(kafka_app) + app = kafka_app.app + + middlewares = {m.cls: m for m in app.user_middleware} + + # Verify CORS configuration + cors_middleware = middlewares[CORSMiddleware] + assert set(cors_middleware.kwargs["allow_origins"]) == {"https://example.com", "https://api.example.com"} + + # Verify TrustedHost configuration + trusted_host_middleware = middlewares[TrustedHostMiddleware] + assert set(trusted_host_middleware.kwargs["allowed_hosts"]) == {"example.com", "api.example.com"} + + # Restore default state for other tests if necessary + reload(kafka_app)