@@ -190,11 +190,13 @@ def embed_many(
190190 if len (texts ) > 0 and not isinstance (texts [0 ], str ):
191191 raise TypeError ("Must pass in a list of str values to embed." )
192192
193+ dtype = kwargs .pop ("dtype" , None )
194+
193195 embeddings : List = []
194196 for batch in self .batchify (texts , batch_size , preprocess ):
195197 response = self ._client .embeddings .create (input = batch , model = self .model )
196198 embeddings += [
197- self ._process_embedding (r .embedding , as_buffer , ** kwargs )
199+ self ._process_embedding (r .embedding , as_buffer , dtype )
198200 for r in response .data
199201 ]
200202 return embeddings
@@ -231,8 +233,11 @@ def embed(
231233
232234 if preprocess :
233235 text = preprocess (text )
236+
237+ dtype = kwargs .pop ("dtype" , None )
238+
234239 result = self ._client .embeddings .create (input = [text ], model = self .model )
235- return self ._process_embedding (result .data [0 ].embedding , as_buffer , ** kwargs )
240+ return self ._process_embedding (result .data [0 ].embedding , as_buffer , dtype )
236241
237242 @retry (
238243 wait = wait_random_exponential (min = 1 , max = 60 ),
@@ -269,13 +274,15 @@ async def aembed_many(
269274 if len (texts ) > 0 and not isinstance (texts [0 ], str ):
270275 raise TypeError ("Must pass in a list of str values to embed." )
271276
277+ dtype = kwargs .pop ("dtype" , None )
278+
272279 embeddings : List = []
273280 for batch in self .batchify (texts , batch_size , preprocess ):
274281 response = await self ._aclient .embeddings .create (
275282 input = batch , model = self .model
276283 )
277284 embeddings += [
278- self ._process_embedding (r .embedding , as_buffer , ** kwargs )
285+ self ._process_embedding (r .embedding , as_buffer , dtype )
279286 for r in response .data
280287 ]
281288 return embeddings
@@ -312,8 +319,11 @@ async def aembed(
312319
313320 if preprocess :
314321 text = preprocess (text )
322+
323+ dtype = kwargs .pop ("dtype" , None )
324+
315325 result = await self ._aclient .embeddings .create (input = [text ], model = self .model )
316- return self ._process_embedding (result .data [0 ].embedding , as_buffer , ** kwargs )
326+ return self ._process_embedding (result .data [0 ].embedding , as_buffer , dtype )
317327
318328 @property
319329 def type (self ) -> str :
0 commit comments