Skip to content

Commit 56e0fad

Browse files
authored
NLP graph parity (#1888)
* Update stopwords config * Minor edits * Update PMI * Format * Perf improvements * Semver * Remove edge collection apply * Remove source/target apply * Add edge weight to graph snapshot * Revert breaking optimizations * Add perf fixes back in * Format/types * Update defaults * Fix source/target ordering * Fix test
1 parent 25b605b commit 56e0fad

File tree

8 files changed

+45
-50
lines changed

8 files changed

+45
-50
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Brings parity with our latest NLP extraction approaches."
4+
}

docs/config/yaml.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ Parameters for manual graph pruning. This can be used to optimize the modularity
251251
- max_node_freq_std **float | None** - The maximum standard deviation of node frequency to allow.
252252
- min_node_degree **int** - The minimum node degree to allow.
253253
- max_node_degree_std **float | None** - The maximum standard deviation of node degree to allow.
254-
- min_edge_weight_pct **int** - The minimum edge weight percentile to allow.
254+
- min_edge_weight_pct **float** - The minimum edge weight percentile to allow.
255255
- remove_ego_nodes **bool** - Remove ego nodes.
256256
- lcc_only **bool** - Only use largest connected component.
257257

graphrag/config/defaults.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
ReportingType,
2020
TextEmbeddingTarget,
2121
)
22+
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
23+
EN_STOP_WORDS,
24+
)
2225
from graphrag.vector_stores.factory import VectorStoreType
2326

2427
DEFAULT_OUTPUT_BASE_DIR = "output"
@@ -186,7 +189,7 @@ class TextAnalyzerDefaults:
186189
max_word_length: int = 15
187190
word_delimiter: str = " "
188191
include_named_entities: bool = True
189-
exclude_nouns: None = None
192+
exclude_nouns: list[str] = field(default_factory=lambda: EN_STOP_WORDS)
190193
exclude_entity_tags: list[str] = field(default_factory=lambda: ["DATE"])
191194
exclude_pos_tags: list[str] = field(
192195
default_factory=lambda: ["DET", "PRON", "INTJ", "X"]
@@ -317,8 +320,8 @@ class PruneGraphDefaults:
317320
max_node_freq_std: None = None
318321
min_node_degree: int = 1
319322
max_node_degree_std: None = None
320-
min_edge_weight_pct: int = 40
321-
remove_ego_nodes: bool = False
323+
min_edge_weight_pct: float = 40.0
324+
remove_ego_nodes: bool = True
322325
lcc_only: bool = False
323326

324327

graphrag/index/operations/build_noun_graph/build_noun_graph.py

Lines changed: 22 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33

44
"""Graph extraction using NLP."""
55

6-
import math
6+
from itertools import combinations
77

8+
import numpy as np
89
import pandas as pd
910

1011
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
@@ -30,7 +31,6 @@ async def build_noun_graph(
3031
text_units, text_analyzer, num_threads=num_threads, cache=cache
3132
)
3233
edges_df = _extract_edges(nodes_df, normalize_edge_weights=normalize_edge_weights)
33-
3434
return (nodes_df, edges_df)
3535

3636

@@ -69,7 +69,7 @@ async def extract(row):
6969
noun_node_df = text_unit_df.explode("noun_phrases")
7070
noun_node_df = noun_node_df.rename(
7171
columns={"noun_phrases": "title", "id": "text_unit_id"}
72-
).drop_duplicates()
72+
)
7373

