11from __future__ import annotations
22
33import itertools
4+ import numbers
45from collections import defaultdict
56from collections .abc import Iterable
67from contextlib import suppress
78from functools import partial
89from operator import itemgetter
10+ from typing import Any , Callable , Dict , Sequence , Tuple , Union
911
1012import numpy as np
1113
1214from adaptive .learner .base_learner import BaseLearner
1315from adaptive .notebook_integration import ensure_holoviews
1416from adaptive .utils import cache_latest , named_product , restore
1517
18+ try :
19+ from typing import Literal , TypeAlias
20+ except ImportError :
21+ from typing_extensions import Literal , TypeAlias
22+
1623try :
1724 import pandas
1825
1926 with_pandas = True
20-
2127except ModuleNotFoundError :
2228 with_pandas = False
2329
2430
25- def dispatch (child_functions , arg ) :
31+ def dispatch (child_functions : list [ Callable ] , arg : Any ) -> Any :
2632 index , x = arg
2733 return child_functions [index ](x )
2834
2935
36+ STRATEGY_TYPE : TypeAlias = Literal ["loss_improvements" , "loss" , "npoints" , "cycle" ]
37+
38+ CDIMS_TYPE : TypeAlias = Union [
39+ Sequence [Dict [str , Any ]],
40+ Tuple [Sequence [str ], Sequence [Tuple [Any , ...]]],
41+ None ,
42+ ]
43+
44+
3045class BalancingLearner (BaseLearner ):
3146 r"""Choose the optimal points from a set of learners.
3247
@@ -78,13 +93,19 @@ class BalancingLearner(BaseLearner):
7893 behave in an undefined way. Change the `strategy` in that case.
7994 """
8095
81- def __init__ (self , learners , * , cdims = None , strategy = "loss_improvements" ):
96+ def __init__ (
97+ self ,
98+ learners : list [BaseLearner ],
99+ * ,
100+ cdims : CDIMS_TYPE = None ,
101+ strategy : STRATEGY_TYPE = "loss_improvements" ,
102+ ) -> None :
82103 self .learners = learners
83104
84105 # Naively we would make 'function' a method, but this causes problems
85106 # when using executors from 'concurrent.futures' because we have to
86107 # pickle the whole learner.
87- self .function = partial (dispatch , [l .function for l in self .learners ])
108+ self .function = partial (dispatch , [l .function for l in self .learners ]) # type: ignore
88109
89110 self ._ask_cache = {}
90111 self ._loss = {}
@@ -96,7 +117,7 @@ def __init__(self, learners, *, cdims=None, strategy="loss_improvements"):
96117 "A BalacingLearner can handle only one type" " of learners."
97118 )
98119
99- self .strategy = strategy
120+ self .strategy : STRATEGY_TYPE = strategy
100121
101122 def new (self ) -> BalancingLearner :
102123 """Create a new `BalancingLearner` with the same parameters."""
@@ -107,21 +128,21 @@ def new(self) -> BalancingLearner:
107128 )
108129
109130 @property
110- def data (self ):
131+ def data (self ) -> dict [ tuple [ int , Any ], Any ] :
111132 data = {}
112133 for i , l in enumerate (self .learners ):
113134 data .update ({(i , p ): v for p , v in l .data .items ()})
114135 return data
115136
116137 @property
117- def pending_points (self ):
138+ def pending_points (self ) -> set [ tuple [ int , Any ]] :
118139 pending_points = set ()
119140 for i , l in enumerate (self .learners ):
120141 pending_points .update ({(i , p ) for p in l .pending_points })
121142 return pending_points
122143
123144 @property
124- def npoints (self ):
145+ def npoints (self ) -> int :
125146 return sum (l .npoints for l in self .learners )
126147
127148 @property
@@ -134,7 +155,7 @@ def nsamples(self):
134155 )
135156
136157 @property
137- def strategy (self ):
158+ def strategy (self ) -> STRATEGY_TYPE :
138159 """Can be either 'loss_improvements' (default), 'loss', 'npoints', or
139160 'cycle'. The points that the `BalancingLearner` choses can be either
140161 based on: the best 'loss_improvements', the smallest total 'loss' of
@@ -145,7 +166,7 @@ def strategy(self):
145166 return self ._strategy
146167
147168 @strategy .setter
148- def strategy (self , strategy ) :
169+ def strategy (self , strategy : STRATEGY_TYPE ) -> None :
149170 self ._strategy = strategy
150171 if strategy == "loss_improvements" :
151172 self ._ask_and_tell = self ._ask_and_tell_based_on_loss_improvements
@@ -162,7 +183,9 @@ def strategy(self, strategy):
162183 ' strategy="npoints", or strategy="cycle" is implemented.'
163184 )
164185
165- def _ask_and_tell_based_on_loss_improvements (self , n ):
186+ def _ask_and_tell_based_on_loss_improvements (
187+ self , n : int
188+ ) -> tuple [list [tuple [int , Any ]], list [float ]]:
166189 selected = [] # tuples ((learner_index, point), loss_improvement)
167190 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
168191 for _ in range (n ):
@@ -185,7 +208,9 @@ def _ask_and_tell_based_on_loss_improvements(self, n):
185208 points , loss_improvements = map (list , zip (* selected ))
186209 return points , loss_improvements
187210
188- def _ask_and_tell_based_on_loss (self , n ):
211+ def _ask_and_tell_based_on_loss (
212+ self , n : int
213+ ) -> tuple [list [tuple [int , Any ]], list [float ]]:
189214 selected = [] # tuples ((learner_index, point), loss_improvement)
190215 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
191216 for _ in range (n ):
@@ -206,7 +231,9 @@ def _ask_and_tell_based_on_loss(self, n):
206231 points , loss_improvements = map (list , zip (* selected ))
207232 return points , loss_improvements
208233
209- def _ask_and_tell_based_on_npoints (self , n ):
234+ def _ask_and_tell_based_on_npoints (
235+ self , n : numbers .Integral
236+ ) -> tuple [list [tuple [numbers .Integral , Any ]], list [float ]]:
210237 selected = [] # tuples ((learner_index, point), loss_improvement)
211238 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
212239 for _ in range (n ):
@@ -222,7 +249,9 @@ def _ask_and_tell_based_on_npoints(self, n):
222249 points , loss_improvements = map (list , zip (* selected ))
223250 return points , loss_improvements
224251
225- def _ask_and_tell_based_on_cycle (self , n ):
252+ def _ask_and_tell_based_on_cycle (
253+ self , n : int
254+ ) -> tuple [list [tuple [numbers .Integral , Any ]], list [float ]]:
226255 points , loss_improvements = [], []
227256 for _ in range (n ):
228257 index = next (self ._cycle )
@@ -233,7 +262,9 @@ def _ask_and_tell_based_on_cycle(self, n):
233262
234263 return points , loss_improvements
235264
236- def ask (self , n , tell_pending = True ):
265+ def ask (
266+ self , n : int , tell_pending : bool = True
267+ ) -> tuple [list [tuple [numbers .Integral , Any ]], list [float ]]:
237268 """Chose points for learners."""
238269 if n == 0 :
239270 return [], []
@@ -244,20 +275,20 @@ def ask(self, n, tell_pending=True):
244275 else :
245276 return self ._ask_and_tell (n )
246277
247- def tell (self , x , y ) :
278+ def tell (self , x : tuple [ numbers . Integral , Any ], y : Any ) -> None :
248279 index , x = x
249280 self ._ask_cache .pop (index , None )
250281 self ._loss .pop (index , None )
251282 self ._pending_loss .pop (index , None )
252283 self .learners [index ].tell (x , y )
253284
254- def tell_pending (self , x ) :
285+ def tell_pending (self , x : tuple [ numbers . Integral , Any ]) -> None :
255286 index , x = x
256287 self ._ask_cache .pop (index , None )
257288 self ._loss .pop (index , None )
258289 self .learners [index ].tell_pending (x )
259290
260- def _losses (self , real = True ):
291+ def _losses (self , real : bool = True ) -> list [ float ] :
261292 losses = []
262293 loss_dict = self ._loss if real else self ._pending_loss
263294
@@ -269,11 +300,16 @@ def _losses(self, real=True):
269300 return losses
270301
271302 @cache_latest
272- def loss (self , real = True ):
303+ def loss (self , real : bool = True ) -> float :
273304 losses = self ._losses (real )
274305 return max (losses )
275306
276- def plot (self , cdims = None , plotter = None , dynamic = True ):
307+ def plot (
308+ self ,
309+ cdims : CDIMS_TYPE = None ,
310+ plotter : Callable [[BaseLearner ], Any ] | None = None ,
311+ dynamic : bool = True ,
312+ ):
277313 """Returns a DynamicMap with sliders.
278314
279315 Parameters
@@ -346,13 +382,19 @@ def plot_function(*args):
346382 vals = {d .name : d .values for d in dm .dimensions () if d .values }
347383 return hv .HoloMap (dm .select (** vals ))
348384
349- def remove_unfinished (self ):
385+ def remove_unfinished (self ) -> None :
350386 """Remove uncomputed data from the learners."""
351387 for learner in self .learners :
352388 learner .remove_unfinished ()
353389
354390 @classmethod
355- def from_product (cls , f , learner_type , learner_kwargs , combos ):
391+ def from_product (
392+ cls ,
393+ f ,
394+ learner_type : BaseLearner ,
395+ learner_kwargs : dict [str , Any ],
396+ combos : dict [str , Sequence [Any ]],
397+ ) -> BalancingLearner :
356398 """Create a `BalancingLearner` with learners of all combinations of
357399 named variables’ values. The `cdims` will be set correctly, so calling
358400 `learner.plot` will be a `holoviews.core.HoloMap` with the correct labels.
@@ -448,7 +490,11 @@ def load_dataframe(
448490 for i , gr in df .groupby (index_name ):
449491 self .learners [i ].load_dataframe (gr , ** kwargs )
450492
451- def save (self , fname , compress = True ):
493+ def save (
494+ self ,
495+ fname : Callable [[BaseLearner ], str ] | Sequence [str ],
496+ compress : bool = True ,
497+ ) -> None :
452498 """Save the data of the child learners into pickle files
453499 in a directory.
454500
@@ -486,7 +532,11 @@ def save(self, fname, compress=True):
486532 for l in self .learners :
487533 l .save (fname (l ), compress = compress )
488534
489- def load (self , fname , compress = True ):
535+ def load (
536+ self ,
537+ fname : Callable [[BaseLearner ], str ] | Sequence [str ],
538+ compress : bool = True ,
539+ ) -> None :
490540 """Load the data of the child learners from pickle files
491541 in a directory.
492542
@@ -510,20 +560,20 @@ def load(self, fname, compress=True):
510560 for l in self .learners :
511561 l .load (fname (l ), compress = compress )
512562
513- def _get_data (self ):
563+ def _get_data (self ) -> list [ Any ] :
514564 return [l ._get_data () for l in self .learners ]
515565
516- def _set_data (self , data ):
566+ def _set_data (self , data : list [ Any ] ):
517567 for l , _data in zip (self .learners , data ):
518568 l ._set_data (_data )
519569
520- def __getstate__ (self ):
570+ def __getstate__ (self ) -> tuple [ list [ BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ] :
521571 return (
522572 self .learners ,
523573 self ._cdims_default ,
524574 self .strategy ,
525575 )
526576
527- def __setstate__ (self , state ):
577+ def __setstate__ (self , state : tuple [ list [ BaseLearner ], CDIMS_TYPE , STRATEGY_TYPE ] ):
528578 learners , cdims , strategy = state
529579 self .__init__ (learners , cdims = cdims , strategy = strategy )
0 commit comments