1515from ..utils import cache_latest
1616
1717
18+ def uses_nth_neighbors (n ):
19+ """Decorator to specify how many neighboring intervals the loss function uses.
20+
21+ Wraps loss functions to indicate that they expect intervals together
22+ with ``n`` nearest neighbors
23+
24+ The loss function is then guaranteed to receive the data of at least the
25+ N nearest neighbors (``nth_neighbors``) in a dict that tells you what the
26+ neighboring points of these are. And the `~adaptive.Learner1D` will
27+ then make sure that the loss is updated whenever one of the
28+ ``nth_neighbors`` changes.
29+
30+ Examples
31+ --------
32+
33+ The next function is a part of the `get_curvature_loss` function.
34+
35+ >>> @uses_nth_neighbors(1)
36+ ... def triangle_loss(interval, scale, data, neighbors):
37+ ... x_left, x_right = interval
38+ ... xs = [neighbors[x_left][0], x_left, x_right, neighbors[x_right][1]]
39+ ... # at the boundary, neighbors[<left boundary x>] is (None, <some other x>)
40+ ... xs = [x for x in xs if x is not None]
41+ ... if len(xs) <= 2:
42+ ... return (x_right - x_left) / scale[0]
43+ ...
44+ ... y_scale = scale[1] or 1
45+ ... ys_scaled = [data[x] / y_scale for x in xs]
46+ ... xs_scaled = [x / scale[0] for x in xs]
47+ ... N = len(xs) - 2
48+ ... pts = [(x, y) for x, y in zip(xs_scaled, ys_scaled)]
49+ ... return sum(volume(pts[i:i+3]) for i in range(N)) / N
50+
51+ Or you may define a loss that favours the (local) minima of a function.
52+
53+ >>> @uses_nth_neighbors(1)
54+ ... def local_minima_resolving_loss(interval, scale, data, neighbors):
55+ ... x_left, x_right = interval
56+ ... n_left = neighbors[x_left][0]
57+ ... n_right = neighbors[x_right][1]
58+ ... loss = (x_right - x_left) / scale[0]
59+ ...
60+ ... if not ((n_left is not None and data[x_left] > data[n_left])
61+ ... or (n_right is not None and data[x_right] > data[n_right])):
62+ ... return loss * 100
63+ ...
64+ ... return loss
65+ """
66+ def _wrapped (loss_per_interval ):
67+ loss_per_interval .nth_neighbors = n
68+ return loss_per_interval
69+ return _wrapped
70+
71+
72+ @uses_nth_neighbors (0 )
1873def uniform_loss (interval , scale , data , neighbors ):
1974 """Loss function that samples the domain uniformly.
2075
@@ -36,6 +91,7 @@ def uniform_loss(interval, scale, data, neighbors):
3691 return dx
3792
3893
94+ @uses_nth_neighbors (0 )
3995def default_loss (interval , scale , data , neighbors ):
4096 """Calculate loss on a single interval.
4197
@@ -70,6 +126,7 @@ def _loss_of_multi_interval(xs, ys):
70126 return sum (vol (pts [i :i + 3 ]) for i in range (N )) / N
71127
72128
129+ @uses_nth_neighbors (1 )
73130def triangle_loss (interval , scale , data , neighbors ):
74131 x_left , x_right = interval
75132 xs = [neighbors [x_left ][0 ], x_left , x_right , neighbors [x_right ][1 ]]
@@ -85,6 +142,7 @@ def triangle_loss(interval, scale, data, neighbors):
85142
86143
87144def get_curvature_loss (area_factor = 1 , euclid_factor = 0.02 , horizontal_factor = 0.02 ):
145+ @uses_nth_neighbors (1 )
88146 def curvature_loss (interval , scale , data , neighbors ):
89147 triangle_loss_ = triangle_loss (interval , scale , data , neighbors )
90148 default_loss_ = default_loss (interval , scale , data , neighbors )
@@ -118,8 +176,8 @@ def _get_neighbors_from_list(xs):
118176 return sortedcontainers .SortedDict (neighbors )
119177
120178
121- def _get_intervals (x , neighbors , nn_neighbors ):
122- nn = nn_neighbors
179+ def _get_intervals (x , neighbors , nth_neighbors ):
180+ nn = nth_neighbors
123181 i = neighbors .index (x )
124182 start = max (0 , i - nn - 1 )
125183 end = min (len (neighbors ), i + nn + 2 )
@@ -141,10 +199,6 @@ class Learner1D(BaseLearner):
141199 A function that returns the loss for a single interval of the domain.
142200 If not provided, then a default is used, which uses the scaled distance
143201 in the x-y plane as the loss. See the notes for more details.
144- nn_neighbors : int, default: 0
145- The number of neighboring intervals that the loss function
146- takes into account. If ``loss_per_interval`` doesn't use the neighbors
147- at all, then it should be 0.
148202
149203 Attributes
150204 ----------
@@ -170,16 +224,25 @@ class Learner1D(BaseLearner):
170224 A map containing points as keys to its neighbors as a tuple.
171225 At the left ``x_left`` and right ``x_left`` most boundary it has
172226 ``x_left: (None, float)`` and ``x_right: (float, None)``.
227+
228+ The `loss_per_interval` function should also have
229+ an attribute `nth_neighbors` that indicates how many of the neighboring
230+ intervals to `interval` are used. If `loss_per_interval` doesn't
231+ have such an attribute, it's assumed that is uses **no** neighboring
232+ intervals. Also see the `uses_nth_neighbors` decorator.
233+ **WARNING**: When modifying the `data` and `neighbors` datastructures
234+ the learner will behave in an undefined way.
173235 """
174236
175- def __init__ (self , function , bounds , loss_per_interval = None , nn_neighbors = 0 ):
237+ def __init__ (self , function , bounds , loss_per_interval = None ):
176238 self .function = function
177- self .nn_neighbors = nn_neighbors
178239
179- if nn_neighbors == 0 :
180- self .loss_per_interval = loss_per_interval or default_loss
240+ if hasattr ( loss_per_interval , 'nth_neighbors' ) :
241+ self .nth_neighbors = loss_per_interval . nth_neighbors
181242 else :
182- self .loss_per_interval = loss_per_interval or get_curvature_loss ()
243+ self .nth_neighbors = 0
244+
245+ self .loss_per_interval = loss_per_interval or default_loss
183246
184247 # A dict storing the loss function for each interval x_n.
185248 self .losses = {}
@@ -278,10 +341,10 @@ def _update_losses(self, x, real=True):
278341
279342 if real :
280343 # We need to update all interpolated losses in the interval
281- # (x_left, x), (x, x_right) and the nn_neighbors nearest
344+ # (x_left, x), (x, x_right) and the nth_neighbors nearest
282345 # neighboring intervals. Since the addition of the
283346 # point 'x' could change their loss.
284- for ival in _get_intervals (x , self .neighbors , self .nn_neighbors ):
347+ for ival in _get_intervals (x , self .neighbors , self .nth_neighbors ):
285348 self ._update_interpolated_loss_in_interval (* ival )
286349
287350 # Since 'x' is in between (x_left, x_right),
0 commit comments