4646from ads .config import (
4747 AQUA_CONFIG_FOLDER ,
4848 AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME ,
49- AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME ,
5049 AQUA_MODEL_DEPLOYMENT_CONFIG ,
5150 AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS ,
5251 COMPARTMENT_OCID ,
@@ -87,27 +86,27 @@ class AquaDeploymentApp(AquaApp):
8786
8887 @telemetry (entry_point = "plugin=deployment&action=create" , name = "aqua" )
8988 def create (
90- self ,
91- model_id : str ,
92- instance_shape : str ,
93- display_name : str ,
94- instance_count : int = None ,
95- log_group_id : str = None ,
96- access_log_id : str = None ,
97- predict_log_id : str = None ,
98- compartment_id : str = None ,
99- project_id : str = None ,
100- description : str = None ,
101- bandwidth_mbps : int = None ,
102- web_concurrency : int = None ,
103- server_port : int = None ,
104- health_check_port : int = None ,
105- env_var : Dict = None ,
106- container_family : str = None ,
107- memory_in_gbs : Optional [float ] = None ,
108- ocpus : Optional [float ] = None ,
109- model_file : Optional [str ] = None ,
110- private_endpoint_id : Optional [str ] = None ,
89+ self ,
90+ model_id : str ,
91+ instance_shape : str ,
92+ display_name : str ,
93+ instance_count : int = None ,
94+ log_group_id : str = None ,
95+ access_log_id : str = None ,
96+ predict_log_id : str = None ,
97+ compartment_id : str = None ,
98+ project_id : str = None ,
99+ description : str = None ,
100+ bandwidth_mbps : int = None ,
101+ web_concurrency : int = None ,
102+ server_port : int = None ,
103+ health_check_port : int = None ,
104+ env_var : Dict = None ,
105+ container_family : str = None ,
106+ memory_in_gbs : Optional [float ] = None ,
107+ ocpus : Optional [float ] = None ,
108+ model_file : Optional [str ] = None ,
109+ private_endpoint_id : Optional [str ] = None ,
111110 ) -> "AquaDeployment" :
112111 """
113112 Creates a new Aqua deployment
@@ -179,6 +178,7 @@ def create(
179178 tags [tag ] = aqua_model .freeform_tags [tag ]
180179
181180 tags .update ({Tags .AQUA_MODEL_NAME_TAG : aqua_model .display_name })
181+ tags .update ({Tags .TASK : aqua_model .freeform_tags .get (Tags .TASK , None )})
182182
183183 # Set up info to get deployment config
184184 config_source_id = model_id
@@ -235,8 +235,7 @@ def create(
235235 env_var .update ({"FT_MODEL" : f"{ fine_tune_output_path } " })
236236
237237 container_type_key = self ._get_container_type_key (
238- model = aqua_model ,
239- container_family = container_family
238+ model = aqua_model , container_family = container_family
240239 )
241240
242241 # fetch image name from config
@@ -252,7 +251,11 @@ def create(
252251 model_format = model_formats_str .split ("," )
253252
254253 # Figure out a better way to handle this in future release
255- if ModelFormat .GGUF .value in model_format and container_type_key .lower () == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY :
254+ if (
255+ ModelFormat .GGUF .value in model_format
256+ and container_type_key .lower ()
257+ == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
258+ ):
256259 if model_file is not None :
257260 logger .info (
258261 f"Overriding { model_file } as model_file for model { aqua_model .id } ."
@@ -303,8 +306,8 @@ def create(
303306 if user_params :
304307 # todo: remove this check in the future version, logic to be moved to container_index
305308 if (
306- container_type_key .lower ()
307- == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
309+ container_type_key .lower ()
310+ == InferenceContainerTypeFamily .AQUA_LLAMA_CPP_CONTAINER_FAMILY
308311 ):
309312 # AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params
310313 # to be set as env vars
@@ -427,9 +430,8 @@ def _get_container_type_key(model: DataScienceModel, container_family: str) -> s
427430 f"for model { model .id } . For unverified Aqua models, { AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME } should be"
428431 f"set and value can be one of { ', ' .join (InferenceContainerTypeFamily .values ())} ."
429432 ) from err
430-
433+
431434 return container_type_key
432-
433435
434436 @telemetry (entry_point = "plugin=deployment&action=list" , name = "aqua" )
435437 def list (self , ** kwargs ) -> List ["AquaDeployment" ]:
@@ -458,8 +460,8 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
458460 for model_deployment in model_deployments :
459461 oci_aqua = (
460462 (
461- Tags .AQUA_TAG in model_deployment .freeform_tags
462- or Tags .AQUA_TAG .lower () in model_deployment .freeform_tags
463+ Tags .AQUA_TAG in model_deployment .freeform_tags
464+ or Tags .AQUA_TAG .lower () in model_deployment .freeform_tags
463465 )
464466 if model_deployment .freeform_tags
465467 else False
@@ -513,8 +515,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
513515
514516 oci_aqua = (
515517 (
516- Tags .AQUA_TAG in model_deployment .freeform_tags
517- or Tags .AQUA_TAG .lower () in model_deployment .freeform_tags
518+ Tags .AQUA_TAG in model_deployment .freeform_tags
519+ or Tags .AQUA_TAG .lower () in model_deployment .freeform_tags
518520 )
519521 if model_deployment .freeform_tags
520522 else False
@@ -531,8 +533,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
531533 log_group_name = ""
532534
533535 logs = (
534- model_deployment .category_log_details .access
535- or model_deployment .category_log_details .predict
536+ model_deployment .category_log_details .access
537+ or model_deployment .category_log_details .predict
536538 )
537539 if logs :
538540 log_id = logs .log_id
@@ -587,9 +589,9 @@ def get_deployment_config(self, model_id: str) -> Dict:
587589 return config
588590
589591 def get_deployment_default_params (
590- self ,
591- model_id : str ,
592- instance_shape : str ,
592+ self ,
593+ model_id : str ,
594+ instance_shape : str ,
593595 ) -> List [str ]:
594596 """Gets the default params set in the deployment configs for the given model and instance shape.
595597
@@ -621,8 +623,8 @@ def get_deployment_default_params(
621623 )
622624
623625 if (
624- container_type_key
625- and container_type_key in InferenceContainerTypeFamily .values ()
626+ container_type_key
627+ and container_type_key in InferenceContainerTypeFamily .values ()
626628 ):
627629 deployment_config = self .get_deployment_config (model_id )
628630 config_params = (
@@ -645,10 +647,10 @@ def get_deployment_default_params(
645647 return default_params
646648
647649 def validate_deployment_params (
648- self ,
649- model_id : str ,
650- params : List [str ] = None ,
651- container_family : str = None ,
650+ self ,
651+ model_id : str ,
652+ params : List [str ] = None ,
653+ container_family : str = None ,
652654 ) -> Dict :
653655 """Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
654656 validated, only param keys are validated.
@@ -671,8 +673,7 @@ def validate_deployment_params(
671673 if params :
672674 model = DataScienceModel .from_id (model_id )
673675 container_type_key = self ._get_container_type_key (
674- model = model ,
675- container_family = container_family
676+ model = model , container_family = container_family
676677 )
677678
678679 container_config = get_container_config ()
@@ -694,9 +695,9 @@ def validate_deployment_params(
694695
695696 @staticmethod
696697 def _find_restricted_params (
697- default_params : Union [str , List [str ]],
698- user_params : Union [str , List [str ]],
699- container_family : str ,
698+ default_params : Union [str , List [str ]],
699+ user_params : Union [str , List [str ]],
700+ container_family : str ,
700701 ) -> List [str ]:
701702 """Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
702703 The default parameters coming from the container index json file cannot be overridden.
0 commit comments