@@ -210,21 +210,6 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None):
210210 if wait :
211211 self .latest_training_job .wait (logs = logs )
212212
213- @classmethod
214- def _from_training_job (cls , init_params , hyperparameters , image , sagemaker_session ):
215- """Create an Estimator from existing training job data.
216-
217- Args:
218- init_params (dict): The init_params the training job was created with.
219- hyperparameters (dict): The hyperparameters the training job was created with.
220- image (str): Container image (if any) the training job was created with
221- sagemaker_session (sagemaker.session.Session): A sagemaker Session to pass to the estimator.
222-
223- Returns: An instance of the calling Estimator Class.
224-
225- """
226- raise NotImplementedError ()
227-
228213 @classmethod
229214 def attach (cls , training_job_name , sagemaker_session = None , model_channel_name = 'model' ):
230215 """Attach to an existing training job.
@@ -262,7 +247,7 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
262247
263248 estimator = cls (sagemaker_session = sagemaker_session , ** init_params )
264249 estimator .latest_training_job = _TrainingJob (sagemaker_session = sagemaker_session ,
265- training_job_name = init_params ['base_job_name' ])
250+ job_name = init_params ['base_job_name' ])
266251 estimator .latest_training_job .wait ()
267252 return estimator
268253
@@ -425,9 +410,6 @@ def _ensure_latest_training_job(self, error_message='Estimator is not associated
425410
426411
427412class _TrainingJob (_Job ):
428- def __init__ (self , sagemaker_session , training_job_name ):
429- super (_TrainingJob , self ).__init__ (sagemaker_session , training_job_name )
430-
431413 @classmethod
432414 def start_new (cls , estimator , inputs ):
433415 """Create a new Amazon SageMaker training job from the estimator.
@@ -627,12 +609,10 @@ class Framework(EstimatorBase):
627609 such as training/deployment images and predictor instances.
628610 """
629611
630- _DISTRIBUTION_SUPPORTED_FRAMEWORKS = ('mxnet' ,)
631- LAUNCH_PS_ENV_NAME = 'sagemaker_parameter_server_enabled'
612+ __framework_name__ = None
632613
633614 def __init__ (self , entry_point , source_dir = None , hyperparameters = None , enable_cloudwatch_metrics = False ,
634- container_log_level = logging .INFO , code_location = None , image_name = None ,
635- distributions = None , ** kwargs ):
615+ container_log_level = logging .INFO , code_location = None , image_name = None , ** kwargs ):
636616 """Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
637617
638618 Args:
@@ -654,8 +634,6 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
654634 image_name (str): An alternate image name to use instead of the official Sagemaker image
655635 for the framework. This is useful to run one of the Sagemaker supported frameworks
656636 with an image containing custom dependencies.
657- distributions (dict): A dictionary with information on how to run distributed training
658- (default: None).
659637 **kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
660638 """
661639 super (Framework , self ).__init__ (** kwargs )
@@ -670,22 +648,6 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
670648 self .image_name = image_name
671649
672650 self ._hyperparameters = hyperparameters or {}
673- self ._configure_distributions (distributions )
674-
675- def _configure_distributions (self , distributions ):
676- if distributions is None :
677- return
678-
679- if self .__framework_name__ not in self ._DISTRIBUTION_SUPPORTED_FRAMEWORKS :
680- raise ValueError ('This framework does not support the distributions option.' )
681-
682- if self .framework_version .split ('.' ) < self ._LOWEST_SCRIPT_MODE_VERSION :
683- raise ValueError ('The distributions option is valid for only versions {} and higher'
684- .format ('.' .join (self ._LOWEST_SCRIPT_MODE_VERSION )))
685-
686- if 'parameter_server' in distributions :
687- enabled = distributions ['parameter_server' ].get ('enabled' , False )
688- self ._hyperparameters [self .LAUNCH_PS_ENV_NAME ] = enabled
689651
690652 def _prepare_for_training (self , job_name = None ):
691653 """Set hyperparameters needed for training. This method will also validate ``source_dir``.
@@ -810,8 +772,11 @@ def train_image(self):
810772 if self .image_name :
811773 return self .image_name
812774 else :
813- return create_image_uri (self .sagemaker_session .boto_region_name , self .__framework_name__ ,
814- self .train_instance_type , self .framework_version , py_version = self .py_version )
775+ return create_image_uri (self .sagemaker_session .boto_region_name ,
776+ self .__framework_name__ ,
777+ self .train_instance_type ,
778+ self .framework_version , # pylint: disable=no-member
779+ py_version = self .py_version ) # pylint: disable=no-member
815780
816781 @classmethod
817782 def attach (cls , training_job_name , sagemaker_session = None , model_channel_name = 'model' ):
@@ -844,7 +809,11 @@ def attach(cls, training_job_name, sagemaker_session=None, model_channel_name='m
844809 Instance of the calling ``Estimator`` Class with the attached training job.
845810 """
846811 estimator = super (Framework , cls ).attach (training_job_name , sagemaker_session , model_channel_name )
847- estimator .uploaded_code = UploadedCode (estimator .source_dir , estimator .entry_point )
812+
813+ # pylint gets confused thinking that estimator is an EstimatorBase instance, but it actually
814+ # is a Framework or any of its derived classes. We can safely ignore the no-member errors.
815+ estimator .uploaded_code = UploadedCode (
816+ estimator .source_dir , estimator .entry_point ) # pylint: disable=no-member
848817 return estimator
849818
850819 @staticmethod
0 commit comments