|
| 1 | +# Copyright (c) 2024 Microsoft Corporation. |
| 2 | +# Licensed under the MIT License |
| 3 | + |
| 4 | +"""A package containing the CosmosDB vector store implementation.""" |
| 5 | + |
| 6 | +import json |
| 7 | +from typing import Any |
| 8 | + |
| 9 | +from azure.cosmos import ContainerProxy, CosmosClient, DatabaseProxy |
| 10 | +from azure.cosmos.partition_key import PartitionKey |
| 11 | +from azure.identity import DefaultAzureCredential |
| 12 | + |
| 13 | +from graphrag.model.types import TextEmbedder |
| 14 | +from graphrag.vector_stores.base import ( |
| 15 | + DEFAULT_VECTOR_SIZE, |
| 16 | + BaseVectorStore, |
| 17 | + VectorStoreDocument, |
| 18 | + VectorStoreSearchResult, |
| 19 | +) |
| 20 | + |
| 21 | + |
| 22 | +class CosmosDBVectoreStore(BaseVectorStore): |
| 23 | + """Azure CosmosDB vector storage implementation.""" |
| 24 | + |
| 25 | + _cosmos_client: CosmosClient |
| 26 | + _database_client: DatabaseProxy |
| 27 | + _container_client: ContainerProxy |
| 28 | + |
| 29 | + def __init__(self, **kwargs: Any) -> None: |
| 30 | + super().__init__(**kwargs) |
| 31 | + |
| 32 | + def connect(self, **kwargs: Any) -> Any: |
| 33 | + """Connect to CosmosDB vector storage.""" |
| 34 | + connection_string = kwargs.get("connection_string") |
| 35 | + if connection_string: |
| 36 | + self._cosmos_client = CosmosClient.from_connection_string(connection_string) |
| 37 | + else: |
| 38 | + url = kwargs.get("url") |
| 39 | + if not url: |
| 40 | + msg = "Either connection_string or url must be provided." |
| 41 | + raise ValueError(msg) |
| 42 | + self._cosmos_client = CosmosClient( |
| 43 | + url=url, credential=DefaultAzureCredential() |
| 44 | + ) |
| 45 | + |
| 46 | + database_name = kwargs.get("database_name") |
| 47 | + if database_name is None: |
| 48 | + msg = "Database name must be provided." |
| 49 | + raise ValueError(msg) |
| 50 | + self._database_name = database_name |
| 51 | + collection_name = self.collection_name |
| 52 | + if collection_name is None: |
| 53 | + msg = "Collection name is empty or not provided." |
| 54 | + raise ValueError(msg) |
| 55 | + self._container_name = collection_name |
| 56 | + |
| 57 | + self.vector_size = kwargs.get("vector_size", DEFAULT_VECTOR_SIZE) |
| 58 | + self._create_database() |
| 59 | + self._create_container() |
| 60 | + |
| 61 | + def _create_database(self) -> None: |
| 62 | + """Create the database if it doesn't exist.""" |
| 63 | + self._cosmos_client.create_database_if_not_exists(id=self._database_name) |
| 64 | + self._database_client = self._cosmos_client.get_database_client( |
| 65 | + self._database_name |
| 66 | + ) |
| 67 | + |
| 68 | + def _delete_database(self) -> None: |
| 69 | + """Delete the database if it exists.""" |
| 70 | + if self._database_exists(): |
| 71 | + self._cosmos_client.delete_database(self._database_name) |
| 72 | + |
| 73 | + def _database_exists(self) -> bool: |
| 74 | + """Check if the database exists.""" |
| 75 | + existing_database_names = [ |
| 76 | + database["id"] for database in self._cosmos_client.list_databases() |
| 77 | + ] |
| 78 | + return self._database_name in existing_database_names |
| 79 | + |
| 80 | + def _create_container(self) -> None: |
| 81 | + """Create the container if it doesn't exist.""" |
| 82 | + partition_key = PartitionKey(path="/id", kind="Hash") |
| 83 | + |
| 84 | + # Define the container vector policy |
| 85 | + vector_embedding_policy = { |
| 86 | + "vectorEmbeddings": [ |
| 87 | + { |
| 88 | + "path": "/vector", |
| 89 | + "dataType": "float32", |
| 90 | + "distanceFunction": "cosine", |
| 91 | + "dimensions": self.vector_size, |
| 92 | + } |
| 93 | + ] |
| 94 | + } |
| 95 | + |
| 96 | + # Define the vector indexing policy |
| 97 | + indexing_policy = { |
| 98 | + "indexingMode": "consistent", |
| 99 | + "automatic": True, |
| 100 | + "includedPaths": [{"path": "/*"}], |
| 101 | + "excludedPaths": [{"path": "/_etag/?"}, {"path": "/vector/*"}], |
| 102 | + "vectorIndexes": [{"path": "/vector", "type": "diskANN"}], |
| 103 | + } |
| 104 | + |
| 105 | + # Create the container and container client |
| 106 | + self._database_client.create_container_if_not_exists( |
| 107 | + id=self._container_name, |
| 108 | + partition_key=partition_key, |
| 109 | + indexing_policy=indexing_policy, |
| 110 | + vector_embedding_policy=vector_embedding_policy, |
| 111 | + ) |
| 112 | + self._container_client = self._database_client.get_container_client( |
| 113 | + self._container_name |
| 114 | + ) |
| 115 | + |
| 116 | + def _delete_container(self) -> None: |
| 117 | + """Delete the vector store container in the database if it exists.""" |
| 118 | + if self._container_exists(): |
| 119 | + self._database_client.delete_container(self._container_name) |
| 120 | + |
| 121 | + def _container_exists(self) -> bool: |
| 122 | + """Check if the container name exists in the database.""" |
| 123 | + existing_container_names = [ |
| 124 | + container["id"] for container in self._database_client.list_containers() |
| 125 | + ] |
| 126 | + return self._container_name in existing_container_names |
| 127 | + |
| 128 | + def load_documents( |
| 129 | + self, documents: list[VectorStoreDocument], overwrite: bool = True |
| 130 | + ) -> None: |
| 131 | + """Load documents into CosmosDB.""" |
| 132 | + # Create a CosmosDB container on overwrite |
| 133 | + if overwrite: |
| 134 | + self._delete_container() |
| 135 | + self._create_container() |
| 136 | + |
| 137 | + if self._container_client is None: |
| 138 | + msg = "Container client is not initialized." |
| 139 | + raise ValueError(msg) |
| 140 | + |
| 141 | + # Upload documents to CosmosDB |
| 142 | + for doc in documents: |
| 143 | + if doc.vector is not None: |
| 144 | + doc_json = { |
| 145 | + "id": doc.id, |
| 146 | + "vector": doc.vector, |
| 147 | + "text": doc.text, |
| 148 | + "attributes": json.dumps(doc.attributes), |
| 149 | + } |
| 150 | + self._container_client.upsert_item(doc_json) |
| 151 | + |
| 152 | + def similarity_search_by_vector( |
| 153 | + self, query_embedding: list[float], k: int = 10, **kwargs: Any |
| 154 | + ) -> list[VectorStoreSearchResult]: |
| 155 | + """Perform a vector-based similarity search.""" |
| 156 | + if self._container_client is None: |
| 157 | + msg = "Container client is not initialized." |
| 158 | + raise ValueError(msg) |
| 159 | + |
| 160 | + query = f"SELECT TOP {k} c.id, c.text, c.vector, c.attributes, VectorDistance(c.vector, @embedding) AS SimilarityScore FROM c ORDER BY VectorDistance(c.vector, @embedding)" # noqa: S608 |
| 161 | + query_params = [{"name": "@embedding", "value": query_embedding}] |
| 162 | + items = self._container_client.query_items( |
| 163 | + query=query, |
| 164 | + parameters=query_params, |
| 165 | + enable_cross_partition_query=True, |
| 166 | + ) |
| 167 | + |
| 168 | + return [ |
| 169 | + VectorStoreSearchResult( |
| 170 | + document=VectorStoreDocument( |
| 171 | + id=item.get("id", ""), |
| 172 | + text=item.get("text", ""), |
| 173 | + vector=item.get("vector", []), |
| 174 | + attributes=(json.loads(item.get("attributes", "{}"))), |
| 175 | + ), |
| 176 | + score=item.get("SimilarityScore", 0.0), |
| 177 | + ) |
| 178 | + for item in items |
| 179 | + ] |
| 180 | + |
| 181 | + def similarity_search_by_text( |
| 182 | + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any |
| 183 | + ) -> list[VectorStoreSearchResult]: |
| 184 | + """Perform a text-based similarity search.""" |
| 185 | + query_embedding = text_embedder(text) |
| 186 | + if query_embedding: |
| 187 | + return self.similarity_search_by_vector( |
| 188 | + query_embedding=query_embedding, k=k |
| 189 | + ) |
| 190 | + return [] |
| 191 | + |
| 192 | + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: |
| 193 | + """Build a query filter to filter documents by a list of ids.""" |
| 194 | + if include_ids is None or len(include_ids) == 0: |
| 195 | + self.query_filter = None |
| 196 | + else: |
| 197 | + if isinstance(include_ids[0], str): |
| 198 | + id_filter = ", ".join([f"'{id}'" for id in include_ids]) |
| 199 | + else: |
| 200 | + id_filter = ", ".join([str(id) for id in include_ids]) |
| 201 | + self.query_filter = f"SELECT * FROM c WHERE c.id IN ({id_filter})" # noqa: S608 |
| 202 | + return self.query_filter |
| 203 | + |
| 204 | + def search_by_id(self, id: str) -> VectorStoreDocument: |
| 205 | + """Search for a document by id.""" |
| 206 | + if self._container_client is None: |
| 207 | + msg = "Container client is not initialized." |
| 208 | + raise ValueError(msg) |
| 209 | + |
| 210 | + item = self._container_client.read_item(item=id, partition_key=id) |
| 211 | + return VectorStoreDocument( |
| 212 | + id=item.get("id", ""), |
| 213 | + vector=item.get("vector", []), |
| 214 | + text=item.get("text", ""), |
| 215 | + attributes=(json.loads(item.get("attributes", "{}"))), |
| 216 | + ) |
0 commit comments