88
99import numpy as np
1010import sortedcontainers
11+ import sortedcollections
1112
1213from adaptive .learner .base_learner import BaseLearner
1314from adaptive .learner .learnerND import volume
@@ -225,9 +226,6 @@ def __init__(self, function, bounds, loss_per_interval=None):
225226
226227 self .loss_per_interval = loss_per_interval or default_loss
227228
228- # A dict storing the loss function for each interval x_n.
229- self .losses = {}
230- self .losses_combined = {}
231229
232230 # When the scale changes by a factor 2, the losses are
233231 # recomputed. This is tunable such that we can test
@@ -249,6 +247,10 @@ def __init__(self, function, bounds, loss_per_interval=None):
249247 self ._scale = [bounds [1 ] - bounds [0 ], 0 ]
250248 self ._oldscale = deepcopy (self ._scale )
251249
250+ # A LossManager storing the loss function for each interval x_n.
251+ self .losses = loss_manager (self ._scale [0 ])
252+ self .losses_combined = loss_manager (self ._scale [0 ])
253+
252254 # The precision in 'x' below which we set losses to 0.
253255 self ._dx_eps = 2 * max (np .abs (bounds )) * np .finfo (float ).eps
254256
@@ -284,7 +286,10 @@ def npoints(self):
284286 @cache_latest
285287 def loss (self , real = True ):
286288 losses = self .losses if real else self .losses_combined
287- return max (losses .values ()) if len (losses ) > 0 else float ('inf' )
289+ if not losses :
290+ return np .inf
291+ max_interval , max_loss = losses .peekitem (0 )
292+ return max_loss
288293
289294 def _scale_x (self , x ):
290295 if x is None :
@@ -454,8 +459,7 @@ def tell(self, x, y):
454459
455460 # If the scale has increased enough, recompute all losses.
456461 if self ._scale [1 ] > self ._recompute_losses_factor * self ._oldscale [1 ]:
457-
458- for interval in self .losses :
462+ for interval in reversed (self .losses ):
459463 self ._update_interpolated_loss_in_interval (* interval )
460464
461465 self ._oldscale = deepcopy (self ._scale )
@@ -504,18 +508,18 @@ def tell_many(self, xs, ys, *, force=False):
504508 for neighbors in (self .neighbors , self .neighbors_combined )]
505509
506510 # The the losses for the "real" intervals.
507- self .losses = {}
511+ self .losses = loss_manager ( self . _scale [ 0 ])
508512 for ival in intervals :
509513 self .losses [ival ] = self ._get_loss_in_interval (* ival )
510514
511515 # List with "real" intervals that have interpolated intervals inside
512516 to_interpolate = []
513517
514- self .losses_combined = {}
518+ self .losses_combined = loss_manager ( self . _scale [ 0 ])
515519 for ival in intervals_combined :
516520 # If this interval exists in 'losses' then copy it otherwise
517521 # calculate it.
518- if ival in self .losses :
522+ if ival in reversed ( self .losses ) :
519523 self .losses_combined [ival ] = self .losses [ival ]
520524 else :
521525 # Set all losses to inf now, later they might be udpdated if the
@@ -530,7 +534,7 @@ def tell_many(self, xs, ys, *, force=False):
530534 to_interpolate .append ((x_left , x_right ))
531535
532536 for ival in to_interpolate :
533- if ival in self .losses :
537+ if ival in reversed ( self .losses ) :
534538 # If this interval does not exist it should already
535539 # have an inf loss.
536540 self ._update_interpolated_loss_in_interval (* ival )
@@ -566,64 +570,57 @@ def _ask_points_without_adding(self, n):
566570 if len (missing_bounds ) >= n :
567571 return missing_bounds [:n ], [np .inf ] * n
568572
569- def finite_loss (loss , xs ):
570- # If the loss is infinite we return the
571- # distance between the two points.
572- if math .isinf (loss ):
573- loss = (xs [1 ] - xs [0 ]) / self ._scale [0 ]
574-
575- # We round the loss to 12 digits such that losses
576- # are equal up to numerical precision will be considered
577- # equal.
578- return round (loss , ndigits = 12 )
579-
580- quals = [(- finite_loss (loss , x ), x , 1 )
581- for x , loss in self .losses_combined .items ()]
582-
583573 # Add bound intervals to quals if bounds were missing.
584574 if len (self .data ) + len (self .pending_points ) == 0 :
585575 # We don't have any points, so return a linspace with 'n' points.
586576 return np .linspace (* self .bounds , n ).tolist (), [np .inf ] * n
587- elif len (missing_bounds ) > 0 :
577+
578+ quals = loss_manager (self ._scale [0 ])
579+ if len (missing_bounds ) > 0 :
588580 # There is at least one point in between the bounds.
589581 all_points = list (self .data .keys ()) + list (self .pending_points )
590582 intervals = [(self .bounds [0 ], min (all_points )),
591583 (max (all_points ), self .bounds [1 ])]
592584 for interval , bound in zip (intervals , self .bounds ):
593585 if bound in missing_bounds :
594- qual = (- finite_loss (np .inf , interval ), interval , 1 )
595- quals .append (qual )
596-
597- # Calculate how many points belong to each interval.
598- points , loss_improvements = self ._subdivide_quals (
599- quals , n - len (missing_bounds ))
600-
601- points = missing_bounds + points
602- loss_improvements = [np .inf ] * len (missing_bounds ) + loss_improvements
586+ quals [(* interval , 1 )] = np .inf
603587
604- return points , loss_improvements
588+ points_to_go = n - len ( missing_bounds )
605589
606- def _subdivide_quals (self , quals , n ):
607590 # Calculate how many points belong to each interval.
608- heapq .heapify (quals )
609-
610- for _ in range (n ):
611- quality , x , n = quals [0 ]
612- if abs (x [1 ] - x [0 ]) / (n + 1 ) <= self ._dx_eps :
613- # The interval is too small and should not be subdivided.
614- quality = np .inf
615- # XXX: see https://gitlab.kwant-project.org/qt/adaptive/issues/104
616- heapq .heapreplace (quals , (quality * n / (n + 1 ), x , n + 1 ))
591+ i , i_max = 0 , len (self .losses_combined )
592+ for _ in range (points_to_go ):
593+ qual , loss_qual = quals .peekitem (0 ) if quals else (None , 0 )
594+ ival , loss_ival = self .losses_combined .peekitem (i ) if i < i_max else (None , 0 )
595+
596+ if (qual is None
597+ or (ival is not None
598+ and self ._loss (self .losses_combined , ival )
599+ >= self ._loss (quals , qual ))):
600+ i += 1
601+ quals [(* ival , 2 )] = loss_ival / 2
602+ else :
603+ quals .pop (qual , None )
604+ * xs , n = qual
605+ quals [(* xs , n + 1 )] = loss_qual * n / (n + 1 )
617606
618607 points = list (itertools .chain .from_iterable (
619- linspace (* interval , n ) for quality , interval , n in quals ))
608+ linspace (* ival , n ) for ( * ival , n ) in quals ))
620609
621610 loss_improvements = list (itertools .chain .from_iterable (
622- itertools .repeat (- quality , n - 1 )
623- for quality , interval , n in quals ))
611+ itertools .repeat (quals [x0 , x1 , n ], n - 1 )
612+ for (x0 , x1 , n ) in quals ))
613+
614+ # add the missing bounds
615+ points = missing_bounds + points
616+ loss_improvements = [np .inf ] * len (missing_bounds ) + loss_improvements
624617
625618 return points , loss_improvements
626619
620+ def _loss (self , mapping , ival ):
621+ loss = mapping [ival ]
622+ return finite_loss (ival , loss , self ._scale [0 ])
623+
627624 def plot (self ):
628625 """Returns a plot of the evaluated data.
629626
@@ -658,3 +655,42 @@ def _get_data(self):
658655
659656 def _set_data (self , data ):
660657 self .tell_many (* zip (* data .items ()))
658+
659+
660+ def _fix_deepcopy (sorted_dict , x_scale ):
661+ # XXX: until https://github.com/grantjenks/sortedcollections/issues/5 is fixed
662+ import types
663+ def __deepcopy__ (self , memo ):
664+ items = deepcopy (list (self .items ()))
665+ lm = loss_manager (self .x_scale )
666+ lm .update (items )
667+ return lm
668+ sorted_dict .x_scale = x_scale
669+ sorted_dict .__deepcopy__ = types .MethodType (__deepcopy__ , sorted_dict )
670+
671+
672+ def loss_manager (x_scale ):
673+ def sort_key (ival , loss ):
674+ loss , ival = finite_loss (ival , loss , x_scale )
675+ return - loss , ival
676+ sorted_dict = sortedcollections .ItemSortedDict (sort_key )
677+ _fix_deepcopy (sorted_dict , x_scale )
678+ return sorted_dict
679+
680+
681+ def finite_loss (ival , loss , x_scale ):
682+ """Get the socalled finite_loss of an interval in order to be able to
683+ sort intervals that have infinite loss."""
684+ # If the loss is infinite we return the
685+ # distance between the two points.
686+ if math .isinf (loss ):
687+ loss = (ival [1 ] - ival [0 ]) / x_scale
688+ if len (ival ) == 3 :
689+ # Used when constructing quals. Last item is
690+ # the number of points inside the qual.
691+ loss /= ival [2 ]
692+
693+ # We round the loss to 12 digits such that losses
694+ # are equal up to numerical precision will be considered
695+ # equal.
696+ return round (loss , ndigits = 12 ), ival
0 commit comments