1313from adaptive .utils import cache_latest , named_product , restore
1414
1515
16- def dispatch (
17- child_functions : List [Callable ], arg : Any ,
18- ) -> Union [int , np .float64 , float ]:
16+ def dispatch (child_functions : List [Callable ], arg : Any ,) -> Union [Any ]:
1917 index , x = arg
2018 return child_functions [index ](x )
2119
@@ -94,14 +92,14 @@ def __init__(
9492 self .strategy = strategy
9593
9694 @property
97- def data (self ) -> Dict [Tuple [int , int ], int ]:
95+ def data (self ) -> Dict [Tuple [int , Any ], Any ]:
9896 data = {}
9997 for i , l in enumerate (self .learners ):
10098 data .update ({(i , p ): v for p , v in l .data .items ()})
10199 return data
102100
103101 @property
104- def pending_points (self ) -> Set [Tuple [int , int ]]:
102+ def pending_points (self ) -> Set [Tuple [int , Any ]]:
105103 pending_points = set ()
106104 for i , l in enumerate (self .learners ):
107105 pending_points .update ({(i , p ) for p in l .pending_points })
@@ -140,7 +138,9 @@ def strategy(self, strategy):
140138 ' strategy="npoints", or strategy="cycle" is implemented.'
141139 )
142140
143- def _ask_and_tell_based_on_loss_improvements (self , n : int ) -> Any :
141+ def _ask_and_tell_based_on_loss_improvements (
142+ self , n : int
143+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
144144 selected = [] # tuples ((learner_index, point), loss_improvement)
145145 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
146146 for _ in range (n ):
@@ -165,11 +165,7 @@ def _ask_and_tell_based_on_loss_improvements(self, n: int) -> Any:
165165
166166 def _ask_and_tell_based_on_loss (
167167 self , n : int
168- ) -> Union [
169- Tuple [List [Tuple [int , float ]], List [np .float64 ]],
170- Tuple [List [Union [Tuple [int , int ], Tuple [int , float ]]], List [float ]],
171- Tuple [List [Tuple [int , int ]], List [float ]],
172- ]:
168+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
173169 selected = [] # tuples ((learner_index, point), loss_improvement)
174170 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
175171 for _ in range (n ):
@@ -192,11 +188,7 @@ def _ask_and_tell_based_on_loss(
192188
193189 def _ask_and_tell_based_on_npoints (
194190 self , n : int
195- ) -> Union [
196- Tuple [List [Union [Tuple [np .int64 , int ], Tuple [np .int64 , float ]]], List [float ]],
197- Tuple [List [Tuple [np .int64 , float ]], List [np .float64 ]],
198- Tuple [List [Tuple [np .int64 , int ]], List [float ]],
199- ]:
191+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
200192 selected = [] # tuples ((learner_index, point), loss_improvement)
201193 total_points = [l .npoints + len (l .pending_points ) for l in self .learners ]
202194 for _ in range (n ):
@@ -214,11 +206,7 @@ def _ask_and_tell_based_on_npoints(
214206
215207 def _ask_and_tell_based_on_cycle (
216208 self , n : int
217- ) -> Union [
218- Tuple [List [Tuple [int , float ]], List [np .float64 ]],
219- Tuple [List [Union [Tuple [int , int ], Tuple [int , float ]]], List [float ]],
220- Tuple [List [Tuple [int , int ]], List [float ]],
221- ]:
209+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
222210 points , loss_improvements = [], []
223211 for _ in range (n ):
224212 index = next (self ._cycle )
@@ -229,7 +217,9 @@ def _ask_and_tell_based_on_cycle(
229217
230218 return points , loss_improvements
231219
232- def ask (self , n : int , tell_pending : bool = True ) -> Tuple [List [Any ], List [float ]]:
220+ def ask (
221+ self , n : int , tell_pending : bool = True
222+ ) -> Tuple [List [Tuple [int , Any ]], List [float ]]:
233223 """Chose points for learners."""
234224 if n == 0 :
235225 return [], []
@@ -240,26 +230,20 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[List[Any], List[float]
240230 else :
241231 return self ._ask_and_tell (n )
242232
243- def tell (
244- self ,
245- x : Any ,
246- y : Union [int , np .float64 , float , Tuple [int , int ], Tuple [np .int64 , int ]],
247- ) -> None :
233+ def tell (self , x : Tuple [int , Any ], y : Any ,) -> None :
248234 index , x = x
249235 self ._ask_cache .pop (index , None )
250236 self ._loss .pop (index , None )
251237 self ._pending_loss .pop (index , None )
252238 self .learners [index ].tell (x , y )
253239
254- def tell_pending (self , x : Any ) -> None :
240+ def tell_pending (self , x : Tuple [ int , Any ] ) -> None :
255241 index , x = x
256242 self ._ask_cache .pop (index , None )
257243 self ._loss .pop (index , None )
258244 self .learners [index ].tell_pending (x )
259245
260- def _losses (
261- self , real : bool = True
262- ) -> Union [List [float ], List [np .float64 ], List [Union [float , np .float64 ]]]:
246+ def _losses (self , real : bool = True ) -> List [float ]:
263247 losses = []
264248 loss_dict = self ._loss if real else self ._pending_loss
265249
@@ -271,7 +255,7 @@ def _losses(
271255 return losses
272256
273257 @cache_latest
274- def loss (self , real : bool = True ) -> Union [np . float64 , float ]:
258+ def loss (self , real : bool = True ) -> Union [float ]:
275259 losses = self ._losses (real )
276260 return max (losses )
277261
0 commit comments