11from multiprocessing .managers import ListProxy , ValueProxy
22import sys
3- from multiprocessing import Pool , cpu_count , Manager
3+ from multiprocessing import cpu_count , Manager
44import time
55from typing import Callable , List , Any , Union
66from threading import Lock , Thread , Event
7+ from concurrent .futures import ThreadPoolExecutor
78import shutil
89
910from .runner_arguments import RunnerArguments , AdditionalSwiftSourcesArguments
@@ -44,18 +45,17 @@ def __init__(
4445 self ,
4546 fn : Callable ,
4647 pool_args : List [Union [RunnerArguments , AdditionalSwiftSourcesArguments ]],
47- n_processes : int = 0 ,
48+ n_threads : int = 0 ,
4849 ):
49- if n_processes == 0 :
50- # Limit the number of processes as the performance regresses after
51- # if the number is too high.
52- n_processes = min (cpu_count () * 2 , 16 )
53- self ._n_processes = n_processes
50+ if n_threads == 0 :
51+ # Limit the number of threads as the performance regresses if the
52+ # number is too high.
53+ n_threads = min (cpu_count () * 2 , 16 )
54+ self ._n_threads = n_threads
5455 self ._monitor_polling_period = 0.1
5556 self ._terminal_width = shutil .get_terminal_size ().columns
5657 self ._pool_args = pool_args
5758 self ._fn = fn
58- self ._pool = Pool (processes = self ._n_processes )
5959 self ._output_prefix = pool_args [0 ].output_prefix
6060 self ._nb_repos = len (pool_args )
6161 self ._stop_event = Event ()
@@ -70,24 +70,15 @@ def __init__(
7070 )
7171
7272 def run (self ) -> List [Any ]:
73- print (
74- "Running ``%s`` with up to %d processes."
75- % (self ._fn .__name__ , self ._n_processes )
76- )
73+ print (f"Running ``{ self ._fn .__name__ } `` with up to { self ._n_threads } processes." )
7774 if self ._verbose :
78- results = self ._pool .map_async (
79- func = self ._fn , iterable = self ._pool_args
80- ).get (timeout = 1800 )
81- self ._pool .close ()
82- self ._pool .join ()
75+ with ThreadPoolExecutor (max_workers = self ._n_threads ) as pool :
76+ results = list (pool .map (self ._fn , self ._pool_args , timeout = 1800 ))
8377 else :
8478 monitor_thread = Thread (target = self ._monitor , daemon = True )
8579 monitor_thread .start ()
86- results = self ._pool .map_async (
87- func = self ._monitored_fn , iterable = self ._pool_args
88- ).get (timeout = 1800 )
89- self ._pool .close ()
90- self ._pool .join ()
80+ with ThreadPoolExecutor (max_workers = self ._n_threads ) as pool :
81+ results = list (pool .map (self ._monitored_fn , self ._pool_args , timeout = 1800 ))
9182 self ._stop_event .set ()
9283 monitor_thread .join ()
9384 return results
@@ -131,14 +122,18 @@ def check_results(results, op) -> int:
131122 if results is None :
132123 return 0
133124 for r in results :
134- if r is not None :
135- if fail_count == 0 :
136- print ("======%s FAILURES======" % op )
137- fail_count += 1
138- if isinstance (r , str ):
139- print (r )
140- continue
141- print ("%s failed (ret=%d): %s" % (r .repo_path , r .ret , r ))
142- if r .stderr :
143- print (r .stderr .decode ())
125+ if r is None :
126+ continue
127+ if fail_count == 0 :
128+ print ("======%s FAILURES======" % op )
129+ fail_count += 1
130+ if isinstance (r , str ):
131+ print (r )
132+ continue
133+ if not hasattr (r , "repo_path" ):
134+ # TODO: create a proper Exception class with these attributes
135+ continue
136+ print ("%s failed (ret=%d): %s" % (r .repo_path , r .ret , r ))
137+ if r .stderr :
138+ print (r .stderr .decode ())
144139 return fail_count
0 commit comments