Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 49 additions & 48 deletions ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
from ads.config import (
AQUA_CONFIG_FOLDER,
AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME,
AQUA_DEPLOYMENT_CONTAINER_OVERRIDE_FLAG_METADATA_NAME,
AQUA_MODEL_DEPLOYMENT_CONFIG,
AQUA_MODEL_DEPLOYMENT_CONFIG_DEFAULTS,
COMPARTMENT_OCID,
Expand Down Expand Up @@ -87,26 +86,26 @@ class AquaDeploymentApp(AquaApp):

@telemetry(entry_point="plugin=deployment&action=create", name="aqua")
def create(
self,
model_id: str,
instance_shape: str,
display_name: str,
instance_count: int = None,
log_group_id: str = None,
access_log_id: str = None,
predict_log_id: str = None,
compartment_id: str = None,
project_id: str = None,
description: str = None,
bandwidth_mbps: int = None,
web_concurrency: int = None,
server_port: int = None,
health_check_port: int = None,
env_var: Dict = None,
container_family: str = None,
memory_in_gbs: Optional[float] = None,
ocpus: Optional[float] = None,
model_file: Optional[str] = None,
self,
model_id: str,
instance_shape: str,
display_name: str,
instance_count: int = None,
log_group_id: str = None,
access_log_id: str = None,
predict_log_id: str = None,
compartment_id: str = None,
project_id: str = None,
description: str = None,
bandwidth_mbps: int = None,
web_concurrency: int = None,
server_port: int = None,
health_check_port: int = None,
env_var: Dict = None,
container_family: str = None,
memory_in_gbs: Optional[float] = None,
ocpus: Optional[float] = None,
model_file: Optional[str] = None,
) -> "AquaDeployment":
"""
Creates a new Aqua deployment
Expand Down Expand Up @@ -175,6 +174,7 @@ def create(
tags[tag] = aqua_model.freeform_tags[tag]

tags.update({Tags.AQUA_MODEL_NAME_TAG: aqua_model.display_name})
tags.update({Tags.TASK: aqua_model.freeform_tags.get(Tags.TASK, None)})

# Set up info to get deployment config
config_source_id = model_id
Expand Down Expand Up @@ -231,8 +231,7 @@ def create(
env_var.update({"FT_MODEL": f"{fine_tune_output_path}"})

container_type_key = self._get_container_type_key(
model=aqua_model,
container_family=container_family
model=aqua_model, container_family=container_family
)

# fetch image name from config
Expand All @@ -248,7 +247,11 @@ def create(
model_format = model_formats_str.split(",")

# Figure out a better way to handle this in future release
if ModelFormat.GGUF.value in model_format and container_type_key.lower() == InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY:
if (
ModelFormat.GGUF.value in model_format
and container_type_key.lower()
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
):
if model_file is not None:
logger.info(
f"Overriding {model_file} as model_file for model {aqua_model.id}."
Expand Down Expand Up @@ -299,8 +302,8 @@ def create(
if user_params:
# todo: remove this check in the future version, logic to be moved to container_index
if (
container_type_key.lower()
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
container_type_key.lower()
== InferenceContainerTypeFamily.AQUA_LLAMA_CPP_CONTAINER_FAMILY
):
# AQUA_LLAMA_CPP_CONTAINER_FAMILY container uses uvicorn that required model/server params
# to be set as env vars
Expand Down Expand Up @@ -422,9 +425,8 @@ def _get_container_type_key(model: DataScienceModel, container_family: str) -> s
f"for model {model.id}. For unverified Aqua models, {AQUA_DEPLOYMENT_CONTAINER_METADATA_NAME} should be"
f"set and value can be one of {', '.join(InferenceContainerTypeFamily.values())}."
) from err

return container_type_key


@telemetry(entry_point="plugin=deployment&action=list", name="aqua")
def list(self, **kwargs) -> List["AquaDeployment"]:
Expand Down Expand Up @@ -453,8 +455,8 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
for model_deployment in model_deployments:
oci_aqua = (
(
Tags.AQUA_TAG in model_deployment.freeform_tags
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
Tags.AQUA_TAG in model_deployment.freeform_tags
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
)
if model_deployment.freeform_tags
else False
Expand Down Expand Up @@ -508,8 +510,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":

oci_aqua = (
(
Tags.AQUA_TAG in model_deployment.freeform_tags
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
Tags.AQUA_TAG in model_deployment.freeform_tags
or Tags.AQUA_TAG.lower() in model_deployment.freeform_tags
)
if model_deployment.freeform_tags
else False
Expand All @@ -526,8 +528,8 @@ def get(self, model_deployment_id: str, **kwargs) -> "AquaDeploymentDetail":
log_group_name = ""

logs = (
model_deployment.category_log_details.access
or model_deployment.category_log_details.predict
model_deployment.category_log_details.access
or model_deployment.category_log_details.predict
)
if logs:
log_id = logs.log_id
Expand Down Expand Up @@ -582,9 +584,9 @@ def get_deployment_config(self, model_id: str) -> Dict:
return config

def get_deployment_default_params(
self,
model_id: str,
instance_shape: str,
self,
model_id: str,
instance_shape: str,
) -> List[str]:
"""Gets the default params set in the deployment configs for the given model and instance shape.

Expand Down Expand Up @@ -616,8 +618,8 @@ def get_deployment_default_params(
)

if (
container_type_key
and container_type_key in InferenceContainerTypeFamily.values()
container_type_key
and container_type_key in InferenceContainerTypeFamily.values()
):
deployment_config = self.get_deployment_config(model_id)
config_params = (
Expand All @@ -640,10 +642,10 @@ def get_deployment_default_params(
return default_params

def validate_deployment_params(
self,
model_id: str,
params: List[str] = None,
container_family: str = None,
self,
model_id: str,
params: List[str] = None,
container_family: str = None,
) -> Dict:
"""Validate if the deployment parameters passed by the user can be overridden. Parameter values are not
validated, only param keys are validated.
Expand All @@ -666,8 +668,7 @@ def validate_deployment_params(
if params:
model = DataScienceModel.from_id(model_id)
container_type_key = self._get_container_type_key(
model=model,
container_family=container_family
model=model, container_family=container_family
)

container_config = get_container_config()
Expand All @@ -689,9 +690,9 @@ def validate_deployment_params(

@staticmethod
def _find_restricted_params(
default_params: Union[str, List[str]],
user_params: Union[str, List[str]],
container_family: str,
default_params: Union[str, List[str]],
user_params: Union[str, List[str]],
container_family: str,
) -> List[str]:
"""Returns a list of restricted params that user chooses to override when creating an Aqua deployment.
The default parameters coming from the container index json file cannot be overridden.
Expand Down
Loading