1919
2020
2121@uses_nth_neighbors (0 )
22- def uniform_loss (
23- xs : Union [Tuple [float , float ], Tuple [float , float ]],
24- ys : Union [Tuple [float , float ], Tuple [float , float ]],
25- ) -> Union [float , float ]:
22+ def uniform_loss (xs : Tuple [float , float ], ys : Tuple [float , float ],) -> float :
2623 """Loss function that samples the domain uniformly.
2724
2825 Works with `~adaptive.Learner1D` only.
@@ -62,7 +59,7 @@ def default_loss(
6259
6360
6461@uses_nth_neighbors (1 )
65- def triangle_loss (xs : Any , ys : Any ) -> Union [ float , float ] :
62+ def triangle_loss (xs : Any , ys : Any ) -> float :
6663 xs = [x for x in xs if x is not None ]
6764 ys = [y for y in ys if y is not None ]
6865
@@ -101,7 +98,7 @@ def curvature_loss(xs, ys):
10198
10299
103100def linspace (
104- x_left : Union [int , float , float ], x_right : Union [int , float , float ], n : int ,
101+ x_left : Union [int , float ], x_right : Union [int , float ], n : int ,
105102) -> Union [List [float ], List [float ]]:
106103 """This is equivalent to
107104 'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
@@ -125,7 +122,7 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
125122
126123
127124def _get_intervals (
128- x : Union [int , float , float ], neighbors : SortedDict , nth_neighbors : int
125+ x : Union [int , float ], neighbors : SortedDict , nth_neighbors : int
129126) -> Any :
130127 nn = nth_neighbors
131128 i = neighbors .index (x )
@@ -251,23 +248,21 @@ def npoints(self) -> int:
251248 return len (self .data )
252249
253250 @cache_latest
254- def loss (self , real : bool = True ) -> Union [int , float , float ]:
251+ def loss (self , real : bool = True ) -> Union [int , float ]:
255252 losses = self .losses if real else self .losses_combined
256253 if not losses :
257254 return np .inf
258255 max_interval , max_loss = losses .peekitem (0 )
259256 return max_loss
260257
261- def _scale_x (
262- self , x : Optional [Union [float , int , float ]]
263- ) -> Optional [Union [float , float ]]:
258+ def _scale_x (self , x : Optional [Union [float , int ]]) -> Optional [float ]:
264259 if x is None :
265260 return None
266261 return x / self ._scale [0 ]
267262
268263 def _scale_y (
269264 self , y : Optional [Union [int , np .ndarray , float , float ]]
270- ) -> Optional [Union [float , float , np .ndarray ]]:
265+ ) -> Optional [Union [float , np .ndarray ]]:
271266 if y is None :
272267 return None
273268 y_scale = self ._scale [1 ] or 1
@@ -279,8 +274,8 @@ def _get_point_by_index(self, ind: int) -> Optional[Union[int, float, float]]:
279274 return self .neighbors .keys ()[ind ]
280275
281276 def _get_loss_in_interval (
282- self , x_left : Union [int , float , float ], x_right : Union [int , float , float ],
283- ) -> Union [int , float , float ]:
277+ self , x_left : Union [int , float ], x_right : Union [int , float ],
278+ ) -> Union [int , float ]:
284279 assert x_left is not None and x_right is not None
285280
286281 if x_right - x_left < self ._dx_eps :
@@ -301,7 +296,7 @@ def _get_loss_in_interval(
301296 return self .loss_per_interval (xs_scaled , ys_scaled )
302297
303298 def _update_interpolated_loss_in_interval (
304- self , x_left : Union [int , float , float ], x_right : Union [int , float , float ],
299+ self , x_left : Union [int , float ], x_right : Union [int , float ],
305300 ) -> None :
306301 if x_left is None or x_right is None :
307302 return
@@ -318,7 +313,7 @@ def _update_interpolated_loss_in_interval(
318313 self .losses_combined [a , b ] = (b - a ) * loss / dx
319314 a = b
320315
321- def _update_losses (self , x : Union [int , float , float ], real : bool = True ) -> None :
316+ def _update_losses (self , x : Union [int , float ], real : bool = True ) -> None :
322317 """Update all losses that depend on x"""
323318 # When we add a new point x, we should update the losses
324319 # (x_left, x_right) are the "real" neighbors of 'x'.
@@ -361,7 +356,7 @@ def _update_losses(self, x: Union[int, float, float], real: bool = True) -> None
361356 self .losses_combined [x , b ] = float ("inf" )
362357
363358 @staticmethod
364- def _find_neighbors (x : Union [int , float , float ], neighbors : SortedDict ) -> Any :
359+ def _find_neighbors (x : Union [int , float ], neighbors : SortedDict ) -> Any :
365360 if x in neighbors :
366361 return neighbors [x ]
367362 pos = neighbors .bisect_left (x )
@@ -370,17 +365,15 @@ def _find_neighbors(x: Union[int, float, float], neighbors: SortedDict) -> Any:
370365 x_right = keys [pos ] if pos != len (neighbors ) else None
371366 return x_left , x_right
372367
373- def _update_neighbors (
374- self , x : Union [int , float , float ], neighbors : SortedDict
375- ) -> None :
368+ def _update_neighbors (self , x : Union [int , float ], neighbors : SortedDict ) -> None :
376369 if x not in neighbors : # The point is new
377370 x_left , x_right = self ._find_neighbors (x , neighbors )
378371 neighbors [x ] = [x_left , x_right ]
379372 neighbors .get (x_left , [None , None ])[1 ] = x
380373 neighbors .get (x_right , [None , None ])[0 ] = x
381374
382375 def _update_scale (
383- self , x : Union [int , float , float ], y : Union [float , int , float , np .ndarray ],
376+ self , x : Union [int , float ], y : Union [float , int , float , np .ndarray ],
384377 ) -> None :
385378 """Update the scale with which the x and y-values are scaled.
386379
@@ -408,7 +401,7 @@ def _update_scale(
408401 self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
409402 self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
410403
411- def tell (self , x : Union [int , float , float ], y : Any ) -> None :
404+ def tell (self , x : Union [int , float ], y : Any ) -> None :
412405 if x in self .data :
413406 # The point is already evaluated before
414407 return
@@ -443,7 +436,7 @@ def tell(self, x: Union[int, float, float], y: Any) -> None:
443436
444437 self ._oldscale = deepcopy (self ._scale )
445438
446- def tell_pending (self , x : Union [int , float , float ]) -> None :
439+ def tell_pending (self , x : Union [int , float ]) -> None :
447440 if x in self .data :
448441 # The point is already evaluated before
449442 return
@@ -659,7 +652,7 @@ def _set_data(self, data: Dict[Union[int, float], float]) -> None:
659652 self .tell_many (* zip (* data .items ()))
660653
661654
662- def loss_manager (x_scale : Union [int , float , float ]) -> ItemSortedDict :
655+ def loss_manager (x_scale : Union [int , float ]) -> ItemSortedDict :
663656 def sort_key (ival , loss ):
664657 loss , ival = finite_loss (ival , loss , x_scale )
665658 return - loss , ival
@@ -668,9 +661,7 @@ def sort_key(ival, loss):
668661 return sorted_dict
669662
670663
671- def finite_loss (
672- ival : Any , loss : Union [int , float , float ], x_scale : Union [int , float , float ],
673- ) -> Any :
664+ def finite_loss (ival : Any , loss : Union [int , float ], x_scale : Union [int , float ],) -> Any :
674665 """Get the socalled finite_loss of an interval in order to be able to
675666 sort intervals that have infinite loss."""
676667 # If the loss is infinite we return the
0 commit comments