Skip to content

Commit 3034acb

Browse files
committed
type hint fixes for adaptive/learner/learner1D.py
1 parent 9137aca commit 3034acb

File tree

1 file changed

+28
-39
lines changed

1 file changed

+28
-39
lines changed

adaptive/learner/learner1D.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,9 @@
22
import math
33
from collections.abc import Iterable
44
from copy import deepcopy
5-
from functools import partial
65
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
76

87
import numpy as np
9-
import sortedcollections
10-
import sortedcontainers
118
from sortedcollections.recipes import ItemSortedDict
129
from sortedcontainers.sorteddict import SortedDict
1310

@@ -19,7 +16,7 @@
1916

2017

2118
@uses_nth_neighbors(0)
22-
def uniform_loss(xs: Tuple[float, float], ys: Tuple[float, float],) -> float:
19+
def uniform_loss(xs: Tuple[float, float], ys: Tuple[float, float]) -> float:
2320
"""Loss function that samples the domain uniformly.
2421
2522
Works with `~adaptive.Learner1D` only.
@@ -59,7 +56,7 @@ def default_loss(
5956

6057

6158
@uses_nth_neighbors(1)
62-
def triangle_loss(xs: Any, ys: Any) -> float:
59+
def triangle_loss(xs: Tuple[float], ys: Tuple[Union[float, np.ndarray]]) -> float:
6360
xs = [x for x in xs if x is not None]
6461
ys = [y for y in ys if y is not None]
6562

@@ -77,7 +74,7 @@ def triangle_loss(xs: Any, ys: Any) -> float:
7774

7875

7976
def curvature_loss_function(
80-
area_factor: int = 1, euclid_factor: float = 0.02, horizontal_factor: float = 0.02
77+
area_factor: float = 1, euclid_factor: float = 0.02, horizontal_factor: float = 0.02
8178
) -> Callable:
8279
# XXX: add a doc-string
8380
@uses_nth_neighbors(1)
@@ -97,9 +94,7 @@ def curvature_loss(xs, ys):
9794
return curvature_loss
9895

9996

100-
def linspace(
101-
x_left: Union[int, float], x_right: Union[int, float], n: int,
102-
) -> Union[List[float], List[float]]:
97+
def linspace(x_left: float, x_right: float, n: int,) -> List[float]:
10398
"""This is equivalent to
10499
'np.linspace(x_left, x_right, n, endpoint=False)[1:]',
105100
but it is 15-30 times faster for small 'n'."""
@@ -118,12 +113,10 @@ def _get_neighbors_from_list(xs: np.ndarray) -> SortedDict:
118113
xs_left[0] = None
119114
xs_right[-1] = None
120115
neighbors = {x: [x_L, x_R] for x, x_L, x_R in zip(xs, xs_left, xs_right)}
121-
return sortedcontainers.SortedDict(neighbors)
116+
return SortedDict(neighbors)
122117

123118

124-
def _get_intervals(
125-
x: Union[int, float], neighbors: SortedDict, nth_neighbors: int
126-
) -> Any:
119+
def _get_intervals(x: float, neighbors: SortedDict, nth_neighbors: int) -> Any:
127120
nn = nth_neighbors
128121
i = neighbors.index(x)
129122
start = max(0, i - nn - 1)
@@ -178,8 +171,8 @@ class Learner1D(BaseLearner):
178171

179172
def __init__(
180173
self,
181-
function: Union[Callable, partial],
182-
bounds: Union[Tuple[int, int], Tuple[float, float], np.ndarray],
174+
function: Callable,
175+
bounds: Tuple[float, float],
183176
loss_per_interval: Optional[Callable] = None,
184177
) -> None:
185178
self.function = function
@@ -201,8 +194,8 @@ def __init__(
201194

202195
# A dict {x_n: [x_{n-1}, x_{n+1}]} for quick checking of local
203196
# properties.
204-
self.neighbors = sortedcontainers.SortedDict()
205-
self.neighbors_combined = sortedcontainers.SortedDict()
197+
self.neighbors = SortedDict()
198+
self.neighbors_combined = SortedDict()
206199

207200
# Bounding box [[minx, maxx], [miny, maxy]].
208201
self._bbox = [list(bounds), [np.inf, -np.inf]]
@@ -248,34 +241,32 @@ def npoints(self) -> int:
248241
return len(self.data)
249242

250243
@cache_latest
251-
def loss(self, real: bool = True) -> Union[int, float]:
244+
def loss(self, real: bool = True) -> float:
252245
losses = self.losses if real else self.losses_combined
253246
if not losses:
254247
return np.inf
255248
max_interval, max_loss = losses.peekitem(0)
256249
return max_loss
257250

258-
def _scale_x(self, x: Optional[Union[float, int]]) -> Optional[float]:
251+
def _scale_x(self, x: Optional[float]) -> Optional[float]:
259252
if x is None:
260253
return None
261254
return x / self._scale[0]
262255

263256
def _scale_y(
264-
self, y: Optional[Union[int, np.ndarray, float, float]]
257+
self, y: Optional[Union[float, np.ndarray]]
265258
) -> Optional[Union[float, np.ndarray]]:
266259
if y is None:
267260
return None
268261
y_scale = self._scale[1] or 1
269262
return y / y_scale
270263

271-
def _get_point_by_index(self, ind: int) -> Optional[Union[int, float, float]]:
264+
def _get_point_by_index(self, ind: int) -> Optional[float]:
272265
if ind < 0 or ind >= len(self.neighbors):
273266
return None
274267
return self.neighbors.keys()[ind]
275268

276-
def _get_loss_in_interval(
277-
self, x_left: Union[int, float], x_right: Union[int, float],
278-
) -> Union[int, float]:
269+
def _get_loss_in_interval(self, x_left: float, x_right: float,) -> float:
279270
assert x_left is not None and x_right is not None
280271

281272
if x_right - x_left < self._dx_eps:
@@ -296,7 +287,7 @@ def _get_loss_in_interval(
296287
return self.loss_per_interval(xs_scaled, ys_scaled)
297288

298289
def _update_interpolated_loss_in_interval(
299-
self, x_left: Union[int, float], x_right: Union[int, float],
290+
self, x_left: float, x_right: float,
300291
) -> None:
301292
if x_left is None or x_right is None:
302293
return
@@ -313,7 +304,7 @@ def _update_interpolated_loss_in_interval(
313304
self.losses_combined[a, b] = (b - a) * loss / dx
314305
a = b
315306

316-
def _update_losses(self, x: Union[int, float], real: bool = True) -> None:
307+
def _update_losses(self, x: float, real: bool = True) -> None:
317308
"""Update all losses that depend on x"""
318309
# When we add a new point x, we should update the losses
319310
# (x_left, x_right) are the "real" neighbors of 'x'.
@@ -356,7 +347,7 @@ def _update_losses(self, x: Union[int, float], real: bool = True) -> None:
356347
self.losses_combined[x, b] = float("inf")
357348

358349
@staticmethod
359-
def _find_neighbors(x: Union[int, float], neighbors: SortedDict) -> Any:
350+
def _find_neighbors(x: float, neighbors: SortedDict) -> Any:
360351
if x in neighbors:
361352
return neighbors[x]
362353
pos = neighbors.bisect_left(x)
@@ -365,16 +356,14 @@ def _find_neighbors(x: Union[int, float], neighbors: SortedDict) -> Any:
365356
x_right = keys[pos] if pos != len(neighbors) else None
366357
return x_left, x_right
367358

368-
def _update_neighbors(self, x: Union[int, float], neighbors: SortedDict) -> None:
359+
def _update_neighbors(self, x: float, neighbors: SortedDict) -> None:
369360
if x not in neighbors: # The point is new
370361
x_left, x_right = self._find_neighbors(x, neighbors)
371362
neighbors[x] = [x_left, x_right]
372363
neighbors.get(x_left, [None, None])[1] = x
373364
neighbors.get(x_right, [None, None])[0] = x
374365

375-
def _update_scale(
376-
self, x: Union[int, float], y: Union[float, int, float, np.ndarray],
377-
) -> None:
366+
def _update_scale(self, x: float, y: Union[float, np.ndarray]) -> None:
378367
"""Update the scale with which the x and y-values are scaled.
379368
380369
For a learner where the function returns a single scalar the scale
@@ -401,7 +390,7 @@ def _update_scale(
401390
self._bbox[1][1] = max(self._bbox[1][1], y)
402391
self._scale[1] = self._bbox[1][1] - self._bbox[1][0]
403392

404-
def tell(self, x: Union[int, float], y: Any) -> None:
393+
def tell(self, x: float, y: Union[float, np.ndarray]) -> None:
405394
if x in self.data:
406395
# The point is already evaluated before
407396
return
@@ -436,15 +425,15 @@ def tell(self, x: Union[int, float], y: Any) -> None:
436425

437426
self._oldscale = deepcopy(self._scale)
438427

439-
def tell_pending(self, x: Union[int, float]) -> None:
428+
def tell_pending(self, x: float) -> None:
440429
if x in self.data:
441430
# The point is already evaluated before
442431
return
443432
self.pending_points.add(x)
444433
self._update_neighbors(x, self.neighbors_combined)
445434
self._update_losses(x, real=False)
446435

447-
def tell_many(self, xs: Any, ys: Any, *, force=False) -> None:
436+
def tell_many(self, xs: List[float], ys: List[Any], *, force=False) -> None:
448437
if not force and not (len(xs) > 0.5 * len(self.data) and len(xs) > 2):
449438
# Only run this more efficient method if there are
450439
# at least 2 points and the amount of points added are
@@ -644,24 +633,24 @@ def remove_unfinished(self) -> None:
644633
self.losses_combined = deepcopy(self.losses)
645634
self.neighbors_combined = deepcopy(self.neighbors)
646635

647-
def _get_data(self) -> Dict[Union[int, float], float]:
636+
def _get_data(self) -> Dict[float, float]:
648637
return self.data
649638

650-
def _set_data(self, data: Dict[Union[int, float], float]) -> None:
639+
def _set_data(self, data: Dict[float, float]) -> None:
651640
if data:
652641
self.tell_many(*zip(*data.items()))
653642

654643

655-
def loss_manager(x_scale: Union[int, float]) -> ItemSortedDict:
644+
def loss_manager(x_scale: float) -> ItemSortedDict:
656645
def sort_key(ival, loss):
657646
loss, ival = finite_loss(ival, loss, x_scale)
658647
return -loss, ival
659648

660-
sorted_dict = sortedcollections.ItemSortedDict(sort_key)
649+
sorted_dict = ItemSortedDict(sort_key)
661650
return sorted_dict
662651

663652

664-
def finite_loss(ival: Any, loss: Union[int, float], x_scale: Union[int, float],) -> Any:
653+
def finite_loss(ival: Any, loss: float, x_scale: float) -> Any:
665654
"""Get the socalled finite_loss of an interval in order to be able to
666655
sort intervals that have infinite loss."""
667656
# If the loss is infinite we return the

0 commit comments

Comments
 (0)