1717import boto3
1818import pytest
1919from sagemaker .tensorflow import TensorFlow
20+ from sagemaker .tuner import HyperparameterTuner , IntegerParameter
2021from six .moves .urllib .parse import urlparse
2122
2223from test .integration .utils import processor , py_version , unique_name_from_base # noqa: F401
24+ from timeout import timeout
2325
2426
2527@pytest .mark .deploy_test
2628def test_mnist (sagemaker_session , ecr_image , instance_type , framework_version ):
27- resource_path = os .path .join (os .path .dirname (__file__ ), '../ ..' , 'resources' )
29+ resource_path = os .path .join (os .path .dirname (__file__ ), '..' , ' ..' , 'resources' )
2830 script = os .path .join (resource_path , 'mnist' , 'mnist.py' )
2931 estimator = TensorFlow (entry_point = script ,
3032 role = 'SageMakerRole' ,
@@ -42,7 +44,7 @@ def test_mnist(sagemaker_session, ecr_image, instance_type, framework_version):
4244
4345
4446def test_distributed_mnist_no_ps (sagemaker_session , ecr_image , instance_type , framework_version ):
45- resource_path = os .path .join (os .path .dirname (__file__ ), '../ ..' , 'resources' )
47+ resource_path = os .path .join (os .path .dirname (__file__ ), '..' , ' ..' , 'resources' )
4648 script = os .path .join (resource_path , 'mnist' , 'mnist.py' )
4749 estimator = TensorFlow (entry_point = script ,
4850 role = 'SageMakerRole' ,
@@ -110,6 +112,40 @@ def test_s3_plugin(sagemaker_session, ecr_image, instance_type, region, framewor
110112 _assert_checkpoint_exists (region , estimator .model_dir , 200 )
111113
112114
115+ def test_tuning (sagemaker_session , ecr_image , instance_type , framework_version ):
116+ resource_path = os .path .join (os .path .dirname (__file__ ), '..' , '..' , 'resources' )
117+ script = os .path .join (resource_path , 'mnist' , 'mnist.py' )
118+
119+ estimator = TensorFlow (entry_point = script ,
120+ role = 'SageMakerRole' ,
121+ train_instance_type = instance_type ,
122+ train_instance_count = 1 ,
123+ sagemaker_session = sagemaker_session ,
124+ image_name = ecr_image ,
125+ framework_version = framework_version ,
126+ script_mode = True )
127+
128+ hyperparameter_ranges = {'epochs' : IntegerParameter (1 , 2 )}
129+ objective_metric_name = 'accuracy'
130+ metric_definitions = [{'Name' : objective_metric_name , 'Regex' : 'accuracy = ([0-9\\ .]+)' }]
131+
132+ tuner = HyperparameterTuner (estimator ,
133+ objective_metric_name ,
134+ hyperparameter_ranges ,
135+ metric_definitions ,
136+ max_jobs = 2 ,
137+ max_parallel_jobs = 2 )
138+
139+ with timeout (minutes = 20 ):
140+ inputs = estimator .sagemaker_session .upload_data (
141+ path = os .path .join (resource_path , 'mnist' , 'data' ),
142+ key_prefix = 'scriptmode/mnist' )
143+
144+ tuning_job_name = unique_name_from_base ('test-tf-sm-tuning' , max_length = 32 )
145+ tuner .fit (inputs , job_name = tuning_job_name )
146+ tuner .wait ()
147+
148+
113149def _assert_checkpoint_exists (region , model_dir , checkpoint_number ):
114150 _assert_s3_file_exists (region , os .path .join (model_dir , 'graph.pbtxt' ))
115151 _assert_s3_file_exists (region ,
0 commit comments