33from collections import defaultdict
44from copy import deepcopy
55from math import hypot
6- from typing import (
7- Callable ,
8- DefaultDict ,
9- Dict ,
10- List ,
11- Optional ,
12- Sequence ,
13- Set ,
14- Tuple ,
15- Union ,
16- )
6+ from typing import Callable , DefaultDict , Dict , List , Optional , Sequence , Set , Tuple
177
188import numpy as np
199import scipy .stats
2212
2313from adaptive .learner .learner1D import Learner1D , _get_intervals
2414from adaptive .notebook_integration import ensure_holoviews
15+ from adaptive .types import Real
2516
26- number = Union [int , float , np .int_ , np .float_ ]
27- Point = Tuple [int , number ]
17+ Point = Tuple [int , Real ]
2818Points = List [Point ]
2919
3020__all__ : List [str ] = ["AverageLearner1D" ]
@@ -45,7 +35,7 @@ class AverageLearner1D(Learner1D):
4535 If not provided, then a default is used, which uses the scaled distance
4636 in the x-y plane as the loss. See the notes for more details
4737 of `adaptive.Learner1D` for more details.
48- delta : float
38+ delta : float, optional, default 0.2
4939 This parameter controls the resampling condition. A point is resampled
5040 if its uncertainty is larger than delta times the smallest neighboring
5141 interval.
@@ -75,10 +65,10 @@ class AverageLearner1D(Learner1D):
7565
7666 def __init__ (
7767 self ,
78- function : Callable [[Tuple [int , number ]], number ],
79- bounds : Tuple [number , number ],
68+ function : Callable [[Tuple [int , Real ]], Real ],
69+ bounds : Tuple [Real , Real ],
8070 loss_per_interval : Optional [
81- Callable [[Sequence [number ], Sequence [number ]], float ]
71+ Callable [[Sequence [Real ], Sequence [Real ]], float ]
8272 ] = None ,
8373 delta : float = 0.2 ,
8474 alpha : float = 0.005 ,
@@ -115,15 +105,15 @@ def __init__(
115105 self ._number_samples = SortedDict ()
116106 # This set contains the points x that have less than min_samples
117107 # samples or less than a (neighbor_sampling*100)% of their neighbors
118- self ._undersampled_points : Set [number ] = set ()
108+ self ._undersampled_points : Set [Real ] = set ()
119109 # Contains the error in the estimate of the
120110 # mean at each point x in the form {x0: error(x0), ...}
121- self .error : ItemSortedDict [number , float ] = decreasing_dict ()
111+ self .error : ItemSortedDict [Real , float ] = decreasing_dict ()
122112 # Distance between two neighboring points in the
123113 # form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
124- self ._distances : ItemSortedDict [number , float ] = decreasing_dict ()
114+ self ._distances : ItemSortedDict [Real , float ] = decreasing_dict ()
125115 # {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
126- self .rescaled_error : ItemSortedDict [number , float ] = decreasing_dict ()
116+ self .rescaled_error : ItemSortedDict [Real , float ] = decreasing_dict ()
127117
128118 @property
129119 def nsamples (self ) -> int :
@@ -165,7 +155,7 @@ def ask(self, n: int, tell_pending: bool = True) -> Tuple[Points, List[float]]:
165155
166156 return points , loss_improvements
167157
168- def _ask_for_more_samples (self , x : number , n : int ) -> Tuple [Points , List [float ]]:
158+ def _ask_for_more_samples (self , x : Real , n : int ) -> Tuple [Points , List [float ]]:
169159 """When asking for n points, the learner returns n times an existing point
170160 to be resampled, since in general n << min_samples and this point will
171161 need to be resampled many more times"""
@@ -200,7 +190,7 @@ def tell_pending(self, seed_x: Point) -> None:
200190 self ._update_neighbors (x , self .neighbors_combined )
201191 self ._update_losses (x , real = False )
202192
203- def tell (self , seed_x : Point , y : number ) -> None :
193+ def tell (self , seed_x : Point , y : Real ) -> None :
204194 seed , x = seed_x
205195 if y is None :
206196 raise TypeError (
@@ -216,7 +206,7 @@ def tell(self, seed_x: Point, y: number) -> None:
216206 self ._update_data_structures (seed_x , y , "resampled" )
217207 self .pending_points .discard (seed_x )
218208
219- def _update_rescaled_error_in_mean (self , x : number , point_type : str ) -> None :
209+ def _update_rescaled_error_in_mean (self , x : Real , point_type : str ) -> None :
220210 """Updates ``self.rescaled_error``.
221211
222212 Parameters
@@ -253,17 +243,15 @@ def _update_rescaled_error_in_mean(self, x: number, point_type: str) -> None:
253243 norm = min (d_left , d_right )
254244 self .rescaled_error [x ] = self .error [x ] / norm
255245
256- def _update_data (self , x : number , y : number , point_type : str ) -> None :
246+ def _update_data (self , x : Real , y : Real , point_type : str ) -> None :
257247 if point_type == "new" :
258248 self .data [x ] = y
259249 elif point_type == "resampled" :
260250 n = len (self ._data_samples [x ])
261251 new_average = self .data [x ] * n / (n + 1 ) + y / (n + 1 )
262252 self .data [x ] = new_average
263253
264- def _update_data_structures (
265- self , seed_x : Point , y : number , point_type : str
266- ) -> None :
254+ def _update_data_structures (self , seed_x : Point , y : Real , point_type : str ) -> None :
267255 seed , x = seed_x
268256 if point_type == "new" :
269257 self ._data_samples [x ] = {seed : y }
@@ -331,15 +319,15 @@ def _update_data_structures(
331319 self ._update_interpolated_loss_in_interval (* interval )
332320 self ._oldscale = deepcopy (self ._scale )
333321
334- def _update_distances (self , x : number ) -> None :
322+ def _update_distances (self , x : Real ) -> None :
335323 x_left , x_right = self .neighbors [x ]
336324 y = self .data [x ]
337325 if x_left is not None :
338326 self ._distances [x_left ] = hypot ((x - x_left ), (y - self .data [x_left ]))
339327 if x_right is not None :
340328 self ._distances [x ] = hypot ((x_right - x ), (self .data [x_right ] - y ))
341329
342- def _update_losses_resampling (self , x : number , real = True ) -> None :
330+ def _update_losses_resampling (self , x : Real , real = True ) -> None :
343331 """Update all losses that depend on x, whenever the new point is a re-sampled point."""
344332 # (x_left, x_right) are the "real" neighbors of 'x'.
345333 x_left , x_right = self ._find_neighbors (x , self .neighbors )
@@ -368,12 +356,12 @@ def _update_losses_resampling(self, x: number, real=True) -> None:
368356 if (b is not None ) and right_loss_is_unknown :
369357 self .losses_combined [x , b ] = float ("inf" )
370358
371- def _calc_error_in_mean (self , ys : Sequence [number ], y_avg : number , n : int ) -> float :
359+ def _calc_error_in_mean (self , ys : Sequence [Real ], y_avg : Real , n : int ) -> float :
372360 variance_in_mean = sum ((y - y_avg ) ** 2 for y in ys ) / (n - 1 )
373361 t_student = scipy .stats .t .ppf (1 - self .alpha , df = n - 1 )
374362 return t_student * (variance_in_mean / n ) ** 0.5
375363
376- def tell_many (self , xs : Points , ys : Sequence [number ]) -> None :
364+ def tell_many (self , xs : Points , ys : Sequence [Real ]) -> None :
377365 # Check that all x are within the bounds
378366 # TODO: remove this requirement, all other learners add the data
379367 # but ignore it going forward.
@@ -384,7 +372,7 @@ def tell_many(self, xs: Points, ys: Sequence[number]) -> None:
384372 )
385373
386374 # Create a mapping of points to a list of samples
387- mapping : DefaultDict [number , DefaultDict [int , number ]] = defaultdict (
375+ mapping : DefaultDict [Real , DefaultDict [int , Real ]] = defaultdict (
388376 lambda : defaultdict (dict )
389377 )
390378 for (seed , x ), y in zip (xs , ys ):
@@ -400,14 +388,14 @@ def tell_many(self, xs: Points, ys: Sequence[number]) -> None:
400388 # simultaneously, before we move on to a new x
401389 self .tell_many_at_point (x , seed_y_mapping )
402390
403- def tell_many_at_point (self , x : number , seed_y_mapping : Dict [int , number ]) -> None :
391+ def tell_many_at_point (self , x : Real , seed_y_mapping : Dict [int , Real ]) -> None :
404392 """Tell the learner about many samples at a certain location x.
405393
406394 Parameters
407395 ----------
408396 x : float
409397 Value from the function domain.
410- seed_y_mapping : Dict[int, number ]
398+ seed_y_mapping : Dict[int, Real ]
411399 Dictionary of ``seed`` -> ``y`` at ``x``.
412400 """
413401 # Check x is within the bounds
@@ -456,10 +444,10 @@ def tell_many_at_point(self, x: number, seed_y_mapping: Dict[int, number]) -> No
456444 self ._update_interpolated_loss_in_interval (* interval )
457445 self ._oldscale = deepcopy (self ._scale )
458446
459- def _get_data (self ) -> SortedDict [number , number ]:
447+ def _get_data (self ) -> SortedDict [Real , Real ]:
460448 return self ._data_samples
461449
462- def _set_data (self , data : SortedDict [number , number ]) -> None :
450+ def _set_data (self , data : SortedDict [Real , Real ]) -> None :
463451 if data :
464452 for x , samples in data .items ():
465453 self .tell_many_at_point (x , samples )
0 commit comments