3535 ENV_AQUA_FINE_TUNING_CONTAINER ,
3636 FineTuneCustomMetadata ,
3737)
38- from ads .aqua .finetuning .entities import *
38+ from ads .aqua .finetuning .entities import (
39+ AquaFineTuningParams ,
40+ AquaFineTuningSummary ,
41+ CreateFineTuningDetails ,
42+ )
3943from ads .common .auth import default_signer
4044from ads .common .object_storage_details import ObjectStorageDetails
4145from ads .common .utils import get_console_link
@@ -100,14 +104,14 @@ def create(
100104 if not create_fine_tuning_details :
101105 try :
102106 create_fine_tuning_details = CreateFineTuningDetails (** kwargs )
103- except :
107+ except Exception as ex :
104108 allowed_create_fine_tuning_details = ", " .join (
105109 field .name for field in fields (CreateFineTuningDetails )
106110 ).rstrip ()
107111 raise AquaValueError (
108112 "Invalid create fine tuning parameters. Allowable parameters are: "
109113 f"{ allowed_create_fine_tuning_details } ."
110- )
114+ ) from ex
111115
112116 source = self .get_source (create_fine_tuning_details .ft_source_id )
113117
@@ -148,28 +152,27 @@ def create(
148152 "Specify the subnet id via API or environment variable AQUA_JOB_SUBNET_ID."
149153 )
150154
151- if create_fine_tuning_details .replica > DEFAULT_FT_REPLICA :
152- if not (
153- create_fine_tuning_details .log_id
154- and create_fine_tuning_details .log_group_id
155- ):
156- raise AquaValueError (
157- f"Logging is required for fine tuning if replica is larger than { DEFAULT_FT_REPLICA } ."
158- )
155+ if create_fine_tuning_details .replica > DEFAULT_FT_REPLICA and not (
156+ create_fine_tuning_details .log_id
157+ and create_fine_tuning_details .log_group_id
158+ ):
159+ raise AquaValueError (
160+ f"Logging is required for fine tuning if replica is larger than { DEFAULT_FT_REPLICA } ."
161+ )
159162
160163 ft_parameters = None
161164 try :
162165 ft_parameters = AquaFineTuningParams (
163166 ** create_fine_tuning_details .ft_parameters ,
164167 )
165- except :
168+ except Exception as ex :
166169 allowed_fine_tuning_parameters = ", " .join (
167170 field .name for field in fields (AquaFineTuningParams )
168171 ).rstrip ()
169172 raise AquaValueError (
170173 "Invalid fine tuning parameters. Fine tuning parameters should "
171174 f"be a dictionary with keys: { allowed_fine_tuning_parameters } ."
172- )
175+ ) from ex
173176
174177 experiment_model_version_set_id = create_fine_tuning_details .experiment_id
175178 experiment_model_version_set_name = create_fine_tuning_details .experiment_name
@@ -197,11 +200,11 @@ def create(
197200 auth = default_signer (),
198201 force_overwrite = create_fine_tuning_details .force_overwrite ,
199202 )
200- except FileExistsError :
203+ except FileExistsError as fe :
201204 raise AquaFileExistsError (
202205 f"Dataset { dataset_file } already exists in { create_fine_tuning_details .report_path } . "
203206 "Please use a new dataset file name, report path or set `force_overwrite` as True."
204- )
207+ ) from fe
205208 logger .debug (
206209 f"Uploaded local file { ft_dataset_path } to object storage { dst_uri } ."
207210 )
@@ -222,6 +225,8 @@ def create(
222225 description = create_fine_tuning_details .experiment_description ,
223226 compartment_id = target_compartment ,
224227 project_id = target_project ,
228+ freeform_tags = create_fine_tuning_details .freeform_tags ,
229+ defined_tags = create_fine_tuning_details .defined_tags ,
225230 )
226231
227232 ft_model_custom_metadata = ModelCustomMetadata ()
@@ -273,6 +278,10 @@ def create(
273278 Tags .AQUA_TAG : UNKNOWN ,
274279 Tags .AQUA_FINE_TUNED_MODEL_TAG : f"{ source .id } #{ source .display_name } " ,
275280 }
281+ ft_job_freeform_tags = {
282+ ** ft_job_freeform_tags ,
283+ ** (create_fine_tuning_details .freeform_tags or {}),
284+ }
276285
277286 ft_job = Job (name = ft_model .display_name ).with_infrastructure (
278287 DataScienceJob ()
@@ -286,6 +295,7 @@ def create(
286295 or DEFAULT_FT_BLOCK_STORAGE_SIZE
287296 )
288297 .with_freeform_tag (** ft_job_freeform_tags )
298+ .with_defined_tag (** (create_fine_tuning_details .defined_tags or {}))
289299 )
290300
291301 if not subnet_id :
@@ -353,6 +363,7 @@ def create(
353363 ft_job_run = ft_job .run (
354364 name = ft_model .display_name ,
355365 freeform_tags = ft_job_freeform_tags ,
366+ defined_tags = create_fine_tuning_details .defined_tags or {},
356367 wait = False ,
357368 )
358369 logger .debug (
@@ -372,22 +383,25 @@ def create(
372383 for metadata in ft_model_custom_metadata .to_dict ()["data" ]
373384 ]
374385
375- source_freeform_tags = source .freeform_tags or {}
376- source_freeform_tags .pop (Tags .LICENSE , None )
377- source_freeform_tags .update ({Tags .READY_TO_FINE_TUNE : "false" })
378- source_freeform_tags .update ({Tags .AQUA_TAG : UNKNOWN })
379- source_freeform_tags .pop (Tags .BASE_MODEL_CUSTOM , None )
386+ model_freeform_tags = source .freeform_tags or {}
387+ model_freeform_tags .pop (Tags .LICENSE , None )
388+ model_freeform_tags .pop (Tags .BASE_MODEL_CUSTOM , None )
389+
390+ model_freeform_tags = {
391+ ** model_freeform_tags ,
392+ Tags .READY_TO_FINE_TUNE : "false" ,
393+ Tags .AQUA_TAG : UNKNOWN ,
394+ Tags .AQUA_FINE_TUNED_MODEL_TAG : f"{ source .id } #{ source .display_name } " ,
395+ ** (create_fine_tuning_details .freeform_tags or {}),
396+ }
397+ model_defined_tags = create_fine_tuning_details .defined_tags or {}
380398
381399 self .update_model (
382400 model_id = ft_model .id ,
383401 update_model_details = UpdateModelDetails (
384402 custom_metadata_list = updated_custom_metadata_list ,
385- freeform_tags = {
386- Tags .AQUA_FINE_TUNED_MODEL_TAG : (
387- f"{ source .id } #{ source .display_name } "
388- ),
389- ** source_freeform_tags ,
390- },
403+ freeform_tags = model_freeform_tags ,
404+ defined_tags = model_defined_tags ,
391405 ),
392406 )
393407
@@ -462,12 +476,16 @@ def create(
462476 region = self .region ,
463477 ),
464478 ),
465- tags = dict (
466- aqua_finetuning = Tags .AQUA_FINE_TUNING ,
467- finetuning_job_id = ft_job .id ,
468- finetuning_source = source .id ,
469- finetuning_experiment_id = experiment_model_version_set_id ,
470- ),
479+ tags = {
480+ ** {
481+ "aqua_finetuning" : Tags .AQUA_FINE_TUNING ,
482+ "finetuning_job_id" : ft_job .id ,
483+ "finetuning_source" : source .id ,
484+ "finetuning_experiment_id" : experiment_model_version_set_id ,
485+ },
486+ ** model_freeform_tags ,
487+ ** model_defined_tags ,
488+ },
471489 parameters = {
472490 key : value
473491 for key , value in asdict (ft_parameters ).items ()
@@ -635,6 +653,6 @@ def validate_finetuning_params(self, params: Dict = None) -> Dict:
635653 raise AquaValueError (
636654 f"Invalid fine tuning parameters. Allowable parameters are: "
637655 f"{ allowed_fine_tuning_parameters } ."
638- )
656+ ) from e
639657
640- return dict ( valid = True )
658+ return { " valid" : True }
0 commit comments