Skip to content

Commit f94861f

Browse files
authored
Merge pull request #257 from jhyearsley/main
(250) - Add MongoDB Atlas Retrieval Model
2 parents f3dda13 + 797cf94 commit f94861f

File tree

3 files changed

+120
-2
lines changed

3 files changed

+120
-2
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ By default, DSPy depends on `openai==0.28`. However, if you install `openai>=1.0
6363
For the optional Pinecone, Qdrant, [chromadb](https://github.com/chroma-core/chroma), or [marqo](https://github.com/marqo-ai/marqo) retrieval integration(s), include the extra(s) below:
6464

6565
```
66-
pip install dspy-ai[pinecone] # or [qdrant] or [chromadb] or [marqo]
66+
pip install dspy-ai[pinecone] # or [qdrant] or [chromadb] or [marqo] or [mongodb]
6767
```
6868

6969
## 2) Documentation

dspy/retrieve/mongodb_atlas_rm.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from typing import List, Optional, Union, Any
2+
import dspy
3+
import os
4+
from openai import (
5+
OpenAI,
6+
APITimeoutError,
7+
InternalServerError,
8+
RateLimitError,
9+
UnprocessableEntityError,
10+
)
11+
import backoff
12+
13+
try:
14+
from pymongo import MongoClient
15+
from pymongo.errors import (
16+
ConnectionFailure,
17+
ConfigurationError,
18+
ServerSelectionTimeoutError,
19+
InvalidURI,
20+
OperationFailure,
21+
)
22+
except ImportError:
23+
raise ImportError(
24+
"Please install the pymongo package by running `pip install dspy-ai[mongodb]`"
25+
)
26+
27+
28+
def build_vector_search_pipeline(
29+
index_name: str, query_vector: List[float], num_candidates: int, limit: int
30+
) -> List[dict[str, Any]]:
31+
return [
32+
{
33+
"$vectorSearch": {
34+
"index": index_name,
35+
"path": "embedding",
36+
"queryVector": query_vector,
37+
"numCandidates": num_candidates,
38+
"limit": limit,
39+
}
40+
},
41+
{"$project": {"_id": 0, "text": 1, "score": {"$meta": "vectorSearchScore"}}},
42+
]
43+
44+
45+
class Embedder:
46+
def __init__(self, provider: str, model: str):
47+
if provider == "openai":
48+
api_key = os.getenv("OPENAI_API_KEY")
49+
if not api_key:
50+
raise ValueError("Environment variable OPENAI_API_KEY must be set")
51+
self.client = OpenAI()
52+
self.model = model
53+
54+
@backoff.on_exception(
55+
backoff.expo,
56+
(
57+
APITimeoutError,
58+
InternalServerError,
59+
RateLimitError,
60+
UnprocessableEntityError,
61+
),
62+
max_time=15,
63+
)
64+
def __call__(self, queries) -> Any:
65+
embedding = self.client.embeddings.create(input=queries, model=self.model)
66+
return [result.embedding for result in embedding.data]
67+
68+
69+
class MongoDBAtlasRM(dspy.Retrieve):
70+
def __init__(
71+
self,
72+
db_name: str,
73+
collection_name: str,
74+
index_name: str,
75+
k: int = 5,
76+
embedding_provider: str = "openai",
77+
embedding_model: str = "text-embedding-ada-002",
78+
):
79+
super().__init__(k=k)
80+
self.db_name = db_name
81+
self.collection_name = collection_name
82+
self.index_name = index_name
83+
self.username = os.getenv("ATLAS_USERNAME")
84+
self.password = os.getenv("ATLAS_PASSWORD")
85+
self.cluster_url = os.getenv("ATLAS_CLUSTER_URL")
86+
if not self.username:
87+
raise ValueError("Environment variable ATLAS_USERNAME must be set")
88+
if not self.password:
89+
raise ValueError("Environment variable ATLAS_PASSWORD must be set")
90+
if not self.cluster_url:
91+
raise ValueError("Environment variable ATLAS_CLUSTER_URL must be set")
92+
try:
93+
self.client = MongoClient(
94+
f"mongodb+srv://{self.username}:{self.password}@{self.cluster_url}/{self.db_name}"
95+
"?retryWrites=true&w=majority"
96+
)
97+
except (
98+
InvalidURI,
99+
ConfigurationError,
100+
ConnectionFailure,
101+
ServerSelectionTimeoutError,
102+
OperationFailure,
103+
) as e:
104+
raise ConnectionError("Failed to connect to MongoDB Atlas") from e
105+
106+
self.embedder = Embedder(provider=embedding_provider, model=embedding_model)
107+
108+
def forward(self, query_or_queries: str) -> dspy.Prediction:
109+
query_vector = self.embedder([query_or_queries])
110+
pipeline = build_vector_search_pipeline(
111+
index_name=self.index_name,
112+
query_vector=query_vector[0],
113+
num_candidates=self.k * 10,
114+
limit=self.k,
115+
)
116+
contents = self.client[self.db_name][self.collection_name].aggregate(pipeline)
117+
return dspy.Prediction(passages=list(contents))

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
"pinecone": ["pinecone-client~=2.2.4"],
2626
"qdrant": ["qdrant-client~=1.6.2", "fastembed~=0.1.0"],
2727
"chromadb": ["chromadb~=0.4.14"],
28-
"marqo": ["marqo"]
28+
"marqo": ["marqo"],
29+
"mongodb": ["pymongo~=3.12.0"],
2930
},
3031
classifiers=[
3132
"Development Status :: 3 - Alpha",

0 commit comments

Comments
 (0)