|
99 | 99 | from ads.jobs.builders.runtimes.base import Runtime |
100 | 100 | from ads.jobs.builders.runtimes.container_runtime import ContainerRuntime |
101 | 101 | from ads.model.datascience_model import DataScienceModel |
| 102 | +from ads.model.datascience_model_group import DataScienceModelGroup |
102 | 103 | from ads.model.deployment import ModelDeploymentContainerRuntime |
103 | 104 | from ads.model.deployment.model_deployment import ModelDeployment |
104 | 105 | from ads.model.generic_model import ModelDeploymentRuntimeType |
@@ -254,7 +255,11 @@ def create( |
254 | 255 | f"Make sure the {Tags.AQUA_MODEL_ID_TAG} tag is added to the deployment." |
255 | 256 | ) |
256 | 257 |
|
257 | | - aqua_model = DataScienceModel.from_id(multi_model_id) |
| 258 | + aqua_model = ( |
| 259 | + DataScienceModelGroup.from_id(multi_model_id) |
| 260 | + if "datasciencemodelgroup" in multi_model_id |
| 261 | + else DataScienceModel.from_id(multi_model_id) |
| 262 | + ) |
258 | 263 | AquaEvaluationApp.validate_model_name( |
259 | 264 | aqua_model, create_aqua_evaluation_details |
260 | 265 | ) |
@@ -630,23 +635,23 @@ def create( |
630 | 635 |
|
631 | 636 | @staticmethod |
632 | 637 | def validate_model_name( |
633 | | - evaluation_source: DataScienceModel, |
| 638 | + evaluation_source: Union[DataScienceModel, DataScienceModelGroup], |
634 | 639 | create_aqua_evaluation_details: CreateAquaEvaluationDetails, |
635 | 640 | ) -> None: |
636 | 641 | """ |
637 | 642 | Validates the user input for the model name when creating an Aqua evaluation. |
638 | 643 |
|
639 | 644 | This function verifies that: |
640 | 645 | - The model group is not empty. |
641 | | - - The model multi metadata is present in the DataScienceModel metadata. |
| 646 | + - The model multi metadata is present in the DataScienceModel or DataScienceModelGroup metadata. |
642 | 647 | - The user provided a non-empty model name. |
643 | | - - The provided model name exists in the DataScienceModel metadata. |
| 648 | + - The provided model name exists in the DataScienceModel or DataScienceModelGroup metadata. |
644 | 649 | - The deployment configuration contains core metadata required for validation. |
645 | 650 |
|
646 | 651 | Parameters |
647 | 652 | ---------- |
648 | | - evaluation_source : DataScienceModel |
649 | | - The DataScienceModel object containing metadata about each model in the deployment. |
| 653 | + evaluation_source : Union[DataScienceModel, DataScienceModelGroup] |
| 654 | + The DataScienceModel or DataScienceModelGroup object containing metadata about each model in the deployment. |
650 | 655 | create_aqua_evaluation_details : CreateAquaEvaluationDetails |
651 | 656 | Contains required and optional fields for creating the Aqua evaluation. |
652 | 657 |
|
@@ -711,27 +716,30 @@ def validate_model_name( |
711 | 716 | logger.debug(error_message) |
712 | 717 | raise AquaRuntimeError(error_message) |
713 | 718 |
|
714 | | - try: |
715 | | - multi_model_metadata = json.loads( |
716 | | - evaluation_source.dsc_model.get_custom_metadata_artifact( |
717 | | - metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA |
718 | | - ).decode("utf-8") |
719 | | - ) |
720 | | - except Exception as ex: |
721 | | - error_message = ( |
722 | | - f"Error fetching {ModelCustomMetadataFields.MULTIMODEL_METADATA} " |
723 | | - f"from custom metadata for evaluation source ID '{evaluation_source.id}'. " |
724 | | - f"Details: {ex}" |
725 | | - ) |
726 | | - logger.error(error_message) |
727 | | - raise AquaRuntimeError(error_message) from ex |
| 719 | + if isinstance(evaluation_source, DataScienceModel): |
| 720 | + try: |
| 721 | + multi_model_metadata = json.loads( |
| 722 | + evaluation_source.dsc_model.get_custom_metadata_artifact( |
| 723 | + metadata_key_name=ModelCustomMetadataFields.MULTIMODEL_METADATA |
| 724 | + ).decode("utf-8") |
| 725 | + ) |
| 726 | + except Exception as ex: |
| 727 | + error_message = ( |
| 728 | + f"Error fetching {ModelCustomMetadataFields.MULTIMODEL_METADATA} " |
| 729 | + f"from custom metadata for evaluation source ID '{evaluation_source.id}'. " |
| 730 | + f"Details: {ex}" |
| 731 | + ) |
| 732 | + logger.error(error_message) |
| 733 | + raise AquaRuntimeError(error_message) from ex |
728 | 734 |
|
729 | 735 | # Build the list of valid model names from custom metadata. |
730 | 736 | model_names = [] |
731 | 737 | for metadata in multi_model_metadata: |
732 | 738 | model = AquaMultiModelRef(**metadata) |
733 | 739 | model_names.append(model.model_name) |
734 | | - model_names.extend(ft.model_name for ft in (model.fine_tune_weights or []) if ft.model_name) |
| 740 | + model_names.extend( |
| 741 | + ft.model_name for ft in (model.fine_tune_weights or []) if ft.model_name |
| 742 | + ) |
735 | 743 |
|
736 | 744 | # Check if the provided model name is among the valid names. |
737 | 745 | if user_model_name not in model_names: |
|
0 commit comments