diff --git a/ads/aqua/common/enums.py b/ads/aqua/common/enums.py index 4a423788d..f1d1cd661 100644 --- a/ads/aqua/common/enums.py +++ b/ads/aqua/common/enums.py @@ -52,6 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta): AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving" AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving" AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving" + + +class CustomInferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta): AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving" diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 42f90ffef..68c9cedaa 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -8,6 +8,9 @@ from tornado.web import HTTPError from ads.aqua.common.decorator import handle_exceptions +from ads.aqua.common.enums import ( + CustomInferenceContainerTypeFamily, +) from ads.aqua.common.errors import AquaRuntimeError, AquaValueError from ads.aqua.common.utils import ( get_hf_model_info, @@ -163,7 +166,9 @@ def put(self, id): raise HTTPError(400, Errors.NO_INPUT_DATA) inference_container = input_data.get("inference_container") + inference_container_uri = input_data.get("inference_container_uri") inference_containers = AquaModelApp.list_valid_inference_containers() + inference_containers.extend(CustomInferenceContainerTypeFamily.values()) if ( inference_container is not None and inference_container not in inference_containers @@ -176,7 +181,13 @@ def put(self, id): task = input_data.get("task") app = AquaModelApp() self.finish( - app.edit_registered_model(id, inference_container, enable_finetuning, task) + app.edit_registered_model( + id, + inference_container, + inference_container_uri, + enable_finetuning, + task, + ) ) app.clear_model_details_cache(model_id=id) diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 02e0df00f..064f5fab4 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -15,6 +15,7 @@ from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger from ads.aqua.app import AquaApp from ads.aqua.common.enums import ( + CustomInferenceContainerTypeFamily, FineTuningContainerTypeFamily, InferenceContainerTypeFamily, Tags, @@ -376,8 +377,10 @@ def delete_model(self, model_id): f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted." ) - @telemetry(entry_point="plugin=model&action=delete", name="aqua") - def edit_registered_model(self, id, inference_container, enable_finetuning, task): + @telemetry(entry_point="plugin=model&action=edit", name="aqua") + def edit_registered_model( + self, id, inference_container, inference_container_uri, enable_finetuning, task + ): """Edits the default config of unverified registered model. Parameters @@ -386,6 +389,8 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task The model OCID. inference_container: str. The inference container family name + inference_container_uri: str + The inference container uri for embedding models enable_finetuning: str Flag to enable or disable finetuning over the model. Defaults to None task: @@ -401,19 +406,44 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None): if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None): raise AquaRuntimeError( - f"Failed to edit model:{id}. Only registered unverified models can be edited." + "Only registered unverified models can be edited." ) else: custom_metadata_list = ds_model.custom_metadata_list freeform_tags = ds_model.freeform_tags if inference_container: - custom_metadata_list.add( - key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, - value=inference_container, - category=MetadataCustomCategory.OTHER, - description="Deployment container mapping for SMC", - replace=True, - ) + if ( + inference_container in CustomInferenceContainerTypeFamily + and inference_container_uri is None + ): + raise AquaRuntimeError( + "Inference container URI must be provided." + ) + else: + custom_metadata_list.add( + key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER, + value=inference_container, + category=MetadataCustomCategory.OTHER, + description="Deployment container mapping for SMC", + replace=True, + ) + if inference_container_uri: + if ( + inference_container in CustomInferenceContainerTypeFamily + or inference_container is None + ): + custom_metadata_list.add( + key=ModelCustomMetadataFields.DEPLOYMENT_CONTAINER_URI, + value=inference_container_uri, + category=MetadataCustomCategory.OTHER, + description=f"Inference container URI for {ds_model.display_name}", + replace=True, + ) + else: + raise AquaRuntimeError( + f"Inference container URI can be edited only with container values: {CustomInferenceContainerTypeFamily.values()}" + ) + if enable_finetuning is not None: if enable_finetuning.lower() == "true": custom_metadata_list.add( @@ -448,9 +478,7 @@ def edit_registered_model(self, id, inference_container, enable_finetuning, task ) AquaApp().update_model(id, update_model_details) else: - raise AquaRuntimeError( - f"Failed to edit model:{id}. Only registered unverified models can be edited." - ) + raise AquaRuntimeError("Only registered unverified models can be edited.") def _fetch_metric_from_metadata( self, @@ -869,8 +897,7 @@ def _create_model_catalog_entry( # only add cmd vars if inference container is not an SMC if ( inference_container not in smc_container_set - and inference_container - == InferenceContainerTypeFamily.AQUA_TEI_CONTAINER_FAMILY + and inference_container in CustomInferenceContainerTypeFamily.values() ): cmd_vars = generate_tei_cmd_var(os_path) metadata.add(