Skip to content

Commit 77bda8e

Browse files
maxc01Xingchen Ma
andauthored
feature: Add autotune for HyperparameterTuner (#3892)
* feature: Add autotune for HyperparameterTuner * change: Rename keep_static, make comments more clean * change: further clean comments relaetd to Autotune --------- Co-authored-by: Xingchen Ma <xgchenma@amazon.com>
1 parent e981daa commit 77bda8e

File tree

6 files changed

+495
-6
lines changed

6 files changed

+495
-6
lines changed

src/sagemaker/session.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,6 +2521,8 @@ def tune( # noqa: C901
25212521
random_seed=None,
25222522
environment=None,
25232523
hpo_resource_config=None,
2524+
autotune=False,
2525+
auto_parameters=None,
25242526
):
25252527
"""Create an Amazon SageMaker hyperparameter tuning job.
25262528
@@ -2626,6 +2628,11 @@ def tune( # noqa: C901
26262628
* volume_kms_key_id: The AWS Key Management Service (AWS KMS) key
26272629
that Amazon SageMaker uses to encrypt data on the storage
26282630
volume attached to the ML compute instance(s) that run the training job.
2631+
autotune (bool): Whether the parameter ranges or other unset settings of a tuning job
2632+
should be chosen automatically (default: False).
2633+
auto_parameters (dict[str, str]): Dictionary of auto parameters. The keys are names
2634+
of auto parameters and values are example values of auto parameters
2635+
(default: ``None``).
26292636
"""
26302637

26312638
tune_request = {
@@ -2642,6 +2649,7 @@ def tune( # noqa: C901
26422649
random_seed=random_seed,
26432650
strategy_config=strategy_config,
26442651
completion_criteria_config=completion_criteria_config,
2652+
auto_parameters=auto_parameters,
26452653
),
26462654
"TrainingJobDefinition": self._map_training_config(
26472655
static_hyperparameters=static_hyperparameters,
@@ -2668,6 +2676,9 @@ def tune( # noqa: C901
26682676
if warm_start_config is not None:
26692677
tune_request["WarmStartConfig"] = warm_start_config
26702678

2679+
if autotune:
2680+
tune_request["Autotune"] = {"Mode": "Enabled"}
2681+
26712682
tags = _append_project_tags(tags)
26722683
if tags is not None:
26732684
tune_request["Tags"] = tags
@@ -2684,6 +2695,7 @@ def create_tuning_job(
26842695
training_config_list=None,
26852696
warm_start_config=None,
26862697
tags=None,
2698+
autotune=False,
26872699
):
26882700
"""Create an Amazon SageMaker hyperparameter tuning job.
26892701
@@ -2703,6 +2715,8 @@ def create_tuning_job(
27032715
other required configurations.
27042716
tags (list[dict]): List of tags for labeling the tuning job. For more, see
27052717
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
2718+
autotune (bool): Whether the parameter ranges or other unset settings of a tuning job
2719+
should be chosen automatically.
27062720
"""
27072721

27082722
if training_config is None and training_config_list is None:
@@ -2719,6 +2733,7 @@ def create_tuning_job(
27192733
training_config_list=training_config_list,
27202734
warm_start_config=warm_start_config,
27212735
tags=tags,
2736+
autotune=autotune,
27222737
)
27232738

27242739
def submit(request):
@@ -2736,6 +2751,7 @@ def _get_tuning_request(
27362751
training_config_list=None,
27372752
warm_start_config=None,
27382753
tags=None,
2754+
autotune=False,
27392755
):
27402756
"""Construct CreateHyperParameterTuningJob request
27412757
@@ -2751,13 +2767,17 @@ def _get_tuning_request(
27512767
other required configurations.
27522768
tags (list[dict]): List of tags for labeling the tuning job. For more, see
27532769
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
2770+
autotune (bool): Whether the parameter ranges or other unset settings of a tuning job
2771+
should be chosen automatically.
27542772
Returns:
27552773
dict: A dictionary for CreateHyperParameterTuningJob request
27562774
"""
27572775
tune_request = {
27582776
"HyperParameterTuningJobName": job_name,
27592777
"HyperParameterTuningJobConfig": self._map_tuning_config(**tuning_config),
27602778
}
2779+
if autotune:
2780+
tune_request["Autotune"] = {"Mode": "Enabled"}
27612781

27622782
if training_config is not None:
27632783
tune_request["TrainingJobDefinition"] = self._map_training_config(**training_config)
@@ -2803,6 +2823,7 @@ def _map_tuning_config(
28032823
random_seed=None,
28042824
strategy_config=None,
28052825
completion_criteria_config=None,
2826+
auto_parameters=None,
28062827
):
28072828
"""Construct tuning job configuration dictionary.
28082829
@@ -2829,6 +2850,8 @@ def _map_tuning_config(
28292850
strategy.
28302851
completion_criteria_config (dict): A configuration
28312852
for the completion criteria.
2853+
auto_parameters (dict): Dictionary of auto parameters. The keys are names of auto
2854+
parameters and valeus are example values of auto parameters.
28322855
28332856
Returns:
28342857
A dictionary of tuning job configuration. For format details, please refer to
@@ -2858,6 +2881,13 @@ def _map_tuning_config(
28582881
if parameter_ranges is not None:
28592882
tuning_config["ParameterRanges"] = parameter_ranges
28602883

2884+
if auto_parameters is not None:
2885+
if parameter_ranges is None:
2886+
tuning_config["ParameterRanges"] = {}
2887+
tuning_config["ParameterRanges"]["AutoParameters"] = [
2888+
{"Name": name, "ValueHint": value} for name, value in auto_parameters.items()
2889+
]
2890+
28612891
if strategy_config is not None:
28622892
tuning_config["StrategyConfig"] = strategy_config
28632893

@@ -2919,6 +2949,7 @@ def _map_training_config(
29192949
checkpoint_local_path=None,
29202950
max_retry_attempts=None,
29212951
environment=None,
2952+
auto_parameters=None,
29222953
):
29232954
"""Construct a dictionary of training job configuration from the arguments.
29242955
@@ -3039,6 +3070,13 @@ def _map_training_config(
30393070
if parameter_ranges is not None:
30403071
training_job_definition["HyperParameterRanges"] = parameter_ranges
30413072

3073+
if auto_parameters is not None:
3074+
if parameter_ranges is None:
3075+
training_job_definition["HyperParameterRanges"] = {}
3076+
training_job_definition["HyperParameterRanges"]["AutoParameters"] = [
3077+
{"Name": name, "ValueHint": value} for name, value in auto_parameters.items()
3078+
]
3079+
30423080
if max_retry_attempts is not None:
30433081
training_job_definition["RetryStrategy"] = {"MaximumRetryAttempts": max_retry_attempts}
30443082

0 commit comments

Comments
 (0)