|
24 | 24 | AquaModel, |
25 | 25 | ModelValidationResult, |
26 | 26 | ) |
| 27 | +from ads.aqua.common.utils import get_hf_model_info |
27 | 28 | import ads.common |
28 | 29 | import ads.common.oci_client |
29 | 30 | import ads.config |
@@ -64,7 +65,7 @@ def mock_get_container_config(): |
64 | 65 | yield mock_config |
65 | 66 |
|
66 | 67 |
|
67 | | -@pytest.fixture(autouse=True, scope="class") |
| 68 | +@pytest.fixture(autouse=True, scope="function") |
68 | 69 | def mock_get_hf_model_info(): |
69 | 70 | with patch.object(HfApi, "model_info") as mock_get_hf_model_info: |
70 | 71 | test_hf_model_info = ModelInfo( |
@@ -266,6 +267,7 @@ def teardown_method(self): |
266 | 267 | self.create_signer_patch.stop() |
267 | 268 | self.validate_config_patch.stop() |
268 | 269 | self.create_client_patch.stop() |
| 270 | + get_hf_model_info.cache_clear() |
269 | 271 |
|
270 | 272 | @classmethod |
271 | 273 | def setup_class(cls): |
@@ -1012,6 +1014,87 @@ def test_import_any_model_smc_container( |
1012 | 1014 | assert model.ready_to_deploy is True |
1013 | 1015 | assert model.ready_to_finetune is True |
1014 | 1016 |
|
| 1017 | + @pytest.mark.parametrize( |
| 1018 | + "download_from_hf", |
| 1019 | + [True, False], |
| 1020 | + ) |
| 1021 | + @patch.object(AquaModelApp, "_find_matching_aqua_model") |
| 1022 | + @patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects") |
| 1023 | + @patch("ads.aqua.common.utils.load_config", return_value={}) |
| 1024 | + @patch("huggingface_hub.snapshot_download") |
| 1025 | + @patch("subprocess.check_call") |
| 1026 | + def test_import_tei_model_byoc( |
| 1027 | + self, |
| 1028 | + mock_subprocess, |
| 1029 | + mock_snapshot_download, |
| 1030 | + mock_load_config, |
| 1031 | + mock_list_objects, |
| 1032 | + mock__find_matching_aqua_model, |
| 1033 | + download_from_hf, |
| 1034 | + mock_get_hf_model_info, |
| 1035 | + ): |
| 1036 | + ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) |
| 1037 | + ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock() |
| 1038 | + DataScienceModel.upload_artifact = MagicMock() |
| 1039 | + DataScienceModel.sync = MagicMock() |
| 1040 | + OCIDataScienceModel.create = MagicMock() |
| 1041 | + |
| 1042 | + artifact_path = "service_models/model-name/commit-id/artifact" |
| 1043 | + obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150) |
| 1044 | + obj1.name = f"{artifact_path}/config.json" |
| 1045 | + objects = [obj1] |
| 1046 | + mock_list_objects.return_value = MagicMock(objects=objects) |
| 1047 | + ds_model = DataScienceModel() |
| 1048 | + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" |
| 1049 | + model_name = "oracle/aqua-1t-mega-model" |
| 1050 | + ds_freeform_tags = { |
| 1051 | + "OCI_AQUA": "ACTIVE", |
| 1052 | + "license": "aqua-license", |
| 1053 | + "organization": "oracle", |
| 1054 | + "task": "text_embedding", |
| 1055 | + } |
| 1056 | + ds_model = ( |
| 1057 | + ds_model.with_compartment_id("test_model_compartment_id") |
| 1058 | + .with_project_id("test_project_id") |
| 1059 | + .with_display_name(model_name) |
| 1060 | + .with_description("test_description") |
| 1061 | + .with_model_version_set_id("test_model_version_set_id") |
| 1062 | + .with_freeform_tags(**ds_freeform_tags) |
| 1063 | + .with_version_id("ocid1.version.id") |
| 1064 | + ) |
| 1065 | + custom_metadata_list = ModelCustomMetadata() |
| 1066 | + custom_metadata_list.add( |
| 1067 | + **{"key": "deployment-container", "value": "odsc-tei-serving"} |
| 1068 | + ) |
| 1069 | + ds_model.with_custom_metadata_list(custom_metadata_list) |
| 1070 | + ds_model.set_spec(ds_model.CONST_MODEL_FILE_DESCRIPTION, {}) |
| 1071 | + DataScienceModel.from_id = MagicMock(return_value=ds_model) |
| 1072 | + mock__find_matching_aqua_model.return_value = None |
| 1073 | + reload(ads.aqua.model.model) |
| 1074 | + app = AquaModelApp() |
| 1075 | + |
| 1076 | + if download_from_hf: |
| 1077 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 1078 | + model: AquaModel = app.register( |
| 1079 | + model=model_name, |
| 1080 | + os_path=os_path, |
| 1081 | + local_dir=str(tmpdir), |
| 1082 | + download_from_hf=True, |
| 1083 | + inference_container="odsc-tei-serving", |
| 1084 | + inference_container_uri="region.ocir.io/your_tenancy/your_image", |
| 1085 | + ) |
| 1086 | + else: |
| 1087 | + model: AquaModel = app.register( |
| 1088 | + model="ocid1.datasciencemodel.xxx.xxxx.", |
| 1089 | + os_path=os_path, |
| 1090 | + download_from_hf=False, |
| 1091 | + inference_container="odsc-tei-serving", |
| 1092 | + inference_container_uri="region.ocir.io/your_tenancy/your_image", |
| 1093 | + ) |
| 1094 | + assert model.inference_container == "odsc-tei-serving" |
| 1095 | + assert model.ready_to_deploy is True |
| 1096 | + assert model.ready_to_finetune is False |
| 1097 | + |
1015 | 1098 | @pytest.mark.parametrize( |
1016 | 1099 | "data, expected_output", |
1017 | 1100 | [ |
|
0 commit comments