Skip to content

Commit f4c0951

Browse files
committed
Add python server-side batching example
1 parent 18ab26c commit f4c0951

File tree

7 files changed

+134
-4
lines changed

7 files changed

+134
-4
lines changed

test/apis/pytorch/iris-classifier/model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@
22
import torch.nn as nn
33
import torch.nn.functional as F
44
from torch.autograd import Variable
5-
from sklearn.datasets import load_iris
6-
from sklearn.model_selection import train_test_split
7-
from sklearn.metrics import accuracy_score
85

96

107
class IrisNet(nn.Module):
@@ -24,6 +21,10 @@ def forward(self, X):
2421

2522

2623
if __name__ == "__main__":
24+
from sklearn.datasets import load_iris
25+
from sklearn.model_selection import train_test_split
26+
from sklearn.metrics import accuracy_score
27+
2728
iris = load_iris()
2829
X, y = iris.data, iris.target
2930
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42)
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
torch
2-
scikit-learn
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
- name: iris-classifier
2+
kind: RealtimeAPI
3+
predictor:
4+
type: python
5+
path: predictor.py
6+
config:
7+
model: s3://cortex-examples/pytorch/iris-classifier/weights.pth
8+
server_side_batching:
9+
max_batch_size: 8
10+
batch_interval: 0.1s
11+
threads_per_process: 8
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.autograd import Variable
5+
6+
7+
class IrisNet(nn.Module):
8+
def __init__(self):
9+
super(IrisNet, self).__init__()
10+
self.fc1 = nn.Linear(4, 100)
11+
self.fc2 = nn.Linear(100, 100)
12+
self.fc3 = nn.Linear(100, 3)
13+
self.softmax = nn.Softmax(dim=1)
14+
15+
def forward(self, X):
16+
X = F.relu(self.fc1(X))
17+
X = self.fc2(X)
18+
X = self.fc3(X)
19+
X = self.softmax(X)
20+
return X
21+
22+
23+
if __name__ == "__main__":
24+
from sklearn.datasets import load_iris
25+
from sklearn.model_selection import train_test_split
26+
from sklearn.metrics import accuracy_score
27+
28+
iris = load_iris()
29+
X, y = iris.data, iris.target
30+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.8, random_state=42)
31+
32+
train_X = Variable(torch.Tensor(X_train).float())
33+
test_X = Variable(torch.Tensor(X_test).float())
34+
train_y = Variable(torch.Tensor(y_train).long())
35+
test_y = Variable(torch.Tensor(y_test).long())
36+
37+
model = IrisNet()
38+
39+
criterion = nn.CrossEntropyLoss()
40+
41+
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
42+
43+
for epoch in range(1000):
44+
optimizer.zero_grad()
45+
out = model(train_X)
46+
loss = criterion(out, train_y)
47+
loss.backward()
48+
optimizer.step()
49+
50+
if epoch % 100 == 0:
51+
print("number of epoch {} loss {}".format(epoch, loss))
52+
53+
predict_out = model(test_X)
54+
_, predict_y = torch.max(predict_out, 1)
55+
56+
print("prediction accuracy {}".format(accuracy_score(test_y.data, predict_y.data)))
57+
58+
torch.save(model.state_dict(), "weights.pth")
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import re
2+
import torch
3+
import os
4+
import boto3
5+
from botocore import UNSIGNED
6+
from botocore.client import Config
7+
from model import IrisNet
8+
9+
labels = ["setosa", "versicolor", "virginica"]
10+
11+
12+
class PythonPredictor:
13+
def __init__(self, config):
14+
# download the model
15+
bucket, key = re.match("s3://(.+?)/(.+)", config["model"]).groups()
16+
17+
if os.environ.get("AWS_ACCESS_KEY_ID"):
18+
s3 = boto3.client("s3") # client will use your credentials if available
19+
else:
20+
s3 = boto3.client("s3", config=Config(signature_version=UNSIGNED)) # anonymous client
21+
22+
s3.download_file(bucket, key, "/tmp/model.pth")
23+
24+
# initialize the model
25+
model = IrisNet()
26+
model.load_state_dict(torch.load("/tmp/model.pth"))
27+
model.eval()
28+
29+
self.model = model
30+
31+
def predict(self, payload):
32+
responses = []
33+
34+
# note: this is not the most efficient way, it's just to test server-side batching
35+
for sample in payload:
36+
# Convert the request to a tensor and pass it into the model
37+
input_tensor = torch.FloatTensor(
38+
[
39+
[
40+
sample["sepal_length"],
41+
sample["sepal_width"],
42+
sample["petal_length"],
43+
sample["petal_width"],
44+
]
45+
]
46+
)
47+
48+
# Run the prediction
49+
output = self.model(input_tensor)
50+
51+
# Translate the model output to the corresponding label string
52+
responses.append(labels[torch.argmax(output[0])])
53+
54+
return responses
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
torch
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"sepal_length": 2.2,
3+
"sepal_width": 3.6,
4+
"petal_length": 1.4,
5+
"petal_width": 3.3
6+
}

0 commit comments

Comments
 (0)