11#!/usr/bin/env python
2- # -*- coding: utf-8 -*-
32# Copyright (c) 2024 Oracle and/or its affiliates.
43# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
54
65import json
76import os
8- from dataclasses import asdict , fields , MISSING
7+ from dataclasses import MISSING , asdict , fields
98from typing import Dict
109
1110from oci .data_science .models import (
1413 UpdateModelProvenanceDetails ,
1514)
1615
17- from ads .aqua import ODSC_MODEL_COMPARTMENT_OCID , logger
16+ from ads .aqua import logger
1817from ads .aqua .app import AquaApp
1918from ads .aqua .common .enums import Resource , Tags
2019from ads .aqua .common .errors import AquaFileExistsError , AquaValueError
2120from ads .aqua .common .utils import (
2221 get_container_image ,
2322 upload_local_to_os ,
2423)
24+ from ads .aqua .config .config import get_finetuning_config_defaults
2525from ads .aqua .constants import (
2626 DEFAULT_FT_BATCH_SIZE ,
2727 DEFAULT_FT_BLOCK_STORAGE_SIZE ,
3131 UNKNOWN ,
3232 UNKNOWN_DICT ,
3333)
34- from ads .aqua .config .config import get_finetuning_config_defaults
3534from ads .aqua .data import AquaResourceIdentifier
3635from ads .aqua .finetuning .constants import *
3736from ads .aqua .finetuning .entities import *
@@ -132,7 +131,7 @@ def create(
132131 or create_fine_tuning_details .validation_set_size >= 1
133132 ):
134133 raise AquaValueError (
135- f "Fine tuning validation set size should be a float number in between [0, 1)."
134+ "Fine tuning validation set size should be a float number in between [0, 1)."
136135 )
137136
138137 if create_fine_tuning_details .replica < DEFAULT_FT_REPLICA :
@@ -334,8 +333,6 @@ def create(
334333 parameters = ft_parameters ,
335334 ft_container = ft_container ,
336335 is_custom_container = is_custom_container ,
337- early_stopping_patience = create_fine_tuning_details .early_stopping_patience ,
338- early_stopping_threshold = create_fine_tuning_details .early_stopping_threshold
339336 )
340337 ).create ()
341338 logger .debug (
@@ -396,7 +393,7 @@ def create(
396393 )
397394 # track shapes that were used for fine-tune creation
398395 self .telemetry .record_event_async (
399- category = f "aqua/service/finetune/create/shape/" ,
396+ category = "aqua/service/finetune/create/shape/" ,
400397 action = f"{ create_fine_tuning_details .shape_name } x{ create_fine_tuning_details .replica } " ,
401398 ** telemetry_kwargs ,
402399 )
@@ -479,8 +476,6 @@ def _build_fine_tuning_runtime(
479476 ft_container : str = None ,
480477 finetuning_params : str = None ,
481478 is_custom_container : bool = False ,
482- early_stopping_patience : int = None ,
483- early_stopping_threshold : float = 0.0
484479 ) -> Runtime :
485480 """Builds fine tuning runtime for Job."""
486481 container = (
@@ -509,8 +504,6 @@ def _build_fine_tuning_runtime(
509504 val_set_size = val_set_size ,
510505 parameters = parameters ,
511506 finetuning_params = finetuning_params ,
512- early_stopping_patience = early_stopping_patience ,
513- early_stopping_threshold = early_stopping_threshold
514507 ),
515508 "CONDA_BUCKET_NS" : CONDA_BUCKET_NS ,
516509 }
@@ -528,13 +521,9 @@ def _build_oci_launch_cmd(
528521 val_set_size : float ,
529522 parameters : AquaFineTuningParams ,
530523 finetuning_params : str = None ,
531- early_stopping_patience : int = None ,
532- early_stopping_threshold : float = 0.0
533524 ) -> str :
534525 """Builds the oci launch cmd for fine tuning container runtime."""
535526 oci_launch_cmd = f"--training_data { dataset_path } --output_dir { report_path } --val_set_size { val_set_size } "
536- if early_stopping_patience :
537- oci_launch_cmd += f"--early_stopping_patience { early_stopping_patience } --early_stopping_threshold { early_stopping_threshold } "
538527 for key , value in asdict (parameters ).items ():
539528 if value is not None :
540529 if key == "batch_size" :
@@ -543,6 +532,12 @@ def _build_oci_launch_cmd(
543532 oci_launch_cmd += f"--num_{ key } { value } "
544533 elif key == "lora_target_modules" :
545534 oci_launch_cmd += f"--{ key } { ',' .join (str (k ) for k in value )} "
535+ elif key == "early_stopping_patience" :
536+ if value != 0 :
537+ oci_launch_cmd += f"--{ key } { value } "
538+ elif key == "early_stopping_threshold" :
539+ if "early_stopping_patience" in oci_launch_cmd :
540+ oci_launch_cmd += f"--{ key } { value } "
546541 else :
547542 oci_launch_cmd += f"--{ key } { value } "
548543
0 commit comments