Skip to content

Commit f7cd155

Browse files
Fix/encoding model config (#1527)
* fix: include encoding_model option when initializing LLMParameters * chore: add semver patch description * Fix encoding model parsing * Fix unit tests --------- Co-authored-by: Nico Reinartz <nico.reinartz@rwth-aachen.de>
1 parent 329b83c commit f7cd155

File tree

6 files changed

+38
-11
lines changed

6 files changed

+38
-11
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Respect encoding_model option"
4+
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Fix encoding model config parsing"
4+
}

graphrag/config/create_graphrag_config.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def hydrate_llm_params(
8585
deployment_name = (
8686
reader.str(Fragment.deployment_name) or base.deployment_name
8787
)
88+
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model
8889

8990
if api_key is None and not _is_azure(llm_type):
9091
raise ApiKeyMissingError
@@ -106,6 +107,7 @@ def hydrate_llm_params(
106107
organization=reader.str("organization") or base.organization,
107108
proxy=reader.str("proxy") or base.proxy,
108109
model=reader.str("model") or base.model,
110+
encoding_model=encoding_model,
109111
max_tokens=reader.int(Fragment.max_tokens) or base.max_tokens,
110112
temperature=reader.float(Fragment.temperature) or base.temperature,
111113
top_p=reader.float(Fragment.top_p) or base.top_p,
@@ -155,6 +157,7 @@ def hydrate_embeddings_params(
155157
api_proxy = reader.str("proxy") or base.proxy
156158
audience = reader.str(Fragment.audience) or base.audience
157159
deployment_name = reader.str(Fragment.deployment_name)
160+
encoding_model = reader.str(Fragment.encoding_model) or base.encoding_model
158161

159162
if api_key is None and not _is_azure(api_type):
160163
raise ApiKeyMissingError(embedding=True)
@@ -176,6 +179,7 @@ def hydrate_embeddings_params(
176179
organization=api_organization,
177180
proxy=api_proxy,
178181
model=reader.str(Fragment.model) or defs.EMBEDDING_MODEL,
182+
encoding_model=encoding_model,
179183
request_timeout=reader.float(Fragment.request_timeout)
180184
or defs.LLM_REQUEST_TIMEOUT,
181185
audience=audience,
@@ -217,6 +221,9 @@ def hydrate_parallelization_params(
217221
fallback_oai_base = reader.str(Fragment.api_base) or fallback_oai_base
218222
fallback_oai_version = reader.str(Fragment.api_version) or fallback_oai_version
219223
fallback_oai_proxy = reader.str(Fragment.api_proxy)
224+
global_encoding_model = (
225+
reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
226+
)
220227

221228
with reader.envvar_prefix(Section.llm):
222229
with reader.use(values.get("llm")):
@@ -231,6 +238,9 @@ def hydrate_parallelization_params(
231238
api_proxy = reader.str(Fragment.api_proxy) or fallback_oai_proxy
232239
audience = reader.str(Fragment.audience)
233240
deployment_name = reader.str(Fragment.deployment_name)
241+
encoding_model = (
242+
reader.str(Fragment.encoding_model) or global_encoding_model
243+
)
234244

235245
if api_key is None and not _is_azure(llm_type):
236246
raise ApiKeyMissingError
@@ -252,6 +262,7 @@ def hydrate_parallelization_params(
252262
proxy=api_proxy,
253263
type=llm_type,
254264
model=reader.str(Fragment.model) or defs.LLM_MODEL,
265+
encoding_model=encoding_model,
255266
max_tokens=reader.int(Fragment.max_tokens) or defs.LLM_MAX_TOKENS,
256267
temperature=reader.float(Fragment.temperature)
257268
or defs.LLM_TEMPERATURE,
@@ -396,12 +407,15 @@ def hydrate_parallelization_params(
396407
group_by_columns = reader.list("group_by_columns", "BY_COLUMNS")
397408
if group_by_columns is None:
398409
group_by_columns = defs.CHUNK_GROUP_BY_COLUMNS
410+
encoding_model = (
411+
reader.str(Fragment.encoding_model) or global_encoding_model
412+
)
399413

400414
chunks_model = ChunkingConfig(
401415
size=reader.int("size") or defs.CHUNK_SIZE,
402416
overlap=reader.int("overlap") or defs.CHUNK_OVERLAP,
403417
group_by_columns=group_by_columns,
404-
encoding_model=reader.str(Fragment.encoding_model),
418+
encoding_model=encoding_model,
405419
)
406420
with (
407421
reader.envvar_prefix(Section.snapshot),
@@ -428,6 +442,9 @@ def hydrate_parallelization_params(
428442
if max_gleanings is not None
429443
else defs.ENTITY_EXTRACTION_MAX_GLEANINGS
430444
)
445+
encoding_model = (
446+
reader.str(Fragment.encoding_model) or global_encoding_model
447+
)
431448

432449
entity_extraction_model = EntityExtractionConfig(
433450
llm=hydrate_llm_params(entity_extraction_config, llm_model),
@@ -440,7 +457,7 @@ def hydrate_parallelization_params(
440457
max_gleanings=max_gleanings,
441458
prompt=reader.str("prompt", Fragment.prompt_file),
442459
strategy=entity_extraction_config.get("strategy"),
443-
encoding_model=reader.str(Fragment.encoding_model),
460+
encoding_model=encoding_model,
444461
)
445462

446463
claim_extraction_config = values.get("claim_extraction") or {}
@@ -452,6 +469,9 @@ def hydrate_parallelization_params(
452469
max_gleanings = (
453470
max_gleanings if max_gleanings is not None else defs.CLAIM_MAX_GLEANINGS
454471
)
472+
encoding_model = (
473+
reader.str(Fragment.encoding_model) or global_encoding_model
474+
)
455475
claim_extraction_model = ClaimExtractionConfig(
456476
enabled=reader.bool(Fragment.enabled) or defs.CLAIM_EXTRACTION_ENABLED,
457477
llm=hydrate_llm_params(claim_extraction_config, llm_model),
@@ -462,7 +482,7 @@ def hydrate_parallelization_params(
462482
description=reader.str("description") or defs.CLAIM_DESCRIPTION,
463483
prompt=reader.str("prompt", Fragment.prompt_file),
464484
max_gleanings=max_gleanings,
465-
encoding_model=reader.str(Fragment.encoding_model),
485+
encoding_model=encoding_model,
466486
)
467487

468488
community_report_config = values.get("community_reports") or {}
@@ -603,7 +623,6 @@ def hydrate_parallelization_params(
603623
or defs.DRIFT_LOCAL_SEARCH_LLM_MAX_TOKENS,
604624
)
605625

606-
encoding_model = reader.str(Fragment.encoding_model) or defs.ENCODING_MODEL
607626
skip_workflows = reader.list("skip_workflows") or []
608627

609628
return GraphRagConfig(
@@ -626,7 +645,7 @@ def hydrate_parallelization_params(
626645
summarize_descriptions=summarize_descriptions_model,
627646
umap=umap_model,
628647
cluster_graph=cluster_graph_model,
629-
encoding_model=encoding_model,
648+
encoding_model=global_encoding_model,
630649
skip_workflows=skip_workflows,
631650
local_search=local_search_model,
632651
global_search=global_search_model,

graphrag/config/models/chunking_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class ChunkingConfig(BaseModel):
2727
default=None, description="The encoding model to use."
2828
)
2929

30-
def resolved_strategy(self, encoding_model: str) -> dict:
30+
def resolved_strategy(self, encoding_model: str | None) -> dict:
3131
"""Get the resolved chunking strategy."""
3232
from graphrag.index.operations.chunk_text import ChunkStrategyType
3333

@@ -36,5 +36,5 @@ def resolved_strategy(self, encoding_model: str) -> dict:
3636
"chunk_size": self.size,
3737
"chunk_overlap": self.overlap,
3838
"group_by_columns": self.group_by_columns,
39-
"encoding_name": self.encoding_model or encoding_model,
39+
"encoding_name": encoding_model or self.encoding_model,
4040
}

graphrag/config/models/claim_extraction_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class ClaimExtractionConfig(LLMConfig):
3535
default=None, description="The encoding model to use."
3636
)
3737

38-
def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
38+
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
3939
"""Get the resolved claim extraction strategy."""
4040
from graphrag.index.operations.extract_covariates import (
4141
ExtractClaimsStrategyType,
@@ -52,5 +52,5 @@ def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
5252
else None,
5353
"claim_description": self.description,
5454
"max_gleanings": self.max_gleanings,
55-
"encoding_name": self.encoding_model or encoding_model,
55+
"encoding_name": encoding_model or self.encoding_model,
5656
}

graphrag/config/models/entity_extraction_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class EntityExtractionConfig(LLMConfig):
3232
default=None, description="The encoding model to use."
3333
)
3434

35-
def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
35+
def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
3636
"""Get the resolved entity extraction strategy."""
3737
from graphrag.index.operations.extract_entities import (
3838
ExtractEntityStrategyType,
@@ -49,6 +49,6 @@ def resolved_strategy(self, root_dir: str, encoding_model: str) -> dict:
4949
else None,
5050
"max_gleanings": self.max_gleanings,
5151
# It's prechunked in create_base_text_units
52-
"encoding_name": self.encoding_model or encoding_model,
52+
"encoding_name": encoding_model or self.encoding_model,
5353
"prechunked": True,
5454
}

0 commit comments

Comments
 (0)