diff --git a/libs/langchain-mongodb/langchain_mongodb/index.py b/libs/langchain-mongodb/langchain_mongodb/index.py index fd53956c..f2c99c94 100644 --- a/libs/langchain-mongodb/langchain_mongodb/index.py +++ b/libs/langchain-mongodb/langchain_mongodb/index.py @@ -2,10 +2,15 @@ import logging from time import monotonic, sleep -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional from pymongo.collection import Collection -from pymongo.operations import SearchIndexModel +from pymongo_search_utils import ( + create_fulltext_search_index, # noqa: F401 + create_vector_search_index, # noqa: F401 + drop_vector_search_index, # noqa: F401 + update_vector_search_index, # noqa: F401 +) logger = logging.getLogger(__file__) @@ -34,135 +39,6 @@ def _vector_search_index_definition( return definition -def create_vector_search_index( - collection: Collection, - index_name: str, - dimensions: int, - path: str, - similarity: str, - filters: Optional[List[str]] = None, - *, - wait_until_complete: Optional[float] = None, - **kwargs: Any, -) -> None: - """Experimental Utility function to create a vector search index - - Args: - collection (Collection): MongoDB Collection - index_name (str): Name of Index - dimensions (int): Number of dimensions in embedding - path (str): field with vector embedding - similarity (str): The similarity score used for the index - filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch - wait_until_complete (Optional[float]): If provided, number of seconds to wait - until search index is ready. - kwargs: Keyword arguments supplying any additional options to SearchIndexModel. - """ - logger.info("Creating Search Index %s on %s", index_name, collection.name) - - if collection.name not in collection.database.list_collection_names( - authorizedCollections=True - ): - collection.database.create_collection(collection.name) - - result = collection.create_search_index( - SearchIndexModel( - definition=_vector_search_index_definition( - dimensions=dimensions, - path=path, - similarity=similarity, - filters=filters, - **kwargs, - ), - name=index_name, - type="vectorSearch", - ) - ) - - if wait_until_complete: - _wait_for_predicate( - predicate=lambda: _is_index_ready(collection, index_name), - err=f"{index_name=} did not complete in {wait_until_complete}!", - timeout=wait_until_complete, - ) - logger.info(result) - - -def drop_vector_search_index( - collection: Collection, - index_name: str, - *, - wait_until_complete: Optional[float] = None, -) -> None: - """Drop a created vector search index - - Args: - collection (Collection): MongoDB Collection with index to be dropped - index_name (str): Name of the MongoDB index - wait_until_complete (Optional[float]): If provided, number of seconds to wait - until search index is ready. - """ - logger.info( - "Dropping Search Index %s from Collection: %s", index_name, collection.name - ) - collection.drop_search_index(index_name) - if wait_until_complete: - _wait_for_predicate( - predicate=lambda: len(list(collection.list_search_indexes())) == 0, - err=f"Index {index_name} did not drop in {wait_until_complete}!", - timeout=wait_until_complete, - ) - logger.info("Vector Search index %s.%s dropped", collection.name, index_name) - - -def update_vector_search_index( - collection: Collection, - index_name: str, - dimensions: int, - path: str, - similarity: str, - filters: Optional[List[str]] = None, - *, - wait_until_complete: Optional[float] = None, - **kwargs: Any, -) -> None: - """Update a search index. - - Replace the existing index definition with the provided definition. - - Args: - collection (Collection): MongoDB Collection - index_name (str): Name of Index - dimensions (int): Number of dimensions in embedding - path (str): field with vector embedding - similarity (str): The similarity score used for the index. - filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch - wait_until_complete (Optional[float]): If provided, number of seconds to wait - until search index is ready. - kwargs: Keyword arguments supplying any additional options to SearchIndexModel. - """ - logger.info( - "Updating Search Index %s from Collection: %s", index_name, collection.name - ) - collection.update_search_index( - name=index_name, - definition=_vector_search_index_definition( - dimensions=dimensions, - path=path, - similarity=similarity, - filters=filters, - **kwargs, - ), - ) - if wait_until_complete: - _wait_for_predicate( - predicate=lambda: _is_index_ready(collection, index_name), - err=f"Index {index_name} update did not complete in {wait_until_complete}!", - timeout=wait_until_complete, - ) - logger.info("Update succeeded") - - def _is_index_ready(collection: Collection, index_name: str) -> bool: """Check for the index name in the list of available search indexes to see if the specified index is of status READY @@ -199,50 +75,3 @@ def _wait_for_predicate( if monotonic() - start > timeout: raise TimeoutError(err) sleep(interval) - - -def create_fulltext_search_index( - collection: Collection, - index_name: str, - field: Union[str, List[str]], - *, - wait_until_complete: Optional[float] = None, - **kwargs: Any, -) -> None: - """Experimental Utility function to create an Atlas Search index - - Args: - collection (Collection): MongoDB Collection - index_name (str): Name of Index - field (str): Field to index - wait_until_complete (Optional[float]): If provided, number of seconds to wait - until search index is ready - kwargs: Keyword arguments supplying any additional options to SearchIndexModel. - """ - logger.info("Creating Search Index %s on %s", index_name, collection.name) - - if collection.name not in collection.database.list_collection_names( - authorizedCollections=True - ): - collection.database.create_collection(collection.name) - - if isinstance(field, str): - fields_definition = {field: [{"type": "string"}]} - else: - fields_definition = {f: [{"type": "string"}] for f in field} - definition = {"mappings": {"dynamic": False, "fields": fields_definition}} - result = collection.create_search_index( - SearchIndexModel( - definition=definition, - name=index_name, - type="search", - **kwargs, - ) - ) - if wait_until_complete: - _wait_for_predicate( - predicate=lambda: _is_index_ready(collection, index_name), - err=f"{index_name=} did not complete in {wait_until_complete}!", - timeout=wait_until_complete, - ) - logger.info(result) diff --git a/libs/langchain-mongodb/langchain_mongodb/pipelines.py b/libs/langchain-mongodb/langchain_mongodb/pipelines.py index bc5fff18..d6b09ea9 100644 --- a/libs/langchain-mongodb/langchain_mongodb/pipelines.py +++ b/libs/langchain-mongodb/langchain_mongodb/pipelines.py @@ -9,6 +9,13 @@ from typing import Any, Dict, List, Optional, Union +from pymongo_search_utils import ( + combine_pipelines, # noqa: F401 + final_hybrid_stage, # noqa: F401 + reciprocal_rank_stage, # noqa: F401 + vector_search_stage, # noqa: F401 +) + def text_search_stage( query: str, @@ -48,115 +55,3 @@ def text_search_stage( pipeline.append({"$limit": limit}) # type: ignore return pipeline # type: ignore - - -def vector_search_stage( - query_vector: List[float], - search_field: str, - index_name: str, - top_k: int = 4, - filter: Optional[Dict[str, Any]] = None, - oversampling_factor: int = 10, - **kwargs: Any, -) -> Dict[str, Any]: # noqa: E501 - """Vector Search Stage without Scores. - - Scoring is applied later depending on strategy. - vector search includes a vectorSearchScore that is typically used. - hybrid uses Reciprocal Rank Fusion. - - Args: - query_vector: List of embedding vector - search_field: Field in Collection containing embedding vectors - index_name: Name of Atlas Vector Search Index tied to Collection - top_k: Number of documents to return - oversampling_factor: this times limit is the number of candidates - filter: MQL match expression comparing an indexed field. - Some operators are not supported. - See `vectorSearch filter docs `_ - - - Returns: - Dictionary defining the $vectorSearch - """ - stage = { - "index": index_name, - "path": search_field, - "queryVector": query_vector, - "numCandidates": top_k * oversampling_factor, - "limit": top_k, - } - if filter: - stage["filter"] = filter - return {"$vectorSearch": stage} - - -def combine_pipelines( - pipeline: List[Any], stage: List[Dict[str, Any]], collection_name: str -) -> None: - """Combines two aggregations into a single result set in-place.""" - if pipeline: - pipeline.append({"$unionWith": {"coll": collection_name, "pipeline": stage}}) - else: - pipeline.extend(stage) - - -def reciprocal_rank_stage( - score_field: str, penalty: float = 0, weight: float = 1, **kwargs: Any -) -> List[Dict[str, Any]]: - """ - Stage adds Weighted Reciprocal Rank Fusion (WRRF) scoring. - - First, it groups documents into an array, assigns rank by array index, - and then computes a weighted RRF score. - - Args: - score_field: A unique string to identify the search being ranked. - penalty: A non-negative float (e.g., 60 for RRF-60). Controls the denominator. - weight: A float multiplier for this source's importance. - **kwargs: Ignored; allows future extensions or passthrough args. - - Returns: - Aggregation pipeline stage for weighted RRF scoring. - """ - - return [ - {"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}}, - {"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}}, - { - "$addFields": { - f"docs.{score_field}": { - "$multiply": [ - weight, - {"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]}, - ] - }, - "docs.rank": "$rank", - "_id": "$docs._id", - } - }, - {"$replaceRoot": {"newRoot": "$docs"}}, - ] - - -def final_hybrid_stage( - scores_fields: List[str], limit: int, **kwargs: Any -) -> List[Dict[str, Any]]: - """Sum weighted scores, sort, and apply limit. - - Args: - scores_fields: List of fields given to scores of vector and text searches - limit: Number of documents to return - - Returns: - Final aggregation stages - """ - - return [ - {"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}}, - {"$replaceRoot": {"newRoot": "$docs"}}, - {"$set": {score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}}, - {"$addFields": {"score": {"$add": [f"${score}" for score in scores_fields]}}}, - {"$sort": {"score": -1}}, - {"$limit": limit}, - ] diff --git a/libs/langchain-mongodb/langchain_mongodb/utils.py b/libs/langchain-mongodb/langchain_mongodb/utils.py index 826ce56e..8e4f717e 100644 --- a/libs/langchain-mongodb/langchain_mongodb/utils.py +++ b/libs/langchain-mongodb/langchain_mongodb/utils.py @@ -26,6 +26,7 @@ import numpy as np from pymongo import MongoClient from pymongo.driver_info import DriverInfo +from pymongo_search_utils import append_client_metadata logger = logging.getLogger(__name__) @@ -35,9 +36,7 @@ def _append_client_metadata(client: MongoClient) -> None: - # append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions - if callable(client.append_metadata): - client.append_metadata(DRIVER_METADATA) + append_client_metadata(client=client, driver_info=DRIVER_METADATA) def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray: diff --git a/libs/langchain-mongodb/langchain_mongodb/vectorstores.py b/libs/langchain-mongodb/langchain_mongodb/vectorstores.py index a1adf7e5..15b593b9 100644 --- a/libs/langchain-mongodb/langchain_mongodb/vectorstores.py +++ b/libs/langchain-mongodb/langchain_mongodb/vectorstores.py @@ -22,9 +22,10 @@ from langchain_core.embeddings import Embeddings from langchain_core.runnables.config import run_in_executor from langchain_core.vectorstores import VectorStore -from pymongo import MongoClient, ReplaceOne +from pymongo import MongoClient from pymongo.collection import Collection from pymongo.errors import CollectionInvalid +from pymongo_search_utils import bulk_embed_and_insert_texts from langchain_mongodb.index import ( create_vector_search_index, @@ -429,28 +430,15 @@ def bulk_embed_and_insert_texts( See add_texts for additional details. """ - if not texts: - return [] - # Compute embedding vectors - embeddings = self._embedding.embed_documents(list(texts)) - if not ids: - ids = [str(ObjectId()) for _ in range(len(list(texts)))] - docs = [ - { - "_id": str_to_oid(i), - self._text_key: t, - self._embedding_key: embedding, - **m, - } - for i, t, m, embedding in zip( - ids, texts, metadatas, embeddings, strict=True - ) - ] - operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs] - # insert the documents in MongoDB Atlas - result = self._collection.bulk_write(operations) - assert result.upserted_ids is not None - return [oid_to_str(_id) for _id in result.upserted_ids.values()] + return bulk_embed_and_insert_texts( + texts=texts, + metadatas=metadatas, + embedding_func=self._embedding.embed_documents, + collection=self._collection, + text_key=self._text_key, + embedding_key=self._embedding_key, + ids=ids, + ) def add_documents( self, diff --git a/libs/langchain-mongodb/pyproject.toml b/libs/langchain-mongodb/pyproject.toml index fe6d7426..90f607ee 100644 --- a/libs/langchain-mongodb/pyproject.toml +++ b/libs/langchain-mongodb/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "langchain-text-splitters>=1.0", "numpy>=1.26", "lark<2.0.0,>=1.1.9", + "pymongo-search-utils>=0.1.0", ] [dependency-groups] @@ -65,6 +66,7 @@ filterwarnings = [ [tool.mypy] disallow_untyped_defs = true +disable_error_code = ["import-untyped"] [[tool.mypy.overrides]] module = ["tests.*"] diff --git a/libs/langchain-mongodb/uv.lock b/libs/langchain-mongodb/uv.lock index ba0b4acb..6835769d 100644 --- a/libs/langchain-mongodb/uv.lock +++ b/libs/langchain-mongodb/uv.lock @@ -799,6 +799,7 @@ dependencies = [ { name = "numpy", version = "2.2.6", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, { name = "numpy", version = "2.3.2", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, { name = "pymongo" }, + { name = "pymongo-search-utils" }, ] [package.dev-dependencies] @@ -836,6 +837,7 @@ requires-dist = [ { name = "lark", specifier = ">=1.1.9,<2.0.0" }, { name = "numpy", specifier = ">=1.26" }, { name = "pymongo", specifier = ">=4.6.1" }, + { name = "pymongo-search-utils", specifier = ">=0.1.0" }, ] [package.metadata.requires-dev] @@ -1872,6 +1874,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/fa/68b1555e62ed3ee87f8a2de99d5fb840cf045748da4488870b4dced44a95/pymongo-4.14.0-cp313-cp313t-win_amd64.whl", hash = "sha256:e506af9b25aac77cc5c5ea4a72f81764e4f5ea90ca799aac43d665ab269f291d", size = 1011181, upload-time = "2025-08-06T13:40:48.641Z" }, ] +[[package]] +name = "pymongo-search-utils" +version = "0.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pymongo" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/17/161b9a61a4a66d267a2eecab7157352b02a2b8fff7349311183be84f4c96/pymongo_search_utils-0.1.0.tar.gz", hash = "sha256:e2e2adc7292a8ae6031ab5c935fccb8c40542a45d3d217afea56995659da79e5", size = 11976, upload-time = "2025-11-24T15:12:12.642Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/05/7944a1cfb4a844d75a5c28f19ad94c3facf520e494e3e8fcd31b17f085c3/pymongo_search_utils-0.1.0-py3-none-any.whl", hash = "sha256:44f7601a99e8d979bb7ef7be611863c1a98943c92fb192bfa549ac8b1c281580", size = 17186, upload-time = "2025-11-24T15:12:11.2Z" }, +] + [[package]] name = "pypdf" version = "6.0.0" diff --git a/uv.lock b/uv.lock index fd8fa82f..8bffb3c8 100644 --- a/uv.lock +++ b/uv.lock @@ -793,6 +793,7 @@ dependencies = [ { name = "lark" }, { name = "numpy" }, { name = "pymongo" }, + { name = "pymongo-search-utils" }, ] [package.metadata] @@ -804,6 +805,7 @@ requires-dist = [ { name = "lark", specifier = ">=1.1.9,<2.0.0" }, { name = "numpy", specifier = ">=1.26" }, { name = "pymongo", specifier = ">=4.6.1" }, + { name = "pymongo-search-utils", specifier = ">=0.1.0" }, ] [package.metadata.requires-dev] @@ -1805,6 +1807,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/9c/00301a6df26f0f8d5c5955192892241e803742e7c3da8c2c222efabc0df6/pymongo-4.13.2-cp313-cp313t-win_amd64.whl", hash = "sha256:c38168263ed94a250fc5cf9c6d33adea8ab11c9178994da1c3481c2a49d235f8", size = 1011057, upload-time = "2025-06-16T18:16:07.917Z" }, ] +[[package]] +name = "pymongo-search-utils" +version = "0.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pymongo" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d9/17/161b9a61a4a66d267a2eecab7157352b02a2b8fff7349311183be84f4c96/pymongo_search_utils-0.1.0.tar.gz", hash = "sha256:e2e2adc7292a8ae6031ab5c935fccb8c40542a45d3d217afea56995659da79e5", size = 11976, upload-time = "2025-11-24T15:12:12.642Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/05/7944a1cfb4a844d75a5c28f19ad94c3facf520e494e3e8fcd31b17f085c3/pymongo_search_utils-0.1.0-py3-none-any.whl", hash = "sha256:44f7601a99e8d979bb7ef7be611863c1a98943c92fb192bfa549ac8b1c281580", size = 17186, upload-time = "2025-11-24T15:12:11.2Z" }, +] + [[package]] name = "python-dotenv" version = "1.1.0"