3333import yaml
3434
3535import sagemaker
36+ import sagemaker .local .data
37+ import sagemaker .local .utils
38+ import sagemaker .utils
3639
3740CONTAINER_PREFIX = 'algo'
3841DOCKER_COMPOSE_FILENAME = 'docker-compose.yaml'
@@ -78,7 +81,7 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None)
7881 self .container_root = None
7982 self .container = None
8083
81- def train (self , input_data_config , hyperparameters , job_name ):
84+ def train (self , input_data_config , output_data_config , hyperparameters , job_name ):
8285 """Run a training job locally using docker-compose.
8386 Args:
8487 input_data_config (dict): The Input Data Configuration, this contains data such as the
@@ -126,23 +129,17 @@ def train(self, input_data_config, hyperparameters, job_name):
126129 msg = "Failed to run: %s, %s" % (compose_command , str (e ))
127130 raise RuntimeError (msg )
128131
129- s3_artifacts = self .retrieve_artifacts (compose_data )
132+ artifacts = self .retrieve_artifacts (compose_data , output_data_config , job_name )
130133
131134 # free up the training data directory as it may contain
132135 # lots of data downloaded from S3. This doesn't delete any local
133136 # data that was just mounted to the container.
134- _delete_tree (data_dir )
135- _delete_tree (shared_dir )
136- # Also free the container config files.
137- for host in self .hosts :
138- container_config_path = os .path .join (self .container_root , host )
139- _delete_tree (container_config_path )
140-
141- self ._cleanup ()
142- # Print our Job Complete line to have a simmilar experience to training on SageMaker where you
137+ dirs_to_delete = [data_dir , shared_dir ]
138+ self ._cleanup (dirs_to_delete )
139+ # Print our Job Complete line to have a similar experience to training on SageMaker where you
143140 # see this line at the end.
144141 print ('===== Job Complete =====' )
145- return s3_artifacts
142+ return artifacts
146143
147144 def serve (self , model_dir , environment ):
148145 """Host a local endpoint using docker-compose.
@@ -188,7 +185,7 @@ def stop_serving(self):
188185 # for serving we can delete everything in the container root.
189186 _delete_tree (self .container_root )
190187
191- def retrieve_artifacts (self , compose_data ):
188+ def retrieve_artifacts (self , compose_data , output_data_config , job_name ):
192189 """Get the model artifacts from all the container nodes.
193190
194191 Used after training completes to gather the data from all the individual containers. As the
@@ -201,26 +198,49 @@ def retrieve_artifacts(self, compose_data):
201198 Returns: Local path to the collected model artifacts.
202199
203200 """
204- # Grab the model artifacts from all the Nodes.
205- s3_artifacts = os .path .join (self .container_root , 's3_artifacts' )
206- os .mkdir (s3_artifacts )
201+ # We need a directory to store the artfiacts from all the nodes
202+ # and another one to contained the compressed final artifacts
203+ artifacts = os .path .join (self .container_root , 'artifacts' )
204+ compressed_artifacts = os .path .join (self .container_root , 'compressed_artifacts' )
205+ os .mkdir (artifacts )
206+
207+ model_artifacts = os .path .join (artifacts , 'model' )
208+ output_artifacts = os .path .join (artifacts , 'output' )
207209
208- s3_model_artifacts = os .path .join (s3_artifacts , 'model' )
209- s3_output_artifacts = os .path .join (s3_artifacts , 'output' )
210- os .mkdir (s3_model_artifacts )
211- os .mkdir (s3_output_artifacts )
210+ artifact_dirs = [model_artifacts , output_artifacts , compressed_artifacts ]
211+ for d in artifact_dirs :
212+ os .mkdir (d )
212213
214+ # Gather the artifacts from all nodes into artifacts/model and artifacts/output
213215 for host in self .hosts :
214216 volumes = compose_data ['services' ][str (host )]['volumes' ]
215-
216217 for volume in volumes :
217218 host_dir , container_dir = volume .split (':' )
218219 if container_dir == '/opt/ml/model' :
219- sagemaker .local .utils .recursive_copy (host_dir , s3_model_artifacts )
220+ sagemaker .local .utils .recursive_copy (host_dir , model_artifacts )
220221 elif container_dir == '/opt/ml/output' :
221- sagemaker .local .utils .recursive_copy (host_dir , s3_output_artifacts )
222+ sagemaker .local .utils .recursive_copy (host_dir , output_artifacts )
222223
223- return s3_model_artifacts
224+ # Tar Artifacts -> model.tar.gz and output.tar.gz
225+ model_files = [os .path .join (model_artifacts , name ) for name in os .listdir (model_artifacts )]
226+ output_files = [os .path .join (output_artifacts , name ) for name in os .listdir (output_artifacts )]
227+ sagemaker .utils .create_tar_file (model_files , os .path .join (compressed_artifacts , 'model.tar.gz' ))
228+ sagemaker .utils .create_tar_file (output_files , os .path .join (compressed_artifacts , 'output.tar.gz' ))
229+
230+ if output_data_config ['S3OutputPath' ] == '' :
231+ output_data = 'file://%s' % compressed_artifacts
232+ else :
233+ # Now we just need to move the compressed artifacts to wherever they are required
234+ output_data = sagemaker .local .utils .move_to_destination (
235+ compressed_artifacts ,
236+ output_data_config ['S3OutputPath' ],
237+ job_name ,
238+ self .sagemaker_session )
239+
240+ _delete_tree (model_artifacts )
241+ _delete_tree (output_artifacts )
242+
243+ return os .path .join (output_data , 'model.tar.gz' )
224244
225245 def write_config_files (self , host , hyperparameters , input_data_config ):
226246 """Write the config files for the training containers.
@@ -235,17 +255,22 @@ def write_config_files(self, host, hyperparameters, input_data_config):
235255 Returns: None
236256
237257 """
238-
239258 config_path = os .path .join (self .container_root , host , 'input' , 'config' )
240259
241260 resource_config = {
242261 'current_host' : host ,
243262 'hosts' : self .hosts
244263 }
245264
246- json_input_data_config = {
247- c ['ChannelName' ]: {'ContentType' : 'application/octet-stream' } for c in input_data_config
248- }
265+ print (input_data_config )
266+ json_input_data_config = {}
267+ for c in input_data_config :
268+ channel_name = c ['ChannelName' ]
269+ json_input_data_config [channel_name ] = {
270+ 'TrainingInputMode' : 'File'
271+ }
272+ if 'ContentType' in c :
273+ json_input_data_config [channel_name ]['ContentType' ] = c ['ContentType' ]
249274
250275 _write_json_file (os .path .join (config_path , 'hyperparameters.json' ), hyperparameters )
251276 _write_json_file (os .path .join (config_path , 'resourceconfig.json' ), resource_config )
@@ -261,29 +286,13 @@ def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters
261286 # mount the local directory to the container. For S3 Data we will download the S3 data
262287 # first.
263288 for channel in input_data_config :
264- if channel ['DataSource' ] and 'S3DataSource' in channel ['DataSource' ]:
265- uri = channel ['DataSource' ]['S3DataSource' ]['S3Uri' ]
266- elif channel ['DataSource' ] and 'FileDataSource' in channel ['DataSource' ]:
267- uri = channel ['DataSource' ]['FileDataSource' ]['FileUri' ]
268- else :
269- raise ValueError ('Need channel[\' DataSource\' ] to have'
270- ' [\' S3DataSource\' ] or [\' FileDataSource\' ]' )
271-
272- parsed_uri = urlparse (uri )
273- key = parsed_uri .path .lstrip ('/' )
274-
289+ uri = channel ['DataUri' ]
275290 channel_name = channel ['ChannelName' ]
276291 channel_dir = os .path .join (data_dir , channel_name )
277292 os .mkdir (channel_dir )
278293
279- if parsed_uri .scheme == 's3' :
280- bucket_name = parsed_uri .netloc
281- sagemaker .utils .download_folder (bucket_name , key , channel_dir , self .sagemaker_session )
282- elif parsed_uri .scheme == 'file' :
283- path = parsed_uri .path
284- volumes .append (_Volume (path , channel = channel_name ))
285- else :
286- raise ValueError ('Unknown URI scheme {}' .format (parsed_uri .scheme ))
294+ data_source = sagemaker .local .data .get_data_source_instance (uri , self .sagemaker_session )
295+ volumes .append (_Volume (data_source .get_root_dir (), channel = channel_name ))
287296
288297 # If there is a training script directory and it is a local directory,
289298 # mount it to the container.
@@ -301,25 +310,20 @@ def _prepare_serving_volumes(self, model_location):
301310 volumes = []
302311 host = self .hosts [0 ]
303312 # Make the model available to the container. If this is a local file just mount it to
304- # the container as a volume. If it is an S3 location download it and extract the tar file.
313+ # the container as a volume. If it is an S3 location, the DataSource will download it, we
314+ # just need to extract the tar file.
305315 host_dir = os .path .join (self .container_root , host )
306316 os .makedirs (host_dir )
307317
308- if model_location .startswith ('s3' ):
309- container_model_dir = os .path .join (self .container_root , host , 'model' )
310- os .makedirs (container_model_dir )
318+ model_data_source = sagemaker .local .data .get_data_source_instance (
319+ model_location , self .sagemaker_session )
311320
312- parsed_uri = urlparse ( model_location )
313- filename = os . path . basename ( parsed_uri . path )
314- tar_location = os . path . join ( container_model_dir , filename )
315- sagemaker . utils . download_file ( parsed_uri . netloc , parsed_uri . path , tar_location , self . sagemaker_session )
321+ for filename in model_data_source . get_file_list ():
322+ if tarfile . is_tarfile ( filename ):
323+ with tarfile . open ( filename ) as tar :
324+ tar . extractall ( path = model_data_source . get_root_dir () )
316325
317- if tarfile .is_tarfile (tar_location ):
318- with tarfile .open (tar_location ) as tar :
319- tar .extractall (path = container_model_dir )
320- volumes .append (_Volume (container_model_dir , '/opt/ml/model' ))
321- else :
322- volumes .append (_Volume (model_location , '/opt/ml/model' ))
326+ volumes .append (_Volume (model_data_source .get_root_dir (), '/opt/ml/model' ))
323327
324328 return volumes
325329
@@ -368,7 +372,6 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
368372 'networks' : {
369373 'sagemaker-local' : {'name' : 'sagemaker-local' }
370374 }
371-
372375 }
373376
374377 docker_compose_path = os .path .join (self .container_root , DOCKER_COMPOSE_FILENAME )
@@ -469,9 +472,15 @@ def _build_optml_volumes(self, host, subdirs):
469472
470473 return volumes
471474
472- def _cleanup (self ):
473- # we don't need to cleanup anything at the moment
474- pass
475+ def _cleanup (self , dirs_to_delete = None ):
476+ if dirs_to_delete :
477+ for d in dirs_to_delete :
478+ _delete_tree (d )
479+
480+ # Free the container config files.
481+ for host in self .hosts :
482+ container_config_path = os .path .join (self .container_root , host )
483+ _delete_tree (container_config_path )
475484
476485
477486class _HostingContainer (Thread ):
@@ -610,7 +619,7 @@ def _aws_credentials(session):
610619 'AWS_SECRET_ACCESS_KEY=%s' % (str (secret_key ))
611620 ]
612621 elif not _aws_credentials_available_in_metadata_service ():
613- logger .warn ("Using the short-lived AWS credentials found in session. They might expire while running." )
622+ logger .warning ("Using the short-lived AWS credentials found in session. They might expire while running." )
614623 return [
615624 'AWS_ACCESS_KEY_ID=%s' % (str (access_key )),
616625 'AWS_SECRET_ACCESS_KEY=%s' % (str (secret_key )),
0 commit comments