Skip to content

Commit ede6a74

Browse files
Pipeline callbacks (#1729)
* Add pipeline_start and pipeline_end callbacks * Collapse redundant callback/logger logic * Remove redundant reporting config classes * Remove a few out-of-date type ignores * Semver --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
1 parent e404761 commit ede6a74

File tree

15 files changed

+150
-156
lines changed

15 files changed

+150
-156
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Add pipeline_start and pipeline_end callbacks."
4+
}

graphrag/api/index.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@
1010

1111
import logging
1212

13-
from graphrag.cache.noop_pipeline_cache import NoopPipelineCache
1413
from graphrag.callbacks.reporting import create_pipeline_reporter
1514
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
16-
from graphrag.config.enums import CacheType, IndexingMethod
15+
from graphrag.config.enums import IndexingMethod
1716
from graphrag.config.models.graph_rag_config import GraphRagConfig
17+
from graphrag.index.run.pipeline_run_result import PipelineRunResult
1818
from graphrag.index.run.run_pipeline import run_pipeline
19-
from graphrag.index.typing import PipelineRunResult, WorkflowFunction
19+
from graphrag.index.run.utils import create_callback_chain
20+
from graphrag.index.typing import WorkflowFunction
2021
from graphrag.index.workflows.factory import PipelineFactory
2122
from graphrag.logger.base import ProgressLogger
23+
from graphrag.logger.null_progress import NullProgressLogger
2224

2325
log = logging.getLogger(__name__)
2426

