@@ -1076,7 +1076,7 @@ def test_build_negative_path_when_schema_builder_not_present(
10761076
10771077 model_builder = ModelBuilder (model = "CompVis/stable-diffusion-v1-4" )
10781078
1079- self .assertRaisesRegexp (
1079+ self .assertRaisesRegex (
10801080 TaskNotFoundException ,
10811081 "Error Message: Schema builder for text-to-image could not be found." ,
10821082 lambda : model_builder .build (sagemaker_session = mock_session ),
@@ -1593,3 +1593,126 @@ def test_total_inference_model_size_mib_throws(
15931593 model_builder .build (sagemaker_session = mock_session )
15941594
15951595 self .assertEqual (model_builder ._can_fit_on_single_gpu (), False )
1596+
1597+ @patch ("sagemaker.serve.builder.tgi_builder.HuggingFaceModel" )
1598+ @patch ("sagemaker.image_uris.retrieve" )
1599+ @patch ("sagemaker.djl_inference.model.urllib" )
1600+ @patch ("sagemaker.djl_inference.model.json" )
1601+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
1602+ @patch ("sagemaker.huggingface.llm_utils.json" )
1603+ @patch ("sagemaker.model_uris.retrieve" )
1604+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
1605+ def test_build_happy_path_override_with_task_provided (
1606+ self ,
1607+ mock_serveSettings ,
1608+ mock_model_uris_retrieve ,
1609+ mock_llm_utils_json ,
1610+ mock_llm_utils_urllib ,
1611+ mock_model_json ,
1612+ mock_model_urllib ,
1613+ mock_image_uris_retrieve ,
1614+ mock_hf_model ,
1615+ ):
1616+ # Setup mocks
1617+
1618+ mock_setting_object = mock_serveSettings .return_value
1619+ mock_setting_object .role_arn = mock_role_arn
1620+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
1621+
1622+ # HF Pipeline Tag
1623+ mock_model_uris_retrieve .side_effect = KeyError
1624+ mock_llm_utils_json .load .return_value = {"pipeline_tag" : "fill-mask" }
1625+ mock_llm_utils_urllib .request .Request .side_effect = Mock ()
1626+
1627+ # HF Model config
1628+ mock_model_json .load .return_value = {"some" : "config" }
1629+ mock_model_urllib .request .Request .side_effect = Mock ()
1630+
1631+ mock_image_uris_retrieve .return_value = "https://some-image-uri"
1632+
1633+ model_builder = ModelBuilder (
1634+ model = "bert-base-uncased" , model_metadata = {"HF_TASK" : "text-generation" }
1635+ )
1636+ model_builder .build (sagemaker_session = mock_session )
1637+
1638+ self .assertIsNotNone (model_builder .schema_builder )
1639+ sample_inputs , sample_outputs = task .retrieve_local_schemas ("text-generation" )
1640+ self .assertEqual (
1641+ sample_inputs ["inputs" ], model_builder .schema_builder .sample_input ["inputs" ]
1642+ )
1643+ self .assertEqual (sample_outputs , model_builder .schema_builder .sample_output )
1644+
1645+ @patch ("sagemaker.image_uris.retrieve" )
1646+ @patch ("sagemaker.djl_inference.model.urllib" )
1647+ @patch ("sagemaker.djl_inference.model.json" )
1648+ @patch ("sagemaker.huggingface.llm_utils.urllib" )
1649+ @patch ("sagemaker.huggingface.llm_utils.json" )
1650+ @patch ("sagemaker.model_uris.retrieve" )
1651+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
1652+ def test_build_task_override_with_invalid_task_provided (
1653+ self ,
1654+ mock_serveSettings ,
1655+ mock_model_uris_retrieve ,
1656+ mock_llm_utils_json ,
1657+ mock_llm_utils_urllib ,
1658+ mock_model_json ,
1659+ mock_model_urllib ,
1660+ mock_image_uris_retrieve ,
1661+ ):
1662+ # Setup mocks
1663+
1664+ mock_setting_object = mock_serveSettings .return_value
1665+ mock_setting_object .role_arn = mock_role_arn
1666+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
1667+
1668+ # HF Pipeline Tag
1669+ mock_model_uris_retrieve .side_effect = KeyError
1670+ mock_llm_utils_json .load .return_value = {"pipeline_tag" : "fill-mask" }
1671+ mock_llm_utils_urllib .request .Request .side_effect = Mock ()
1672+
1673+ # HF Model config
1674+ mock_model_json .load .return_value = {"some" : "config" }
1675+ mock_model_urllib .request .Request .side_effect = Mock ()
1676+
1677+ mock_image_uris_retrieve .return_value = "https://some-image-uri"
1678+ model_ids_with_invalid_task = {
1679+ "bert-base-uncased" : "invalid-task" ,
1680+ "bert-large-uncased-whole-word-masking-finetuned-squad" : "" ,
1681+ }
1682+ for model_id in model_ids_with_invalid_task :
1683+ provided_task = model_ids_with_invalid_task [model_id ]
1684+ model_builder = ModelBuilder (model = model_id , model_metadata = {"HF_TASK" : provided_task })
1685+
1686+ self .assertRaisesRegex (
1687+ TaskNotFoundException ,
1688+ f"Error Message: Schema builder for { provided_task } could not be found." ,
1689+ lambda : model_builder .build (sagemaker_session = mock_session ),
1690+ )
1691+
1692+ @patch ("sagemaker.image_uris.retrieve" )
1693+ @patch ("sagemaker.model_uris.retrieve" )
1694+ @patch ("sagemaker.serve.builder.model_builder._ServeSettings" )
1695+ def test_build_task_override_with_invalid_model_provided (
1696+ self ,
1697+ mock_serveSettings ,
1698+ mock_model_uris_retrieve ,
1699+ mock_image_uris_retrieve ,
1700+ ):
1701+ # Setup mocks
1702+
1703+ mock_setting_object = mock_serveSettings .return_value
1704+ mock_setting_object .role_arn = mock_role_arn
1705+ mock_setting_object .s3_model_data_url = mock_s3_model_data_url
1706+
1707+ # HF Pipeline Tag
1708+ mock_model_uris_retrieve .side_effect = KeyError
1709+
1710+ mock_image_uris_retrieve .return_value = "https://some-image-uri"
1711+ invalid_model_id = ""
1712+ provided_task = "fill-mask"
1713+
1714+ model_builder = ModelBuilder (
1715+ model = invalid_model_id , model_metadata = {"HF_TASK" : provided_task }
1716+ )
1717+ with self .assertRaises (Exception ):
1718+ model_builder .build (sagemaker_session = mock_session )
0 commit comments