@@ -348,6 +348,87 @@ async def local_search_streaming(
348348 yield stream_chunk
349349
350350
351+ @validate_call (config = {"arbitrary_types_allowed" : True })
352+ async def drift_search_streaming (
353+ config : GraphRagConfig ,
354+ nodes : pd .DataFrame ,
355+ entities : pd .DataFrame ,
356+ community_reports : pd .DataFrame ,
357+ text_units : pd .DataFrame ,
358+ relationships : pd .DataFrame ,
359+ community_level : int ,
360+ response_type : str ,
361+ query : str ,
362+ ) -> AsyncGenerator :
363+ """Perform a DRIFT search and return the context data and response.
364+
365+ Parameters
366+ ----------
367+ - config (GraphRagConfig): A graphrag configuration (from settings.yaml)
368+ - nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
369+ - entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
370+ - community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
371+ - text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
372+ - relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
373+ - community_level (int): The community level to search at.
374+ - query (str): The user query to search for.
375+
376+ Returns
377+ -------
378+ TODO: Document the search response type and format.
379+
380+ Raises
381+ ------
382+ TODO: Document any exceptions to expect.
383+ """
384+ vector_store_args = config .embeddings .vector_store
385+ logger .info (f"Vector Store Args: { redact (vector_store_args )} " ) # type: ignore # noqa
386+
387+ description_embedding_store = _get_embedding_store (
388+ config_args = vector_store_args , # type: ignore
389+ embedding_name = entity_description_embedding ,
390+ )
391+
392+ full_content_embedding_store = _get_embedding_store (
393+ config_args = vector_store_args , # type: ignore
394+ embedding_name = community_full_content_embedding ,
395+ )
396+
397+ entities_ = read_indexer_entities (nodes , entities , community_level )
398+ reports = read_indexer_reports (community_reports , nodes , community_level )
399+ read_indexer_report_embeddings (reports , full_content_embedding_store )
400+ prompt = _load_search_prompt (config .root_dir , config .drift_search .prompt )
401+ reduce_prompt = _load_search_prompt (
402+ config .root_dir , config .drift_search .reduce_prompt
403+ )
404+
405+ search_engine = get_drift_search_engine (
406+ config = config ,
407+ reports = reports ,
408+ text_units = read_indexer_text_units (text_units ),
409+ entities = entities_ ,
410+ relationships = read_indexer_relationships (relationships ),
411+ description_embedding_store = description_embedding_store , # type: ignore
412+ local_system_prompt = prompt ,
413+ reduce_system_prompt = reduce_prompt ,
414+ response_type = response_type ,
415+ )
416+
417+ search_result = search_engine .astream_search (query = query )
418+
419+ # when streaming results, a context data object is returned as the first result
420+ # and the query response in subsequent tokens
421+ context_data = None
422+ get_context_data = True
423+ async for stream_chunk in search_result :
424+ if get_context_data :
425+ context_data = _reformat_context_data (stream_chunk ) # type: ignore
426+ yield context_data
427+ get_context_data = False
428+ else :
429+ yield stream_chunk
430+
431+
351432@validate_call (config = {"arbitrary_types_allowed" : True })
352433async def drift_search (
353434 config : GraphRagConfig ,
@@ -357,6 +438,7 @@ async def drift_search(
357438 text_units : pd .DataFrame ,
358439 relationships : pd .DataFrame ,
359440 community_level : int ,
441+ response_type : str ,
360442 query : str ,
361443) -> tuple [
362444 str | dict [str , Any ] | list [dict [str , Any ]],
@@ -400,6 +482,10 @@ async def drift_search(
400482 reports = read_indexer_reports (community_reports , nodes , community_level )
401483 read_indexer_report_embeddings (reports , full_content_embedding_store )
402484 prompt = _load_search_prompt (config .root_dir , config .drift_search .prompt )
485+ reduce_prompt = _load_search_prompt (
486+ config .root_dir , config .drift_search .reduce_prompt
487+ )
488+
403489 search_engine = get_drift_search_engine (
404490 config = config ,
405491 reports = reports ,
@@ -408,21 +494,15 @@ async def drift_search(
408494 relationships = read_indexer_relationships (relationships ),
409495 description_embedding_store = description_embedding_store , # type: ignore
410496 local_system_prompt = prompt ,
497+ reduce_system_prompt = reduce_prompt ,
498+ response_type = response_type ,
411499 )
412500
413501 result : SearchResult = await search_engine .asearch (query = query )
414502 response = result .response
415503 context_data = _reformat_context_data (result .context_data ) # type: ignore
416504
417- # TODO: Map/reduce the response to a single string with a comprehensive answer including all follow-ups
418- # For the time being, return highest scoring response (position 0) and context data
419- match response :
420- case dict ():
421- return response ["nodes" ][0 ]["answer" ], context_data # type: ignore
422- case str ():
423- return response , context_data
424- case list ():
425- return response , context_data
505+ return response , context_data
426506
427507
428508@validate_call (config = {"arbitrary_types_allowed" : True })
0 commit comments