Skip to content

Commit e2a4481

Browse files
natoverseha2trinh
andauthored
Fix/minor query fixes (#1893)
* fixed token count for drift search * basic search fixes * updated basic search prompt * fixed text splitting logic * Lint/format * Semver * Fix text splitting tests --------- Co-authored-by: ha2trinh <trinhha@microsoft.com>
1 parent ad4cdd6 commit e2a4481

File tree

10 files changed

+107
-30
lines changed

10 files changed

+107
-30
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Fixes to basic search."
4+
}

graphrag/config/defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ class BasicSearchDefaults:
4242

4343
prompt: None = None
4444
k: int = 10
45+
max_context_tokens: int = 12_000
4546
chat_model_id: str = DEFAULT_CHAT_MODEL_ID
4647
embedding_model_id: str = DEFAULT_EMBEDDING_MODEL_ID
4748

graphrag/config/models/basic_search_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,7 @@ class BasicSearchConfig(BaseModel):
2727
description="The number of text units to include in search context.",
2828
default=graphrag_config_defaults.basic_search.k,
2929
)
30+
max_context_tokens: int = Field(
31+
description="The maximum tokens.",
32+
default=graphrag_config_defaults.basic_search.max_context_tokens,
33+
)

graphrag/index/text_splitting/text_splitting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,8 @@ def split_single_text_on_tokens(text: str, tokenizer: Tokenizer) -> list[str]:
152152
while start_idx < len(input_ids):
153153
chunk_text = tokenizer.decode(list(chunk_ids))
154154
result.append(chunk_text) # Append chunked text as string
155+
if cur_idx == len(input_ids):
156+
break
155157
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
156158
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
157159
chunk_ids = input_ids[start_idx:cur_idx]
@@ -186,6 +188,8 @@ def split_multiple_texts_on_tokens(
186188
chunk_text = tokenizer.decode([id for _, id in chunk_ids])
187189
doc_indices = list({doc_idx for doc_idx, _ in chunk_ids})
188190
result.append(TextChunk(chunk_text, doc_indices, len(chunk_ids)))
191+
if cur_idx == len(input_ids):
192+
break
189193
start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
190194
cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
191195
chunk_ids = input_ids[start_idx:cur_idx]

graphrag/prompts/query/basic_search_system_prompt.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,23 +11,25 @@
1111
1212
---Goal---
1313
14-
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
14+
Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data tables appropriate for the response length and format.
1515
16-
If you don't know the answer, just say so. Do not make anything up.
16+
You should use the data provided in the data tables below as the primary context for generating the response.
17+
18+
If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.
1719
1820
Points supported by data should list their data references as follows:
1921
20-
"This is an example sentence supported by multiple text references [Data: Sources (record ids)]."
22+
"This is an example sentence supported by multiple data references [Data: Sources (record ids)]."
2123
2224
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
2325
2426
For example:
2527
26-
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]."
28+
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]"
2729
28-
where 15 and 16 represent the id (not the index) of the relevant data record.
30+
where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables.
2931
30-
Do not include information where the supporting text for it is not provided.
32+
Do not include information where the supporting evidence for it is not provided.
3133
3234
3335
---Target response length and format---
@@ -42,23 +44,26 @@
4244
4345
---Goal---
4446
45-
Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
47+
Generate a response of the target length and format that responds to the user's question, summarizing all relevant information in the input data appropriate for the response length and format.
48+
49+
You should use the data provided in the data tables below as the primary context for generating the response.
4650
47-
If you don't know the answer, just say so. Do not make anything up.
51+
If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up.
4852
4953
Points supported by data should list their data references as follows:
5054
51-
"This is an example sentence supported by multiple text references [Data: Sources (record ids)]."
55+
"This is an example sentence supported by multiple data references [Data: Sources (record ids)]."
5256
5357
Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more.
5458
5559
For example:
5660
57-
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (15, 16)]."
61+
"Person X is the owner of Company Y and subject to many allegations of wrongdoing [Data: Sources (2, 7, 64, 46, 34, +more)]. He is also CEO of company X [Data: Sources (1, 3)]"
62+
63+
where 1, 2, 3, 7, 34, 46, and 64 represent the source id taken from the "source_id" column in the provided tables.
5864
59-
where 15 and 16 represent the id (not the index) of the relevant data record.
65+
Do not include information where the supporting evidence for it is not provided.
6066
61-
Do not include information where the supporting text for it is not provided.
6267
6368
---Target response length and format---
6469

