Skip to content

Commit 41d0f0e

Browse files
Merge pull request #1062 from stanfordnlp/revert-1027-main
Revert "fix(pinecone_rm): refactored to use cloud_embed and fix pinecone init"
2 parents 604943e + a377ffc commit 41d0f0e

File tree

2 files changed

+91
-715
lines changed

2 files changed

+91
-715
lines changed

dspy/retrieve/pinecone_rm.py

Lines changed: 91 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
Author: Dhar Rawal (@drawal1)
44
"""
55

6-
from abc import ABC, abstractmethod
76
from typing import List, Optional, Union
87

98
import backoff
@@ -14,6 +13,9 @@
1413
try:
1514
import pinecone
1615
except ImportError:
16+
pinecone = None
17+
18+
if pinecone is None:
1719
raise ImportError(
1820
"The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`",
1921
)
@@ -31,64 +33,6 @@
3133
except Exception:
3234
ERRORS = (openai.RateLimitError, openai.APIError)
3335

34-
35-
class CloudEmbedProvider(ABC):
36-
def __init__ (self, model, api_key=None):
37-
self.model = model
38-
self.api_key = api_key
39-
40-
@abstractmethod
41-
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
42-
pass
43-
44-
class OpenAIEmbed(CloudEmbedProvider):
45-
def __init__(self, model="text-embedding-ada-002", api_key: Optional[str]=None, org: Optional[str]=None):
46-
super().__init__(model, api_key)
47-
self.org = org
48-
if self.api_key:
49-
openai.api_key = self.api_key
50-
if self.org:
51-
openai.organization = org
52-
53-
54-
@backoff.on_exception(
55-
backoff.expo,
56-
ERRORS,
57-
max_time=15,
58-
)
59-
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
60-
if OPENAI_LEGACY:
61-
embedding = openai.Embedding.create(
62-
input=queries, model=self.model,
63-
)
64-
else:
65-
embedding = openai.embeddings.create(
66-
input=queries, model=self.model,
67-
).model_dump()
68-
return [embedding["embedding"] for embedding in embedding["data"]]
69-
70-
class CohereEmbed(CloudEmbedProvider):
71-
def __init__(self, model: str = "multilingual-22-12", api_key: Optional[str] = None):
72-
try:
73-
import cohere
74-
except ImportError:
75-
raise ImportError(
76-
"The cohere library is required to use CohereEmbed. Install it with `pip install cohere`",
77-
)
78-
super().__init__(model, api_key)
79-
self.client = cohere.Client(api_key)
80-
81-
@backoff.on_exception(
82-
backoff.expo,
83-
ERRORS,
84-
max_time=15,
85-
)
86-
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
87-
embeddings = self.client.embed(texts=queries, model=self.model).embeddings
88-
return embeddings
89-
90-
91-
9236
class PineconeRM(dspy.Retrieve):
9337
"""
9438
A retrieval module that uses Pinecone to return the top passages for a given query.
@@ -99,8 +43,11 @@ class PineconeRM(dspy.Retrieve):
9943
Args:
10044
pinecone_index_name (str): The name of the Pinecone index to query against.
10145
pinecone_api_key (str, optional): The Pinecone API key. Defaults to None.
46+
pinecone_env (str, optional): The Pinecone environment. Defaults to None.
10247
local_embed_model (str, optional): The local embedding model to use. A popular default is "sentence-transformers/all-mpnet-base-v2".
103-
cloud_emded_provider (CloudEmbedProvider, optional): The cloud embedding provider to use. Defaults to None.
48+
openai_embed_model (str, optional): The OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
49+
openai_api_key (str, optional): The API key for OpenAI. Defaults to None.
50+
openai_org (str, optional): The organization for OpenAI. Defaults to None.
10451
k (int, optional): The number of top passages to retrieve. Defaults to 3.
10552
10653
Returns:
@@ -110,7 +57,6 @@ class PineconeRM(dspy.Retrieve):
11057
Below is a code snippet that shows how to use this as the default retriver:
11158
```python
11259
llm = dspy.OpenAI(model="gpt-3.5-turbo")
113-
retriever_model = PineconeRM(index_name, cloud_emded_provider=OpenAIEmbed())
11460
retriever_model = PineconeRM(openai.api_key)
11561
dspy.settings.configure(lm=llm, rm=retriever_model)
11662
```
@@ -125,8 +71,11 @@ def __init__(
12571
self,
12672
pinecone_index_name: str,
12773
pinecone_api_key: Optional[str] = None,
74+
pinecone_env: Optional[str] = None,
12875
local_embed_model: Optional[str] = None,
129-
cloud_emded_provider: Optional[CloudEmbedProvider] = None,
76+
openai_embed_model: Optional[str] = "text-embedding-ada-002",
77+
openai_api_key: Optional[str] = None,
78+
openai_org: Optional[str] = None,
13079
k: int = 3,
13180
):
13281
if local_embed_model is not None:
@@ -146,25 +95,69 @@ def __init__(
14695
'mps' if torch.backends.mps.is_available()
14796
else 'cpu',
14897
)
149-
150-
elif cloud_emded_provider is not None:
98+
elif openai_embed_model is not None:
99+
self._openai_embed_model = openai_embed_model
151100
self.use_local_model = False
152-
self.cloud_emded_provider = cloud_emded_provider
153-
101+
# If not provided, defaults to env vars OPENAI_API_KEY and OPENAI_ORGANIZATION
102+
if openai_api_key:
103+
openai.api_key = openai_api_key
104+
if openai_org:
105+
openai.organization = openai_org
154106
else:
155107
raise ValueError(
156-
"Either local_embed_model or cloud_embed_provider must be provided.",
108+
"Either local_embed_model or openai_embed_model must be provided.",
157109
)
158110

