@@ -128,20 +128,20 @@ def _create_train_job(toolkit, toolkit_version, framework):
128128 }
129129
130130
131- def test_create_tf_model (sagemaker_session , rl_coach_version ):
131+ def test_create_tf_model (sagemaker_session , rl_coach_tf_version ):
132132 container_log_level = '"logging.INFO"'
133133 source_dir = 's3://mybucket/source'
134134 rl = RLEstimator (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
135135 train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
136- toolkit = RLToolkit .COACH , toolkit_version = rl_coach_version ,
136+ toolkit = RLToolkit .COACH , toolkit_version = rl_coach_tf_version ,
137137 framework = RLFramework .TENSORFLOW , container_log_level = container_log_level ,
138138 source_dir = source_dir )
139139
140140 job_name = 'new_name'
141141 rl .fit (inputs = 's3://mybucket/train' , job_name = 'new_name' )
142142 model = rl .create_model ()
143143 supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP [RLToolkit .COACH .value ]
144- framework_version = supported_versions [rl_coach_version ][RLFramework .TENSORFLOW .value ]
144+ framework_version = supported_versions [rl_coach_tf_version ][RLFramework .TENSORFLOW .value ]
145145
146146 assert isinstance (model , tfs .Model )
147147 assert model .sagemaker_session == sagemaker_session
@@ -152,20 +152,20 @@ def test_create_tf_model(sagemaker_session, rl_coach_version):
152152 assert model .vpc_config is None
153153
154154
155- def test_create_mxnet_model (sagemaker_session , rl_coach_version ):
155+ def test_create_mxnet_model (sagemaker_session , rl_coach_mxnet_version ):
156156 container_log_level = '"logging.INFO"'
157157 source_dir = 's3://mybucket/source'
158158 rl = RLEstimator (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
159159 train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
160- toolkit = RLToolkit .COACH , toolkit_version = rl_coach_version ,
160+ toolkit = RLToolkit .COACH , toolkit_version = rl_coach_mxnet_version ,
161161 framework = RLFramework .MXNET , container_log_level = container_log_level ,
162162 source_dir = source_dir )
163163
164164 job_name = 'new_name'
165165 rl .fit (inputs = 's3://mybucket/train' , job_name = 'new_name' )
166166 model = rl .create_model ()
167167 supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP [RLToolkit .COACH .value ]
168- framework_version = supported_versions [rl_coach_version ][RLFramework .MXNET .value ]
168+ framework_version = supported_versions [rl_coach_mxnet_version ][RLFramework .MXNET .value ]
169169
170170 assert isinstance (model , MXNetModel )
171171 assert model .sagemaker_session == sagemaker_session
@@ -179,12 +179,12 @@ def test_create_mxnet_model(sagemaker_session, rl_coach_version):
179179 assert model .vpc_config is None
180180
181181
182- def test_create_model_with_optional_params (sagemaker_session , rl_coach_version ):
182+ def test_create_model_with_optional_params (sagemaker_session , rl_coach_mxnet_version ):
183183 container_log_level = '"logging.INFO"'
184184 source_dir = 's3://mybucket/source'
185185 rl = RLEstimator (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
186186 train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
187- toolkit = RLToolkit .COACH , toolkit_version = rl_coach_version ,
187+ toolkit = RLToolkit .COACH , toolkit_version = rl_coach_mxnet_version ,
188188 framework = RLFramework .MXNET , container_log_level = container_log_level ,
189189 source_dir = source_dir )
190190
@@ -226,10 +226,10 @@ def test_create_model_with_custom_image(sagemaker_session):
226226
227227@patch ('sagemaker.utils.create_tar_file' , MagicMock ())
228228@patch ('time.strftime' , return_value = TIMESTAMP )
229- def test_rl (strftime , sagemaker_session , rl_coach_version ):
229+ def test_rl (strftime , sagemaker_session , rl_coach_mxnet_version ):
230230 rl = RLEstimator (entry_point = SCRIPT_PATH , role = ROLE , sagemaker_session = sagemaker_session ,
231231 train_instance_count = INSTANCE_COUNT , train_instance_type = INSTANCE_TYPE ,
232- toolkit = RLToolkit .COACH , toolkit_version = rl_coach_version ,
232+ toolkit = RLToolkit .COACH , toolkit_version = rl_coach_mxnet_version ,
233233 framework = RLFramework .MXNET )
234234
235235 inputs = 's3://mybucket/train'
@@ -241,7 +241,7 @@ def test_rl(strftime, sagemaker_session, rl_coach_version):
241241 boto_call_names = [c [0 ] for c in sagemaker_session .boto_session .method_calls ]
242242 assert boto_call_names == ['resource' ]
243243
244- expected_train_args = _create_train_job (RLToolkit .COACH .value , rl_coach_version ,
244+ expected_train_args = _create_train_job (RLToolkit .COACH .value , rl_coach_mxnet_version ,
245245 RLFramework .MXNET .value )
246246 expected_train_args ['input_config' ][0 ]['DataSource' ]['S3DataSource' ]['S3Uri' ] = inputs
247247
@@ -250,7 +250,7 @@ def test_rl(strftime, sagemaker_session, rl_coach_version):
250250
251251 model = rl .create_model ()
252252 supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP [RLToolkit .COACH .value ]
253- framework_version = supported_versions [rl_coach_version ][RLFramework .MXNET .value ]
253+ framework_version = supported_versions [rl_coach_mxnet_version ][RLFramework .MXNET .value ]
254254
255255 expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py3'
256256 submit_dir = 's3://notmybucket/sagemaker-rl-mxnet-{}/source/sourcedir.tar.gz' .format (TIMESTAMP )
@@ -266,17 +266,17 @@ def test_rl(strftime, sagemaker_session, rl_coach_version):
266266
267267
268268@patch ('sagemaker.utils.create_tar_file' , MagicMock ())
269- def test_deploy_mxnet (sagemaker_session , rl_coach_version ):
270- rl = _rl_estimator (sagemaker_session , RLToolkit .COACH , rl_coach_version , RLFramework .MXNET ,
269+ def test_deploy_mxnet (sagemaker_session , rl_coach_mxnet_version ):
270+ rl = _rl_estimator (sagemaker_session , RLToolkit .COACH , rl_coach_mxnet_version , RLFramework .MXNET ,
271271 train_instance_type = 'ml.g2.2xlarge' )
272272 rl .fit ()
273273 predictor = rl .deploy (1 , CPU )
274274 assert isinstance (predictor , MXNetPredictor )
275275
276276
277277@patch ('sagemaker.utils.create_tar_file' , MagicMock ())
278- def test_deploy_tfs (sagemaker_session , rl_coach_version ):
279- rl = _rl_estimator (sagemaker_session , RLToolkit .COACH , rl_coach_version , RLFramework .TENSORFLOW ,
278+ def test_deploy_tfs (sagemaker_session , rl_coach_tf_version ):
279+ rl = _rl_estimator (sagemaker_session , RLToolkit .COACH , rl_coach_tf_version , RLFramework .TENSORFLOW ,
280280 train_instance_type = 'ml.g2.2xlarge' )
281281 rl .fit ()
282282 predictor = rl .deploy (1 , GPU )
@@ -312,25 +312,25 @@ def test_train_image_cpu_instances(sagemaker_session, rl_ray_version):
312312 framework .value )
313313
314314
315- def test_train_image_gpu_instances (sagemaker_session , rl_coach_version ):
315+ def test_train_image_gpu_instances (sagemaker_session , rl_coach_mxnet_version ):
316316 toolkit = RLToolkit .COACH
317317 framework = RLFramework .MXNET
318- rl = _rl_estimator (sagemaker_session , toolkit , rl_coach_version , framework ,
318+ rl = _rl_estimator (sagemaker_session , toolkit , rl_coach_mxnet_version , framework ,
319319 train_instance_type = 'ml.g2.2xlarge' )
320- assert rl .train_image () == _get_full_gpu_image_uri (toolkit .value , rl_coach_version ,
320+ assert rl .train_image () == _get_full_gpu_image_uri (toolkit .value , rl_coach_mxnet_version ,
321321 framework .value )
322322
323- rl = _rl_estimator (sagemaker_session , toolkit , rl_coach_version , framework ,
323+ rl = _rl_estimator (sagemaker_session , toolkit , rl_coach_mxnet_version , framework ,
324324 train_instance_type = 'ml.p2.2xlarge' )
325- assert rl .train_image () == _get_full_gpu_image_uri (toolkit .value , rl_coach_version ,
325+ assert rl .train_image () == _get_full_gpu_image_uri (toolkit .value , rl_coach_mxnet_version ,
326326 framework .value )
327327
328328
329- def test_attach (sagemaker_session , rl_coach_version ):
329+ def test_attach (sagemaker_session , rl_coach_mxnet_version ):
330330 training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-{}:{}{}-cpu-py3' \
331- .format (RLFramework .MXNET .value , RLToolkit .COACH .value , rl_coach_version )
331+ .format (RLFramework .MXNET .value , RLToolkit .COACH .value , rl_coach_mxnet_version )
332332 supported_versions = TOOLKIT_FRAMEWORK_VERSION_MAP [RLToolkit .COACH .value ]
333- framework_version = supported_versions [rl_coach_version ][RLFramework .MXNET .value ]
333+ framework_version = supported_versions [rl_coach_mxnet_version ][RLFramework .MXNET .value ]
334334 returned_job_description = {'AlgorithmSpecification' : {'TrainingInputMode' : 'File' ,
335335 'TrainingImage' : training_image },
336336 'HyperParameters' :
@@ -361,7 +361,7 @@ def test_attach(sagemaker_session, rl_coach_version):
361361 assert estimator .framework == RLFramework .MXNET .value
362362 assert estimator .toolkit == RLToolkit .COACH .value
363363 assert estimator .framework_version == framework_version
364- assert estimator .toolkit_version == rl_coach_version
364+ assert estimator .toolkit_version == rl_coach_mxnet_version
365365 assert estimator .role == 'arn:aws:iam::366:role/SageMakerRole'
366366 assert estimator .train_instance_count == 1
367367 assert estimator .train_max_run == 24 * 60 * 60
0 commit comments