33import heapq
44import itertools
55import math
6+ from collections import Iterable
67
78import numpy as np
89import sortedcontainers
910
1011from .base_learner import BaseLearner
12+ from .learnerND import volume
13+ from .triangulation import simplex_volume_in_embedding
1114from ..notebook_integration import ensure_holoviews
1215from ..utils import cache_latest
1316
1417
15- def uniform_loss (interval , scale , function_values ):
18+ def uses_nth_neighbors (n ):
19+ """Decorator to specify how many neighboring intervals the loss function uses.
20+
21+ Wraps loss functions to indicate that they expect intervals together
22+ with ``n`` nearest neighbors
23+
24+ The loss function will then receive the data of the N nearest neighbors
25+ (``nth_neighbors``) aling with the data of the interval itself in a dict.
26+ The `~adaptive.Learner1D` will also make sure that the loss is updated
27+ whenever one of the ``nth_neighbors`` changes.
28+
29+ Examples
30+ --------
31+
32+ The next function is a part of the `curvature_loss_function` function.
33+
34+ >>> @uses_nth_neighbors(1)
35+ ...def triangle_loss(xs, ys):
36+ ... xs = [x for x in xs if x is not None]
37+ ... ys = [y for y in ys if y is not None]
38+ ...
39+ ... if len(xs) == 2: # we do not have enough points for a triangle
40+ ... return xs[1] - xs[0]
41+ ...
42+ ... N = len(xs) - 2 # number of constructed triangles
43+ ... if isinstance(ys[0], Iterable):
44+ ... pts = [(x, *y) for x, y in zip(xs, ys)]
45+ ... vol = simplex_volume_in_embedding
46+ ... else:
47+ ... pts = [(x, y) for x, y in zip(xs, ys)]
48+ ... vol = volume
49+ ... return sum(vol(pts[i:i+3]) for i in range(N)) / N
50+
51+ Or you may define a loss that favours the (local) minima of a function,
52+ assuming that you know your function will have a single float as output.
53+
54+ >>> @uses_nth_neighbors(1)
55+ ... def local_minima_resolving_loss(xs, ys):
56+ ... dx = xs[2] - xs[1] # the width of the interval of interest
57+ ...
58+ ... if not ((ys[0] is not None and ys[0] > ys[1])
59+ ... or (ys[3] is not None and ys[3] > ys[2])):
60+ ... return loss * 100
61+ ...
62+ ... return loss
63+ """
64+ def _wrapped (loss_per_interval ):
65+ loss_per_interval .nth_neighbors = n
66+ return loss_per_interval
67+ return _wrapped
68+
69+ @uses_nth_neighbors (0 )
70+ def uniform_loss (xs , ys ):
1671 """Loss function that samples the domain uniformly.
1772
1873 Works with `~adaptive.Learner1D` only.
@@ -27,33 +82,58 @@ def uniform_loss(interval, scale, function_values):
2782 ... loss_per_interval=uniform_sampling_1d)
2883 >>>
2984 """
30- x_left , x_right = interval
31- x_scale , _ = scale
32- dx = (x_right - x_left ) / x_scale
85+ dx = xs [1 ] - xs [0 ]
3386 return dx
3487
3588
36- def default_loss (interval , scale , function_values ):
89+ @uses_nth_neighbors (0 )
90+ def default_loss (xs , ys ):
3791 """Calculate loss on a single interval.
3892
3993 Currently returns the rescaled length of the interval. If one of the
4094 y-values is missing, returns 0 (so the intervals with missing data are
4195 never touched. This behavior should be improved later.
4296 """
43- x_left , x_right = interval
44- y_right , y_left = function_values [x_right ], function_values [x_left ]
45- x_scale , y_scale = scale
46- dx = (x_right - x_left ) / x_scale
47- if y_scale == 0 :
48- loss = dx
97+ dx = xs [1 ] - xs [0 ]
98+ if isinstance (ys [0 ], Iterable ):
99+ dy = [abs (a - b ) for a , b in zip (* ys )]
100+ return np .hypot (dx , dy ).max ()
101+ else :
102+ dy = ys [1 ] - ys [0 ]
103+ return np .hypot (dx , dy )
104+
105+
106+ @uses_nth_neighbors (1 )
107+ def triangle_loss (xs , ys ):
108+ xs = [x for x in xs if x is not None ]
109+ ys = [y for y in ys if y is not None ]
110+
111+ if len (xs ) == 2 : # we do not have enough points for a triangle
112+ return xs [1 ] - xs [0 ]
113+
114+ N = len (xs ) - 2 # number of constructed triangles
115+ if isinstance (ys [0 ], Iterable ):
116+ pts = [(x , * y ) for x , y in zip (xs , ys )]
117+ vol = simplex_volume_in_embedding
49118 else :
50- dy = (y_right - y_left ) / y_scale
51- try :
52- len (dy )
53- loss = np .hypot (dx , dy ).max ()
54- except TypeError :
55- loss = math .hypot (dx , dy )
56- return loss
119+ pts = [(x , y ) for x , y in zip (xs , ys )]
120+ vol = volume
121+ return sum (vol (pts [i :i + 3 ]) for i in range (N )) / N
122+
123+
124+ def curvature_loss_function (area_factor = 1 , euclid_factor = 0.02 , horizontal_factor = 0.02 ):
125+ @uses_nth_neighbors (1 )
126+ def curvature_loss (xs , ys ):
127+ xs_middle = xs [1 :3 ]
128+ ys_middle = xs [1 :3 ]
129+
130+ triangle_loss_ = triangle_loss (xs , ys )
131+ default_loss_ = default_loss (xs_middle , ys_middle )
132+ dx = xs_middle [0 ] - xs_middle [0 ]
133+ return (area_factor * (triangle_loss_ ** 0.5 )
134+ + euclid_factor * default_loss_
135+ + horizontal_factor * dx )
136+ return curvature_loss
57137
58138
59139def linspace (x_left , x_right , n ):
@@ -79,6 +159,15 @@ def _get_neighbors_from_list(xs):
79159 return sortedcontainers .SortedDict (neighbors )
80160
81161
162+ def _get_intervals (x , neighbors , nth_neighbors ):
163+ nn = nth_neighbors
164+ i = neighbors .index (x )
165+ start = max (0 , i - nn - 1 )
166+ end = min (len (neighbors ), i + nn + 2 )
167+ points = neighbors .keys ()[start :end ]
168+ return list (zip (points , points [1 :]))
169+
170+
82171class Learner1D (BaseLearner ):
83172 """Learns and predicts a function 'f:ℝ → ℝ^N'.
84173
@@ -103,21 +192,34 @@ class Learner1D(BaseLearner):
103192
104193 Notes
105194 -----
106- `loss_per_interval` takes 3 parameters: ``interval``, ``scale``, and
107- ``function_values``, and returns a scalar; the loss over the interval.
108-
109- interval : (float, float)
110- The bounds of the interval.
111- scale : (float, float)
112- The x and y scale over all the intervals, useful for rescaling the
113- interval loss.
114- function_values : dict(float → float)
115- A map containing evaluated function values. It is guaranteed
116- to have values for both of the points in 'interval'.
195+ `loss_per_interval` takes 2 parameters: ``xs`` and ``ys``, and returns a
196+ scalar; the loss over the interval.
197+ xs : tuple of floats
198+ The x values of the interval, if `nth_neighbors` is greater than zero it
199+ also contains the x-values of the neighbors of the interval, in ascending
200+ order. The interval we want to know the loss of is then the middle
201+ interval. If no neighbor is available (at the edges of the domain) then
202+ `None` will take the place of the x-value of the neighbor.
203+ ys : tuple of function values
204+ The output values of the function when evaluated at the `xs`. This is
205+ either a float or a tuple of floats in the case of vector output.
206+
207+
208+ The `loss_per_interval` function may also have an attribute `nth_neighbors`
209+ that indicates how many of the neighboring intervals to `interval` are used.
210+ If `loss_per_interval` doesn't have such an attribute, it's assumed that is
211+ uses **no** neighboring intervals. Also see the `uses_nth_neighbors`
212+ decorator for more information.
117213 """
118214
119215 def __init__ (self , function , bounds , loss_per_interval = None ):
120216 self .function = function
217+
218+ if hasattr (loss_per_interval , 'nth_neighbors' ):
219+ self .nth_neighbors = loss_per_interval .nth_neighbors
220+ else :
221+ self .nth_neighbors = 0
222+
121223 self .loss_per_interval = loss_per_interval or default_loss
122224
123225 # A dict storing the loss function for each interval x_n.
@@ -176,25 +278,60 @@ def loss(self, real=True):
176278 losses = self .losses if real else self .losses_combined
177279 return max (losses .values ()) if len (losses ) > 0 else float ('inf' )
178280
281+ def _scale_x (self , x ):
282+ if x is None :
283+ return None
284+ return x / self ._scale [0 ]
285+
286+ def _scale_y (self , y ):
287+ if y is None :
288+ return None
289+ y_scale = self ._scale [1 ] or 1
290+ return y / y_scale
291+
292+ def _get_point_by_index (self , ind ):
293+ if ind < 0 or ind >= len (self .neighbors ):
294+ return None
295+ return self .neighbors .keys ()[ind ]
296+
297+ def _get_loss_in_interval (self , x_left , x_right ):
298+ assert x_left is not None and x_right is not None
299+
300+ if x_right - x_left < self ._dx_eps :
301+ return 0
302+
303+ nn = self .nth_neighbors
304+ i = self .neighbors .index (x_left )
305+ start = i - nn
306+ end = i + nn + 2
307+
308+ xs = [self ._get_point_by_index (i ) for i in range (start , end )]
309+ ys = [self .data .get (x , None ) for x in xs ]
310+
311+ xs_scaled = tuple (self ._scale_x (x ) for x in xs )
312+ ys_scaled = tuple (self ._scale_y (y ) for y in ys )
313+
314+ # we need to compute the loss for this interval
315+ return self .loss_per_interval (xs_scaled , ys_scaled )
316+
179317 def _update_interpolated_loss_in_interval (self , x_left , x_right ):
180- if x_left is not None and x_right is not None :
181- dx = x_right - x_left
182- if dx < self ._dx_eps :
183- loss = 0
184- else :
185- loss = self .loss_per_interval ((x_left , x_right ),
186- self ._scale , self .data )
187- self .losses [x_left , x_right ] = loss
188-
189- # Iterate over all interpolated intervals in between
190- # x_left and x_right and set the newly interpolated loss.
191- a , b = x_left , None
192- while b != x_right :
193- b = self .neighbors_combined [a ][1 ]
194- self .losses_combined [a , b ] = (b - a ) * loss / dx
195- a = b
318+ if x_left is None or x_right is None :
319+ return
320+
321+ loss = self ._get_loss_in_interval (x_left , x_right )
322+ self .losses [x_left , x_right ] = loss
323+
324+ # Iterate over all interpolated intervals in between
325+ # x_left and x_right and set the newly interpolated loss.
326+ a , b = x_left , None
327+ dx = x_right - x_left
328+ while b != x_right :
329+ b = self .neighbors_combined [a ][1 ]
330+ self .losses_combined [a , b ] = (b - a ) * loss / dx
331+ a = b
196332
197333 def _update_losses (self , x , real = True ):
334+ """Update all losses that depend on x"""
198335 # When we add a new point x, we should update the losses
199336 # (x_left, x_right) are the "real" neighbors of 'x'.
200337 x_left , x_right = self ._find_neighbors (x , self .neighbors )
@@ -207,10 +344,11 @@ def _update_losses(self, x, real=True):
207344
208345 if real :
209346 # We need to update all interpolated losses in the interval
210- # (x_left, x) and (x, x_right). Since the addition of the point
211- # 'x' could change their loss.
212- self ._update_interpolated_loss_in_interval (x_left , x )
213- self ._update_interpolated_loss_in_interval (x , x_right )
347+ # (x_left, x), (x, x_right) and the nth_neighbors nearest
348+ # neighboring intervals. Since the addition of the
349+ # point 'x' could change their loss.
350+ for ival in _get_intervals (x , self .neighbors , self .nth_neighbors ):
351+ self ._update_interpolated_loss_in_interval (* ival )
214352
215353 # Since 'x' is in between (x_left, x_right),
216354 # we get rid of the interval.
@@ -284,6 +422,9 @@ def tell(self, x, y):
284422 if x in self .data :
285423 # The point is already evaluated before
286424 return
425+ if y is None :
426+ raise TypeError ("Y-value may not be None, use learner.tell_pending(x)"
427+ "to indicate that this value is currently being calculated" )
287428
288429 # either it is a float/int, if not, try casting to a np.array
289430 if not isinstance (y , (float , int )):
@@ -356,10 +497,8 @@ def tell_many(self, xs, ys, *, force=False):
356497
357498 # The the losses for the "real" intervals.
358499 self .losses = {}
359- for x_left , x_right in intervals :
360- self .losses [x_left , x_right ] = (
361- self .loss_per_interval ((x_left , x_right ), self ._scale , self .data )
362- if x_right - x_left >= self ._dx_eps else 0 )
500+ for ival in intervals :
501+ self .losses [ival ] = self ._get_loss_in_interval (* ival )
363502
364503 # List with "real" intervals that have interpolated intervals inside
365504 to_interpolate = []
0 commit comments