11#!/usr/bin/env python
2- # Copyright (c) 2024 Oracle and/or its affiliates.
2+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4- from dataclasses import dataclass , field
5- from typing import List , Optional
64
7- from ads . aqua . data import AquaJobSummary
8- from ads . common . serializer import DataClassSerializable
5+ import json
6+ from typing import List , Literal , Optional , Union
97
8+ from pydantic import Field , model_validator
109
11- @dataclass (repr = False )
12- class AquaFineTuningParams (DataClassSerializable ):
13- epochs : int
10+ from ads .aqua .common .errors import AquaValueError
11+ from ads .aqua .config .utils .serializer import Serializable
12+ from ads .aqua .data import AquaResourceIdentifier
13+ from ads .aqua .finetuning .constants import FineTuningRestrictedParams
14+
15+
16+ class AquaFineTuningParams (Serializable ):
17+ """Class for maintaining aqua fine-tuning model parameters"""
18+
19+ epochs : Optional [int ] = None
1420 learning_rate : Optional [float ] = None
15- sample_packing : Optional [bool ] = "auto"
21+ sample_packing : Union [bool , None , Literal [ "auto" ] ] = "auto"
1622 batch_size : Optional [int ] = (
1723 None # make it batch_size for user, but internally this is micro_batch_size
1824 )
@@ -22,21 +28,59 @@ class AquaFineTuningParams(DataClassSerializable):
2228 lora_alpha : Optional [int ] = None
2329 lora_dropout : Optional [float ] = None
2430 lora_target_linear : Optional [bool ] = None
25- lora_target_modules : Optional [List ] = None
31+ lora_target_modules : Optional [List [ str ] ] = None
2632 early_stopping_patience : Optional [int ] = None
2733 early_stopping_threshold : Optional [float ] = None
2834
35+ class Config :
36+ extra = "allow"
37+
38+ def to_dict (self ) -> dict :
39+ return json .loads (super ().to_json (exclude_none = True ))
40+
41+ @model_validator (mode = "before" )
42+ @classmethod
43+ def validate_restricted_fields (cls , data : dict ):
44+ # we may want to skip validation if loading data from config files instead of user entered parameters
45+ validate = data .pop ("_validate" , True )
46+ if not (validate and isinstance (data , dict )):
47+ return data
48+ restricted_params = [
49+ param for param in data if param in FineTuningRestrictedParams .values ()
50+ ]
51+ if restricted_params :
52+ raise AquaValueError (
53+ f"Found restricted parameter name: { restricted_params } "
54+ )
55+ return data
2956
30- @dataclass (repr = False )
31- class AquaFineTuningSummary (AquaJobSummary , DataClassSerializable ):
32- parameters : AquaFineTuningParams = field (default_factory = AquaFineTuningParams )
3357
58+ class AquaFineTuningSummary (Serializable ):
59+ """Represents a summary of Aqua Finetuning job."""
3460
35- @dataclass (repr = False )
36- class CreateFineTuningDetails (DataClassSerializable ):
37- """Dataclass to create aqua model fine tuning.
61+ id : str
62+ name : str
63+ console_url : str
64+ lifecycle_state : str
65+ lifecycle_details : str
66+ time_created : str
67+ tags : dict
68+ experiment : AquaResourceIdentifier = Field (default_factory = AquaResourceIdentifier )
69+ source : AquaResourceIdentifier = Field (default_factory = AquaResourceIdentifier )
70+ job : AquaResourceIdentifier = Field (default_factory = AquaResourceIdentifier )
71+ parameters : AquaFineTuningParams = Field (default_factory = AquaFineTuningParams )
3872
39- Fields
73+ class Config :
74+ extra = "ignore"
75+
76+ def to_dict (self ) -> dict :
77+ return json .loads (super ().to_json (exclude_none = True ))
78+
79+
80+ class CreateFineTuningDetails (Serializable ):
81+ """Class to create aqua model fine-tuning instance.
82+
83+ Properties
4084 ------
4185 ft_source_id: str
4286 The fine tuning source id. Must be model ocid.
@@ -107,3 +151,6 @@ class CreateFineTuningDetails(DataClassSerializable):
107151 force_overwrite : Optional [bool ] = False
108152 freeform_tags : Optional [dict ] = None
109153 defined_tags : Optional [dict ] = None
154+
155+ class Config :
156+ extra = "ignore"
0 commit comments