11#!/usr/bin/env python
2- # Copyright (c) 2024 Oracle and/or its affiliates.
2+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
55import json
66import os
7- from dataclasses import MISSING , asdict , fields
87from typing import Dict
98
109from oci .data_science .models import (
1110 Metadata ,
1211 UpdateModelDetails ,
1312 UpdateModelProvenanceDetails ,
1413)
14+ from pydantic import ValidationError
1515
1616from ads .aqua import logger
1717from ads .aqua .app import AquaApp
@@ -104,24 +104,16 @@ def create(
104104 if not create_fine_tuning_details :
105105 try :
106106 create_fine_tuning_details = CreateFineTuningDetails (** kwargs )
107- except Exception as ex :
108- allowed_create_fine_tuning_details = ", " . join (
109- field . name for field in fields ( CreateFineTuningDetails )
110- ). rstrip ()
107+ except ValidationError as ex :
108+ custom_errors = {
109+ "." . join ( map ( str , e [ "loc" ])): e [ "msg" ] for e in ex . errors ( )
110+ }
111111 raise AquaValueError (
112- "Invalid create fine tuning parameters. Allowable parameters are: "
113- f"{ allowed_create_fine_tuning_details } ."
112+ f"Invalid parameters for creating a fine-tuned model. Error details: { custom_errors } ."
114113 ) from ex
115114
116115 source = self .get_source (create_fine_tuning_details .ft_source_id )
117116
118- # todo: revisit validation for fine tuned models
119- # if source.compartment_id != ODSC_MODEL_COMPARTMENT_OCID:
120- # raise AquaValueError(
121- # f"Fine tuning is only supported for Aqua service models in {ODSC_MODEL_COMPARTMENT_OCID}. "
122- # "Use a valid Aqua service model id instead."
123- # )
124-
125117 target_compartment = (
126118 create_fine_tuning_details .compartment_id or COMPARTMENT_OCID
127119 )
@@ -160,19 +152,9 @@ def create(
160152 f"Logging is required for fine tuning if replica is larger than { DEFAULT_FT_REPLICA } ."
161153 )
162154
163- ft_parameters = None
164- try :
165- ft_parameters = AquaFineTuningParams (
166- ** create_fine_tuning_details .ft_parameters ,
167- )
168- except Exception as ex :
169- allowed_fine_tuning_parameters = ", " .join (
170- field .name for field in fields (AquaFineTuningParams )
171- ).rstrip ()
172- raise AquaValueError (
173- "Invalid fine tuning parameters. Fine tuning parameters should "
174- f"be a dictionary with keys: { allowed_fine_tuning_parameters } ."
175- ) from ex
155+ ft_parameters = self ._validate_finetuning_params (
156+ create_fine_tuning_details .ft_parameters
157+ )
176158
177159 experiment_model_version_set_id = create_fine_tuning_details .experiment_id
178160 experiment_model_version_set_name = create_fine_tuning_details .experiment_name
@@ -481,11 +463,7 @@ def create(
481463 ** model_freeform_tags ,
482464 ** model_defined_tags ,
483465 },
484- parameters = {
485- key : value
486- for key , value in asdict (ft_parameters ).items ()
487- if value is not None
488- },
466+ parameters = ft_parameters ,
489467 )
490468
491469 def _build_fine_tuning_runtime (
@@ -548,7 +526,7 @@ def _build_oci_launch_cmd(
548526 ) -> str :
549527 """Builds the oci launch cmd for fine tuning container runtime."""
550528 oci_launch_cmd = f"--training_data { dataset_path } --output_dir { report_path } --val_set_size { val_set_size } "
551- for key , value in asdict ( parameters ).items ():
529+ for key , value in parameters . to_dict ( ).items ():
552530 if value is not None :
553531 if key == "batch_size" :
554532 oci_launch_cmd += f"--micro_{ key } { value } "
@@ -613,15 +591,33 @@ def get_finetuning_default_params(self, model_id: str) -> Dict:
613591 default_params = {"params" : {}}
614592 finetuning_config = self .get_finetuning_config (model_id )
615593 config_parameters = finetuning_config .get ("configuration" , UNKNOWN_DICT )
616- dataclass_fields = {field .name for field in fields (AquaFineTuningParams )}
594+ config_parameters ["_validate" ] = False
595+ dataclass_fields = AquaFineTuningParams (** config_parameters ).to_dict ()
617596 for name , value in config_parameters .items ():
618- if name == "micro_batch_size" :
619- name = "batch_size"
620597 if name in dataclass_fields :
598+ if name == "micro_batch_size" :
599+ name = "batch_size"
621600 default_params ["params" ][name ] = value
622601
623602 return default_params
624603
604+ @staticmethod
605+ def _validate_finetuning_params (params : Dict = None ) -> AquaFineTuningParams :
606+ try :
607+ finetuning_params = AquaFineTuningParams (** params )
608+ except ValidationError as ex :
609+ # combine both loc and msg for errors where loc (field) is present in error details, else only build error
610+ # message using msg field. Added to handle error messages from pydantic model validator handler.
611+ custom_errors = {
612+ "." .join (map (str , e ["loc" ])): e ["msg" ]
613+ for e in ex .errors ()
614+ if "loc" in e and e ["loc" ]
615+ } or "; " .join (e ["msg" ] for e in ex .errors ())
616+ raise AquaValueError (
617+ f"Invalid finetuning parameters. Error details: { custom_errors } ."
618+ ) from ex
619+ return finetuning_params
620+
625621 def validate_finetuning_params (self , params : Dict = None ) -> Dict :
626622 """Validate if the fine-tuning parameters passed by the user can be overridden. Parameter values are not
627623 validated, only param keys are validated.
@@ -635,19 +631,5 @@ def validate_finetuning_params(self, params: Dict = None) -> Dict:
635631 -------
636632 Return a list of restricted params.
637633 """
638- try :
639- AquaFineTuningParams (
640- ** params ,
641- )
642- except Exception as e :
643- logger .debug (str (e ))
644- allowed_fine_tuning_parameters = ", " .join (
645- f"{ field .name } (required)" if field .default is MISSING else field .name
646- for field in fields (AquaFineTuningParams )
647- ).rstrip ()
648- raise AquaValueError (
649- f"Invalid fine tuning parameters. Allowable parameters are: "
650- f"{ allowed_fine_tuning_parameters } ."
651- ) from e
652-
634+ self ._validate_finetuning_params (params or {})
653635 return {"valid" : True }
0 commit comments