Skip to content

Commit 8005294

Browse files
merge ODSC-65657/ignore_config_validation changes
2 parents e3cb8d5 + 71269a7 commit 8005294

File tree

5 files changed

+151
-65
lines changed

5 files changed

+151
-65
lines changed

ads/aqua/extension/model_handler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,10 @@ def post(self, *args, **kwargs): # noqa: ARG002
133133
ignore_patterns = input_data.get("ignore_patterns")
134134
freeform_tags = input_data.get("freeform_tags")
135135
defined_tags = input_data.get("defined_tags")
136+
ignore_model_artifact_check = (
137+
str(input_data.get("ignore_model_artifact_check", "false")).lower()
138+
== "true"
139+
)
136140

137141
return self.finish(
138142
AquaModelApp().register(
@@ -149,6 +153,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
149153
ignore_patterns=ignore_patterns,
150154
freeform_tags=freeform_tags,
151155
defined_tags=defined_tags,
156+
ignore_model_artifact_check=ignore_model_artifact_check,
152157
)
153158
)
154159

ads/aqua/model/entities.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ class ImportModelDetails(CLIBuilderMixin):
293293
ignore_patterns: Optional[List[str]] = None
294294
freeform_tags: Optional[dict] = None
295295
defined_tags: Optional[dict] = None
296+
ignore_model_artifact_check: Optional[bool] = None
296297

297298
def __post_init__(self):
298299
self._command = "model register"

ads/aqua/model/model.py

