Skip to content

Commit 4637270

Browse files
authored
Implement CosmosDB vector store (#1587)
1 parent e21a38f commit 4637270

File tree

7 files changed

+247
-10
lines changed

7 files changed

+247
-10
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "add cosmosdb vector store"
4+
}

graphrag/config/defaults.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,16 @@
9393
UPDATE_STORAGE_BASE_DIR = "update_output"
9494

9595
VECTOR_STORE = f"""
96-
type: {VectorStoreType.LanceDB.value}
96+
type: {VectorStoreType.LanceDB.value} # one of [lancedb, azure_ai_search, cosmosdb]
9797
db_uri: '{(Path(STORAGE_BASE_DIR) / "lancedb")!s}'
98-
container_name: default
98+
collection_name: default
9999
overwrite: true\
100100
"""
101101

102102
VECTOR_STORE_DICT = {
103103
"type": VectorStoreType.LanceDB.value,
104104
"db_uri": str(Path(STORAGE_BASE_DIR) / "lancedb"),
105-
"container_name": "default",
105+
"collection_name": "default",
106106
"overwrite": True,
107107
}
108108

graphrag/index/create_pipeline_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -379,8 +379,8 @@ def _get_storage_config(
379379
connection_string = storage_settings.connection_string
380380
base_dir = storage_settings.base_dir
381381
container_name = storage_settings.container_name
382-
if cosmosdb_account_url is None:
383-
msg = "CosmosDB account url must be provided for cosmosdb storage."
382+
if connection_string is None and cosmosdb_account_url is None:
383+
msg = "Connection string or cosmosDB account url must be provided for cosmosdb storage."
384384
raise ValueError(msg)
385385
if base_dir is None:
386386
msg = "Base directory must be provided for cosmosdb storage."

graphrag/storage/cosmosdb_pipeline_storage.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class CosmosDBPipelineStorage(PipelineStorage):
3434
_database_name: str
3535
_container_name: str
3636
_encoding: str
37+
_no_id_prefixes: list[str]
3738

3839
def __init__(
3940
self,
@@ -66,6 +67,7 @@ def __init__(
6667
if cosmosdb_account_url
6768
else None
6869
)
70+
self._no_id_prefixes = []
6971
log.info(
7072
"creating cosmosdb storage with account: %s and database: %s and container: %s",
7173
self._cosmosdb_account_name,
@@ -208,6 +210,12 @@ async def get(
208210
items_df = pd.read_json(
209211
StringIO(items_json_str), orient="records", lines=False
210212
)
213+
214+
# Drop the "id" column if the original dataframe does not include it
215+
# TODO: Figure out optimal way to handle missing id keys in input dataframes
216+
if prefix in self._no_id_prefixes:
217+
items_df.drop(columns=["id"], axis=1, inplace=True)
218+
211219
return items_df.to_parquet()
212220
item = self._container_client.read_item(item=key, partition_key=key)
213221
item_body = item.get("body")
@@ -236,9 +244,14 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
236244
log.exception("Error converting output %s to json", key)
237245
else:
238246
cosmosdb_item_list = json.loads(value_json)
239-
for cosmosdb_item in cosmosdb_item_list:
247+
for index, cosmosdb_item in enumerate(cosmosdb_item_list):
248+
# If the id key does not exist in the input dataframe json, create a unique id using the prefix and item index
249+
# TODO: Figure out optimal way to handle missing id keys in input dataframes
250+
if "id" not in cosmosdb_item:
251+
prefixed_id = f"{prefix}:{index}"
252+
self._no_id_prefixes.append(prefix)
240253
# Append an additional prefix to the id to force a unique identifier for the create_final_nodes rows
241-
if prefix == "create_final_nodes":
254+
elif prefix == "create_final_nodes":
242255
prefixed_id = f"{prefix}-community_{cosmosdb_item['community']}:{cosmosdb_item['id']}"
243256
else:
244257
prefixed_id = f"{prefix}:{cosmosdb_item['id']}"

graphrag/vector_stores/azure_ai_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
)
3434

3535

36-
class AzureAISearch(BaseVectorStore):
36+
class AzureAISearchVectorStore(BaseVectorStore):
3737
"""Azure AI Search vector storage implementation."""
3838

3939
index_client: SearchIndexClient

graphrag/vector_stores/cosmosdb.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""A package containing the CosmosDB vector store implementation."""
5+
6+
import json
7+
from typing import Any
8+
9+
from azure.cosmos import ContainerProxy, CosmosClient, DatabaseProxy
10+
from azure.cosmos.partition_key import PartitionKey
11+
from azure.identity import DefaultAzureCredential
12+
13+
from graphrag.model.types import TextEmbedder
14+
from graphrag.vector_stores.base import (
15+
DEFAULT_VECTOR_SIZE,
16+
BaseVectorStore,
17+
VectorStoreDocument,
18+
VectorStoreSearchResult,
19+
)
20+
21+
22+
class CosmosDBVectoreStore(BaseVectorStore):
23+
"""Azure CosmosDB vector storage implementation."""
24+
25+
_cosmos_client: CosmosClient
26+
_database_client: DatabaseProxy
27+
_container_client: ContainerProxy
28+
29+
def __init__(self, **kwargs: Any) -> None:
30+
super().__init__(**kwargs)
31+
32+
def connect(self, **kwargs: Any) -> Any:
33+
"""Connect to CosmosDB vector storage."""
34+
connection_string = kwargs.get("connection_string")
35+
if connection_string:
36+
self._cosmos_client = CosmosClient.from_connection_string(connection_string)
37+
else:
38+
url = kwargs.get("url")
39+
if not url:
40+
msg = "Either connection_string or url must be provided."
41+
raise ValueError(msg)
42+
self._cosmos_client = CosmosClient(
43+
url=url, credential=DefaultAzureCredential()
44+
)
45+
46+
database_name = kwargs.get("database_name")
47+
if database_name is None:
48+
msg = "Database name must be provided."
49+
raise ValueError(msg)
50+
self._database_name = database_name
51+
collection_name = self.collection_name
52+
if collection_name is None:
53+
msg = "Collection name is empty or not provided."
54+
raise ValueError(msg)
55+
self._container_name = collection_name
56+
57+
self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE)
58+
self._create_database()
59+
self._create_container()
60+
61+
def _create_database(self) -> None:
62+
"""Create the database if it doesn't exist."""
63+
self._cosmos_client.create_database_if_not_exists(id=self._database_name)
64+
self._database_client = self._cosmos_client.get_database_client(
65+
self._database_name
66+
)
67+
68+
def _delete_database(self) -> None:
69+
"""Delete the database if it exists."""
70+
if self._database_exists():
71+
self._cosmos_client.delete_database(self._database_name)
72+
73+
def _database_exists(self) -> bool:
74+
"""Check if the database exists."""
75+
existing_database_names = [
76+
database["id"] for database in self._cosmos_client.list_databases()
77+
]
78+
return self._database_name in existing_database_names
79+
80+
def _create_container(self) -> None:
81+
"""Create the container if it doesn't exist."""
82+
partition_key = PartitionKey(path="/id", kind="Hash")
83+
84+
# Define the container vector policy
85+
vector_embedding_policy = {
86+
"vectorEmbeddings": [
87+
{
88+
"path": "/vector",
89+
"dataType": "float32",
90+
"distanceFunction": "cosine",
91+
"dimensions": self.vector_size,
92+
}
93+
]
94+
}
95+
96+
# Define the vector indexing policy
97+
indexing_policy = {
98+
"indexingMode": "consistent",
99+
"automatic": True,
100+
"includedPaths": [{"path": "/*"}],
101+
"excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}],
102+
"vectorIndexes": [{"path": "/vector", "type": "diskANN"}],
103+
}
104+
105+
# Create the container and container client
106+
self._database_client.create_container_if_not_exists(
107+
id=self._container_name,
108+
partition_key=partition_key,
109+
indexing_policy=indexing_policy,
110+
vector_embedding_policy=vector_embedding_policy,
111+
)
112+
self._container_client = self._database_client.get_container_client(
113+
self._container_name
114+
)
115+
116+
def _delete_container(self) -> None:
117+
"""Delete the vector store container in the database if it exists."""
118+
if self._container_exists():
119+
self._database_client.delete_container(self._container_name)
120+
121+
def _container_exists(self) -> bool:
122+
"""Check if the container name exists in the database."""
123+
existing_container_names = [
124+
container["id"] for container in self._database_client.list_containers()
125+
]
126+
return self._container_name in existing_container_names
127+
128+
def load_documents(
129+
self, documents: list[VectorStoreDocument], overwrite: bool = True
130+
) -> None:
131+
"""Load documents into CosmosDB."""
132+
# Create a CosmosDB container on overwrite
133+
if overwrite:
134+
self._delete_container()
135+
self._create_container()
136+
137+
if self._container_client is None:
138+
msg = "Container client is not initialized."
139+
raise ValueError(msg)
140+
141+
# Upload documents to CosmosDB
142+
for doc in documents:
143+
if doc.vector is not None:
144+
doc_json = {
145+
"id": doc.id,
146+
"vector": doc.vector,
147+
"text": doc.text,
148+
"attributes": json.dumps(doc.attributes),
149+
}
150+
self._container_client.upsert_item(doc_json)
151+
152+
def similarity_search_by_vector(
153+
self, query_embedding: list[float], k: int = 10, **kwargs: Any
154+
) -> list[VectorStoreSearchResult]:
155+
"""Perform a vector-based similarity search."""
156+
if self._container_client is None:
157+
msg = "Container client is not initialized."
158+
raise ValueError(msg)
159+
160+
query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608
161+
query_params = [{"name": "@embedding", "value": query_embedding}]
162+
items = self._container_client.query_items(
163+
query=query,
164+
parameters=query_params,
165+
enable_cross_partition_query=True,
166+
)
167+
168+
return [
169+
VectorStoreSearchResult(
170+
document=VectorStoreDocument(
171+
id=item.get("id", ""),
172+
text=item.get("text", ""),
173+
vector=item.get("vector", []),
174+
attributes=(json.loads(item.get("attributes", "{}"))),
175+
),
176+
score=item.get("SimilarityScore", 0.0),
177+
)
178+
for item in items
179+
]
180+
181+
def similarity_search_by_text(
182+
self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any
183+
) -> list[VectorStoreSearchResult]:
184+
"""Perform a text-based similarity search."""
185+
query_embedding = text_embedder(text)
186+
if query_embedding:
187+
return self.similarity_search_by_vector(
188+
query_embedding=query_embedding, k=k
189+
)
190+
return []
191+
192+
def filter_by_id(self, include_ids: list[str] | list[int]) -> Any:
193+
"""Build a query filter to filter documents by a list of ids."""
194+
if include_ids is None or len(include_ids) == 0:
195+
self.query_filter = None
196+
else:
197+
if isinstance(include_ids[0], str):
198+
id_filter = ", ".join([f"'{id}'" for id in include_ids])
199+
else:
200+
id_filter = ", ".join([str(id) for id in include_ids])
201+
self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608
202+
return self.query_filter
203+
204+
def search_by_id(self, id: str) -> VectorStoreDocument:
205+
"""Search for a document by id."""
206+
if self._container_client is None:
207+
msg = "Container client is not initialized."
208+
raise ValueError(msg)
209+
210+
item = self._container_client.read_item(item=id, partition_key=id)
211+
return VectorStoreDocument(
212+
id=item.get("id", ""),
213+
vector=item.get("vector", []),
214+
text=item.get("text", ""),
215+
attributes=(json.loads(item.get("attributes", "{}"))),
216+
)

graphrag/vector_stores/factory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from enum import Enum
77
from typing import ClassVar
88

9-
from graphrag.vector_stores.azure_ai_search import AzureAISearch
9+
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
1010
from graphrag.vector_stores.base import BaseVectorStore
11+
from graphrag.vector_stores.cosmosdb import CosmosDBVectoreStore
1112
from graphrag.vector_stores.lancedb import LanceDBVectorStore
1213

1314

@@ -16,6 +17,7 @@ class VectorStoreType(str, Enum):
1617

1718
LanceDB = "lancedb"
1819
AzureAISearch = "azure_ai_search"
20+
CosmosDB = "cosmosdb"
1921

2022

2123
class VectorStoreFactory:
@@ -40,7 +42,9 @@ def create_vector_store(
4042
case VectorStoreType.LanceDB:
4143
return LanceDBVectorStore(**kwargs)
4244
case VectorStoreType.AzureAISearch:
43-
return AzureAISearch(**kwargs)
45+
return AzureAISearchVectorStore(**kwargs)
46+
case VectorStoreType.CosmosDB:
47+
return CosmosDBVectoreStore(**kwargs)
4448
case _:
4549
if vector_store_type in cls.vector_store_types:
4650
return cls.vector_store_types[vector_store_type](**kwargs)

0 commit comments

Comments
 (0)