Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions infrastructure/openai/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""OpenAI provider implementation for RootPilot."""

from infrastructure.openai.openai_llm_provider import OpenAILLMProvider, OpenAIProviderConfig

__all__ = [
"OpenAILLMProvider",
"OpenAIProviderConfig",
]
101 changes: 101 additions & 0 deletions infrastructure/openai/openai_llm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""OpenAI LLM provider adapter implementing the LLMProvider contract."""

from typing import Any, TypeVar

from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings

from shared.contracts.interfaces.llm_provider import LLMMessage, LLMProvider, LLMResponse

T = TypeVar("T", bound=BaseModel)


class OpenAIProviderConfig(BaseSettings):
model_config = {"env_prefix": "OPENAI_"}

api_key: str = Field(default="", description="OpenAI API key.")
default_model: str = Field(default="gpt-4o", description="Default model identifier.")
default_temperature: float = Field(default=0.1, ge=0.0, le=2.0, description="Default generation temperature.")
default_max_tokens: int = Field(default=4096, ge=1, description="Default max tokens per response.")
timeout_seconds: int = Field(default=60, ge=1, description="Request timeout in seconds.")


class OpenAILLMProvider(LLMProvider):
"""OpenAI-backed LLM provider.

Lifecycle:
Call ``start()`` before first use to initialise the client, and
``close()`` to release resources.
"""

def __init__(self, config: OpenAIProviderConfig | None = None) -> None:
self._config = config or OpenAIProviderConfig()
self._client: Any = None

async def start(self) -> None:
from openai import AsyncOpenAI

self._client = AsyncOpenAI(api_key=self._config.api_key, timeout=self._config.timeout_seconds)

async def close(self) -> None:
self._client = None

async def generate(
self,
messages: list[LLMMessage],
model: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
) -> LLMResponse:
kwargs = self._build_kwargs(messages, model, temperature, max_tokens)
response = await self._client.chat.completions.create(**kwargs)
return self._to_response(response)

async def generate_structured(
self,
messages: list[LLMMessage],
schema: type[T],
model: str | None = None,
) -> T:
kwargs = self._build_kwargs(messages, model)
kwargs["response_format"] = {
"type": "json_schema",
"json_schema": {"name": schema.__name__, "strict": True, "schema": schema.model_json_schema()},
}
response = await self._client.chat.completions.create(**kwargs)
return schema.model_validate_json(response.choices[0].message.content)

async def embed(self, text: str, model: str | None = None) -> list[float]:
model = model or "text-embedding-3-small"
response = await self._client.embeddings.create(input=text, model=model)
return response.data[0].embedding

def _build_kwargs(
self,
messages: list[LLMMessage],
model: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
) -> dict[str, Any]:
return {
"model": model or self._config.default_model,
"messages": [{"role": m.role, "content": m.content} for m in messages],
"temperature": temperature if temperature is not None else self._config.default_temperature,
"max_tokens": max_tokens or self._config.default_max_tokens,
}

