Skip to content

Commit 7897b38

Browse files
committed
added validation logic for model_task and unit test in test_model.py
1 parent 4e2bda0 commit 7897b38

File tree

4 files changed

+44
-30
lines changed

4 files changed

+44
-30
lines changed

ads/aqua/model/enums.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ class FineTuningCustomMetadata(ExtendedEnum):
2828
class MultiModelSupportedTaskType(ExtendedEnum):
2929
TEXT_GENERATION = "text-generation"
3030
TEXT_GENERATION_ALT = "text_generation"
31-
EMBEDDING_ALT = "text_embedding"
32-
33-
class MultiModelConfigMode(ExtendedEnum):
34-
EMBEDDING = "embedding"
35-
DEFAULT = "completion"
31+
IMAGE_TEXT_TO_TEXT = "image_text_to_text"
32+
CODE_SYNTHESIS = "code_synthesis"
33+
EMBEDDING = "text_embedding"

ads/aqua/model/model.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import json
55
import os
66
import pathlib
7+
import re
78
from datetime import datetime, timedelta
89
from threading import Lock
910
from typing import Any, Dict, List, Optional, Set, Union
@@ -80,7 +81,7 @@
8081
ImportModelDetails,
8182
ModelValidationResult,
8283
)
83-
from ads.aqua.model.enums import MultiModelSupportedTaskType, MultiModelConfigMode
84+
from ads.aqua.model.enums import MultiModelSupportedTaskType
8485
from ads.common.auth import default_signer
8586
from ads.common.oci_resource import SEARCH_TYPE, OCIResource
8687
from ads.common.utils import (
@@ -709,14 +710,16 @@ def edit_registered_model(
709710
else:
710711
raise AquaRuntimeError("Only registered unverified models can be edited.")
711712

712-
def _get_task(
713-
self,
714-
model: AquaMultiModelRef,
715-
freeform_task_tag: str
716-
) -> str:
717-
"""In a Multi Model Deployment, will set model task if freeform task tag from model needs a non-completion endpoint (embedding)"""
718-
if freeform_task_tag == MultiModelSupportedTaskType.EMBEDDING_ALT:
719-
model.model_task = MultiModelConfigMode.EMBEDDING
713+
def _get_task(self, model: AquaMultiModelRef, freeform_task_tag: str) -> str:
714+
"""In a Multi Model Deployment, will set model_task parameter in AquaMultiModelRef from freeform tags or user"""
715+
task_tag = re.sub(r"-", "_", freeform_task_tag)
716+
717+
if task_tag in MultiModelSupportedTaskType:
718+
model.model_task = task_tag
719+
else:
720+
raise AquaValueError(
721+
f"{freeform_task_tag} is not supported. Valid model_task inputs are: {MultiModelSupportedTaskType.values()}."
722+
)
720723

721724
def _fetch_metric_from_metadata(
722725
self,

tests/unitary/with_extras/aqua/test_deployment.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
ModelDeploymentConfigSummary,
4646
ModelParams,
4747
)
48+
from ads.aqua.model.enums import MultiModelSupportedTaskType
4849
from ads.aqua.modeldeployment.utils import MultiModelDeploymentConfigLoader
4950
from ads.model.datascience_model import DataScienceModel
5051
from ads.model.deployment.model_deployment import ModelDeployment
@@ -276,7 +277,7 @@ class TestDataset:
276277
"environment_configuration_type": "OCIR_CONTAINER",
277278
"environment_variables": {
278279
"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions",
279-
"MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/"}]}',
280+
"MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "text_embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/", "model_task": "image_text_to_text"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/", "model_task": "code_synthesis"}]}',
280281
},
281282
"health_check_port": 8080,
282283
"image": "dsmc://image-name:1.0.0.0",
@@ -486,30 +487,30 @@ class TestDataset:
486487
"gpu_count": 2,
487488
"model_id": "test_model_id_1",
488489
"model_name": "test_model_1",
489-
"model_task": "embedding",
490+
"model_task": "text_embedding",
490491
"artifact_location": "test_location_1",
491492
},
492493
{
493494
"env_var": {},
494495
"gpu_count": 2,
495496
"model_id": "test_model_id_2",
496497
"model_name": "test_model_2",
497-
"model_task": None,
498+
"model_task": "image_text_to_text",
498499
"artifact_location": "test_location_2",
499500
},
500501
{
501502
"env_var": {},
502503
"gpu_count": 2,
503504
"model_id": "test_model_id_3",
504505
"model_name": "test_model_3",
505-
"model_task": None,
506+
"model_task": "code_synthesis",
506507
"artifact_location": "test_location_3",
507508
},
508509
],
509510
"model_id": "ocid1.datasciencemodel.oc1.<region>.<OCID>",
510511
"environment_variables": {
511512
"MODEL_DEPLOY_PREDICT_ENDPOINT": "/v1/completions",
512-
"MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/"}]}',
513+
"MULTI_MODEL_CONFIG": '{ "models": [{ "params": "--served-model-name model_one --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_one/5be6479/artifact/", "model_task": "text_embedding"}, {"params": "--served-model-name model_two --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_two/83e9aa1/artifact/", "model_task": "image_text_to_text"}, {"params": "--served-model-name model_three --tensor-parallel-size 1 --max-model-len 2096", "model_path": "models/model_three/83e9aa1/artifact/", "model_task": "code_synthesis"}]}',
513514
},
514515
"cmd": [],
515516
"console_link": "https://cloud.oracle.com/data-science/model-deployments/ocid1.datasciencemodeldeployment.oc1.<region>.<MD_OCID>?region=region-name",
@@ -968,23 +969,23 @@ class TestDataset:
968969
"gpu_count": 1,
969970
"model_id": "ocid1.compartment.oc1..<OCID>",
970971
"model_name": "model_one",
971-
"model_task": "embedding",
972+
"model_task": "text_embedding",
972973
"artifact_location": "artifact_location_one",
973974
},
974975
{
975976
"env_var": {"--test_key_two": "test_value_two"},
976977
"gpu_count": 1,
977978
"model_id": "ocid1.compartment.oc1..<OCID>",
978979
"model_name": "model_two",
979-
"model_task": None,
980+
"model_task": "image_text_to_text",
980981
"artifact_location": "artifact_location_two",
981982
},
982983
{
983984
"env_var": {"--test_key_three": "test_value_three"},
984985
"gpu_count": 1,
985986
"model_id": "ocid1.compartment.oc1..<OCID>",
986987
"model_name": "model_three",
987-
"model_task": None,
988+
"model_task": "code_synthesis",
988989
"artifact_location": "artifact_location_three",
989990
},
990991
]
@@ -1793,23 +1794,23 @@ def test_create_deployment_for_multi_model(
17931794
model_info_1 = AquaMultiModelRef(
17941795
model_id="test_model_id_1",
17951796
model_name="test_model_1",
1796-
model_task="embedding",
1797+
model_task="text_embedding",
17971798
gpu_count=2,
17981799
artifact_location="test_location_1",
17991800
)
18001801

18011802
model_info_2 = AquaMultiModelRef(
18021803
model_id="test_model_id_2",
18031804
model_name="test_model_2",
1804-
model_task=None,
1805+
model_task="image_text_to_text",
18051806
gpu_count=2,
18061807
artifact_location="test_location_2",
18071808
)
18081809

18091810
model_info_3 = AquaMultiModelRef(
18101811
model_id="test_model_id_3",
18111812
model_name="test_model_3",
1812-
model_task=None,
1813+
model_task="code_synthesis",
18131814
gpu_count=2,
18141815
artifact_location="test_location_3",
18151816
)

tests/unitary/with_extras/aqua/test_model.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import json
77
import os
8+
import re
89
import shlex
910
import tempfile
1011
from dataclasses import asdict
@@ -13,17 +14,14 @@
1314

1415
import oci
1516
import pytest
16-
17-
from ads.aqua.app import AquaApp
18-
from ads.aqua.config.container_config import AquaContainerConfig
1917
from huggingface_hub.hf_api import HfApi, ModelInfo
2018
from parameterized import parameterized
2119

2220
import ads.aqua.model
2321
import ads.common
2422
import ads.common.oci_client
2523
import ads.config
26-
24+
from ads.aqua.app import AquaApp
2725
from ads.aqua.common.entities import AquaMultiModelRef
2826
from ads.aqua.common.enums import ModelFormat, Tags
2927
from ads.aqua.common.errors import (
@@ -32,6 +30,7 @@
3230
AquaValueError,
3331
)
3432
from ads.aqua.common.utils import get_hf_model_info
33+
from ads.aqua.config.container_config import AquaContainerConfig
3534
from ads.aqua.constants import HF_METADATA_FOLDER
3635
from ads.aqua.model import AquaModelApp
3736
from ads.aqua.model.entities import (
@@ -40,14 +39,14 @@
4039
ImportModelDetails,
4140
ModelValidationResult,
4241
)
42+
from ads.aqua.model.enums import MultiModelSupportedTaskType
4343
from ads.common.object_storage_details import ObjectStorageDetails
4444
from ads.model.datascience_model import DataScienceModel
4545
from ads.model.model_metadata import (
4646
ModelCustomMetadata,
4747
ModelProvenanceMetadata,
4848
ModelTaxonomyMetadata,
4949
)
50-
5150
from tests.unitary.with_extras.aqua.utils import ServiceManagedContainers
5251

5352

@@ -397,12 +396,14 @@ def test_create_multimodel(
397396
model_info_1 = AquaMultiModelRef(
398397
model_id="test_model_id_1",
399398
gpu_count=2,
399+
model_task = "text_embedding",
400400
env_var={"params": "--trust-remote-code --max-model-len 60000"},
401401
)
402402

403403
model_info_2 = AquaMultiModelRef(
404404
model_id="test_model_id_2",
405405
gpu_count=2,
406+
model_task = "image_text_to_text",
406407
env_var={"params": "--trust-remote-code --max-model-len 32000"},
407408
)
408409

@@ -439,6 +440,17 @@ def test_create_multimodel(
439440
mock_model.custom_metadata_list = custom_metadata_list
440441
mock_from_id.return_value = mock_model
441442

443+
mock_model.freeform_tags["task"] = "invalid_task"
444+
445+
with pytest.raises(AquaValueError):
446+
model = self.app.create_multi(
447+
models=[model_info_1, model_info_2],
448+
project_id="test_project_id",
449+
compartment_id="test_compartment_id",
450+
)
451+
452+
mock_model.freeform_tags["task"] = "text-generation"
453+
442454
# will create a multi-model group
443455
model = self.app.create_multi(
444456
models=[model_info_1, model_info_2],

0 commit comments

Comments
 (0)