1818import pytest
1919
2020from sagemaker .serve .utils .exceptions import TaskNotFoundException
21+ from sagemaker_schema_inference_artifacts .huggingface import remote_schema_retriever
2122from tests .integ .sagemaker .serve .constants import (
2223 PYTHON_VERSION_IS_NOT_310 ,
2324 SERVE_SAGEMAKER_ENDPOINT_TIMEOUT ,
3132logger = logging .getLogger (__name__ )
3233
3334
34- def test_model_builder_happy_path_with_only_model_id_fill_mask (sagemaker_session ):
35- model_builder = ModelBuilder (model = "bert-base-uncased " )
35+ def test_model_builder_happy_path_with_only_model_id_text_generation (sagemaker_session ):
36+ model_builder = ModelBuilder (model = "HuggingFaceH4/zephyr-7b-beta " )
3637
3738 model = model_builder .build (sagemaker_session = sagemaker_session )
3839
3940 assert model is not None
4041 assert model_builder .schema_builder is not None
4142
42- inputs , outputs = task .retrieve_local_schemas ("fill-mask " )
43- assert model_builder .schema_builder .sample_input == inputs
43+ inputs , outputs = task .retrieve_local_schemas ("text-generation " )
44+ assert model_builder .schema_builder .sample_input [ "inputs" ] == inputs [ "inputs" ]
4445 assert model_builder .schema_builder .sample_output == outputs
4546
4647
48+ def test_model_builder_negative_path (sagemaker_session ):
49+ # A model-task combo unsupported by both the local and remote schema fallback options. (eg: text-to-video)
50+ model_builder = ModelBuilder (model = "ByteDance/AnimateDiff-Lightning" )
51+ with pytest .raises (
52+ TaskNotFoundException ,
53+ match = "Error Message: HuggingFace Schema builder samples for text-to-video could not be found locally or "
54+ "via remote." ,
55+ ):
56+ model_builder .build (sagemaker_session = sagemaker_session )
57+
58+
4759@pytest .mark .skipif (
4860 PYTHON_VERSION_IS_NOT_310 ,
49- reason = "Testing Schema Builder Simplification feature" ,
61+ reason = "Testing Schema Builder Simplification feature - Local Schema " ,
5062)
51- def test_model_builder_happy_path_with_only_model_id_question_answering (
52- sagemaker_session , gpu_instance_type
63+ @pytest .mark .parametrize (
64+ "model_id, task_provided, instance_type_provided, container_startup_timeout" ,
65+ [
66+ (
67+ "distilbert/distilbert-base-uncased-finetuned-sst-2-english" ,
68+ "text-classification" ,
69+ "ml.m5.xlarge" ,
70+ None ,
71+ ),
72+ (
73+ "cardiffnlp/twitter-roberta-base-sentiment-latest" ,
74+ "text-classification" ,
75+ "ml.m5.xlarge" ,
76+ None ,
77+ ),
78+ ("HuggingFaceH4/zephyr-7b-beta" , "text-generation" , "ml.g5.2xlarge" , 900 ),
79+ ("HuggingFaceH4/zephyr-7b-alpha" , "text-generation" , "ml.g5.2xlarge" , 900 ),
80+ ],
81+ )
82+ def test_model_builder_happy_path_with_task_provided_local_schema_mode (
83+ model_id , task_provided , sagemaker_session , instance_type_provided , container_startup_timeout
5384):
54- model_builder = ModelBuilder (model = "bert-large-uncased-whole-word-masking-finetuned-squad" )
85+ model_builder = ModelBuilder (
86+ model = model_id ,
87+ model_metadata = {"HF_TASK" : task_provided },
88+ instance_type = instance_type_provided ,
89+ )
5590
5691 model = model_builder .build (sagemaker_session = sagemaker_session )
5792
5893 assert model is not None
5994 assert model_builder .schema_builder is not None
6095
61- inputs , outputs = task .retrieve_local_schemas ("question-answering" )
62- assert model_builder .schema_builder .sample_input == inputs
96+ inputs , outputs = task .retrieve_local_schemas (task_provided )
97+ if task_provided == "text-generation" :
98+ # ignore 'tokens' and other metadata in this case
99+ assert model_builder .schema_builder .sample_input ["inputs" ] == inputs ["inputs" ]
100+ else :
101+ assert model_builder .schema_builder .sample_input == inputs
63102 assert model_builder .schema_builder .sample_output == outputs
64103
65104 with timeout (minutes = SERVE_SAGEMAKER_ENDPOINT_TIMEOUT ):
@@ -69,9 +108,17 @@ def test_model_builder_happy_path_with_only_model_id_question_answering(
69108 role_arn = iam_client .get_role (RoleName = "SageMakerRole" )["Role" ]["Arn" ]
70109
71110 logger .info ("Deploying and predicting in SAGEMAKER_ENDPOINT mode..." )
72- predictor = model .deploy (
73- role = role_arn , instance_count = 1 , instance_type = gpu_instance_type
74- )
111+ if container_startup_timeout :
112+ predictor = model .deploy (
113+ role = role_arn ,
114+ instance_count = 1 ,
115+ instance_type = instance_type_provided ,
116+ container_startup_health_check_timeout = container_startup_timeout ,
117+ )
118+ else :
119+ predictor = model .deploy (
120+ role = role_arn , instance_count = 1 , instance_type = instance_type_provided
121+ )
75122
76123 predicted_outputs = predictor .predict (inputs )
77124 assert predicted_outputs is not None
@@ -91,38 +138,38 @@ def test_model_builder_happy_path_with_only_model_id_question_answering(
91138 ), f"{ caught_ex } was thrown when running transformers sagemaker endpoint test"
92139
93140
94- def test_model_builder_negative_path (sagemaker_session ):
95- model_builder = ModelBuilder (model = "CompVis/stable-diffusion-v1-4" )
96-
97- with pytest .raises (
98- TaskNotFoundException ,
99- match = "Error Message: Schema builder for text-to-image could not be found." ,
100- ):
101- model_builder .build (sagemaker_session = sagemaker_session )
102-
103-
104141@pytest .mark .skipif (
105142 PYTHON_VERSION_IS_NOT_310 ,
106- reason = "Testing Schema Builder Simplification feature" ,
143+ reason = "Testing Schema Builder Simplification feature - Remote Schema " ,
107144)
108145@pytest .mark .parametrize (
109- "model_id, task_provided" ,
146+ "model_id, task_provided, instance_type_provided " ,
110147 [
111- ("bert-base-uncased" , "fill-mask" ),
112- ("bert-large-uncased-whole-word-masking-finetuned-squad" , "question-answering" ),
148+ ("google-bert/bert-base-uncased" , "fill-mask" , "ml.m5.xlarge" ),
149+ ("google-bert/bert-base-cased" , "fill-mask" , "ml.m5.xlarge" ),
150+ (
151+ "google-bert/bert-large-uncased-whole-word-masking-finetuned-squad" ,
152+ "question-answering" ,
153+ "ml.m5.xlarge" ,
154+ ),
155+ ("deepset/roberta-base-squad2" , "question-answering" , "ml.m5.xlarge" ),
113156 ],
114157)
115- def test_model_builder_happy_path_with_task_provided (
116- model_id , task_provided , sagemaker_session , gpu_instance_type
158+ def test_model_builder_happy_path_with_task_provided_remote_schema_mode (
159+ model_id , task_provided , sagemaker_session , instance_type_provided
117160):
118- model_builder = ModelBuilder (model = model_id , model_metadata = {"HF_TASK" : task_provided })
119-
161+ model_builder = ModelBuilder (
162+ model = model_id ,
163+ model_metadata = {"HF_TASK" : task_provided },
164+ instance_type = instance_type_provided ,
165+ )
120166 model = model_builder .build (sagemaker_session = sagemaker_session )
121167
122168 assert model is not None
123169 assert model_builder .schema_builder is not None
124170
125- inputs , outputs = task .retrieve_local_schemas (task_provided )
171+ remote_hf_schema_helper = remote_schema_retriever .RemoteSchemaRetriever ()
172+ inputs , outputs = remote_hf_schema_helper .get_resolved_hf_schema_for_task (task_provided )
126173 assert model_builder .schema_builder .sample_input == inputs
127174 assert model_builder .schema_builder .sample_output == outputs
128175
@@ -134,7 +181,7 @@ def test_model_builder_happy_path_with_task_provided(
134181
135182 logger .info ("Deploying and predicting in SAGEMAKER_ENDPOINT mode..." )
136183 predictor = model .deploy (
137- role = role_arn , instance_count = 1 , instance_type = gpu_instance_type
184+ role = role_arn , instance_count = 1 , instance_type = instance_type_provided
138185 )
139186
140187 predicted_outputs = predictor .predict (inputs )
@@ -162,6 +209,7 @@ def test_model_builder_negative_path_with_invalid_task(sagemaker_session):
162209
163210 with pytest .raises (
164211 TaskNotFoundException ,
165- match = "Error Message: Schema builder for invalid-task could not be found." ,
212+ match = "Error Message: HuggingFace Schema builder samples for invalid-task could not be found locally or "
213+ "via remote." ,
166214 ):
167215 model_builder .build (sagemaker_session = sagemaker_session )
0 commit comments