Skip to content

Commit e58d514

Browse files
update dataclasses
1 parent b643faa commit e58d514

File tree

3 files changed

+69
-36
lines changed

3 files changed

+69
-36
lines changed

ads/aqua/data.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

6-
from dataclasses import dataclass, field
5+
from dataclasses import dataclass
76

87
from ads.common.serializer import DataClassSerializable
98

@@ -13,19 +12,3 @@ class AquaResourceIdentifier(DataClassSerializable):
1312
id: str = ""
1413
name: str = ""
1514
url: str = ""
16-
17-
18-
@dataclass(repr=False)
19-
class AquaJobSummary(DataClassSerializable):
20-
"""Represents an Aqua job summary."""
21-
22-
id: str
23-
name: str
24-
console_url: str
25-
lifecycle_state: str
26-
lifecycle_details: str
27-
time_created: str
28-
tags: dict
29-
experiment: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
30-
source: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)
31-
job: AquaResourceIdentifier = field(default_factory=AquaResourceIdentifier)

ads/aqua/finetuning/constants.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
3-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65
from ads.common.extended_enum import ExtendedEnumMeta
@@ -17,4 +16,8 @@ class FineTuneCustomMetadata(str, metaclass=ExtendedEnumMeta):
1716
SERVICE_MODEL_FINE_TUNE_CONTAINER = "finetune-container"
1817

1918

19+
class FineTuningForbiddenParams(str, metaclass=ExtendedEnumMeta):
20+
OPTIMIZER = "optimizer"
21+
22+
2023
ENV_AQUA_FINE_TUNING_CONTAINER = "AQUA_FINE_TUNING_CONTAINER"

ads/aqua/finetuning/entities.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4-
from dataclasses import dataclass, field
4+
5+
import json
56
from typing import List, Optional
67

7-
from ads.aqua.data import AquaJobSummary
8-
from ads.common.serializer import DataClassSerializable
8+
from pydantic import Field, model_validator
9+
10+
from ads.aqua.common.errors import AquaValueError
11+
from ads.aqua.config.utils.serializer import Serializable
12+
from ads.aqua.data import AquaResourceIdentifier
13+
from ads.aqua.finetuning.constants import FineTuningForbiddenParams
14+
915

16+
class AquaFineTuningParams(Serializable):
17+
"""Class for maintaining aqua fine-tuning model parameters"""
1018

11-
@dataclass(repr=False)
12-
class AquaFineTuningParams(DataClassSerializable):
13-
epochs: int
19+
epochs: Optional[int] = None
1420
learning_rate: Optional[float] = None
1521
sample_packing: Optional[bool] = "auto"
1622
batch_size: Optional[int] = (
@@ -22,21 +28,59 @@ class AquaFineTuningParams(DataClassSerializable):
2228
lora_alpha: Optional[int] = None
2329
lora_dropout: Optional[float] = None
2430
lora_target_linear: Optional[bool] = None
25-
lora_target_modules: Optional[List] = None
31+
lora_target_modules: Optional[List[str]] = None
2632
early_stopping_patience: Optional[int] = None
2733
early_stopping_threshold: Optional[float] = None
34+
load_best_model_at_end: Optional[bool] = None
35+
metric_for_best_model: Optional[str] = None
36+
37+
class Config:
38+
extra = "allow"
39+
40+
def to_dict(self) -> dict:
41+
return json.loads(super().to_json(exclude_none=True))
2842

43+
@model_validator(mode="before")
44+
@classmethod
45+
def validate_forbidden_fields(cls, data: dict):
46+
# we may want to skip validation if loading data from config files instead of user entered parameters
47+
validate = data.pop("_validate", True)
48+
if not (validate and isinstance(data, dict)):
49+
return data
50+
forbidden_params = [
51+
param for param in data if param in FineTuningForbiddenParams.values()
52+
]
53+
if forbidden_params:
54+
raise AquaValueError(f"Found restricted parameter name: {forbidden_params}")
55+
return data
2956

30-
@dataclass(repr=False)
31-
class AquaFineTuningSummary(AquaJobSummary, DataClassSerializable):
32-
parameters: AquaFineTuningParams = field(default_factory=AquaFineTuningParams)
3357

58+
class AquaFineTuningSummary(Serializable):
59+
"""Represents a summary of Aqua Finetuning job."""
3460

35-
@dataclass(repr=False)
36-
class CreateFineTuningDetails(DataClassSerializable):
37-
"""Dataclass to create aqua model fine tuning.
61+
id: str
62+
name: str
63+
console_url: str
64+
lifecycle_state: str
65+
lifecycle_details: str
66+
time_created: str
67+
tags: dict
68+
experiment: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
69+
source: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
70+
job: AquaResourceIdentifier = Field(default_factory=AquaResourceIdentifier)
71+
parameters: AquaFineTuningParams = Field(default_factory=AquaFineTuningParams)
3872

39-
Fields
73+
class Config:
74+
extra = "ignore"
75+
76+
def to_dict(self) -> dict:
77+
return json.loads(super().to_json(exclude_none=True))
78+
79+
80+
class CreateFineTuningDetails(Serializable):
81+
"""Class to create aqua model fine-tuning instance.
82+
83+
Properties
4084
------
4185
ft_source_id: str
4286
The fine tuning source id. Must be model ocid.
@@ -107,3 +151,6 @@ class CreateFineTuningDetails(DataClassSerializable):
107151
force_overwrite: Optional[bool] = False
108152
freeform_tags: Optional[dict] = None
109153
defined_tags: Optional[dict] = None
154+
155+
class Config:
156+
extra = "ignore"

0 commit comments

Comments
 (0)