diff --git a/.env.example b/.env.example
index 7a96783..1c3385c 100644
--- a/.env.example
+++ b/.env.example
@@ -46,6 +46,7 @@ SERVICE_NAME=rootpilot-service
# ── OpenAI LLM Provider (prefix OPENAI_) ───────────────────────────────────
OPENAI_API_KEY= # REQUIRED for AI investigation
+# OPENAI_BASE_URL= # defaults to https://api.openai.com/v1
# OPENAI_DEFAULT_MODEL=gpt-4o
# OPENAI_DEFAULT_TEMPERATURE=0.1
# OPENAI_DEFAULT_MAX_TOKENS=4096
diff --git a/infrastructure/openai/openai_llm_provider.py b/infrastructure/openai/openai_llm_provider.py
index 04c968a..a946995 100644
--- a/infrastructure/openai/openai_llm_provider.py
+++ b/infrastructure/openai/openai_llm_provider.py
@@ -1,5 +1,6 @@
"""OpenAI LLM provider adapter implementing the LLMProvider contract."""
+import json
from typing import Any, TypeVar
from pydantic import BaseModel, Field
@@ -11,9 +12,12 @@
class OpenAIProviderConfig(BaseSettings):
- model_config = {"env_prefix": "OPENAI_"}
+ model_config = {"env_prefix": "OPENAI_", "env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"}
api_key: str = Field(default="", description="OpenAI API key.")
+ base_url: str | None = Field(
+ default=None, description="OpenAI API base URL (defaults to https://api.openai.com/v1)."
+ )
default_model: str = Field(default="gpt-4o", description="Default model identifier.")
default_temperature: float = Field(default=0.1, ge=0.0, le=2.0, description="Default generation temperature.")
default_max_tokens: int = Field(default=4096, ge=1, description="Default max tokens per response.")
@@ -35,7 +39,10 @@ def __init__(self, config: OpenAIProviderConfig | None = None) -> None:
async def start(self) -> None:
from openai import AsyncOpenAI
- self._client = AsyncOpenAI(api_key=self._config.api_key, timeout=self._config.timeout_seconds)
+ kwargs: dict[str, Any] = {"api_key": self._config.api_key, "timeout": self._config.timeout_seconds}
+ if self._config.base_url is not None:
+ kwargs["base_url"] = self._config.base_url
+ self._client = AsyncOpenAI(**kwargs)
async def close(self) -> None:
self._client = None
@@ -63,7 +70,23 @@ async def generate_structured(
"json_schema": {"name": schema.__name__, "strict": True, "schema": schema.model_json_schema()},
}
response = await self._client.chat.completions.create(**kwargs)
- return schema.model_validate_json(response.choices[0].message.content)
+ raw = response.choices[0].message.content or ""
+ try:
+ cleaned = self._extract_json(raw)
+ return schema.model_validate_json(cleaned)
+ except (ValueError, Exception):
+ pass
+
+ # Fallback — model doesn't support response_format, retry with JSON instruction in system prompt
+ schema_json = json.dumps(schema.model_json_schema(), indent=2)
+ instruction = f"\n\nYou MUST respond with valid JSON matching this schema:\n```json\n{schema_json}\n```\nReturn ONLY the JSON object — no markdown, no explanations, no tags."
+ fallback = [
+ LLMMessage(role=m.role, content=m.content + instruction) if m.role == "system" else m for m in messages
+ ]
+ fbk = self._build_kwargs(fallback, model)
+ response = await self._client.chat.completions.create(**fbk)
+ cleaned = self._extract_json(response.choices[0].message.content or "")
+ return schema.model_validate_json(cleaned)
async def embed(self, text: str, model: str | None = None) -> list[float]:
model = model or "text-embedding-3-small"
@@ -84,6 +107,24 @@ def _build_kwargs(
"max_tokens": max_tokens or self._config.default_max_tokens,
}
+ @staticmethod
+ def _extract_json(content: str) -> str:
+ """Strip non-JSON wrappers (think tags, markdown fences, leading text) from LLM output."""
+ import re
+
+ # Remove ... blocks
+ content = re.sub(r".*?", "", content, flags=re.DOTALL).strip()
+ # Extract JSON from markdown code fence if present
+ m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, flags=re.DOTALL)
+ if m:
+ return m.group(1).strip()
+ # Fallback: find first { and last }
+ start = content.find("{")
+ end = content.rfind("}")
+ if start != -1 and end != -1 and end > start:
+ return content[start : end + 1]
+ return content
+
@staticmethod
def _to_response(raw: Any) -> LLMResponse:
choice = raw.choices[0]
diff --git a/infrastructure/openai/tests/test_openai_llm_provider.py b/infrastructure/openai/tests/test_openai_llm_provider.py
index 133fce7..1276b29 100644
--- a/infrastructure/openai/tests/test_openai_llm_provider.py
+++ b/infrastructure/openai/tests/test_openai_llm_provider.py
@@ -15,7 +15,7 @@ class _FakeSchema(BaseModel):
@pytest.fixture
def config() -> OpenAIProviderConfig:
- return OpenAIProviderConfig(api_key="test-key")
+ return OpenAIProviderConfig(api_key="test-key", base_url=None, _env_file=None) # type: ignore[call-arg]
@pytest.fixture
@@ -25,11 +25,12 @@ def provider(config: OpenAIProviderConfig) -> OpenAILLMProvider:
class TestOpenAIProviderConfig:
def test_defaults(self) -> None:
- cfg = OpenAIProviderConfig()
+ cfg = OpenAIProviderConfig(_env_file=None) # type: ignore[call-arg]
assert cfg.default_model == "gpt-4o"
assert cfg.default_temperature == 0.1
assert cfg.default_max_tokens == 4096
assert cfg.timeout_seconds == 60
+ assert cfg.base_url is None
def test_env_prefix(self) -> None:
assert OpenAIProviderConfig.model_config.get("env_prefix") == "OPENAI_"
@@ -43,6 +44,13 @@ async def test_start_creates_client(self, provider: OpenAILLMProvider) -> None:
mock.assert_called_once_with(api_key="test-key", timeout=60)
assert provider._client is not None
+ async def test_start_passes_base_url_when_set(self) -> None:
+ cfg = OpenAIProviderConfig(api_key="test-key", base_url="https://custom.openai.com/v1")
+ provider = OpenAILLMProvider(cfg)
+ with patch("openai.AsyncOpenAI") as mock:
+ await provider.start()
+ mock.assert_called_once_with(api_key="test-key", timeout=60, base_url="https://custom.openai.com/v1")
+
async def test_close_clears_client(self, provider: OpenAILLMProvider) -> None:
provider._client = MagicMock()
await provider.close()
@@ -81,6 +89,22 @@ async def test_generate_structured_returns_parsed_schema(self, provider: OpenAIL
assert isinstance(result, _FakeSchema)
assert result.result == "ok"
+ def test_extract_json_strips_think_tags(self) -> None:
+ raw = 'Let me analyze the data.{"result": "ok"}'
+ assert OpenAILLMProvider._extract_json(raw) == '{"result": "ok"}'
+
+ def test_extract_json_strips_markdown_fence(self) -> None:
+ raw = '```json\n{"result": "ok"}\n```'
+ assert OpenAILLMProvider._extract_json(raw) == '{"result": "ok"}'
+
+ def test_extract_json_falls_back_to_braces(self) -> None:
+ raw = 'Some text before {"result": "ok"} and after'
+ assert OpenAILLMProvider._extract_json(raw) == '{"result": "ok"}'
+
+ def test_extract_json_clean_json_passthrough(self) -> None:
+ raw = '{"result": "ok"}'
+ assert OpenAILLMProvider._extract_json(raw) == '{"result": "ok"}'
+
async def test_embed_returns_float_list(self, provider: OpenAILLMProvider) -> None:
provider._client = MagicMock()
mock_embedding = MagicMock()
diff --git a/scripts/simulate_incident.py b/scripts/simulate_incident.py
new file mode 100644
index 0000000..6fa460b
--- /dev/null
+++ b/scripts/simulate_incident.py
@@ -0,0 +1,163 @@
+"""Simulate a realistic database latency spike incident end-to-end.
+
+Sends telemetry to the ingestion service (port 8000), then builds an
+incident context and runs the investigation pipeline (port 8002).
+
+Usage:
+ python -m scripts.simulate_incident
+ # or: python scripts/simulate_incident.py
+"""
+
+import sys
+from datetime import UTC, datetime, timedelta
+
+import httpx
+
+INGESTION_URL = "http://localhost:8000/api/v1/ingest"
+INVESTIGATION_URL = "http://localhost:8002/investigate/run"
+
+
+def _ts(offset_seconds: int = 0) -> str:
+ return (datetime.now(UTC) - timedelta(seconds=offset_seconds)).isoformat()
+
+
+TELEMETRY_BATCH = [
+ # ── Phase 1: DB CPU climb (T-120s to T-60s) ──────────────────────
+ {"metric": "db.cpu.usage", "value": 72.0, "unit": "%", "source": "postgres-primary", "tags": {"service": "db", "region": "us-east-1"}, "timestamp": _ts(120)},
+ {"metric": "db.cpu.usage", "value": 78.0, "unit": "%", "source": "postgres-primary", "tags": {"service": "db", "region": "us-east-1"}, "timestamp": _ts(90)},
+ {"metric": "db.cpu.usage", "value": 85.0, "unit": "%", "source": "postgres-primary", "tags": {"service": "db", "region": "us-east-1"}, "timestamp": _ts(60)},
+ # ── Phase 2: Query latency degrades (T-60s to T-30s) ─────────────
+ {"metric": "db.query.latency", "value": 450.0, "unit": "ms", "source": "postgres-primary", "tags": {"service": "db", "query": "SELECT * FROM orders"}, "timestamp": _ts(60)},
+ {"metric": "db.query.latency", "value": 1200.0, "unit": "ms", "source": "postgres-primary", "tags": {"service": "db", "query": "SELECT * FROM orders"}, "timestamp": _ts(45)},
+ {"metric": "db.query.latency", "value": 2300.0, "unit": "ms", "source": "postgres-primary", "tags": {"service": "db", "query": "SELECT * FROM orders"}, "timestamp": _ts(30)},
+ {"metric": "db.connections.active", "value": 89.0, "unit": "count", "source": "postgres-primary", "tags": {"service": "db", "pool": "primary"}, "timestamp": _ts(30)},
+ # ── Phase 3: API timeouts (T-30s to T-10s) ───────────────────────
+ {"metric": "api.request.latency", "value": 3200.0, "unit": "ms", "source": "api-gateway", "tags": {"service": "api", "endpoint": "/api/orders"}, "timestamp": _ts(30)},
+ {"metric": "api.error.rate", "value": 12.5, "unit": "%", "source": "api-gateway", "tags": {"service": "api", "error": "timeout"}, "timestamp": _ts(25)},
+ {"metric": "api.error.rate", "value": 28.0, "unit": "%", "source": "api-gateway", "tags": {"service": "api", "error": "timeout"}, "timestamp": _ts(15)},
+ {"metric": "api.request.latency", "value": 5100.0, "unit": "ms", "source": "api-gateway", "tags": {"service": "api", "endpoint": "/api/orders"}, "timestamp": _ts(10)},
+ # ── Phase 4: Gateway 502s (T-10s to now) ─────────────────────────
+ {"metric": "gateway.upstream.errors", "value": 42.0, "unit": "count", "source": "gateway", "tags": {"service": "gateway", "upstream": "api", "status": "502"}, "timestamp": _ts(10)},
+ {"metric": "gateway.upstream.errors", "value": 87.0, "unit": "count", "source": "gateway", "tags": {"service": "gateway", "upstream": "api", "status": "502"}, "timestamp": _ts(5)},
+ {"metric": "gateway.latency", "value": 6200.0, "unit": "ms", "source": "gateway", "tags": {"service": "gateway", "endpoint": "/api/orders"}, "timestamp": _ts(3)},
+]
+
+
+def send_telemetry(client: httpx.Client) -> list[str]:
+ ids: list[str] = []
+ for point in TELEMETRY_BATCH:
+ resp = client.post(INGESTION_URL, json=point)
+ resp.raise_for_status()
+ data = resp.json()
+ ids.append(data["event_id"])
+ print(f" ✓ {point['metric']:30s} = {point['value']:>8} {point['unit'] or '':4s} → {data['event_id'][:8]}...")
+ return ids
+
+
+def build_incident_context(event_ids: list[str]) -> dict:
+ return {
+ "incident_id": f"sim-{datetime.now(UTC).strftime('%Y%m%d-%H%M%S')}",
+ "primary_service": "postgres-primary",
+ "severity": "CRITICAL",
+ "title": "Database latency spike cascading to API timeouts and gateway 502 errors",
+ "detected_at": datetime.now(UTC).isoformat(),
+ "event_count": len(event_ids),
+ "service_count": 3,
+ "trace_count": 0,
+ "ungrouped_events": [],
+ "correlation_groups": [
+ {
+ "group_id": "g-db",
+ "event_ids": event_ids[:7],
+ "composite_score": 0.89,
+ "signals": ["time_proximity"],
+ "services": ["postgres-primary"],
+ },
+ {
+ "group_id": "g-api",
+ "event_ids": event_ids[7:11],
+ "composite_score": 0.82,
+ "signals": ["time_proximity"],
+ "services": ["api-gateway"],
+ },
+ {
+ "group_id": "g-gateway",
+ "event_ids": event_ids[11:],
+ "composite_score": 0.75,
+ "signals": ["time_proximity", "dependency_chain"],
+ "services": ["gateway"],
+ },
+ ],
+ "impacts": [
+ {
+ "service": "api-gateway",
+ "upstream_causes": ["postgres-primary"],
+ "downstream_impact": ["gateway"],
+ "propagation_paths": [["postgres-primary", "api-gateway", "gateway"]],
+ },
+ {
+ "service": "gateway",
+ "upstream_causes": ["postgres-primary", "api-gateway"],
+ "downstream_impact": [],
+ "propagation_paths": [],
+ },
+ ],
+ }
+
+
+def run_investigation(client: httpx.Client, context: dict) -> dict:
+ resp = client.post(INVESTIGATION_URL, json=context, timeout=120)
+ resp.raise_for_status()
+ return resp.json()
+
+
+def check_service(url: str, name: str) -> bool:
+ try:
+ resp = httpx.get(url.replace("/api/v1/ingest", "/health").replace("/investigate/run", "/health"), timeout=5)
+ return resp.is_success
+ except httpx.RequestError as e:
+ print(f" ✗ {name} unreachable ({e})")
+ return False
+
+
+def main() -> None:
+ print("Checking services...")
+ health_ok = (
+ check_service(INGESTION_URL, "ingestion (port 8000)")
+ and check_service(INVESTIGATION_URL, "investigation (port 8002)")
+ )
+ if not health_ok:
+ sys.exit(1)
+
+ print("\n── Scenario: Database Latency Spike ──────────────────────")
+ print(" Primary database CPU spikes → query latency jumps to 2.3s\n")
+
+ with httpx.Client() as client:
+ print("Sending 14 telemetry points to ingestion...")
+ event_ids = send_telemetry(client)
+ print(f"\n ✓ {len(event_ids)} events accepted\n")
+
+ print("Building incident context and running investigation...")
+ context = build_incident_context(event_ids)
+ result = run_investigation(client, context)
+
+ print(f"\n── Investigation Result ──────────────────────────────")
+ summary = result["summary"]
+ print(f" Incident: {summary['incident_id']}")
+ print(f" Title: {summary['title']}")
+ print(f" Duration: {result['duration_ms']:.0f}ms")
+ print(f" Confidence: {summary['overall_confidence']:.0%}")
+ print(f"\n Root Causes ({len(summary['root_causes'])}):")
+ for rc in summary["root_causes"]:
+ print(f" • {rc['service']:25s} ({rc['confidence']:.0%} confidence)")
+ print(f" {rc['explanation'][:150]}")
+ print(f"\n Remediation ({len(summary.get('remediation', []))}):")
+ for step in summary.get("remediation", []):
+ print(f" • [{step['priority']:>8}] {step['action']}")
+ print(f"\n Timeline:")
+ print(f" {summary['progression']['timeline_summary'][:200]}")
+ print(f"\nDone.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/services/ai-investigation-service/Dockerfile b/services/ai-investigation-service/Dockerfile
index 38fbdef..43a3bc9 100644
--- a/services/ai-investigation-service/Dockerfile
+++ b/services/ai-investigation-service/Dockerfile
@@ -11,4 +11,6 @@ RUN pip install --no-cache-dir -r requirements.txt \
COPY services/ai-investigation-service/app ./app
COPY services/ai-investigation-service/workflows ./workflows
-CMD ["python", "-c", "print('ai-investigation-service scaffold ready')"]
+EXPOSE 8002
+
+CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8002"]
diff --git a/services/ai-investigation-service/app/config.py b/services/ai-investigation-service/app/config.py
index 271e753..dc80da3 100644
--- a/services/ai-investigation-service/app/config.py
+++ b/services/ai-investigation-service/app/config.py
@@ -2,7 +2,7 @@
class InvestigationServiceConfig(BaseSettings):
- model_config = {"env_prefix": "INVESTIGATION_"}
+ model_config = {"env_prefix": "INVESTIGATION_", "env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"}
service_name: str = "ai-investigation-service"
environment: str = "development"
diff --git a/services/ai-investigation-service/app/main.py b/services/ai-investigation-service/app/main.py
new file mode 100644
index 0000000..ac75f03
--- /dev/null
+++ b/services/ai-investigation-service/app/main.py
@@ -0,0 +1,43 @@
+from collections.abc import AsyncIterator
+from contextlib import asynccontextmanager
+
+from fastapi import FastAPI
+
+from app.config import InvestigationServiceConfig
+from app.pipeline import InvestigationPipeline
+from app.routers import health, investigate
+from infrastructure.openai.openai_llm_provider import OpenAILLMProvider, OpenAIProviderConfig
+from shared.config import load_settings
+
+logger = __import__("logging").getLogger(__name__)
+
+
+@asynccontextmanager
+async def lifespan(app: FastAPI) -> AsyncIterator[None]:
+ settings: InvestigationServiceConfig = load_settings(InvestigationServiceConfig)
+ app.state.settings = settings
+
+ provider = OpenAILLMProvider(OpenAIProviderConfig())
+ await provider.start()
+ app.state.llm = provider
+ app.state.pipeline = InvestigationPipeline(provider)
+
+ logger.info("Service started", extra={"service": settings.service_name})
+ yield
+
+ await provider.close()
+ logger.info("Service stopped", extra={"service": settings.service_name})
+
+
+def create_app() -> FastAPI:
+ app = FastAPI(
+ title="ai-investigation-service",
+ version="0.1.0",
+ lifespan=lifespan,
+ )
+ app.include_router(health.router)
+ app.include_router(investigate.router)
+ return app
+
+
+app = create_app()
diff --git a/services/ai-investigation-service/app/routers/__init__.py b/services/ai-investigation-service/app/routers/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/services/ai-investigation-service/app/routers/health.py b/services/ai-investigation-service/app/routers/health.py
new file mode 100644
index 0000000..d1a7c9b
--- /dev/null
+++ b/services/ai-investigation-service/app/routers/health.py
@@ -0,0 +1,22 @@
+from fastapi import APIRouter
+from pydantic import BaseModel
+
+from app.config import InvestigationServiceConfig
+
+router = APIRouter(tags=["health"])
+
+
+class HealthResponse(BaseModel):
+ status: str
+ service: str
+ environment: str
+
+
+@router.get("/health", response_model=HealthResponse)
+async def health() -> HealthResponse:
+ cfg = InvestigationServiceConfig()
+ return HealthResponse(
+ status="healthy",
+ service=cfg.service_name,
+ environment=cfg.environment,
+ )
diff --git a/services/ai-investigation-service/app/routers/investigate.py b/services/ai-investigation-service/app/routers/investigate.py
new file mode 100644
index 0000000..111dd2c
--- /dev/null
+++ b/services/ai-investigation-service/app/routers/investigate.py
@@ -0,0 +1,22 @@
+from fastapi import APIRouter, Depends, HTTPException, Request
+
+from app.pipeline import InvestigationPipeline
+from shared.domain.incident.context.models import IncidentContext
+from shared.domain.investigation.models import InvestigationResult
+
+router = APIRouter(prefix="/investigate", tags=["investigate"])
+
+
+def _get_pipeline(request: Request) -> InvestigationPipeline:
+ pipeline: InvestigationPipeline | None = getattr(request.app.state, "pipeline", None)
+ if pipeline is None:
+ raise HTTPException(status_code=503, detail="Investigation pipeline not available")
+ return pipeline
+
+
+@router.post("/run", response_model=InvestigationResult)
+async def run_investigation(
+ context: IncidentContext,
+ pipeline: InvestigationPipeline = Depends(_get_pipeline),
+) -> InvestigationResult:
+ return await pipeline.run(context)