diff --git a/.github/workflows/build-and-push.yml b/.github/workflows/build-and-push.yml new file mode 100644 index 0000000..67ae5f8 --- /dev/null +++ b/.github/workflows/build-and-push.yml @@ -0,0 +1,63 @@ +name: Build and Push Docker Image + +on: + push: + branches: + - main + - add_load_factor + tags: + - 'v*' + pull_request: + branches: + - main + workflow_dispatch: + +env: + REGISTRY: ghcr.io + IMAGE_NAME: ${{ github.repository }} + +jobs: + build-and-push: + runs-on: ubuntu-latest + permissions: + contents: read + packages: write + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + - name: Log in to GitHub Container Registry + if: github.event_name != 'pull_request' + uses: docker/login-action@v3 + with: + registry: ${{ env.REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + + - name: Extract metadata (tags, labels) + id: meta + uses: docker/metadata-action@v5 + with: + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} + tags: | + type=ref,event=branch + type=ref,event=pr + type=semver,pattern={{version}} + type=semver,pattern={{major}}.{{minor}} + type=semver,pattern={{major}} + type=sha,prefix={{branch}}- + type=raw,value=latest,enable={{is_default_branch}} + + - name: Build and push Docker image + uses: docker/build-push-action@v5 + with: + context: . + push: ${{ github.event_name != 'pull_request' }} + tags: ${{ steps.meta.outputs.tags }} + labels: ${{ steps.meta.outputs.labels }} + cache-from: type=gha + cache-to: type=gha,mode=max \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..10204eb --- /dev/null +++ b/Dockerfile @@ -0,0 +1,15 @@ +FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim +ENV UV_COMPILE_BYTECODE=1 UV_LINK_MODE=copy +ENV UV_PYTHON_DOWNLOADS=0 + +WORKDIR / +COPY ./openmockllm/ /openmockllm +RUN --mount=type=cache,target=/root/.cache/uv \ + uv venv +RUN --mount=type=cache,target=/root/.cache/uv \ + --mount=type=bind,source=pyproject.toml,target=pyproject.toml \ + uv pip install "." + +ENV PATH="/.venv/bin:$PATH" + +ENTRYPOINT ["python", "-m", "openmockllm.main", "--backend", "vllm", "--port", "8000"] \ No newline at end of file diff --git a/openmockllm/logger.py b/openmockllm/logger.py index bae96e2..06ad716 100644 --- a/openmockllm/logger.py +++ b/openmockllm/logger.py @@ -1,5 +1,167 @@ -from logging import Formatter, Logger, StreamHandler, getLogger +import json +import logging import sys +import time +import uuid +from datetime import datetime, timezone +from logging import Formatter, Logger, StreamHandler, getLogger +from typing import Any, Dict, Optional, List +from typing import Callable + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + + +class JsonFormatter(logging.Formatter): + AVAILABLE_FIELDS = ["timestamp", "level", "logger", "message", "file", "line", "function", "process", "thread", "exception"] + INTERNAL_FIELDS = { + "name", + "msg", + "args", + "created", + "filename", + "funcName", + "levelname", + "levelno", + "lineno", + "module", + "msecs", + "message", + "pathname", + "process", + "processName", + "relativeCreated", + "thread", + "threadName", + "exc_info", + "exc_text", + "stack_info", + } + + def __init__( + self, + fields: Optional[List[str]] = None, + include_extra: bool = True, + ): + super().__init__() + self.fields: List[str] = fields if fields else self.AVAILABLE_FIELDS.copy() + self.include_extra = include_extra + + def format(self, record: logging.LogRecord) -> str: + field_mapping = { + "timestamp": datetime.now(timezone.utc).isoformat().replace("+00:00", "Z"), + "level": record.levelname, + "logger": record.name, + "message": record.getMessage(), + "file": record.pathname, + "line": record.lineno, + "function": record.funcName, + "process": record.process, + "process_name": record.processName, + "thread": record.thread, + "thread_name": record.threadName, + "exception": self.formatException(record.exc_info) if record.exc_info else None, + } + + log_data = self.filtrer_fields_from_log_record(field_mapping, record) + + return json.dumps(log_data, ensure_ascii=False) + + def filtrer_fields_from_log_record(self, field_mapping: dict[str, Any], record: logging.LogRecord) -> dict[str, Any]: + available_data: Dict[str, Any] = {} + record_as_dict: Dict[str, Any] = record.__dict__ + + for field in self.fields: + if field in field_mapping.keys(): + available_data[field] = field_mapping[field] + else: + if field in record_as_dict.keys(): + available_data[field] = record_as_dict[field] + fields = list(available_data.keys()) + [field for field in self.INTERNAL_FIELDS if field not in list(available_data.keys())] + extra_fields = [extra_field for extra_field in record_as_dict.keys() if extra_field not in fields] + for extra_field in extra_fields: + available_data[extra_field] = record_as_dict[extra_field] + return available_data + + +def init_json_logger( + name: str, + level: str = "INFO", + fields: Optional[List[str]] = None, + include_extra: bool = True, +) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(getattr(logging, level.upper())) + + logger.handlers.clear() + + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + JsonFormatter( + fields=fields, + include_extra=include_extra, + ) + ) + logger.addHandler(handler) + + logger.propagate = False + + return logger + + +class LoggingMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: Callable) -> Response: + request_id = str(uuid.uuid4()) + + request.state.request_id = request_id + + start_time = time.time() + + logger = logging.getLogger("api") + logger.info( + "Requête entrante", + extra={ + "request_id": request_id, + "method": request.method, + "path": request.url.path, + "client_ip": request.client.host if request.client else None, + "user_agent": request.headers.get("user-agent"), + }, + ) + + try: + response = await call_next(request) + + duration = time.time() - start_time + + logger.info( + "Requête complétée", + extra={ + "request_id": request_id, + "method": request.method, + "path": request.url.path, + "status_code": response.status_code, + "duration_ms": round(duration * 1000, 2), + }, + ) + + response.headers["X-Request-ID"] = request_id + + return response + + except Exception as e: + duration = time.time() - start_time + + logger.exception( + "Erreur lors du traitement de la requête", + extra={ + "request_id": request_id, + "method": request.method, + "path": request.url.path, + "duration_ms": round(duration * 1000, 2), + }, + ) + raise class ColoredFormatter(Formatter): diff --git a/openmockllm/main.py b/openmockllm/main.py index a12ec9c..cabbae7 100644 --- a/openmockllm/main.py +++ b/openmockllm/main.py @@ -1,12 +1,25 @@ import argparse +import logging +import os -from fastapi import FastAPI import uvicorn +from fastapi import FastAPI -from openmockllm.logger import init_logger +from openmockllm.logger import LoggingMiddleware, init_json_logger from openmockllm.settings import settings -logger = init_logger("openmockllm") + +def init_logger_from_env(name: str) -> logging.Logger: + settings.log_level = os.getenv("LOG_LEVEL", "INFO") + fields = os.getenv("LOG_FIELDS") + settings.log_fields = [f.strip() for f in fields.split(",")] if fields else None + settings.log_include_extra = os.getenv("LOG_INCLUDE_EXTRA", "true").lower() == "true" + return init_json_logger( + name, + level=settings.log_level, + fields=settings.log_fields, + include_extra=settings.log_include_extra, + ) def parse_args(): @@ -26,7 +39,7 @@ def parse_args(): return parser.parse_args() -def create_app(args): +def create_app(args, logger: logging.Logger): """Create and configure FastAPI application""" if args.api_key: settings.api_key = args.api_key @@ -42,6 +55,7 @@ def create_app(args): description="Mock LLM API Server supporting vllm and mistral", version="1.0.0", ) + app.add_middleware(LoggingMiddleware) # Store configuration in app state app.state.backend = args.backend @@ -86,7 +100,7 @@ def create_app(args): def main(): """Main entry point""" args = parse_args() - + logger = init_logger_from_env("api") logger.info("=" * 60) logger.info("OpenMockLLM API Server") logger.info("=" * 60) @@ -101,12 +115,18 @@ def main(): logger.info(f"Faker seed instance: {args.faker_seed_instance if args.faker_seed_instance else 'Disabled'}") logger.info("=" * 60) - app = create_app(args) + app = create_app(args, logger) logger.info(f"Starting server on http://0.0.0.0:{args.port}") logger.info(f"API documentation: http://0.0.0.0:{args.port}/docs") - uvicorn.run(app, host="0.0.0.0", port=args.port) + uvicorn.run( + app, + host="0.0.0.0", + port=args.port, + log_config=None, + access_log=False, + ) if __name__ == "__main__": diff --git a/openmockllm/settings.py b/openmockllm/settings.py index 70b3d99..04a33aa 100644 --- a/openmockllm/settings.py +++ b/openmockllm/settings.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, List class Settings: @@ -6,6 +6,8 @@ class Settings: tiktoken_encoder: str = "cl100k_base" faker_langage: str = "fr_FR" faker_seed_instance: Optional[int] = None + log_fields: List[str] = [] + log_level: str = "INFO" settings = Settings() diff --git a/openmockllm/vllm/endpoints/chat.py b/openmockllm/vllm/endpoints/chat.py index 06c9462..d814597 100644 --- a/openmockllm/vllm/endpoints/chat.py +++ b/openmockllm/vllm/endpoints/chat.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse -from openmockllm.logger import init_logger +from openmockllm.main import init_logger_from_env from openmockllm.security import check_api_key from openmockllm.vllm.exceptions import NotFoundError from openmockllm.vllm.schemas.chat import ( @@ -15,9 +15,18 @@ Message, Usage, ) -from openmockllm.vllm.utils.chat import calculate_realistic_delay, count_tokens, generate_random_response, generate_stream_response +from openmockllm.vllm.utils.chat import ( + calculate_realistic_delay, + count_tokens, + generate_random_response, + generate_stream_response, + get_active_requests, + increment_active_requests, + decrement_active_requests, +) + +logger = init_logger_from_env(__name__) -logger = init_logger(__name__) router = APIRouter(prefix="/v1", tags=["chat"]) @@ -25,31 +34,43 @@ async def chat_completions(request: Request, body: ChatRequest): """Handle chat completion requests with streaming and non-streaming support""" request_id = f"chatcmpl-{uuid.uuid4().hex}" + start_time = time.perf_counter() + try: + increment_active_requests() + logger.info("beginning of request", extra={"active requests": get_active_requests()}) + + if body.model and body.model != request.app.state.model_name: + raise NotFoundError(f"The model `{body.model}` does not exist.") + last_message = body.messages[-1].content if body.messages else "" + simulated_response = generate_random_response(last_message, body.temperature, body.max_tokens) + + if body.stream: + return StreamingResponse(generate_stream_response(simulated_response, body.model, body.temperature), media_type="text/event-stream") + else: + prompt_tokens = sum(count_tokens(msg.content) for msg in body.messages) + completion_tokens = count_tokens(simulated_response) + + delay = calculate_realistic_delay(completion_tokens, body.temperature) + await asyncio.sleep(delay) - # Use the model from the request or fall back to the default - model = body.model or request.app.state.model_name - - # Validate model if specified - if body.model and body.model != request.app.state.model_name: - raise NotFoundError(f"The model `{body.model}` does not exist.") - last_message = body.messages[-1].content if body.messages else "" - simulated_response = generate_random_response(last_message, body.temperature, body.max_tokens) - - if body.stream: - return StreamingResponse(generate_stream_response(simulated_response, body.model, body.temperature), media_type="text/event-stream") - else: - prompt_tokens = sum(count_tokens(msg.content) for msg in body.messages) - completion_tokens = count_tokens(simulated_response) - - delay = calculate_realistic_delay(completion_tokens, body.temperature) - await asyncio.sleep(delay) - - response = ChatResponse( - id=request_id, - object="chat.completion", - created=int(time.time()), - model=body.model, - choices=[ChatResponseChoice(index=0, message=Message(role="assistant", content=simulated_response), finish_reason="stop")], - usage=Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens), + response = ChatResponse( + id=request_id, + object="chat.completion", + created=int(time.time()), + model=body.model, + choices=[ChatResponseChoice(index=0, message=Message(role="assistant", content=simulated_response), finish_reason="stop")], + usage=Usage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=prompt_tokens + completion_tokens), + ) + return response + finally: + decrement_active_requests() + execution_time = time.perf_counter() - start_time + logger.info( + "end of request", + extra={ + "requestMessage": last_message, + "activeRequests": get_active_requests(), + "execution_time_seconds": round(execution_time, 3), + "request_id": request_id, + }, ) - return response diff --git a/openmockllm/vllm/utils/chat.py b/openmockllm/vllm/utils/chat.py index 4881e1a..0784240 100644 --- a/openmockllm/vllm/utils/chat.py +++ b/openmockllm/vllm/utils/chat.py @@ -1,6 +1,7 @@ import asyncio import json import random +import threading import time from typing import AsyncGenerator, Optional import uuid @@ -12,6 +13,41 @@ fake = Faker("fr_FR") fake.seed_instance() +_active_requests = 0 +_lock = threading.Lock() + + +def increment_active_requests() -> int: + global _active_requests + with _lock: + _active_requests += 1 + return _active_requests + + +def decrement_active_requests() -> int: + global _active_requests + with _lock: + _active_requests = max(0, _active_requests - 1) + return _active_requests + + +def get_active_requests() -> int: + global _active_requests + with _lock: + return _active_requests + + +def calculate_load_factor() -> float: + active = get_active_requests() + if active <= 1: + return 1.0 + elif active <= 5: + return 1.0 + (active - 1) * 0.2 + elif active <= 10: + return 1.8 + (active - 5) * 0.4 + else: + return 3.8 + (active - 10) * 0.6 + def count_tokens(text: str) -> int: return len(tokenizer.encode(text)) @@ -53,7 +89,10 @@ def calculate_realistic_delay(completion_tokens: int, temperature: Optional[floa total_delay = (base_delay + startup_delay) * variation - return max(0.1, total_delay) + load_factor = calculate_load_factor() + total_delay *= load_factor + + return total_delay async def generate_stream_response(response_text: str, model: str, temperature: Optional[float] = 0.7) -> AsyncGenerator[str, None]: @@ -86,7 +125,6 @@ async def generate_stream_response(response_text: str, model: str, temperature: "choices": [{"index": 0, "delta": {"content": token_text}, "finish_reason": None}], } yield f"data: {json.dumps(chunk)}\n\n" - await asyncio.sleep(token_delay) final_chunk = {