Skip to content

Commit fbe22a2

Browse files
committed
docs(aws.md): added commentary on AWSModel methods
1 parent f33eeff commit fbe22a2

File tree

2 files changed

+35
-10
lines changed

2 files changed

+35
-10
lines changed

docs/api/language_model_clients/aws.md

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ lm = dspy.AWSMistral(sagemaker, "<YOUR_MISTRAL_ENDPOINT_NAME>", **kwargs)
3030

3131
### Constructor
3232

33-
The constructor initializes the base class `LM` and the `AWSModel` class.
33+
The `AWSMistral` constructor initializes the base class `AWSModel` which itself inherits from the `LM` class.
3434

3535
```python
3636
class AWSMistral(AWSModel):
@@ -47,15 +47,40 @@ class AWSMistral(AWSModel):
4747
```
4848

4949
**Parameters:**
50-
- `aws_provider` (AWSProvider): The aws provider to use. One of `Bedrock` or `Sagemaker`.
51-
- `model` (_str_): Mistral AI pretrained models. Defaults to `mistral-medium-latest`.
50+
- `aws_provider` (AWSProvider): The aws provider to use. One of `dspy.Bedrock` or `dspy.Sagemaker`.
51+
- `model` (_str_): Mistral AI pretrained models. For Bedrock, this is the Model ID in https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids.html#model-ids-arns. For Sagemaker, this is the endpoint name.
5252
- `max_context_size` (_Optional[int]_, _optional_): Max context size for this model. Defaults to 32768.
5353
- `max_new_tokens` (_Optional[int]_, _optional_): Max new tokens possible for this model. Defaults to 1500.
5454
- `**kwargs`: Additional language model arguments to pass to the API provider.
5555

5656
### Methods
5757

58-
Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) documentation.
58+
```python
59+
def _format_prompt(self, raw_prompt: str) -> str:
60+
```
61+
This function formats the prompt for the model. Refer to the model card for the specific formatting required.
62+
63+
<br/>
64+
65+
```python
66+
def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | float]]:
67+
```
68+
This function creates the body of the request to the model. It takes the prompt and any additional keyword arguments and returns a tuple of the number of tokens to generate and a dictionary of keys including the prompt used to create the body of the request.
69+
70+
<br/>
71+
72+
```python
73+
def _call_model(self, body: str) -> str:
74+
```
75+
This function calls the model using the provider `call_model()` function and extracts the generated text (completion) from the provider-specific response.
76+
77+
<br/>
78+
79+
The above model-specific methods are called by the `AWSModel::basic_request()` method, which is the main method for querying the model. This method takes the prompt and any additional keyword arguments and calls the `AWSModel::_simple_api_call()` which then delegates to the model-specific `_create_body()` and `_call_model()` methods to create the body of the request, call the model and extract the generated text.
80+
81+
82+
Refer to [`dspy.OpenAI`](https://dspy-docs.vercel.app/api/language_model_clients/OpenAI) documentation for information on the `LM` base class functionality.
5983

84+
<br/>
6085

6186
`AWSAnthropic` and `AWSMeta` work exactly the same as `AWSMistral`.

dsp/modules/aws_models.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ class AWSModel(LM):
2121
Instead, use one of the subclasses - AWSMistral, AWSAnthropic, or AWSMeta.
2222
The subclasses implement the abstract methods _create_body and _call_model and work in conjunction with the AWSProvider classes Bedrock and Sagemaker.
2323
Usage Example:
24-
bedrock = Bedrock(region_name="us-west-2")
25-
bedrock_mixtral = AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs)
26-
bedrock_haiku = AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs)
27-
bedrock_llama2 = AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs)
24+
bedrock = dspy.Bedrock(region_name="us-west-2")
25+
bedrock_mixtral = dspy.AWSMistral(bedrock, "mistral.mixtral-8x7b-instruct-v0:1", **kwargs)
26+
bedrock_haiku = dspy.AWSAnthropic(bedrock, "anthropic.claude-3-haiku-20240307-v1:0", **kwargs)
27+
bedrock_llama2 = dspy.AWSMeta(bedrock, "meta.llama2-13b-chat-v1", **kwargs)
2828
29-
sagemaker = Sagemaker(region_name="us-west-2")
30-
sagemaker_model = AWSMistral(sagemaker, "<YOUR_ENDPOINT_NAME>", **kwargs)
29+
sagemaker = dspy.Sagemaker(region_name="us-west-2")
30+
sagemaker_model = dspy.AWSMistral(sagemaker, "<YOUR_ENDPOINT_NAME>", **kwargs)
3131
"""
3232

3333
def __init__(

0 commit comments

Comments
 (0)