@@ -107,8 +107,8 @@ def __init__(
107107 """
108108 self ._default_bucket = None
109109 self ._default_bucket_name_override = default_bucket
110-
111- # currently is used for local_code in local mode
110+ self . s3_resource = None
111+ self . s3_client = None
112112 self .config = None
113113
114114 self ._initialize (
@@ -199,7 +199,10 @@ def upload_data(self, path, bucket=None, key_prefix="data", extra_args=None):
199199 key_suffix = name
200200
201201 bucket = bucket or self .default_bucket ()
202- s3 = self .boto_session .resource ("s3" )
202+ if self .s3_resource is None :
203+ s3 = self .boto_session .resource ("s3" , region_name = self .boto_region_name )
204+ else :
205+ s3 = self .s3_resource
203206
204207 for local_path , s3_key in files :
205208 s3 .Object (bucket , s3_key ).upload_file (local_path , ExtraArgs = extra_args )
@@ -227,7 +230,11 @@ def upload_string_as_file_body(self, body, bucket, key, kms_key=None):
227230 str: The S3 URI of the uploaded file.
228231 The URI format is: ``s3://{bucket name}/{key}``.
229232 """
230- s3 = self .boto_session .resource ("s3" )
233+ if self .s3_resource is None :
234+ s3 = self .boto_session .resource ("s3" , region_name = self .boto_region_name )
235+ else :
236+ s3 = self .s3_resource
237+
231238 s3_object = s3 .Object (bucket_name = bucket , key = key )
232239
233240 if kms_key is not None :
@@ -254,7 +261,10 @@ def download_data(self, path, bucket, key_prefix="", extra_args=None):
254261
255262 """
256263 # Initialize the S3 client.
257- s3 = self .boto_session .client ("s3" )
264+ if self .s3_client is None :
265+ s3 = self .boto_session .client ("s3" , region_name = self .boto_region_name )
266+ else :
267+ s3 = self .s3_client
258268
259269 # Initialize the variables used to loop through the contents of the S3 bucket.
260270 keys = []
@@ -299,7 +309,10 @@ def read_s3_file(self, bucket, key_prefix):
299309 str: The body of the s3 file as a string.
300310
301311 """
302- s3 = self .boto_session .client ("s3" )
312+ if self .s3_client is None :
313+ s3 = self .boto_session .client ("s3" , region_name = self .boto_region_name )
314+ else :
315+ s3 = self .s3_client
303316
304317 # Explicitly passing a None kms_key to boto3 throws a validation error.
305318 s3_object = s3 .get_object (Bucket = bucket , Key = key_prefix )
@@ -317,7 +330,10 @@ def list_s3_files(self, bucket, key_prefix):
317330 [str]: The list of files at the S3 path.
318331
319332 """
320- s3 = self .boto_session .resource ("s3" )
333+ if self .s3_resource is None :
334+ s3 = self .boto_session .resource ("s3" , region_name = self .boto_region_name )
335+ else :
336+ s3 = self .s3_resource
321337
322338 s3_bucket = s3 .Bucket (name = bucket )
323339 s3_objects = s3_bucket .objects .filter (Prefix = key_prefix ).all ()
@@ -330,6 +346,7 @@ def default_bucket(self):
330346 str: The name of the default bucket, which is of the form:
331347 ``sagemaker-{region}-{AWS account ID}``.
332348 """
349+
333350 if self ._default_bucket :
334351 return self ._default_bucket
335352
@@ -364,10 +381,14 @@ def _create_s3_bucket_if_it_does_not_exist(self, bucket_name, region):
364381 already being created, no exception is raised.
365382
366383 """
367- bucket = self .boto_session .resource ("s3" , region_name = region ).Bucket (name = bucket_name )
384+ if self .s3_resource is None :
385+ s3 = self .boto_session .resource ("s3" , region_name = region )
386+ else :
387+ s3 = self .s3_resource
388+
389+ bucket = s3 .Bucket (name = bucket_name )
368390 if bucket .creation_date is None :
369391 try :
370- s3 = self .boto_session .resource ("s3" , region_name = region )
371392 if region == "us-east-1" :
372393 # 'us-east-1' cannot be specified because it is the default region:
373394 # https://github.com/boto/boto3/issues/125
0 commit comments