@@ -83,22 +83,26 @@ def fixture_sagemaker_session():
8383 return session
8484
8585
86- def _get_full_gpu_image_uri (version , instance_type , training_compiler_config ):
86+ def _get_full_gpu_image_uri (version , instance_type , training_compiler_config , py_version ):
8787 return image_uris .retrieve (
8888 "pytorch-training-compiler" ,
8989 REGION ,
9090 version = version ,
91- py_version = "py38" ,
91+ py_version = py_version ,
9292 instance_type = instance_type ,
9393 image_scope = "training" ,
9494 container_version = None ,
9595 training_compiler_config = training_compiler_config ,
9696 )
9797
9898
99- def _create_train_job (version , instance_type , training_compiler_config , instance_count = 1 ):
99+ def _create_train_job (
100+ version , instance_type , training_compiler_config , py_version , instance_count = 1
101+ ):
100102 return {
101- "image_uri" : _get_full_gpu_image_uri (version , instance_type , training_compiler_config ),
103+ "image_uri" : _get_full_gpu_image_uri (
104+ version , instance_type , training_compiler_config , py_version
105+ ),
102106 "input_mode" : "File" ,
103107 "input_config" : [
104108 {
@@ -303,15 +307,20 @@ def test_unsupported_distribution(
303307@patch ("time.time" , return_value = TIME )
304308@pytest .mark .parametrize ("instance_class" , SUPPORTED_GPU_INSTANCE_CLASSES )
305309def test_pytorchxla_distribution (
306- time , name_from_base , sagemaker_session , pytorch_training_compiler_version , instance_class
310+ time ,
311+ name_from_base ,
312+ sagemaker_session ,
313+ pytorch_training_compiler_version ,
314+ instance_class ,
315+ pytorch_training_compiler_py_version ,
307316):
308317 if Version (pytorch_training_compiler_version ) < Version ("1.12" ):
309318 pytest .skip ("This test is intended for PyTorch 1.12 and above" )
310319 compiler_config = TrainingCompilerConfig ()
311320 instance_type = f"ml.{ instance_class } .xlarge"
312321
313322 pt = PyTorch (
314- py_version = "py38" ,
323+ py_version = pytorch_training_compiler_py_version ,
315324 entry_point = SCRIPT_PATH ,
316325 role = ROLE ,
317326 sagemaker_session = sagemaker_session ,
@@ -333,7 +342,11 @@ def test_pytorchxla_distribution(
333342 assert boto_call_names == ["resource" ]
334343
335344 expected_train_args = _create_train_job (
336- pytorch_training_compiler_version , instance_type , compiler_config , instance_count = 2
345+ pytorch_training_compiler_version ,
346+ instance_type ,
347+ compiler_config ,
348+ pytorch_training_compiler_py_version ,
349+ instance_count = 2 ,
337350 )
338351 expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
339352 expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -357,13 +370,17 @@ def test_pytorchxla_distribution(
357370@patch ("time.time" , return_value = TIME )
358371@pytest .mark .parametrize ("instance_class" , SUPPORTED_GPU_INSTANCE_CLASSES )
359372def test_default_compiler_config (
360- time , name_from_base , sagemaker_session , pytorch_training_compiler_version , instance_class
373+ time ,
374+ name_from_base ,
375+ sagemaker_session ,
376+ pytorch_training_compiler_version ,
377+ instance_class ,
378+ pytorch_training_compiler_py_version ,
361379):
362380 compiler_config = TrainingCompilerConfig ()
363381 instance_type = f"ml.{ instance_class } .xlarge"
364-
365382 pt = PyTorch (
366- py_version = "py38" ,
383+ py_version = pytorch_training_compiler_py_version ,
367384 entry_point = SCRIPT_PATH ,
368385 role = ROLE ,
369386 sagemaker_session = sagemaker_session ,
@@ -384,7 +401,10 @@ def test_default_compiler_config(
384401 assert boto_call_names == ["resource" ]
385402
386403 expected_train_args = _create_train_job (
387- pytorch_training_compiler_version , instance_type , compiler_config
404+ pytorch_training_compiler_version ,
405+ instance_type ,
406+ compiler_config ,
407+ pytorch_training_compiler_py_version ,
388408 )
389409 expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
390410 expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -406,12 +426,16 @@ def test_default_compiler_config(
406426@patch ("sagemaker.estimator.name_from_base" , return_value = JOB_NAME )
407427@patch ("time.time" , return_value = TIME )
408428def test_debug_compiler_config (
409- time , name_from_base , sagemaker_session , pytorch_training_compiler_version
429+ time ,
430+ name_from_base ,
431+ sagemaker_session ,
432+ pytorch_training_compiler_version ,
433+ pytorch_training_compiler_py_version ,
410434):
411435 compiler_config = TrainingCompilerConfig (debug = True )
412436
413437 pt = PyTorch (
414- py_version = "py38" ,
438+ py_version = pytorch_training_compiler_py_version ,
415439 entry_point = SCRIPT_PATH ,
416440 role = ROLE ,
417441 sagemaker_session = sagemaker_session ,
@@ -432,7 +456,10 @@ def test_debug_compiler_config(
432456 assert boto_call_names == ["resource" ]
433457
434458 expected_train_args = _create_train_job (
435- pytorch_training_compiler_version , INSTANCE_TYPE , compiler_config
459+ pytorch_training_compiler_version ,
460+ INSTANCE_TYPE ,
461+ compiler_config ,
462+ pytorch_training_compiler_py_version ,
436463 )
437464 expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
438465 expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -454,12 +481,16 @@ def test_debug_compiler_config(
454481@patch ("sagemaker.estimator.name_from_base" , return_value = JOB_NAME )
455482@patch ("time.time" , return_value = TIME )
456483def test_disable_compiler_config (
457- time , name_from_base , sagemaker_session , pytorch_training_compiler_version
484+ time ,
485+ name_from_base ,
486+ sagemaker_session ,
487+ pytorch_training_compiler_version ,
488+ pytorch_training_compiler_py_version ,
458489):
459490 compiler_config = TrainingCompilerConfig (enabled = False )
460491
461492 pt = PyTorch (
462- py_version = "py38" ,
493+ py_version = pytorch_training_compiler_py_version ,
463494 entry_point = SCRIPT_PATH ,
464495 role = ROLE ,
465496 sagemaker_session = sagemaker_session ,
@@ -480,7 +511,10 @@ def test_disable_compiler_config(
480511 assert boto_call_names == ["resource" ]
481512
482513 expected_train_args = _create_train_job (
483- pytorch_training_compiler_version , INSTANCE_TYPE , compiler_config
514+ pytorch_training_compiler_version ,
515+ INSTANCE_TYPE ,
516+ compiler_config ,
517+ pytorch_training_compiler_py_version ,
484518 )
485519 expected_train_args ["input_config" ][0 ]["DataSource" ]["S3DataSource" ]["S3Uri" ] = inputs
486520 expected_train_args ["enable_sagemaker_metrics" ] = False
@@ -508,7 +542,10 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
508542 "py38-cu113-ubuntu20.04"
509543 )
510544 returned_job_description = {
511- "AlgorithmSpecification" : {"TrainingInputMode" : "File" , "TrainingImage" : training_image },
545+ "AlgorithmSpecification" : {
546+ "TrainingInputMode" : "File" ,
547+ "TrainingImage" : training_image ,
548+ },
512549 "HyperParameters" : {
513550 "sagemaker_submit_directory" : '"s3://some/sourcedir.tar.gz"' ,
514551 "sagemaker_program" : '"iris-dnn-classifier.py"' ,
@@ -530,7 +567,10 @@ def test_attach(sagemaker_session, compiler_enabled, debug_enabled):
530567 "TrainingJobName" : "trcomp" ,
531568 "TrainingJobStatus" : "Completed" ,
532569 "TrainingJobArn" : "arn:aws:sagemaker:us-west-2:336:training-job/trcomp" ,
533- "OutputDataConfig" : {"KmsKeyId" : "" , "S3OutputPath" : "s3://place/output/trcomp" },
570+ "OutputDataConfig" : {
571+ "KmsKeyId" : "" ,
572+ "S3OutputPath" : "s3://place/output/trcomp" ,
573+ },
534574 "TrainingJobOutput" : {"S3TrainingJobOutput" : "s3://here/output.tar.gz" },
535575 }
536576 sagemaker_session .sagemaker_client .describe_training_job = Mock (
0 commit comments