@@ -51,36 +53,37 @@ async def build_index(
5153
list[PipelineRunResult]
5254
The list of pipeline run results
5355
"""
54-
pipeline_cache = (
55-
NoopPipelineCache() if config.cache.type == CacheType.none is None else None
56-
)
56+
logger = progress_logger or NullProgressLogger()
5757
# create a pipeline reporter and add to any additional callbacks
58-
# TODO: remove the type ignore once the new config engine has been refactored
5958
callbacks = callbacks or []
60-
callbacks.append(create_pipeline_reporter(config.reporting, None)) # type: ignore
59+
callbacks.append(create_pipeline_reporter(config.reporting, None))
60+
61+
workflow_callbacks = create_callback_chain(callbacks, logger)
62+
6163
outputs: list[PipelineRunResult] = []
6264

6365
if memory_profile:
6466
log.warning("New pipeline does not yet support memory profiling.")
6567

6668
pipeline = PipelineFactory.create_pipeline(config, method)
6769

70+
workflow_callbacks.pipeline_start(pipeline.names())
71+
6872
async for output in run_pipeline(
6973
pipeline,
7074
config,
71-
cache=pipeline_cache,
72-
callbacks=callbacks,
73-
logger=progress_logger,
75+
callbacks=workflow_callbacks,
76+
logger=logger,
7477
is_update_run=is_update_run,
7578
):
7679
outputs.append(output)
77-
if progress_logger:
78-
if output.errors and len(output.errors) > 0:
79-
progress_logger.error(output.workflow)
80-
else:
81-
progress_logger.success(output.workflow)
82-
progress_logger.info(str(output.result))
80+
if output.errors and len(output.errors) > 0:
81+
logger.error(output.workflow)
82+
else:
83+
logger.success(output.workflow)
84+
logger.info(str(output.result))
8385

86+
workflow_callbacks.pipeline_end(outputs)
8487
return outputs
8588

8689

graphrag/api/query.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -435,10 +435,11 @@ def local_search_streaming(
435435
vector_store_args = {}
436436
for index, store in config.vector_store.items():
437437
vector_store_args[index] = store.model_dump()
438-
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
438+
msg = f"Vector Store Args: {redact(vector_store_args)}"
439+
logger.info(msg)
439440

440441
description_embedding_store = get_embedding_store(
441-
config_args=vector_store_args, # type: ignore
442+
config_args=vector_store_args,
442443
embedding_name=entity_description_embedding,
443444
)
444445

@@ -453,7 +454,7 @@ def local_search_streaming(
453454
entities=entities_,
454455
relationships=read_indexer_relationships(relationships),
455456
covariates={"claims": covariates_},
456-
description_embedding_store=description_embedding_store, # type: ignore
457+
description_embedding_store=description_embedding_store,
457458
response_type=response_type,
458459
system_prompt=prompt,
459460
callbacks=callbacks,
@@ -789,15 +790,16 @@ def drift_search_streaming(
789790
vector_store_args = {}
790791
for index, store in config.vector_store.items():
791792
vector_store_args[index] = store.model_dump()
792-
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
793+
msg = f"Vector Store Args: {redact(vector_store_args)}"
794+
logger.info(msg)
793795

794796
description_embedding_store = get_embedding_store(
795-
config_args=vector_store_args, # type: ignore
797+
config_args=vector_store_args,
796798
embedding_name=entity_description_embedding,
797799
)
798800

799801
full_content_embedding_store = get_embedding_store(
800-
config_args=vector_store_args, # type: ignore
802+
config_args=vector_store_args,
801803
embedding_name=community_full_content_embedding,
802804
)
803805

@@ -815,7 +817,7 @@ def drift_search_streaming(
815817
text_units=read_indexer_text_units(text_units),
816818
entities=entities_,
817819
relationships=read_indexer_relationships(relationships),
818-
description_embedding_store=description_embedding_store, # type: ignore
820+
description_embedding_store=description_embedding_store,
819821
local_system_prompt=prompt,
820822
reduce_system_prompt=reduce_prompt,
821823
response_type=response_type,
@@ -1104,10 +1106,11 @@ def basic_search_streaming(
11041106
vector_store_args = {}
11051107
for index, store in config.vector_store.items():
11061108
vector_store_args[index] = store.model_dump()
1107-
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
1109+
msg = f"Vector Store Args: {redact(vector_store_args)}"
1110+
logger.info(msg)
11081111

11091112
description_embedding_store = get_embedding_store(
1110-
config_args=vector_store_args, # type: ignore
1113+
config_args=vector_store_args,
11111114
embedding_name=text_unit_text_embedding,
11121115
)
11131116

graphrag/callbacks/blob_workflow_callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ class BlobWorkflowCallbacks(NoopWorkflowCallbacks):
2424
def __init__(
2525
self,
2626
connection_string: str | None,
27-
container_name: str,
27+
container_name: str | None,
2828
blob_name: str = "",
2929
base_dir: str | None = None,
3030
storage_account_blob_url: str | None = None,
31-
): # type: ignore
31+
):
3232
"""Create a new instance of the BlobStorageReporter class."""
3333
if container_name is None:
3434
msg = "No container name provided for blob storage."

graphrag/callbacks/noop_workflow_callbacks.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,19 @@
44
"""A no-op implementation of WorkflowCallbacks."""
55

66
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
7+
from graphrag.index.run.pipeline_run_result import PipelineRunResult
78
from graphrag.logger.progress import Progress
89

910

1011
class NoopWorkflowCallbacks(WorkflowCallbacks):
1112
"""A no-op implementation of WorkflowCallbacks."""
1213

14+
def pipeline_start(self, names: list[str]) -> None:
15+
"""Execute this callback when a the entire pipeline starts."""
16+
17+
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
18+
"""Execute this callback when the entire pipeline ends."""
19+
1320
def workflow_start(self, name: str, instance: object) -> None:
1421
"""Execute this callback when a workflow starts."""
1522

graphrag/callbacks/reporting.py

Lines changed: 5 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,110 +1,39 @@
11
# Copyright (c) 2024 Microsoft Corporation.
22
# Licensed under the MIT License
33

4-
"""A module containing 'PipelineReportingConfig', 'PipelineFileReportingConfig' and 'PipelineConsoleReportingConfig' models."""
4+
"""A module containing the pipeline reporter factory."""
55

66
from __future__ import annotations
77

88
from pathlib import Path
9-
from typing import TYPE_CHECKING, Generic, Literal, TypeVar, cast
10-
11-
from pydantic import BaseModel, Field
9+
from typing import TYPE_CHECKING
1210

1311
from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks
1412
from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks
1513
from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks
1614
from graphrag.config.enums import ReportingType
15+
from graphrag.config.models.reporting_config import ReportingConfig
1716

1817
if TYPE_CHECKING:
1918
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
2019

21-
T = TypeVar("T")
22-
23-
24-
class PipelineReportingConfig(BaseModel, Generic[T]):
25-
"""Represent the reporting configuration for the pipeline."""
26-
27-
type: T
28-
29-
30-
class PipelineFileReportingConfig(PipelineReportingConfig[Literal[ReportingType.file]]):
31-
"""Represent the file reporting configuration for the pipeline."""
32-
33-
type: Literal[ReportingType.file] = ReportingType.file
34-
"""The type of reporting."""
35-
36-
base_dir: str | None = Field(
37-
description="The base directory for the reporting.", default=None
38-
)
39-
"""The base directory for the reporting."""
40-
41-
42-
class PipelineConsoleReportingConfig(
43-
PipelineReportingConfig[Literal[ReportingType.console]]
44-
):
45-
"""Represent the console reporting configuration for the pipeline."""
46-
47-
type: Literal[ReportingType.console] = ReportingType.console
48-
"""The type of reporting."""
49-
50-
51-
class PipelineBlobReportingConfig(PipelineReportingConfig[Literal[ReportingType.blob]]):
52-
"""Represents the blob reporting configuration for the pipeline."""
53-
54-
type: Literal[ReportingType.blob] = ReportingType.blob
55-
"""The type of reporting."""
56-
57-
connection_string: str | None = Field(
58-
description="The blob reporting connection string for the reporting.",
59-
default=None,
60-
)
61-
"""The blob reporting connection string for the reporting."""
62-
63-
container_name: str = Field(
64-
description="The container name for reporting", default=""
65-
)
66-
"""The container name for reporting"""
67-
68-
storage_account_blob_url: str | None = Field(
69-
description="The storage account blob url for reporting", default=None
70-
)
71-
"""The storage account blob url for reporting"""
72-
73-
base_dir: str | None = Field(
74-
description="The base directory for the reporting.", default=None
75-
)
76-
"""The base directory for the reporting."""
77-
78-
79-
PipelineReportingConfigTypes = (
80-
PipelineFileReportingConfig
81-
| PipelineConsoleReportingConfig
82-
| PipelineBlobReportingConfig
83-
)
84-
8520

8621
def create_pipeline_reporter(
87-
config: PipelineReportingConfig | None, root_dir: str | None
22+
config: ReportingConfig | None, root_dir: str | None
8823
) -> WorkflowCallbacks:
8924
"""Create a logger for the given pipeline config."""
90-
config = config or PipelineFileReportingConfig(base_dir="logs")
91-
25+
config = config or ReportingConfig(base_dir="logs", type=ReportingType.file)
9226
match config.type:
9327
case ReportingType.file:
94-
config = cast("PipelineFileReportingConfig", config)
9528
return FileWorkflowCallbacks(
9629
str(Path(root_dir or "") / (config.base_dir or ""))
9730
)
9831
case ReportingType.console:
9932
return ConsoleWorkflowCallbacks()
10033
case ReportingType.blob:
101-
config = cast("PipelineBlobReportingConfig", config)
10234
return BlobWorkflowCallbacks(
10335
config.connection_string,
10436
config.container_name,
10537
base_dir=config.base_dir,
10638
storage_account_blob_url=config.storage_account_blob_url,
10739
)
108-
case _:
109-
msg = f"Unknown reporting type: {config.type}"
110-
raise ValueError(msg)

graphrag/callbacks/workflow_callbacks.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from typing import Protocol
77

8+
from graphrag.index.run.pipeline_run_result import PipelineRunResult
89
from graphrag.logger.progress import Progress
910

1011

@@ -15,6 +16,14 @@ class WorkflowCallbacks(Protocol):
1516
This base class is a "noop" implementation so that clients may implement just the callbacks they need.
1617
"""
1718

19+
def pipeline_start(self, names: list[str]) -> None:
20+
"""Execute this callback to signal when the entire pipeline starts."""
21+
...
22+
23+
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
24+
"""Execute this callback to signal when the entire pipeline ends."""
25+
...
26+
1827
def workflow_start(self, name: str, instance: object) -> None:
1928
"""Execute this callback when a workflow starts."""
2029
...

graphrag/callbacks/workflow_callbacks_manager.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"""A module containing the WorkflowCallbacks registry."""
55

66
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
7+
from graphrag.index.run.pipeline_run_result import PipelineRunResult
78
from graphrag.logger.progress import Progress
89

910

@@ -20,6 +21,18 @@ def register(self, callbacks: WorkflowCallbacks) -> None:
2021
"""Register a new WorkflowCallbacks type."""
2122
self._callbacks.append(callbacks)
2223

24+
def pipeline_start(self, names: list[str]) -> None:
25+
"""Execute this callback when a the entire pipeline starts."""
26+
for callback in self._callbacks:
27+
if hasattr(callback, "pipeline_start"):
28+
callback.pipeline_start(names)
29+
30+
def pipeline_end(self, results: list[PipelineRunResult]) -> None:
31+
"""Execute this callback when the entire pipeline ends."""
32+
for callback in self._callbacks:
33+
if hasattr(callback, "pipeline_end"):
34+
callback.pipeline_end(results)
35+
2336
def workflow_start(self, name: str, instance: object) -> None:
2437
"""Execute this callback when a workflow starts."""
2538
for callback in self._callbacks:

graphrag/cli/query.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ def _resolve_output_files(
527527
return dataframe_dict
528528
# Loading output files for single-index search
529529
dataframe_dict["multi-index"] = False
530-
output_config = config.output.model_dump() # type: ignore
530+
output_config = config.output.model_dump()
531531
storage_obj = StorageFactory().create_storage(
532532
storage_type=output_config["type"], kwargs=output_config
533533
)

graphrag/index/run/pipeline.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""A module containing the Pipeline class."""
5+
6+
from collections.abc import Generator
7+
8+
from graphrag.index.typing import Workflow
9+
10+
11+
class Pipeline:
12+
"""Encapsulates running workflows."""
13+
14+
def __init__(self, workflows: list[Workflow]):
15+
self.workflows = workflows
16+
17+
def run(self) -> Generator[Workflow]:
18+
"""Return a Generator over the pipeline workflows."""
19+
yield from self.workflows
20+
21+
def names(self) -> list[str]:
22+
"""Return the names of the workflows in the pipeline."""
23+
return [name for name, _ in self.workflows]

0 commit comments

Comments
 (0)