Skip to content

Commit 0c6b422

Browse files
Adding early stopping params for FT
1 parent eb785ca commit 0c6b422

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

ads/aqua/finetuning/entities.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,5 @@ class CreateFineTuningDetails(DataClassSerializable):
100100
log_id: Optional[str] = None
101101
log_group_id: Optional[str] = None
102102
force_overwrite: Optional[bool] = False
103+
early_stopping_patience: Optional[int] = None
104+
early_stopping_threshold: Optional[float] = 0.0

ads/aqua/finetuning/finetuning.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ def create(
334334
parameters=ft_parameters,
335335
ft_container=ft_container,
336336
is_custom_container=is_custom_container,
337+
early_stopping_patience=create_fine_tuning_details.early_stopping_patience,
338+
early_stopping_threshold=create_fine_tuning_details.early_stopping_threshold
337339
)
338340
).create()
339341
logger.debug(
@@ -477,6 +479,8 @@ def _build_fine_tuning_runtime(
477479
ft_container: str = None,
478480
finetuning_params: str = None,
479481
is_custom_container: bool = False,
482+
early_stopping_patience: int = None,
483+
early_stopping_threshold: float = 0.0
480484
) -> Runtime:
481485
"""Builds fine tuning runtime for Job."""
482486
container = (
@@ -505,6 +509,8 @@ def _build_fine_tuning_runtime(
505509
val_set_size=val_set_size,
506510
parameters=parameters,
507511
finetuning_params=finetuning_params,
512+
early_stopping_patience=early_stopping_patience,
513+
early_stopping_threshold=early_stopping_threshold
508514
),
509515
"CONDA_BUCKET_NS": CONDA_BUCKET_NS,
510516
}
@@ -522,9 +528,11 @@ def _build_oci_launch_cmd(
522528
val_set_size: float,
523529
parameters: AquaFineTuningParams,
524530
finetuning_params: str = None,
531+
early_stopping_patience: int = None,
532+
early_stopping_threshold: float = 0.0
525533
) -> str:
526534
"""Builds the oci launch cmd for fine tuning container runtime."""
527-
oci_launch_cmd = f"--training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} "
535+
oci_launch_cmd = f"--training_data {dataset_path} --output_dir {report_path} --val_set_size {val_set_size} --early_stopping_patience {early_stopping_patience} --early_stopping_threshold {early_stopping_threshold}"
528536
for key, value in asdict(parameters).items():
529537
if value is not None:
530538
if key == "batch_size":

0 commit comments

Comments
 (0)