22import math
33from collections .abc import Iterable
44from copy import deepcopy
5- from functools import partial
65from typing import Any , Callable , Dict , List , Optional , Tuple , Union
76
87import numpy as np
9- import sortedcollections
10- import sortedcontainers
118from sortedcollections .recipes import ItemSortedDict
129from sortedcontainers .sorteddict import SortedDict
1310
1916
2017
2118@uses_nth_neighbors (0 )
22- def uniform_loss (xs : Tuple [float , float ], ys : Tuple [float , float ], ) -> float :
19+ def uniform_loss (xs : Tuple [float , float ], ys : Tuple [float , float ]) -> float :
2320 """Loss function that samples the domain uniformly.
2421
2522 Works with `~adaptive.Learner1D` only.
@@ -59,7 +56,7 @@ def default_loss(
5956
6057
6158@uses_nth_neighbors (1 )
62- def triangle_loss (xs : Any , ys : Any ) -> float :
59+ def triangle_loss (xs : Tuple [ float ] , ys : Tuple [ Union [ float , np . ndarray ]] ) -> float :
6360 xs = [x for x in xs if x is not None ]
6461 ys = [y for y in ys if y is not None ]
6562
@@ -77,7 +74,7 @@ def triangle_loss(xs: Any, ys: Any) -> float:
7774
7875
7976def curvature_loss_function (
80- area_factor : int = 1 , euclid_factor : float = 0.02 , horizontal_factor : float = 0.02
77+ area_factor : float = 1 , euclid_factor : float = 0.02 , horizontal_factor : float = 0.02
8178) -> Callable :
8279 # XXX: add a doc-string
8380 @uses_nth_neighbors (1 )
@@ -97,9 +94,7 @@ def curvature_loss(xs, ys):
9794 return curvature_loss
9895
9996
100- def linspace (
101- x_left : Union [int , float ], x_right : Union [int , float ], n : int ,
102- ) -> Union [List [float ], List [float ]]:
97+ def linspace (x_left : float , x_right : float , n : int ,) -> List [float ]:
10398 """This is equivalent to
10499 'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
105100 but it is 15-30 times faster for small 'n'."""
@@ -118,12 +113,10 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
118113 xs_left [0 ] = None
119114 xs_right [- 1 ] = None
120115 neighbors = {x : [x_L , x_R ] for x , x_L , x_R in zip (xs , xs_left , xs_right )}
121- return sortedcontainers . SortedDict (neighbors )
116+ return SortedDict (neighbors )
122117
123118
124- def _get_intervals (
125- x : Union [int , float ], neighbors : SortedDict , nth_neighbors : int
126- ) -> Any :
119+ def _get_intervals (x : float , neighbors : SortedDict , nth_neighbors : int ) -> Any :
127120 nn = nth_neighbors
128121 i = neighbors .index (x )
129122 start = max (0 , i - nn - 1 )
@@ -178,8 +171,8 @@ class Learner1D(BaseLearner):
178171
179172 def __init__ (
180173 self ,
181- function : Union [ Callable , partial ] ,
182- bounds : Union [ Tuple [int , int ], Tuple [ float , float ], np . ndarray ],
174+ function : Callable ,
175+ bounds : Tuple [float , float ],
183176 loss_per_interval : Optional [Callable ] = None ,
184177 ) -> None :
185178 self .function = function
@@ -201,8 +194,8 @@ def __init__(
201194
202195 # A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
203196 # properties.
204- self .neighbors = sortedcontainers . SortedDict ()
205- self .neighbors_combined = sortedcontainers . SortedDict ()
197+ self .neighbors = SortedDict ()
198+ self .neighbors_combined = SortedDict ()
206199
207200 # Bounding box [[minx, maxx], [miny, maxy]].
208201 self ._bbox = [list (bounds ), [np .inf , - np .inf ]]
@@ -248,34 +241,32 @@ def npoints(self) -> int:
248241 return len (self .data )
249242
250243 @cache_latest
251- def loss (self , real : bool = True ) -> Union [ int , float ] :
244+ def loss (self , real : bool = True ) -> float :
252245 losses = self .losses if real else self .losses_combined
253246 if not losses :
254247 return np .inf
255248 max_interval , max_loss = losses .peekitem (0 )
256249 return max_loss
257250
258- def _scale_x (self , x : Optional [Union [ float , int ] ]) -> Optional [float ]:
251+ def _scale_x (self , x : Optional [float ]) -> Optional [float ]:
259252 if x is None :
260253 return None
261254 return x / self ._scale [0 ]
262255
263256 def _scale_y (
264- self , y : Optional [Union [int , np .ndarray , float , float ]]
257+ self , y : Optional [Union [float , np .ndarray ]]
265258 ) -> Optional [Union [float , np .ndarray ]]:
266259 if y is None :
267260 return None
268261 y_scale = self ._scale [1 ] or 1
269262 return y / y_scale
270263
271- def _get_point_by_index (self , ind : int ) -> Optional [Union [ int , float , float ] ]:
264+ def _get_point_by_index (self , ind : int ) -> Optional [float ]:
272265 if ind < 0 or ind >= len (self .neighbors ):
273266 return None
274267 return self .neighbors .keys ()[ind ]
275268
276- def _get_loss_in_interval (
277- self , x_left : Union [int , float ], x_right : Union [int , float ],
278- ) -> Union [int , float ]:
269+ def _get_loss_in_interval (self , x_left : float , x_right : float ,) -> float :
279270 assert x_left is not None and x_right is not None
280271
281272 if x_right - x_left < self ._dx_eps :
@@ -296,7 +287,7 @@ def _get_loss_in_interval(
296287 return self .loss_per_interval (xs_scaled , ys_scaled )
297288
298289 def _update_interpolated_loss_in_interval (
299- self , x_left : Union [ int , float ] , x_right : Union [ int , float ] ,
290+ self , x_left : float , x_right : float ,
300291 ) -> None :
301292 if x_left is None or x_right is None :
302293 return
@@ -313,7 +304,7 @@ def _update_interpolated_loss_in_interval(
313304 self .losses_combined [a , b ] = (b - a ) * loss / dx
314305 a = b
315306
316- def _update_losses (self , x : Union [ int , float ] , real : bool = True ) -> None :
307+ def _update_losses (self , x : float , real : bool = True ) -> None :
317308 """Update all losses that depend on x"""
318309 # When we add a new point x, we should update the losses
319310 # (x_left, x_right) are the "real" neighbors of 'x'.
@@ -356,7 +347,7 @@ def _update_losses(self, x: Union[int, float], real: bool = True) -> None:
356347 self .losses_combined [x , b ] = float ("inf" )
357348
358349 @staticmethod
359- def _find_neighbors (x : Union [ int , float ] , neighbors : SortedDict ) -> Any :
350+ def _find_neighbors (x : float , neighbors : SortedDict ) -> Any :
360351 if x in neighbors :
361352 return neighbors [x ]
362353 pos = neighbors .bisect_left (x )
@@ -365,16 +356,14 @@ def _find_neighbors(x: Union[int, float], neighbors: SortedDict) -> Any:
365356 x_right = keys [pos ] if pos != len (neighbors ) else None
366357 return x_left , x_right
367358
368- def _update_neighbors (self , x : Union [ int , float ] , neighbors : SortedDict ) -> None :
359+ def _update_neighbors (self , x : float , neighbors : SortedDict ) -> None :
369360 if x not in neighbors : # The point is new
370361 x_left , x_right = self ._find_neighbors (x , neighbors )
371362 neighbors [x ] = [x_left , x_right ]
372363 neighbors .get (x_left , [None , None ])[1 ] = x
373364 neighbors .get (x_right , [None , None ])[0 ] = x
374365
375- def _update_scale (
376- self , x : Union [int , float ], y : Union [float , int , float , np .ndarray ],
377- ) -> None :
366+ def _update_scale (self , x : float , y : Union [float , np .ndarray ]) -> None :
378367 """Update the scale with which the x and y-values are scaled.
379368
380369 For a learner where the function returns a single scalar the scale
@@ -401,7 +390,7 @@ def _update_scale(
401390 self ._bbox [1 ][1 ] = max (self ._bbox [1 ][1 ], y )
402391 self ._scale [1 ] = self ._bbox [1 ][1 ] - self ._bbox [1 ][0 ]
403392
404- def tell (self , x : Union [ int , float ] , y : Any ) -> None :
393+ def tell (self , x : float , y : Union [ float , np . ndarray ] ) -> None :
405394 if x in self .data :
406395 # The point is already evaluated before
407396 return
@@ -436,15 +425,15 @@ def tell(self, x: Union[int, float], y: Any) -> None:
436425
437426 self ._oldscale = deepcopy (self ._scale )
438427
439- def tell_pending (self , x : Union [ int , float ] ) -> None :
428+ def tell_pending (self , x : float ) -> None :
440429 if x in self .data :
441430 # The point is already evaluated before
442431 return
443432 self .pending_points .add (x )
444433 self ._update_neighbors (x , self .neighbors_combined )
445434 self ._update_losses (x , real = False )
446435
447- def tell_many (self , xs : Any , ys : Any , * , force = False ) -> None :
436+ def tell_many (self , xs : List [ float ] , ys : List [ Any ] , * , force = False ) -> None :
448437 if not force and not (len (xs ) > 0.5 * len (self .data ) and len (xs ) > 2 ):
449438 # Only run this more efficient method if there are
450439 # at least 2 points and the amount of points added are
@@ -644,24 +633,24 @@ def remove_unfinished(self) -> None:
644633 self .losses_combined = deepcopy (self .losses )
645634 self .neighbors_combined = deepcopy (self .neighbors )
646635
647- def _get_data (self ) -> Dict [Union [ int , float ] , float ]:
636+ def _get_data (self ) -> Dict [float , float ]:
648637 return self .data
649638
650- def _set_data (self , data : Dict [Union [ int , float ] , float ]) -> None :
639+ def _set_data (self , data : Dict [float , float ]) -> None :
651640 if data :
652641 self .tell_many (* zip (* data .items ()))
653642
654643
655- def loss_manager (x_scale : Union [ int , float ] ) -> ItemSortedDict :
644+ def loss_manager (x_scale : float ) -> ItemSortedDict :
656645 def sort_key (ival , loss ):
657646 loss , ival = finite_loss (ival , loss , x_scale )
658647 return - loss , ival
659648
660- sorted_dict = sortedcollections . ItemSortedDict (sort_key )
649+ sorted_dict = ItemSortedDict (sort_key )
661650 return sorted_dict
662651
663652
664- def finite_loss (ival : Any , loss : Union [ int , float ] , x_scale : Union [ int , float ], ) -> Any :
653+ def finite_loss (ival : Any , loss : float , x_scale : float ) -> Any :
665654 """Get the socalled finite_loss of an interval in order to be able to
666655 sort intervals that have infinite loss."""
667656 # If the loss is infinite we return the
0 commit comments