Skip to content

Commit facf681

Browse files
Fix summarization and relationship grouping on Inc Indexing (#1768)
* Finx sumarization for large descriptions on incremental indexing * Semver * Ruff
1 parent ede6a74 commit facf681

File tree

6 files changed

+118
-108
lines changed

6 files changed

+118
-108
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Fix summarization over large datasets for inc indexing. Fix relationship summarization"
4+
}

graphrag/index/update/entities.py

Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,12 @@
33

44
"""Entity related operations and utils for Incremental Indexing."""
55

6-
import asyncio
76
import itertools
87

98
import numpy as np
109
import pandas as pd
1110

12-
from graphrag.cache.pipeline_cache import PipelineCache
13-
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
14-
from graphrag.config.models.graph_rag_config import GraphRagConfig
1511
from graphrag.data_model.schemas import ENTITIES_FINAL_COLUMNS
16-
from graphrag.index.operations.summarize_descriptions.graph_intelligence_strategy import (
17-
run_graph_intelligence as run_entity_summarization,
18-
)
1912

2013

2114
def _group_and_resolve_entities(
@@ -83,61 +76,3 @@ def _group_and_resolve_entities(
8376
resolved = resolved.loc[:, ENTITIES_FINAL_COLUMNS]
8477

8578
return resolved, id_mapping
86-
87-
88-
async def _run_entity_summarization(
89-
entities_df: pd.DataFrame,
90-
config: GraphRagConfig,
91-
cache: PipelineCache,
92-
callbacks: WorkflowCallbacks,
93-
) -> pd.DataFrame:
94-
"""Run entity summarization.
95-
96-
Parameters
97-
----------
98-
entities_df : pd.DataFrame
99-
The entities dataframe.
100-
config : GraphRagConfig
101-
The pipeline configuration.
102-
cache : PipelineCache
103-
Pipeline cache used during the summarization process.
104-
105-
Returns
106-
-------
107-
pd.DataFrame
108-
The updated entities dataframe with summarized descriptions.
109-
"""
110-
summarization_llm_settings = config.get_language_model_config(
111-
config.summarize_descriptions.model_id
112-
)
113-
summarization_strategy = config.summarize_descriptions.resolved_strategy(
114-
config.root_dir, summarization_llm_settings
115-
)
116-
117-
# Prepare tasks for async summarization where needed
118-
async def process_row(row):
119-
# Accessing attributes directly from the named tuple.
120-
description = row.description
121-
if isinstance(description, list) and len(description) > 1:
122-
# Run entity summarization asynchronously
123-
result = await run_entity_summarization(
124-
row.title,
125-
description,
126-
callbacks,
127-
cache,
128-
summarization_strategy,
129-
)
130-
return result.description
131-
# Handle case where description is a single-item list or not a list
132-
return description[0] if isinstance(description, list) else description
133-
134-
# Create a list of async tasks for summarization
135-
tasks = [
136-
process_row(row) for row in entities_df.itertuples(index=False, name="Entity")
137-
]
138-
results = await asyncio.gather(*tasks)
139-
140-
# Update the 'description' column in the DataFrame
141-
entities_df["description"] = results
142-
143-
return entities_df

graphrag/index/update/incremental_index.py

Lines changed: 69 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818
)
1919
from graphrag.index.update.entities import (
2020
_group_and_resolve_entities,
21-
_run_entity_summarization,
2221
)
2322
from graphrag.index.update.relationships import _update_and_merge_relationships
23+
from graphrag.index.workflows.extract_graph import get_summarized_entities_relationships
2424
from graphrag.index.workflows.generate_text_embeddings import generate_text_embeddings
2525
from graphrag.logger.print_progress import ProgressLogger
2626
from graphrag.storage.pipeline_storage import PipelineStorage
@@ -104,18 +104,16 @@ async def update_dataframe_outputs(
104104
"documents", previous_storage, delta_storage, output_storage
105105
)
106106

107-
# Update entities and merge them
108-
progress_logger.info("Updating Entities")
109-
merged_entities_df, entity_id_mapping = await _update_entities(
107+
# Update entities, relationships and merge them
108+
progress_logger.info("Updating Entities and Relationships")
109+
(
110+
merged_entities_df,
111+
merged_relationships_df,
112+
entity_id_mapping,
113+
) = await _update_entities_and_relationships(
110114
previous_storage, delta_storage, output_storage, config, cache, callbacks
111115
)
112116

113-
# Update relationships with the entities id mapping
114-
progress_logger.info("Updating Relationships")
115-
merged_relationships_df = await _update_relationships(
116-
previous_storage, delta_storage, output_storage
117-
)
118-
119117
# Update and merge final text units
120118
progress_logger.info("Updating Text Units")
121119
merged_text_units = await _update_text_units(
@@ -166,8 +164,11 @@ async def update_dataframe_outputs(
166164

167165

168166
async def _update_community_reports(
169-
previous_storage, delta_storage, output_storage, community_id_mapping
170-
):
167+
previous_storage: PipelineStorage,
168+
delta_storage: PipelineStorage,
169+
output_storage: PipelineStorage,
170+
community_id_mapping: dict,
171+
) -> pd.DataFrame:
171172
"""Update the community reports output."""
172173
old_community_reports = await load_table_from_storage(
173174
"community_reports", previous_storage
@@ -186,7 +187,11 @@ async def _update_community_reports(
186187
return merged_community_reports
187188

188189

189-
async def _update_communities(previous_storage, delta_storage, output_storage):
190+
async def _update_communities(
191+
previous_storage: PipelineStorage,
192+
delta_storage: PipelineStorage,
193+
output_storage: PipelineStorage,
194+
) -> dict:
190195
"""Update the communities output."""
191196
old_communities = await load_table_from_storage("communities", previous_storage)
192197
delta_communities = await load_table_from_storage("communities", delta_storage)
@@ -199,7 +204,11 @@ async def _update_communities(previous_storage, delta_storage, output_storage):
199204
return community_id_mapping
200205

201206

202-
async def _update_covariates(previous_storage, delta_storage, output_storage):
207+
async def _update_covariates(
208+
previous_storage: PipelineStorage,
209+
delta_storage: PipelineStorage,
210+
output_storage: PipelineStorage,
211+
) -> None:
203212
"""Update the covariates output."""
204213
old_covariates = await load_table_from_storage("covariates", previous_storage)
205214
delta_covariates = await load_table_from_storage("covariates", delta_storage)
@@ -209,8 +218,11 @@ async def _update_covariates(previous_storage, delta_storage, output_storage):
209218

210219

211220
async def _update_text_units(
212-
previous_storage, delta_storage, output_storage, entity_id_mapping
213-
):
221+
previous_storage: PipelineStorage,
222+
delta_storage: PipelineStorage,
223+
output_storage: PipelineStorage,
224+
entity_id_mapping: dict,
225+
) -> pd.DataFrame:
214226
"""Update the text units output."""
215227
old_text_units = await load_table_from_storage("text_units", previous_storage)
216228
delta_text_units = await load_table_from_storage("text_units", delta_storage)
@@ -223,48 +235,65 @@ async def _update_text_units(
223235
return merged_text_units
224236

225237

226-
async def _update_relationships(previous_storage, delta_storage, output_storage):
227-
"""Update the relationships output."""
238+
async def _update_entities_and_relationships(
239+
previous_storage: PipelineStorage,
240+
delta_storage: PipelineStorage,
241+
output_storage: PipelineStorage,
242+
config: GraphRagConfig,
243+
cache: PipelineCache,
244+
callbacks: WorkflowCallbacks,
245+
) -> tuple[pd.DataFrame, pd.DataFrame, dict]:
246+
"""Update Final Entities and Relationships output."""
247+
old_entities = await load_table_from_storage("entities", previous_storage)
248+
delta_entities = await load_table_from_storage("entities", delta_storage)
249+
250+
merged_entities_df, entity_id_mapping = _group_and_resolve_entities(
251+
old_entities, delta_entities
252+
)
253+
254+
# Update Relationships
228255
old_relationships = await load_table_from_storage("relationships", previous_storage)
229256
delta_relationships = await load_table_from_storage("relationships", delta_storage)
230257
merged_relationships_df = _update_and_merge_relationships(
231258
old_relationships,
232259
delta_relationships,
233260
)
234261

235-
await write_table_to_storage(
236-
merged_relationships_df, "relationships", output_storage
262+
summarization_llm_settings = config.get_language_model_config(
263+
config.summarize_descriptions.model_id
237264
)
238-
239-
return merged_relationships_df
240-
241-
242-
async def _update_entities(
243-
previous_storage, delta_storage, output_storage, config, cache, callbacks
244-
):
245-
"""Update Final Entities output."""
246-
old_entities = await load_table_from_storage("entities", previous_storage)
247-
delta_entities = await load_table_from_storage("entities", delta_storage)
248-
249-
merged_entities_df, entity_id_mapping = _group_and_resolve_entities(
250-
old_entities, delta_entities
265+
summarization_strategy = config.summarize_descriptions.resolved_strategy(
266+
config.root_dir, summarization_llm_settings
251267
)
252268

253-
# Re-run description summarization
254-
merged_entities_df = await _run_entity_summarization(
269+
(
255270
merged_entities_df,
256-
config,
257-
cache,
258-
callbacks,
271+
merged_relationships_df,
272+
) = await get_summarized_entities_relationships(
273+
extracted_entities=merged_entities_df,
274+
extracted_relationships=merged_relationships_df,
275+
callbacks=callbacks,
276+
cache=cache,
277+
summarization_strategy=summarization_strategy,
278+
summarization_num_threads=summarization_llm_settings.concurrent_requests,
259279
)
260280

261281
# Save the updated entities back to storage
262282
await write_table_to_storage(merged_entities_df, "entities", output_storage)
263283

264-
return merged_entities_df, entity_id_mapping
284+
await write_table_to_storage(
285+
merged_relationships_df, "relationships", output_storage
286+
)
287+
288+
return merged_entities_df, merged_relationships_df, entity_id_mapping
265289

266290

267-
async def _concat_dataframes(name, previous_storage, delta_storage, output_storage):
291+
async def _concat_dataframes(
292+
name: str,
293+
previous_storage: PipelineStorage,
294+
delta_storage: PipelineStorage,
295+
output_storage: PipelineStorage,
296+
) -> pd.DataFrame:
268297
"""Concatenate dataframes."""
269298
old_df = await load_table_from_storage(name, previous_storage)
270299
delta_df = await load_table_from_storage(name, delta_storage)

graphrag/index/update/relationships.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
"""Relationship related operations and utils for Incremental Indexing."""
55

6+
import itertools
7+
68
import numpy as np
79
import pandas as pd
810

@@ -42,10 +44,28 @@ def _update_and_merge_relationships(
4244
)
4345

4446
# Merge the DataFrames without copying if possible
45-
final_relationships = pd.concat(
47+
merged_relationships = pd.concat(
4648
[old_relationships, delta_relationships], ignore_index=True, copy=False
4749
)
4850

51+
# Group by title and resolve conflicts
52+
aggregated = (
53+
merged_relationships.groupby(["source", "target"])
54+
.agg({
55+
"id": "first",
56+
"human_readable_id": "first",
57+
"description": lambda x: list(x.astype(str)), # Ensure str
58+
# Concatenate nd.array into a single list
59+
"text_unit_ids": lambda x: list(itertools.chain(*x.tolist())),
60+
"weight": "mean",
61+
"combined_degree": "sum",
62+
})
63+
.reset_index()
64+
)
65+
66+
# Force the result into a DataFrame
67+
final_relationships: pd.DataFrame = pd.DataFrame(aggregated)
68+
4969
# Recalculate target and source degrees
5070
final_relationships["source_degree"] = final_relationships.groupby("source")[
5171
"target"

graphrag/index/validate_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def validate_config_names(logger: ProgressLogger, parameters: GraphRagConfig) ->
4040
embedding_llm_settings = parameters.get_language_model_config(
4141
parameters.embed_text.model_id
4242
)
43+
if embedding_llm_settings.max_retries == -1:
44+
embedding_llm_settings.max_retries = language_model_defaults.max_retries
4345
embed_llm = ModelManager().register_embedding(
4446
name="test-embed-llm",
4547
model_type=embedding_llm_settings.type,

graphrag/index/workflows/extract_graph.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,27 @@ async def extract_graph(
105105
callbacks.error(error_msg)
106106
raise ValueError(error_msg)
107107

108+
entities, relationships = await get_summarized_entities_relationships(
109+
extracted_entities=extracted_entities,
110+
extracted_relationships=extracted_relationships,
111+
callbacks=callbacks,
112+
cache=cache,
113+
summarization_strategy=summarization_strategy,
114+
summarization_num_threads=summarization_num_threads,
115+
)
116+
117+
return (entities, relationships)
118+
119+
120+
async def get_summarized_entities_relationships(
121+
extracted_entities: pd.DataFrame,
122+
extracted_relationships: pd.DataFrame,
123+
callbacks: WorkflowCallbacks,
124+
cache: PipelineCache,
125+
summarization_strategy: dict[str, Any] | None = None,
126+
summarization_num_threads: int = 4,
127+
) -> tuple[pd.DataFrame, pd.DataFrame]:
128+
"""Summarize the entities and relationships."""
108129
entity_summaries, relationship_summaries = await summarize_descriptions(
109130
entities_df=extracted_entities,
110131
relationships_df=extracted_relationships,
@@ -120,8 +141,7 @@ async def extract_graph(
120141

121142
extracted_entities.drop(columns=["description"], inplace=True)
122143
entities = extracted_entities.merge(entity_summaries, on="title", how="left")
123-
124-
return (entities, relationships)
144+
return entities, relationships
125145

126146

127147
def _validate_data(df: pd.DataFrame) -> bool:

0 commit comments

Comments
 (0)