Skip to content

Commit a2abc5c

Browse files
Adding delete/activate/deactivate support for model deployment and registered models (#972)
2 parents b441ea7 + 3582437 commit a2abc5c

File tree

7 files changed

+265
-30
lines changed

7 files changed

+265
-30
lines changed

ads/aqua/extension/deployment_handler.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,33 @@ def get(self, id=""):
5454
else:
5555
raise HTTPError(400, f"The request {self.request.path} is invalid.")
5656

57+
@handle_exceptions
58+
def delete(self, model_deployment_id):
59+
return self.finish(AquaDeploymentApp().delete(model_deployment_id))
60+
61+
@handle_exceptions
62+
def put(self, *args, **kwargs):
63+
"""
64+
Handles put request for the activating and deactivating OCI datascience model deployments
65+
Raises
66+
------
67+
HTTPError
68+
Raises HTTPError if inputs are missing or are invalid
69+
"""
70+
url_parse = urlparse(self.request.path)
71+
paths = url_parse.path.strip("/").split("/")
72+
if len(paths) != 4 or paths[0] != "aqua" or paths[1] != "deployments":
73+
raise HTTPError(400, f"The request {self.request.path} is invalid.")
74+
75+
model_deployment_id = paths[2]
76+
action = paths[3]
77+
if action == "activate":
78+
return self.finish(AquaDeploymentApp().activate(model_deployment_id))
79+
elif action == "deactivate":
80+
return self.finish(AquaDeploymentApp().deactivate(model_deployment_id))
81+
else:
82+
raise HTTPError(400, f"The request {self.request.path} is invalid.")
83+
5784
@handle_exceptions
5885
def post(self, *args, **kwargs):
5986
"""
@@ -264,5 +291,7 @@ def post(self, *args, **kwargs):
264291
("deployments/?([^/]*)/params", AquaDeploymentParamsHandler),
265292
("deployments/config/?([^/]*)", AquaDeploymentHandler),
266293
("deployments/?([^/]*)", AquaDeploymentHandler),
294+
("deployments/?([^/]*)/activate", AquaDeploymentHandler),
295+
("deployments/?([^/]*)/deactivate", AquaDeploymentHandler),
267296
("inference", AquaDeploymentInferenceHandler),
268297
]

ads/aqua/extension/errors.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,4 @@ class Errors(str):
88
NO_INPUT_DATA = "No input data provided."
99
MISSING_REQUIRED_PARAMETER = "Missing required parameter: '{}'"
1010
MISSING_ONEOF_REQUIRED_PARAMETER = "Either '{}' or '{}' is required."
11+
INVALID_VALUE_OF_PARAMETER = "Invalid value of parameter: '{}'"

ads/aqua/extension/model_handler.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,16 @@
99

1010
from ads.aqua.common.decorator import handle_exceptions
1111
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
12-
from ads.aqua.common.utils import get_hf_model_info, list_hf_models
12+
from ads.aqua.common.utils import (
13+
get_container_config,
14+
get_hf_model_info,
15+
list_hf_models,
16+
)
1317
from ads.aqua.extension.base_handler import AquaAPIhandler
1418
from ads.aqua.extension.errors import Errors
1519
from ads.aqua.model import AquaModelApp
1620
from ads.aqua.model.entities import AquaModelSummary, HFModelSummary
17-
from ads.aqua.ui import ModelFormat
21+
from ads.aqua.ui import AquaContainerConfig, ModelFormat
1822

1923

2024
class AquaModelHandler(AquaAPIhandler):
@@ -73,6 +77,8 @@ def delete(self, id=""):
7377
paths = url_parse.path.strip("/")
7478
if paths.startswith("aqua/model/cache"):
7579
return self.finish(AquaModelApp().clear_model_list_cache())
80+
elif id:
81+
return self.finish(AquaModelApp().delete_model(id))
7682
else:
7783
raise HTTPError(400, f"The request {self.request.path} is invalid.")
7884

@@ -137,6 +143,37 @@ def post(self, *args, **kwargs):
137143
)
138144
)
139145

146+
@handle_exceptions
147+
def put(self, id):
148+
try:
149+
input_data = self.get_json_body()
150+
except Exception as ex:
151+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
152+
153+
if not input_data:
154+
raise HTTPError(400, Errors.NO_INPUT_DATA)
155+
156+
inference_container = input_data.get("inference_container")
157+
containers = list(
158+
AquaContainerConfig.from_container_index_json(
159+
config=get_container_config(), enable_spec=True
160+
).inference.values()
161+
)
162+
family_values = [item.family for item in containers]
163+
164+
if inference_container is not None and inference_container not in family_values:
165+
raise HTTPError(
166+
400, Errors.INVALID_VALUE_OF_PARAMETER.format("inference_container")
167+
)
168+
169+
enable_finetuning = input_data.get("enable_finetuning")
170+
task = input_data.get("task")
171+
return self.finish(
172+
AquaModelApp().edit_registered_model(
173+
id, inference_container, enable_finetuning, task
174+
)
175+
)
176+
140177

141178
class AquaModelLicenseHandler(AquaAPIhandler):
142179
"""Handler for Aqua Model license REST APIs."""

ads/aqua/model/model.py

Lines changed: 116 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,15 @@
1010
import oci
1111
from cachetools import TTLCache
1212
from huggingface_hub import snapshot_download
13-
from oci.data_science.models import JobRun, Model
13+
from oci.data_science.models import JobRun, Metadata, Model, UpdateModelDetails
1414

1515
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
1616
from ads.aqua.app import AquaApp
17-
from ads.aqua.common.enums import InferenceContainerTypeFamily, Tags
17+
from ads.aqua.common.enums import (
18+
FineTuningContainerTypeFamily,
19+
InferenceContainerTypeFamily,
20+
Tags,
21+
)
1822
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1923
from ads.aqua.common.utils import (
2024
LifecycleStatus,
@@ -75,7 +79,11 @@
7579
TENANCY_OCID,
7680
)
7781
from ads.model import DataScienceModel
78-
from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem
82+
from ads.model.model_metadata import (
83+
MetadataCustomCategory,
84+
ModelCustomMetadata,
85+
ModelCustomMetadataItem,
86+
)
7987
from ads.telemetry import telemetry
8088

8189

@@ -323,6 +331,97 @@ def get(self, model_id: str, load_model_card: Optional[bool] = True) -> "AquaMod
323331

324332
return model_details
325333

334+
@telemetry(entry_point="plugin=model&action=delete", name="aqua")
335+
def delete_model(self, model_id):
336+
ds_model = DataScienceModel.from_id(model_id)
337+
is_registered_model = ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None)
338+
is_fine_tuned_model = ds_model.freeform_tags.get(
339+
Tags.AQUA_FINE_TUNED_MODEL_TAG, None
340+
)
341+
if is_registered_model or is_fine_tuned_model:
342+
return ds_model.delete()
343+
else:
344+
raise AquaRuntimeError(
345+
f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted."
346+
)
347+
348+
@telemetry(entry_point="plugin=model&action=delete", name="aqua")
349+
def edit_registered_model(self, id, inference_container, enable_finetuning, task):
350+
"""Edits the default config of unverified registered model.
351+
352+
Parameters
353+
----------
354+
id: str
355+
The model OCID.
356+
inference_container: str.
357+
The inference container family name
358+
enable_finetuning: str
359+
Flag to enable or disable finetuning over the model. Defaults to None
360+
task:
361+
The usecase type of the model. e.g , text-generation , text_embedding etc.
362+
363+
Returns
364+
-------
365+
Model:
366+
The instance of oci.data_science.models.Model.
367+
368+
"""
369+
ds_model = DataScienceModel.from_id(id)
370+
if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None):
371+
if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None):
372+
raise AquaRuntimeError(
373+
f"Failed to edit model:{id}. Only registered unverified models can be edited."
374+
)
375+
else:
376+
custom_metadata_list = ds_model.custom_metadata_list
377+
freeform_tags = ds_model.freeform_tags
378+
if inference_container:
379+
custom_metadata_list.add(
380+
key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER,
381+
value=inference_container,
382+
category=MetadataCustomCategory.OTHER,
383+
description="Deployment container mapping for SMC",
384+
replace=True,
385+
)
386+
if enable_finetuning is not None:
387+
if enable_finetuning.lower() == "true":
388+
custom_metadata_list.add(
389+
key=ModelCustomMetadataFields.FINETUNE_CONTAINER,
390+
value=FineTuningContainerTypeFamily.AQUA_FINETUNING_CONTAINER_FAMILY,
391+
category=MetadataCustomCategory.OTHER,
392+
description="Fine-tuning container mapping for SMC",
393+
replace=True,
394+
)
395+
freeform_tags.update({Tags.READY_TO_FINE_TUNE: "true"})
396+
elif enable_finetuning.lower() == "false":
397+
try:
398+
custom_metadata_list.remove(
399+
ModelCustomMetadataFields.FINETUNE_CONTAINER
400+
)
401+
freeform_tags.pop(Tags.READY_TO_FINE_TUNE)
402+
except Exception as ex:
403+
raise AquaRuntimeError(
404+
f"The given model already doesn't support finetuning: {ex}"
405+
)
406+
407+
custom_metadata_list.remove("modelDescription")
408+
if task:
409+
freeform_tags.update({Tags.TASK: task})
410+
411+
updated_custom_metadata_list = [
412+
Metadata(**metadata)
413+
for metadata in custom_metadata_list.to_dict()["data"]
414+
]
415+
update_model_details = UpdateModelDetails(
416+
custom_metadata_list=updated_custom_metadata_list,
417+
freeform_tags=freeform_tags,
418+
)
419+
return AquaApp().update_model(id, update_model_details).data
420+
else:
421+
raise AquaRuntimeError(
422+
f"Failed to edit model:{id}. Only registered unverified models can be edited."
423+
)
424+
326425
def _fetch_metric_from_metadata(
327426
self,
328427
custom_metadata_list: ModelCustomMetadata,
@@ -935,38 +1034,39 @@ def _validate_model(
9351034
# gguf extension exist.
9361035
if {ModelFormat.SAFETENSORS, ModelFormat.GGUF}.issubset(set(model_formats)):
9371036
if (
938-
import_model_details.inference_container.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
1037+
import_model_details.inference_container.lower()
1038+
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
9391039
):
9401040
self._validate_gguf_format(
9411041
import_model_details=import_model_details,
9421042
verified_model=verified_model,
9431043
gguf_model_files=gguf_model_files,
9441044
validation_result=validation_result,
945-
model_name=model_name
1045+
model_name=model_name,
9461046
)
9471047
else:
9481048
self._validate_safetensor_format(
9491049
import_model_details=import_model_details,
9501050
verified_model=verified_model,
9511051
validation_result=validation_result,
9521052
hf_download_config_present=hf_download_config_present,
953-
model_name=model_name
1053+
model_name=model_name,
9541054
)
9551055
elif ModelFormat.SAFETENSORS in model_formats:
9561056
self._validate_safetensor_format(
9571057
import_model_details=import_model_details,
9581058
verified_model=verified_model,
9591059
validation_result=validation_result,
9601060
hf_download_config_present=hf_download_config_present,
961-
model_name=model_name
1061+
model_name=model_name,
9621062
)
9631063
elif ModelFormat.GGUF in model_formats:
9641064
self._validate_gguf_format(
9651065
import_model_details=import_model_details,
9661066
verified_model=verified_model,
9671067
gguf_model_files=gguf_model_files,
9681068
validation_result=validation_result,
969-
model_name=model_name
1069+
model_name=model_name,
9701070
)
9711071

9721072
return validation_result
@@ -977,7 +1077,7 @@ def _validate_safetensor_format(
9771077
verified_model: DataScienceModel = None,
9781078
validation_result: ModelValidationResult = None,
9791079
hf_download_config_present: bool = None,
980-
model_name: str = None
1080+
model_name: str = None,
9811081
):
9821082
if import_model_details.download_from_hf:
9831083
# validates config.json exists for safetensors model from hugginface
@@ -1004,20 +1104,13 @@ def _validate_safetensor_format(
10041104
) from ex
10051105
else:
10061106
try:
1007-
metadata_model_type = (
1008-
verified_model.custom_metadata_list.get(
1009-
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1010-
).value
1011-
)
1107+
metadata_model_type = verified_model.custom_metadata_list.get(
1108+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1109+
).value
10121110
if metadata_model_type:
1013-
if (
1014-
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1015-
in model_config
1016-
):
1111+
if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
10171112
if (
1018-
model_config[
1019-
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
1020-
]
1113+
model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]
10211114
!= metadata_model_type
10221115
):
10231116
raise AquaRuntimeError(
@@ -1035,9 +1128,7 @@ def _validate_safetensor_format(
10351128
except Exception:
10361129
pass
10371130
if verified_model:
1038-
validation_result.telemetry_model_name = (
1039-
verified_model.display_name
1040-
)
1131+
validation_result.telemetry_model_name = verified_model.display_name
10411132
elif (
10421133
model_config is not None
10431134
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
@@ -1049,9 +1140,7 @@ def _validate_safetensor_format(
10491140
):
10501141
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
10511142
else:
1052-
validation_result.telemetry_model_name = (
1053-
AQUA_MODEL_TYPE_CUSTOM
1054-
)
1143+
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
10551144

10561145
@staticmethod
10571146
def _validate_gguf_format(

ads/aqua/modeldeployment/deployment.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,18 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
485485

486486
return results
487487

488+
@telemetry(entry_point="plugin=deployment&action=delete", name="aqua")
489+
def delete(self,model_deployment_id:str):
490+
return self.ds_client.delete_model_deployment(model_deployment_id=model_deployment_id).data
491+
492+
@telemetry(entry_point="plugin=deployment&action=deactivate",name="aqua")
493+
def deactivate(self,model_deployment_id:str):
494+
return self.ds_client.deactivate_model_deployment(model_deployment_id=model_deployment_id).data
495+
496+
@telemetry(entry_point="plugin=deployment&action=activate",name="aqua")
497+
def activate(self,model_deployment_id:str):
498+
return self.ds_client.activate_model_deployment(model_deployment_id=model_deployment_id).data
499+
488500
@telemetry(entry_point="plugin=deployment&action=get", name="aqua")
489501
def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
490502
"""Gets the information of Aqua model deployment.

tests/unitary/with_extras/aqua/test_deployment_handler.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,30 @@ def test_get_deployment(self, mock_get):
9292
self.deployment_handler.get(id="mock-model-id")
9393
mock_get.assert_called()
9494

95+
@patch("ads.aqua.modeldeployment.AquaDeploymentApp.delete")
96+
def test_delete_deployment(self, mock_delete):
97+
self.deployment_handler.request.path = "aqua/deployments"
98+
self.deployment_handler.delete("mock-model-id")
99+
mock_delete.assert_called()
100+
101+
@patch("ads.aqua.modeldeployment.AquaDeploymentApp.activate")
102+
def test_activate_deployment(self, mock_activate):
103+
self.deployment_handler.request.path = (
104+
"aqua/deployments/ocid1.datasciencemodeldeployment.oc1.iad.xxx/activate"
105+
)
106+
mock_activate.return_value = {"lifecycle_state": "UPDATING"}
107+
self.deployment_handler.put()
108+
mock_activate.assert_called()
109+
110+
@patch("ads.aqua.modeldeployment.AquaDeploymentApp.deactivate")
111+
def test_deactivate_deployment(self, mock_deactivate):
112+
self.deployment_handler.request.path = (
113+
"aqua/deployments/ocid1.datasciencemodeldeployment.oc1.iad.xxx/deactivate"
114+
)
115+
mock_deactivate.return_value = {"lifecycle_state": "UPDATING"}
116+
self.deployment_handler.put()
117+
mock_deactivate.assert_called()
118+
95119
@patch("ads.aqua.modeldeployment.AquaDeploymentApp.list")
96120
def test_list_deployment(self, mock_list):
97121
"""Test get method to return a list of model deployments."""

0 commit comments

Comments
 (0)