1- from multiprocessing .managers import ListProxy , ValueProxy
21import sys
3- from multiprocessing import cpu_count , Manager
2+ from multiprocessing import cpu_count
43import time
5- from typing import Callable , List , Any , Union
4+ from typing import Callable , List , Any , Tuple , Union
65from threading import Lock , Thread , Event
76from concurrent .futures import ThreadPoolExecutor
87import shutil
98
109from .runner_arguments import RunnerArguments , AdditionalSwiftSourcesArguments
1110
1211
12+ class TaskTracker :
13+ _running_tasks : List [str ]
14+ _done_task_counter : int
15+ _lock : Lock
16+
17+ def __init__ (self ):
18+ self ._running_tasks = []
19+ self ._done_task_counter = 0
20+ self ._lock = Lock ()
21+
22+ def mark_task_as_running (self , task_name : str ):
23+ self ._lock .acquire ()
24+ self ._running_tasks .append (task_name )
25+ self ._lock .release ()
26+
27+ def mark_task_as_done (self , task_name : str ):
28+ self ._lock .acquire ()
29+ if task_name in self ._running_tasks :
30+ self ._running_tasks .remove (task_name )
31+ self ._done_task_counter += 1
32+ self ._lock .release ()
33+
34+ def status (self ) -> Tuple [List [str ], int ]:
35+ self ._lock .acquire ()
36+ running_tasks_str = ", " .join (self .running_tasks )
37+ done_tasks = self .done_task_counter
38+ self ._lock .release ()
39+ return running_tasks_str , done_tasks
40+
41+ @property
42+ def running_tasks (self ) -> List [str ]:
43+ return self ._running_tasks
44+
45+ @property
46+ def done_task_counter (self ) -> int :
47+ return self ._done_task_counter
48+
49+
1350class MonitoredFunction :
1451 def __init__ (
1552 self ,
1653 fn : Callable ,
17- running_tasks : ListProxy ,
18- updated_repos : ValueProxy ,
19- lock : Lock ,
54+ task_tracker : TaskTracker ,
2055 ):
2156 self .fn = fn
22- self .running_tasks = running_tasks
23- self .updated_repos = updated_repos
24- self ._lock = lock
57+ self ._task_tracker = task_tracker
2558
2659 def __call__ (self , * args : Union [RunnerArguments , AdditionalSwiftSourcesArguments ]):
2760 task_name = args [0 ].repo_name
28- self .running_tasks . append (task_name )
61+ self ._task_tracker . mark_task_as_running (task_name )
2962 result = None
3063 try :
3164 result = self .fn (* args )
3265 except Exception as e :
3366 print (e )
3467 finally :
35- self ._lock .acquire ()
36- if task_name in self .running_tasks :
37- self .running_tasks .remove (task_name )
38- self .updated_repos .set (self .updated_repos .get () + 1 )
39- self ._lock .release ()
68+ self ._task_tracker .mark_task_as_done (task_name )
4069 return result
4170
4271
@@ -61,13 +90,8 @@ def __init__(
6190 self ._stop_event = Event ()
6291 self ._verbose = pool_args [0 ].verbose
6392 if not self ._verbose :
64- manager = Manager ()
65- self ._lock = manager .Lock ()
66- self ._running_tasks = manager .list ()
67- self ._updated_repos = manager .Value ("i" , 0 )
68- self ._monitored_fn = MonitoredFunction (
69- self ._fn , self ._running_tasks , self ._updated_repos , self ._lock
70- )
93+ self ._task_tracker = TaskTracker ()
94+ self ._monitored_fn = MonitoredFunction (self ._fn , self ._task_tracker )
7195
7296 def run (self ) -> List [Any ]:
7397 print (f"Running ``{ self ._fn .__name__ } `` with up to { self ._n_threads } processes." )
@@ -86,12 +110,7 @@ def run(self) -> List[Any]:
86110 def _monitor (self ):
87111 last_output = ""
88112 while not self ._stop_event .is_set ():
89- self ._lock .acquire ()
90- current = list (self ._running_tasks )
91- current_line = ", " .join (current )
92- updated_repos = self ._updated_repos .get ()
93- self ._lock .release ()
94-
113+ current_line , updated_repos = self ._task_tracker .status ()
95114 if current_line != last_output :
96115 truncated = f"{ self ._output_prefix } [{ updated_repos } /{ self ._nb_repos } ] ({ current_line } )"
97116 if len (truncated ) > self ._terminal_width :
0 commit comments