Skip to content

Commit 0814b7c

Browse files
moving params to AQUAFineTunedParams
1 parent 28c5d26 commit 0814b7c

File tree

3 files changed

+17
-24
lines changed

3 files changed

+17
-24
lines changed

ads/aqua/extension/finetune_handler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -9,8 +8,8 @@
98
from tornado.web import HTTPError
109

1110
from ads.aqua.common.decorator import handle_exceptions
12-
from ads.aqua.extension.errors import Errors
1311
from ads.aqua.extension.base_handler import AquaAPIhandler
12+
from ads.aqua.extension.errors import Errors
1413
from ads.aqua.extension.utils import validate_function_parameters
1514
from ads.aqua.finetuning import AquaFineTuningApp
1615
from ads.aqua.finetuning.entities import CreateFineTuningDetails

ads/aqua/finetuning/entities.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
from dataclasses import dataclass, field
@@ -14,16 +13,18 @@ class AquaFineTuningParams(DataClassSerializable):
1413
epochs: int
1514
learning_rate: Optional[float] = None
1615
sample_packing: Optional[bool] = "auto"
17-
batch_size: Optional[
18-
int
19-
] = None # make it batch_size for user, but internally this is micro_batch_size
16+
batch_size: Optional[int] = (
17+
None # make it batch_size for user, but internally this is micro_batch_size
18+
)
2019
sequence_len: Optional[int] = None
2120
pad_to_sequence_len: Optional[bool] = None
2221
lora_r: Optional[int] = None
2322
lora_alpha: Optional[int] = None
2423
lora_dropout: Optional[float] = None
2524
lora_target_linear: Optional[bool] = None
2625
lora_target_modules: Optional[List] = None
26+
early_stopping_patience: Optional[int] = None
27+
early_stopping_threshold: Optional[float] = None
2728

2829

2930
@dataclass(repr=False)
@@ -100,5 +101,3 @@ class CreateFineTuningDetails(DataClassSerializable):
100101
log_id: Optional[str] = None
101102
log_group_id: Optional[str] = None
102103
force_overwrite: Optional[bool] = False
103-
early_stopping_patience: Optional[int] = None
104-
early_stopping_threshold: Optional[float] = 0.0

ads/aqua/finetuning/finetuning.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65
import json
76
import os
8-
from dataclasses import asdict, fields, MISSING
7+
from dataclasses import MISSING, asdict, fields
98
from typing import Dict
109

1110
from oci.data_science.models import (
@@ -14,14 +13,15 @@
1413
UpdateModelProvenanceDetails,
1514
)
1615

17-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
16+
from ads.aqua import logger
1817
from ads.aqua.app import AquaApp
1918
from ads.aqua.common.enums import Resource, Tags
2019
from ads.aqua.common.errors import AquaFileExistsError, AquaValueError
2120
from ads.aqua.common.utils import (
2221
get_container_image,
2322
upload_local_to_os,
2423
)
24+
from ads.aqua.config.config import get_finetuning_config_defaults
2525
from ads.aqua.constants import (
2626
DEFAULT_FT_BATCH_SIZE,
2727
DEFAULT_FT_BLOCK_STORAGE_SIZE,
@@ -31,7 +31,6 @@
3131
UNKNOWN,
3232
UNKNOWN_DICT,
3333
)
34-
from ads.aqua.config.config import get_finetuning_config_defaults
3534
from ads.aqua.data import AquaResourceIdentifier
3635
from ads.aqua.finetuning.constants import *
3736
from ads.aqua.finetuning.entities import *
@@ -132,7 +131,7 @@ def create(
132131
or create_fine_tuning_details.validation_set_size >= 1
133132
):
134133
raise AquaValueError(
135-
f"Fine tuning validation set size should be a float number in between [0, 1)."
134+
"Fine tuning validation set size should be a float number in between [0, 1)."
136135
)
137136

138137
if create_fine_tuning_details.replica < DEFAULT_FT_REPLICA:
@@ -334,8 +333,6 @@ def create(
334333
parameters=ft_parameters,
335334
ft_container=ft_container,
336335
is_custom_container=is_custom_container,
337-
early_stopping_patience=create_fine_tuning_details.early_stopping_patience,
338-
early_stopping_threshold=create_fine_tuning_details.early_stopping_threshold
339336
)
340337
).create()
341338
logger.debug(
@@ -396,7 +393,7 @@ def create(
396393
)
397394
# track shapes that were used for fine-tune creation
398395
self.telemetry.record_event_async(
399-
category=f"aqua/service/finetune/create/shape/",
396+
category="aqua/service/finetune/create/shape/",
400397
action=f"{create_fine_tuning_details.shape_name}x{create_fine_tuning_details.replica}",
401398
**telemetry_kwargs,
402399
)
@@ -479,8 +476,6 @@ def _build_fine_tuning_runtime(
479476
ft_container: str = None,
480477
finetuning_params: str = None,
481478
is_custom_container: bool = False,
482-
early_stopping_patience: int = None,
483-
early_stopping_threshold: float = 0.0
484479
) -> Runtime:
485480
"""Builds fine tuning runtime for Job."""
486481
container = (
@@ -509,8 +504,6 @@ def _build_fine_tuning_runtime(
509504
val_set_size=val_set_size,
510505
parameters=parameters,
511506
finetuning_params=finetuning_params,
512-
early_stopping_patience=early_stopping_patience,
513-
early_stopping_threshold=early_stopping_threshold
514507
),
515508
"CONDA_BUCKET_NS": CONDA_BUCKET_NS,
516509
}
@@ -528,13 +521,9 @@ def _build_oci_launch_cmd(
528521
val_set_size: float,
529522
parameters: AquaFineTuningParams,
530523
finetuning_params: str = None,
531-
early_stopping_patience: int = None,
532-
early_stopping_threshold: float = 0.0
533524
) -> str:
534525
"""Builds the oci launch cmd for fine tuning container runtime."""
535526
oci_launch_cmd = f"--training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} "
536-
if early_stopping_patience:
537-
oci_launch_cmd += f"--early_stopping_patience {early_stopping_patience} --early_stopping_threshold {early_stopping_threshold} "
538527
for key, value in asdict(parameters).items():
539528
if value is not None:
540529
if key == "batch_size":
@@ -543,6 +532,12 @@ def _build_oci_launch_cmd(
543532
oci_launch_cmd += f"--num_{key} {value} "
544533
elif key == "lora_target_modules":
545534
oci_launch_cmd += f"--{key} {','.join(str(k) for k in value)} "
535+
elif key == "early_stopping_patience":
536+
if value != 0:
537+
oci_launch_cmd += f"--{key} {value} "
538+
elif key == "early_stopping_threshold":
539+
if "early_stopping_patience" in oci_launch_cmd:
540+
oci_launch_cmd += f"--{key} {value} "
546541
else:
547542
oci_launch_cmd += f"--{key} {value} "
548543

0 commit comments

Comments
 (0)