@@ -425,7 +425,7 @@ def compile_model(
425425 LOGGER .info ("Creating compilation-job with name: %s" , job_name )
426426 self .sagemaker_client .create_compilation_job (** compilation_job_request )
427427
428- def tune (
428+ def tune ( # noqa: C901
429429 self ,
430430 job_name ,
431431 strategy ,
@@ -450,6 +450,9 @@ def tune(
450450 early_stopping_type = "Off" ,
451451 encrypt_inter_container_traffic = False ,
452452 vpc_config = None ,
453+ train_use_spot_instances = False ,
454+ checkpoint_s3_uri = None ,
455+ checkpoint_local_path = None ,
453456 ):
454457 """Create an Amazon SageMaker hyperparameter tuning job
455458
@@ -512,6 +515,18 @@ def tune(
512515 The key in vpc_config is 'Subnets'.
513516 * security_group_ids (list[str]): List of security group ids.
514517 The key in vpc_config is 'SecurityGroupIds'.
518+ train_use_spot_instances (bool): whether to use spot instances for training.
519+ checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
520+ that the algorithm persists (if any) during training. (default:
521+ ``None``).
522+ checkpoint_local_path (str): The local path that the algorithm
523+ writes its checkpoints to. SageMaker will persist all files
524+ under this path to `checkpoint_s3_uri` continually during
525+ training. On job startup the reverse happens - data from the
526+ s3 location is downloaded to this path before the algorithm is
527+ started. If the path is unset then SageMaker assumes the
528+ checkpoints will be provided under `/opt/ml/checkpoints/`.
529+ (default: ``None``).
515530
516531 """
517532 tune_request = {
@@ -569,6 +584,15 @@ def tune(
569584 if encrypt_inter_container_traffic :
570585 tune_request ["TrainingJobDefinition" ]["EnableInterContainerTrafficEncryption" ] = True
571586
587+ if train_use_spot_instances :
588+ tune_request ["TrainingJobDefinition" ]["EnableManagedSpotTraining" ] = True
589+
590+ if checkpoint_s3_uri :
591+ checkpoint_config = {"S3Uri" : checkpoint_s3_uri }
592+ if checkpoint_local_path :
593+ checkpoint_config ["LocalPath" ] = checkpoint_local_path
594+ tune_request ["TrainingJobDefinition" ]["CheckpointConfig" ] = checkpoint_config
595+
572596 LOGGER .info ("Creating hyperparameter tuning job with name: %s" , job_name )
573597 LOGGER .debug ("tune request: %s" , json .dumps (tune_request , indent = 4 ))
574598 self .sagemaker_client .create_hyper_parameter_tuning_job (** tune_request )
0 commit comments