@@ -377,7 +377,7 @@ def create(
377377 combined_model_names ,
378378 ) = self ._build_model_group_configs (
379379 models = create_deployment_details .models ,
380- create_deployment_details = create_deployment_details ,
380+ deployment_details = create_deployment_details ,
381381 model_config_summary = model_config_summary ,
382382 freeform_tags = freeform_tags ,
383383 source_models = source_models ,
@@ -429,7 +429,9 @@ def _validate_input_models(
429429 def _build_model_group_configs (
430430 self ,
431431 models : List [AquaMultiModelRef ],
432- create_deployment_details : CreateModelDeploymentDetails ,
432+ deployment_details : Union [
433+ CreateModelDeploymentDetails , UpdateModelDeploymentDetails
434+ ],
433435 model_config_summary : ModelDeploymentConfigSummary ,
434436 freeform_tags : Optional [Dict ] = None ,
435437 source_models : Optional [Dict [str , DataScienceModel ]] = None ,
@@ -442,9 +444,9 @@ def _build_model_group_configs(
442444 ----------
443445 models : List[AquaMultiModelRef]
444446 List of AquaMultiModelRef instances for creating a multi-model group.
445- create_deployment_details : CreateModelDeploymentDetails
446- An instance of CreateModelDeploymentDetails containing all required and optional
447- fields for creating a model deployment via Aqua.
447+ deployment_details : Union[ CreateModelDeploymentDetails, UpdateModelDeploymentDetails]
448+ An instance of CreateModelDeploymentDetails or UpdateModelDeploymentDetails containing all required and optional
449+ fields for creating or updating a model deployment via Aqua.
448450 model_config_summary : ModelConfigSummary
449451 Summary Model Deployment configuration for the group of models.
450452 freeform_tags : Optional[Dict]
@@ -667,7 +669,7 @@ def _build_model_group_configs(
667669 model_custom_metadata .add (
668670 key = AQUA_MULTI_MODEL_CONFIG ,
669671 value = self ._build_model_group_config (
670- create_deployment_details = create_deployment_details ,
672+ deployment_details = deployment_details ,
671673 model_config_summary = model_config_summary ,
672674 deployment_container = deployment_container ,
673675 ).model_dump_json (),
@@ -719,21 +721,21 @@ def _extract_model_task(
719721
720722 def _build_model_group_config (
721723 self ,
722- create_deployment_details ,
724+ deployment_details : Union [
725+ CreateModelDeploymentDetails , UpdateModelDeploymentDetails
726+ ],
723727 model_config_summary ,
724728 deployment_container : str ,
725729 ) -> ModelGroupConfig :
726730 """Builds model group config required to deploy multi models."""
727- container_type_key = (
728- create_deployment_details .container_family or deployment_container
729- )
731+ container_type_key = deployment_details .container_family or deployment_container
730732 container_config = self .get_container_config_item (container_type_key )
731733 container_spec = container_config .spec if container_config else UNKNOWN
732734
733735 container_params = container_spec .cli_param if container_spec else UNKNOWN
734736
735- multi_model_config = ModelGroupConfig .from_create_model_deployment_details (
736- create_deployment_details ,
737+ multi_model_config = ModelGroupConfig .from_model_deployment_details (
738+ deployment_details ,
737739 model_config_summary ,
738740 container_type_key ,
739741 container_params ,
@@ -1305,7 +1307,7 @@ def update(
13051307
13061308 # updates model group if fine tuned weights changed.
13071309 model = self ._update_model_group (
1308- runtime .model_group_id , update_model_deployment_details
1310+ runtime .model_group_id , update_model_deployment_details , model_deployment
13091311 )
13101312
13111313 # updates model group deployment infrastructure
@@ -1356,7 +1358,9 @@ def update(
13561358 # applies LIVE update if model group id has been changed
13571359 if runtime .model_group_id != model .id :
13581360 runtime .with_model_group_id (model .id )
1359- update_type = ModelDeploymentUpdateType .LIVE
1361+ if model .dsc_model_group .model_group_details .type == DeploymentType .STACKED :
1362+ # only applies LIVE update for stacked deployment
1363+ update_type = ModelDeploymentUpdateType .LIVE
13601364
13611365 freeform_tags = (
13621366 update_model_deployment_details .freeform_tags
@@ -1395,6 +1399,7 @@ def _update_model_group(
13951399 self ,
13961400 model_group_id : str ,
13971401 update_model_deployment_details : UpdateModelDeploymentDetails ,
1402+ model_deployment : ModelDeployment ,
13981403 ) -> DataScienceModelGroup :
13991404 """Creates a new model group if fine tuned weights changed.
14001405
@@ -1405,69 +1410,133 @@ def _update_model_group(
14051410 update_model_deployment_details: UpdateModelDeploymentDetails
14061411 An instance of UpdateModelDeploymentDetails containing all optional
14071412 fields for updating a model deployment via Aqua.
1413+ model_deployment: ModelDeployment
1414+ An instance of ModelDeployment.
14081415
14091416 Returns
14101417 -------
14111418 DataScienceModelGroup
14121419 The instance of DataScienceModelGroup.
14131420 """
14141421 model_group = DataScienceModelGroup .from_id (model_group_id )
1415- if (
1416- model_group .dsc_model_group .model_group_details .type
1417- != DeploymentType .STACKED
1418- ):
1419- raise AquaValueError (
1420- "Invalid 'model_deployment_id'. Only stacked deployment is supported to update."
1421- )
1422- # create a new model group if fine tune weights changed as member models in ds model group is inmutable
14231422 if update_model_deployment_details .models :
1424- if len (update_model_deployment_details .models ) != 1 :
1425- raise AquaValueError (
1426- "Invalid 'models' provided. Only one base model is required for updating model stack deployment."
1427- )
14281423 # validates input base and fine tune models
1429- self ._validate_input_models (update_model_deployment_details )
1430- target_stacked_model = update_model_deployment_details .models [0 ]
1431- target_base_model_id = target_stacked_model .model_id
1432- if model_group .base_model_id != target_base_model_id :
1433- raise AquaValueError (
1434- "Invalid parameter 'models'. Base model id can't be changed for stacked model deployment."
1424+ source_models , _ = self ._validate_input_models (
1425+ update_model_deployment_details
1426+ )
1427+ if (
1428+ model_group .dsc_model_group .model_group_details .type
1429+ == DeploymentType .STACKED
1430+ ):
1431+ # create a new model group if fine tune weights changed as member models in ds model group is inmutable
1432+ if len (update_model_deployment_details .models ) != 1 :
1433+ raise AquaValueError (
1434+ "Invalid 'models' provided. Only one base model is required for updating model stack deployment."
1435+ )
1436+ target_stacked_model = update_model_deployment_details .models [0 ]
1437+ target_base_model_id = target_stacked_model .model_id
1438+ if model_group .base_model_id != target_base_model_id :
1439+ raise AquaValueError (
1440+ "Invalid parameter 'models'. Base model id can't be changed for stacked model deployment."
1441+ )
1442+
1443+ # add member models
1444+ member_models = [
1445+ {
1446+ "inference_key" : fine_tune_weight .model_name ,
1447+ "model_id" : fine_tune_weight .model_id ,
1448+ }
1449+ for fine_tune_weight in target_stacked_model .fine_tune_weights
1450+ ]
1451+ # add base model
1452+ member_models .append (
1453+ {
1454+ "inference_key" : target_stacked_model .model_name ,
1455+ "model_id" : target_base_model_id ,
1456+ }
14351457 )
14361458
1437- # add member models
1438- member_models = [
1439- {
1440- "inference_key" : fine_tune_weight .model_name ,
1441- "model_id" : fine_tune_weight .model_id ,
1442- }
1443- for fine_tune_weight in target_stacked_model .fine_tune_weights
1444- ]
1445- # add base model
1446- member_models .append (
1447- {
1448- "inference_key" : target_stacked_model .model_name ,
1449- "model_id" : target_base_model_id ,
1450- }
1451- )
1459+ # creates a model group with the same configurations from original model group except member models
1460+ model_group = (
1461+ DataScienceModelGroup ()
1462+ .with_compartment_id (model_group .compartment_id )
1463+ .with_project_id (model_group .project_id )
1464+ .with_display_name (model_group .display_name )
1465+ .with_description (model_group .description )
1466+ .with_freeform_tags (** (model_group .freeform_tags or {}))
1467+ .with_defined_tags (** (model_group .defined_tags or {}))
1468+ .with_custom_metadata_list (model_group .custom_metadata_list )
1469+ .with_base_model_id (target_base_model_id )
1470+ .with_member_models (member_models )
1471+ .create ()
1472+ )
14521473
1453- # creates a model group with the same configurations from original model group except member models
1454- model_group = (
1455- DataScienceModelGroup ()
1456- .with_compartment_id (model_group .compartment_id )
1457- .with_project_id (model_group .project_id )
1458- .with_display_name (model_group .display_name )
1459- .with_description (model_group .description )
1460- .with_freeform_tags (** (model_group .freeform_tags or {}))
1461- .with_defined_tags (** (model_group .defined_tags or {}))
1462- .with_custom_metadata_list (model_group .custom_metadata_list )
1463- .with_base_model_id (target_base_model_id )
1464- .with_member_models (member_models )
1465- .create ()
1466- )
1474+ logger .info (
1475+ f"Model group of base model { target_base_model_id } has been updated: { model_group .id } ."
1476+ )
1477+ else :
1478+ compartment_id = model_deployment .infrastructure .compartment_id
1479+ project_id = model_deployment .infrastructure .project_id
1480+ freeform_tags = (
1481+ update_model_deployment_details .freeform_tags
1482+ or model_deployment .freeform_tags
1483+ )
1484+ defined_tags = (
1485+ update_model_deployment_details .defined_tags
1486+ or model_deployment .defined_tags
1487+ )
1488+ # needs instance shape here for building the multi model config from update_model_deployment_details
1489+ update_model_deployment_details .instance_shape = (
1490+ model_deployment .infrastructure .shape_name
1491+ )
14671492
1468- logger .info (
1469- f"Model group of base model { target_base_model_id } has been updated: { model_group .id } ."
1470- )
1493+ # rebuilds MULTI_MODEL_CONFIG and creates model group
1494+ base_model_ids = [
1495+ model .model_id for model in update_model_deployment_details .models
1496+ ]
1497+
1498+ try :
1499+ model_config_summary = self .get_multimodel_deployment_config (
1500+ model_ids = base_model_ids , compartment_id = compartment_id
1501+ )
1502+ if not model_config_summary .gpu_allocation :
1503+ raise AquaValueError (model_config_summary .error_message )
1504+
1505+ update_model_deployment_details .validate_multimodel_deployment_feasibility (
1506+ models_config_summary = model_config_summary
1507+ )
1508+ except ConfigValidationError as err :
1509+ raise AquaValueError (f"{ err } " ) from err
1510+
1511+ (
1512+ model_group_display_name ,
1513+ model_group_description ,
1514+ tags ,
1515+ model_custom_metadata ,
1516+ combined_model_names ,
1517+ ) = self ._build_model_group_configs (
1518+ models = update_model_deployment_details .models ,
1519+ deployment_details = update_model_deployment_details ,
1520+ model_config_summary = model_config_summary ,
1521+ freeform_tags = freeform_tags ,
1522+ source_models = source_models ,
1523+ )
1524+
1525+ model_group = AquaModelApp ().create_multi (
1526+ models = update_model_deployment_details .models ,
1527+ model_custom_metadata = model_custom_metadata ,
1528+ model_group_display_name = model_group_display_name ,
1529+ model_group_description = model_group_description ,
1530+ tags = tags ,
1531+ combined_model_names = combined_model_names ,
1532+ compartment_id = compartment_id ,
1533+ project_id = project_id ,
1534+ defined_tags = defined_tags ,
1535+ )
1536+
1537+ logger .info (
1538+ f"Model group of multi model deployment { model_deployment .id } has been updated: { model_group .id } ."
1539+ )
14711540
14721541 return model_group
14731542
0 commit comments