@@ -205,6 +205,11 @@ class DataSinkInputSpec(DynamicTraitedSpec, BaseInterfaceInputSpec):
205205 'access' )
206206 encrypt_bucket_keys = traits .Bool (desc = 'Flag indicating whether to use S3 ' \
207207 'server-side AES-256 encryption' )
208+ # Set this if user wishes to override the bucket with their own
209+ bucket = traits .Generic (mandatory = False ,
210+ desc = 'Boto3 S3 bucket for manual override of bucket' )
211+ # Set this if user wishes to have local copy of files as well
212+ local_dir = traits .Str (desc = 'Copy files locally as well as to S3 bucket' )
208213
209214 # Set call-able inputs attributes
210215 def __setattr__ (self , key , value ):
@@ -385,7 +390,6 @@ def _check_s3_base_dir(self):
385390
386391 # Init variables
387392 s3_str = 's3://'
388- sep = os .path .sep
389393 base_directory = self .inputs .base_directory
390394
391395 # Explicitly lower-case the "s3"
@@ -396,11 +400,16 @@ def _check_s3_base_dir(self):
396400
397401 # Check if 's3://' in base dir
398402 if base_directory .startswith (s3_str ):
403+ # Attempt to access bucket
399404 try :
400405 # Expects bucket name to be 's3://bucket_name/base_dir/..'
401- bucket_name = base_directory .split (s3_str )[1 ].split (sep )[0 ]
406+ bucket_name = base_directory .split (s3_str )[1 ].split ('/' )[0 ]
402407 # Get the actual bucket object
403- self .bucket = self ._fetch_bucket (bucket_name )
408+ if self .inputs .bucket :
409+ self .bucket = self .inputs .bucket
410+ else :
411+ self .bucket = self ._fetch_bucket (bucket_name )
412+ # Report error in case of exception
404413 except Exception as exc :
405414 err_msg = 'Unable to access S3 bucket. Error:\n %s. Exiting...' \
406415 % exc
@@ -566,7 +575,7 @@ def _upload_to_s3(self, src, dst):
566575 bucket = self .bucket
567576 iflogger = logging .getLogger ('interface' )
568577 s3_str = 's3://'
569- s3_prefix = os . path . join ( s3_str , bucket .name )
578+ s3_prefix = s3_str + bucket .name
570579
571580 # Explicitly lower-case the "s3"
572581 if dst .lower ().startswith (s3_str ):
@@ -629,41 +638,53 @@ def _list_outputs(self):
629638 iflogger = logging .getLogger ('interface' )
630639 outputs = self .output_spec ().get ()
631640 out_files = []
632- outdir = self . inputs . base_directory
641+ # Use hardlink
633642 use_hardlink = str2bool (config .get ('execution' , 'try_hard_link_datasink' ))
634643
635- # If base directory isn't given, assume current directory
636- if not isdefined (outdir ):
637- outdir = '.'
644+ # Set local output directory if specified
645+ if isdefined (self .inputs .local_copy ):
646+ outdir = self .inputs .local_copy
647+ else :
648+ outdir = self .inputs .base_directory
649+ # If base directory isn't given, assume current directory
650+ if not isdefined (outdir ):
651+ outdir = '.'
638652
639- # Check if base directory reflects S3- bucket upload
653+ # Check if base directory reflects S3 bucket upload
640654 try :
641655 s3_flag = self ._check_s3_base_dir ()
656+ s3dir = self .inputs .base_directory
657+ if isdefined (self .inputs .container ):
658+ s3dir = os .path .join (s3dir , self .inputs .container )
642659 # If encountering an exception during bucket access, set output
643660 # base directory to a local folder
644661 except Exception as exc :
645- local_out_exception = os .path .join (os .path .expanduser ('~' ),
646- 'data_output' )
662+ if not isdefined (self .inputs .local_copy ):
663+ local_out_exception = os .path .join (os .path .expanduser ('~' ),
664+ 's3_datasink_' + self .bucket .name )
665+ outdir = local_out_exception
666+ else :
667+ outdir = self .inputs .local_copy
668+ # Log local copying directory
647669 iflogger .info ('Access to S3 failed! Storing outputs locally at: ' \
648- '%s\n Error: %s' % (local_out_exception , exc ))
649- self .inputs .base_directory = local_out_exception
650-
651- # If not accessing S3, just set outdir to local absolute path
652- if not s3_flag :
653- outdir = os .path .abspath (outdir )
670+ '%s\n Error: %s' % (outdir , exc ))
654671
655672 # If container input is given, append that to outdir
656673 if isdefined (self .inputs .container ):
657674 outdir = os .path .join (outdir , self .inputs .container )
658- # Create the directory if it doesn't exist
659- if not os .path .exists (outdir ):
660- try :
661- os .makedirs (outdir )
662- except OSError , inst :
663- if 'File exists' in inst :
664- pass
665- else :
666- raise (inst )
675+
676+ # If doing a localy output
677+ if not outdir .lower ().startswith ('s3://' ):
678+ outdir = os .path .abspath (outdir )
679+ # Create the directory if it doesn't exist
680+ if not os .path .exists (outdir ):
681+ try :
682+ os .makedirs (outdir )
683+ except OSError , inst :
684+ if 'File exists' in inst :
685+ pass
686+ else :
687+ raise (inst )
667688
668689 # Iterate through outputs attributes {key : path(s)}
669690 for key , files in self .inputs ._outputs .items ():
@@ -672,10 +693,14 @@ def _list_outputs(self):
672693 iflogger .debug ("key: %s files: %s" % (key , str (files )))
673694 files = filename_to_list (files )
674695 tempoutdir = outdir
696+ if s3_flag :
697+ s3tempoutdir = s3dir
675698 for d in key .split ('.' ):
676699 if d [0 ] == '@' :
677700 continue
678701 tempoutdir = os .path .join (tempoutdir , d )
702+ if s3_flag :
703+ s3tempoutdir = os .path .join (s3tempoutdir , d )
679704
680705 # flattening list
681706 if isinstance (files , list ):
@@ -690,25 +715,26 @@ def _list_outputs(self):
690715 src = os .path .join (src , '' )
691716 dst = self ._get_dst (src )
692717 dst = os .path .join (tempoutdir , dst )
718+ s3dst = os .path .join (s3tempoutdir , dst )
693719 dst = self ._substitute (dst )
694720 path , _ = os .path .split (dst )
695721
696- # Create output directory if it doesnt exist
697- if not os .path .exists (path ):
698- try :
699- os .makedirs (path )
700- except OSError , inst :
701- if 'File exists' in inst :
702- pass
703- else :
704- raise (inst )
705-
706722 # If we're uploading to S3
707723 if s3_flag :
724+ dst = dst .replace (outdir , self .inputs .base_directory )
708725 self ._upload_to_s3 (src , dst )
709726 out_files .append (dst )
710727 # Otherwise, copy locally src -> dst
711728 else :
729+ # Create output directory if it doesnt exist
730+ if not os .path .exists (path ):
731+ try :
732+ os .makedirs (path )
733+ except OSError , inst :
734+ if 'File exists' in inst :
735+ pass
736+ else :
737+ raise (inst )
712738 # If src is a file, copy it to dst
713739 if os .path .isfile (src ):
714740 iflogger .debug ('copyfile: %s %s' % (src , dst ))
0 commit comments