1414# limitations under the License.
1515from __future__ import annotations
1616
17- from typing import Any , Optional , Type
17+ import abc
18+ from typing import Any , Optional
1819
1920from ..exceptions import LLMGenerationError
2021from .base import LLMInterface
2627 openai = None # type: ignore
2728
2829
29- class OpenAILLM (LLMInterface ):
30- client_class : Type [ openai . OpenAI ] = openai . OpenAI
31- async_client_class : Type [ openai . AsyncOpenAI ] = openai . AsyncOpenAI
30+ class BaseOpenAILLM (LLMInterface , abc . ABC ):
31+ client : Any
32+ async_client : Any
3233
3334 def __init__ (
3435 self ,
3536 model_name : str ,
3637 model_params : Optional [dict [str , Any ]] = None ,
37- ** kwargs : Any ,
3838 ):
3939 """
40+ Base class for OpenAI LLM.
41+
42+ Makes sure the openai Python client is installed during init.
4043
4144 Args:
4245 model_name (str):
4346 model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
44- kwargs: All other parameters will be passed to the openai.OpenAI init.
45-
4647 """
4748 if openai is None :
4849 raise ImportError (
4950 "Could not import openai Python client. "
5051 "Please install it with `pip install openai`."
5152 )
52- super ().__init__ (model_name , model_params , ** kwargs )
53- self .client = self .client_class (** kwargs )
54- self .async_client = self .async_client_class (** kwargs )
53+ super ().__init__ (model_name , model_params )
5554
5655 def get_messages (
5756 self ,
@@ -76,7 +75,7 @@ def invoke(self, input: str) -> LLMResponse:
7675 """
7776 try :
7877 response = self .client .chat .completions .create (
79- messages = self .get_messages (input ), # type: ignore
78+ messages = self .get_messages (input ),
8079 model = self .model_name ,
8180 ** self .model_params ,
8281 )
@@ -100,7 +99,7 @@ async def ainvoke(self, input: str) -> LLMResponse:
10099 """
101100 try :
102101 response = await self .async_client .chat .completions .create (
103- messages = self .get_messages (input ), # type: ignore
102+ messages = self .get_messages (input ),
104103 model = self .model_name ,
105104 ** self .model_params ,
106105 )
@@ -110,6 +109,42 @@ async def ainvoke(self, input: str) -> LLMResponse:
110109 raise LLMGenerationError (e )
111110
112111
113- class AzureOpenAILLM (OpenAILLM ):
114- client_class : Type [openai .OpenAI ] = openai .AzureOpenAI
115- async_client_class : Type [openai .AsyncOpenAI ] = openai .AsyncAzureOpenAI
112+ class OpenAILLM (BaseOpenAILLM ):
113+ def __init__ (
114+ self ,
115+ model_name : str ,
116+ model_params : Optional [dict [str , Any ]] = None ,
117+ ** kwargs : Any ,
118+ ):
119+ """OpenAI LLM
120+
121+ Wrapper for the openai Python client LLM.
122+
123+ Args:
124+ model_name (str):
125+ model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
126+ kwargs: All other parameters will be passed to the openai.OpenAI init.
127+ """
128+ super ().__init__ (model_name , model_params )
129+ self .client = openai .OpenAI (** kwargs )
130+ self .async_client = openai .AsyncOpenAI (** kwargs )
131+
132+
133+ class AzureOpenAILLM (BaseOpenAILLM ):
134+ def __init__ (
135+ self ,
136+ model_name : str ,
137+ model_params : Optional [dict [str , Any ]] = None ,
138+ ** kwargs : Any ,
139+ ):
140+ """Azure OpenAI LLM. Use this class when using an OpenAI model
141+ hosted on Microsoft Azure.
142+
143+ Args:
144+ model_name (str):
145+ model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
146+ kwargs: All other parameters will be passed to the openai.OpenAI init.
147+ """
148+ super ().__init__ (model_name , model_params )
149+ self .client = openai .AzureOpenAI (** kwargs )
150+ self .async_client = openai .AsyncAzureOpenAI (** kwargs )
0 commit comments