Skip to content

Commit a113821

Browse files
authored
Fix/few shot selection (#530)
* try to always use at least 3 few shot examples * add args for auto tune * use context-based KNN to select most relevant chunks * enforce at least 3 few shot examples for generated prompts * utils for content-based KNN * sem version * fix callback arg * fixes * switch back to no op callbacks * make n few shot, user controlled. default to 2"
1 parent 2ddee65 commit a113821

File tree

5 files changed

+112
-12
lines changed

5 files changed

+112
-12
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Add content-based KNN for selecting prompt tune few shot examples"
4+
}

graphrag/prompt_tune/__main__.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from graphrag.prompt_tune.generator import MAX_TOKEN_COUNT
1111
from graphrag.prompt_tune.loader import MIN_CHUNK_SIZE
1212

13-
from .cli import fine_tune
13+
from .cli import prompt_tune
1414

1515

1616
class DocSelectionType(Enum):
@@ -19,6 +19,7 @@ class DocSelectionType(Enum):
1919
ALL = "all"
2020
RANDOM = "random"
2121
TOP = "top"
22+
AUTO = "auto"
2223

2324
def __str__(self):
2425
"""Return the string representation of the enum value."""
@@ -46,13 +47,29 @@ def __str__(self):
4647

4748
parser.add_argument(
4849
"--method",
49-
help="The method to select documents, one of: all, random or top",
50+
help="The method to select documents, one of: all, random, top or auto",
5051
required=False,
5152
type=DocSelectionType,
5253
choices=list(DocSelectionType),
5354
default=DocSelectionType.RANDOM,
5455
)
5556

57+
parser.add_argument(
58+
"--n_subset_max",
59+
help="The number of text chunks to embed when using auto selection method",
60+
required=False,
61+
type=int,
62+
default=300,
63+
)
64+
65+
parser.add_argument(
66+
"--k",
67+
help="The maximum number of documents to select from each centroid when using auto selection method",
68+
required=False,
69+
type=int,
70+
default=15,
71+
)
72+
5673
parser.add_argument(
5774
"--limit",
5875
help="The limit of files to load when doing random or top selection",
@@ -69,6 +86,14 @@ def __str__(self):
6986
default=MAX_TOKEN_COUNT,
7087
)
7188

