99import numpy as np
1010from scipy import interpolate
1111import scipy .spatial
12+ from sortedcontainers import SortedKeyList
1213
1314from adaptive .learner .base_learner import BaseLearner
1415from adaptive .notebook_integration import ensure_holoviews , ensure_plotly
@@ -91,7 +92,6 @@ def choose_point_in_simplex(simplex, transform=None):
9192 distance_matrix = scipy .spatial .distance .squareform (distances )
9293 i , j = np .unravel_index (np .argmax (distance_matrix ),
9394 distance_matrix .shape )
94-
9595 point = (simplex [i , :] + simplex [j , :]) / 2
9696
9797 if transform is not None :
@@ -100,6 +100,15 @@ def choose_point_in_simplex(simplex, transform=None):
100100 return point
101101
102102
103+ def _simplex_evaluation_priority (key ):
104+ # We round the loss to 8 digits such that losses
105+ # are equal up to numerical precision will be considered
106+ # to be equal. This is needed because we want the learner
107+ # to behave in a deterministic fashion.
108+ loss , simplex , subsimplex = key
109+ return - round (loss , ndigits = 8 ), simplex , subsimplex or (0 ,)
110+
111+
103112class LearnerND (BaseLearner ):
104113 """Learns and predicts a function 'f: ℝ^N → ℝ^M'.
105114
@@ -200,7 +209,7 @@ def __init__(self, func, bounds, loss_per_simplex=None):
200209 # so when popping an item, you should check that the simplex that has
201210 # been returned has not been deleted. This checking is done by
202211 # _pop_highest_existing_simplex
203- self ._simplex_queue = [] # heap
212+ self ._simplex_queue = SortedKeyList ( key = _simplex_evaluation_priority )
204213
205214 @property
206215 def npoints (self ):
@@ -344,9 +353,7 @@ def _update_subsimplex_losses(self, simplex, new_subsimplices):
344353 subtriangulation = self ._subtriangulations [simplex ]
345354 for subsimplex in new_subsimplices :
346355 subloss = subtriangulation .volume (subsimplex ) * loss_density
347- subloss = round (subloss , ndigits = 8 )
348- heapq .heappush (self ._simplex_queue ,
349- (- subloss , simplex , subsimplex ))
356+ self ._simplex_queue .add ((subloss , simplex , subsimplex ))
350357
351358 def _ask_and_tell_pending (self , n = 1 ):
352359 xs , losses = zip (* (self ._ask () for _ in range (n )))
@@ -386,7 +393,7 @@ def _pop_highest_existing_simplex(self):
386393 # find the simplex with the highest loss, we do need to check that the
387394 # simplex hasn't been deleted yet
388395 while len (self ._simplex_queue ):
389- loss , simplex , subsimplex = heapq . heappop ( self ._simplex_queue )
396+ loss , simplex , subsimplex = self ._simplex_queue . pop ( 0 )
390397 if (subsimplex is None
391398 and simplex in self .tri .simplices
392399 and simplex not in self ._subtriangulations ):
@@ -462,8 +469,7 @@ def _update_losses(self, to_delete: set, to_add: set):
462469 self ._try_adding_pending_point_to_simplex (p , simplex )
463470
464471 if simplex not in self ._subtriangulations :
465- loss = round (loss , ndigits = 8 )
466- heapq .heappush (self ._simplex_queue , (- loss , simplex , None ))
472+ self ._simplex_queue .add ((loss , simplex , None ))
467473 continue
468474
469475 self ._update_subsimplex_losses (
@@ -488,7 +494,7 @@ def _recompute_all_losses(self):
488494 return
489495
490496 # reset the _simplex_queue
491- self ._simplex_queue = []
497+ self ._simplex_queue = SortedKeyList ( key = _simplex_evaluation_priority )
492498
493499 # recompute all losses
494500 for simplex in self .tri .simplices :
@@ -497,8 +503,7 @@ def _recompute_all_losses(self):
497503
498504 # now distribute it around the the children if they are present
499505 if simplex not in self ._subtriangulations :
500- loss = round (loss , ndigits = 8 )
501- heapq .heappush (self ._simplex_queue , (- loss , simplex , None ))
506+ self ._simplex_queue .add ((loss , simplex , None ))
502507 continue
503508
504509 self ._update_subsimplex_losses (
0 commit comments