Skip to content

Commit 604943e

Browse files
Merge pull request #1027 from csellis/main
fix(pinecone_rm): refactored to use cloud_embed and fix pinecone init
2 parents 1351647 + e2115cf commit 604943e

File tree

2 files changed

+715
-91
lines changed

2 files changed

+715
-91
lines changed

dspy/retrieve/pinecone_rm.py

Lines changed: 90 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Author: Dhar Rawal (@drawal1)
44
"""
55

6+
from abc import ABC, abstractmethod
67
from typing import List, Optional, Union
78

89
import backoff
@@ -13,9 +14,6 @@
1314
try:
1415
import pinecone
1516
except ImportError:
16-
pinecone = None
17-
18-
if pinecone is None:
1917
raise ImportError(
2018
"The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`",
2119
)
@@ -33,6 +31,64 @@
3331
except Exception:
3432
ERRORS = (openai.RateLimitError, openai.APIError)
3533

34+
35+
class CloudEmbedProvider(ABC):
36+
def __init__ (self, model, api_key=None):
37+
self.model = model
38+
self.api_key = api_key
39+
40+
@abstractmethod
41+
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
42+
pass
43+
44+
class OpenAIEmbed(CloudEmbedProvider):
45+
def __init__(self, model="text-embedding-ada-002", api_key: Optional[str]=None, org: Optional[str]=None):
46+
super().__init__(model, api_key)
47+
self.org = org
48+
if self.api_key:
49+
openai.api_key = self.api_key
50+
if self.org:
51+
openai.organization = org
52+
53+
54+
@backoff.on_exception(
55+
backoff.expo,
56+
ERRORS,
57+
max_time=15,
58+
)
59+
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
60+
if OPENAI_LEGACY:
61+
embedding = openai.Embedding.create(
62+
input=queries, model=self.model,
63+
)
64+
else:
65+
embedding = openai.embeddings.create(
66+
input=queries, model=self.model,
67+
).model_dump()
68+
return [embedding["embedding"] for embedding in embedding["data"]]
69+
70+
class CohereEmbed(CloudEmbedProvider):
71+
def __init__(self, model: str = "multilingual-22-12", api_key: Optional[str] = None):
72+
try:
73+
import cohere
74+
except ImportError:
75+
raise ImportError(
76+
"The cohere library is required to use CohereEmbed. Install it with `pip install cohere`",
77+
)
78+
super().__init__(model, api_key)
79+
self.client = cohere.Client(api_key)
80+
81+
@backoff.on_exception(
82+
backoff.expo,
83+
ERRORS,
84+
max_time=15,
85+
)
86+
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
87+
embeddings = self.client.embed(texts=queries, model=self.model).embeddings
88+
return embeddings
89+
90+
91+
3692
class PineconeRM(dspy.Retrieve):
3793
"""
3894
A retrieval module that uses Pinecone to return the top passages for a given query.
@@ -43,11 +99,8 @@ class PineconeRM(dspy.Retrieve):
4399
Args:
44100
pinecone_index_name (str): The name of the Pinecone index to query against.
45101
pinecone_api_key (str, optional): The Pinecone API key. Defaults to None.
46-
pinecone_env (str, optional): The Pinecone environment. Defaults to None.
47102
local_embed_model (str, optional): The local embedding model to use. A popular default is "sentence-transformers/all-mpnet-base-v2".
48-
openai_embed_model (str, optional): The OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
49-
openai_api_key (str, optional): The API key for OpenAI. Defaults to None.
50-
openai_org (str, optional): The organization for OpenAI. Defaults to None.
103+
cloud_emded_provider (CloudEmbedProvider, optional): The cloud embedding provider to use. Defaults to None.
51104
k (int, optional): The number of top passages to retrieve. Defaults to 3.
52105
53106
Returns:
@@ -57,6 +110,7 @@ class PineconeRM(dspy.Retrieve):
57110
Below is a code snippet that shows how to use this as the default retriver:
58111
```python
59112
llm = dspy.OpenAI(model="gpt-3.5-turbo")
113+
retriever_model = PineconeRM(index_name, cloud_emded_provider=OpenAIEmbed())
60114
retriever_model = PineconeRM(openai.api_key)
61115
dspy.settings.configure(lm=llm, rm=retriever_model)
62116
```
@@ -71,11 +125,8 @@ def __init__(
71125
self,
72126
pinecone_index_name: str,
73127
pinecone_api_key: Optional[str] = None,
74-
pinecone_env: Optional[str] = None,
75128
local_embed_model: Optional[str] = None,
76-
openai_embed_model: Optional[str] = "text-embedding-ada-002",
77-
openai_api_key: Optional[str] = None,
78-
openai_org: Optional[str] = None,
129+
cloud_emded_provider: Optional[CloudEmbedProvider] = None,
79130
k: int = 3,
80131
):
81132
if local_embed_model is not None:
@@ -95,69 +146,25 @@ def __init__(
95146
'mps' if torch.backends.mps.is_available()
96147
else 'cpu',
97148
)
98-
elif openai_embed_model is not None:
99-
self._openai_embed_model = openai_embed_model
149+
150+
elif cloud_emded_provider is not None:
100151
self.use_local_model = False
101-
# If not provided, defaults to env vars OPENAI_API_KEY and OPENAI_ORGANIZATION
102-
if openai_api_key:
103-
openai.api_key = openai_api_key
104-
if openai_org:
105-
openai.organization = openai_org
152+
self.cloud_emded_provider = cloud_emded_provider
153+
106154
else:
107155
raise ValueError(
108-
"Either local_embed_model or openai_embed_model must be provided.",
156+
"Either local_embed_model or cloud_embed_provider must be provided.",
109157
)
110158

