From c19cda3694c4bebb3e38826ad8adb65e8830a84d Mon Sep 17 00:00:00 2001 From: saurabh batham Date: Sun, 14 Jun 2026 21:07:09 +0530 Subject: [PATCH 1/2] feat(ai-investigation): implement AI investigation pipeline with OpenAI provider - Add InvestigationPipeline for deterministic RCA orchestration - Add OpenAILLMProvider adapter with generate/generate_structured/embed - Add RCA prompt templates in modular RCAPrompts class - Add investigation domain models (RCASummary, RootCause, etc.) - Fix Pylance warnings: use .get() for SettingsConfigDict, cast AsyncMock, use Any for kwargs --- infrastructure/openai/__init__.py | 8 ++ infrastructure/openai/openai_llm_provider.py | 101 +++++++++++++++ .../openai/tests/__init__.py | 0 .../openai/tests/test_openai_llm_provider.py | 95 ++++++++++++++ .../ai-investigation-service/app/__init__.py | 1 + .../ai-investigation-service/app/config.py | 14 ++ .../ai-investigation-service/app/pipeline.py | 41 ++++++ .../prompts/__init__.py | 5 + .../ai-investigation-service/prompts/rca.py | 91 +++++++++++++ .../{prompts/.gitkeep => tests/__init__.py} | 0 .../tests/conftest.py | 10 ++ .../tests/test_pipeline.py | 114 ++++++++++++++++ .../tests/test_prompts.py | 122 ++++++++++++++++++ shared/domain/investigation/__init__.py | 17 +++ shared/domain/investigation/models.py | 53 ++++++++ .../domain/investigation/tests/__init__.py | 0 .../domain/investigation/tests/test_models.py | 88 +++++++++++++ 17 files changed, 760 insertions(+) create mode 100644 infrastructure/openai/__init__.py create mode 100644 infrastructure/openai/openai_llm_provider.py rename services/ai-investigation-service/app/.gitkeep => infrastructure/openai/tests/__init__.py (100%) create mode 100644 infrastructure/openai/tests/test_openai_llm_provider.py create mode 100644 services/ai-investigation-service/app/__init__.py create mode 100644 services/ai-investigation-service/app/config.py create mode 100644 services/ai-investigation-service/app/pipeline.py create mode 100644 services/ai-investigation-service/prompts/__init__.py create mode 100644 services/ai-investigation-service/prompts/rca.py rename services/ai-investigation-service/{prompts/.gitkeep => tests/__init__.py} (100%) create mode 100644 services/ai-investigation-service/tests/conftest.py create mode 100644 services/ai-investigation-service/tests/test_pipeline.py create mode 100644 services/ai-investigation-service/tests/test_prompts.py create mode 100644 shared/domain/investigation/__init__.py create mode 100644 shared/domain/investigation/models.py rename services/ai-investigation-service/tests/.gitkeep => shared/domain/investigation/tests/__init__.py (100%) create mode 100644 shared/domain/investigation/tests/test_models.py diff --git a/infrastructure/openai/__init__.py b/infrastructure/openai/__init__.py new file mode 100644 index 0000000..a18b8be --- /dev/null +++ b/infrastructure/openai/__init__.py @@ -0,0 +1,8 @@ +"""OpenAI provider implementation for RootPilot.""" + +from infrastructure.openai.openai_llm_provider import OpenAILLMProvider, OpenAIProviderConfig + +__all__ = [ + "OpenAILLMProvider", + "OpenAIProviderConfig", +] diff --git a/infrastructure/openai/openai_llm_provider.py b/infrastructure/openai/openai_llm_provider.py new file mode 100644 index 0000000..04c968a --- /dev/null +++ b/infrastructure/openai/openai_llm_provider.py @@ -0,0 +1,101 @@ +"""OpenAI LLM provider adapter implementing the LLMProvider contract.""" + +from typing import Any, TypeVar + +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings + +from shared.contracts.interfaces.llm_provider import LLMMessage, LLMProvider, LLMResponse + +T = TypeVar("T", bound=BaseModel) + + +class OpenAIProviderConfig(BaseSettings): + model_config = {"env_prefix": "OPENAI_"} + + api_key: str = Field(default="", description="OpenAI API key.") + default_model: str = Field(default="gpt-4o", description="Default model identifier.") + default_temperature: float = Field(default=0.1, ge=0.0, le=2.0, description="Default generation temperature.") + default_max_tokens: int = Field(default=4096, ge=1, description="Default max tokens per response.") + timeout_seconds: int = Field(default=60, ge=1, description="Request timeout in seconds.") + + +class OpenAILLMProvider(LLMProvider): + """OpenAI-backed LLM provider. + + Lifecycle: + Call ``start()`` before first use to initialise the client, and + ``close()`` to release resources. + """ + + def __init__(self, config: OpenAIProviderConfig | None = None) -> None: + self._config = config or OpenAIProviderConfig() + self._client: Any = None + + async def start(self) -> None: + from openai import AsyncOpenAI + + self._client = AsyncOpenAI(api_key=self._config.api_key, timeout=self._config.timeout_seconds) + + async def close(self) -> None: + self._client = None + + async def generate( + self, + messages: list[LLMMessage], + model: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + ) -> LLMResponse: + kwargs = self._build_kwargs(messages, model, temperature, max_tokens) + response = await self._client.chat.completions.create(**kwargs) + return self._to_response(response) + + async def generate_structured( + self, + messages: list[LLMMessage], + schema: type[T], + model: str | None = None, + ) -> T: + kwargs = self._build_kwargs(messages, model) + kwargs["response_format"] = { + "type": "json_schema", + "json_schema": {"name": schema.__name__, "strict": True, "schema": schema.model_json_schema()}, + } + response = await self._client.chat.completions.create(**kwargs) + return schema.model_validate_json(response.choices[0].message.content) + + async def embed(self, text: str, model: str | None = None) -> list[float]: + model = model or "text-embedding-3-small" + response = await self._client.embeddings.create(input=text, model=model) + return response.data[0].embedding + + def _build_kwargs( + self, + messages: list[LLMMessage], + model: str | None = None, + temperature: float | None = None, + max_tokens: int | None = None, + ) -> dict[str, Any]: + return { + "model": model or self._config.default_model, + "messages": [{"role": m.role, "content": m.content} for m in messages], + "temperature": temperature if temperature is not None else self._config.default_temperature, + "max_tokens": max_tokens or self._config.default_max_tokens, + } + + @staticmethod + def _to_response(raw: Any) -> LLMResponse: + choice = raw.choices[0] + usage = raw.usage + return LLMResponse( + content=choice.message.content or "", + finish_reason=choice.finish_reason, + usage={ + "prompt_tokens": usage.prompt_tokens, + "completion_tokens": usage.completion_tokens, + "total_tokens": usage.total_tokens, + } + if usage + else None, + ) diff --git a/services/ai-investigation-service/app/.gitkeep b/infrastructure/openai/tests/__init__.py similarity index 100% rename from services/ai-investigation-service/app/.gitkeep rename to infrastructure/openai/tests/__init__.py diff --git a/infrastructure/openai/tests/test_openai_llm_provider.py b/infrastructure/openai/tests/test_openai_llm_provider.py new file mode 100644 index 0000000..133fce7 --- /dev/null +++ b/infrastructure/openai/tests/test_openai_llm_provider.py @@ -0,0 +1,95 @@ +"""Tests for the OpenAI LLM provider adapter.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from pydantic import BaseModel, Field + +from infrastructure.openai.openai_llm_provider import OpenAILLMProvider, OpenAIProviderConfig +from shared.contracts.interfaces.llm_provider import LLMMessage + + +class _FakeSchema(BaseModel): + result: str = Field(description="Test field") + + +@pytest.fixture +def config() -> OpenAIProviderConfig: + return OpenAIProviderConfig(api_key="test-key") + + +@pytest.fixture +def provider(config: OpenAIProviderConfig) -> OpenAILLMProvider: + return OpenAILLMProvider(config) + + +class TestOpenAIProviderConfig: + def test_defaults(self) -> None: + cfg = OpenAIProviderConfig() + assert cfg.default_model == "gpt-4o" + assert cfg.default_temperature == 0.1 + assert cfg.default_max_tokens == 4096 + assert cfg.timeout_seconds == 60 + + def test_env_prefix(self) -> None: + assert OpenAIProviderConfig.model_config.get("env_prefix") == "OPENAI_" + + +class TestOpenAILLMProvider: + async def test_start_creates_client(self, provider: OpenAILLMProvider) -> None: + assert provider._client is None + with patch("openai.AsyncOpenAI") as mock: + await provider.start() + mock.assert_called_once_with(api_key="test-key", timeout=60) + assert provider._client is not None + + async def test_close_clears_client(self, provider: OpenAILLMProvider) -> None: + provider._client = MagicMock() + await provider.close() + assert provider._client is None + + async def test_generate_returns_response(self, provider: OpenAILLMProvider) -> None: + provider._client = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = "Hello" + mock_choice.finish_reason = "stop" + mock_usage = MagicMock() + mock_usage.prompt_tokens = 10 + mock_usage.completion_tokens = 5 + mock_usage.total_tokens = 15 + provider._client.chat.completions.create = AsyncMock( + return_value=MagicMock(choices=[mock_choice], usage=mock_usage) + ) + + messages = [LLMMessage(role="user", content="hi")] + response = await provider.generate(messages) + + assert response.content == "Hello" + assert response.finish_reason == "stop" + assert response.usage == {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15} + + async def test_generate_structured_returns_parsed_schema(self, provider: OpenAILLMProvider) -> None: + provider._client = MagicMock() + mock_choice = MagicMock() + mock_choice.message.content = '{"result": "ok"}' + mock_choice.finish_reason = "stop" + provider._client.chat.completions.create = AsyncMock(return_value=MagicMock(choices=[mock_choice], usage=None)) + + messages = [LLMMessage(role="user", content="test")] + result = await provider.generate_structured(messages, _FakeSchema) + + assert isinstance(result, _FakeSchema) + assert result.result == "ok" + + async def test_embed_returns_float_list(self, provider: OpenAILLMProvider) -> None: + provider._client = MagicMock() + mock_embedding = MagicMock() + mock_embedding.embedding = [0.1, 0.2, 0.3] + provider._client.embeddings.create = AsyncMock(return_value=MagicMock(data=[mock_embedding])) + + result = await provider.embed("test text") + assert result == [0.1, 0.2, 0.3] + + async def test_generate_raises_if_not_started(self, provider: OpenAILLMProvider) -> None: + with pytest.raises(AttributeError): + await provider.generate([LLMMessage(role="user", content="hi")]) diff --git a/services/ai-investigation-service/app/__init__.py b/services/ai-investigation-service/app/__init__.py new file mode 100644 index 0000000..fc77d14 --- /dev/null +++ b/services/ai-investigation-service/app/__init__.py @@ -0,0 +1 @@ +"""AI investigation service — deterministic RCA pipeline.""" diff --git a/services/ai-investigation-service/app/config.py b/services/ai-investigation-service/app/config.py new file mode 100644 index 0000000..271e753 --- /dev/null +++ b/services/ai-investigation-service/app/config.py @@ -0,0 +1,14 @@ +from pydantic_settings import BaseSettings + + +class InvestigationServiceConfig(BaseSettings): + model_config = {"env_prefix": "INVESTIGATION_"} + + service_name: str = "ai-investigation-service" + environment: str = "development" + debug: bool = False + log_level: str = "INFO" + llm_provider: str = "openai" + llm_model: str = "gpt-4o" + llm_temperature: float = 0.1 + llm_max_tokens: int = 4096 diff --git a/services/ai-investigation-service/app/pipeline.py b/services/ai-investigation-service/app/pipeline.py new file mode 100644 index 0000000..b53c779 --- /dev/null +++ b/services/ai-investigation-service/app/pipeline.py @@ -0,0 +1,41 @@ +"""InvestigationPipeline — deterministic RCA orchestration. + +Assembles incident context into prompts, calls the LLM for structured output, +and returns an InvestigationResult. +""" + +import time + +from prompts.rca import RCAPrompts + +from shared.contracts.interfaces.llm_provider import LLMProvider +from shared.domain.incident.context.models import IncidentContext +from shared.domain.investigation.models import InvestigationResult, RCASummary + + +class InvestigationPipeline: + """Deterministic investigation pipeline. + + Usage: + pipeline = InvestigationPipeline(llm_provider) + result = await pipeline.run(incident_context) + """ + + def __init__(self, llm_provider: LLMProvider) -> None: + self._llm = llm_provider + + async def run(self, context: IncidentContext) -> InvestigationResult: + """Run the full investigation pipeline for a given incident context.""" + start = time.perf_counter() + + messages = RCAPrompts.build_rca_messages(context) + + summary: RCASummary = await self._llm.generate_structured(messages, RCASummary) + + duration = (time.perf_counter() - start) * 1000 + + return InvestigationResult( + summary=summary, + raw_output=None, + duration_ms=round(duration, 2), + ) diff --git a/services/ai-investigation-service/prompts/__init__.py b/services/ai-investigation-service/prompts/__init__.py new file mode 100644 index 0000000..179ea51 --- /dev/null +++ b/services/ai-investigation-service/prompts/__init__.py @@ -0,0 +1,5 @@ +"""Modular prompt templates for AI investigation workflows.""" + +from prompts.rca import RCAPrompts + +__all__ = ["RCAPrompts"] diff --git a/services/ai-investigation-service/prompts/rca.py b/services/ai-investigation-service/prompts/rca.py new file mode 100644 index 0000000..cf25042 --- /dev/null +++ b/services/ai-investigation-service/prompts/rca.py @@ -0,0 +1,91 @@ +"""RCA prompt templates — modular, testable, and version-controlled.""" + +from shared.contracts.interfaces.llm_provider import LLMMessage +from shared.domain.incident.context.models import IncidentContext + + +class RCAPrompts: + """Factory for building RCA prompt messages from incident context.""" + + SYSTEM_PROMPT = """You are a senior Site Reliability Engineer conducting a root cause analysis. + +Your task is to analyse the provided incident context and produce a structured RCA summary. + +Rules: +- Base your analysis strictly on the evidence provided. Do not speculate beyond the data. +- Rank root causes by confidence based on correlation scores, impact chains, and supporting evidence. +- If evidence is insufficient, reflect low confidence rather than guessing. +- Remediation steps must be concrete, actionable, and ordered by priority. +- Use technical precision — be specific about services, signals, and time windows.""" + + @classmethod + def build_rca_messages(cls, context: IncidentContext) -> list[LLMMessage]: + """Build the message list for RCA generation.""" + return [ + LLMMessage(role="system", content=cls.SYSTEM_PROMPT), + LLMMessage(role="user", content=cls._format_context(context)), + ] + + @classmethod + def _format_context(cls, context: IncidentContext) -> str: + """Format the incident context as a structured text prompt.""" + lines: list[str] = [ + f"## Incident: {context.title or 'Untitled'}", + f"ID: {context.incident_id}", + f"Primary Service: {context.primary_service}", + f"Severity: {context.severity}", + f"Detected At: {context.detected_at.isoformat()}", + "", + f"Total Events: {context.event_count}", + f"Services Involved: {context.service_count}", + f"Unique Traces: {context.trace_count}", + "", + ] + + if context.timeline: + timeline = context.timeline + lines.append("### Timeline") + lines.append(f"Duration: {timeline.duration_seconds or 'N/A'} seconds") + lines.append(f"Windows: {timeline.window_count}") + lines.append(f"Total Events: {timeline.total_events}") + lines.append("") + + if context.correlation_groups: + lines.append("### Correlation Groups") + for i, group in enumerate(context.correlation_groups, 1): + services = ", ".join(group.services) if group.services else "unknown" + lines.append( + f" {i}. Score: {group.composite_score} | Signals: {[s.value for s in group.signals]} | Services: [{services}]" + ) + if group.trace_id: + lines.append(f" Trace: {group.trace_id} ({group.span_count} spans)") + lines.append("") + + if context.impacts: + lines.append("### Impact Analysis") + for impact in context.impacts: + lines.append(f" Service: {impact.service}") + if impact.upstream_causes: + lines.append(f" Upstream Causes: {', '.join(impact.upstream_causes)}") + if impact.downstream_impact: + lines.append(f" Downstream Impact: {', '.join(impact.downstream_impact)}") + if impact.propagation_paths: + paths = [" -> ".join(p) for p in impact.propagation_paths] + lines.append(f" Propagation Paths: {' | '.join(paths)}") + lines.append("") + + if context.trace_groups: + lines.append("### Trace Groups") + for tg in context.trace_groups: + services = ", ".join(tg.service_names) if tg.service_names else "N/A" + lines.append( + f" Trace: {tg.trace_id} | Services: [{services}] | Spans: {tg.span_count} | Events: {len(tg.event_ids)}" + ) + lines.append("") + + if context.ungrouped_events: + lines.append(f"### Ungrouped Events: {len(context.ungrouped_events)} events below correlation threshold") + lines.append("") + + lines.append("Provide your analysis following the output schema exactly.") + return "\n".join(lines) diff --git a/services/ai-investigation-service/prompts/.gitkeep b/services/ai-investigation-service/tests/__init__.py similarity index 100% rename from services/ai-investigation-service/prompts/.gitkeep rename to services/ai-investigation-service/tests/__init__.py diff --git a/services/ai-investigation-service/tests/conftest.py b/services/ai-investigation-service/tests/conftest.py new file mode 100644 index 0000000..2203819 --- /dev/null +++ b/services/ai-investigation-service/tests/conftest.py @@ -0,0 +1,10 @@ +"""Shared fixtures for ai-investigation-service tests.""" + +import sys +from pathlib import Path + +# Add service root and project root to sys.path +_service_root = Path(__file__).parent.parent +_project_root = _service_root.parent.parent +sys.path.insert(0, str(_service_root)) +sys.path.insert(0, str(_project_root)) diff --git a/services/ai-investigation-service/tests/test_pipeline.py b/services/ai-investigation-service/tests/test_pipeline.py new file mode 100644 index 0000000..e7720eb --- /dev/null +++ b/services/ai-investigation-service/tests/test_pipeline.py @@ -0,0 +1,114 @@ +"""Tests for InvestigationPipeline.""" + +from datetime import UTC, datetime +from unittest.mock import AsyncMock + +import pytest +from app.pipeline import InvestigationPipeline + +from shared.contracts.interfaces.llm_provider import LLMProvider +from shared.domain.incident.context.models import IncidentContext +from shared.domain.investigation.models import ( + IncidentProgression, + InvestigationResult, + RCASummary, + RemediationStep, + RootCause, +) + + +@pytest.fixture +def mock_llm() -> LLMProvider: + provider = AsyncMock(spec=LLMProvider) + + async def structured_side_effect(messages, schema, model=None): + rc = RootCause(service="api", confidence=0.85, evidence=["5xx spike"], explanation="timeout cascade") + prog = IncidentProgression( + sequence=["deploy", "degradation"], timeline_summary="post-deploy regression", key_transitions=["paging"] + ) + step = RemediationStep(action="rollback", service="api", priority="critical", expected_impact="restore service") + return RCASummary( + incident_id="inc-001", + title="API degradation", + root_causes=[rc], + progression=prog, + remediation=[step], + overall_confidence=0.85, + ) + + provider.generate_structured = AsyncMock(side_effect=structured_side_effect) + return provider + + +@pytest.fixture +def incident_context() -> IncidentContext: + return IncidentContext( + incident_id="inc-001", + primary_service="api", + severity="CRITICAL", + title="API degradation", + detected_at=datetime.now(UTC), + ) + + +class TestInvestigationPipeline: + async def test_run_returns_investigation_result( + self, mock_llm: LLMProvider, incident_context: IncidentContext + ) -> None: + pipeline = InvestigationPipeline(mock_llm) + result = await pipeline.run(incident_context) + assert isinstance(result, InvestigationResult) + assert isinstance(result.summary, RCASummary) + + async def test_run_populates_summary_fields(self, mock_llm: LLMProvider, incident_context: IncidentContext) -> None: + pipeline = InvestigationPipeline(mock_llm) + result = await pipeline.run(incident_context) + assert result.summary.incident_id == "inc-001" + assert result.summary.title == "API degradation" + assert len(result.summary.root_causes) == 1 + assert result.summary.root_causes[0].service == "api" + assert len(result.summary.remediation) == 1 + + async def test_run_measures_duration(self, mock_llm: LLMProvider, incident_context: IncidentContext) -> None: + pipeline = InvestigationPipeline(mock_llm) + result = await pipeline.run(incident_context) + assert result.duration_ms > 0 + + async def test_run_passes_context_to_llm(self, mock_llm: LLMProvider, incident_context: IncidentContext) -> None: + from typing import cast + from unittest.mock import AsyncMock + + pipeline = InvestigationPipeline(mock_llm) + await pipeline.run(incident_context) + generated = cast(AsyncMock, mock_llm.generate_structured) + generated.assert_awaited_once() + assert generated.await_args is not None + args, kwargs = generated.await_args + assert len(args[0]) == 2 # system + user message + assert args[1] == RCASummary + + async def test_run_with_complex_context(self, mock_llm: LLMProvider) -> None: + from shared.domain.correlation.grouping.models import TraceGroup + + ctx = IncidentContext( + incident_id="inc-002", + primary_service="gateway", + severity="HIGH", + title="Gateway timeout", + detected_at=datetime.now(UTC), + event_count=100, + service_count=3, + trace_count=5, + correlation_groups=[], + trace_groups=[ + TraceGroup( + trace_id="t1", + event_ids=["a", "b", "c"], + service_names=["gateway"], + span_count=10, + ) + ], + ) + pipeline = InvestigationPipeline(mock_llm) + result = await pipeline.run(ctx) + assert result.summary.incident_id == "inc-001" # mock returns inc-001 regardless of context diff --git a/services/ai-investigation-service/tests/test_prompts.py b/services/ai-investigation-service/tests/test_prompts.py new file mode 100644 index 0000000..e72dbf6 --- /dev/null +++ b/services/ai-investigation-service/tests/test_prompts.py @@ -0,0 +1,122 @@ +"""Tests for RCA prompt templates.""" + +from typing import Any + +from prompts.rca import RCAPrompts + +from shared.domain.incident.context.models import ( + AggregatedCorrelationGroup, + AggregatedTimeline, + ImpactAnalysis, + IncidentContext, +) + + +def _make_context(**overrides: Any) -> IncidentContext: + from datetime import UTC, datetime + + defaults: dict = dict( + incident_id="inc-001", + primary_service="api", + severity="CRITICAL", + title="API degradation", + detected_at=datetime.now(UTC), + ) + defaults.update(overrides) + return IncidentContext(**defaults) + + +class TestRCAPrompts: + def test_build_rca_messages_returns_system_and_user(self) -> None: + ctx = _make_context() + messages = RCAPrompts.build_rca_messages(ctx) + assert len(messages) == 2 + assert messages[0].role == "system" + assert messages[1].role == "user" + + def test_system_prompt_includes_sre_context(self) -> None: + ctx = _make_context() + messages = RCAPrompts.build_rca_messages(ctx) + assert "senior Site Reliability Engineer" in messages[0].content + + def test_format_context_includes_incident_metadata(self) -> None: + ctx = _make_context(incident_id="inc-999", title="Test incident", severity="HIGH") + text = RCAPrompts._format_context(ctx) + assert "inc-999" in text + assert "Test incident" in text + assert "HIGH" in text + + def test_format_context_with_timeline(self) -> None: + from datetime import UTC, datetime + + ctx = _make_context() + ctx.timeline = AggregatedTimeline( + incident_id="inc-001", + primary_service="api", + total_events=10, + window_count=2, + start_time=datetime.now(UTC), + end_time=datetime.now(UTC), + duration_seconds=300.0, + ) + text = RCAPrompts._format_context(ctx) + assert "300" in text + assert "10" in text + + def test_format_context_with_correlation_groups(self) -> None: + from shared.domain.correlation.enums import CorrelationSignal + + ctx = _make_context() + ctx.correlation_groups = [ + AggregatedCorrelationGroup( + group_id="g1", + event_ids=["a", "b"], + composite_score=0.95, + signals=[CorrelationSignal.TRACE_MATCH], + services=["api", "db"], + ) + ] + ctx.event_count = 5 + text = RCAPrompts._format_context(ctx) + assert "0.95" in text + assert "trace_match" in text + assert "api" in text + + def test_format_context_with_impacts(self) -> None: + ctx = _make_context() + ctx.impacts = [ + ImpactAnalysis( + service="api", + upstream_causes=["gateway"], + downstream_impact=["cache"], + propagation_paths=[["gateway", "api", "cache"]], + ) + ] + text = RCAPrompts._format_context(ctx) + assert "gateway" in text + assert "cache" in text + assert "propagation" in text.lower() + + def test_format_context_with_trace_groups(self) -> None: + from shared.domain.correlation.grouping.models import TraceGroup + + ctx = _make_context() + ctx.trace_groups = [ + TraceGroup(trace_id="t1", event_ids=["a", "b"], service_names=["api-gateway"], span_count=5), + ] + text = RCAPrompts._format_context(ctx) + assert "t1" in text + assert "api-gateway" in text + + def test_format_context_with_ungrouped_events(self) -> None: + ctx = _make_context() + ctx.ungrouped_events = ["e1", "e2"] + text = RCAPrompts._format_context(ctx) + assert "ungrouped" in text.lower() + assert "e1" in text or "2 events" in text + + def test_format_context_empty_minimal(self) -> None: + ctx = _make_context() + text = RCAPrompts._format_context(ctx) + assert "## Incident" in text + assert "Incident" in text diff --git a/shared/domain/investigation/__init__.py b/shared/domain/investigation/__init__.py new file mode 100644 index 0000000..357ce7e --- /dev/null +++ b/shared/domain/investigation/__init__.py @@ -0,0 +1,17 @@ +"""Domain models for AI-powered incident investigation.""" + +from shared.domain.investigation.models import ( + IncidentProgression, + InvestigationResult, + RCASummary, + RemediationStep, + RootCause, +) + +__all__ = [ + "IncidentProgression", + "InvestigationResult", + "RCASummary", + "RemediationStep", + "RootCause", +] diff --git a/shared/domain/investigation/models.py b/shared/domain/investigation/models.py new file mode 100644 index 0000000..56e3785 --- /dev/null +++ b/shared/domain/investigation/models.py @@ -0,0 +1,53 @@ +from datetime import UTC, datetime + +from pydantic import BaseModel, Field + + +class RootCause(BaseModel): + """A probable root cause identified during investigation.""" + + service: str = Field(description="Service identified as the root cause.") + confidence: float = Field(ge=0.0, le=1.0, description="Confidence in this root cause identification.") + evidence: list[str] = Field(description="Evidence supporting this root cause.") + explanation: str = Field(description="Natural language explanation of why this is the root cause.") + + +class IncidentProgression(BaseModel): + """Chronological narrative of how the incident unfolded.""" + + sequence: list[str] = Field(description="Chronological sequence of key events.") + timeline_summary: str = Field(description="Narrative summary of how the incident unfolded over time.") + key_transitions: list[str] = Field(description="Key state changes or escalation points.") + + +class RemediationStep(BaseModel): + """A suggested remediation action.""" + + action: str = Field(description="Action to take.") + service: str = Field(description="Target service for this action.") + priority: str = Field(description="Priority level: critical, high, medium, low.") + expected_impact: str = Field(description="Expected result of this action.") + + +class RCASummary(BaseModel): + """Complete root cause analysis summary for an incident.""" + + incident_id: str = Field(description="Incident identifier.") + title: str = Field(description="Short human-readable incident summary.") + root_causes: list[RootCause] = Field(description="Identified probable root causes, ranked by confidence.") + progression: IncidentProgression = Field(description="Incident progression narrative.") + remediation: list[RemediationStep] = Field( + default_factory=list, description="Suggested remediation steps, ordered by priority." + ) + overall_confidence: float = Field(ge=0.0, le=1.0, description="Overall confidence in the analysis.") + generated_at: datetime = Field( + default_factory=lambda: datetime.now(UTC), description="When this summary was generated." + ) + + +class InvestigationResult(BaseModel): + """The complete output of an investigation pipeline run.""" + + summary: RCASummary = Field(description="The structured RCA summary.") + raw_output: str | None = Field(default=None, description="Raw LLM response text for debugging.") + duration_ms: float = Field(default=0.0, description="Pipeline execution time in milliseconds.") diff --git a/services/ai-investigation-service/tests/.gitkeep b/shared/domain/investigation/tests/__init__.py similarity index 100% rename from services/ai-investigation-service/tests/.gitkeep rename to shared/domain/investigation/tests/__init__.py diff --git a/shared/domain/investigation/tests/test_models.py b/shared/domain/investigation/tests/test_models.py new file mode 100644 index 0000000..20f40b5 --- /dev/null +++ b/shared/domain/investigation/tests/test_models.py @@ -0,0 +1,88 @@ +"""Tests for investigation domain models.""" + +from shared.domain.investigation.models import ( + IncidentProgression, + InvestigationResult, + RCASummary, + RemediationStep, + RootCause, +) + + +class TestRootCause: + def test_minimal(self) -> None: + rc = RootCause( + service="api-gateway", + confidence=0.85, + evidence=["error budget burned"], + explanation="gateway timeout cascade", + ) + assert rc.service == "api-gateway" + assert rc.confidence == 0.85 + assert len(rc.evidence) == 1 + + def test_confidence_rejects_out_of_range(self) -> None: + import pytest + from pydantic import ValidationError + + with pytest.raises(ValidationError): + RootCause(service="db", confidence=1.5, evidence=[], explanation="") + + +class TestRemediationStep: + def test_minimal(self) -> None: + step = RemediationStep(action="restart", service="api", priority="high", expected_impact="resume serving") + assert step.priority == "high" + + +class TestIncidentProgression: + def test_minimal(self) -> None: + prog = IncidentProgression( + sequence=["error spike", "degradation"], timeline_summary="escalated quickly", key_transitions=["paged"] + ) + assert len(prog.sequence) == 2 + + +class TestRCASummary: + def test_minimal(self) -> None: + rc = RootCause(service="auth", confidence=0.9, evidence=["5xx spike"], explanation="auth timeout") + prog = IncidentProgression( + sequence=["deploy", "errors"], timeline_summary="post-deploy regression", key_transitions=["rollback"] + ) + summary = RCASummary( + incident_id="inc-001", + title="Auth degradation", + root_causes=[rc], + progression=prog, + overall_confidence=0.9, + ) + assert summary.incident_id == "inc-001" + assert len(summary.root_causes) == 1 + assert summary.generated_at.tzinfo is not None + + def test_with_remediation(self) -> None: + rc = RootCause(service="db", confidence=0.7, evidence=["slow queries"], explanation="index missing") + prog = IncidentProgression(sequence=["latency"], timeline_summary="degraded", key_transitions=[]) + step = RemediationStep( + action="add index", service="db", priority="critical", expected_impact="restore performance" + ) + summary = RCASummary( + incident_id="inc-002", + title="DB slowdown", + root_causes=[rc], + progression=prog, + remediation=[step], + overall_confidence=0.7, + ) + assert len(summary.remediation) == 1 + + +class TestInvestigationResult: + def test_minimal(self) -> None: + rc = RootCause(service="srv", confidence=0.5, evidence=["err"], explanation="cause") + prog = IncidentProgression(sequence=["a"], timeline_summary="b", key_transitions=["c"]) + summary = RCASummary(incident_id="i1", title="t", root_causes=[rc], progression=prog, overall_confidence=0.5) + result = InvestigationResult(summary=summary) + assert result.summary.incident_id == "i1" + assert result.duration_ms == 0.0 + assert result.raw_output is None From 65d4c07f8083e22d06788823e97c54652618e0d3 Mon Sep 17 00:00:00 2001 From: saurabh batham Date: Sun, 14 Jun 2026 21:14:19 +0530 Subject: [PATCH 2/2] fix(ai-investigation): move prompts into app/ to fix CI import resolution ModuleNotFoundError occurred in CI because app/pipeline.py used absolute import 'from prompts.rca import RCAPrompts' which relied on sys.path containing the service root. Moved prompts/ into app/ so imports always resolve via 'from app.prompts.rca import RCAPrompts' within the app package, regardless of how the test image is built. --- services/ai-investigation-service/Dockerfile | 1 - services/ai-investigation-service/app/pipeline.py | 3 +-- .../ai-investigation-service/{ => app}/prompts/__init__.py | 2 +- services/ai-investigation-service/{ => app}/prompts/rca.py | 0 services/ai-investigation-service/tests/test_prompts.py | 2 +- 5 files changed, 3 insertions(+), 5 deletions(-) rename services/ai-investigation-service/{ => app}/prompts/__init__.py (69%) rename services/ai-investigation-service/{ => app}/prompts/rca.py (100%) diff --git a/services/ai-investigation-service/Dockerfile b/services/ai-investigation-service/Dockerfile index 96e2427..38fbdef 100644 --- a/services/ai-investigation-service/Dockerfile +++ b/services/ai-investigation-service/Dockerfile @@ -9,7 +9,6 @@ RUN pip install --no-cache-dir -r requirements.txt \ && pip install --no-cache-dir --no-deps . COPY services/ai-investigation-service/app ./app -COPY services/ai-investigation-service/prompts ./prompts COPY services/ai-investigation-service/workflows ./workflows CMD ["python", "-c", "print('ai-investigation-service scaffold ready')"] diff --git a/services/ai-investigation-service/app/pipeline.py b/services/ai-investigation-service/app/pipeline.py index b53c779..ed10cbc 100644 --- a/services/ai-investigation-service/app/pipeline.py +++ b/services/ai-investigation-service/app/pipeline.py @@ -6,8 +6,7 @@ import time -from prompts.rca import RCAPrompts - +from app.prompts.rca import RCAPrompts from shared.contracts.interfaces.llm_provider import LLMProvider from shared.domain.incident.context.models import IncidentContext from shared.domain.investigation.models import InvestigationResult, RCASummary diff --git a/services/ai-investigation-service/prompts/__init__.py b/services/ai-investigation-service/app/prompts/__init__.py similarity index 69% rename from services/ai-investigation-service/prompts/__init__.py rename to services/ai-investigation-service/app/prompts/__init__.py index 179ea51..29d20e0 100644 --- a/services/ai-investigation-service/prompts/__init__.py +++ b/services/ai-investigation-service/app/prompts/__init__.py @@ -1,5 +1,5 @@ """Modular prompt templates for AI investigation workflows.""" -from prompts.rca import RCAPrompts +from app.prompts.rca import RCAPrompts __all__ = ["RCAPrompts"] diff --git a/services/ai-investigation-service/prompts/rca.py b/services/ai-investigation-service/app/prompts/rca.py similarity index 100% rename from services/ai-investigation-service/prompts/rca.py rename to services/ai-investigation-service/app/prompts/rca.py diff --git a/services/ai-investigation-service/tests/test_prompts.py b/services/ai-investigation-service/tests/test_prompts.py index e72dbf6..0730143 100644 --- a/services/ai-investigation-service/tests/test_prompts.py +++ b/services/ai-investigation-service/tests/test_prompts.py @@ -2,7 +2,7 @@ from typing import Any -from prompts.rca import RCAPrompts +from app.prompts.rca import RCAPrompts from shared.domain.incident.context.models import ( AggregatedCorrelationGroup,