Skip to content

Commit 0b8d994

Browse files
add model tests
1 parent 71b1e9f commit 0b8d994

File tree

1 file changed

+84
-1
lines changed

1 file changed

+84
-1
lines changed

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
AquaModel,
2525
ModelValidationResult,
2626
)
27+
from ads.aqua.common.utils import get_hf_model_info
2728
import ads.common
2829
import ads.common.oci_client
2930
import ads.config
@@ -64,7 +65,7 @@ def mock_get_container_config():
6465
yield mock_config
6566

6667

67-
@pytest.fixture(autouse=True, scope="class")
68+
@pytest.fixture(autouse=True, scope="function")
6869
def mock_get_hf_model_info():
6970
with patch.object(HfApi, "model_info") as mock_get_hf_model_info:
7071
test_hf_model_info = ModelInfo(
@@ -266,6 +267,7 @@ def teardown_method(self):
266267
self.create_signer_patch.stop()
267268
self.validate_config_patch.stop()
268269
self.create_client_patch.stop()
270+
get_hf_model_info.cache_clear()
269271

270272
@classmethod
271273
def setup_class(cls):
@@ -1012,6 +1014,87 @@ def test_import_any_model_smc_container(
10121014
assert model.ready_to_deploy is True
10131015
assert model.ready_to_finetune is True
10141016

1017+
@pytest.mark.parametrize(
1018+
"download_from_hf",
1019+
[True, False],
1020+
)
1021+
@patch.object(AquaModelApp, "_find_matching_aqua_model")
1022+
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
1023+
@patch("ads.aqua.common.utils.load_config", return_value={})
1024+
@patch("huggingface_hub.snapshot_download")
1025+
@patch("subprocess.check_call")
1026+
def test_import_tei_model_byoc(
1027+
self,
1028+
mock_subprocess,
1029+
mock_snapshot_download,
1030+
mock_load_config,
1031+
mock_list_objects,
1032+
mock__find_matching_aqua_model,
1033+
download_from_hf,
1034+
mock_get_hf_model_info,
1035+
):
1036+
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
1037+
ads.common.oci_datascience.OCIDataScienceMixin.init_client = MagicMock()
1038+
DataScienceModel.upload_artifact = MagicMock()
1039+
DataScienceModel.sync = MagicMock()
1040+
OCIDataScienceModel.create = MagicMock()
1041+
1042+
artifact_path = "service_models/model-name/commit-id/artifact"
1043+
obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150)
1044+
obj1.name = f"{artifact_path}/config.json"
1045+
objects = [obj1]
1046+
mock_list_objects.return_value = MagicMock(objects=objects)
1047+
ds_model = DataScienceModel()
1048+
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
1049+
model_name = "oracle/aqua-1t-mega-model"
1050+
ds_freeform_tags = {
1051+
"OCI_AQUA": "ACTIVE",
1052+
"license": "aqua-license",
1053+
"organization": "oracle",
1054+
"task": "text_embedding",
1055+
}
1056+
ds_model = (
1057+
ds_model.with_compartment_id("test_model_compartment_id")
1058+
.with_project_id("test_project_id")
1059+
.with_display_name(model_name)
1060+
.with_description("test_description")
1061+
.with_model_version_set_id("test_model_version_set_id")
1062+
.with_freeform_tags(**ds_freeform_tags)
1063+
.with_version_id("ocid1.version.id")
1064+
)
1065+
custom_metadata_list = ModelCustomMetadata()
1066+
custom_metadata_list.add(
1067+
**{"key": "deployment-container", "value": "odsc-tei-serving"}
1068+
)
1069+
ds_model.with_custom_metadata_list(custom_metadata_list)
1070+
ds_model.set_spec(ds_model.CONST_MODEL_FILE_DESCRIPTION, {})
1071+
DataScienceModel.from_id = MagicMock(return_value=ds_model)
1072+
mock__find_matching_aqua_model.return_value = None
1073+
reload(ads.aqua.model.model)
1074+
app = AquaModelApp()
1075+
1076+
if download_from_hf:
1077+
with tempfile.TemporaryDirectory() as tmpdir:
1078+
model: AquaModel = app.register(
1079+
model=model_name,
1080+
os_path=os_path,
1081+
local_dir=str(tmpdir),
1082+
download_from_hf=True,
1083+
inference_container="odsc-tei-serving",
1084+
inference_container_uri="region.ocir.io/your_tenancy/your_image",
1085+
)
1086+
else:
1087+
model: AquaModel = app.register(
1088+
model="ocid1.datasciencemodel.xxx.xxxx.",
1089+
os_path=os_path,
1090+
download_from_hf=False,
1091+
inference_container="odsc-tei-serving",
1092+
inference_container_uri="region.ocir.io/your_tenancy/your_image",
1093+
)
1094+
assert model.inference_container == "odsc-tei-serving"
1095+
assert model.ready_to_deploy is True
1096+
assert model.ready_to_finetune is False
1097+
10151098
@pytest.mark.parametrize(
10161099
"data, expected_output",
10171100
[

0 commit comments

Comments
 (0)