Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 7 additions & 178 deletions libs/langchain-mongodb/langchain_mongodb/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

import logging
from time import monotonic, sleep
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional

from pymongo.collection import Collection
from pymongo.operations import SearchIndexModel
from pymongo_search_utils import (
create_fulltext_search_index, # noqa: F401
create_vector_search_index, # noqa: F401
drop_vector_search_index, # noqa: F401
update_vector_search_index, # noqa: F401
)

logger = logging.getLogger(__file__)

Expand Down Expand Up @@ -34,135 +39,6 @@ def _vector_search_index_definition(
return definition


def create_vector_search_index(
collection: Collection,
index_name: str,
dimensions: int,
path: str,
similarity: str,
filters: Optional[List[str]] = None,
*,
wait_until_complete: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Experimental Utility function to create a vector search index

Args:
collection (Collection): MongoDB Collection
index_name (str): Name of Index
dimensions (int): Number of dimensions in embedding
path (str): field with vector embedding
similarity (str): The similarity score used for the index
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready.
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
"""
logger.info("Creating Search Index %s on %s", index_name, collection.name)

if collection.name not in collection.database.list_collection_names(
authorizedCollections=True
):
collection.database.create_collection(collection.name)

result = collection.create_search_index(
SearchIndexModel(
definition=_vector_search_index_definition(
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
**kwargs,
),
name=index_name,
type="vectorSearch",
)
)

if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
err=f"{index_name=} did not complete in {wait_until_complete}!",
timeout=wait_until_complete,
)
logger.info(result)


def drop_vector_search_index(
collection: Collection,
index_name: str,
*,
wait_until_complete: Optional[float] = None,
) -> None:
"""Drop a created vector search index

Args:
collection (Collection): MongoDB Collection with index to be dropped
index_name (str): Name of the MongoDB index
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready.
"""
logger.info(
"Dropping Search Index %s from Collection: %s", index_name, collection.name
)
collection.drop_search_index(index_name)
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: len(list(collection.list_search_indexes())) == 0,
err=f"Index {index_name} did not drop in {wait_until_complete}!",
timeout=wait_until_complete,
)
logger.info("Vector Search index %s.%s dropped", collection.name, index_name)


def update_vector_search_index(
collection: Collection,
index_name: str,
dimensions: int,
path: str,
similarity: str,
filters: Optional[List[str]] = None,
*,
wait_until_complete: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Update a search index.

Replace the existing index definition with the provided definition.

Args:
collection (Collection): MongoDB Collection
index_name (str): Name of Index
dimensions (int): Number of dimensions in embedding
path (str): field with vector embedding
similarity (str): The similarity score used for the index.
filters (List[str]): Fields/paths to index to allow filtering in $vectorSearch
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready.
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
"""
logger.info(
"Updating Search Index %s from Collection: %s", index_name, collection.name
)
collection.update_search_index(
name=index_name,
definition=_vector_search_index_definition(
dimensions=dimensions,
path=path,
similarity=similarity,
filters=filters,
**kwargs,
),
)
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
err=f"Index {index_name} update did not complete in {wait_until_complete}!",
timeout=wait_until_complete,
)
logger.info("Update succeeded")


def _is_index_ready(collection: Collection, index_name: str) -> bool:
"""Check for the index name in the list of available search indexes to see if the
specified index is of status READY
Expand Down Expand Up @@ -199,50 +75,3 @@ def _wait_for_predicate(
if monotonic() - start > timeout:
raise TimeoutError(err)
sleep(interval)


def create_fulltext_search_index(
collection: Collection,
index_name: str,
field: Union[str, List[str]],
*,
wait_until_complete: Optional[float] = None,
**kwargs: Any,
) -> None:
"""Experimental Utility function to create an Atlas Search index

Args:
collection (Collection): MongoDB Collection
index_name (str): Name of Index
field (str): Field to index
wait_until_complete (Optional[float]): If provided, number of seconds to wait
until search index is ready
kwargs: Keyword arguments supplying any additional options to SearchIndexModel.
"""
logger.info("Creating Search Index %s on %s", index_name, collection.name)

if collection.name not in collection.database.list_collection_names(
authorizedCollections=True
):
collection.database.create_collection(collection.name)

if isinstance(field, str):
fields_definition = {field: [{"type": "string"}]}
else:
fields_definition = {f: [{"type": "string"}] for f in field}
definition = {"mappings": {"dynamic": False, "fields": fields_definition}}
result = collection.create_search_index(
SearchIndexModel(
definition=definition,
name=index_name,
type="search",
**kwargs,
)
)
if wait_until_complete:
_wait_for_predicate(
predicate=lambda: _is_index_ready(collection, index_name),
err=f"{index_name=} did not complete in {wait_until_complete}!",
timeout=wait_until_complete,
)
logger.info(result)
119 changes: 7 additions & 112 deletions libs/langchain-mongodb/langchain_mongodb/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@

from typing import Any, Dict, List, Optional, Union

from pymongo_search_utils import (
combine_pipelines, # noqa: F401
final_hybrid_stage, # noqa: F401
reciprocal_rank_stage, # noqa: F401
vector_search_stage, # noqa: F401
)


def text_search_stage(
query: str,
Expand Down Expand Up @@ -48,115 +55,3 @@ def text_search_stage(
pipeline.append({"$limit": limit}) # type: ignore

return pipeline # type: ignore


def vector_search_stage(
query_vector: List[float],
search_field: str,
index_name: str,
top_k: int = 4,
filter: Optional[Dict[str, Any]] = None,
oversampling_factor: int = 10,
**kwargs: Any,
) -> Dict[str, Any]: # noqa: E501
"""Vector Search Stage without Scores.

Scoring is applied later depending on strategy.
vector search includes a vectorSearchScore that is typically used.
hybrid uses Reciprocal Rank Fusion.

Args:
query_vector: List of embedding vector
search_field: Field in Collection containing embedding vectors
index_name: Name of Atlas Vector Search Index tied to Collection
top_k: Number of documents to return
oversampling_factor: this times limit is the number of candidates
filter: MQL match expression comparing an indexed field.
Some operators are not supported.
See `vectorSearch filter docs <https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#atlas-vector-search-pre-filter>`_


Returns:
Dictionary defining the $vectorSearch
"""
stage = {
"index": index_name,
"path": search_field,
"queryVector": query_vector,
"numCandidates": top_k * oversampling_factor,
"limit": top_k,
}
if filter:
stage["filter"] = filter
return {"$vectorSearch": stage}


def combine_pipelines(
pipeline: List[Any], stage: List[Dict[str, Any]], collection_name: str
) -> None:
"""Combines two aggregations into a single result set in-place."""
if pipeline:
pipeline.append({"$unionWith": {"coll": collection_name, "pipeline": stage}})
else:
pipeline.extend(stage)


def reciprocal_rank_stage(
score_field: str, penalty: float = 0, weight: float = 1, **kwargs: Any
) -> List[Dict[str, Any]]:
"""
Stage adds Weighted Reciprocal Rank Fusion (WRRF) scoring.

First, it groups documents into an array, assigns rank by array index,
and then computes a weighted RRF score.

Args:
score_field: A unique string to identify the search being ranked.
penalty: A non-negative float (e.g., 60 for RRF-60). Controls the denominator.
weight: A float multiplier for this source's importance.
**kwargs: Ignored; allows future extensions or passthrough args.

Returns:
Aggregation pipeline stage for weighted RRF scoring.
"""

return [
{"$group": {"_id": None, "docs": {"$push": "$$ROOT"}}},
{"$unwind": {"path": "$docs", "includeArrayIndex": "rank"}},
{
"$addFields": {
f"docs.{score_field}": {
"$multiply": [
weight,
{"$divide": [1.0, {"$add": ["$rank", penalty, 1]}]},
]
},
"docs.rank": "$rank",
"_id": "$docs._id",
}
},
{"$replaceRoot": {"newRoot": "$docs"}},
]


def final_hybrid_stage(
scores_fields: List[str], limit: int, **kwargs: Any
) -> List[Dict[str, Any]]:
"""Sum weighted scores, sort, and apply limit.

Args:
scores_fields: List of fields given to scores of vector and text searches
limit: Number of documents to return

Returns:
Final aggregation stages
"""

return [
{"$group": {"_id": "$_id", "docs": {"$mergeObjects": "$$ROOT"}}},
{"$replaceRoot": {"newRoot": "$docs"}},
{"$set": {score: {"$ifNull": [f"${score}", 0]} for score in scores_fields}},
{"$addFields": {"score": {"$add": [f"${score}" for score in scores_fields]}}},
{"$sort": {"score": -1}},
{"$limit": limit},
]
5 changes: 2 additions & 3 deletions libs/langchain-mongodb/langchain_mongodb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import numpy as np
from pymongo import MongoClient
from pymongo.driver_info import DriverInfo
from pymongo_search_utils import append_client_metadata

logger = logging.getLogger(__name__)

Expand All @@ -35,9 +36,7 @@


def _append_client_metadata(client: MongoClient) -> None:
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
if callable(client.append_metadata):
client.append_metadata(DRIVER_METADATA)
append_client_metadata(client=client, driver_info=DRIVER_METADATA)


def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
Expand Down
Loading
Loading