77
88from ads .aqua .app import AquaApp , logger
99from ads .aqua .common .enums import (
10- InferenceContainerType ,
1110 InferenceContainerTypeFamily ,
1211 Tags ,
1312)
2221 get_params_dict ,
2322 get_params_list ,
2423 get_resource_name ,
24+ get_restricted_params_by_container ,
2525 load_config ,
2626)
2727from ads .aqua .constants import (
3434from ads .aqua .data import AquaResourceIdentifier
3535from ads .aqua .finetuning .finetuning import FineTuneCustomMetadata
3636from ads .aqua .model import AquaModelApp
37- from ads .aqua .modeldeployment .constants import (
38- TGIInferenceRestrictedParams ,
39- VLLMInferenceRestrictedParams ,
40- )
4137from ads .aqua .modeldeployment .entities import (
4238 AquaDeployment ,
4339 AquaDeploymentDetail ,
@@ -567,19 +563,27 @@ def get_deployment_default_params(
567563 f"{ AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME } key is not available in the custom metadata field for model { model_id } ."
568564 )
569565
570- if container_type_key :
571- container_type_key = container_type_key .lower ()
572- if container_type_key in InferenceContainerTypeFamily .values ():
573- deployment_config = self .get_deployment_config (model_id )
574- params = (
575- deployment_config .get ("configuration" , UNKNOWN_DICT )
576- .get (instance_shape , UNKNOWN_DICT )
577- .get ("parameters" , UNKNOWN_DICT )
578- .get (get_container_params_type (container_type_key ))
566+ if (
567+ container_type_key
568+ and container_type_key in InferenceContainerTypeFamily .values ()
569+ ):
570+ deployment_config = self .get_deployment_config (model_id )
571+ config_params = (
572+ deployment_config .get ("configuration" , UNKNOWN_DICT )
573+ .get (instance_shape , UNKNOWN_DICT )
574+ .get ("parameters" , UNKNOWN_DICT )
575+ .get (get_container_params_type (container_type_key ), UNKNOWN )
576+ )
577+ if config_params :
578+ params_list = get_params_list (config_params )
579+ restricted_params_set = get_restricted_params_by_container (
580+ container_type_key
579581 )
580- if params :
581- # account for param that can have --arg but no values, e.g. --trust-remote-code
582- default_params .extend (get_params_list (params ))
582+
583+ # remove restricted params from the list as user cannot override them during deployment
584+ for params in params_list :
585+ if params .split ()[0 ] not in restricted_params_set :
586+ default_params .append (params )
583587
584588 return default_params
585589
@@ -651,8 +655,7 @@ def _find_restricted_params(
651655 container_family : str ,
652656 ) -> List [str ]:
653657 """Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
654- The default parameters coming from the container index json file cannot be overridden. In addition to this,
655- a set of parameters maintained in
658+ The default parameters coming from the container index json file cannot be overridden.
656659
657660 Parameters
658661 ----------
@@ -673,18 +676,9 @@ def _find_restricted_params(
673676 default_params_dict = get_params_dict (default_params )
674677 user_params_dict = get_params_dict (user_params )
675678
679+ restricted_params_set = get_restricted_params_by_container (container_family )
676680 for key , _items in user_params_dict .items ():
677- if (
678- key in default_params_dict
679- or (
680- InferenceContainerType .CONTAINER_TYPE_VLLM in container_family
681- and key in VLLMInferenceRestrictedParams
682- )
683- or (
684- InferenceContainerType .CONTAINER_TYPE_TGI in container_family
685- and key in TGIInferenceRestrictedParams
686- )
687- ):
681+ if key in default_params_dict or key in restricted_params_set :
688682 restricted_params .append (key .lstrip ("-" ))
689683
690684 return restricted_params
0 commit comments