Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ SERVICE_NAME=rootpilot-service
# ── OpenAI LLM Provider (prefix OPENAI_) ───────────────────────────────────

OPENAI_API_KEY= # REQUIRED for AI investigation
# OPENAI_BASE_URL= # defaults to https://api.openai.com/v1
# OPENAI_DEFAULT_MODEL=gpt-4o
# OPENAI_DEFAULT_TEMPERATURE=0.1
# OPENAI_DEFAULT_MAX_TOKENS=4096
Expand Down
47 changes: 44 additions & 3 deletions infrastructure/openai/openai_llm_provider.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""OpenAI LLM provider adapter implementing the LLMProvider contract."""

import json
from typing import Any, TypeVar

from pydantic import BaseModel, Field
Expand All @@ -11,9 +12,12 @@


class OpenAIProviderConfig(BaseSettings):
model_config = {"env_prefix": "OPENAI_"}
model_config = {"env_prefix": "OPENAI_", "env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"}

api_key: str = Field(default="", description="OpenAI API key.")
base_url: str | None = Field(
default=None, description="OpenAI API base URL (defaults to https://api.openai.com/v1)."
)
default_model: str = Field(default="gpt-4o", description="Default model identifier.")
default_temperature: float = Field(default=0.1, ge=0.0, le=2.0, description="Default generation temperature.")
default_max_tokens: int = Field(default=4096, ge=1, description="Default max tokens per response.")
Expand All @@ -35,7 +39,10 @@ def __init__(self, config: OpenAIProviderConfig | None = None) -> None:
async def start(self) -> None:
from openai import AsyncOpenAI

self._client = AsyncOpenAI(api_key=self._config.api_key, timeout=self._config.timeout_seconds)
kwargs: dict[str, Any] = {"api_key": self._config.api_key, "timeout": self._config.timeout_seconds}
if self._config.base_url is not None:
kwargs["base_url"] = self._config.base_url
self._client = AsyncOpenAI(**kwargs)

async def close(self) -> None:
self._client = None
Expand Down Expand Up @@ -63,7 +70,23 @@ async def generate_structured(
"json_schema": {"name": schema.__name__, "strict": True, "schema": schema.model_json_schema()},
}
response = await self._client.chat.completions.create(**kwargs)
return schema.model_validate_json(response.choices[0].message.content)
raw = response.choices[0].message.content or ""
try:
cleaned = self._extract_json(raw)
return schema.model_validate_json(cleaned)
except (ValueError, Exception):
pass

# Fallback — model doesn't support response_format, retry with JSON instruction in system prompt
schema_json = json.dumps(schema.model_json_schema(), indent=2)
instruction = f"\n\nYou MUST respond with valid JSON matching this schema:\n```json\n{schema_json}\n```\nReturn ONLY the JSON object — no markdown, no explanations, no tags."
fallback = [
LLMMessage(role=m.role, content=m.content + instruction) if m.role == "system" else m for m in messages
]
fbk = self._build_kwargs(fallback, model)
response = await self._client.chat.completions.create(**fbk)
cleaned = self._extract_json(response.choices[0].message.content or "")
return schema.model_validate_json(cleaned)

async def embed(self, text: str, model: str | None = None) -> list[float]:
model = model or "text-embedding-3-small"
Expand All @@ -84,6 +107,24 @@ def _build_kwargs(
"max_tokens": max_tokens or self._config.default_max_tokens,
}

@staticmethod
def _extract_json(content: str) -> str:
"""Strip non-JSON wrappers (think tags, markdown fences, leading text) from LLM output."""
import re

# Remove <think>...</think> blocks
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
# Extract JSON from markdown code fence if present
m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", content, flags=re.DOTALL)
if m:
return m.group(1).strip()
# Fallback: find first { and last }
start = content.find("{")
end = content.rfind("}")
if start != -1 and end != -1 and end > start:
return content[start : end + 1]
return content

@staticmethod
def _to_response(raw: Any) -> LLMResponse:
choice = raw.choices[0]
Expand Down
28 changes: 26 additions & 2 deletions infrastructure/openai/tests/test_openai_llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class _FakeSchema(BaseModel):

@pytest.fixture
def config() -> OpenAIProviderConfig:
return OpenAIProviderConfig(api_key="test-key")
return OpenAIProviderConfig(api_key="test-key", base_url=None, _env_file=None) # type: ignore[call-arg]


@pytest.fixture
Expand All @@ -25,11 +25,12 @@ def provider(config: OpenAIProviderConfig) -> OpenAILLMProvider:

class TestOpenAIProviderConfig:
def test_defaults(self) -> None:
cfg = OpenAIProviderConfig()
cfg = OpenAIProviderConfig(_env_file=None) # type: ignore[call-arg]
assert cfg.default_model == "gpt-4o"
assert cfg.default_temperature == 0.1
assert cfg.default_max_tokens == 4096
assert cfg.timeout_seconds == 60
assert cfg.base_url is None

def test_env_prefix(self) -> None:
assert OpenAIProviderConfig.model_config.get("env_prefix") == "OPENAI_"
Expand All @@ -43,6 +44,13 @@ async def test_start_creates_client(self, provider: OpenAILLMProvider) -> None:
mock.assert_called_once_with(api_key="test-key", timeout=60)
assert provider._client is not None

async def test_start_passes_base_url_when_set(self) -> None:
cfg = OpenAIProviderConfig(api_key="test-key", base_url="https://custom.openai.com/v1")
provider = OpenAILLMProvider(cfg)
with patch("openai.AsyncOpenAI") as mock:
await provider.start()
mock.assert_called_once_with(api_key="test-key", timeout=60, base_url="https://custom.openai.com/v1")

async def test_close_clears_client(self, provider: OpenAILLMProvider) -> None:
provider._client = MagicMock()
await provider.close()
Expand Down Expand Up @@ -81,6 +89,22 @@ async def test_generate_structured_returns_parsed_schema(self, provider: OpenAIL
assert isinstance(result, _FakeSchema)
assert result.result == "ok"

def test_extract_json_strips_think_tags(self) -> None:
raw = '<think>Let me analyze the data.</think>{"result": "ok"}'
assert OpenAILLMProvider._extract_json(raw) == '{"result": "ok"}'

def test_extract_json_strips_markdown_fence(self) -> None:
raw = '```json\n{"result": "ok"}\n```'
assert OpenAILLMProvider._extract_json(raw) == '{"result": "ok"}'

def test_extract_json_falls_back_to_braces(self) -> None:
raw = 'Some text before {"result": "ok"} and after'
assert OpenAILLMProvider._extract_json(raw) == '{"result": "ok"}'

def test_extract_json_clean_json_passthrough(self) -> None:
raw = '{"result": "ok"}'
assert OpenAILLMProvider._extract_json(raw) == '{"result": "ok"}'

async def test_embed_returns_float_list(self, provider: OpenAILLMProvider) -> None:
provider._client = MagicMock()
mock_embedding = MagicMock()
Expand Down
163 changes: 163 additions & 0 deletions scripts/simulate_incident.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Simulate a realistic database latency spike incident end-to-end.

Sends telemetry to the ingestion service (port 8000), then builds an
incident context and runs the investigation pipeline (port 8002).

Usage:
python -m scripts.simulate_incident
# or: python scripts/simulate_incident.py
"""

