|
1 | 1 | # Copyright (c) 2024 Microsoft Corporation. |
2 | 2 | # Licensed under the MIT License |
3 | 3 |
|
4 | | -"""A module containing 'PipelineReportingConfig', 'PipelineFileReportingConfig' and 'PipelineConsoleReportingConfig' models.""" |
| 4 | +"""A module containing the pipeline reporter factory.""" |
5 | 5 |
|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
8 | 8 | from pathlib import Path |
9 | | -from typing import TYPE_CHECKING, Generic, Literal, TypeVar, cast |
10 | | - |
11 | | -from pydantic import BaseModel, Field |
| 9 | +from typing import TYPE_CHECKING |
12 | 10 |
|
13 | 11 | from graphrag.callbacks.blob_workflow_callbacks import BlobWorkflowCallbacks |
14 | 12 | from graphrag.callbacks.console_workflow_callbacks import ConsoleWorkflowCallbacks |
15 | 13 | from graphrag.callbacks.file_workflow_callbacks import FileWorkflowCallbacks |
16 | 14 | from graphrag.config.enums import ReportingType |
| 15 | +from graphrag.config.models.reporting_config import ReportingConfig |
17 | 16 |
|
18 | 17 | if TYPE_CHECKING: |
19 | 18 | from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks |
20 | 19 |
|
21 | | -T = TypeVar("T") |
22 | | - |
23 | | - |
24 | | -class PipelineReportingConfig(BaseModel, Generic[T]): |
25 | | - """Represent the reporting configuration for the pipeline.""" |
26 | | - |
27 | | - type: T |
28 | | - |
29 | | - |
30 | | -class PipelineFileReportingConfig(PipelineReportingConfig[Literal[ReportingType.file]]): |
31 | | - """Represent the file reporting configuration for the pipeline.""" |
32 | | - |
33 | | - type: Literal[ReportingType.file] = ReportingType.file |
34 | | - """The type of reporting.""" |
35 | | - |
36 | | - base_dir: str | None = Field( |
37 | | - description="The base directory for the reporting.", default=None |
38 | | - ) |
39 | | - """The base directory for the reporting.""" |
40 | | - |
41 | | - |
42 | | -class PipelineConsoleReportingConfig( |
43 | | - PipelineReportingConfig[Literal[ReportingType.console]] |
44 | | -): |
45 | | - """Represent the console reporting configuration for the pipeline.""" |
46 | | - |
47 | | - type: Literal[ReportingType.console] = ReportingType.console |
48 | | - """The type of reporting.""" |
49 | | - |
50 | | - |
51 | | -class PipelineBlobReportingConfig(PipelineReportingConfig[Literal[ReportingType.blob]]): |
52 | | - """Represents the blob reporting configuration for the pipeline.""" |
53 | | - |
54 | | - type: Literal[ReportingType.blob] = ReportingType.blob |
55 | | - """The type of reporting.""" |
56 | | - |
57 | | - connection_string: str | None = Field( |
58 | | - description="The blob reporting connection string for the reporting.", |
59 | | - default=None, |
60 | | - ) |
61 | | - """The blob reporting connection string for the reporting.""" |
62 | | - |
63 | | - container_name: str = Field( |
64 | | - description="The container name for reporting", default="" |
65 | | - ) |
66 | | - """The container name for reporting""" |
67 | | - |
68 | | - storage_account_blob_url: str | None = Field( |
69 | | - description="The storage account blob url for reporting", default=None |
70 | | - ) |
71 | | - """The storage account blob url for reporting""" |
72 | | - |
73 | | - base_dir: str | None = Field( |
74 | | - description="The base directory for the reporting.", default=None |
75 | | - ) |
76 | | - """The base directory for the reporting.""" |
77 | | - |
78 | | - |
79 | | -PipelineReportingConfigTypes = ( |
80 | | - PipelineFileReportingConfig |
81 | | - | PipelineConsoleReportingConfig |
82 | | - | PipelineBlobReportingConfig |
83 | | -) |
84 | | - |
85 | 20 |
|
86 | 21 | def create_pipeline_reporter( |
87 | | - config: PipelineReportingConfig | None, root_dir: str | None |
| 22 | + config: ReportingConfig | None, root_dir: str | None |
88 | 23 | ) -> WorkflowCallbacks: |
89 | 24 | """Create a logger for the given pipeline config.""" |
90 | | - config = config or PipelineFileReportingConfig(base_dir="logs") |
91 | | - |
| 25 | + config = config or ReportingConfig(base_dir="logs", type=ReportingType.file) |
92 | 26 | match config.type: |
93 | 27 | case ReportingType.file: |
94 | | - config = cast("PipelineFileReportingConfig", config) |
95 | 28 | return FileWorkflowCallbacks( |
96 | 29 | str(Path(root_dir or "") / (config.base_dir or "")) |
97 | 30 | ) |
98 | 31 | case ReportingType.console: |
99 | 32 | return ConsoleWorkflowCallbacks() |
100 | 33 | case ReportingType.blob: |
101 | | - config = cast("PipelineBlobReportingConfig", config) |
102 | 34 | return BlobWorkflowCallbacks( |
103 | 35 | config.connection_string, |
104 | 36 | config.container_name, |
105 | 37 | base_dir=config.base_dir, |
106 | 38 | storage_account_blob_url=config.storage_account_blob_url, |
107 | 39 | ) |
108 | | - case _: |
109 | | - msg = f"Unknown reporting type: {config.type}" |
110 | | - raise ValueError(msg) |
0 commit comments