Skip to content
14 changes: 12 additions & 2 deletions ads/aqua/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
import os
from dataclasses import fields
from typing import Dict, Union
Expand Down Expand Up @@ -135,6 +136,8 @@ def create_model_version_set(
description: str = None,
compartment_id: str = None,
project_id: str = None,
freeform_tags: dict = None,
defined_tags: dict = None,
**kwargs,
) -> tuple:
"""Creates ModelVersionSet from given ID or Name.
Expand All @@ -153,7 +156,10 @@ def create_model_version_set(
Project OCID.
tag: (str, optional)
calling tag, can be Tags.AQUA_FINE_TUNING or Tags.AQUA_EVALUATION

freeform_tags: (dict, optional)
Freeform tags for the model version set
defined_tags: (dict, optional)
Defined tags for the model version set
Returns
-------
tuple: (model_version_set_id, model_version_set_name)
Expand Down Expand Up @@ -182,13 +188,15 @@ def create_model_version_set(
mvs_freeform_tags = {
tag: tag,
}
mvs_freeform_tags = {**mvs_freeform_tags, **(freeform_tags or {})}
model_version_set = (
ModelVersionSet()
.with_compartment_id(compartment_id)
.with_project_id(project_id)
.with_name(model_version_set_name)
.with_description(description)
.with_freeform_tags(**mvs_freeform_tags)
.with_defined_tags(**(defined_tags or {}))
# TODO: decide what parameters will be needed
# when refactor eval to use this method, we need to pass tag here.
.create(**kwargs)
Expand Down Expand Up @@ -340,7 +348,9 @@ def build_cli(self) -> str:
"""
cmd = f"ads aqua {self._command}"
params = [
f"--{field.name} {getattr(self,field.name)}"
f"--{field.name} {json.dumps(getattr(self, field.name))}"
if isinstance(getattr(self, field.name), dict)
else f"--{field.name} {getattr(self, field.name)}"
for field in fields(self.__class__)
if getattr(self, field.name) is not None
]
Expand Down
6 changes: 6 additions & 0 deletions ads/aqua/evaluation/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ class CreateAquaEvaluationDetails(Serializable):
The metrics for the evaluation.
force_overwrite: (bool, optional). Defaults to `False`.
Whether to force overwrite the existing file in object storage.
freeform_tags: (dict, optional)
Freeform tags for the evaluation model
defined_tags: (dict, optional)
Defined tags for the evaluation model
"""

evaluation_source_id: str
Expand All @@ -85,6 +89,8 @@ class CreateAquaEvaluationDetails(Serializable):
log_id: Optional[str] = None
metrics: Optional[List[Dict[str, Any]]] = None
force_overwrite: Optional[bool] = False
freeform_tags: Optional[dict] = None
defined_tags: Optional[dict] = None

class Config:
extra = "ignore"
Expand Down
41 changes: 34 additions & 7 deletions ads/aqua/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,10 @@ def create(
evaluation_mvs_freeform_tags = {
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
}
evaluation_mvs_freeform_tags = {
**evaluation_mvs_freeform_tags,
**(create_aqua_evaluation_details.freeform_tags or {}),
}

model_version_set = (
ModelVersionSet()
Expand All @@ -307,6 +311,9 @@ def create(
create_aqua_evaluation_details.experiment_description
)
.with_freeform_tags(**evaluation_mvs_freeform_tags)
.with_defined_tags(
**(create_aqua_evaluation_details.defined_tags or {})
)
# TODO: decide what parameters will be needed
.create(**kwargs)
)
Expand Down Expand Up @@ -369,6 +376,10 @@ def create(
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
Tags.AQUA_EVALUATION_MODEL_ID: evaluation_model.id,
}
evaluation_job_freeform_tags = {
**evaluation_job_freeform_tags,
**(create_aqua_evaluation_details.freeform_tags or {}),
}

evaluation_job = Job(name=evaluation_model.display_name).with_infrastructure(
DataScienceJob()
Expand All @@ -379,6 +390,7 @@ def create(
.with_shape_name(create_aqua_evaluation_details.shape_name)
.with_block_storage_size(create_aqua_evaluation_details.block_storage_size)
.with_freeform_tag(**evaluation_job_freeform_tags)
.with_defined_tag(**(create_aqua_evaluation_details.defined_tags or {}))
)
if (
create_aqua_evaluation_details.memory_in_gbs
Expand Down Expand Up @@ -425,6 +437,7 @@ def create(
evaluation_job_run = evaluation_job.run(
name=evaluation_model.display_name,
freeform_tags=evaluation_job_freeform_tags,
defined_tags=(create_aqua_evaluation_details.defined_tags or {}),
wait=False,
)
logger.debug(
Expand All @@ -444,13 +457,23 @@ def create(
for metadata in evaluation_model_custom_metadata.to_dict()["data"]
]

evaluation_model_freeform_tags = {
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
}
evaluation_model_freeform_tags = {
**evaluation_model_freeform_tags,
**(create_aqua_evaluation_details.freeform_tags or {}),
}
evaluation_model_defined_tags = (
create_aqua_evaluation_details.defined_tags or {}
)

self.ds_client.update_model(
model_id=evaluation_model.id,
update_model_details=UpdateModelDetails(
custom_metadata_list=updated_custom_metadata_list,
freeform_tags={
Tags.AQUA_EVALUATION: Tags.AQUA_EVALUATION,
},
freeform_tags=evaluation_model_freeform_tags,
defined_tags=evaluation_model_defined_tags,
),
)

Expand Down Expand Up @@ -520,10 +543,14 @@ def create(
),
),
tags={
"aqua_evaluation": Tags.AQUA_EVALUATION,
"evaluation_job_id": evaluation_job.id,
"evaluation_source": create_aqua_evaluation_details.evaluation_source_id,
"evaluation_experiment_id": experiment_model_version_set_id,
**{
"aqua_evaluation": Tags.AQUA_EVALUATION,
"evaluation_job_id": evaluation_job.id,
"evaluation_source": create_aqua_evaluation_details.evaluation_source_id,
"evaluation_experiment_id": experiment_model_version_set_id,
},
**evaluation_model_freeform_tags,
**evaluation_model_defined_tags,
},
parameters=AquaEvalParams(),
)
Expand Down
6 changes: 6 additions & 0 deletions ads/aqua/finetuning/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ class CreateFineTuningDetails(DataClassSerializable):
The log id for fine tuning job infrastructure.
force_overwrite: (bool, optional). Defaults to `False`.
Whether to force overwrite the existing file in object storage.
freeform_tags: (dict, optional)
Freeform tags for the fine-tuning model
defined_tags: (dict, optional)
Defined tags for the fine-tuning model
"""

