Skip to content

Commit 0e96712

Browse files
committed
(250) - Add MongoDB Atlas Retrieval Model
1 parent a08b4ac commit 0e96712

File tree

3 files changed

+113
-2
lines changed

3 files changed

+113
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ Or open our intro notebook in Google Colab: [<img align="center" src="https://co
6262
For the optional Pinecone, Qdrant, [chromadb](https://github.com/chroma-core/chroma), or [marqo](https://github.com/marqo-ai/marqo) retrieval integration(s), include the extra(s) below:
6363

6464
```
65-
pip install dspy-ai[pinecone] # or [qdrant] or [chromadb] or [marqo]
65+
pip install dspy-ai[pinecone] # or [qdrant] or [chromadb] or [marqo] or [mongodb]
6666
```
6767

6868
## 2) Syntax: You're in charge of the workflow—it's free-form Python code!

dspy/retrieve/mongodb_atlas_rm.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from typing import List, Optional, Union, Any
2+
import dspy
3+
import os
4+
import openai
5+
import backoff
6+
7+
try:
8+
from pymongo import MongoClient
9+
from pymongo.errors import (
10+
ConnectionFailure,
11+
ConfigurationError,
12+
ServerSelectionTimeoutError,
13+
InvalidURI,
14+
OperationFailure,
15+
)
16+
except ImportError:
17+
raise ImportError(
18+
"Please install the pymongo package by running `pip install dspy-ai[mongodb]`"
19+
)
20+
21+
22+
def build_vector_search_pipeline(
23+
index_name: str, query_vector: List[float], num_candidates: int, limit: int
24+
) -> List[dict[str, Any]]:
25+
return [
26+
{
27+
"$vectorSearch": {
28+
"index": index_name,
29+
"path": "embedding",
30+
"queryVector": query_vector,
31+
"numCandidates": num_candidates,
32+
"limit": limit,
33+
}
34+
},
35+
{"$project": {"_id": 0, "text": 1, "score": {"$meta": "vectorSearchScore"}}},
36+
]
37+
38+
39+
class Embedder:
40+
def __init__(self, provider: str, model: str):
41+
if provider == "openai":
42+
openai.api_key = os.getenv("OPENAI_API_KEY")
43+
if not openai.api_key:
44+
raise ValueError("Environment variable OPENAI_API_KEY must be set")
45+
self.client = openai
46+
self.model = model
47+
48+
@backoff.on_exception(
49+
backoff.expo,
50+
(
51+
openai.error.RateLimitError,
52+
openai.error.ServiceUnavailableError,
53+
openai.error.APIError,
54+
),
55+
max_time=15,
56+
)
57+
def __call__(self, queries) -> Any:
58+
embedding = self.client.Embedding.create(input=queries, model=self.model)
59+
return [embedding["embedding"] for embedding in embedding["data"]]
60+
61+
62+
class MongoDBAtlasRM(dspy.Retrieve):
63+
def __init__(
64+
self,
65+
db_name: str,
66+
collection_name: str,
67+
index_name: str,
68+
k: int = 5,
69+
embedding_provider: str = "openai",
70+
embedding_model: str = "text-embedding-ada-002",
71+
):
72+
super().__init__(k=k)
73+
self.db_name = db_name
74+
self.collection_name = collection_name
75+
self.index_name = index_name
76+
self.username = os.getenv("ATLAS_USERNAME")
77+
self.password = os.getenv("ATLAS_PASSWORD")
78+
self.cluster_url = os.getenv("ATLAS_CLUSTER_URL")
79+
if not self.username:
80+
raise ValueError("Environment variable ATLAS_USERNAME must be set")
81+
if not self.password:
82+
raise ValueError("Environment variable ATLAS_PASSWORD must be set")
83+
if not self.cluster_url:
84+
raise ValueError("Environment variable ATLAS_CLUSTER_URL must be set")
85+
try:
86+
self.client = MongoClient(
87+
f"mongodb+srv://{self.username}:{self.password}@{self.cluster_url}/{self.db_name}"
88+
"?retryWrites=true&w=majority"
89+
)
90+
except (
91+
InvalidURI,
92+
ConfigurationError,
93+
ConnectionFailure,
94+
ServerSelectionTimeoutError,
95+
OperationFailure,
96+
) as e:
97+
raise ConnectionError("Failed to connect to MongoDB Atlas") from e
98+
99+
self.embedder = Embedder(provider=embedding_provider, model=embedding_model)
100+
101+
def forward(self, query_or_queries: str) -> dspy.Prediction:
102+
query_vector = self.embedder([query_or_queries])
103+
pipeline = build_vector_search_pipeline(
104+
index_name=self.index_name,
105+
query_vector=query_vector[0],
106+
num_candidates=self.k * 10,
107+
limit=self.k,
108+
)
109+
contents = self.client[self.db_name][self.collection_name].aggregate(pipeline)
110+
return dspy.Prediction(passages=list(contents))

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
"pinecone": ["pinecone-client~=2.2.4"],
2626
"qdrant": ["qdrant-client~=1.6.2", "fastembed~=0.1.0"],
2727
"chromadb": ["chromadb~=0.4.14"],
28-
"marqo": ["marqo"]
28+
"marqo": ["marqo"],
29+
"mongodb": ["pymongo~=3.12.0"],
2930
},
3031
classifiers=[
3132
"Development Status :: 3 - Alpha",

0 commit comments

Comments
 (0)