Skip to content

Commit 987b60d

Browse files
committed
refactored _get_task
1 parent 7897b38 commit 987b60d

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

ads/aqua/model/model.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -317,10 +317,7 @@ def create_multi(
317317

318318
display_name_list.append(display_name)
319319

320-
model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN)
321-
322-
if model_task != UNKNOWN:
323-
self._get_task(model, model_task)
320+
self._get_task(model, source_model)
324321

325322
# Retrieve model artifact
326323
model_artifact_path = source_model.artifact
@@ -710,15 +707,23 @@ def edit_registered_model(
710707
else:
711708
raise AquaRuntimeError("Only registered unverified models can be edited.")
712709

713-
def _get_task(self, model: AquaMultiModelRef, freeform_task_tag: str) -> str:
710+
def _get_task(
711+
self,
712+
model: AquaMultiModelRef,
713+
source_model: DataScienceModel,
714+
) -> str:
714715
"""In a Multi Model Deployment, will set model_task parameter in AquaMultiModelRef from freeform tags or user"""
715-
task_tag = re.sub(r"-", "_", freeform_task_tag)
716+
# user does not supply model task, we extract from model metadata
717+
if not model.model_task:
718+
model.model_task = source_model.freeform_tags.get(Tags.TASK, UNKNOWN)
719+
720+
task_tag = re.sub(r"-", "_", model.model_task)
716721

717722
if task_tag in MultiModelSupportedTaskType:
718723
model.model_task = task_tag
719724
else:
720725
raise AquaValueError(
721-
f"{freeform_task_tag} is not supported. Valid model_task inputs are: {MultiModelSupportedTaskType.values()}."
726+
f"{task_tag} is not supported. Valid model_task inputs are: {MultiModelSupportedTaskType.values()}."
722727
)
723728

724729
def _fetch_metric_from_metadata(

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,8 @@ def test_create_multimodel(
440440
mock_model.custom_metadata_list = custom_metadata_list
441441
mock_from_id.return_value = mock_model
442442

443-
mock_model.freeform_tags["task"] = "invalid_task"
443+
# testing _get_task when a user passes an invalid task to AquaMultiModelRef
444+
model_info_1.model_task = "invalid_task"
444445

445446
with pytest.raises(AquaValueError):
446447
model = self.app.create_multi(
@@ -449,7 +450,18 @@ def test_create_multimodel(
449450
compartment_id="test_compartment_id",
450451
)
451452

453+
# testing if a user tries to invoke a model with a task mode that is not yet supported
454+
model_info_1.model_task = None
455+
mock_model.freeform_tags["task"] = "unsupported_task"
456+
with pytest.raises(AquaValueError):
457+
model = self.app.create_multi(
458+
models=[model_info_1, model_info_2],
459+
project_id="test_project_id",
460+
compartment_id="test_compartment_id",
461+
)
462+
452463
mock_model.freeform_tags["task"] = "text-generation"
464+
model_info_1.model_task = "text_embedding"
453465

454466
# will create a multi-model group
455467
model = self.app.create_multi(

0 commit comments

Comments
 (0)