Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion check_env.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import mlflow
import os
import sys

import mlflow
from confluent_kafka import Producer


def check_env():
print(f"Tracking URI: {mlflow.get_tracking_uri()}")
try:
Expand Down
4 changes: 3 additions & 1 deletion promote_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os

import mlflow
from mlflow.tracking import MlflowClient
import os


def promote_model():
tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "http://mlflow.llm-apps.svc.cluster.local:5000")
Expand Down
31 changes: 23 additions & 8 deletions src/regression_model_template/controller/kafka_app.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
"""FastAPI and Kafka Service for Predictions with Logging."""

import json
import logging
import os
import signal
import threading
import logging
import time
import json
import typing as T
from typing import Any, Dict, Callable
from typing import Any, Callable, Dict

import uvicorn
import pandas as pd
import uvicorn
from confluent_kafka import Consumer, KafkaError, Message, Producer
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from pydantic import BaseModel

from confluent_kafka import Producer, Consumer, KafkaError, Message

from regression_model_template.core.schemas import InputsSchema, Outputs
from regression_model_template.io import services, registries
from regression_model_template.io import registries, services
from regression_model_template.io.registries import CustomLoader


# Constants
DEFAULT_KAFKA_SERVER = os.getenv("DEFAULT_KAFKA_SERVER", "kafka_server:9092")
DEFAULT_GROUP_ID = os.getenv("DEFAULT_GROUP_ID", "llmops-regression")
Expand All @@ -29,6 +29,8 @@
DEFAULT_OUTPUT_TOPIC = os.getenv("DEFAULT_OUTPUT_TOPIC", "output_topic")
DEFAULT_FASTAPI_HOST = os.getenv("DEFAULT_FASTAPI_HOST", "127.0.0.1")
DEFAULT_FASTAPI_PORT = int(os.getenv("DEFAULT_FASTAPI_PORT", 8100))
ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "*").split(",")
ALLOWED_HOSTS = os.getenv("ALLOWED_HOSTS", "*").split(",")
LOGGING_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"


Expand All @@ -43,6 +45,19 @@
version="1.0.0",
)

app.add_middleware(
CORSMiddleware,
allow_origins=ALLOWED_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

app.add_middleware(
TrustedHostMiddleware,
allowed_hosts=ALLOWED_HOSTS,
)


# Data Models
class PredictionRequest(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion src/regression_model_template/core/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
import typing as T

import mlflow
from mlflow.metrics import MetricValue
import pandas as pd
import pydantic as pdt
from mlflow.metrics import MetricValue
from sklearn import metrics

from regression_model_template.core import models, schemas
Expand Down
15 changes: 7 additions & 8 deletions src/regression_model_template/io/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,26 @@

import abc
import contextlib as ctx
import logging
import sys
import typing as T
from typing import ClassVar

import loguru
import logging
import mlflow
import mlflow.tracking as mt
import pydantic as pdt

from plyer import notification
from opentelemetry import trace
from opentelemetry._logs import set_logger_provider

from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter

from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler
from opentelemetry.sdk._logs.export import BatchLogRecordProcessor
from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from plyer import notification

from regression_model_template.io.osvariables import Env


Expand Down
2 changes: 1 addition & 1 deletion src/regression_model_template/jobs/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@

import typing as T

import pandas as pd
import pydantic as pdt

from regression_model_template.core import schemas
from regression_model_template.io import datasets, registries
from regression_model_template.jobs import base
import pandas as pd

# %% JOBS

Expand Down
7 changes: 3 additions & 4 deletions src/regression_model_template/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,12 @@
import json
import sys

from regression_model_template import settings
from regression_model_template.io import configs

# %% WARNINGS

import warnings

from regression_model_template import settings
from regression_model_template.io import configs

# disable annoying mlflow warnings
warnings.filterwarnings(action="ignore", category=UserWarning)

Expand Down
2 changes: 1 addition & 1 deletion src/regression_model_template/utils/searchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
# %% IMPORTS

import abc
from typing import Union
import typing as T
from typing import Union

import pandas as pd
import pydantic as pdt
Expand Down
7 changes: 4 additions & 3 deletions tests/controller/simulated_integration_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import requests
import time
import subprocess
import os
import subprocess
import sys
import time

import requests

# This script will run the actual FastAPI app but mock the Kafka part to allow it to start
# without a real Kafka server.
Expand Down
14 changes: 6 additions & 8 deletions tests/controller/test_kafka_app.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
import pytest
from unittest.mock import patch, MagicMock
import json
import os
import signal
from unittest.mock import MagicMock, patch

from fastapi import HTTPException

import pytest
from confluent_kafka import KafkaError

from fastapi import HTTPException

# Assuming the code you provided is in a file named 'app.py'
from regression_model_template.controller.kafka_app import (
DEFAULT_FASTAPI_HOST,
DEFAULT_FASTAPI_PORT,
FastAPIKafkaService,
PredictionRequest,
PredictionResponse,
app,
health_check,
predict,
app,
DEFAULT_FASTAPI_HOST,
DEFAULT_FASTAPI_PORT,
)


Expand Down
3 changes: 2 additions & 1 deletion tests/controller/test_kafka_app_leakage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import MagicMock
import json
from unittest.mock import MagicMock

from regression_model_template.controller.kafka_app import FastAPIKafkaService


Expand Down
5 changes: 3 additions & 2 deletions tests/controller/test_kafka_app_security.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import pytest
import asyncio
from unittest.mock import MagicMock, patch

import pytest
from fastapi import HTTPException
from regression_model_template.controller.kafka_app import (
PredictionRequest,
PredictionService,
predict,
)
import asyncio


def test_prediction_service_sanitization():
Expand Down
35 changes: 35 additions & 0 deletions tests/controller/test_middleware_security.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.trustedhost import TrustedHostMiddleware
from regression_model_template.controller.kafka_app import app


def test_cors_middleware_present():
"""Test that CORSMiddleware is added to the application."""
middleware_types = [m.cls for m in app.user_middleware]
assert CORSMiddleware in middleware_types


def test_trusted_host_middleware_present():
"""Test that TrustedHostMiddleware is added to the application."""
middleware_types = [m.cls for m in app.user_middleware]
assert TrustedHostMiddleware in middleware_types


def test_cors_configuration():
"""Test CORSMiddleware configuration."""
cors_middleware = next(m for m in app.user_middleware if m.cls == CORSMiddleware)
# In this environment, it might be .kwargs or .options
kwargs = getattr(cors_middleware, "kwargs", getattr(cors_middleware, "options", {}))

assert kwargs["allow_origins"] == ["*"]
assert kwargs["allow_credentials"] is True
assert kwargs["allow_methods"] == ["*"]
assert kwargs["allow_headers"] == ["*"]


def test_trusted_host_configuration():
"""Test TrustedHostMiddleware configuration."""
trusted_host_middleware = next(m for m in app.user_middleware if m.cls == TrustedHostMiddleware)
kwargs = getattr(trusted_host_middleware, "kwargs", getattr(trusted_host_middleware, "options", {}))

assert kwargs["allowed_hosts"] == ["*"]