33Author: Dhar Rawal (@drawal1)
44"""
55
6+ from abc import ABC , abstractmethod
67from typing import List , Optional , Union
78
89import backoff
1314try :
1415 import pinecone
1516except ImportError :
16- pinecone = None
17-
18- if pinecone is None :
1917 raise ImportError (
2018 "The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`" ,
2119 )
3331except Exception :
3432 ERRORS = (openai .RateLimitError , openai .APIError )
3533
34+
35+ class CloudEmbedProvider (ABC ):
36+ def __init__ (self , model , api_key = None ):
37+ self .model = model
38+ self .api_key = api_key
39+
40+ @abstractmethod
41+ def get_embeddings (self , queries : List [str ]) -> List [List [float ]]:
42+ pass
43+
44+ class OpenAIEmbed (CloudEmbedProvider ):
45+ def __init__ (self , model = "text-embedding-ada-002" , api_key : Optional [str ]= None , org : Optional [str ]= None ):
46+ super ().__init__ (model , api_key )
47+ self .org = org
48+ if self .api_key :
49+ openai .api_key = self .api_key
50+ if self .org :
51+ openai .organization = org
52+
53+
54+ @backoff .on_exception (
55+ backoff .expo ,
56+ ERRORS ,
57+ max_time = 15 ,
58+ )
59+ def get_embeddings (self , queries : List [str ]) -> List [List [float ]]:
60+ if OPENAI_LEGACY :
61+ embedding = openai .Embedding .create (
62+ input = queries , model = self .model ,
63+ )
64+ else :
65+ embedding = openai .embeddings .create (
66+ input = queries , model = self .model ,
67+ ).model_dump ()
68+ return [embedding ["embedding" ] for embedding in embedding ["data" ]]
69+
70+ class CohereEmbed (CloudEmbedProvider ):
71+ def __init__ (self , model : str = "multilingual-22-12" , api_key : Optional [str ] = None ):
72+ try :
73+ import cohere
74+ except ImportError :
75+ raise ImportError (
76+ "The cohere library is required to use CohereEmbed. Install it with `pip install cohere`" ,
77+ )
78+ super ().__init__ (model , api_key )
79+ self .client = cohere .Client (api_key )
80+
81+ @backoff .on_exception (
82+ backoff .expo ,
83+ ERRORS ,
84+ max_time = 15 ,
85+ )
86+ def get_embeddings (self , queries : List [str ]) -> List [List [float ]]:
87+ embeddings = self .client .embed (texts = queries , model = self .model ).embeddings
88+ return embeddings
89+
90+
91+
3692class PineconeRM (dspy .Retrieve ):
3793 """
3894 A retrieval module that uses Pinecone to return the top passages for a given query.
@@ -43,11 +99,8 @@ class PineconeRM(dspy.Retrieve):
4399 Args:
44100 pinecone_index_name (str): The name of the Pinecone index to query against.
45101 pinecone_api_key (str, optional): The Pinecone API key. Defaults to None.
46- pinecone_env (str, optional): The Pinecone environment. Defaults to None.
47102 local_embed_model (str, optional): The local embedding model to use. A popular default is "sentence-transformers/all-mpnet-base-v2".
48- openai_embed_model (str, optional): The OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
49- openai_api_key (str, optional): The API key for OpenAI. Defaults to None.
50- openai_org (str, optional): The organization for OpenAI. Defaults to None.
103+ cloud_emded_provider (CloudEmbedProvider, optional): The cloud embedding provider to use. Defaults to None.
51104 k (int, optional): The number of top passages to retrieve. Defaults to 3.
52105
53106 Returns:
@@ -57,6 +110,7 @@ class PineconeRM(dspy.Retrieve):
57110 Below is a code snippet that shows how to use this as the default retriver:
58111 ```python
59112 llm = dspy.OpenAI(model="gpt-3.5-turbo")
113+ retriever_model = PineconeRM(index_name, cloud_emded_provider=OpenAIEmbed())
60114 retriever_model = PineconeRM(openai.api_key)
61115 dspy.settings.configure(lm=llm, rm=retriever_model)
62116 ```
@@ -71,11 +125,8 @@ def __init__(
71125 self ,
72126 pinecone_index_name : str ,
73127 pinecone_api_key : Optional [str ] = None ,
74- pinecone_env : Optional [str ] = None ,
75128 local_embed_model : Optional [str ] = None ,
76- openai_embed_model : Optional [str ] = "text-embedding-ada-002" ,
77- openai_api_key : Optional [str ] = None ,
78- openai_org : Optional [str ] = None ,
129+ cloud_emded_provider : Optional [CloudEmbedProvider ] = None ,
79130 k : int = 3 ,
80131 ):
81132 if local_embed_model is not None :
@@ -95,69 +146,25 @@ def __init__(
95146 'mps' if torch .backends .mps .is_available ()
96147 else 'cpu' ,
97148 )
98- elif openai_embed_model is not None :
99- self . _openai_embed_model = openai_embed_model
149+
150+ elif cloud_emded_provider is not None :
100151 self .use_local_model = False
101- # If not provided, defaults to env vars OPENAI_API_KEY and OPENAI_ORGANIZATION
102- if openai_api_key :
103- openai .api_key = openai_api_key
104- if openai_org :
105- openai .organization = openai_org
152+ self .cloud_emded_provider = cloud_emded_provider
153+
106154 else :
107155 raise ValueError (
108- "Either local_embed_model or openai_embed_model must be provided." ,
156+ "Either local_embed_model or cloud_embed_provider must be provided." ,
109157 )
110158
111- self ._pinecone_index = self ._init_pinecone (
112- pinecone_index_name , pinecone_api_key , pinecone_env ,
113- )
114-
115- super ().__init__ (k = k )
116-
117- def _init_pinecone (
118- self ,
119- index_name : str ,
120- api_key : Optional [str ] = None ,
121- environment : Optional [str ] = None ,
122- dimension : Optional [int ] = None ,
123- distance_metric : Optional [str ] = None ,
124- ) -> pinecone .Index :
125- """Initialize pinecone and return the loaded index.
126-
127- Args:
128- index_name (str): The name of the index to load. If the index is not does not exist, it will be created.
129- api_key (str, optional): The Pinecone API key, defaults to env var PINECONE_API_KEY if not provided.
130- environment (str, optional): The environment (ie. `us-west1-gcp` or `gcp-starter`. Defaults to env PINECONE_ENVIRONMENT.
131-
132- Raises:
133- ValueError: If api_key or environment is not provided and not set as an environment variable.
134-
135- Returns:
136- pinecone.Index: The loaded index.
137- """
159+ if pinecone_api_key is None :
160+ self .pinecone_client = pinecone .Pinecone ()
161+ else :
162+ self .pinecone_client = pinecone .Pinecone (api_key = pinecone_api_key )
138163
139- # Pinecone init overrides default if kwargs are present, so we need to exclude if None
140- kwargs = {}
141- if api_key :
142- kwargs ["api_key" ] = api_key
143- if environment :
144- kwargs ["environment" ] = environment
145- pinecone .init (** kwargs )
146-
147- active_indexes = pinecone .list_indexes ()
148- if index_name not in active_indexes :
149- if dimension is None and distance_metric is None :
150- raise ValueError (
151- "dimension and distance_metric must be provided since the index provided does not exist." ,
152- )
164+ self ._pinecone_index = self .pinecone_client .Index (pinecone_index_name )
153165
154- pinecone .create_index (
155- name = index_name ,
156- dimension = dimension ,
157- metric = distance_metric ,
158- )
166+ super ().__init__ (k = k )
159167
160- return pinecone .Index (index_name )
161168
162169 def _mean_pooling (
163170 self ,
@@ -175,11 +182,7 @@ def _mean_pooling(
175182 input_mask_expanded = attention_mask .unsqueeze (- 1 ).expand (token_embeddings .size ()).float ()
176183 return torch .sum (token_embeddings * input_mask_expanded , 1 ) / torch .clamp (input_mask_expanded .sum (1 ), min = 1e-9 )
177184
178- @backoff .on_exception (
179- backoff .expo ,
180- ERRORS ,
181- max_time = 15 ,
182- )
185+
183186 def _get_embeddings (
184187 self ,
185188 queries : List [str ],
@@ -192,24 +195,16 @@ def _get_embeddings(
192195 Returns:
193196 List[List[float]]: List of embeddings corresponding to each query.
194197 """
198+ if not self .use_local_model :
199+ return self .cloud_emded_provider .get_embeddings (queries )
200+
195201 try :
196202 import torch
197203 except ImportError as exc :
198204 raise ModuleNotFoundError (
199205 "You need to install torch to use a local embedding model with PineconeRM." ,
200206 ) from exc
201207
202- if not self .use_local_model :
203- if OPENAI_LEGACY :
204- embedding = openai .Embedding .create (
205- input = queries , model = self ._openai_embed_model ,
206- )
207- else :
208- embedding = openai .embeddings .create (
209- input = queries , model = self ._openai_embed_model ,
210- ).model_dump ()
211- return [embedding ["embedding" ] for embedding in embedding ["data" ]]
212-
213208 # Use local model
214209 encoded_input = self ._local_tokenizer (queries , padding = True , truncation = True , return_tensors = "pt" ).to (self .device )
215210 with torch .no_grad ():
@@ -222,51 +217,55 @@ def _get_embeddings(
222217 # we need a pooling strategy to get a single vector representation of the input
223218 # so the default is to take the mean of the hidden states
224219
225- def forward (self , query_or_queries : Union [str , List [str ]]) -> dspy .Prediction :
220+ def forward (self , query_or_queries : Union [str , List [str ]], k : Optional [ int ] = None ) -> dspy .Prediction :
226221 """Search with pinecone for self.k top passages for query
227222
228223 Args:
229224 query_or_queries (Union[str, List[str]]): The query or queries to search for.
225+ k (Optional[int]): The number of top passages to retrieve. Defaults to self.k
230226
231227 Returns:
232228 dspy.Prediction: An object containing the retrieved passages.
233229 """
230+ k = k if k is not None else self .k
234231 queries = (
235232 [query_or_queries ]
236233 if isinstance (query_or_queries , str )
237234 else query_or_queries
238235 )
239236 queries = [q for q in queries if q ] # Filter empty queries
240237 embeddings = self ._get_embeddings (queries )
241-
242238 # For single query, just look up the top k passages
243239 if len (queries ) == 1 :
244240 results_dict = self ._pinecone_index .query (
245- embeddings [0 ], top_k = self . k , include_metadata = True ,
241+ vector = embeddings [0 ], top_k = k , include_metadata = True ,
246242 )
247243
248244 # Sort results by score
249245 sorted_results = sorted (
250246 results_dict ["matches" ], key = lambda x : x .get ("scores" , 0.0 ), reverse = True ,
251247 )
248+
252249 passages = [result ["metadata" ]["text" ] for result in sorted_results ]
253- passages = [dotdict ({"long_text" : passage for passage in passages }) ]
254- return dspy . Prediction ( passages = passages )
250+ passages = [dotdict ({"long_text" : passage }) for passage in passages ]
251+ return passages
255252
256253 # For multiple queries, query each and return the highest scoring passages
257254 # If a passage is returned multiple times, the score is accumulated. For this reason we increase top_k by 3x
258255 passage_scores = {}
259256 for embedding in embeddings :
260257 results_dict = self ._pinecone_index .query (
261- embedding , top_k = self . k * 3 , include_metadata = True ,
258+ vector = embedding , top_k = k * 3 , include_metadata = True ,
262259 )
263260 for result in results_dict ["matches" ]:
264261 passage_scores [result ["metadata" ]["text" ]] = (
265262 passage_scores .get (result ["metadata" ]["text" ], 0.0 )
266263 + result ["score" ]
267264 )
268-
265+
269266 sorted_passages = sorted (
270267 passage_scores .items (), key = lambda x : x [1 ], reverse = True ,
271- )[: self .k ]
272- return dspy .Prediction (passages = [dotdict ({"long_text" : passage }) for passage , _ in sorted_passages ])
268+ )[: k ]
269+
270+ passages = [dotdict ({"long_text" : passage }) for passage , _ in sorted_passages ]
271+ return passages
0 commit comments