1515import io
1616import json
1717import logging
18-
1918import pytest
2019from mock import Mock
21-
2220from sagemaker .tensorflow import TensorFlow
2321from sagemaker .tensorflow .predictor import csv_serializer
2422from sagemaker .tensorflow .serving import Model , Predictor
@@ -167,12 +165,12 @@ def test_predictor_classify(sagemaker_session):
167165 mock_response (json .dumps (CLASSIFY_RESPONSE ).encode ('utf-8' ), sagemaker_session )
168166 result = predictor .classify (CLASSIFY_INPUT )
169167
170- assert_invoked (sagemaker_session ,
171- EndpointName = 'endpoint' ,
172- ContentType = JSON_CONTENT_TYPE ,
173- Accept = JSON_CONTENT_TYPE ,
174- CustomAttributes = 'tfs-method=classify' ,
175- Body = json .dumps (CLASSIFY_INPUT ))
168+ assert_invoked_with_body_dict (sagemaker_session ,
169+ EndpointName = 'endpoint' ,
170+ ContentType = JSON_CONTENT_TYPE ,
171+ Accept = JSON_CONTENT_TYPE ,
172+ CustomAttributes = 'tfs-method=classify' ,
173+ Body = json .dumps (CLASSIFY_INPUT ))
176174
177175 assert CLASSIFY_RESPONSE == result
178176
@@ -183,12 +181,12 @@ def test_predictor_regress(sagemaker_session):
183181 mock_response (json .dumps (REGRESS_RESPONSE ).encode ('utf-8' ), sagemaker_session )
184182 result = predictor .regress (REGRESS_INPUT )
185183
186- assert_invoked (sagemaker_session ,
187- EndpointName = 'endpoint' ,
188- ContentType = JSON_CONTENT_TYPE ,
189- Accept = JSON_CONTENT_TYPE ,
190- CustomAttributes = 'tfs-method=regress,tfs-model-name=model,tfs-model-version=123' ,
191- Body = json .dumps (REGRESS_INPUT ))
184+ assert_invoked_with_body_dict (sagemaker_session ,
185+ EndpointName = 'endpoint' ,
186+ ContentType = JSON_CONTENT_TYPE ,
187+ Accept = JSON_CONTENT_TYPE ,
188+ CustomAttributes = 'tfs-method=regress,tfs-model-name=model,tfs-model-version=123' ,
189+ Body = json .dumps (REGRESS_INPUT ))
192190
193191 assert REGRESS_RESPONSE == result
194192
@@ -208,12 +206,23 @@ def test_predictor_classify_bad_content_type():
208206
209207
210208def assert_invoked (sagemaker_session , ** kwargs ):
209+ sagemaker_session .sagemaker_runtime_client .invoke_endpoint .assert_called_once_with (** kwargs )
210+
211+
212+ def assert_invoked_with_body_dict (sagemaker_session , ** kwargs ):
211213 call = sagemaker_session .sagemaker_runtime_client .invoke_endpoint .call_args
212214 cargs , ckwargs = call
213215 assert not cargs
214216 assert len (kwargs ) == len (ckwargs )
215217 for k in ckwargs :
216- assert kwargs [k ] == ckwargs [k ]
218+ if k != 'Body' :
219+ assert kwargs [k ] == ckwargs [k ]
220+ else :
221+ actual_body = json .loads (ckwargs [k ])
222+ expected_body = json .loads (kwargs [k ])
223+ assert len (actual_body ) == len (expected_body )
224+ for k2 in actual_body :
225+ assert actual_body [k2 ] == expected_body [k2 ]
217226
218227
219228def mock_response (expected_response , sagemaker_session , content_type = JSON_CONTENT_TYPE ):
0 commit comments