2020
2121@uses_nth_neighbors (0 )
2222def uniform_loss (
23- xs : Union [Tuple [float , float ], Tuple [np . float64 , np . float64 ]],
24- ys : Union [Tuple [float , float ], Tuple [np . float64 , np . float64 ]],
25- ) -> Union [np . float64 , float ]:
23+ xs : Union [Tuple [float , float ], Tuple [float , float ]],
24+ ys : Union [Tuple [float , float ], Tuple [float , float ]],
25+ ) -> Union [float , float ]:
2626 """Loss function that samples the domain uniformly.
2727
2828 Works with `~adaptive.Learner1D` only.
@@ -43,18 +43,9 @@ def uniform_loss(
4343
4444@uses_nth_neighbors (0 )
4545def default_loss (
46- xs : Union [
47- Tuple [float , float ],
48- Tuple [np .float64 , float ],
49- Tuple [np .float64 , np .float64 ],
50- Tuple [float , np .float64 ],
51- ],
52- ys : Union [
53- Tuple [float , float ],
54- Tuple [np .ndarray , np .ndarray ],
55- Tuple [np .float64 , np .float64 ],
56- ],
57- ) -> np .float64 :
46+ xs : Tuple [float , float ],
47+ ys : Union [Tuple [np .ndarray , np .ndarray ], Tuple [float , float ]],
48+ ) -> float :
5849 """Calculate loss on a single interval.
5950
6051 Currently returns the rescaled length of the interval. If one of the
@@ -71,7 +62,7 @@ def default_loss(
7162
7263
7364@uses_nth_neighbors (1 )
74- def triangle_loss (xs : Any , ys : Any ) -> Union [np . float64 , float ]:
65+ def triangle_loss (xs : Any , ys : Any ) -> Union [float , float ]:
7566 xs = [x for x in xs if x is not None ]
7667 ys = [y for y in ys if y is not None ]
7768
@@ -110,10 +101,8 @@ def curvature_loss(xs, ys):
110101
111102
112103def linspace (
113- x_left : Union [int , np .float64 , float ],
114- x_right : Union [int , np .float64 , float ],
115- n : int ,
116- ) -> Union [List [float ], List [np .float64 ]]:
104+ x_left : Union [int , float , float ], x_right : Union [int , float , float ], n : int ,
105+ ) -> Union [List [float ], List [float ]]:
117106 """This is equivalent to
118107 'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
119108 but it is 15-30 times faster for small 'n'."""
@@ -136,7 +125,7 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
136125
137126
138127def _get_intervals (
139- x : Union [int , np . float64 , float ], neighbors : SortedDict , nth_neighbors : int
128+ x : Union [int , float , float ], neighbors : SortedDict , nth_neighbors : int
140129) -> Any :
141130 nn = nth_neighbors
142131 i = neighbors .index (x )
@@ -262,38 +251,36 @@ def npoints(self) -> int:
262251 return len (self .data )
263252
264253 @cache_latest
265- def loss (self , real : bool = True ) -> Union [int , np . float64 , float ]:
254+ def loss (self , real : bool = True ) -> Union [int , float , float ]:
266255 losses = self .losses if real else self .losses_combined
267256 if not losses :
268257 return np .inf
269258 max_interval , max_loss = losses .peekitem (0 )
270259 return max_loss
271260
272261 def _scale_x (
273- self , x : Optional [Union [float , int , np . float64 ]]
274- ) -> Optional [Union [float , np . float64 ]]:
262+ self , x : Optional [Union [float , int , float ]]
263+ ) -> Optional [Union [float , float ]]:
275264 if x is None :
276265 return None
277266 return x / self ._scale [0 ]
278267
279268 def _scale_y (
280- self , y : Optional [Union [int , np .ndarray , np . float64 , float ]]
281- ) -> Optional [Union [float , np . float64 , np .ndarray ]]:
269+ self , y : Optional [Union [int , np .ndarray , float , float ]]
270+ ) -> Optional [Union [float , float , np .ndarray ]]:
282271 if y is None :
283272 return None
284273 y_scale = self ._scale [1 ] or 1
285274 return y / y_scale
286275
287- def _get_point_by_index (self , ind : int ) -> Optional [Union [int , np . float64 , float ]]:
276+ def _get_point_by_index (self , ind : int ) -> Optional [Union [int , float , float ]]:
288277 if ind < 0 or ind >= len (self .neighbors ):
289278 return None
290279 return self .neighbors .keys ()[ind ]
291280
292281 def _get_loss_in_interval (
293- self ,
294- x_left : Union [int , np .float64 , float ],
295- x_right : Union [int , np .float64 , float ],
296- ) -> Union [int , np .float64 , float ]:
282+ self , x_left : Union [int , float , float ], x_right : Union [int , float , float ],
283+ ) -> Union [int , float , float ]:
297284 assert x_left is not None and x_right is not None
298285
299286 if x_right - x_left < self ._dx_eps :
@@ -314,9 +301,7 @@ def _get_loss_in_interval(
314301 return self .loss_per_interval (xs_scaled , ys_scaled )
315302
316303 def _update_interpolated_loss_in_interval (
317- self ,
318- x_left : Union [int , np .float64 , float ],
319- x_right : Union [int , np .float64 , float ],
304+ self , x_left : Union [int , float , float ], x_right : Union [int , float , float ],
320305 ) -> None :
321306 if x_left is None or x_right is None :
322307 return
@@ -333,9 +318,7 @@ def _update_interpolated_loss_in_interval(
333318 self .losses_combined [a , b ] = (b - a ) * loss / dx
334319 a = b
335320
336- def _update_losses (
337- self , x : Union [int , np .float64 , float ], real : bool = True
338- ) -> None :
321+ def _update_losses (self , x : Union [int , float , float ], real : bool = True ) -> None :
339322 """Update all losses that depend on x"""
340323 # When we add a new point x, we should update the losses
341324 # (x_left, x_right) are the "real" neighbors of 'x'.
@@ -378,7 +361,7 @@ def _update_losses(
378361 self .losses_combined [x , b ] = float ("inf" )
379362
380363 @staticmethod
381- def _find_neighbors (x : Union [int , np . float64 , float ], neighbors : SortedDict ) -> Any :
364+ def _find_neighbors (x : Union [int , float , float ], neighbors : SortedDict ) -> Any :
382365 if x in neighbors :
383366 return neighbors [x ]
384367 pos = neighbors .bisect_left (x )
@@ -388,7 +371,7 @@ def _find_neighbors(x: Union[int, np.float64, float], neighbors: SortedDict) ->
388371 return x_left , x_right
389372
390373 def _update_neighbors (
391- self , x : Union [int , np . float64 , float ], neighbors : SortedDict
374+ self , x : Union [int , float , float ], neighbors : SortedDict
392375 ) -> None :
393376 if x not in neighbors : # The point is new
394377 x_left , x_right = self ._find_neighbors (x , neighbors )
@@ -397,9 +380,7 @@ def _update_neighbors(
397380 neighbors .get (x_right , [None , None ])[0 ] = x
398381
399382 def _update_scale (
400- self ,
401- x : Union [int , np .float64 , float ],
402- y : Union [float , int , np .float64 , np .ndarray ],
383+ self , x : Union [int , float , float ], y : Union [float , int , float , np .ndarray ],
403384 ) -> None :
404385 """Update the scale with which the x and y-values are scaled.
405386
@@ -427,7 +408,7 @@ def _update_scale(
427408 self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
428409 self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
429410
430- def tell (self , x : Union [int , np . float64 , float ], y : Any ) -> None :
411+ def tell (self , x : Union [int , float , float ], y : Any ) -> None :
431412 if x in self .data :
432413 # The point is already evaluated before
433414 return
@@ -462,7 +443,7 @@ def tell(self, x: Union[int, np.float64, float], y: Any) -> None:
462443
463444 self ._oldscale = deepcopy (self ._scale )
464445
465- def tell_pending (self , x : Union [int , np . float64 , float ]) -> None :
446+ def tell_pending (self , x : Union [int , float , float ]) -> None :
466447 if x in self .data :
467448 # The point is already evaluated before
468449 return
@@ -678,7 +659,7 @@ def _set_data(self, data: Dict[Union[int, float], float]) -> None:
678659 self .tell_many (* zip (* data .items ()))
679660
680661
681- def loss_manager (x_scale : Union [int , np . float64 , float ]) -> ItemSortedDict :
662+ def loss_manager (x_scale : Union [int , float , float ]) -> ItemSortedDict :
682663 def sort_key (ival , loss ):
683664 loss , ival = finite_loss (ival , loss , x_scale )
684665 return - loss , ival
@@ -688,9 +669,7 @@ def sort_key(ival, loss):
688669
689670
690671def finite_loss (
691- ival : Any ,
692- loss : Union [int , np .float64 , float ],
693- x_scale : Union [int , np .float64 , float ],
672+ ival : Any , loss : Union [int , float , float ], x_scale : Union [int , float , float ],
694673) -> Any :
695674 """Get the socalled finite_loss of an interval in order to be able to
696675 sort intervals that have infinite loss."""
0 commit comments