Skip to content

Commit 1864168

Browse files
committed
Update openai imports and client methods to match upgraded python client
1 parent 8e71272 commit 1864168

File tree

1 file changed

+16
-9
lines changed

1 file changed

+16
-9
lines changed

dspy/retrieve/mongodb_atlas_rm.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
from typing import List, Optional, Union, Any
22
import dspy
33
import os
4-
import openai
4+
from openai import (
5+
OpenAI,
6+
APITimeoutError,
7+
InternalServerError,
8+
RateLimitError,
9+
UnprocessableEntityError,
10+
)
511
import backoff
612

713
try:
@@ -39,24 +45,25 @@ def build_vector_search_pipeline(
3945
class Embedder:
4046
def __init__(self, provider: str, model: str):
4147
if provider == "openai":
42-
openai.api_key = os.getenv("OPENAI_API_KEY")
43-
if not openai.api_key:
48+
api_key = os.getenv("OPENAI_API_KEY")
49+
if not api_key:
4450
raise ValueError("Environment variable OPENAI_API_KEY must be set")
45-
self.client = openai
51+
self.client = OpenAI()
4652
self.model = model
4753

4854
@backoff.on_exception(
4955
backoff.expo,
5056
(
51-
openai.error.RateLimitError,
52-
openai.error.ServiceUnavailableError,
53-
openai.error.APIError,
57+
APITimeoutError,
58+
InternalServerError,
59+
RateLimitError,
60+
UnprocessableEntityError,
5461
),
5562
max_time=15,
5663
)
5764
def __call__(self, queries) -> Any:
58-
embedding = self.client.Embedding.create(input=queries, model=self.model)
59-
return [embedding["embedding"] for embedding in embedding["data"]]
65+
embedding = self.client.embeddings.create(input=queries, model=self.model)
66+
return [result.embedding for result in embedding.data]
6067

6168

6269
class MongoDBAtlasRM(dspy.Retrieve):

0 commit comments

Comments
 (0)