33Author: Dhar Rawal (@drawal1)
44"""
55
6- from typing import List , Optional , Union
7-
8- import backoff
9-
10- import dspy
6+ import os
117from dsp .utils import dotdict
8+ from typing import Optional , List , Union , Any
9+ import dspy
10+ import backoff
11+ from abc import ABC , abstractmethod
1212
1313try :
1414 import pinecone
1515except ImportError :
16- pinecone = None
17-
18- if pinecone is None :
1916 raise ImportError (
20- "The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`" ,
17+ "The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`"
2118 )
2219
2320import openai
24-
2521try :
2622 OPENAI_LEGACY = int (openai .version .__version__ [0 ]) == 0
2723except Exception :
3329except Exception :
3430 ERRORS = (openai .RateLimitError , openai .APIError )
3531
32+
33+ class CloudEmbedProvider (ABC ):
34+ def __init__ (self , model , api_key = None ):
35+ self .model = model
36+ self .api_key = api_key
37+
38+ @abstractmethod
39+ def get_embeddings (self , queries : List [str ]) -> List [List [float ]]:
40+ pass
41+
42+ class OpenAIEmbed (CloudEmbedProvider ):
43+ def __init__ (self , model = "text-embedding-ada-002" , api_key : Optional [str ]= None , org : Optional [str ]= None ):
44+ super ().__init__ (model , api_key )
45+ self .org = org
46+ if self .api_key :
47+ openai .api_key = self .api_key
48+ if self .org :
49+ openai .organization = org
50+
51+
52+ @backoff .on_exception (
53+ backoff .expo ,
54+ ERRORS ,
55+ max_time = 15 ,
56+ )
57+ def get_embeddings (self , queries : List [str ]) -> List [List [float ]]:
58+ if OPENAI_LEGACY :
59+ embedding = openai .Embedding .create (
60+ input = queries , model = self .model
61+ )
62+ else :
63+ embedding = openai .embeddings .create (
64+ input = queries , model = self .model
65+ ).model_dump ()
66+ return [embedding ["embedding" ] for embedding in embedding ["data" ]]
67+
68+ class CohereEmbed (CloudEmbedProvider ):
69+ def __init__ (self , model : str = "multilingual-22-12" , api_key : Optional [str ] = None ):
70+ try :
71+ import cohere
72+ except ImportError :
73+ raise ImportError (
74+ "The cohere library is required to use CohereEmbed. Install it with `pip install cohere`"
75+ )
76+ super ().__init__ (model , api_key )
77+ self .client = cohere .Client (api_key )
78+
79+ @backoff .on_exception (
80+ backoff .expo ,
81+ ERRORS ,
82+ max_time = 15 ,
83+ )
84+ def get_embeddings (self , queries : List [str ]) -> List [List [float ]]:
85+ embeddings = self .client .embed (texts = queries , model = self .model ).embeddings
86+ return embeddings
87+
88+
89+
3690class PineconeRM (dspy .Retrieve ):
3791 """
3892 A retrieval module that uses Pinecone to return the top passages for a given query.
@@ -43,11 +97,8 @@ class PineconeRM(dspy.Retrieve):
4397 Args:
4498 pinecone_index_name (str): The name of the Pinecone index to query against.
4599 pinecone_api_key (str, optional): The Pinecone API key. Defaults to None.
46- pinecone_env (str, optional): The Pinecone environment. Defaults to None.
47100 local_embed_model (str, optional): The local embedding model to use. A popular default is "sentence-transformers/all-mpnet-base-v2".
48- openai_embed_model (str, optional): The OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
49- openai_api_key (str, optional): The API key for OpenAI. Defaults to None.
50- openai_org (str, optional): The organization for OpenAI. Defaults to None.
101+ cloud_emded_provider (CloudEmbedProvider, optional): The cloud embedding provider to use. Defaults to None.
51102 k (int, optional): The number of top passages to retrieve. Defaults to 3.
52103
53104 Returns:
@@ -57,6 +108,7 @@ class PineconeRM(dspy.Retrieve):
57108 Below is a code snippet that shows how to use this as the default retriver:
58109 ```python
59110 llm = dspy.OpenAI(model="gpt-3.5-turbo")
111+ retriever_model = PineconeRM(index_name, cloud_emded_provider=OpenAIEmbed())
60112 retriever_model = PineconeRM(openai.api_key)
61113 dspy.settings.configure(lm=llm, rm=retriever_model)
62114 ```
@@ -71,11 +123,8 @@ def __init__(
71123 self ,
72124 pinecone_index_name : str ,
73125 pinecone_api_key : Optional [str ] = None ,
74- pinecone_env : Optional [str ] = None ,
75126 local_embed_model : Optional [str ] = None ,
76- openai_embed_model : Optional [str ] = "text-embedding-ada-002" ,
77- openai_api_key : Optional [str ] = None ,
78- openai_org : Optional [str ] = None ,
127+ cloud_emded_provider : Optional [CloudEmbedProvider ] = None ,
79128 k : int = 3 ,
80129 ):
81130 if local_embed_model is not None :
@@ -84,7 +133,7 @@ def __init__(
84133 from transformers import AutoModel , AutoTokenizer
85134 except ImportError as exc :
86135 raise ModuleNotFoundError (
87- "You need to install Hugging Face transformers library to use a local embedding model with PineconeRM." ,
136+ "You need to install Hugging Face transformers library to use a local embedding model with PineconeRM."
88137 ) from exc
89138
90139 self ._local_embed_model = AutoModel .from_pretrained (local_embed_model )
@@ -93,96 +142,48 @@ def __init__(
93142 self .device = torch .device (
94143 'cuda:0' if torch .cuda .is_available () else
95144 'mps' if torch .backends .mps .is_available ()
96- else 'cpu' ,
145+ else 'cpu'
97146 )
98- elif openai_embed_model is not None :
99- self . _openai_embed_model = openai_embed_model
147+
148+ elif cloud_emded_provider is not None :
100149 self .use_local_model = False
101- # If not provided, defaults to env vars OPENAI_API_KEY and OPENAI_ORGANIZATION
102- if openai_api_key :
103- openai .api_key = openai_api_key
104- if openai_org :
105- openai .organization = openai_org
150+ self .cloud_emded_provider = cloud_emded_provider
151+
106152 else :
107153 raise ValueError (
108- "Either local_embed_model or openai_embed_model must be provided." ,
154+ "Either local_embed_model or cloud_embed_provider must be provided."
109155 )
110156
111- self ._pinecone_index = self ._init_pinecone (
112- pinecone_index_name , pinecone_api_key , pinecone_env ,
113- )
114-
115- super ().__init__ (k = k )
116-
117- def _init_pinecone (
118- self ,
119- index_name : str ,
120- api_key : Optional [str ] = None ,
121- environment : Optional [str ] = None ,
122- dimension : Optional [int ] = None ,
123- distance_metric : Optional [str ] = None ,
124- ) -> pinecone .Index :
125- """Initialize pinecone and return the loaded index.
126-
127- Args:
128- index_name (str): The name of the index to load. If the index is not does not exist, it will be created.
129- api_key (str, optional): The Pinecone API key, defaults to env var PINECONE_API_KEY if not provided.
130- environment (str, optional): The environment (ie. `us-west1-gcp` or `gcp-starter`. Defaults to env PINECONE_ENVIRONMENT.
131-
132- Raises:
133- ValueError: If api_key or environment is not provided and not set as an environment variable.
134-
135- Returns:
136- pinecone.Index: The loaded index.
137- """
157+ if pinecone_api_key is None :
158+ self .pinecone_client = pinecone .Pinecone ()
159+ else :
160+ self .pinecone_client = pinecone .Pinecone (api_key = pinecone_api_key )
138161
139- # Pinecone init overrides default if kwargs are present, so we need to exclude if None
140- kwargs = {}
141- if api_key :
142- kwargs ["api_key" ] = api_key
143- if environment :
144- kwargs ["environment" ] = environment
145- pinecone .init (** kwargs )
146-
147- active_indexes = pinecone .list_indexes ()
148- if index_name not in active_indexes :
149- if dimension is None and distance_metric is None :
150- raise ValueError (
151- "dimension and distance_metric must be provided since the index provided does not exist." ,
152- )
162+ self ._pinecone_index = self .pinecone_client .Index (pinecone_index_name )
153163
154- pinecone .create_index (
155- name = index_name ,
156- dimension = dimension ,
157- metric = distance_metric ,
158- )
164+ super ().__init__ (k = k )
159165
160- return pinecone .Index (index_name )
161166
162167 def _mean_pooling (
163168 self ,
164169 model_output ,
165- attention_mask ,
170+ attention_mask
166171 ):
167172 try :
168173 import torch
169174 except ImportError as exc :
170175 raise ModuleNotFoundError (
171- "You need to install torch to use a local embedding model with PineconeRM." ,
176+ "You need to install torch to use a local embedding model with PineconeRM."
172177 ) from exc
173178
174179 token_embeddings = model_output [0 ] # First element of model_output contains all token embeddings
175180 input_mask_expanded = attention_mask .unsqueeze (- 1 ).expand (token_embeddings .size ()).float ()
176181 return torch .sum (token_embeddings * input_mask_expanded , 1 ) / torch .clamp (input_mask_expanded .sum (1 ), min = 1e-9 )
177182
178- @backoff .on_exception (
179- backoff .expo ,
180- ERRORS ,
181- max_time = 15 ,
182- )
183+
183184 def _get_embeddings (
184185 self ,
185- queries : List [str ],
186+ queries : List [str ]
186187 ) -> List [List [float ]]:
187188 """Return query vector after creating embedding using OpenAI
188189
@@ -192,24 +193,16 @@ def _get_embeddings(
192193 Returns:
193194 List[List[float]]: List of embeddings corresponding to each query.
194195 """
196+ if not self .use_local_model :
197+ return self .cloud_emded_provider .get_embeddings (queries )
198+
195199 try :
196200 import torch
197201 except ImportError as exc :
198202 raise ModuleNotFoundError (
199- "You need to install torch to use a local embedding model with PineconeRM." ,
203+ "You need to install torch to use a local embedding model with PineconeRM."
200204 ) from exc
201205
202- if not self .use_local_model :
203- if OPENAI_LEGACY :
204- embedding = openai .Embedding .create (
205- input = queries , model = self ._openai_embed_model ,
206- )
207- else :
208- embedding = openai .embeddings .create (
209- input = queries , model = self ._openai_embed_model ,
210- ).model_dump ()
211- return [embedding ["embedding" ] for embedding in embedding ["data" ]]
212-
213206 # Use local model
214207 encoded_input = self ._local_tokenizer (queries , padding = True , truncation = True , return_tensors = "pt" ).to (self .device )
215208 with torch .no_grad ():
@@ -222,51 +215,55 @@ def _get_embeddings(
222215 # we need a pooling strategy to get a single vector representation of the input
223216 # so the default is to take the mean of the hidden states
224217
225- def forward (self , query_or_queries : Union [str , List [str ]]) -> dspy .Prediction :
218+ def forward (self , query_or_queries : Union [str , List [str ]], k : Optional [ int ] = None ) -> dspy .Prediction :
226219 """Search with pinecone for self.k top passages for query
227220
228221 Args:
229222 query_or_queries (Union[str, List[str]]): The query or queries to search for.
223+ k (Optional[int]): The number of top passages to retrieve. Defaults to self.k
230224
231225 Returns:
232226 dspy.Prediction: An object containing the retrieved passages.
233227 """
228+ k = k if k is not None else self .k
234229 queries = (
235230 [query_or_queries ]
236231 if isinstance (query_or_queries , str )
237232 else query_or_queries
238233 )
239234 queries = [q for q in queries if q ] # Filter empty queries
240235 embeddings = self ._get_embeddings (queries )
241-
242236 # For single query, just look up the top k passages
243237 if len (queries ) == 1 :
244238 results_dict = self ._pinecone_index .query (
245- embeddings [0 ], top_k = self . k , include_metadata = True ,
239+ vector = embeddings [0 ], top_k = k , include_metadata = True
246240 )
247241
248242 # Sort results by score
249243 sorted_results = sorted (
250- results_dict ["matches" ], key = lambda x : x .get ("scores" , 0.0 ), reverse = True ,
244+ results_dict ["matches" ], key = lambda x : x .get ("scores" , 0.0 ), reverse = True
251245 )
246+
252247 passages = [result ["metadata" ]["text" ] for result in sorted_results ]
253- passages = [dotdict ({"long_text" : passage for passage in passages }) ]
254- return dspy . Prediction ( passages = passages )
248+ passages = [dotdict ({"long_text" : passage }) for passage in passages ]
249+ return passages
255250
256251 # For multiple queries, query each and return the highest scoring passages
257252 # If a passage is returned multiple times, the score is accumulated. For this reason we increase top_k by 3x
258253 passage_scores = {}
259254 for embedding in embeddings :
260255 results_dict = self ._pinecone_index .query (
261- embedding , top_k = self . k * 3 , include_metadata = True ,
256+ vector = embedding , top_k = k * 3 , include_metadata = True
262257 )
263258 for result in results_dict ["matches" ]:
264259 passage_scores [result ["metadata" ]["text" ]] = (
265260 passage_scores .get (result ["metadata" ]["text" ], 0.0 )
266261 + result ["score" ]
267262 )
268-
263+
269264 sorted_passages = sorted (
270- passage_scores .items (), key = lambda x : x [1 ], reverse = True ,
271- )[: self .k ]
272- return dspy .Prediction (passages = [dotdict ({"long_text" : passage }) for passage , _ in sorted_passages ])
265+ passage_scores .items (), key = lambda x : x [1 ], reverse = True
266+ )[: k ]
267+
268+ passages = [dotdict ({"long_text" : passage }) for passage , _ in sorted_passages ]
269+ return passages
0 commit comments