11import math
2+ import sys
23from collections import defaultdict
34from copy import deepcopy
45from math import hypot
2324from adaptive .notebook_integration import ensure_holoviews
2425
2526number = Union [int , float , np .int_ , np .float_ ]
26-
2727Point = Tuple [int , number ]
2828Points = List [Point ]
29- Value = Union [number , Sequence [number ], np .ndarray ]
3029
31- __all__ = ["AverageLearner1D" ]
30+ __all__ : List [ str ] = ["AverageLearner1D" ]
3231
3332
3433class AverageLearner1D (Learner1D ):
@@ -76,7 +75,7 @@ class AverageLearner1D(Learner1D):
7675
7776 def __init__ (
7877 self ,
79- function : Callable [[Tuple [int , number ]], Value ],
78+ function : Callable [[Tuple [int , number ]], number ],
8079 bounds : Tuple [number , number ],
8180 loss_per_interval : Optional [
8281 Callable [[Sequence [number ], Sequence [number ]], float ]
@@ -85,7 +84,7 @@ def __init__(
8584 alpha : float = 0.005 ,
8685 neighbor_sampling : float = 0.3 ,
8786 min_samples : int = 50 ,
88- max_samples : int = np . inf ,
87+ max_samples : int = sys . maxsize ,
8988 min_error : float = 0 ,
9089 ):
9190 if not (0 < delta <= 1 ):
@@ -201,16 +200,13 @@ def tell_pending(self, seed_x: Point) -> None:
201200 self ._update_neighbors (x , self .neighbors_combined )
202201 self ._update_losses (x , real = False )
203202
204- def tell (self , seed_x : Point , y : Value ) -> None :
203+ def tell (self , seed_x : Point , y : number ) -> None :
205204 seed , x = seed_x
206205 if y is None :
207206 raise TypeError (
208207 "Y-value may not be None, use learner.tell_pending(x)"
209208 "to indicate that this value is currently being calculated"
210209 )
211- # either it is a float/int, if not, try casting to a np.array
212- if not isinstance (y , (float , int )):
213- y = np .asarray (y , dtype = float )
214210
215211 if x not in self .data :
216212 self ._update_data (x , y , "new" )
@@ -257,15 +253,17 @@ def _update_rescaled_error_in_mean(self, x: number, point_type: str) -> None:
257253 norm = min (d_left , d_right )
258254 self .rescaled_error [x ] = self .error [x ] / norm
259255
260- def _update_data (self , x : number , y : Value , point_type : str ) -> None :
256+ def _update_data (self , x : number , y : number , point_type : str ) -> None :
261257 if point_type == "new" :
262258 self .data [x ] = y
263259 elif point_type == "resampled" :
264260 n = len (self ._data_samples [x ])
265261 new_average = self .data [x ] * n / (n + 1 ) + y / (n + 1 )
266262 self .data [x ] = new_average
267263
268- def _update_data_structures (self , seed_x : Point , y : Value , point_type : str ) -> None :
264+ def _update_data_structures (
265+ self , seed_x : Point , y : number , point_type : str
266+ ) -> None :
269267 seed , x = seed_x
270268 if point_type == "new" :
271269 self ._data_samples [x ] = {seed : y }
@@ -370,12 +368,12 @@ def _update_losses_resampling(self, x: number, real=True) -> None:
370368 if (b is not None ) and right_loss_is_unknown :
371369 self .losses_combined [x , b ] = float ("inf" )
372370
373- def _calc_error_in_mean (self , ys : Sequence [Value ], y_avg : Value , n : int ) -> float :
371+ def _calc_error_in_mean (self , ys : Sequence [number ], y_avg : number , n : int ) -> float :
374372 variance_in_mean = sum ((y - y_avg ) ** 2 for y in ys ) / (n - 1 )
375373 t_student = scipy .stats .t .ppf (1 - self .alpha , df = n - 1 )
376374 return t_student * (variance_in_mean / n ) ** 0.5
377375
378- def tell_many (self , xs : Points , ys : Sequence [Value ]) -> None :
376+ def tell_many (self , xs : Points , ys : Sequence [number ]) -> None :
379377 # Check that all x are within the bounds
380378 # TODO: remove this requirement, all other learners add the data
381379 # but ignore it going forward.
@@ -386,7 +384,7 @@ def tell_many(self, xs: Points, ys: Sequence[Value]) -> None:
386384 )
387385
388386 # Create a mapping of points to a list of samples
389- mapping : DefaultDict [number , DefaultDict [int , Value ]] = defaultdict (
387+ mapping : DefaultDict [number , DefaultDict [int , number ]] = defaultdict (
390388 lambda : defaultdict (dict )
391389 )
392390 for (seed , x ), y in zip (xs , ys ):
@@ -402,14 +400,14 @@ def tell_many(self, xs: Points, ys: Sequence[Value]) -> None:
402400 # simultaneously, before we move on to a new x
403401 self .tell_many_at_point (x , seed_y_mapping )
404402
405- def tell_many_at_point (self , x : float , seed_y_mapping : Dict [int , Value ]) -> None :
403+ def tell_many_at_point (self , x : number , seed_y_mapping : Dict [int , number ]) -> None :
406404 """Tell the learner about many samples at a certain location x.
407405
408406 Parameters
409407 ----------
410408 x : float
411409 Value from the function domain.
412- seed_y_mapping : Dict[int, Value ]
410+ seed_y_mapping : Dict[int, number ]
413411 Dictionary of ``seed`` -> ``y`` at ``x``.
414412 """
415413 # Check x is within the bounds
@@ -458,10 +456,10 @@ def tell_many_at_point(self, x: float, seed_y_mapping: Dict[int, Value]) -> None
458456 self ._update_interpolated_loss_in_interval (* interval )
459457 self ._oldscale = deepcopy (self ._scale )
460458
461- def _get_data (self ) -> SortedDict [number , Value ]:
459+ def _get_data (self ) -> SortedDict [number , number ]:
462460 return self ._data_samples
463461
464- def _set_data (self , data : SortedDict [number , Value ]) -> None :
462+ def _set_data (self , data : SortedDict [number , number ]) -> None :
465463 if data :
466464 for x , samples in data .items ():
467465 self .tell_many_at_point (x , samples )
0 commit comments