Skip to content

Commit 8c81cc1

Browse files
authored
Update Index as workflows (#1908)
* Incremental index as workflow * Update function docs * fix state management * Remove update workflows when specifying workflows in the config * Fix ruff errors * Add semver * Remove callbacks param
1 parent 832abf1 commit 8c81cc1

15 files changed

+611
-329
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Update as workflows"
4+
}

graphrag/api/index.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ async def build_index(
6565
if memory_profile:
6666
log.warning("New pipeline does not yet support memory profiling.")
6767

68-
pipeline = PipelineFactory.create_pipeline(config, method)
68+
pipeline = PipelineFactory.create_pipeline(config, method, is_update_run)
6969

7070
workflow_callbacks.pipeline_start(pipeline.names())
7171

graphrag/index/run/run_pipeline.py

Lines changed: 23 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,14 @@
1313

1414
import pandas as pd
1515

16-
from graphrag.cache.pipeline_cache import PipelineCache
17-
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
1816
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1917
from graphrag.config.models.graph_rag_config import GraphRagConfig
2018
from graphrag.index.input.factory import create_input
2119
from graphrag.index.run.utils import create_run_context
2220
from graphrag.index.typing.context import PipelineRunContext
2321
from graphrag.index.typing.pipeline import Pipeline
2422
from graphrag.index.typing.pipeline_run_result import PipelineRunResult
25-
from graphrag.index.update.incremental_index import (
26-
get_delta_docs,
27-
update_dataframe_outputs,
28-
)
23+
from graphrag.index.update.incremental_index import get_delta_docs
2924
from graphrag.logger.base import ProgressLogger
3025
from graphrag.logger.progress import Progress
3126
from graphrag.storage.pipeline_storage import PipelineStorage
@@ -50,6 +45,10 @@ async def run_pipeline(
5045

5146
dataset = await create_input(config.input, logger, root_dir)
5247

48+
# load existing state in case any workflows are stateful
49+
state_json = await storage.get("context.json")
50+
state = json.loads(state_json) if state_json else {}
51+
5352
if is_update_run:
5453
logger.info("Running incremental indexing.")
5554

@@ -62,48 +61,45 @@ async def run_pipeline(
6261
else:
6362
update_storage = create_storage_from_config(config.update_index_output)
6463
# we use this to store the new subset index, and will merge its content with the previous index
65-
timestamped_storage = update_storage.child(time.strftime("%Y%m%d-%H%M%S"))
64+
update_timestamp = time.strftime("%Y%m%d-%H%M%S")
65+
timestamped_storage = update_storage.child(update_timestamp)
6666
delta_storage = timestamped_storage.child("delta")
6767
# copy the previous output to a backup folder, so we can replace it with the update
6868
# we'll read from this later when we merge the old and new indexes
6969
previous_storage = timestamped_storage.child("previous")
7070
await _copy_previous_output(storage, previous_storage)
7171

72+
state["update_timestamp"] = update_timestamp
73+
74+
context = create_run_context(
75+
storage=delta_storage, cache=cache, callbacks=callbacks, state=state
76+
)
77+
7278
# Run the pipeline on the new documents
7379
async for table in _run_pipeline(
7480
pipeline=pipeline,
7581
config=config,
7682
dataset=delta_dataset.new_inputs,
77-
cache=cache,
78-
storage=delta_storage,
79-
callbacks=callbacks,
8083
logger=logger,
84+
context=context,
8185
):
8286
yield table
8387

8488
logger.success("Finished running workflows on new documents.")
8589

86-
await update_dataframe_outputs(
87-
previous_storage=previous_storage,
88-
delta_storage=delta_storage,
89-
output_storage=storage,
90-
config=config,
91-
cache=cache,
92-
callbacks=NoopWorkflowCallbacks(),
93-
progress_logger=logger,
94-
)
95-
9690
else:
9791
logger.info("Running standard indexing.")
9892

93+
context = create_run_context(
94+
storage=storage, cache=cache, callbacks=callbacks, state=state
95+
)
96+
9997
async for table in _run_pipeline(
10098
pipeline=pipeline,
10199
config=config,
102100
dataset=dataset,
103-
cache=cache,
104-
storage=storage,
105-
callbacks=callbacks,
106101
logger=logger,
102+
context=context,
107103
):
108104
yield table
109105

@@ -112,21 +108,11 @@ async def _run_pipeline(
112108
pipeline: Pipeline,
113109
config: GraphRagConfig,
114110
dataset: pd.DataFrame,
115-
cache: PipelineCache,
116-
storage: PipelineStorage,
117-
callbacks: WorkflowCallbacks,
118111
logger: ProgressLogger,
112+
context: PipelineRunContext,
119113
) -> AsyncIterable[PipelineRunResult]:
120114
start_time = time.time()
121115

122-
# load existing state in case any workflows are stateful
123-
state_json = await storage.get("context.json")
124-
state = json.loads(state_json) if state_json else {}
125-
126-
context = create_run_context(
127-
storage=storage, cache=cache, callbacks=callbacks, state=state
128-
)
129-
130116
log.info("Final # of rows loaded: %s", len(dataset))
131117
context.stats.num_documents = len(dataset)
132118
last_workflow = "starting documents"
@@ -138,11 +124,11 @@ async def _run_pipeline(
138124
for name, workflow_function in pipeline.run():
139125
last_workflow = name
140126
progress = logger.child(name, transient=False)
141-
callbacks.workflow_start(name, None)
127+
context.callbacks.workflow_start(name, None)
142128
work_time = time.time()
143129
result = await workflow_function(config, context)
144130
progress(Progress(percent=1))
145-
callbacks.workflow_end(name, result)
131+
context.callbacks.workflow_end(name, result)
146132
yield PipelineRunResult(
147133
workflow=name, result=result.result, state=context.state, errors=None
148134
)
@@ -154,7 +140,7 @@ async def _run_pipeline(
154140

155141
except Exception as e:
156142
log.exception("error running workflow %s", last_workflow)
157-
callbacks.error("Error running pipeline!", e, traceback.format_exc())
143+
context.callbacks.error("Error running pipeline!", e, traceback.format_exc())
158144
yield PipelineRunResult(
159145
workflow=last_workflow, result=None, state=context.state, errors=[e]
160146
)

graphrag/index/run/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
from graphrag.callbacks.progress_workflow_callbacks import ProgressWorkflowCallbacks
1010
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1111
from graphrag.callbacks.workflow_callbacks_manager import WorkflowCallbacksManager
12+
from graphrag.config.models.graph_rag_config import GraphRagConfig
1213
from graphrag.index.typing.context import PipelineRunContext
1314
from graphrag.index.typing.state import PipelineState
1415
from graphrag.index.typing.stats import PipelineRunStats
1516
from graphrag.logger.base import ProgressLogger
1617
from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage
1718
from graphrag.storage.pipeline_storage import PipelineStorage
19+
from graphrag.utils.api import create_storage_from_config
1820

1921

2022
def create_run_context(
@@ -44,3 +46,16 @@ def create_callback_chain(
4446
if progress is not None:
4547
manager.register(ProgressWorkflowCallbacks(progress))
4648
return manager
49+
50+
51+
def get_update_storages(
52+
config: GraphRagConfig, timestamp: str
53+
) -> tuple[PipelineStorage, PipelineStorage, PipelineStorage]:
54+
"""Get storage objects for the update index run."""
55+
output_storage = create_storage_from_config(config.output)
56+
update_storage = create_storage_from_config(config.update_index_output)
57+
timestamped_storage = update_storage.child(timestamp)
58+
delta_storage = timestamped_storage.child("delta")
59+
previous_storage = timestamped_storage.child("previous")
60+
61+
return output_storage, previous_storage, delta_storage

0 commit comments

Comments
 (0)