Skip to content

Commit 86ef7ad

Browse files
update logging for finetuning operations
1 parent 68f325a commit 86ef7ad

File tree

3 files changed

+71
-48
lines changed

3 files changed

+71
-48
lines changed

ads/aqua/common/utils.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
"""AQUA utils and constants."""
55

@@ -12,11 +12,12 @@
1212
import re
1313
import shlex
1414
import subprocess
15+
from dataclasses import fields
1516
from datetime import datetime, timedelta
1617
from functools import wraps
1718
from pathlib import Path
1819
from string import Template
19-
from typing import List, Union
20+
from typing import Any, List, Optional, Type, TypeVar, Union
2021

2122
import fsspec
2223
import oci
@@ -74,6 +75,7 @@
7475
from ads.model import DataScienceModel, ModelVersionSet
7576

7677
logger = logging.getLogger("ads.aqua")
78+
T = TypeVar("T")
7779

7880

7981
class LifecycleStatus(str, metaclass=ExtendedEnumMeta):
@@ -788,7 +790,9 @@ def get_ocid_substring(ocid: str, key_len: int) -> str:
788790
return ocid[-key_len:] if ocid and len(ocid) > key_len else ""
789791

790792

791-
def upload_folder(os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None) -> str:
793+
def upload_folder(
794+
os_path: str, local_dir: str, model_name: str, exclude_pattern: str = None
795+
) -> str:
792796
"""Upload the local folder to the object storage
793797
794798
Args:
@@ -1159,3 +1163,44 @@ def validate_cmd_var(cmd_var: List[str], overrides: List[str]) -> List[str]:
11591163

11601164
combined_cmd_var = cmd_var + overrides
11611165
return combined_cmd_var
1166+
1167+
1168+
def validate_dataclass_params(dataclass_type: Type[T], **kwargs: Any) -> Optional[T]:
1169+
"""This method tries to initialize a dataclass with the provided keyword arguments. It handles
1170+
errors related to missing, unexpected or invalid arguments.
1171+
1172+
Parameters
1173+
----------
1174+
dataclass_type (Type[T]):
1175+
the dataclass type to instantiate.
1176+
kwargs (Any):
1177+
the keyword arguments to initialize the dataclass
1178+
Returns
1179+
-------
1180+
Optional[T]
1181+
instance of dataclass if successfully initialized
1182+
"""
1183+
1184+
try:
1185+
return dataclass_type(**kwargs)
1186+
except TypeError as ex:
1187+
error_message = str(ex)
1188+
allowed_params = ", ".join(
1189+
field.name for field in fields(dataclass_type)
1190+
).rstrip()
1191+
if "__init__() missing" in error_message:
1192+
missing_params = error_message.split("missing ")[1]
1193+
raise AquaValueError(
1194+
"Error: Missing required parameters: "
1195+
f"{missing_params}. Allowable parameters are: {allowed_params}."
1196+
) from ex
1197+
elif "__init__() got an unexpected keyword argument" in error_message:
1198+
unexpected_param = error_message.split("argument '")[1].rstrip("'")
1199+
raise AquaValueError(
1200+
"Error: Unexpected parameter: "
1201+
f"{unexpected_param}. Allowable parameters are: {allowed_params}."
1202+
) from ex
1203+
else:
1204+
raise AquaValueError(
1205+
"Invalid parameters. Allowable parameters are: " f"{allowed_params}."
1206+
) from ex

ads/aqua/extension/finetune_handler.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55

@@ -33,7 +33,7 @@ def get(self, id=""):
3333
raise HTTPError(400, f"The request {self.request.path} is invalid.")
3434

3535
@handle_exceptions
36-
def post(self, *args, **kwargs):
36+
def post(self, *args, **kwargs): # noqa: ARG002
3737
"""Handles post request for the fine-tuning API
3838
3939
Raises
@@ -43,8 +43,8 @@ def post(self, *args, **kwargs):
4343
"""
4444
try:
4545
input_data = self.get_json_body()
46-
except Exception:
47-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
46+
except Exception as ex:
47+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
4848

4949
if not input_data:
5050
raise HTTPError(400, Errors.NO_INPUT_DATA)
@@ -71,7 +71,7 @@ def get(self, model_id):
7171
)
7272

7373
@handle_exceptions
74-
def post(self, *args, **kwargs):
74+
def post(self, *args, **kwargs): # noqa: ARG002
7575
"""Handles post request for the finetuning param handler API.
7676
7777
Raises
@@ -81,15 +81,15 @@ def post(self, *args, **kwargs):
8181
"""
8282
try:
8383
input_data = self.get_json_body()
84-
except Exception:
85-
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT)
84+
except Exception as ex:
85+
raise HTTPError(400, Errors.INVALID_INPUT_DATA_FORMAT) from ex
8686

8787
if not input_data:
8888
raise HTTPError(400, Errors.NO_INPUT_DATA)
8989

9090
params = input_data.get("params", None)
9191
return self.finish(
92-
AquaFineTuningApp().validate_finetuning_params(
92+
AquaFineTuningApp.validate_finetuning_params(
9393
params=params,
9494
)
9595
)

ads/aqua/finetuning/finetuning.py

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2024 Oracle and/or its affiliates.
2+
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44

55
import json
66
import os
7-
from dataclasses import MISSING, asdict, fields
7+
from dataclasses import asdict, fields
88
from typing import Dict
99

1010
from oci.data_science.models import (
@@ -20,6 +20,7 @@
2020
from ads.aqua.common.utils import (
2121
get_container_image,
2222
upload_local_to_os,
23+
validate_dataclass_params,
2324
)
2425
from ads.aqua.constants import (
2526
DEFAULT_FT_BATCH_SIZE,
@@ -102,26 +103,10 @@ def create(
102103
The instance of AquaFineTuningSummary.
103104
"""
104105
if not create_fine_tuning_details:
105-
try:
106-
create_fine_tuning_details = CreateFineTuningDetails(**kwargs)
107-
except Exception as ex:
108-
allowed_create_fine_tuning_details = ", ".join(
109-
field.name for field in fields(CreateFineTuningDetails)
110-
).rstrip()
111-
raise AquaValueError(
112-
"Invalid create fine tuning parameters. Allowable parameters are: "
113-
f"{allowed_create_fine_tuning_details}."
114-
) from ex
106+
validate_dataclass_params(CreateFineTuningDetails, **kwargs)
115107