159-
if pinecone_api_key is None:
160-
self.pinecone_client = pinecone.Pinecone()
161-
else:
162-
self.pinecone_client = pinecone.Pinecone(api_key=pinecone_api_key)
163-
164-
self._pinecone_index = self.pinecone_client.Index(pinecone_index_name)
111+
self._pinecone_index = self._init_pinecone(
112+
pinecone_index_name, pinecone_api_key, pinecone_env,
113+
)
165114

166115
super().__init__(k=k)
167116

117+
def _init_pinecone(
118+
self,
119+
index_name: str,
120+
api_key: Optional[str] = None,
121+
environment: Optional[str] = None,
122+
dimension: Optional[int] = None,
123+
distance_metric: Optional[str] = None,
124+
) -> pinecone.Index:
125+
"""Initialize pinecone and return the loaded index.
126+
127+
Args:
128+
index_name (str): The name of the index to load. If the index is not does not exist, it will be created.
129+
api_key (str, optional): The Pinecone API key, defaults to env var PINECONE_API_KEY if not provided.
130+
environment (str, optional): The environment (ie. `us-west1-gcp` or `gcp-starter`. Defaults to env PINECONE_ENVIRONMENT.
131+
132+
Raises:
133+
ValueError: If api_key or environment is not provided and not set as an environment variable.
134+
135+
Returns:
136+
pinecone.Index: The loaded index.
137+
"""
138+
139+
# Pinecone init overrides default if kwargs are present, so we need to exclude if None
140+
kwargs = {}
141+
if api_key:
142+
kwargs["api_key"] = api_key
143+
if environment:
144+
kwargs["environment"] = environment
145+
pinecone.init(**kwargs)
146+
147+
active_indexes = pinecone.list_indexes()
148+
if index_name not in active_indexes:
149+
if dimension is None and distance_metric is None:
150+
raise ValueError(
151+
"dimension and distance_metric must be provided since the index provided does not exist.",
152+
)
153+
154+
pinecone.create_index(
155+
name=index_name,
156+
dimension=dimension,
157+
metric=distance_metric,
158+
)
159+
160+
return pinecone.Index(index_name)
168161

