@@ -178,6 +178,7 @@ def __init__(
178178 container_entry_point : Optional [List [str ]] = None ,
179179 container_arguments : Optional [List [str ]] = None ,
180180 disable_output_compression : bool = False ,
181+ enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
181182 ** kwargs ,
182183 ):
183184 """Initialize an ``EstimatorBase`` instance.
@@ -540,6 +541,8 @@ def __init__(
540541 to Amazon S3 without compression after training finishes.
541542 enable_infra_check (bool or PipelineVariable): Optional.
542543 Specifies whether it is running Sagemaker built-in infra check jobs.
544+ enable_remote_debug (bool or PipelineVariable): Optional.
545+ Specifies whether RemoteDebug is enabled for the training job
543546 """
544547 instance_count = renamed_kwargs (
545548 "train_instance_count" , "instance_count" , instance_count , kwargs
@@ -777,6 +780,8 @@ def __init__(
777780
778781 self .tensorboard_app = TensorBoardApp (region = self .sagemaker_session .boto_region_name )
779782
783+ self ._enable_remote_debug = enable_remote_debug
784+
780785 @abstractmethod
781786 def training_image_uri (self ):
782787 """Return the Docker image to use for training.
@@ -1958,6 +1963,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
19581963 max_wait = job_details .get ("StoppingCondition" , {}).get ("MaxWaitTimeInSeconds" )
19591964 if max_wait :
19601965 init_params ["max_wait" ] = max_wait
1966+
1967+ if "RemoteDebugConfig" in job_details :
1968+ init_params ["enable_remote_debug" ] = job_details ["RemoteDebugConfig" ].get (
1969+ "EnableRemoteDebug"
1970+ )
19611971 return init_params
19621972
19631973 def _get_instance_type (self ):
@@ -2292,6 +2302,32 @@ def update_profiler(
22922302
22932303 _TrainingJob .update (self , profiler_rule_configs , profiler_config_request_dict )
22942304
2305+ def get_remote_debug_config (self ):
2306+ """dict: Return the configuration of RemoteDebug"""
2307+ return (
2308+ None
2309+ if self ._enable_remote_debug is None
2310+ else {"EnableRemoteDebug" : self ._enable_remote_debug }
2311+ )
2312+
2313+ def enable_remote_debug (self ):
2314+ """Enable remote debug for a training job."""
2315+ self ._update_remote_debug (True )
2316+
2317+ def disable_remote_debug (self ):
2318+ """Disable remote debug for a training job."""
2319+ self ._update_remote_debug (False )
2320+
2321+ def _update_remote_debug (self , enable_remote_debug : bool ):
2322+ """Update to enable or disable remote debug for a training job.
2323+
2324+ This method updates the ``_enable_remote_debug`` parameter
2325+ and enables or disables remote debug for a training job
2326+ """
2327+ self ._ensure_latest_training_job ()
2328+ _TrainingJob .update (self , remote_debug_config = {"EnableRemoteDebug" : enable_remote_debug })
2329+ self ._enable_remote_debug = enable_remote_debug
2330+
22952331 def get_app_url (
22962332 self ,
22972333 app_type ,
@@ -2520,6 +2556,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
25202556 if estimator .profiler_config :
25212557 train_args ["profiler_config" ] = estimator .profiler_config ._to_request_dict ()
25222558
2559+ if estimator .get_remote_debug_config () is not None :
2560+ train_args ["remote_debug_config" ] = estimator .get_remote_debug_config ()
2561+
25232562 return train_args
25242563
25252564 @classmethod
@@ -2549,7 +2588,12 @@ def _is_local_channel(cls, input_uri):
25492588
25502589 @classmethod
25512590 def update (
2552- cls , estimator , profiler_rule_configs = None , profiler_config = None , resource_config = None
2591+ cls ,
2592+ estimator ,
2593+ profiler_rule_configs = None ,
2594+ profiler_config = None ,
2595+ resource_config = None ,
2596+ remote_debug_config = None ,
25532597 ):
25542598 """Update a running Amazon SageMaker training job.
25552599
@@ -2562,20 +2606,31 @@ def update(
25622606 resource_config (dict): Configuration of the resources for the training job. You can
25632607 update the keep-alive period if the warm pool status is `Available`. No other fields
25642608 can be updated. (default: None).
2609+ remote_debug_config (dict): Configuration for RemoteDebug. (default: ``None``)
2610+ The dict can contain 'EnableRemoteDebug'(bool).
2611+ For example,
2612+
2613+ .. code:: python
2614+
2615+ remote_debug_config = {
2616+ "EnableRemoteDebug": True,
2617+ } (default: None).
25652618
25662619 Returns:
25672620 sagemaker.estimator._TrainingJob: Constructed object that captures
25682621 all information about the updated training job.
25692622 """
25702623 update_args = cls ._get_update_args (
2571- estimator , profiler_rule_configs , profiler_config , resource_config
2624+ estimator , profiler_rule_configs , profiler_config , resource_config , remote_debug_config
25722625 )
25732626 estimator .sagemaker_session .update_training_job (** update_args )
25742627
25752628 return estimator .latest_training_job
25762629
25772630 @classmethod
2578- def _get_update_args (cls , estimator , profiler_rule_configs , profiler_config , resource_config ):
2631+ def _get_update_args (
2632+ cls , estimator , profiler_rule_configs , profiler_config , resource_config , remote_debug_config
2633+ ):
25792634 """Constructs a dict of arguments for updating an Amazon SageMaker training job.
25802635
25812636 Args:
@@ -2596,6 +2651,7 @@ def _get_update_args(cls, estimator, profiler_rule_configs, profiler_config, res
25962651 update_args .update (build_dict ("profiler_rule_configs" , profiler_rule_configs ))
25972652 update_args .update (build_dict ("profiler_config" , profiler_config ))
25982653 update_args .update (build_dict ("resource_config" , resource_config ))
2654+ update_args .update (build_dict ("remote_debug_config" , remote_debug_config ))
25992655
26002656 return update_args
26012657
@@ -2694,6 +2750,7 @@ def __init__(
26942750 container_arguments : Optional [List [str ]] = None ,
26952751 disable_output_compression : bool = False ,
26962752 enable_infra_check : Optional [Union [bool , PipelineVariable ]] = None ,
2753+ enable_remote_debug : Optional [Union [bool , PipelineVariable ]] = None ,
26972754 ** kwargs ,
26982755 ):
26992756 """Initialize an ``Estimator`` instance.
@@ -3055,6 +3112,8 @@ def __init__(
30553112 to Amazon S3 without compression after training finishes.
30563113 enable_infra_check (bool or PipelineVariable): Optional.
30573114 Specifies whether it is running Sagemaker built-in infra check jobs.
3115+ enable_remote_debug (bool or PipelineVariable): Optional.
3116+ Specifies whether RemoteDebug is enabled for the training job
30583117 """
30593118 self .image_uri = image_uri
30603119 self ._hyperparameters = hyperparameters .copy () if hyperparameters else {}
@@ -3106,6 +3165,7 @@ def __init__(
31063165 container_entry_point = container_entry_point ,
31073166 container_arguments = container_arguments ,
31083167 disable_output_compression = disable_output_compression ,
3168+ enable_remote_debug = enable_remote_debug ,
31093169 ** kwargs ,
31103170 )
31113171
0 commit comments