From 9cc430b345cc6d79ae911d3566c76e165e568c38 Mon Sep 17 00:00:00 2001 From: Vipul Date: Wed, 11 Dec 2024 17:26:44 -0800 Subject: [PATCH 01/18] add ignore validation flag while registering model --- ads/aqua/extension/model_handler.py | 5 ++ ads/aqua/model/entities.py | 1 + ads/aqua/model/model.py | 74 ++++++++++------- tests/unitary/with_extras/aqua/test_model.py | 85 ++++++++++++++------ 4 files changed, 113 insertions(+), 52 deletions(-) diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 42f90ffef..c7d050e84 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -133,6 +133,10 @@ def post(self, *args, **kwargs): # noqa: ARG002 ignore_patterns = input_data.get("ignore_patterns") freeform_tags = input_data.get("freeform_tags") defined_tags = input_data.get("defined_tags") + ignore_model_artifact_check = ( + str(input_data.get("ignore_model_artifact_check", "false")).lower() + == "true" + ) return self.finish( AquaModelApp().register( @@ -149,6 +153,7 @@ def post(self, *args, **kwargs): # noqa: ARG002 ignore_patterns=ignore_patterns, freeform_tags=freeform_tags, defined_tags=defined_tags, + ignore_model_artifact_check=ignore_model_artifact_check, ) ) diff --git a/ads/aqua/model/entities.py b/ads/aqua/model/entities.py index ecdb8b8e7..2d6d93cd8 100644 --- a/ads/aqua/model/entities.py +++ b/ads/aqua/model/entities.py @@ -293,6 +293,7 @@ class ImportModelDetails(CLIBuilderMixin): ignore_patterns: Optional[List[str]] = None freeform_tags: Optional[dict] = None defined_tags: Optional[dict] = None + ignore_model_artifact_check: Optional[bool] = None def __post_init__(self): self._command = "model register" diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 02e0df00f..857a6135d 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -972,6 +972,9 @@ def get_model_files(os_path: str, model_format: ModelFormat) -> List[str]: # todo: revisit this logic to account for .bin files. In the current state, .bin and .safetensor models # are grouped in one category and validation checks for config.json files only. if model_format == ModelFormat.SAFETENSORS: + model_files.extend( + list_os_files_with_extension(oss_path=os_path, extension=".safetensors") + ) try: load_config( file_path=os_path, @@ -1022,10 +1025,12 @@ def get_hf_model_files(model_name: str, model_format: ModelFormat) -> List[str]: for model_sibling in model_siblings: extension = pathlib.Path(model_sibling.rfilename).suffix[1:].upper() - if model_format == ModelFormat.SAFETENSORS: - if model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG: - model_files.append(model_sibling.rfilename) - elif extension == model_format.value: + if ( + model_format == ModelFormat.SAFETENSORS + and model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG + ): + model_files.append(model_sibling.rfilename) + if extension == model_format.value: model_files.append(model_sibling.rfilename) return model_files @@ -1061,7 +1066,10 @@ def _validate_model( safetensors_model_files = self.get_hf_model_files( model_name, ModelFormat.SAFETENSORS ) - if safetensors_model_files: + if ( + safetensors_model_files + and AQUA_MODEL_ARTIFACT_CONFIG in safetensors_model_files + ): hf_download_config_present = True gguf_model_files = self.get_hf_model_files(model_name, ModelFormat.GGUF) else: @@ -1173,14 +1181,20 @@ def _validate_safetensor_format( model_name: str = None, ): if import_model_details.download_from_hf: - # validates config.json exists for safetensors model from hugginface - if not hf_download_config_present: + # validates config.json exists for safetensors model from huggingface + if not ( + hf_download_config_present + or import_model_details.ignore_model_artifact_check + ): raise AquaRuntimeError( f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required " f"by {ModelFormat.SAFETENSORS.value} format model." f" Please check if the model name is correct in Hugging Face repository." ) + validation_result.telemetry_model_name = model_name else: + # validate if config.json is available from object storage, and get model name for telemetry + model_config = None try: model_config = load_config( file_path=import_model_details.os_path, @@ -1191,22 +1205,25 @@ def _validate_safetensor_format( f"Exception occurred while loading config file from {import_model_details.os_path}" f"Exception message: {ex}" ) - raise AquaRuntimeError( - f"The model path {import_model_details.os_path} does not contain the file config.json. " - f"Please check if the path is correct or the model artifacts are available at this location." - ) from ex - else: + if not import_model_details.ignore_model_artifact_check: + raise AquaRuntimeError( + f"The model path {import_model_details.os_path} does not contain the file config.json. " + f"Please check if the path is correct or the model artifacts are available at this location." + ) from ex + + if verified_model: + # model_type validation, log message if metadata field doesn't match. try: metadata_model_type = verified_model.custom_metadata_list.get( AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE ).value - if metadata_model_type: + if metadata_model_type and model_config is not None: if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config: if ( model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE] != metadata_model_type ): - raise AquaRuntimeError( + logger.debug( f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}" f" at {import_model_details.os_path} is invalid, expected {metadata_model_type} for " f"the model {model_name}. Please check if the path is correct or " @@ -1219,21 +1236,22 @@ def _validate_safetensor_format( f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration." ) except Exception: + # todo: raise exception if model_type doesn't match. Currently log message and pass since service + # models do not have this metadata. pass - if verified_model: - validation_result.telemetry_model_name = verified_model.display_name - elif ( - model_config is not None - and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config - ): - validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}" - elif ( - model_config is not None - and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config - ): - validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}" - else: - validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM + validation_result.telemetry_model_name = verified_model.display_name + elif ( + model_config is not None + and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config + ): + validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}" + elif ( + model_config is not None + and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config + ): + validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}" + else: + validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM @staticmethod def _validate_gguf_format( diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index cabb8c523..569902c02 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -920,10 +920,18 @@ def test_import_model_with_project_compartment_override( assert model.project_id == project_override @pytest.mark.parametrize( - "download_from_hf", - [True, False], + ("ignore_artifact_check", "download_from_hf"), + [ + (True, True), + (True, False), + (False, True), + (False, False), + (None, False), + (None, True), + ], ) @patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create") + @patch("ads.model.datascience_model.DataScienceModel.sync") @patch("ads.model.datascience_model.DataScienceModel.upload_artifact") @patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects") @patch("ads.aqua.common.utils.load_config", side_effect=AquaFileNotFoundError) @@ -936,45 +944,65 @@ def test_import_model_with_missing_config( mock_load_config, mock_list_objects, mock_upload_artifact, + mock_sync, mock_ocidsc_create, - mock_get_container_config, + ignore_artifact_check, download_from_hf, mock_get_hf_model_info, mock_init_client, ): - """Test for validating if error is returned when model artifacts are incomplete or not available.""" - - os_path = "oci://aqua-bkt@aqua-ns/prefix/path" - model_name = "oracle/aqua-1t-mega-model" + my_model = "oracle/aqua-1t-mega-model" ObjectStorageDetails.is_bucket_versioned = MagicMock(return_value=True) - mock_list_objects.return_value = MagicMock(objects=[]) - reload(ads.aqua.model.model) - app = AquaModelApp() - app.list = MagicMock(return_value=[]) + # set object list from OSS without config.json + os_path = "oci://aqua-bkt@aqua-ns/prefix/path" + # set object list from HF without config.json if download_from_hf: - with pytest.raises(AquaValueError): - mock_get_hf_model_info.return_value.siblings = [] - with tempfile.TemporaryDirectory() as tmpdir: - model: AquaModel = app.register( - model=model_name, - os_path=os_path, - local_dir=str(tmpdir), - download_from_hf=True, - ) + mock_get_hf_model_info.return_value.siblings = [ + MagicMock(rfilename="model.safetensors") + ] else: - with pytest.raises(AquaRuntimeError): + obj1 = MagicMock(etag="12345-1234-1234-1234-123456789", size=150) + obj1.name = f"prefix/path/model.safetensors" + objects = [obj1] + mock_list_objects.return_value = MagicMock(objects=objects) + + reload(ads.aqua.model.model) + app = AquaModelApp() + with patch.object(AquaModelApp, "list") as aqua_model_mock_list: + aqua_model_mock_list.return_value = [ + AquaModelSummary( + id="test_id1", + name="organization1/name1", + organization="organization1", + ) + ] + + if ignore_artifact_check: model: AquaModel = app.register( - model=model_name, + model=my_model, os_path=os_path, - download_from_hf=False, + inference_container="odsc-vllm-or-tgi-container", + finetuning_container="odsc-llm-fine-tuning", + download_from_hf=download_from_hf, + ignore_model_artifact_check=ignore_artifact_check, ) + assert model.ready_to_deploy is True + else: + with pytest.raises(AquaRuntimeError): + model: AquaModel = app.register( + model=my_model, + os_path=os_path, + inference_container="odsc-vllm-or-tgi-container", + finetuning_container="odsc-llm-fine-tuning", + download_from_hf=download_from_hf, + ignore_model_artifact_check=ignore_artifact_check, + ) @patch("ads.model.service.oci_datascience_model.OCIDataScienceModel.create") @patch("ads.model.datascience_model.DataScienceModel.sync") @patch("ads.model.datascience_model.DataScienceModel.upload_artifact") @patch("ads.common.object_storage_details.ObjectStorageDetails.list_objects") - @patch.object(HfApi, "model_info") @patch("ads.aqua.common.utils.load_config", return_value={}) def test_import_any_model_smc_container( self, @@ -1230,6 +1258,15 @@ def test_import_model_with_input_tags( "--download_from_hf True --inference_container odsc-vllm-serving --freeform_tags " '{"ftag1": "fvalue1", "ftag2": "fvalue2"} --defined_tags {"dtag1": "dvalue1", "dtag2": "dvalue2"}', ), + ( + { + "os_path": "oci://aqua-bkt@aqua-ns/path", + "model": "oracle/oracle-1it", + "inference_container": "odsc-vllm-serving", + "ignore_model_artifact_check": True, + }, + "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-vllm-serving --ignore_model_artifact_check True", + ), ], ) def test_import_cli(self, data, expected_output): From bd4205e65a32a98f0dd1bcee9585284a5a6fc98f Mon Sep 17 00:00:00 2001 From: Vipul Date: Thu, 12 Dec 2024 12:10:07 -0800 Subject: [PATCH 02/18] update logging --- ads/aqua/model/model.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 857a6135d..89949b68f 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -980,8 +980,8 @@ def get_model_files(os_path: str, model_format: ModelFormat) -> List[str]: file_path=os_path, config_file_name=AQUA_MODEL_ARTIFACT_CONFIG, ) - except Exception: - pass + except Exception as ex: + logger.warning(str(ex)) else: model_files.append(AQUA_MODEL_ARTIFACT_CONFIG) @@ -1125,8 +1125,11 @@ def _validate_model( Tags.LICENSE: license_value, } validation_result.tags = hf_tags - except Exception: - pass + except Exception as ex: + logger.debug( + f"An error occurred while getting tag information for model {model_name}. " + f"Error: {str(ex)}" + ) validation_result.model_formats = model_formats @@ -1235,10 +1238,13 @@ def _validate_safetensor_format( f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in " f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration." ) - except Exception: + except Exception as ex: # todo: raise exception if model_type doesn't match. Currently log message and pass since service # models do not have this metadata. - pass + logger.debug( + f"Error occurred while processing metadata for model {model_name}. " + f"Exception: {str(ex)}" + ) validation_result.telemetry_model_name = verified_model.display_name elif ( model_config is not None From 6aaa3ef57c43c698f49455dfb629e78ae4e4646f Mon Sep 17 00:00:00 2001 From: Vipul Date: Sat, 14 Dec 2024 11:55:24 -0800 Subject: [PATCH 03/18] improve error message logging --- ads/aqua/model/model.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 89949b68f..6dc76488b 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -19,7 +19,11 @@ InferenceContainerTypeFamily, Tags, ) -from ads.aqua.common.errors import AquaRuntimeError, AquaValueError +from ads.aqua.common.errors import ( + AquaFileNotFoundError, + AquaRuntimeError, + AquaValueError, +) from ads.aqua.common.utils import ( LifecycleStatus, _build_resource_identifier, @@ -1206,7 +1210,9 @@ def _validate_safetensor_format( except Exception as ex: logger.error( f"Exception occurred while loading config file from {import_model_details.os_path}" - f"Exception message: {ex}" + ) + logger.error( + ex.reason if isinstance(ex, AquaFileNotFoundError) else str(ex) ) if not import_model_details.ignore_model_artifact_check: raise AquaRuntimeError( From 6b2e4aa28357b2ef678f1329c7bb26fa59f4f372 Mon Sep 17 00:00:00 2001 From: Vipul Date: Sat, 14 Dec 2024 12:25:08 -0800 Subject: [PATCH 04/18] update handler tests --- tests/unitary/with_extras/aqua/test_model_handler.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/unitary/with_extras/aqua/test_model_handler.py b/tests/unitary/with_extras/aqua/test_model_handler.py index bf02174b9..16202f477 100644 --- a/tests/unitary/with_extras/aqua/test_model_handler.py +++ b/tests/unitary/with_extras/aqua/test_model_handler.py @@ -132,7 +132,7 @@ def test_list(self, mock_list): @parameterized.expand( [ - (None, None, False, None, None, None, None, None), + (None, None, False, None, None, None, None, None, True), ( "odsc-llm-fine-tuning", None, @@ -142,8 +142,9 @@ def test_list(self, mock_list): ["test.json"], None, None, + False, ), - (None, "test.gguf", True, None, ["*.json"], None, None, None), + (None, "test.gguf", True, None, ["*.json"], None, None, None, False), ( None, None, @@ -153,6 +154,7 @@ def test_list(self, mock_list): ["test.json"], None, None, + False, ), ( None, @@ -163,6 +165,7 @@ def test_list(self, mock_list): None, {"ftag1": "fvalue1"}, {"dtag1": "dvalue1"}, + False, ), ], ) @@ -178,6 +181,7 @@ def test_register( ignore_patterns, freeform_tags, defined_tags, + ignore_model_artifact_check, mock_register, mock_finish, ): @@ -201,6 +205,7 @@ def test_register( ignore_patterns=ignore_patterns, freeform_tags=freeform_tags, defined_tags=defined_tags, + ignore_model_artifact_check=ignore_model_artifact_check, ) ) result = self.model_handler.post() @@ -218,6 +223,7 @@ def test_register( ignore_patterns=ignore_patterns, freeform_tags=freeform_tags, defined_tags=defined_tags, + ignore_model_artifact_check=ignore_model_artifact_check, ) assert result["id"] == "test_id" assert result["inference_container"] == "odsc-tgi-serving" From 71269a7164f0bd0f593a5a56928f6d62c9af3a6c Mon Sep 17 00:00:00 2001 From: Vipul Date: Mon, 16 Dec 2024 13:27:03 -0800 Subject: [PATCH 05/18] update logging level --- ads/aqua/model/model.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 6dc76488b..a7952dde3 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -985,7 +985,14 @@ def get_model_files(os_path: str, model_format: ModelFormat) -> List[str]: config_file_name=AQUA_MODEL_ARTIFACT_CONFIG, ) except Exception as ex: - logger.warning(str(ex)) + message = ( + f"The model path {os_path} does not contain the file config.json. " + f"Please check if the path is correct or the model artifacts are available at this location." + ) + logger.warning( + f"{message}\n" + f"Details: {ex.reason if isinstance(ex, AquaFileNotFoundError) else str(ex)}\n" + ) else: model_files.append(AQUA_MODEL_ARTIFACT_CONFIG) @@ -1208,17 +1215,21 @@ def _validate_safetensor_format( config_file_name=AQUA_MODEL_ARTIFACT_CONFIG, ) except Exception as ex: - logger.error( - f"Exception occurred while loading config file from {import_model_details.os_path}" - ) - logger.error( - ex.reason if isinstance(ex, AquaFileNotFoundError) else str(ex) + message = ( + f"The model path {import_model_details.os_path} does not contain the file config.json. " + f"Please check if the path is correct or the model artifacts are available at this location." ) if not import_model_details.ignore_model_artifact_check: - raise AquaRuntimeError( - f"The model path {import_model_details.os_path} does not contain the file config.json. " - f"Please check if the path is correct or the model artifacts are available at this location." - ) from ex + logger.error( + f"{message}\n" + f"Details: {ex.reason if isinstance(ex, AquaFileNotFoundError) else str(ex)}" + ) + raise AquaRuntimeError(message) from ex + else: + logger.warning( + f"{message}\n" + f"Proceeding with model registration as ignore_model_artifact_check field is set." + ) if verified_model: # model_type validation, log message if metadata field doesn't match. @@ -1446,7 +1457,6 @@ def register( ).rstrip("/") else: artifact_path = import_model_details.os_path.rstrip("/") - # Create Model catalog entry with pass by reference ds_model = self._create_model_catalog_entry( os_path=artifact_path, From 59a042b9d6b96e2d017537fe79f413bf5b6c2384 Mon Sep 17 00:00:00 2001 From: Vipul Date: Fri, 3 Jan 2025 12:36:06 -0800 Subject: [PATCH 06/18] log request ids --- ads/aqua/extension/base_handler.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/ads/aqua/extension/base_handler.py b/ads/aqua/extension/base_handler.py index 5bd9f7091..68cf57b5a 100644 --- a/ads/aqua/extension/base_handler.py +++ b/ads/aqua/extension/base_handler.py @@ -1,6 +1,5 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ @@ -35,7 +34,7 @@ def __init__( self.telemetry = TelemetryClient( bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS ) - except: + except Exception: pass @staticmethod @@ -82,19 +81,23 @@ def write_error(self, status_code, **kwargs): "message": message, "service_payload": service_payload, "reason": reason, + "request_id": str(uuid.uuid4()), } exc_info = kwargs.get("exc_info") if exc_info: - logger.error("".join(traceback.format_exception(*exc_info))) + logger.error( + f"Error Request ID: {reply['request_id']}\n" + f"Error: {''.join(traceback.format_exception(*exc_info))}" + ) e = exc_info[1] if isinstance(e, HTTPError): reply["message"] = e.log_message or message reply["reason"] = e.reason if e.reason else reply["reason"] - reply["request_id"] = str(uuid.uuid4()) - else: - reply["request_id"] = str(uuid.uuid4()) - logger.warning(reply["message"]) + logger.error( + f"Error Request ID: {reply['request_id']}\n" + f"Error: {reply['message']} {reply['reason']}" + ) # telemetry may not be present if there is an error while initializing if hasattr(self, "telemetry"): @@ -103,7 +106,7 @@ def write_error(self, status_code, **kwargs): category="aqua/error", action=str(status_code), value=reason, - **aqua_api_details + **aqua_api_details, ) self.finish(json.dumps(reply)) From a8321a1eda8a779b4fb89585cd59e20d1af13ffa Mon Sep 17 00:00:00 2001 From: Vipul Date: Fri, 3 Jan 2025 12:41:54 -0800 Subject: [PATCH 07/18] update logging for model operations --- ads/aqua/model/model.py | 86 +++++++++++++++++++++++++++++++---------- 1 file changed, 65 insertions(+), 21 deletions(-) diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 02e0df00f..1ac055522 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import os import pathlib @@ -160,7 +160,7 @@ def create( target_compartment = compartment_id or COMPARTMENT_OCID if service_model.compartment_id != ODSC_MODEL_COMPARTMENT_OCID: - logger.debug( + logger.info( f"Aqua Model {model_id} already exists in user's compartment." "Skipped copying." ) @@ -191,8 +191,8 @@ def create( # TODO: decide what kwargs will be needed. .create(model_by_reference=True, **kwargs) ) - logger.debug( - f"Aqua Model {custom_model.id} created with the service model {model_id}" + logger.info( + f"Aqua Model {custom_model.id} created with the service model {model_id}." ) # tracks unique models that were created in the user compartment @@ -223,11 +223,16 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod cached_item = self._service_model_details_cache.get(model_id) if cached_item: + logger.info(f"Fetching model details for model {model_id} from cache.") return cached_item + logger.info(f"Fetching model details for model {model_id}.") ds_model = DataScienceModel.from_id(model_id) if not self._if_show(ds_model): - raise AquaRuntimeError(f"Target model `{ds_model.id} `is not Aqua model.") + raise AquaRuntimeError( + f"Target model `{ds_model.id} `is not an Aqua model as it does not contain " + f"{Tags.AQUA_TAG} tag." + ) is_fine_tuned_model = bool( ds_model.freeform_tags @@ -246,16 +251,21 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod ds_model.custom_metadata_list._to_oci_metadata() ) if artifact_path != UNKNOWN: + model_card_path = ( + f"{artifact_path.rstrip('/')}/config/{README}" + if is_verified_type + else f"{artifact_path.rstrip('/')}/{README}" + ) model_card = str( read_file( - file_path=( - f"{artifact_path.rstrip('/')}/config/{README}" - if is_verified_type - else f"{artifact_path.rstrip('/')}/{README}" - ), + file_path=model_card_path, auth=default_signer(), ) ) + if not model_card: + logger.warn( + f"Model card for {model_id} is empty or could not be loaded from {model_card_path}." + ) inference_container = ds_model.custom_metadata_list.get( ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, @@ -301,9 +311,10 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod try: jobrun_ocid = ds_model.provenance_metadata.training_id jobrun = self.ds_client.get_job_run(jobrun_ocid).data - except Exception: + except Exception as e: logger.debug( f"Missing jobrun information in the provenance metadata of the given model {model_id}." + f"\nError: {str(e)}" ) jobrun = None @@ -312,7 +323,10 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod FineTuningCustomMetadata.FT_SOURCE ).value except ValueError as e: - logger.debug(str(e)) + logger.debug( + f"Custom metadata is missing {FineTuningCustomMetadata.FT_SOURCE} key for " + f"model {model_id}.\nError: {str(e)}" + ) source_id = UNKNOWN try: @@ -320,7 +334,10 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod FineTuningCustomMetadata.FT_SOURCE_NAME ).value except ValueError as e: - logger.debug(str(e)) + logger.debug( + f"Custom metadata is missing {FineTuningCustomMetadata.FT_SOURCE_NAME} key for " + f"model {model_id}.\nError: {str(e)}" + ) source_name = UNKNOWN source_identifier = _build_resource_identifier( @@ -370,6 +387,7 @@ def delete_model(self, model_id): Tags.AQUA_FINE_TUNED_MODEL_TAG, None ) if is_registered_model or is_fine_tuned_model: + logger.info(f"Deleting model {model_id}.") return ds_model.delete() else: raise AquaRuntimeError( @@ -447,6 +465,7 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task freeform_tags=freeform_tags, ) AquaApp().update_model(id, update_model_details) + logger.info(f"Updated model details for the model {id}.") else: raise AquaRuntimeError( f"Failed to edit model:{id}. Only registered unverified models can be edited." @@ -706,7 +725,7 @@ def list( ) logger.info( - f"Fetch {len(models)} model in compartment_id={compartment_id or ODSC_MODEL_COMPARTMENT_OCID}." + f"Fetched {len(models)} model in compartment_id={compartment_id or ODSC_MODEL_COMPARTMENT_OCID}." ) aqua_models = [] @@ -736,10 +755,12 @@ def clear_model_list_cache( dict with the key used, and True if cache has the key that needs to be deleted. """ res = {} - logger.info("Clearing _service_models_cache") with self._cache_lock: if ODSC_MODEL_COMPARTMENT_OCID in self._service_models_cache: self._service_models_cache.pop(key=ODSC_MODEL_COMPARTMENT_OCID) + logger.info( + f"Cleared models cache for service compartment {ODSC_MODEL_COMPARTMENT_OCID}." + ) res = { "key": { "compartment_id": ODSC_MODEL_COMPARTMENT_OCID, @@ -756,10 +777,10 @@ def clear_model_details_cache(self, model_id): dict with the key used, and True if cache has the key that needs to be deleted. """ res = {} - logger.info(f"Clearing _service_model_details_cache for {model_id}") with self._cache_lock: if model_id in self._service_model_details_cache: self._service_model_details_cache.pop(key=model_id) + logger.info(f"Clearing model details cache for model {model_id}.") res = {"key": {"model_id": model_id}, "cache_deleted": True} return res @@ -844,7 +865,8 @@ def _create_model_catalog_entry( metadata = ModelCustomMetadata() if not inference_container: raise AquaRuntimeError( - f"Require Inference container information. Model: {model_name} does not have associated inference container defaults. Check docs for more information on how to pass inference container." + f"Require Inference container information. Model: {model_name} does not have associated inference " + f"container defaults. Check docs for more information on how to pass inference container." ) metadata.add( key=AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME, @@ -920,7 +942,7 @@ def _create_model_catalog_entry( artifact_path = metadata.get(MODEL_BY_REFERENCE_OSS_PATH_KEY).value logger.info( f"Found model artifact in the service bucket. " - f"Using artifact from service bucket instead of {os_path}" + f"Using artifact from service bucket instead of {os_path}." ) # todo: implement generic copy_folder method @@ -952,7 +974,7 @@ def _create_model_catalog_entry( .with_freeform_tags(**tags) .with_defined_tags(**(defined_tags or {})) ).create(model_by_reference=True) - logger.debug(model) + logger.debug(f"Created model catalog entry for the model:\n{model}") return model @staticmethod @@ -986,6 +1008,9 @@ def get_model_files(os_path: str, model_format: ModelFormat) -> List[str]: model_files.extend( list_os_files_with_extension(oss_path=os_path, extension=".gguf") ) + logger.debug( + f"Fetched {len(model_files)} model files from {os_path} for model format {model_format}." + ) return model_files @staticmethod @@ -1028,6 +1053,9 @@ def get_hf_model_files(model_name: str, model_format: ModelFormat) -> List[str]: elif extension == model_format.value: model_files.append(model_sibling.rfilename) + logger.debug( + f"Fetched {len(model_files)} model files for the model {model_name} for model format {model_format}." + ) return model_files def _validate_model( @@ -1118,6 +1146,9 @@ def _validate_model( } validation_result.tags = hf_tags except Exception: + logger.debug( + f"Could not process tags from Hugging Face model details for model {model_name}." + ) pass validation_result.model_formats = model_formats @@ -1329,6 +1360,9 @@ def _download_model_from_hf( local_dir = os.path.join(os.path.expanduser("~"), "cached-model") local_dir = os.path.join(local_dir, model_name) os.makedirs(local_dir, exist_ok=True) + logger.debug( + f"Downloading artifacts from Hugging Face to local directory {local_dir}." + ) snapshot_download( repo_id=model_name, local_dir=local_dir, @@ -1336,6 +1370,9 @@ def _download_model_from_hf( ignore_patterns=ignore_patterns, ) # Upload to object storage and skip .cache/huggingface/ folder + logger.debug( + f"Uploading local artifacts from local directory {local_dir} to {os_path}." + ) model_artifact_path = upload_folder( os_path=os_path, local_dir=local_dir, @@ -1379,6 +1416,7 @@ def register( import_model_details.model.startswith("ocid") and "datasciencemodel" in import_model_details.model ): + logger.info(f"Fetching details for model {import_model_details.model}.") verified_model = DataScienceModel.from_id(import_model_details.model) else: # If users passes model name, check if there is model with the same name in the service model catalog. If it is there, then use that model @@ -1501,7 +1539,7 @@ def _rqs(self, compartment_id: str, model_type="FT", **kwargs): elif model_type == ModelType.BASE: filter_tag = Tags.BASE_MODEL_CUSTOM else: - raise ValueError( + raise AquaValueError( f"Model of type {model_type} is unknown. The values should be in {ModelType.values()}" ) @@ -1541,7 +1579,10 @@ def load_license(self, model_id: str) -> AquaModelLicense: oci_model = self.ds_client.get_model(model_id).data artifact_path = get_artifact_path(oci_model.custom_metadata_list) if not artifact_path: - raise AquaRuntimeError("Failed to get artifact path from custom metadata.") + raise AquaRuntimeError( + f"License could not be loaded. Failed to get artifact path from custom metadata for" + f"the model {model_id}." + ) content = str( read_file( @@ -1572,6 +1613,9 @@ def _find_matching_aqua_model(self, model_id: str) -> Optional[str]: for aqua_model_summary in aqua_model_list: if aqua_model_summary.name.lower() == model_id_lower: + logger.info( + f"Found matching verified model id {aqua_model_summary.id} for the model {model_id}" + ) return aqua_model_summary.id return None From ee8dbf865368f62ce62128ebfca46b0416cbfb97 Mon Sep 17 00:00:00 2001 From: Vipul Date: Fri, 3 Jan 2025 15:12:57 -0800 Subject: [PATCH 08/18] update logging for deployment operations --- ads/aqua/app.py | 4 ++-- ads/aqua/extension/base_handler.py | 2 +- ads/aqua/model/model.py | 2 +- ads/aqua/modeldeployment/deployment.py | 30 ++++++++++++++++++-------- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/ads/aqua/app.py b/ads/aqua/app.py index a7a6165d8..253996268 100644 --- a/ads/aqua/app.py +++ b/ads/aqua/app.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import json @@ -298,7 +298,7 @@ def get_config(self, model_id: str, config_file_name: str) -> Dict: config = {} artifact_path = get_artifact_path(oci_model.custom_metadata_list) if not artifact_path: - logger.error( + logger.debug( f"Failed to get artifact path from custom metadata for the model: {model_id}" ) return config diff --git a/ads/aqua/extension/base_handler.py b/ads/aqua/extension/base_handler.py index 68cf57b5a..19dda9ce5 100644 --- a/ads/aqua/extension/base_handler.py +++ b/ads/aqua/extension/base_handler.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2025 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 1ac055522..93139939d 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2025 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import os import pathlib diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index b7787ea21..179af9a7f 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -1,8 +1,7 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -import logging import shlex from typing import Dict, List, Optional, Union @@ -271,7 +270,7 @@ def create( f"field. Either re-register the model with custom container URI, or set container_image_uri " f"parameter when creating this deployment." ) from err - logging.info( + logger.info( f"Aqua Image used for deploying {aqua_model.id} : {container_image_uri}" ) @@ -282,14 +281,14 @@ def create( default_cmd_var = shlex.split(cmd_var_string) if default_cmd_var: cmd_var = validate_cmd_var(default_cmd_var, cmd_var) - logging.info(f"CMD used for deploying {aqua_model.id} :{cmd_var}") + logger.info(f"CMD used for deploying {aqua_model.id} :{cmd_var}") except ValueError: - logging.debug( + logger.debug( f"CMD will be ignored for this deployment as {AQUA_DEPLOYMENT_CONTAINER_CMD_VAR_METADATA_NAME} " f"key is not available in the custom metadata field for this model." ) except Exception as e: - logging.error( + logger.error( f"There was an issue processing CMD arguments. Error: {str(e)}" ) @@ -385,7 +384,7 @@ def create( if key not in env_var: env_var.update(env) - logging.info(f"Env vars used for deploying {aqua_model.id} :{env_var}") + logger.info(f"Env vars used for deploying {aqua_model.id} :{env_var}") # Start model deployment # configure model deployment infrastructure @@ -440,10 +439,14 @@ def create( .with_runtime(container_runtime) ).deploy(wait_for_completion=False) + deployment_id = deployment.dsc_model_deployment.id + logger.info( + f"Aqua model deployment {deployment_id} created for model {aqua_model.id}." + ) model_type = ( AQUA_MODEL_TYPE_CUSTOM if is_fine_tuned_model else AQUA_MODEL_TYPE_SERVICE ) - deployment_id = deployment.dsc_model_deployment.id + # we arbitrarily choose last 8 characters of OCID to identify MD in telemetry telemetry_kwargs = {"ocid": get_ocid_substring(deployment_id, key_len=8)} @@ -539,6 +542,9 @@ def list(self, **kwargs) -> List["AquaDeployment"]: value=state, ) + logger.info( + f"Fetched {len(results)} model deployments from compartment_id={compartment_id}." + ) # tracks number of times deployment listing was called self.telemetry.record_event_async(category="aqua/deployment", action="list") @@ -546,18 +552,21 @@ def list(self, **kwargs) -> List["AquaDeployment"]: @telemetry(entry_point="plugin=deployment&action=delete", name="aqua") def delete(self, model_deployment_id: str): + logger.info(f"Deleting model deployment {model_deployment_id}.") return self.ds_client.delete_model_deployment( model_deployment_id=model_deployment_id ).data @telemetry(entry_point="plugin=deployment&action=deactivate", name="aqua") def deactivate(self, model_deployment_id: str): + logger.info(f"Deactivating model deployment {model_deployment_id}.") return self.ds_client.deactivate_model_deployment( model_deployment_id=model_deployment_id ).data @telemetry(entry_point="plugin=deployment&action=activate", name="aqua") def activate(self, model_deployment_id: str): + logger.info(f"Activating model deployment {model_deployment_id}.") return self.ds_client.activate_model_deployment( model_deployment_id=model_deployment_id ).data @@ -579,6 +588,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": AquaDeploymentDetail: The instance of the Aqua model deployment details. """ + logger.info(f"Fetching model deployment details for {model_deployment_id}.") + model_deployment = self.ds_client.get_model_deployment( model_deployment_id=model_deployment_id, **kwargs ).data @@ -594,7 +605,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": if not oci_aqua: raise AquaRuntimeError( - f"Target deployment {model_deployment_id} is not Aqua deployment." + f"Target deployment {model_deployment_id} is not Aqua deployment as it does not contain " + f"{Tags.AQUA_TAG} tag." ) log_id = "" From 68f325a332dfb13227097668d73dedfa59a2ca04 Mon Sep 17 00:00:00 2001 From: Vipul Date: Fri, 3 Jan 2025 16:22:06 -0800 Subject: [PATCH 09/18] update logging for deployment operations --- ads/aqua/modeldeployment/deployment.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 179af9a7f..c65858b53 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -664,7 +664,7 @@ def get_deployment_config(self, model_id: str) -> Dict: config = self.get_config(model_id, AQUA_MODEL_DEPLOYMENT_CONFIG) if not config: logger.debug( - f"Deployment config for custom model: {model_id} is not available." + f"Deployment config for custom model: {model_id} is not available. Use defaults." ) return config From 86ef7ada9d5cb1160868c5b0dc00ebfe8cfbfc68 Mon Sep 17 00:00:00 2001 From: Vipul Date: Fri, 3 Jan 2025 16:28:00 -0800 Subject: [PATCH 10/18] update logging for finetuning operations --- ads/aqua/common/utils.py | 51 +++++++++++++++++++++++-- ads/aqua/extension/finetune_handler.py | 16 ++++---- ads/aqua/finetuning/finetuning.py | 52 ++++++++------------------ 3 files changed, 71 insertions(+), 48 deletions(-) diff --git a/ads/aqua/common/utils.py b/ads/aqua/common/utils.py index 6e1e09aca..4e885c32f 100644 --- a/ads/aqua/common/utils.py +++ b/ads/aqua/common/utils.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ """AQUA utils and constants.""" @@ -12,11 +12,12 @@ import re import shlex import subprocess +from dataclasses import fields from datetime import datetime, timedelta from functools import wraps from pathlib import Path from string import Template -from typing import List, Union +from typing import Any, List, Optional, Type, TypeVar, Union import fsspec import oci @@ -74,6 +75,7 @@ from ads.model import DataScienceModel, ModelVersionSet logger = logging.getLogger("ads.aqua") +T = TypeVar("T") class LifecycleStatus(str, metaclass=ExtendedEnumMeta): @@ -788,7 +790,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str: return ocid[-key_len:] if ocid and len(ocid) > key_len else "" -def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str: +def upload_folder( + os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None +) -> str: """Upload the local folder to the object storage Args: @@ -1159,3 +1163,44 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]: combined_cmd_var = cmd_var + overrides return combined_cmd_var + + +def validate_dataclass_params(dataclass_type: Type[T], **kwargs: Any) -> Optional[T]: + """This method tries to initialize a dataclass with the provided keyword arguments. It handles + errors related to missing, unexpected or invalid arguments. + + Parameters + ---------- + dataclass_type (Type[T]): + the dataclass type to instantiate. + kwargs (Any): + the keyword arguments to initialize the dataclass + Returns + ------- + Optional[T] + instance of dataclass if successfully initialized + """ + + try: + return dataclass_type(**kwargs) + except TypeError as ex: + error_message = str(ex) + allowed_params = ", ".join( + field.name for field in fields(dataclass_type) + ).rstrip() + if "__init__() missing" in error_message: + missing_params = error_message.split("missing ")[1] + raise AquaValueError( + "Error: Missing required parameters: " + f"{missing_params}. Allowable parameters are: {allowed_params}." + ) from ex + elif "__init__() got an unexpected keyword argument" in error_message: + unexpected_param = error_message.split("argument '")[1].rstrip("'") + raise AquaValueError( + "Error: Unexpected parameter: " + f"{unexpected_param}. Allowable parameters are: {allowed_params}." + ) from ex + else: + raise AquaValueError( + "Invalid parameters. Allowable parameters are: " f"{allowed_params}." + ) from ex diff --git a/ads/aqua/extension/finetune_handler.py b/ads/aqua/extension/finetune_handler.py index c8ebc5916..50400e04c 100644 --- a/ads/aqua/extension/finetune_handler.py +++ b/ads/aqua/extension/finetune_handler.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ @@ -33,7 +33,7 @@ def get(self, id=""): raise HTTPError(400, f"The request {self.request.path} is invalid.") @handle_exceptions - def post(self, *args, **kwargs): + def post(self, *args, **kwargs): # noqa: ARG002 """Handles post request for the fine-tuning API Raises @@ -43,8 +43,8 @@ def post(self, *args, **kwargs): """ try: input_data = self.get_json_body() - except Exception: - raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) + except Exception as ex: + raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex if not input_data: raise HTTPError(400, Errors.NO_INPUT_DATA) @@ -71,7 +71,7 @@ def get(self, model_id): ) @handle_exceptions - def post(self, *args, **kwargs): + def post(self, *args, **kwargs): # noqa: ARG002 """Handles post request for the finetuning param handler API. Raises @@ -81,15 +81,15 @@ def post(self, *args, **kwargs): """ try: input_data = self.get_json_body() - except Exception: - raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) + except Exception as ex: + raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex if not input_data: raise HTTPError(400, Errors.NO_INPUT_DATA) params = input_data.get("params", None) return self.finish( - AquaFineTuningApp().validate_finetuning_params( + AquaFineTuningApp.validate_finetuning_params( params=params, ) ) diff --git a/ads/aqua/finetuning/finetuning.py b/ads/aqua/finetuning/finetuning.py index 5ff03276b..9a9811817 100644 --- a/ads/aqua/finetuning/finetuning.py +++ b/ads/aqua/finetuning/finetuning.py @@ -1,10 +1,10 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import json import os -from dataclasses import MISSING, asdict, fields +from dataclasses import asdict, fields from typing import Dict from oci.data_science.models import ( @@ -20,6 +20,7 @@ from ads.aqua.common.utils import ( get_container_image, upload_local_to_os, + validate_dataclass_params, ) from ads.aqua.constants import ( DEFAULT_FT_BATCH_SIZE, @@ -102,26 +103,10 @@ def create( The instance of AquaFineTuningSummary. """ if not create_fine_tuning_details: - try: - create_fine_tuning_details = CreateFineTuningDetails(**kwargs) - except Exception as ex: - allowed_create_fine_tuning_details = ", ".join( - field.name for field in fields(CreateFineTuningDetails) - ).rstrip() - raise AquaValueError( - "Invalid create fine tuning parameters. Allowable parameters are: " - f"{allowed_create_fine_tuning_details}." - ) from ex + validate_dataclass_params(CreateFineTuningDetails, **kwargs) source = self.get_source(create_fine_tuning_details.ft_source_id) - # todo: revisit validation for fine tuned models - # if source.compartment_id != ODSC_MODEL_COMPARTMENT_OCID: - # raise AquaValueError( - # f"Fine tuning is only supported for Aqua service models in {ODSC_MODEL_COMPARTMENT_OCID}. " - # "Use a valid Aqua service model id instead." - # ) - target_compartment = ( create_fine_tuning_details.compartment_id or COMPARTMENT_OCID ) @@ -401,6 +386,9 @@ def create( defined_tags=model_defined_tags, ), ) + logger.debug( + f"Successfully updated model custom metadata list and freeform tags for the model {ft_model.id}." + ) self.update_model_provenance( model_id=ft_model.id, @@ -408,6 +396,9 @@ def create( training_id=ft_job_run.id ), ) + logger.debug( + f"Successfully updated model provenance for the model {ft_model.id}." + ) # tracks the shape and replica used for fine-tuning the service models telemetry_kwargs = ( @@ -587,7 +578,7 @@ def get_finetuning_config(self, model_id: str) -> Dict: config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG) if not config: logger.debug( - f"Fine-tuning config for custom model: {model_id} is not available." + f"Fine-tuning config for custom model: {model_id} is not available. Use defaults." ) return config @@ -622,7 +613,8 @@ def get_finetuning_default_params(self, model_id: str) -> Dict: return default_params - def validate_finetuning_params(self, params: Dict = None) -> Dict: + @staticmethod + def validate_finetuning_params(params: Dict = None) -> Dict: """Validate if the fine-tuning parameters passed by the user can be overridden. Parameter values are not validated, only param keys are validated. @@ -633,21 +625,7 @@ def validate_finetuning_params(self, params: Dict = None) -> Dict: Returns ------- - Return a list of restricted params. + Return a dict with value true if valid, else raises AquaValueError. """ - try: - AquaFineTuningParams( - **params, - ) - except Exception as e: - logger.debug(str(e)) - allowed_fine_tuning_parameters = ", ".join( - f"{field.name} (required)" if field.default is MISSING else field.name - for field in fields(AquaFineTuningParams) - ).rstrip() - raise AquaValueError( - f"Invalid fine tuning parameters. Allowable parameters are: " - f"{allowed_fine_tuning_parameters}." - ) from e - + validate_dataclass_params(AquaFineTuningParams, **(params or {})) return {"valid": True} From adace880fb886f85237405676cf84e9726a10fb8 Mon Sep 17 00:00:00 2001 From: Vipul Date: Fri, 3 Jan 2025 16:50:06 -0800 Subject: [PATCH 11/18] update logging for finetuning operations --- ads/aqua/finetuning/finetuning.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ads/aqua/finetuning/finetuning.py b/ads/aqua/finetuning/finetuning.py index 9a9811817..04354db0c 100644 --- a/ads/aqua/finetuning/finetuning.py +++ b/ads/aqua/finetuning/finetuning.py @@ -103,7 +103,9 @@ def create( The instance of AquaFineTuningSummary. """ if not create_fine_tuning_details: - validate_dataclass_params(CreateFineTuningDetails, **kwargs) + create_fine_tuning_details = validate_dataclass_params( + CreateFineTuningDetails, **kwargs + ) source = self.get_source(create_fine_tuning_details.ft_source_id) From e3cb8d5432a13004abf62fe765f97756f4a8b435 Mon Sep 17 00:00:00 2001 From: Vipul Date: Mon, 6 Jan 2025 12:11:18 -0800 Subject: [PATCH 12/18] update evaluation validation for create API --- ads/aqua/common/utils.py | 8 +++++++- ads/aqua/evaluation/evaluation.py | 16 +++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/ads/aqua/common/utils.py b/ads/aqua/common/utils.py index 4e885c32f..04650db20 100644 --- a/ads/aqua/common/utils.py +++ b/ads/aqua/common/utils.py @@ -31,6 +31,7 @@ ) from oci.data_science.models import JobRun, Model from oci.object_storage.models import ObjectSummary +from pydantic import BaseModel, ValidationError from ads.aqua.common.enums import ( InferenceContainerParamType, @@ -75,7 +76,7 @@ from ads.model import DataScienceModel, ModelVersionSet logger = logging.getLogger("ads.aqua") -T = TypeVar("T") +T = TypeVar("T", bound=Union[BaseModel, Any]) class LifecycleStatus(str, metaclass=ExtendedEnumMeta): @@ -1204,3 +1205,8 @@ def validate_dataclass_params(dataclass_type: Type[T], **kwargs: Any) -> Optiona raise AquaValueError( "Invalid parameters. Allowable parameters are: " f"{allowed_params}." ) from ex + except ValidationError as ex: + custom_errors = {".".join(map(str, e["loc"])): e["msg"] for e in ex.errors()} + raise AquaValueError( + f"Invalid parameters. Error details: {custom_errors}." + ) from ex diff --git a/ads/aqua/evaluation/evaluation.py b/ads/aqua/evaluation/evaluation.py index 0b7cb7773..12ed45da4 100644 --- a/ads/aqua/evaluation/evaluation.py +++ b/ads/aqua/evaluation/evaluation.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import base64 import json @@ -43,6 +43,7 @@ get_container_image, is_valid_ocid, upload_local_to_os, + validate_dataclass_params, ) from ads.aqua.config.config import get_evaluation_service_config from ads.aqua.constants import ( @@ -155,16 +156,9 @@ def create( The instance of AquaEvaluationSummary. """ if not create_aqua_evaluation_details: - try: - create_aqua_evaluation_details = CreateAquaEvaluationDetails(**kwargs) - except Exception as ex: - custom_errors = { - ".".join(map(str, e["loc"])): e["msg"] - for e in json.loads(ex.json()) - } - raise AquaValueError( - f"Invalid create evaluation parameters. Error details: {custom_errors}." - ) from ex + create_aqua_evaluation_details = validate_dataclass_params( + CreateAquaEvaluationDetails, **kwargs + ) if not is_valid_ocid(create_aqua_evaluation_details.evaluation_source_id): raise AquaValueError( From b0827e5dd883e663ef7c78981f68efdf01db8e5e Mon Sep 17 00:00:00 2001 From: Vipul Date: Mon, 6 Jan 2025 14:03:20 -0800 Subject: [PATCH 13/18] update evaluation logging --- ads/aqua/evaluation/evaluation.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/ads/aqua/evaluation/evaluation.py b/ads/aqua/evaluation/evaluation.py index 12ed45da4..9521d10c8 100644 --- a/ads/aqua/evaluation/evaluation.py +++ b/ads/aqua/evaluation/evaluation.py @@ -193,11 +193,11 @@ def create( eval_inference_configuration = ( container.spec.evaluation_configuration ) - except Exception: + except Exception as ex: logger.debug( f"Could not load inference config details for the evaluation source id: " f"{create_aqua_evaluation_details.evaluation_source_id}. Please check if the container" - f" runtime has the correct SMC image information." + f" runtime has the correct SMC image information.\nError: {str(ex)}" ) elif ( DataScienceResource.MODEL @@ -283,7 +283,7 @@ def create( f"Invalid experiment name. Please provide an experiment with `{Tags.AQUA_EVALUATION}` in tags." ) except Exception: - logger.debug( + logger.info( f"Model version set {experiment_model_version_set_name} doesn't exist. " "Creating new model version set." ) @@ -705,14 +705,16 @@ def get(self, eval_id) -> AquaEvaluationDetail: try: log = utils.query_resource(log_id, return_all=False) log_name = log.display_name if log else "" - except Exception: + except Exception as ex: + logger.debug(f"Failed to get associated log name. Error: {ex}") pass if loggroup_id: try: loggroup = utils.query_resource(loggroup_id, return_all=False) loggroup_name = loggroup.display_name if loggroup else "" - except Exception: + except Exception as ex: + logger.debug(f"Failed to get associated loggroup name. Error: {ex}") pass try: @@ -1041,6 +1043,7 @@ def download_report(self, eval_id) -> AquaEvalReport: return report with tempfile.TemporaryDirectory() as temp_dir: + logger.info(f"Downloading evaluation artifact for {eval_id}.") DataScienceModel.from_id(eval_id).download_artifact( temp_dir, auth=self._auth, @@ -1194,6 +1197,7 @@ def _delete_job_and_model(job, model): def load_evaluation_config(self, container: Optional[str] = None) -> Dict: """Loads evaluation config.""" + logger.info("Loading evaluation container config.") # retrieve the evaluation config by container family name evaluation_config = get_evaluation_service_config(container) @@ -1273,9 +1277,9 @@ def _get_source( raise AquaRuntimeError( f"Not supported source type: {resource_type}" ) - except Exception: + except Exception as ex: logger.debug( - f"Failed to retrieve source information for evaluation {evaluation.identifier}." + f"Failed to retrieve source information for evaluation {evaluation.identifier}.\nError: {ex}" ) source_name = "" From ed9504d2a98313fcea86b234e549e5101b756746 Mon Sep 17 00:00:00 2001 From: Vipul Date: Mon, 6 Jan 2025 14:39:30 -0800 Subject: [PATCH 14/18] update evaluation logging --- ads/aqua/evaluation/evaluation.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ads/aqua/evaluation/evaluation.py b/ads/aqua/evaluation/evaluation.py index 9521d10c8..8dde31778 100644 --- a/ads/aqua/evaluation/evaluation.py +++ b/ads/aqua/evaluation/evaluation.py @@ -721,7 +721,11 @@ def get(self, eval_id) -> AquaEvaluationDetail: introspection = json.loads( self._get_attribute_from_model_metadata(resource, "ArtifactTestResults") ) - except Exception: + except Exception as ex: + logger.debug( + f"There was an issue loading the model attribute as json object for evaluation {eval_id}. " + f"Setting introspection to empty.\n Error:{ex}" + ) introspection = {} summary = AquaEvaluationDetail( @@ -874,13 +878,13 @@ def get_status(self, eval_id: str) -> dict: try: log_id = job_run_details.log_details.log_id except Exception as e: - logger.debug(f"Failed to get associated log. {str(e)}") + logger.debug(f"Failed to get associated log.\nError: {str(e)}") log_id = "" try: loggroup_id = job_run_details.log_details.log_group_id except Exception as e: - logger.debug(f"Failed to get associated log. {str(e)}") + logger.debug(f"Failed to get associated log.\nError: {str(e)}") loggroup_id = "" loggroup_url = get_log_links(region=self.region, log_group_id=loggroup_id) @@ -954,7 +958,7 @@ def load_metrics(self, eval_id: str) -> AquaEvalMetrics: ) except Exception as e: logger.debug( - "Failed to load `report.json` from evaluation artifact" f"{str(e)}" + f"Failed to load `report.json` from evaluation artifact.\nError: {str(e)}" ) json_report = {} @@ -1279,7 +1283,7 @@ def _get_source( ) except Exception as ex: logger.debug( - f"Failed to retrieve source information for evaluation {evaluation.identifier}.\nError: {ex}" + f"Failed to retrieve source information for evaluation {evaluation.identifier}.\nError: {str(ex)}" ) source_name = "" From ca02f03cd8bc25cff2e4e22d81ecaadd8b88cf39 Mon Sep 17 00:00:00 2001 From: Vipul Date: Mon, 6 Jan 2025 15:01:29 -0800 Subject: [PATCH 15/18] update tests --- .../unitary/with_extras/aqua/test_handlers.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/unitary/with_extras/aqua/test_handlers.py b/tests/unitary/with_extras/aqua/test_handlers.py index a4ae749e9..6cbffe23e 100644 --- a/tests/unitary/with_extras/aqua/test_handlers.py +++ b/tests/unitary/with_extras/aqua/test_handlers.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*-- -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import json @@ -131,9 +131,13 @@ def test_finish(self, name, payload, expected_call, mock_super_finish): ), aqua_api_details=dict( aqua_api_name="TestDataset.create", - oci_api_name=TestDataset.mock_service_payload_create["operation_name"], - service_endpoint=TestDataset.mock_service_payload_create["request_endpoint"] - ) + oci_api_name=TestDataset.mock_service_payload_create[ + "operation_name" + ], + service_endpoint=TestDataset.mock_service_payload_create[ + "request_endpoint" + ], + ), ), "Authorization Failed: The resource you're looking for isn't accessible. Operation Name: get_job_run.", ], @@ -171,10 +175,13 @@ def test_write_error(self, name, input, expected_msg, mock_uuid, mock_logger): input.get("status_code"), ), value=input.get("reason"), - **aqua_api_details + **aqua_api_details, ) - - mock_logger.warning.assert_called_with(expected_msg) + error_message = ( + f"Error Request ID: {expected_reply['request_id']}\n" + f"Error: {expected_reply['message']} {expected_reply['reason']}" + ) + mock_logger.error.assert_called_with(error_message) class TestHandlers(unittest.TestCase): From ed4098d89cb3a9dced272458bfe5f634a08bf81c Mon Sep 17 00:00:00 2001 From: Vipul Date: Fri, 10 Jan 2025 10:50:48 -0800 Subject: [PATCH 16/18] add missing request id --- ads/aqua/extension/aqua_ws_msg_handler.py | 21 ++++++++++++++------- ads/cli.py | 13 +++++++------ tests/unitary/with_extras/aqua/test_cli.py | 10 +++++++++- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/ads/aqua/extension/aqua_ws_msg_handler.py b/ads/aqua/extension/aqua_ws_msg_handler.py index 04ff651f4..1fcbbf946 100644 --- a/ads/aqua/extension/aqua_ws_msg_handler.py +++ b/ads/aqua/extension/aqua_ws_msg_handler.py @@ -1,10 +1,10 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*-- -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import traceback +import uuid from abc import abstractmethod from http.client import responses from typing import List @@ -34,7 +34,7 @@ def __init__(self, message: str): self.telemetry = TelemetryClient( bucket=AQUA_TELEMETRY_BUCKET, namespace=AQUA_TELEMETRY_BUCKET_NS ) - except: + except Exception: pass @staticmethod @@ -66,16 +66,23 @@ def write_error(self, status_code, **kwargs): "message": message, "service_payload": service_payload, "reason": reason, + "request_id": str(uuid.uuid4()), } exc_info = kwargs.get("exc_info") if exc_info: - logger.error("".join(traceback.format_exception(*exc_info))) + logger.error( + f"Error Request ID: {reply['request_id']}\n" + f"Error: {''.join(traceback.format_exception(*exc_info))}" + ) e = exc_info[1] if isinstance(e, HTTPError): reply["message"] = e.log_message or message reply["reason"] = e.reason - else: - logger.warning(reply["message"]) + + logger.error( + f"Error Request ID: {reply['request_id']}\n" + f"Error: {reply['message']} {reply['reason']}" + ) # telemetry may not be present if there is an error while initializing if hasattr(self, "telemetry"): aqua_api_details = kwargs.get("aqua_api_details", {}) @@ -83,7 +90,7 @@ def write_error(self, status_code, **kwargs): category="aqua/error", action=str(status_code), value=reason, - **aqua_api_details + **aqua_api_details, ) response = AquaWsError( status=status_code, diff --git a/ads/cli.py b/ads/cli.py index 872e7d177..249920eef 100644 --- a/ads/cli.py +++ b/ads/cli.py @@ -1,12 +1,12 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*-- -# Copyright (c) 2021, 2024 Oracle and/or its affiliates. +# Copyright (c) 2021, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import logging import sys import traceback -from dataclasses import is_dataclass +import uuid import fire @@ -27,7 +27,7 @@ ) logger.debug(ex) logger.debug(traceback.format_exc()) - exit() + sys.exit() # https://packaging.python.org/en/latest/guides/single-sourcing-package-version/#single-sourcing-the-package-version if sys.version_info >= (3, 8): @@ -122,8 +122,9 @@ def exit_program(ex: Exception, logger: "logging.Logger") -> None: ... exit_program(e, logger) """ - logger.debug(traceback.format_exc()) - logger.error(str(ex)) + request_id = str(uuid.uuid4()) + logger.debug(f"Error Request ID: {request_id}\nError: {traceback.format_exc()}") + logger.error(f"Error Request ID: {request_id}\n" f"Error: {str(ex)}") exit_code = getattr(ex, "exit_code", 1) logger.error(f"Exit code: {exit_code}") diff --git a/tests/unitary/with_extras/aqua/test_cli.py b/tests/unitary/with_extras/aqua/test_cli.py index 6c3c97cc8..4a2c5aed5 100644 --- a/tests/unitary/with_extras/aqua/test_cli.py +++ b/tests/unitary/with_extras/aqua/test_cli.py @@ -1,12 +1,13 @@ #!/usr/bin/env python # -*- coding: utf-8 -*-- -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import logging import os import subprocess +import uuid from importlib import reload from unittest import TestCase from unittest.mock import call, patch @@ -148,6 +149,7 @@ def test_aqua_cli(self, mock_logger, mock_aqua_command, mock_fire, mock_serializ ] ) @patch("sys.argv", ["ads", "aqua", "--error-option"]) + @patch("uuid.uuid4") @patch("fire.Fire") @patch("ads.aqua.cli.AquaCommand") @patch("ads.aqua.logger.error") @@ -162,11 +164,17 @@ def test_aqua_cli_with_error( mock_logger_error, mock_aqua_command, mock_fire, + mock_uuid, ): """Tests when Aqua Cli gracefully exit when error raised.""" mock_fire.side_effect = mock_side_effect from ads.cli import cli + uuid_value = "12345678-1234-5678-1234-567812345678" + mock_uuid.return_value = uuid.UUID(uuid_value) + expected_logging_message = type(expected_logging_message)( + f"Error Request ID: {uuid_value}\nError: {str(expected_logging_message)}" + ) cli() calls = [ call(expected_logging_message), From 4cbd7e05fd67ef4d33f57b3edeab1b962fb45337 Mon Sep 17 00:00:00 2001 From: Vipul Date: Tue, 14 Jan 2025 12:01:41 -0800 Subject: [PATCH 17/18] revert to previous validation --- ads/aqua/common/utils.py | 51 +------------------------- ads/aqua/evaluation/evaluation.py | 14 +++++-- ads/aqua/extension/finetune_handler.py | 2 +- ads/aqua/finetuning/finetuning.py | 37 ++++++++++++++----- 4 files changed, 40 insertions(+), 64 deletions(-) diff --git a/ads/aqua/common/utils.py b/ads/aqua/common/utils.py index 04650db20..5a69aa594 100644 --- a/ads/aqua/common/utils.py +++ b/ads/aqua/common/utils.py @@ -12,12 +12,11 @@ import re import shlex import subprocess -from dataclasses import fields from datetime import datetime, timedelta from functools import wraps from pathlib import Path from string import Template -from typing import Any, List, Optional, Type, TypeVar, Union +from typing import List, Union import fsspec import oci @@ -31,7 +30,6 @@ ) from oci.data_science.models import JobRun, Model from oci.object_storage.models import ObjectSummary -from pydantic import BaseModel, ValidationError from ads.aqua.common.enums import ( InferenceContainerParamType, @@ -76,7 +74,6 @@ from ads.model import DataScienceModel, ModelVersionSet logger = logging.getLogger("ads.aqua") -T = TypeVar("T", bound=Union[BaseModel, Any]) class LifecycleStatus(str, metaclass=ExtendedEnumMeta): @@ -1164,49 +1161,3 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]: combined_cmd_var = cmd_var + overrides return combined_cmd_var - - -def validate_dataclass_params(dataclass_type: Type[T], **kwargs: Any) -> Optional[T]: - """This method tries to initialize a dataclass with the provided keyword arguments. It handles - errors related to missing, unexpected or invalid arguments. - - Parameters - ---------- - dataclass_type (Type[T]): - the dataclass type to instantiate. - kwargs (Any): - the keyword arguments to initialize the dataclass - Returns - ------- - Optional[T] - instance of dataclass if successfully initialized - """ - - try: - return dataclass_type(**kwargs) - except TypeError as ex: - error_message = str(ex) - allowed_params = ", ".join( - field.name for field in fields(dataclass_type) - ).rstrip() - if "__init__() missing" in error_message: - missing_params = error_message.split("missing ")[1] - raise AquaValueError( - "Error: Missing required parameters: " - f"{missing_params}. Allowable parameters are: {allowed_params}." - ) from ex - elif "__init__() got an unexpected keyword argument" in error_message: - unexpected_param = error_message.split("argument '")[1].rstrip("'") - raise AquaValueError( - "Error: Unexpected parameter: " - f"{unexpected_param}. Allowable parameters are: {allowed_params}." - ) from ex - else: - raise AquaValueError( - "Invalid parameters. Allowable parameters are: " f"{allowed_params}." - ) from ex - except ValidationError as ex: - custom_errors = {".".join(map(str, e["loc"])): e["msg"] for e in ex.errors()} - raise AquaValueError( - f"Invalid parameters. Error details: {custom_errors}." - ) from ex diff --git a/ads/aqua/evaluation/evaluation.py b/ads/aqua/evaluation/evaluation.py index 8dde31778..13adf0bdb 100644 --- a/ads/aqua/evaluation/evaluation.py +++ b/ads/aqua/evaluation/evaluation.py @@ -43,7 +43,6 @@ get_container_image, is_valid_ocid, upload_local_to_os, - validate_dataclass_params, ) from ads.aqua.config.config import get_evaluation_service_config from ads.aqua.constants import ( @@ -156,9 +155,16 @@ def create( The instance of AquaEvaluationSummary. """ if not create_aqua_evaluation_details: - create_aqua_evaluation_details = validate_dataclass_params( - CreateAquaEvaluationDetails, **kwargs - ) + try: + create_aqua_evaluation_details = CreateAquaEvaluationDetails(**kwargs) + except Exception as ex: + custom_errors = { + ".".join(map(str, e["loc"])): e["msg"] + for e in json.loads(ex.json()) + } + raise AquaValueError( + f"Invalid create evaluation parameters. Error details: {custom_errors}." + ) from ex if not is_valid_ocid(create_aqua_evaluation_details.evaluation_source_id): raise AquaValueError( diff --git a/ads/aqua/extension/finetune_handler.py b/ads/aqua/extension/finetune_handler.py index 50400e04c..32f3cae8f 100644 --- a/ads/aqua/extension/finetune_handler.py +++ b/ads/aqua/extension/finetune_handler.py @@ -89,7 +89,7 @@ def post(self, *args, **kwargs): # noqa: ARG002 params = input_data.get("params", None) return self.finish( - AquaFineTuningApp.validate_finetuning_params( + AquaFineTuningApp().validate_finetuning_params( params=params, ) ) diff --git a/ads/aqua/finetuning/finetuning.py b/ads/aqua/finetuning/finetuning.py index 04354db0c..0aca320e1 100644 --- a/ads/aqua/finetuning/finetuning.py +++ b/ads/aqua/finetuning/finetuning.py @@ -4,7 +4,7 @@ import json import os -from dataclasses import asdict, fields +from dataclasses import MISSING, asdict, fields from typing import Dict from oci.data_science.models import ( @@ -20,7 +20,6 @@ from ads.aqua.common.utils import ( get_container_image, upload_local_to_os, - validate_dataclass_params, ) from ads.aqua.constants import ( DEFAULT_FT_BATCH_SIZE, @@ -103,9 +102,16 @@ def create( The instance of AquaFineTuningSummary. """ if not create_fine_tuning_details: - create_fine_tuning_details = validate_dataclass_params( - CreateFineTuningDetails, **kwargs - ) + try: + create_fine_tuning_details = CreateFineTuningDetails(**kwargs) + except Exception as ex: + allowed_create_fine_tuning_details = ", ".join( + field.name for field in fields(CreateFineTuningDetails) + ).rstrip() + raise AquaValueError( + "Invalid create fine tuning parameters. Allowable parameters are: " + f"{allowed_create_fine_tuning_details}." + ) from ex source = self.get_source(create_fine_tuning_details.ft_source_id) @@ -615,8 +621,7 @@ def get_finetuning_default_params(self, model_id: str) -> Dict: return default_params - @staticmethod - def validate_finetuning_params(params: Dict = None) -> Dict: + def validate_finetuning_params(self, params: Dict = None) -> Dict: """Validate if the fine-tuning parameters passed by the user can be overridden. Parameter values are not validated, only param keys are validated. @@ -627,7 +632,21 @@ def validate_finetuning_params(params: Dict = None) -> Dict: Returns ------- - Return a dict with value true if valid, else raises AquaValueError. + Return a list of restricted params. """ - validate_dataclass_params(AquaFineTuningParams, **(params or {})) + try: + AquaFineTuningParams( + **params, + ) + except Exception as e: + logger.debug(str(e)) + allowed_fine_tuning_parameters = ", ".join( + f"{field.name} (required)" if field.default is MISSING else field.name + for field in fields(AquaFineTuningParams) + ).rstrip() + raise AquaValueError( + f"Invalid fine tuning parameters. Allowable parameters are: " + f"{allowed_fine_tuning_parameters}." + ) from e + return {"valid": True} From 7616ff24400abc213f724f8ab2f36c0f22873f3d Mon Sep 17 00:00:00 2001 From: Vipul Date: Fri, 31 Jan 2025 20:16:25 +0530 Subject: [PATCH 18/18] fix tests after merge --- tests/unitary/with_extras/aqua/test_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index 2a0475175..4cd59afb9 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -1282,7 +1282,7 @@ def test_import_model_with_input_tags( "inference_container": "odsc-vllm-serving", "ignore_model_artifact_check": True, }, - "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --inference_container odsc-vllm-serving --ignore_model_artifact_check True", + "ads aqua model register --model oracle/oracle-1it --os_path oci://aqua-bkt@aqua-ns/path --download_from_hf True --cleanup_model_cache True --inference_container odsc-vllm-serving --ignore_model_artifact_check True", ), ], )