6969from tests .unit import DATA_DIR
7070from tests .unit .sagemaker .workflow .conftest import ROLE , BUCKET , IMAGE_URI , INSTANCE_TYPE
7171
72+ HF_INSTANCE_TYPE = "ml.p3.2xlarge"
7273DUMMY_S3_SCRIPT_PATH = "s3://dummy-s3/dummy_script.py"
7374LOCAL_SCRIPT_PATH = os .path .join (DATA_DIR , "workflow/abalone/preprocessing.py" )
7475SPARK_APP_JAR_PATH = os .path .join (
142143 pytorch_version = "1.7" ,
143144 role = ROLE ,
144145 instance_count = 1 ,
145- instance_type = "ml.p3.2xlarge" ,
146+ instance_type = HF_INSTANCE_TYPE ,
146147 ),
147148 {"code" : DUMMY_S3_SCRIPT_PATH },
148149 ),
@@ -446,8 +447,15 @@ def test_processing_step_with_framework_processor(
446447):
447448
448449 processor , run_inputs = framework_processor
450+ default_instance_type = (
451+ HF_INSTANCE_TYPE if type (processor ) is HuggingFaceProcessor else INSTANCE_TYPE
452+ )
453+ instance_type_param = ParameterString (
454+ name = "ProcessingInstanceType" , default_value = default_instance_type
455+ )
449456 processor .sagemaker_session = pipeline_session
450457 processor .role = ROLE
458+ processor .instance_type = instance_type_param
451459
452460 processor .volume_kms_key = "volume-kms-key"
453461 processor .network_config = network_config
@@ -465,6 +473,7 @@ def test_processing_step_with_framework_processor(
465473 name = "MyPipeline" ,
466474 steps = [step ],
467475 sagemaker_session = pipeline_session ,
476+ parameters = [instance_type_param ],
468477 )
469478
470479 step_args = get_step_args_helper (step_args , "Processing" )
@@ -475,6 +484,12 @@ def test_processing_step_with_framework_processor(
475484 step_args ["ProcessingOutputConfig" ]["Outputs" ][0 ]["S3Output" ]["S3Uri" ]
476485 == processing_output .destination
477486 )
487+ assert (
488+ type (step_args ["ProcessingResources" ]["ClusterConfig" ]["InstanceType" ]) is ParameterString
489+ )
490+ step_args ["ProcessingResources" ]["ClusterConfig" ]["InstanceType" ] = step_args [
491+ "ProcessingResources"
492+ ]["ClusterConfig" ]["InstanceType" ].expr
478493
479494 del step_args ["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ]
480495 del step_def ["Arguments" ]["ProcessingInputs" ][0 ]["S3Input" ]["S3Uri" ]
0 commit comments