22from collections import defaultdict
33from copy import deepcopy
44from math import hypot
5- from numbers import Number
6- from typing import Dict , List , Sequence , Tuple , Union
5+ from typing import (
6+ Callable ,
7+ DefaultDict ,
8+ Dict ,
9+ List ,
10+ Optional ,
11+ Sequence ,
12+ Set ,
13+ Tuple ,
14+ Union ,
15+ )
716
817import numpy as np
918import scipy .stats
1322from adaptive .learner .learner1D import Learner1D , _get_intervals
1423from adaptive .notebook_integration import ensure_holoviews
1524
16- Point = Tuple [int , Number ]
25+ number = Union [int , float , np .int_ , np .float_ ]
26+
27+ Point = Tuple [int , number ]
1728Points = List [Point ]
18- Value = Union [Number , Sequence [Number ]]
29+ Value = Union [number , Sequence [number ], np .ndarray ]
30+
31+ __all__ = ["AverageLearner1D" ]
1932
2033
2134class AverageLearner1D (Learner1D ):
@@ -37,21 +50,21 @@ class AverageLearner1D(Learner1D):
3750 This parameter controls the resampling condition. A point is resampled
3851 if its uncertainty is larger than delta times the smallest neighboring
3952 interval.
40- We strongly recommend 0 < delta <= 1.
41- alpha : float (0 < alpha < 1)
53+ We strongly recommend `` 0 < delta <= 1`` .
54+ alpha : float (0 < alpha < 1), default 0.005
4255 The true value of the function at x is within the confidence interval
43- [self.data[x] - self.error[x], self.data[x] +
44- self.error[x]] with probability 1-2*alpha.
45- We recommend to keep alpha=0.005.
46- neighbor_sampling : float (0 < neighbor_sampling <= 1)
56+ `` [self.data[x] - self.error[x], self.data[x] + self.error[x]]`` with
57+ probability `` 1-2*alpha`` .
58+ We recommend to keep `` alpha=0.005`` .
59+ neighbor_sampling : float (0 < neighbor_sampling <= 1), default 0.3
4760 Each new point is initially sampled at least a (neighbor_sampling*100)%
4861 of the average number of samples of its neighbors.
49- min_samples : int (min_samples > 0)
62+ min_samples : int (min_samples > 0), default 50
5063 Minimum number of samples at each point x. Each new point is initially
5164 sampled at least min_samples times.
52- max_samples : int (min_samples < max_samples)
65+ max_samples : int (min_samples < max_samples), default np.inf
5366 Maximum number of samples at each point x.
54- min_error : float (min_error >= 0)
67+ min_error : float (min_error >= 0), default 0
5568 Minimum size of the confidence intervals. The true value of the
5669 function at x is within the confidence interval [self.data[x] -
5770 self.error[x], self.data[x] + self.error[x]] with
@@ -63,15 +76,17 @@ class AverageLearner1D(Learner1D):
6376
6477 def __init__ (
6578 self ,
66- function ,
67- bounds ,
68- loss_per_interval = None ,
69- delta = 0.2 ,
70- alpha = 0.005 ,
71- neighbor_sampling = 0.3 ,
72- min_samples = 50 ,
73- max_samples = np .inf ,
74- min_error = 0 ,
79+ function : Callable [[Tuple [int , number ]], Value ],
80+ bounds : Tuple [number , number ],
81+ loss_per_interval : Optional [
82+ Callable [[Sequence [number ], Sequence [number ]], float ]
83+ ] = None ,
84+ delta : float = 0.2 ,
85+ alpha : float = 0.005 ,
86+ neighbor_sampling : float = 0.3 ,
87+ min_samples : int = 50 ,
88+ max_samples : int = np .inf ,
89+ min_error : float = 0 ,
7590 ):
7691 if not (0 < delta <= 1 ):
7792 raise ValueError ("Learner requires 0 < delta <= 1." )
@@ -101,15 +116,15 @@ def __init__(
101116 self ._number_samples = SortedDict ()
102117 # This set contains the points x that have less than min_samples
103118 # samples or less than a (neighbor_sampling*100)% of their neighbors
104- self ._undersampled_points = set ()
119+ self ._undersampled_points : Set [ number ] = set ()
105120 # Contains the error in the estimate of the
106121 # mean at each point x in the form {x0: error(x0), ...}
107- self .error = decreasing_dict ()
122+ self .error : ItemSortedDict [ number , float ] = decreasing_dict ()
108123 # Distance between two neighboring points in the
109124 # form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
110- self ._distances = decreasing_dict ()
125+ self ._distances : ItemSortedDict [ number , float ] = decreasing_dict ()
111126 # {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
112- self .rescaled_error = decreasing_dict ()
127+ self .rescaled_error : ItemSortedDict [ number , float ] = decreasing_dict ()
113128
114129 @property
115130 def nsamples (self ) -> int :
@@ -151,7 +166,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]:
151166
152167 return points , loss_improvements
153168
154- def _ask_for_more_samples (self , x : Number , n : int ) -> Tuple [Points , List [float ]]:
169+ def _ask_for_more_samples (self , x : number , n : int ) -> Tuple [Points , List [float ]]:
155170 """When asking for n points, the learner returns n times an existing point
156171 to be resampled, since in general n << min_samples and this point will
157172 need to be resampled many more times"""
@@ -205,7 +220,7 @@ def tell(self, seed_x: Point, y: Value) -> None:
205220 self ._update_data_structures (seed_x , y , "resampled" )
206221 self .pending_points .discard (seed_x )
207222
208- def _update_rescaled_error_in_mean (self , x : Number , point_type : str ) -> None :
223+ def _update_rescaled_error_in_mean (self , x : number , point_type : str ) -> None :
209224 """Updates ``self.rescaled_error``.
210225
211226 Parameters
@@ -242,7 +257,7 @@ def _update_rescaled_error_in_mean(self, x: Number, point_type: str) -> None:
242257 norm = min (d_left , d_right )
243258 self .rescaled_error [x ] = self .error [x ] / norm
244259
245- def _update_data (self , x : Number , y : Value , point_type : str ) -> None :
260+ def _update_data (self , x : number , y : Value , point_type : str ) -> None :
246261 if point_type == "new" :
247262 self .data [x ] = y
248263 elif point_type == "resampled" :
@@ -318,15 +333,15 @@ def _update_data_structures(self, seed_x: Point, y: Value, point_type: str) -> N
318333 self ._update_interpolated_loss_in_interval (* interval )
319334 self ._oldscale = deepcopy (self ._scale )
320335
321- def _update_distances (self , x : Number ) -> None :
336+ def _update_distances (self , x : number ) -> None :
322337 x_left , x_right = self .neighbors [x ]
323338 y = self .data [x ]
324339 if x_left is not None :
325340 self ._distances [x_left ] = hypot ((x - x_left ), (y - self .data [x_left ]))
326341 if x_right is not None :
327342 self ._distances [x ] = hypot ((x_right - x ), (self .data [x_right ] - y ))
328343
329- def _update_losses_resampling (self , x : Number , real = True ) -> None :
344+ def _update_losses_resampling (self , x : number , real = True ) -> None :
330345 """Update all losses that depend on x, whenever the new point is a re-sampled point."""
331346 # (x_left, x_right) are the "real" neighbors of 'x'.
332347 x_left , x_right = self ._find_neighbors (x , self .neighbors )
@@ -371,7 +386,9 @@ def tell_many(self, xs: Points, ys: Sequence[Value]) -> None:
371386 )
372387
373388 # Create a mapping of points to a list of samples
374- mapping = defaultdict (lambda : defaultdict (dict ))
389+ mapping : DefaultDict [number , DefaultDict [int , Value ]] = defaultdict (
390+ lambda : defaultdict (dict )
391+ )
375392 for (seed , x ), y in zip (xs , ys ):
376393 mapping [x ][seed ] = y
377394
@@ -411,7 +428,7 @@ def tell_many_at_point(self, x: float, seed_y_mapping: Dict[int, Value]) -> None
411428 self ._update_data (x , y , "new" )
412429 self ._update_data_structures ((seed , x ), y , "new" )
413430
414- ys = list (seed_y_mapping .values ()) # cast to list *and* make a copy
431+ ys = np . array ( list (seed_y_mapping .values ()))
415432
416433 # If x is not a new point or if there were more than 1 sample in ys:
417434 if len (ys ) > 0 :
@@ -441,10 +458,10 @@ def tell_many_at_point(self, x: float, seed_y_mapping: Dict[int, Value]) -> None
441458 self ._update_interpolated_loss_in_interval (* interval )
442459 self ._oldscale = deepcopy (self ._scale )
443460
444- def _get_data (self ) -> SortedDict :
461+ def _get_data (self ) -> SortedDict [ number , Value ] :
445462 return self ._data_samples
446463
447- def _set_data (self , data : SortedDict ) -> None :
464+ def _set_data (self , data : SortedDict [ number , Value ] ) -> None :
448465 if data :
449466 for x , samples in data .items ():
450467 self .tell_many_at_point (x , samples )
@@ -478,7 +495,7 @@ def plot(self):
478495 return p .redim (x = dict (range = plot_bounds ))
479496
480497
481- def decreasing_dict ():
498+ def decreasing_dict () -> ItemSortedDict :
482499 """This initialization orders the dictionary from large to small values"""
483500
484501 def sorting_rule (key , value ):
0 commit comments