@staticmethod
def _to_response(raw: Any) -> LLMResponse:
choice = raw.choices[0]
usage = raw.usage
return LLMResponse(
content=choice.message.content or "",
finish_reason=choice.finish_reason,
usage={
"prompt_tokens": usage.prompt_tokens,
"completion_tokens": usage.completion_tokens,
"total_tokens": usage.total_tokens,
}
if usage
else None,
)
95 changes: 95 additions & 0 deletions infrastructure/openai/tests/test_openai_llm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Tests for the OpenAI LLM provider adapter."""

from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from pydantic import BaseModel, Field

from infrastructure.openai.openai_llm_provider import OpenAILLMProvider, OpenAIProviderConfig
from shared.contracts.interfaces.llm_provider import LLMMessage


class _FakeSchema(BaseModel):
result: str = Field(description="Test field")


@pytest.fixture
def config() -> OpenAIProviderConfig:
return OpenAIProviderConfig(api_key="test-key")


@pytest.fixture
def provider(config: OpenAIProviderConfig) -> OpenAILLMProvider:
return OpenAILLMProvider(config)


class TestOpenAIProviderConfig:
def test_defaults(self) -> None:
cfg = OpenAIProviderConfig()
assert cfg.default_model == "gpt-4o"
assert cfg.default_temperature == 0.1
assert cfg.default_max_tokens == 4096
assert cfg.timeout_seconds == 60

def test_env_prefix(self) -> None:
assert OpenAIProviderConfig.model_config.get("env_prefix") == "OPENAI_"


class TestOpenAILLMProvider:
async def test_start_creates_client(self, provider: OpenAILLMProvider) -> None:
assert provider._client is None
with patch("openai.AsyncOpenAI") as mock:
await provider.start()
mock.assert_called_once_with(api_key="test-key", timeout=60)
assert provider._client is not None

async def test_close_clears_client(self, provider: OpenAILLMProvider) -> None:
provider._client = MagicMock()
await provider.close()
assert provider._client is None

async def test_generate_returns_response(self, provider: OpenAILLMProvider) -> None:
provider._client = MagicMock()
mock_choice = MagicMock()
mock_choice.message.content = "Hello"
mock_choice.finish_reason = "stop"
mock_usage = MagicMock()
mock_usage.prompt_tokens = 10
mock_usage.completion_tokens = 5
mock_usage.total_tokens = 15
provider._client.chat.completions.create = AsyncMock(
return_value=MagicMock(choices=[mock_choice], usage=mock_usage)
)

messages = [LLMMessage(role="user", content="hi")]
response = await provider.generate(messages)

assert response.content == "Hello"
assert response.finish_reason == "stop"
assert response.usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}

async def test_generate_structured_returns_parsed_schema(self, provider: OpenAILLMProvider) -> None:
provider._client = MagicMock()
mock_choice = MagicMock()
mock_choice.message.content = '{"result": "ok"}'
mock_choice.finish_reason = "stop"
provider._client.chat.completions.create = AsyncMock(return_value=MagicMock(choices=[mock_choice], usage=None))

messages = [LLMMessage(role="user", content="test")]
result = await provider.generate_structured(messages, _FakeSchema)

assert isinstance(result, _FakeSchema)
assert result.result == "ok"

async def test_embed_returns_float_list(self, provider: OpenAILLMProvider) -> None:
provider._client = MagicMock()
mock_embedding = MagicMock()
mock_embedding.embedding = [0.1, 0.2, 0.3]
provider._client.embeddings.create = AsyncMock(return_value=MagicMock(data=[mock_embedding]))

result = await provider.embed("test text")
assert result == [0.1, 0.2, 0.3]

async def test_generate_raises_if_not_started(self, provider: OpenAILLMProvider) -> None:
with pytest.raises(AttributeError):
await provider.generate([LLMMessage(role="user", content="hi")])
1 change: 0 additions & 1 deletion services/ai-investigation-service/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ RUN pip install --no-cache-dir -r requirements.txt \
&& pip install --no-cache-dir --no-deps .

COPY services/ai-investigation-service/app ./app
COPY services/ai-investigation-service/prompts ./prompts
COPY services/ai-investigation-service/workflows ./workflows

CMD ["python", "-c", "print('ai-investigation-service scaffold ready')"]
1 change: 1 addition & 0 deletions services/ai-investigation-service/app/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""AI investigation service — deterministic RCA pipeline."""
14 changes: 14 additions & 0 deletions services/ai-investigation-service/app/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from pydantic_settings import BaseSettings


class InvestigationServiceConfig(BaseSettings):
model_config = {"env_prefix": "INVESTIGATION_"}

service_name: str = "ai-investigation-service"
environment: str = "development"
debug: bool = False
log_level: str = "INFO"
llm_provider: str = "openai"
llm_model: str = "gpt-4o"
llm_temperature: float = 0.1
llm_max_tokens: int = 4096
40 changes: 40 additions & 0 deletions services/ai-investigation-service/app/pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""InvestigationPipeline — deterministic RCA orchestration.

Assembles incident context into prompts, calls the LLM for structured output,
and returns an InvestigationResult.
"""

import time

from app.prompts.rca import RCAPrompts
from shared.contracts.interfaces.llm_provider import LLMProvider
from shared.domain.incident.context.models import IncidentContext
from shared.domain.investigation.models import InvestigationResult, RCASummary


class InvestigationPipeline:
"""Deterministic investigation pipeline.

Usage:
pipeline = InvestigationPipeline(llm_provider)
result = await pipeline.run(incident_context)
"""

def __init__(self, llm_provider: LLMProvider) -> None:
self._llm = llm_provider

async def run(self, context: IncidentContext) -> InvestigationResult:
"""Run the full investigation pipeline for a given incident context."""
start = time.perf_counter()

messages = RCAPrompts.build_rca_messages(context)

summary: RCASummary = await self._llm.generate_structured(messages, RCASummary)

duration = (time.perf_counter() - start) * 1000

return InvestigationResult(
summary=summary,
raw_output=None,
duration_ms=round(duration, 2),
)
5 changes: 5 additions & 0 deletions services/ai-investigation-service/app/prompts/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Modular prompt templates for AI investigation workflows."""

