Skip to content

Commit d0bcfbb

Browse files
authored
Merge branch 'microsoft:main' into main
2 parents 321ad29 + bb9afeb commit d0bcfbb

File tree

12 files changed

+791
-132
lines changed

12 files changed

+791
-132
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "extract_graph_nlp streaming"
4+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "filter phantom relationships in graph"
4+
}
Lines changed: 91 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,141 +1,143 @@
1-
# Copyright (c) 2024 Microsoft Corporation.
1+
# Copyright (C) 2026 Microsoft
22
# Licensed under the MIT License
33

44
"""Graph extraction using NLP."""
55

6+
import logging
7+
from collections import defaultdict
68
from itertools import combinations
79

8-
import numpy as np
910
import pandas as pd
1011
from graphrag_cache import Cache
12+
from graphrag_storage.tables.table import Table
1113

12-
from graphrag.config.enums import AsyncType
1314
from graphrag.graphs.edge_weights import calculate_pmi_edge_weights
1415
from graphrag.index.operations.build_noun_graph.np_extractors.base import (
1516
BaseNounPhraseExtractor,
1617
)
17-
from graphrag.index.utils.derive_from_rows import derive_from_rows
1818
from graphrag.index.utils.hashing import gen_sha512_hash
1919

20+
logger = logging.getLogger(__name__)
21+
2022

