22import sys
33from multiprocessing import Pool , cpu_count , Manager
44import time
5- from typing import Callable , List , Any
6- from threading import Thread , Event , Lock
5+ from typing import Callable , List , Any , Union
6+ from threading import Lock , Thread , Event
77import shutil
88
9+ from .runner_arguments import RunnerArguments , AdditionalSwiftSourcesArguments
10+
11+
912class MonitoredFunction :
10- def __init__ (self , fn : Callable , running_tasks : ListProxy , updated_repos : ValueProxy , lock : Lock ):
13+ def __init__ (
14+ self ,
15+ fn : Callable ,
16+ running_tasks : ListProxy ,
17+ updated_repos : ValueProxy ,
18+ lock : Lock
19+ ):
1120 self .fn = fn
1221 self .running_tasks = running_tasks
1322 self .updated_repos = updated_repos
1423 self ._lock = lock
1524
16- def __call__ (self , * args ):
17- task_name = args [0 ][ 2 ]
25+ def __call__ (self , * args : Union [ RunnerArguments , AdditionalSwiftSourcesArguments ] ):
26+ task_name = args [0 ]. repo_name
1827 self .running_tasks .append (task_name )
28+ result = None
1929 try :
20- return self .fn (* args )
30+ result = self .fn (* args )
31+ except Exception as e :
32+ print (e )
2133 finally :
2234 self ._lock .acquire ()
23- self .running_tasks .remove (task_name )
35+ if task_name in self .running_tasks :
36+ self .running_tasks .remove (task_name )
2437 self .updated_repos .set (self .updated_repos .get () + 1 )
2538 self ._lock .release ()
39+ return result
2640
2741
2842class ParallelRunner :
29- def __init__ (self , fn : Callable , pool_args : List [List [Any ]], n_processes : int = 0 ):
43+ def __init__ (
44+ self ,
45+ fn : Callable ,
46+ pool_args : List [Union [RunnerArguments , AdditionalSwiftSourcesArguments ]],
47+ n_processes : int = 0 ,
48+ ):
3049 self ._monitor_polling_period = 0.1
3150 if n_processes == 0 :
3251 n_processes = cpu_count () * 2
3352 self ._terminal_width = shutil .get_terminal_size ().columns
3453 self ._n_processes = n_processes
3554 self ._pool_args = pool_args
55+ manager = Manager ()
56+ self ._lock = manager .Lock ()
57+ self ._running_tasks = manager .list ()
58+ self ._updated_repos = manager .Value ("i" , 0 )
3659 self ._fn = fn
37- self ._lock = Manager ().Lock ()
38- self ._pool = Pool (
39- processes = self ._n_processes , initializer = self ._child_init , initargs = (self ._lock ,)
40- )
41- self ._verbose = pool_args [0 ][len (pool_args [0 ]) - 1 ]
60+ self ._pool = Pool (processes = self ._n_processes )
61+ self ._verbose = pool_args [0 ].verbose
62+ self ._output_prefix = pool_args [0 ].output_prefix
4263 self ._nb_repos = len (pool_args )
4364 self ._stop_event = Event ()
44- self ._running_tasks = Manager (). list ()
45- self ._updated_repos = Manager (). Value ( 'i' , 0 )
46- self . _monitored_fn = MonitoredFunction ( self . _fn , self . _running_tasks , self . _updated_repos , self . _lock )
65+ self ._monitored_fn = MonitoredFunction (
66+ self ._fn , self . _running_tasks , self . _updated_repos , self . _lock
67+ )
4768
4869 def run (self ) -> List [Any ]:
4970 print (
5071 "Running ``%s`` with up to %d processes."
5172 % (self ._fn .__name__ , self ._n_processes )
5273 )
53-
5474 if self ._verbose :
5575 results = self ._pool .map_async (
5676 func = self ._fn , iterable = self ._pool_args
57- ).get ()
77+ ).get (timeout = 1800 )
5878 self ._pool .close ()
5979 self ._pool .join ()
6080 else :
6181 monitor_thread = Thread (target = self ._monitor , daemon = True )
6282 monitor_thread .start ()
6383 results = self ._pool .map_async (
6484 func = self ._monitored_fn , iterable = self ._pool_args
65- ).get ()
85+ ).get (timeout = 1800 )
6686 self ._pool .close ()
6787 self ._pool .join ()
6888 self ._stop_event .set ()
@@ -72,11 +92,14 @@ def run(self) -> List[Any]:
7292 def _monitor (self ):
7393 last_output = ""
7494 while not self ._stop_event .is_set ():
95+ self ._lock .acquire ()
7596 current = list (self ._running_tasks )
7697 current_line = ", " .join (current )
98+ updated_repos = self ._updated_repos .get ()
99+ self ._lock .release ()
77100
78101 if current_line != last_output :
79- truncated = f"Updating [ { self ._updated_repos . get () } /{ self ._nb_repos } ] ({ current_line } )"
102+ truncated = ( f" { self ._output_prefix } [ { updated_repos } /{ self ._nb_repos } ] ({ current_line } )")
80103 if len (truncated ) > self ._terminal_width :
81104 ellipsis_marker = " ..."
82105 truncated = (
@@ -89,17 +112,11 @@ def _monitor(self):
89112
90113 time .sleep (self ._monitor_polling_period )
91114
92- sys .stdout .write ("\r " + " " * len (last_output ) + "\r " )
115+ sys .stdout .write ("\r " + " " * len (last_output ) + "\r \n " )
93116 sys .stdout .flush ()
94117
95118 @staticmethod
96- def _clear_lines (n ):
97- for _ in range (n ):
98- sys .stdout .write ("\x1b [1A" )
99- sys .stdout .write ("\x1b [2K" )
100-
101- @staticmethod
102- def check_results (results , op ):
119+ def check_results (results , op ) -> int :
103120 """Function used to check the results of ParallelRunner.
104121
105122 NOTE: This function was originally located in the shell module of
@@ -123,7 +140,3 @@ def check_results(results, op):
123140 print (r .stderr .decode ())
124141 return fail_count
125142
126- @staticmethod
127- def _child_init (lck ):
128- global lock
129- lock = lck
0 commit comments