11from __future__ import annotations
22
33import itertools
4+ import numbers
45from collections import defaultdict
56from collections .abc import Iterable
67from contextlib import suppress
78from functools import partial
89from operator import itemgetter
10+ from typing import Any , Callable , Dict , Sequence , Tuple , Union
911
1012import numpy as np
1113
1214from adaptive .learner .base_learner import BaseLearner
1315from adaptive .notebook_integration import ensure_holoviews
1416from adaptive .utils import cache_latest , named_product , restore
1517
18+ try :
19+ from typing import Literal , TypeAlias
20+ except ImportError :
21+ from typing_extensions import Literal , TypeAlias
22+
1623try :
1724 import pandas
1825
1926 with_pandas = True
20-
2127except ModuleNotFoundError :
2228 with_pandas = False
2329
2430
25- def dispatch (child_functions , arg ) :
31+ def dispatch (child_functions : list [ Callable ] , arg : Any ) -> Any :
2632 index , x = arg
2733 return child_functions [index ](x )
2834
2935
36+ STRATEGY_TYPE : TypeAlias = Literal ["loss_improvements" , "loss" , "npoints" , "cycle" ]
37+
38+ CDIMS_TYPE : TypeAlias = Union [
39+ Sequence [Dict [str , Any ]],
40+ Tuple [Sequence [str ], Sequence [Tuple [Any , ...]]],
41+ ]
42+
43+
3044class BalancingLearner (BaseLearner ):
3145 r"""Choose the optimal points from a set of learners.
3246
@@ -78,13 +92,19 @@ class BalancingLearner(BaseLearner):
7892 behave in an undefined way. Change the `strategy` in that case.
7993 """
8094
81- def __init__ (self , learners , * , cdims = None , strategy = "loss_improvements" ):
95+ def __init__ (
96+ self ,
97+ learners : list [BaseLearner ],
98+ * ,
99+ cdims : CDIMS_TYPE | None = None ,
100+ strategy : STRATEGY_TYPE = "loss_improvements" ,
101+ ) -> None :
82102 self .learners = learners
83103
84104 # Naively we would make 'function' a method, but this causes problems
85105 # when using executors from 'concurrent.futures' because we have to
86106 # pickle the whole learner.
87- self .function = partial (dispatch , [l .function for l in self .learners ])
107+ self .function = partial (dispatch , [l .function for l in self .learners ]) # type: ignore
88108
89109 self ._ask_cache = {}
90110 self ._loss = {}
@@ -96,7 +116,7 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
96116 "A BalacingLearner can handle only one type" " of learners."
97117 )
98118
99- self .strategy = strategy
119+ self .strategy : STRATEGY_TYPE = strategy
100120
101121 def new (self ) -> BalancingLearner :
102122 """Create a new `BalancingLearner` with the same parameters."""
@@ -107,21 +127,21 @@ def new(self) -> BalancingLearner:
107127 )
108128
109129 @property
110- def data (self ):
130+ def data (self ) -> dict [ tuple [ int , Any ], Any ] :
111131 data = {}
112132 for i , l in enumerate (self .learners ):
113133 data .update ({(i , p ): v for p , v in l .data .items ()})
114134 return data
115135
116136 @property
117- def pending_points (self ):
137+ def pending_points (self ) -> set [ tuple [ int , Any ]] :
118138 pending_points = set ()
119139 for i , l in enumerate (self .learners ):
120140 pending_points .update ({(i , p ) for p in l .pending_points })
121141 return pending_points
122142
123143 @property
124- def npoints (self ):
144+ def npoints (self ) -> int :
125145 return sum (l .npoints for l in self .learners )
126146
127147 @property
@@ -134,7 +154,7 @@ def nsamples(self):
134154 )
135155
136156 @property
137- def strategy (self ):
157+ def strategy (self ) -> STRATEGY_TYPE :
138158 """Can be either 'loss_improvements' (default), 'loss', 'npoints', or
139159 'cycle'. The points that the `BalancingLearner` choses can be either
140160 based on: the best 'loss_improvements', the smallest total 'loss' of
@@ -145,7 +165,7 @@ def strategy(self):
145165 return self ._strategy
146166
147167 @strategy .setter
148- def strategy (self , strategy ) :
168+ def strategy (self , strategy : STRATEGY_TYPE ) -> None :
149169 self ._strategy = strategy
150170 if strategy == "loss_improvements" :
151171 self ._ask_and_tell = self ._ask_and_tell_based_on_loss_improvements
@@ -162,7 +182,9 @@ def strategy(self, strategy):
162182 ' strategy="npoints", or strategy="cycle" is implemented.'
163183 )
164184
165- def _ask_and_tell_based_on_loss_improvements (self , n ):
185+ def _ask_and_tell_based_on_loss_improvements (
186+ self , n : int
187+ ) -> tuple [list [tuple [int , Any ]], list [float ]]:
166188 selected = [] # tuples ((learner_index, point), loss_improvement)
167189 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
168190 for _ in range (n ):
@@ -185,7 +207,9 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
185207 points , loss_improvements = map (list , zip (* selected ))
186208 return points , loss_improvements
187209
188- def _ask_and_tell_based_on_loss (self , n ):
210+ def _ask_and_tell_based_on_loss (
211+ self , n : int
212+ ) -> tuple [list [tuple [int , Any ]], list [float ]]:
189213 selected = [] # tuples ((learner_index, point), loss_improvement)
190214 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
191215 for _ in range (n ):
@@ -206,7 +230,9 @@ def _ask_and_tell_based_on_loss(self, n):
206230 points , loss_improvements = map (list , zip (* selected ))
207231 return points , loss_improvements
208232
209- def _ask_and_tell_based_on_npoints (self , n ):
233+ def _ask_and_tell_based_on_npoints (
234+ self , n : numbers .Integral
235+ ) -> tuple [list [tuple [numbers .Integral , Any ]], list [float ]]:
210236 selected = [] # tuples ((learner_index, point), loss_improvement)
211237 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
212238 for _ in range (n ):
@@ -222,7 +248,9 @@ def _ask_and_tell_based_on_npoints(self, n):
222248 points , loss_improvements = map (list , zip (* selected ))
223249 return points , loss_improvements
224250
225- def _ask_and_tell_based_on_cycle (self , n ):
251+ def _ask_and_tell_based_on_cycle (
252+ self , n : int
253+ ) -> tuple [list [tuple [numbers .Integral , Any ]], list [float ]]:
226254 points , loss_improvements = [], []
227255 for _ in range (n ):
228256 index = next (self ._cycle )
@@ -233,7 +261,9 @@ def _ask_and_tell_based_on_cycle(self, n):
233261
234262 return points , loss_improvements
235263
236- def ask (self , n , tell_pending = True ):
264+ def ask (
265+ self , n : int , tell_pending : bool = True
266+ ) -> tuple [list [tuple [numbers .Integral , Any ]], list [float ]]:
237267 """Chose points for learners."""
238268 if n == 0 :
239269 return [], []
@@ -244,20 +274,20 @@ def ask(self, n, tell_pending=True):
244274 else :
245275 return self ._ask_and_tell (n )
246276
247- def tell (self , x , y ) :
277+ def tell (self , x : tuple [ numbers . Integral , Any ], y : Any ) -> None :
248278 index , x = x
249279 self ._ask_cache .pop (index , None )
250280 self ._loss .pop (index , None )
251281 self ._pending_loss .pop (index , None )
252282 self .learners [index ].tell (x , y )
253283
254- def tell_pending (self , x ) :
284+ def tell_pending (self , x : tuple [ numbers . Integral , Any ]) -> None :
255285 index , x = x
256286 self ._ask_cache .pop (index , None )
257287 self ._loss .pop (index , None )
258288 self .learners [index ].tell_pending (x )
259289
260- def _losses (self , real = True ):
290+ def _losses (self , real : bool = True ) -> list [ float ] :
261291 losses = []
262292 loss_dict = self ._loss if real else self ._pending_loss
263293
@@ -269,11 +299,16 @@ def _losses(self, real=True):
269299 return losses
270300
271301 @cache_latest
272- def loss (self , real = True ):
302+ def loss (self , real : bool = True ) -> float :
273303 losses = self ._losses (real )
274304 return max (losses )
275305
276- def plot (self , cdims = None , plotter = None , dynamic = True ):
306+ def plot (
307+ self ,
308+ cdims : CDIMS_TYPE | None = None ,
309+ plotter : Callable [[BaseLearner ], Any ] | None = None ,
310+ dynamic : bool = True ,
311+ ):
277312 """Returns a DynamicMap with sliders.
278313
279314 Parameters
@@ -346,13 +381,19 @@ def plot_function(*args):
346381 vals = {d .name : d .values for d in dm .dimensions () if d .values }
347382 return hv .HoloMap (dm .select (** vals ))
348383
349- def remove_unfinished (self ):
384+ def remove_unfinished (self ) -> None :
350385 """Remove uncomputed data from the learners."""
351386 for learner in self .learners :
352387 learner .remove_unfinished ()
353388
354389 @classmethod
355- def from_product (cls , f , learner_type , learner_kwargs , combos ):
390+ def from_product (
391+ cls ,
392+ f ,
393+ learner_type : BaseLearner ,
394+ learner_kwargs : dict [str , Any ],
395+ combos : dict [str , Sequence [Any ]],
396+ ) -> BalancingLearner :
356397 """Create a `BalancingLearner` with learners of all combinations of
357398 named variables’ values. The `cdims` will be set correctly, so calling
358399 `learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
@@ -448,7 +489,11 @@ def load_dataframe(
448489 for i , gr in df .groupby (index_name ):
449490 self .learners [i ].load_dataframe (gr , ** kwargs )
450491
451- def save (self , fname , compress = True ):
492+ def save (
493+ self ,
494+ fname : Callable [[BaseLearner ], str ] | Sequence [str ],
495+ compress : bool = True ,
496+ ) -> None :
452497 """Save the data of the child learners into pickle files
453498 in a directory.
454499
@@ -486,7 +531,11 @@ def save(self, fname, compress=True):
486531 for l in self .learners :
487532 l .save (fname (l ), compress = compress )
488533
489- def load (self , fname , compress = True ):
534+ def load (
535+ self ,
536+ fname : Callable [[BaseLearner ], str ] | Sequence [str ],
537+ compress : bool = True ,
538+ ) -> None :
490539 """Load the data of the child learners from pickle files
491540 in a directory.
492541
@@ -510,20 +559,20 @@ def load(self, fname, compress=True):
510559 for l in self .learners :
511560 l .load (fname (l ), compress = compress )
512561
513- def _get_data (self ):
562+ def _get_data (self ) -> list [ Any ] :
514563 return [l ._get_data () for l in self .learners ]
515564
516- def _set_data (self , data ):
565+ def _set_data (self , data : list [ Any ] ):
517566 for l , _data in zip (self .learners , data ):
518567 l ._set_data (_data )
519568
520- def __getstate__ (self ):
569+ def __getstate__ (self ) -> tuple [ list [ BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ] :
521570 return (
522571 self .learners ,
523572 self ._cdims_default ,
524573 self .strategy ,
525574 )
526575
527- def __setstate__ (self , state ):
576+ def __setstate__ (self , state : tuple [ list [ BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ] ):
528577 learners , cdims , strategy = state
529578 self .__init__ (learners , cdims = cdims , strategy = strategy )
0 commit comments