2123
async def build_noun_graph(
22-
text_unit_df: pd.DataFrame,
24+
text_unit_table: Table,
2325
text_analyzer: BaseNounPhraseExtractor,
2426
normalize_edge_weights: bool,
25-
num_threads: int,
26-
async_mode: AsyncType,
2727
cache: Cache,
2828
) -> tuple[pd.DataFrame, pd.DataFrame]:
2929
"""Build a noun graph from text units."""
30-
text_units = text_unit_df.loc[:, ["id", "text"]]
31-
nodes_df = await _extract_nodes(
32-
text_units,
30+
title_to_ids = await _extract_nodes(
31+
text_unit_table,
3332
text_analyzer,
34-
num_threads=num_threads,
35-
async_mode=async_mode,
3633
cache=cache,
3734
)
38-
edges_df = _extract_edges(nodes_df, normalize_edge_weights=normalize_edge_weights)
35+
36+
nodes_df = pd.DataFrame(
37+
[
38+
{
39+
"title": title,
40+
"frequency": len(ids),
41+
"text_unit_ids": ids,
42+
}
43+
for title, ids in title_to_ids.items()
44+
],
45+
columns=["title", "frequency", "text_unit_ids"],
46+
)
47+
48+
edges_df = _extract_edges(
49+
title_to_ids,
50+
nodes_df=nodes_df,
51+
normalize_edge_weights=normalize_edge_weights,
52+
)
3953
return (nodes_df, edges_df)
4054

4155

4256
async def _extract_nodes(
43-
text_unit_df: pd.DataFrame,
57+
text_unit_table: Table,
4458
text_analyzer: BaseNounPhraseExtractor,
45-
num_threads: int,
46-
async_mode: AsyncType,
4759
cache: Cache,
48-
) -> pd.DataFrame:
49-
"""
50-
Extract initial nodes and edges from text units.
60+
) -> dict[str, list[str]]:
61+
"""Extract noun-phrase nodes from text units.
5162
52-
Input: text unit df with schema [id, text, document_id]
53-
Returns a dataframe with schema [id, title, frequency, text_unit_ids].
63+
NLP extraction is CPU-bound (spaCy/TextBlob), so threading
64+
provides no benefit under the GIL. We process rows
65+
sequentially, relying on the cache to skip repeated work.
66+
67+
Returns a mapping of noun-phrase title to text-unit ids.
5468
"""
55-
cache = cache.child("extract_noun_phrases")
69+
extraction_cache = cache.child("extract_noun_phrases")
70+
total = await text_unit_table.length()
71+
title_to_ids: dict[str, list[str]] = defaultdict(list)
72+
completed = 0
5673

57-
async def extract(row):
74+
async for row in text_unit_table:
75+
text_unit_id = row["id"]
5876
text = row["text"]
77+
5978
attrs = {"text": text, "analyzer": str(text_analyzer)}
6079
key = gen_sha512_hash(attrs, attrs.keys())
61-
result = await cache.get(key)
80+
result = await extraction_cache.get(key)
6281
if not result:
6382
result = text_analyzer.extract(text)
64-
await cache.set(key, result)
65-
return result
66-
67-
text_unit_df["noun_phrases"] = await derive_from_rows( # type: ignore
68-
text_unit_df,
69-
extract,
70-
num_threads=num_threads,
71-
async_type=async_mode,
72-
progress_msg="extract noun phrases progress: ",
73-
)
83+
await extraction_cache.set(key, result)
7484

75-
noun_node_df = text_unit_df.explode("noun_phrases")
76-
noun_node_df = noun_node_df.rename(
77-
columns={"noun_phrases": "title", "id": "text_unit_id"}
78-
)
85+
for phrase in result:
86+
title_to_ids[phrase].append(text_unit_id)
7987

80-
# group by title and count the number of text units
81-
grouped_node_df = (
82-
noun_node_df.groupby("title").agg({"text_unit_id": list}).reset_index()
83-
)
84-
grouped_node_df = grouped_node_df.rename(columns={"text_unit_id": "text_unit_ids"})
85-
grouped_node_df["frequency"] = grouped_node_df["text_unit_ids"].apply(len)
86-
grouped_node_df = grouped_node_df[["title", "frequency", "text_unit_ids"]]
87-
return grouped_node_df.loc[:, ["title", "frequency", "text_unit_ids"]]
88+
completed += 1
89+
if completed % 100 == 0 or completed == total:
90+
logger.info(
91+
"extract noun phrases progress: %d/%d",
92+
completed,
93+
total,
94+
)
95+
96+
return dict(title_to_ids)
8897

8998

9099
def _extract_edges(
100+
title_to_ids: dict[str, list[str]],
91101
nodes_df: pd.DataFrame,
92102
normalize_edge_weights: bool = True,
93103
) -> pd.DataFrame:
94-
"""
95-
Extract edges from nodes.
104+
"""Build co-occurrence edges between noun phrases.
96105
97-
Nodes appear in the same text unit are connected.
98-
Input: nodes_df with schema [id, title, frequency, text_unit_ids]
99-
Returns: edges_df with schema [source, target, weight, text_unit_ids]
106+
Nodes that appear in the same text unit are connected.
107+
Returns edges with schema [source, target, weight, text_unit_ids].
100108
"""
101-
if nodes_df.empty:
102-
return pd.DataFrame(columns=["source", "target", "weight", "text_unit_ids"])
103-
104-
text_units_df = nodes_df.explode("text_unit_ids")
105-
text_units_df = text_units_df.rename(columns={"text_unit_ids": "text_unit_id"})
106-
text_units_df = (
107-
text_units_df
108-
.groupby("text_unit_id")
109-
.agg({"title": lambda x: list(x) if len(x) > 1 else np.nan})
110-
.reset_index()
111-
)
112-
text_units_df = text_units_df.dropna()
113-
titles = text_units_df["title"].tolist()
114-
all_edges: list[list[tuple[str, str]]] = [list(combinations(t, 2)) for t in titles]
115-
116-
text_units_df = text_units_df.assign(edges=all_edges) # type: ignore
117-
edge_df = text_units_df.explode("edges")[["edges", "text_unit_id"]]
118-
119-
edge_df[["source", "target"]] = edge_df.loc[:, "edges"].to_list()
120-
edge_df["min_source"] = edge_df[["source", "target"]].min(axis=1)
121-
edge_df["max_target"] = edge_df[["source", "target"]].max(axis=1)
122-
edge_df = edge_df.drop(columns=["source", "target"]).rename(
123-
columns={"min_source": "source", "max_target": "target"} # type: ignore
109+
if not title_to_ids:
110+
return pd.DataFrame(
111+
columns=["source", "target", "weight", "text_unit_ids"],
112+
)
113+
114+
text_unit_to_titles: dict[str, list[str]] = defaultdict(list)
115+
for title, tu_ids in title_to_ids.items():
116+
for tu_id in tu_ids:
117+
text_unit_to_titles[tu_id].append(title)
118+
119+
edge_map: dict[tuple[str, str], list[str]] = defaultdict(list)
120+
for tu_id, titles in text_unit_to_titles.items():
121+
if len(titles) < 2:
122+
continue
123+
for pair in combinations(sorted(set(titles)), 2):
124+
edge_map[pair].append(tu_id)
125+
126+
records = [
127+
{
128+
"source": src,
129+
"target": tgt,
130+
"weight": len(tu_ids),
131+
"text_unit_ids": tu_ids,
132+
}
133+
for (src, tgt), tu_ids in edge_map.items()
134+
]
135+
edges_df = pd.DataFrame(
136+
records,
137+
columns=["source", "target", "weight", "text_unit_ids"],
124138
)
125139

126-
edge_df = edge_df[(edge_df.source.notna()) & (edge_df.target.notna())]
127-
edge_df = edge_df.drop(columns=["edges"])
128-
# group by source and target, count the number of text units
129-
grouped_edge_df = (
130-
edge_df.groupby(["source", "target"]).agg({"text_unit_id": list}).reset_index()
131-
)
132-
grouped_edge_df = grouped_edge_df.rename(columns={"text_unit_id": "text_unit_ids"})
133-
grouped_edge_df["weight"] = grouped_edge_df["text_unit_ids"].apply(len)
134-
grouped_edge_df = grouped_edge_df.loc[
135-
:, ["source", "target", "weight", "text_unit_ids"]
136-
]
137-
if normalize_edge_weights:
138-
# use PMI weight instead of raw weight
139-
grouped_edge_df = calculate_pmi_edge_weights(nodes_df, grouped_edge_df)
140+
if normalize_edge_weights and not edges_df.empty:
141+
edges_df = calculate_pmi_edge_weights(nodes_df, edges_df)
140142

141-
return grouped_edge_df
143+
return edges_df

packages/graphrag/graphrag/index/operations/embed_text/embed_text.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,18 @@ async def embed_text(
3333
id_column: str = "id",
3434
output_table: Table | None = None,
3535
) -> int:
36-
"""Embed text from a streaming Table into a vector store."""
36+
"""Embed text from a streaming Table into a vector store.
37+
38+
Rows are buffered before flushing to ``run_embed_text``,
39+
which dispatches API batches concurrently up to
40+
``num_threads``. The buffer is sized so each flush produces
41+
enough batches to saturate the concurrency limit.
42+
"""
3743
vector_store.create_index()
3844

3945
buffer: list[dict[str, Any]] = []
4046
total_rows = 0
41-
flush_size = batch_size * 4
47+
flush_size = batch_size * num_threads
4248

4349
async for row in input_table:
4450
text = row.get(embed_column)

packages/graphrag/graphrag/index/operations/extract_graph/extract_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1212
from graphrag.config.enums import AsyncType
1313
from graphrag.index.operations.extract_graph.graph_extractor import GraphExtractor
14+
from graphrag.index.operations.extract_graph.utils import filter_orphan_relationships
1415
from graphrag.index.utils.derive_from_rows import derive_from_rows
1516

1617
if TYPE_CHECKING:
@@ -67,6 +68,7 @@ async def run_strategy(row):
6768

6869
entities = _merge_entities(entity_dfs)
6970
relationships = _merge_relationships(relationship_dfs)
71+
relationships = filter_orphan_relationships(relationships, entities)
7072

7173
return (entities, relationships)
7274

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Copyright (C) 2026 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""Utility functions for graph extraction operations."""
5+
6+
import logging
7+
8+
import pandas as pd
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def filter_orphan_relationships(
14+
relationships: pd.DataFrame,
15+
entities: pd.DataFrame,
16+
) -> pd.DataFrame:
17+
"""Remove relationships whose source or target has no entity entry.
18+
19+
After LLM graph extraction, the model may hallucinate entity
20+
names in relationships that have no corresponding entity row.
21+
This function drops those dangling references so downstream
22+
processing never encounters broken graph edges.
23+
24+
Parameters
25+
----------
26+
relationships:
27+
Merged relationship DataFrame with at least ``source``
28+
and ``target`` columns.
29+
entities:
30+
Merged entity DataFrame with at least a ``title`` column.
31+
32+
Returns
33+
-------
34+
pd.DataFrame
35+
Relationships filtered to only those whose ``source``
36+
and ``target`` both appear in ``entities["title"]``.
37+
"""
38+
if relationships.empty or entities.empty:
39+
return relationships.iloc[0:0].reset_index(drop=True)
40+
41+
entity_titles = set(entities["title"])
42+
before_count = len(relationships)
43+
mask = relationships["source"].isin(entity_titles) & relationships["target"].isin(
44+
entity_titles
45+
)
46+
filtered = relationships[mask].reset_index(drop=True)
47+
dropped = before_count - len(filtered)
48+
if dropped > 0:
49+
logger.warning(
50+
"Dropped %d relationship(s) referencing non-existent entities.",
51+
dropped,
52+
)
53+
return filtered

0 commit comments

Comments
 (0)