11import os
2- from typing import Callable , Dict , List , Optional
2+ from typing import Any , Callable , Dict , List , Optional
33
44from tenacity import retry , stop_after_attempt , wait_random_exponential
55from tenacity .retry import retry_if_not_exception_type
@@ -19,7 +19,7 @@ class OpenAITextVectorizer(BaseVectorizer):
1919 in the `api_config` dictionary or through the `OPENAI_API_KEY` environment
2020 variable. Users must obtain an API key from OpenAI's website
2121 (https://api.openai.com/). Additionally, the `openai` python client must be
22- installed with `pip install openai==0.28.1 `.
22+ installed with `pip install openai>=1.13.0 `.
2323
2424 The vectorizer supports both synchronous and asynchronous operations,
2525 allowing for batch processing of texts and flexibility in handling
@@ -42,6 +42,8 @@ class OpenAITextVectorizer(BaseVectorizer):
4242
4343 """
4444
45+ aclient : Any # Since the OpenAI module is loaded dynamically
46+
4547 def __init__ (
4648 self , model : str = "text-embedding-ada-002" , api_config : Optional [Dict ] = None
4749 ):
@@ -59,7 +61,7 @@ def __init__(
5961 """
6062 # Dynamic import of the openai module
6163 try :
62- import openai
64+ from openai import AsyncOpenAI , OpenAI
6365 except ImportError :
6466 raise ImportError (
6567 "OpenAI vectorizer requires the openai library. \
@@ -77,17 +79,19 @@ def __init__(
7779 environment variable."
7880 )
7981
80- openai .api_key = api_key
81- client = openai .Embedding
82+ client = OpenAI (api_key = api_key )
8283 dims = self ._set_model_dims (client , model )
8384 super ().__init__ (model = model , dims = dims , client = client )
85+ self .aclient = AsyncOpenAI (api_key = api_key )
8486
8587 @staticmethod
8688 def _set_model_dims (client , model ) -> int :
8789 try :
88- embedding = client .create (input = ["dimension test" ], engine = model )["data" ][
89- 0
90- ]["embedding" ]
90+ embedding = (
91+ client .embeddings .create (input = ["dimension test" ], model = model )
92+ .data [0 ]
93+ .embedding
94+ )
9195 except (KeyError , IndexError ) as ke :
9296 raise ValueError (f"Unexpected response from the OpenAI API: { str (ke )} " )
9397 except Exception as e : # pylint: disable=broad-except
@@ -132,10 +136,9 @@ def embed_many(
132136
133137 embeddings : List = []
134138 for batch in self .batchify (texts , batch_size , preprocess ):
135- response = self .client .create (input = batch , engine = self .model )
139+ response = self .client .embeddings . create (input = batch , model = self .model )
136140 embeddings += [
137- self ._process_embedding (r ["embedding" ], as_buffer )
138- for r in response ["data" ]
141+ self ._process_embedding (r .embedding , as_buffer ) for r in response .data
139142 ]
140143 return embeddings
141144
@@ -171,8 +174,8 @@ def embed(
171174
172175 if preprocess :
173176 text = preprocess (text )
174- result = self .client .create (input = [text ], engine = self .model )
175- return self ._process_embedding (result [ " data" ] [0 ][ " embedding" ] , as_buffer )
177+ result = self .client .embeddings . create (input = [text ], model = self .model )
178+ return self ._process_embedding (result . data [0 ]. embedding , as_buffer )
176179
177180 @retry (
178181 wait = wait_random_exponential (min = 1 , max = 60 ),
@@ -211,10 +214,11 @@ async def aembed_many(
211214
212215 embeddings : List = []
213216 for batch in self .batchify (texts , batch_size , preprocess ):
214- response = await self .client .acreate (input = batch , engine = self .model )
217+ response = await self .aclient .embeddings .create (
218+ input = batch , model = self .model
219+ )
215220 embeddings += [
216- self ._process_embedding (r ["embedding" ], as_buffer )
217- for r in response ["data" ]
221+ self ._process_embedding (r .embedding , as_buffer ) for r in response .data
218222 ]
219223 return embeddings
220224
@@ -250,5 +254,5 @@ async def aembed(
250254
251255 if preprocess :
252256 text = preprocess (text )
253- result = await self .client . acreate (input = [text ], engine = self .model )
254- return self ._process_embedding (result [ " data" ] [0 ][ " embedding" ] , as_buffer )
257+ result = await self .aclient . embeddings . create (input = [text ], model = self .model )
258+ return self ._process_embedding (result . data [0 ]. embedding , as_buffer )
0 commit comments