ft_source_id: str
Expand All @@ -101,3 +105,5 @@ class CreateFineTuningDetails(DataClassSerializable):
log_id: Optional[str] = None
log_group_id: Optional[str] = None
force_overwrite: Optional[bool] = False
freeform_tags: Optional[dict] = None
defined_tags: Optional[dict] = None
86 changes: 52 additions & 34 deletions ads/aqua/finetuning/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
ENV_AQUA_FINE_TUNING_CONTAINER,
FineTuneCustomMetadata,
)
from ads.aqua.finetuning.entities import *
from ads.aqua.finetuning.entities import (
AquaFineTuningParams,
AquaFineTuningSummary,
CreateFineTuningDetails,
)
from ads.common.auth import default_signer
from ads.common.object_storage_details import ObjectStorageDetails
from ads.common.utils import get_console_link
Expand Down Expand Up @@ -100,14 +104,14 @@ def create(
if not create_fine_tuning_details:
try:
create_fine_tuning_details = CreateFineTuningDetails(**kwargs)
except:
except Exception as ex:
allowed_create_fine_tuning_details = ", ".join(
field.name for field in fields(CreateFineTuningDetails)
).rstrip()
raise AquaValueError(
"Invalid create fine tuning parameters. Allowable parameters are: "
f"{allowed_create_fine_tuning_details}."
)
) from ex

source = self.get_source(create_fine_tuning_details.ft_source_id)

Expand Down Expand Up @@ -148,28 +152,27 @@ def create(
"Specify the subnet id via API or environment variable AQUA_JOB_SUBNET_ID."
)

if create_fine_tuning_details.replica > DEFAULT_FT_REPLICA:
if not (
create_fine_tuning_details.log_id
and create_fine_tuning_details.log_group_id
):
raise AquaValueError(
f"Logging is required for fine tuning if replica is larger than {DEFAULT_FT_REPLICA}."
)
if create_fine_tuning_details.replica > DEFAULT_FT_REPLICA and not (
create_fine_tuning_details.log_id
and create_fine_tuning_details.log_group_id
):
raise AquaValueError(
f"Logging is required for fine tuning if replica is larger than {DEFAULT_FT_REPLICA}."
)

ft_parameters = None
try:
ft_parameters = AquaFineTuningParams(
**create_fine_tuning_details.ft_parameters,
)
except:
except Exception as ex:
allowed_fine_tuning_parameters = ", ".join(
field.name for field in fields(AquaFineTuningParams)
).rstrip()
raise AquaValueError(
"Invalid fine tuning parameters. Fine tuning parameters should "
f"be a dictionary with keys: {allowed_fine_tuning_parameters}."
)
) from ex

