2020from pandas import DataFrame
2121
2222from sagemaker .serve import SchemaBuilder , CustomPayloadTranslator
23- from sagemaker .serve .builder .schema_builder import JSONSerializerWrapper
23+ from sagemaker .serve .builder .schema_builder import JSONSerializerWrapper , CSVSerializerWrapper
2424from sagemaker .deserializers import (
2525 BytesDeserializer ,
2626 NumpyDeserializer ,
3030from sagemaker .serializers import (
3131 DataSerializer ,
3232 NumpySerializer ,
33- CSVSerializer ,
3433)
3534
3635NUMPY_CONTENT_TYPE = "application/x-npy"
@@ -94,13 +93,9 @@ def custom_translator():
9493 return MyPayloadTranslator ()
9594
9695
97- # @pytest.fixture
98- # def torch_tensor():
99- # return torch.rand(3, 4)
100-
101-
10296def test_schema_builder_with_numpy (numpy_array ):
10397 schema_builder = SchemaBuilder (numpy_array , numpy_array )
98+ _validate_marshalling_function (schema_builder = schema_builder )
10499 assert isinstance (schema_builder .input_serializer , NumpySerializer )
105100 assert isinstance (schema_builder .output_serializer , NumpySerializer )
106101 assert isinstance (schema_builder .input_deserializer ._deserializer , NumpyDeserializer )
@@ -111,8 +106,9 @@ def test_schema_builder_with_numpy(numpy_array):
111106
112107def test_schema_builder_with_pandas_dataframe (pandas_df ):
113108 schema_builder = SchemaBuilder (pandas_df , pandas_df )
114- assert isinstance (schema_builder .input_serializer , CSVSerializer )
115- assert isinstance (schema_builder .output_serializer , CSVSerializer )
109+ _validate_marshalling_function (schema_builder = schema_builder )
110+ assert isinstance (schema_builder .input_serializer , CSVSerializerWrapper )
111+ assert isinstance (schema_builder .output_serializer , CSVSerializerWrapper )
116112 assert isinstance (schema_builder .input_deserializer ._deserializer , PandasDeserializer )
117113 assert schema_builder .input_deserializer .ACCEPT == DATAFRAME_CONTENT_TYPE
118114 assert isinstance (schema_builder .output_deserializer ._deserializer , PandasDeserializer )
@@ -121,6 +117,7 @@ def test_schema_builder_with_pandas_dataframe(pandas_df):
121117
122118def test_schema_builder_with_jsonable (jsonable_obj ):
123119 schema_builder = SchemaBuilder (jsonable_obj , jsonable_obj )
120+ _validate_marshalling_function (schema_builder = schema_builder )
124121 assert isinstance (schema_builder .input_serializer , JSONSerializerWrapper )
125122 assert isinstance (schema_builder .output_serializer , JSONSerializerWrapper )
126123 assert isinstance (schema_builder .input_deserializer ._deserializer , JSONDeserializer )
@@ -131,13 +128,14 @@ def test_schema_builder_with_jsonable(jsonable_obj):
131128
132129def test_schema_builder_with_bytes (some_bytes ):
133130 schema_builder = SchemaBuilder (some_bytes , some_bytes )
131+ _validate_marshalling_function (schema_builder = schema_builder )
134132 assert isinstance (schema_builder .input_serializer , DataSerializer )
135133 assert isinstance (schema_builder .output_serializer , DataSerializer )
136134 assert isinstance (schema_builder .input_deserializer ._deserializer , BytesDeserializer )
137135 assert isinstance (schema_builder .output_deserializer ._deserializer , BytesDeserializer )
138136
139137
140- def test_schema_builder_with_cloudpickle (unsupported_object ):
138+ def test_schema_builder_unsupported_type (unsupported_object ):
141139 with pytest .raises (ValueError , match = "SchemaBuilder cannot determine" ):
142140 SchemaBuilder (unsupported_object , unsupported_object )
143141
@@ -149,6 +147,19 @@ def test_json_serializer_wrapper(jsonable):
149147 JSONDeserializer ().deserialize (stream , content_type = "application/json" )
150148
151149
150+ def _validate_marshalling_function (schema_builder : SchemaBuilder ):
151+ """Invoke serializer and deserializer to validate the payload"""
152+ # Validate sample_input
153+ b = schema_builder .input_serializer .serialize (schema_builder .sample_input )
154+ stream = BytesIO (b )
155+ schema_builder .input_deserializer .deserialize (stream = stream )
156+
157+ # Validate sample_output
158+ b = schema_builder .output_serializer .serialize (schema_builder .sample_output )
159+ stream = BytesIO (b )
160+ schema_builder .output_deserializer .deserialize (stream = stream )
161+
162+
152163def test_schema_builder_with_payload_translator (custom_translator ):
153164 payload = "payload"
154165 schema_builder = SchemaBuilder (
0 commit comments