7474
# group by title and count the number of text units
7575
grouped_node_df = (
@@ -94,64 +94,44 @@ def _extract_edges(
9494
"""
9595
text_units_df = nodes_df.explode("text_unit_ids")
9696
text_units_df = text_units_df.rename(columns={"text_unit_ids": "text_unit_id"})
97+
9798
text_units_df = (
98-
text_units_df.groupby("text_unit_id").agg({"title": list}).reset_index()
99+
text_units_df.groupby("text_unit_id")
100+
.agg({"title": lambda x: list(x) if len(x) > 1 else np.nan})
101+
.reset_index()
99102
)
100-
text_units_df["edges"] = text_units_df["title"].apply(
101-
lambda x: _create_relationships(x)
103+
text_units_df = text_units_df.dropna()
104+
titles = text_units_df["title"].tolist()
105+
all_edges: list[list[tuple[str, str]]] = [list(combinations(t, 2)) for t in titles]
106+
107+
text_units_df = text_units_df.assign(edges=all_edges) # type: ignore
108+
edge_df = text_units_df.explode("edges")[["edges", "text_unit_id"]]
109+
110+
edge_df[["source", "target"]] = edge_df.loc[:, "edges"].to_list()
111+
edge_df["min_source"] = edge_df[["source", "target"]].min(axis=1)
112+
edge_df["max_target"] = edge_df[["source", "target"]].max(axis=1)
113+
edge_df = edge_df.drop(columns=["source", "target"]).rename(
114+
columns={"min_source": "source", "max_target": "target"} # type: ignore
102115
)
103-
edge_df = text_units_df.explode("edges").loc[:, ["edges", "text_unit_id"]]
104116

105-
edge_df["source"] = edge_df["edges"].apply(
106-
lambda x: x[0] if isinstance(x, tuple) else None
107-
)
108-
edge_df["target"] = edge_df["edges"].apply(
109-
lambda x: x[1] if isinstance(x, tuple) else None
110-
)
111117
edge_df = edge_df[(edge_df.source.notna()) & (edge_df.target.notna())]
112118
edge_df = edge_df.drop(columns=["edges"])
113-
114-
# make sure source is always smaller than target
115-
edge_df["source"], edge_df["target"] = zip(
116-
*edge_df.apply(
117-
lambda x: (x["source"], x["target"])
118-
if x["source"] < x["target"]
119-
else (x["target"], x["source"]),
120-
axis=1,
121-
),
122-
strict=False,
123-
)
124-
125-
# group by source and target, count the number of text units and collect their ids
119+
# group by source and target, count the number of text units
126120
grouped_edge_df = (
127121
edge_df.groupby(["source", "target"]).agg({"text_unit_id": list}).reset_index()
128122
)
129123
grouped_edge_df = grouped_edge_df.rename(columns={"text_unit_id": "text_unit_ids"})
130124
grouped_edge_df["weight"] = grouped_edge_df["text_unit_ids"].apply(len)
131-
132125
grouped_edge_df = grouped_edge_df.loc[
133126
:, ["source", "target", "weight", "text_unit_ids"]
134127
]
135-
136128
if normalize_edge_weights:
137129
# use PMI weight instead of raw weight
138130
grouped_edge_df = _calculate_pmi_edge_weights(nodes_df, grouped_edge_df)
139131

140132
return grouped_edge_df
141133

142134

143-
def _create_relationships(
144-
noun_phrases: list[str],
145-
) -> list[tuple[str, str]]:
146-
"""Create a (source, target) tuple pairwise for all noun phrases in a list."""
147-
relationships = []
148-
if len(noun_phrases) >= 2:
149-
for i in range(len(noun_phrases) - 1):
150-
for j in range(i + 1, len(noun_phrases)):
151-
relationships.extend([(noun_phrases[i], noun_phrases[j])])
152-
return relationships
153-
154-
155135
def _calculate_pmi_edge_weights(
156136
nodes_df: pd.DataFrame,
157137
edges_df: pd.DataFrame,
@@ -192,8 +172,7 @@ def _calculate_pmi_edge_weights(
192172
.drop(columns=[node_name_col])
193173
.rename(columns={"prop_occurrence": "target_prop"})
194174
)
195-
edges_df[edge_weight_col] = edges_df.apply(
196-
lambda x: math.log2(x["prop_weight"] / (x["source_prop"] * x["target_prop"])),
197-
axis=1,
175+
edges_df[edge_weight_col] = edges_df["prop_weight"] * np.log2(
176+
edges_df["prop_weight"] / (edges_df["source_prop"] * edges_df["target_prop"])
198177
)
199178
return edges_df.drop(columns=["prop_weight", "source_prop", "target_prop"])

graphrag/index/operations/graph_to_dataframes.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@ def graph_to_dataframes(
2121

2222
edges = nx.to_pandas_edgelist(graph)
2323

24+
# we don't deal in directed graphs, but we do need to ensure consistent ordering for df joins
25+
# nx loses the initial ordering
26+
edges["min_source"] = edges[["source", "target"]].min(axis=1)
27+
edges["max_target"] = edges[["source", "target"]].max(axis=1)
28+
edges = edges.drop(columns=["source", "target"]).rename(
29+
columns={"min_source": "source", "max_target": "target"} # type: ignore
30+
)
31+
2432
if node_columns:
2533
nodes = nodes.loc[:, node_columns]
2634

graphrag/index/operations/prune_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def prune_graph(
2121
max_node_freq_std: float | None = None,
2222
min_node_degree: int = 1,
2323
max_node_degree_std: float | None = None,
24-
min_edge_weight_pct: float = 0,
24+
min_edge_weight_pct: float = 40,
2525
remove_ego_nodes: bool = False,
2626
lcc_only: bool = False,
2727
) -> nx.Graph:

graphrag/index/workflows/finalize_graph.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@ async def run_workflow(
3838

3939
if config.snapshots.graphml:
4040
# todo: extract graphs at each level, and add in meta like descriptions
41-
graph = create_graph(relationships)
41+
graph = create_graph(final_relationships, edge_attr=["weight"])
42+
4243
await snapshot_graphml(
4344
graph,
4445
name="graph",

tests/verbs/test_prune_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,4 @@ async def test_prune_graph():
2828

2929
nodes_actual = await load_table_from_storage("entities", context.storage)
3030

31-
assert len(nodes_actual) == 21
31+
assert len(nodes_actual) == 20

0 commit comments

Comments
 (0)