Skip to content

Commit 3a9fa5b

Browse files
Addressing review comments
1 parent c6f87ec commit 3a9fa5b

File tree

3 files changed

+12
-15
lines changed

3 files changed

+12
-15
lines changed

ads/aqua/common/enums.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ class InferenceContainerTypeFamily(str, metaclass=ExtendedEnumMeta):
5252
AQUA_VLLM_CONTAINER_FAMILY = "odsc-vllm-serving"
5353
AQUA_TGI_CONTAINER_FAMILY = "odsc-tgi-serving"
5454
AQUA_LLAMA_CPP_CONTAINER_FAMILY = "odsc-llama-cpp-serving"
55-
AQUA_TEI_CONTAINER_FAMILY = "odsc-tei-serving"
5655

56+
class CustomInferenceContainerTypeFamily(str,metaclass=ExtendedEnumMeta):
57+
AQUA_TEI_CONTAINER_FAMILY="odsc-tei-serving"
5758

5859
class InferenceContainerParamType(str, metaclass=ExtendedEnumMeta):
5960
PARAM_TYPE_VLLM = "VLLM_PARAMS"

ads/aqua/extension/model_handler.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from tornado.web import HTTPError
99

1010
from ads.aqua.common.decorator import handle_exceptions
11-
from ads.aqua.common.enums import InferenceContainerTypeFamily
11+
from ads.aqua.common.enums import InferenceContainerTypeFamily,CustomInferenceContainerTypeFamily
1212
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
1313
from ads.aqua.common.utils import (
1414
get_hf_model_info,
@@ -166,11 +166,10 @@ def put(self, id):
166166
inference_container = input_data.get("inference_container")
167167
inference_container_uri = input_data.get("inference_container_uri")
168168
inference_containers = AquaModelApp.list_valid_inference_containers()
169+
inference_containers.extend(CustomInferenceContainerTypeFamily.values())
169170
if (
170171
inference_container is not None
171172
and inference_container not in inference_containers
172-
and inference_container
173-
!= InferenceContainerTypeFamily.AQUA_TEI_CONTAINER_FAMILY
174173
):
175174
raise HTTPError(
176175
400, Errors.INVALID_VALUE_OF_PARAMETER.format("inference_container")

ads/aqua/model/model.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ads.aqua.common.enums import (
1818
FineTuningContainerTypeFamily,
1919
InferenceContainerTypeFamily,
20-
Tags,
20+
Tags, CustomInferenceContainerTypeFamily,
2121
)
2222
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
2323
from ads.aqua.common.utils import (
@@ -405,19 +405,18 @@ def edit_registered_model(
405405
if ds_model.freeform_tags.get(Tags.BASE_MODEL_CUSTOM, None):
406406
if ds_model.freeform_tags.get(Tags.AQUA_SERVICE_MODEL_TAG, None):
407407
raise AquaRuntimeError(
408-
f"Failed to edit model:{id}. Only registered unverified models can be edited."
408+
f"Only registered unverified models can be edited."
409409
)
410410
else:
411411
custom_metadata_list = ds_model.custom_metadata_list
412412
freeform_tags = ds_model.freeform_tags
413413
if inference_container:
414414
if (
415-
inference_container
416-
== InferenceContainerTypeFamily.AQUA_TEI_CONTAINER_FAMILY
415+
inference_container in CustomInferenceContainerTypeFamily.values()
417416
and inference_container_uri is None
418417
):
419418
raise AquaRuntimeError(
420-
f"Failed to edit model:{id}. Inference container URI must be provided."
419+
f"Inference container URI must be provided."
421420
)
422421
else:
423422
custom_metadata_list.add(
@@ -429,8 +428,7 @@ def edit_registered_model(
429428
)
430429
if inference_container_uri:
431430
if (
432-
inference_container
433-
== InferenceContainerTypeFamily.AQUA_TEI_CONTAINER_FAMILY
431+
inference_container in CustomInferenceContainerTypeFamily.values()
434432
or inference_container is None
435433
):
436434
custom_metadata_list.add(
@@ -442,7 +440,7 @@ def edit_registered_model(
442440
)
443441
else:
444442
raise AquaRuntimeError(
445-
f"Failed to edit model:{id}. Inference container URI can be edited only with TEI container."
443+
f"Inference container URI can be edited only with container values: {CustomInferenceContainerTypeFamily.values()}"
446444
)
447445

448446
if enable_finetuning is not None:
@@ -480,7 +478,7 @@ def edit_registered_model(
480478
AquaApp().update_model(id, update_model_details)
481479
else:
482480
raise AquaRuntimeError(
483-
f"Failed to edit model:{id}. Only registered unverified models can be edited."
481+
f"Only registered unverified models can be edited."
484482
)
485483

486484
def _fetch_metric_from_metadata(
@@ -900,8 +898,7 @@ def _create_model_catalog_entry(
900898
# only add cmd vars if inference container is not an SMC
901899
if (
902900
inference_container not in smc_container_set
903-
and inference_container
904-
== InferenceContainerTypeFamily.AQUA_TEI_CONTAINER_FAMILY
901+
and inference_container in CustomInferenceContainerTypeFamily.values()
905902
):
906903
cmd_vars = generate_tei_cmd_var(os_path)
907904
metadata.add(

0 commit comments

Comments
 (0)