Skip to content

Commit cfe2082

Browse files
Fix/llm bugs empty extraction (#1533)
* Add llm singleton and check for empty extraction * Semver * Tests and spellcheck * Move the singletons to a proper place * Leftover print * Ruff
1 parent f7cd155 commit cfe2082

File tree

4 files changed

+79
-2
lines changed

4 files changed

+79
-2
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Manage llm instances inside a cached singleton. Check for empty dfs after entity/relationship extraction"
4+
}

graphrag/index/flows/extract_graph.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,18 @@ async def extract_graph(
5252
num_threads=extraction_num_threads,
5353
)
5454

55+
if not _validate_data(entity_dfs):
56+
error_msg = "Entity Extraction failed. No entities detected during extraction."
57+
callbacks.error(error_msg)
58+
raise ValueError(error_msg)
59+
60+
if not _validate_data(relationship_dfs):
61+
error_msg = (
62+
"Entity Extraction failed. No relationships detected during extraction."
63+
)
64+
callbacks.error(error_msg)
65+
raise ValueError(error_msg)
66+
5567
merged_entities = _merge_entities(entity_dfs)
5668
merged_relationships = _merge_relationships(relationship_dfs)
5769

@@ -145,3 +157,10 @@ def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
145157
{"name": node, "degree": int(degree)}
146158
for node, degree in graph.degree # type: ignore
147159
])
160+
161+
162+
def _validate_data(df_list: list[pd.DataFrame]) -> bool:
163+
"""Validate that the dataframe list is valid. At least one dataframe must contain data."""
164+
return any(
165+
len(df) > 0 for df in df_list
166+
) # Check for len, not .empty, as the dfs have schemas in some cases

graphrag/index/llm/load_llm.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import graphrag.config.defaults as defs
2525
from graphrag.config.enums import LLMType
2626
from graphrag.config.models.llm_parameters import LLMParameters
27+
from graphrag.index.llm.manager import ChatLLMSingleton, EmbeddingsLLMSingleton
2728

2829
from .mock_llm import MockChatLLM
2930

@@ -110,6 +111,10 @@ def load_llm(
110111
chat_only=False,
111112
) -> ChatLLM:
112113
"""Load the LLM for the entity extraction chain."""
114+
singleton_llm = ChatLLMSingleton().get_llm(name)
115+
if singleton_llm is not None:
116+
return singleton_llm
117+
113118
on_error = _create_error_handler(callbacks)
114119
llm_type = config.type
115120

@@ -119,7 +124,9 @@ def load_llm(
119124
raise ValueError(msg)
120125

121126
loader = loaders[llm_type]
122-
return loader["load"](on_error, create_cache(cache, name), config)
127+
llm_instance = loader["load"](on_error, create_cache(cache, name), config)
128+
ChatLLMSingleton().set_llm(name, llm_instance)
129+
return llm_instance
123130

124131
msg = f"Unknown LLM type {llm_type}"
125132
raise ValueError(msg)
@@ -134,15 +141,21 @@ def load_llm_embeddings(
134141
chat_only=False,
135142
) -> EmbeddingsLLM:
136143
"""Load the LLM for the entity extraction chain."""
144+
singleton_llm = EmbeddingsLLMSingleton().get_llm(name)
145+
if singleton_llm is not None:
146+
return singleton_llm
147+
137148
on_error = _create_error_handler(callbacks)
138149
llm_type = llm_config.type
139150
if llm_type in loaders:
140151
if chat_only and not loaders[llm_type]["chat"]:
141152
msg = f"LLM type {llm_type} does not support chat"
142153
raise ValueError(msg)
143-
return loaders[llm_type]["load"](
154+
llm_instance = loaders[llm_type]["load"](
144155
on_error, create_cache(cache, name), llm_config or {}
145156
)
157+
EmbeddingsLLMSingleton().set_llm(name, llm_instance)
158+
return llm_instance
146159

147160
msg = f"Unknown LLM type {llm_type}"
148161
raise ValueError(msg)
@@ -198,6 +211,7 @@ def _create_openai_config(config: LLMParameters, azure: bool) -> OpenAIConfig:
198211
n=config.n,
199212
temperature=config.temperature,
200213
)
214+
201215
if azure:
202216
if config.api_base is None:
203217
msg = "Azure OpenAI Chat LLM requires an API base"

graphrag/index/llm/manager.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""LLM Manager singleton."""
5+
6+
from functools import cache
7+
8+
from fnllm import ChatLLM, EmbeddingsLLM
9+
10+
11+
@cache
12+
class ChatLLMSingleton:
13+
"""A singleton class for the chat LLM instances."""
14+
15+
def __init__(self):
16+
self.llm_dict = {}
17+
18+
def set_llm(self, name, llm):
19+
"""Add an LLM to the dictionary."""
20+
self.llm_dict[name] = llm
21+
22+
def get_llm(self, name) -> ChatLLM | None:
23+
"""Get an LLM from the dictionary."""
24+
return self.llm_dict.get(name)
25+
26+
27+
@cache
28+
class EmbeddingsLLMSingleton:
29+
"""A singleton class for the embeddings LLM instances."""
30+
31+
def __init__(self):
32+
self.llm_dict = {}
33+
34+
def set_llm(self, name, llm):
35+
"""Add an LLM to the dictionary."""
36+
self.llm_dict[name] = llm
37+
38+
def get_llm(self, name) -> EmbeddingsLLM | None:
39+
"""Get an LLM from the dictionary."""
40+
return self.llm_dict.get(name)

0 commit comments

Comments
 (0)