Skip to content

Commit 3defab2

Browse files
Reduce Drift Response and Streaming endpoint (#1624)
* Adding basic wrappes for reduce in drift * Add response_type parameter to run_drift_search and enhance reduce response functionality * Add streaming endpoint * Semver * Spellcheck * Ruff checks * Count tokens on reduce * Use list comprehension and remove llm_params map in favor of just using kwargs
1 parent 4637270 commit 3defab2

File tree

16 files changed

+809
-581
lines changed

16 files changed

+809
-581
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Add Drift Reduce response and streaming endpoint"
4+
}

graphrag/api/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
basic_search,
1414
basic_search_streaming,
1515
drift_search,
16+
drift_search_streaming,
1617
global_search,
1718
global_search_streaming,
1819
local_search,
@@ -29,6 +30,7 @@
2930
"local_search",
3031
"local_search_streaming",
3132
"drift_search",
33+
"drift_search_streaming",
3234
"basic_search",
3335
"basic_search_streaming",
3436
# prompt tuning API

graphrag/api/query.py

Lines changed: 89 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,87 @@ async def local_search_streaming(
348348
yield stream_chunk
349349

350350

351+
@validate_call(config={"arbitrary_types_allowed": True})
352+
async def drift_search_streaming(
353+
config: GraphRagConfig,
354+
nodes: pd.DataFrame,
355+
entities: pd.DataFrame,
356+
community_reports: pd.DataFrame,
357+
text_units: pd.DataFrame,
358+
relationships: pd.DataFrame,
359+
community_level: int,
360+
response_type: str,
361+
query: str,
362+
) -> AsyncGenerator:
363+
"""Perform a DRIFT search and return the context data and response.
364+
365+
Parameters
366+
----------
367+
- config (GraphRagConfig): A graphrag configuration (from settings.yaml)
368+
- nodes (pd.DataFrame): A DataFrame containing the final nodes (from create_final_nodes.parquet)
369+
- entities (pd.DataFrame): A DataFrame containing the final entities (from create_final_entities.parquet)
370+
- community_reports (pd.DataFrame): A DataFrame containing the final community reports (from create_final_community_reports.parquet)
371+
- text_units (pd.DataFrame): A DataFrame containing the final text units (from create_final_text_units.parquet)
372+
- relationships (pd.DataFrame): A DataFrame containing the final relationships (from create_final_relationships.parquet)
373+
- community_level (int): The community level to search at.
374+
- query (str): The user query to search for.
375+
376+
Returns
377+
-------
378+
TODO: Document the search response type and format.
379+
380+
Raises
381+
------
382+
TODO: Document any exceptions to expect.
383+
"""
384+
vector_store_args = config.embeddings.vector_store
385+
logger.info(f"Vector Store Args: {redact(vector_store_args)}") # type: ignore # noqa
386+
387+
description_embedding_store = _get_embedding_store(
388+
config_args=vector_store_args, # type: ignore
389+
embedding_name=entity_description_embedding,
390+
)
391+
392+
full_content_embedding_store = _get_embedding_store(
393+
config_args=vector_store_args, # type: ignore
394+
embedding_name=community_full_content_embedding,
395+
)
396+
397+
entities_ = read_indexer_entities(nodes, entities, community_level)
398+
reports = read_indexer_reports(community_reports, nodes, community_level)
399+
read_indexer_report_embeddings(reports, full_content_embedding_store)
400+
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
401+
reduce_prompt = _load_search_prompt(
402+
config.root_dir, config.drift_search.reduce_prompt
403+
)
404+
405+
search_engine = get_drift_search_engine(
406+
config=config,
407+
reports=reports,
408+
text_units=read_indexer_text_units(text_units),
409+
entities=entities_,
410+
relationships=read_indexer_relationships(relationships),
411+
description_embedding_store=description_embedding_store, # type: ignore
412+
local_system_prompt=prompt,
413+
reduce_system_prompt=reduce_prompt,
414+
response_type=response_type,
415+
)
416+
417+
search_result = search_engine.astream_search(query=query)
418+
419+
# when streaming results, a context data object is returned as the first result
420+
# and the query response in subsequent tokens
421+
context_data = None
422+
get_context_data = True
423+
async for stream_chunk in search_result:
424+
if get_context_data:
425+
context_data = _reformat_context_data(stream_chunk) # type: ignore
426+
yield context_data
427+
get_context_data = False
428+
else:
429+
yield stream_chunk
430+
431+
351432
@validate_call(config={"arbitrary_types_allowed": True})
352433
async def drift_search(
353434
config: GraphRagConfig,
@@ -357,6 +438,7 @@ async def drift_search(
357438
text_units: pd.DataFrame,
358439
relationships: pd.DataFrame,
359440
community_level: int,
441+
response_type: str,
360442
query: str,
361443
) -> tuple[
362444
str | dict[str, Any] | list[dict[str, Any]],
@@ -400,6 +482,10 @@ async def drift_search(
400482
reports = read_indexer_reports(community_reports, nodes, community_level)
401483
read_indexer_report_embeddings(reports, full_content_embedding_store)
402484
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
485+
reduce_prompt = _load_search_prompt(
486+
config.root_dir, config.drift_search.reduce_prompt
487+
)
488+
403489
search_engine = get_drift_search_engine(
404490
config=config,
405491
reports=reports,
@@ -408,21 +494,15 @@ async def drift_search(
408494
relationships=read_indexer_relationships(relationships),
409495
description_embedding_store=description_embedding_store, # type: ignore
410496
local_system_prompt=prompt,
497+
reduce_system_prompt=reduce_prompt,
498+
response_type=response_type,
411499
)
412500

413501
result: SearchResult = await search_engine.asearch(query=query)
414502
response = result.response
415503
context_data = _reformat_context_data(result.context_data) # type: ignore
416504

417-
# TODO: Map/reduce the response to a single string with a comprehensive answer including all follow-ups
418-
# For the time being, return highest scoring response (position 0) and context data
419-
match response:
420-
case dict():
421-
return response["nodes"][0]["answer"], context_data # type: ignore
422-
case str():
423-
return response, context_data
424-
case list():
425-
return response, context_data
505+
return response, context_data
426506

427507

428508
@validate_call(config={"arbitrary_types_allowed": True})

graphrag/cli/initialize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,10 @@
1414
from graphrag.prompts.index.entity_extraction import GRAPH_EXTRACTION_PROMPT
1515
from graphrag.prompts.index.summarize_descriptions import SUMMARIZE_PROMPT
1616
from graphrag.prompts.query.basic_search_system_prompt import BASIC_SEARCH_SYSTEM_PROMPT
17-
from graphrag.prompts.query.drift_search_system_prompt import DRIFT_LOCAL_SYSTEM_PROMPT
17+
from graphrag.prompts.query.drift_search_system_prompt import (
18+
DRIFT_LOCAL_SYSTEM_PROMPT,
19+
DRIFT_REDUCE_PROMPT,
20+
)
1821
from graphrag.prompts.query.global_search_knowledge_system_prompt import (
1922
GENERAL_KNOWLEDGE_INSTRUCTION,
2023
)
@@ -57,6 +60,7 @@ def initialize_project_at(path: Path) -> None:
5760
"claim_extraction": CLAIM_EXTRACTION_PROMPT,
5861
"community_report": COMMUNITY_REPORT_PROMPT,
5962
"drift_search_system_prompt": DRIFT_LOCAL_SYSTEM_PROMPT,
63+
"drift_reduce_prompt": DRIFT_REDUCE_PROMPT,
6064
"global_search_map_system_prompt": MAP_SYSTEM_PROMPT,
6165
"global_search_reduce_system_prompt": REDUCE_SYSTEM_PROMPT,
6266
"global_search_knowledge_system_prompt": GENERAL_KNOWLEDGE_INSTRUCTION,

graphrag/cli/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,8 @@ def _query_cli(
460460
data_dir=data,
461461
root_dir=root,
462462
community_level=community_level,
463-
streaming=False, # Drift search does not support streaming (yet)
463+
streaming=streaming,
464+
response_type=response_type,
464465
query=query,
465466
)
466467
case SearchType.BASIC:

graphrag/cli/query.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def run_drift_search(
202202
data_dir: Path | None,
203203
root_dir: Path,
204204
community_level: int,
205+
response_type: str,
205206
streaming: bool,
206207
query: str,
207208
):
@@ -234,8 +235,33 @@ def run_drift_search(
234235

235236
# call the Query API
236237
if streaming:
237-
error_msg = "Streaming is not supported yet for DRIFT search."
238-
raise NotImplementedError(error_msg)
238+
239+
async def run_streaming_search():
240+
full_response = ""
241+
context_data = None
242+
get_context_data = True
243+
async for stream_chunk in api.drift_search_streaming(
244+
config=config,
245+
nodes=final_nodes,
246+
entities=final_entities,
247+
community_reports=final_community_reports,
248+
text_units=final_text_units,
249+
relationships=final_relationships,
250+
community_level=community_level,
251+
response_type=response_type,
252+
query=query,
253+
):
254+
if get_context_data:
255+
context_data = stream_chunk
256+
get_context_data = False
257+
else:
258+
full_response += stream_chunk
259+
print(stream_chunk, end="") # noqa: T201
260+
sys.stdout.flush() # flush output buffer to display text immediately
261+
print() # noqa: T201
262+
return full_response, context_data
263+
264+
return asyncio.run(run_streaming_search())
239265

240266
# not streaming
241267
response, context_data = asyncio.run(
@@ -247,6 +273,7 @@ def run_drift_search(
247273
text_units=final_text_units,
248274
relationships=final_relationships,
249275
community_level=community_level,
276+
response_type=response_type,
250277
query=query,
251278
)
252279
)
@@ -281,8 +308,6 @@ def run_basic_search(
281308
)
282309
final_text_units: pd.DataFrame = dataframe_dict["create_final_text_units"]
283310

284-
print(streaming) # noqa: T201
285-
286311
# # call the Query API
287312
if streaming:
288313

graphrag/config/create_graphrag_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,7 @@ def hydrate_parallelization_params(
589589
):
590590
drift_search_model = DRIFTSearchConfig(
591591
prompt=reader.str("prompt") or None,
592+
reduce_prompt=reader.str("reduce_prompt") or None,
592593
temperature=reader.float("llm_temperature")
593594
or defs.DRIFT_SEARCH_LLM_TEMPERATURE,
594595
top_p=reader.float("llm_top_p") or defs.DRIFT_SEARCH_LLM_TOP_P,
@@ -597,6 +598,10 @@ def hydrate_parallelization_params(
597598
or defs.DRIFT_SEARCH_MAX_TOKENS,
598599
data_max_tokens=reader.int("data_max_tokens")
599600
or defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
601+
reduce_max_tokens=reader.int("reduce_max_tokens")
602+
or defs.DRIFT_SEARCH_REDUCE_MAX_TOKENS,
603+
reduce_temperature=reader.float("reduce_temperature")
604+
or defs.DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE,
600605
concurrency=reader.int("concurrency") or defs.DRIFT_SEARCH_CONCURRENCY,
601606
drift_k_followups=reader.int("drift_k_followups")
602607
or defs.DRIFT_SEARCH_K_FOLLOW_UPS,

graphrag/config/defaults.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,9 @@
149149
DRIFT_SEARCH_PRIMER_FOLDS = 5
150150
DRIFT_SEARCH_PRIMER_MAX_TOKENS = 12_000
151151

152+
DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE = 0
153+
DRIFT_SEARCH_REDUCE_MAX_TOKENS = 2_000
154+
152155
DRIFT_LOCAL_SEARCH_TEXT_UNIT_PROP = 0.9
153156
DRIFT_LOCAL_SEARCH_COMMUNITY_PROP = 0.1
154157
DRIFT_LOCAL_SEARCH_TOP_K_MAPPED_ENTITIES = 10

graphrag/config/init_content.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@
132132
133133
drift_search:
134134
prompt: "prompts/drift_search_system_prompt.txt"
135+
reduce_prompt: "prompts/drift_search_reduce_prompt.txt"
135136
136137
basic_search:
137138
prompt: "prompts/basic_search_system_prompt.txt"

graphrag/config/models/drift_search_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ class DRIFTSearchConfig(BaseModel):
1414
prompt: str | None = Field(
1515
description="The drift search prompt to use.", default=None
1616
)
17+
reduce_prompt: str | None = Field(
18+
description="The drift search reduce prompt to use.", default=None
19+
)
1720
temperature: float = Field(
1821
description="The temperature to use for token generation.",
1922
default=defs.DRIFT_SEARCH_LLM_TEMPERATURE,
@@ -35,6 +38,16 @@ class DRIFTSearchConfig(BaseModel):
3538
default=defs.DRIFT_SEARCH_DATA_MAX_TOKENS,
3639
)
3740

41+
reduce_max_tokens: int = Field(
42+
description="The reduce llm maximum tokens response to produce.",
43+
default=defs.DRIFT_SEARCH_REDUCE_MAX_TOKENS,
44+
)
45+
46+
reduce_temperature: float = Field(
47+
description="The temperature to use for token generation in reduce.",
48+
default=defs.DRIFT_SEARCH_REDUCE_LLM_TEMPERATURE,
49+
)
50+
3851
concurrency: int = Field(
3952
description="The number of concurrent requests.",
4053
default=defs.DRIFT_SEARCH_CONCURRENCY,

0 commit comments

Comments
 (0)