111-
self._pinecone_index = self._init_pinecone(
112-
pinecone_index_name, pinecone_api_key, pinecone_env,
113-
)
114-
115-
super().__init__(k=k)
116-
117-
def _init_pinecone(
118-
self,
119-
index_name: str,
120-
api_key: Optional[str] = None,
121-
environment: Optional[str] = None,
122-
dimension: Optional[int] = None,
123-
distance_metric: Optional[str] = None,
124-
) -> pinecone.Index:
125-
"""Initialize pinecone and return the loaded index.
126-
127-
Args:
128-
index_name (str): The name of the index to load. If the index is not does not exist, it will be created.
129-
api_key (str, optional): The Pinecone API key, defaults to env var PINECONE_API_KEY if not provided.
130-
environment (str, optional): The environment (ie. `us-west1-gcp` or `gcp-starter`. Defaults to env PINECONE_ENVIRONMENT.
131-
132-
Raises:
133-
ValueError: If api_key or environment is not provided and not set as an environment variable.
134-
135-
Returns:
136-
pinecone.Index: The loaded index.
137-
"""
159+
if pinecone_api_key is None:
160+
self.pinecone_client = pinecone.Pinecone()
161+
else:
162+
self.pinecone_client = pinecone.Pinecone(api_key=pinecone_api_key)
138163

139-
# Pinecone init overrides default if kwargs are present, so we need to exclude if None
140-
kwargs = {}
141-
if api_key:
142-
kwargs["api_key"] = api_key
143-
if environment:
144-
kwargs["environment"] = environment
145-
pinecone.init(**kwargs)
146-
147-
active_indexes = pinecone.list_indexes()
148-
if index_name not in active_indexes:
149-
if dimension is None and distance_metric is None:
150-
raise ValueError(
151-
"dimension and distance_metric must be provided since the index provided does not exist.",
152-
)
164+
self._pinecone_index = self.pinecone_client.Index(pinecone_index_name)
153165

154-
pinecone.create_index(
155-
name=index_name,
156-
dimension=dimension,
157-
metric=distance_metric,
158-
)
166+
super().__init__(k=k)
159167

160-
return pinecone.Index(index_name)
161168

