@@ -96,6 +96,64 @@ def test_transformer_fails_without_model():
9696 )
9797
9898
99+ def test_transformer_init (sagemaker_session ):
100+ transformer = Transformer (
101+ MODEL_NAME , INSTANCE_COUNT , INSTANCE_TYPE , sagemaker_session = sagemaker_session
102+ )
103+
104+ assert transformer .model_name == MODEL_NAME
105+ assert transformer .instance_count == INSTANCE_COUNT
106+ assert transformer .instance_type == INSTANCE_TYPE
107+ assert transformer .sagemaker_session == sagemaker_session
108+
109+ assert transformer ._current_job_name is None
110+ assert transformer .latest_transform_job is None
111+ assert transformer ._reset_output_path is False
112+
113+
114+ def test_transformer_init_optional_params (sagemaker_session ):
115+ strategy = "MultiRecord"
116+ assemble_with = "Line"
117+ accept = "text/csv"
118+ max_concurrent_transforms = 100
119+ max_payload = 100
120+ tags = {"Key" : "foo" , "Value" : "bar" }
121+ env = {"FOO" : "BAR" }
122+
123+ transformer = Transformer (
124+ MODEL_NAME ,
125+ INSTANCE_COUNT ,
126+ INSTANCE_TYPE ,
127+ strategy = strategy ,
128+ assemble_with = assemble_with ,
129+ output_path = OUTPUT_PATH ,
130+ output_kms_key = KMS_KEY_ID ,
131+ accept = accept ,
132+ max_concurrent_transforms = max_concurrent_transforms ,
133+ max_payload = max_payload ,
134+ tags = tags ,
135+ env = env ,
136+ base_transform_job_name = JOB_NAME ,
137+ sagemaker_session = sagemaker_session ,
138+ volume_kms_key = KMS_KEY_ID ,
139+ )
140+
141+ assert transformer .model_name == MODEL_NAME
142+ assert transformer .strategy == strategy
143+ assert transformer .env == env
144+ assert transformer .output_path == OUTPUT_PATH
145+ assert transformer .output_kms_key == KMS_KEY_ID
146+ assert transformer .accept == accept
147+ assert transformer .assemble_with == assemble_with
148+ assert transformer .instance_count == INSTANCE_COUNT
149+ assert transformer .instance_type == INSTANCE_TYPE
150+ assert transformer .volume_kms_key == KMS_KEY_ID
151+ assert transformer .max_concurrent_transforms == max_concurrent_transforms
152+ assert transformer .max_payload == max_payload
153+ assert transformer .tags == tags
154+ assert transformer .base_transform_job_name == JOB_NAME
155+
156+
99157@patch ("sagemaker.transformer._TransformJob.start_new" )
100158def test_transform_with_all_params (start_new_job , transformer ):
101159 content_type = "text/csv"
@@ -333,29 +391,78 @@ def test_prepare_init_params_from_job_description_all_keys(transformer):
333391
334392
335393# _TransformJob tests
336- def test_start_new (transformer , sagemaker_session ):
394+ @patch ("sagemaker.transformer._TransformJob._load_config" )
395+ @patch ("sagemaker.transformer._TransformJob._prepare_data_processing" )
396+ def test_start_new (prepare_data_processing , load_config , sagemaker_session ):
397+ input_config = "input"
398+ output_config = "output"
399+ resource_config = "resource"
400+ load_config .return_value = {
401+ "input_config" : input_config ,
402+ "output_config" : output_config ,
403+ "resource_config" : resource_config ,
404+ }
405+
406+ strategy = "MultiRecord"
407+ max_concurrent_transforms = 100
408+ max_payload = 100
409+ tags = {"Key" : "foo" , "Value" : "bar" }
410+ env = {"FOO" : "BAR" }
411+
412+ transformer = Transformer (
413+ MODEL_NAME ,
414+ INSTANCE_COUNT ,
415+ INSTANCE_TYPE ,
416+ strategy = strategy ,
417+ output_path = OUTPUT_PATH ,
418+ max_concurrent_transforms = max_concurrent_transforms ,
419+ max_payload = max_payload ,
420+ tags = tags ,
421+ env = env ,
422+ sagemaker_session = sagemaker_session ,
423+ )
337424 transformer ._current_job_name = JOB_NAME
338425
339- job = _TransformJob (sagemaker_session , JOB_NAME )
340- started_job = job .start_new (
341- transformer ,
342- DATA ,
343- S3_DATA_TYPE ,
344- None ,
345- None ,
346- None ,
347- None ,
348- None ,
349- None ,
350- {"ExperimentName" : "exp" },
426+ content_type = "text/csv"
427+ compression_type = "Gzip"
428+ split_type = "Line"
429+ io_filter = "$"
430+ join_source = "Input"
431+ job = _TransformJob .start_new (
432+ transformer = transformer ,
433+ data = DATA ,
434+ data_type = S3_DATA_TYPE ,
435+ content_type = content_type ,
436+ compression_type = compression_type ,
437+ split_type = split_type ,
438+ input_filter = io_filter ,
439+ output_filter = io_filter ,
440+ join_source = join_source ,
441+ experiment_config = {"ExperimentName" : "exp" },
351442 )
352443
353- assert started_job .sagemaker_session == sagemaker_session
354- sagemaker_session . transform . assert_called_once ()
444+ assert job .sagemaker_session == sagemaker_session
445+ assert job . job_name == JOB_NAME
355446
356- called_args = sagemaker_session .transform .call_args
447+ load_config .assert_called_with (
448+ DATA , S3_DATA_TYPE , content_type , compression_type , split_type , transformer
449+ )
450+ prepare_data_processing .assert_called_with (io_filter , io_filter , join_source )
357451
358- assert called_args [1 ]["experiment_config" ] == {"ExperimentName" : "exp" }
452+ sagemaker_session .transform .assert_called_with (
453+ job_name = JOB_NAME ,
454+ model_name = MODEL_NAME ,
455+ strategy = strategy ,
456+ max_concurrent_transforms = max_concurrent_transforms ,
457+ max_payload = max_payload ,
458+ env = env ,
459+ input_config = input_config ,
460+ output_config = output_config ,
461+ resource_config = resource_config ,
462+ experiment_config = {"ExperimentName" : "exp" },
463+ tags = tags ,
464+ data_processing = prepare_data_processing .return_value ,
465+ )
359466
360467
361468def test_load_config (transformer ):
0 commit comments