1414
1515import os
1616
17+ import boto3
1718import pytest
1819from botocore .exceptions import WaiterError
1920
21+ from sagemaker .workflow import ParameterString
2022from sagemaker .workflow .automl_step import AutoMLStep
2123from sagemaker .automl .automl import AutoML , AutoMLInput
2224
23- from sagemaker import utils , get_execution_role
24- from sagemaker .utils import unique_name_from_base
25+ from sagemaker import utils , get_execution_role , ModelMetrics , MetricsSource
2526from sagemaker .workflow .model_step import ModelStep
2627from sagemaker .workflow .pipeline import Pipeline
2728
@@ -50,10 +51,8 @@ def test_automl_step(pipeline_session, role, pipeline_name):
5051 role = role ,
5152 target_attribute_name = TARGET_ATTRIBUTE_NAME ,
5253 sagemaker_session = pipeline_session ,
53- max_candidates = 1 ,
5454 mode = MODE ,
5555 )
56- job_name = unique_name_from_base ("auto-ml" , max_length = 32 )
5756 s3_input_training = pipeline_session .upload_data (
5857 path = TRAINING_DATA , key_prefix = PREFIX + "/input"
5958 )
@@ -72,27 +71,56 @@ def test_automl_step(pipeline_session, role, pipeline_name):
7271 )
7372 inputs = [input_training , input_validation ]
7473
75- step_args = auto_ml .fit (inputs = inputs , job_name = job_name )
74+ step_args = auto_ml .fit (inputs = inputs )
7675
7776 automl_step = AutoMLStep (
7877 name = "MyAutoMLStep" ,
7978 step_args = step_args ,
8079 )
8180
8281 automl_model = automl_step .get_best_auto_ml_model (sagemaker_session = pipeline_session , role = role )
83-
8482 step_args_create_model = automl_model .create (
8583 instance_type = "c4.4xlarge" ,
8684 )
87-
8885 automl_model_step = ModelStep (
8986 name = "MyAutoMLModelStep" ,
9087 step_args = step_args_create_model ,
9188 )
9289
90+ model_package_group_name = ParameterString (
91+ name = "ModelPackageName" , default_value = "AutoMlModelPackageGroup"
92+ )
93+ model_approval_status = ParameterString (name = "ModelApprovalStatus" , default_value = "Approved" )
94+ model_metrics = ModelMetrics (
95+ model_statistics = MetricsSource (
96+ s3_uri = automl_step .properties .BestCandidateProperties .ModelInsightsJsonReportPath ,
97+ content_type = "application/json" ,
98+ ),
99+ explainability = MetricsSource (
100+ s3_uri = automl_step .properties .BestCandidateProperties .ExplainabilityJsonReportPath ,
101+ content_type = "application/json" ,
102+ ),
103+ )
104+ step_args_register_model = automl_model .register (
105+ content_types = ["text/csv" ],
106+ response_types = ["text/csv" ],
107+ inference_instances = ["ml.m5.xlarge" ],
108+ transform_instances = ["ml.m5.xlarge" ],
109+ model_package_group_name = model_package_group_name ,
110+ approval_status = model_approval_status ,
111+ model_metrics = model_metrics ,
112+ )
113+ register_model_step = ModelStep (
114+ name = "ModelRegistrationStep" , step_args = step_args_register_model
115+ )
116+
93117 pipeline = Pipeline (
94118 name = pipeline_name ,
95- steps = [automl_step , automl_model_step ],
119+ parameters = [
120+ model_approval_status ,
121+ model_package_group_name ,
122+ ],
123+ steps = [automl_step , automl_model_step , register_model_step ],
96124 sagemaker_session = pipeline_session ,
97125 )
98126
@@ -114,9 +142,20 @@ def test_automl_step(pipeline_session, role, pipeline_name):
114142 assert step ["Metadata" ]["AutoMLJob" ]["Arn" ] is not None
115143
116144 assert has_automl_job
117- assert len (execution_steps ) == 2
145+ assert len (execution_steps ) == 3
118146 finally :
119147 try :
148+ sagemaker_client = boto3 .client ("sagemaker" )
149+ for model_package in sagemaker_client .list_model_packages (
150+ ModelPackageGroupName = "AutoMlModelPackageGroup"
151+ )["ModelPackageSummaryList" ]:
152+ sagemaker_client .delete_model_package (
153+ ModelPackageName = model_package ["ModelPackageArn" ]
154+ )
155+ sagemaker_client .delete_model_package_group (
156+ ModelPackageGroupName = "AutoMlModelPackageGroup"
157+ )
158+
120159 pipeline .delete ()
121160 except Exception :
122161 pass
0 commit comments