import sys
from datetime import UTC, datetime, timedelta

import httpx

INGESTION_URL = "http://localhost:8000/api/v1/ingest"
INVESTIGATION_URL = "http://localhost:8002/investigate/run"


def _ts(offset_seconds: int = 0) -> str:
return (datetime.now(UTC) - timedelta(seconds=offset_seconds)).isoformat()


TELEMETRY_BATCH = [
# ── Phase 1: DB CPU climb (T-120s to T-60s) ──────────────────────
{"metric": "db.cpu.usage", "value": 72.0, "unit": "%", "source": "postgres-primary", "tags": {"service": "db", "region": "us-east-1"}, "timestamp": _ts(120)},
{"metric": "db.cpu.usage", "value": 78.0, "unit": "%", "source": "postgres-primary", "tags": {"service": "db", "region": "us-east-1"}, "timestamp": _ts(90)},
{"metric": "db.cpu.usage", "value": 85.0, "unit": "%", "source": "postgres-primary", "tags": {"service": "db", "region": "us-east-1"}, "timestamp": _ts(60)},
# ── Phase 2: Query latency degrades (T-60s to T-30s) ─────────────
{"metric": "db.query.latency", "value": 450.0, "unit": "ms", "source": "postgres-primary", "tags": {"service": "db", "query": "SELECT * FROM orders"}, "timestamp": _ts(60)},
{"metric": "db.query.latency", "value": 1200.0, "unit": "ms", "source": "postgres-primary", "tags": {"service": "db", "query": "SELECT * FROM orders"}, "timestamp": _ts(45)},
{"metric": "db.query.latency", "value": 2300.0, "unit": "ms", "source": "postgres-primary", "tags": {"service": "db", "query": "SELECT * FROM orders"}, "timestamp": _ts(30)},
{"metric": "db.connections.active", "value": 89.0, "unit": "count", "source": "postgres-primary", "tags": {"service": "db", "pool": "primary"}, "timestamp": _ts(30)},
# ── Phase 3: API timeouts (T-30s to T-10s) ───────────────────────
{"metric": "api.request.latency", "value": 3200.0, "unit": "ms", "source": "api-gateway", "tags": {"service": "api", "endpoint": "/api/orders"}, "timestamp": _ts(30)},
{"metric": "api.error.rate", "value": 12.5, "unit": "%", "source": "api-gateway", "tags": {"service": "api", "error": "timeout"}, "timestamp": _ts(25)},
{"metric": "api.error.rate", "value": 28.0, "unit": "%", "source": "api-gateway", "tags": {"service": "api", "error": "timeout"}, "timestamp": _ts(15)},
{"metric": "api.request.latency", "value": 5100.0, "unit": "ms", "source": "api-gateway", "tags": {"service": "api", "endpoint": "/api/orders"}, "timestamp": _ts(10)},
# ── Phase 4: Gateway 502s (T-10s to now) ─────────────────────────
{"metric": "gateway.upstream.errors", "value": 42.0, "unit": "count", "source": "gateway", "tags": {"service": "gateway", "upstream": "api", "status": "502"}, "timestamp": _ts(10)},
{"metric": "gateway.upstream.errors", "value": 87.0, "unit": "count", "source": "gateway", "tags": {"service": "gateway", "upstream": "api", "status": "502"}, "timestamp": _ts(5)},
{"metric": "gateway.latency", "value": 6200.0, "unit": "ms", "source": "gateway", "tags": {"service": "gateway", "endpoint": "/api/orders"}, "timestamp": _ts(3)},
]


