Skip to content

Commit 26cda04

Browse files
mrDzurblu-ohaiAryanag2
authored
[AQUA] Group Model Deployment & Stacked Model Deployment (#1217)
Co-authored-by: Lu Peng <bolu.peng@oracle.com> Co-authored-by: Aryan Gosaliya <aryan.gosaliya@oracle.com>
1 parent 42f297b commit 26cda04

27 files changed

+3956
-1179
lines changed

ads/aqua/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def create_model_catalog(
290290
model_taxonomy_metadata: Union[ModelTaxonomyMetadata, Dict],
291291
compartment_id: str,
292292
project_id: str,
293+
freeform_tags: Dict = None,
293294
defined_tags: Dict = None,
294295
**kwargs,
295296
) -> DataScienceModel:
@@ -303,6 +304,7 @@ def create_model_catalog(
303304
.with_custom_metadata_list(model_custom_metadata)
304305
.with_defined_metadata_list(model_taxonomy_metadata)
305306
.with_provenance_metadata(ModelProvenanceMetadata(training_id=UNKNOWN))
307+
.with_freeform_tags(**(freeform_tags or {}))
306308
.with_defined_tags(
307309
**(defined_tags or {})
308310
) # Create defined tags when a model is created.

ads/aqua/common/enums.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ class Tags(ExtendedEnum):
4444
MODEL_FORMAT = "model_format"
4545
MODEL_ARTIFACT_FILE = "model_file"
4646
MULTIMODEL_TYPE_TAG = "aqua_multimodel"
47+
STACKED_MODEL_TYPE_TAG = "aqua_stacked_model"
48+
AQUA_FINE_TUNE_MODEL_VERSION = "fine_tune_model_version"
4749

4850

4951
class InferenceContainerType(ExtendedEnum):

ads/aqua/common/utils.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -643,14 +643,18 @@ def get_resource_name(ocid: str) -> str:
643643
return name
644644

645645

646-
def get_model_by_reference_paths(model_file_description: dict):
646+
def get_model_by_reference_paths(
647+
model_file_description: dict, is_ft_model_v2: bool = False
648+
):
647649
"""Reads the model file description json dict and returns the base model path and fine-tuned path for
648650
models created by reference.
649651
650652
Parameters
651653
----------
652654
model_file_description: dict
653655
json dict containing model paths and objects for models created by reference.
656+
is_ft_model_v2: bool
657+
Flag to indicate if it's fine tuned model v2. Defaults to False.
654658
655659
Returns
656660
-------
@@ -666,8 +670,18 @@ def get_model_by_reference_paths(model_file_description: dict):
666670
"Please check if the model created by reference has the correct artifact."
667671
)
668672

673+
if is_ft_model_v2:
674+
# model_file_description json for fine tuned model v2 contains only fine tuned model artifacts
675+
# so first model is always the fine tuned model
676+
ft_model_artifact = models[0]
677+
fine_tune_output_path = f"oci://{ft_model_artifact['bucketName']}@{ft_model_artifact['namespace']}/{ft_model_artifact['prefix']}".rstrip(
678+
"/"
679+
)
680+
681+
return UNKNOWN, fine_tune_output_path
682+
669683
if len(models) > 0:
670-
# since the model_file_description json does not have a flag to identify the base model, we consider
684+
# since the model_file_description json for legacy fine tuned model does not have a flag to identify the base model, we consider
671685
# the first instance to be the base model.
672686
base_model_artifact = models[0]
673687
base_model_path = f"oci://{base_model_artifact['bucketName']}@{base_model_artifact['namespace']}/{base_model_artifact['prefix']}".rstrip(

ads/aqua/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
AQUA_TROUBLESHOOTING_LINK = "https://github.com/oracle-samples/oci-data-science-ai-samples/blob/main/ai-quick-actions/troubleshooting-tips.md"
4646
MODEL_FILE_DESCRIPTION_VERSION = "1.0"
4747
MODEL_FILE_DESCRIPTION_TYPE = "modelOSSReferenceDescription"
48+
AQUA_FINE_TUNE_MODEL_VERSION = "v2"
49+
INCLUDE_BASE_MODEL = 1
4850

4951
TRAINING_METRICS_FINAL = "training_metrics_final"
5052
VALIDATION_METRICS_FINAL = "validation_metrics_final"

ads/aqua/evaluation/evaluation.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@
9999
from ads.jobs.builders.runtimes.base import Runtime
100100
from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime
101101
from ads.model.datascience_model import DataScienceModel
102+
from ads.model.datascience_model_group import DataScienceModelGroup
102103
from ads.model.deployment import ModelDeploymentContainerRuntime
103104
from ads.model.deployment.model_deployment import ModelDeployment
104105
from ads.model.generic_model import ModelDeploymentRuntimeType
@@ -254,7 +255,11 @@ def create(
254255
f"Make sure the {Tags.AQUA_MODEL_ID_TAG} tag is added to the deployment."
255256
)
256257

257-
aqua_model = DataScienceModel.from_id(multi_model_id)
258+
aqua_model = (
259+
DataScienceModelGroup.from_id(multi_model_id)
260+
if "datasciencemodelgroup" in multi_model_id
261+
else DataScienceModel.from_id(multi_model_id)
262+
)
258263
AquaEvaluationApp.validate_model_name(
259264
aqua_model, create_aqua_evaluation_details
260265
)
@@ -630,23 +635,23 @@ def create(
630635

631636
@staticmethod
632637
def validate_model_name(
633-
evaluation_source: DataScienceModel,
638+
evaluation_source: Union[DataScienceModel, DataScienceModelGroup],
634639
create_aqua_evaluation_details: CreateAquaEvaluationDetails,
635640
) -> None:
636641
"""
637642
Validates the user input for the model name when creating an Aqua evaluation.
638643
639644
This function verifies that:
640645
- The model group is not empty.
641-
- The model multi metadata is present in the DataScienceModel metadata.
646+
- The model multi metadata is present in the DataScienceModel or DataScienceModelGroup metadata.
642647
- The user provided a non-empty model name.
643-
- The provided model name exists in the DataScienceModel metadata.
648+
- The provided model name exists in the DataScienceModel or DataScienceModelGroup metadata.
644649
- The deployment configuration contains core metadata required for validation.
645650
646651
Parameters
647652
----------
648-
evaluation_source : DataScienceModel
649-
The DataScienceModel object containing metadata about each model in the deployment.
653+
evaluation_source : Union[DataScienceModel, DataScienceModelGroup]
654+
The DataScienceModel or DataScienceModelGroup object containing metadata about each model in the deployment.
650655
create_aqua_evaluation_details : CreateAquaEvaluationDetails
651656
Contains required and optional fields for creating the Aqua evaluation.
652657
@@ -711,27 +716,30 @@ def validate_model_name(
711716
logger.debug(error_message)
712717
raise AquaRuntimeError(error_message)
713718

714-
try:
715-
multi_model_metadata = json.loads(
716-
evaluation_source.dsc_model.get_custom_metadata_artifact(
717-
metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA
718-
).decode("utf-8")
719-
)
720-
except Exception as ex:
721-
error_message = (
722-
f"Error fetching {ModelCustomMetadataFields.MULTIMODEL_METADATA} "
723-
f"from custom metadata for evaluation source ID '{evaluation_source.id}'. "
724-
f"Details: {ex}"
725-
)
726-
logger.error(error_message)
727-
raise AquaRuntimeError(error_message) from ex
719+
if isinstance(evaluation_source, DataScienceModel):
720+
try:
721+
multi_model_metadata = json.loads(
722+
evaluation_source.dsc_model.get_custom_metadata_artifact(
723+
metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA
724+
).decode("utf-8")
725+
)
726+
except Exception as ex:
727+
error_message = (
728+
f"Error fetching {ModelCustomMetadataFields.MULTIMODEL_METADATA} "
729+
f"from custom metadata for evaluation source ID '{evaluation_source.id}'. "
730+
f"Details: {ex}"
731+
)
732+
logger.error(error_message)
733+
raise AquaRuntimeError(error_message) from ex
728734

729735
# Build the list of valid model names from custom metadata.
730736
model_names = []
731737
for metadata in multi_model_metadata:
732738
model = AquaMultiModelRef(**metadata)
733739
model_names.append(model.model_name)
734-
model_names.extend(ft.model_name for ft in (model.fine_tune_weights or []) if ft.model_name)
740+
model_names.extend(
741+
ft.model_name for ft in (model.fine_tune_weights or []) if ft.model_name
742+
)
735743

736744
# Check if the provided model name is among the valid names.
737745
if user_model_name not in model_names:

ads/aqua/extension/deployment_handler.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,15 @@ def post(self, *args, **kwargs): # noqa: ARG002
119119
if not input_data:
120120
raise HTTPError(400, Errors.NO_INPUT_DATA)
121121

122-
self.finish(AquaDeploymentApp().create(**input_data))
122+
model_deployment_id = input_data.pop("model_deployment_id", None)
123+
if model_deployment_id:
124+
self.finish(
125+
AquaDeploymentApp().update(
126+
model_deployment_id=model_deployment_id, **input_data
127+
)
128+
)
129+
else:
130+
self.finish(AquaDeploymentApp().create(**input_data))
123131

124132
def read(self, id):
125133
"""Read the information of an Aqua model deployment."""
@@ -436,7 +444,14 @@ def get(self, model_deployment_id):
436444
list_model_result = aqua_client.fetch_data()
437445
return self.finish(list_model_result)
438446
except Exception as ex:
439-
raise HTTPError(500, str(ex))
447+
error_type = type(ex).__name__
448+
error_message = (
449+
f"Error fetching data from endpoint '{endpoint}' [{error_type}]: {ex}"
450+
)
451+
logger.error(
452+
error_message, exc_info=True
453+
) # Log with stack trace for diagnostics
454+
raise HTTPError(500, error_message) from ex
440455

441456

442457
__handlers__ = [

ads/aqua/finetuning/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class FineTuneCustomMetadata(ExtendedEnum):
1414
SERVICE_MODEL_ARTIFACT_LOCATION = "artifact_location"
1515
SERVICE_MODEL_DEPLOYMENT_CONTAINER = "deployment-container"
1616
SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container"
17+
FINE_TUNE_INCLUDE_BASE_MODEL_ARTIFACT = "include_base_model_artifact"
1718

1819

1920
class FineTuningRestrictedParams(ExtendedEnum):

ads/aqua/finetuning/finetuning.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
upload_local_to_os,
2626
)
2727
from ads.aqua.constants import (
28+
AQUA_FINE_TUNE_MODEL_VERSION,
2829
DEFAULT_FT_BATCH_SIZE,
2930
DEFAULT_FT_BLOCK_STORAGE_SIZE,
3031
DEFAULT_FT_REPLICA,
@@ -304,6 +305,11 @@ def create(
304305
"val_set_size": create_fine_tuning_details.validation_set_size,
305306
"training_data": ft_dataset_path,
306307
}
308+
# needs to add 'fine_tune_model_version' tag when creating the ft model for the
309+
# ft container to block merging base model artifact with ft model artifact.
310+
ft_model_freeform_tags = {
311+
Tags.AQUA_FINE_TUNE_MODEL_VERSION: AQUA_FINE_TUNE_MODEL_VERSION
312+
}
307313

308314
ft_model = self.create_model_catalog(
309315
display_name=create_fine_tuning_details.ft_name,
@@ -314,6 +320,7 @@ def create(
314320
compartment_id=target_compartment,
315321
project_id=target_project,
316322
model_by_reference=True,
323+
freeform_tags=ft_model_freeform_tags,
317324
defined_tags=create_fine_tuning_details.defined_tags,
318325
)
319326
defined_metadata_dict = {}
@@ -446,6 +453,7 @@ def create(
446453

447454
model_freeform_tags = {
448455
**model_freeform_tags,
456+
**(ft_model.freeform_tags or {}),
449457
Tags.READY_TO_FINE_TUNE: "false",
450458
Tags.AQUA_TAG: UNKNOWN,
451459
Tags.AQUA_FINE_TUNED_MODEL_TAG: f"{source.id}#{source.display_name}",

0 commit comments

Comments
 (0)