graphrag/query/factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,7 @@ def get_basic_search_engine(
275275
text_unit_embeddings: BaseVectorStore,
276276
config: GraphRagConfig,
277277
system_prompt: str | None = None,
278+
response_type: str = "multiple paragraphs",
278279
callbacks: list[QueryCallbacks] | None = None,
279280
) -> BasicSearch:
280281
"""Create a basic search engine based on data + configuration."""
@@ -312,6 +313,7 @@ def get_basic_search_engine(
312313
return BasicSearch(
313314
model=chat_model,
314315
system_prompt=system_prompt,
316+
response_type=response_type,
315317
context_builder=BasicSearchContext(
316318
text_embedder=embedding_model,
317319
text_unit_embeddings=text_unit_embeddings,
@@ -323,6 +325,7 @@ def get_basic_search_engine(
323325
context_builder_params={
324326
"embedding_vectorstore_key": "id",
325327
"k": bs_config.k,
328+
"max_context_tokens": bs_config.max_context_tokens,
326329
},
327330
callbacks=callbacks,
328331
)

graphrag/query/structured_search/basic_search/basic_context.py

Lines changed: 67 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
"""Basic Context Builder implementation."""
55

6+
import logging
7+
from typing import cast
8+
69
import pandas as pd
710
import tiktoken
811

@@ -13,8 +16,11 @@
1316
ContextBuilderResult,
1417
)
1518
from graphrag.query.context_builder.conversation_history import ConversationHistory
19+
from graphrag.query.llm.text_utils import num_tokens
1620
from graphrag.vector_stores.base import BaseVectorStore
1721

22+
log = logging.getLogger(__name__)
23+
1824

1925
class BasicSearchContext(BasicContextBuilder):
2026
"""Class representing the Basic Search Context Builder."""
@@ -32,30 +38,76 @@ def __init__(
3238
self.text_units = text_units
3339
self.text_unit_embeddings = text_unit_embeddings
3440
self.embedding_vectorstore_key = embedding_vectorstore_key
41+
self.text_id_map = self._map_ids()
3542

3643
def build_context(
3744
self,
3845
query: str,
3946
conversation_history: ConversationHistory | None = None,
47+
k: int = 10,
48+
max_context_tokens: int = 12_000,
49+
context_name: str = "Sources",
50+
column_delimiter: str = "|",
51+
text_id_col: str = "source_id",
52+
text_col: str = "text",
4053
**kwargs,
4154
) -> ContextBuilderResult:
42-
"""Build the context for the local search mode."""
43-
search_results = self.text_unit_embeddings.similarity_search_by_text(
44-
text=query,
45-
text_embedder=lambda t: self.text_embedder.embed(t),
46-
k=kwargs.get("k", 10),
55+
"""Build the context for the basic search mode."""
56+
if query != "":
57+
related_texts = self.text_unit_embeddings.similarity_search_by_text(
58+
text=query,
59+
text_embedder=lambda t: self.text_embedder.embed(t),
60+
k=k,
61+
)
62+
related_text_list = [
63+
{
64+
text_id_col: self.text_id_map[f"{chunk.document.id}"],
65+
text_col: chunk.document.text,
66+
}
67+
for chunk in related_texts
68+
]
69+
related_text_df = pd.DataFrame(related_text_list)
70+
else:
71+
related_text_df = pd.DataFrame({
72+
text_id_col: [],
73+
text_col: [],
74+
})
75+
76+
# add these related text chunks into context until we fill up the context window
77+
current_tokens = 0
78+
text_ids = []
79+
current_tokens = num_tokens(
80+
text_id_col + column_delimiter + text_col + "\n", self.token_encoder
4781
)
48-
# we don't have a friendly id on text_units, so just copy the index
49-
sources = [
50-
{"id": str(search_results.index(r)), "text": r.document.text}
51-
for r in search_results
52-
]
53-
# make a delimited table for the context; this imitates graphrag context building
54-
table = ["id|text"] + [f"{s['id']}|{s['text']}" for s in sources]
82+
for i, row in related_text_df.iterrows():
83+
text = row[text_id_col] + column_delimiter + row[text_col] + "\n"
84+
tokens = num_tokens(text, self.token_encoder)
85+
if current_tokens + tokens > max_context_tokens:
86+
msg = f"Reached token limit: {current_tokens + tokens}. Reverting to previous context state"
87+
log.info(msg)
88+
break
5589

56-
columns = pd.Index(["id", "text"])
90+
current_tokens += tokens
91+
text_ids.append(i)
92+
final_text_df = cast(
93+
"pd.DataFrame",
94+
related_text_df[related_text_df.index.isin(text_ids)].reset_index(
95+
drop=True
96+
),
97+
)
98+
final_text = final_text_df.to_csv(
99+
index=False, escapechar="\\", sep=column_delimiter
100+
)
57101

58102
return ContextBuilderResult(
59-
context_chunks="\n\n".join(table),
60-
context_records={"sources": pd.DataFrame(sources, columns=columns)},
103+
context_chunks=final_text,
104+
context_records={context_name: final_text_df},
61105
)
106+
107+
def _map_ids(self) -> dict[str, str]:
108+
"""Map id to short id in the text units."""
109+
id_map = {}
110+
text_units = self.text_units or []
111+
for unit in text_units:
112+
id_map[unit.id] = unit.short_id
113+
return id_map

graphrag/query/structured_search/basic_search/search.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,9 @@ async def search(
108108
llm_calls=1,
109109
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
110110
output_tokens=sum(output_tokens.values()),
111+
llm_calls_categories=llm_calls,
112+
prompt_tokens_categories=prompt_tokens,
113+
output_tokens_categories=output_tokens,
111114
)
112115

113116
except Exception:
@@ -120,6 +123,9 @@ async def search(
120123
llm_calls=1,
121124
prompt_tokens=num_tokens(search_prompt, self.token_encoder),
122125
output_tokens=0,
126+
llm_calls_categories=llm_calls,
127+
prompt_tokens_categories=prompt_tokens,
128+
output_tokens_categories=output_tokens,
123129
)
124130

125131
async def stream_search(

graphrag/query/structured_search/drift_search/search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ async def search(
213213
primer_context, token_ct = await self.context_builder.build_context(query)
214214
llm_calls["build_context"] = token_ct["llm_calls"]
215215
prompt_tokens["build_context"] = token_ct["prompt_tokens"]
216-
output_tokens["build_context"] = token_ct["prompt_tokens"]
216+
output_tokens["build_context"] = token_ct["output_tokens"]
217217

218218
primer_response = await self.primer.search(
219219
query=query, top_k_reports=primer_context

tests/unit/indexing/text_splitting/test_text_splitting.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def test_split_single_text_on_tokens():
136136
" by this t",
137137
"his test o",
138138
"est only.",
139-
"nly.",
140139
]
141140

142141
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
@@ -197,7 +196,6 @@ def decode(tokens: list[int]) -> str:
197196
" this test",
198197
" test only",
199198
" only.",
200-
".",
201199
]
202200

203201
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)

0 commit comments

Comments
 (0)