1717import sagemaker
1818import sagemaker .predictor
1919import sagemaker .utils
20+ import tests .integ
21+ import tests .integ .timeout
2022from sagemaker .tensorflow .serving import Model , Predictor
21- from tests .integ .timeout import timeout_and_delete_endpoint_by_name
2223
2324
2425@pytest .fixture (scope = 'session' , params = ['ml.c5.xlarge' , 'ml.p3.2xlarge' ])
@@ -32,17 +33,21 @@ def tfs_predictor(instance_type, sagemaker_session, tf_full_version):
3233 model_data = sagemaker_session .upload_data (
3334 path = 'tests/data/tensorflow-serving-test-model.tar.gz' ,
3435 key_prefix = 'tensorflow-serving/models' )
35- with timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
36+ with tests . integ . timeout . timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
3637 model = Model (model_data = model_data , role = 'SageMakerRole' ,
3738 framework_version = tf_full_version ,
3839 sagemaker_session = sagemaker_session )
3940 predictor = model .deploy (1 , instance_type , endpoint_name = endpoint_name )
4041 yield predictor
4142
4243
43- # @pytest.mark.continuous_testing
44- # @pytest.mark.regional_testing
45- def test_predict (tfs_predictor ):
44+ @pytest .mark .continuous_testing
45+ @pytest .mark .regional_testing
46+ def test_predict (tfs_predictor , instance_type ):
47+ if ('p3' in instance_type ) and (
48+ tests .integ .REGION in tests .integ .HOSTING_P3_UNAVAILABLE_REGIONS ):
49+ pytest .skip ('no ml.p3 instances in this region' )
50+
4651 input_data = {'instances' : [1.0 , 2.0 , 5.0 ]}
4752 expected_result = {'predictions' : [3.5 , 4.0 , 5.5 ]}
4853
0 commit comments