141141 "MaxParallelOfTests" : 5 ,
142142}
143143
144+ IR_SAMPLE_PRIMARY_CONTAINER = {
145+ "Image" : "model-image-for-ir" ,
146+ "Environment" : {},
147+ "ModelDataUrl" : "s3://bucket/model.tar.gz" ,
148+ }
149+
150+ IR_PRODUCTION_VARIANTS = [
151+ {
152+ "ModelName" : "model-name-for-ir" ,
153+ "VariantName" : "AllTraffic" ,
154+ "InitialVariantWeight" : 1 ,
155+ "InitialInstanceCount" : 1 ,
156+ "InstanceType" : "ml.m5.xlarge" ,
157+ }
158+ ]
159+
160+ IR_OVERRIDDEN_PRODUCTION_VARIANTS = [
161+ {
162+ "ModelName" : "model-name-for-ir" ,
163+ "VariantName" : "AllTraffic" ,
164+ "InitialVariantWeight" : 1 ,
165+ "InitialInstanceCount" : 5 ,
166+ "InstanceType" : "ml.c5.2xlarge" ,
167+ }
168+ ]
169+
170+ IR_SERVERLESS_PRODUCTION_VARIANTS = [
171+ {
172+ "ModelName" : "model-name-for-ir" ,
173+ "VariantName" : "AllTraffic" ,
174+ "InitialVariantWeight" : 1 ,
175+ "ServerlessConfig" : {"MemorySizeInMB" : 2048 , "MaxConcurrency" : 5 },
176+ }
177+ ]
178+
144179
145180@pytest .fixture ()
146181def sagemaker_session ():
@@ -185,17 +220,17 @@ def test_right_size_default_with_model_name_successful(sagemaker_session, model)
185220 framework = IR_SAMPLE_FRAMEWORK ,
186221 )
187222
188- assert sagemaker_session .create_model .called_with (
223+ sagemaker_session .create_model .assert_called_with (
189224 name = ANY ,
190225 role = IR_ROLE_ARN ,
191226 container_defs = None ,
192- primary_container = {} ,
227+ primary_container = IR_SAMPLE_PRIMARY_CONTAINER ,
193228 vpc_config = None ,
194229 enable_network_isolation = False ,
195230 )
196231
197232 # assert that the create api has been called with default parameters with model name
198- assert sagemaker_session .create_inference_recommendations_job .called_with (
233+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
199234 role = IR_ROLE_ARN ,
200235 job_name = IR_JOB_NAME ,
201236 job_type = "Default" ,
@@ -213,7 +248,9 @@ def test_right_size_default_with_model_name_successful(sagemaker_session, model)
213248 resource_limit = None ,
214249 )
215250
216- assert sagemaker_session .wait_for_inference_recommendations_job .called_with (IR_JOB_NAME )
251+ sagemaker_session .wait_for_inference_recommendations_job .assert_called_with (
252+ IR_JOB_NAME , log_level = "Verbose"
253+ )
217254
218255 # confirm that the IR instance attributes have been set
219256 assert (
@@ -232,6 +269,7 @@ def test_right_size_default_with_model_name_successful(sagemaker_session, model)
232269@patch ("uuid.uuid4" , MagicMock (return_value = "sample-unique-uuid" ))
233270def test_right_size_advanced_list_instances_model_name_successful (sagemaker_session , model ):
234271 inference_recommender_model = model .right_size (
272+ job_name = IR_JOB_NAME ,
235273 sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
236274 supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
237275 framework = "SAGEMAKER-SCIKIT-LEARN" ,
@@ -246,7 +284,7 @@ def test_right_size_advanced_list_instances_model_name_successful(sagemaker_sess
246284 )
247285
248286 # assert that the create api has been called with advanced parameters
249- assert sagemaker_session .create_inference_recommendations_job .called_with (
287+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
250288 role = IR_ROLE_ARN ,
251289 job_name = IR_JOB_NAME ,
252290 job_type = "Advanced" ,
@@ -256,15 +294,17 @@ def test_right_size_advanced_list_instances_model_name_successful(sagemaker_sess
256294 framework = IR_SAMPLE_FRAMEWORK ,
257295 framework_version = None ,
258296 sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
259- supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
260- supported_instance_types = [ IR_SAMPLE_INSTANCE_TYPE ] ,
297+ supported_content_types = [ "text/csv" ] ,
298+ supported_instance_types = None ,
261299 endpoint_configurations = IR_SAMPLE_ENDPOINT_CONFIG ,
262300 traffic_pattern = IR_SAMPLE_TRAFFIC_PATTERN ,
263301 stopping_conditions = IR_SAMPLE_STOPPING_CONDITIONS ,
264302 resource_limit = IR_SAMPLE_RESOURCE_LIMIT ,
265303 )
266304
267- assert sagemaker_session .wait_for_inference_recommendations_job .called_with (IR_JOB_NAME )
305+ sagemaker_session .wait_for_inference_recommendations_job .assert_called_with (
306+ IR_JOB_NAME , log_level = "Verbose"
307+ )
268308
269309 # confirm that the IR instance attributes have been set
270310 assert (
@@ -283,6 +323,7 @@ def test_right_size_advanced_list_instances_model_name_successful(sagemaker_sess
283323@patch ("uuid.uuid4" , MagicMock (return_value = "sample-unique-uuid" ))
284324def test_right_size_advanced_single_instances_model_name_successful (sagemaker_session , model ):
285325 model .right_size (
326+ job_name = IR_JOB_NAME ,
286327 sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
287328 supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
288329 framework = "SAGEMAKER-SCIKIT-LEARN" ,
@@ -297,7 +338,7 @@ def test_right_size_advanced_single_instances_model_name_successful(sagemaker_se
297338 )
298339
299340 # assert that the create api has been called with advanced parameters
300- assert sagemaker_session .create_inference_recommendations_job .called_with (
341+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
301342 role = IR_ROLE_ARN ,
302343 job_name = IR_JOB_NAME ,
303344 job_type = "Advanced" ,
@@ -308,7 +349,7 @@ def test_right_size_advanced_single_instances_model_name_successful(sagemaker_se
308349 framework_version = None ,
309350 sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
310351 supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
311- supported_instance_types = [ IR_SAMPLE_INSTANCE_TYPE ] ,
352+ supported_instance_types = None ,
312353 endpoint_configurations = IR_SAMPLE_ENDPOINT_CONFIG ,
313354 traffic_pattern = IR_SAMPLE_TRAFFIC_PATTERN ,
314355 stopping_conditions = IR_SAMPLE_STOPPING_CONDITIONS ,
@@ -326,7 +367,7 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
326367 )
327368
328369 # assert that the create api has been called with default parameters
329- assert sagemaker_session .create_inference_recommendations_job .called_with (
370+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
330371 role = IR_ROLE_ARN ,
331372 job_name = IR_JOB_NAME ,
332373 job_type = "Default" ,
@@ -344,7 +385,9 @@ def test_right_size_default_with_model_package_successful(sagemaker_session, mod
344385 resource_limit = None ,
345386 )
346387
347- assert sagemaker_session .wait_for_inference_recommendations_job .called_with (IR_JOB_NAME )
388+ sagemaker_session .wait_for_inference_recommendations_job .assert_called_with (
389+ IR_JOB_NAME , log_level = "Verbose"
390+ )
348391
349392 # confirm that the IR instance attributes have been set
350393 assert (
@@ -364,6 +407,7 @@ def test_right_size_advanced_list_instances_model_package_successful(
364407 sagemaker_session , model_package
365408):
366409 inference_recommender_model_pkg = model_package .right_size (
410+ job_name = IR_JOB_NAME ,
367411 sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
368412 supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
369413 framework = "SAGEMAKER-SCIKIT-LEARN" ,
@@ -378,24 +422,27 @@ def test_right_size_advanced_list_instances_model_package_successful(
378422 )
379423
380424 # assert that the create api has been called with advanced parameters
381- assert sagemaker_session .create_inference_recommendations_job .called_with (
425+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
382426 role = IR_ROLE_ARN ,
383427 job_name = IR_JOB_NAME ,
384428 job_type = "Advanced" ,
385429 job_duration_in_seconds = 7200 ,
430+ model_name = None ,
386431 model_package_version_arn = model_package .model_package_arn ,
387432 framework = IR_SAMPLE_FRAMEWORK ,
388433 framework_version = None ,
389434 sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
390435 supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
391- supported_instance_types = [ IR_SAMPLE_INSTANCE_TYPE ] ,
436+ supported_instance_types = None ,
392437 endpoint_configurations = IR_SAMPLE_ENDPOINT_CONFIG ,
393438 traffic_pattern = IR_SAMPLE_TRAFFIC_PATTERN ,
394439 stopping_conditions = IR_SAMPLE_STOPPING_CONDITIONS ,
395440 resource_limit = IR_SAMPLE_RESOURCE_LIMIT ,
396441 )
397442
398- assert sagemaker_session .wait_for_inference_recommendations_job .called_with (IR_JOB_NAME )
443+ sagemaker_session .wait_for_inference_recommendations_job .assert_called_with (
444+ IR_JOB_NAME , log_level = "Verbose"
445+ )
399446
400447 # confirm that the IR instance attributes have been set
401448 assert (
@@ -415,6 +462,7 @@ def test_right_size_advanced_single_instances_model_package_successful(
415462 sagemaker_session , model_package
416463):
417464 model_package .right_size (
465+ job_name = IR_JOB_NAME ,
418466 sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
419467 supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
420468 framework = "SAGEMAKER-SCIKIT-LEARN" ,
@@ -429,17 +477,18 @@ def test_right_size_advanced_single_instances_model_package_successful(
429477 )
430478
431479 # assert that the create api has been called with advanced parameters
432- assert sagemaker_session .create_inference_recommendations_job .called_with (
480+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
433481 role = IR_ROLE_ARN ,
434482 job_name = IR_JOB_NAME ,
435483 job_type = "Advanced" ,
436484 job_duration_in_seconds = 7200 ,
485+ model_name = None ,
437486 model_package_version_arn = model_package .model_package_arn ,
438487 framework = IR_SAMPLE_FRAMEWORK ,
439488 framework_version = None ,
440489 sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
441490 supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
442- supported_instance_types = [ IR_SAMPLE_INSTANCE_TYPE ] ,
491+ supported_instance_types = None ,
443492 endpoint_configurations = IR_SAMPLE_ENDPOINT_CONFIG ,
444493 traffic_pattern = IR_SAMPLE_TRAFFIC_PATTERN ,
445494 stopping_conditions = IR_SAMPLE_STOPPING_CONDITIONS ,
@@ -451,6 +500,7 @@ def test_right_size_advanced_model_package_partial_params_successful(
451500 sagemaker_session , model_package
452501):
453502 model_package .right_size (
503+ job_name = IR_JOB_NAME ,
454504 sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
455505 supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
456506 framework = "SAGEMAKER-SCIKIT-LEARN" ,
@@ -463,17 +513,18 @@ def test_right_size_advanced_model_package_partial_params_successful(
463513 )
464514
465515 # assert that the create api has been called with advanced parameters
466- assert sagemaker_session .create_inference_recommendations_job .called_with (
516+ sagemaker_session .create_inference_recommendations_job .assert_called_with (
467517 role = IR_ROLE_ARN ,
468518 job_name = IR_JOB_NAME ,
469519 job_type = "Advanced" ,
470520 job_duration_in_seconds = 7200 ,
521+ model_name = None ,
471522 model_package_version_arn = model_package .model_package_arn ,
472523 framework = IR_SAMPLE_FRAMEWORK ,
473524 framework_version = None ,
474525 sample_payload_url = IR_SAMPLE_PAYLOAD_URL ,
475526 supported_content_types = IR_SUPPORTED_CONTENT_TYPES ,
476- supported_instance_types = [ IR_SAMPLE_INSTANCE_TYPE ] ,
527+ supported_instance_types = None ,
477528 endpoint_configurations = IR_SAMPLE_ENDPOINT_CONFIG ,
478529 traffic_pattern = IR_SAMPLE_TRAFFIC_PATTERN ,
479530 stopping_conditions = IR_SAMPLE_STOPPING_CONDITIONS ,
@@ -504,45 +555,42 @@ def test_right_size_invalid_hyperparameter_ranges(sagemaker_session, model_packa
504555# TODO check our framework mapping when we add in inference_recommendation_id support
505556
506557
507- @patch ("sagemaker.production_variant" )
508- @patch ("sagemaker.utils.name_from_base" , return_value = MODEL_NAME )
509558def test_deploy_right_size_with_model_package_succeeds (
510- production_variant , default_right_sized_model
559+ sagemaker_session , default_right_sized_model
511560):
561+
562+ default_right_sized_model .name = MODEL_NAME
512563 default_right_sized_model .deploy (endpoint_name = IR_DEPLOY_ENDPOINT_NAME )
513564
514- assert production_variant .called_with (
515- model_name = MODEL_NAME ,
516- instance_type = IR_RIGHT_SIZE_INSTANCE_TYPE ,
517- initial_instance_count = IR_RIGHT_SIZE_INITIAL_INSTANCE_COUNT ,
518- accelerator_type = None ,
519- serverless_inference_config = None ,
520- volume_size = None ,
521- model_data_download_timeout = None ,
522- container_startup_health_check_timeout = None ,
565+ sagemaker_session .endpoint_from_production_variants .assert_called_with (
566+ async_inference_config_dict = None ,
567+ data_capture_config_dict = None ,
568+ kms_key = None ,
569+ name = "ir-endpoint-test" ,
570+ production_variants = IR_PRODUCTION_VARIANTS ,
571+ tags = None ,
572+ wait = True ,
523573 )
524574
525575
526- @patch ("sagemaker.production_variant" )
527- @patch ("sagemaker.utils.name_from_base" , return_value = MODEL_NAME )
528576def test_deploy_right_size_with_both_overrides_succeeds (
529- production_variant , default_right_sized_model
577+ sagemaker_session , default_right_sized_model
530578):
579+ default_right_sized_model .name = MODEL_NAME
531580 default_right_sized_model .deploy (
532581 instance_type = "ml.c5.2xlarge" ,
533582 initial_instance_count = 5 ,
534583 endpoint_name = IR_DEPLOY_ENDPOINT_NAME ,
535584 )
536585
537- assert production_variant .called_with (
538- model_name = MODEL_NAME ,
539- instance_type = "ml.c5.2xlarge" ,
540- initial_instance_count = 5 ,
541- accelerator_type = None ,
542- serverless_inference_config = None ,
543- volume_size = None ,
544- model_data_download_timeout = None ,
545- container_startup_health_check_timeout = None ,
586+ sagemaker_session .endpoint_from_production_variants .assert_called_with (
587+ async_inference_config_dict = None ,
588+ data_capture_config_dict = None ,
589+ kms_key = None ,
590+ name = "ir-endpoint-test" ,
591+ production_variants = IR_OVERRIDDEN_PRODUCTION_VARIANTS ,
592+ tags = None ,
593+ wait = True ,
546594 )
547595
548596
@@ -576,41 +624,41 @@ def test_deploy_right_size_accelerator_type_fails(default_right_sized_model):
576624 default_right_sized_model .deploy (accelerator_type = "ml.eia.medium" )
577625
578626
579- @patch ("sagemaker.production_variant" )
580- @ patch ( "sagemaker.utils.name_from_base" , return_value = MODEL_NAME )
581- def test_deploy_right_size_serverless_override ( production_variant , default_right_sized_model ):
627+ @patch ("sagemaker.utils.name_from_base" , MagicMock ( return_value = MODEL_NAME ) )
628+ def test_deploy_right_size_serverless_override ( sagemaker_session , default_right_sized_model ):
629+ default_right_sized_model . name = MODEL_NAME
582630 serverless_inference_config = ServerlessInferenceConfig ()
583631 default_right_sized_model .deploy (serverless_inference_config = serverless_inference_config )
584632
585- assert production_variant .called_with (
586- model_name = MODEL_NAME ,
587- instance_type = None ,
588- initial_instance_count = None ,
589- accelerator_type = None ,
590- serverless_inference_config = serverless_inference_config ._to_request_dict ,
591- volume_size = None ,
592- model_data_download_timeout = None ,
593- container_startup_health_check_timeout = None ,
633+ sagemaker_session .endpoint_from_production_variants .assert_called_with (
634+ name = MODEL_NAME ,
635+ production_variants = IR_SERVERLESS_PRODUCTION_VARIANTS ,
636+ tags = None ,
637+ kms_key = None ,
638+ wait = True ,
639+ data_capture_config_dict = None ,
640+ async_inference_config_dict = None ,
594641 )
595642
596643
597- @patch ("sagemaker.utils.name_from_base" , return_value = MODEL_NAME )
644+ @patch ("sagemaker.utils.name_from_base" , MagicMock ( return_value = MODEL_NAME ) )
598645def test_deploy_right_size_async_override (sagemaker_session , default_right_sized_model ):
646+ default_right_sized_model .name = MODEL_NAME
599647 async_inference_config = AsyncInferenceConfig (output_path = "s3://some-path" )
600648 default_right_sized_model .deploy (
601649 instance_type = "ml.c5.2xlarge" ,
602650 initial_instance_count = 1 ,
603651 async_inference_config = async_inference_config ,
604652 )
605653
606- assert sagemaker_session .endpoint_from_production_variants .called_with (
654+ sagemaker_session .endpoint_from_production_variants .assert_called_with (
607655 name = MODEL_NAME ,
608656 production_variants = [ANY ],
609657 tags = None ,
610658 kms_key = None ,
611- wait = None ,
659+ wait = True ,
612660 data_capture_config_dict = None ,
613- async_inference_config_dict = async_inference_config . _to_request_dict ,
661+ async_inference_config_dict = { "OutputConfig" : { "S3OutputPath" : "s3://some-path" }} ,
614662 )
615663
616664
0 commit comments