2222import time
2323
2424from sagemaker .estimator import Framework
25- from sagemaker .fw_utils import framework_name_from_image , framework_version_from_tag , \
26- empty_framework_version_warning
25+ import sagemaker .fw_utils as fw
2726from sagemaker .tensorflow .defaults import TF_VERSION
2827from sagemaker .tensorflow .model import TensorFlowModel
2928from sagemaker .tensorflow .serving import Model
3433LOGGER = logging .getLogger ('sagemaker' )
3534
3635
36+ _FRAMEWORK_MODE_ARGS = ('training_steps' , 'evaluation_steps' , 'requirements_file' , 'checkpoint_path' )
37+ _SCRIPT_MODE = 'tensorflow-scriptmode'
38+ _SCRIPT_MODE_SERVING_ERROR_MSG = 'Script mode containers does not support serving yet. ' \
39+ 'Please use our new tensorflow-serving container by creating the model ' \
40+ 'with \' endpoint_type\' set to \' tensorflow-serving\' .'
41+ _SCRIPT_MODE_TENSORBOARD_WARNING = 'Tensorboard is not supported with script mode. You can run the following ' \
42+ 'command: tensorboard --logdir {} --host localhost --port 6006 This can be ' \
43+ 'run from anywhere with access to the S3 URI used as the logdir.'
44+
45+
3746class Tensorboard (threading .Thread ):
3847 def __init__ (self , estimator , logdir = None ):
3948 """Initialize ``Tensorboard`` instance.
@@ -163,9 +172,9 @@ class TensorFlow(Framework):
163172
164173 __framework_name__ = 'tensorflow'
165174
166- def __init__ (self , training_steps = None , evaluation_steps = None , checkpoint_path = None ,
167- py_version = 'py2' , framework_version = None , requirements_file = '' , image_name = None ,
168- ** kwargs ):
175+ def __init__ (self , training_steps = None , evaluation_steps = None , checkpoint_path = None , py_version = 'py2' ,
176+ framework_version = None , model_dir = None , requirements_file = '' , image_name = None ,
177+ script_mode = False , distributions = None , ** kwargs ):
169178 """Initialize an ``TensorFlow`` estimator.
170179 Args:
171180 training_steps (int): Perform this many steps of training. `None`, the default means train forever.
@@ -176,6 +185,9 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
176185 py_version (str): Python version you want to use for executing your model training code (default: 'py2').
177186 framework_version (str): TensorFlow version you want to use for executing your model training code.
178187 List of supported versions https://github.com/aws/sagemaker-python-sdk#tensorflow-sagemaker-estimators
188+ model_dir (str): S3 location where the checkpoint data and models can be exported to during training
189+ (default: None). If not specified a default S3 URI will be generated. It will be passed in the
190+ training script as one of the command line arguments.
179191 requirements_file (str): Path to a ``requirements.txt`` file (default: ''). The path should be within and
180192 relative to ``source_dir``. Details on the format can be found in the
181193 `Pip User Guide <https://pip.pypa.io/en/stable/reference/pip_install/#requirements-file-format>`_.
@@ -185,21 +197,61 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
185197 Examples:
186198 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
187199 custom-image:latest.
200+ script_mode (bool): If set to True will the estimator will use the Script Mode containers (default: False).
201+ This will be ignored if py_version is set to 'py3'.
202+ distribution (dict): A dictionary with information on how to run distributed training
203+ (default: None). Currently we only support distributed training with parameter servers. To enable it
204+ use the following setup:
205+ {
206+ 'parameter_server':
207+ {
208+ 'enabled': True
209+ }
210+ }
188211 **kwargs: Additional kwargs passed to the Framework constructor.
189212 """
190213 if framework_version is None :
191- LOGGER .warning (empty_framework_version_warning (TF_VERSION , TF_VERSION ))
214+ LOGGER .warning (fw . empty_framework_version_warning (TF_VERSION , TF_VERSION ))
192215 self .framework_version = framework_version or TF_VERSION
193216
194217 super (TensorFlow , self ).__init__ (image_name = image_name , ** kwargs )
195218 self .checkpoint_path = checkpoint_path
196219 self .py_version = py_version
197220 self .training_steps = training_steps
198221 self .evaluation_steps = evaluation_steps
222+ self .model_dir = model_dir
223+ self .script_mode = script_mode
224+ self .distributions = distributions or {}
199225
226+ self ._validate_args (py_version = py_version , script_mode = script_mode , framework_version = framework_version ,
227+ training_steps = training_steps , evaluation_steps = evaluation_steps ,
228+ requirements_file = requirements_file , checkpoint_path = checkpoint_path )
200229 self ._validate_requirements_file (requirements_file )
201230 self .requirements_file = requirements_file
202231
232+ def _validate_args (self , py_version , script_mode , framework_version , training_steps ,
233+ evaluation_steps , requirements_file , checkpoint_path ):
234+
235+ if py_version == 'py3' or script_mode :
236+
237+ if framework_version is None :
238+ raise AttributeError (fw .EMPTY_FRAMEWORK_VERSION_ERROR )
239+
240+ found_args = []
241+ if training_steps :
242+ found_args .append ('training_steps' )
243+ if evaluation_steps :
244+ found_args .append ('evaluation_steps' )
245+ if requirements_file :
246+ found_args .append ('requirements_file' )
247+ if checkpoint_path :
248+ found_args .append ('checkpoint_path' )
249+ if found_args :
250+ raise AttributeError (
251+ '{} are deprecated in script mode. Please do not set {}.'
252+ .format (', ' .join (_FRAMEWORK_MODE_ARGS ), ', ' .join (found_args ))
253+ )
254+
203255 def _validate_requirements_file (self , requirements_file ):
204256 if not requirements_file :
205257 return
@@ -245,7 +297,10 @@ def fit_super():
245297 if run_tensorboard_locally and wait is False :
246298 raise ValueError ("Tensorboard is not supported with async fit" )
247299
248- if run_tensorboard_locally :
300+ if self ._script_mode_enabled () and run_tensorboard_locally :
301+ LOGGER .warning (_SCRIPT_MODE_TENSORBOARD_WARNING .format (self .model_dir ))
302+ fit_super ()
303+ elif run_tensorboard_locally :
249304 tensorboard = Tensorboard (self )
250305 tensorboard .validate_requirements ()
251306
@@ -275,13 +330,13 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
275330 model_channel_name )
276331
277332 # Move some of the tensorflow specific init params from hyperparameters into the main init params.
278- for argument in [ 'checkpoint_path' , 'training_steps' , 'evaluation_steps' ] :
333+ for argument in ( 'checkpoint_path' , 'training_steps' , 'evaluation_steps' , 'model_dir' ) :
279334 value = init_params ['hyperparameters' ].pop (argument , None )
280335 if value is not None :
281336 init_params [argument ] = value
282337
283338 image_name = init_params .pop ('image' )
284- framework , py_version , tag = framework_name_from_image (image_name )
339+ framework , py_version , tag = fw . framework_name_from_image (image_name )
285340 if not framework :
286341 # If we were unable to parse the framework name from the image it is not one of our
287342 # officially supported images, in this case just add the image to the init params.
@@ -294,7 +349,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
294349 # containing framework version, device type and python version (e.g. '1.5-gpu-py2').
295350 # For backward compatibility map deprecated image tag '1.0' to a '1.4' framework version
296351 # otherwise extract framework version from the tag itself.
297- init_params ['framework_version' ] = '1.4' if tag == '1.0' else framework_version_from_tag (
352+ init_params ['framework_version' ] = '1.4' if tag == '1.0' else fw . framework_version_from_tag (
298353 tag )
299354
300355 training_job_name = init_params ['base_job_name' ]
@@ -328,7 +383,7 @@ def create_model(self, model_server_workers=None, role=None,
328383 """
329384
330385 role = role or self .role
331- if endpoint_type == 'tensorflow-serving' :
386+ if endpoint_type == 'tensorflow-serving' or self . _script_mode_enabled () :
332387 return self ._create_tfs_model (role = role , vpc_config_override = vpc_config_override )
333388
334389 return self ._create_default_model (model_server_workers = model_server_workers , role = role ,
@@ -362,18 +417,39 @@ def hyperparameters(self):
362417 """Return hyperparameters used by your custom TensorFlow code during model training."""
363418 hyperparameters = super (TensorFlow , self ).hyperparameters ()
364419
365- if not self .checkpoint_path :
366- local_code = get_config_value ('local.local_code' , self .sagemaker_session .config )
367- if self .sagemaker_session .local_mode and local_code :
368- self .checkpoint_path = '/opt/ml/shared/checkpoints'
369- else :
370- self .checkpoint_path = os .path .join (self .output_path ,
371- self ._current_job_name , 'checkpoints' )
420+ self .checkpoint_path = self .checkpoint_path or self ._default_s3_path ('checkpoints' )
372421
373- additional_hyperparameters = {'checkpoint_path' : self .checkpoint_path ,
374- 'training_steps' : self .training_steps ,
375- 'evaluation_steps' : self .evaluation_steps ,
376- 'sagemaker_requirements' : self .requirements_file }
422+ if self ._script_mode_enabled ():
423+ self .model_dir = self .model_dir or self ._default_s3_path ('model' )
424+ additional_hyperparameters = {'model_dir' : self .model_dir }
425+ if 'parameter_server' in self .distributions :
426+ enabled = self .distributions ['parameter_server' ].get ('enabled' , False )
427+ additional_hyperparameters [self .LAUNCH_PS_ENV_NAME ] = enabled
428+ else :
429+ additional_hyperparameters = {'checkpoint_path' : self .checkpoint_path ,
430+ 'training_steps' : self .training_steps ,
431+ 'evaluation_steps' : self .evaluation_steps ,
432+ 'sagemaker_requirements' : self .requirements_file }
377433
378434 hyperparameters .update (Framework ._json_encode_hyperparameters (additional_hyperparameters ))
379435 return hyperparameters
436+
437+ def _default_s3_path (self , directory ):
438+ local_code = get_config_value ('local.local_code' , self .sagemaker_session .config )
439+ if self .sagemaker_session .local_mode and local_code :
440+ return '/opt/ml/shared/{}' .format (directory )
441+ else :
442+ return os .path .join (self .output_path , self ._current_job_name , directory )
443+
444+ def _script_mode_enabled (self ):
445+ return self .py_version == 'py3' or self .script_mode
446+
447+ def train_image (self ):
448+ if self .image_name :
449+ return self .image_name
450+
451+ if self ._script_mode_enabled ():
452+ return fw .create_image_uri (self .sagemaker_session .boto_region_name , _SCRIPT_MODE ,
453+ self .train_instance_type , self .framework_version , self .py_version )
454+
455+ return super (TensorFlow , self ).train_image ()
0 commit comments