88try :
99 import premai
1010
11- premai_error = premai .errors .UnexpectedStatus
11+ premai_api_error = premai .errors .UnexpectedStatus
1212except ImportError :
1313 premai_api_error = Exception
1414except AttributeError :
@@ -49,18 +49,18 @@ class PremAI(LM):
4949
5050 def __init__ (
5151 self ,
52- model : str ,
5352 project_id : int ,
53+ model : Optional [str ] = None ,
5454 api_key : Optional [str ] = None ,
5555 session_id : Optional [int ] = None ,
5656 ** kwargs ,
5757 ) -> None :
5858 """Parameters
5959
60- model: str
61- The name of model name
6260 project_id: int
6361 "The project ID in which the experiments or deployments are carried out. can find all your projects here: https://app.premai.io/projects/"
62+ model: Optional[str]
63+ The name of model deployed on launchpad. When None, it will show 'default'
6464 api_key: Optional[str]
6565 Prem AI API key, to connect with the API. If not provided then it will check from env var by the name
6666 PREMAI_API_KEY
@@ -69,6 +69,7 @@ def __init__(
6969 **kwargs: dict
7070 Additional arguments to pass to the API provider
7171 """
72+ model = "default" if model is None else model
7273 super ().__init__ (model )
7374 if premai_api_error == Exception :
7475 raise ImportError (
@@ -85,13 +86,18 @@ def __init__(
8586 self .history : list [dict [str , Any ]] = []
8687
8788 self .kwargs = {
88- "model" : model ,
8989 "temperature" : 0.17 ,
9090 "max_tokens" : 150 ,
9191 ** kwargs ,
9292 }
9393 if session_id is not None :
94- kwargs ["session_id" ] = session_id
94+ self .kwargs ["session_id" ] = session_id
95+
96+ # However this is not recommended to change the model once
97+ # deployed from launchpad
98+
99+ if model != "default" :
100+ self .kwargs ["model" ] = model
95101
96102 def _get_all_kwargs (self , ** kwargs ) -> dict :
97103 other_kwargs = {
@@ -111,7 +117,6 @@ def _get_all_kwargs(self, **kwargs) -> dict:
111117 "frequency_penalty" ,
112118 "presence_penalty" ,
113119 "tools" ,
114- "model" ,
115120 ]
116121
117122 for key in _keys_that_cannot_be_none :
@@ -122,15 +127,15 @@ def _get_all_kwargs(self, **kwargs) -> dict:
122127 def basic_request (self , prompt , ** kwargs ) -> str :
123128 """Handles retrieval of completions from Prem AI whilst handling API errors."""
124129 all_kwargs = self ._get_all_kwargs (** kwargs )
125- message = []
130+ messages = []
126131
127132 if "system_prompt" in all_kwargs :
128- message .append ({"role" : "system" , "content" : all_kwargs ["system_prompt" ]})
129- message .append ({"role" : "user" , "content" : prompt })
133+ messages .append ({"role" : "system" , "content" : all_kwargs ["system_prompt" ]})
134+ messages .append ({"role" : "user" , "content" : prompt })
130135
131136 response = self .client .chat .completions .create (
132137 project_id = self .project_id ,
133- messages = message ,
138+ messages = messages ,
134139 ** all_kwargs ,
135140 )
136141 if not response .choices :
0 commit comments