Skip to content

Commit 119bca2

Browse files
authored
Optional model for predictor (#549)
* Introduce optional model path to predictor API * Fix sklearn mpg model path * Allow folders to be accessed correctly after being downloaded from s3 * Update example readme and docs * Address PR comments
1 parent 4b1e8d0 commit 119bca2

File tree

15 files changed

+87
-54
lines changed

15 files changed

+87
-54
lines changed

docs/deployments/apis.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ In addition to supporting Python models via the Predictor interface, Cortex can
1515
endpoint: <string> # the endpoint for the API (default: /<deployment_name>/<api_name>)
1616
predictor:
1717
path: <string> # path to the predictor Python file, relative to the Cortex root (required)
18+
model: <string> # S3 path to a file or directory (e.g. s3://my-bucket/exported_model) (optional)
1819
python_path: <string> # path to the root of your Python folder that will be appended to PYTHONPATH (default: folder containing cortex.yaml)
1920
metadata: <string: value> # dictionary that can be used to configure custom values (optional)
2021
tracker:

docs/deployments/predictor.md

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@ Global variables can be shared across functions safely because each replica hand
1111
```python
1212
# initialization code and variables can be declared here in global scope
1313

14-
def init(metadata):
14+
def init(model_path, metadata):
1515
"""Called once before the API is made available. Setup for model serving such
1616
as downloading/initializing the model or downloading vocabulary can be done here.
1717
Optional.
1818
1919
Args:
20+
model_path: Local path to model file or directory if specified by user in API configuration, otherwise None.
2021
metadata: Custom dictionary specified by the user in API configuration.
2122
"""
2223
pass
@@ -45,14 +46,9 @@ from my_model import IrisNet
4546
labels = ["iris-setosa", "iris-versicolor", "iris-virginica"]
4647
model = IrisNet()
4748

48-
def init(metadata):
49-
# Download model from S3 (location specified in the API's metadata)
50-
s3 = boto3.client("s3")
51-
bucket, key = re.match(r"s3:\/\/(.+?)\/(.+)", metadata["model"]).groups()
52-
s3.download_file(bucket, key, "iris_model.pth")
53-
49+
def init(model_path, metadata):
5450
# Initialize the model
55-
model.load_state_dict(torch.load("iris_model.pth"))
51+
model.load_state_dict(torch.load(model_path))
5652
model.eval()
5753

5854

examples/pytorch/iris-classifier/README.md

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,26 +8,22 @@ We implement Cortex's Predictor interface to load the model and make predictions
88

99
### Initialization
1010

11-
We can place our code to download and initialize the model in the `init()` function. The PyTorch model class is defined in [src/my_model.py](./src/my_model.py), and we assume that we've already trained the model and uploaded the state_dict (weights) to S3.
11+
We can place our code to download and initialize the model in the `init()` function. The PyTorch model class is defined in [src/model.py](./src/model.py), and we assume that we've already trained the model and uploaded the state_dict (weights) to S3.
1212

1313
```python
1414
# predictor.py
1515

16-
from my_model import IrisNet
16+
from model import IrisNet
1717

1818
# instantiate the model
1919
model = IrisNet()
2020

2121
# define the labels
2222
labels = ["iris-setosa", "iris-versicolor", "iris-virginica"]
2323

24-
def init(metadata):
25-
# download the model from S3 (location specified in the metadata field of our api configuration)
26-
s3 = boto3.client("s3")
27-
bucket, key = re.match(r"s3:\/\/(.+?)\/(.+)", metadata["model"]).groups()
28-
s3.download_file(bucket, key, "weights.pth")
29-
30-
model.load_state_dict(torch.load("weights.pth"))
24+
def init(model_path, metadata):
25+
# model_path is a local path pointing to your model weights file
26+
model.load_state_dict(torch.load(model_path))
3127
model.eval()
3228
```
3329

@@ -71,8 +67,7 @@ A `deployment` specifies a set of resources that are deployed together. An `api`
7167
predictor:
7268
path: src/predictor.py
7369
python_path: src/
74-
metadata:
75-
model: s3://cortex-examples/pytorch/iris-classifier/weights.pth
70+
model: s3://cortex-examples/pytorch/iris-classifier/weights.pth
7671
tracker:
7772
model_type: classification
7873
```

examples/pytorch/iris-classifier/cortex.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
predictor:
77
path: src/predictor.py
88
python_path: src/
9-
metadata:
10-
model: s3://cortex-examples/pytorch/iris-classifier/weights.pth
9+
model: s3://cortex-examples/pytorch/iris-classifier/weights.pth
1110
tracker:
1211
model_type: classification

examples/pytorch/iris-classifier/src/predictor.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import boto3
21
import re
32
import torch
43
from model import IrisNet
@@ -8,11 +7,8 @@
87
model = IrisNet()
98

109

11-
def init(metadata):
12-
s3 = boto3.client("s3")
13-
bucket, key = re.match(r"s3:\/\/(.+?)\/(.+)", metadata["model"]).groups()
14-
s3.download_file(bucket, key, "weights.pth")
15-
model.load_state_dict(torch.load("weights.pth"))
10+
def init(model_path, metadata):
11+
model.load_state_dict(torch.load(model_path))
1612
model.eval()
1713

1814

examples/pytorch/text-generator/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ model.eval()
2020
# download the tokenizer
2121
tokenizer = GPT2Tokenizer.from_pretrained("distilgpt2")
2222

23-
def init(metadata):
23+
def init(model_path, metadata):
2424
# load the model onto the device specified in the metadata field of our api configuration
2525
model.to(metadata["device"])
2626
```

examples/pytorch/text-generator/predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float("Inf")
4646
return logits
4747

4848

49-
def init(metadata):
49+
def init(model_path, metadata):
5050
model.to(metadata["device"])
5151

5252

examples/sklearn/mpg-estimation/README.md

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,17 @@ We implement Cortex's Predictor interface to load the model and make predictions
88

99
### Initialization
1010

11-
We can place our code to download and initialize the model in the `init()` function:
11+
We can place our code to initialize the model in the `init()` function:
1212

1313
```python
1414
# predictor.py
1515

1616
model = None
1717

18-
def init(metadata):
18+
def init(model_path, metadata):
1919
global model
20-
21-
# download the model from S3 (path specified in the metadata field of our api configuration)
22-
s3 = boto3.client("s3")
23-
bucket, key = re.match(r"s3:\/\/(.+?)\/(.+)", metadata["model"]).groups()
24-
s3.download_file(bucket, key, "linreg.joblib")
25-
model = load("linreg.joblib")
20+
# model_path is a local path pointing to your model
21+
model = load(model_path)
2622
```
2723

2824
### Predict
@@ -61,8 +57,7 @@ A `deployment` specifies a set of resources that are deployed together. An `api`
6157
name: mpg
6258
predictor:
6359
path: predictor.py
64-
metadata:
65-
model: s3://cortex-examples/sklearn/mpg-estimation/linreg.joblib
60+
model: s3://cortex-examples/sklearn/mpg-estimation/linreg.joblib
6661
tracker:
6762
model_type: regression
6863
```

examples/sklearn/mpg-estimation/cortex.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
name: mpg
66
predictor:
77
path: predictor.py
8-
metadata:
9-
model: s3://cortex-examples/sklearn/mpg-estimation/linreg.joblib
8+
model: s3://cortex-examples/sklearn/mpg-estimation/linreg.joblib
109
tracker:
1110
model_type: regression

examples/sklearn/mpg-estimation/predictor.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
1-
import boto3
2-
import re
3-
import numpy as np
41
from joblib import load
2+
import numpy as np
53

64

75
model = None
86

97

10-
def init(metadata):
8+
def init(model_path, metadata):
119
global model
12-
s3 = boto3.client("s3")
13-
bucket, key = re.match(r"s3:\/\/(.+?)\/(.+)", metadata["model"]).groups()
14-
s3.download_file(bucket, key, "linreg.joblib")
15-
model = load("linreg.joblib")
10+
model = load(model_path)
1611

1712

1813
def predict(sample, metadata):

0 commit comments

Comments
 (0)