Skip to content

Commit cc25f33

Browse files
committed
INTPYTHON-752 Integrate pymongo-vectorsearch-utils
1 parent 982c1d4 commit cc25f33

File tree

5 files changed

+17
-48
lines changed

5 files changed

+17
-48
lines changed

libs/langchain-mongodb/langchain_mongodb/index.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22

33
import logging
44
from time import monotonic, sleep
5-
from typing import Any, Callable, Dict, List, Optional, Union
5+
from typing import Any, Callable, Dict, List, Optional
66

77
from pymongo.collection import Collection
8-
from pymongo.operations import SearchIndexModel
98

109
logger = logging.getLogger(__file__)
1110

1211

12+
# Don't break imports for modules that expect these functions
13+
# to be in this module.
14+
15+
1316
def _vector_search_index_definition(
1417
dimensions: int,
1518
path: str,

libs/langchain-mongodb/langchain_mongodb/utils.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from typing import Any, Dict, List, Union
2525

2626
import numpy as np
27-
from pymongo import MongoClient
2827
from pymongo.driver_info import DriverInfo
2928

3029
logger = logging.getLogger(__name__)
@@ -33,11 +32,8 @@
3332

3433
DRIVER_METADATA = DriverInfo(name="Langchain", version=version("langchain-mongodb"))
3534

36-
37-
def _append_client_metadata(client: MongoClient) -> None:
38-
# append_metadata was added in PyMongo 4.14.0, but is a valid database name on earlier versions
39-
if callable(client.append_metadata):
40-
client.append_metadata(DRIVER_METADATA)
35+
# Don't break imports for modules that expect this function
36+
# to be in this module.
4137

4238

4339
def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:

libs/langchain-mongodb/langchain_mongodb/vectorstores.py

Lines changed: 7 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
from langchain_core.embeddings import Embeddings
2323
from langchain_core.runnables.config import run_in_executor
2424
from langchain_core.vectorstores import VectorStore
25-
from pymongo import MongoClient, ReplaceOne
25+
from pymongo import MongoClient
2626
from pymongo.collection import Collection
2727
from pymongo.errors import CollectionInvalid
28+
from pymongo_vectorsearch_utils import bulk_embed_and_insert_texts
2829

2930
from langchain_mongodb.index import (
3031
create_vector_search_index,
@@ -362,11 +363,11 @@ def add_texts(
362363
metadatas_batch.append(metadata)
363364
if (j + 1) % batch_size == 0 or size >= 47_000_000:
364365
if ids:
365-
batch_res = self.bulk_embed_and_insert_texts(
366+
batch_res = bulk_embed_and_insert_texts(
366367
texts_batch, metadatas_batch, ids[i : j + 1]
367368
)
368369
else:
369-
batch_res = self.bulk_embed_and_insert_texts(
370+
batch_res = bulk_embed_and_insert_texts(
370371
texts_batch, metadatas_batch
371372
)
372373
result_ids.extend(batch_res)
@@ -376,13 +377,11 @@ def add_texts(
376377
i = j + 1
377378
if texts_batch:
378379
if ids:
379-
batch_res = self.bulk_embed_and_insert_texts(
380+
batch_res = bulk_embed_and_insert_texts(
380381
texts_batch, metadatas_batch, ids[i : j + 1]
381382
)
382383
else:
383-
batch_res = self.bulk_embed_and_insert_texts(
384-
texts_batch, metadatas_batch
385-
)
384+
batch_res = bulk_embed_and_insert_texts(texts_batch, metadatas_batch)
386385
result_ids.extend(batch_res)
387386
return result_ids
388387

@@ -419,37 +418,6 @@ def get_by_ids(self, ids: Sequence[str], /) -> list[Document]:
419418
docs.append(Document(page_content=text, id=oid_to_str(_id), metadata=doc))
420419
return docs
421420

422-
def bulk_embed_and_insert_texts(
423-
self,
424-
texts: Union[List[str], Iterable[str]],
425-
metadatas: Union[List[dict], Generator[dict, Any, Any]],
426-
ids: Optional[List[str]] = None,
427-
) -> List[str]:
428-
"""Bulk insert single batch of texts, embeddings, and optionally ids.
429-
430-
See add_texts for additional details.
431-
"""
432-
if not texts:
433-
return []
434-
# Compute embedding vectors
435-
embeddings = self._embedding.embed_documents(list(texts))
436-
if not ids:
437-
ids = [str(ObjectId()) for _ in range(len(list(texts)))]
438-
docs = [
439-
{
440-
"_id": str_to_oid(i),
441-
self._text_key: t,
442-
self._embedding_key: embedding,
443-
**m,
444-
}
445-
for i, t, m, embedding in zip(ids, texts, metadatas, embeddings)
446-
]
447-
operations = [ReplaceOne({"_id": doc["_id"]}, doc, upsert=True) for doc in docs]
448-
# insert the documents in MongoDB Atlas
449-
result = self._collection.bulk_write(operations)
450-
assert result.upserted_ids is not None
451-
return [oid_to_str(_id) for _id in result.upserted_ids.values()]
452-
453421
def add_documents(
454422
self,
455423
documents: List[Document],
@@ -481,7 +449,7 @@ def add_documents(
481449
*[(doc.page_content, doc.metadata) for doc in documents[start:end]]
482450
)
483451
result_ids.extend(
484-
self.bulk_embed_and_insert_texts(
452+
bulk_embed_and_insert_texts(
485453
texts=texts, metadatas=metadatas, ids=ids[start:end]
486454
)
487455
)

libs/langchain-mongodb/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ dependencies = [
1616
"langchain-text-splitters>=0.3",
1717
"numpy>=1.26",
1818
"lark<2.0.0,>=1.1.9",
19+
# "pymongo-vectorsearch-utils",
1920
]
2021

2122
[dependency-groups]

libs/langchain-mongodb/tests/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pymongo.driver_info import DriverInfo
2727
from pymongo.operations import SearchIndexModel
2828
from pymongo.results import BulkWriteResult, DeleteResult, InsertManyResult
29+
from pymongo_vectorsearch_utils import bulk_embed_and_insert_texts
2930

3031
from langchain_mongodb import MongoDBAtlasVectorSearch
3132
from langchain_mongodb.agent_toolkit.database import MongoDBDatabase
@@ -63,7 +64,7 @@ def bulk_embed_and_insert_texts(
6364
ids: Optional[List[str]] = None,
6465
) -> List:
6566
"""Patched insert_texts that waits for data to be indexed before returning"""
66-
ids_inserted = super().bulk_embed_and_insert_texts(texts, metadatas, ids)
67+
ids_inserted = bulk_embed_and_insert_texts(texts, metadatas, ids)
6768
n_docs = self.collection.count_documents({})
6869
start = monotonic()
6970
while monotonic() - start <= TIMEOUT:

0 commit comments

Comments
 (0)