2424 ERR_STR_BOTH_OR_NONE_INSTANCEGROUPS_OR_INSTANCEFLEETS ,
2525 ERR_STR_WITH_BOTH_CLUSTER_ID_AND_CLUSTER_CFG ,
2626 ERR_STR_WITHOUT_CLUSTER_ID_AND_CLUSTER_CFG ,
27+ ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID ,
2728)
2829from sagemaker .workflow .steps import CacheConfig
2930from sagemaker .workflow .pipeline import Pipeline , PipelineGraph
3031from sagemaker .workflow .parameters import ParameterString
3132from tests .unit .sagemaker .workflow .helpers import CustomStep , ordered
3233
3334
34- def test_emr_step_with_one_step_config (sagemaker_session ):
35+ @pytest .mark .parametrize ("execution_role_arn" , [None , "arn:aws:iam:000000000000:role/runtime-role" ])
36+ def test_emr_step_with_one_step_config (sagemaker_session , execution_role_arn ):
3537 emr_step_config = EMRStepConfig (
3638 jar = "s3:/script-runner/script-runner.jar" ,
3739 args = ["--arg_0" , "arg_0_value" ],
@@ -47,9 +49,11 @@ def test_emr_step_with_one_step_config(sagemaker_session):
4749 step_config = emr_step_config ,
4850 depends_on = ["TestStep" ],
4951 cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" ),
52+ execution_role_arn = execution_role_arn ,
5053 )
5154 emr_step .add_depends_on (["SecondTestStep" ])
52- assert emr_step .to_request () == {
55+
56+ expected_request = {
5357 "Name" : "MyEMRStep" ,
5458 "Type" : "EMR" ,
5559 "Arguments" : {
@@ -72,7 +76,16 @@ def test_emr_step_with_one_step_config(sagemaker_session):
7276 "CacheConfig" : {"Enabled" : True , "ExpireAfter" : "PT1H" },
7377 }
7478
79+ if execution_role_arn is not None :
80+ expected_request ["Arguments" ]["ExecutionRoleArn" ] = execution_role_arn
81+
82+ assert emr_step .to_request () == expected_request
7583 assert emr_step .properties .ClusterId == "MyClusterID"
84+ assert (
85+ emr_step .properties .ExecutionRoleArn == execution_role_arn
86+ if execution_role_arn is not None
87+ else True
88+ )
7689 assert emr_step .properties .ActionOnFailure .expr == {"Get" : "Steps.MyEMRStep.ActionOnFailure" }
7790 assert emr_step .properties .Config .Args .expr == {"Get" : "Steps.MyEMRStep.Config.Args" }
7891 assert emr_step .properties .Config .Jar .expr == {"Get" : "Steps.MyEMRStep.Config.Jar" }
@@ -239,6 +252,27 @@ def test_emr_step_throws_exception_when_both_cluster_id_and_cluster_config_are_n
239252 assert actual_error_msg == expected_error_msg
240253
241254
255+ def test_emr_step_throws_exception_when_both_execution_role_arn_and_cluster_config_are_present ():
256+ with pytest .raises (ValueError ) as exceptionInfo :
257+ EMRStep (
258+ name = g_emr_step_name ,
259+ display_name = "MyEMRStep" ,
260+ description = "MyEMRStepDescription" ,
261+ step_config = g_emr_step_config ,
262+ cluster_id = None ,
263+ cluster_config = g_cluster_config ,
264+ depends_on = ["TestStep" ],
265+ cache_config = CacheConfig (enable_caching = True , expire_after = "PT1H" ),
266+ execution_role_arn = "arn:aws:iam:000000000000:role/some-role" ,
267+ )
268+ expected_error_msg = ERR_STR_WITH_EXEC_ROLE_ARN_AND_WITHOUT_CLUSTER_ID .format (
269+ step_name = g_emr_step_name
270+ )
271+ actual_error_msg = exceptionInfo .value .args [0 ]
272+
273+ assert actual_error_msg == expected_error_msg
274+
275+
242276def test_emr_step_with_valid_cluster_config ():
243277 emr_step = EMRStep (
244278 name = g_emr_step_name ,
0 commit comments