Skip to content

Commit 62eaf88

Browse files
committed
removed adding model as a required argument, so that it aligns with prem's workflow
1 parent 47f9199 commit 62eaf88

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

dsp/modules/premai.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
try:
99
import premai
1010

11-
premai_error = premai.errors.UnexpectedStatus
11+
premai_api_error = premai.errors.UnexpectedStatus
1212
except ImportError:
1313
premai_api_error = Exception
1414
except AttributeError:
@@ -49,18 +49,18 @@ class PremAI(LM):
4949

5050
def __init__(
5151
self,
52-
model: str,
5352
project_id: int,
53+
model: Optional[str] = None,
5454
api_key: Optional[str] = None,
5555
session_id: Optional[int] = None,
5656
**kwargs,
5757
) -> None:
5858
"""Parameters
5959
60-
model: str
61-
The name of model name
6260
project_id: int
6361
"The project ID in which the experiments or deployments are carried out. can find all your projects here: https://app.premai.io/projects/"
62+
model: Optional[str]
63+
The name of model deployed on launchpad. When None, it will show 'default'
6464
api_key: Optional[str]
6565
Prem AI API key, to connect with the API. If not provided then it will check from env var by the name
6666
PREMAI_API_KEY
@@ -69,6 +69,7 @@ def __init__(
6969
**kwargs: dict
7070
Additional arguments to pass to the API provider
7171
"""
72+
model = "default" if model is None else model
7273
super().__init__(model)
7374
if premai_api_error == Exception:
7475
raise ImportError(
@@ -85,13 +86,18 @@ def __init__(
8586
self.history: list[dict[str, Any]] = []
8687

8788
self.kwargs = {
88-
"model": model,
8989
"temperature": 0.17,
9090
"max_tokens": 150,
9191
**kwargs,
9292
}
9393
if session_id is not None:
94-
kwargs["session_id"] = session_id
94+
self.kwargs["session_id"] = session_id
95+
96+
# However this is not recommended to change the model once
97+
# deployed from launchpad
98+
99+
if model != "default":
100+
self.kwargs["model"] = model
95101

96102
def _get_all_kwargs(self, **kwargs) -> dict:
97103
other_kwargs = {
@@ -111,7 +117,6 @@ def _get_all_kwargs(self, **kwargs) -> dict:
111117
"frequency_penalty",
112118
"presence_penalty",
113119
"tools",
114-
"model",
115120
]
116121

117122
for key in _keys_that_cannot_be_none:
@@ -122,15 +127,15 @@ def _get_all_kwargs(self, **kwargs) -> dict:
122127
def basic_request(self, prompt, **kwargs) -> str:
123128
"""Handles retrieval of completions from Prem AI whilst handling API errors."""
124129
all_kwargs = self._get_all_kwargs(**kwargs)
125-
message = []
130+
messages = []
126131

127132
if "system_prompt" in all_kwargs:
128-
message.append({"role": "system", "content": all_kwargs["system_prompt"]})
129-
message.append({"role": "user", "content": prompt})
133+
messages.append({"role": "system", "content": all_kwargs["system_prompt"]})
134+
messages.append({"role": "user", "content": prompt})
130135

131136
response = self.client.chat.completions.create(
132137
project_id=self.project_id,
133-
messages=message,
138+
messages=messages,
134139
**all_kwargs,
135140
)
136141
if not response.choices:

0 commit comments

Comments
 (0)