diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 654e00dc8..6a534ac4a 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -46,7 +46,6 @@ from ads.config import ( AQUA_CONFIG_FOLDER, AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME, - AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME, AQUA_MODEL_DEPLOYMENT_CONFIG, AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS, COMPARTMENT_OCID, @@ -87,26 +86,26 @@ class AquaDeploymentApp(AquaApp): @telemetry(entry_point="plugin=deployment&action=create", name="aqua") def create( - self, - model_id: str, - instance_shape: str, - display_name: str, - instance_count: int = None, - log_group_id: str = None, - access_log_id: str = None, - predict_log_id: str = None, - compartment_id: str = None, - project_id: str = None, - description: str = None, - bandwidth_mbps: int = None, - web_concurrency: int = None, - server_port: int = None, - health_check_port: int = None, - env_var: Dict = None, - container_family: str = None, - memory_in_gbs: Optional[float] = None, - ocpus: Optional[float] = None, - model_file: Optional[str] = None, + self, + model_id: str, + instance_shape: str, + display_name: str, + instance_count: int = None, + log_group_id: str = None, + access_log_id: str = None, + predict_log_id: str = None, + compartment_id: str = None, + project_id: str = None, + description: str = None, + bandwidth_mbps: int = None, + web_concurrency: int = None, + server_port: int = None, + health_check_port: int = None, + env_var: Dict = None, + container_family: str = None, + memory_in_gbs: Optional[float] = None, + ocpus: Optional[float] = None, + model_file: Optional[str] = None, ) -> "AquaDeployment": """ Creates a new Aqua deployment @@ -175,6 +174,7 @@ def create( tags[tag] = aqua_model.freeform_tags[tag] tags.update({Tags.AQUA_MODEL_NAME_TAG: aqua_model.display_name}) + tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, None)}) # Set up info to get deployment config config_source_id = model_id @@ -231,8 +231,7 @@ def create( env_var.update({"FT_MODEL": f"{fine_tune_output_path}"}) container_type_key = self._get_container_type_key( - model=aqua_model, - container_family=container_family + model=aqua_model, container_family=container_family ) # fetch image name from config @@ -248,7 +247,11 @@ def create( model_format = model_formats_str.split(",") # Figure out a better way to handle this in future release - if ModelFormat.GGUF.value in model_format and container_type_key.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY: + if ( + ModelFormat.GGUF.value in model_format + and container_type_key.lower() + == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY + ): if model_file is not None: logger.info( f"Overriding {model_file} as model_file for model {aqua_model.id}." @@ -299,8 +302,8 @@ def create( if user_params: # todo: remove this check in the future version, logic to be moved to container_index if ( - container_type_key.lower() - == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY + container_type_key.lower() + == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY ): # AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params # to be set as env vars @@ -422,9 +425,8 @@ def _get_container_type_key(model: DataScienceModel, container_family: str) -> s f"for model {model.id}. For unverified Aqua models, {AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} should be" f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}." ) from err - + return container_type_key - @telemetry(entry_point="plugin=deployment&action=list", name="aqua") def list(self, **kwargs) -> List["AquaDeployment"]: @@ -453,8 +455,8 @@ def list(self, **kwargs) -> List["AquaDeployment"]: for model_deployment in model_deployments: oci_aqua = ( ( - Tags.AQUA_TAG in model_deployment.freeform_tags - or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags + Tags.AQUA_TAG in model_deployment.freeform_tags + or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags ) if model_deployment.freeform_tags else False @@ -508,8 +510,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": oci_aqua = ( ( - Tags.AQUA_TAG in model_deployment.freeform_tags - or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags + Tags.AQUA_TAG in model_deployment.freeform_tags + or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags ) if model_deployment.freeform_tags else False @@ -526,8 +528,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail": log_group_name = "" logs = ( - model_deployment.category_log_details.access - or model_deployment.category_log_details.predict + model_deployment.category_log_details.access + or model_deployment.category_log_details.predict ) if logs: log_id = logs.log_id @@ -582,9 +584,9 @@ def get_deployment_config(self, model_id: str) -> Dict: return config def get_deployment_default_params( - self, - model_id: str, - instance_shape: str, + self, + model_id: str, + instance_shape: str, ) -> List[str]: """Gets the default params set in the deployment configs for the given model and instance shape. @@ -616,8 +618,8 @@ def get_deployment_default_params( ) if ( - container_type_key - and container_type_key in InferenceContainerTypeFamily.values() + container_type_key + and container_type_key in InferenceContainerTypeFamily.values() ): deployment_config = self.get_deployment_config(model_id) config_params = ( @@ -640,10 +642,10 @@ def get_deployment_default_params( return default_params def validate_deployment_params( - self, - model_id: str, - params: List[str] = None, - container_family: str = None, + self, + model_id: str, + params: List[str] = None, + container_family: str = None, ) -> Dict: """Validate if the deployment parameters passed by the user can be overridden. Parameter values are not validated, only param keys are validated. @@ -666,8 +668,7 @@ def validate_deployment_params( if params: model = DataScienceModel.from_id(model_id) container_type_key = self._get_container_type_key( - model=model, - container_family=container_family + model=model, container_family=container_family ) container_config = get_container_config() @@ -689,9 +690,9 @@ def validate_deployment_params( @staticmethod def _find_restricted_params( - default_params: Union[str, List[str]], - user_params: Union[str, List[str]], - container_family: str, + default_params: Union[str, List[str]], + user_params: Union[str, List[str]], + container_family: str, ) -> List[str]: """Returns a list of restricted params that user chooses to override when creating an Aqua deployment. The default parameters coming from the container index json file cannot be overridden.