diff --git a/.vscode/settings.json b/.vscode/settings.json index 0b678d5d95..2aa700444c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -48,5 +48,10 @@ }, "custom": true, // Enable the `custom` dictionary "internal-terms": true // Disable the `internal-terms` dictionary - } + }, + "python.testing.pytestArgs": [ + "tests" + ], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } diff --git a/graphrag/cache/factory.py b/graphrag/cache/factory.py index f44c68953b..efdeb09fde 100644 --- a/graphrag/cache/factory.py +++ b/graphrag/cache/factory.py @@ -10,6 +10,7 @@ from graphrag.config.enums import CacheType from graphrag.storage.blob_pipeline_storage import create_blob_storage from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage +from graphrag.storage.documentdb_pipeline_storage import create_documentdb_storage from graphrag.storage.file_pipeline_storage import FilePipelineStorage if TYPE_CHECKING: @@ -56,6 +57,8 @@ def create_cache( return JsonPipelineCache(create_blob_storage(**kwargs)) case CacheType.cosmosdb: return JsonPipelineCache(create_cosmosdb_storage(**kwargs)) + case CacheType.documentdb: + return JsonPipelineCache(create_documentdb_storage(**kwargs)) case _: if cache_type in cls.cache_types: return cls.cache_types[cache_type](**kwargs) diff --git a/graphrag/config/enums.py b/graphrag/config/enums.py index 450ec6bda7..e3b363bd65 100644 --- a/graphrag/config/enums.py +++ b/graphrag/config/enums.py @@ -21,6 +21,8 @@ class CacheType(str, Enum): """The blob cache configuration type.""" cosmosdb = "cosmosdb" """The cosmosdb cache configuration type""" + documentdb = "documentdb" + """The documentdb cache configuration type""" def __repr__(self): """Get a string representation.""" @@ -64,6 +66,8 @@ class OutputType(str, Enum): """The blob output type.""" cosmosdb = "cosmosdb" """The cosmosdb output type""" + documentdb = "documentdb" + """The documentdb output type""" def __repr__(self): """Get a string representation.""" diff --git a/graphrag/config/init_content.py b/graphrag/config/init_content.py index c0d4cefcfe..35ecde0af1 100644 --- a/graphrag/config/init_content.py +++ b/graphrag/config/init_content.py @@ -80,15 +80,15 @@ ## connection_string and container_name must be provided cache: - type: {defs.CACHE_TYPE.value} # [file, blob, cosmosdb] + type: {defs.CACHE_TYPE.value} # [file, blob, cosmosdb, documentdb] base_dir: "{defs.CACHE_BASE_DIR}" reporting: - type: {defs.REPORTING_TYPE.value} # [file, blob, cosmosdb] + type: {defs.REPORTING_TYPE.value} # [file, blob, cosmosdb, documentdb] base_dir: "{defs.REPORTING_BASE_DIR}" output: - type: {defs.OUTPUT_TYPE.value} # [file, blob, cosmosdb] + type: {defs.OUTPUT_TYPE.value} # [file, blob, cosmosdb, documentdb] base_dir: "{defs.OUTPUT_BASE_DIR}" ### Workflow settings ### diff --git a/graphrag/config/models/cache_config.py b/graphrag/config/models/cache_config.py index fb4d22b229..ff2f5e8577 100644 --- a/graphrag/config/models/cache_config.py +++ b/graphrag/config/models/cache_config.py @@ -30,3 +30,6 @@ class CacheConfig(BaseModel): cosmosdb_account_url: str | None = Field( description="The cosmosdb account url to use.", default=None ) + documentdb_account_url: str | None = Field( + description="The documentdb account url to use.", default=None + ) diff --git a/graphrag/config/models/output_config.py b/graphrag/config/models/output_config.py index c237ef2567..7d83ef47db 100644 --- a/graphrag/config/models/output_config.py +++ b/graphrag/config/models/output_config.py @@ -31,3 +31,6 @@ class OutputConfig(BaseModel): cosmosdb_account_url: str | None = Field( description="The cosmosdb account url to use.", default=None ) + documentdb_account_url: str | None = Field( + description="The documentdb account url to use.", default=None + ) diff --git a/graphrag/config/models/vector_store_config.py b/graphrag/config/models/vector_store_config.py index 0bda54650c..e848dc5517 100644 --- a/graphrag/config/models/vector_store_config.py +++ b/graphrag/config/models/vector_store_config.py @@ -50,11 +50,17 @@ def _validate_url(self) -> None: ): msg = "vector_store.url is required when vector_store.type == cosmos_db. Please rerun `graphrag init` and select the correct vector store type." raise ValueError(msg) + + if self.type == VectorStoreType.DocumentDB and ( + self.url is None or self.url.strip() == "" + ): + msg = "vector_store.url is required when vector_store.type == document_db. Please rerun `graphrag init` and select the correct vector store type." + raise ValueError(msg) if self.type == VectorStoreType.LanceDB and ( self.url is not None and self.url.strip() != "" ): - msg = "vector_store.url is only used when vector_store.type == azure_ai_search or vector_store.type == cosmos_db. Please rerun `graphrag init` and select the correct vector store type." + msg = "vector_store.url is only used when vector_store.type == azure_ai_search or vector_store.type == cosmos_db or vector_store.type == document_db. Please rerun `graphrag init` and select the correct vector store type." raise ValueError(msg) api_key: str | None = Field( @@ -73,7 +79,7 @@ def _validate_url(self) -> None: ) database_name: str | None = Field( - description="The database name to use when type == cosmos_db.", default=None + description="The database name to use when type == cosmos_db or document_db.", default=None ) overwrite: bool = Field( diff --git a/graphrag/storage/documentdb_pipeline_storage.py b/graphrag/storage/documentdb_pipeline_storage.py new file mode 100644 index 0000000000..54e4f98e04 --- /dev/null +++ b/graphrag/storage/documentdb_pipeline_storage.py @@ -0,0 +1,344 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""Azure DocumentDB Storage implementation of PipelineStorage.""" + +import sys +import json +import logging +import re +from collections.abc import Iterator +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Pattern, Tuple, Union + +import pandas as pd +import psycopg2 +from psycopg2.extras import RealDictCursor, Json + +from graphrag.logger.base import ProgressLogger +from graphrag.logger.progress import Progress +from graphrag.storage.pipeline_storage import ( + PipelineStorage, + get_timestamp_formatted_with_local_tz, +) + +log = logging.getLogger(__name__) + +class DocumentDBPipelineStorage(PipelineStorage): + """The DocumentDB Storage Implementation.""" + + _connection: psycopg2.extensions.connection + _cursor: psycopg2.extensions.cursor + _database_name: str + _collection: str + _encoding: str + + def __init__( + self, + database_name: str, + collection: str, + user: str, + password: str, + host: str = "localhost", + port: int = 5432, + encoding: str = "utf-8", + ) -> None: + """Initialize the DocumentDB Storage.""" + self._connection = psycopg2.connect( + dbname="postgres", + user=user, + password=password, + host=host, + port=port + ) + self._cursor = self._connection.cursor(cursor_factory=RealDictCursor) + self._encoding = encoding + self._database_name = database_name + self._collection = collection + log.info( + "creating documentdb storage with database: %s and table: %s", + self._database_name, + self._collection, + ) + + set_first = 'SET search_path TO documentdb_api, documentdb_core;' + self._cursor.execute(set_first) + self._connection.commit() + self._create_collection() + + def _create_collection(self) -> None: + """Create the table if it doesn't exist.""" + self._cursor.execute(f""" + SELECT documentdb_api.create_collection('{self._database_name}', '{self._collection}'); + """) + self._connection.commit() + + def _delete_collection(self) -> None: + """Delete the table if it exists.""" + self._cursor.execute(f""" + SELECT documentdb_api.drop_collection('{self._database_name}', '{self._collection}'); + """) + self._connection.commit() + + def find( + self, + file_pattern: Pattern[str], + base_dir: Optional[str] = None, + progress: Optional[ProgressLogger] = None, + file_filter: Optional[Dict[str, Any]] = None, + max_count: int = -1, + ) -> Iterator[Tuple[str, Dict[str, Any]]]: + """Find documents in a DocumentDB table using a file pattern regex and custom file filter (optional).""" + base_dir = base_dir or "" + log.info( + "search table %s for documents matching %s", + self._collection, + file_pattern.pattern, + ) + + if not self._connection or not self._cursor: + return + + try: + find_query = { + "find": self._collection, + "filter": { + "$and": [ + { + "key": { + "$regex": file_pattern.pattern + } + } + ] + } + } + + if file_filter: + for key, value in file_filter.items(): + find_query["filter"]["$and"].append({key: value}) + + if max_count > 0: + find_query["$limit"] = max_count + + query = f""" + SELECT cursorPage->>'cursor.firstBatch' AS results + FROM documentdb_api.find_cursor_first_page('{self._database_name}', {Json(find_query)}); + """ + + self._cursor.execute(query) + item = self._cursor.fetchone() + items = json.loads(item.get('results', '[]')) + num_loaded = 0 + num_total = len(items) + + if num_total == 0: + return + + num_filtered = 0 + for item in items: + key = item["key"] + match = file_pattern.search(key) + if match: + group = match.groupdict() + yield (key, group) + num_loaded += 1 + if max_count > 0 and num_loaded >= max_count: + break + else: + num_filtered += 1 + else: + num_filtered += 1 + if progress is not None: + progress( + _create_progress_status(num_loaded, num_filtered, num_total) + ) + except Exception as ex: + log.exception( + "An error occurred while searching for documents in Document DB." + ) + + async def get(self, key: str, as_bytes: Optional[bool] = None, encoding: Optional[str] = None) -> Any: + """Fetch an item from the table that matches the given key.""" + try: + find_query = { + "find": self._collection, + "filter": { + "key": key + } + } + + query = f""" + SELECT cursorPage->>'cursor.firstBatch' AS results + FROM documentdb_api.find_cursor_first_page('{self._database_name}', {Json(find_query)}); + """ + + self._cursor.execute(query) + item = self._cursor.fetchone() + items = json.loads(item.get('results', '[]')) + if len(items) == 0: + return None + + item = items[0] + + if item: + return json.dumps(item["value"]) + return None + except Exception: + log.exception("Error reading item %s", key) + return None + + async def set(self, key: str, value: Any, encoding: Optional[str] = None) -> None: + """Insert or update the contents of a file into the DocumentDB table for the given filename key.""" + try: + insert_query = { + "key": key, + "value": json.loads(value), + "created_at": datetime.now(timezone.utc).isoformat() + } + self._cursor.execute(f""" + SELECT documentdb_api.insert_one('{self._database_name}', '{self._collection}', {Json(insert_query)}); + """) + self._connection.commit() + except Exception: + log.exception("Error writing item %s", key) + + async def has(self, key: str) -> bool: + """Check if the contents of the given filename key exist in the DocumentDB table.""" + aggregate_query = { + "aggregate": self._collection, + "pipeline": [ + {"$match": {"key": key}}, + {"$count": "key"} + ], + "cursor": {"batchSize": 1} + } + + self._cursor.execute(f""" + SELECT jsonb_extract_path_text(results::jsonb, '0', 'key')::int AS count + FROM ( + SELECT cursorPage->>'cursor.firstBatch' AS results + FROM documentdb_api.aggregate_cursor_first_page('{self._database_name}', {Json(aggregate_query)}) + ); + """) + result = self._cursor.fetchone() + return result.get('count', 0) > 0 + + async def delete(self, key: str) -> None: + """Delete the item with the given filename key from the DocumentDB table.""" + try: + delete_query = { + "delete": self._collection, + "deletes": [ + { + "q": {"key": key}, + "limit": 1 + } + ] + } + self._cursor.execute(f"SELECT documentdb_api.delete('{self._database_name}', {Json(delete_query)});") + self._connection.commit() + except Exception: + log.exception("Error deleting item %s", key) + + async def clear(self) -> None: + """Clear all contents from storage.""" + self._delete_collection() + self._create_collection() + + def keys(self) -> List[str]: + """Return the keys in the storage.""" + count_query = { + "aggregate": self._collection, + "pipeline": [ + { "$count": "key" } + ], + "cursor": { "batchSize": 1 } + } + + self._cursor.execute(f""" + SELECT jsonb_extract_path_text(results::jsonb, '0', 'key')::int AS batch_size + FROM ( + SELECT cursorPage->>'cursor.firstBatch' AS results + FROM documentdb_api.aggregate_cursor_first_page( + 'documentdb', + {Json(count_query)} + ) + ) subquery; + """) + result = self._cursor.fetchone() + + keys_query = { + "aggregate": self._collection, + "pipeline": [ + { "$group": + { "_id": "$key", } + } + ] , "cursor": { "batchSize": result.get('batch_size', sys.maxsize) } } + self._cursor.execute(f"SELECT cursorpage->>'cursor.firstBatch' AS result FROM documentdb_api.aggregate_cursor_first_page('documentdb', {Json(keys_query)});") + result = self._cursor.fetchone() + return [row["_id"] for row in json.loads(result['result'])] + + def child(self, name: Optional[str]) -> PipelineStorage: + """Create a child storage instance.""" + return self + + def _get_prefix(self, key: str) -> str: + """Get the prefix of the filename key.""" + return key.split(".")[0] + + async def get_creation_date(self, key: str) -> str: + """Get the creation date of the item with the given key.""" + try: + find_query = { + "find": self._collection, + "filter": { + "key": key + } + } + + query = f""" + SELECT cursorPage->>'cursor.firstBatch' AS results + FROM documentdb_api.find_cursor_first_page('{self._database_name}', {Json(find_query)}); + """ + + self._cursor.execute(query) + item = self._cursor.fetchone() + items = json.loads(item.get('results', '[]')) + if len(items) == 0: + return "" + + item = items[0] + + if item: + return get_timestamp_formatted_with_local_tz( + datetime.fromisoformat(item["created_at"]) + ) + return "" + except Exception: + log.exception("Error getting key %s", key) + return "" + +def create_documentdb_storage(**kwargs: Any) -> PipelineStorage: + """Create a DocumentDB storage instance.""" + log.info("Creating postgres storage") + database_name = kwargs["database_name"] + collection = kwargs["collection"] + user = kwargs["user"] + password = kwargs["password"] + host = kwargs.get("host", "localhost") + port = kwargs.get("port", 5432) + return DocumentDBPipelineStorage( + database_name=database_name, + collection=collection, + user=user, + password=password, + host=host, + port=port, + ) + +def _create_progress_status(num_loaded: int, num_filtered: int, num_total: int) -> Progress: + return Progress( + total_items=num_total, + completed_items=num_loaded + num_filtered, + description=f"{num_loaded} files loaded ({num_filtered} filtered)", + ) diff --git a/graphrag/storage/factory.py b/graphrag/storage/factory.py index 8a6e0df4d6..a1d8c9601f 100644 --- a/graphrag/storage/factory.py +++ b/graphrag/storage/factory.py @@ -10,6 +10,7 @@ from graphrag.config.enums import OutputType from graphrag.storage.blob_pipeline_storage import create_blob_storage from graphrag.storage.cosmosdb_pipeline_storage import create_cosmosdb_storage +from graphrag.storage.documentdb_pipeline_storage import create_documentdb_storage from graphrag.storage.file_pipeline_storage import create_file_storage from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage @@ -43,6 +44,8 @@ def create_storage( return create_blob_storage(**kwargs) case OutputType.cosmosdb: return create_cosmosdb_storage(**kwargs) + case OutputType.documentdb: + return create_documentdb_storage(**kwargs) case OutputType.file: return create_file_storage(**kwargs) case OutputType.memory: diff --git a/graphrag/vector_stores/documentdb.py b/graphrag/vector_stores/documentdb.py new file mode 100644 index 0000000000..0cc235f39c --- /dev/null +++ b/graphrag/vector_stores/documentdb.py @@ -0,0 +1,225 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License + +"""A package containing the DocumentDB vector store implementation.""" + +import json +from typing import Any, List, Union +import psycopg2 +from psycopg2 import sql +from psycopg2.extras import RealDictCursor, Json + +from graphrag.model.types import TextEmbedder +from graphrag.vector_stores.base import ( + DEFAULT_VECTOR_SIZE, + BaseVectorStore, + VectorStoreDocument, + VectorStoreSearchResult, +) + +class DocumentDBVectorStore(BaseVectorStore): + """Microsoft DocumentDB (PostgreSQL) vector storage implementation.""" + + _connection: psycopg2.extensions.connection + _cursor: psycopg2.extensions.cursor + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self.db_name = kwargs.get("database_name", "documentdb") + + kind = kwargs.get("kind", "vector-ivf") # Possible values: vector-hnsw, vector-ivf. + similarity = kwargs.get("similarity", "COS") # COS for cosine similarity, L2 for Euclidean distance, or IP for inner product. + dimensions = kwargs.get("dimensions", DEFAULT_VECTOR_SIZE) + num_lists = kwargs.get("numLists", 100) # Only IVF index requires this parameter. + m = kwargs.get("m", 16) # Only HNSW index requires this parameter. + ef_construction = kwargs.get("efConstruction", 64) # Only HNSW index requires this parameter. + + self.vector_options = kwargs.get("vector_options", { + "kind": kind, + "similarity": similarity, + "dimensions": dimensions, + "numLists": num_lists, + "m": m, + "efConstruction": ef_construction + }) + + def connect(self, **kwargs: Any) -> None: + """Connect to DocumentDB (PostgreSQL) vector storage.""" + user = kwargs.get("user") + password = kwargs.get("password") + host = kwargs.get("host") + port = kwargs.get("port", 5432) + + if not all([self.db_name, user, password, host]): + raise ValueError("Database credentials must be provided.") + + self._connection = psycopg2.connect( + dbname="postgres", + user=user, + password=password, + host=host, + port=port + ) + self._cursor = self._connection.cursor(cursor_factory=RealDictCursor) + + set_first = 'SET search_path TO documentdb_api, documentdb_core;' + self._cursor.execute(set_first) + self._connection.commit() + + self.document_collection = self.create_collection() + + def load_documents(self, documents: List[VectorStoreDocument], overwrite: bool = True) -> None: + """Load documents into vector storage.""" + data = [ + { + "id": document.id, + "text": document.text, + "vector": document.vector, + "attributes": json.dumps(document.attributes), + } + for document in documents + if document.vector is not None + ] + + if len(data) == 0: + data = None + + if overwrite: + drop_query = sql.SQL("SELECT * FROM documentdb_api.drop_collection(%s, %s);") + create_query = sql.SQL("SELECT * FROM documentdb_api.create_collection(%s, %s);") + + self._cursor.execute(drop_query, [self.db_name, self.collection_name]) + self._connection.commit() + self._cursor.execute(create_query, [self.db_name, self.collection_name]) + self._connection.commit() + + self.create_vector_index("vector_index", "vector") + + if data: + for doc in data: + insert_query = sql.SQL("SELECT documentdb_api.insert_one(%s, %s, %s);") + self._cursor.execute(insert_query, [self.db_name, self.collection_name, Json(doc)]) + self._connection.commit() + + def create_collection(self) -> str: + """Create a collection in the database.""" + create_collection_query = sql.SQL("SELECT * FROM documentdb_api.create_collection(%s, %s);") + self._cursor.execute(create_collection_query, [self.db_name, self.collection_name]) + self._connection.commit() + + self.create_vector_index("vector_index", "vector") + return self.collection_name + + def create_vector_index(self, index_name: str, index_field: str) -> None: + """Create an index on the collection.""" + index_query = { + "createIndexes": self.collection_name, + "indexes": [ + { + "name": index_name, + "key": { + index_field: "cosmosSearch" + }, + "cosmosSearchOptions": self.vector_options + } + ] + } + + # see bug issue: https://github.com/microsoft/documentdb/issues/63 + create_index_query = sql.SQL("SELECT documentdb_api_internal.create_indexes_non_concurrently(%s, %s, true);") + self._cursor.execute(create_index_query, [self.db_name, Json(index_query)]) + self._connection.commit() + + def filter_by_id(self, include_ids: list[str] | list[int]) -> Any: + """Build a query filter to filter documents by id.""" + if len(include_ids) == 0: + self.query_filter = None + else: + if isinstance(include_ids[0], str): + self.query_filter = { + "id": { + "$in": [f"'{id}'" for id in include_ids] + } + } + else: + self.query_filter = { + "id": { + "$in": include_ids + } + } + + return self.query_filter + + def similarity_search_by_vector(self, query_embedding: List[float], k: int = 10, **kwargs: Any) -> List[VectorStoreSearchResult]: + """Perform a vector-based similarity search.""" + search_filter = { + "aggregate": self.collection_name, + "pipeline": [ + { + "$search": { + "cosmosSearch": { + "vector": query_embedding, + "path": "vector", + "k": k, + }, + "returnStoredSource": True + } + } + ], + "cursor": { + "batchSize": k + } + } + + if self.query_filter: + search_filter["pipeline"].insert(0, {"filter": self.query_filter}) + + self._cursor = self._connection.cursor() + search_query = sql.SQL("SELECT cursorpage->>'cursor.firstBatch' FROM documentdb_api.aggregate_cursor_first_page(%s, %s);") + self._cursor.execute(search_query, [self.db_name, Json(search_filter)]) + results = self._cursor.fetchall() + docs = [json.loads(result[0]) for result in results if result[0]] + + return [ + VectorStoreSearchResult( + document=VectorStoreDocument( + id=doc["id"], + text=doc["text"], + vector=doc["vector"], + attributes=json.loads(doc["attributes"]), + ), + score=abs(float(doc['__cosmos_meta__']['score'])), + ) + for doc in docs[0] + ] + + def similarity_search_by_text(self, text: str, text_embedder: TextEmbedder, k: int = 10, **kwargs: Any) -> List[VectorStoreSearchResult]: + """Perform a similarity search using a given input text.""" + query_embedding = text_embedder(text) + if query_embedding: + return self.similarity_search_by_vector(query_embedding, k) + return [] + + def search_by_id(self, id: str) -> VectorStoreDocument: + """Search for a document by id.""" + find_filter = { + "find": self.collection_name, + "filter": { + "id": id + } + } + + self._cursor = self._connection.cursor() + search_query = sql.SQL("SELECT cursorpage->>'cursor.firstBatch' FROM documentdb_api.find_cursor_first_page(%s, %s);") + self._cursor.execute(search_query, [self.db_name, Json(find_filter)]) + result = self._cursor.fetchone() + doc = json.loads(result[0]) if result and result[0] else None + + if doc: + return VectorStoreDocument( + id=doc[0]["id"], + text=doc[0]["text"], + vector=doc[0]["vector"], + attributes=json.loads(doc[0]["attributes"]), + ) + return VectorStoreDocument(id=id, text=None, vector=None) \ No newline at end of file diff --git a/graphrag/vector_stores/factory.py b/graphrag/vector_stores/factory.py index 1c37316d0c..066a1d2d94 100644 --- a/graphrag/vector_stores/factory.py +++ b/graphrag/vector_stores/factory.py @@ -9,6 +9,7 @@ from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore from graphrag.vector_stores.base import BaseVectorStore from graphrag.vector_stores.cosmosdb import CosmosDBVectoreStore +from graphrag.vector_stores.documentdb import DocumentDBVectoreStore from graphrag.vector_stores.lancedb import LanceDBVectorStore @@ -18,6 +19,7 @@ class VectorStoreType(str, Enum): LanceDB = "lancedb" AzureAISearch = "azure_ai_search" CosmosDB = "cosmosdb" + DocumentDB = "documentdb" class VectorStoreFactory: @@ -45,6 +47,8 @@ def create_vector_store( return AzureAISearchVectorStore(**kwargs) case VectorStoreType.CosmosDB: return CosmosDBVectoreStore(**kwargs) + case VectorStoreType.DocumentDB: + return DocumentDBVectoreStore(**kwargs) case _: if vector_store_type in cls.vector_store_types: return cls.vector_store_types[vector_store_type](**kwargs) diff --git a/tests/integration/storage/test_documentdb_storage.py b/tests/integration/storage/test_documentdb_storage.py new file mode 100644 index 0000000000..bf93a48a3f --- /dev/null +++ b/tests/integration/storage/test_documentdb_storage.py @@ -0,0 +1,162 @@ +# Copyright (c) 2024 Microsoft Corporation. +# Licensed under the MIT License +"""CosmosDB Storage Tests.""" + +import json +import re +import sys +from datetime import datetime + +import pytest + +from graphrag.storage.documentdb_pipeline_storage import DocumentDBPipelineStorage + + +async def test_find(): + storage = DocumentDBPipelineStorage( + database_name = "documentdb", + collection = "testfindtable", + user = "admin", + password = "admin", + host = "host.docker.internal", + port = 9712, + ) + try: + try: + items = list(storage.find(file_pattern=re.compile(r".*\.json$"))) + items = [item[0] for item in items] + assert items == [] + + json_content = { + "content": "Merry Christmas!", + } + await storage.set( + "christmas.json", json.dumps(json_content), encoding="utf-8" + ) + items = list(storage.find(file_pattern=re.compile(r".*\.json$"))) + items = [item[0] for item in items] + assert items == ["christmas.json"] + + json_content = { + "content": "Hello, World!", + } + await storage.set("test.json", json.dumps(json_content), encoding="utf-8") + items = list(storage.find(file_pattern=re.compile(r".*\.json$"))) + items = [item[0] for item in items] + assert items == ["christmas.json", "test.json"] + + items = list(storage.find(file_pattern=re.compile(r".*\.json$"), file_filter={"key": "test.json"})) + items = [item[0] for item in items] + assert items == ["test.json"] + + output = await storage.get("test.json") + output_json = json.loads(output) + assert output_json["content"] == "Hello, World!" + + json_exists = await storage.has("christmas.json") + assert json_exists is True + json_exists = await storage.has("easter.json") + assert json_exists is False + finally: + await storage.delete("test.json") + output = await storage.get("test.json") + assert output is None + finally: + await storage.clear() + + +async def test_child(): + storage = DocumentDBPipelineStorage( + database_name = "documentdb", + collection = "testfindtable", + user = "admin", + password = "admin", + host = "host.docker.internal", + port = 9712, + ) + try: + child_storage = storage.child("child") + assert type(child_storage) is DocumentDBPipelineStorage + finally: + await storage.clear() + + +async def test_clear(): + storage = DocumentDBPipelineStorage( + database_name = "documentdb", + collection = "testfindtable", + user = "admin", + password = "admin", + host = "host.docker.internal", + port = 9712, + ) + try: + json_exists = { + "content": "Merry Christmas!", + } + await storage.set("christmas.json", json.dumps(json_exists), encoding="utf-8") + json_exists = { + "content": "Happy Easter!", + } + await storage.set("easter.json", json.dumps(json_exists), encoding="utf-8") + await storage.clear() + + items = list(storage.find(file_pattern=re.compile(r".*\.json$"))) + items = [item[0] for item in items] + assert items == [] + + output = await storage.get("easter.json") + assert output is None + finally: + # await storage.clear() + print("Table cleared") + + +async def test_get_creation_date(): + storage = DocumentDBPipelineStorage( + database_name = "documentdb", + collection = "testfindtable", + user = "admin", + password = "admin", + host = "host.docker.internal", + port = 9712, + ) + try: + json_content = { + "content": "Happy Easter!", + } + await storage.set("easter.json", json.dumps(json_content), encoding="utf-8") + + creation_date = await storage.get_creation_date("easter.json") + + datetime_format = "%Y-%m-%d %H:%M:%S %z" + parsed_datetime = datetime.strptime(creation_date, datetime_format).astimezone() + + assert parsed_datetime.strftime(datetime_format) == creation_date + + finally: + await storage.clear() + +async def test_keys(): + storage = DocumentDBPipelineStorage( + database_name = "documentdb", + collection = "testkeystable", + user = "admin", + password = "admin", + host = "host.docker.internal", + port = 9712, + ) + try: + json_content = { + "content": "Happy Easter!", + } + await storage.set("easter.json", json.dumps(json_content), encoding="utf-8") + json_content = { + "content": "Happy Christmas!", + } + await storage.set("christmas.json", json.dumps(json_content), encoding="utf-8") + + keys = storage.keys() + assert len(keys) == 2 + finally: + await storage.clear() \ No newline at end of file diff --git a/tests/integration/storage/test_factory.py b/tests/integration/storage/test_factory.py index 5be62e9016..b2d5e79e26 100644 --- a/tests/integration/storage/test_factory.py +++ b/tests/integration/storage/test_factory.py @@ -12,6 +12,7 @@ from graphrag.config.enums import OutputType from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage from graphrag.storage.cosmosdb_pipeline_storage import CosmosDBPipelineStorage +from graphrag.storage.documentdb_pipeline_storage import DocumentDBPipelineStorage from graphrag.storage.factory import StorageFactory from graphrag.storage.file_pipeline_storage import FilePipelineStorage from graphrag.storage.memory_pipeline_storage import MemoryPipelineStorage @@ -47,6 +48,19 @@ def test_create_cosmosdb_storage(): storage = StorageFactory.create_storage(OutputType.cosmosdb, kwargs) assert isinstance(storage, CosmosDBPipelineStorage) +def test_create_documentdb_storage(): + kwargs = { + "type": "documentdb", + "database_name": "postgres", + "collection": "testtable", + "user": "admin", + "password": "admin", + "host": "host.docker.internal", + "port": 9712, + } + storage = StorageFactory.create_storage(OutputType.documentdb, kwargs) + assert isinstance(storage, DocumentDBPipelineStorage) + def test_create_file_storage(): kwargs = {"type": "file", "base_dir": "/tmp/teststorage"} diff --git a/tests/unit/config/utils.py b/tests/unit/config/utils.py index b1a92e6a89..d26cc069fe 100644 --- a/tests/unit/config/utils.py +++ b/tests/unit/config/utils.py @@ -89,6 +89,7 @@ "container_name": None, "storage_account_blob_url": None, "cosmosdb_account_url": None, + "documentdb_account_url": None, }, "input": { "type": defs.INPUT_TYPE, @@ -321,6 +322,7 @@ def assert_output_configs(actual: OutputConfig, expected: OutputConfig) -> None: assert expected.container_name == actual.container_name assert expected.storage_account_blob_url == actual.storage_account_blob_url assert expected.cosmosdb_account_url == actual.cosmosdb_account_url + assert expected.documentdb_account_url == actual.documentdb_account_url def assert_update_output_configs(actual: OutputConfig, expected: OutputConfig) -> None: @@ -330,6 +332,7 @@ def assert_update_output_configs(actual: OutputConfig, expected: OutputConfig) - assert expected.container_name == actual.container_name assert expected.storage_account_blob_url == actual.storage_account_blob_url assert expected.cosmosdb_account_url == actual.cosmosdb_account_url + assert expected.documentdb_account_url == actual.documentdb_account_url def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None: @@ -339,6 +342,7 @@ def assert_cache_configs(actual: CacheConfig, expected: CacheConfig) -> None: assert actual.container_name == expected.container_name assert actual.storage_account_blob_url == expected.storage_account_blob_url assert actual.cosmosdb_account_url == expected.cosmosdb_account_url + assert actual.documentdb_account_url == expected.documentdb_account_url def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None: