@@ -678,6 +678,116 @@ def test_jumpstart_model_package_arn_unsupported_region(
678678 "us-east-2. Please try one of the following regions: us-west-2, us-east-1."
679679 )
680680
681+ @mock .patch ("sagemaker.utils.sagemaker_timestamp" )
682+ @mock .patch ("sagemaker.jumpstart.model.is_valid_model_id" )
683+ @mock .patch ("sagemaker.jumpstart.factory.model.Session" )
684+ @mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
685+ @mock .patch ("sagemaker.jumpstart.model.Model.__init__" )
686+ @mock .patch ("sagemaker.jumpstart.model.Model.deploy" )
687+ @mock .patch ("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME" , region )
688+ @mock .patch ("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER.info" )
689+ def test_model_data_s3_prefix_override (
690+ self ,
691+ mock_js_info_logger : mock .Mock ,
692+ mock_model_deploy : mock .Mock ,
693+ mock_model_init : mock .Mock ,
694+ mock_get_model_specs : mock .Mock ,
695+ mock_session : mock .Mock ,
696+ mock_is_valid_model_id : mock .Mock ,
697+ mock_sagemaker_timestamp : mock .Mock ,
698+ ):
699+ mock_model_deploy .return_value = default_predictor
700+
701+ mock_sagemaker_timestamp .return_value = "7777"
702+
703+ mock_is_valid_model_id .return_value = True
704+ model_id , _ = "js-trainable-model" , "*"
705+
706+ mock_get_model_specs .side_effect = get_special_model_spec
707+
708+ mock_session .return_value = sagemaker_session
709+
710+ JumpStartModel (model_id = model_id , model_data = "s3://some-bucket/path/to/prefix/" )
711+
712+ mock_model_init .assert_called_once_with (
713+ image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/"
714+ "autogluon-inference:0.4.3-gpu-py38" ,
715+ model_data = {
716+ "S3DataSource" : {
717+ "S3Uri" : "s3://some-bucket/path/to/prefix/" ,
718+ "S3DataType" : "S3Prefix" ,
719+ "CompressionType" : "None" ,
720+ }
721+ },
722+ source_dir = "s3://jumpstart-cache-prod-us-west-2/source-directory-"
723+ "tarballs/autogluon/inference/classification/v1.0.0/sourcedir.tar.gz" ,
724+ entry_point = "inference.py" ,
725+ env = {
726+ "SAGEMAKER_PROGRAM" : "inference.py" ,
727+ "ENDPOINT_SERVER_TIMEOUT" : "3600" ,
728+ "MODEL_CACHE_ROOT" : "/opt/ml/model" ,
729+ "SAGEMAKER_ENV" : "1" ,
730+ "SAGEMAKER_MODEL_SERVER_WORKERS" : "1" ,
731+ },
732+ predictor_cls = Predictor ,
733+ role = execution_role ,
734+ sagemaker_session = sagemaker_session ,
735+ enable_network_isolation = False ,
736+ name = "blahblahblah-7777" ,
737+ )
738+
739+ mock_js_info_logger .assert_called_with (
740+ "S3 prefix model_data detected for JumpStartModel: '%s'. "
741+ "Converting to S3DataSource dictionary: '%s'." ,
742+ "s3://some-bucket/path/to/prefix/" ,
743+ '{"S3DataSource": {"S3Uri": "s3://some-bucket/path/to/prefix/", '
744+ '"S3DataType": "S3Prefix", "CompressionType": "None"}}' ,
745+ )
746+
747+ @mock .patch ("sagemaker.jumpstart.model.is_valid_model_id" )
748+ @mock .patch ("sagemaker.jumpstart.factory.model.Session" )
749+ @mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
750+ @mock .patch ("sagemaker.jumpstart.model.Model.__init__" )
751+ @mock .patch ("sagemaker.jumpstart.model.Model.deploy" )
752+ @mock .patch ("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME" , region )
753+ @mock .patch ("sagemaker.jumpstart.factory.model.JUMPSTART_LOGGER.info" )
754+ def test_model_data_s3_prefix_model (
755+ self ,
756+ mock_js_info_logger : mock .Mock ,
757+ mock_model_deploy : mock .Mock ,
758+ mock_model_init : mock .Mock ,
759+ mock_get_model_specs : mock .Mock ,
760+ mock_session : mock .Mock ,
761+ mock_is_valid_model_id : mock .Mock ,
762+ ):
763+ mock_model_deploy .return_value = default_predictor
764+
765+ mock_is_valid_model_id .return_value = True
766+ model_id , _ = "model_data_s3_prefix_model" , "*"
767+
768+ mock_get_model_specs .side_effect = get_special_model_spec
769+
770+ mock_session .return_value = sagemaker_session
771+
772+ JumpStartModel (model_id = model_id , instance_type = "ml.p2.xlarge" )
773+
774+ mock_model_init .assert_called_once_with (
775+ image_uri = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-inference:1.12.0-gpu-py38" ,
776+ model_data = {
777+ "S3DataSource" : {
778+ "S3Uri" : "s3://jumpstart-cache-prod-us-west-2/huggingface-infer/prepack/v1.0.1/" ,
779+ "S3DataType" : "S3Prefix" ,
780+ "CompressionType" : "None" ,
781+ }
782+ },
783+ predictor_cls = Predictor ,
784+ role = execution_role ,
785+ sagemaker_session = sagemaker_session ,
786+ enable_network_isolation = False ,
787+ )
788+
789+ mock_js_info_logger .assert_not_called ()
790+
681791
682792def test_jumpstart_model_requires_model_id ():
683793 with pytest .raises (ValueError ):
0 commit comments