From a8526a4212369aa1ce27b706873791e8fe501d56 Mon Sep 17 00:00:00 2001 From: saurabh batham Date: Sun, 14 Jun 2026 17:47:00 +0530 Subject: [PATCH 1/2] feat: implement trace-aware event grouping with span relation scoring - Add span_id and parent_span_id fields to TimelineEvent model - Create TraceGroupingService with SpanNode/TraceTree/TraceGroup models for reconstructing span hierarchies from event telemetry - Implement SpanRelationStrategy with hierarchical scoring: parent-child (1.0) > sibling (0.8) > same trace (0.5) - Wire SpanRelationStrategy into CorrelationEngine default pipeline - Fix edge cases: orphan spans with unknown parents treated as roots, events without span_ids fall back to same-trace scoring RP-15 --- shared/domain/correlation/engine.py | 2 + shared/domain/correlation/enums.py | 3 + .../domain/correlation/grouping/__init__.py | 11 ++ shared/domain/correlation/grouping/models.py | 54 ++++++++ .../correlation/grouping/tests/__init__.py | 0 .../grouping/tests/test_trace_grouping.py | 131 ++++++++++++++++++ .../correlation/grouping/trace_grouping.py | 91 ++++++++++++ .../domain/correlation/strategies/__init__.py | 2 + .../correlation/strategies/span_relation.py | 79 +++++++++++ .../correlation/tests/test_span_relation.py | 112 +++++++++++++++ shared/domain/timeline/models.py | 2 + 11 files changed, 487 insertions(+) create mode 100644 shared/domain/correlation/grouping/__init__.py create mode 100644 shared/domain/correlation/grouping/models.py create mode 100644 shared/domain/correlation/grouping/tests/__init__.py create mode 100644 shared/domain/correlation/grouping/tests/test_trace_grouping.py create mode 100644 shared/domain/correlation/grouping/trace_grouping.py create mode 100644 shared/domain/correlation/strategies/span_relation.py create mode 100644 shared/domain/correlation/tests/test_span_relation.py diff --git a/shared/domain/correlation/engine.py b/shared/domain/correlation/engine.py index 0fb3084..865ee1a 100644 --- a/shared/domain/correlation/engine.py +++ b/shared/domain/correlation/engine.py @@ -5,6 +5,7 @@ DependencyStrategy, ErrorSignatureStrategy, RequestIdStrategy, + SpanRelationStrategy, TimeWindowStrategy, TraceIdStrategy, ) @@ -26,6 +27,7 @@ def __init__( built: list[CorrelationStrategy] = [ TimeWindowStrategy(window_seconds=default_window_seconds), TraceIdStrategy(), + SpanRelationStrategy(), RequestIdStrategy(), ErrorSignatureStrategy(), ] diff --git a/shared/domain/correlation/enums.py b/shared/domain/correlation/enums.py index 72a1e29..2f1603c 100644 --- a/shared/domain/correlation/enums.py +++ b/shared/domain/correlation/enums.py @@ -4,6 +4,7 @@ class CorrelationStrategyType(StrEnum): TIME_WINDOW = "time_window" TRACE_ID = "trace_id" + SPAN_RELATION = "span_relation" REQUEST_ID = "request_id" DEPENDENCY = "dependency" ERROR_SIGNATURE = "error_signature" @@ -12,6 +13,8 @@ class CorrelationStrategyType(StrEnum): class CorrelationSignal(StrEnum): TIME_PROXIMITY = "time_proximity" TRACE_MATCH = "trace_match" + SPAN_PARENT_CHILD = "span_parent_child" + SPAN_SIBLING = "span_sibling" REQUEST_MATCH = "request_match" DEPENDENCY_CHAIN = "dependency_chain" ERROR_PATTERN = "error_pattern" diff --git a/shared/domain/correlation/grouping/__init__.py b/shared/domain/correlation/grouping/__init__.py new file mode 100644 index 0000000..bb872c2 --- /dev/null +++ b/shared/domain/correlation/grouping/__init__.py @@ -0,0 +1,11 @@ +"""Trace-aware event grouping for telemetry correlation.""" + +from shared.domain.correlation.grouping.models import SpanNode, TraceGroup, TraceTree +from shared.domain.correlation.grouping.trace_grouping import TraceGroupingService + +__all__ = [ + "SpanNode", + "TraceGroup", + "TraceTree", + "TraceGroupingService", +] diff --git a/shared/domain/correlation/grouping/models.py b/shared/domain/correlation/grouping/models.py new file mode 100644 index 0000000..ed93901 --- /dev/null +++ b/shared/domain/correlation/grouping/models.py @@ -0,0 +1,54 @@ +from pydantic import BaseModel, Field + +from shared.domain.timeline.models import TimelineEvent + + +class SpanNode(BaseModel): + """A span node in a trace tree, with parent-child relationships.""" + + span_id: str = Field(description="Span identifier (16 hex chars).") + trace_id: str = Field(description="Trace identifier (32 hex chars).") + parent_span_id: str | None = Field(default=None, description="Parent span ID, if this is a child span.") + service_name: str = Field(default="", description="Service that produced this span.") + event_ids: list[str] = Field(default_factory=list, description="TimelineEvent IDs mapped to this span.") + children: list["SpanNode"] = Field(default_factory=list, description="Child spans.") + + @property + def is_root(self) -> bool: + return self.parent_span_id is None + + +class TraceTree(BaseModel): + """A reconstructed trace tree from a set of timeline events.""" + + trace_id: str = Field(description="Trace identifier (32 hex chars).") + root_spans: list[SpanNode] = Field(default_factory=list, description="Root spans (no parent).") + all_spans: list[SpanNode] = Field(default_factory=list, description="Flat list of all spans.") + event_ids: list[str] = Field(default_factory=list, description="All event IDs belonging to this trace.") + service_names: list[str] = Field(default_factory=list, description="Unique services involved.") + + @property + def span_count(self) -> int: + return len(self.all_spans) + + @property + def depth(self) -> int: + if not self.root_spans: + return 0 + return max(_max_depth(r) for r in self.root_spans) + + +class TraceGroup(BaseModel): + """A group of events sharing the same trace, with optional span relationship metadata.""" + + trace_id: str = Field(description="Trace identifier.") + event_ids: list[str] = Field(description="Event IDs in this trace group.") + tree: TraceTree | None = Field(default=None, description="Reconstructed span tree, if available.") + service_names: list[str] = Field(default_factory=list, description="Services involved.") + span_count: int = Field(default=0, description="Number of distinct spans.") + + +def _max_depth(node: SpanNode) -> int: + if not node.children: + return 1 + return 1 + max(_max_depth(c) for c in node.children) diff --git a/shared/domain/correlation/grouping/tests/__init__.py b/shared/domain/correlation/grouping/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/shared/domain/correlation/grouping/tests/test_trace_grouping.py b/shared/domain/correlation/grouping/tests/test_trace_grouping.py new file mode 100644 index 0000000..db46cb8 --- /dev/null +++ b/shared/domain/correlation/grouping/tests/test_trace_grouping.py @@ -0,0 +1,131 @@ +from datetime import datetime, timezone + +from shared.domain.correlation.grouping import TraceGroupingService +from shared.domain.correlation.grouping.models import SpanNode, TraceTree +from shared.domain.timeline.enums import TimelineEventCategory, TimelineEventSource +from shared.domain.timeline.models import TimelineEvent + + +def _event( + event_id: str, + trace_id: str | None = None, + span_id: str | None = None, + parent_span_id: str | None = None, + service: str = "api", + ts_offset: int = 0, +) -> TimelineEvent: + base = datetime(2026, 6, 14, 10, 0, 0, tzinfo=timezone.utc) + return TimelineEvent( + event_id=event_id, + category=TimelineEventCategory.METRIC_ANOMALY, + source=TimelineEventSource.TELEMETRY, + timestamp=base, + service_name=service, + title=f"event {event_id}", + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + ) + + +class TestTraceGroupingService: + def setup_method(self) -> None: + self.service = TraceGroupingService() + + async def test_empty_events_returns_empty(self) -> None: + trees = self.service.build_trace_trees([]) + assert trees == [] + + async def test_events_without_trace_id_are_ignored(self) -> None: + trees = self.service.build_trace_trees([_event("a"), _event("b")]) + assert trees == [] + + async def test_single_trace_single_span(self) -> None: + trees = self.service.build_trace_trees([ + _event("a", trace_id="t1", span_id="s1"), + ]) + assert len(trees) == 1 + tree = trees[0] + assert tree.trace_id == "t1" + assert tree.span_count == 1 + assert len(tree.root_spans) == 1 + assert tree.root_spans[0].span_id == "s1" + + async def test_parent_child_span_hierarchy(self) -> None: + trees = self.service.build_trace_trees([ + _event("a", trace_id="t1", span_id="s1"), + _event("b", trace_id="t1", span_id="s2", parent_span_id="s1"), + _event("c", trace_id="t1", span_id="s3", parent_span_id="s2"), + ]) + assert len(trees) == 1 + tree = trees[0] + assert tree.span_count == 3 + assert len(tree.root_spans) == 1 + root = tree.root_spans[0] + assert root.span_id == "s1" + assert len(root.children) == 1 + assert root.children[0].span_id == "s2" + assert len(root.children[0].children) == 1 + assert root.children[0].children[0].span_id == "s3" + + async def test_multiple_root_spans(self) -> None: + trees = self.service.build_trace_trees([ + _event("a", trace_id="t1", span_id="s1"), + _event("b", trace_id="t1", span_id="s2"), + ]) + assert len(trees) == 1 + tree = trees[0] + assert len(tree.root_spans) == 2 + + async def test_multiple_traces(self) -> None: + trees = self.service.build_trace_trees([ + _event("a", trace_id="t1", span_id="s1"), + _event("b", trace_id="t2", span_id="s2"), + ]) + assert len(trees) == 2 + trace_ids = {t.trace_id for t in trees} + assert trace_ids == {"t1", "t2"} + + async def test_multiple_events_per_span(self) -> None: + trees = self.service.build_trace_trees([ + _event("a", trace_id="t1", span_id="s1"), + _event("b", trace_id="t1", span_id="s1"), + ]) + assert len(trees) == 1 + tree = trees[0] + root = tree.root_spans[0] + assert sorted(root.event_ids) == ["a", "b"] + + async def test_depth_calculation(self) -> None: + trees = self.service.build_trace_trees([ + _event("a", trace_id="t1", span_id="s1"), + _event("b", trace_id="t1", span_id="s2", parent_span_id="s1"), + _event("c", trace_id="t1", span_id="s3", parent_span_id="s2"), + ]) + assert trees[0].depth == 3 + + +class TestTraceGroup: + def setup_method(self) -> None: + self.service = TraceGroupingService() + + async def test_build_trace_groups(self) -> None: + groups = self.service.build_trace_groups([ + _event("a", trace_id="t1", span_id="s1", service="api"), + _event("b", trace_id="t1", span_id="s2", service="db"), + ]) + assert len(groups) == 1 + group = groups[0] + assert group.trace_id == "t1" + assert sorted(group.event_ids) == ["a", "b"] + assert sorted(group.service_names) == ["api", "db"] + assert group.span_count == 2 + assert group.tree is not None + + async def test_trace_group_no_spans(self) -> None: + groups = self.service.build_trace_groups([ + _event("a", trace_id="t1"), + _event("b", trace_id="t1"), + ]) + assert len(groups) == 1 + assert groups[0].span_count == 0 diff --git a/shared/domain/correlation/grouping/trace_grouping.py b/shared/domain/correlation/grouping/trace_grouping.py new file mode 100644 index 0000000..6bdc1f6 --- /dev/null +++ b/shared/domain/correlation/grouping/trace_grouping.py @@ -0,0 +1,91 @@ +"""Service for building trace trees from timeline events.""" + +from collections import defaultdict + +from shared.domain.correlation.grouping.models import SpanNode, TraceGroup, TraceTree +from shared.domain.timeline.models import TimelineEvent + + +class TraceGroupingService: + """Groups timeline events by trace ID and reconstructs span hierarchies.""" + + def build_trace_trees(self, events: list[TimelineEvent]) -> list[TraceTree]: + """Build a list of TraceTree objects from a set of timeline events.""" + by_trace: dict[str, list[TimelineEvent]] = defaultdict(list) + for ev in events: + if ev.trace_id: + by_trace[ev.trace_id].append(ev) + + return [self._build_single_tree(trace_id, evs) for trace_id, evs in by_trace.items()] + + def build_trace_groups(self, events: list[TimelineEvent]) -> list[TraceGroup]: + """Build TraceGroup objects (lighter weight than full trees when spans aren't needed).""" + by_trace: dict[str, list[TimelineEvent]] = defaultdict(list) + for ev in events: + if ev.trace_id: + by_trace[ev.trace_id].append(ev) + + groups: list[TraceGroup] = [] + for trace_id, evs in by_trace.items(): + span_ids = {ev.span_id for ev in evs if ev.span_id} + services = list({ev.service_name for ev in evs}) + tree = self._build_single_tree(trace_id, evs) + groups.append( + TraceGroup( + trace_id=trace_id, + event_ids=[ev.event_id for ev in evs], + tree=tree, + service_names=services, + span_count=len(span_ids), + ) + ) + return groups + + def _build_single_tree(self, trace_id: str, events: list[TimelineEvent]) -> TraceTree: + """Reconstruct a TraceTree from events belonging to a single trace.""" + span_map: dict[str, SpanNode] = {} + event_map: dict[str, list[str]] = defaultdict(list) + services: set[str] = set() + all_event_ids: list[str] = [] + + for ev in events: + all_event_ids.append(ev.event_id) + if ev.service_name: + services.add(ev.service_name) + if ev.span_id: + event_map[ev.span_id].append(ev.event_id) + else: + continue + + for ev in events: + if not ev.span_id: + continue + if ev.span_id not in span_map: + span_map[ev.span_id] = SpanNode( + span_id=ev.span_id, + trace_id=trace_id, + parent_span_id=ev.parent_span_id, + service_name=ev.service_name or "", + event_ids=event_map.get(ev.span_id, []), + ) + + children_map: dict[str, list[SpanNode]] = defaultdict(list) + for span in span_map.values(): + if span.parent_span_id and span.parent_span_id in span_map: + children_map[span.parent_span_id].append(span) + + for span in span_map.values(): + span.children = children_map.get(span.span_id, []) + + root_spans = [ + s for s in span_map.values() + if s.parent_span_id is None or s.parent_span_id not in span_map + ] + + return TraceTree( + trace_id=trace_id, + root_spans=root_spans, + all_spans=list(span_map.values()), + event_ids=all_event_ids, + service_names=sorted(services), + ) diff --git a/shared/domain/correlation/strategies/__init__.py b/shared/domain/correlation/strategies/__init__.py index 9ec590c..f621af0 100644 --- a/shared/domain/correlation/strategies/__init__.py +++ b/shared/domain/correlation/strategies/__init__.py @@ -4,6 +4,7 @@ from shared.domain.correlation.strategies.dependency import DependencyStrategy from shared.domain.correlation.strategies.error_signature import ErrorSignatureStrategy from shared.domain.correlation.strategies.request_id import RequestIdStrategy +from shared.domain.correlation.strategies.span_relation import SpanRelationStrategy from shared.domain.correlation.strategies.time_window import TimeWindowStrategy from shared.domain.correlation.strategies.trace_id import TraceIdStrategy @@ -12,6 +13,7 @@ "DependencyStrategy", "ErrorSignatureStrategy", "RequestIdStrategy", + "SpanRelationStrategy", "TimeWindowStrategy", "TraceIdStrategy", ] diff --git a/shared/domain/correlation/strategies/span_relation.py b/shared/domain/correlation/strategies/span_relation.py new file mode 100644 index 0000000..fb35e53 --- /dev/null +++ b/shared/domain/correlation/strategies/span_relation.py @@ -0,0 +1,79 @@ +from shared.domain.correlation.enums import CorrelationSignal, CorrelationStrategyType +from shared.domain.correlation.grouping import TraceGroupingService +from shared.domain.correlation.grouping.models import SpanNode, TraceTree +from shared.domain.correlation.models import CorrelationContext, CorrelationMatch +from shared.domain.correlation.strategies.base import CorrelationStrategy + + +class SpanRelationStrategy(CorrelationStrategy): + """Scores event pairs based on their span relationship within a trace. + + Hierarchy (highest → lowest score): + - Parent-child spans (direct call relationship): 1.0 + - Sibling spans (same parent): 0.8 + - Same trace, no direct span relation: 0.5 + """ + + strategy_type = CorrelationStrategyType.SPAN_RELATION + weight = 0.85 + + PARENT_CHILD_SCORE = 1.0 + SIBLING_SCORE = 0.8 + SAME_TRACE_SCORE = 0.5 + + def __init__(self, grouping_service: TraceGroupingService | None = None) -> None: + self._grouping = grouping_service or TraceGroupingService() + + async def correlate(self, context: CorrelationContext) -> list[CorrelationMatch]: + matches: list[CorrelationMatch] = [] + trace_groups = self._grouping.build_trace_groups(context.events) + + for group in trace_groups: + tree = group.tree + if tree is None: + continue + + event_ids = group.event_ids + if len(event_ids) < 2: + continue + + for i in range(len(event_ids)): + for j in range(i + 1, len(event_ids)): + eid_a = event_ids[i] + eid_b = event_ids[j] + score, signal = self._score_pair(eid_a, eid_b, tree) + matches.append( + CorrelationMatch( + event_id_a=eid_a, + event_id_b=eid_b, + strategy_type=self.strategy_type, + signal=signal, + score=score, + metadata={"trace_id": group.trace_id}, + ) + ) + + return matches + + def _score_pair(self, eid_a: str, eid_b: str, tree: TraceTree) -> tuple[float, CorrelationSignal]: + spans_a = [s for s in tree.all_spans if eid_a in s.event_ids] + spans_b = [s for s in tree.all_spans if eid_b in s.event_ids] + + if not spans_a or not spans_b: + return self.SAME_TRACE_SCORE, CorrelationSignal.TRACE_MATCH + + for sa in spans_a: + for sb in spans_b: + if sa.parent_span_id == sb.span_id: + return self.PARENT_CHILD_SCORE, CorrelationSignal.SPAN_PARENT_CHILD + if sb.parent_span_id == sa.span_id: + return self.PARENT_CHILD_SCORE, CorrelationSignal.SPAN_PARENT_CHILD + + for sa in spans_a: + for sb in spans_b: + if (sa.parent_span_id is not None + and sb.parent_span_id is not None + and sa.parent_span_id == sb.parent_span_id): + return self.SIBLING_SCORE, CorrelationSignal.SPAN_SIBLING + + return self.SAME_TRACE_SCORE, CorrelationSignal.TRACE_MATCH diff --git a/shared/domain/correlation/tests/test_span_relation.py b/shared/domain/correlation/tests/test_span_relation.py new file mode 100644 index 0000000..b38e643 --- /dev/null +++ b/shared/domain/correlation/tests/test_span_relation.py @@ -0,0 +1,112 @@ +from datetime import datetime, timezone + +from shared.domain.correlation.enums import CorrelationSignal, CorrelationStrategyType +from shared.domain.correlation.models import CorrelationContext +from shared.domain.correlation.strategies.span_relation import SpanRelationStrategy +from shared.domain.timeline.enums import TimelineEventCategory, TimelineEventSource +from shared.domain.timeline.models import TimelineEvent + +TRACE = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa0" + + +def _event( + event_id: str, + trace_id: str | None = TRACE, + span_id: str | None = None, + parent_span_id: str | None = None, +) -> TimelineEvent: + return TimelineEvent( + event_id=event_id, + category=TimelineEventCategory.METRIC_ANOMALY, + source=TimelineEventSource.TELEMETRY, + timestamp=datetime(2026, 6, 14, 10, 0, 0, tzinfo=timezone.utc), + service_name="api", + title=f"event {event_id}", + trace_id=trace_id, + span_id=span_id, + parent_span_id=parent_span_id, + ) + + +class TestSpanRelationStrategy: + def setup_method(self) -> None: + self.strategy = SpanRelationStrategy() + + async def test_empty_events(self) -> None: + ctx = CorrelationContext(events=[]) + matches = await self.strategy.correlate(ctx) + assert matches == [] + + async def test_single_event_no_match(self) -> None: + ctx = CorrelationContext(events=[_event("a", span_id="s1")]) + matches = await self.strategy.correlate(ctx) + assert matches == [] + + async def test_parent_child_scores_highest(self) -> None: + ctx = CorrelationContext(events=[ + _event("a", span_id="s1"), + _event("b", span_id="s2", parent_span_id="s1"), + ]) + matches = await self.strategy.correlate(ctx) + assert len(matches) == 1 + assert matches[0].score == SpanRelationStrategy.PARENT_CHILD_SCORE + assert matches[0].signal == CorrelationSignal.SPAN_PARENT_CHILD + + async def test_siblings_score_middle(self) -> None: + ctx = CorrelationContext(events=[ + _event("a", span_id="s1", parent_span_id="s0"), + _event("b", span_id="s2", parent_span_id="s0"), + ]) + matches = await self.strategy.correlate(ctx) + assert len(matches) == 1 + assert matches[0].score == SpanRelationStrategy.SIBLING_SCORE + assert matches[0].signal == CorrelationSignal.SPAN_SIBLING + + async def test_same_trace_no_span_relation_scores_lowest(self) -> None: + ctx = CorrelationContext(events=[ + _event("a", trace_id=TRACE, span_id="s1"), + _event("b", trace_id=TRACE, span_id="s2"), + ]) + matches = await self.strategy.correlate(ctx) + assert len(matches) == 1 + assert matches[0].score == SpanRelationStrategy.SAME_TRACE_SCORE + assert matches[0].signal == CorrelationSignal.TRACE_MATCH + + async def test_events_without_span_ids_fall_back_to_same_trace(self) -> None: + ctx = CorrelationContext(events=[ + _event("a", trace_id=TRACE), + _event("b", trace_id=TRACE), + ]) + matches = await self.strategy.correlate(ctx) + assert len(matches) == 1 + assert matches[0].score == SpanRelationStrategy.SAME_TRACE_SCORE + + async def test_different_traces_no_match(self) -> None: + ctx = CorrelationContext(events=[ + _event("a", trace_id="t1", span_id="s1"), + _event("b", trace_id="t2", span_id="s2"), + ]) + matches = await self.strategy.correlate(ctx) + assert matches == [] + + async def test_metadata_includes_trace_id(self) -> None: + ctx = CorrelationContext(events=[ + _event("a", span_id="s1"), + _event("b", span_id="s2", parent_span_id="s1"), + ]) + matches = await self.strategy.correlate(ctx) + assert matches[0].metadata.get("trace_id") == TRACE + + async def test_in_engine_default_pipeline(self) -> None: + from shared.domain.correlation.engine import CorrelationEngine + + engine = CorrelationEngine() + events = [ + _event("a", span_id="s1"), + _event("b", span_id="s2", parent_span_id="s1"), + _event("c"), + ] + result = await engine.correlate(events) + assert result.total_events == 3 + assert len(result.groups) >= 1 + assert CorrelationStrategyType.SPAN_RELATION.value in result.strategy_counts diff --git a/shared/domain/timeline/models.py b/shared/domain/timeline/models.py index 43f62b2..a03bdb4 100644 --- a/shared/domain/timeline/models.py +++ b/shared/domain/timeline/models.py @@ -15,6 +15,8 @@ class TimelineEvent(BaseModel): title: str = Field(description="Short human-readable event summary.") description: str = Field(default="", description="Detailed event description.") trace_id: str | None = Field(default=None, description="Correlated trace identifier.") + span_id: str | None = Field(default=None, description="Span identifier within the trace (16 hex chars).") + parent_span_id: str | None = Field(default=None, description="Parent span identifier, if this span is a child.") request_id: str | None = Field(default=None, description="Correlated request identifier.") severity: Severity | None = Field(default=None, description="Severity level if applicable.") tags: dict[str, str] = Field(default_factory=dict, description="Dimension key-value pairs.") From c9004890ce96f2891801feed1e38a38c3790785f Mon Sep 17 00:00:00 2001 From: saurabh batham Date: Sun, 14 Jun 2026 18:08:23 +0530 Subject: [PATCH 2/2] chore: add ruff, mypy, and pre-commit hooks - Configure ruff (py313, 120 line-length, common rule sets) - Configure mypy (strict optional, check untyped defs) - Add .pre-commit-config.yaml with ruff --fix, ruff-format, mypy - Fix 106+ pre-existing lint issues (unused imports, datetime.UTC, import sorting, unsorted __all__, assert False, etc.) - Fix mypy issues: mock constructor kwargs, type narrowing, None guards - Add ruff/mypy/pre-commit to [dev] optional dependencies --- .pre-commit-config.yaml | 18 ++++ examples/event_contracts.py | 7 +- examples/provider_interfaces.py | 10 ++- .../elasticsearch/elasticsearch_log_store.py | 21 +++-- .../tests/test_elasticsearch_log_store.py | 18 ++-- infrastructure/monitoring/otel/__init__.py | 4 +- .../monitoring/otel/instrumentation.py | 6 +- .../monitoring/otel/otel_tracer_provider.py | 16 ++-- infrastructure/rabbitmq/rabbitmq_event_bus.py | 10 +-- .../rabbitmq/tests/test_rabbitmq_event_bus.py | 72 +++++---------- pyproject.toml | 45 ++++++++++ .../app/routers/correlation.py | 4 +- .../app/routers/timeline.py | 1 - services/correlation-service/app/schemas.py | 4 +- .../correlation-service/tests/conftest.py | 3 +- .../tests/test_timeline.py | 4 - .../ingestion-service/app/routers/health.py | 2 +- .../app/services/ingestion_service.py | 6 +- services/ingestion-service/tests/conftest.py | 4 +- .../ingestion-service/tests/test_config.py | 4 +- .../ingestion-service/tests/test_health.py | 8 +- .../ingestion-service/tests/test_ingest.py | 10 +-- shared/config/base.py | 10 +-- shared/config/tests/test_base.py | 2 +- shared/contracts/__init__.py | 4 +- shared/contracts/events/__init__.py | 5 +- shared/contracts/events/base.py | 4 +- shared/contracts/events/incident.py | 4 +- shared/contracts/events/investigation.py | 4 +- shared/contracts/events/telemetry.py | 4 +- .../contracts/events/tests/test_contracts.py | 5 +- shared/contracts/interfaces/__init__.py | 2 +- shared/contracts/interfaces/event_bus.py | 4 +- shared/contracts/interfaces/log_store.py | 4 +- shared/domain/__init__.py | 11 ++- shared/domain/correlation/__init__.py | 3 +- .../domain/correlation/grouping/__init__.py | 2 +- shared/domain/correlation/grouping/models.py | 2 - .../grouping/tests/test_trace_grouping.py | 89 +++++++++++-------- .../correlation/grouping/trace_grouping.py | 5 +- shared/domain/correlation/models.py | 4 +- shared/domain/correlation/pipeline.py | 8 +- .../correlation/strategies/span_relation.py | 10 ++- .../domain/correlation/tests/test_engine.py | 42 +++++---- .../domain/correlation/tests/test_pipeline.py | 4 +- .../correlation/tests/test_span_relation.py | 64 +++++++------ .../correlation/tests/test_strategies.py | 20 ++--- shared/domain/timeline/models.py | 8 +- .../domain/timeline/services/reconstructor.py | 6 +- shared/domain/timeline/tests/test_models.py | 32 +++---- .../timeline/tests/test_reconstructor.py | 24 +++-- shared/observability/tracing/__init__.py | 2 +- shared/observability/tracing/models.py | 1 - shared/observability/tracing/provider.py | 33 +++---- 54 files changed, 369 insertions(+), 330 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..cbc2624 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.11.6 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: local + hooks: + - id: mypy + name: mypy + entry: .venv/Scripts/python.exe -m mypy + language: system + types: [python] + require_serial: true + pass_filenames: false + args: [--no-error-summary, shared, infrastructure] diff --git a/examples/event_contracts.py b/examples/event_contracts.py index e06bcdd..109d466 100644 --- a/examples/event_contracts.py +++ b/examples/event_contracts.py @@ -1,7 +1,7 @@ """Serialization and usage examples for RootPilot event contracts.""" import json -from datetime import datetime, timezone +from datetime import UTC, datetime from shared.contracts import ( Event, @@ -11,7 +11,6 @@ TelemetryEvent, ) - # --------------------------------------------------------------------------- # TelemetryEvent – serialize / deserialize # --------------------------------------------------------------------------- @@ -55,7 +54,7 @@ title="Gateway timeout spike above threshold", description="p99 latency exceeded 5s for 3 consecutive minutes.", source_event_ids=[envelope.id], - detected_at=datetime.now(timezone.utc), + detected_at=datetime.now(UTC), ) incident_json = incident.model_dump_json() @@ -122,4 +121,4 @@ with open("examples/_sample_events.json", "w") as f: json.dump(all_events, f, indent=2, default=str) -print(f"Bulk export written to examples/_sample_events.json") +print("Bulk export written to examples/_sample_events.json") diff --git a/examples/provider_interfaces.py b/examples/provider_interfaces.py index 4a018b7..ecd50fa 100644 --- a/examples/provider_interfaces.py +++ b/examples/provider_interfaces.py @@ -2,16 +2,17 @@ import asyncio from collections.abc import AsyncIterator -from datetime import datetime, timezone +from datetime import UTC, datetime from pydantic import BaseModel -from shared.contracts import Event, EventBus, LLMMessage, LLMProvider, LLMResponse, LogEntry, LogFilter, LogStore +from shared.contracts import Event, EventBus, LLMMessage, LLMProvider, LLMResponse, LogEntry, LogFilter, LogStore # --------------------------------------------------------------------------- # Example EventBus implementation # --------------------------------------------------------------------------- + class PrintBus(EventBus): async def publish(self, event: Event, topic: str | None = None) -> None: print(f"[{topic or event.topic}] {event.source}: {event.payload}") @@ -33,6 +34,7 @@ async def health(self) -> bool: # Example LogStore implementation # --------------------------------------------------------------------------- + class MemoryLogStore(LogStore): def __init__(self) -> None: self._logs: list[LogEntry] = [] @@ -56,6 +58,7 @@ async def health(self) -> bool: # Example LLMProvider implementation # --------------------------------------------------------------------------- + class EchoProvider(LLMProvider): async def generate( self, @@ -82,12 +85,13 @@ async def embed(self, text: str, model: str | None = None) -> list[float]: # Usage # --------------------------------------------------------------------------- + async def main() -> None: bus = PrintBus() await bus.publish(Event(source="test", topic="ping", payload={"msg": "hello"})) store = MemoryLogStore() - await store.write(LogEntry(timestamp=datetime.now(timezone.utc), service="svc", level="INFO", message="started")) + await store.write(LogEntry(timestamp=datetime.now(UTC), service="svc", level="INFO", message="started")) async for entry in store.query(LogFilter(service="svc")): print(f" Log: {entry.message}") diff --git a/infrastructure/elasticsearch/elasticsearch_log_store.py b/infrastructure/elasticsearch/elasticsearch_log_store.py index 26bdbd4..ab09ddd 100644 --- a/infrastructure/elasticsearch/elasticsearch_log_store.py +++ b/infrastructure/elasticsearch/elasticsearch_log_store.py @@ -4,13 +4,14 @@ import logging from collections.abc import AsyncIterator -from datetime import datetime, timezone +from datetime import UTC, datetime +from typing import Any -from elasticsearch import AsyncElasticsearch # type: ignore -from elasticsearch.helpers import async_bulk # type: ignore +from elasticsearch import AsyncElasticsearch +from elasticsearch.helpers import async_bulk from pydantic import BaseModel, Field -from shared.contracts.interfaces.log_store import LogEntry, LogFilter, LogStore, SortOrder +from shared.contracts.interfaces.log_store import LogEntry, LogFilter, LogStore logger = logging.getLogger(__name__) @@ -29,7 +30,7 @@ def _index_name(dt: datetime | None = None) -> str: """Return the target index name for a given timestamp (UTC daily bucket).""" - ts = dt or datetime.now(timezone.utc) + ts = dt or datetime.now(UTC) return f"{INDEX_PREFIX}-{ts.strftime('%Y.%m.%d')}" @@ -51,6 +52,7 @@ def _build_es_doc(entry: LogEntry) -> dict: # Applied automatically at startup to ensure consistent mappings. # ───────────────────────────────────────────────────────────────────────── + def _default_index_template() -> dict: return { "index_patterns": [f"{INDEX_PREFIX}-*"], @@ -117,6 +119,7 @@ def _default_ilm_policy() -> dict: # ── Query builder ───────────────────────────────────────────────────────── + def _build_query_body(filter: LogFilter) -> dict: """Translate a LogFilter into an Elasticsearch query body.""" must_clauses: list[dict] = [] @@ -157,6 +160,7 @@ def _build_query_body(filter: LogFilter) -> dict: # ── Elasticsearch Configuration ─────────────────────────────────────────── + class ElasticsearchConfig(BaseModel): hosts: str = Field( default="http://localhost:9200", @@ -182,6 +186,7 @@ class ElasticsearchConfig(BaseModel): # ── Elasticsearch LogStore Adapter ──────────────────────────────────────── + class ElasticsearchLogStore(LogStore): """LogStore implementation backed by Elasticsearch. @@ -266,7 +271,9 @@ async def _generate_actions(): index = _index_name(entry.timestamp) yield {"_index": index, "_source": doc} - success, errors = await async_bulk( + success: int + errors: list[Any] + success, errors = await async_bulk( # type: ignore[assignment] client=self._client, actions=_generate_actions(), chunk_size=self._config.bulk_batch_size, @@ -281,7 +288,7 @@ async def _generate_actions(): else: logger.debug("Bulk write succeeded", extra={"count": success}) - async def query(self, filter: LogFilter) -> AsyncIterator[LogEntry]: + async def query(self, filter: LogFilter) -> AsyncIterator[LogEntry]: # type: ignore[override,misc] assert self._client is not None, "ElasticsearchLogStore not started" body = _build_query_body(filter) diff --git a/infrastructure/elasticsearch/tests/test_elasticsearch_log_store.py b/infrastructure/elasticsearch/tests/test_elasticsearch_log_store.py index dd9ff05..bef617e 100644 --- a/infrastructure/elasticsearch/tests/test_elasticsearch_log_store.py +++ b/infrastructure/elasticsearch/tests/test_elasticsearch_log_store.py @@ -2,8 +2,8 @@ from __future__ import annotations -from datetime import datetime, timezone -from unittest.mock import AsyncMock, MagicMock, patch +from datetime import UTC, datetime +from unittest.mock import AsyncMock, patch import pytest @@ -27,7 +27,7 @@ def config() -> ElasticsearchConfig: @pytest.fixture def entry() -> LogEntry: return LogEntry( - timestamp=datetime(2026, 6, 13, 12, 0, 0, tzinfo=timezone.utc), + timestamp=datetime(2026, 6, 13, 12, 0, 0, tzinfo=UTC), service="ingestion-service", level="ERROR", message="Connection refused", @@ -39,7 +39,7 @@ def entry() -> LogEntry: class TestIndexNaming: def test_index_name_format(self) -> None: - dt = datetime(2026, 6, 13, 12, 0, 0, tzinfo=timezone.utc) + dt = datetime(2026, 6, 13, 12, 0, 0, tzinfo=UTC) name = _index_name(dt) assert name == "rp-tl-2026.06.13" @@ -48,11 +48,11 @@ def test_index_name_defaults_to_utc_now(self) -> None: assert name.startswith("rp-tl-") def test_index_name_pads_single_digit_month(self) -> None: - dt = datetime(2026, 1, 5, tzinfo=timezone.utc) + dt = datetime(2026, 1, 5, tzinfo=UTC) assert _index_name(dt) == "rp-tl-2026.01.05" def test_index_name_pads_single_digit_day(self) -> None: - dt = datetime(2026, 12, 1, tzinfo=timezone.utc) + dt = datetime(2026, 12, 1, tzinfo=UTC) assert _index_name(dt) == "rp-tl-2026.12.01" @@ -70,7 +70,7 @@ def test_build_doc_structure(self, entry: LogEntry) -> None: def test_build_doc_without_trace_span(self) -> None: entry = LogEntry( - timestamp=datetime(2026, 6, 13, tzinfo=timezone.utc), + timestamp=datetime(2026, 6, 13, tzinfo=UTC), service="test", level="INFO", message="hello", @@ -128,8 +128,8 @@ def test_filter_by_trace_id(self) -> None: assert body["query"]["bool"]["must"] == [{"term": {"trace_id": "abc123"}}] def test_filter_by_time_range(self) -> None: - start = datetime(2026, 6, 13, tzinfo=timezone.utc) - end = datetime(2026, 6, 14, tzinfo=timezone.utc) + start = datetime(2026, 6, 13, tzinfo=UTC) + end = datetime(2026, 6, 14, tzinfo=UTC) f = LogFilter(start_time=start, end_time=end) body = _build_query_body(f) time_range = body["query"]["bool"]["must"][0]["range"]["@timestamp"] diff --git a/infrastructure/monitoring/otel/__init__.py b/infrastructure/monitoring/otel/__init__.py index e882d24..0bf3a3e 100644 --- a/infrastructure/monitoring/otel/__init__.py +++ b/infrastructure/monitoring/otel/__init__.py @@ -2,8 +2,8 @@ from infrastructure.monitoring.otel.instrumentation import ( OpenTelemetryMiddleware, - setup_tracing, get_trace_context, + setup_tracing, ) from infrastructure.monitoring.otel.otel_tracer_provider import ( OTelSpan, @@ -16,6 +16,6 @@ "OTelTracer", "OTelTracerProvider", "OpenTelemetryMiddleware", - "setup_tracing", "get_trace_context", + "setup_tracing", ] diff --git a/infrastructure/monitoring/otel/instrumentation.py b/infrastructure/monitoring/otel/instrumentation.py index 07f1c24..ef3dd32 100644 --- a/infrastructure/monitoring/otel/instrumentation.py +++ b/infrastructure/monitoring/otel/instrumentation.py @@ -1,8 +1,6 @@ """Tracing setup utilities and ASGI middleware for FastAPI services.""" import logging -from collections.abc import Awaitable, Callable -from typing import Any from fastapi import FastAPI from opentelemetry import trace as otel_trace @@ -59,9 +57,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: await self.app(scope, receive, send) return - provider: OTelTracerProvider | None = getattr( - getattr(self.app, "state", None), "tracer_provider", None - ) + provider: OTelTracerProvider | None = getattr(getattr(self.app, "state", None), "tracer_provider", None) if provider is None: await self.app(scope, receive, send) return diff --git a/infrastructure/monitoring/otel/otel_tracer_provider.py b/infrastructure/monitoring/otel/otel_tracer_provider.py index 368e4ed..8d91240 100644 --- a/infrastructure/monitoring/otel/otel_tracer_provider.py +++ b/infrastructure/monitoring/otel/otel_tracer_provider.py @@ -8,10 +8,10 @@ from opentelemetry.propagators.composite import CompositeHTTPPropagator from opentelemetry.propagators.textmap import Setter, TextMapPropagator from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace import Span as OTelSDKSpan from opentelemetry.sdk.trace import TracerProvider as OTelSDKTracerProvider from opentelemetry.sdk.trace.export import BatchSpanProcessor -from opentelemetry.trace import NonRecordingSpan, SpanContext as OTelSpanContext +from opentelemetry.trace import NonRecordingSpan +from opentelemetry.trace import SpanContext as OTelSpanContext from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from shared.observability.tracing import Span, SpanContext, SpanKind, SpanStatus, Tracer, TracerProvider @@ -108,9 +108,7 @@ def start_span( if context is not None: parent_octx = _to_span_context(context) if parent_octx is not None: - otel_ctx = otel_trace.set_span_in_context( - NonRecordingSpan(parent_octx) - ) + otel_ctx = otel_trace.set_span_in_context(NonRecordingSpan(parent_octx)) otel_span = self._tracer.start_span( name=name, @@ -137,9 +135,11 @@ def __init__( sdk_provider.add_span_processor(span_processor) self._provider = sdk_provider - self._propagator: TextMapPropagator = CompositeHTTPPropagator([ - TraceContextTextMapPropagator(), - ]) + self._propagator: TextMapPropagator = CompositeHTTPPropagator( + [ + TraceContextTextMapPropagator(), + ] + ) otel_trace.set_tracer_provider(sdk_provider) diff --git a/infrastructure/rabbitmq/rabbitmq_event_bus.py b/infrastructure/rabbitmq/rabbitmq_event_bus.py index d929d8b..de94464 100644 --- a/infrastructure/rabbitmq/rabbitmq_event_bus.py +++ b/infrastructure/rabbitmq/rabbitmq_event_bus.py @@ -46,7 +46,7 @@ class RabbitMQConfig(BaseModel): class _SubscriberInfo: - __slots__ = ("queue_name", "channel", "consumer_tag", "topic") + __slots__ = ("channel", "consumer_tag", "queue_name", "topic") def __init__( self, @@ -174,9 +174,7 @@ async def subscribe(self, topic: str, handler: EventHandler) -> None: async def health(self) -> bool: if self._closed: return False - if self._connection is None or self._connection.is_closed: - return False - return True + return not (self._connection is None or self._connection.is_closed) async def _resolve_exchange(self) -> AbstractExchange: async with self._lock: @@ -192,9 +190,7 @@ async def _resolve_exchange(self) -> AbstractExchange: await channel.close() return self._exchange - def _make_handler( - self, handler: EventHandler - ) -> Callable[[AbstractIncomingMessage], Awaitable[None]]: + def _make_handler(self, handler: EventHandler) -> Callable[[AbstractIncomingMessage], Awaitable[None]]: async def _on_message(message: AbstractIncomingMessage) -> None: async with message.process(requeue=True): try: diff --git a/infrastructure/rabbitmq/tests/test_rabbitmq_event_bus.py b/infrastructure/rabbitmq/tests/test_rabbitmq_event_bus.py index d095d4b..7616359 100644 --- a/infrastructure/rabbitmq/tests/test_rabbitmq_event_bus.py +++ b/infrastructure/rabbitmq/tests/test_rabbitmq_event_bus.py @@ -39,14 +39,12 @@ class TestRabbitMQEventBus: async def test_start_connects_and_declares_exchange(self, config: RabbitMQConfig) -> None: bus = RabbitMQEventBus(config=config) - mock_channel = AsyncMock() + mock_channel = AsyncMock(is_closed=False) mock_exchange = MagicMock() mock_channel.declare_exchange = AsyncMock(return_value=mock_exchange) - mock_channel.is_closed = False - mock_connection = AsyncMock() + mock_connection = AsyncMock(is_closed=False) mock_connection.channel = AsyncMock(return_value=mock_channel) - mock_connection.is_closed = False with patch("aio_pika.connect_robust", AsyncMock(return_value=mock_connection)): await bus.start() @@ -65,13 +63,11 @@ async def test_start_connects_and_declares_exchange(self, config: RabbitMQConfig async def test_start_is_idempotent(self, config: RabbitMQConfig) -> None: bus = RabbitMQEventBus(config=config) - mock_channel = AsyncMock() + mock_channel = AsyncMock(is_closed=False) mock_channel.declare_exchange = AsyncMock(return_value=MagicMock()) - mock_channel.is_closed = False - mock_connection = AsyncMock() + mock_connection = AsyncMock(is_closed=False) mock_connection.channel = AsyncMock(return_value=mock_channel) - mock_connection.is_closed = False with patch("aio_pika.connect_robust", AsyncMock(return_value=mock_connection)): await bus.start() @@ -81,26 +77,22 @@ async def test_start_is_idempotent(self, config: RabbitMQConfig) -> None: async def test_start_reconnects_if_connection_died(self, config: RabbitMQConfig) -> None: bus = RabbitMQEventBus(config=config) - mock_channel = AsyncMock() + mock_channel = AsyncMock(is_closed=False) mock_channel.declare_exchange = AsyncMock(return_value=MagicMock()) - mock_channel.is_closed = False - mock_connection = AsyncMock() + mock_connection = AsyncMock(is_closed=False) mock_connection.channel = AsyncMock(return_value=mock_channel) - mock_connection.is_closed = False with patch("aio_pika.connect_robust", AsyncMock(return_value=mock_connection)): await bus.start() - bus._connection.is_closed = True + mock_connection.is_closed = True - new_channel = AsyncMock() + new_channel = AsyncMock(is_closed=False) new_channel.declare_exchange = AsyncMock(return_value=MagicMock()) - new_channel.is_closed = False - new_connection = AsyncMock() + new_connection = AsyncMock(is_closed=False) new_connection.channel = AsyncMock(return_value=new_channel) - new_connection.is_closed = False with patch("aio_pika.connect_robust", AsyncMock(return_value=new_connection)): await bus.start() @@ -110,12 +102,8 @@ async def test_start_reconnects_if_connection_died(self, config: RabbitMQConfig) async def test_close_cleans_up_connection(self, config: RabbitMQConfig) -> None: bus = RabbitMQEventBus(config=config) - mock_channel = AsyncMock() - mock_channel.is_closed = False - - mock_connection = AsyncMock() - mock_connection.is_closed = False - + mock_channel = AsyncMock(is_closed=False) + mock_connection = AsyncMock(is_closed=False) bus._channel = mock_channel bus._connection = mock_connection bus._exchange = MagicMock() @@ -146,17 +134,12 @@ async def test_close_skips_if_already_closed(self, config: RabbitMQConfig) -> No async def test_close_closes_subscriber_channels(self, config: RabbitMQConfig) -> None: bus = RabbitMQEventBus(config=config) - sub_channel = AsyncMock() - sub_channel.is_closed = False - sub_info = AsyncMock() sub_info.close = AsyncMock() bus._subscribers = {"test.topic": sub_info} - bus._channel = AsyncMock() - bus._channel.is_closed = False - bus._connection = AsyncMock() - bus._connection.is_closed = False + bus._channel = AsyncMock(is_closed=False) + bus._connection = AsyncMock(is_closed=False) await bus.close() @@ -165,11 +148,9 @@ async def test_close_closes_subscriber_channels(self, config: RabbitMQConfig) -> async def test_publish_serializes_event_and_publishes(self, config: RabbitMQConfig, event: Event) -> None: bus = RabbitMQEventBus(config=config) - mock_exchange = AsyncMock() - mock_exchange.is_closed = False + mock_exchange = AsyncMock(is_closed=False) bus._exchange = mock_exchange - bus._connection = AsyncMock() - bus._connection.is_closed = False + bus._connection = AsyncMock(is_closed=False) await bus.publish(event) @@ -187,11 +168,9 @@ async def test_publish_serializes_event_and_publishes(self, config: RabbitMQConf async def test_publish_with_custom_topic(self, config: RabbitMQConfig, event: Event) -> None: bus = RabbitMQEventBus(config=config) - mock_exchange = AsyncMock() - mock_exchange.is_closed = False + mock_exchange = AsyncMock(is_closed=False) bus._exchange = mock_exchange - bus._connection = AsyncMock() - bus._connection.is_closed = False + bus._connection = AsyncMock(is_closed=False) await bus.publish(event, topic="custom.route") @@ -215,8 +194,7 @@ async def test_publish_raises_when_not_connected(self, config: RabbitMQConfig, e async def test_subscribe_declares_queue_and_binds(self, config: RabbitMQConfig) -> None: bus = RabbitMQEventBus(config=config) - mock_channel = AsyncMock() - mock_channel.is_closed = False + mock_channel = AsyncMock(is_closed=False) mock_queue = AsyncMock() mock_queue.name = "test-queue" mock_queue.consume = AsyncMock(return_value="consumer-tag") @@ -225,9 +203,8 @@ async def test_subscribe_declares_queue_and_binds(self, config: RabbitMQConfig) mock_exchange = MagicMock() - bus._connection = AsyncMock() + bus._connection = AsyncMock(is_closed=False) bus._connection.channel = AsyncMock(return_value=mock_channel) - bus._connection.is_closed = False handler: EventHandler = AsyncMock() @@ -256,8 +233,7 @@ async def test_subscribe_raises_when_not_connected(self, config: RabbitMQConfig) async def test_health_returns_true_when_connected(self, config: RabbitMQConfig) -> None: bus = RabbitMQEventBus(config=config) - bus._connection = MagicMock() - bus._connection.is_closed = False + bus._connection = MagicMock(is_closed=False) bus._closed = False assert await bus.health() is True @@ -274,8 +250,7 @@ async def test_health_returns_false_when_no_connection(self, config: RabbitMQCon async def test_health_returns_false_when_disconnected(self, config: RabbitMQConfig) -> None: bus = RabbitMQEventBus(config=config) - bus._connection = MagicMock() - bus._connection.is_closed = True + bus._connection = MagicMock(is_closed=True) bus._closed = False assert await bus.health() is False @@ -283,7 +258,7 @@ async def test_health_returns_false_when_disconnected(self, config: RabbitMQConf async def test_make_handler_calls_handler_on_message(self, config: RabbitMQConfig, event: Event) -> None: bus = RabbitMQEventBus(config=config) - handler: EventHandler = AsyncMock() + handler = AsyncMock() wrapped = bus._make_handler(handler) mock_message = AsyncMock() @@ -295,6 +270,7 @@ async def test_make_handler_calls_handler_on_message(self, config: RabbitMQConfi await wrapped(mock_message) handler.assert_awaited_once() + assert handler.await_args is not None called_event = handler.await_args.args[0] assert isinstance(called_event, Event) assert called_event.id == event.id @@ -304,7 +280,7 @@ async def test_make_handler_calls_handler_on_message(self, config: RabbitMQConfi async def test_make_handler_does_not_raise_on_failure(self, config: RabbitMQConfig, event: Event) -> None: bus = RabbitMQEventBus(config=config) - handler: EventHandler = AsyncMock(side_effect=ValueError("handler error")) + handler = AsyncMock(side_effect=ValueError("handler error")) wrapped = bus._make_handler(handler) mock_message = AsyncMock() diff --git a/pyproject.toml b/pyproject.toml index fcc748f..3568e0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,9 @@ dev = [ "pytest", "pytest-asyncio", "httpx", + "ruff", + "mypy", + "pre-commit", ] [tool.pytest.ini_options] @@ -47,6 +50,48 @@ python_functions = [ "test_*", ] +[tool.ruff] +target-version = "py313" +line-length = 120 +exclude = [".venv", "build", "dist", "datasets", "docs", "scripts"] + +[tool.ruff.lint] +select = ["E", "F", "I", "N", "W", "UP", "B", "SIM", "ARG", "RUF"] +ignore = ["B905"] +extend-ignore = [ + "E501", +] + +[tool.ruff.lint.per-file-ignores] +"__init__.py" = ["F401"] +"tests/*" = ["ARG"] +"services/*" = ["ARG", "B008"] +"examples/*" = ["ARG", "RUF003"] + +[tool.ruff.format] +quote-style = "double" +indent-style = "space" +line-ending = "auto" + +[tool.mypy] +python_version = "3.13" +ignore_missing_imports = true +check_untyped_defs = true +warn_unused_ignores = true +strict_optional = true +warn_redundant_casts = true +warn_unused_configs = true +warn_return_any = false +exclude = [ + ".venv/", + "build/", + "dist/", + "datasets/", + "examples/", + "scripts/", + "services/", +] + [tool.setuptools.packages.find] include = [ "shared*", diff --git a/services/correlation-service/app/routers/correlation.py b/services/correlation-service/app/routers/correlation.py index ee391b1..7fe130a 100644 --- a/services/correlation-service/app/routers/correlation.py +++ b/services/correlation-service/app/routers/correlation.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from fastapi import APIRouter, Depends @@ -53,7 +53,7 @@ async def correlate_events( ) return CorrelateResponse( - correlation_id=f"corr-{datetime.now(timezone.utc).isoformat()}", + correlation_id=f"corr-{datetime.now(UTC).isoformat()}", total_events=ctx.total_events, groups=group_responses, ungrouped_event_ids=list(ctx.ungrouped_event_ids), diff --git a/services/correlation-service/app/routers/timeline.py b/services/correlation-service/app/routers/timeline.py index f4b2ae3..6d45444 100644 --- a/services/correlation-service/app/routers/timeline.py +++ b/services/correlation-service/app/routers/timeline.py @@ -47,7 +47,6 @@ def _request_event_to_domain(ev: TimelineEventResponse) -> TimelineEvent: def _domain_timeline_to_response(timeline: IncidentTimeline) -> IncidentTimelineResponse: - return IncidentTimelineResponse( incident_id=timeline.incident_id, service=timeline.service, diff --git a/services/correlation-service/app/schemas.py b/services/correlation-service/app/schemas.py index bcc64a6..6921c81 100644 --- a/services/correlation-service/app/schemas.py +++ b/services/correlation-service/app/schemas.py @@ -65,7 +65,9 @@ class CorrelateResponse(BaseModel): total_events: int = Field(description="Number of events processed.") groups: list[CorrelationGroupResponse] = Field(default_factory=list, description="Detected correlation groups.") ungrouped_event_ids: list[str] = Field(default_factory=list, description="Events that did not join any group.") - strategy_counts: dict[str, int] = Field(default_factory=dict, description="Number of matches produced per strategy.") + strategy_counts: dict[str, int] = Field( + default_factory=dict, description="Number of matches produced per strategy." + ) class CorrelateRequest(BaseModel): diff --git a/services/correlation-service/tests/conftest.py b/services/correlation-service/tests/conftest.py index abc94bc..6ca57db 100644 --- a/services/correlation-service/tests/conftest.py +++ b/services/correlation-service/tests/conftest.py @@ -10,6 +10,7 @@ from app.config import CorrelationServiceSettings from app.main import create_app + from shared.domain.correlation.engine import CorrelationEngine from shared.domain.timeline.services import TimelineReconstructor @@ -39,7 +40,7 @@ def app( @pytest.fixture -async def client(app: FastAPI) -> AsyncGenerator[AsyncClient, None]: +async def client(app: FastAPI) -> AsyncGenerator[AsyncClient]: transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac diff --git a/services/correlation-service/tests/test_timeline.py b/services/correlation-service/tests/test_timeline.py index 3d5699c..ed30e77 100644 --- a/services/correlation-service/tests/test_timeline.py +++ b/services/correlation-service/tests/test_timeline.py @@ -1,9 +1,5 @@ -from datetime import datetime, timezone - from httpx import AsyncClient -from shared.domain.timeline.enums import TimelineEventCategory, TimelineEventSource - class TestReconstructEndpoint: async def test_reconstruct_empty_events(self, client: AsyncClient) -> None: diff --git a/services/ingestion-service/app/routers/health.py b/services/ingestion-service/app/routers/health.py index 940d332..5340028 100644 --- a/services/ingestion-service/app/routers/health.py +++ b/services/ingestion-service/app/routers/health.py @@ -1,8 +1,8 @@ from fastapi import APIRouter, Depends from pydantic import BaseModel -from app.dependencies import get_event_bus, get_settings from app.config import IngestionServiceSettings +from app.dependencies import get_event_bus, get_settings from shared.contracts import EventBus router = APIRouter(tags=["health"]) diff --git a/services/ingestion-service/app/services/ingestion_service.py b/services/ingestion-service/app/services/ingestion_service.py index 649f4d7..78b502a 100644 --- a/services/ingestion-service/app/services/ingestion_service.py +++ b/services/ingestion-service/app/services/ingestion_service.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from app.schemas import IngestRequest from shared.contracts import Event, EventBus, ServiceName, TelemetryEvent @@ -36,9 +36,9 @@ async def process_telemetry(self, request: IngestRequest) -> str: def _parse_timestamp(self, raw: str | None) -> datetime: if raw is None: - return datetime.now(timezone.utc) + return datetime.now(UTC) try: return datetime.fromisoformat(raw) except ValueError: logger.warning("Invalid timestamp format, falling back to now", extra={"raw": raw}) - return datetime.now(timezone.utc) + return datetime.now(UTC) diff --git a/services/ingestion-service/tests/conftest.py b/services/ingestion-service/tests/conftest.py index e618465..6f0ed36 100644 --- a/services/ingestion-service/tests/conftest.py +++ b/services/ingestion-service/tests/conftest.py @@ -10,8 +10,8 @@ sys.path.insert(0, str(Path(__file__).parent.parent)) from app.config import IngestionServiceSettings -from app.dependencies import get_event_bus, get_settings from app.main import create_app + from shared.contracts import EventBus @@ -48,7 +48,7 @@ def app( @pytest.fixture -async def client(app: FastAPI) -> AsyncGenerator[AsyncClient, None]: +async def client(app: FastAPI) -> AsyncGenerator[AsyncClient]: transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield ac diff --git a/services/ingestion-service/tests/test_config.py b/services/ingestion-service/tests/test_config.py index a7815d3..9526a2b 100644 --- a/services/ingestion-service/tests/test_config.py +++ b/services/ingestion-service/tests/test_config.py @@ -1,7 +1,5 @@ -import pytest -from pydantic import Field - from app.config import IngestionServiceSettings + from shared.config import BaseAppSettings, load_settings diff --git a/services/ingestion-service/tests/test_health.py b/services/ingestion-service/tests/test_health.py index fae5939..47d659c 100644 --- a/services/ingestion-service/tests/test_health.py +++ b/services/ingestion-service/tests/test_health.py @@ -3,9 +3,7 @@ class TestHealthEndpoint: - async def test_health_returns_healthy_when_event_bus_connected( - self, client: AsyncClient - ) -> None: + async def test_health_returns_healthy_when_event_bus_connected(self, client: AsyncClient) -> None: response = await client.get("/health") assert response.status_code == 200 data = response.json() @@ -14,9 +12,7 @@ async def test_health_returns_healthy_when_event_bus_connected( assert data["environment"] == "test" assert data["event_bus_connected"] is True - async def test_health_returns_degraded_when_event_bus_disconnected( - self, client: AsyncClient, app: FastAPI - ) -> None: + async def test_health_returns_degraded_when_event_bus_disconnected(self, client: AsyncClient, app: FastAPI) -> None: from unittest.mock import AsyncMock app.state.event_bus.health = AsyncMock(return_value=False) diff --git a/services/ingestion-service/tests/test_ingest.py b/services/ingestion-service/tests/test_ingest.py index 43b9bcb..73a5252 100644 --- a/services/ingestion-service/tests/test_ingest.py +++ b/services/ingestion-service/tests/test_ingest.py @@ -1,5 +1,3 @@ -import json - from httpx import AsyncClient @@ -65,14 +63,10 @@ async def test_ingest_accepts_optional_timestamp(self, client: AsyncClient) -> N response = await client.post("/api/v1/ingest", json=payload) assert response.status_code == 202 - async def test_ingest_handles_concurrent_requests( - self, client: AsyncClient - ) -> None: + async def test_ingest_handles_concurrent_requests(self, client: AsyncClient) -> None: import asyncio payload = {"metric": "cpu", "value": 50.0, "source": "test"} - responses = await asyncio.gather( - *[client.post("/api/v1/ingest", json=payload) for _ in range(5)] - ) + responses = await asyncio.gather(*[client.post("/api/v1/ingest", json=payload) for _ in range(5)]) for resp in responses: assert resp.status_code == 202 diff --git a/shared/config/base.py b/shared/config/base.py index 805b8d5..30c2631 100644 --- a/shared/config/base.py +++ b/shared/config/base.py @@ -1,7 +1,7 @@ """Base settings used by RootPilot services.""" from pathlib import Path -from typing import Literal, TypeVar +from typing import Any, Literal from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict @@ -9,8 +9,6 @@ Environment = Literal["local", "development", "staging", "production", "test"] LogLevel = Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] -SettingsT = TypeVar("SettingsT", bound="BaseAppSettings") - class BaseAppSettings(BaseSettings): """Common strongly typed configuration shared by RootPilot services.""" @@ -55,11 +53,11 @@ def resolved_otel_service_name(self) -> str: return self.otel_service_name or self.service_name -def load_settings( - settings_cls: type[SettingsT] = BaseAppSettings, +def load_settings[SettingsT: BaseAppSettings]( + settings_cls: type[SettingsT] = BaseAppSettings, # type: ignore[assignment] *, env_file: str | Path | None = None, - **overrides: object, + **overrides: Any, ) -> SettingsT: """Create a settings instance without blocking async request paths. diff --git a/shared/config/tests/test_base.py b/shared/config/tests/test_base.py index 7aecd83..77d925b 100644 --- a/shared/config/tests/test_base.py +++ b/shared/config/tests/test_base.py @@ -1,7 +1,7 @@ from pathlib import Path -from pydantic import Field import pytest +from pydantic import Field from shared.config import BaseAppSettings, load_settings diff --git a/shared/contracts/__init__.py b/shared/contracts/__init__.py index b1881f2..73564fa 100644 --- a/shared/contracts/__init__.py +++ b/shared/contracts/__init__.py @@ -4,8 +4,8 @@ Event, IncidentDetectedEvent, InvestigationRequestedEvent, - Severity, ServiceName, + Severity, TelemetryEvent, ) from shared.contracts.interfaces import ( @@ -29,7 +29,7 @@ "LogEntry", "LogFilter", "LogStore", - "Severity", "ServiceName", + "Severity", "TelemetryEvent", ] diff --git a/shared/contracts/events/__init__.py b/shared/contracts/events/__init__.py index 0b4fbb5..fc027fb 100644 --- a/shared/contracts/events/__init__.py +++ b/shared/contracts/events/__init__.py @@ -1,7 +1,7 @@ """Event schemas for RootPilot messaging.""" from shared.contracts.events.base import Event -from shared.contracts.events.enums import Severity, ServiceName +from shared.contracts.events.enums import ServiceName, Severity from shared.contracts.events.incident import IncidentDetectedEvent from shared.contracts.events.investigation import InvestigationRequestedEvent from shared.contracts.events.telemetry import TelemetryEvent @@ -11,9 +11,8 @@ "Event", "IncidentDetectedEvent", "InvestigationRequestedEvent", - "Severity", "ServiceName", + "Severity", "TelemetryEvent", "TraceContext", ] - diff --git a/shared/contracts/events/base.py b/shared/contracts/events/base.py index 432e808..18c5561 100644 --- a/shared/contracts/events/base.py +++ b/shared/contracts/events/base.py @@ -1,6 +1,6 @@ """Base event type for provider-agnostic messaging.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any from uuid import uuid4 @@ -19,7 +19,7 @@ class Event(BaseModel): payload: dict[str, Any] = Field(default_factory=dict, description="Arbitrary event data.") id: str = Field(default_factory=_new_id, description="Unique event identifier (auto-generated hex UUID).") timestamp: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), + default_factory=lambda: datetime.now(UTC), description="When the event was created (UTC).", ) trace_context: SpanContext | None = Field( diff --git a/shared/contracts/events/incident.py b/shared/contracts/events/incident.py index c768d3f..730d8a6 100644 --- a/shared/contracts/events/incident.py +++ b/shared/contracts/events/incident.py @@ -1,6 +1,6 @@ """Incident event schemas for detection and lifecycle.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from pydantic import BaseModel, Field @@ -18,6 +18,6 @@ class IncidentDetectedEvent(BaseModel): default_factory=list, description="IDs of the source telemetry events that triggered this." ) detected_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), + default_factory=lambda: datetime.now(UTC), description="When the incident was first detected (UTC).", ) diff --git a/shared/contracts/events/investigation.py b/shared/contracts/events/investigation.py index a4bf6f0..62064e7 100644 --- a/shared/contracts/events/investigation.py +++ b/shared/contracts/events/investigation.py @@ -1,6 +1,6 @@ """Investigation event schema for AI-driven root cause analysis.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any, Literal from pydantic import BaseModel, Field @@ -17,6 +17,6 @@ class InvestigationRequestedEvent(BaseModel): default="standard", description="Investigation depth: quick summary, standard analysis, or deep dive." ) requested_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), + default_factory=lambda: datetime.now(UTC), description="When the investigation was requested (UTC).", ) diff --git a/shared/contracts/events/telemetry.py b/shared/contracts/events/telemetry.py index 7cdd767..f9e46dc 100644 --- a/shared/contracts/events/telemetry.py +++ b/shared/contracts/events/telemetry.py @@ -1,6 +1,6 @@ """Telemetry event schema for ingestion and metric forwarding.""" -from datetime import datetime, timezone +from datetime import UTC, datetime from pydantic import BaseModel, Field @@ -13,6 +13,6 @@ class TelemetryEvent(BaseModel): tags: dict[str, str] = Field(default_factory=dict, description="Dimension key-value pairs.") source: str = Field(description="Service or component that produced the telemetry.") timestamp: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), + default_factory=lambda: datetime.now(UTC), description="When the measurement was taken (UTC).", ) diff --git a/shared/contracts/events/tests/test_contracts.py b/shared/contracts/events/tests/test_contracts.py index b802ba4..d29be6c 100644 --- a/shared/contracts/events/tests/test_contracts.py +++ b/shared/contracts/events/tests/test_contracts.py @@ -1,7 +1,6 @@ """Tests for telemetry and incident event contracts.""" import json -from datetime import datetime, timezone from shared.contracts.events import ( Event, @@ -75,9 +74,7 @@ def test_telemetry_in_envelope(self) -> None: assert envelope.payload["metric"] == "cpu" def test_incident_in_envelope(self) -> None: - incident = IncidentDetectedEvent( - incident_id="INC-001", severity=Severity.ERROR, service="svc", title="err" - ) + incident = IncidentDetectedEvent(incident_id="INC-001", severity=Severity.ERROR, service="svc", title="err") envelope = Event(source="svc", topic="incident.detected", payload=incident.model_dump()) assert envelope.payload["incident_id"] == "INC-001" diff --git a/shared/contracts/interfaces/__init__.py b/shared/contracts/interfaces/__init__.py index 574617b..f9bc3c2 100644 --- a/shared/contracts/interfaces/__init__.py +++ b/shared/contracts/interfaces/__init__.py @@ -1,7 +1,7 @@ """Provider-agnostic abstraction interfaces for RootPilot.""" from shared.contracts.interfaces.event_bus import EventBus -from shared.contracts.interfaces.llm_provider import LLMProvider, LLMMessage, LLMResponse +from shared.contracts.interfaces.llm_provider import LLMMessage, LLMProvider, LLMResponse from shared.contracts.interfaces.log_store import LogEntry, LogFilter, LogStore, SortOrder __all__ = [ diff --git a/shared/contracts/interfaces/event_bus.py b/shared/contracts/interfaces/event_bus.py index d4b2aaf..4e6922f 100644 --- a/shared/contracts/interfaces/event_bus.py +++ b/shared/contracts/interfaces/event_bus.py @@ -1,8 +1,8 @@ """Event bus abstraction for provider-agnostic async messaging.""" from abc import ABC, abstractmethod -from collections.abc import Coroutine -from typing import Any, Callable +from collections.abc import Callable, Coroutine +from typing import Any from shared.contracts.events.base import Event diff --git a/shared/contracts/interfaces/log_store.py b/shared/contracts/interfaces/log_store.py index 12633a6..9357949 100644 --- a/shared/contracts/interfaces/log_store.py +++ b/shared/contracts/interfaces/log_store.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any from pydantic import BaseModel, Field @@ -19,7 +19,7 @@ class LogEntry(BaseModel): metadata: dict[str, Any] = Field(default_factory=dict, description="Arbitrary structured context.") -class SortOrder(str, Enum): +class SortOrder(StrEnum): ASC = "asc" DESC = "desc" diff --git a/shared/domain/__init__.py b/shared/domain/__init__.py index e7669a6..71144d8 100644 --- a/shared/domain/__init__.py +++ b/shared/domain/__init__.py @@ -15,7 +15,15 @@ TimeWindowStrategy, TraceIdStrategy, ) -from shared.domain.timeline import EventClassifier, IncidentTimeline, TimelineEvent, TimelineEventCategory, TimelineEventSource, TimelineReconstructor, TimelineWindow +from shared.domain.timeline import ( + EventClassifier, + IncidentTimeline, + TimelineEvent, + TimelineEventCategory, + TimelineEventSource, + TimelineReconstructor, + TimelineWindow, +) __all__ = [ "CorrelationEngine", @@ -39,4 +47,3 @@ "TimelineWindow", "TraceIdStrategy", ] - diff --git a/shared/domain/correlation/__init__.py b/shared/domain/correlation/__init__.py index a901bb7..9276783 100644 --- a/shared/domain/correlation/__init__.py +++ b/shared/domain/correlation/__init__.py @@ -19,8 +19,8 @@ ) __all__ = [ - "CorrelationEngine", "CorrelationContext", + "CorrelationEngine", "CorrelationGroup", "CorrelationMatch", "CorrelationPipeline", @@ -34,4 +34,3 @@ "TimeWindowStrategy", "TraceIdStrategy", ] - diff --git a/shared/domain/correlation/grouping/__init__.py b/shared/domain/correlation/grouping/__init__.py index bb872c2..15ea95f 100644 --- a/shared/domain/correlation/grouping/__init__.py +++ b/shared/domain/correlation/grouping/__init__.py @@ -6,6 +6,6 @@ __all__ = [ "SpanNode", "TraceGroup", - "TraceTree", "TraceGroupingService", + "TraceTree", ] diff --git a/shared/domain/correlation/grouping/models.py b/shared/domain/correlation/grouping/models.py index ed93901..1e601a7 100644 --- a/shared/domain/correlation/grouping/models.py +++ b/shared/domain/correlation/grouping/models.py @@ -1,7 +1,5 @@ from pydantic import BaseModel, Field -from shared.domain.timeline.models import TimelineEvent - class SpanNode(BaseModel): """A span node in a trace tree, with parent-child relationships.""" diff --git a/shared/domain/correlation/grouping/tests/test_trace_grouping.py b/shared/domain/correlation/grouping/tests/test_trace_grouping.py index db46cb8..2543ee4 100644 --- a/shared/domain/correlation/grouping/tests/test_trace_grouping.py +++ b/shared/domain/correlation/grouping/tests/test_trace_grouping.py @@ -1,7 +1,6 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from shared.domain.correlation.grouping import TraceGroupingService -from shared.domain.correlation.grouping.models import SpanNode, TraceTree from shared.domain.timeline.enums import TimelineEventCategory, TimelineEventSource from shared.domain.timeline.models import TimelineEvent @@ -12,9 +11,9 @@ def _event( span_id: str | None = None, parent_span_id: str | None = None, service: str = "api", - ts_offset: int = 0, + _ts_offset: int = 0, ) -> TimelineEvent: - base = datetime(2026, 6, 14, 10, 0, 0, tzinfo=timezone.utc) + base = datetime(2026, 6, 14, 10, 0, 0, tzinfo=UTC) return TimelineEvent( event_id=event_id, category=TimelineEventCategory.METRIC_ANOMALY, @@ -41,9 +40,11 @@ async def test_events_without_trace_id_are_ignored(self) -> None: assert trees == [] async def test_single_trace_single_span(self) -> None: - trees = self.service.build_trace_trees([ - _event("a", trace_id="t1", span_id="s1"), - ]) + trees = self.service.build_trace_trees( + [ + _event("a", trace_id="t1", span_id="s1"), + ] + ) assert len(trees) == 1 tree = trees[0] assert tree.trace_id == "t1" @@ -52,11 +53,13 @@ async def test_single_trace_single_span(self) -> None: assert tree.root_spans[0].span_id == "s1" async def test_parent_child_span_hierarchy(self) -> None: - trees = self.service.build_trace_trees([ - _event("a", trace_id="t1", span_id="s1"), - _event("b", trace_id="t1", span_id="s2", parent_span_id="s1"), - _event("c", trace_id="t1", span_id="s3", parent_span_id="s2"), - ]) + trees = self.service.build_trace_trees( + [ + _event("a", trace_id="t1", span_id="s1"), + _event("b", trace_id="t1", span_id="s2", parent_span_id="s1"), + _event("c", trace_id="t1", span_id="s3", parent_span_id="s2"), + ] + ) assert len(trees) == 1 tree = trees[0] assert tree.span_count == 3 @@ -69,39 +72,47 @@ async def test_parent_child_span_hierarchy(self) -> None: assert root.children[0].children[0].span_id == "s3" async def test_multiple_root_spans(self) -> None: - trees = self.service.build_trace_trees([ - _event("a", trace_id="t1", span_id="s1"), - _event("b", trace_id="t1", span_id="s2"), - ]) + trees = self.service.build_trace_trees( + [ + _event("a", trace_id="t1", span_id="s1"), + _event("b", trace_id="t1", span_id="s2"), + ] + ) assert len(trees) == 1 tree = trees[0] assert len(tree.root_spans) == 2 async def test_multiple_traces(self) -> None: - trees = self.service.build_trace_trees([ - _event("a", trace_id="t1", span_id="s1"), - _event("b", trace_id="t2", span_id="s2"), - ]) + trees = self.service.build_trace_trees( + [ + _event("a", trace_id="t1", span_id="s1"), + _event("b", trace_id="t2", span_id="s2"), + ] + ) assert len(trees) == 2 trace_ids = {t.trace_id for t in trees} assert trace_ids == {"t1", "t2"} async def test_multiple_events_per_span(self) -> None: - trees = self.service.build_trace_trees([ - _event("a", trace_id="t1", span_id="s1"), - _event("b", trace_id="t1", span_id="s1"), - ]) + trees = self.service.build_trace_trees( + [ + _event("a", trace_id="t1", span_id="s1"), + _event("b", trace_id="t1", span_id="s1"), + ] + ) assert len(trees) == 1 tree = trees[0] root = tree.root_spans[0] assert sorted(root.event_ids) == ["a", "b"] async def test_depth_calculation(self) -> None: - trees = self.service.build_trace_trees([ - _event("a", trace_id="t1", span_id="s1"), - _event("b", trace_id="t1", span_id="s2", parent_span_id="s1"), - _event("c", trace_id="t1", span_id="s3", parent_span_id="s2"), - ]) + trees = self.service.build_trace_trees( + [ + _event("a", trace_id="t1", span_id="s1"), + _event("b", trace_id="t1", span_id="s2", parent_span_id="s1"), + _event("c", trace_id="t1", span_id="s3", parent_span_id="s2"), + ] + ) assert trees[0].depth == 3 @@ -110,10 +121,12 @@ def setup_method(self) -> None: self.service = TraceGroupingService() async def test_build_trace_groups(self) -> None: - groups = self.service.build_trace_groups([ - _event("a", trace_id="t1", span_id="s1", service="api"), - _event("b", trace_id="t1", span_id="s2", service="db"), - ]) + groups = self.service.build_trace_groups( + [ + _event("a", trace_id="t1", span_id="s1", service="api"), + _event("b", trace_id="t1", span_id="s2", service="db"), + ] + ) assert len(groups) == 1 group = groups[0] assert group.trace_id == "t1" @@ -123,9 +136,11 @@ async def test_build_trace_groups(self) -> None: assert group.tree is not None async def test_trace_group_no_spans(self) -> None: - groups = self.service.build_trace_groups([ - _event("a", trace_id="t1"), - _event("b", trace_id="t1"), - ]) + groups = self.service.build_trace_groups( + [ + _event("a", trace_id="t1"), + _event("b", trace_id="t1"), + ] + ) assert len(groups) == 1 assert groups[0].span_count == 0 diff --git a/shared/domain/correlation/grouping/trace_grouping.py b/shared/domain/correlation/grouping/trace_grouping.py index 6bdc1f6..96bcbf6 100644 --- a/shared/domain/correlation/grouping/trace_grouping.py +++ b/shared/domain/correlation/grouping/trace_grouping.py @@ -77,10 +77,7 @@ def _build_single_tree(self, trace_id: str, events: list[TimelineEvent]) -> Trac for span in span_map.values(): span.children = children_map.get(span.span_id, []) - root_spans = [ - s for s in span_map.values() - if s.parent_span_id is None or s.parent_span_id not in span_map - ] + root_spans = [s for s in span_map.values() if s.parent_span_id is None or s.parent_span_id not in span_map] return TraceTree( trace_id=trace_id, diff --git a/shared/domain/correlation/models.py b/shared/domain/correlation/models.py index 35d1c36..fcd0f7a 100644 --- a/shared/domain/correlation/models.py +++ b/shared/domain/correlation/models.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from pydantic import BaseModel, Field @@ -34,7 +34,7 @@ class CorrelationResult(BaseModel): strategy_counts: dict[str, int] = Field(default_factory=dict, description="Matches per strategy.") duration_ms: float = Field(default=0.0, description="Pipeline execution time in ms.") created_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), + default_factory=lambda: datetime.now(UTC), description="When the result was produced.", ) diff --git a/shared/domain/correlation/pipeline.py b/shared/domain/correlation/pipeline.py index 984daa4..6db989d 100644 --- a/shared/domain/correlation/pipeline.py +++ b/shared/domain/correlation/pipeline.py @@ -48,7 +48,9 @@ async def run(self, context: CorrelationContext) -> CorrelationResult: duration_ms=round(duration, 2), ) - def _merge_into_groups(self, matches: list[CorrelationMatch], context: CorrelationContext) -> list[CorrelationGroup]: + def _merge_into_groups( + self, matches: list[CorrelationMatch], context: CorrelationContext + ) -> list[CorrelationGroup]: adj: dict[str, set[str]] = defaultdict(set) event_scores: dict[str, list[float]] = defaultdict(list) event_strategies: dict[str, set[CorrelationStrategyType]] = defaultdict(set) @@ -83,9 +85,7 @@ def _merge_into_groups(self, matches: list[CorrelationMatch], context: Correlati combined_scores: dict[str, float] = {} for eid in group: for strategy_key, sc in strategy_scores.get(eid, {}).items(): - combined_scores[strategy_key] = max( - combined_scores.get(strategy_key, 0.0), sc - ) + combined_scores[strategy_key] = max(combined_scores.get(strategy_key, 0.0), sc) groups.append( CorrelationGroup( group_id=uuid4().hex, diff --git a/shared/domain/correlation/strategies/span_relation.py b/shared/domain/correlation/strategies/span_relation.py index fb35e53..6a56cfb 100644 --- a/shared/domain/correlation/strategies/span_relation.py +++ b/shared/domain/correlation/strategies/span_relation.py @@ -1,6 +1,6 @@ from shared.domain.correlation.enums import CorrelationSignal, CorrelationStrategyType from shared.domain.correlation.grouping import TraceGroupingService -from shared.domain.correlation.grouping.models import SpanNode, TraceTree +from shared.domain.correlation.grouping.models import TraceTree from shared.domain.correlation.models import CorrelationContext, CorrelationMatch from shared.domain.correlation.strategies.base import CorrelationStrategy @@ -71,9 +71,11 @@ def _score_pair(self, eid_a: str, eid_b: str, tree: TraceTree) -> tuple[float, C for sa in spans_a: for sb in spans_b: - if (sa.parent_span_id is not None - and sb.parent_span_id is not None - and sa.parent_span_id == sb.parent_span_id): + if ( + sa.parent_span_id is not None + and sb.parent_span_id is not None + and sa.parent_span_id == sb.parent_span_id + ): return self.SIBLING_SCORE, CorrelationSignal.SPAN_SIBLING return self.SAME_TRACE_SCORE, CorrelationSignal.TRACE_MATCH diff --git a/shared/domain/correlation/tests/test_engine.py b/shared/domain/correlation/tests/test_engine.py index 3980364..3d863ca 100644 --- a/shared/domain/correlation/tests/test_engine.py +++ b/shared/domain/correlation/tests/test_engine.py @@ -15,7 +15,7 @@ def _event( ) -> TimelineEvent: import datetime - base = datetime.datetime(2026, 6, 12, 10, 0, 0, tzinfo=datetime.timezone.utc) + base = datetime.datetime(2026, 6, 12, 10, 0, 0, tzinfo=datetime.UTC) return TimelineEvent( event_id=event_id, category=TimelineEventCategory(category), @@ -43,19 +43,23 @@ async def test_single_event_no_group(self) -> None: async def test_two_correlated_events(self) -> None: engine = CorrelationEngine() - result = await engine.correlate([ - _event("a", trace_id="t1"), - _event("b", trace_id="t1"), - ]) + result = await engine.correlate( + [ + _event("a", trace_id="t1"), + _event("b", trace_id="t1"), + ] + ) assert result.total_events == 2 assert len(result.groups) == 1 async def test_noise_filtered(self) -> None: engine = CorrelationEngine() - result = await engine.correlate([ - _event("a", ts_offset=0), - _event("b", ts_offset=100), - ]) + result = await engine.correlate( + [ + _event("a", ts_offset=0), + _event("b", ts_offset=100), + ] + ) assert result.total_events == 2 assert len(result.groups) == 0 @@ -63,16 +67,20 @@ async def test_with_dependency_store(self) -> None: store = InMemoryGraphStore() await store.add_edge(DependencyEdge(source="api", target="db")) engine = CorrelationEngine(store=store) - result = await engine.correlate([ - _event("a", service="api"), - _event("b", service="db"), - ]) + result = await engine.correlate( + [ + _event("a", service="api"), + _event("b", service="db"), + ] + ) assert len(result.groups) == 1 async def test_strategy_counts_reported(self) -> None: engine = CorrelationEngine() - result = await engine.correlate([ - _event("a", trace_id="t1"), - _event("b", trace_id="t1"), - ]) + result = await engine.correlate( + [ + _event("a", trace_id="t1"), + _event("b", trace_id="t1"), + ] + ) assert result.strategy_counts.get("trace_id", 0) == 1 diff --git a/shared/domain/correlation/tests/test_pipeline.py b/shared/domain/correlation/tests/test_pipeline.py index 7c2fb3a..78b452f 100644 --- a/shared/domain/correlation/tests/test_pipeline.py +++ b/shared/domain/correlation/tests/test_pipeline.py @@ -2,7 +2,7 @@ from shared.domain.correlation.models import CorrelationContext from shared.domain.correlation.pipeline import CorrelationPipeline -from shared.domain.correlation.strategies import TraceIdStrategy, TimeWindowStrategy +from shared.domain.correlation.strategies import TimeWindowStrategy, TraceIdStrategy from shared.domain.timeline.enums import TimelineEventCategory, TimelineEventSource from shared.domain.timeline.models import TimelineEvent @@ -10,7 +10,7 @@ def _event(event_id: str, ts_offset: int = 0, trace_id: str | None = None) -> TimelineEvent: import datetime - base = datetime.datetime(2026, 6, 12, 10, 0, 0, tzinfo=datetime.timezone.utc) + base = datetime.datetime(2026, 6, 12, 10, 0, 0, tzinfo=datetime.UTC) return TimelineEvent( event_id=event_id, category=TimelineEventCategory.METRIC_ANOMALY, diff --git a/shared/domain/correlation/tests/test_span_relation.py b/shared/domain/correlation/tests/test_span_relation.py index b38e643..fc99a18 100644 --- a/shared/domain/correlation/tests/test_span_relation.py +++ b/shared/domain/correlation/tests/test_span_relation.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from shared.domain.correlation.enums import CorrelationSignal, CorrelationStrategyType from shared.domain.correlation.models import CorrelationContext @@ -19,7 +19,7 @@ def _event( event_id=event_id, category=TimelineEventCategory.METRIC_ANOMALY, source=TimelineEventSource.TELEMETRY, - timestamp=datetime(2026, 6, 14, 10, 0, 0, tzinfo=timezone.utc), + timestamp=datetime(2026, 6, 14, 10, 0, 0, tzinfo=UTC), service_name="api", title=f"event {event_id}", trace_id=trace_id, @@ -43,57 +43,69 @@ async def test_single_event_no_match(self) -> None: assert matches == [] async def test_parent_child_scores_highest(self) -> None: - ctx = CorrelationContext(events=[ - _event("a", span_id="s1"), - _event("b", span_id="s2", parent_span_id="s1"), - ]) + ctx = CorrelationContext( + events=[ + _event("a", span_id="s1"), + _event("b", span_id="s2", parent_span_id="s1"), + ] + ) matches = await self.strategy.correlate(ctx) assert len(matches) == 1 assert matches[0].score == SpanRelationStrategy.PARENT_CHILD_SCORE assert matches[0].signal == CorrelationSignal.SPAN_PARENT_CHILD async def test_siblings_score_middle(self) -> None: - ctx = CorrelationContext(events=[ - _event("a", span_id="s1", parent_span_id="s0"), - _event("b", span_id="s2", parent_span_id="s0"), - ]) + ctx = CorrelationContext( + events=[ + _event("a", span_id="s1", parent_span_id="s0"), + _event("b", span_id="s2", parent_span_id="s0"), + ] + ) matches = await self.strategy.correlate(ctx) assert len(matches) == 1 assert matches[0].score == SpanRelationStrategy.SIBLING_SCORE assert matches[0].signal == CorrelationSignal.SPAN_SIBLING async def test_same_trace_no_span_relation_scores_lowest(self) -> None: - ctx = CorrelationContext(events=[ - _event("a", trace_id=TRACE, span_id="s1"), - _event("b", trace_id=TRACE, span_id="s2"), - ]) + ctx = CorrelationContext( + events=[ + _event("a", trace_id=TRACE, span_id="s1"), + _event("b", trace_id=TRACE, span_id="s2"), + ] + ) matches = await self.strategy.correlate(ctx) assert len(matches) == 1 assert matches[0].score == SpanRelationStrategy.SAME_TRACE_SCORE assert matches[0].signal == CorrelationSignal.TRACE_MATCH async def test_events_without_span_ids_fall_back_to_same_trace(self) -> None: - ctx = CorrelationContext(events=[ - _event("a", trace_id=TRACE), - _event("b", trace_id=TRACE), - ]) + ctx = CorrelationContext( + events=[ + _event("a", trace_id=TRACE), + _event("b", trace_id=TRACE), + ] + ) matches = await self.strategy.correlate(ctx) assert len(matches) == 1 assert matches[0].score == SpanRelationStrategy.SAME_TRACE_SCORE async def test_different_traces_no_match(self) -> None: - ctx = CorrelationContext(events=[ - _event("a", trace_id="t1", span_id="s1"), - _event("b", trace_id="t2", span_id="s2"), - ]) + ctx = CorrelationContext( + events=[ + _event("a", trace_id="t1", span_id="s1"), + _event("b", trace_id="t2", span_id="s2"), + ] + ) matches = await self.strategy.correlate(ctx) assert matches == [] async def test_metadata_includes_trace_id(self) -> None: - ctx = CorrelationContext(events=[ - _event("a", span_id="s1"), - _event("b", span_id="s2", parent_span_id="s1"), - ]) + ctx = CorrelationContext( + events=[ + _event("a", span_id="s1"), + _event("b", span_id="s2", parent_span_id="s1"), + ] + ) matches = await self.strategy.correlate(ctx) assert matches[0].metadata.get("trace_id") == TRACE diff --git a/shared/domain/correlation/tests/test_strategies.py b/shared/domain/correlation/tests/test_strategies.py index 2afdee1..3c9f4dd 100644 --- a/shared/domain/correlation/tests/test_strategies.py +++ b/shared/domain/correlation/tests/test_strategies.py @@ -1,5 +1,3 @@ -import pytest - from shared.domain.correlation.models import CorrelationContext from shared.domain.correlation.strategies import ( DependencyStrategy, @@ -25,7 +23,7 @@ def _event( ) -> TimelineEvent: import datetime - base = datetime.datetime(2026, 6, 12, 10, 0, 0, tzinfo=datetime.timezone.utc) + base = datetime.datetime(2026, 6, 12, 10, 0, 0, tzinfo=datetime.UTC) return TimelineEvent( event_id=event_id, category=TimelineEventCategory(category), @@ -54,7 +52,9 @@ async def test_no_match_distant_events(self) -> None: assert len(matches) == 0 async def test_score_decays_with_distance(self) -> None: - ctx = CorrelationContext(events=[_event("a", ts_offset=0), _event("b", ts_offset=30), _event("c", ts_offset=55)]) + ctx = CorrelationContext( + events=[_event("a", ts_offset=0), _event("b", ts_offset=30), _event("c", ts_offset=55)] + ) s = TimeWindowStrategy(window_seconds=60) matches = await s.correlate(ctx) scores = {(m.event_id_a, m.event_id_b): m.score for m in matches} @@ -127,9 +127,7 @@ class TestDependencyStrategy: async def test_matches_dependent_services(self) -> None: store = InMemoryGraphStore() await store.add_edge(DependencyEdge(source="api", target="db")) - ctx = CorrelationContext( - events=[_event("a", service="api"), _event("b", service="db")] - ) + ctx = CorrelationContext(events=[_event("a", service="api"), _event("b", service="db")]) s = DependencyStrategy(store=store) matches = await s.correlate(ctx) assert len(matches) == 1 @@ -138,9 +136,7 @@ async def test_matches_dependent_services(self) -> None: async def test_no_match_independent_services(self) -> None: store = InMemoryGraphStore() await store.add_edge(DependencyEdge(source="api", target="db")) - ctx = CorrelationContext( - events=[_event("a", service="api"), _event("c", service="worker")] - ) + ctx = CorrelationContext(events=[_event("a", service="api"), _event("c", service="worker")]) s = DependencyStrategy(store=store) matches = await s.correlate(ctx) assert len(matches) == 0 @@ -148,9 +144,7 @@ async def test_no_match_independent_services(self) -> None: async def test_uses_edge_weight(self) -> None: store = InMemoryGraphStore() await store.add_edge(DependencyEdge(source="api", target="db", weight=0.5)) - ctx = CorrelationContext( - events=[_event("a", service="api"), _event("b", service="db")] - ) + ctx = CorrelationContext(events=[_event("a", service="api"), _event("b", service="db")]) s = DependencyStrategy(store=store) matches = await s.correlate(ctx) assert matches[0].score == 0.5 diff --git a/shared/domain/timeline/models.py b/shared/domain/timeline/models.py index a03bdb4..dae9a8e 100644 --- a/shared/domain/timeline/models.py +++ b/shared/domain/timeline/models.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from pydantic import BaseModel, Field @@ -26,7 +26,9 @@ class TimelineEvent(BaseModel): class TimelineWindow(BaseModel): window_start: datetime = Field(description="Start of the time window (UTC).") window_end: datetime = Field(description="End of the time window (UTC).") - events: list[TimelineEvent] = Field(default_factory=list, description="Events in this window, sorted chronologically.") + events: list[TimelineEvent] = Field( + default_factory=list, description="Events in this window, sorted chronologically." + ) @property def duration_seconds(self) -> float: @@ -43,7 +45,7 @@ class IncidentTimeline(BaseModel): windows: list[TimelineWindow] = Field(default_factory=list, description="Time-windowed event groups.") window_duration_seconds: int = Field(default=300, ge=1, description="Size of each time window in seconds.") created_at: datetime = Field( - default_factory=lambda: datetime.now(timezone.utc), + default_factory=lambda: datetime.now(UTC), description="When the timeline was built (UTC).", ) diff --git a/shared/domain/timeline/services/reconstructor.py b/shared/domain/timeline/services/reconstructor.py index 11a2618..a7c55d8 100644 --- a/shared/domain/timeline/services/reconstructor.py +++ b/shared/domain/timeline/services/reconstructor.py @@ -1,6 +1,6 @@ import re from collections.abc import Callable -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from uuid import uuid4 from shared.contracts.events.telemetry import TelemetryEvent @@ -10,7 +10,7 @@ ClassifierFn = Callable[[str, float, dict[str, str]], TimelineEventCategory] -def _default_classifier(metric: str, value: float, tags: dict[str, str]) -> TimelineEventCategory: +def _default_classifier(metric: str, _value: float, _tags: dict[str, str]) -> TimelineEventCategory: _patterns: dict[re.Pattern[str], TimelineEventCategory] = { re.compile(r"^(error|failure|exception|fault)", re.IGNORECASE): TimelineEventCategory.FAILURE, re.compile(r"^retry", re.IGNORECASE): TimelineEventCategory.RETRY, @@ -110,4 +110,4 @@ def _group_into_windows(self, events: list[TimelineEvent]) -> list[TimelineWindo def _floor_to_window(self, dt: datetime) -> datetime: epoch = int(dt.timestamp()) floored = epoch - (epoch % self._window_duration) - return datetime.fromtimestamp(floored, tz=timezone.utc) + return datetime.fromtimestamp(floored, tz=UTC) diff --git a/shared/domain/timeline/tests/test_models.py b/shared/domain/timeline/tests/test_models.py index 8809934..d495757 100644 --- a/shared/domain/timeline/tests/test_models.py +++ b/shared/domain/timeline/tests/test_models.py @@ -1,4 +1,4 @@ -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from shared.contracts.events.enums import Severity from shared.domain.timeline.enums import TimelineEventCategory, TimelineEventSource @@ -7,7 +7,7 @@ class TestTimelineEvent: async def test_minimal_event(self) -> None: - ts = datetime.now(timezone.utc) + ts = datetime.now(UTC) event = TimelineEvent( event_id="evt-1", category=TimelineEventCategory.FAILURE, @@ -23,7 +23,7 @@ async def test_minimal_event(self) -> None: assert event.tags == {} async def test_event_with_all_fields(self) -> None: - ts = datetime.now(timezone.utc) + ts = datetime.now(UTC) event = TimelineEvent( event_id="evt-2", category=TimelineEventCategory.DEPLOYMENT, @@ -43,7 +43,7 @@ async def test_event_with_all_fields(self) -> None: assert event.severity == Severity.INFO async def test_round_trip_json(self) -> None: - ts = datetime.now(timezone.utc) + ts = datetime.now(UTC) original = TimelineEvent( event_id="evt-3", category=TimelineEventCategory.RECOVERY, @@ -58,20 +58,20 @@ async def test_round_trip_json(self) -> None: class TestTimelineWindow: async def test_empty_window(self) -> None: - start = datetime(2026, 6, 12, 10, 0, 0, tzinfo=timezone.utc) - end = datetime(2026, 6, 12, 10, 5, 0, tzinfo=timezone.utc) + start = datetime(2026, 6, 12, 10, 0, 0, tzinfo=UTC) + end = datetime(2026, 6, 12, 10, 5, 0, tzinfo=UTC) window = TimelineWindow(window_start=start, window_end=end) assert window.event_count == 0 assert window.duration_seconds == 300.0 async def test_window_with_events(self) -> None: - start = datetime(2026, 6, 12, 10, 0, 0, tzinfo=timezone.utc) - end = datetime(2026, 6, 12, 10, 5, 0, tzinfo=timezone.utc) + start = datetime(2026, 6, 12, 10, 0, 0, tzinfo=UTC) + end = datetime(2026, 6, 12, 10, 5, 0, tzinfo=UTC) event = TimelineEvent( event_id="e1", category=TimelineEventCategory.FAILURE, source=TimelineEventSource.TELEMETRY, - timestamp=datetime(2026, 6, 12, 10, 2, 0, tzinfo=timezone.utc), + timestamp=datetime(2026, 6, 12, 10, 2, 0, tzinfo=UTC), service_name="api", title="failure", ) @@ -91,9 +91,9 @@ async def test_empty_timeline(self) -> None: assert timeline.end_time is None async def test_timeline_with_windows(self) -> None: - start = datetime(2026, 6, 12, 10, 0, 0, tzinfo=timezone.utc) - mid = datetime(2026, 6, 12, 10, 5, 0, tzinfo=timezone.utc) - end = datetime(2026, 6, 12, 10, 10, 0, tzinfo=timezone.utc) + start = datetime(2026, 6, 12, 10, 0, 0, tzinfo=UTC) + mid = datetime(2026, 6, 12, 10, 5, 0, tzinfo=UTC) + end = datetime(2026, 6, 12, 10, 10, 0, tzinfo=UTC) w1 = TimelineWindow( window_start=start, @@ -103,7 +103,7 @@ async def test_timeline_with_windows(self) -> None: event_id="e1", category=TimelineEventCategory.FAILURE, source=TimelineEventSource.TELEMETRY, - timestamp=datetime(2026, 6, 12, 10, 2, 0, tzinfo=timezone.utc), + timestamp=datetime(2026, 6, 12, 10, 2, 0, tzinfo=UTC), service_name="api", title="fail-1", ) @@ -117,7 +117,7 @@ async def test_timeline_with_windows(self) -> None: event_id="e2", category=TimelineEventCategory.RETRY, source=TimelineEventSource.TELEMETRY, - timestamp=datetime(2026, 6, 12, 10, 7, 0, tzinfo=timezone.utc), + timestamp=datetime(2026, 6, 12, 10, 7, 0, tzinfo=UTC), service_name="api", title="retry-1", ), @@ -125,7 +125,7 @@ async def test_timeline_with_windows(self) -> None: event_id="e3", category=TimelineEventCategory.RECOVERY, source=TimelineEventSource.TELEMETRY, - timestamp=datetime(2026, 6, 12, 10, 9, 0, tzinfo=timezone.utc), + timestamp=datetime(2026, 6, 12, 10, 9, 0, tzinfo=UTC), service_name="api", title="recovered", ), @@ -149,7 +149,7 @@ async def test_timeline_with_windows(self) -> None: assert [e.event_id for e in flattened] == ["e1", "e2", "e3"] async def test_json_round_trip(self) -> None: - ts = datetime(2026, 6, 12, 10, 0, 0, tzinfo=timezone.utc) + ts = datetime(2026, 6, 12, 10, 0, 0, tzinfo=UTC) timeline = IncidentTimeline( incident_id="inc-2", service="db", diff --git a/shared/domain/timeline/tests/test_reconstructor.py b/shared/domain/timeline/tests/test_reconstructor.py index 1fe1d9d..cec5aee 100644 --- a/shared/domain/timeline/tests/test_reconstructor.py +++ b/shared/domain/timeline/tests/test_reconstructor.py @@ -1,4 +1,6 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime + +import pytest from shared.contracts.events.telemetry import TelemetryEvent from shared.domain.timeline.enums import TimelineEventCategory, TimelineEventSource @@ -14,7 +16,7 @@ async def test_empty_events(self) -> None: assert timeline.window_count == 0 async def test_single_event_single_window(self) -> None: - ts = datetime(2026, 6, 12, 10, 0, 0, tzinfo=timezone.utc) + ts = datetime(2026, 6, 12, 10, 0, 0, tzinfo=UTC) event = TimelineEvent( event_id="e1", category=TimelineEventCategory.FAILURE, @@ -30,7 +32,7 @@ async def test_single_event_single_window(self) -> None: async def test_events_grouped_into_separate_windows(self) -> None: r = TimelineReconstructor(window_duration_seconds=300) - base = datetime(2026, 6, 12, 10, 0, 0, tzinfo=timezone.utc) + base = datetime(2026, 6, 12, 10, 0, 0, tzinfo=UTC) events = [ TimelineEvent( @@ -66,7 +68,7 @@ async def test_events_grouped_into_separate_windows(self) -> None: async def test_events_sorted_chronologically(self) -> None: r = TimelineReconstructor(window_duration_seconds=300) - base = datetime(2026, 6, 12, 10, 0, 0, tzinfo=timezone.utc) + base = datetime(2026, 6, 12, 10, 0, 0, tzinfo=UTC) events = [ TimelineEvent( @@ -101,7 +103,7 @@ async def test_events_sorted_chronologically(self) -> None: async def test_window_boundary_alignment(self) -> None: r = TimelineReconstructor(window_duration_seconds=60) - ts = datetime(2026, 6, 12, 10, 2, 30, tzinfo=timezone.utc) + ts = datetime(2026, 6, 12, 10, 2, 30, tzinfo=UTC) event = TimelineEvent( event_id="e1", category=TimelineEventCategory.FAILURE, @@ -166,7 +168,7 @@ async def test_telemetry_metadata_preserved(self) -> None: class TestFullReconstruction: async def test_build_timeline_from_telemetry(self) -> None: r = TimelineReconstructor(window_duration_seconds=300) - base = datetime(2026, 6, 12, 10, 0, 0, tzinfo=timezone.utc) + base = datetime(2026, 6, 12, 10, 0, 0, tzinfo=UTC) telemetry_events = [ TelemetryEvent(metric="error.rate", value=0.3, source="api", timestamp=base), @@ -193,15 +195,9 @@ async def test_build_timeline_from_telemetry(self) -> None: class TestWindowDurationValidation: async def test_rejects_zero_window_duration(self) -> None: - try: + with pytest.raises(ValueError, match="window_duration_seconds must be >= 1"): TimelineReconstructor(window_duration_seconds=0) - assert False, "Should have raised ValueError" - except ValueError: - pass async def test_rejects_negative_window_duration(self) -> None: - try: + with pytest.raises(ValueError, match="window_duration_seconds must be >= 1"): TimelineReconstructor(window_duration_seconds=-1) - assert False, "Should have raised ValueError" - except ValueError: - pass diff --git a/shared/observability/tracing/__init__.py b/shared/observability/tracing/__init__.py index 2ca8f25..58f2fe3 100644 --- a/shared/observability/tracing/__init__.py +++ b/shared/observability/tracing/__init__.py @@ -4,10 +4,10 @@ from shared.observability.tracing.provider import Span, Tracer, TracerProvider __all__ = [ + "Span", "SpanContext", "SpanKind", "SpanStatus", - "Span", "Tracer", "TracerProvider", ] diff --git a/shared/observability/tracing/models.py b/shared/observability/tracing/models.py index 8fb2424..98a934e 100644 --- a/shared/observability/tracing/models.py +++ b/shared/observability/tracing/models.py @@ -1,5 +1,4 @@ from enum import StrEnum -from typing import Any from pydantic import BaseModel, Field diff --git a/shared/observability/tracing/provider.py b/shared/observability/tracing/provider.py index f2d85a7..23d4fba 100644 --- a/shared/observability/tracing/provider.py +++ b/shared/observability/tracing/provider.py @@ -8,25 +8,20 @@ class Span(ABC): """Represents a single unit of work within a distributed trace.""" @abstractmethod - def set_attribute(self, key: str, value: str | bool | float | int) -> None: - ... + def set_attribute(self, key: str, value: str | bool | float | int) -> None: ... @abstractmethod - def set_status(self, status: SpanStatus | int, description: str | None = None) -> None: - ... + def set_status(self, status: SpanStatus | int, description: str | None = None) -> None: ... @abstractmethod - def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None: - ... + def add_event(self, name: str, attributes: dict[str, Any] | None = None) -> None: ... @abstractmethod - def end(self) -> None: - ... + def end(self) -> None: ... @property @abstractmethod - def context(self) -> SpanContext: - ... + def context(self) -> SpanContext: ... class Tracer(ABC): @@ -39,29 +34,23 @@ def start_span( context: SpanContext | None = None, kind: SpanKind = SpanKind.INTERNAL, attributes: dict[str, Any] | None = None, - ) -> Span: - ... + ) -> Span: ... class TracerProvider(ABC): """Provider-agnostic entry point for distributed tracing.""" @abstractmethod - def get_tracer(self, name: str, version: str | None = None) -> Tracer: - ... + def get_tracer(self, name: str, version: str | None = None) -> Tracer: ... @abstractmethod - def inject(self, context: SpanContext, headers: dict[str, str]) -> dict[str, str]: - ... + def inject(self, context: SpanContext, headers: dict[str, str]) -> dict[str, str]: ... @abstractmethod - def extract(self, headers: dict[str, str]) -> SpanContext | None: - ... + def extract(self, headers: dict[str, str]) -> SpanContext | None: ... @abstractmethod - async def force_flush(self) -> None: - ... + async def force_flush(self) -> None: ... @abstractmethod - async def shutdown(self) -> None: - ... + async def shutdown(self) -> None: ...