Skip to content

Commit b441ea7

Browse files
Passing Early stopping params to SMC container for finetuning (#970)
2 parents a8780ea + 0814b7c commit b441ea7

File tree

3 files changed

+17
-12
lines changed

3 files changed

+17
-12
lines changed

ads/aqua/extension/finetune_handler.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

@@ -9,8 +8,8 @@
98
from tornado.web import HTTPError
109

1110
from ads.aqua.common.decorator import handle_exceptions
12-
from ads.aqua.extension.errors import Errors
1311
from ads.aqua.extension.base_handler import AquaAPIhandler
12+
from ads.aqua.extension.errors import Errors
1413
from ads.aqua.extension.utils import validate_function_parameters
1514
from ads.aqua.finetuning import AquaFineTuningApp
1615
from ads.aqua.finetuning.entities import CreateFineTuningDetails

ads/aqua/finetuning/entities.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
from dataclasses import dataclass, field
@@ -14,16 +13,18 @@ class AquaFineTuningParams(DataClassSerializable):
1413
epochs: int
1514
learning_rate: Optional[float] = None
1615
sample_packing: Optional[bool] = "auto"
17-
batch_size: Optional[
18-
int
19-
] = None # make it batch_size for user, but internally this is micro_batch_size
16+
batch_size: Optional[int] = (
17+
None # make it batch_size for user, but internally this is micro_batch_size
18+
)
2019
sequence_len: Optional[int] = None
2120
pad_to_sequence_len: Optional[bool] = None
2221
lora_r: Optional[int] = None
2322
lora_alpha: Optional[int] = None
2423
lora_dropout: Optional[float] = None
2524
lora_target_linear: Optional[bool] = None
2625
lora_target_modules: Optional[List] = None
26+
early_stopping_patience: Optional[int] = None
27+
early_stopping_threshold: Optional[float] = None
2728

2829

2930
@dataclass(repr=False)

ads/aqua/finetuning/finetuning.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
#!/usr/bin/env python
2-
# -*- coding: utf-8 -*-
32
# Copyright (c) 2024 Oracle and/or its affiliates.
43
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54

65
import json
76
import os
8-
from dataclasses import asdict, fields, MISSING
7+
from dataclasses import MISSING, asdict, fields
98
from typing import Dict
109

1110
from oci.data_science.models import (
@@ -14,14 +13,15 @@
1413
UpdateModelProvenanceDetails,
1514
)
1615

17-
from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger
16+
from ads.aqua import logger
1817
from ads.aqua.app import AquaApp
1918
from ads.aqua.common.enums import Resource, Tags
2019
from ads.aqua.common.errors import AquaFileExistsError, AquaValueError
2120
from ads.aqua.common.utils import (
2221
get_container_image,
2322
upload_local_to_os,
2423
)
24+
from ads.aqua.config.config import get_finetuning_config_defaults
2525
from ads.aqua.constants import (
2626
DEFAULT_FT_BATCH_SIZE,
2727
DEFAULT_FT_BLOCK_STORAGE_SIZE,
@@ -31,7 +31,6 @@
3131
UNKNOWN,
3232
UNKNOWN_DICT,
3333
)
34-
from ads.aqua.config.config import get_finetuning_config_defaults
3534
from ads.aqua.data import AquaResourceIdentifier
3635
from ads.aqua.finetuning.constants import *
3736
from ads.aqua.finetuning.entities import *
@@ -132,7 +131,7 @@ def create(
132131
or create_fine_tuning_details.validation_set_size >= 1
133132
):
134133
raise AquaValueError(
135-
f"Fine tuning validation set size should be a float number in between [0, 1)."
134+
"Fine tuning validation set size should be a float number in between [0, 1)."
136135
)
137136

138137
if create_fine_tuning_details.replica < DEFAULT_FT_REPLICA:
@@ -394,7 +393,7 @@ def create(
394393
)
395394
# track shapes that were used for fine-tune creation
396395
self.telemetry.record_event_async(
397-
category=f"aqua/service/finetune/create/shape/",
396+
category="aqua/service/finetune/create/shape/",
398397
action=f"{create_fine_tuning_details.shape_name}x{create_fine_tuning_details.replica}",
399398
**telemetry_kwargs,
400399
)
@@ -533,6 +532,12 @@ def _build_oci_launch_cmd(
533532
oci_launch_cmd += f"--num_{key} {value} "
534533
elif key == "lora_target_modules":
535534
oci_launch_cmd += f"--{key} {','.join(str(k) for k in value)} "
535+
elif key == "early_stopping_patience":
536+
if value != 0:
537+
oci_launch_cmd += f"--{key} {value} "
538+
elif key == "early_stopping_threshold":
539+
if "early_stopping_patience" in oci_launch_cmd:
540+
oci_launch_cmd += f"--{key} {value} "
536541
else:
537542
oci_launch_cmd += f"--{key} {value} "
538543

0 commit comments

Comments
 (0)