33Author: Dhar Rawal (@drawal1)
44"""
55
6- import os
7- from dsp .utils import dotdict
8- from typing import Optional , List , Union , Any
9- import dspy
10- import backoff
116from abc import ABC , abstractmethod
7+ from typing import List , Optional , Union
8+
9+ import backoff
10+
11+ import dspy
12+ from dsp .utils import dotdict
1213
1314try :
1415 import pinecone
1516except ImportError :
1617 raise ImportError (
17- "The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`"
18+ "The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`" ,
1819 )
1920
2021import openai
22+
2123try :
2224 OPENAI_LEGACY = int (openai .version .__version__ [0 ]) == 0
2325except Exception :
@@ -57,11 +59,11 @@ def __init__(self, model="text-embedding-ada-002", api_key: Optional[str]=None,
5759 def get_embeddings (self , queries : List [str ]) -> List [List [float ]]:
5860 if OPENAI_LEGACY :
5961 embedding = openai .Embedding .create (
60- input = queries , model = self .model
62+ input = queries , model = self .model ,
6163 )
6264 else :
6365 embedding = openai .embeddings .create (
64- input = queries , model = self .model
66+ input = queries , model = self .model ,
6567 ).model_dump ()
6668 return [embedding ["embedding" ] for embedding in embedding ["data" ]]
6769
@@ -71,7 +73,7 @@ def __init__(self, model: str = "multilingual-22-12", api_key: Optional[str] = N
7173 import cohere
7274 except ImportError :
7375 raise ImportError (
74- "The cohere library is required to use CohereEmbed. Install it with `pip install cohere`"
76+ "The cohere library is required to use CohereEmbed. Install it with `pip install cohere`" ,
7577 )
7678 super ().__init__ (model , api_key )
7779 self .client = cohere .Client (api_key )
@@ -133,7 +135,7 @@ def __init__(
133135 from transformers import AutoModel , AutoTokenizer
134136 except ImportError as exc :
135137 raise ModuleNotFoundError (
136- "You need to install Hugging Face transformers library to use a local embedding model with PineconeRM."
138+ "You need to install Hugging Face transformers library to use a local embedding model with PineconeRM." ,
137139 ) from exc
138140
139141 self ._local_embed_model = AutoModel .from_pretrained (local_embed_model )
@@ -142,7 +144,7 @@ def __init__(
142144 self .device = torch .device (
143145 'cuda:0' if torch .cuda .is_available () else
144146 'mps' if torch .backends .mps .is_available ()
145- else 'cpu'
147+ else 'cpu' ,
146148 )
147149
148150 elif cloud_emded_provider is not None :
@@ -151,7 +153,7 @@ def __init__(
151153
152154 else :
153155 raise ValueError (
154- "Either local_embed_model or cloud_embed_provider must be provided."
156+ "Either local_embed_model or cloud_embed_provider must be provided." ,
155157 )
156158
157159 if pinecone_api_key is None :
@@ -167,13 +169,13 @@ def __init__(
167169 def _mean_pooling (
168170 self ,
169171 model_output ,
170- attention_mask
172+ attention_mask ,
171173 ):
172174 try :
173175 import torch
174176 except ImportError as exc :
175177 raise ModuleNotFoundError (
176- "You need to install torch to use a local embedding model with PineconeRM."
178+ "You need to install torch to use a local embedding model with PineconeRM." ,
177179 ) from exc
178180
179181 token_embeddings = model_output [0 ] # First element of model_output contains all token embeddings
@@ -183,7 +185,7 @@ def _mean_pooling(
183185
184186 def _get_embeddings (
185187 self ,
186- queries : List [str ]
188+ queries : List [str ],
187189 ) -> List [List [float ]]:
188190 """Return query vector after creating embedding using OpenAI
189191
@@ -200,7 +202,7 @@ def _get_embeddings(
200202 import torch
201203 except ImportError as exc :
202204 raise ModuleNotFoundError (
203- "You need to install torch to use a local embedding model with PineconeRM."
205+ "You need to install torch to use a local embedding model with PineconeRM." ,
204206 ) from exc
205207
206208 # Use local model
@@ -236,12 +238,12 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No
236238 # For single query, just look up the top k passages
237239 if len (queries ) == 1 :
238240 results_dict = self ._pinecone_index .query (
239- vector = embeddings [0 ], top_k = k , include_metadata = True
241+ vector = embeddings [0 ], top_k = k , include_metadata = True ,
240242 )
241243
242244 # Sort results by score
243245 sorted_results = sorted (
244- results_dict ["matches" ], key = lambda x : x .get ("scores" , 0.0 ), reverse = True
246+ results_dict ["matches" ], key = lambda x : x .get ("scores" , 0.0 ), reverse = True ,
245247 )
246248
247249 passages = [result ["metadata" ]["text" ] for result in sorted_results ]
@@ -253,7 +255,7 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No
253255 passage_scores = {}
254256 for embedding in embeddings :
255257 results_dict = self ._pinecone_index .query (
256- vector = embedding , top_k = k * 3 , include_metadata = True
258+ vector = embedding , top_k = k * 3 , include_metadata = True ,
257259 )
258260 for result in results_dict ["matches" ]:
259261 passage_scores [result ["metadata" ]["text" ]] = (
@@ -262,7 +264,7 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No
262264 )
263265
264266 sorted_passages = sorted (
265- passage_scores .items (), key = lambda x : x [1 ], reverse = True
267+ passage_scores .items (), key = lambda x : x [1 ], reverse = True ,
266268 )[: k ]
267269
268270 passages = [dotdict ({"long_text" : passage }) for passage , _ in sorted_passages ]
0 commit comments