99
1010import pandas as pd
1111
12- from graphrag .config import load_config , resolve_path
12+ from graphrag .config import (
13+ GraphRagConfig ,
14+ load_config ,
15+ resolve_path ,
16+ )
17+ from graphrag .index .create_pipeline_config import create_pipeline_config
1318from graphrag .index .progress import PrintProgressReporter
19+ from graphrag .utils .storage import _create_storage , _load_table_from_storage
1420
1521from . import api
1622
@@ -36,17 +42,21 @@ def run_global_search(
3642 if data_dir :
3743 config .storage .base_dir = str (resolve_path (data_dir , root ))
3844
39- data_path = Path (config .storage .base_dir ).resolve ()
40-
41- final_nodes : pd .DataFrame = pd .read_parquet (
42- data_path / "create_final_nodes.parquet"
43- )
44- final_entities : pd .DataFrame = pd .read_parquet (
45- data_path / "create_final_entities.parquet"
46- )
47- final_community_reports : pd .DataFrame = pd .read_parquet (
48- data_path / "create_final_community_reports.parquet"
45+ dataframe_dict = _resolve_parquet_files (
46+ root_dir = root_dir ,
47+ config = config ,
48+ parquet_list = [
49+ "create_final_nodes.parquet" ,
50+ "create_final_entities.parquet" ,
51+ "create_final_community_reports.parquet" ,
52+ ],
53+ optional_list = [],
4954 )
55+ final_nodes : pd .DataFrame = dataframe_dict ["create_final_nodes" ]
56+ final_entities : pd .DataFrame = dataframe_dict ["create_final_entities" ]
57+ final_community_reports : pd .DataFrame = dataframe_dict [
58+ "create_final_community_reports"
59+ ]
5060
5161 # call the Query API
5262 if streaming :
@@ -112,23 +122,26 @@ def run_local_search(
112122 if data_dir :
113123 config .storage .base_dir = str (resolve_path (data_dir , root ))
114124
115- data_path = Path (config .storage .base_dir ).resolve ()
116-
117- final_nodes = pd .read_parquet (data_path / "create_final_nodes.parquet" )
118- final_community_reports = pd .read_parquet (
119- data_path / "create_final_community_reports.parquet"
120- )
121- final_text_units = pd .read_parquet (data_path / "create_final_text_units.parquet" )
122- final_relationships = pd .read_parquet (
123- data_path / "create_final_relationships.parquet"
124- )
125- final_entities = pd .read_parquet (data_path / "create_final_entities.parquet" )
126- final_covariates_path = data_path / "create_final_covariates.parquet"
127- final_covariates = (
128- pd .read_parquet (final_covariates_path )
129- if final_covariates_path .exists ()
130- else None
125+ dataframe_dict = _resolve_parquet_files (
126+ root_dir = root_dir ,
127+ config = config ,
128+ parquet_list = [
129+ "create_final_nodes.parquet" ,
130+ "create_final_community_reports.parquet" ,
131+ "create_final_text_units.parquet" ,
132+ "create_final_relationships.parquet" ,
133+ "create_final_entities.parquet" ,
134+ ],
135+ optional_list = ["create_final_covariates.parquet" ],
131136 )
137+ final_nodes : pd .DataFrame = dataframe_dict ["create_final_nodes" ]
138+ final_community_reports : pd .DataFrame = dataframe_dict [
139+ "create_final_community_reports"
140+ ]
141+ final_text_units : pd .DataFrame = dataframe_dict ["create_final_text_units" ]
142+ final_relationships : pd .DataFrame = dataframe_dict ["create_final_relationships" ]
143+ final_entities : pd .DataFrame = dataframe_dict ["create_final_entities" ]
144+ final_covariates : pd .DataFrame | None = dataframe_dict ["create_final_covariates" ]
132145
133146 # call the Query API
134147 if streaming :
@@ -179,3 +192,35 @@ async def run_streaming_search():
179192 # NOTE: we return the response and context data here purely as a complete demonstration of the API.
180193 # External users should use the API directly to get the response and context data.
181194 return response , context_data
195+
196+
197+ def _resolve_parquet_files (
198+ root_dir : str ,
199+ config : GraphRagConfig ,
200+ parquet_list : list [str ],
201+ optional_list : list [str ],
202+ ) -> dict [str , pd .DataFrame ]:
203+ """Read parquet files to a dataframe dict."""
204+ dataframe_dict = {}
205+ pipeline_config = create_pipeline_config (config )
206+ storage_obj = _create_storage (root_dir = root_dir , config = pipeline_config .storage )
207+ for parquet_file in parquet_list :
208+ df_key = parquet_file .split ("." )[0 ]
209+ df_value = asyncio .run (
210+ _load_table_from_storage (name = parquet_file , storage = storage_obj )
211+ )
212+ dataframe_dict [df_key ] = df_value
213+
214+ # for optional parquet files, set the dict entry to None instead of erroring out if it does not exist
215+ for optional_file in optional_list :
216+ file_exists = asyncio .run (storage_obj .has (optional_file ))
217+ df_key = optional_file .split ("." )[0 ]
218+ if file_exists :
219+ df_value = asyncio .run (
220+ _load_table_from_storage (name = optional_file , storage = storage_obj )
221+ )
222+ dataframe_dict [df_key ] = df_value
223+ else :
224+ dataframe_dict [df_key ] = None
225+
226+ return dataframe_dict
0 commit comments