Skip to content

Commit c7ff584

Browse files
Adding edit registered model api
1 parent 75a5f5c commit c7ff584

File tree

3 files changed

+93
-3
lines changed

3 files changed

+93
-3
lines changed

ads/aqua/extension/errors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ class Errors(str):
88
NO_INPUT_DATA = "No input data provided."
99
MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
1010
MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required."
11+
INVALID_VALUE_OF_PARAMETER = "Invalid value of parameter: '{}'"

ads/aqua/extension/model_handler.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from tornado.web import HTTPError
99

1010
from ads.aqua.common.decorator import handle_exceptions
11+
from ads.aqua.common.enums import InferenceContainerTypeFamily
1112
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1213
from ads.aqua.common.utils import get_hf_model_info, list_hf_models
1314
from ads.aqua.extension.base_handler import AquaAPIhandler
@@ -139,6 +140,28 @@ def post(self, *args, **kwargs):
139140
)
140141
)
141142

143+
@handle_exceptions
144+
def put(self,id):
145+
try:
146+
input_data = self.get_json_body()
147+
except Exception as ex:
148+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
149+
150+
if not input_data:
151+
raise HTTPError(400, Errors.NO_INPUT_DATA)
152+
153+
inference_container=input_data.get('inference_container')
154+
if inference_container is not None and inference_container not in [
155+
InferenceContainerTypeFamily.AQUA_TGI_CONTAINER_FAMILY,
156+
InferenceContainerTypeFamily.AQUA_VLLM_CONTAINER_FAMILY,
157+
InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
158+
]:
159+
raise HTTPError(400,Errors.INVALID_VALUE_OF_PARAMETER.format("inference_container"))
160+
enable_finetuning=input_data.get('enable_finetuning')
161+
task=input_data.get('task')
162+
return self.finish(AquaModelApp().edit_registered_model(id,inference_container,enable_finetuning,task))
163+
164+
142165

143166
class AquaModelLicenseHandler(AquaAPIhandler):
144167
"""Handler for Aqua Model license REST APIs."""

ads/aqua/model/model.py

Lines changed: 69 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
import oci
1111
from cachetools import TTLCache
1212
from huggingface_hub import snapshot_download
13-
from oci.data_science.models import JobRun, Model
13+
from oci.data_science.models import JobRun, Model, UpdateModelDetails, Metadata
1414

1515
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
1616
from ads.aqua.app import AquaApp
17-
from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags
17+
from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags, FineTuningContainerTypeFamily
1818
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1919
from ads.aqua.common.utils import (
2020
LifecycleStatus,
@@ -75,7 +75,7 @@
7575
TENANCY_OCID,
7676
)
7777
from ads.model import DataScienceModel
78-
from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem
78+
from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem, MetadataCustomCategory
7979
from ads.telemetry import telemetry
8080

8181

@@ -332,6 +332,72 @@ def delete_registered_model(self,model_id):
332332
else:
333333
raise AquaRuntimeError(f"Failed to delete model:{model_id}. Only registered models can be deleted.")
334334

335+
@telemetry(entry_point="plugin=model&action=delete", name="aqua")
336+
def edit_registered_model(self,id,inference_container,enable_finetuning,task):
337+
"""Edits the default config of unverified registered model.
338+
339+
Parameters
340+
----------
341+
id: str
342+
The model OCID.
343+
inference_container: str.
344+
The inference container family name
345+
enable_finetuning: str
346+
Flag to enable or disable finetuning over the model. Defaults to None
347+
348+
Returns
349+
-------
350+
Model:
351+
The instance of oci.data_science.models.Model.
352+
353+
"""
354+
ds_model=DataScienceModel.from_id(id)
355+
if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM,None):
356+
if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG,None):
357+
raise AquaRuntimeError(f"Failed to edit model:{id}. Only registered unverified models can be edited.")
358+
else:
359+
custom_metadata_list=ds_model.custom_metadata_list
360+
freeform_tags=ds_model.freeform_tags
361+
if inference_container:
362+
custom_metadata_list.add(key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
363+
value=inference_container,
364+
category=MetadataCustomCategory.OTHER,
365+
description="Deployment container mapping for SMC",
366+
replace=True
367+
)
368+
if enable_finetuning is not None:
369+
if enable_finetuning.lower()=="true":
370+
custom_metadata_list.add(key=ModelCustomMetadataFields.FINETUNE_CONTAINER,
371+
value=FineTuningContainerTypeFamily.AQUA_FINETUNING_CONTAINER_FAMILY,
372+
category=MetadataCustomCategory.OTHER,
373+
description="Fine-tuning container mapping for SMC",
374+
replace=True
375+
)
376+
freeform_tags.update({Tags.READY_TO_FINE_TUNE:"true"})
377+
elif enable_finetuning.lower()=="false":
378+
try:
379+
custom_metadata_list.remove(ModelCustomMetadataFields.FINETUNE_CONTAINER)
380+
freeform_tags.pop(Tags.READY_TO_FINE_TUNE)
381+
except Exception as ex:
382+
raise AquaRuntimeError(f"The given model already doesn't support finetuning: {ex}")
383+
384+
custom_metadata_list.remove("modelDescription")
385+
if task:
386+
freeform_tags.update({"task":task})
387+
388+
updated_custom_metadata_list = [
389+
Metadata(**metadata)
390+
for metadata in custom_metadata_list.to_dict()["data"]
391+
]
392+
update_model_details = UpdateModelDetails(
393+
custom_metadata_list=updated_custom_metadata_list,
394+
freeform_tags=freeform_tags
395+
)
396+
return self.ds_client.update_model(id,update_model_details).data
397+
else:
398+
raise AquaRuntimeError(f"Failed to edit model:{id}. Only registered unverified models can be deleted.")
399+
400+
335401
def _fetch_metric_from_metadata(
336402
self,
337403
custom_metadata_list: ModelCustomMetadata,

0 commit comments

Comments
 (0)