diff --git a/ads/aqua/extension/model_handler.py b/ads/aqua/extension/model_handler.py index 42f90ffef..8d5002150 100644 --- a/ads/aqua/extension/model_handler.py +++ b/ads/aqua/extension/model_handler.py @@ -133,6 +133,7 @@ def post(self, *args, **kwargs): # noqa: ARG002 ignore_patterns = input_data.get("ignore_patterns") freeform_tags = input_data.get("freeform_tags") defined_tags = input_data.get("defined_tags") + task = input_data.get("task") return self.finish( AquaModelApp().register( @@ -149,6 +150,7 @@ def post(self, *args, **kwargs): # noqa: ARG002 ignore_patterns=ignore_patterns, freeform_tags=freeform_tags, defined_tags=defined_tags, + task=task, ) ) diff --git a/ads/aqua/model/entities.py b/ads/aqua/model/entities.py index ecdb8b8e7..865ac14b6 100644 --- a/ads/aqua/model/entities.py +++ b/ads/aqua/model/entities.py @@ -293,6 +293,7 @@ class ImportModelDetails(CLIBuilderMixin): ignore_patterns: Optional[List[str]] = None freeform_tags: Optional[dict] = None defined_tags: Optional[dict] = None + task: Optional[str] = None def __post_init__(self): self._command = "model register" diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 02e0df00f..c0d2345d3 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -376,7 +376,7 @@ def delete_model(self, model_id): f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted." ) - @telemetry(entry_point="plugin=model&action=delete", name="aqua") + @telemetry(entry_point="plugin=model&action=edit", name="aqua") def edit_registered_model(self, id, inference_container, enable_finetuning, task): """Edits the default config of unverified registered model. @@ -1119,6 +1119,12 @@ def _validate_model( validation_result.tags = hf_tags except Exception: pass + else: + validation_result.tags = { + Tags.TASK: import_model_details.task + if import_model_details.task + else UNKNOWN + } validation_result.model_formats = model_formats diff --git a/tests/unitary/with_extras/aqua/test_model.py b/tests/unitary/with_extras/aqua/test_model.py index cabb8c523..c69a57aa2 100644 --- a/tests/unitary/with_extras/aqua/test_model.py +++ b/tests/unitary/with_extras/aqua/test_model.py @@ -1021,11 +1021,15 @@ def test_import_any_model_smc_container( inference_container="odsc-vllm-or-tgi-container", finetuning_container="odsc-llm-fine-tuning", download_from_hf=False, + task="text-generation", ) assert model.tags == { "aqua_custom_base_model": "true", "model_format": "SAFETENSORS", "ready_to_fine_tune": "true", + "license": "", + "organization": "", + "task": "text-generation", **ds_freeform_tags, } assert model.inference_container == "odsc-vllm-or-tgi-container" @@ -1162,6 +1166,7 @@ def test_import_model_with_input_tags( download_from_hf=False, freeform_tags={"ftag1": "fvalue1", "ftag2": "fvalue2"}, defined_tags={"dtag1": "dvalue1", "dtag2": "dvalue2"}, + task="image_text_to_text", ) assert model.tags == { "aqua_custom_base_model": "true", @@ -1171,6 +1176,9 @@ def test_import_model_with_input_tags( "dtag2": "dvalue2", "ftag1": "fvalue1", "ftag2": "fvalue2", + "license": "", + "organization": "", + "task": "image_text_to_text", **ds_freeform_tags, } diff --git a/tests/unitary/with_extras/aqua/test_model_handler.py b/tests/unitary/with_extras/aqua/test_model_handler.py index bf02174b9..fda60ae94 100644 --- a/tests/unitary/with_extras/aqua/test_model_handler.py +++ b/tests/unitary/with_extras/aqua/test_model_handler.py @@ -218,6 +218,7 @@ def test_register( ignore_patterns=ignore_patterns, freeform_tags=freeform_tags, defined_tags=defined_tags, + task=None, ) assert result["id"] == "test_id" assert result["inference_container"] == "odsc-tgi-serving"