Skip to content

Commit fbf11f3

Browse files
Optional embeddings (#1890)
* Make all tables optional for embeddings * Semver --------- Co-authored-by: Alonso Guevara <alonsog@microsoft.com>
1 parent 56e0fad commit fbf11f3

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Align embeddings table loading with configured fields."
4+
}

graphrag/index/workflows/generate_text_embeddings.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@
2525
from graphrag.index.operations.embed_text import embed_text
2626
from graphrag.index.typing.context import PipelineRunContext
2727
from graphrag.index.typing.workflow import WorkflowFunctionOutput
28-
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
28+
from graphrag.utils.storage import (
29+
load_table_from_storage,
30+
storage_has_table,
31+
write_table_to_storage,
32+
)
2933

3034
log = logging.getLogger(__name__)
3135

@@ -35,13 +39,23 @@ async def run_workflow(
3539
context: PipelineRunContext,
3640
) -> WorkflowFunctionOutput:
3741
"""All the steps to transform community reports."""
38-
documents = await load_table_from_storage("documents", context.storage)
39-
relationships = await load_table_from_storage("relationships", context.storage)
40-
text_units = await load_table_from_storage("text_units", context.storage)
41-
entities = await load_table_from_storage("entities", context.storage)
42-
community_reports = await load_table_from_storage(
43-
"community_reports", context.storage
44-
)
42+
documents = None
43+
relationships = None
44+
text_units = None
45+
entities = None
46+
community_reports = None
47+
if await storage_has_table("documents", context.storage):
48+
documents = await load_table_from_storage("documents", context.storage)
49+
if await storage_has_table("relationships", context.storage):
50+
relationships = await load_table_from_storage("relationships", context.storage)
51+
if await storage_has_table("text_units", context.storage):
52+
text_units = await load_table_from_storage("text_units", context.storage)
53+
if await storage_has_table("entities", context.storage):
54+
entities = await load_table_from_storage("entities", context.storage)
55+
if await storage_has_table("community_reports", context.storage):
56+
community_reports = await load_table_from_storage(
57+
"community_reports", context.storage
58+
)
4559

4660
embedded_fields = get_embedded_fields(config)
4761
text_embed = get_embedding_settings(config)
@@ -133,6 +147,10 @@ async def generate_text_embeddings(
133147
log.info("Creating embeddings")
134148
outputs = {}
135149
for field in embedded_fields:
150+
if embedding_param_map[field]["data"] is None:
151+
msg = f"Embedding {field} is specified but data table is not in storage."
152+
raise ValueError(msg)
153+
136154
outputs[field] = await _run_and_snapshot_embeddings(
137155
name=field,
138156
callbacks=callbacks,

0 commit comments

Comments
 (0)