1010import oci
1111from cachetools import TTLCache
1212from huggingface_hub import snapshot_download
13- from oci .data_science .models import JobRun , Model , UpdateModelDetails , Metadata
13+ from oci .data_science .models import JobRun , Metadata , Model , UpdateModelDetails
1414
1515from ads .aqua import ODSC_MODEL_COMPARTMENT_OCID , logger
1616from ads .aqua .app import AquaApp
17- from ads .aqua .common .enums import InferenceContainerTypeFamily , Tags , FineTuningContainerTypeFamily
17+ from ads .aqua .common .enums import (
18+ FineTuningContainerTypeFamily ,
19+ InferenceContainerTypeFamily ,
20+ Tags ,
21+ )
1822from ads .aqua .common .errors import AquaRuntimeError , AquaValueError
1923from ads .aqua .common .utils import (
2024 LifecycleStatus ,
7579 TENANCY_OCID ,
7680)
7781from ads .model import DataScienceModel
78- from ads .model .model_metadata import ModelCustomMetadata , ModelCustomMetadataItem , MetadataCustomCategory
82+ from ads .model .model_metadata import (
83+ MetadataCustomCategory ,
84+ ModelCustomMetadata ,
85+ ModelCustomMetadataItem ,
86+ )
7987from ads .telemetry import telemetry
8088
8189
@@ -324,16 +332,21 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod
324332 return model_details
325333
326334 @telemetry (entry_point = "plugin=model&action=delete" , name = "aqua" )
327- def delete_registered_model (self ,model_id ):
328- ds_model = DataScienceModel .from_id (model_id )
329- is_registered_model = ds_model .freeform_tags .get (Tags .BASE_MODEL_CUSTOM ,None )
330- if is_registered_model :
335+ def delete_model (self , model_id ):
336+ ds_model = DataScienceModel .from_id (model_id )
337+ is_registered_model = ds_model .freeform_tags .get (Tags .BASE_MODEL_CUSTOM , None )
338+ is_fine_tuned_model = ds_model .freeform_tags .get (
339+ Tags .AQUA_FINE_TUNED_MODEL_TAG , None
340+ )
341+ if is_registered_model or is_fine_tuned_model :
331342 return ds_model .delete ()
332343 else :
333- raise AquaRuntimeError (f"Failed to delete model:{ model_id } . Only registered models can be deleted." )
344+ raise AquaRuntimeError (
345+ f"Failed to delete model:{ model_id } . Only registered models or finetuned model can be deleted."
346+ )
334347
335348 @telemetry (entry_point = "plugin=model&action=delete" , name = "aqua" )
336- def edit_registered_model (self ,id ,inference_container ,enable_finetuning ,task ):
349+ def edit_registered_model (self , id , inference_container , enable_finetuning , task ):
337350 """Edits the default config of unverified registered model.
338351
339352 Parameters
@@ -351,52 +364,61 @@ def edit_registered_model(self,id,inference_container,enable_finetuning,task):
351364 The instance of oci.data_science.models.Model.
352365
353366 """
354- ds_model = DataScienceModel .from_id (id )
355- if ds_model .freeform_tags .get (Tags .BASE_MODEL_CUSTOM ,None ):
356- if ds_model .freeform_tags .get (Tags .AQUA_SERVICE_MODEL_TAG ,None ):
357- raise AquaRuntimeError (f"Failed to edit model:{ id } . Only registered unverified models can be edited." )
367+ ds_model = DataScienceModel .from_id (id )
368+ if ds_model .freeform_tags .get (Tags .BASE_MODEL_CUSTOM , None ):
369+ if ds_model .freeform_tags .get (Tags .AQUA_SERVICE_MODEL_TAG , None ):
370+ raise AquaRuntimeError (
371+ f"Failed to edit model:{ id } . Only registered unverified models can be edited."
372+ )
358373 else :
359- custom_metadata_list = ds_model .custom_metadata_list
360- freeform_tags = ds_model .freeform_tags
374+ custom_metadata_list = ds_model .custom_metadata_list
375+ freeform_tags = ds_model .freeform_tags
361376 if inference_container :
362- custom_metadata_list .add (key = ModelCustomMetadataFields .DEPLOYMENT_CONTAINER ,
363- value = inference_container ,
364- category = MetadataCustomCategory .OTHER ,
365- description = "Deployment container mapping for SMC" ,
366- replace = True
367- )
377+ custom_metadata_list .add (
378+ key = ModelCustomMetadataFields .DEPLOYMENT_CONTAINER ,
379+ value = inference_container ,
380+ category = MetadataCustomCategory .OTHER ,
381+ description = "Deployment container mapping for SMC" ,
382+ replace = True ,
383+ )
368384 if enable_finetuning is not None :
369- if enable_finetuning .lower ()== "true" :
370- custom_metadata_list .add (key = ModelCustomMetadataFields .FINETUNE_CONTAINER ,
371- value = FineTuningContainerTypeFamily .AQUA_FINETUNING_CONTAINER_FAMILY ,
372- category = MetadataCustomCategory .OTHER ,
373- description = "Fine-tuning container mapping for SMC" ,
374- replace = True
375- )
376- freeform_tags .update ({Tags .READY_TO_FINE_TUNE :"true" })
377- elif enable_finetuning .lower ()== "false" :
385+ if enable_finetuning .lower () == "true" :
386+ custom_metadata_list .add (
387+ key = ModelCustomMetadataFields .FINETUNE_CONTAINER ,
388+ value = FineTuningContainerTypeFamily .AQUA_FINETUNING_CONTAINER_FAMILY ,
389+ category = MetadataCustomCategory .OTHER ,
390+ description = "Fine-tuning container mapping for SMC" ,
391+ replace = True ,
392+ )
393+ freeform_tags .update ({Tags .READY_TO_FINE_TUNE : "true" })
394+ elif enable_finetuning .lower () == "false" :
378395 try :
379- custom_metadata_list .remove (ModelCustomMetadataFields .FINETUNE_CONTAINER )
396+ custom_metadata_list .remove (
397+ ModelCustomMetadataFields .FINETUNE_CONTAINER
398+ )
380399 freeform_tags .pop (Tags .READY_TO_FINE_TUNE )
381400 except Exception as ex :
382- raise AquaRuntimeError (f"The given model already doesn't support finetuning: { ex } " )
401+ raise AquaRuntimeError (
402+ f"The given model already doesn't support finetuning: { ex } "
403+ )
383404
384405 custom_metadata_list .remove ("modelDescription" )
385406 if task :
386- freeform_tags .update ({"task" :task })
407+ freeform_tags .update ({"task" : task })
387408
388409 updated_custom_metadata_list = [
389410 Metadata (** metadata )
390411 for metadata in custom_metadata_list .to_dict ()["data" ]
391412 ]
392413 update_model_details = UpdateModelDetails (
393414 custom_metadata_list = updated_custom_metadata_list ,
394- freeform_tags = freeform_tags
415+ freeform_tags = freeform_tags ,
395416 )
396- return self . ds_client . update_model (id ,update_model_details ).data
417+ return AquaApp (). update_model (id , update_model_details ).data
397418 else :
398- raise AquaRuntimeError (f"Failed to edit model:{ id } . Only registered unverified models can be deleted." )
399-
419+ raise AquaRuntimeError (
420+ f"Failed to edit model:{ id } . Only registered unverified models can be deleted."
421+ )
400422
401423 def _fetch_metric_from_metadata (
402424 self ,
@@ -1010,38 +1032,39 @@ def _validate_model(
10101032 # gguf extension exist.
10111033 if {ModelFormat .SAFETENSORS , ModelFormat .GGUF }.issubset (set (model_formats )):
10121034 if (
1013- import_model_details .inference_container .lower () == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
1035+ import_model_details .inference_container .lower ()
1036+ == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
10141037 ):
10151038 self ._validate_gguf_format (
10161039 import_model_details = import_model_details ,
10171040 verified_model = verified_model ,
10181041 gguf_model_files = gguf_model_files ,
10191042 validation_result = validation_result ,
1020- model_name = model_name
1043+ model_name = model_name ,
10211044 )
10221045 else :
10231046 self ._validate_safetensor_format (
10241047 import_model_details = import_model_details ,
10251048 verified_model = verified_model ,
10261049 validation_result = validation_result ,
10271050 hf_download_config_present = hf_download_config_present ,
1028- model_name = model_name
1051+ model_name = model_name ,
10291052 )
10301053 elif ModelFormat .SAFETENSORS in model_formats :
10311054 self ._validate_safetensor_format (
10321055 import_model_details = import_model_details ,
10331056 verified_model = verified_model ,
10341057 validation_result = validation_result ,
10351058 hf_download_config_present = hf_download_config_present ,
1036- model_name = model_name
1059+ model_name = model_name ,
10371060 )
10381061 elif ModelFormat .GGUF in model_formats :
10391062 self ._validate_gguf_format (
10401063 import_model_details = import_model_details ,
10411064 verified_model = verified_model ,
10421065 gguf_model_files = gguf_model_files ,
10431066 validation_result = validation_result ,
1044- model_name = model_name
1067+ model_name = model_name ,
10451068 )
10461069
10471070 return validation_result
@@ -1052,7 +1075,7 @@ def _validate_safetensor_format(
10521075 verified_model : DataScienceModel = None ,
10531076 validation_result : ModelValidationResult = None ,
10541077 hf_download_config_present : bool = None ,
1055- model_name : str = None
1078+ model_name : str = None ,
10561079 ):
10571080 if import_model_details .download_from_hf :
10581081 # validates config.json exists for safetensors model from hugginface
@@ -1079,20 +1102,13 @@ def _validate_safetensor_format(
10791102 ) from ex
10801103 else :
10811104 try :
1082- metadata_model_type = (
1083- verified_model .custom_metadata_list .get (
1084- AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1085- ).value
1086- )
1105+ metadata_model_type = verified_model .custom_metadata_list .get (
1106+ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1107+ ).value
10871108 if metadata_model_type :
1088- if (
1089- AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1090- in model_config
1091- ):
1109+ if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config :
10921110 if (
1093- model_config [
1094- AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1095- ]
1111+ model_config [AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE ]
10961112 != metadata_model_type
10971113 ):
10981114 raise AquaRuntimeError (
@@ -1110,9 +1126,7 @@ def _validate_safetensor_format(
11101126 except Exception :
11111127 pass
11121128 if verified_model :
1113- validation_result .telemetry_model_name = (
1114- verified_model .display_name
1115- )
1129+ validation_result .telemetry_model_name = verified_model .display_name
11161130 elif (
11171131 model_config is not None
11181132 and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
@@ -1124,9 +1138,7 @@ def _validate_safetensor_format(
11241138 ):
11251139 validation_result .telemetry_model_name = f"{ AQUA_MODEL_TYPE_CUSTOM } _{ model_config [AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE ]} "
11261140 else :
1127- validation_result .telemetry_model_name = (
1128- AQUA_MODEL_TYPE_CUSTOM
1129- )
1141+ validation_result .telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
11301142
11311143 @staticmethod
11321144 def _validate_gguf_format (
0 commit comments