Skip to content

Commit b1281a8

Browse files
refactor: add torchserve handler to adopt tf-serving/kf-serving v1 standard and read data from instances key in the body
1 parent 2dbb786 commit b1281a8

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

deployment/pytorch/my_handler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
import json
32

43
import torch
54
from my_model import InferenceAutoencoder
@@ -25,11 +24,11 @@ def initialize(self, context):
2524
self.model = InferenceAutoencoder(input_shape=(51,), l2_lambda=1e-4)
2625
self.model.load_state_dict(state_dict)
2726
self.model.eval()
28-
logger.info("✅ Model Loaded Successfully!")
2927

3028
def preprocess(self, data):
3129
"""Convert input data to tensor"""
32-
input_data = torch.tensor(data[0]['body'], dtype=torch.float32)
30+
input_data = data[0].get("data") or data[0].get("body")
31+
input_data = torch.tensor(input_data.get("instances"), dtype=torch.float32)
3332
return input_data
3433

3534
def inference(self, data):
@@ -41,8 +40,9 @@ def inference(self, data):
4140
def postprocess(self, data):
4241
"""Convert output to JSON format"""
4342
# We have to return the same length as the input:
44-
# If our input is: [[1,2,3], [1,2,3]]
45-
# Ou output has to be [list|str|int|float, list|str|int|float]
43+
# If our input is: [[1,2,3], [1,2,3], ...]
44+
# Our output has to be [list|str|int|float, list|str|int|float, ...]
45+
# This way, the webser will but the payload directly on the response body
4646
reconstructed, avg_mse = data
4747
payload = []
4848
for idx in range(len(avg_mse)):

deployment/pytorch/test_predictions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
def inference_over_http(data: np.array):
99
url = "http://localhost:8080/predictions/my_model"
1010
headers = {"Content-Type": "application/json"}
11-
response = requests.post(url, headers=headers, json=data.tolist())
11+
body = {"instances": data.tolist()}
12+
response = requests.post(url, headers=headers, json=body)
1213
response.raise_for_status()
1314
return response.json()
1415

0 commit comments

Comments
 (0)