Skip to content

Commit f7f4ee7

Browse files
committed
fix(pinecone_rm): refactored to use cloud_embed and fix pinecone init
source is from https://github.com/stanfordnlp/dspy/pull/342/files
1 parent 2a97d20 commit f7f4ee7

File tree

2 files changed

+728
-106
lines changed

2 files changed

+728
-106
lines changed

dspy/retrieve/pinecone_rm.py

Lines changed: 103 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,21 @@
33
Author: Dhar Rawal (@drawal1)
44
"""
55

6-
from typing import List, Optional, Union
7-
8-
import backoff
9-
10-
import dspy
6+
import os
117
from dsp.utils import dotdict
8+
from typing import Optional, List, Union, Any
9+
import dspy
10+
import backoff
11+
from abc import ABC, abstractmethod
1212

1313
try:
1414
import pinecone
1515
except ImportError:
16-
pinecone = None
17-
18-
if pinecone is None:
1916
raise ImportError(
20-
"The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`",
17+
"The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`"
2118
)
2219

2320
import openai
24-
2521
try:
2622
OPENAI_LEGACY = int(openai.version.__version__[0]) == 0
2723
except Exception:
@@ -33,6 +29,64 @@
3329
except Exception:
3430
ERRORS = (openai.RateLimitError, openai.APIError)
3531

32+
33+
class CloudEmbedProvider(ABC):
34+
def __init__ (self, model, api_key=None):
35+
self.model = model
36+
self.api_key = api_key
37+
38+
@abstractmethod
39+
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
40+
pass
41+
42+
class OpenAIEmbed(CloudEmbedProvider):
43+
def __init__(self, model="text-embedding-ada-002", api_key: Optional[str]=None, org: Optional[str]=None):
44+
super().__init__(model, api_key)
45+
self.org = org
46+
if self.api_key:
47+
openai.api_key = self.api_key
48+
if self.org:
49+
openai.organization = org
50+
51+
52+
@backoff.on_exception(
53+
backoff.expo,
54+
ERRORS,
55+
max_time=15,
56+
)
57+
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
58+
if OPENAI_LEGACY:
59+
embedding = openai.Embedding.create(
60+
input=queries, model=self.model
61+
)
62+
else:
63+
embedding = openai.embeddings.create(
64+
input=queries, model=self.model
65+
).model_dump()
66+
return [embedding["embedding"] for embedding in embedding["data"]]
67+
68+
class CohereEmbed(CloudEmbedProvider):
69+
def __init__(self, model: str = "multilingual-22-12", api_key: Optional[str] = None):
70+
try:
71+
import cohere
72+
except ImportError:
73+
raise ImportError(
74+
"The cohere library is required to use CohereEmbed. Install it with `pip install cohere`"
75+
)
76+
super().__init__(model, api_key)
77+
self.client = cohere.Client(api_key)
78+
79+
@backoff.on_exception(
80+
backoff.expo,
81+
ERRORS,
82+
max_time=15,
83+
)
84+
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
85+
embeddings = self.client.embed(texts=queries, model=self.model).embeddings
86+
return embeddings
87+
88+
89+
3690
class PineconeRM(dspy.Retrieve):
3791
"""
3892
A retrieval module that uses Pinecone to return the top passages for a given query.
@@ -43,11 +97,8 @@ class PineconeRM(dspy.Retrieve):
4397
Args:
4498
pinecone_index_name (str): The name of the Pinecone index to query against.
4599
pinecone_api_key (str, optional): The Pinecone API key. Defaults to None.
46-
pinecone_env (str, optional): The Pinecone environment. Defaults to None.
47100
local_embed_model (str, optional): The local embedding model to use. A popular default is "sentence-transformers/all-mpnet-base-v2".
48-
openai_embed_model (str, optional): The OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
49-
openai_api_key (str, optional): The API key for OpenAI. Defaults to None.
50-
openai_org (str, optional): The organization for OpenAI. Defaults to None.
101+
cloud_emded_provider (CloudEmbedProvider, optional): The cloud embedding provider to use. Defaults to None.
51102
k (int, optional): The number of top passages to retrieve. Defaults to 3.
52103
53104
Returns:
@@ -57,6 +108,7 @@ class PineconeRM(dspy.Retrieve):
57108
Below is a code snippet that shows how to use this as the default retriver:
58109
```python
59110
llm = dspy.OpenAI(model="gpt-3.5-turbo")
111+
retriever_model = PineconeRM(index_name, cloud_emded_provider=OpenAIEmbed())
60112
retriever_model = PineconeRM(openai.api_key)
61113
dspy.settings.configure(lm=llm, rm=retriever_model)
62114
```
@@ -71,11 +123,8 @@ def __init__(
71123
self,
72124
pinecone_index_name: str,
73125
pinecone_api_key: Optional[str] = None,
74-
pinecone_env: Optional[str] = None,
75126
local_embed_model: Optional[str] = None,
76-
openai_embed_model: Optional[str] = "text-embedding-ada-002",
77-
openai_api_key: Optional[str] = None,
78-
openai_org: Optional[str] = None,
127+
cloud_emded_provider: Optional[CloudEmbedProvider] = None,
79128
k: int = 3,
80129
):
81130
if local_embed_model is not None:
@@ -84,7 +133,7 @@ def __init__(
84133
from transformers import AutoModel, AutoTokenizer
85134
except ImportError as exc:
86135
raise ModuleNotFoundError(
87-
"You need to install Hugging Face transformers library to use a local embedding model with PineconeRM.",
136+
"You need to install Hugging Face transformers library to use a local embedding model with PineconeRM."
88137
) from exc
89138

90139
self._local_embed_model = AutoModel.from_pretrained(local_embed_model)
@@ -93,96 +142,48 @@ def __init__(
93142
self.device = torch.device(
94143
'cuda:0' if torch.cuda.is_available() else
95144
'mps' if torch.backends.mps.is_available()
96-
else 'cpu',
145+
else 'cpu'
97146
)
98-
elif openai_embed_model is not None:
99-
self._openai_embed_model = openai_embed_model
147+
148+
elif cloud_emded_provider is not None:
100149
self.use_local_model = False
101-
# If not provided, defaults to env vars OPENAI_API_KEY and OPENAI_ORGANIZATION
102-
if openai_api_key:
103-
openai.api_key = openai_api_key
104-
if openai_org:
105-
openai.organization = openai_org
150+
self.cloud_emded_provider = cloud_emded_provider
151+
106152
else:
107153
raise ValueError(
108-
"Either local_embed_model or openai_embed_model must be provided.",
154+
"Either local_embed_model or cloud_embed_provider must be provided."
109155
)
110156

111-
self._pinecone_index = self._init_pinecone(
112-
pinecone_index_name, pinecone_api_key, pinecone_env,
113-
)
114-
115-
super().__init__(k=k)
116-
117-
def _init_pinecone(
118-
self,
119-
index_name: str,
120-
api_key: Optional[str] = None,
121-
environment: Optional[str] = None,
122-
dimension: Optional[int] = None,
123-
distance_metric: Optional[str] = None,
124-
) -> pinecone.Index:
125-
"""Initialize pinecone and return the loaded index.
126-
127-
Args:
128-
index_name (str): The name of the index to load. If the index is not does not exist, it will be created.
129-
api_key (str, optional): The Pinecone API key, defaults to env var PINECONE_API_KEY if not provided.
130-
environment (str, optional): The environment (ie. `us-west1-gcp` or `gcp-starter`. Defaults to env PINECONE_ENVIRONMENT.
131-
132-
Raises:
133-
ValueError: If api_key or environment is not provided and not set as an environment variable.
134-
135-
Returns:
136-
pinecone.Index: The loaded index.
137-
"""
157+
if pinecone_api_key is None:
158+
self.pinecone_client = pinecone.Pinecone()
159+
else:
160+
self.pinecone_client = pinecone.Pinecone(api_key=pinecone_api_key)
138161

139-
# Pinecone init overrides default if kwargs are present, so we need to exclude if None
140-
kwargs = {}
141-
if api_key:
142-
kwargs["api_key"] = api_key
143-
if environment:
144-
kwargs["environment"] = environment
145-
pinecone.init(**kwargs)
146-
147-
active_indexes = pinecone.list_indexes()
148-
if index_name not in active_indexes:
149-
if dimension is None and distance_metric is None:
150-
raise ValueError(
151-
"dimension and distance_metric must be provided since the index provided does not exist.",
152-
)
162+
self._pinecone_index = self.pinecone_client.Index(pinecone_index_name)
153163

154-
pinecone.create_index(
155-
name=index_name,
156-
dimension=dimension,
157-
metric=distance_metric,
158-
)
164+
super().__init__(k=k)
159165

160-
return pinecone.Index(index_name)
161166

162167
def _mean_pooling(
163168
self,
164169
model_output,
165-
attention_mask,
170+
attention_mask
166171
):
167172
try:
168173
import torch
169174
except ImportError as exc:
170175
raise ModuleNotFoundError(
171-
"You need to install torch to use a local embedding model with PineconeRM.",
176+
"You need to install torch to use a local embedding model with PineconeRM."
172177
) from exc
173178

174179
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
175180
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
176181
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
177182

178-
@backoff.on_exception(
179-
backoff.expo,
180-
ERRORS,
181-
max_time=15,
182-
)
183+
183184
def _get_embeddings(
184185
self,
185-
queries: List[str],
186+
queries: List[str]
186187
) -> List[List[float]]:
187188
"""Return query vector after creating embedding using OpenAI
188189
@@ -192,24 +193,16 @@ def _get_embeddings(
192193
Returns:
193194
List[List[float]]: List of embeddings corresponding to each query.
194195
"""
196+
if not self.use_local_model:
197+
return self.cloud_emded_provider.get_embeddings(queries)
198+
195199
try:
196200
import torch
197201
except ImportError as exc:
198202
raise ModuleNotFoundError(
199-
"You need to install torch to use a local embedding model with PineconeRM.",
203+
"You need to install torch to use a local embedding model with PineconeRM."
200204
) from exc
201205

202-
if not self.use_local_model:
203-
if OPENAI_LEGACY:
204-
embedding = openai.Embedding.create(
205-
input=queries, model=self._openai_embed_model,
206-
)
207-
else:
208-
embedding = openai.embeddings.create(
209-
input=queries, model=self._openai_embed_model,
210-
).model_dump()
211-
return [embedding["embedding"] for embedding in embedding["data"]]
212-
213206
# Use local model
214207
encoded_input = self._local_tokenizer(queries, padding=True, truncation=True, return_tensors="pt").to(self.device)
215208
with torch.no_grad():
@@ -222,51 +215,55 @@ def _get_embeddings(
222215
# we need a pooling strategy to get a single vector representation of the input
223216
# so the default is to take the mean of the hidden states
224217

225-
def forward(self, query_or_queries: Union[str, List[str]]) -> dspy.Prediction:
218+
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None) -> dspy.Prediction:
226219
"""Search with pinecone for self.k top passages for query
227220
228221
Args:
229222
query_or_queries (Union[str, List[str]]): The query or queries to search for.
223+
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k
230224
231225
Returns:
232226
dspy.Prediction: An object containing the retrieved passages.
233227
"""
228+
k = k if k is not None else self.k
234229
queries = (
235230
[query_or_queries]
236231
if isinstance(query_or_queries, str)
237232
else query_or_queries
238233
)
239234
queries = [q for q in queries if q] # Filter empty queries
240235
embeddings = self._get_embeddings(queries)
241-
242236
# For single query, just look up the top k passages
243237
if len(queries) == 1:
244238
results_dict = self._pinecone_index.query(
245-
embeddings[0], top_k=self.k, include_metadata=True,
239+
vector=embeddings[0], top_k=k, include_metadata=True
246240
)
247241

248242
# Sort results by score
249243
sorted_results = sorted(
250-
results_dict["matches"], key=lambda x: x.get("scores", 0.0), reverse=True,
244+
results_dict["matches"], key=lambda x: x.get("scores", 0.0), reverse=True
251245
)
246+
252247
passages = [result["metadata"]["text"] for result in sorted_results]
253-
passages = [dotdict({"long_text": passage for passage in passages})]
254-
return dspy.Prediction(passages=passages)
248+
passages = [dotdict({"long_text": passage}) for passage in passages]
249+
return passages
255250

256251
# For multiple queries, query each and return the highest scoring passages
257252
# If a passage is returned multiple times, the score is accumulated. For this reason we increase top_k by 3x
258253
passage_scores = {}
259254
for embedding in embeddings:
260255
results_dict = self._pinecone_index.query(
261-
embedding, top_k=self.k * 3, include_metadata=True,
256+
vector=embedding, top_k=k * 3, include_metadata=True
262257
)
263258
for result in results_dict["matches"]:
264259
passage_scores[result["metadata"]["text"]] = (
265260
passage_scores.get(result["metadata"]["text"], 0.0)
266261
+ result["score"]
267262
)
268-
263+
269264
sorted_passages = sorted(
270-
passage_scores.items(), key=lambda x: x[1], reverse=True,
271-
)[: self.k]
272-
return dspy.Prediction(passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages])
265+
passage_scores.items(), key=lambda x: x[1], reverse=True
266+
)[: k]
267+
268+
passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages]
269+
return passages

0 commit comments

Comments
 (0)