from app.prompts.rca import RCAPrompts

__all__ = ["RCAPrompts"]
91 changes: 91 additions & 0 deletions services/ai-investigation-service/app/prompts/rca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""RCA prompt templates — modular, testable, and version-controlled."""

from shared.contracts.interfaces.llm_provider import LLMMessage
from shared.domain.incident.context.models import IncidentContext


class RCAPrompts:
"""Factory for building RCA prompt messages from incident context."""

SYSTEM_PROMPT = """You are a senior Site Reliability Engineer conducting a root cause analysis.

Your task is to analyse the provided incident context and produce a structured RCA summary.

Rules:
- Base your analysis strictly on the evidence provided. Do not speculate beyond the data.
- Rank root causes by confidence based on correlation scores, impact chains, and supporting evidence.
- If evidence is insufficient, reflect low confidence rather than guessing.
- Remediation steps must be concrete, actionable, and ordered by priority.
- Use technical precision — be specific about services, signals, and time windows."""

@classmethod
def build_rca_messages(cls, context: IncidentContext) -> list[LLMMessage]:
"""Build the message list for RCA generation."""
return [
LLMMessage(role="system", content=cls.SYSTEM_PROMPT),
LLMMessage(role="user", content=cls._format_context(context)),
]

@classmethod
def _format_context(cls, context: IncidentContext) -> str:
"""Format the incident context as a structured text prompt."""
lines: list[str] = [
f"## Incident: {context.title or 'Untitled'}",
f"ID: {context.incident_id}",
f"Primary Service: {context.primary_service}",
f"Severity: {context.severity}",
f"Detected At: {context.detected_at.isoformat()}",
"",
f"Total Events: {context.event_count}",
f"Services Involved: {context.service_count}",
f"Unique Traces: {context.trace_count}",
"",
]

if context.timeline:
timeline = context.timeline
lines.append("### Timeline")
lines.append(f"Duration: {timeline.duration_seconds or 'N/A'} seconds")
lines.append(f"Windows: {timeline.window_count}")
lines.append(f"Total Events: {timeline.total_events}")
lines.append("")

if context.correlation_groups:
lines.append("### Correlation Groups")
for i, group in enumerate(context.correlation_groups, 1):
services = ", ".join(group.services) if group.services else "unknown"
lines.append(
f" {i}. Score: {group.composite_score} | Signals: {[s.value for s in group.signals]} | Services: [{services}]"
)
if group.trace_id:
lines.append(f" Trace: {group.trace_id} ({group.span_count} spans)")
lines.append("")

if context.impacts:
lines.append("### Impact Analysis")
for impact in context.impacts:
lines.append(f" Service: {impact.service}")
if impact.upstream_causes:
lines.append(f" Upstream Causes: {', '.join(impact.upstream_causes)}")
if impact.downstream_impact:
lines.append(f" Downstream Impact: {', '.join(impact.downstream_impact)}")
if impact.propagation_paths:
paths = [" -> ".join(p) for p in impact.propagation_paths]
lines.append(f" Propagation Paths: {' | '.join(paths)}")
lines.append("")

if context.trace_groups:
lines.append("### Trace Groups")
for tg in context.trace_groups:
services = ", ".join(tg.service_names) if tg.service_names else "N/A"
lines.append(
f" Trace: {tg.trace_id} | Services: [{services}] | Spans: {tg.span_count} | Events: {len(tg.event_ids)}"
)
lines.append("")

if context.ungrouped_events:
lines.append(f"### Ungrouped Events: {len(context.ungrouped_events)} events below correlation threshold")
lines.append("")

lines.append("Provide your analysis following the output schema exactly.")
return "\n".join(lines)
10 changes: 10 additions & 0 deletions services/ai-investigation-service/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Shared fixtures for ai-investigation-service tests."""

import sys
from pathlib import Path

# Add service root and project root to sys.path
_service_root = Path(__file__).parent.parent
_project_root = _service_root.parent.parent
sys.path.insert(0, str(_service_root))
sys.path.insert(0, str(_project_root))
Loading
Loading