116108
source = self.get_source(create_fine_tuning_details.ft_source_id)
117109

118-
# todo: revisit validation for fine tuned models
119-
# if source.compartment_id != ODSC_MODEL_COMPARTMENT_OCID:
120-
# raise AquaValueError(
121-
# f"Fine tuning is only supported for Aqua service models in {ODSC_MODEL_COMPARTMENT_OCID}. "
122-
# "Use a valid Aqua service model id instead."
123-
# )
124-
125110
target_compartment = (
126111
create_fine_tuning_details.compartment_id or COMPARTMENT_OCID
127112
)
@@ -401,13 +386,19 @@ def create(
401386
defined_tags=model_defined_tags,
402387
),
403388
)
389+
logger.debug(
390+
f"Successfully updated model custom metadata list and freeform tags for the model {ft_model.id}."
391+
)
404392

405393
self.update_model_provenance(
406394
model_id=ft_model.id,
407395
update_model_provenance_details=UpdateModelProvenanceDetails(
408396
training_id=ft_job_run.id
409397
),
410398
)
399+
logger.debug(
400+
f"Successfully updated model provenance for the model {ft_model.id}."
401+
)
411402

412403
# tracks the shape and replica used for fine-tuning the service models
413404
telemetry_kwargs = (
@@ -587,7 +578,7 @@ def get_finetuning_config(self, model_id: str) -> Dict:
587578
config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG)
588579
if not config:
589580
logger.debug(
590-
f"Fine-tuning config for custom model: {model_id} is not available."
581+
f"Fine-tuning config for custom model: {model_id} is not available. Use defaults."
591582
)
592583
return config
593584

@@ -622,7 +613,8 @@ def get_finetuning_default_params(self, model_id: str) -> Dict:
622613

623614
return default_params
624615

625-
def validate_finetuning_params(self, params: Dict = None) -> Dict:
616+
@staticmethod
617+
def validate_finetuning_params(params: Dict = None) -> Dict:
626618
"""Validate if the fine-tuning parameters passed by the user can be overridden. Parameter values are not
627619
validated, only param keys are validated.
628620
@@ -633,21 +625,7 @@ def validate_finetuning_params(self, params: Dict = None) -> Dict:
633625
634626
Returns
635627
-------
636-
Return a list of restricted params.
628+
Return a dict with value true if valid, else raises AquaValueError.
637629
"""
638-
try:
639-
AquaFineTuningParams(
640-
**params,
641-
)
642-
except Exception as e:
643-
logger.debug(str(e))
644-
allowed_fine_tuning_parameters = ", ".join(
645-
f"{field.name} (required)" if field.default is MISSING else field.name
646-
for field in fields(AquaFineTuningParams)
647-
).rstrip()
648-
raise AquaValueError(
649-
f"Invalid fine tuning parameters. Allowable parameters are: "
650-
f"{allowed_fine_tuning_parameters}."
651-
) from e
652-
630+
validate_dataclass_params(AquaFineTuningParams, **(params or {}))
653631
return {"valid": True}

0 commit comments

Comments
 (0)