2424from sagemaker .utils import unique_name_from_base
2525
2626import tests .integ
27+ from tests .integ import timeout
2728
2829ROLE = 'SageMakerRole'
2930
30- RESOURCE_PATH = os .path .join (os .path .dirname (__file__ ), '..' , 'data' , 'tensorflow_mnist' )
31- SCRIPT = os .path .join (RESOURCE_PATH , 'mnist.py' )
31+ RESOURCE_PATH = os .path .join (os .path .dirname (__file__ ), '..' , 'data' )
32+ MNIST_RESOURCE_PATH = os .path .join (RESOURCE_PATH , 'tensorflow_mnist' )
33+ TFS_RESOURCE_PATH = os .path .join (RESOURCE_PATH , 'tfs' , 'tfs-test-entrypoint-with-handler' )
34+
35+ SCRIPT = os .path .join (MNIST_RESOURCE_PATH , 'mnist.py' )
3236PARAMETER_SERVER_DISTRIBUTION = {'parameter_server' : {'enabled' : True }}
3337MPI_DISTRIBUTION = {'mpi' : {'enabled' : True }}
3438TAGS = [{'Key' : 'some-key' , 'Value' : 'some-value' }]
@@ -57,7 +61,7 @@ def test_mnist(sagemaker_session, instance_type):
5761 metric_definitions = [
5862 {'Name' : 'train:global_steps' , 'Regex' : r'global_step\/sec:\s(.*)' }])
5963 inputs = estimator .sagemaker_session .upload_data (
60- path = os .path .join (RESOURCE_PATH , 'data' ),
64+ path = os .path .join (MNIST_RESOURCE_PATH , 'data' ),
6165 key_prefix = 'scriptmode/mnist' )
6266
6367 with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
@@ -88,7 +92,7 @@ def test_server_side_encryption(sagemaker_session):
8892 output_kms_key = kms_key )
8993
9094 inputs = estimator .sagemaker_session .upload_data (
91- path = os .path .join (RESOURCE_PATH , 'data' ),
95+ path = os .path .join (MNIST_RESOURCE_PATH , 'data' ),
9296 key_prefix = 'scriptmode/mnist' )
9397
9498 with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
@@ -110,7 +114,7 @@ def test_mnist_distributed(sagemaker_session, instance_type):
110114 framework_version = TensorFlow .LATEST_VERSION ,
111115 distributions = PARAMETER_SERVER_DISTRIBUTION )
112116 inputs = estimator .sagemaker_session .upload_data (
113- path = os .path .join (RESOURCE_PATH , 'data' ),
117+ path = os .path .join (MNIST_RESOURCE_PATH , 'data' ),
114118 key_prefix = 'scriptmode/distributed_mnist' )
115119
116120 with tests .integ .timeout .timeout (minutes = tests .integ .TRAINING_DEFAULT_TIMEOUT_MINUTES ):
@@ -129,7 +133,7 @@ def test_mnist_async(sagemaker_session):
129133 framework_version = TensorFlow .LATEST_VERSION ,
130134 tags = TAGS )
131135 inputs = estimator .sagemaker_session .upload_data (
132- path = os .path .join (RESOURCE_PATH , 'data' ),
136+ path = os .path .join (MNIST_RESOURCE_PATH , 'data' ),
133137 key_prefix = 'scriptmode/mnist' )
134138 estimator .fit (inputs = inputs , wait = False , job_name = unique_name_from_base ('test-tf-sm-async' ))
135139 training_job_name = estimator .latest_training_job .name
@@ -150,6 +154,35 @@ def test_mnist_async(sagemaker_session):
150154 estimator .latest_training_job .name , TAGS )
151155
152156
157+ @pytest .mark .skipif (tests .integ .PYTHON_VERSION != 'py3' ,
158+ reason = "Script Mode tests are only configured to run with Python 3" )
159+ def test_deploy_with_input_handlers (sagemaker_session , instance_type ):
160+ estimator = TensorFlow (entry_point = 'inference.py' ,
161+ source_dir = TFS_RESOURCE_PATH ,
162+ role = ROLE ,
163+ train_instance_count = 1 ,
164+ train_instance_type = instance_type ,
165+ sagemaker_session = sagemaker_session ,
166+ py_version = 'py3' ,
167+ framework_version = TensorFlow .LATEST_VERSION ,
168+ tags = TAGS )
169+
170+ estimator .fit (job_name = unique_name_from_base ('test-tf-tfs-deploy' ))
171+
172+ endpoint_name = estimator .latest_training_job .name
173+
174+ with timeout .timeout_and_delete_endpoint_by_name (endpoint_name , sagemaker_session ):
175+
176+ predictor = estimator .deploy (initial_instance_count = 1 , instance_type = instance_type ,
177+ endpoint_name = endpoint_name )
178+
179+ input_data = {'instances' : [1.0 , 2.0 , 5.0 ]}
180+ expected_result = {'predictions' : [4.0 , 4.5 , 6.0 ]}
181+
182+ result = predictor .predict (input_data )
183+ assert expected_result == result
184+
185+
153186def _assert_s3_files_exist (s3_url , files ):
154187 parsed_url = urlparse (s3_url )
155188 s3 = boto3 .client ('s3' )
0 commit comments