Lines changed: 76 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919
InferenceContainerTypeFamily,
2020
Tags,
2121
)
22-
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
22+
from ads.aqua.common.errors import (
23+
AquaFileNotFoundError,
24+
AquaRuntimeError,
25+
AquaValueError,
26+
)
2327
from ads.aqua.common.utils import (
2428
LifecycleStatus,
2529
_build_resource_identifier,
@@ -994,13 +998,23 @@ def get_model_files(os_path: str, model_format: ModelFormat) -> List[str]:
994998
# todo: revisit this logic to account for .bin files. In the current state, .bin and .safetensor models
995999
# are grouped in one category and validation checks for config.json files only.
9961000
if model_format == ModelFormat.SAFETENSORS:
1001+
model_files.extend(
1002+
list_os_files_with_extension(oss_path=os_path, extension=".safetensors")
1003+
)
9971004
try:
9981005
load_config(
9991006
file_path=os_path,
10001007
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
10011008
)
1002-
except Exception:
1003-
pass
1009+
except Exception as ex:
1010+
message = (
1011+
f"The model path {os_path} does not contain the file config.json. "
1012+
f"Please check if the path is correct or the model artifacts are available at this location."
1013+
)
1014+
logger.warning(
1015+
f"{message}\n"
1016+
f"Details: {ex.reason if isinstance(ex, AquaFileNotFoundError) else str(ex)}\n"
1017+
)
10041018
else:
10051019
model_files.append(AQUA_MODEL_ARTIFACT_CONFIG)
10061020

@@ -1047,10 +1061,12 @@ def get_hf_model_files(model_name: str, model_format: ModelFormat) -> List[str]:
10471061

10481062
for model_sibling in model_siblings:
10491063
extension = pathlib.Path(model_sibling.rfilename).suffix[1:].upper()
1050-
if model_format == ModelFormat.SAFETENSORS:
1051-
if model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG:
1052-
model_files.append(model_sibling.rfilename)
1053-
elif extension == model_format.value:
1064+
if (
1065+
model_format == ModelFormat.SAFETENSORS
1066+
and model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG
1067+
):
1068+
model_files.append(model_sibling.rfilename)
1069+
if extension == model_format.value:
10541070
model_files.append(model_sibling.rfilename)
10551071

10561072
logger.debug(
@@ -1089,7 +1105,10 @@ def _validate_model(
10891105
safetensors_model_files = self.get_hf_model_files(
10901106
model_name, ModelFormat.SAFETENSORS
10911107
)
1092-
if safetensors_model_files:
1108+
if (
1109+
safetensors_model_files
1110+
and AQUA_MODEL_ARTIFACT_CONFIG in safetensors_model_files
1111+
):
10931112
hf_download_config_present = True
10941113
gguf_model_files = self.get_hf_model_files(model_name, ModelFormat.GGUF)
10951114
else:
@@ -1145,11 +1164,11 @@ def _validate_model(
11451164
Tags.LICENSE: license_value,
11461165
}
11471166
validation_result.tags = hf_tags
1148-
except Exception:
1167+
except Exception as ex:
11491168
logger.debug(
1150-
f"Could not process tags from Hugging Face model details for model {model_name}."
1169+
f"An error occurred while getting tag information for model {model_name}. "
1170+
f"Error: {str(ex)}"
11511171
)
1152-
pass
11531172

11541173
validation_result.model_formats = model_formats
11551174

@@ -1204,40 +1223,55 @@ def _validate_safetensor_format(
12041223
model_name: str = None,
12051224
):
12061225
if import_model_details.download_from_hf:
1207-
# validates config.json exists for safetensors model from hugginface
1208-
if not hf_download_config_present:
1226+
# validates config.json exists for safetensors model from huggingface
1227+
if not (
1228+
hf_download_config_present
1229+
or import_model_details.ignore_model_artifact_check
1230+
):
12091231
raise AquaRuntimeError(
12101232
f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required "
12111233
f"by {ModelFormat.SAFETENSORS.value} format model."
12121234
f" Please check if the model name is correct in Hugging Face repository."
12131235
)
1236+
validation_result.telemetry_model_name = model_name
12141237
else:
1238+
# validate if config.json is available from object storage, and get model name for telemetry
1239+
model_config = None
12151240
try:
12161241
model_config = load_config(
12171242
file_path=import_model_details.os_path,
12181243
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
12191244
)
12201245
except Exception as ex:
1221-
logger.error(
1222-
f"Exception occurred while loading config file from {import_model_details.os_path}"
1223-
f"Exception message: {ex}"
1224-
)
1225-
raise AquaRuntimeError(
1246+
message = (
12261247
f"The model path {import_model_details.os_path} does not contain the file config.json. "
12271248
f"Please check if the path is correct or the model artifacts are available at this location."
1228-
) from ex
1229-
else:
1249+
)
1250+
if not import_model_details.ignore_model_artifact_check:
1251+
logger.error(
1252+
f"{message}\n"
1253+
f"Details: {ex.reason if isinstance(ex, AquaFileNotFoundError) else str(ex)}"
1254+
)
1255+
raise AquaRuntimeError(message) from ex
1256+
else:
1257+
logger.warning(
1258+
f"{message}\n"
1259+
f"Proceeding with model registration as ignore_model_artifact_check field is set."
1260+
)
1261+
1262+
if verified_model:
1263+
# model_type validation, log message if metadata field doesn't match.
12301264
try:
12311265
metadata_model_type = verified_model.custom_metadata_list.get(
12321266
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
12331267
).value
1234-
if metadata_model_type:
1268+
if metadata_model_type and model_config is not None:
12351269
if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
12361270
if (
12371271
model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]
12381272
!= metadata_model_type
12391273
):
1240-
raise AquaRuntimeError(
1274+
logger.debug(
12411275
f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
12421276
f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for "
12431277
f"the model {model_name}. Please check if the path is correct or "
@@ -1249,22 +1283,26 @@ def _validate_safetensor_format(
12491283
f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
12501284
f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
12511285
)
1252-
except Exception:
1253-
pass
1254-
if verified_model:
1255-
validation_result.telemetry_model_name = verified_model.display_name
1256-
elif (
1257-
model_config is not None
1258-
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
1259-
):
1260-
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
1261-
elif (
1262-
model_config is not None
1263-
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
1264-
):
1265-
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
1266-
else:
1267-
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
1286+
except Exception as ex:
1287+
# todo: raise exception if model_type doesn't match. Currently log message and pass since service
1288+
# models do not have this metadata.
1289+
logger.debug(
1290+
f"Error occurred while processing metadata for model {model_name}. "
1291+
f"Exception: {str(ex)}"
1292+
)
1293+
validation_result.telemetry_model_name = verified_model.display_name
1294+
elif (
1295+
model_config is not None
1296+
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
1297+
):
1298+
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
1299+
elif (
1300+
model_config is not None
1301+
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
1302+
):
1303+
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
1304+
else:
1305+
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
12681306

12691307
@staticmethod
12701308
def _validate_gguf_format(
@@ -1454,7 +1492,6 @@ def register(
14541492
).rstrip("/")
14551493
else:
14561494
artifact_path = import_model_details.os_path.rstrip("/")
1457-
14581495
# Create Model catalog entry with pass by reference
14591496
ds_model = self._create_model_catalog_entry(
14601497
os_path=artifact_path,

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -920,10 +920,18 @@ def test_import_model_with_project_compartment_override(
920920
assert model.project_id == project_override
921921

922922
@pytest.mark.parametrize(
923-
"download_from_hf",
924-
[True, False],
923+
("ignore_artifact_check", "download_from_hf"),
924+
[
925+
(True, True),
926+
(True, False),
927+
(False, True),
928+
(False, False),
929+
(None, False),
930+
(None, True),
931+
],
925932
)
926933
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
934+
@patch("ads.model.datascience_model.DataScienceModel.sync")
927935
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
928936
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
929937
@patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError)
@@ -936,45 +944,65 @@ def test_import_model_with_missing_config(
936944
mock_load_config,
937945
mock_list_objects,
938946
mock_upload_artifact,
947+
mock_sync,
939948
mock_ocidsc_create,
940-
mock_get_container_config,
949+
ignore_artifact_check,
941950
download_from_hf,
942951
mock_get_hf_model_info,
943952
mock_init_client,
944953
):
945-
"""Test for validating if error is returned when model artifacts are incomplete or not available."""
946-
947-
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
948-
model_name = "oracle/aqua-1t-mega-model"
954+
my_model = "oracle/aqua-1t-mega-model"
949955
ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True)
950-
mock_list_objects.return_value = MagicMock(objects=[])
951-
reload(ads.aqua.model.model)
952-
app = AquaModelApp()
953-
app.list = MagicMock(return_value=[])
956+
# set object list from OSS without config.json
957+
os_path = "oci://aqua-bkt@aqua-ns/prefix/path"
954958

959+
# set object list from HF without config.json
955960
if download_from_hf:
956-
with pytest.raises(AquaValueError):
957-
mock_get_hf_model_info.return_value.siblings = []
958-
with tempfile.TemporaryDirectory() as tmpdir:
959-
model: AquaModel = app.register(
960-
model=model_name,
961-
os_path=os_path,
962-
local_dir=str(tmpdir),
963-
download_from_hf=True,
964-
)
961+
mock_get_hf_model_info.return_value.siblings = [
962+
MagicMock(rfilename="model.safetensors")
963+
]
965964
else:
966-
with pytest.raises(AquaRuntimeError):
965+
obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150)
966+
obj1.name = f"prefix/path/model.safetensors"
967+
objects = [obj1]
968+
mock_list_objects.return_value = MagicMock(objects=objects)
969+
970+
reload(ads.aqua.model.model)
971+
app = AquaModelApp()
972+
with patch.object(AquaModelApp, "list") as aqua_model_mock_list:
973+
aqua_model_mock_list.return_value = [
974+
AquaModelSummary(
975+
id="test_id1",
976+
name="organization1/name1",
977+
organization="organization1",
978+
)
979+
]
980+
981+
if ignore_artifact_check:
967982
model: AquaModel = app.register(
968-
model=model_name,
983+
model=my_model,
969984
os_path=os_path,
970-
download_from_hf=False,
985+
inference_container="odsc-vllm-or-tgi-container",
986+
finetuning_container="odsc-llm-fine-tuning",
987+
download_from_hf=download_from_hf,
988+
ignore_model_artifact_check=ignore_artifact_check,
971989
)
990+
assert model.ready_to_deploy is True
991+
else:
992+
with pytest.raises(AquaRuntimeError):
993+
model: AquaModel = app.register(
994+
model=my_model,
995+
os_path=os_path,
996+
inference_container="odsc-vllm-or-tgi-container",
997+
finetuning_container="odsc-llm-fine-tuning",
998+
download_from_hf=download_from_hf,
999+
ignore_model_artifact_check=ignore_artifact_check,
1000+
)
9721001

9731002
@patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create")
9741003
@patch("ads.model.datascience_model.DataScienceModel.sync")
9751004
@patch("ads.model.datascience_model.DataScienceModel.upload_artifact")
9761005
@patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects")
977-
@patch.object(HfApi, "model_info")
9781006
@patch("ads.aqua.common.utils.load_config", return_value={})
9791007
def test_import_any_model_smc_container(
9801008
self,
@@ -1230,6 +1258,15 @@ def test_import_model_with_input_tags(
12301258
"--download_from_hf True --inference_container odsc-vllm-serving --freeform_tags "
12311259
'{"ftag1": "fvalue1", "ftag2": "fvalue2"} --defined_tags {"dtag1": "dvalue1", "dtag2": "dvalue2"}',
12321260
),
1261+
(
1262+
{
1263+
"os_path": "oci://aqua-bkt@aqua-ns/path",
1264+
"model": "oracle/oracle-1it",
1265+
"inference_container": "odsc-vllm-serving",
1266+
"ignore_model_artifact_check": True,
1267+
},
1268+
"ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-vllm-serving --ignore_model_artifact_check True",
1269+
),
12331270
],
12341271
)
12351272
def test_import_cli(self, data, expected_output):

0 commit comments

Comments
 (0)