@@ -106,6 +106,8 @@ def train(self, input_data_config, output_data_config, hyperparameters, job_name
106106 data_dir = self ._create_tmp_folder ()
107107 volumes = self ._prepare_training_volumes (data_dir , input_data_config , output_data_config ,
108108 hyperparameters )
109+ # If local, source directory needs to be updated to mounted /opt/ml/code path
110+ hyperparameters = self ._update_local_src_path (hyperparameters , key = sagemaker .estimator .DIR_PARAM_NAME )
109111
110112 # Create the configuration files for each container that we will create
111113 # Each container will map the additional local volumes (if any).
@@ -169,6 +171,9 @@ def serve(self, model_dir, environment):
169171 parsed_uri = urlparse (script_dir )
170172 if parsed_uri .scheme == 'file' :
171173 volumes .append (_Volume (parsed_uri .path , '/opt/ml/code' ))
174+ # Update path to mount location
175+ environment = environment .copy ()
176+ environment [sagemaker .estimator .DIR_PARAM_NAME .upper ()] = '/opt/ml/code'
172177
173178 if _ecr_login_if_needed (self .sagemaker_session .boto_session , self .image ):
174179 _pull_image (self .image )
@@ -302,7 +307,7 @@ def _prepare_training_volumes(self, data_dir, input_data_config, output_data_con
302307 volumes .append (_Volume (data_source .get_root_dir (), channel = channel_name ))
303308
304309 # If there is a training script directory and it is a local directory,
305- # mount it to the container.
310+ # mount it to the container.
306311 if sagemaker .estimator .DIR_PARAM_NAME in hyperparameters :
307312 training_dir = json .loads (hyperparameters [sagemaker .estimator .DIR_PARAM_NAME ])
308313 parsed_uri = urlparse (training_dir )
@@ -321,6 +326,16 @@ def _prepare_training_volumes(self, data_dir, input_data_config, output_data_con
321326
322327 return volumes
323328
329+ def _update_local_src_path (self , params , key ):
330+ if key in params :
331+ src_dir = json .loads (params [key ])
332+ parsed_uri = urlparse (src_dir )
333+ if parsed_uri .scheme == 'file' :
334+ new_params = params .copy ()
335+ new_params [key ] = json .dumps ('/opt/ml/code' )
336+ return new_params
337+ return params
338+
324339 def _prepare_serving_volumes (self , model_location ):
325340 volumes = []
326341 host = self .hosts [0 ]
0 commit comments