33
44from unittest import TestCase
55from unittest .mock import Mock , patch
6- import datetime
76
87import pytest
98from sagemaker .jumpstart .constants import (
1716 get_prototype_manifest ,
1817 get_prototype_model_spec ,
1918)
20- from tests .unit .sagemaker .jumpstart .constants import BASE_PROPRIETARY_MANIFEST
2119from sagemaker .jumpstart .enums import JumpStartModelType
2220from sagemaker .jumpstart .notebook_utils import (
2321 _generate_jumpstart_model_versions ,
@@ -227,10 +225,6 @@ def test_list_jumpstart_models_simple_case(
227225 patched_get_manifest .assert_called ()
228226 patched_get_model_specs .assert_not_called ()
229227
230- @pytest .mark .skipif (
231- datetime .datetime .now () < datetime .datetime (year = 2024 , month = 5 , day = 1 ),
232- reason = "Contact JumpStart team to fix flaky test." ,
233- )
234228 @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
235229 @patch ("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file" )
236230 def test_list_jumpstart_models_script_filter (
@@ -246,23 +240,25 @@ def test_list_jumpstart_models_script_filter(
246240 manifest_length = len (get_prototype_manifest ())
247241 vals = [True , False ]
248242 for val in vals :
249- kwargs = {"filter" : f"training_supported == { val } " }
243+ kwargs = {"filter" : And ( f"training_supported == { val } " , "model_type is open_weights" ) }
250244 list_jumpstart_models (** kwargs )
251245 assert patched_read_s3_file .call_count == manifest_length
252- patched_get_manifest .assert_called_once ()
246+ assert patched_get_manifest .call_count == 2
253247
254248 patched_get_manifest .reset_mock ()
255249 patched_read_s3_file .reset_mock ()
256250
257- kwargs = {"filter" : f"training_supported != { val } " }
251+ kwargs = {"filter" : And ( f"training_supported != { val } " , "model_type is open_weights" ) }
258252 list_jumpstart_models (** kwargs )
259253 assert patched_read_s3_file .call_count == manifest_length
260254 assert patched_get_manifest .call_count == 2
261255
262256 patched_get_manifest .reset_mock ()
263257 patched_read_s3_file .reset_mock ()
264-
265- kwargs = {"filter" : f"training_supported in { vals } " , "list_versions" : True }
258+ kwargs = {
259+ "filter" : And (f"training_supported != { val } " , "model_type is open_weights" ),
260+ "list_versions" : True ,
261+ }
266262 assert list_jumpstart_models (** kwargs ) == [
267263 ("catboost-classification-model" , "1.0.0" ),
268264 ("huggingface-spc-bert-base-cased" , "1.0.0" ),
@@ -279,7 +275,7 @@ def test_list_jumpstart_models_script_filter(
279275 patched_get_manifest .reset_mock ()
280276 patched_read_s3_file .reset_mock ()
281277
282- kwargs = {"filter" : f"training_supported not in { vals } " }
278+ kwargs = {"filter" : And ( f"training_supported not in { vals } " , "model_type is open_weights" ) }
283279 models = list_jumpstart_models (** kwargs )
284280 assert [] == models
285281 assert patched_read_s3_file .call_count == manifest_length
@@ -518,10 +514,6 @@ def get_manifest_more_versions(region: str = JUMPSTART_DEFAULT_REGION_NAME):
518514 list_old_models = False , list_versions = True
519515 ) == list_jumpstart_models (list_versions = True )
520516
521- @pytest .mark .skipif (
522- datetime .datetime .now () < datetime .datetime (year = 2024 , month = 5 , day = 1 ),
523- reason = "Contact JumpStart team to fix flaky test." ,
524- )
525517 @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
526518 @patch ("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file" )
527519 def test_list_jumpstart_models_vulnerable_models (
@@ -547,12 +539,15 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
547539 patched_read_s3_file .side_effect = vulnerable_inference_model_spec
548540
549541 num_specs = len (PROTOTYPICAL_MODEL_SPECS_DICT )
550- num_prop_specs = len (BASE_PROPRIETARY_MANIFEST )
551542 assert [] == list_jumpstart_models (
552- And ("inference_vulnerable is false" , "training_vulnerable is false" )
543+ And (
544+ "inference_vulnerable is false" ,
545+ "training_vulnerable is false" ,
546+ "model_type is open_weights" ,
547+ )
553548 )
554549
555- assert patched_read_s3_file .call_count == num_specs + num_prop_specs
550+ assert patched_read_s3_file .call_count == num_specs
556551 assert patched_get_manifest .call_count == 2
557552
558553 patched_get_manifest .reset_mock ()
@@ -561,10 +556,14 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
561556 patched_read_s3_file .side_effect = vulnerable_training_model_spec
562557
563558 assert [] == list_jumpstart_models (
564- And ("inference_vulnerable is false" , "training_vulnerable is false" )
559+ And (
560+ "inference_vulnerable is false" ,
561+ "training_vulnerable is false" ,
562+ "model_type is open_weights" ,
563+ )
565564 )
566565
567- assert patched_read_s3_file .call_count == num_specs + num_prop_specs
566+ assert patched_read_s3_file .call_count == num_specs
568567 assert patched_get_manifest .call_count == 2
569568
570569 patched_get_manifest .reset_mock ()
@@ -574,10 +573,6 @@ def vulnerable_training_model_spec(bucket, key, *args, **kwargs):
574573
575574 assert patched_read_s3_file .call_count == 0
576575
577- @pytest .mark .skipif (
578- datetime .datetime .now () < datetime .datetime (year = 2024 , month = 5 , day = 1 ),
579- reason = "Contact JumpStart team to fix flaky test." ,
580- )
581576 @patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
582577 @patch ("sagemaker.jumpstart.notebook_utils.DEFAULT_JUMPSTART_SAGEMAKER_SESSION.read_s3_file" )
583578 def test_list_jumpstart_models_deprecated_models (
@@ -598,10 +593,11 @@ def deprecated_model_spec(bucket, key, *args, **kwargs) -> str:
598593 patched_read_s3_file .side_effect = deprecated_model_spec
599594
600595 num_specs = len (PROTOTYPICAL_MODEL_SPECS_DICT )
601- num_prop_specs = len (BASE_PROPRIETARY_MANIFEST )
602- assert [] == list_jumpstart_models ("deprecated equals false" )
596+ assert [] == list_jumpstart_models (
597+ And ("deprecated equals false" , "model_type is open_weights" )
598+ )
603599
604- assert patched_read_s3_file .call_count == num_specs + num_prop_specs
600+ assert patched_read_s3_file .call_count == num_specs
605601 assert patched_get_manifest .call_count == 2
606602
607603 patched_get_manifest .reset_mock ()
0 commit comments