From bde6c1013356881181564886322c667bb23a28c3 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Tue, 9 Sep 2025 15:52:21 -0700 Subject: [PATCH 1/3] AQUA. Support env var overrides and enhance multi-model entity structure --- ads/aqua/common/entities.py | 126 +++++++++++-- ads/aqua/common/utils.py | 167 +++++++++++------- ads/aqua/modeldeployment/deployment.py | 14 +- .../modeldeployment/model_group_config.py | 6 +- .../with_extras/aqua/test_common_entities.py | 133 +++++++++++++- 5 files changed, 362 insertions(+), 84 deletions(-) diff --git a/ads/aqua/common/entities.py b/ads/aqua/common/entities.py index 5973dd035..ba280e10e 100644 --- a/ads/aqua/common/entities.py +++ b/ads/aqua/common/entities.py @@ -3,7 +3,7 @@ # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from oci.data_science.models import Model from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -245,55 +245,71 @@ class AquaMultiModelRef(Serializable): """ Lightweight model descriptor used for multi-model deployment. - This class only contains essential details - required to fetch complete model metadata and deploy models. + This class holds essential details required to fetch model metadata and deploy + individual models as part of a multi-model deployment group. Attributes ---------- model_id : str - The unique identifier of the model. + The unique identifier (OCID) of the base model. model_name : Optional[str] - The name of the model. + Optional name for the model. gpu_count : Optional[int] - Number of GPUs required for deployment. + Number of GPUs required to allocate for this model during deployment. model_task : Optional[str] - The task that model operates on. Supported tasks are in MultiModelSupportedTaskType + The machine learning task this model performs (e.g., text-generation, summarization). + Supported values are listed in `MultiModelSupportedTaskType`. env_var : Optional[Dict[str, Any]] - Optional environment variables to override during deployment. + Optional dictionary of environment variables to inject into the runtime environment + of the model container. + params : Optional[Dict[str, Any]] + Optional dictionary of container-specific inference parameters to override. + These are typically framework-level flags required by the runtime backend. + For example, in vLLM containers, valid params may include: + `--tensor-parallel-size`, `--enforce-eager`, `--max-model-len`, etc. artifact_location : Optional[str] - Artifact path of model in the multimodel group. + Relative path or URI of the model artifact inside the multi-model group folder. fine_tune_weights : Optional[List[LoraModuleSpec]] - For fine tuned models, the artifact path of the modified model weights + List of fine-tuned weight artifacts (e.g., LoRA modules) associated with this model. """ model_id: str = Field(..., description="The model OCID to deploy.") - model_name: Optional[str] = Field(None, description="The name of model.") + model_name: Optional[str] = Field(None, description="The name of the model.") gpu_count: Optional[int] = Field( - None, description="The gpu count allocation for the model." + None, description="The number of GPUs allocated for the model." ) model_task: Optional[str] = Field( None, - description="The task that model operates on. Supported tasks are in MultiModelSupportedTaskType", + description="The task this model performs. See `MultiModelSupportedTaskType` for supported values.", ) env_var: Optional[dict] = Field( - default_factory=dict, description="The environment variables of the model." + default_factory=dict, + description="Environment variables to override during container startup.", + ) + params: Optional[dict] = Field( + default_factory=dict, + description=( + "Framework-specific startup parameters required by the container runtime. " + "For example, vLLM models may use flags like `--tensor-parallel-size`, `--enforce-eager`, etc." + ), ) artifact_location: Optional[str] = Field( - None, description="Artifact path of model in the multimodel group." + None, + description="Path to the model artifact relative to the multi-model base folder.", ) fine_tune_weights: Optional[List[LoraModuleSpec]] = Field( None, - description="For fine tuned models, the artifact path of the modified model weights", + description="List of fine-tuned weight modules (e.g., LoRA) associated with this base model.", ) def all_model_ids(self) -> List[str]: """ - Returns all associated model OCIDs, including the base model and any fine-tuned models. + Returns all model OCIDs associated with this reference, including fine-tuned weights. Returns ------- List[str] - A list of all model OCIDs associated with this multi-model reference. + A list containing the base model OCID and any fine-tuned module OCIDs. """ ids = {self.model_id} if self.fine_tune_weights: @@ -302,8 +318,80 @@ def all_model_ids(self) -> List[str]: ) return list(ids) + @model_validator(mode="before") + @classmethod + def extract_params_from_env_var(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """ + A model-level validator that extracts `PARAMS` from the `env_var` dictionary + and injects them into the `params` field as a dictionary. + + This is useful for backward compatibility where users pass CLI-style + parameters via environment variables, e.g.: + env_var = { "PARAMS": "--max-model-len 65536 --enable-streaming" } + + If `params` is already set, values from `PARAMS` in `env_var` are added + only if they do not override existing keys. + """ + env = values.get("env_var", {}) + param_string = env.pop("PARAMS", None) + + if param_string: + parsed_params = cls._parse_params(params=param_string) + existing_params = values.get("params", {}) or {} + # Avoid overriding existing keys + for k, v in parsed_params.items(): + if k not in existing_params: + existing_params[k] = v + values["params"] = existing_params + values["env_var"] = env # cleaned up version without PARAMS + + return values + + @staticmethod + def _parse_params(params: Union[str, List[str]]) -> Dict[str, str]: + """ + Parses CLI-style parameters into a dictionary format. + + This method accepts either: + - A single string of parameters (e.g., "--key1 val1 --key2 val2") + - A list of strings (e.g., ["--key1", "val1", "--key2", "val2"]) + + Returns a dictionary of the form { "key1": "val1", "key2": "val2" }. + + Parameters + ---------- + params : Union[str, List[str]] + The parameters to parse. Can be a single string or a list of strings. + + Returns + ------- + Dict[str, str] + Dictionary with parameter names as keys and their corresponding values as strings. + """ + if not params: + return {} + + # Normalize string to list of "--key value" strings + if isinstance(params, str): + params_list = [ + f"--{param.strip()}" for param in params.split("--") if param.strip() + ] + else: + params_list = params + + parsed = {} + for item in params_list: + parts = item.strip().split() + if not parts: + continue + key = parts[0] + value = " ".join(parts[1:]) if len(parts) > 1 else "" + parsed[key] = value + + return parsed + class Config: - extra = "ignore" + extra = "allow" protected_namespaces = () diff --git a/ads/aqua/common/utils.py b/ads/aqua/common/utils.py index 1d16ac07c..4bfbdd32d 100644 --- a/ads/aqua/common/utils.py +++ b/ads/aqua/common/utils.py @@ -800,35 +800,49 @@ def is_service_managed_container(container): def get_params_list(params: str) -> List[str]: - """Parses the string parameter and returns a list of params. + """ + Parses a string of CLI-style double-dash parameters and returns them as a list. Parameters ---------- - params - string parameters by separated by -- delimiter + params : str + A single string containing parameters separated by the `--` delimiter. Returns ------- - list of params + List[str] + A list of parameter strings, each starting with `--`. + Example + ------- + >>> get_params_list("--max-model-len 65536 --enforce-eager") + ['--max-model-len 65536', '--enforce-eager'] """ if not params: return [] - return ["--" + param.strip() for param in params.split("--")[1:]] + return [f"--{param.strip()}" for param in params.split("--") if param.strip()] def get_params_dict(params: Union[str, List[str]]) -> dict: - """Accepts a string or list of string of double-dash parameters and returns a dict with the parameter keys and values. + """ + Converts CLI-style double-dash parameters (as string or list) into a dictionary. Parameters ---------- - params: - List of parameters or parameter string separated by space. + params : Union[str, List[str]] + Parameters provided either as: + - a single string: "--key1 val1 --key2 val2" + - a list of strings: ["--key1 val1", "--key2 val2"] Returns ------- - dict containing parameter keys and values + dict + A dictionary mapping parameter names to their values. If no value is found, uses `UNKNOWN`. + Example + ------- + >>> get_params_dict("--max-model-len 65536 --enforce-eager") + {'--max-model-len': '65536', '--enforce-eager': ''} """ params_list = get_params_list(params) if isinstance(params, str) else params return { @@ -839,35 +853,43 @@ def get_params_dict(params: Union[str, List[str]]) -> dict: } -def get_combined_params(params1: str = None, params2: str = None) -> str: +def get_combined_params( + params1: Optional[str] = None, params2: Optional[str] = None +) -> str: """ - Combines string of double-dash parameters, and overrides the values from the second string in the first. + Merges two double-dash parameter strings (`--param value`) into one, with values from `params2` + overriding any duplicates from `params1`. + Parameters ---------- - params1: - Parameter string with values - params2: - Parameter string with values that need to be overridden. + params1 : Optional[str] + The base parameter string. Can be None. + params2 : Optional[str] + The override parameter string. Parameters in this string will override those in `params1`. Returns ------- - A combined list with overridden values from params2. + str + A combined parameter string with deduplicated keys and overridden values from `params2`. """ + if not params1 and not params2: + return "" + + # If only one string is provided, return it directly if not params1: - return params2 + return params2.strip() if not params2: - return params1 - - # overwrite values from params2 into params1 - combined_params = [ - f"{key} {value}" if value else key - for key, value in { - **get_params_dict(params1), - **get_params_dict(params2), - }.items() - ] + return params1.strip() + + # Combine both dictionaries, with params2 overriding params1 + merged_dict = {**get_params_dict(params1), **get_params_dict(params2)} + + # Reconstruct the string + combined = " ".join( + f"{key} {value}" if value else key for key, value in merged_dict.items() + ) - return " ".join(combined_params) + return combined.strip() def find_restricted_params( @@ -905,28 +927,46 @@ def find_restricted_params( return restricted_params -def build_params_string(params: dict) -> str: - """Builds params string from params dict +def build_params_string(params: Optional[Dict[str, Any]]) -> str: + """ + Converts a dictionary of CLI parameters into a command-line friendly string. + + This is typically used to transform framework-specific model parameters (e.g., vLLM or TGI flags) + into a space-separated string that can be passed to container startup commands. Parameters ---------- - params: - Parameter dict with key-value pairs + params : Optional[Dict[str, Any]] + Dictionary containing parameter name as keys (e.g., "--max-model-len") and their corresponding values. + If a parameter does not require a value (e.g., a boolean flag), its value can be None or an empty string. Returns ------- - A params string. + str + A space-separated string of CLI arguments. + Returns "" if the input dictionary is None or empty. + + Example + ------- + >>> build_params_string({"--max-model-len": 4096, "--enforce-eager": None}) + '--max-model-len 4096 --enforce-eager' """ - return ( - " ".join( - f"{name} {value}" if value else f"{name}" for name, value in params.items() - ).strip() - if params - else UNKNOWN - ) + if not params: + return UNKNOWN + + parts = [] + for key, value in params.items(): + if value is None or value == "": + parts.append(str(key)) + else: + parts.append(f"{key} {value}") + + return " ".join(parts).strip() -def copy_model_config(artifact_path: str, os_path: str, auth: dict = None): +def copy_model_config( + artifact_path: str, os_path: str, auth: Optional[Dict[str, Any]] = None +): """Copies the aqua model config folder from the artifact path to the user provided object storage path. The config folder is overwritten if the files already exist at the destination path. @@ -1202,36 +1242,45 @@ def parse_cmd_var(cmd_list: List[str]) -> dict: return parsed_cmd -def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]: - """This function accepts two lists of parameters and combines them. If the second list shares the common parameter - names/keys, then it raises an error. +def validate_cmd_var( + cmd_var: Optional[List[str]], overrides: Optional[List[str]] +) -> List[str]: + """ + Validates and combines two lists of command-line parameters. Raises an error if any parameter + key in the `overrides` list already exists in the `cmd_var` list, preventing unintended overrides. + Parameters ---------- - cmd_var: List[str] - Default list of parameters - overrides: List[str] - List of parameters to override + cmd_var : Optional[List[str]] + The default list of command-line parameters (e.g., ["--param1", "value1", "--flag"]). + overrides : Optional[List[str]] + The list of overriding command-line parameters. + Returns ------- - List[str] of combined parameters + List[str] + A validated and combined list of parameters, with overrides appended. + + Raises + ------ + AquaValueError + If `overrides` contain any parameter keys that already exist in `cmd_var`. """ - cmd_var = [str(x) for x in cmd_var] - if not overrides: - return cmd_var - overrides = [str(x) for x in overrides] + cmd_var = [str(x).strip() for x in cmd_var or []] + overrides = [str(x).strip() for x in overrides or []] cmd_dict = parse_cmd_var(cmd_var) overrides_dict = parse_cmd_var(overrides) - # check for conflicts - common_keys = set(cmd_dict.keys()) & set(overrides_dict.keys()) - if common_keys: + # Check for conflicting keys + conflicting_keys = set(cmd_dict.keys()) & set(overrides_dict.keys()) + if conflicting_keys: raise AquaValueError( - f"The following CMD input cannot be overridden for model deployment: {', '.join(common_keys)}" + f"Cannot override the following model deployment parameters: {', '.join(sorted(conflicting_keys))}" ) - combined_cmd_var = cmd_var + overrides - return combined_cmd_var + combined_params = cmd_var + overrides + return combined_params def build_pydantic_error_message(ex: ValidationError): diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 52d7613b6..2710534bb 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -520,10 +520,20 @@ def _create( deployment_config = self.get_deployment_config(model_id=config_source_id) + # Loads frameworks specific default params from the configuration config_params = deployment_config.configuration.get( create_deployment_details.instance_shape, ConfigurationItem() ).parameters.get(get_container_params_type(container_type_key), UNKNOWN) + # Loads default environment variables from the configuration + config_env = deployment_config.configuration.get( + create_deployment_details.instance_shape, ConfigurationItem() + ).env.get(get_container_params_type(container_type_key), {}) + + # Merges user provided environment variables with the ones provided in the deployment config + # The values provided by user will override the ones provided by default config + env_var = {**config_env, **env_var} + # validate user provided params user_params = env_var.get("PARAMS", UNKNOWN) @@ -643,8 +653,8 @@ def _create_multi( env_var.update({AQUA_MULTI_MODEL_CONFIG: multi_model_config.model_dump_json()}) - env_vars = container_spec.env_vars if container_spec else [] - for env in env_vars: + container_spec_env_vars = container_spec.env_vars if container_spec else [] + for env in container_spec_env_vars: if isinstance(env, dict): env = {k: v for k, v in env.items() if v} for key, _ in env.items(): diff --git a/ads/aqua/modeldeployment/model_group_config.py b/ads/aqua/modeldeployment/model_group_config.py index e452ec7f5..ec9acb364 100644 --- a/ads/aqua/modeldeployment/model_group_config.py +++ b/ads/aqua/modeldeployment/model_group_config.py @@ -130,7 +130,7 @@ def _extract_model_params( Validates if user-provided parameters override pre-set parameters by AQUA. Updates model name and TP size parameters to user-provided parameters. """ - user_params = build_params_string(model.env_var) + user_params = build_params_string(model.params) if user_params: restricted_params = find_restricted_params( container_params, user_params, container_type_key @@ -138,8 +138,8 @@ def _extract_model_params( if restricted_params: selected_model = model.model_name or model.model_id raise AquaValueError( - f"Parameters {restricted_params} are set by Aqua " - f"and cannot be overridden or are invalid." + f"Parameters {restricted_params} are set by AI Quick Actions " + f"and cannot be overridden or are invalid. " f"Select other parameters for model {selected_model}." ) diff --git a/tests/unitary/with_extras/aqua/test_common_entities.py b/tests/unitary/with_extras/aqua/test_common_entities.py index 778c07ff1..0c2b293b4 100644 --- a/tests/unitary/with_extras/aqua/test_common_entities.py +++ b/tests/unitary/with_extras/aqua/test_common_entities.py @@ -4,9 +4,16 @@ # Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from typing import Dict, List, Union +from unittest.mock import MagicMock, patch + import pytest -from ads.aqua.common.entities import ComputeShapeSummary, ContainerPath +from ads.aqua.common.entities import ( + AquaMultiModelRef, + ComputeShapeSummary, + ContainerPath, +) class TestComputeShapeSummary: @@ -119,3 +126,127 @@ class TestContainerPath: ) def test_positive(self, image_path, expected_result): assert ContainerPath(full_path=image_path).model_dump() == expected_result + + +class TestAquaMultiModelRef: + @pytest.mark.parametrize( + "env_var, params, expected_params", + [ + ( + {"PARAMS": "--max-model-len 8192 --enforce-eager"}, + {}, + {"--max-model-len": "8192", "--enforce-eager": "UNKNOWN"}, + ), + ( + {"PARAMS": "--a 1 --b 2"}, + {"--a": "existing"}, + {"--a": "existing", "--b": "2"}, + ), + ( + {"PARAMS": "--x 1"}, + None, + {"--x": "1"}, + ), + ( + {}, # No PARAMS key + {"--existing": "value"}, + {"--existing": "value"}, + ), + ], + ) + @patch.object(AquaMultiModelRef, "_parse_params") + def test_extract_params_from_env_var( + self, mock_parse_params, env_var, params, expected_params + ): + mock_parse_params.return_value = {k: v for k, v in expected_params.items()} + + values = { + "model_id": "ocid1.model.oc1..xxxx", + "env_var": dict(env_var), # copy + "params": params, + } + + result = AquaMultiModelRef.model_validate(values) + assert result.params == expected_params + assert "PARAMS" not in result.env_var + + @patch.object(AquaMultiModelRef, "_parse_params") + def test_extract_params_from_env_var_skips_override(self, mock_parse_params): + input_params = {"--max-model-len": "65536"} + env_var = {"PARAMS": "--max-model-len 8000 --new-flag yes"} + + mock_parse_params.return_value = { + "--max-model-len": "8000", + "--new-flag": "yes", + } + + values = { + "model_id": "ocid1.model.oc1..abcd", + "params": dict(input_params), + "env_var": dict(env_var), + } + + result = AquaMultiModelRef.model_validate(values) + assert result.params["--max-model-len"] == "65536" # original + assert result.params["--new-flag"] == "yes" + + def test_extract_params_from_env_var_missing_env(self): + values = { + "model_id": "ocid1.model.oc1..abcd", + } + result = AquaMultiModelRef.model_validate(values) + assert result.env_var == {} + assert result.params == {} + + def test_all_model_ids_no_finetunes(self): + model = AquaMultiModelRef(model_id="ocid1.model.oc1..base") + assert model.all_model_ids() == ["ocid1.model.oc1..base"] + + @patch.object(AquaMultiModelRef, "_parse_params") + def test_model_validator_with_other_fields(self, mock_parse_params): + values = { + "model_id": "ocid1.model.oc1..xyz", + "gpu_count": 2, + "artifact_location": "some/path", + "env_var": {"PARAMS": "--x abc"}, + } + + mock_parse_params.return_value = {"--x": "abc"} + + result = AquaMultiModelRef.model_validate(values) + + assert result.model_id == "ocid1.model.oc1..xyz" + assert result.gpu_count == 2 + assert result.artifact_location == "some/path" + assert result.params == {"--x": "abc"} + + @pytest.mark.parametrize( + "input_param,expected_dict", + [ + ( + "--max-model-len 65536 --enable-streaming", + {"--max-model-len": "65536", "--enable-streaming": ""}, + ), + ( + ["--max-model-len 4096", "--foo bar"], + {"--max-model-len": "4096", "--foo": "bar"}, + ), + ( + "", + {}, + ), + ( + None, + {}, + ), + ( + "--key1 value1 --key2 value with spaces", + {"--key1": "value1", "--key2": "value with spaces"}, + ), + ], + ) + def test_parse_params( + self, input_param: Union[str, List[str]], expected_dict: Dict[str, str] + ): + result = AquaMultiModelRef._parse_params(input_param) + assert result == expected_dict From bc939f9840ff3123f58a60d42fe3af35c5441f72 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Tue, 9 Sep 2025 16:40:54 -0700 Subject: [PATCH 2/3] Fixes unit tests --- tests/unitary/with_extras/aqua/test_deployment.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index c7ac40a71..2254bb6a5 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -486,6 +486,7 @@ class TestDataset: "models": [ { "env_var": {}, + "params": {}, "gpu_count": 2, "model_id": "test_model_id_1", "model_name": "test_model_1", @@ -495,6 +496,7 @@ class TestDataset: }, { "env_var": {}, + "params": {}, "gpu_count": 2, "model_id": "test_model_id_2", "model_name": "test_model_2", @@ -504,6 +506,7 @@ class TestDataset: }, { "env_var": {}, + "params": {}, "gpu_count": 2, "model_id": "test_model_id_3", "model_name": "test_model_3", @@ -985,6 +988,7 @@ class TestDataset: multi_model_deployment_model_attributes = [ { "env_var": {"--test_key_one": "test_value_one"}, + "params": {}, "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_one", @@ -994,6 +998,7 @@ class TestDataset: }, { "env_var": {"--test_key_two": "test_value_two"}, + "params": {}, "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_two", @@ -1003,6 +1008,7 @@ class TestDataset: }, { "env_var": {"--test_key_three": "test_value_three"}, + "params": {}, "gpu_count": 1, "model_id": "ocid1.compartment.oc1..", "model_name": "model_three", From bee73228991e1b0e8c0ad05846e6a9c706b0d060 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Tue, 9 Sep 2025 16:45:53 -0700 Subject: [PATCH 3/3] Fixes by comments --- ads/aqua/common/entities.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ads/aqua/common/entities.py b/ads/aqua/common/entities.py index ba280e10e..f537b32f3 100644 --- a/ads/aqua/common/entities.py +++ b/ads/aqua/common/entities.py @@ -368,7 +368,7 @@ def _parse_params(params: Union[str, List[str]]) -> Dict[str, str]: Dict[str, str] Dictionary with parameter names as keys and their corresponding values as strings. """ - if not params: + if not params or not isinstance(params, (str, list)): return {} # Normalize string to list of "--key value" strings