From 68bd7272fc9d0276bb62c6b3acf1b2bf25e089e0 Mon Sep 17 00:00:00 2001 From: Gijs Segerink Date: Wed, 19 Feb 2025 19:24:08 +0100 Subject: [PATCH 1/3] documentdb support --- .gitignore | 1 + .vscode/extensions.json | 12 - .vscode/launch.json | 39 --- .vscode/settings.json | 52 --- graphrag/cache/factory.py | 3 + graphrag/config/enums.py | 4 + graphrag/config/init_content.py | 6 +- graphrag/config/models/cache_config.py | 3 + graphrag/config/models/output_config.py | 3 + graphrag/config/models/vector_store_config.py | 10 +- .../storage/documentdb_pipeline_storage.py | 328 ++++++++++++++++++ graphrag/storage/factory.py | 3 + graphrag/vector_stores/documentdb.py | 281 +++++++++++++++ graphrag/vector_stores/factory.py | 4 + .../storage/test_documentdb_storage.py | 138 ++++++++ tests/integration/storage/test_factory.py | 14 + tests/unit/config/utils.py | 4 + 17 files changed, 797 insertions(+), 108 deletions(-) delete mode 100644 .vscode/extensions.json delete mode 100644 .vscode/launch.json delete mode 100644 .vscode/settings.json create mode 100644 graphrag/storage/documentdb_pipeline_storage.py create mode 100644 graphrag/vector_stores/documentdb.py create mode 100644 tests/integration/storage/test_documentdb_storage.py diff --git a/.gitignore b/.gitignore index 707bc44711..e75a076761 100644 --- a/.gitignore +++ b/.gitignore @@ -20,6 +20,7 @@ output/lancedb venv/ .conda .tmp +.vscode .env build.zip diff --git a/.vscode/extensions.json b/.vscode/extensions.json deleted file mode 100644 index 2e5e67a214..0000000000 --- a/.vscode/extensions.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "recommendations": [ - "arcanis.vscode-zipfs", - "ms-python.python", - "charliermarsh.ruff", - "ms-python.vscode-pylance", - "bierner.markdown-mermaid", - "streetsidesoftware.code-spell-checker", - "ronnidc.nunjucks", - "lucien-martijn.parquet-visualizer", - ] -} diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index 2167063966..0000000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,39 +0,0 @@ -{ - "_comment": "Use this file to configure the graphrag project for debugging. You may create other configuration profiles based on these or select one below to use.", - "version": "0.2.0", - "configurations": [ - { - "name": "Indexer", - "type": "debugpy", - "request": "launch", - "module": "poetry", - "args": [ - "poe", "index", - "--root", "" - ], - }, - { - "name": "Query", - "type": "debugpy", - "request": "launch", - "module": "poetry", - "args": [ - "poe", "query", - "--root", "", - "--method", "global", - "--query", "What are the top themes in this story", - ] - }, - { - "name": "Prompt Tuning", - "type": "debugpy", - "request": "launch", - "module": "poetry", - "args": [ - "poe", "prompt-tune", - "--config", - "/settings.yaml", - ] - } - ] -} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 0b678d5d95..0000000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,52 +0,0 @@ -{ - "search.exclude": { - "**/.yarn": true, - "**/.pnp.*": true - }, - "editor.formatOnSave": false, - "eslint.nodePath": ".yarn/sdks", - "typescript.tsdk": ".yarn/sdks/typescript/lib", - "typescript.enablePromptUseWorkspaceTsdk": true, - "javascript.preferences.importModuleSpecifier": "relative", - "javascript.preferences.importModuleSpecifierEnding": "js", - "typescript.preferences.importModuleSpecifier": "relative", - "typescript.preferences.importModuleSpecifierEnding": "js", - "explorer.fileNesting.enabled": true, - "explorer.fileNesting.patterns": { - "*.ts": "${capture}.ts, ${capture}.hooks.ts, ${capture}.hooks.tsx, ${capture}.contexts.ts, ${capture}.stories.tsx, ${capture}.story.tsx, ${capture}.spec.tsx, ${capture}.base.ts, ${capture}.base.tsx, ${capture}.types.ts, ${capture}.styles.ts, ${capture}.styles.tsx, ${capture}.utils.ts, ${capture}.utils.tsx, ${capture}.constants.ts, ${capture}.module.scss, ${capture}.module.css, ${capture}.md", - "*.js": "${capture}.js.map, ${capture}.min.js, ${capture}.d.ts", - "*.jsx": "${capture}.js", - "*.tsx": "${capture}.ts, ${capture}.hooks.ts, ${capture}.hooks.tsx, ${capture}.contexts.ts, ${capture}.stories.tsx, ${capture}.story.tsx, ${capture}.spec.tsx, ${capture}.base.ts, ${capture}.base.tsx, ${capture}.types.ts, ${capture}.styles.ts, ${capture}.styles.tsx, ${capture}.utils.ts, ${capture}.utils.tsx, ${capture}.constants.ts, ${capture}.module.scss, ${capture}.module.css, ${capture}.md, ${capture}.css", - "tsconfig.json": "tsconfig.*.json", - "package.json": "package-lock.json, turbo.json, tsconfig.json, rome.json, biome.json, .npmignore, dictionary.txt, cspell.config.yaml", - "README.md": "*.md, LICENSE, CODEOWNERS", - ".eslintrc": ".eslintignore", - ".prettierrc": ".prettierignore", - ".gitattributes": ".gitignore", - ".yarnrc.yml": "yarn.lock, .pnp.*", - "jest.config.js": "jest.setup.mjs", - "pyproject.toml": "poetry.lock, poetry.toml, mkdocs.yaml", - "cspell.config.yaml": "dictionary.txt" - }, - "azureFunctions.postDeployTask": "npm install (functions)", - "azureFunctions.projectLanguage": "TypeScript", - "azureFunctions.projectRuntime": "~4", - "debug.internalConsoleOptions": "neverOpen", - "azureFunctions.preDeployTask": "npm prune (functions)", - "appService.zipIgnorePattern": [ - "node_modules{,/**}", - ".vscode{,/**}" - ], - "python.defaultInterpreterPath": "python/services/.venv/bin/python", - "python.languageServer": "Pylance", - "cSpell.customDictionaries": { - "project-words": { - "name": "project-words", - "path": "${workspaceRoot}/dictionary.txt", - "description": "Words used in this project", - "addWords": true - }, - "custom": true, // Enable the `custom` dictionary - "internal-terms": true // Disable the `internal-terms` dictionary - } -} diff --git a/graphrag/cache/factory.py b/graphrag/cache/factory.py index f44c68953b..efdeb09fde 100644 --- a/graphrag/cache/factory.py +++ b/graphrag/cache/factory.py @@ -10,6 +10,7 @@ from graphrag.config.enums import CacheType from graphrag.storage.blob_pipeline_storage import create_blob_storage from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage +from graphrag.storage.documentdb_pipeline_storage import create_documentdb_storage from graphrag.storage.file_pipeline_storage import FilePipelineStorage if TYPE_CHECKING: @@ -56,6 +57,8 @@ def create_cache( return JsonPipelineCache(create_blob_storage(**kwargs)) case CacheType.cosmosdb: return JsonPipelineCache(create_cosmosdb_storage(**kwargs)) + case CacheType.documentdb: + return JsonPipelineCache(create_documentdb_storage(**kwargs)) case _: if cache_type in cls.cache_types: return cls.cache_types[cache_type](**kwargs) diff --git a/graphrag/config/enums.py b/graphrag/config/enums.py index 450ec6bda7..e3b363bd65 100644 --- a/graphrag/config/enums.py +++ b/graphrag/config/enums.py @@ -21,6 +21,8 @@ class CacheType(str, Enum): """The blob cache configuration type.""" cosmosdb = "cosmosdb" """The cosmosdb cache configuration type""" + documentdb = "documentdb" + """The documentdb cache configuration type""" def __repr__(self): """Get a string representation.""" @@ -64,6 +66,8 @@ class OutputType(str, Enum): """The blob output type.""" cosmosdb = "cosmosdb" """The cosmosdb output type""" + documentdb = "documentdb" + """The documentdb output type""" def __repr__(self): """Get a string representation.""" diff --git a/graphrag/config/init_content.py b/graphrag/config/init_content.py index c0d4cefcfe..35ecde0af1 100644 --- a/graphrag/config/init_content.py +++ b/graphrag/config/init_content.py @@ -80,15 +80,15 @@ ## connection_string and container_name must be provided cache: - type: {defs.CACHE_TYPE.value} # [file, blob, cosmosdb] + type: {defs.CACHE_TYPE.value} # [file, blob, cosmosdb, documentdb] base_dir: "{defs.CACHE_BASE_DIR}" reporting: - type: {defs.REPORTING_TYPE.value} # [file, blob, cosmosdb] + type: {defs.REPORTING_TYPE.value} # [file, blob, cosmosdb, documentdb] base_dir: "{defs.REPORTING_BASE_DIR}" output: - type: {defs.OUTPUT_TYPE.value} # [file, blob, cosmosdb] + type: {defs.OUTPUT_TYPE.value} # [file, blob, cosmosdb, documentdb] base_dir: "{defs.OUTPUT_BASE_DIR}" ### Workflow settings ### diff --git a/graphrag/config/models/cache_config.py b/graphrag/config/models/cache_config.py index fb4d22b229..ff2f5e8577 100644 --- a/graphrag/config/models/cache_config.py +++ b/graphrag/config/models/cache_config.py @@ -30,3 +30,6 @@ class CacheConfig(BaseModel): cosmosdb_account_url: str | None = Field( description="The cosmosdb account url to use.", default=None ) + documentdb_account_url: str | None = Field( + description="The documentdb account url to use.", default=None + ) diff --git a/graphrag/config/models/output_config.py b/graphrag/config/models/output_config.py index c237ef2567..7d83ef47db 100644 --- a/graphrag/config/models/output_config.py +++ b/graphrag/config/models/output_config.py @@ -31,3 +31,6 @@ class OutputConfig(BaseModel): cosmosdb_account_url: str | None = Field( description="The cosmosdb account url to use.", default=None ) + documentdb_account_url: str | None = Field( + description="The documentdb account url to use.", default=None + ) diff --git a/graphrag/config/models/vector_store_config.py b/graphrag/config/models/vector_store_config.py index 0bda54650c..e848dc5517 100644 --- a/graphrag/config/models/vector_store_config.py +++ b/graphrag/config/models/vector_store_config.py @@ -50,11 +50,17 @@ def _validate_url(self) -> None: ): msg = "vector_store.url is required when vector_store.type == cosmos_db. Please rerun `graphrag init` and select the correct vector store type." raise ValueError(msg) + + if self.type == VectorStoreType.DocumentDB and ( + self.url is None or self.url.strip() == "" + ): + msg = "vector_store.url is required when vector_store.type == document_db. Please rerun `graphrag init` and select the correct vector store type." + raise ValueError(msg) if self.type == VectorStoreType.LanceDB and ( self.url is not None and self.url.strip() != "" ): - msg = "vector_store.url is only used when vector_store.type == azure_ai_search or vector_store.type == cosmos_db. Please rerun `graphrag init` and select the correct vector store type." + msg = "vector_store.url is only used when vector_store.type == azure_ai_search or vector_store.type == cosmos_db or vector_store.type == document_db. Please rerun `graphrag init` and select the correct vector store type." raise ValueError(msg) api_key: str | None = Field( @@ -73,7 +79,7 @@ def _validate_url(self) -> None: ) database_name: str | None = Field( - description="The database name to use when type == cosmos_db.", default=None + description="The database name to use when type == cosmos_db or document_db.", default=None ) overwrite: bool = Field( diff --git a/graphrag/storage/documentdb_pipeline_storage.py b/graphrag/storage/documentdb_pipeline_storage.py new file mode 100644 index 0000000000..3f7c46c60b --- /dev/null +++ b/graphrag/storage/documentdb_pipeline_storage.py @@ -0,0 +1,328 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Azure DocumentDB Storage implementation of PipelineStorage.""" + +import json +import logging +import re +from collections.abc import Iterator +from datetime import datetime, timezone +from io import BytesIO, StringIO +from typing import Any + +import pandas as pd + +from graphrag.logger.base import ProgressLogger +from graphrag.logger.progress import Progress +from graphrag.storage.pipeline_storage import ( + PipelineStorage, + get_timestamp_formatted_with_local_tz, +) + +import psycopg2 +from psycopg2.extras import RealDictCursor, Json + +log = logging.getLogger(__name__) + + +class DocumentDBPipelineStorage(PipelineStorage): + """The DocumentDB Storage Implementation.""" + + _connection: psycopg2.extensions.connection + _cursor: psycopg2.extensions.cursor + _database_name: str + _collection: str + _encoding: str + + def __init__( + self, + database_name: str, + collection: str, + user: str, + password: str, + host: str = "localhost", + port: int = 5432, + encoding: str = "utf-8", + ): + """Initialize the DocumentDB Storage.""" + self._connection = psycopg2.connect( + dbname="postgres", + user=user, + password=password, + host=host, + port=port + ) + self._cursor = self._connection.cursor(cursor_factory=RealDictCursor) + self._encoding = encoding + self._database_name = database_name + self._collection = collection + log.info( + "creating documentdb storage with database: %s and table: %s", + self._database_name, + self._collection, + ) + + set_first = 'SET search_path TO documentdb_api, documentdb_core;' + self._cursor.execute(set_first) + self._connection.commit() + self._create_collection() + + def _create_collection(self) -> None: + """Create the table if it doesn't exist.""" + self._cursor.execute(f""" + SELECT documentdb_api.create_collection('{self._database_name}', '{self._collection}'); + """) + self._connection.commit() + + def _delete_collection(self) -> None: + """Delete the table if it exists.""" + self._cursor.execute(f""" + SELECT documentdb_api.drop_collection('{self._database_name}', '{self._collection}'); + """) + self._connection.commit() + + def find( + self, + file_pattern: re.Pattern[str], + base_dir: str | None = None, + progress: ProgressLogger | None = None, + file_filter: dict[str, Any] | None = None, + max_count=-1, + ) -> Iterator[tuple[str, dict[str, Any]]]: + """Find documents in a DocumentDB table using a file pattern regex and custom file filter (optional).""" + base_dir = base_dir or "" + log.info( + "search table %s for documents matching %s", + self._collection, + file_pattern.pattern, + ) + + if not self._connection or not self._cursor: + return + + try: + find_query = { + "find" : self._collection, + "filter" : { + "$and": [ + { + "key": { + "$regex": file_pattern.pattern + } + } + ] + } + } + + if file_filter: + for key, value in file_filter.items(): + find_query["filter"]["$and"].append({key: value}) + + if max_count > 0: + find_query["$limit"] = max_count + + query = f""" + SELECT cursorPage->>'cursor.firstBatch' AS results + FROM documentdb_api.find_cursor_first_page('{self._database_name}', {Json(find_query)}); + """ + + self._cursor.execute(query) + item = self._cursor.fetchone() + items = json.loads(item.get('results', '[]')) + num_loaded = 0 + num_total = len(items) + + if num_total == 0: + return + + num_filtered = 0 + for item in items: + key = item["key"] + match = file_pattern.search(key) + if match: + group = match.groupdict() + yield (key, group) + num_loaded += 1 + if max_count > 0 and num_loaded >= max_count: + break + else: + num_filtered += 1 + else: + num_filtered += 1 + if progress is not None: + progress( + _create_progress_status(num_loaded, num_filtered, num_total) + ) + except Exception as ex: + log.exception( + "An error occurred while searching for documents in Document DB." + ) + + async def get( + self, key: str, as_bytes: bool | None = None, encoding: str | None = None + ) -> Any: + """Fetch an item from the table that matches the given key.""" + try: + find_query = { + "find" : self._collection, + "filter" : { + "key": key + } + } + + query = f""" + SELECT cursorPage->>'cursor.firstBatch' AS results + FROM documentdb_api.find_cursor_first_page('{self._database_name}', {Json(find_query)}); + """ + + self._cursor.execute(query) + item = self._cursor.fetchone() + items = json.loads(item.get('results', '[]')) + if len(items) == 0: + return None + + item = items[0] + + if item: + return json.dumps(item["value"]) + return None + except Exception: + log.exception("Error reading item %s", key) + return None + + async def set(self, key: str, value: Any, encoding: str | None = None) -> None: + """Insert or update the contents of a file into the DocumentDB table for the given filename key.""" + try: + insert_query = { + "key": key, + "value": json.loads(value), + "created_at": datetime.now(timezone.utc).isoformat() + } + self._cursor.execute(f""" + SELECT documentdb_api.insert_one('{self._database_name}', '{self._collection}', {Json(insert_query)}); + """) + self._connection.commit() + except Exception: + log.exception("Error writing item %s", key) + + async def has(self, key: str) -> bool: + """Check if the contents of the given filename key exist in the DocumentDB table.""" + aggregate_query = { + "aggregate": self._collection, + "pipeline": [ + { "$match": { + "key": key + } + }, + { + "$count": "key" + } + ] , "cursor": { "batchSize": 1 } } + + self._cursor.execute(f""" + SELECT jsonb_extract_path_text(results::jsonb, '0', 'key')::int AS count + FROM ( + SELECT cursorPage->>'cursor.firstBatch' AS results + FROM documentdb_api.aggregate_cursor_first_page('{self._database_name}', {Json(aggregate_query)}) + ); + """) + result = self._cursor.fetchone() + return result.get('count', 0) > 0 + + async def delete(self, key: str) -> None: + """Delete the item with the given filename key from the DocumentDB table.""" + try: + delete_query = { + "delete": self._collection, + "deletes": [ + { + "q": { + "key": key + }, "limit": 1 + } + ] + } + self._cursor.execute(f"SELECT documentdb_api.delete('{self._database_name}', {Json(delete_query)});") + self._connection.commit() + except Exception: + log.exception("Error deleting item %s", key) + + async def clear(self) -> None: + """Clear all contents from storage.""" + self._delete_collection() + self._create_collection() + + def keys(self) -> list[str]: + """Return the keys in the storage.""" + self._cursor.execute(f"SELECT key FROM {self._table_name}") + return [row["key"] for row in self._cursor.fetchall()] + + def child(self, name: str | None) -> PipelineStorage: + """Create a child storage instance.""" + return self + + def _get_prefix(self, key: str) -> str: + """Get the prefix of the filename key.""" + return key.split(".")[0] + + async def get_creation_date(self, key: str) -> str: + """Get the creation date of the item with the given key.""" + try: + find_query = { + "find" : self._collection, + "filter" : { + "key": key + } + } + + query = f""" + SELECT cursorPage->>'cursor.firstBatch' AS results + FROM documentdb_api.find_cursor_first_page('{self._database_name}', {Json(find_query)}); + """ + + self._cursor.execute(query) + item = self._cursor.fetchone() + items = json.loads(item.get('results', '[]')) + if len(items) == 0: + return "" + + item = items[0] + + if item: + return get_timestamp_formatted_with_local_tz( + datetime.fromisoformat(item["created_at"]) + ) + return "" + except Exception: + log.exception("Error getting key %s", key) + return "" + + +def create_documentdb_storage(**kwargs: Any) -> PipelineStorage: + """Create a DocumentDB storage instance.""" + log.info("Creating postgres storage") + database_name = kwargs["database_name"] + collection = kwargs["collection"] + user = kwargs["user"] + password = kwargs["password"] + host = kwargs.get("host", "localhost") + port = kwargs.get("port", 5432) + return DocumentDBPipelineStorage( + database_name=database_name, + collection=collection, + user=user, + password=password, + host=host, + port=port, + ) + + +def _create_progress_status( + num_loaded: int, num_filtered: int, num_total: int +) -> Progress: + return Progress( + total_items=num_total, + completed_items=num_loaded + num_filtered, + description=f"{num_loaded} files loaded ({num_filtered} filtered)", + ) diff --git a/graphrag/storage/factory.py b/graphrag/storage/factory.py index 8a6e0df4d6..a1d8c9601f 100644 --- a/graphrag/storage/factory.py +++ b/graphrag/storage/factory.py @@ -10,6 +10,7 @@ from graphrag.config.enums import OutputType from graphrag.storage.blob_pipeline_storage import create_blob_storage from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage +from graphrag.storage.documentdb_pipeline_storage import create_documentdb_storage from graphrag.storage.file_pipeline_storage import create_file_storage from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage @@ -43,6 +44,8 @@ def create_storage( return create_blob_storage(**kwargs) case OutputType.cosmosdb: return create_cosmosdb_storage(**kwargs) + case OutputType.documentdb: + return create_documentdb_storage(**kwargs) case OutputType.file: return create_file_storage(**kwargs) case OutputType.memory: diff --git a/graphrag/vector_stores/documentdb.py b/graphrag/vector_stores/documentdb.py new file mode 100644 index 0000000000..6d7e70e6af --- /dev/null +++ b/graphrag/vector_stores/documentdb.py @@ -0,0 +1,281 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the DocumentDB vector store implementation.""" + +import json +from typing import Any +import psycopg2 +from psycopg2 import sql +from psycopg2.extras import Json + +from graphrag.model.types import TextEmbedder +from graphrag.vector_stores.base import ( + DEFAULT_VECTOR_SIZE, + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) + +class DocumentDBVectoreStore(BaseVectorStore): + """Microsoft DocumentB (PostgreSQL) vector storage implementation.""" + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.db_name = kwargs.get("database_name", "documentdb") + self.vector_options = kwargs.get("vector_options", { "kind": "vector-ivf", "similarity": "COS", "dimensions": DEFAULT_VECTOR_SIZE, "numLists": 3 }) + + def connect(self, **kwargs: Any) -> Any: + """Connect to DocumentDB (PostgreSQL) vector storage.""" + user = kwargs.get("user") + password = kwargs.get("password") + host = kwargs.get("host") + port = kwargs.get("port", 5432) + + if not all([self.db_name, user, password, host]): + raise ValueError("Database credentials must be provided.") + + self.db_connection = psycopg2.connect( + dbname="postgres", user=user, password=password, host=host, port=port + ) + cursor = self.db_connection.cursor() + + # replace this with a more general solution + x = sql.SQL("SET search_path TO documentdb_api, documentdb_core; SET documentdb_core.bsonUseEJson TO true;") + cursor.execute(x) + self.db_connection.commit() + + coll_query = { + "listCollections": 1, + "filter" : { + "collection_name": self.collection_name + } + } + try: + query = sql.SQL("SELECT cursorpage->>'cursor.firstBatch' FROM documentdb_api.list_collections_cursor_first_page(%s, %s);") + cursor.execute(query, [self.db_name, Json(coll_query)]) + result = cursor.fetchone() + if result and result[0]: + first_batch = json.loads(result[0]) + if self.collection_name in first_batch: + self.document_collection = self.collection_name + else: + self.document_collection = self.create_collection() + except Exception as e: + self.document_collection = self.collection_name + + def load_documents( + self, documents: list[VectorStoreDocument], overwrite: bool = True + ) -> None: + """Load documents into vector storage.""" + data = [ + { + "id": document.id, + "text": document.text, + "vector": document.vector, + "attributes": json.dumps(document.attributes), + } + for document in documents + if document.vector is not None + ] + + if len(data) == 0: + data = None + + # NOTE: If modifying the next section of code, ensure that the schema remains the same. + # The pyarrow format of the 'vector' field may change if the order of operations is changed + # and will break vector search. + cursor = self.db_connection.cursor() + + if overwrite: + drop_query = sql.SQL("SELECT * FROM documentdb_api.drop_collection(%s, %s);") + create_query = sql.SQL("SELECT * FROM documentdb_api.create_collection(%s, %s);") + + cursor.execute(drop_query, [self.db_name, self.collection_name]) + self.db_connection.commit() + cursor.execute(create_query, [self.db_name, self.collection_name]) + self.db_connection.commit() + + self.create_vector_index("vector_index", "vector") + + if data: + for doc in data: + insert_query = sql.SQL("SELECT * FROM documentdb_api.insert_document(%s, %s, %s, %s, %s);") + cursor.execute(insert_query, [self.db_name, self.collection_name, Json(doc)]) + self.db_connection.commit() + + def create_collection(self) -> str: + """Create a collection in the database.""" + cursor = self.db_connection.cursor() + create_collection_query = sql.SQL("SELECT * FROM documentdb_api.create_collection(%s, %s);") + cursor.execute(create_collection_query, [self.db_name, self.collection_name]) + self.db_connection.commit() + + self.create_vector_index("vector_index", "vector") + return self.collection_name + + def create_vector_index(self, index_name: str, index_field: str) -> None: + """Create an index on the collection.""" + index_query = { + "createIndexes": self.collection_name, + "indexes": [ + { + "name": index_name, + "key": { + index_field: "cosmosSearch" + }, + "cosmosSearchOptions": self.vector_options + } + ] + } + cursor = self.db_connection.cursor() + # create_index_query = sql.SQL("SELECT documentdb_api.create_indexes_background(%s, %s);") + create_index_query = sql.SQL("SELECT documentdb_api_internal.create_indexes_non_concurrently(%s, %s, true);") + cursor.execute(create_index_query, [self.db_name, Json(index_query)]) + self.db_connection.commit() + + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + """Build a query filter to filter documents by id.""" + if len(include_ids) == 0: + self.query_filter = None + else: + if isinstance(include_ids[0], str): + self.query_filter = { + "id": { + "$in": [f"'{id}'" for id in include_ids] + } + } + else: + self.query_filter = { + "id": { + "$in": include_ids + } + } + + return self.query_filter + + def similarity_search_by_vector( + self, query_embedding: list[float], k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a vector-based similarity search.""" + search_filter = { + "aggregate": self.collection_name, + "pipeline": [ + { + "$search": { + "cosmosSearch": { + "vector": query_embedding, + "path": "vector", + "k": k, + }, + "returnStoredSource": True + } + } + ], + "cursor": { + "batchSize": k + } + } + + if self.query_filter: + search_filter["pipeline"].insert(0, {"filter": self.query_filter}) + + cursor = self.db_connection.cursor() + search_query = sql.SQL("SELECT cursorpage->>'cursor.firstBatch' FROM documentdb_api.aggregate_cursor_first_page(%s, %s);") + cursor.execute(search_query, [self.db_name, Json(search_filter)]) + results = cursor.fetchall() + docs = [json.loads(result[0]) for result in results if result[0]] + + return [ + VectorStoreSearchResult( + document=VectorStoreDocument( + id=doc["id"], + text=doc["text"], + vector=doc["vector"], + attributes=json.loads(doc["attributes"]), + ), + score=1 - abs(float(doc["_distance"])), + ) + for doc in docs + ] + + def similarity_search_by_text( + self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any + ) -> list[VectorStoreSearchResult]: + """Perform a similarity search using a given input text.""" + query_embedding = text_embedder(text) + if query_embedding: + return self.similarity_search_by_vector(query_embedding, k) + return [] + + def search_by_id(self, id: str) -> VectorStoreDocument: + """Search for a document by id.""" + find_filter = { + "find": self.collection_name, + "filter": { + "id": id + } + } + + cursor = self.db_connection.cursor() + search_query = sql.SQL("SELECT cursorpage->>'cursor.firstBatch' FROM documentdb_api.find_cursor_first_page(%s, %s);") + cursor.execute(search_query, [self.db_name, Json(find_filter)]) + result = cursor.fetchone() + doc = json.loads(result[0]) if result and result[0] else None + + if doc: + return VectorStoreDocument( + id=doc[0]["id"], + text=doc[0]["text"], + vector=doc[0]["vector"], + attributes=json.loads(doc[0]["attributes"]), + ) + return VectorStoreDocument(id=id, text=None, vector=None) + + +if __name__ == "__main__": + from graphrag.model.entity import Entity + entities = [ + Entity( + id="2da37c7a-50a8-44d4-aa2c-fd401e19976c", + short_id="sid1", + title="t1", + rank=2, + ), + Entity( + id="c4f93564-4507-4ee4-b102-98add401a965", + short_id="sid2", + title="t22", + rank=4, + ), + Entity( + id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", + short_id="sid3", + title="t333", + rank=1, + ), + Entity( + id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83", + short_id="sid4", + title="t4444", + rank=3, + ), + ] + documents = [VectorStoreDocument(id=entity.id, text=entity.title, vector=[0]) for entity in entities] + + + kwargs = { + "collection_name": "default", + "database_name": "documentdb", + "vector_options": { "kind": "vector-ivf", "similarity": "COS", "dimensions": DEFAULT_VECTOR_SIZE, "numLists": 3 } + } + + store = DocumentDBVectoreStore(**kwargs) + store.connect(**{ + "user": "admin", + "password": "admin", + "host": "host.docker.internal", + "port": 9712 + }) + store.load_documents(documents, overwrite=False) + store.search_by_id("1") \ No newline at end of file diff --git a/graphrag/vector_stores/factory.py b/graphrag/vector_stores/factory.py index 1c37316d0c..066a1d2d94 100644 --- a/graphrag/vector_stores/factory.py +++ b/graphrag/vector_stores/factory.py @@ -9,6 +9,7 @@ from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore from graphrag.vector_stores.base import BaseVectorStore from graphrag.vector_stores.cosmosdb import CosmosDBVectoreStore +from graphrag.vector_stores.documentdb import DocumentDBVectoreStore from graphrag.vector_stores.lancedb import LanceDBVectorStore @@ -18,6 +19,7 @@ class VectorStoreType(str, Enum): LanceDB = "lancedb" AzureAISearch = "azure_ai_search" CosmosDB = "cosmosdb" + DocumentDB = "documentdb" class VectorStoreFactory: @@ -45,6 +47,8 @@ def create_vector_store( return AzureAISearchVectorStore(**kwargs) case VectorStoreType.CosmosDB: return CosmosDBVectoreStore(**kwargs) + case VectorStoreType.DocumentDB: + return DocumentDBVectoreStore(**kwargs) case _: if vector_store_type in cls.vector_store_types: return cls.vector_store_types[vector_store_type](**kwargs) diff --git a/tests/integration/storage/test_documentdb_storage.py b/tests/integration/storage/test_documentdb_storage.py new file mode 100644 index 0000000000..c555953d17 --- /dev/null +++ b/tests/integration/storage/test_documentdb_storage.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""CosmosDB Storage Tests.""" + +import json +import re +import sys +from datetime import datetime + +import pytest + +from graphrag.storage.documentdb_pipeline_storage import DocumentDBPipelineStorage + + +async def test_find(): + storage = DocumentDBPipelineStorage( + database_name = "documentdb", + collection = "testfindtable", + user = "admin", + password = "admin", + host = "host.docker.internal", + port = 9712, + ) + try: + try: + items = list(storage.find(file_pattern=re.compile(r".*\.json$"))) + items = [item[0] for item in items] + assert items == [] + + json_content = { + "content": "Merry Christmas!", + } + await storage.set( + "christmas.json", json.dumps(json_content), encoding="utf-8" + ) + items = list(storage.find(file_pattern=re.compile(r".*\.json$"))) + items = [item[0] for item in items] + assert items == ["christmas.json"] + + json_content = { + "content": "Hello, World!", + } + await storage.set("test.json", json.dumps(json_content), encoding="utf-8") + items = list(storage.find(file_pattern=re.compile(r".*\.json$"))) + items = [item[0] for item in items] + assert items == ["christmas.json", "test.json"] + + items = list(storage.find(file_pattern=re.compile(r".*\.json$"), file_filter={"key": "test.json"})) + items = [item[0] for item in items] + assert items == ["test.json"] + + output = await storage.get("test.json") + output_json = json.loads(output) + assert output_json["content"] == "Hello, World!" + + json_exists = await storage.has("christmas.json") + assert json_exists is True + json_exists = await storage.has("easter.json") + assert json_exists is False + finally: + await storage.delete("test.json") + output = await storage.get("test.json") + assert output is None + finally: + await storage.clear() + + +async def test_child(): + storage = DocumentDBPipelineStorage( + database_name = "documentdb", + collection = "testfindtable", + user = "admin", + password = "admin", + host = "host.docker.internal", + port = 9712, + ) + try: + child_storage = storage.child("child") + assert type(child_storage) is DocumentDBPipelineStorage + finally: + await storage.clear() + + +async def test_clear(): + storage = DocumentDBPipelineStorage( + database_name = "documentdb", + collection = "testfindtable", + user = "admin", + password = "admin", + host = "host.docker.internal", + port = 9712, + ) + try: + json_exists = { + "content": "Merry Christmas!", + } + await storage.set("christmas.json", json.dumps(json_exists), encoding="utf-8") + json_exists = { + "content": "Happy Easter!", + } + await storage.set("easter.json", json.dumps(json_exists), encoding="utf-8") + await storage.clear() + + items = list(storage.find(file_pattern=re.compile(r".*\.json$"))) + items = [item[0] for item in items] + assert items == [] + + output = await storage.get("easter.json") + assert output is None + finally: + # await storage.clear() + print("Table cleared") + + +async def test_get_creation_date(): + storage = DocumentDBPipelineStorage( + database_name = "documentdb", + collection = "testfindtable", + user = "admin", + password = "admin", + host = "host.docker.internal", + port = 9712, + ) + try: + json_content = { + "content": "Happy Easter!", + } + await storage.set("easter.json", json.dumps(json_content), encoding="utf-8") + + creation_date = await storage.get_creation_date("easter.json") + + datetime_format = "%Y-%m-%d %H:%M:%S %z" + parsed_datetime = datetime.strptime(creation_date, datetime_format).astimezone() + + assert parsed_datetime.strftime(datetime_format) == creation_date + + finally: + await storage.clear() diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 5be62e9016..b2d5e79e26 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -12,6 +12,7 @@ from graphrag.config.enums import OutputType from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage +from graphrag.storage.documentdb_pipeline_storage import DocumentDBPipelineStorage from graphrag.storage.factory import StorageFactory from graphrag.storage.file_pipeline_storage import FilePipelineStorage from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage @@ -47,6 +48,19 @@ def test_create_cosmosdb_storage(): storage = StorageFactory.create_storage(OutputType.cosmosdb, kwargs) assert isinstance(storage, CosmosDBPipelineStorage) +def test_create_documentdb_storage(): + kwargs = { + "type": "documentdb", + "database_name": "postgres", + "collection": "testtable", + "user": "admin", + "password": "admin", + "host": "host.docker.internal", + "port": 9712, + } + storage = StorageFactory.create_storage(OutputType.documentdb, kwargs) + assert isinstance(storage, DocumentDBPipelineStorage) + def test_create_file_storage(): kwargs = {"type": "file", "base_dir": "/tmp/teststorage"} diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index b1a92e6a89..d26cc069fe 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -89,6 +89,7 @@ "container_name": None, "storage_account_blob_url": None, "cosmosdb_account_url": None, + "documentdb_account_url": None, }, "input": { "type": defs.INPUT_TYPE, @@ -321,6 +322,7 @@ def assert_output_configs(actual: OutputConfig, expected: OutputConfig) -> None: assert expected.container_name == actual.container_name assert expected.storage_account_blob_url == actual.storage_account_blob_url assert expected.cosmosdb_account_url == actual.cosmosdb_account_url + assert expected.documentdb_account_url == actual.documentdb_account_url def assert_update_output_configs(actual: OutputConfig, expected: OutputConfig) -> None: @@ -330,6 +332,7 @@ def assert_update_output_configs(actual: OutputConfig, expected: OutputConfig) - assert expected.container_name == actual.container_name assert expected.storage_account_blob_url == actual.storage_account_blob_url assert expected.cosmosdb_account_url == actual.cosmosdb_account_url + assert expected.documentdb_account_url == actual.documentdb_account_url def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None: @@ -339,6 +342,7 @@ def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None: assert actual.container_name == expected.container_name assert actual.storage_account_blob_url == expected.storage_account_blob_url assert actual.cosmosdb_account_url == expected.cosmosdb_account_url + assert actual.documentdb_account_url == expected.documentdb_account_url def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None: From f0dd652a9519a67f16c595b1cbb2aa9103b33980 Mon Sep 17 00:00:00 2001 From: Gijs Segerink Date: Thu, 20 Feb 2025 08:38:28 +0100 Subject: [PATCH 2/3] .vscode restore --- .gitignore | 1 - .vscode/extensions.json | 12 ++++++++++ .vscode/launch.json | 39 +++++++++++++++++++++++++++++++ .vscode/settings.json | 52 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 103 insertions(+), 1 deletion(-) create mode 100644 .vscode/extensions.json create mode 100644 .vscode/launch.json create mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index e75a076761..707bc44711 100644 --- a/.gitignore +++ b/.gitignore @@ -20,7 +20,6 @@ output/lancedb venv/ .conda .tmp -.vscode .env build.zip diff --git a/.vscode/extensions.json b/.vscode/extensions.json new file mode 100644 index 0000000000..2e5e67a214 --- /dev/null +++ b/.vscode/extensions.json @@ -0,0 +1,12 @@ +{ + "recommendations": [ + "arcanis.vscode-zipfs", + "ms-python.python", + "charliermarsh.ruff", + "ms-python.vscode-pylance", + "bierner.markdown-mermaid", + "streetsidesoftware.code-spell-checker", + "ronnidc.nunjucks", + "lucien-martijn.parquet-visualizer", + ] +} diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000000..2167063966 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,39 @@ +{ + "_comment": "Use this file to configure the graphrag project for debugging. You may create other configuration profiles based on these or select one below to use.", + "version": "0.2.0", + "configurations": [ + { + "name": "Indexer", + "type": "debugpy", + "request": "launch", + "module": "poetry", + "args": [ + "poe", "index", + "--root", "" + ], + }, + { + "name": "Query", + "type": "debugpy", + "request": "launch", + "module": "poetry", + "args": [ + "poe", "query", + "--root", "", + "--method", "global", + "--query", "What are the top themes in this story", + ] + }, + { + "name": "Prompt Tuning", + "type": "debugpy", + "request": "launch", + "module": "poetry", + "args": [ + "poe", "prompt-tune", + "--config", + "/settings.yaml", + ] + } + ] +} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000000..0b678d5d95 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,52 @@ +{ + "search.exclude": { + "**/.yarn": true, + "**/.pnp.*": true + }, + "editor.formatOnSave": false, + "eslint.nodePath": ".yarn/sdks", + "typescript.tsdk": ".yarn/sdks/typescript/lib", + "typescript.enablePromptUseWorkspaceTsdk": true, + "javascript.preferences.importModuleSpecifier": "relative", + "javascript.preferences.importModuleSpecifierEnding": "js", + "typescript.preferences.importModuleSpecifier": "relative", + "typescript.preferences.importModuleSpecifierEnding": "js", + "explorer.fileNesting.enabled": true, + "explorer.fileNesting.patterns": { + "*.ts": "${capture}.ts, ${capture}.hooks.ts, ${capture}.hooks.tsx, ${capture}.contexts.ts, ${capture}.stories.tsx, ${capture}.story.tsx, ${capture}.spec.tsx, ${capture}.base.ts, ${capture}.base.tsx, ${capture}.types.ts, ${capture}.styles.ts, ${capture}.styles.tsx, ${capture}.utils.ts, ${capture}.utils.tsx, ${capture}.constants.ts, ${capture}.module.scss, ${capture}.module.css, ${capture}.md", + "*.js": "${capture}.js.map, ${capture}.min.js, ${capture}.d.ts", + "*.jsx": "${capture}.js", + "*.tsx": "${capture}.ts, ${capture}.hooks.ts, ${capture}.hooks.tsx, ${capture}.contexts.ts, ${capture}.stories.tsx, ${capture}.story.tsx, ${capture}.spec.tsx, ${capture}.base.ts, ${capture}.base.tsx, ${capture}.types.ts, ${capture}.styles.ts, ${capture}.styles.tsx, ${capture}.utils.ts, ${capture}.utils.tsx, ${capture}.constants.ts, ${capture}.module.scss, ${capture}.module.css, ${capture}.md, ${capture}.css", + "tsconfig.json": "tsconfig.*.json", + "package.json": "package-lock.json, turbo.json, tsconfig.json, rome.json, biome.json, .npmignore, dictionary.txt, cspell.config.yaml", + "README.md": "*.md, LICENSE, CODEOWNERS", + ".eslintrc": ".eslintignore", + ".prettierrc": ".prettierignore", + ".gitattributes": ".gitignore", + ".yarnrc.yml": "yarn.lock, .pnp.*", + "jest.config.js": "jest.setup.mjs", + "pyproject.toml": "poetry.lock, poetry.toml, mkdocs.yaml", + "cspell.config.yaml": "dictionary.txt" + }, + "azureFunctions.postDeployTask": "npm install (functions)", + "azureFunctions.projectLanguage": "TypeScript", + "azureFunctions.projectRuntime": "~4", + "debug.internalConsoleOptions": "neverOpen", + "azureFunctions.preDeployTask": "npm prune (functions)", + "appService.zipIgnorePattern": [ + "node_modules{,/**}", + ".vscode{,/**}" + ], + "python.defaultInterpreterPath": "python/services/.venv/bin/python", + "python.languageServer": "Pylance", + "cSpell.customDictionaries": { + "project-words": { + "name": "project-words", + "path": "${workspaceRoot}/dictionary.txt", + "description": "Words used in this project", + "addWords": true + }, + "custom": true, // Enable the `custom` dictionary + "internal-terms": true // Disable the `internal-terms` dictionary + } +} From df161c4ce962bbb5119692ea3903236226666120 Mon Sep 17 00:00:00 2001 From: Gijs Segerink Date: Thu, 20 Feb 2025 11:00:09 +0100 Subject: [PATCH 3/3] test keys --- .vscode/settings.json | 7 +- .../storage/documentdb_pipeline_storage.py | 138 +++++++------ graphrag/vector_stores/documentdb.py | 192 +++++++----------- .../storage/test_documentdb_storage.py | 24 +++ 4 files changed, 175 insertions(+), 186 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 0b678d5d95..2aa700444c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -48,5 +48,10 @@ }, "custom": true, // Enable the `custom` dictionary "internal-terms": true // Disable the `internal-terms` dictionary - } + }, + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } diff --git a/graphrag/storage/documentdb_pipeline_storage.py b/graphrag/storage/documentdb_pipeline_storage.py index 3f7c46c60b..54e4f98e04 100644 --- a/graphrag/storage/documentdb_pipeline_storage.py +++ b/graphrag/storage/documentdb_pipeline_storage.py @@ -3,15 +3,17 @@ """Azure DocumentDB Storage implementation of PipelineStorage.""" +import sys import json import logging import re from collections.abc import Iterator from datetime import datetime, timezone -from io import BytesIO, StringIO -from typing import Any +from typing import Any, Dict, List, Optional, Pattern, Tuple, Union import pandas as pd +import psycopg2 +from psycopg2.extras import RealDictCursor, Json from graphrag.logger.base import ProgressLogger from graphrag.logger.progress import Progress @@ -20,12 +22,8 @@ get_timestamp_formatted_with_local_tz, ) -import psycopg2 -from psycopg2.extras import RealDictCursor, Json - log = logging.getLogger(__name__) - class DocumentDBPipelineStorage(PipelineStorage): """The DocumentDB Storage Implementation.""" @@ -44,7 +42,7 @@ def __init__( host: str = "localhost", port: int = 5432, encoding: str = "utf-8", - ): + ) -> None: """Initialize the DocumentDB Storage.""" self._connection = psycopg2.connect( dbname="postgres", @@ -72,24 +70,24 @@ def _create_collection(self) -> None: """Create the table if it doesn't exist.""" self._cursor.execute(f""" SELECT documentdb_api.create_collection('{self._database_name}', '{self._collection}'); - """) + """) self._connection.commit() def _delete_collection(self) -> None: """Delete the table if it exists.""" self._cursor.execute(f""" SELECT documentdb_api.drop_collection('{self._database_name}', '{self._collection}'); - """) + """) self._connection.commit() def find( self, - file_pattern: re.Pattern[str], - base_dir: str | None = None, - progress: ProgressLogger | None = None, - file_filter: dict[str, Any] | None = None, - max_count=-1, - ) -> Iterator[tuple[str, dict[str, Any]]]: + file_pattern: Pattern[str], + base_dir: Optional[str] = None, + progress: Optional[ProgressLogger] = None, + file_filter: Optional[Dict[str, Any]] = None, + max_count: int = -1, + ) -> Iterator[Tuple[str, Dict[str, Any]]]: """Find documents in a DocumentDB table using a file pattern regex and custom file filter (optional).""" base_dir = base_dir or "" log.info( @@ -100,16 +98,16 @@ def find( if not self._connection or not self._cursor: return - + try: - find_query = { - "find" : self._collection, - "filter" : { + find_query = { + "find": self._collection, + "filter": { "$and": [ { "key": { "$regex": file_pattern.pattern - } + } } ] } @@ -135,7 +133,7 @@ def find( if num_total == 0: return - + num_filtered = 0 for item in items: key = item["key"] @@ -159,18 +157,16 @@ def find( "An error occurred while searching for documents in Document DB." ) - async def get( - self, key: str, as_bytes: bool | None = None, encoding: str | None = None - ) -> Any: + async def get(self, key: str, as_bytes: Optional[bool] = None, encoding: Optional[str] = None) -> Any: """Fetch an item from the table that matches the given key.""" try: - find_query = { - "find" : self._collection, - "filter" : { + find_query = { + "find": self._collection, + "filter": { "key": key } } - + query = f""" SELECT cursorPage->>'cursor.firstBatch' AS results FROM documentdb_api.find_cursor_first_page('{self._database_name}', {Json(find_query)}); @@ -181,7 +177,7 @@ async def get( items = json.loads(item.get('results', '[]')) if len(items) == 0: return None - + item = items[0] if item: @@ -191,10 +187,10 @@ async def get( log.exception("Error reading item %s", key) return None - async def set(self, key: str, value: Any, encoding: str | None = None) -> None: + async def set(self, key: str, value: Any, encoding: Optional[str] = None) -> None: """Insert or update the contents of a file into the DocumentDB table for the given filename key.""" try: - insert_query = { + insert_query = { "key": key, "value": json.loads(value), "created_at": datetime.now(timezone.utc).isoformat() @@ -208,17 +204,14 @@ async def set(self, key: str, value: Any, encoding: str | None = None) -> None: async def has(self, key: str) -> bool: """Check if the contents of the given filename key exist in the DocumentDB table.""" - aggregate_query = { - "aggregate": self._collection, - "pipeline": [ - { "$match": { - "key": key - } - }, - { - "$count": "key" - } - ] , "cursor": { "batchSize": 1 } } + aggregate_query = { + "aggregate": self._collection, + "pipeline": [ + {"$match": {"key": key}}, + {"$count": "key"} + ], + "cursor": {"batchSize": 1} + } self._cursor.execute(f""" SELECT jsonb_extract_path_text(results::jsonb, '0', 'key')::int AS count @@ -226,7 +219,7 @@ async def has(self, key: str) -> bool: SELECT cursorPage->>'cursor.firstBatch' AS results FROM documentdb_api.aggregate_cursor_first_page('{self._database_name}', {Json(aggregate_query)}) ); - """) + """) result = self._cursor.fetchone() return result.get('count', 0) > 0 @@ -234,12 +227,11 @@ async def delete(self, key: str) -> None: """Delete the item with the given filename key from the DocumentDB table.""" try: delete_query = { - "delete": self._collection, + "delete": self._collection, "deletes": [ { - "q": { - "key": key - }, "limit": 1 + "q": {"key": key}, + "limit": 1 } ] } @@ -253,12 +245,40 @@ async def clear(self) -> None: self._delete_collection() self._create_collection() - def keys(self) -> list[str]: + def keys(self) -> List[str]: """Return the keys in the storage.""" - self._cursor.execute(f"SELECT key FROM {self._table_name}") - return [row["key"] for row in self._cursor.fetchall()] + count_query = { + "aggregate": self._collection, + "pipeline": [ + { "$count": "key" } + ], + "cursor": { "batchSize": 1 } + } - def child(self, name: str | None) -> PipelineStorage: + self._cursor.execute(f""" + SELECT jsonb_extract_path_text(results::jsonb, '0', 'key')::int AS batch_size + FROM ( + SELECT cursorPage->>'cursor.firstBatch' AS results + FROM documentdb_api.aggregate_cursor_first_page( + 'documentdb', + {Json(count_query)} + ) + ) subquery; + """) + result = self._cursor.fetchone() + + keys_query = { + "aggregate": self._collection, + "pipeline": [ + { "$group": + { "_id": "$key", } + } + ] , "cursor": { "batchSize": result.get('batch_size', sys.maxsize) } } + self._cursor.execute(f"SELECT cursorpage->>'cursor.firstBatch' AS result FROM documentdb_api.aggregate_cursor_first_page('documentdb', {Json(keys_query)});") + result = self._cursor.fetchone() + return [row["_id"] for row in json.loads(result['result'])] + + def child(self, name: Optional[str]) -> PipelineStorage: """Create a child storage instance.""" return self @@ -269,13 +289,13 @@ def _get_prefix(self, key: str) -> str: async def get_creation_date(self, key: str) -> str: """Get the creation date of the item with the given key.""" try: - find_query = { - "find" : self._collection, - "filter" : { + find_query = { + "find": self._collection, + "filter": { "key": key } } - + query = f""" SELECT cursorPage->>'cursor.firstBatch' AS results FROM documentdb_api.find_cursor_first_page('{self._database_name}', {Json(find_query)}); @@ -286,7 +306,7 @@ async def get_creation_date(self, key: str) -> str: items = json.loads(item.get('results', '[]')) if len(items) == 0: return "" - + item = items[0] if item: @@ -298,7 +318,6 @@ async def get_creation_date(self, key: str) -> str: log.exception("Error getting key %s", key) return "" - def create_documentdb_storage(**kwargs: Any) -> PipelineStorage: """Create a DocumentDB storage instance.""" log.info("Creating postgres storage") @@ -317,10 +336,7 @@ def create_documentdb_storage(**kwargs: Any) -> PipelineStorage: port=port, ) - -def _create_progress_status( - num_loaded: int, num_filtered: int, num_total: int -) -> Progress: +def _create_progress_status(num_loaded: int, num_filtered: int, num_total: int) -> Progress: return Progress( total_items=num_total, completed_items=num_loaded + num_filtered, diff --git a/graphrag/vector_stores/documentdb.py b/graphrag/vector_stores/documentdb.py index 6d7e70e6af..0cc235f39c 100644 --- a/graphrag/vector_stores/documentdb.py +++ b/graphrag/vector_stores/documentdb.py @@ -4,10 +4,10 @@ """A package containing the DocumentDB vector store implementation.""" import json -from typing import Any +from typing import Any, List, Union import psycopg2 from psycopg2 import sql -from psycopg2.extras import Json +from psycopg2.extras import RealDictCursor, Json from graphrag.model.types import TextEmbedder from graphrag.vector_stores.base import ( @@ -17,15 +17,33 @@ VectorStoreSearchResult, ) -class DocumentDBVectoreStore(BaseVectorStore): - """Microsoft DocumentB (PostgreSQL) vector storage implementation.""" +class DocumentDBVectorStore(BaseVectorStore): + """Microsoft DocumentDB (PostgreSQL) vector storage implementation.""" + + _connection: psycopg2.extensions.connection + _cursor: psycopg2.extensions.cursor def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) self.db_name = kwargs.get("database_name", "documentdb") - self.vector_options = kwargs.get("vector_options", { "kind": "vector-ivf", "similarity": "COS", "dimensions": DEFAULT_VECTOR_SIZE, "numLists": 3 }) - def connect(self, **kwargs: Any) -> Any: + kind = kwargs.get("kind", "vector-ivf") # Possible values: vector-hnsw, vector-ivf. + similarity = kwargs.get("similarity", "COS") # COS for cosine similarity, L2 for Euclidean distance, or IP for inner product. + dimensions = kwargs.get("dimensions", DEFAULT_VECTOR_SIZE) + num_lists = kwargs.get("numLists", 100) # Only IVF index requires this parameter. + m = kwargs.get("m", 16) # Only HNSW index requires this parameter. + ef_construction = kwargs.get("efConstruction", 64) # Only HNSW index requires this parameter. + + self.vector_options = kwargs.get("vector_options", { + "kind": kind, + "similarity": similarity, + "dimensions": dimensions, + "numLists": num_lists, + "m": m, + "efConstruction": ef_construction + }) + + def connect(self, **kwargs: Any) -> None: """Connect to DocumentDB (PostgreSQL) vector storage.""" user = kwargs.get("user") password = kwargs.get("password") @@ -34,39 +52,23 @@ def connect(self, **kwargs: Any) -> Any: if not all([self.db_name, user, password, host]): raise ValueError("Database credentials must be provided.") - - self.db_connection = psycopg2.connect( - dbname="postgres", user=user, password=password, host=host, port=port + + self._connection = psycopg2.connect( + dbname="postgres", + user=user, + password=password, + host=host, + port=port ) - cursor = self.db_connection.cursor() + self._cursor = self._connection.cursor(cursor_factory=RealDictCursor) - # replace this with a more general solution - x = sql.SQL("SET search_path TO documentdb_api, documentdb_core; SET documentdb_core.bsonUseEJson TO true;") - cursor.execute(x) - self.db_connection.commit() + set_first = 'SET search_path TO documentdb_api, documentdb_core;' + self._cursor.execute(set_first) + self._connection.commit() - coll_query = { - "listCollections": 1, - "filter" : { - "collection_name": self.collection_name - } - } - try: - query = sql.SQL("SELECT cursorpage->>'cursor.firstBatch' FROM documentdb_api.list_collections_cursor_first_page(%s, %s);") - cursor.execute(query, [self.db_name, Json(coll_query)]) - result = cursor.fetchone() - if result and result[0]: - first_batch = json.loads(result[0]) - if self.collection_name in first_batch: - self.document_collection = self.collection_name - else: - self.document_collection = self.create_collection() - except Exception as e: - self.document_collection = self.collection_name - - def load_documents( - self, documents: list[VectorStoreDocument], overwrite: bool = True - ) -> None: + self.document_collection = self.create_collection() + + def load_documents(self, documents: List[VectorStoreDocument], overwrite: bool = True) -> None: """Load documents into vector storage.""" data = [ { @@ -82,34 +84,28 @@ def load_documents( if len(data) == 0: data = None - # NOTE: If modifying the next section of code, ensure that the schema remains the same. - # The pyarrow format of the 'vector' field may change if the order of operations is changed - # and will break vector search. - cursor = self.db_connection.cursor() - if overwrite: drop_query = sql.SQL("SELECT * FROM documentdb_api.drop_collection(%s, %s);") create_query = sql.SQL("SELECT * FROM documentdb_api.create_collection(%s, %s);") - cursor.execute(drop_query, [self.db_name, self.collection_name]) - self.db_connection.commit() - cursor.execute(create_query, [self.db_name, self.collection_name]) - self.db_connection.commit() + self._cursor.execute(drop_query, [self.db_name, self.collection_name]) + self._connection.commit() + self._cursor.execute(create_query, [self.db_name, self.collection_name]) + self._connection.commit() self.create_vector_index("vector_index", "vector") - + if data: for doc in data: - insert_query = sql.SQL("SELECT * FROM documentdb_api.insert_document(%s, %s, %s, %s, %s);") - cursor.execute(insert_query, [self.db_name, self.collection_name, Json(doc)]) - self.db_connection.commit() + insert_query = sql.SQL("SELECT documentdb_api.insert_one(%s, %s, %s);") + self._cursor.execute(insert_query, [self.db_name, self.collection_name, Json(doc)]) + self._connection.commit() def create_collection(self) -> str: """Create a collection in the database.""" - cursor = self.db_connection.cursor() create_collection_query = sql.SQL("SELECT * FROM documentdb_api.create_collection(%s, %s);") - cursor.execute(create_collection_query, [self.db_name, self.collection_name]) - self.db_connection.commit() + self._cursor.execute(create_collection_query, [self.db_name, self.collection_name]) + self._connection.commit() self.create_vector_index("vector_index", "vector") return self.collection_name @@ -128,11 +124,11 @@ def create_vector_index(self, index_name: str, index_field: str) -> None: } ] } - cursor = self.db_connection.cursor() - # create_index_query = sql.SQL("SELECT documentdb_api.create_indexes_background(%s, %s);") + + # see bug issue: https://github.com/microsoft/documentdb/issues/63 create_index_query = sql.SQL("SELECT documentdb_api_internal.create_indexes_non_concurrently(%s, %s, true);") - cursor.execute(create_index_query, [self.db_name, Json(index_query)]) - self.db_connection.commit() + self._cursor.execute(create_index_query, [self.db_name, Json(index_query)]) + self._connection.commit() def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: """Build a query filter to filter documents by id.""" @@ -141,22 +137,20 @@ def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: else: if isinstance(include_ids[0], str): self.query_filter = { - "id": { + "id": { "$in": [f"'{id}'" for id in include_ids] } } else: self.query_filter = { - "id": { - "$in": include_ids + "id": { + "$in": include_ids } } - + return self.query_filter - def similarity_search_by_vector( - self, query_embedding: list[float], k: int = 10, **kwargs: Any - ) -> list[VectorStoreSearchResult]: + def similarity_search_by_vector(self, query_embedding: List[float], k: int = 10, **kwargs: Any) -> List[VectorStoreSearchResult]: """Perform a vector-based similarity search.""" search_filter = { "aggregate": self.collection_name, @@ -180,10 +174,10 @@ def similarity_search_by_vector( if self.query_filter: search_filter["pipeline"].insert(0, {"filter": self.query_filter}) - cursor = self.db_connection.cursor() + self._cursor = self._connection.cursor() search_query = sql.SQL("SELECT cursorpage->>'cursor.firstBatch' FROM documentdb_api.aggregate_cursor_first_page(%s, %s);") - cursor.execute(search_query, [self.db_name, Json(search_filter)]) - results = cursor.fetchall() + self._cursor.execute(search_query, [self.db_name, Json(search_filter)]) + results = self._cursor.fetchall() docs = [json.loads(result[0]) for result in results if result[0]] return [ @@ -194,14 +188,12 @@ def similarity_search_by_vector( vector=doc["vector"], attributes=json.loads(doc["attributes"]), ), - score=1 - abs(float(doc["_distance"])), + score=abs(float(doc['__cosmos_meta__']['score'])), ) - for doc in docs + for doc in docs[0] ] - def similarity_search_by_text( - self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any - ) -> list[VectorStoreSearchResult]: + def similarity_search_by_text(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any) -> List[VectorStoreSearchResult]: """Perform a similarity search using a given input text.""" query_embedding = text_embedder(text) if query_embedding: @@ -217,12 +209,12 @@ def search_by_id(self, id: str) -> VectorStoreDocument: } } - cursor = self.db_connection.cursor() + self._cursor = self._connection.cursor() search_query = sql.SQL("SELECT cursorpage->>'cursor.firstBatch' FROM documentdb_api.find_cursor_first_page(%s, %s);") - cursor.execute(search_query, [self.db_name, Json(find_filter)]) - result = cursor.fetchone() + self._cursor.execute(search_query, [self.db_name, Json(find_filter)]) + result = self._cursor.fetchone() doc = json.loads(result[0]) if result and result[0] else None - + if doc: return VectorStoreDocument( id=doc[0]["id"], @@ -230,52 +222,4 @@ def search_by_id(self, id: str) -> VectorStoreDocument: vector=doc[0]["vector"], attributes=json.loads(doc[0]["attributes"]), ) - return VectorStoreDocument(id=id, text=None, vector=None) - - -if __name__ == "__main__": - from graphrag.model.entity import Entity - entities = [ - Entity( - id="2da37c7a-50a8-44d4-aa2c-fd401e19976c", - short_id="sid1", - title="t1", - rank=2, - ), - Entity( - id="c4f93564-4507-4ee4-b102-98add401a965", - short_id="sid2", - title="t22", - rank=4, - ), - Entity( - id="7c6f2bc9-47c9-4453-93a3-d2e174a02cd9", - short_id="sid3", - title="t333", - rank=1, - ), - Entity( - id="8fd6d72a-8e9d-4183-8a97-c38bcc971c83", - short_id="sid4", - title="t4444", - rank=3, - ), - ] - documents = [VectorStoreDocument(id=entity.id, text=entity.title, vector=[0]) for entity in entities] - - - kwargs = { - "collection_name": "default", - "database_name": "documentdb", - "vector_options": { "kind": "vector-ivf", "similarity": "COS", "dimensions": DEFAULT_VECTOR_SIZE, "numLists": 3 } - } - - store = DocumentDBVectoreStore(**kwargs) - store.connect(**{ - "user": "admin", - "password": "admin", - "host": "host.docker.internal", - "port": 9712 - }) - store.load_documents(documents, overwrite=False) - store.search_by_id("1") \ No newline at end of file + return VectorStoreDocument(id=id, text=None, vector=None) \ No newline at end of file diff --git a/tests/integration/storage/test_documentdb_storage.py b/tests/integration/storage/test_documentdb_storage.py index c555953d17..bf93a48a3f 100644 --- a/tests/integration/storage/test_documentdb_storage.py +++ b/tests/integration/storage/test_documentdb_storage.py @@ -136,3 +136,27 @@ async def test_get_creation_date(): finally: await storage.clear() + +async def test_keys(): + storage = DocumentDBPipelineStorage( + database_name = "documentdb", + collection = "testkeystable", + user = "admin", + password = "admin", + host = "host.docker.internal", + port = 9712, + ) + try: + json_content = { + "content": "Happy Easter!", + } + await storage.set("easter.json", json.dumps(json_content), encoding="utf-8") + json_content = { + "content": "Happy Christmas!", + } + await storage.set("christmas.json", json.dumps(json_content), encoding="utf-8") + + keys = storage.keys() + assert len(keys) == 2 + finally: + await storage.clear() \ No newline at end of file