Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 34 additions & 44 deletions code_review_graph/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Optional

from .constants import SECURITY_KEYWORDS as _SECURITY_KEYWORDS
from .graph import GraphNode, GraphStore, _sanitize_name
from .graph import FlowAdjacency, GraphNode, GraphStore, _sanitize_name

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -201,7 +201,7 @@ def detect_entry_points(


def _trace_single_flow(
store: GraphStore,
adj: FlowAdjacency,
ep: GraphNode,
max_depth: int = 15,
) -> Optional[dict]:
Expand All @@ -210,18 +210,14 @@ def _trace_single_flow(
Returns a flow dict (see :func:`trace_flows` for the schema) or ``None``
if the flow is trivial (single-node, no outgoing CALLS that resolve).
"""
path_ids: list[int] = []
path_qnames: list[str] = []
visited: set[str] = set()
queue: deque[tuple[str, int]] = deque()

# Seed with the entry point itself.
queue.append((ep.qualified_name, 0))
visited.add(ep.qualified_name)
path_ids.append(ep.id)
path_qnames.append(ep.qualified_name)
path_ids: list[int] = [ep.id]
path_qnames: list[str] = [ep.qualified_name]
visited: set[str] = {ep.qualified_name}
queue: deque[tuple[str, int]] = deque([(ep.qualified_name, 0)])

actual_depth = 0
nodes_by_qn = adj.nodes_by_qn
calls_out = adj.calls_out

while queue:
current_qn, depth = queue.popleft()
Expand All @@ -230,16 +226,10 @@ def _trace_single_flow(
if depth >= max_depth:
continue

# Follow forward CALLS edges.
edges = store.get_edges_by_source(current_qn)
for edge in edges:
if edge.kind != "CALLS":
continue
target_qn = edge.target_qualified
for target_qn in calls_out.get(current_qn, ()):
if target_qn in visited:
continue
# Resolve the target node to get its id.
target_node = store.get_node(target_qn)
target_node = nodes_by_qn.get(target_qn)
if target_node is None:
continue
visited.add(target_qn)
Expand All @@ -254,7 +244,7 @@ def _trace_single_flow(
files = list({
n.file_path
for qn in path_qnames
if (n := store.get_node(qn)) is not None
if (n := nodes_by_qn.get(qn)) is not None
})

flow: dict = {
Expand All @@ -268,7 +258,7 @@ def _trace_single_flow(
"files": files,
"criticality": 0.0,
}
flow["criticality"] = compute_criticality(flow, store)
flow["criticality"] = compute_criticality(flow, adj)
return flow


Expand All @@ -291,10 +281,14 @@ def trace_flows(
- criticality: computed criticality score (0.0-1.0)
"""
entry_points = detect_entry_points(store, include_tests=include_tests)
if not entry_points:
return []

adj = store.load_flow_adjacency()
flows: list[dict] = []

for ep in entry_points:
flow = _trace_single_flow(store, ep, max_depth)
flow = _trace_single_flow(adj, ep, max_depth)
if flow is not None:
flows.append(flow)

Expand All @@ -308,7 +302,7 @@ def trace_flows(
# ---------------------------------------------------------------------------


def compute_criticality(flow: dict, store: GraphStore) -> float:
def compute_criticality(flow: dict, adj: FlowAdjacency) -> float:
"""Score a flow from 0.0 to 1.0 based on multiple weighted factors.

Weights:
Expand All @@ -322,13 +316,14 @@ def compute_criticality(flow: dict, store: GraphStore) -> float:
if not node_ids:
return 0.0

# Resolve nodes once.
nodes: list[GraphNode] = []
for nid in node_ids:
n = store.get_node_by_id(nid)
if n:
nodes.append(n)
nodes_by_id = adj.nodes_by_id
nodes_by_qn = adj.nodes_by_qn
calls_out = adj.calls_out
has_tested_by = adj.has_tested_by

nodes: list[GraphNode] = [
n for nid in node_ids if (n := nodes_by_id.get(nid)) is not None
]
if not nodes:
return 0.0

Expand All @@ -341,9 +336,8 @@ def compute_criticality(flow: dict, store: GraphStore) -> float:
# Calls that target nodes NOT in the graph are considered external.
external_count = 0
for n in nodes:
edges = store.get_edges_by_source(n.qualified_name)
for e in edges:
if e.kind == "CALLS" and store.get_node(e.target_qualified) is None:
for target_qn in calls_out.get(n.qualified_name, ()):
if target_qn not in nodes_by_qn:
external_count += 1
# Normalize: 0 => 0.0, 5+ => 1.0
external_score = min(external_count / 5.0, 1.0)
Expand All @@ -360,13 +354,7 @@ def compute_criticality(flow: dict, store: GraphStore) -> float:
security_score = min(security_hits / max(len(nodes), 1), 1.0)

# --- Test coverage gap (0.0 - 1.0) ---
tested_count = 0
for n in nodes:
tested_edges = store.get_edges_by_target(n.qualified_name)
for te in tested_edges:
if te.kind == "TESTED_BY":
tested_count += 1
break
tested_count = sum(1 for n in nodes if n.qualified_name in has_tested_by)
coverage = tested_count / max(len(nodes), 1)
test_gap = 1.0 - coverage

Expand Down Expand Up @@ -514,10 +502,12 @@ def incremental_trace_flows(
# 5. BFS-trace each relevant entry point
# ------------------------------------------------------------------
new_flows: list[dict] = []
for ep in relevant_eps:
flow = _trace_single_flow(store, ep, max_depth)
if flow is not None:
new_flows.append(flow)
if relevant_eps:
adj = store.load_flow_adjacency()
for ep in relevant_eps:
flow = _trace_single_flow(adj, ep, max_depth)
if flow is not None:
new_flows.append(flow)

# ------------------------------------------------------------------
# 6. INSERT new flows without clearing unrelated ones
Expand Down
50 changes: 50 additions & 0 deletions code_review_graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ class GraphEdge:
confidence_tier: str = "EXTRACTED"


@dataclass
class FlowAdjacency:
"""In-memory adjacency structure for flow tracing.

Loaded once via :meth:`GraphStore.load_flow_adjacency` and passed to
``trace_flows`` / ``compute_criticality`` to avoid per-edge SQLite
point queries on large graphs.
"""
calls_out: dict[str, list[str]]
has_tested_by: set[str]
nodes_by_qn: dict[str, "GraphNode"]
nodes_by_id: dict[int, "GraphNode"]


@dataclass
class GraphStats:
total_nodes: int
Expand Down Expand Up @@ -1199,6 +1213,42 @@ def _batch_get_nodes(self, qualified_names: set[str]) -> list[GraphNode]:
results.extend(self._row_to_node(r) for r in rows)
return results

def load_flow_adjacency(self) -> "FlowAdjacency":
"""Load all nodes and CALLS/TESTED_BY edges into memory for fast traversal.

Reads the entire ``nodes`` and ``edges`` tables in two streaming
queries and returns an in-memory adjacency structure suitable for
flow tracing and criticality scoring. At ~500k nodes / 3M edges
this fits in a few hundred MB and eliminates tens of millions of
single-row SQLite point queries that otherwise dominate
``trace_flows`` / ``compute_criticality`` runtime.
"""
nodes_by_qn: dict[str, GraphNode] = {}
nodes_by_id: dict[int, GraphNode] = {}
for row in self._conn.execute("SELECT * FROM nodes"):
node = self._row_to_node(row)
nodes_by_qn[node.qualified_name] = node
nodes_by_id[node.id] = node

calls_out: dict[str, list[str]] = {}
has_tested_by: set[str] = set()
for row in self._conn.execute(
"SELECT kind, source_qualified, target_qualified FROM edges "
"WHERE kind IN ('CALLS', 'TESTED_BY')"
):
kind, src, tgt = row["kind"], row["source_qualified"], row["target_qualified"]
if kind == "CALLS":
calls_out.setdefault(src, []).append(tgt)
else: # TESTED_BY
has_tested_by.add(tgt)

return FlowAdjacency(
calls_out=calls_out,
has_tested_by=has_tested_by,
nodes_by_qn=nodes_by_qn,
nodes_by_id=nodes_by_id,
)

# --- Internal helpers ---

def _build_networkx_graph(self) -> nx.DiGraph:
Expand Down