Skip to content

Commit 9dfb673

Browse files
Merge pull request #1007 from Anindyadeep/modules/premai
feat(dspy) PremAI python sdk
2 parents d63d24f + 4044d38 commit 9dfb673

File tree

7 files changed

+328
-5
lines changed

7 files changed

+328
-5
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
---
2+
sidebar_position: 5
3+
---
4+
5+
# dsp.PremAI
6+
7+
[PremAI](https://app.premai.io) is an all-in-one platform that simplifies the process of creating robust, production-ready applications powered by Generative AI. By streamlining the development process, PremAI allows you to concentrate on enhancing user experience and driving overall growth for your application.
8+
9+
### Prerequisites
10+
11+
Refer to the [quick start](https://docs.premai.io/introduction) guide to getting started with the PremAI platform, create your first project and grab your API key.
12+
13+
### Usage
14+
15+
Please make sure you have premai python sdk installed. Otherwise you can do it using this command:
16+
17+
```bash
18+
pip install -U premai
19+
```
20+
21+
Here is a quick example on how to use premai python sdk with dspy
22+
23+
```python
24+
from dspy import PremAI
25+
26+
llm = PremAI(model='mistral-tiny', project_id=123, api_key="your-premai-api-key")
27+
print(llm("what is a large language model"))
28+
```
29+
30+
> Please note: Project ID 123 is just an example. You can find your project ID inside our platform under which you created your project.
31+
32+
### Constructor
33+
34+
The constructor initializes the base class `LM` and verifies the `api_key` provided or defined through the `PREMAI_API_KEY` environment variable.
35+
36+
```python
37+
class PremAI(LM):
38+
def __init__(
39+
self,
40+
model: str,
41+
project_id: int,
42+
api_key: str,
43+
base_url: Optional[str] = None,
44+
session_id: Optional[int] = None,
45+
**kwargs,
46+
) -> None:
47+
```
48+
49+
**Parameters:**
50+
51+
- `model` (_str_): Models supported by PremAI. Example: `mistral-tiny`. We recommend using the model selected in [project launchpad](https://docs.premai.io/get-started/launchpad).
52+
- `project_id` (_int_): The [project id](https://docs.premai.io/get-started/projects) which contains the model of choice.
53+
- `api_key` (_Optional[str]_, _optional_): API provider from PremAI. Defaults to None.
54+
- `session_id` (_Optional[int]_, _optional_): The ID of the session to use. It helps to track the chat history.
55+
- `**kwargs`: Additional language model arguments will be passed to the API provider.
56+
57+
### Methods
58+
59+
#### `__call__(self, prompt: str, **kwargs) -> List[Dict[str, Any]]`
60+
61+
Retrieves completions from PremAI by calling `request`.
62+
63+
Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response.
64+
65+
After generation, the completions are post-processed based on the `model_type` parameter.
66+
67+
**Parameters:**
68+
69+
- `prompt` (_str_): Prompt to send to PremAI.
70+
- `**kwargs`: Additional keyword arguments for completion request. Example: parameters like `temperature`, `max_tokens` etc. You can find all the additional kwargs [here](https://docs.premai.io/get-started/sdk#optional-parameters).

docs/docs/building-blocks/1-language_models.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ lm = dspy.{provider_listed_below}(model="your model", model_request_kwargs="..."
137137

138138
4. `dspy.Together` for hosted various open source models.
139139

140+
5. `dspy.PremAI` for hosted best open source and closed source models.
140141

141142
### Local LMs.
142143

@@ -173,4 +174,4 @@ model = 'dist/prebuilt/mlc-chat-Llama-2-7b-chat-hf-q4f16_1'
173174
model_path = 'dist/prebuilt/lib/Llama-2-7b-chat-hf-q4f16_1-cuda.so'
174175

175176
llama = dspy.ChatModuleClient(model=model, model_path=model_path)
176-
```
177+
```
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
## PremAI
2+
3+
[PremAI](https://app.premai.io) is an all-in-one platform that simplifies the process of creating robust, production-ready applications powered by Generative AI. By streamlining the development process, PremAI allows you to concentrate on enhancing user experience and driving overall growth for your application.
4+
5+
### Prerequisites
6+
7+
Refer to the [quick start](https://docs.premai.io/introduction) guide to getting started with the PremAI platform, create your first project and grab your API key.
8+
9+
### Setting up the PremAI Client
10+
11+
The constructor initializes the base class `LM` to support prompting requests to supported PremAI hosted models. This requires the following parameters:
12+
13+
- `model` (_str_): Models supported by PremAI. Example: `mistral-tiny`. We recommend using the model selected in [project launchpad](https://docs.premai.io/get-started/launchpad).
14+
- `project_id` (_int_): The [project id](https://docs.premai.io/get-started/projects) which contains the model of choice.
15+
- `api_key` (_Optional[str]_, _optional_): API provider from PremAI. Defaults to None.
16+
- `session_id` (_Optional[int]_, _optional_): The ID of the session to use. It helps to track the chat history.
17+
- `**kwargs`: Additional language model arguments will be passed to the API provider.
18+
19+
Example of PremAI constructor:
20+
21+
```python
22+
class PremAI(LM):
23+
def __init__(
24+
self,
25+
model: str,
26+
project_id: int,
27+
api_key: str,
28+
base_url: Optional[str] = None,
29+
session_id: Optional[int] = None,
30+
**kwargs,
31+
) -> None:
32+
```
33+
34+
### Under the Hood
35+
36+
#### `__call__(self, prompt: str, **kwargs) -> str`
37+
38+
**Parameters:**
39+
- `prompt` (_str_): Prompt to send to PremAI.
40+
- `**kwargs`: Additional keyword arguments for completion request.
41+
42+
**Returns:**
43+
- `str`: Completions string from the chosen LLM provider
44+
45+
Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response.
46+
47+
### Using the PremAI client
48+
49+
```python
50+
premai_client = dspy.PremAI(project_id=1111)
51+
```
52+
53+
Please note that, this is a dummy `project_id`. You need to change this to the project_id you are interested to use with dspy.
54+
55+
```python
56+
dspy.configure(lm=premai_client)
57+
58+
#Example DSPy CoT QA program
59+
qa = dspy.ChainOfThought('question -> answer')
60+
61+
response = qa(question="What is the capital of Paris?")
62+
print(response.answer)
63+
```
64+
65+
2) Generate responses using the client directly.
66+
67+
```python
68+
response = premai_client(prompt='What is the capital of Paris?')
69+
print(response)
70+
```

dsp/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .hf_client import Anyscale, HFClientTGI, Together
2121
from .mistral import *
2222
from .ollama import *
23+
from .premai import PremAI
2324
from .pyserini import *
2425
from .sbert import *
2526
from .sentence_vectorizer import *

dsp/modules/lm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,15 @@ def inspect_history(self, n: int = 1, skip: int = 0):
5252
or provider == "groq"
5353
or provider == "Bedrock"
5454
or provider == "Sagemaker"
55+
or provider == "premai"
5556
):
5657
printed.append((prompt, x["response"]))
5758
elif provider == "anthropic":
58-
blocks = [{"text": block.text} for block in x["response"].content if block.type == "text"]
59+
blocks = [
60+
{"text": block.text}
61+
for block in x["response"].content
62+
if block.type == "text"
63+
]
5964
printed.append((prompt, blocks))
6065
elif provider == "cohere":
6166
printed.append((prompt, x["response"].text))
@@ -85,7 +90,7 @@ def inspect_history(self, n: int = 1, skip: int = 0):
8590
if provider == "cohere" or provider == "Bedrock" or provider == "Sagemaker":
8691
text = choices
8792
elif provider == "openai" or provider == "ollama":
88-
text = ' ' + self._get_choice_text(choices[0]).strip()
93+
text = " " + self._get_choice_text(choices[0]).strip()
8994
elif provider == "clarifai" or provider == "claude":
9095
text = choices
9196
elif provider == "groq":
@@ -96,14 +101,16 @@ def inspect_history(self, n: int = 1, skip: int = 0):
96101
text = choices[0].message.content
97102
elif provider == "cloudflare":
98103
text = choices[0]
99-
elif provider == "ibm":
104+
elif provider == "ibm" or provider == "premai":
100105
text = choices
101106
else:
102107
text = choices[0]["text"]
103108
printing_value += self.print_green(text, end="")
104109

105110
if len(choices) > 1 and isinstance(choices, list):
106-
printing_value += self.print_red(f" \t (and {len(choices)-1} other completions)", end="")
111+
printing_value += self.print_red(
112+
f" \t (and {len(choices)-1} other completions)", end="",
113+
)
107114

108115
printing_value += "\n\n\n"
109116

dsp/modules/premai.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
import os
2+
from typing import Any, Optional
3+
4+
import backoff
5+
6+
from dsp.modules.lm import LM
7+
8+
try:
9+
import premai
10+
11+
premai_api_error = premai.errors.UnexpectedStatus
12+
except ImportError:
13+
premai_api_error = Exception
14+
except AttributeError:
15+
premai_api_error = Exception
16+
17+
18+
def backoff_hdlr(details) -> None:
19+
"""Handler for the backoff package.
20+
21+
See more at: https://pypi.org/project/backoff/
22+
"""
23+
print(
24+
"Backing off {wait:0.1f} seconds after {tries} tries calling function {target} with kwargs {kwargs}".format(
25+
**details,
26+
),
27+
)
28+
29+
30+
def giveup_hdlr(details) -> bool:
31+
"""Wrapper function that decides when to give up on retry."""
32+
if "rate limits" in details.message:
33+
return False
34+
return True
35+
36+
37+
def get_premai_api_key(api_key: Optional[str] = None) -> str:
38+
"""Retrieve the PreMAI API key from a passed argument or environment variable."""
39+
api_key = api_key or os.environ.get("PREMAI_API_KEY")
40+
if api_key is None:
41+
raise RuntimeError(
42+
"No API key found. See the quick start guide at https://docs.premai.io/introduction to get your API key.",
43+
)
44+
return api_key
45+
46+
47+
class PremAI(LM):
48+
"""Wrapper around Prem AI's API."""
49+
50+
def __init__(
51+
self,
52+
project_id: int,
53+
model: Optional[str] = None,
54+
api_key: Optional[str] = None,
55+
session_id: Optional[int] = None,
56+
**kwargs,
57+
) -> None:
58+
"""Parameters
59+
60+
project_id: int
61+
"The project ID in which the experiments or deployments are carried out. can find all your projects here: https://app.premai.io/projects/"
62+
model: Optional[str]
63+
The name of model deployed on launchpad. When None, it will show 'default'
64+
api_key: Optional[str]
65+
Prem AI API key, to connect with the API. If not provided then it will check from env var by the name
66+
PREMAI_API_KEY
67+
session_id: Optional[int]
68+
The ID of the session to use. It helps to track the chat history.
69+
**kwargs: dict
70+
Additional arguments to pass to the API provider
71+
"""
72+
model = "default" if model is None else model
73+
super().__init__(model)
74+
if premai_api_error == Exception:
75+
raise ImportError(
76+
"Not loading Prem AI because it is not installed. Install it with `pip install premai`.",
77+
)
78+
self.kwargs = kwargs if kwargs == {} else self.kwargs
79+
80+
self.project_id = project_id
81+
self.session_id = session_id
82+
83+
api_key = get_premai_api_key(api_key=api_key)
84+
self.client = premai.Prem(api_key=api_key)
85+
self.provider = "premai"
86+
self.history: list[dict[str, Any]] = []
87+
88+
self.kwargs = {
89+
"temperature": 0.17,
90+
"max_tokens": 150,
91+
**kwargs,
92+
}
93+
if session_id is not None:
94+
self.kwargs["session_id"] = session_id
95+
96+
# However this is not recommended to change the model once
97+
# deployed from launchpad
98+
99+
if model != "default":
100+
self.kwargs["model"] = model
101+
102+
def _get_all_kwargs(self, **kwargs) -> dict:
103+
other_kwargs = {
104+
"seed": None,
105+
"logit_bias": None,
106+
"tools": None,
107+
"system_prompt": None,
108+
}
109+
all_kwargs = {
110+
**self.kwargs,
111+
**other_kwargs,
112+
**kwargs,
113+
}
114+
115+
_keys_that_cannot_be_none = [
116+
"system_prompt",
117+
"frequency_penalty",
118+
"presence_penalty",
119+
"tools",
120+
]
121+
122+
for key in _keys_that_cannot_be_none:
123+
if all_kwargs.get(key) is None:
124+
all_kwargs.pop(key, None)
125+
return all_kwargs
126+
127+
def basic_request(self, prompt, **kwargs) -> str:
128+
"""Handles retrieval of completions from Prem AI whilst handling API errors."""
129+
all_kwargs = self._get_all_kwargs(**kwargs)
130+
messages = []
131+
132+
if "system_prompt" in all_kwargs:
133+
messages.append({"role": "system", "content": all_kwargs["system_prompt"]})
134+
messages.append({"role": "user", "content": prompt})
135+
136+
response = self.client.chat.completions.create(
137+
project_id=self.project_id,
138+
messages=messages,
139+
**all_kwargs,
140+
)
141+
if not response.choices:
142+
raise premai_api_error("ChatResponse must have at least one candidate")
143+
144+
content = response.choices[0].message.content
145+
if not content:
146+
raise premai_api_error("ChatResponse is none")
147+
148+
output_text = content or ""
149+
150+
self.history.append(
151+
{
152+
"prompt": prompt,
153+
"response": content,
154+
"kwargs": all_kwargs,
155+
"raw_kwargs": kwargs,
156+
},
157+
)
158+
159+
return output_text
160+
161+
@backoff.on_exception(
162+
backoff.expo,
163+
(premai_api_error),
164+
max_time=1000,
165+
on_backoff=backoff_hdlr,
166+
giveup=giveup_hdlr,
167+
)
168+
def request(self, prompt, **kwargs) -> str:
169+
"""Handles retrieval of completions from Prem AI whilst handling API errors."""
170+
return self.basic_request(prompt=prompt, **kwargs)
171+
172+
def __call__(self, prompt, **kwargs):
173+
return self.request(prompt, **kwargs)

dspy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
AWSMeta = dsp.AWSMeta
4646

4747
Watsonx = dsp.Watsonx
48+
PremAI = dsp.PremAI
4849

4950
configure = settings.configure
5051
context = settings.context

0 commit comments

Comments
 (0)