diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index aab79b008..52d7613b6 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -68,10 +68,7 @@ ModelDeploymentConfigSummary, MultiModelDeploymentConfigLoader, ) -from ads.aqua.modeldeployment.constants import ( - DEFAULT_POLL_INTERVAL, - DEFAULT_WAIT_TIME, -) +from ads.aqua.modeldeployment.constants import DEFAULT_POLL_INTERVAL, DEFAULT_WAIT_TIME from ads.aqua.modeldeployment.entities import ( AquaDeployment, AquaDeploymentDetail, @@ -529,6 +526,7 @@ def _create( # validate user provided params user_params = env_var.get("PARAMS", UNKNOWN) + if user_params: # todo: remove this check in the future version, logic to be moved to container_index if ( @@ -554,6 +552,18 @@ def _create( deployment_params = get_combined_params(config_params, user_params) params = f"{params} {deployment_params}".strip() + + if create_deployment_details.model_name: + # Replace existing --served-model-name argument if present, otherwise add it + if "--served-model-name" in params: + params = re.sub( + r"--served-model-name\s+\S+", + f"--served-model-name {create_deployment_details.model_name}", + params, + ) + else: + params += f" --served-model-name {create_deployment_details.model_name}" + if params: env_var.update({"PARAMS": params}) env_vars = container_spec.env_vars if container_spec else [] diff --git a/ads/aqua/modeldeployment/entities.py b/ads/aqua/modeldeployment/entities.py index 4429b0472..ee030bc4a 100644 --- a/ads/aqua/modeldeployment/entities.py +++ b/ads/aqua/modeldeployment/entities.py @@ -233,6 +233,9 @@ class CreateModelDeploymentDetails(BaseModel): None, description="The description of the deployment." ) model_id: Optional[str] = Field(None, description="The model OCID to deploy.") + model_name: Optional[str] = Field( + None, description="The model name specified by user to deploy." + ) models: Optional[List[AquaMultiModelRef]] = Field( None, description="List of models for multimodel deployment." diff --git a/tests/unitary/with_extras/aqua/test_common_utils.py b/tests/unitary/with_extras/aqua/test_common_utils.py index e10e146ed..85cb57941 100644 --- a/tests/unitary/with_extras/aqua/test_common_utils.py +++ b/tests/unitary/with_extras/aqua/test_common_utils.py @@ -5,6 +5,7 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import pytest + from ads.aqua.common.utils import get_preferred_compatible_family @@ -14,15 +15,15 @@ class TestCommonUtils: [ ( {"odsc-vllm-serving", "odsc-vllm-serving-v1"}, - "odsc-vllm-serving-v1", + "odsc-vllm-serving-openai", ), ( {"odsc-vllm-serving", "odsc-vllm-serving-llama4"}, - "odsc-vllm-serving-llama4", + "odsc-vllm-serving-openai", ), ( {"odsc-vllm-serving-v1", "odsc-vllm-serving-llama4"}, - "odsc-vllm-serving-llama4", + "odsc-vllm-serving-openai", ), ( { @@ -30,7 +31,7 @@ class TestCommonUtils: "odsc-vllm-serving-v1", "odsc-vllm-serving-llama4", }, - "odsc-vllm-serving-llama4", + "odsc-vllm-serving-openai", ), ({"odsc-tgi-serving", "odsc-vllm-serving"}, None), ({"non-existing-one", "odsc-tgi-serving"}, None),