def send_telemetry(client: httpx.Client) -> list[str]:
ids: list[str] = []
for point in TELEMETRY_BATCH:
resp = client.post(INGESTION_URL, json=point)
resp.raise_for_status()
data = resp.json()
ids.append(data["event_id"])
print(f" ✓ {point['metric']:30s} = {point['value']:>8} {point['unit'] or '':4s} → {data['event_id'][:8]}...")
return ids


def build_incident_context(event_ids: list[str]) -> dict:
return {
"incident_id": f"sim-{datetime.now(UTC).strftime('%Y%m%d-%H%M%S')}",
"primary_service": "postgres-primary",
"severity": "CRITICAL",
"title": "Database latency spike cascading to API timeouts and gateway 502 errors",
"detected_at": datetime.now(UTC).isoformat(),
"event_count": len(event_ids),
"service_count": 3,
"trace_count": 0,
"ungrouped_events": [],
"correlation_groups": [
{
"group_id": "g-db",
"event_ids": event_ids[:7],
"composite_score": 0.89,
"signals": ["time_proximity"],
"services": ["postgres-primary"],
},
{
"group_id": "g-api",
"event_ids": event_ids[7:11],
"composite_score": 0.82,
"signals": ["time_proximity"],
"services": ["api-gateway"],
},
{
"group_id": "g-gateway",
"event_ids": event_ids[11:],
"composite_score": 0.75,
"signals": ["time_proximity", "dependency_chain"],
"services": ["gateway"],
},
],
"impacts": [
{
"service": "api-gateway",
"upstream_causes": ["postgres-primary"],
"downstream_impact": ["gateway"],
"propagation_paths": [["postgres-primary", "api-gateway", "gateway"]],
},
{
"service": "gateway",
"upstream_causes": ["postgres-primary", "api-gateway"],
"downstream_impact": [],
"propagation_paths": [],
},
],
}


def run_investigation(client: httpx.Client, context: dict) -> dict:
resp = client.post(INVESTIGATION_URL, json=context, timeout=120)
resp.raise_for_status()
return resp.json()


