@@ -176,7 +176,8 @@ class TensorFlow(Framework):
176176 def __init__ (self , training_steps = None , evaluation_steps = None , checkpoint_path = None , py_version = 'py2' ,
177177 framework_version = None , model_dir = None , requirements_file = '' , image_name = None ,
178178 script_mode = False , distributions = None , ** kwargs ):
179- """Initialize an ``TensorFlow`` estimator.
179+ """Initialize a ``TensorFlow`` estimator.
180+
180181 Args:
181182 training_steps (int): Perform this many steps of training. `None`, the default means train forever.
182183 evaluation_steps (int): Perform this many steps of evaluation. `None`, the default means that evaluation
@@ -195,26 +196,36 @@ def __init__(self, training_steps=None, evaluation_steps=None, checkpoint_path=N
195196 image_name (str): If specified, the estimator will use this image for training and hosting, instead of
196197 selecting the appropriate SageMaker official image based on framework_version and py_version. It can
197198 be an ECR url or dockerhub image and tag.
198- Examples:
199- 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
200- custom-image:latest.
199+
200+ Examples:
201+ 123.dkr.ecr.us-west-2.amazonaws.com/my-custom-image:1.0
202+ custom-image:latest.
201203 script_mode (bool): If set to True will the estimator will use the Script Mode containers (default: False).
202204 This will be ignored if py_version is set to 'py3'.
203205 distributions (dict): A dictionary with information on how to run distributed training
204- (default: None). Currently we support distributed training with parameter servers and MPI. To enable
205- parameter server use the following setup:
206+ (default: None). Currently we support distributed training with parameter servers and MPI.
207+ To enable parameter server use the following setup:
208+
209+ .. code:: python
210+
211+ {
206212 'parameter_server':
207213 {
208214 'enabled': True
209215 }
210216 }
217+
211218 To enable MPI:
219+
220+ .. code:: python
221+
212222 {
213223 'mpi':
214224 {
215225 'enabled': True
216226 }
217227 }
228+
218229 **kwargs: Additional kwargs passed to the Framework constructor.
219230 """
220231 if framework_version is None :
@@ -281,13 +292,15 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None, run_tensorboard_
281292 Args:
282293 inputs (str or dict or sagemaker.session.s3_input): Information about the training data.
283294 This can be one of three types:
284- (str) - the S3 location where training data is saved.
285- (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
295+
296+ * (str) - the S3 location where training data is saved.
297+ * (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple channels for
286298 training data, you can specify a dict mapping channel names
287299 to strings or :func:`~sagemaker.session.s3_input` objects.
288- (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
300+ * (sagemaker.session.s3_input) - channel configuration for S3 data sources that can provide
289301 additional information as well as the path to the training dataset.
290302 See :func:`sagemaker.session.s3_input` for full details.
303+
291304 wait (bool): Whether the call should wait until the job completes (default: True).
292305 logs (bool): Whether to show the logs produced by the job.
293306 Only meaningful when wait is True (default: True).
0 commit comments