Skip to content

Commit e2115cf

Browse files
ruff fixes
1 parent 5db1853 commit e2115cf

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

dspy/retrieve/pinecone_rm.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,23 @@
33
Author: Dhar Rawal (@drawal1)
44
"""
55

6-
import os
7-
from dsp.utils import dotdict
8-
from typing import Optional, List, Union, Any
9-
import dspy
10-
import backoff
116
from abc import ABC, abstractmethod
7+
from typing import List, Optional, Union
8+
9+
import backoff
10+
11+
import dspy
12+
from dsp.utils import dotdict
1213

1314
try:
1415
import pinecone
1516
except ImportError:
1617
raise ImportError(
17-
"The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`"
18+
"The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`",
1819
)
1920

2021
import openai
22+
2123
try:
2224
OPENAI_LEGACY = int(openai.version.__version__[0]) == 0
2325
except Exception:
@@ -57,11 +59,11 @@ def __init__(self, model="text-embedding-ada-002", api_key: Optional[str]=None,
5759
def get_embeddings(self, queries: List[str]) -> List[List[float]]:
5860
if OPENAI_LEGACY:
5961
embedding = openai.Embedding.create(
60-
input=queries, model=self.model
62+
input=queries, model=self.model,
6163
)
6264
else:
6365
embedding = openai.embeddings.create(
64-
input=queries, model=self.model
66+
input=queries, model=self.model,
6567
).model_dump()
6668
return [embedding["embedding"] for embedding in embedding["data"]]
6769

@@ -71,7 +73,7 @@ def __init__(self, model: str = "multilingual-22-12", api_key: Optional[str] = N
7173
import cohere
7274
except ImportError:
7375
raise ImportError(
74-
"The cohere library is required to use CohereEmbed. Install it with `pip install cohere`"
76+
"The cohere library is required to use CohereEmbed. Install it with `pip install cohere`",
7577
)
7678
super().__init__(model, api_key)
7779
self.client = cohere.Client(api_key)
@@ -133,7 +135,7 @@ def __init__(
133135
from transformers import AutoModel, AutoTokenizer
134136
except ImportError as exc:
135137
raise ModuleNotFoundError(
136-
"You need to install Hugging Face transformers library to use a local embedding model with PineconeRM."
138+
"You need to install Hugging Face transformers library to use a local embedding model with PineconeRM.",
137139
) from exc
138140

139141
self._local_embed_model = AutoModel.from_pretrained(local_embed_model)
@@ -142,7 +144,7 @@ def __init__(
142144
self.device = torch.device(
143145
'cuda:0' if torch.cuda.is_available() else
144146
'mps' if torch.backends.mps.is_available()
145-
else 'cpu'
147+
else 'cpu',
146148
)
147149

148150
elif cloud_emded_provider is not None:
@@ -151,7 +153,7 @@ def __init__(
151153

152154
else:
153155
raise ValueError(
154-
"Either local_embed_model or cloud_embed_provider must be provided."
156+
"Either local_embed_model or cloud_embed_provider must be provided.",
155157
)
156158

157159
if pinecone_api_key is None:
@@ -167,13 +169,13 @@ def __init__(
167169
def _mean_pooling(
168170
self,
169171
model_output,
170-
attention_mask
172+
attention_mask,
171173
):
172174
try:
173175
import torch
174176
except ImportError as exc:
175177
raise ModuleNotFoundError(
176-
"You need to install torch to use a local embedding model with PineconeRM."
178+
"You need to install torch to use a local embedding model with PineconeRM.",
177179
) from exc
178180

179181
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
@@ -183,7 +185,7 @@ def _mean_pooling(
183185

184186
def _get_embeddings(
185187
self,
186-
queries: List[str]
188+
queries: List[str],
187189
) -> List[List[float]]:
188190
"""Return query vector after creating embedding using OpenAI
189191
@@ -200,7 +202,7 @@ def _get_embeddings(
200202
import torch
201203
except ImportError as exc:
202204
raise ModuleNotFoundError(
203-
"You need to install torch to use a local embedding model with PineconeRM."
205+
"You need to install torch to use a local embedding model with PineconeRM.",
204206
) from exc
205207

206208
# Use local model
@@ -236,12 +238,12 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No
236238
# For single query, just look up the top k passages
237239
if len(queries) == 1:
238240
results_dict = self._pinecone_index.query(
239-
vector=embeddings[0], top_k=k, include_metadata=True
241+
vector=embeddings[0], top_k=k, include_metadata=True,
240242
)
241243

242244
# Sort results by score
243245
sorted_results = sorted(
244-
results_dict["matches"], key=lambda x: x.get("scores", 0.0), reverse=True
246+
results_dict["matches"], key=lambda x: x.get("scores", 0.0), reverse=True,
245247
)
246248

247249
passages = [result["metadata"]["text"] for result in sorted_results]
@@ -253,7 +255,7 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No
253255
passage_scores = {}
254256
for embedding in embeddings:
255257
results_dict = self._pinecone_index.query(
256-
vector=embedding, top_k=k * 3, include_metadata=True
258+
vector=embedding, top_k=k * 3, include_metadata=True,
257259
)
258260
for result in results_dict["matches"]:
259261
passage_scores[result["metadata"]["text"]] = (
@@ -262,7 +264,7 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = No
262264
)
263265

264266
sorted_passages = sorted(
265-
passage_scores.items(), key=lambda x: x[1], reverse=True
267+
passage_scores.items(), key=lambda x: x[1], reverse=True,
266268
)[: k]
267269

268270
passages=[dotdict({"long_text": passage}) for passage, _ in sorted_passages]

0 commit comments

Comments
 (0)