169162
def _mean_pooling(
170163
self,
@@ -182,7 +175,11 @@ def _mean_pooling(
182175
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
183176
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
184177

185-
178+
@backoff.on_exception(
179+
backoff.expo,
180+
ERRORS,
181+
max_time=15,
182+
)
186183
def _get_embeddings(
187184
self,
188185
queries: List[str],
@@ -195,16 +192,24 @@ def _get_embeddings(
195192
Returns:
196193
List[List[float]]: List of embeddings corresponding to each query.
197194
"""
198-
if not self.use_local_model:
199-
return self.cloud_emded_provider.get_embeddings(queries)
200-
201195
try:
202196
import torch
203197
except ImportError as exc:
204198
raise ModuleNotFoundError(
205199
"You need to install torch to use a local embedding model with PineconeRM.",
206200
) from exc
207201

202+
if not self.use_local_model:
203+
if OPENAI_LEGACY:
204+
embedding = openai.Embedding.create(
205+
input=queries, model=self._openai_embed_model,
206+
)
207+
else:
208+
embedding = openai.embeddings.create(
209+
input=queries, model=self._openai_embed_model,
210+
).model_dump()
211+
return [embedding["embedding"] for embedding in embedding["data"]]
212+
208213
# Use local model
209214
encoded_input = self._local_tokenizer(queries, padding=True, truncation=True, return_tensors="pt").to(self.device)
210215
with torch.no_grad():
@@ -217,55 +222,51 @@ def _get_embeddings(
217222
# we need a pooling strategy to get a single vector representation of the input
218223
# so the default is to take the mean of the hidden states
219224

220-
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction:
225+
def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
221226
"""Search with pinecone for self.k top passages for query
222227
223228
Args:
224229
query_or_queries (Union[str, List[str]]): The query or queries to search for.
225-
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k
226230
227231
Returns:
228232
dspy.Prediction: An object containing the retrieved passages.
229233
"""
230-
k = k if k is not None else self.k
231234
queries = (
232235
[query_or_queries]
233236
if isinstance(query_or_queries, str)
234237
else query_or_queries
235238
)
236239
queries = [q for q in queries if q] # Filter empty queries
237240
embeddings = self._get_embeddings(queries)
241+
238242
# For single query, just look up the top k passages
239243
if len(queries) == 1:
240244
results_dict = self._pinecone_index.query(
241-
vector=embeddings[0], top_k=k, include_metadata=True,
245+
embeddings[0], top_k=self.k, include_metadata=True,
242246
)
243247

244248
# Sort results by score
245249
sorted_results = sorted(
246250
results_dict["matches"], key=lambda x: x.get("scores", 0.0), reverse=True,
247251
)
248-
249252
passages = [result["metadata"]["text"] for result in sorted_results]
250-
passages = [dotdict({"long_text": passage}) for passage in passages]
251-
return passages
253+
passages = [dotdict({"long_text": passage for passage in passages})]
254+
return dspy.Prediction(passages=passages)
252255

253256
# For multiple queries, query each and return the highest scoring passages
254257
# If a passage is returned multiple times, the score is accumulated. For this reason we increase top_k by 3x
255258
passage_scores = {}
256259
for embedding in embeddings:
257260
results_dict = self._pinecone_index.query(
258-
vector=embedding, top_k=k * 3, include_metadata=True,
261+
embedding, top_k=self.k * 3, include_metadata=True,
259262
)
260263
for result in results_dict["matches"]:
261264
passage_scores[result["metadata"]["text"]] = (
262265
passage_scores.get(result["metadata"]["text"], 0.0)
263266
+ result["score"]
264267
)
265-
268+
266269
sorted_passages = sorted(
267270
passage_scores.items(), key=lambda x: x[1], reverse=True,
268-
)[: k]
269-
270-
passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages]
271-
return passages
271+
)[: self.k]
272+
return dspy.Prediction(passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages])

0 commit comments

Comments
 (0)