Skip to content

Commit ae094bb

Browse files
Collapse create final relationships (#1158)
* Collapse pre/post embedding workflows * Semver * Fix smoke tests --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
1 parent bd2c1da commit ae094bb

File tree

9 files changed

+204
-61
lines changed

9 files changed

+204
-61
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Collapse create_final_relationships."
4+
}

graphrag/index/verbs/graph/compute_edge_combined_degree.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,34 @@ def compute_edge_combined_degree(
3232
- to: The name of the column to output the combined degree to. Default="rank"
3333
"""
3434
edge_df: pd.DataFrame = cast(pd.DataFrame, input.get_input())
35-
if to in edge_df.columns:
36-
return TableContainer(table=edge_df)
3735
node_degree_df = _get_node_degree_table(input, node_name_column, node_degree_column)
3836

37+
output_df = compute_edge_combined_degree_df(
38+
edge_df,
39+
node_degree_df,
40+
to,
41+
node_name_column,
42+
node_degree_column,
43+
edge_source_column,
44+
edge_target_column,
45+
)
46+
47+
return TableContainer(table=output_df)
48+
49+
50+
def compute_edge_combined_degree_df(
51+
edge_df: pd.DataFrame,
52+
node_degree_df: pd.DataFrame,
53+
to: str,
54+
node_name_column: str,
55+
node_degree_column: str,
56+
edge_source_column: str,
57+
edge_target_column: str,
58+
) -> pd.DataFrame:
59+
"""Compute the combined degree for each edge in a graph."""
60+
if to in edge_df.columns:
61+
return edge_df
62+
3963
def join_to_degree(df: pd.DataFrame, column: str) -> pd.DataFrame:
4064
degree_column = _degree_colname(column)
4165
result = df.merge(
@@ -48,14 +72,13 @@ def join_to_degree(df: pd.DataFrame, column: str) -> pd.DataFrame:
4872
result[degree_column] = result[degree_column].fillna(0)
4973
return result
5074

51-
edge_df = join_to_degree(edge_df, edge_source_column)
52-
edge_df = join_to_degree(edge_df, edge_target_column)
53-
edge_df[to] = (
54-
edge_df[_degree_colname(edge_source_column)]
55-
+ edge_df[_degree_colname(edge_target_column)]
75+
output_df = join_to_degree(edge_df, edge_source_column)
76+
output_df = join_to_degree(output_df, edge_target_column)
77+
output_df[to] = (
78+
output_df[_degree_colname(edge_source_column)]
79+
+ output_df[_degree_colname(edge_target_column)]
5680
)
57-
58-
return TableContainer(table=edge_df)
81+
return output_df
5982

6083

6184
def _degree_colname(column: str) -> str:

graphrag/index/workflows/v1/create_final_relationships.py

Lines changed: 9 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def build_steps(
1616
1717
## Dependencies
1818
* `workflow:create_base_entity_graph`
19+
* `workflow:create_final_nodes`
1920
"""
2021
base_text_embed = config.get("text_embed", {})
2122
relationship_description_embed_config = config.get(
@@ -25,25 +26,12 @@ def build_steps(
2526

2627
return [
2728
{
28-
"verb": "unpack_graph",
29-
"args": {
30-
"column": "clustered_graph",
31-
"type": "edges",
32-
},
29+
"id": "pre_embedding",
30+
"verb": "create_final_relationships_pre_embedding",
3331
"input": {"source": "workflow:create_base_entity_graph"},
3432
},
3533
{
36-
"verb": "rename",
37-
"args": {"columns": {"source_id": "text_unit_ids"}},
38-
},
39-
{
40-
"verb": "filter",
41-
"args": {
42-
"column": "level",
43-
"criteria": [{"type": "value", "operator": "equals", "value": 0}],
44-
},
45-
},
46-
{
34+
"id": "description_embedding",
4735
"verb": "text_embed",
4836
"enabled": not skip_description_embedding,
4937
"args": {
@@ -54,41 +42,12 @@ def build_steps(
5442
},
5543
},
5644
{
57-
"id": "pruned_edges",
58-
"verb": "drop",
59-
"args": {"columns": ["level"]},
60-
},
61-
{
62-
"id": "filtered_nodes",
63-
"verb": "filter",
64-
"args": {
65-
"column": "level",
66-
"criteria": [{"type": "value", "operator": "equals", "value": 0}],
67-
},
68-
"input": "workflow:create_final_nodes",
69-
},
70-
{
71-
"verb": "compute_edge_combined_degree",
72-
"args": {"to": "rank"},
45+
"verb": "create_final_relationships_post_embedding",
7346
"input": {
74-
"source": "pruned_edges",
75-
"nodes": "filtered_nodes",
76-
},
77-
},
78-
{
79-
"verb": "convert",
80-
"args": {
81-
"column": "human_readable_id",
82-
"type": "string",
83-
"to": "human_readable_id",
84-
},
85-
},
86-
{
87-
"verb": "convert",
88-
"args": {
89-
"column": "text_unit_ids",
90-
"type": "array",
91-
"to": "text_unit_ids",
47+
"source": "pre_embedding"
48+
if skip_description_embedding
49+
else "description_embedding",
50+
"nodes": "workflow:create_final_nodes",
9251
},
9352
},
9453
]

graphrag/index/workflows/v1/subflows/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,17 @@
44
"""The Indexing Engine workflows -> subflows package root."""
55

66
from .create_final_communities import create_final_communities
7+
from .create_final_relationships_post_embedding import (
8+
create_final_relationships_post_embedding,
9+
)
10+
from .create_final_relationships_pre_embedding import (
11+
create_final_relationships_pre_embedding,
12+
)
713
from .create_final_text_units_pre_embedding import create_final_text_units_pre_embedding
814

915
__all__ = [
1016
"create_final_communities",
17+
"create_final_relationships_post_embedding",
18+
"create_final_relationships_pre_embedding",
1119
"create_final_text_units_pre_embedding",
1220
]
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""All the steps to transform final relationships after they are embedded."""
5+
6+
from typing import Any, cast
7+
8+
import pandas as pd
9+
from datashaper import (
10+
Table,
11+
VerbInput,
12+
verb,
13+
)
14+
from datashaper.table_store.types import VerbResult, create_verb_result
15+
16+
from graphrag.index.utils.ds_util import get_required_input_table
17+
from graphrag.index.verbs.graph.compute_edge_combined_degree import (
18+
compute_edge_combined_degree_df,
19+
)
20+
21+
22+
@verb(
23+
name="create_final_relationships_post_embedding",
24+
treats_input_tables_as_immutable=True,
25+
)
26+
def create_final_relationships_post_embedding(
27+
input: VerbInput,
28+
**_kwargs: dict,
29+
) -> VerbResult:
30+
"""All the steps to transform final relationships after they are embedded."""
31+
table = cast(pd.DataFrame, input.get_input())
32+
nodes = cast(pd.DataFrame, get_required_input_table(input, "nodes").table)
33+
34+
pruned_edges = table.drop(columns=["level"])
35+
36+
filtered_nodes = cast(
37+
pd.DataFrame,
38+
nodes[nodes["level"] == 0].reset_index(drop=True)[["title", "degree"]],
39+
)
40+
41+
edge_combined_degree = compute_edge_combined_degree_df(
42+
pruned_edges,
43+
filtered_nodes,
44+
to="rank",
45+
node_name_column="title",
46+
node_degree_column="degree",
47+
edge_source_column="source",
48+
edge_target_column="target",
49+
)
50+
51+
edge_combined_degree["human_readable_id"] = edge_combined_degree[
52+
"human_readable_id"
53+
].astype(str)
54+
edge_combined_degree["text_unit_ids"] = _to_array(
55+
edge_combined_degree["text_unit_ids"], ","
56+
)
57+
58+
return create_verb_result(cast(Table, edge_combined_degree))
59+
60+
61+
# from datashaper, we should be able to inline this
62+
def _to_array(column, delimiter: str):
63+
def convert_value(value: Any) -> list:
64+
if pd.isna(value):
65+
return []
66+
if isinstance(value, list):
67+
return value
68+
if isinstance(value, str):
69+
return value.split(delimiter)
70+
return [value]
71+
72+
return column.apply(convert_value)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""All the steps to transform final relationships before they are embedded."""
5+
6+
from typing import cast
7+
8+
import pandas as pd
9+
from datashaper import (
10+
Table,
11+
VerbCallbacks,
12+
VerbInput,
13+
verb,
14+
)
15+
from datashaper.table_store.types import VerbResult, create_verb_result
16+
17+
from graphrag.index.verbs.graph.unpack import unpack_graph_df
18+
19+
20+
@verb(
21+
name="create_final_relationships_pre_embedding",
22+
treats_input_tables_as_immutable=True,
23+
)
24+
def create_final_relationships_pre_embedding(
25+
input: VerbInput,
26+
callbacks: VerbCallbacks,
27+
**_kwargs: dict,
28+
) -> VerbResult:
29+
"""All the steps to transform final relationships before they are embedded."""
30+
table = cast(pd.DataFrame, input.get_input())
31+
32+
graph_edges = unpack_graph_df(table, callbacks, "clustered_graph", "edges")
33+
34+
graph_edges.rename(columns={"source_id": "text_unit_ids"}, inplace=True)
35+
36+
filtered = graph_edges[graph_edges["level"] == 0].reset_index(drop=True)
37+
38+
return create_verb_result(cast(Table, filtered))

tests/fixtures/min-csv/config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
1,
5353
2000
5454
],
55-
"subworkflows": 8,
55+
"subworkflows": 2,
5656
"max_runtime": 100
5757
},
5858
"create_final_nodes": {

tests/fixtures/text/config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
1,
7272
2000
7373
],
74-
"subworkflows": 8,
74+
"subworkflows": 2,
7575
"max_runtime": 100
7676
},
7777
"create_final_nodes": {
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
from graphrag.index.workflows.v1.create_final_relationships import (
5+
build_steps,
6+
workflow_name,
7+
)
8+
9+
from .util import (
10+
compare_outputs,
11+
get_config_for_workflow,
12+
get_workflow_output,
13+
load_expected,
14+
load_input_tables,
15+
remove_disabled_steps,
16+
)
17+
18+
19+
async def test_create_final_relationships():
20+
input_tables = load_input_tables([
21+
"workflow:create_base_entity_graph",
22+
"workflow:create_final_nodes",
23+
])
24+
expected = load_expected(workflow_name)
25+
26+
config = get_config_for_workflow(workflow_name)
27+
28+
config["skip_description_embedding"] = True
29+
30+
steps = remove_disabled_steps(build_steps(config))
31+
32+
actual = await get_workflow_output(
33+
input_tables,
34+
{
35+
"steps": steps,
36+
},
37+
)
38+
39+
compare_outputs(actual, expected)

0 commit comments

Comments
 (0)