33Author: Dhar Rawal (@drawal1)
44"""
55
6- from abc import ABC , abstractmethod
76from typing import List , Optional , Union
87
98import backoff
1413try :
1514 import pinecone
1615except ImportError :
16+ pinecone = None
17+
18+ if pinecone is None :
1719 raise ImportError (
1820 "The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`" ,
1921 )
3133except Exception :
3234 ERRORS = (openai .RateLimitError , openai .APIError )
3335
34-
35- class CloudEmbedProvider (ABC ):
36- def __init__ (self , model , api_key = None ):
37- self .model = model
38- self .api_key = api_key
39-
40- @abstractmethod
41- def get_embeddings (self , queries : List [str ]) -> List [List [float ]]:
42- pass
43-
44- class OpenAIEmbed (CloudEmbedProvider ):
45- def __init__ (self , model = "text-embedding-ada-002" , api_key : Optional [str ]= None , org : Optional [str ]= None ):
46- super ().__init__ (model , api_key )
47- self .org = org
48- if self .api_key :
49- openai .api_key = self .api_key
50- if self .org :
51- openai .organization = org
52-
53-
54- @backoff .on_exception (
55- backoff .expo ,
56- ERRORS ,
57- max_time = 15 ,
58- )
59- def get_embeddings (self , queries : List [str ]) -> List [List [float ]]:
60- if OPENAI_LEGACY :
61- embedding = openai .Embedding .create (
62- input = queries , model = self .model ,
63- )
64- else :
65- embedding = openai .embeddings .create (
66- input = queries , model = self .model ,
67- ).model_dump ()
68- return [embedding ["embedding" ] for embedding in embedding ["data" ]]
69-
70- class CohereEmbed (CloudEmbedProvider ):
71- def __init__ (self , model : str = "multilingual-22-12" , api_key : Optional [str ] = None ):
72- try :
73- import cohere
74- except ImportError :
75- raise ImportError (
76- "The cohere library is required to use CohereEmbed. Install it with `pip install cohere`" ,
77- )
78- super ().__init__ (model , api_key )
79- self .client = cohere .Client (api_key )
80-
81- @backoff .on_exception (
82- backoff .expo ,
83- ERRORS ,
84- max_time = 15 ,
85- )
86- def get_embeddings (self , queries : List [str ]) -> List [List [float ]]:
87- embeddings = self .client .embed (texts = queries , model = self .model ).embeddings
88- return embeddings
89-
90-
91-
9236class PineconeRM (dspy .Retrieve ):
9337 """
9438 A retrieval module that uses Pinecone to return the top passages for a given query.
@@ -99,8 +43,11 @@ class PineconeRM(dspy.Retrieve):
9943 Args:
10044 pinecone_index_name (str): The name of the Pinecone index to query against.
10145 pinecone_api_key (str, optional): The Pinecone API key. Defaults to None.
46+ pinecone_env (str, optional): The Pinecone environment. Defaults to None.
10247 local_embed_model (str, optional): The local embedding model to use. A popular default is "sentence-transformers/all-mpnet-base-v2".
103- cloud_emded_provider (CloudEmbedProvider, optional): The cloud embedding provider to use. Defaults to None.
48+ openai_embed_model (str, optional): The OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
49+ openai_api_key (str, optional): The API key for OpenAI. Defaults to None.
50+ openai_org (str, optional): The organization for OpenAI. Defaults to None.
10451 k (int, optional): The number of top passages to retrieve. Defaults to 3.
10552
10653 Returns:
@@ -110,7 +57,6 @@ class PineconeRM(dspy.Retrieve):
11057 Below is a code snippet that shows how to use this as the default retriver:
11158 ```python
11259 llm = dspy.OpenAI(model="gpt-3.5-turbo")
113- retriever_model = PineconeRM(index_name, cloud_emded_provider=OpenAIEmbed())
11460 retriever_model = PineconeRM(openai.api_key)
11561 dspy.settings.configure(lm=llm, rm=retriever_model)
11662 ```
@@ -125,8 +71,11 @@ def __init__(
12571 self ,
12672 pinecone_index_name : str ,
12773 pinecone_api_key : Optional [str ] = None ,
74+ pinecone_env : Optional [str ] = None ,
12875 local_embed_model : Optional [str ] = None ,
129- cloud_emded_provider : Optional [CloudEmbedProvider ] = None ,
76+ openai_embed_model : Optional [str ] = "text-embedding-ada-002" ,
77+ openai_api_key : Optional [str ] = None ,
78+ openai_org : Optional [str ] = None ,
13079 k : int = 3 ,
13180 ):
13281 if local_embed_model is not None :
@@ -146,25 +95,69 @@ def __init__(
14695 'mps' if torch .backends .mps .is_available ()
14796 else 'cpu' ,
14897 )
149-
150- elif cloud_emded_provider is not None :
98+ elif openai_embed_model is not None :
99+ self . _openai_embed_model = openai_embed_model
151100 self .use_local_model = False
152- self .cloud_emded_provider = cloud_emded_provider
153-
101+ # If not provided, defaults to env vars OPENAI_API_KEY and OPENAI_ORGANIZATION
102+ if openai_api_key :
103+ openai .api_key = openai_api_key
104+ if openai_org :
105+ openai .organization = openai_org
154106 else :
155107 raise ValueError (
156- "Either local_embed_model or cloud_embed_provider must be provided." ,
108+ "Either local_embed_model or openai_embed_model must be provided." ,
157109 )
158110
159- if pinecone_api_key is None :
160- self .pinecone_client = pinecone .Pinecone ()
161- else :
162- self .pinecone_client = pinecone .Pinecone (api_key = pinecone_api_key )
163-
164- self ._pinecone_index = self .pinecone_client .Index (pinecone_index_name )
111+ self ._pinecone_index = self ._init_pinecone (
112+ pinecone_index_name , pinecone_api_key , pinecone_env ,
113+ )
165114
166115 super ().__init__ (k = k )
167116
117+ def _init_pinecone (
118+ self ,
119+ index_name : str ,
120+ api_key : Optional [str ] = None ,
121+ environment : Optional [str ] = None ,
122+ dimension : Optional [int ] = None ,
123+ distance_metric : Optional [str ] = None ,
124+ ) -> pinecone .Index :
125+ """Initialize pinecone and return the loaded index.
126+
127+ Args:
128+ index_name (str): The name of the index to load. If the index is not does not exist, it will be created.
129+ api_key (str, optional): The Pinecone API key, defaults to env var PINECONE_API_KEY if not provided.
130+ environment (str, optional): The environment (ie. `us-west1-gcp` or `gcp-starter`. Defaults to env PINECONE_ENVIRONMENT.
131+
132+ Raises:
133+ ValueError: If api_key or environment is not provided and not set as an environment variable.
134+
135+ Returns:
136+ pinecone.Index: The loaded index.
137+ """
138+
139+ # Pinecone init overrides default if kwargs are present, so we need to exclude if None
140+ kwargs = {}
141+ if api_key :
142+ kwargs ["api_key" ] = api_key
143+ if environment :
144+ kwargs ["environment" ] = environment
145+ pinecone .init (** kwargs )
146+
147+ active_indexes = pinecone .list_indexes ()
148+ if index_name not in active_indexes :
149+ if dimension is None and distance_metric is None :
150+ raise ValueError (
151+ "dimension and distance_metric must be provided since the index provided does not exist." ,
152+ )
153+
154+ pinecone .create_index (
155+ name = index_name ,
156+ dimension = dimension ,
157+ metric = distance_metric ,
158+ )
159+
160+ return pinecone .Index (index_name )
168161
169162 def _mean_pooling (
170163 self ,
@@ -182,7 +175,11 @@ def _mean_pooling(
182175 input_mask_expanded = attention_mask .unsqueeze (- 1 ).expand (token_embeddings .size ()).float ()
183176 return torch .sum (token_embeddings * input_mask_expanded , 1 ) / torch .clamp (input_mask_expanded .sum (1 ), min = 1e-9 )
184177
185-
178+ @backoff .on_exception (
179+ backoff .expo ,
180+ ERRORS ,
181+ max_time = 15 ,
182+ )
186183 def _get_embeddings (
187184 self ,
188185 queries : List [str ],
@@ -195,16 +192,24 @@ def _get_embeddings(
195192 Returns:
196193 List[List[float]]: List of embeddings corresponding to each query.
197194 """
198- if not self .use_local_model :
199- return self .cloud_emded_provider .get_embeddings (queries )
200-
201195 try :
202196 import torch
203197 except ImportError as exc :
204198 raise ModuleNotFoundError (
205199 "You need to install torch to use a local embedding model with PineconeRM." ,
206200 ) from exc
207201
202+ if not self .use_local_model :
203+ if OPENAI_LEGACY :
204+ embedding = openai .Embedding .create (
205+ input = queries , model = self ._openai_embed_model ,
206+ )
207+ else :
208+ embedding = openai .embeddings .create (
209+ input = queries , model = self ._openai_embed_model ,
210+ ).model_dump ()
211+ return [embedding ["embedding" ] for embedding in embedding ["data" ]]
212+
208213 # Use local model
209214 encoded_input = self ._local_tokenizer (queries , padding = True , truncation = True , return_tensors = "pt" ).to (self .device )
210215 with torch .no_grad ():
@@ -217,55 +222,51 @@ def _get_embeddings(
217222 # we need a pooling strategy to get a single vector representation of the input
218223 # so the default is to take the mean of the hidden states
219224
220- def forward (self , query_or_queries : Union [str , List [str ]], k : Optional [ int ] = None ) -> dspy .Prediction :
225+ def forward (self , query_or_queries : Union [str , List [str ]]) -> dspy .Prediction :
221226 """Search with pinecone for self.k top passages for query
222227
223228 Args:
224229 query_or_queries (Union[str, List[str]]): The query or queries to search for.
225- k (Optional[int]): The number of top passages to retrieve. Defaults to self.k
226230
227231 Returns:
228232 dspy.Prediction: An object containing the retrieved passages.
229233 """
230- k = k if k is not None else self .k
231234 queries = (
232235 [query_or_queries ]
233236 if isinstance (query_or_queries , str )
234237 else query_or_queries
235238 )
236239 queries = [q for q in queries if q ] # Filter empty queries
237240 embeddings = self ._get_embeddings (queries )
241+
238242 # For single query, just look up the top k passages
239243 if len (queries ) == 1 :
240244 results_dict = self ._pinecone_index .query (
241- vector = embeddings [0 ], top_k = k , include_metadata = True ,
245+ embeddings [0 ], top_k = self . k , include_metadata = True ,
242246 )
243247
244248 # Sort results by score
245249 sorted_results = sorted (
246250 results_dict ["matches" ], key = lambda x : x .get ("scores" , 0.0 ), reverse = True ,
247251 )
248-
249252 passages = [result ["metadata" ]["text" ] for result in sorted_results ]
250- passages = [dotdict ({"long_text" : passage }) for passage in passages ]
251- return passages
253+ passages = [dotdict ({"long_text" : passage for passage in passages }) ]
254+ return dspy . Prediction ( passages = passages )
252255
253256 # For multiple queries, query each and return the highest scoring passages
254257 # If a passage is returned multiple times, the score is accumulated. For this reason we increase top_k by 3x
255258 passage_scores = {}
256259 for embedding in embeddings :
257260 results_dict = self ._pinecone_index .query (
258- vector = embedding , top_k = k * 3 , include_metadata = True ,
261+ embedding , top_k = self . k * 3 , include_metadata = True ,
259262 )
260263 for result in results_dict ["matches" ]:
261264 passage_scores [result ["metadata" ]["text" ]] = (
262265 passage_scores .get (result ["metadata" ]["text" ], 0.0 )
263266 + result ["score" ]
264267 )
265-
268+
266269 sorted_passages = sorted (
267270 passage_scores .items (), key = lambda x : x [1 ], reverse = True ,
268- )[: k ]
269-
270- passages = [dotdict ({"long_text" : passage }) for passage , _ in sorted_passages ]
271- return passages
271+ )[: self .k ]
272+ return dspy .Prediction (passages = [dotdict ({"long_text" : passage }) for passage , _ in sorted_passages ])
0 commit comments