11from collections import defaultdict
22from copy import deepcopy
33from math import hypot
4+ from numbers import Number
5+ from typing import Dict , List , Sequence , Tuple , Union
46
57import numpy as np
68import scipy .stats
1012from adaptive .learner .learner1D import Learner1D , _get_intervals
1113from adaptive .notebook_integration import ensure_holoviews
1214
15+ Point = Tuple [int , Number ]
16+ Points = List [Point ]
17+ Value = Union [Number , Sequence [Number ]]
18+
1319
1420class AverageLearner1D (Learner1D ):
1521 """Learns and predicts a noisy function 'f:ℝ → ℝ^N'.
@@ -77,7 +83,7 @@ def __init__(
7783 self .neighbor_sampling = neighbor_sampling
7884
7985 # Contains all samples f(x) for each
80- # point x in the form {x0:[ f_0(x0), f_1(x0), ...] , ...}
86+ # point x in the form {x0: {0: f_0(x0), 1: f_1(x0), ...} , ...}
8187 self ._data_samples = SortedDict ()
8288 # Contains the number of samples taken
8389 # at each point x in the form {x0: n0, x1: n1, ...}
@@ -95,17 +101,17 @@ def __init__(
95101 self .rescaled_error = decreasing_dict ()
96102
97103 @property
98- def nsamples (self ):
104+ def nsamples (self ) -> int :
99105 """Returns the total number of samples"""
100106 return sum (self ._number_samples .values ())
101107
102108 @property
103- def min_samples_per_point (self ):
109+ def min_samples_per_point (self ) -> int :
104110 if not self ._number_samples :
105111 return 0
106112 return min (self ._number_samples .values ())
107113
108- def ask (self , n , tell_pending = True ):
114+ def ask (self , n : int , tell_pending : bool = True ) -> Tuple [ Points , List [ float ]] :
109115 """Return 'n' points that are expected to maximally reduce the loss."""
110116 # If some point is undersampled, resample it
111117 if len (self ._undersampled_points ):
@@ -133,32 +139,34 @@ def ask(self, n, tell_pending=True):
133139
134140 return points , loss_improvements
135141
136- def _ask_for_more_samples (self , x , n ) :
142+ def _ask_for_more_samples (self , x : Number , n : int ) -> Tuple [ Points , List [ float ]] :
137143 """When asking for n points, the learner returns n times an existing point
138144 to be resampled, since in general n << min_samples and this point will
139145 need to be resampled many more times"""
140- points = [x ] * n
146+ n_existing = self ._number_samples .get (x , 0 )
147+ points = [(seed + n_existing , x ) for seed in range (n )]
148+
141149 loss_improvements = [0 ] * n # We set the loss_improvements of resamples to 0
142150 return points , loss_improvements
143151
144- def _ask_for_new_point (self , n ) :
152+ def _ask_for_new_point (self , n : int ) -> Tuple [ Points , List [ float ]] :
145153 """When asking for n new points, the learner returns n times a single
146154 new point, since in general n << min_samples and this point will need
147155 to be resampled many more times"""
148156 points , loss_improvements = self ._ask_points_without_adding (1 )
149- points = points * n
157+ points = [( seed , x ) for seed , x in zip ( range ( n ), n * points )]
150158 loss_improvements = loss_improvements + [0 ] * (n - 1 )
151159 return points , loss_improvements
152160
153- def tell_pending (self , x ):
154- if x in self .data :
155- self .pending_points .add (x )
156- else :
157- self .pending_points .add (x )
161+ def tell_pending (self , seed_x : Point ) -> None :
162+ _ , x = seed_x
163+ self .pending_points .add (seed_x )
164+ if x not in self .data :
158165 self ._update_neighbors (x , self .neighbors_combined )
159166 self ._update_losses (x , real = False )
160167
161- def tell (self , x , y ):
168+ def tell (self , seed_x : Point , y : Value ) -> None :
169+ seed , x = seed_x
162170 if y is None :
163171 raise TypeError (
164172 "Y-value may not be None, use learner.tell_pending(x)"
@@ -170,13 +178,13 @@ def tell(self, x, y):
170178
171179 if x not in self .data :
172180 self ._update_data (x , y , "new" )
173- self ._update_data_structures (x , y , "new" )
174- else :
181+ self ._update_data_structures (seed_x , y , "new" )
182+ elif seed not in self . _data_samples [ x ]: # check if the seed is new
175183 self ._update_data (x , y , "resampled" )
176- self ._update_data_structures (x , y , "resampled" )
177- self .pending_points .discard (x )
184+ self ._update_data_structures (seed_x , y , "resampled" )
185+ self .pending_points .discard (seed_x )
178186
179- def _update_rescaled_error_in_mean (self , x , point_type : str ) -> None :
187+ def _update_rescaled_error_in_mean (self , x : Number , point_type : str ) -> None :
180188 """Updates ``self.rescaled_error``.
181189
182190 Parameters
@@ -213,17 +221,18 @@ def _update_rescaled_error_in_mean(self, x, point_type: str) -> None:
213221 norm = min (d_left , d_right )
214222 self .rescaled_error [x ] = self .error [x ] / norm
215223
216- def _update_data (self , x , y , point_type : str ):
224+ def _update_data (self , x : Number , y : Value , point_type : str ) -> None :
217225 if point_type == "new" :
218226 self .data [x ] = y
219227 elif point_type == "resampled" :
220228 n = len (self ._data_samples [x ])
221229 new_average = self .data [x ] * n / (n + 1 ) + y / (n + 1 )
222230 self .data [x ] = new_average
223231
224- def _update_data_structures (self , x , y , point_type : str ):
232+ def _update_data_structures (self , seed_x : Point , y : Value , point_type : str ) -> None :
233+ seed , x = seed_x
225234 if point_type == "new" :
226- self ._data_samples [x ] = [ y ]
235+ self ._data_samples [x ] = { seed : y }
227236
228237 if not self .bounds [0 ] <= x <= self .bounds [1 ]:
229238 return
@@ -247,7 +256,7 @@ def _update_data_structures(self, x, y, point_type: str):
247256 self ._update_rescaled_error_in_mean (x , "new" )
248257
249258 elif point_type == "resampled" :
250- self ._data_samples [x ]. append ( y )
259+ self ._data_samples [x ][ seed ] = y
251260 ns = self ._number_samples
252261 ns [x ] += 1
253262 n = ns [x ]
@@ -268,7 +277,7 @@ def _update_data_structures(self, x, y, point_type: str):
268277 # the std of the mean multiplied by a t-Student factor to ensure that
269278 # the mean value lies within the correct interval of confidence
270279 y_avg = self .data [x ]
271- ys = self ._data_samples [x ]
280+ ys = self ._data_samples [x ]. values ()
272281 self .error [x ] = self ._calc_error_in_mean (ys , y_avg , n )
273282 self ._update_distances (x )
274283 self ._update_rescaled_error_in_mean (x , "resampled" )
@@ -288,15 +297,15 @@ def _update_data_structures(self, x, y, point_type: str):
288297 self ._update_interpolated_loss_in_interval (* interval )
289298 self ._oldscale = deepcopy (self ._scale )
290299
291- def _update_distances (self , x ) :
300+ def _update_distances (self , x : Number ) -> None :
292301 x_left , x_right = self .neighbors [x ]
293302 y = self .data [x ]
294303 if x_left is not None :
295304 self ._distances [x_left ] = hypot ((x - x_left ), (y - self .data [x_left ]))
296305 if x_right is not None :
297306 self ._distances [x ] = hypot ((x_right - x ), (self .data [x_right ] - y ))
298307
299- def _update_losses_resampling (self , x , real = True ):
308+ def _update_losses_resampling (self , x : Number , real = True ) -> None :
300309 """Update all losses that depend on x, whenever the new point is a re-sampled point."""
301310 # (x_left, x_right) are the "real" neighbors of 'x'.
302311 x_left , x_right = self ._find_neighbors (x , self .neighbors )
@@ -325,42 +334,43 @@ def _update_losses_resampling(self, x, real=True):
325334 if (b is not None ) and right_loss_is_unknown :
326335 self .losses_combined [x , b ] = float ("inf" )
327336
328- def _calc_error_in_mean (self , ys , y_avg , n ) :
337+ def _calc_error_in_mean (self , ys : Sequence [ Value ] , y_avg : Value , n : int ) -> float :
329338 variance_in_mean = sum ((y - y_avg ) ** 2 for y in ys ) / (n - 1 )
330339 t_student = scipy .stats .t .ppf (1 - self .alpha , df = n - 1 )
331340 return t_student * (variance_in_mean / n ) ** 0.5
332341
333- def tell_many (self , xs , ys ) :
342+ def tell_many (self , xs : Points , ys : Sequence [ Value ]) -> None :
334343 # Check that all x are within the bounds
335- if not np .prod ([x >= self .bounds [0 ] and x <= self .bounds [1 ] for x in xs ]):
344+ if not np .prod ([x >= self .bounds [0 ] and x <= self .bounds [1 ] for _ , x in xs ]):
336345 raise ValueError (
337346 "x value out of bounds, "
338347 "remove x or enlarge the bounds of the learner"
339348 )
340349
341350 # Create a mapping of points to a list of samples
342- mapping = defaultdict (list )
343- for x , y in zip (xs , ys ):
344- mapping [x ].append (y )
345-
346- for x , ys in mapping .items ():
347- if len (ys ) == 1 :
348- self .tell (x , ys [0 ])
349- elif len (ys ) > 1 :
351+ mapping = defaultdict (lambda : defaultdict (dict ))
352+ for (seed , x ), y in zip (xs , ys ):
353+ mapping [x ][seed ] = y
354+
355+ for x , seed_y_mapping in mapping .items ():
356+ if len (seed_y_mapping ) == 1 :
357+ seed , y = list (seed_y_mapping .items ())[0 ]
358+ self .tell ((seed , x ), y )
359+ elif len (seed_y_mapping ) > 1 :
350360 # If we stored more than 1 y-value for the previous x,
351361 # use a more efficient routine to tell many samples
352362 # simultaneously, before we move on to a new x
353- self .tell_many_at_point (x , ys )
363+ self .tell_many_at_point (x , seed_y_mapping )
354364
355- def tell_many_at_point (self , x , ys ) :
365+ def tell_many_at_point (self , x : float , seed_y_mapping : Dict [ int , Value ]) -> None :
356366 """Tell the learner about many samples at a certain location x.
357367
358368 Parameters
359369 ----------
360370 x : float
361371 Value from the function domain.
362- ys : List[float ]
363- List of data samples at ``x``.
372+ seed_y_mapping : Dict[int, Value ]
373+ Dictionary of ``seed`` -> ``y`` at ``x``.
364374 """
365375 # Check x is within the bounds
366376 if not np .prod (x >= self .bounds [0 ] and x <= self .bounds [1 ]):
@@ -369,16 +379,20 @@ def tell_many_at_point(self, x, ys):
369379 "remove x or enlarge the bounds of the learner"
370380 )
371381
372- ys = list (ys ) # cast to list *and* make a copy
373382 # If x is a new point:
374383 if x not in self .data :
375- y = ys .pop (0 )
384+ # we make a copy because we don't want to modify the original dict
385+ seed_y_mapping = seed_y_mapping .copy ()
386+ seed = next (iter (seed_y_mapping ))
387+ y = seed_y_mapping .pop (seed )
376388 self ._update_data (x , y , "new" )
377- self ._update_data_structures (x , y , "new" )
389+ self ._update_data_structures ((seed , x ), y , "new" )
390+
391+ ys = list (seed_y_mapping .values ()) # cast to list *and* make a copy
378392
379393 # If x is not a new point or if there were more than 1 sample in ys:
380394 if len (ys ) > 0 :
381- self ._data_samples [x ].extend ( ys )
395+ self ._data_samples [x ].update ( seed_y_mapping )
382396 n = len (ys ) + self ._number_samples [x ]
383397 self .data [x ] = (
384398 np .mean (ys ) * len (ys ) + self .data [x ] * self ._number_samples [x ]
@@ -390,24 +404,24 @@ def tell_many_at_point(self, x, ys):
390404 if n > self .min_samples :
391405 self ._undersampled_points .discard (x )
392406 self .error [x ] = self ._calc_error_in_mean (
393- self ._data_samples [x ], self .data [x ], n
407+ self ._data_samples [x ]. values () , self .data [x ], n
394408 )
395409 self ._update_distances (x )
396410 self ._update_rescaled_error_in_mean (x , "resampled" )
397411 if self .error [x ] <= self .min_error or n >= self .max_samples :
398412 self .rescaled_error .pop (x , None )
399- self ._update_scale (x , min (self ._data_samples [x ]))
400- self ._update_scale (x , max (self ._data_samples [x ]))
413+ self ._update_scale (x , min (self ._data_samples [x ]. values () ))
414+ self ._update_scale (x , max (self ._data_samples [x ]. values () ))
401415 self ._update_losses_resampling (x , real = True )
402416 if self ._scale [1 ] > self ._recompute_losses_factor * self ._oldscale [1 ]:
403417 for interval in reversed (self .losses ):
404418 self ._update_interpolated_loss_in_interval (* interval )
405419 self ._oldscale = deepcopy (self ._scale )
406420
407- def _get_data (self ):
421+ def _get_data (self ) -> SortedDict :
408422 return self ._data_samples
409423
410- def _set_data (self , data ) :
424+ def _set_data (self , data : SortedDict ) -> None :
411425 if data :
412426 for x , samples in data .items ():
413427 self .tell_many_at_point (x , samples )
0 commit comments