experiment_model_version_set_id = create_fine_tuning_details.experiment_id
experiment_model_version_set_name = create_fine_tuning_details.experiment_name
Expand Down Expand Up @@ -197,11 +200,11 @@ def create(
auth=default_signer(),
force_overwrite=create_fine_tuning_details.force_overwrite,
)
except FileExistsError:
except FileExistsError as fe:
raise AquaFileExistsError(
f"Dataset {dataset_file} already exists in {create_fine_tuning_details.report_path}. "
"Please use a new dataset file name, report path or set `force_overwrite` as True."
)
) from fe
logger.debug(
f"Uploaded local file {ft_dataset_path} to object storage {dst_uri}."
)
Expand All @@ -222,6 +225,8 @@ def create(
description=create_fine_tuning_details.experiment_description,
compartment_id=target_compartment,
project_id=target_project,
freeform_tags=create_fine_tuning_details.freeform_tags,
defined_tags=create_fine_tuning_details.defined_tags,
)

ft_model_custom_metadata = ModelCustomMetadata()
Expand Down Expand Up @@ -273,6 +278,10 @@ def create(
Tags.AQUA_TAG: UNKNOWN,
Tags.AQUA_FINE_TUNED_MODEL_TAG: f"{source.id}#{source.display_name}",
}
ft_job_freeform_tags = {
**ft_job_freeform_tags,
**(create_fine_tuning_details.freeform_tags or {}),
}

ft_job = Job(name=ft_model.display_name).with_infrastructure(
DataScienceJob()
Expand All @@ -286,6 +295,7 @@ def create(
or DEFAULT_FT_BLOCK_STORAGE_SIZE
)
.with_freeform_tag(**ft_job_freeform_tags)
.with_defined_tag(**(create_fine_tuning_details.defined_tags or {}))
)

if not subnet_id:
Expand Down Expand Up @@ -353,6 +363,7 @@ def create(
ft_job_run = ft_job.run(
name=ft_model.display_name,
freeform_tags=ft_job_freeform_tags,
defined_tags=create_fine_tuning_details.defined_tags or {},
wait=False,
)
logger.debug(
Expand All @@ -372,22 +383,25 @@ def create(
for metadata in ft_model_custom_metadata.to_dict()["data"]
]

source_freeform_tags = source.freeform_tags or {}
source_freeform_tags.pop(Tags.LICENSE, None)
source_freeform_tags.update({Tags.READY_TO_FINE_TUNE: "false"})
source_freeform_tags.update({Tags.AQUA_TAG: UNKNOWN})
source_freeform_tags.pop(Tags.BASE_MODEL_CUSTOM, None)
model_freeform_tags = source.freeform_tags or {}
model_freeform_tags.pop(Tags.LICENSE, None)
model_freeform_tags.pop(Tags.BASE_MODEL_CUSTOM, None)

model_freeform_tags = {
**model_freeform_tags,
Tags.READY_TO_FINE_TUNE: "false",
Tags.AQUA_TAG: UNKNOWN,
Tags.AQUA_FINE_TUNED_MODEL_TAG: f"{source.id}#{source.display_name}",
**(create_fine_tuning_details.freeform_tags or {}),
}
model_defined_tags = create_fine_tuning_details.defined_tags or {}

self.update_model(
model_id=ft_model.id,
update_model_details=UpdateModelDetails(
custom_metadata_list=updated_custom_metadata_list,
freeform_tags={
Tags.AQUA_FINE_TUNED_MODEL_TAG: (
f"{source.id}#{source.display_name}"
),
**source_freeform_tags,
},
freeform_tags=model_freeform_tags,
defined_tags=model_defined_tags,
),
)

Expand Down Expand Up @@ -462,12 +476,16 @@ def create(
region=self.region,
),
),
tags=dict(
aqua_finetuning=Tags.AQUA_FINE_TUNING,
finetuning_job_id=ft_job.id,
finetuning_source=source.id,
finetuning_experiment_id=experiment_model_version_set_id,
),
tags={
**{
"aqua_finetuning": Tags.AQUA_FINE_TUNING,
"finetuning_job_id": ft_job.id,
"finetuning_source": source.id,
"finetuning_experiment_id": experiment_model_version_set_id,
},
**model_freeform_tags,
**model_defined_tags,
},
parameters={
key: value
for key, value in asdict(ft_parameters).items()
Expand Down Expand Up @@ -635,6 +653,6 @@ def validate_finetuning_params(self, params: Dict = None) -> Dict:
raise AquaValueError(
f"Invalid fine tuning parameters. Allowable parameters are: "
f"{allowed_fine_tuning_parameters}."
)
) from e

return dict(valid=True)
return {"valid": True}
2 changes: 2 additions & 0 deletions ads/aqua/model/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,8 @@ class ImportModelDetails(CLIBuilderMixin):
inference_container_uri: Optional[str] = None
allow_patterns: Optional[List[str]] = None
ignore_patterns: Optional[List[str]] = None
freeform_tags: Optional[dict] = None
defined_tags: Optional[dict] = None

def __post_init__(self):
self._command = "model register"
Loading
Loading