4646from ads .config import (
4747 AQUA_CONFIG_FOLDER ,
4848 AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME ,
49- AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME ,
5049 AQUA_MODEL_DEPLOYMENT_CONFIG ,
5150 AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS ,
5251 COMPARTMENT_OCID ,
@@ -87,26 +86,27 @@ class AquaDeploymentApp(AquaApp):
8786
8887 @telemetry (entry_point = "plugin=deployment&action=create" , name = "aqua" )
8988 def create (
90- self ,
91- model_id : str ,
92- instance_shape : str ,
93- display_name : str ,
94- instance_count : int = None ,
95- log_group_id : str = None ,
96- access_log_id : str = None ,
97- predict_log_id : str = None ,
98- compartment_id : str = None ,
99- project_id : str = None ,
100- description : str = None ,
101- bandwidth_mbps : int = None ,
102- web_concurrency : int = None ,
103- server_port : int = None ,
104- health_check_port : int = None ,
105- env_var : Dict = None ,
106- container_family : str = None ,
107- memory_in_gbs : Optional [float ] = None ,
108- ocpus : Optional [float ] = None ,
109- model_file : Optional [str ] = None ,
89+ self ,
90+ model_id : str ,
91+ instance_shape : str ,
92+ display_name : str ,
93+ instance_count : int = None ,
94+ log_group_id : str = None ,
95+ access_log_id : str = None ,
96+ predict_log_id : str = None ,
97+ compartment_id : str = None ,
98+ project_id : str = None ,
99+ description : str = None ,
100+ bandwidth_mbps : int = None ,
101+ web_concurrency : int = None ,
102+ server_port : int = None ,
103+ health_check_port : int = None ,
104+ env_var : Dict = None ,
105+ container_family : str = None ,
106+ memory_in_gbs : Optional [float ] = None ,
107+ ocpus : Optional [float ] = None ,
108+ model_file : Optional [str ] = None ,
109+ cmd_var : List [str ] = None ,
110110 ) -> "AquaDeployment" :
111111 """
112112 Creates a new Aqua deployment
@@ -153,6 +153,8 @@ def create(
153153 The ocpu count for the shape selected.
154154 model_file: str
155155 The file used for model deployment.
156+ cmd_var: List[str]
157+ The cmd of model deployment container runtime.
156158 Returns
157159 -------
158160 AquaDeployment
@@ -231,8 +233,7 @@ def create(
231233 env_var .update ({"FT_MODEL" : f"{ fine_tune_output_path } " })
232234
233235 container_type_key = self ._get_container_type_key (
234- model = aqua_model ,
235- container_family = container_family
236+ model = aqua_model , container_family = container_family
236237 )
237238
238239 # fetch image name from config
@@ -248,7 +249,11 @@ def create(
248249 model_format = model_formats_str .split ("," )
249250
250251 # Figure out a better way to handle this in future release
251- if ModelFormat .GGUF .value in model_format and container_type_key .lower () == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY :
252+ if (
253+ ModelFormat .GGUF .value in model_format
254+ and container_type_key .lower ()
255+ == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
256+ ):
252257 if model_file is not None :
253258 logger .info (
254259 f"Overriding { model_file } as model_file for model { aqua_model .id } ."
@@ -299,8 +304,8 @@ def create(
299304 if user_params :
300305 # todo: remove this check in the future version, logic to be moved to container_index
301306 if (
302- container_type_key .lower ()
303- == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
307+ container_type_key .lower ()
308+ == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
304309 ):
305310 # AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params
306311 # to be set as env vars
@@ -369,6 +374,8 @@ def create(
369374 .with_overwrite_existing_artifact (True )
370375 .with_remove_existing_artifact (True )
371376 )
377+ if cmd_var :
378+ container_runtime .with_cmd (cmd_var )
372379
373380 # configure model deployment and deploy model on container runtime
374381 deployment = (
@@ -422,9 +429,8 @@ def _get_container_type_key(model: DataScienceModel, container_family: str) -> s
422429 f"for model { model .id } . For unverified Aqua models, { AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME } should be"
423430 f"set and value can be one of { ', ' .join (InferenceContainerTypeFamily .values ())} ."
424431 ) from err
425-
432+
426433 return container_type_key
427-
428434
429435 @telemetry (entry_point = "plugin=deployment&action=list" , name = "aqua" )
430436 def list (self , ** kwargs ) -> List ["AquaDeployment" ]:
@@ -453,8 +459,8 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
453459 for model_deployment in model_deployments :
454460 oci_aqua = (
455461 (
456- Tags .AQUA_TAG in model_deployment .freeform_tags
457- or Tags .AQUA_TAG .lower () in model_deployment .freeform_tags
462+ Tags .AQUA_TAG in model_deployment .freeform_tags
463+ or Tags .AQUA_TAG .lower () in model_deployment .freeform_tags
458464 )
459465 if model_deployment .freeform_tags
460466 else False
@@ -508,8 +514,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
508514
509515 oci_aqua = (
510516 (
511- Tags .AQUA_TAG in model_deployment .freeform_tags
512- or Tags .AQUA_TAG .lower () in model_deployment .freeform_tags
517+ Tags .AQUA_TAG in model_deployment .freeform_tags
518+ or Tags .AQUA_TAG .lower () in model_deployment .freeform_tags
513519 )
514520 if model_deployment .freeform_tags
515521 else False
@@ -526,8 +532,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
526532 log_group_name = ""
527533
528534 logs = (
529- model_deployment .category_log_details .access
530- or model_deployment .category_log_details .predict
535+ model_deployment .category_log_details .access
536+ or model_deployment .category_log_details .predict
531537 )
532538 if logs :
533539 log_id = logs .log_id
@@ -582,9 +588,9 @@ def get_deployment_config(self, model_id: str) -> Dict:
582588 return config
583589
584590 def get_deployment_default_params (
585- self ,
586- model_id : str ,
587- instance_shape : str ,
591+ self ,
592+ model_id : str ,
593+ instance_shape : str ,
588594 ) -> List [str ]:
589595 """Gets the default params set in the deployment configs for the given model and instance shape.
590596
@@ -616,8 +622,8 @@ def get_deployment_default_params(
616622 )
617623
618624 if (
619- container_type_key
620- and container_type_key in InferenceContainerTypeFamily .values ()
625+ container_type_key
626+ and container_type_key in InferenceContainerTypeFamily .values ()
621627 ):
622628 deployment_config = self .get_deployment_config (model_id )
623629 config_params = (
@@ -640,10 +646,10 @@ def get_deployment_default_params(
640646 return default_params
641647
642648 def validate_deployment_params (
643- self ,
644- model_id : str ,
645- params : List [str ] = None ,
646- container_family : str = None ,
649+ self ,
650+ model_id : str ,
651+ params : List [str ] = None ,
652+ container_family : str = None ,
647653 ) -> Dict :
648654 """Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
649655 validated, only param keys are validated.
@@ -666,8 +672,7 @@ def validate_deployment_params(
666672 if params :
667673 model = DataScienceModel .from_id (model_id )
668674 container_type_key = self ._get_container_type_key (
669- model = model ,
670- container_family = container_family
675+ model = model , container_family = container_family
671676 )
672677
673678 container_config = get_container_config ()
@@ -689,9 +694,9 @@ def validate_deployment_params(
689694
690695 @staticmethod
691696 def _find_restricted_params (
692- default_params : Union [str , List [str ]],
693- user_params : Union [str , List [str ]],
694- container_family : str ,
697+ default_params : Union [str , List [str ]],
698+ user_params : Union [str , List [str ]],
699+ container_family : str ,
695700 ) -> List [str ]:
696701 """Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
697702 The default parameters coming from the container index json file cannot be overridden.
0 commit comments