1515
1616from __future__ import annotations
1717
18- from typing import Any
18+ import abc
19+ from typing import TYPE_CHECKING , Any
20+
1921from neo4j_graphrag .embeddings .base import Embedder
2022
23+ if TYPE_CHECKING :
24+ import openai
2125
22- class OpenAIEmbeddings (Embedder ):
23- """
24- OpenAI embeddings class.
25- This class uses the OpenAI python client to generate embeddings for text data.
2626
27- Args :
28- model (str): The name of the OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
29- kwargs: All other parameters will be passed to the openai. OpenAI init .
27+ class BaseOpenAIEmbeddings ( Embedder , abc . ABC ) :
28+ """
29+ Abstract base class for OpenAI embeddings .
3030 """
3131
32+ client : openai .OpenAI
33+
3234 def __init__ (self , model : str = "text-embedding-ada-002" , ** kwargs : Any ) -> None :
3335 try :
3436 import openai
@@ -39,23 +41,52 @@ def __init__(self, model: str = "text-embedding-ada-002", **kwargs: Any) -> None
3941 )
4042 self .openai = openai
4143 self .model = model
42- self .openai_client = self .openai .OpenAI (** kwargs )
44+ self .client = self ._initialize_client (** kwargs )
45+
46+ @abc .abstractmethod
47+ def _initialize_client (self , ** kwargs : Any ) -> Any :
48+ """
49+ Initialize the OpenAI client.
50+ Must be implemented by subclasses.
51+ """
52+ pass
4353
4454 def embed_query (self , text : str , ** kwargs : Any ) -> list [float ]:
4555 """
46- Generate embeddings for a given query using a OpenAI text embedding model.
56+ Generate embeddings for a given query using an OpenAI text embedding model.
4757
4858 Args:
4959 text (str): The text to generate an embedding for.
5060 **kwargs (Any): Additional arguments to pass to the OpenAI embedding generation function.
5161 """
52- response = self .openai_client .embeddings .create (
53- input = text , model = self .model , ** kwargs
54- )
55- return response .data [0 ].embedding
62+ response = self .client .embeddings .create (input = text , model = self .model , ** kwargs )
63+ embedding : list [float ] = response .data [0 ].embedding
64+ return embedding
5665
5766
58- class AzureOpenAIEmbeddings (OpenAIEmbeddings ):
59- def __init__ (self , model : str = "text-embedding-ada-002" , ** kwargs : Any ) -> None :
60- super ().__init__ (model , ** kwargs )
61- self .openai_client = self .openai .AzureOpenAI (** kwargs )
67+ class OpenAIEmbeddings (BaseOpenAIEmbeddings ):
68+ """
69+ OpenAI embeddings class.
70+ This class uses the OpenAI python client to generate embeddings for text data.
71+
72+ Args:
73+ model (str): The name of the OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
74+ kwargs: All other parameters will be passed to the openai.OpenAI init.
75+ """
76+
77+ def _initialize_client (self , ** kwargs : Any ) -> Any :
78+ return self .openai .OpenAI (** kwargs )
79+
80+
81+ class AzureOpenAIEmbeddings (BaseOpenAIEmbeddings ):
82+ """
83+ Azure OpenAI embeddings class.
84+ This class uses the Azure OpenAI python client to generate embeddings for text data.
85+
86+ Args:
87+ model (str): The name of the Azure OpenAI embedding model to use. Defaults to "text-embedding-ada-002".
88+ kwargs: All other parameters will be passed to the openai.AzureOpenAI init.
89+ """
90+
91+ def _initialize_client (self , ** kwargs : Any ) -> Any :
92+ return self .openai .AzureOpenAI (** kwargs )
0 commit comments