Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ads/aqua/extension/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
ignore_patterns = input_data.get("ignore_patterns")
freeform_tags = input_data.get("freeform_tags")
defined_tags = input_data.get("defined_tags")
task = input_data.get("task")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can use freeform_tags input here for task, right? Then just pop the dict later on and check if the tags contain task.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes correct . closing this PR.
We can use freeform tags for task input for OSS flow based unverified model registration


return self.finish(
AquaModelApp().register(
Expand All @@ -149,6 +150,7 @@ def post(self, *args, **kwargs): # noqa: ARG002
ignore_patterns=ignore_patterns,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
task=task,
)
)

Expand Down
1 change: 1 addition & 0 deletions ads/aqua/model/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ class ImportModelDetails(CLIBuilderMixin):
ignore_patterns: Optional[List[str]] = None
freeform_tags: Optional[dict] = None
defined_tags: Optional[dict] = None
task: Optional[str] = None

def __post_init__(self):
self._command = "model register"
8 changes: 7 additions & 1 deletion ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def delete_model(self, model_id):
f"Failed to delete model:{model_id}. Only registered models or finetuned model can be deleted."
)

@telemetry(entry_point="plugin=model&action=delete", name="aqua")
@telemetry(entry_point="plugin=model&action=edit", name="aqua")
def edit_registered_model(self, id, inference_container, enable_finetuning, task):
"""Edits the default config of unverified registered model.

Expand Down Expand Up @@ -1119,6 +1119,12 @@ def _validate_model(
validation_result.tags = hf_tags
except Exception:
pass
else:
validation_result.tags = {
Tags.TASK: import_model_details.task
if import_model_details.task
else UNKNOWN
}

validation_result.model_formats = model_formats

Expand Down
8 changes: 8 additions & 0 deletions tests/unitary/with_extras/aqua/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1021,11 +1021,15 @@ def test_import_any_model_smc_container(
inference_container="odsc-vllm-or-tgi-container",
finetuning_container="odsc-llm-fine-tuning",
download_from_hf=False,
task="text-generation",
)
assert model.tags == {
"aqua_custom_base_model": "true",
"model_format": "SAFETENSORS",
"ready_to_fine_tune": "true",
"license": "",
"organization": "",
"task": "text-generation",
**ds_freeform_tags,
}
assert model.inference_container == "odsc-vllm-or-tgi-container"
Expand Down Expand Up @@ -1162,6 +1166,7 @@ def test_import_model_with_input_tags(
download_from_hf=False,
freeform_tags={"ftag1": "fvalue1", "ftag2": "fvalue2"},
defined_tags={"dtag1": "dvalue1", "dtag2": "dvalue2"},
task="image_text_to_text",
)
assert model.tags == {
"aqua_custom_base_model": "true",
Expand All @@ -1171,6 +1176,9 @@ def test_import_model_with_input_tags(
"dtag2": "dvalue2",
"ftag1": "fvalue1",
"ftag2": "fvalue2",
"license": "",
"organization": "",
"task": "image_text_to_text",
**ds_freeform_tags,
}

Expand Down
1 change: 1 addition & 0 deletions tests/unitary/with_extras/aqua/test_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def test_register(
ignore_patterns=ignore_patterns,
freeform_tags=freeform_tags,
defined_tags=defined_tags,
task=None,
)
assert result["id"] == "test_id"
assert result["inference_container"] == "odsc-tgi-serving"
Expand Down
Loading