89+
parser.add_argument(
90+
"--min-examples-required",
91+
help="The minimum number of examples required in entity extraction prompt",
92+
type=int,
93+
required=False,
94+
default=2,
95+
)
96+
7297
parser.add_argument(
7398
"--chunk-size",
7499
help="Max token count for prompt generation",
@@ -106,7 +131,7 @@ def __str__(self):
106131
loop = asyncio.get_event_loop()
107132

108133
loop.run_until_complete(
109-
fine_tune(
134+
prompt_tune(
110135
args.root,
111136
args.domain,
112137
str(args.method),
@@ -116,5 +141,8 @@ def __str__(self):
116141
args.language,
117142
args.no_entity_types,
118143
args.output,
144+
args.n_subset_max,
145+
args.k,
146+
args.min_examples_required,
119147
)
120148
)

graphrag/prompt_tune/cli.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
)
3333

3434

35-
async def fine_tune(
35+
async def prompt_tune(
3636
root: str,
3737
domain: str,
3838
select: str = "random",
@@ -42,8 +42,11 @@ async def fine_tune(
4242
language: str | None = None,
4343
skip_entity_types: bool = False,
4444
output: str = "prompts",
45+
n_subset_max: int = 300,
46+
k: int = 15,
47+
min_examples_required: int = 2,
4548
):
46-
"""Fine tune the model.
49+
"""Prompt tune the model.
4750
4851
Parameters
4952
----------
@@ -55,11 +58,13 @@ async def fine_tune(
5558
- chunk_size: The chunk token size to use.
5659
- skip_entity_types: Skip generating entity types.
5760
- output: The output folder to store the prompts.
61+
- n_subset_max: The number of text chunks to embed when using auto selection method.
62+
- k: The number of documents to select when using auto selection method.
5863
"""
5964
reporter = PrintProgressReporter("")
6065
config = read_config_parameters(root, reporter)
6166

62-
await fine_tune_with_config(
67+
await prompt_tune_with_config(
6368
root,
6469
config,
6570
domain,
@@ -71,10 +76,13 @@ async def fine_tune(
7176
skip_entity_types,
7277
output,
7378
reporter,
79+
n_subset_max,
80+
k,
81+
min_examples_required,
7482
)
7583

7684

77-
async def fine_tune_with_config(
85+
async def prompt_tune_with_config(
7886
root: str,
7987
config: GraphRagConfig,
8088
domain: str,
@@ -86,8 +94,11 @@ async def fine_tune_with_config(
8694
skip_entity_types: bool = False,
8795
output: str = "prompts",
8896
reporter: ProgressReporter | None = None,
97+
n_subset_max: int = 300,
98+
k: int = 15,
99+
min_examples_required: int = 2,
89100
):
90-
"""Fine tune the model with a configuration.
101+
"""Prompt tune the model with a configuration.
91102
92103
Parameters
93104
----------
@@ -101,6 +112,8 @@ async def fine_tune_with_config(
101112
- skip_entity_types: Skip generating entity types.
102113
- output: The output folder to store the prompts.
103114
- reporter: The progress reporter.
115+
- n_subset_max: The number of text chunks to embed when using auto selection method.
116+
- k: The number of documents to select when using auto selection method.
104117
105118
Returns
106119
-------
@@ -118,11 +131,13 @@ async def fine_tune_with_config(
118131
select_method=select,
119132
reporter=reporter,
120133
chunk_size=chunk_size,
134+
n_subset_max=n_subset_max,
135+
k=k,
121136
)
122137

123138
# Create LLM from config
124139
llm = load_llm(
125-
"fine_tuning",
140+
"prompt_tuning",
126141
config.llm.type,
127142
NoopVerbCallbacks(),
128143
None,
@@ -139,6 +154,7 @@ async def fine_tune_with_config(
139154
language,
140155
max_tokens,
141156
skip_entity_types,
157+
min_examples_required,
142158
)
143159

144160

@@ -152,6 +168,7 @@ async def generate_indexing_prompts(
152168
language: str | None = None,
153169
max_tokens: int = MAX_TOKEN_COUNT,
154170
skip_entity_types: bool = False,
171+
min_examples_required: int = 2,
155172
):
156173
"""Generate indexing prompts.
157174
@@ -165,6 +182,7 @@ async def generate_indexing_prompts(
165182
- domain: The domain to map the input documents to.
166183
- max_tokens: The maximum number of tokens to use on entity extraction prompts
167184
- skip_entity_types: Skip generating entity types.
185+
- min_examples_required: The minimum number of examples required for entity extraction prompts.
168186
"""
169187
if not domain:
170188
reporter.info("Generating domain...")
@@ -221,6 +239,7 @@ async def generate_indexing_prompts(
221239
output_path=output_path,
222240
encoding_model=config.encoding_model,
223241
max_token_count=max_tokens,
242+
min_examples_required=min_examples_required,
224243
)
225244
reporter.info(f"Generated entity extraction prompt, stored in folder {output_path}")
226245

graphrag/prompt_tune/generator/entity_extraction_prompt.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def create_entity_extraction_prompt(
2727
encoding_model: str = defs.ENCODING_MODEL,
2828
json_mode: bool = False,
2929
output_path: Path | None = None,
30+
min_examples_required: int = 2,
3031
) -> str:
3132
"""
3233
Create a prompt for entity extraction.
@@ -41,6 +42,7 @@ def create_entity_extraction_prompt(
4142
- max_token_count (int): The maximum number of tokens to use for the prompt
4243
- json_mode (bool): Whether to use JSON mode for the prompt. Default is False
4344
- output_path (Path | None): The path to write the prompt to. Default is None. If None, the prompt is not written to a file. Default is None.
45+
- min_examples_required (int): The minimum number of examples required. Default is 2.
4446
4547
Returns
4648
-------
@@ -79,8 +81,8 @@ def create_entity_extraction_prompt(
7981

8082
example_tokens = num_tokens_from_string(example_formatted, model=encoding_model)
8183

82-
# Squeeze in at least one example
83-
if i > 0 and example_tokens > tokens_left:
84+
# Ensure at least three examples are included
85+
if i >= min_examples_required and example_tokens > tokens_left:
8486
break
8587

8688
examples_prompt += example_formatted

graphrag/prompt_tune/loader/input.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,45 @@
55

66
from typing import cast
77

8+
import numpy as np
89
import pandas as pd
910
from datashaper import NoopVerbCallbacks, TableContainer, VerbInput
1011

1112
from graphrag.config.models.graph_rag_config import GraphRagConfig
1213
from graphrag.index.input import load_input
14+
from graphrag.index.llm import load_llm_embeddings
1315
from graphrag.index.progress.types import ProgressReporter
1416
from graphrag.index.verbs import chunk
17+
from graphrag.llm.types.llm_types import EmbeddingLLM
1518

16-
MIN_CHUNK_SIZE = 200
1719
MIN_CHUNK_OVERLAP = 0
20+
MIN_CHUNK_SIZE = 200
21+
N_SUBSET_MAX = 300
22+
K = 15
23+
24+
25+
async def _embed_chunks(
26+
text_chunks: pd.DataFrame,
27+
embedding_llm: EmbeddingLLM,
28+
n_subset_max: int = N_SUBSET_MAX,
29+
) -> tuple[pd.DataFrame, np.ndarray]:
30+
"""Convert text chunks into dense text embeddings."""
31+
sampled_text_chunks = text_chunks.sample(n=min(n_subset_max, len(text_chunks)))
32+
embeddings = await embedding_llm(sampled_text_chunks["chunks"].tolist())
33+
return text_chunks, np.array(embeddings.output)
34+
35+
36+
def _sample_chunks_from_embeddings(
37+
text_chunks: pd.DataFrame,
38+
embeddings,
39+
k: int = K,
40+
) -> pd.DataFrame:
41+
"""Sample text chunks from embeddings."""
42+
center = np.mean(embeddings, axis=0)
43+
distances = np.linalg.norm(embeddings - center, axis=1)
44+
nearest_indices = np.argsort(distances)[:k]
45+
46+
return text_chunks.iloc[nearest_indices]
1847

1948

2049
async def load_docs_in_chunks(
@@ -24,6 +53,8 @@ async def load_docs_in_chunks(
2453
limit: int,
2554
reporter: ProgressReporter,
2655
chunk_size: int = MIN_CHUNK_SIZE,
56+
n_subset_max: int = N_SUBSET_MAX,
57+
k: int = K,
2758
) -> list[str]:
2859
"""Load docs into chunks for generating prompts."""
2960
dataset = await load_input(config.input, reporter, root)
@@ -57,6 +88,22 @@ async def load_docs_in_chunks(
5788
chunks_df = chunks_df[:limit]
5889
elif select_method == "random":
5990
chunks_df = chunks_df.sample(n=limit)
91+
elif select_method == "auto":
92+
if k is None or k <= 0:
93+
msg = "k must be an integer > 0"
94+
raise ValueError(msg)
95+
embedding_llm = load_llm_embeddings(
96+
name="prompt_tuning_embeddings",
97+
llm_type=config.embeddings.resolved_strategy()["llm"]["type"],
98+
callbacks=NoopVerbCallbacks(),
99+
cache=None,
100+
llm_config=config.embeddings.resolved_strategy()["llm"],
101+
)
102+
103+
chunks_df, embeddings = await _embed_chunks(
104+
chunks_df, embedding_llm, n_subset_max=n_subset_max
105+
)
106+
chunks_df = _sample_chunks_from_embeddings(chunks_df, embeddings, k=k)
60107

61108
# Convert the dataset to list form, so we have a list of documents
62109
return chunks_df["chunks"].tolist()

0 commit comments

Comments
 (0)