2323
2424from tests .integ import test_local_mode
2525from tests .unit import SAGEMAKER_CONFIG_TRANSFORM_JOB
26+ from sagemaker .model_monitor import DatasetFormat
27+ from sagemaker .workflow .quality_check_step import (
28+ ModelQualityCheckConfig ,
29+ )
30+ from sagemaker .workflow .check_job_config import CheckJobConfig
31+
32+ _CHECK_JOB_PREFIX = "CheckJobPrefix"
2633
2734ROLE = "DummyRole"
2835REGION = "us-west-2"
4956 "base_transform_job_name" : JOB_NAME ,
5057}
5158
59+ PROCESS_REQUEST_ARGS = {
60+ "inputs" : "processing_inputs" ,
61+ "output_config" : "output_config" ,
62+ "job_name" : "job_name" ,
63+ "resources" : "resource_config" ,
64+ "stopping_condition" : {"MaxRuntimeInSeconds" : 3600 },
65+ "app_specification" : "app_specification" ,
66+ "experiment_config" : {"ExperimentName" : "AnExperiment" },
67+ }
68+
5269MODEL_DESC_PRIMARY_CONTAINER = {"PrimaryContainer" : {"Image" : IMAGE_URI }}
5370
5471MODEL_DESC_CONTAINERS_ONLY = {"Containers" : [{"Image" : IMAGE_URI }]}
@@ -72,7 +89,7 @@ def mock_create_tar_file():
7289
7390@pytest .fixture ()
7491def sagemaker_session ():
75- boto_mock = Mock (name = "boto_session" )
92+ boto_mock = Mock (name = "boto_session" , region_name = REGION )
7693 session = Mock (
7794 name = "sagemaker_session" ,
7895 boto_session = boto_mock ,
@@ -764,6 +781,48 @@ def test_stop_transform_job(sagemaker_session, transformer):
764781 sagemaker_session .stop_transform_job .assert_called_once_with (name = JOB_NAME )
765782
766783
784+ @patch ("sagemaker.transformer.Transformer._retrieve_image_uri" , return_value = IMAGE_URI )
785+ @patch ("sagemaker.workflow.pipeline.Pipeline.upsert" , return_value = {})
786+ @patch ("sagemaker.workflow.pipeline.Pipeline.start" , return_value = Mock ())
787+ def test_transform_with_monitoring_create_and_starts_pipeline (
788+ pipeline_start , upsert , image_uri , sagemaker_session , transformer
789+ ):
790+
791+ config = CheckJobConfig (
792+ role = ROLE ,
793+ instance_count = 1 ,
794+ instance_type = "ml.m5.xlarge" ,
795+ volume_size_in_gb = 60 ,
796+ max_runtime_in_seconds = 1800 ,
797+ sagemaker_session = sagemaker_session ,
798+ base_job_name = _CHECK_JOB_PREFIX ,
799+ )
800+
801+ quality_check_config = ModelQualityCheckConfig (
802+ baseline_dataset = "s3://baseline_dataset_s3_url" ,
803+ dataset_format = DatasetFormat .csv (header = True ),
804+ problem_type = "BinaryClassification" ,
805+ inference_attribute = "quality_cfg_attr_value" ,
806+ probability_attribute = "quality_cfg_attr_value" ,
807+ ground_truth_attribute = "quality_cfg_attr_value" ,
808+ probability_threshold_attribute = "quality_cfg_attr_value" ,
809+ post_analytics_processor_script = "s3://my_bucket/data_quality/postprocessor.py" ,
810+ output_s3_uri = "s3://output_s3_uri" ,
811+ )
812+
813+ transformer .transform_with_monitoring (
814+ monitoring_config = quality_check_config ,
815+ monitoring_resource_config = config ,
816+ data = DATA ,
817+ content_type = "text/libsvm" ,
818+ supplied_baseline_constraints = "supplied_baseline_constraints" ,
819+ role = ROLE ,
820+ )
821+
822+ upsert .assert_called_once ()
823+ pipeline_start .assert_called_once ()
824+
825+
767826def test_stop_transform_job_no_transform_job (transformer ):
768827 with pytest .raises (ValueError ) as e :
769828 transformer .stop_transform_job ()
0 commit comments