Skip to content

Commit 4472805

Browse files
committed
Added support to edit multi deployment.
1 parent 5cfa052 commit 4472805

File tree

4 files changed

+329
-248
lines changed

4 files changed

+329
-248
lines changed

ads/aqua/modeldeployment/deployment.py

Lines changed: 133 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def create(
377377
combined_model_names,
378378
) = self._build_model_group_configs(
379379
models=create_deployment_details.models,
380-
create_deployment_details=create_deployment_details,
380+
deployment_details=create_deployment_details,
381381
model_config_summary=model_config_summary,
382382
freeform_tags=freeform_tags,
383383
source_models=source_models,
@@ -429,7 +429,9 @@ def _validate_input_models(
429429
def _build_model_group_configs(
430430
self,
431431
models: List[AquaMultiModelRef],
432-
create_deployment_details: CreateModelDeploymentDetails,
432+
deployment_details: Union[
433+
CreateModelDeploymentDetails, UpdateModelDeploymentDetails
434+
],
433435
model_config_summary: ModelDeploymentConfigSummary,
434436
freeform_tags: Optional[Dict] = None,
435437
source_models: Optional[Dict[str, DataScienceModel]] = None,
@@ -442,9 +444,9 @@ def _build_model_group_configs(
442444
----------
443445
models : List[AquaMultiModelRef]
444446
List of AquaMultiModelRef instances for creating a multi-model group.
445-
create_deployment_details : CreateModelDeploymentDetails
446-
An instance of CreateModelDeploymentDetails containing all required and optional
447-
fields for creating a model deployment via Aqua.
447+
deployment_details : Union[CreateModelDeploymentDetails, UpdateModelDeploymentDetails]
448+
An instance of CreateModelDeploymentDetails or UpdateModelDeploymentDetails containing all required and optional
449+
fields for creating or updating a model deployment via Aqua.
448450
model_config_summary : ModelConfigSummary
449451
Summary Model Deployment configuration for the group of models.
450452
freeform_tags : Optional[Dict]
@@ -667,7 +669,7 @@ def _build_model_group_configs(
667669
model_custom_metadata.add(
668670
key=AQUA_MULTI_MODEL_CONFIG,
669671
value=self._build_model_group_config(
670-
create_deployment_details=create_deployment_details,
672+
deployment_details=deployment_details,
671673
model_config_summary=model_config_summary,
672674
deployment_container=deployment_container,
673675
).model_dump_json(),
@@ -719,21 +721,21 @@ def _extract_model_task(
719721

720722
def _build_model_group_config(
721723
self,
722-
create_deployment_details,
724+
deployment_details: Union[
725+
CreateModelDeploymentDetails, UpdateModelDeploymentDetails
726+
],
723727
model_config_summary,
724728
deployment_container: str,
725729
) -> ModelGroupConfig:
726730
"""Builds model group config required to deploy multi models."""
727-
container_type_key = (
728-
create_deployment_details.container_family or deployment_container
729-
)
731+
container_type_key = deployment_details.container_family or deployment_container
730732
container_config = self.get_container_config_item(container_type_key)
731733
container_spec = container_config.spec if container_config else UNKNOWN
732734

733735
container_params = container_spec.cli_param if container_spec else UNKNOWN
734736

735-
multi_model_config = ModelGroupConfig.from_create_model_deployment_details(
736-
create_deployment_details,
737+
multi_model_config = ModelGroupConfig.from_model_deployment_details(
738+
deployment_details,
737739
model_config_summary,
738740
container_type_key,
739741
container_params,
@@ -1305,7 +1307,7 @@ def update(
13051307

13061308
# updates model group if fine tuned weights changed.
13071309
model = self._update_model_group(
1308-
runtime.model_group_id, update_model_deployment_details
1310+
runtime.model_group_id, update_model_deployment_details, model_deployment
13091311
)
13101312

13111313
# updates model group deployment infrastructure
@@ -1356,7 +1358,9 @@ def update(
13561358
# applies LIVE update if model group id has been changed
13571359
if runtime.model_group_id != model.id:
13581360
runtime.with_model_group_id(model.id)
1359-
update_type = ModelDeploymentUpdateType.LIVE
1361+
if model.dsc_model_group.model_group_details.type == DeploymentType.STACKED:
1362+
# only applies LIVE update for stacked deployment
1363+
update_type = ModelDeploymentUpdateType.LIVE
13601364

13611365
freeform_tags = (
13621366
update_model_deployment_details.freeform_tags
@@ -1395,6 +1399,7 @@ def _update_model_group(
13951399
self,
13961400
model_group_id: str,
13971401
update_model_deployment_details: UpdateModelDeploymentDetails,
1402+
model_deployment: ModelDeployment,
13981403
) -> DataScienceModelGroup:
13991404
"""Creates a new model group if fine tuned weights changed.
14001405
@@ -1405,69 +1410,133 @@ def _update_model_group(
14051410
update_model_deployment_details: UpdateModelDeploymentDetails
14061411
An instance of UpdateModelDeploymentDetails containing all optional
14071412
fields for updating a model deployment via Aqua.
1413+
model_deployment: ModelDeployment
1414+
An instance of ModelDeployment.
14081415
14091416
Returns
14101417
-------
14111418
DataScienceModelGroup
14121419
The instance of DataScienceModelGroup.
14131420
"""
14141421
model_group = DataScienceModelGroup.from_id(model_group_id)
1415-
if (
1416-
model_group.dsc_model_group.model_group_details.type
1417-
!= DeploymentType.STACKED
1418-
):
1419-
raise AquaValueError(
1420-
"Invalid 'model_deployment_id'. Only stacked deployment is supported to update."
1421-
)
1422-
# create a new model group if fine tune weights changed as member models in ds model group is inmutable
14231422
if update_model_deployment_details.models:
1424-
if len(update_model_deployment_details.models) != 1:
1425-
raise AquaValueError(
1426-
"Invalid 'models' provided. Only one base model is required for updating model stack deployment."
1427-
)
14281423
# validates input base and fine tune models
1429-
self._validate_input_models(update_model_deployment_details)
1430-
target_stacked_model = update_model_deployment_details.models[0]
1431-
target_base_model_id = target_stacked_model.model_id
1432-
if model_group.base_model_id != target_base_model_id:
1433-
raise AquaValueError(
1434-
"Invalid parameter 'models'. Base model id can't be changed for stacked model deployment."
1424+
source_models, _ = self._validate_input_models(
1425+
update_model_deployment_details
1426+
)
1427+
if (
1428+
model_group.dsc_model_group.model_group_details.type
1429+
== DeploymentType.STACKED
1430+
):
1431+
# create a new model group if fine tune weights changed as member models in ds model group is inmutable
1432+
if len(update_model_deployment_details.models) != 1:
1433+
raise AquaValueError(
1434+
"Invalid 'models' provided. Only one base model is required for updating model stack deployment."
1435+
)
1436+
target_stacked_model = update_model_deployment_details.models[0]
1437+
target_base_model_id = target_stacked_model.model_id
1438+
if model_group.base_model_id != target_base_model_id:
1439+
raise AquaValueError(
1440+
"Invalid parameter 'models'. Base model id can't be changed for stacked model deployment."
1441+
)
1442+
1443+
# add member models
1444+
member_models = [
1445+
{
1446+
"inference_key": fine_tune_weight.model_name,
1447+
"model_id": fine_tune_weight.model_id,
1448+
}
1449+
for fine_tune_weight in target_stacked_model.fine_tune_weights
1450+
]
1451+
# add base model
1452+
member_models.append(
1453+
{
1454+
"inference_key": target_stacked_model.model_name,
1455+
"model_id": target_base_model_id,
1456+
}
14351457
)
14361458

1437-
# add member models
1438-
member_models = [
1439-
{
1440-
"inference_key": fine_tune_weight.model_name,
1441-
"model_id": fine_tune_weight.model_id,
1442-
}
1443-
for fine_tune_weight in target_stacked_model.fine_tune_weights
1444-
]
1445-
# add base model
1446-
member_models.append(
1447-
{
1448-
"inference_key": target_stacked_model.model_name,
1449-
"model_id": target_base_model_id,
1450-
}
1451-
)
1459+
# creates a model group with the same configurations from original model group except member models
1460+
model_group = (
1461+
DataScienceModelGroup()
1462+
.with_compartment_id(model_group.compartment_id)
1463+
.with_project_id(model_group.project_id)
1464+
.with_display_name(model_group.display_name)
1465+
.with_description(model_group.description)
1466+
.with_freeform_tags(**(model_group.freeform_tags or {}))
1467+
.with_defined_tags(**(model_group.defined_tags or {}))
1468+
.with_custom_metadata_list(model_group.custom_metadata_list)
1469+
.with_base_model_id(target_base_model_id)
1470+
.with_member_models(member_models)
1471+
.create()
1472+
)
14521473

1453-
# creates a model group with the same configurations from original model group except member models
1454-
model_group = (
1455-
DataScienceModelGroup()
1456-
.with_compartment_id(model_group.compartment_id)
1457-
.with_project_id(model_group.project_id)
1458-
.with_display_name(model_group.display_name)
1459-
.with_description(model_group.description)
1460-
.with_freeform_tags(**(model_group.freeform_tags or {}))
1461-
.with_defined_tags(**(model_group.defined_tags or {}))
1462-
.with_custom_metadata_list(model_group.custom_metadata_list)
1463-
.with_base_model_id(target_base_model_id)
1464-
.with_member_models(member_models)
1465-
.create()
1466-
)
1474+
logger.info(
1475+
f"Model group of base model {target_base_model_id} has been updated: {model_group.id}."
1476+
)
1477+
else:
1478+
compartment_id = model_deployment.infrastructure.compartment_id
1479+
project_id = model_deployment.infrastructure.project_id
1480+
freeform_tags = (
1481+
update_model_deployment_details.freeform_tags
1482+
or model_deployment.freeform_tags
1483+
)
1484+
defined_tags = (
1485+
update_model_deployment_details.defined_tags
1486+
or model_deployment.defined_tags
1487+
)
1488+
# needs instance shape here for building the multi model config from update_model_deployment_details
1489+
update_model_deployment_details.instance_shape = (
1490+
model_deployment.infrastructure.shape_name
1491+
)
14671492

1468-
logger.info(
1469-
f"Model group of base model {target_base_model_id} has been updated: {model_group.id}."
1470-
)
1493+
# rebuilds MULTI_MODEL_CONFIG and creates model group
1494+
base_model_ids = [
1495+
model.model_id for model in update_model_deployment_details.models
1496+
]
1497+
1498+
try:
1499+
model_config_summary = self.get_multimodel_deployment_config(
1500+
model_ids=base_model_ids, compartment_id=compartment_id
1501+
)
1502+
if not model_config_summary.gpu_allocation:
1503+
raise AquaValueError(model_config_summary.error_message)
1504+
1505+
update_model_deployment_details.validate_multimodel_deployment_feasibility(
1506+
models_config_summary=model_config_summary
1507+
)
1508+
except ConfigValidationError as err:
1509+
raise AquaValueError(f"{err}") from err
1510+
1511+
(
1512+
model_group_display_name,
1513+
model_group_description,
1514+
tags,
1515+
model_custom_metadata,
1516+
combined_model_names,
1517+
) = self._build_model_group_configs(
1518+
models=update_model_deployment_details.models,
1519+
deployment_details=update_model_deployment_details,
1520+
model_config_summary=model_config_summary,
1521+
freeform_tags=freeform_tags,
1522+
source_models=source_models,
1523+
)
1524+
1525+
model_group = AquaModelApp().create_multi(
1526+
models=update_model_deployment_details.models,
1527+
model_custom_metadata=model_custom_metadata,
1528+
model_group_display_name=model_group_display_name,
1529+
model_group_description=model_group_description,
1530+
tags=tags,
1531+
combined_model_names=combined_model_names,
1532+
compartment_id=compartment_id,
1533+
project_id=project_id,
1534+
defined_tags=defined_tags,
1535+
)
1536+
1537+
logger.info(
1538+
f"Model group of multi model deployment {model_deployment.id} has been updated: {model_group.id}."
1539+
)
14711540

14721541
return model_group
14731542

0 commit comments

Comments
 (0)