162169
def _mean_pooling(
163170
self,
@@ -175,11 +182,7 @@ def _mean_pooling(
175182
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
176183
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
177184

178-
@backoff.on_exception(
179-
backoff.expo,
180-
ERRORS,
181-
max_time=15,
182-
)
185+
183186
def _get_embeddings(
184187
self,
185188
queries: List[str],
@@ -192,24 +195,16 @@ def _get_embeddings(
192195
Returns:
193196
List[List[float]]: List of embeddings corresponding to each query.
194197
"""
198+
if not self.use_local_model:
199+
return self.cloud_emded_provider.get_embeddings(queries)
200+
195201
try:
196202
import torch
197203
except ImportError as exc:
198204
raise ModuleNotFoundError(
199205
"You need to install torch to use a local embedding model with PineconeRM.",
200206
) from exc
201207

202-
if not self.use_local_model:
203-
if OPENAI_LEGACY:
204-
embedding = openai.Embedding.create(
205-
input=queries, model=self._openai_embed_model,
206-
)
207-
else:
208-
embedding = openai.embeddings.create(
209-
input=queries, model=self._openai_embed_model,
210-
).model_dump()
211-
return [embedding["embedding"] for embedding in embedding["data"]]
212-
213208
# Use local model
214209
encoded_input = self._local_tokenizer(queries, padding=True, truncation=True, return_tensors="pt").to(self.device)
215210
with torch.no_grad():
@@ -222,51 +217,55 @@ def _get_embeddings(
222217
# we need a pooling strategy to get a single vector representation of the input
223218
# so the default is to take the mean of the hidden states
224219

225-
def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
220+
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction:
226221
"""Search with pinecone for self.k top passages for query
227222
228223
Args:
229224
query_or_queries (Union[str, List[str]]): The query or queries to search for.
225+
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k
230226
231227
Returns:
232228
dspy.Prediction: An object containing the retrieved passages.
233229
"""
230+
k = k if k is not None else self.k
234231
queries = (
235232
[query_or_queries]
236233
if isinstance(query_or_queries, str)
237234
else query_or_queries
238235
)
239236
queries = [q for q in queries if q] # Filter empty queries
240237
embeddings = self._get_embeddings(queries)
241-
242238
# For single query, just look up the top k passages
243239
if len(queries) == 1:
244240
results_dict = self._pinecone_index.query(
245-
embeddings[0], top_k=self.k, include_metadata=True,
241+
vector=embeddings[0], top_k=k, include_metadata=True,
246242
)
247243

248244
# Sort results by score
249245
sorted_results = sorted(
250246
results_dict["matches"], key=lambda x: x.get("scores", 0.0), reverse=True,
251247
)
248+
252249
passages = [result["metadata"]["text"] for result in sorted_results]
253-
passages = [dotdict({"long_text": passage for passage in passages})]
254-
return dspy.Prediction(passages=passages)
250+
passages = [dotdict({"long_text": passage}) for passage in passages]
251+
return passages
255252

256253
# For multiple queries, query each and return the highest scoring passages
257254
# If a passage is returned multiple times, the score is accumulated. For this reason we increase top_k by 3x
258255
passage_scores = {}
259256
for embedding in embeddings:
260257
results_dict = self._pinecone_index.query(
261-
embedding, top_k=self.k * 3, include_metadata=True,
258+
vector=embedding, top_k=k * 3, include_metadata=True,
262259
)
263260
for result in results_dict["matches"]:
264261
passage_scores[result["metadata"]["text"]] = (
265262
passage_scores.get(result["metadata"]["text"], 0.0)
266263
+ result["score"]
267264
)
268-
265+
269266
sorted_passages = sorted(
270267
passage_scores.items(), key=lambda x: x[1], reverse=True,
271-
)[: self.k]
272-
return dspy.Prediction(passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages])
268+
)[: k]
269+
270+
passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages]
271+
return passages

0 commit comments

Comments
 (0)