@@ -51,9 +51,9 @@ def wait(self):
5151 pass
5252
5353 @staticmethod
54- def _load_config (inputs , estimator ):
55- input_config = _Job ._format_inputs_to_input_config (inputs )
56- role = estimator .sagemaker_session .expand_role (estimator .role )
54+ def _load_config (inputs , estimator , expand_role = True , validate_uri = True ):
55+ input_config = _Job ._format_inputs_to_input_config (inputs , validate_uri )
56+ role = estimator .sagemaker_session .expand_role (estimator .role ) if expand_role else estimator . role
5757 output_config = _Job ._prepare_output_config (estimator .output_path , estimator .output_kms_key )
5858 resource_config = _Job ._prepare_resource_config (estimator .train_instance_count ,
5959 estimator .train_instance_type ,
@@ -62,7 +62,8 @@ def _load_config(inputs, estimator):
6262 stop_condition = _Job ._prepare_stop_condition (estimator .train_max_run )
6363 vpc_config = estimator .get_vpc_config ()
6464
65- model_channel = _Job ._prepare_model_channel (input_config , estimator .model_uri , estimator .model_channel_name )
65+ model_channel = _Job ._prepare_model_channel (input_config , estimator .model_uri , estimator .model_channel_name ,
66+ validate_uri )
6667 if model_channel :
6768 input_config = [] if input_config is None else input_config
6869 input_config .append (model_channel )
@@ -75,7 +76,7 @@ def _load_config(inputs, estimator):
7576 'vpc_config' : vpc_config }
7677
7778 @staticmethod
78- def _format_inputs_to_input_config (inputs ):
79+ def _format_inputs_to_input_config (inputs , validate_uri = True ):
7980 if inputs is None :
8081 return None
8182
@@ -86,14 +87,14 @@ def _format_inputs_to_input_config(inputs):
8687
8788 input_dict = {}
8889 if isinstance (inputs , string_types ):
89- input_dict ['training' ] = _Job ._format_string_uri_input (inputs )
90+ input_dict ['training' ] = _Job ._format_string_uri_input (inputs , validate_uri )
9091 elif isinstance (inputs , s3_input ):
9192 input_dict ['training' ] = inputs
9293 elif isinstance (inputs , file_input ):
9394 input_dict ['training' ] = inputs
9495 elif isinstance (inputs , dict ):
9596 for k , v in inputs .items ():
96- input_dict [k ] = _Job ._format_string_uri_input (v )
97+ input_dict [k ] = _Job ._format_string_uri_input (v , validate_uri )
9798 elif isinstance (inputs , list ):
9899 input_dict = _Job ._format_record_set_list_input (inputs )
99100 else :
@@ -111,15 +112,16 @@ def _convert_input_to_channel(channel_name, channel_s3_input):
111112 return channel_config
112113
113114 @staticmethod
114- def _format_string_uri_input (uri_input ):
115- if isinstance (uri_input , str ):
116- if uri_input .startswith ('s3://' ):
117- return s3_input (uri_input )
118- elif uri_input .startswith ('file://' ):
119- return file_input (uri_input )
120- else :
121- raise ValueError ('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
122- '"file://"' )
115+ def _format_string_uri_input (uri_input , validate_uri = True ):
116+ if isinstance (uri_input , str ) and validate_uri and uri_input .startswith ('s3://' ):
117+ return s3_input (uri_input )
118+ elif isinstance (uri_input , str ) and validate_uri and uri_input .startswith ('file://' ):
119+ return file_input (uri_input )
120+ elif isinstance (uri_input , str ) and validate_uri :
121+ raise ValueError ('Training input data must be a valid S3 or FILE URI: must start with "s3://" or '
122+ '"file://"' )
123+ elif isinstance (uri_input , str ):
124+ return s3_input (uri_input )
123125 elif isinstance (uri_input , s3_input ):
124126 return uri_input
125127 elif isinstance (uri_input , file_input ):
@@ -128,7 +130,7 @@ def _format_string_uri_input(uri_input):
128130 raise ValueError ('Cannot format input {}. Expecting one of str, s3_input, or file_input' .format (uri_input ))
129131
130132 @staticmethod
131- def _prepare_model_channel (input_config , model_uri = None , model_channel_name = None ):
133+ def _prepare_model_channel (input_config , model_uri = None , model_channel_name = None , validate_uri = True ):
132134 if not model_uri :
133135 return
134136 elif not model_channel_name :
@@ -139,22 +141,24 @@ def _prepare_model_channel(input_config, model_uri=None, model_channel_name=None
139141 if channel ['ChannelName' ] == model_channel_name :
140142 raise ValueError ('Duplicate channels not allowed.' )
141143
142- model_input = _Job ._format_model_uri_input (model_uri )
144+ model_input = _Job ._format_model_uri_input (model_uri , validate_uri )
143145 model_channel = _Job ._convert_input_to_channel (model_channel_name , model_input )
144146
145147 return model_channel
146148
147149 @staticmethod
148- def _format_model_uri_input (model_uri ):
149- if isinstance (model_uri , string_types ):
150- if model_uri .startswith ('s3://' ):
151- return s3_input (model_uri , input_mode = 'File' , distribution = 'FullyReplicated' ,
152- content_type = 'application/x-sagemaker-model' )
153- elif model_uri .startswith ('file://' ):
154- return file_input (model_uri )
155- else :
156- raise ValueError ('Model URI must be a valid S3 or FILE URI: must start with "s3://" or '
157- '"file://' )
150+ def _format_model_uri_input (model_uri , validate_uri = True ):
151+ if isinstance (model_uri , string_types )and validate_uri and model_uri .startswith ('s3://' ):
152+ return s3_input (model_uri , input_mode = 'File' , distribution = 'FullyReplicated' ,
153+ content_type = 'application/x-sagemaker-model' )
154+ elif isinstance (model_uri , string_types ) and validate_uri and model_uri .startswith ('file://' ):
155+ return file_input (model_uri )
156+ elif isinstance (model_uri , string_types ) and validate_uri :
157+ raise ValueError ('Model URI must be a valid S3 or FILE URI: must start with "s3://" or '
158+ '"file://' )
159+ elif isinstance (model_uri , string_types ):
160+ return s3_input (model_uri , input_mode = 'File' , distribution = 'FullyReplicated' ,
161+ content_type = 'application/x-sagemaker-model' )
158162 else :
159163 raise ValueError ('Cannot format model URI {}. Expecting str' .format (model_uri ))
160164
0 commit comments