Skip to content

Commit b205d06

Browse files
author
Balaji Veeramani
authored
feature: Support multiple accept types (#61)
1 parent e35b489 commit b205d06

File tree

5 files changed

+51
-6
lines changed

5 files changed

+51
-6
lines changed

src/sagemaker_inference/default_inference_handler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""
1616
import textwrap
1717

18-
from sagemaker_inference import decoder, encoder
18+
from sagemaker_inference import decoder, encoder, errors, utils
1919

2020

2121
class DefaultInferenceHandler(object):
@@ -85,4 +85,7 @@ def default_output_fn(self, prediction, accept): # pylint: disable=no-self-use
8585
obj: prediction data.
8686
8787
"""
88-
return encoder.encode(prediction, accept), accept
88+
for content_type in utils.parse_accept(accept):
89+
if content_type in encoder.SUPPORTED_CONTENT_TYPES:
90+
return encoder.encode(prediction, content_type), content_type
91+
raise errors.UnsupportedFormatError(accept)

src/sagemaker_inference/encoder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ def _array_to_csv(array_like):
8787
}
8888

8989

90+
SUPPORTED_CONTENT_TYPES = set(_encoder_map.keys())
91+
92+
9093
def encode(array_like, content_type):
9194
"""Encode an array-like object in a specific content_type to a numpy array.
9295

src/sagemaker_inference/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,16 @@ def retrieve_content_type_header(request_property):
6666
return request_property[key]
6767

6868
return None
69+
70+
71+
def parse_accept(accept):
72+
"""Parses the Accept header sent with a request.
73+
74+
Args:
75+
accept (str): the value of an Accept header.
76+
77+
Returns:
78+
(list): A list containing the MIME types that the client is able to
79+
understand.
80+
"""
81+
return accept.replace(" ", "").split(",")

test/unit/test_default_inference_handler.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,19 @@ def test_default_input_fn(loads):
2424
loads.assert_called_with(42, content_types.JSON)
2525

2626

27+
@pytest.mark.parametrize(
28+
"accept, expected_content_type",
29+
[
30+
("text/csv", "text/csv"),
31+
("text/csv, application/json", "text/csv"),
32+
("unsupported/type, text/csv", "text/csv"),
33+
],
34+
)
2735
@patch("sagemaker_inference.encoder.encode", lambda prediction, accept: prediction ** 2)
28-
def test_default_output_fn():
29-
result, accept = DefaultInferenceHandler().default_output_fn(2, content_types.CSV)
36+
def test_default_output_fn(accept, expected_content_type):
37+
result, content_type = DefaultInferenceHandler().default_output_fn(2, accept)
3038
assert result == 4
31-
assert accept == content_types.CSV
39+
assert content_type == expected_content_type
3240

3341

3442
def test_default_model_fn():

test/unit/test_utils.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
from mock import Mock, mock_open, patch
1414
import pytest
1515

16-
from sagemaker_inference.utils import read_file, retrieve_content_type_header, write_file
16+
from sagemaker_inference.utils import (
17+
parse_accept,
18+
read_file,
19+
retrieve_content_type_header,
20+
write_file,
21+
)
1722

1823
TEXT = "text"
1924
CONTENT_TYPE = "content_type"
@@ -74,3 +79,16 @@ def test_content_type_header(content_type_key):
7479
result = retrieve_content_type_header(request_property)
7580

7681
assert result == CONTENT_TYPE
82+
83+
84+
@pytest.mark.parametrize(
85+
"input, expected",
86+
[
87+
("application/json", ["application/json"]),
88+
("application/json, text/csv", ["application/json", "text/csv"]),
89+
("application/json,text/csv", ["application/json", "text/csv"]),
90+
],
91+
)
92+
def test_parse_accept(input, expected):
93+
actual = parse_accept(input)
94+
assert actual == expected

0 commit comments

Comments
 (0)