Skip to content

Commit 1c3b0f3

Browse files
Chore/lib updates (#1477)
* Update dependencies and fix issues * Format * Semver * Fix Pyright * Pyright * More Pyright * Pyright
1 parent b1f2ca7 commit 1c3b0f3

File tree

71 files changed

+554
-537
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

71 files changed

+554
-537
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Dependency updates"
4+
}

examples_notebooks/community_contrib/neo4j/graphrag_import_neo4j_cypher.ipynb

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,16 @@
209209
"source": [
210210
"# create constraints, idempotent operation\n",
211211
"\n",
212-
"statements = \"\"\"\n",
213-
"create constraint chunk_id if not exists for (c:__Chunk__) require c.id is unique;\n",
214-
"create constraint document_id if not exists for (d:__Document__) require d.id is unique;\n",
215-
"create constraint entity_id if not exists for (c:__Community__) require c.community is unique;\n",
216-
"create constraint entity_id if not exists for (e:__Entity__) require e.id is unique;\n",
217-
"create constraint entity_title if not exists for (e:__Entity__) require e.name is unique;\n",
218-
"create constraint entity_title if not exists for (e:__Covariate__) require e.title is unique;\n",
219-
"create constraint related_id if not exists for ()-[rel:RELATED]->() require rel.id is unique;\n",
220-
"\"\"\".split(\";\")\n",
212+
"statements = [\n",
213+
" \"\\ncreate constraint chunk_id if not exists for (c:__Chunk__) require c.id is unique\",\n",
214+
" \"\\ncreate constraint document_id if not exists for (d:__Document__) require d.id is unique\",\n",
215+
" \"\\ncreate constraint entity_id if not exists for (c:__Community__) require c.community is unique\",\n",
216+
" \"\\ncreate constraint entity_id if not exists for (e:__Entity__) require e.id is unique\",\n",
217+
" \"\\ncreate constraint entity_title if not exists for (e:__Entity__) require e.name is unique\",\n",
218+
" \"\\ncreate constraint entity_title if not exists for (e:__Covariate__) require e.title is unique\",\n",
219+
" \"\\ncreate constraint related_id if not exists for ()-[rel:RELATED]->() require rel.id is unique\",\n",
220+
" \"\\n\",\n",
221+
"]\n",
221222
"\n",
222223
"for statement in statements:\n",
223224
" if len((statement or \"\").strip()) > 0:\n",

graphrag/api/query.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
read_indexer_reports,
4545
read_indexer_text_units,
4646
)
47-
from graphrag.query.structured_search.base import SearchResult # noqa: TCH001
47+
from graphrag.query.structured_search.base import SearchResult # noqa: TC001
4848
from graphrag.utils.cli import redact
4949
from graphrag.utils.embeddings import create_collection_name
5050
from graphrag.vector_stores.base import BaseVectorStore
@@ -90,14 +90,14 @@ async def global_search(
9090
------
9191
TODO: Document any exceptions to expect.
9292
"""
93-
_communities = read_indexer_communities(communities, nodes, community_reports)
93+
communities_ = read_indexer_communities(communities, nodes, community_reports)
9494
reports = read_indexer_reports(
9595
community_reports,
9696
nodes,
9797
community_level=community_level,
9898
dynamic_community_selection=dynamic_community_selection,
9999
)
100-
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
100+
entities_ = read_indexer_entities(nodes, entities, community_level=community_level)
101101
map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt)
102102
reduce_prompt = _load_search_prompt(
103103
config.root_dir, config.global_search.reduce_prompt
@@ -109,8 +109,8 @@ async def global_search(
109109
search_engine = get_global_search_engine(
110110
config,
111111
reports=reports,
112-
entities=_entities,
113-
communities=_communities,
112+
entities=entities_,
113+
communities=communities_,
114114
response_type=response_type,
115115
dynamic_community_selection=dynamic_community_selection,
116116
map_system_prompt=map_prompt,
@@ -159,14 +159,14 @@ async def global_search_streaming(
159159
------
160160
TODO: Document any exceptions to expect.
161161
"""
162-
_communities = read_indexer_communities(communities, nodes, community_reports)
162+
communities_ = read_indexer_communities(communities, nodes, community_reports)
163163
reports = read_indexer_reports(
164164
community_reports,
165165
nodes,
166166
community_level=community_level,
167167
dynamic_community_selection=dynamic_community_selection,
168168
)
169-
_entities = read_indexer_entities(nodes, entities, community_level=community_level)
169+
entities_ = read_indexer_entities(nodes, entities, community_level=community_level)
170170
map_prompt = _load_search_prompt(config.root_dir, config.global_search.map_prompt)
171171
reduce_prompt = _load_search_prompt(
172172
config.root_dir, config.global_search.reduce_prompt
@@ -178,8 +178,8 @@ async def global_search_streaming(
178178
search_engine = get_global_search_engine(
179179
config,
180180
reports=reports,
181-
entities=_entities,
182-
communities=_communities,
181+
entities=entities_,
182+
communities=communities_,
183183
response_type=response_type,
184184
dynamic_community_selection=dynamic_community_selection,
185185
map_system_prompt=map_prompt,
@@ -258,17 +258,17 @@ async def local_search(
258258
embedding_name=entity_description_embedding,
259259
)
260260

261-
_entities = read_indexer_entities(nodes, entities, community_level)
262-
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
261+
entities_ = read_indexer_entities(nodes, entities, community_level)
262+
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
263263
prompt = _load_search_prompt(config.root_dir, config.local_search.prompt)
264264

265265
search_engine = get_local_search_engine(
266266
config=config,
267267
reports=read_indexer_reports(community_reports, nodes, community_level),
268268
text_units=read_indexer_text_units(text_units),
269-
entities=_entities,
269+
entities=entities_,
270270
relationships=read_indexer_relationships(relationships),
271-
covariates={"claims": _covariates},
271+
covariates={"claims": covariates_},
272272
description_embedding_store=description_embedding_store, # type: ignore
273273
response_type=response_type,
274274
system_prompt=prompt,
@@ -334,17 +334,17 @@ async def local_search_streaming(
334334
embedding_name=entity_description_embedding,
335335
)
336336

337-
_entities = read_indexer_entities(nodes, entities, community_level)
338-
_covariates = read_indexer_covariates(covariates) if covariates is not None else []
337+
entities_ = read_indexer_entities(nodes, entities, community_level)
338+
covariates_ = read_indexer_covariates(covariates) if covariates is not None else []
339339
prompt = _load_search_prompt(config.root_dir, config.local_search.prompt)
340340

341341
search_engine = get_local_search_engine(
342342
config=config,
343343
reports=read_indexer_reports(community_reports, nodes, community_level),
344344
text_units=read_indexer_text_units(text_units),
345-
entities=_entities,
345+
entities=entities_,
346346
relationships=read_indexer_relationships(relationships),
347-
covariates={"claims": _covariates},
347+
covariates={"claims": covariates_},
348348
description_embedding_store=description_embedding_store, # type: ignore
349349
response_type=response_type,
350350
system_prompt=prompt,
@@ -424,15 +424,15 @@ async def drift_search(
424424
embedding_name=community_full_content_embedding,
425425
)
426426

427-
_entities = read_indexer_entities(nodes, entities, community_level)
428-
_reports = read_indexer_reports(community_reports, nodes, community_level)
429-
read_indexer_report_embeddings(_reports, full_content_embedding_store)
427+
entities_ = read_indexer_entities(nodes, entities, community_level)
428+
reports = read_indexer_reports(community_reports, nodes, community_level)
429+
read_indexer_report_embeddings(reports, full_content_embedding_store)
430430
prompt = _load_search_prompt(config.root_dir, config.drift_search.prompt)
431431
search_engine = get_drift_search_engine(
432432
config=config,
433-
reports=_reports,
433+
reports=reports,
434434
text_units=read_indexer_text_units(text_units),
435-
entities=_entities,
435+
entities=entities_,
436436
relationships=read_indexer_relationships(relationships),
437437
description_embedding_store=description_embedding_store, # type: ignore
438438
local_system_prompt=prompt,
@@ -492,9 +492,9 @@ def _patch_vector_store(
492492
db_uri=config.embeddings.vector_store["db_uri"]
493493
)
494494
# dump embeddings from the entities list to the description_embedding_store
495-
_entities = read_indexer_entities(nodes, entities, community_level)
495+
entities_ = read_indexer_entities(nodes, entities, community_level)
496496
store_entity_semantic_embeddings(
497-
entities=_entities, vectorstore=description_embedding_store
497+
entities=entities_, vectorstore=description_embedding_store
498498
)
499499

500500
if with_reports is not None:
@@ -506,7 +506,7 @@ def _patch_vector_store(
506506
community_reports = with_reports
507507
container_name = config.embeddings.vector_store["container_name"]
508508
# Store report embeddings
509-
_reports = read_indexer_reports(
509+
reports = read_indexer_reports(
510510
community_reports,
511511
nodes,
512512
community_level,
@@ -526,7 +526,7 @@ def _patch_vector_store(
526526
)
527527
# dump embeddings from the reports list to the full_content_embedding_store
528528
store_reports_semantic_embeddings(
529-
reports=_reports, vectorstore=full_content_embedding_store
529+
reports=reports, vectorstore=full_content_embedding_store
530530
)
531531

532532
return config

graphrag/cache/factory.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,15 @@
88
from typing import TYPE_CHECKING, cast
99

1010
from graphrag.config.enums import CacheType
11-
from graphrag.index.config.cache import (
12-
PipelineBlobCacheConfig,
13-
PipelineFileCacheConfig,
14-
)
1511
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
1612
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
1713

1814
if TYPE_CHECKING:
1915
from graphrag.cache.pipeline_cache import PipelineCache
2016
from graphrag.index.config.cache import (
17+
PipelineBlobCacheConfig,
2118
PipelineCacheConfig,
19+
PipelineFileCacheConfig,
2220
)
2321

2422
from graphrag.cache.json_pipeline_cache import JsonPipelineCache
@@ -39,11 +37,11 @@ def create_cache(
3937
case CacheType.memory:
4038
return InMemoryCache()
4139
case CacheType.file:
42-
config = cast(PipelineFileCacheConfig, config)
40+
config = cast("PipelineFileCacheConfig", config)
4341
storage = FilePipelineStorage(root_dir).child(config.base_dir)
4442
return JsonPipelineCache(storage)
4543
case CacheType.blob:
46-
config = cast(PipelineBlobCacheConfig, config)
44+
config = cast("PipelineBlobCacheConfig", config)
4745
storage = BlobPipelineStorage(
4846
config.connection_string,
4947
config.container_name,

graphrag/callbacks/factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,14 @@ def create_pipeline_reporter(
2727

2828
match config.type:
2929
case ReportingType.file:
30-
config = cast(PipelineFileReportingConfig, config)
30+
config = cast("PipelineFileReportingConfig", config)
3131
return FileWorkflowCallbacks(
3232
str(Path(root_dir or "") / (config.base_dir or ""))
3333
)
3434
case ReportingType.console:
3535
return ConsoleWorkflowCallbacks()
3636
case ReportingType.blob:
37-
config = cast(PipelineBlobReportingConfig, config)
37+
config = cast("PipelineBlobReportingConfig", config)
3838
return BlobWorkflowCallbacks(
3939
config.connection_string,
4040
config.container_name,

graphrag/cli/main.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,26 +46,37 @@ def wildcard_match(string: str, pattern: str) -> bool:
4646
regex = re.escape(pattern).replace(r"\?", ".").replace(r"\*", ".*")
4747
return re.fullmatch(regex, string) is not None
4848

49+
from pathlib import Path
50+
4951
def completer(incomplete: str) -> list[str]:
50-
items = os.listdir()
52+
# List items in the current directory as Path objects
53+
items = Path().iterdir()
5154
completions = []
55+
5256
for item in items:
53-
if not file_okay and Path(item).is_file():
57+
# Filter based on file/directory properties
58+
if not file_okay and item.is_file():
5459
continue
55-
if not dir_okay and Path(item).is_dir():
60+
if not dir_okay and item.is_dir():
5661
continue
5762
if readable and not os.access(item, os.R_OK):
5863
continue
5964
if writable and not os.access(item, os.W_OK):
6065
continue
61-
completions.append(item)
66+
67+
# Append the name of the matching item
68+
completions.append(item.name)
69+
70+
# Apply wildcard matching if required
6271
if match_wildcard:
6372
completions = filter(
6473
lambda i: wildcard_match(i, match_wildcard)
6574
if match_wildcard
6675
else False,
6776
completions,
6877
)
78+
79+
# Return completions that start with the given incomplete string
6980
return [i for i in completions if i.startswith(incomplete)]
7081

7182
return completer

graphrag/config/create_graphrag_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def create_graphrag_config(
6464
values = values or {}
6565
root_dir = root_dir or str(Path.cwd())
6666
env = _make_env(root_dir)
67-
_token_replace(cast(dict, values))
67+
_token_replace(cast("dict", values))
6868
InputModelValidator.validate_python(values, strict=True)
6969

7070
reader = EnvironmentReader(env)

graphrag/config/defaults.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@
9797
overwrite: true\
9898
"""
9999

100+
VECTOR_STORE_DICT = {
101+
"type": VectorStoreType.LanceDB.value,
102+
"db_uri": str(Path(STORAGE_BASE_DIR) / "lancedb"),
103+
"container_name": "default",
104+
"overwrite": True,
105+
}
106+
100107
# Local Search
101108
LOCAL_SEARCH_TEXT_UNIT_PROP = 0.5
102109
LOCAL_SEARCH_COMMUNITY_PROP = 0.1

graphrag/config/models/graph_rag_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __str__(self):
4141
return self.model_dump_json(indent=4)
4242

4343
root_dir: str = Field(
44-
description="The root directory for the configuration.", default=None
44+
description="The root directory for the configuration.", default="."
4545
)
4646

4747
reporting: ReportingConfig = Field(

graphrag/config/models/text_embedding_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class TextEmbeddingConfig(LLMConfig):
2626
)
2727
skip: list[str] = Field(description="The specific embeddings to skip.", default=[])
2828
vector_store: dict | None = Field(
29-
description="The vector storage configuration", default=defs.VECTOR_STORE
29+
description="The vector storage configuration", default=defs.VECTOR_STORE_DICT
3030
)
3131
strategy: dict | None = Field(
3232
description="The override strategy to use.", default=None

0 commit comments

Comments
 (0)