Skip to content

Commit f83f906

Browse files
authored
Merge pull request #261 from DanielUH2019/backwards-compatibility-for-rm
openai 0.28 backwards compatibility for pinecone and chromadb
2 parents 30f240a + e4d5bd2 commit f83f906

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

dspy/retrieve/chromadb_rm.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88
import backoff
99
from dsp.utils import dotdict
1010

11+
try:
12+
import openai.error
13+
ERRORS = (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.APIError)
14+
except Exception:
15+
ERRORS = (openai.RateLimitError, openai.APIError)
16+
1117
try:
1218
import chromadb
1319
from chromadb.config import Settings
@@ -108,7 +114,7 @@ def _init_chromadb(
108114

109115
@backoff.on_exception(
110116
backoff.expo,
111-
(openai.RateLimitError),
117+
ERRORS,
112118
max_time=15,
113119
)
114120
def _get_embeddings(self, queries: List[str]) -> List[List[float]]:

dspy/retrieve/pinecone_rm.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from dsp.utils import dotdict
77
from typing import Optional, List, Union
8-
import openai
98
import dspy
109
import backoff
1110

@@ -19,6 +18,17 @@
1918
"The pinecone library is required to use PineconeRM. Install it with `pip install dspy-ai[pinecone]`"
2019
)
2120

21+
import openai
22+
try:
23+
OPENAI_LEGACY = int(openai.version.__version__[0]) == 0
24+
except Exception:
25+
OPENAI_LEGACY = True
26+
27+
try:
28+
import openai.error
29+
ERRORS = (openai.error.RateLimitError, openai.error.ServiceUnavailableError, openai.error.APIError)
30+
except Exception:
31+
ERRORS = (openai.RateLimitError, openai.APIError)
2232

2333
class PineconeRM(dspy.Retrieve):
2434
"""
@@ -164,7 +174,7 @@ def _mean_pooling(
164174

165175
@backoff.on_exception(
166176
backoff.expo,
167-
(openai.RateLimitError),
177+
ERRORS,
168178
max_time=15,
169179
)
170180
def _get_embeddings(
@@ -187,10 +197,15 @@ def _get_embeddings(
187197
) from exc
188198

189199
if not self.use_local_model:
190-
embedding = openai.embeddings.create(
191-
input=queries, model=self._openai_embed_model
192-
)
193-
return [embedding.embedding for embedding in embedding.data]
200+
if OPENAI_LEGACY:
201+
embedding = openai.Embedding.create(
202+
input=queries, model=self._openai_embed_model
203+
)
204+
else:
205+
embedding = openai.embeddings.create(
206+
input=queries, model=self._openai_embed_model
207+
).model_dump()
208+
return [embedding["embedding"] for embedding in embedding["data"]]
194209

195210
# Use local model
196211
encoded_input = self._local_tokenizer(queries, padding=True, truncation=True, return_tensors="pt").to(self.device)

0 commit comments

Comments
 (0)