def check_service(url: str, name: str) -> bool:
try:
resp = httpx.get(url.replace("/api/v1/ingest", "/health").replace("/investigate/run", "/health"), timeout=5)
return resp.is_success
except httpx.RequestError as e:
print(f" ✗ {name} unreachable ({e})")
return False


def main() -> None:
print("Checking services...")
health_ok = (
check_service(INGESTION_URL, "ingestion (port 8000)")
and check_service(INVESTIGATION_URL, "investigation (port 8002)")
)
if not health_ok:
sys.exit(1)

print("\n── Scenario: Database Latency Spike ──────────────────────")
print(" Primary database CPU spikes → query latency jumps to 2.3s\n")

with httpx.Client() as client:
print("Sending 14 telemetry points to ingestion...")
event_ids = send_telemetry(client)
print(f"\n ✓ {len(event_ids)} events accepted\n")

print("Building incident context and running investigation...")
context = build_incident_context(event_ids)
result = run_investigation(client, context)

print(f"\n── Investigation Result ──────────────────────────────")
summary = result["summary"]
print(f" Incident: {summary['incident_id']}")
print(f" Title: {summary['title']}")
print(f" Duration: {result['duration_ms']:.0f}ms")
print(f" Confidence: {summary['overall_confidence']:.0%}")
print(f"\n Root Causes ({len(summary['root_causes'])}):")
for rc in summary["root_causes"]:
print(f" • {rc['service']:25s} ({rc['confidence']:.0%} confidence)")
print(f" {rc['explanation'][:150]}")
print(f"\n Remediation ({len(summary.get('remediation', []))}):")
for step in summary.get("remediation", []):
print(f" • [{step['priority']:>8}] {step['action']}")
print(f"\n Timeline:")
print(f" {summary['progression']['timeline_summary'][:200]}")
print(f"\nDone.")


if __name__ == "__main__":
main()
4 changes: 3 additions & 1 deletion services/ai-investigation-service/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,6 @@ RUN pip install --no-cache-dir -r requirements.txt \
COPY services/ai-investigation-service/app ./app
COPY services/ai-investigation-service/workflows ./workflows

CMD ["python", "-c", "print('ai-investigation-service scaffold ready')"]
EXPOSE 8002

CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8002"]
2 changes: 1 addition & 1 deletion services/ai-investigation-service/app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class InvestigationServiceConfig(BaseSettings):
model_config = {"env_prefix": "INVESTIGATION_"}
model_config = {"env_prefix": "INVESTIGATION_", "env_file": ".env", "env_file_encoding": "utf-8", "extra": "ignore"}

service_name: str = "ai-investigation-service"
environment: str = "development"
Expand Down
43 changes: 43 additions & 0 deletions services/ai-investigation-service/app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager

from fastapi import FastAPI

from app.config import InvestigationServiceConfig
from app.pipeline import InvestigationPipeline
from app.routers import health, investigate
from infrastructure.openai.openai_llm_provider import OpenAILLMProvider, OpenAIProviderConfig
from shared.config import load_settings

logger = __import__("logging").getLogger(__name__)


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
settings: InvestigationServiceConfig = load_settings(InvestigationServiceConfig)
app.state.settings = settings

provider = OpenAILLMProvider(OpenAIProviderConfig())
await provider.start()
app.state.llm = provider
app.state.pipeline = InvestigationPipeline(provider)

logger.info("Service started", extra={"service": settings.service_name})
yield

await provider.close()
logger.info("Service stopped", extra={"service": settings.service_name})


def create_app() -> FastAPI:
app = FastAPI(
title="ai-investigation-service",
version="0.1.0",
lifespan=lifespan,
)
app.include_router(health.router)
app.include_router(investigate.router)
return app


app = create_app()
Empty file.
22 changes: 22 additions & 0 deletions services/ai-investigation-service/app/routers/health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from fastapi import APIRouter
from pydantic import BaseModel

from app.config import InvestigationServiceConfig

router = APIRouter(tags=["health"])


class HealthResponse(BaseModel):
status: str
service: str
environment: str


@router.get("/health", response_model=HealthResponse)
async def health() -> HealthResponse:
cfg = InvestigationServiceConfig()
return HealthResponse(
status="healthy",
service=cfg.service_name,
environment=cfg.environment,
)
Loading
Loading