1010import oci
1111from cachetools import TTLCache
1212from huggingface_hub import snapshot_download
13- from oci .data_science .models import JobRun , Model
13+ from oci .data_science .models import JobRun , Metadata , Model , UpdateModelDetails
1414
1515from ads .aqua import ODSC_MODEL_COMPARTMENT_OCID , logger
1616from ads .aqua .app import AquaApp
17- from ads .aqua .common .enums import InferenceContainerTypeFamily , Tags
17+ from ads .aqua .common .enums import (
18+ FineTuningContainerTypeFamily ,
19+ InferenceContainerTypeFamily ,
20+ Tags ,
21+ )
1822from ads .aqua .common .errors import AquaRuntimeError , AquaValueError
1923from ads .aqua .common .utils import (
2024 LifecycleStatus ,
7579 TENANCY_OCID ,
7680)
7781from ads .model import DataScienceModel
78- from ads .model .model_metadata import ModelCustomMetadata , ModelCustomMetadataItem
82+ from ads .model .model_metadata import (
83+ MetadataCustomCategory ,
84+ ModelCustomMetadata ,
85+ ModelCustomMetadataItem ,
86+ )
7987from ads .telemetry import telemetry
8088
8189
@@ -323,6 +331,97 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod
323331
324332 return model_details
325333
334+ @telemetry (entry_point = "plugin=model&action=delete" , name = "aqua" )
335+ def delete_model (self , model_id ):
336+ ds_model = DataScienceModel .from_id (model_id )
337+ is_registered_model = ds_model .freeform_tags .get (Tags .BASE_MODEL_CUSTOM , None )
338+ is_fine_tuned_model = ds_model .freeform_tags .get (
339+ Tags .AQUA_FINE_TUNED_MODEL_TAG , None
340+ )
341+ if is_registered_model or is_fine_tuned_model :
342+ return ds_model .delete ()
343+ else :
344+ raise AquaRuntimeError (
345+ f"Failed to delete model:{ model_id } . Only registered models or finetuned model can be deleted."
346+ )
347+
348+ @telemetry (entry_point = "plugin=model&action=delete" , name = "aqua" )
349+ def edit_registered_model (self , id , inference_container , enable_finetuning , task ):
350+ """Edits the default config of unverified registered model.
351+
352+ Parameters
353+ ----------
354+ id: str
355+ The model OCID.
356+ inference_container: str.
357+ The inference container family name
358+ enable_finetuning: str
359+ Flag to enable or disable finetuning over the model. Defaults to None
360+ task:
361+ The usecase type of the model. e.g , text-generation , text_embedding etc.
362+
363+ Returns
364+ -------
365+ Model:
366+ The instance of oci.data_science.models.Model.
367+
368+ """
369+ ds_model = DataScienceModel .from_id (id )
370+ if ds_model .freeform_tags .get (Tags .BASE_MODEL_CUSTOM , None ):
371+ if ds_model .freeform_tags .get (Tags .AQUA_SERVICE_MODEL_TAG , None ):
372+ raise AquaRuntimeError (
373+ f"Failed to edit model:{ id } . Only registered unverified models can be edited."
374+ )
375+ else :
376+ custom_metadata_list = ds_model .custom_metadata_list
377+ freeform_tags = ds_model .freeform_tags
378+ if inference_container :
379+ custom_metadata_list .add (
380+ key = ModelCustomMetadataFields .DEPLOYMENT_CONTAINER ,
381+ value = inference_container ,
382+ category = MetadataCustomCategory .OTHER ,
383+ description = "Deployment container mapping for SMC" ,
384+ replace = True ,
385+ )
386+ if enable_finetuning is not None :
387+ if enable_finetuning .lower () == "true" :
388+ custom_metadata_list .add (
389+ key = ModelCustomMetadataFields .FINETUNE_CONTAINER ,
390+ value = FineTuningContainerTypeFamily .AQUA_FINETUNING_CONTAINER_FAMILY ,
391+ category = MetadataCustomCategory .OTHER ,
392+ description = "Fine-tuning container mapping for SMC" ,
393+ replace = True ,
394+ )
395+ freeform_tags .update ({Tags .READY_TO_FINE_TUNE : "true" })
396+ elif enable_finetuning .lower () == "false" :
397+ try :
398+ custom_metadata_list .remove (
399+ ModelCustomMetadataFields .FINETUNE_CONTAINER
400+ )
401+ freeform_tags .pop (Tags .READY_TO_FINE_TUNE )
402+ except Exception as ex :
403+ raise AquaRuntimeError (
404+ f"The given model already doesn't support finetuning: { ex } "
405+ )
406+
407+ custom_metadata_list .remove ("modelDescription" )
408+ if task :
409+ freeform_tags .update ({Tags .TASK : task })
410+
411+ updated_custom_metadata_list = [
412+ Metadata (** metadata )
413+ for metadata in custom_metadata_list .to_dict ()["data" ]
414+ ]
415+ update_model_details = UpdateModelDetails (
416+ custom_metadata_list = updated_custom_metadata_list ,
417+ freeform_tags = freeform_tags ,
418+ )
419+ return AquaApp ().update_model (id , update_model_details ).data
420+ else :
421+ raise AquaRuntimeError (
422+ f"Failed to edit model:{ id } . Only registered unverified models can be edited."
423+ )
424+
326425 def _fetch_metric_from_metadata (
327426 self ,
328427 custom_metadata_list : ModelCustomMetadata ,
@@ -935,38 +1034,39 @@ def _validate_model(
9351034 # gguf extension exist.
9361035 if {ModelFormat .SAFETENSORS , ModelFormat .GGUF }.issubset (set (model_formats )):
9371036 if (
938- import_model_details .inference_container .lower () == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
1037+ import_model_details .inference_container .lower ()
1038+ == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
9391039 ):
9401040 self ._validate_gguf_format (
9411041 import_model_details = import_model_details ,
9421042 verified_model = verified_model ,
9431043 gguf_model_files = gguf_model_files ,
9441044 validation_result = validation_result ,
945- model_name = model_name
1045+ model_name = model_name ,
9461046 )
9471047 else :
9481048 self ._validate_safetensor_format (
9491049 import_model_details = import_model_details ,
9501050 verified_model = verified_model ,
9511051 validation_result = validation_result ,
9521052 hf_download_config_present = hf_download_config_present ,
953- model_name = model_name
1053+ model_name = model_name ,
9541054 )
9551055 elif ModelFormat .SAFETENSORS in model_formats :
9561056 self ._validate_safetensor_format (
9571057 import_model_details = import_model_details ,
9581058 verified_model = verified_model ,
9591059 validation_result = validation_result ,
9601060 hf_download_config_present = hf_download_config_present ,
961- model_name = model_name
1061+ model_name = model_name ,
9621062 )
9631063 elif ModelFormat .GGUF in model_formats :
9641064 self ._validate_gguf_format (
9651065 import_model_details = import_model_details ,
9661066 verified_model = verified_model ,
9671067 gguf_model_files = gguf_model_files ,
9681068 validation_result = validation_result ,
969- model_name = model_name
1069+ model_name = model_name ,
9701070 )
9711071
9721072 return validation_result
@@ -977,7 +1077,7 @@ def _validate_safetensor_format(
9771077 verified_model : DataScienceModel = None ,
9781078 validation_result : ModelValidationResult = None ,
9791079 hf_download_config_present : bool = None ,
980- model_name : str = None
1080+ model_name : str = None ,
9811081 ):
9821082 if import_model_details .download_from_hf :
9831083 # validates config.json exists for safetensors model from hugginface
@@ -1004,20 +1104,13 @@ def _validate_safetensor_format(
10041104 ) from ex
10051105 else :
10061106 try :
1007- metadata_model_type = (
1008- verified_model .custom_metadata_list .get (
1009- AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1010- ).value
1011- )
1107+ metadata_model_type = verified_model .custom_metadata_list .get (
1108+ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1109+ ).value
10121110 if metadata_model_type :
1013- if (
1014- AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1015- in model_config
1016- ):
1111+ if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config :
10171112 if (
1018- model_config [
1019- AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1020- ]
1113+ model_config [AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE ]
10211114 != metadata_model_type
10221115 ):
10231116 raise AquaRuntimeError (
@@ -1035,9 +1128,7 @@ def _validate_safetensor_format(
10351128 except Exception :
10361129 pass
10371130 if verified_model :
1038- validation_result .telemetry_model_name = (
1039- verified_model .display_name
1040- )
1131+ validation_result .telemetry_model_name = verified_model .display_name
10411132 elif (
10421133 model_config is not None
10431134 and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
@@ -1049,9 +1140,7 @@ def _validate_safetensor_format(
10491140 ):
10501141 validation_result .telemetry_model_name = f"{ AQUA_MODEL_TYPE_CUSTOM } _{ model_config [AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE ]} "
10511142 else :
1052- validation_result .telemetry_model_name = (
1053- AQUA_MODEL_TYPE_CUSTOM
1054- )
1143+ validation_result .telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
10551144
10561145 @staticmethod
10571146 def _validate_gguf_format (
0 commit comments