4040 VulnerableJumpStartModelError ,
4141)
4242from sagemaker .jumpstart .types import JumpStartModelHeader , JumpStartVersionedModelId
43- from tests .unit .sagemaker .jumpstart .utils import get_spec_from_base_spec
43+ from tests .unit .sagemaker .jumpstart .utils import get_spec_from_base_spec , get_prototype_manifest
4444from mock import MagicMock
4545
4646
@@ -1178,7 +1178,7 @@ def test_mime_type_enum_from_str():
11781178class TestIsValidModelId (TestCase ):
11791179 @patch ("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest" )
11801180 @patch ("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs" )
1181- def test_validate_model_id_and_get_type_true (
1181+ def test_validate_model_id_and_get_type_open_weights (
11821182 self ,
11831183 mock_get_model_specs : Mock ,
11841184 mock_get_manifest : Mock ,
@@ -1197,11 +1197,11 @@ def test_validate_model_id_and_get_type_true(
11971197 )
11981198
11991199 with patch ("sagemaker.jumpstart.utils.validate_model_id_and_get_type" , patched ):
1200- self . assertTrue ( utils .validate_model_id_and_get_type ("bee" ))
1200+ assert utils .validate_model_id_and_get_type ("bee" ) == JumpStartModelType . OPEN_WEIGHTS
12011201 mock_get_manifest .assert_called_with (
12021202 region = JUMPSTART_DEFAULT_REGION_NAME ,
12031203 s3_client = mock_s3_client_value ,
1204- model_type = JumpStartModelType .PROPRIETARY ,
1204+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
12051205 )
12061206 mock_get_model_specs .assert_not_called ()
12071207
@@ -1215,25 +1215,30 @@ def test_validate_model_id_and_get_type_true(
12151215 ]
12161216
12171217 mock_get_model_specs .return_value = Mock (training_supported = True )
1218- self .assertTrue (
1218+ self .assertIsNone (
1219+ utils .validate_model_id_and_get_type (
1220+ "invalid" , script = JumpStartScriptScope .TRAINING
1221+ )
1222+ )
1223+ assert (
12191224 utils .validate_model_id_and_get_type ("bee" , script = JumpStartScriptScope .TRAINING )
1225+ == JumpStartModelType .OPEN_WEIGHTS
12201226 )
1227+
12211228 mock_get_manifest .assert_called_with (
12221229 region = JUMPSTART_DEFAULT_REGION_NAME ,
12231230 s3_client = mock_s3_client_value ,
1224- model_type = JumpStartModelType .PROPRIETARY ,
1231+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
12251232 )
12261233
12271234 @patch ("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest" )
12281235 @patch ("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs" )
1229- def test_validate_model_id_and_get_type_false (
1236+ def test_validate_model_id_and_get_type_invalid (
12301237 self , mock_get_model_specs : Mock , mock_get_manifest : Mock
12311238 ):
1232- mock_get_manifest .return_value = [
1233- Mock (model_id = "ay" ),
1234- Mock (model_id = "bee" ),
1235- Mock (model_id = "see" ),
1236- ]
1239+ mock_get_manifest .side_effect = (
1240+ lambda region , model_type , * args , ** kwargs : get_prototype_manifest (region , model_type )
1241+ )
12371242
12381243 mock_session_value = DEFAULT_JUMPSTART_SAGEMAKER_SESSION
12391244 mock_s3_client_value = mock_session_value .s3_client
@@ -1244,10 +1249,10 @@ def test_validate_model_id_and_get_type_false(
12441249
12451250 with patch ("sagemaker.jumpstart.utils.validate_model_id_and_get_type" , patched ):
12461251
1247- self .assertFalse (utils .validate_model_id_and_get_type ("dee" ))
1248- self .assertFalse (utils .validate_model_id_and_get_type ("" ))
1249- self .assertFalse (utils .validate_model_id_and_get_type (None ))
1250- self .assertFalse (utils .validate_model_id_and_get_type (set ()))
1252+ self .assertIsNone (utils .validate_model_id_and_get_type ("dee" ))
1253+ self .assertIsNone (utils .validate_model_id_and_get_type ("" ))
1254+ self .assertIsNone (utils .validate_model_id_and_get_type (None ))
1255+ self .assertIsNone (utils .validate_model_id_and_get_type (set ()))
12511256
12521257 mock_get_manifest .assert_called ()
12531258
@@ -1256,53 +1261,44 @@ def test_validate_model_id_and_get_type_false(
12561261 mock_get_manifest .reset_mock ()
12571262 mock_get_model_specs .reset_mock ()
12581263
1259- mock_get_manifest .return_value = [
1260- Mock (model_id = "ay" ),
1261- Mock (model_id = "bee" ),
1262- Mock (model_id = "see" ),
1263- ]
1264- self .assertFalse (
1265- utils .validate_model_id_and_get_type ("dee" , script = JumpStartScriptScope .TRAINING )
1264+ assert (
1265+ utils .validate_model_id_and_get_type ("ai21-summarization" )
1266+ == JumpStartModelType .PROPRIETARY
12661267 )
1268+ self .assertIsNone (utils .validate_model_id_and_get_type ("ai21-summarization-2" ))
1269+
12671270 mock_get_manifest .assert_called_with (
12681271 region = JUMPSTART_DEFAULT_REGION_NAME ,
12691272 s3_client = mock_s3_client_value ,
12701273 model_type = JumpStartModelType .PROPRIETARY ,
12711274 )
12721275
1273- mock_get_manifest .reset_mock ()
1274-
1275- self .assertFalse (
1276+ self .assertIsNone (
12761277 utils .validate_model_id_and_get_type ("dee" , script = JumpStartScriptScope .TRAINING )
12771278 )
1278- self .assertFalse (
1279+ self .assertIsNone (
12791280 utils .validate_model_id_and_get_type ("" , script = JumpStartScriptScope .TRAINING )
12801281 )
1281- self .assertFalse (
1282+ self .assertIsNone (
12821283 utils .validate_model_id_and_get_type (None , script = JumpStartScriptScope .TRAINING )
12831284 )
1284- self .assertFalse (
1285+ self .assertIsNone (
12851286 utils .validate_model_id_and_get_type (set (), script = JumpStartScriptScope .TRAINING )
12861287 )
12871288
1288- mock_get_model_specs .assert_not_called ()
1289+ assert (
1290+ utils .validate_model_id_and_get_type ("pytorch-eqa-bert-base-cased" )
1291+ == JumpStartModelType .OPEN_WEIGHTS
1292+ )
12891293 mock_get_manifest .assert_called_with (
12901294 region = JUMPSTART_DEFAULT_REGION_NAME ,
12911295 s3_client = mock_s3_client_value ,
1292- model_type = JumpStartModelType .PROPRIETARY ,
1296+ model_type = JumpStartModelType .OPEN_WEIGHTS ,
12931297 )
12941298
1295- mock_get_manifest .reset_mock ()
1296- mock_get_model_specs .reset_mock ()
1297-
1298- mock_get_model_specs .return_value = Mock (training_supported = False )
1299- self .assertTrue (
1300- utils .validate_model_id_and_get_type ("ay" , script = JumpStartScriptScope .TRAINING )
1301- )
1302- mock_get_manifest .assert_called_with (
1303- region = JUMPSTART_DEFAULT_REGION_NAME ,
1304- s3_client = mock_s3_client_value ,
1305- model_type = JumpStartModelType .PROPRIETARY ,
1299+ with pytest .raises (ValueError ):
1300+ utils .validate_model_id_and_get_type (
1301+ "ai21-summarization" , script = JumpStartScriptScope .TRAINING
13061302 )
13071303
13081304
0 commit comments