Skip to content

Commit bcebff1

Browse files
Merge pull request #795 from drawal1/main
Added support for models on AWS Bedrock and Sagemaker
2 parents ff6d813 + d08fa38 commit bcebff1

File tree

12 files changed

+737
-273
lines changed

12 files changed

+737
-273
lines changed
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
---
2+
sidebar_position: 9
3+
---
4+
5+
# dspy.AWSMistral, dspy.AWSAnthropic, dspy.AWSMeta
6+
7+
### Usage
8+
9+
```python
10+
# Notes:
11+
# 1. Install boto3 to use AWS models.
12+
# 2. Configure your AWS credentials with the AWS CLI before using these models
13+
14+
# initialize the bedrock aws provider
15+
bedrock = dspy.Bedrock(region_name="us-west-2")
16+
# For mixtral on Bedrock
17+
lm = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs)
18+
# For haiku on Bedrock
19+
lm = dspy.AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs)
20+
# For llama2 on Bedrock
21+
lm = dspy.AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs)
22+
23+
# initialize the sagemaker aws provider
24+
sagemaker = dspy.Sagemaker(region_name="us-west-2")
25+
# For mistral on Sagemaker
26+
# Note: you need to create a Sagemaker endpoint for the mistral model first
27+
lm = dspy.AWSMistral(sagemaker, "<YOUR_MISTRAL_ENDPOINT_NAME>", **kwargs)
28+
29+
```
30+
31+
### Constructor
32+
33+
The `AWSMistral` constructor initializes the base class `AWSModel` which itself inherits from the `LM` class.
34+
35+
```python
36+
class AWSMistral(AWSModel):
37+
"""Mistral family of models."""
38+
39+
def __init__(
40+
self,
41+
aws_provider: AWSProvider,
42+
model: str,
43+
max_context_size: int = 32768,
44+
max_new_tokens: int = 1500,
45+
**kwargs
46+
) -> None:
47+
```
48+
49+
**Parameters:**
50+
- `aws_provider` (AWSProvider): The aws provider to use. One of `dspy.Bedrock` or `dspy.Sagemaker`.
51+
- `model` (_str_): Mistral AI pretrained models. For Bedrock, this is the Model ID in https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns. For Sagemaker, this is the endpoint name.
52+
- `max_context_size` (_Optional[int]_, _optional_): Max context size for this model. Defaults to 32768.
53+
- `max_new_tokens` (_Optional[int]_, _optional_): Max new tokens possible for this model. Defaults to 1500.
54+
- `**kwargs`: Additional language model arguments to pass to the API provider.
55+
56+
### Methods
57+
58+
```python
59+
def _format_prompt(self, raw_prompt: str) -> str:
60+
```
61+
This function formats the prompt for the model. Refer to the model card for the specific formatting required.
62+
63+
<br/>
64+
65+
```python
66+
def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]:
67+
```
68+
This function creates the body of the request to the model. It takes the prompt and any additional keyword arguments and returns a tuple of the number of tokens to generate and a dictionary of keys including the prompt used to create the body of the request.
69+
70+
<br/>
71+
72+
```python
73+
def _call_model(self, body: str) -> str:
74+
```
75+
This function calls the model using the provider `call_model()` function and extracts the generated text (completion) from the provider-specific response.
76+
77+
<br/>
78+
79+
The above model-specific methods are called by the `AWSModel::basic_request()` method, which is the main method for querying the model. This method takes the prompt and any additional keyword arguments and calls the `AWSModel::_simple_api_call()` which then delegates to the model-specific `_create_body()` and `_call_model()` methods to create the body of the request, call the model and extract the generated text.
80+
81+
82+
Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) documentation for information on the `LM` base class functionality.
83+
84+
<br/>
85+
86+
`AWSAnthropic` and `AWSMeta` work exactly the same as `AWSMistral`.
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
---
2+
sidebar_position: 9
3+
---
4+
5+
# dspy.Bedrock, dspy.Sagemaker
6+
7+
### Usage
8+
9+
The `AWSProvider` class is the base class for the AWS providers - `dspy.Bedrock` and `dspy.Sagemaker`. An instance of one of these providers is passed to the constructor when creating an instance of an AWS model class (e.g., `dspy.AWSMistral`) that is ultimately used to query the model.
10+
11+
```python
12+
# Notes:
13+
# 1. Install boto3 to use AWS models.
14+
# 2. Configure your AWS credentials with the AWS CLI before using these models
15+
16+
# initialize the bedrock aws provider
17+
bedrock = dspy.Bedrock(region_name="us-west-2")
18+
19+
# initialize the sagemaker aws provider
20+
sagemaker = dspy.Sagemaker(region_name="us-west-2")
21+
```
22+
23+
### Constructor
24+
25+
The `Bedrock` constructor initializes the base class `AWSProvider`.
26+
27+
```python
28+
class Bedrock(AWSProvider):
29+
"""This class adds support for Bedrock models."""
30+
31+
def __init__(
32+
self,
33+
region_name: str,
34+
profile_name: Optional[str] = None,
35+
batch_n_enabled: bool = False, # This has to be setup manually on Bedrock.
36+
) -> None:
37+
```
38+
39+
**Parameters:**
40+
- `region_name` (str): The AWS region where this LM is hosted.
41+
- `profile_name` (str, optional): boto3 credentials profile.
42+
- `batch_n_enabled` (bool): If False, call the LM N times rather than batching.
43+
44+
### Methods
45+
46+
```python
47+
def call_model(self, model_id: str, body: str) -> str:
48+
```
49+
This function implements the actual invocation of the model on AWS using the boto3 provider.
50+
51+
<br/>
52+
53+
`Sagemaker` works exactly the same as `Bedrock`.

dsp/modules/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from .anthropic import Claude
2+
from .aws_models import AWSAnthropic, AWSMeta, AWSMistral, AWSModel
3+
4+
# Below is obsolete. It has been replaced with Bedrock class in dsp/modules/aws_providers.py
5+
# from .bedrock import *
6+
from .aws_providers import Bedrock, Sagemaker
27
from .azure_openai import AzureOpenAI
3-
from .bedrock import *
48
from .cache_utils import *
59
from .clarifai import *
610
from .cohere import *
@@ -17,4 +21,3 @@
1721
from .pyserini import *
1822
from .sbert import *
1923
from .sentence_vectorizer import *
20-

dsp/modules/aws_lm.py

Lines changed: 0 additions & 186 deletions
This file was deleted.

0 commit comments

Comments
 (0)