@@ -74,8 +74,9 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None):
7474 """Create an AutoML Job with the input dataset.
7575
7676 Args:
77- inputs (str or list[str]): Local path or S3 Uri where the training data is stored. If a
78- local path is provided, the dataset will be uploaded to an S3 location.
77+ inputs (str or list[str] or AutoMLInput): Local path or S3 Uri where the training data
78+ is stored. Or an AutoMLInput object. If a local path is provided, the dataset will
79+ be uploaded to an S3 location.
7980 wait (bool): Whether the call should wait until the job completes (default: True).
8081 logs (bool): Whether to show the logs produced by the job.
8182 Only meaningful when wait is True (default: True).
@@ -95,7 +96,7 @@ def fit(self, inputs=None, wait=True, logs=True, job_name=None):
9596 inputs = self .sagemaker_session .upload_data (inputs , key_prefix = "auto-ml-input-data" )
9697 self ._prepare_for_auto_ml_job (job_name = job_name )
9798
98- self .latest_auto_ml_job = _AutoMLJob .start_new (self , inputs ) # pylint: disable=W0201
99+ self .latest_auto_ml_job = AutoMLJob .start_new (self , inputs ) # pylint: disable=W0201
99100 if wait :
100101 self .latest_auto_ml_job .wait (logs = logs )
101102
@@ -385,9 +386,48 @@ def _prepare_for_auto_ml_job(self, job_name=None):
385386 self .output_path = "s3://{}/" .format (self .sagemaker_session .default_bucket ())
386387
387388
388- class _AutoMLJob (_Job ):
389+ class AutoMLInput (object ):
390+ """Accepts parameters that specify an S3 input for an auto ml job and provides
391+ a method to turn those parameters into a dictionary."""
392+
393+ def __init__ (self , inputs , target_attribute_name , compression = None ):
394+ """Convert an S3 Uri or a list of S3 Uri to an AutoMLInput object.
395+
396+ :param inputs (str, list[str]): a string or a list of string that points to (a)
397+ S3 location(s) where input data is stored.
398+ :param target_attribute_name (str): the target attribute name for regression
399+ or classification.
400+ :param compression (str): if training data is compressed, the compression type.
401+ The default value is None.
402+ """
403+ self .inputs = inputs
404+ self .target_attribute_name = target_attribute_name
405+ self .compression = compression
406+
407+ def to_request_dict (self ):
408+ """Generates a request dictionary using the parameters provided to the class."""
409+ # Create the request dictionary.
410+ auto_ml_input = []
411+ if isinstance (self .inputs , string_types ):
412+ self .inputs = [self .inputs ]
413+ for entry in self .inputs :
414+ input_entry = {
415+ "DataSource" : {"S3DataSource" : {"S3DataType" : "S3Prefix" , "S3Uri" : entry }},
416+ "TargetAttributeName" : self .target_attribute_name ,
417+ }
418+ if self .compression is not None :
419+ input_entry ["CompressionType" ] = self .compression
420+ auto_ml_input .append (input_entry )
421+ return auto_ml_input
422+
423+
424+ class AutoMLJob (_Job ):
389425 """A class for interacting with CreateAutoMLJob API."""
390426
427+ def __init__ (self , sagemaker_session , job_name , inputs ):
428+ self .inputs = inputs
429+ super (AutoMLJob , self ).__init__ (sagemaker_session = sagemaker_session , job_name = job_name )
430+
391431 @classmethod
392432 def start_new (cls , auto_ml , inputs ):
393433 """Create a new Amazon SageMaker AutoML job from auto_ml.
@@ -399,7 +439,7 @@ def start_new(cls, auto_ml, inputs):
399439 :meth:`~sagemaker.automl.AutoML.fit`.
400440
401441 Returns:
402- sagemaker.automl._AutoMLJob : Constructed object that captures
442+ sagemaker.automl.AutoMLJob : Constructed object that captures
403443 all information about the started AutoML job.
404444 """
405445 config = cls ._load_config (inputs , auto_ml )
@@ -410,7 +450,7 @@ def start_new(cls, auto_ml, inputs):
410450 auto_ml_args ["tags" ] = auto_ml .tags
411451
412452 auto_ml .sagemaker_session .auto_ml (** auto_ml_args )
413- return cls (auto_ml .sagemaker_session , auto_ml ._current_job_name )
453+ return cls (auto_ml .sagemaker_session , auto_ml ._current_job_name , inputs )
414454
415455 @classmethod
416456 def _load_config (cls , inputs , auto_ml , expand_role = True , validate_uri = True ):
@@ -432,9 +472,12 @@ def _load_config(cls, inputs, auto_ml, expand_role=True, validate_uri=True):
432472 # InputDataConfig
433473 # OutputConfig
434474
435- input_config = cls ._format_inputs_to_input_config (
436- inputs , validate_uri , auto_ml .compression_type , auto_ml .target_attribute_name
437- )
475+ if isinstance (inputs , AutoMLInput ):
476+ input_config = inputs .to_request_dict ()
477+ else :
478+ input_config = cls ._format_inputs_to_input_config (
479+ inputs , validate_uri , auto_ml .compression_type , auto_ml .target_attribute_name
480+ )
438481 output_config = _Job ._prepare_output_config (auto_ml .output_path , auto_ml .output_kms_key )
439482
440483 role = auto_ml .sagemaker_session .expand_role (auto_ml .role ) if expand_role else auto_ml .role
@@ -486,7 +529,9 @@ def _format_inputs_to_input_config(
486529 return None
487530
488531 channels = []
489- if isinstance (inputs , string_types ):
532+ if isinstance (inputs , AutoMLInput ):
533+ channels .append (inputs .to_request_dict ())
534+ elif isinstance (inputs , string_types ):
490535 channel = _Job ._format_string_uri_input (
491536 inputs ,
492537 validate_uri ,
@@ -540,6 +585,10 @@ def _prepare_auto_ml_stop_condition(
540585
541586 return stopping_condition
542587
588+ def describe (self ):
589+ """Prints out a response from the DescribeAutoMLJob API call."""
590+ return self .sagemaker_session .describe_auto_ml_job (self .job_name )
591+
543592 def wait (self , logs = True ):
544593 """Wait for the AutoML job to finish.
545594 Args:
0 commit comments