1+ # Copyright (c) "Neo4j"
2+ # Neo4j Sweden AB [https://neo4j.com]
3+ # #
4+ # Licensed under the Apache License, Version 2.0 (the "License");
5+ # you may not use this file except in compliance with the License.
6+ # You may obtain a copy of the License at
7+ # #
8+ # https://www.apache.org/licenses/LICENSE-2.0
9+ # #
10+ # Unless required by applicable law or agreed to in writing, software
11+ # distributed under the License is distributed on an "AS IS" BASIS,
12+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ # See the License for the specific language governing permissions and
14+ # limitations under the License.
15+
116from __future__ import annotations
217
318from typing import Any
621
722
823class OpenAIEmbeddings (Embedder ):
9- def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
24+ def __init__ (self , model : str = "text-embedding-ada-002" ) -> None :
1025 try :
1126 import openai
1227 except ImportError :
@@ -15,10 +30,11 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
1530 "Please install it with `pip install openai`."
1631 )
1732
18- self .model = openai .OpenAI (* args , ** kwargs )
33+ self .openai_model = openai .OpenAI ()
34+ self .model = model
1935
20- def embed_query (
21- self , text : str , model : str = "text-embedding-ada-002" , ** kwargs : Any
22- ) -> list [ float ]:
23- response = self . model . embeddings . create ( input = text , model = model , ** kwargs )
36+ def embed_query (self , text : str , ** kwargs : Any ) -> list [ float ]:
37+ response = self . openai_model . embeddings . create (
38+ input = text , model = self . model , ** kwargs
39+ )
2440 return response .data [0 ].embedding
0 commit comments