1111
1212from six .moves .queue import Queue
1313
14- from .copy import copy_file_internal
14+ from .copy import copy_file_internal , copy_modified_time
1515from .errors import BulkCopyFailed
16+ from .tools import copy_file_data
1617
1718if typing .TYPE_CHECKING :
1819 from .base import FS
1920 from types import TracebackType
20- from typing import List , Optional , Text , Type
21+ from typing import List , Optional , Text , Type , IO , Tuple
2122
2223
2324class _Worker (threading .Thread ):
@@ -55,40 +56,32 @@ def __call__(self):
5556class _CopyTask (_Task ):
5657 """A callable that copies from one file another."""
5758
58- def __init__ (
59- self ,
60- src_fs , # type: FS
61- src_path , # type: Text
62- dst_fs , # type: FS
63- dst_path , # type: Text
64- preserve_time , # type: bool
65- ):
66- # type: (...) -> None
67- self .src_fs = src_fs
68- self .src_path = src_path
69- self .dst_fs = dst_fs
70- self .dst_path = dst_path
71- self .preserve_time = preserve_time
59+ def __init__ (self , src_file , dst_file ):
60+ # type: (IO, IO) -> None
61+ self .src_file = src_file
62+ self .dst_file = dst_file
7263
7364 def __call__ (self ):
7465 # type: () -> None
75- copy_file_internal (
76- self .src_fs ,
77- self . src_path ,
78- self . dst_fs ,
79- self .dst_path ,
80- preserve_time = self . preserve_time ,
81- )
66+ try :
67+ copy_file_data ( self .src_file , self . dst_file , chunk_size = 1024 * 1024 )
68+ finally :
69+ try :
70+ self .src_file . close ()
71+ finally :
72+ self . dst_file . close ( )
8273
8374
8475class Copier (object ):
8576 """Copy files in worker threads."""
8677
87- def __init__ (self , num_workers = 4 ):
88- # type: (int) -> None
78+ def __init__ (self , num_workers = 4 , preserve_time = False ):
79+ # type: (int, bool ) -> None
8980 if num_workers < 0 :
9081 raise ValueError ("num_workers must be >= 0" )
9182 self .num_workers = num_workers
83+ self .preserve_time = preserve_time
84+ self .all_tasks = [] # type: List[Tuple[FS, Text, FS, Text]]
9285 self .queue = None # type: Optional[Queue[_Task]]
9386 self .workers = [] # type: List[_Worker]
9487 self .errors = [] # type: List[Exception]
@@ -97,7 +90,7 @@ def __init__(self, num_workers=4):
9790 def start (self ):
9891 """Start the workers."""
9992 if self .num_workers :
100- self .queue = Queue ()
93+ self .queue = Queue (maxsize = self . num_workers )
10194 self .workers = [_Worker (self ) for _ in range (self .num_workers )]
10295 for worker in self .workers :
10396 worker .start ()
@@ -106,10 +99,18 @@ def start(self):
10699 def stop (self ):
107100 """Stop the workers (will block until they are finished)."""
108101 if self .running and self .num_workers :
102+ # Notify the workers that all tasks have arrived
103+ # and wait for them to finish.
109104 for _worker in self .workers :
110105 self .queue .put (None )
111106 for worker in self .workers :
112107 worker .join ()
108+
109+ # If the "last modified" time is to be preserved, do it now.
110+ if self .preserve_time :
111+ for args in self .all_tasks :
112+ copy_modified_time (* args )
113+
113114 # Free up references held by workers
114115 del self .workers [:]
115116 self .queue .join ()
@@ -139,8 +140,15 @@ def copy(self, src_fs, src_path, dst_fs, dst_path, preserve_time=False):
139140 if self .queue is None :
140141 # This should be the most performant for a single-thread
141142 copy_file_internal (
142- src_fs , src_path , dst_fs , dst_path , preserve_time = preserve_time
143+ src_fs , src_path , dst_fs , dst_path , preserve_time = self . preserve_time
143144 )
144145 else :
145- task = _CopyTask (src_fs , src_path , dst_fs , dst_path , preserve_time )
146+ self .all_tasks .append ((src_fs , src_path , dst_fs , dst_path ))
147+ src_file = src_fs .openbin (src_path , "r" )
148+ try :
149+ dst_file = dst_fs .openbin (dst_path , "w" )
150+ except Exception :
151+ src_file .close ()
152+ raise
153+ task = _CopyTask (src_file , dst_file )
146154 self .queue .put (task )
0 commit comments