From 04ef3a63abddfc792ab93ef7b4b5ae16e48a2f97 Mon Sep 17 00:00:00 2001 From: saurabh batham Date: Sun, 14 Jun 2026 22:24:31 +0530 Subject: [PATCH] feat(ai-investigation): operationalize with FastAPI entry point and LLM fallback - Add main.py with lifespan (creates OpenAILLMProvider, InvestigationPipeline) - Add health and investigate routers (GET /health, POST /investigate/run) - Add generate_structured fallback: try response_format first, retry with prompt-based JSON instruction when model doesn't support structured output - Add _extract_json to strip think tags, markdown fences, and leading text - Add scripts/simulate_incident.py for end-to-end testing - Update Dockerfile with EXPOSE 8002 and uvicorn CMD - Add env_file and extra=ignore to OpenAI and investigation configs - Update .env with OPENAI_BASE_URL and infrastructure credentials --- .env.example | 1 + infrastructure/openai/openai_llm_provider.py | 47 ++++- .../openai/tests/test_openai_llm_provider.py | 28 ++- scripts/simulate_incident.py | 163 ++++++++++++++++++ services/ai-investigation-service/Dockerfile | 4 +- .../ai-investigation-service/app/config.py | 2 +- services/ai-investigation-service/app/main.py | 43 +++++ .../app/routers/__init__.py | 0 .../app/routers/health.py | 22 +++ .../app/routers/investigate.py | 22 +++ 10 files changed, 325 insertions(+), 7 deletions(-) create mode 100644 scripts/simulate_incident.py create mode 100644 services/ai-investigation-service/app/main.py create mode 100644 services/ai-investigation-service/app/routers/__init__.py create mode 100644 services/ai-investigation-service/app/routers/health.py create mode 100644 services/ai-investigation-service/app/routers/investigate.py diff --git a/.env.example b/.env.example index 7a96783..1c3385c 100644 --- a/.env.example +++ b/.env.example @@ -46,6 +46,7 @@ SERVICE_NAME=rootpilot-service # ── OpenAI LLM Provider (prefix OPENAI_) ─────────────────────────────────── OPENAI_API_KEY= # REQUIRED for AI investigation +# OPENAI_BASE_URL= # defaults to https://api.openai.com/v1 # OPENAI_DEFAULT_MODEL=gpt-4o # OPENAI_DEFAULT_TEMPERATURE=0.1 # OPENAI_DEFAULT_MAX_TOKENS=4096 diff --git a/infrastructure/openai/openai_llm_provider.py b/infrastructure/openai/openai_llm_provider.py index 04c968a..a946995 100644 --- a/infrastructure/openai/openai_llm_provider.py +++ b/infrastructure/openai/openai_llm_provider.py @@ -1,5 +1,6 @@ """OpenAI LLM provider adapter implementing the LLMProvider contract.""" +import json from typing import Any, TypeVar from pydantic import BaseModel, Field @@ -11,9 +12,12 @@ class OpenAIProviderConfig(BaseSettings): - model_config = {"env_prefix": "OPENAI_"} + model_config = {"env_prefix": "OPENAI_", "env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"} api_key: str = Field(default="", description="OpenAI API key.") + base_url: str | None = Field( + default=None, description="OpenAI API base URL (defaults to https://api.openai.com/v1)." + ) default_model: str = Field(default="gpt-4o", description="Default model identifier.") default_temperature: float = Field(default=0.1, ge=0.0, le=2.0, description="Default generation temperature.") default_max_tokens: int = Field(default=4096, ge=1, description="Default max tokens per response.") @@ -35,7 +39,10 @@ def __init__(self, config: OpenAIProviderConfig | None = None) -> None: async def start(self) -> None: from openai import AsyncOpenAI - self._client = AsyncOpenAI(api_key=self._config.api_key, timeout=self._config.timeout_seconds) + kwargs: dict[str, Any] = {"api_key": self._config.api_key, "timeout": self._config.timeout_seconds} + if self._config.base_url is not None: + kwargs["base_url"] = self._config.base_url + self._client = AsyncOpenAI(**kwargs) async def close(self) -> None: self._client = None @@ -63,7 +70,23 @@ async def generate_structured( "json_schema": {"name": schema.__name__, "strict": True, "schema": schema.model_json_schema()}, } response = await self._client.chat.completions.create(**kwargs) - return schema.model_validate_json(response.choices[0].message.content) + raw = response.choices[0].message.content or "" + try: + cleaned = self._extract_json(raw) + return schema.model_validate_json(cleaned) + except (ValueError, Exception): + pass + + # Fallback — model doesn't support response_format, retry with JSON instruction in system prompt + schema_json = json.dumps(schema.model_json_schema(), indent=2) + instruction = f"\n\nYou MUST respond with valid JSON matching this schema:\n```json\n{schema_json}\n```\nReturn ONLY the JSON object — no markdown, no explanations, no tags." + fallback = [ + LLMMessage(role=m.role, content=m.content + instruction) if m.role == "system" else m for m in messages + ] + fbk = self._build_kwargs(fallback, model) + response = await self._client.chat.completions.create(**fbk) + cleaned = self._extract_json(response.choices[0].message.content or "") + return schema.model_validate_json(cleaned) async def embed(self, text: str, model: str | None = None) -> list[float]: model = model or "text-embedding-3-small" @@ -84,6 +107,24 @@ def _build_kwargs( "max_tokens": max_tokens or self._config.default_max_tokens, } + @staticmethod + def _extract_json(content: str) -> str: + """Strip non-JSON wrappers (think tags, markdown fences, leading text) from LLM output.""" + import re + + # Remove ... blocks + content = re.sub(r".*?", "", content, flags=re.DOTALL).strip() + # Extract JSON from markdown code fence if present + m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, flags=re.DOTALL) + if m: + return m.group(1).strip() + # Fallback: find first { and last } + start = content.find("{") + end = content.rfind("}") + if start != -1 and end != -1 and end > start: + return content[start : end + 1] + return content + @staticmethod def _to_response(raw: Any) -> LLMResponse: choice = raw.choices[0] diff --git a/infrastructure/openai/tests/test_openai_llm_provider.py b/infrastructure/openai/tests/test_openai_llm_provider.py index 133fce7..1276b29 100644 --- a/infrastructure/openai/tests/test_openai_llm_provider.py +++ b/infrastructure/openai/tests/test_openai_llm_provider.py @@ -15,7 +15,7 @@ class _FakeSchema(BaseModel): @pytest.fixture def config() -> OpenAIProviderConfig: - return OpenAIProviderConfig(api_key="test-key") + return OpenAIProviderConfig(api_key="test-key", base_url=None, _env_file=None) # type: ignore[call-arg] @pytest.fixture @@ -25,11 +25,12 @@ def provider(config: OpenAIProviderConfig) -> OpenAILLMProvider: class TestOpenAIProviderConfig: def test_defaults(self) -> None: - cfg = OpenAIProviderConfig() + cfg = OpenAIProviderConfig(_env_file=None) # type: ignore[call-arg] assert cfg.default_model == "gpt-4o" assert cfg.default_temperature == 0.1 assert cfg.default_max_tokens == 4096 assert cfg.timeout_seconds == 60 + assert cfg.base_url is None def test_env_prefix(self) -> None: assert OpenAIProviderConfig.model_config.get("env_prefix") == "OPENAI_" @@ -43,6 +44,13 @@ async def test_start_creates_client(self, provider: OpenAILLMProvider) -> None: mock.assert_called_once_with(api_key="test-key", timeout=60) assert provider._client is not None + async def test_start_passes_base_url_when_set(self) -> None: + cfg = OpenAIProviderConfig(api_key="test-key", base_url="https://custom.openai.com/v1") + provider = OpenAILLMProvider(cfg) + with patch("openai.AsyncOpenAI") as mock: + await provider.start() + mock.assert_called_once_with(api_key="test-key", timeout=60, base_url="https://custom.openai.com/v1") + async def test_close_clears_client(self, provider: OpenAILLMProvider) -> None: provider._client = MagicMock() await provider.close() @@ -81,6 +89,22 @@ async def test_generate_structured_returns_parsed_schema(self, provider: OpenAIL assert isinstance(result, _FakeSchema) assert result.result == "ok" + def test_extract_json_strips_think_tags(self) -> None: + raw = 'Let me analyze the data.{"result": "ok"}' + assert OpenAILLMProvider._extract_json(raw) == '{"result": "ok"}' + + def test_extract_json_strips_markdown_fence(self) -> None: + raw = '```json\n{"result": "ok"}\n```' + assert OpenAILLMProvider._extract_json(raw) == '{"result": "ok"}' + + def test_extract_json_falls_back_to_braces(self) -> None: + raw = 'Some text before {"result": "ok"} and after' + assert OpenAILLMProvider._extract_json(raw) == '{"result": "ok"}' + + def test_extract_json_clean_json_passthrough(self) -> None: + raw = '{"result": "ok"}' + assert OpenAILLMProvider._extract_json(raw) == '{"result": "ok"}' + async def test_embed_returns_float_list(self, provider: OpenAILLMProvider) -> None: provider._client = MagicMock() mock_embedding = MagicMock() diff --git a/scripts/simulate_incident.py b/scripts/simulate_incident.py new file mode 100644 index 0000000..6fa460b --- /dev/null +++ b/scripts/simulate_incident.py @@ -0,0 +1,163 @@ +"""Simulate a realistic database latency spike incident end-to-end. + +Sends telemetry to the ingestion service (port 8000), then builds an +incident context and runs the investigation pipeline (port 8002). + +Usage: + python -m scripts.simulate_incident + # or: python scripts/simulate_incident.py +""" + +import sys +from datetime import UTC, datetime, timedelta + +import httpx + +INGESTION_URL = "http://localhost:8000/api/v1/ingest" +INVESTIGATION_URL = "http://localhost:8002/investigate/run" + + +def _ts(offset_seconds: int = 0) -> str: + return (datetime.now(UTC) - timedelta(seconds=offset_seconds)).isoformat() + + +TELEMETRY_BATCH = [ + # ── Phase 1: DB CPU climb (T-120s to T-60s) ────────────────────── + {"metric": "db.cpu.usage", "value": 72.0, "unit": "%", "source": "postgres-primary", "tags": {"service": "db", "region": "us-east-1"}, "timestamp": _ts(120)}, + {"metric": "db.cpu.usage", "value": 78.0, "unit": "%", "source": "postgres-primary", "tags": {"service": "db", "region": "us-east-1"}, "timestamp": _ts(90)}, + {"metric": "db.cpu.usage", "value": 85.0, "unit": "%", "source": "postgres-primary", "tags": {"service": "db", "region": "us-east-1"}, "timestamp": _ts(60)}, + # ── Phase 2: Query latency degrades (T-60s to T-30s) ───────────── + {"metric": "db.query.latency", "value": 450.0, "unit": "ms", "source": "postgres-primary", "tags": {"service": "db", "query": "SELECT * FROM orders"}, "timestamp": _ts(60)}, + {"metric": "db.query.latency", "value": 1200.0, "unit": "ms", "source": "postgres-primary", "tags": {"service": "db", "query": "SELECT * FROM orders"}, "timestamp": _ts(45)}, + {"metric": "db.query.latency", "value": 2300.0, "unit": "ms", "source": "postgres-primary", "tags": {"service": "db", "query": "SELECT * FROM orders"}, "timestamp": _ts(30)}, + {"metric": "db.connections.active", "value": 89.0, "unit": "count", "source": "postgres-primary", "tags": {"service": "db", "pool": "primary"}, "timestamp": _ts(30)}, + # ── Phase 3: API timeouts (T-30s to T-10s) ─────────────────────── + {"metric": "api.request.latency", "value": 3200.0, "unit": "ms", "source": "api-gateway", "tags": {"service": "api", "endpoint": "/api/orders"}, "timestamp": _ts(30)}, + {"metric": "api.error.rate", "value": 12.5, "unit": "%", "source": "api-gateway", "tags": {"service": "api", "error": "timeout"}, "timestamp": _ts(25)}, + {"metric": "api.error.rate", "value": 28.0, "unit": "%", "source": "api-gateway", "tags": {"service": "api", "error": "timeout"}, "timestamp": _ts(15)}, + {"metric": "api.request.latency", "value": 5100.0, "unit": "ms", "source": "api-gateway", "tags": {"service": "api", "endpoint": "/api/orders"}, "timestamp": _ts(10)}, + # ── Phase 4: Gateway 502s (T-10s to now) ───────────────────────── + {"metric": "gateway.upstream.errors", "value": 42.0, "unit": "count", "source": "gateway", "tags": {"service": "gateway", "upstream": "api", "status": "502"}, "timestamp": _ts(10)}, + {"metric": "gateway.upstream.errors", "value": 87.0, "unit": "count", "source": "gateway", "tags": {"service": "gateway", "upstream": "api", "status": "502"}, "timestamp": _ts(5)}, + {"metric": "gateway.latency", "value": 6200.0, "unit": "ms", "source": "gateway", "tags": {"service": "gateway", "endpoint": "/api/orders"}, "timestamp": _ts(3)}, +] + + +def send_telemetry(client: httpx.Client) -> list[str]: + ids: list[str] = [] + for point in TELEMETRY_BATCH: + resp = client.post(INGESTION_URL, json=point) + resp.raise_for_status() + data = resp.json() + ids.append(data["event_id"]) + print(f" ✓ {point['metric']:30s} = {point['value']:>8} {point['unit'] or '':4s} → {data['event_id'][:8]}...") + return ids + + +def build_incident_context(event_ids: list[str]) -> dict: + return { + "incident_id": f"sim-{datetime.now(UTC).strftime('%Y%m%d-%H%M%S')}", + "primary_service": "postgres-primary", + "severity": "CRITICAL", + "title": "Database latency spike cascading to API timeouts and gateway 502 errors", + "detected_at": datetime.now(UTC).isoformat(), + "event_count": len(event_ids), + "service_count": 3, + "trace_count": 0, + "ungrouped_events": [], + "correlation_groups": [ + { + "group_id": "g-db", + "event_ids": event_ids[:7], + "composite_score": 0.89, + "signals": ["time_proximity"], + "services": ["postgres-primary"], + }, + { + "group_id": "g-api", + "event_ids": event_ids[7:11], + "composite_score": 0.82, + "signals": ["time_proximity"], + "services": ["api-gateway"], + }, + { + "group_id": "g-gateway", + "event_ids": event_ids[11:], + "composite_score": 0.75, + "signals": ["time_proximity", "dependency_chain"], + "services": ["gateway"], + }, + ], + "impacts": [ + { + "service": "api-gateway", + "upstream_causes": ["postgres-primary"], + "downstream_impact": ["gateway"], + "propagation_paths": [["postgres-primary", "api-gateway", "gateway"]], + }, + { + "service": "gateway", + "upstream_causes": ["postgres-primary", "api-gateway"], + "downstream_impact": [], + "propagation_paths": [], + }, + ], + } + + +def run_investigation(client: httpx.Client, context: dict) -> dict: + resp = client.post(INVESTIGATION_URL, json=context, timeout=120) + resp.raise_for_status() + return resp.json() + + +def check_service(url: str, name: str) -> bool: + try: + resp = httpx.get(url.replace("/api/v1/ingest", "/health").replace("/investigate/run", "/health"), timeout=5) + return resp.is_success + except httpx.RequestError as e: + print(f" ✗ {name} unreachable ({e})") + return False + + +def main() -> None: + print("Checking services...") + health_ok = ( + check_service(INGESTION_URL, "ingestion (port 8000)") + and check_service(INVESTIGATION_URL, "investigation (port 8002)") + ) + if not health_ok: + sys.exit(1) + + print("\n── Scenario: Database Latency Spike ──────────────────────") + print(" Primary database CPU spikes → query latency jumps to 2.3s\n") + + with httpx.Client() as client: + print("Sending 14 telemetry points to ingestion...") + event_ids = send_telemetry(client) + print(f"\n ✓ {len(event_ids)} events accepted\n") + + print("Building incident context and running investigation...") + context = build_incident_context(event_ids) + result = run_investigation(client, context) + + print(f"\n── Investigation Result ──────────────────────────────") + summary = result["summary"] + print(f" Incident: {summary['incident_id']}") + print(f" Title: {summary['title']}") + print(f" Duration: {result['duration_ms']:.0f}ms") + print(f" Confidence: {summary['overall_confidence']:.0%}") + print(f"\n Root Causes ({len(summary['root_causes'])}):") + for rc in summary["root_causes"]: + print(f" • {rc['service']:25s} ({rc['confidence']:.0%} confidence)") + print(f" {rc['explanation'][:150]}") + print(f"\n Remediation ({len(summary.get('remediation', []))}):") + for step in summary.get("remediation", []): + print(f" • [{step['priority']:>8}] {step['action']}") + print(f"\n Timeline:") + print(f" {summary['progression']['timeline_summary'][:200]}") + print(f"\nDone.") + + +if __name__ == "__main__": + main() diff --git a/services/ai-investigation-service/Dockerfile b/services/ai-investigation-service/Dockerfile index 38fbdef..43a3bc9 100644 --- a/services/ai-investigation-service/Dockerfile +++ b/services/ai-investigation-service/Dockerfile @@ -11,4 +11,6 @@ RUN pip install --no-cache-dir -r requirements.txt \ COPY services/ai-investigation-service/app ./app COPY services/ai-investigation-service/workflows ./workflows -CMD ["python", "-c", "print('ai-investigation-service scaffold ready')"] +EXPOSE 8002 + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8002"] diff --git a/services/ai-investigation-service/app/config.py b/services/ai-investigation-service/app/config.py index 271e753..dc80da3 100644 --- a/services/ai-investigation-service/app/config.py +++ b/services/ai-investigation-service/app/config.py @@ -2,7 +2,7 @@ class InvestigationServiceConfig(BaseSettings): - model_config = {"env_prefix": "INVESTIGATION_"} + model_config = {"env_prefix": "INVESTIGATION_", "env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"} service_name: str = "ai-investigation-service" environment: str = "development" diff --git a/services/ai-investigation-service/app/main.py b/services/ai-investigation-service/app/main.py new file mode 100644 index 0000000..ac75f03 --- /dev/null +++ b/services/ai-investigation-service/app/main.py @@ -0,0 +1,43 @@ +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager + +from fastapi import FastAPI + +from app.config import InvestigationServiceConfig +from app.pipeline import InvestigationPipeline +from app.routers import health, investigate +from infrastructure.openai.openai_llm_provider import OpenAILLMProvider, OpenAIProviderConfig +from shared.config import load_settings + +logger = __import__("logging").getLogger(__name__) + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncIterator[None]: + settings: InvestigationServiceConfig = load_settings(InvestigationServiceConfig) + app.state.settings = settings + + provider = OpenAILLMProvider(OpenAIProviderConfig()) + await provider.start() + app.state.llm = provider + app.state.pipeline = InvestigationPipeline(provider) + + logger.info("Service started", extra={"service": settings.service_name}) + yield + + await provider.close() + logger.info("Service stopped", extra={"service": settings.service_name}) + + +def create_app() -> FastAPI: + app = FastAPI( + title="ai-investigation-service", + version="0.1.0", + lifespan=lifespan, + ) + app.include_router(health.router) + app.include_router(investigate.router) + return app + + +app = create_app() diff --git a/services/ai-investigation-service/app/routers/__init__.py b/services/ai-investigation-service/app/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/services/ai-investigation-service/app/routers/health.py b/services/ai-investigation-service/app/routers/health.py new file mode 100644 index 0000000..d1a7c9b --- /dev/null +++ b/services/ai-investigation-service/app/routers/health.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter +from pydantic import BaseModel + +from app.config import InvestigationServiceConfig + +router = APIRouter(tags=["health"]) + + +class HealthResponse(BaseModel): + status: str + service: str + environment: str + + +@router.get("/health", response_model=HealthResponse) +async def health() -> HealthResponse: + cfg = InvestigationServiceConfig() + return HealthResponse( + status="healthy", + service=cfg.service_name, + environment=cfg.environment, + ) diff --git a/services/ai-investigation-service/app/routers/investigate.py b/services/ai-investigation-service/app/routers/investigate.py new file mode 100644 index 0000000..111dd2c --- /dev/null +++ b/services/ai-investigation-service/app/routers/investigate.py @@ -0,0 +1,22 @@ +from fastapi import APIRouter, Depends, HTTPException, Request + +from app.pipeline import InvestigationPipeline +from shared.domain.incident.context.models import IncidentContext +from shared.domain.investigation.models import InvestigationResult + +router = APIRouter(prefix="/investigate", tags=["investigate"]) + + +def _get_pipeline(request: Request) -> InvestigationPipeline: + pipeline: InvestigationPipeline | None = getattr(request.app.state, "pipeline", None) + if pipeline is None: + raise HTTPException(status_code=503, detail="Investigation pipeline not available") + return pipeline + + +@router.post("/run", response_model=InvestigationResult) +async def run_investigation( + context: IncidentContext, + pipeline: InvestigationPipeline = Depends(_get_pipeline), +) -> InvestigationResult: + return await pipeline.run(context)