Skip to content

Commit e798ea8

Browse files
update finetuning module
1 parent e58d514 commit e798ea8

File tree

1 file changed

+34
-52
lines changed

1 file changed

+34
-52
lines changed

ads/aqua/finetuning/finetuning.py

Lines changed: 34 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import json
66
import os
7-
from dataclasses import MISSING, asdict, fields
87
from typing import Dict
98

109
from oci.data_science.models import (
1110
Metadata,
1211
UpdateModelDetails,
1312
UpdateModelProvenanceDetails,
1413
)
14+
from pydantic import ValidationError
1515

1616
from ads.aqua import logger
1717
from ads.aqua.app import AquaApp
@@ -104,24 +104,16 @@ def create(
104104
if not create_fine_tuning_details:
105105
try:
106106
create_fine_tuning_details = CreateFineTuningDetails(**kwargs)
107-
except Exception as ex:
108-
allowed_create_fine_tuning_details = ", ".join(
109-
field.name for field in fields(CreateFineTuningDetails)
110-
).rstrip()
107+
except ValidationError as ex:
108+
custom_errors = {
109+
".".join(map(str, e["loc"])): e["msg"] for e in ex.errors()
110+
}
111111
raise AquaValueError(
112-
"Invalid create fine tuning parameters. Allowable parameters are: "
113-
f"{allowed_create_fine_tuning_details}."
112+
f"Invalid parameters for creating a fine-tuned model. Error details: {custom_errors}."
114113
) from ex
115114

116115
source = self.get_source(create_fine_tuning_details.ft_source_id)
117116

118-
# todo: revisit validation for fine tuned models
119-
# if source.compartment_id != ODSC_MODEL_COMPARTMENT_OCID:
120-
# raise AquaValueError(
121-
# f"Fine tuning is only supported for Aqua service models in {ODSC_MODEL_COMPARTMENT_OCID}. "
122-
# "Use a valid Aqua service model id instead."
123-
# )
124-
125117
target_compartment = (
126118
create_fine_tuning_details.compartment_id or COMPARTMENT_OCID
127119
)
@@ -160,19 +152,9 @@ def create(
160152
f"Logging is required for fine tuning if replica is larger than {DEFAULT_FT_REPLICA}."
161153
)
162154

163-
ft_parameters = None
164-
try:
165-
ft_parameters = AquaFineTuningParams(
166-
**create_fine_tuning_details.ft_parameters,
167-
)
168-
except Exception as ex:
169-
allowed_fine_tuning_parameters = ", ".join(
170-
field.name for field in fields(AquaFineTuningParams)
171-
).rstrip()
172-
raise AquaValueError(
173-
"Invalid fine tuning parameters. Fine tuning parameters should "
174-
f"be a dictionary with keys: {allowed_fine_tuning_parameters}."
175-
) from ex
155+
ft_parameters = self._validate_finetuning_params(
156+
create_fine_tuning_details.ft_parameters
157+
)
176158

177159
experiment_model_version_set_id = create_fine_tuning_details.experiment_id
178160
experiment_model_version_set_name = create_fine_tuning_details.experiment_name
@@ -481,11 +463,7 @@ def create(
481463
**model_freeform_tags,
482464
**model_defined_tags,
483465
},
484-
parameters={
485-
key: value
486-
for key, value in asdict(ft_parameters).items()
487-
if value is not None
488-
},
466+
parameters=ft_parameters,
489467
)
490468

491469
def _build_fine_tuning_runtime(
@@ -548,7 +526,7 @@ def _build_oci_launch_cmd(
548526
) -> str:
549527
"""Builds the oci launch cmd for fine tuning container runtime."""
550528
oci_launch_cmd = f"--training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} "
551-
for key, value in asdict(parameters).items():
529+
for key, value in parameters.to_dict().items():
552530
if value is not None:
553531
if key == "batch_size":
554532
oci_launch_cmd += f"--micro_{key} {value} "
@@ -613,15 +591,33 @@ def get_finetuning_default_params(self, model_id: str) -> Dict:
613591
default_params = {"params": {}}
614592
finetuning_config = self.get_finetuning_config(model_id)
615593
config_parameters = finetuning_config.get("configuration", UNKNOWN_DICT)
616-
dataclass_fields = {field.name for field in fields(AquaFineTuningParams)}
594+
config_parameters["_validate"] = False
595+
dataclass_fields = AquaFineTuningParams(**config_parameters).to_dict()
617596
for name, value in config_parameters.items():
618-
if name == "micro_batch_size":
619-
name = "batch_size"
620597
if name in dataclass_fields:
598+
if name == "micro_batch_size":
599+
name = "batch_size"
621600
default_params["params"][name] = value
622601

623602
return default_params
624603

604+
@staticmethod
605+
def _validate_finetuning_params(params: Dict = None) -> AquaFineTuningParams:
606+
try:
607+
finetuning_params = AquaFineTuningParams(**params)
608+
except ValidationError as ex:
609+
# combine both loc and msg for errors where loc (field) is present in error details, else only build error
610+
# message using msg field. Added to handle error messages from pydantic model validator handler.
611+
custom_errors = {
612+
".".join(map(str, e["loc"])): e["msg"]
613+
for e in ex.errors()
614+
if "loc" in e and e["loc"]
615+
} or "; ".join(e["msg"] for e in ex.errors())
616+
raise AquaValueError(
617+
f"Invalid finetuning parameters. Error details: {custom_errors}."
618+
) from ex
619+
return finetuning_params
620+
625621
def validate_finetuning_params(self, params: Dict = None) -> Dict:
626622
"""Validate if the fine-tuning parameters passed by the user can be overridden. Parameter values are not
627623
validated, only param keys are validated.
@@ -635,19 +631,5 @@ def validate_finetuning_params(self, params: Dict = None) -> Dict:
635631
-------
636632
Return a list of restricted params.
637633
"""
638-
try:
639-
AquaFineTuningParams(
640-
**params,
641-
)
642-
except Exception as e:
643-
logger.debug(str(e))
644-
allowed_fine_tuning_parameters = ", ".join(
645-
f"{field.name} (required)" if field.default is MISSING else field.name
646-
for field in fields(AquaFineTuningParams)
647-
).rstrip()
648-
raise AquaValueError(
649-
f"Invalid fine tuning parameters. Allowable parameters are: "
650-
f"{allowed_fine_tuning_parameters}."
651-
) from e
652-
634+
self._validate_finetuning_params(params or {})
653635
return {"valid": True}

0 commit comments

Comments
 (0)