99import traceback
1010import warnings
1111from _asyncio import Future , Task
12- from concurrent .futures .process import ProcessPoolExecutor
1312from contextlib import suppress
14- from typing import Any , Callable , List , Optional , Set , Tuple , Union
15-
16- from distributed .cfexecutor import ClientExecutor
17- from distributed .client import Client
18- from ipyparallel .client .asyncresult import AsyncResult
19- from ipyparallel .client .view import ViewExecutor
13+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Union
2014
2115from adaptive .learner import BaseLearner
2216from adaptive .notebook_integration import in_ipynb , live_info , live_plot
2317
18+ _ThirdPartyClient = []
19+ _ThirdPartyExecutor = []
20+
2421try :
2522 import ipyparallel
23+ from ipyparallel .client .asyncresult import AsyncResult
2624
2725 with_ipyparallel = True
26+ _ThirdPartyClient .append (ipyparallel .Client )
27+ _ThirdPartyExecutor .append (ipyparallel .client .view .ViewExecutor )
2828except ModuleNotFoundError :
2929 with_ipyparallel = False
3030
3131try :
3232 import distributed
3333
3434 with_distributed = True
35+ _ThirdPartyClient .append (distributed .client .Client )
36+ _ThirdPartyExecutor .append (distributed .cfexecutor .ClientExecutor )
3537except ModuleNotFoundError :
3638 with_distributed = False
3739
3840try :
3941 import mpi4py .futures
4042
4143 with_mpi4py = True
44+ _ThirdPartyExecutor .append (mpi4py .futures .MPIPoolExecutor )
4245except ModuleNotFoundError :
4346 with_mpi4py = False
4447
4750
4851 asyncio .set_event_loop_policy (uvloop .EventLoopPolicy ())
4952
53+ ThirdPartyClient = Union [tuple (_ThirdPartyClient )]
54+ ThirdPartyExecutor = Union [tuple (_ThirdPartyExecutor )]
5055
5156if os .name == "nt" :
5257 if with_distributed :
@@ -72,8 +77,80 @@ def _default_executor(*args, **kwargs):
7277 _default_executor_kwargs = {}
7378
7479
80+ # -- Internal executor-related, things
81+
82+
83+ class SequentialExecutor (concurrent .Executor ):
84+ """A trivial executor that runs functions synchronously.
85+
86+ This executor is mainly for testing.
87+ """
88+
89+ def submit (self , fn : Callable , * args , ** kwargs ) -> Future :
90+ fut = concurrent .Future ()
91+ try :
92+ fut .set_result (fn (* args , ** kwargs ))
93+ except Exception as e :
94+ fut .set_exception (e )
95+ return fut
96+
97+ def map (self , fn , * iterable , timeout = None , chunksize = 1 ):
98+ return map (fn , iterable )
99+
100+ def shutdown (self , wait = True ):
101+ pass
102+
103+
104+ def _ensure_executor (
105+ executor : Optional [Union [ThirdPartyClient , concurrent .Executor ]]
106+ ) -> concurrent .Executor :
107+ if executor is None :
108+ executor = _default_executor (** _default_executor_kwargs )
109+
110+ if isinstance (executor , concurrent .Executor ):
111+ return executor
112+ elif with_ipyparallel and isinstance (executor , ipyparallel .Client ):
113+ return executor .executor ()
114+ elif with_distributed and isinstance (executor , distributed .Client ):
115+ return executor .get_executor ()
116+ else :
117+ raise TypeError (
118+ "Only a concurrent.futures.Executor, distributed.Client,"
119+ " or ipyparallel.Client can be used."
120+ )
121+
122+
123+ def _get_ncores (
124+ ex : Union [
125+ ThirdPartyExecutor ,
126+ concurrent .ProcessPoolExecutor ,
127+ concurrent .ThreadPoolExecutor ,
128+ SequentialExecutor ,
129+ ]
130+ ) -> int :
131+ """Return the maximum number of cores that an executor can use."""
132+ if with_ipyparallel and isinstance (ex , ipyparallel .client .view .ViewExecutor ):
133+ return len (ex .view )
134+ elif isinstance (
135+ ex , (concurrent .ProcessPoolExecutor , concurrent .ThreadPoolExecutor )
136+ ):
137+ return ex ._max_workers # not public API!
138+ elif isinstance (ex , SequentialExecutor ):
139+ return 1
140+ elif with_distributed and isinstance (ex , distributed .cfexecutor .ClientExecutor ):
141+ return sum (n for n in ex ._client .ncores ().values ())
142+ elif with_mpi4py and isinstance (ex , mpi4py .futures .MPIPoolExecutor ):
143+ ex .bootup () # wait until all workers are up and running
144+ return ex ._pool .size # not public API!
145+ else :
146+ raise TypeError (f"Cannot get number of cores for { ex .__class__ } " )
147+
148+
149+ # -- Runner definitions
150+
151+
75152class BaseRunner (metaclass = abc .ABCMeta ):
76- r"""Base class for runners that use `concurrent.futures.Executors` .
153+ r"""Base class for runners that use `concurrent.futures.Executor`\'s .
77154
78155 Parameters
79156 ----------
@@ -133,12 +210,17 @@ def __init__(
133210 learner : BaseLearner ,
134211 goal : Callable ,
135212 * ,
136- executor = None ,
137- ntasks = None ,
138- log = False ,
139- shutdown_executor = False ,
140- retries = 0 ,
141- raise_if_retries_exceeded = True ,
213+ executor : Union [
214+ ThirdPartyExecutor ,
215+ concurrent .ProcessPoolExecutor ,
216+ concurrent .ThreadPoolExecutor ,
217+ SequentialExecutor ,
218+ ] = None ,
219+ ntasks : int = None ,
220+ log : bool = False ,
221+ shutdown_executor : bool = False ,
222+ retries : int = 0 ,
223+ raise_if_retries_exceeded : bool = True ,
142224 ) -> None :
143225
144226 self .executor = _ensure_executor (executor )
@@ -216,7 +298,12 @@ def overhead(self):
216298 return (1 - t_function / t_total ) * 100
217299
218300 def _process_futures (
219- self , done_futs : Union [Set [Future ], Set [Future ], Set [AsyncResult ], Set [Task ]]
301+ self ,
302+ done_futs : Union [
303+ Set [Future ],
304+ Set [AsyncResult ], # XXX: AsyncResult might not be imported
305+ Set [Task ],
306+ ],
220307 ) -> None :
221308 for fut in done_futs :
222309 x = self .pending_points .pop (fut )
@@ -240,7 +327,11 @@ def _process_futures(
240327
241328 def _get_futures (
242329 self ,
243- ) -> Union [List [Task ], List [Future ], List [Future ], List [AsyncResult ]]:
330+ ) -> Union [
331+ List [Task ],
332+ List [Future ],
333+ List [AsyncResult ], # XXX: AsyncResult might not be imported
334+ ]:
244335 # Launch tasks to replace the ones that completed
245336 # on the last iteration, making sure to fill workers
246337 # that have started since the last iteration.
@@ -363,8 +454,13 @@ def __init__(
363454 learner : BaseLearner ,
364455 goal : Callable ,
365456 * ,
366- executor = None ,
367- ntasks = None ,
457+ executor : Union [
458+ ThirdPartyExecutor ,
459+ concurrent .ProcessPoolExecutor ,
460+ concurrent .ThreadPoolExecutor ,
461+ SequentialExecutor ,
462+ ] = None ,
463+ ntasks : Optional [int ] = None ,
368464 log = False ,
369465 shutdown_executor = False ,
370466 retries = 0 ,
@@ -386,9 +482,7 @@ def __init__(
386482 )
387483 self ._run ()
388484
389- def _submit (
390- self , x : Union [Tuple [int , int ], int , Tuple [float , float ], float ]
391- ) -> Union [Future , AsyncResult ]:
485+ def _submit (self , x : Union [Tuple [float , ...], float , int ]) -> Future :
392486 return self .executor .submit (self .learner .function , x )
393487
394488 def _run (self ) -> None :
@@ -494,13 +588,18 @@ def __init__(
494588 learner : BaseLearner ,
495589 goal : Optional [Callable ] = None ,
496590 * ,
497- executor = None ,
498- ntasks = None ,
499- log = False ,
500- shutdown_executor = False ,
591+ executor : Union [
592+ ThirdPartyExecutor ,
593+ concurrent .ProcessPoolExecutor ,
594+ concurrent .ThreadPoolExecutor ,
595+ SequentialExecutor ,
596+ ] = None ,
597+ ntasks : Optional [int ] = None ,
598+ log : bool = False ,
599+ shutdown_executor : bool = False ,
501600 ioloop = None ,
502- retries = 0 ,
503- raise_if_retries_exceeded = True ,
601+ retries : int = 0 ,
602+ raise_if_retries_exceeded : bool = True ,
504603 ) -> None :
505604
506605 if goal is None :
@@ -640,7 +739,7 @@ async def _run(self) -> None:
640739 await asyncio .wait (remaining )
641740 self ._cleanup ()
642741
643- def elapsed_time (self ):
742+ def elapsed_time (self ) -> float :
644743 """Return the total time elapsed since the runner
645744 was started."""
646745 if self .task .done ():
@@ -653,7 +752,7 @@ def elapsed_time(self):
653752 end_time = time .time ()
654753 return end_time - self .start_time
655754
656- def start_periodic_saving (self , save_kwargs , interval ):
755+ def start_periodic_saving (self , save_kwargs : Dict [ str , Any ], interval : int ):
657756 """Periodically save the learner's data.
658757
659758 Parameters
@@ -711,16 +810,7 @@ def simple(learner: BaseLearner, goal: Callable) -> None:
711810 learner .tell (x , y )
712811
713812
714- def replay_log (
715- learner : BaseLearner ,
716- log : List [
717- Union [
718- Tuple [str , int ],
719- Tuple [str , Tuple [int , int , int ], float ],
720- Tuple [str , Tuple [float , float , float ], float ],
721- ]
722- ],
723- ) -> None :
813+ def replay_log (learner : BaseLearner , log ) -> None :
724814 """Apply a sequence of method calls to a learner.
725815
726816 This is useful for debugging runners.
@@ -771,67 +861,3 @@ def stop_after(*, seconds=0, minutes=0, hours=0) -> Callable:
771861 """
772862 stop_time = time .time () + seconds + 60 * minutes + 3600 * hours
773863 return lambda _ : time .time () > stop_time
774-
775-
776- # -- Internal executor-related, things
777-
778-
779- class SequentialExecutor (concurrent .Executor ):
780- """A trivial executor that runs functions synchronously.
781-
782- This executor is mainly for testing.
783- """
784-
785- def submit (self , fn : Callable , * args , ** kwargs ) -> Future :
786- fut = concurrent .Future ()
787- try :
788- fut .set_result (fn (* args , ** kwargs ))
789- except Exception as e :
790- fut .set_exception (e )
791- return fut
792-
793- def map (self , fn , * iterable , timeout = None , chunksize = 1 ):
794- return map (fn , iterable )
795-
796- def shutdown (self , wait = True ):
797- pass
798-
799-
800- def _ensure_executor (
801- executor : Optional [Union [Client , Client , ProcessPoolExecutor , SequentialExecutor ]]
802- ) -> Union [SequentialExecutor , ProcessPoolExecutor , ViewExecutor , ClientExecutor ]:
803- if executor is None :
804- executor = _default_executor (** _default_executor_kwargs )
805-
806- if isinstance (executor , concurrent .Executor ):
807- return executor
808- elif with_ipyparallel and isinstance (executor , ipyparallel .Client ):
809- return executor .executor ()
810- elif with_distributed and isinstance (executor , distributed .Client ):
811- return executor .get_executor ()
812- else :
813- raise TypeError (
814- "Only a concurrent.futures.Executor, distributed.Client,"
815- " or ipyparallel.Client can be used."
816- )
817-
818-
819- def _get_ncores (
820- ex : Union [SequentialExecutor , ProcessPoolExecutor , ViewExecutor , ClientExecutor ]
821- ) -> int :
822- """Return the maximum number of cores that an executor can use."""
823- if with_ipyparallel and isinstance (ex , ipyparallel .client .view .ViewExecutor ):
824- return len (ex .view )
825- elif isinstance (
826- ex , (concurrent .ProcessPoolExecutor , concurrent .ThreadPoolExecutor )
827- ):
828- return ex ._max_workers # not public API!
829- elif isinstance (ex , SequentialExecutor ):
830- return 1
831- elif with_distributed and isinstance (ex , distributed .cfexecutor .ClientExecutor ):
832- return sum (n for n in ex ._client .ncores ().values ())
833- elif with_mpi4py and isinstance (ex , mpi4py .futures .MPIPoolExecutor ):
834- ex .bootup () # wait until all workers are up and running
835- return ex ._pool .size # not public API!
836- else :
837- raise TypeError (f"Cannot get number of cores for { ex .__class__ } " )
0 commit comments