Skip to content

Commit 27c5468

Browse files
KennyZhang1Kenny Zhang
andauthored
Load query from blob (#1095)
* Moved query loading from file to helper function * added loading parquets from blob to function * resolved adlfs async error * debugging cleanup and small fixes * added connection string support * semversioner and ruff fixes * completed testing for merge with main * more ruff changes * fixed unbound vars warning * rewrote function to use storage utils * removed unused vars --------- Co-authored-by: Kenny Zhang <zhangken@microsoft.com>
1 parent 044516f commit 27c5468

File tree

2 files changed

+76
-27
lines changed

2 files changed

+76
-27
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "add querying from azure blob storage"
4+
}

graphrag/query/cli.py

Lines changed: 72 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@
99

1010
import pandas as pd
1111

12-
from graphrag.config import load_config, resolve_path
12+
from graphrag.config import (
13+
GraphRagConfig,
14+
load_config,
15+
resolve_path,
16+
)
17+
from graphrag.index.create_pipeline_config import create_pipeline_config
1318
from graphrag.index.progress import PrintProgressReporter
19+
from graphrag.utils.storage import _create_storage, _load_table_from_storage
1420

1521
from . import api
1622

@@ -36,17 +42,21 @@ def run_global_search(
3642
if data_dir:
3743
config.storage.base_dir = str(resolve_path(data_dir, root))
3844

39-
data_path = Path(config.storage.base_dir).resolve()
40-
41-
final_nodes: pd.DataFrame = pd.read_parquet(
42-
data_path / "create_final_nodes.parquet"
43-
)
44-
final_entities: pd.DataFrame = pd.read_parquet(
45-
data_path / "create_final_entities.parquet"
46-
)
47-
final_community_reports: pd.DataFrame = pd.read_parquet(
48-
data_path / "create_final_community_reports.parquet"
45+
dataframe_dict = _resolve_parquet_files(
46+
root_dir=root_dir,
47+
config=config,
48+
parquet_list=[
49+
"create_final_nodes.parquet",
50+
"create_final_entities.parquet",
51+
"create_final_community_reports.parquet",
52+
],
53+
optional_list=[],
4954
)
55+
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
56+
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
57+
final_community_reports: pd.DataFrame = dataframe_dict[
58+
"create_final_community_reports"
59+
]
5060

5161
# call the Query API
5262
if streaming:
@@ -112,23 +122,26 @@ def run_local_search(
112122
if data_dir:
113123
config.storage.base_dir = str(resolve_path(data_dir, root))
114124

115-
data_path = Path(config.storage.base_dir).resolve()
116-
117-
final_nodes = pd.read_parquet(data_path / "create_final_nodes.parquet")
118-
final_community_reports = pd.read_parquet(
119-
data_path / "create_final_community_reports.parquet"
120-
)
121-
final_text_units = pd.read_parquet(data_path / "create_final_text_units.parquet")
122-
final_relationships = pd.read_parquet(
123-
data_path / "create_final_relationships.parquet"
124-
)
125-
final_entities = pd.read_parquet(data_path / "create_final_entities.parquet")
126-
final_covariates_path = data_path / "create_final_covariates.parquet"
127-
final_covariates = (
128-
pd.read_parquet(final_covariates_path)
129-
if final_covariates_path.exists()
130-
else None
125+
dataframe_dict = _resolve_parquet_files(
126+
root_dir=root_dir,
127+
config=config,
128+
parquet_list=[
129+
"create_final_nodes.parquet",
130+
"create_final_community_reports.parquet",
131+
"create_final_text_units.parquet",
132+
"create_final_relationships.parquet",
133+
"create_final_entities.parquet",
134+
],
135+
optional_list=["create_final_covariates.parquet"],
131136
)
137+
final_nodes: pd.DataFrame = dataframe_dict["create_final_nodes"]
138+
final_community_reports: pd.DataFrame = dataframe_dict[
139+
"create_final_community_reports"
140+
]
141+
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
142+
final_relationships: pd.DataFrame = dataframe_dict["create_final_relationships"]
143+
final_entities: pd.DataFrame = dataframe_dict["create_final_entities"]
144+
final_covariates: pd.DataFrame | None = dataframe_dict["create_final_covariates"]
132145

133146
# call the Query API
134147
if streaming:
@@ -179,3 +192,35 @@ async def run_streaming_search():
179192
# NOTE: we return the response and context data here purely as a complete demonstration of the API.
180193
# External users should use the API directly to get the response and context data.
181194
return response, context_data
195+
196+
197+
def _resolve_parquet_files(
198+
root_dir: str,
199+
config: GraphRagConfig,
200+
parquet_list: list[str],
201+
optional_list: list[str],
202+
) -> dict[str, pd.DataFrame]:
203+
"""Read parquet files to a dataframe dict."""
204+
dataframe_dict = {}
205+
pipeline_config = create_pipeline_config(config)
206+
storage_obj = _create_storage(root_dir=root_dir, config=pipeline_config.storage)
207+
for parquet_file in parquet_list:
208+
df_key = parquet_file.split(".")[0]
209+
df_value = asyncio.run(
210+
_load_table_from_storage(name=parquet_file, storage=storage_obj)
211+
)
212+
dataframe_dict[df_key] = df_value
213+
214+
# for optional parquet files, set the dict entry to None instead of erroring out if it does not exist
215+
for optional_file in optional_list:
216+
file_exists = asyncio.run(storage_obj.has(optional_file))
217+
df_key = optional_file.split(".")[0]
218+
if file_exists:
219+
df_value = asyncio.run(
220+
_load_table_from_storage(name=optional_file, storage=storage_obj)
221+
)
222+
dataframe_dict[df_key] = df_value
223+
else:
224+
dataframe_dict[df_key] = None
225+
226+
return dataframe_dict

0 commit comments

Comments
 (0)