Skip to content

Commit f897549

Browse files
authored
feature: Decode application/x-npz content type (#60)
1 parent deef6a8 commit f897549

File tree

4 files changed

+40
-3
lines changed

4 files changed

+40
-3
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def read_version():
2929

3030
packages = setuptools.find_packages(where="src", exclude=("test",))
3131

32-
required_packages = ["numpy", "six", "psutil", "retrying==1.3.3"]
32+
required_packages = ["numpy", "six", "psutil", "retrying==1.3.3", "scipy"]
3333

3434
# enum is introduced in Python 3.4. Installing enum back port
3535
if sys.version_info < (3, 4):

src/sagemaker_inference/content_types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,5 @@
1616
OCTET_STREAM = "application/octet-stream"
1717
ANY = "*/*"
1818
NPY = "application/x-npy"
19+
NPZ = "application/x-npz"
1920
UTF8_TYPES = [JSON, CSV]

src/sagemaker_inference/decoder.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import json
1818

1919
import numpy as np
20+
import scipy.sparse
2021
from six import BytesIO, StringIO
2122

2223
from sagemaker_inference import content_types, errors
@@ -70,22 +71,36 @@ def _npy_to_numpy(npy_array): # type: (object) -> np.array
7071
return np.load(stream, allow_pickle=True)
7172

7273

74+
def _npz_to_sparse(npz_bytes): # type: (object) -> scipy.sparse.spmatrix
75+
"""Convert .npz-formatted data to a sparse matrix.
76+
77+
Args:
78+
npz_bytes (object): Bytes encoding a sparse matrix in the .npz format.
79+
80+
Returns:
81+
(scipy.sparse.spmatrix): A sparse matrix.
82+
"""
83+
buffer = BytesIO(npz_bytes)
84+
return scipy.sparse.load_npz(buffer)
85+
86+
7387
_decoder_map = {
7488
content_types.NPY: _npy_to_numpy,
7589
content_types.CSV: _csv_to_numpy,
7690
content_types.JSON: _json_to_numpy,
91+
content_types.NPZ: _npz_to_sparse,
7792
}
7893

7994

8095
def decode(obj, content_type):
81-
"""Decode an object to one of the default content types to a numpy array.
96+
"""Decode an object that is encoded as one of the default content types.
8297
8398
Args:
8499
obj (object): to be decoded.
85100
content_type (str): content type to be used.
86101
87102
Returns:
88-
np.array: decoded object.
103+
object: decoded object for prediction.
89104
"""
90105
try:
91106
decoder = _decoder_map[content_type]

test/unit/test_decoder.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from mock import Mock, patch
1414
import numpy as np
1515
import pytest
16+
import scipy.sparse
1617
from six import BytesIO
1718

1819
from sagemaker_inference import content_types, decoder, errors
@@ -63,6 +64,26 @@ def test_csv_to_numpy(target, expected):
6364
np.testing.assert_equal(actual, expected)
6465

6566

67+
@pytest.mark.parametrize(
68+
"target",
69+
[
70+
scipy.sparse.csc_matrix(np.array([[0, 0, 3], [4, 0, 0]])),
71+
scipy.sparse.csr_matrix(np.array([[1, 0], [0, 7]])),
72+
scipy.sparse.coo_matrix(np.array([[6, 2], [5, 9]])),
73+
],
74+
)
75+
def test_npz_to_sparse(target):
76+
buffer = BytesIO()
77+
scipy.sparse.save_npz(buffer, target)
78+
data = buffer.getvalue()
79+
matrix = decoder._npz_to_sparse(data)
80+
81+
actual = matrix.toarray()
82+
expected = target.toarray()
83+
84+
np.testing.assert_equal(actual, expected)
85+
86+
6687
def test_decode_error():
6788
with pytest.raises(errors.UnsupportedFormatError):
6889
decoder.decode(42, content_types.OCTET_STREAM)

0 commit comments

Comments
 (0)