55from collections import OrderedDict
66from copy import copy
77from math import sqrt
8+ from typing import Callable , Iterable
89
910import cloudpickle
1011import numpy as np
1112from scipy import interpolate
13+ from scipy .interpolate .interpnd import LinearNDInterpolator
1214
1315from adaptive .learner .base_learner import BaseLearner
1416from adaptive .learner .triangulation import simplex_volume_in_embedding
1517from adaptive .notebook_integration import ensure_holoviews
18+ from adaptive .types import Bool , Float , Real
1619from adaptive .utils import (
1720 assign_defaults ,
1821 cache_latest ,
3033# Learner2D and helper functions.
3134
3235
33- def deviations (ip ) :
36+ def deviations (ip : LinearNDInterpolator ) -> list [ np . ndarray ] :
3437 """Returns the deviation of the linear estimate.
3538
3639 Is useful when defining custom loss functions.
@@ -68,7 +71,7 @@ def deviation(p, v, g):
6871 return devs
6972
7073
71- def areas (ip ) :
74+ def areas (ip : LinearNDInterpolator ) -> np . ndarray :
7275 """Returns the area per triangle of the triangulation inside
7376 a `LinearNDInterpolator` instance.
7477
@@ -89,7 +92,7 @@ def areas(ip):
8992 return areas
9093
9194
92- def uniform_loss (ip ) :
95+ def uniform_loss (ip : LinearNDInterpolator ) -> np . ndarray :
9396 """Loss function that samples the domain uniformly.
9497
9598 Works with `~adaptive.Learner2D` only.
@@ -120,7 +123,9 @@ def uniform_loss(ip):
120123 return np .sqrt (areas (ip ))
121124
122125
123- def resolution_loss_function (min_distance = 0 , max_distance = 1 ):
126+ def resolution_loss_function (
127+ min_distance : float = 0 , max_distance : float = 1
128+ ) -> Callable [[LinearNDInterpolator ], np .ndarray ]:
124129 """Loss function that is similar to the `default_loss` function, but you
125130 can set the maximimum and minimum size of a triangle.
126131
@@ -159,7 +164,7 @@ def resolution_loss(ip):
159164 return resolution_loss
160165
161166
162- def minimize_triangle_surface_loss (ip ) :
167+ def minimize_triangle_surface_loss (ip : LinearNDInterpolator ) -> np . ndarray :
163168 """Loss function that is similar to the distance loss function in the
164169 `~adaptive.Learner1D`. The loss is the area spanned by the 3D
165170 vectors of the vertices.
@@ -205,7 +210,7 @@ def _get_vectors(points):
205210 return np .linalg .norm (np .cross (a , b ) / 2 , axis = 1 )
206211
207212
208- def default_loss (ip ) :
213+ def default_loss (ip : LinearNDInterpolator ) -> np . ndarray :
209214 """Loss function that combines `deviations` and `areas` of the triangles.
210215
211216 Works with `~adaptive.Learner2D` only.
@@ -225,7 +230,7 @@ def default_loss(ip):
225230 return losses
226231
227232
228- def choose_point_in_triangle (triangle , max_badness ) :
233+ def choose_point_in_triangle (triangle : np . ndarray , max_badness : int ) -> np . ndarray :
229234 """Choose a new point in inside a triangle.
230235
231236 If the ratio of the longest edge of the triangle squared
@@ -364,7 +369,12 @@ class Learner2D(BaseLearner):
364369 over each triangle.
365370 """
366371
367- def __init__ (self , function , bounds , loss_per_triangle = None ):
372+ def __init__ (
373+ self ,
374+ function : Callable ,
375+ bounds : tuple [tuple [Real , Real ], tuple [Real , Real ]],
376+ loss_per_triangle : Callable | None = None ,
377+ ) -> None :
368378 self .ndim = len (bounds )
369379 self ._vdim = None
370380 self .loss_per_triangle = loss_per_triangle or default_loss
@@ -379,7 +389,7 @@ def __init__(self, function, bounds, loss_per_triangle=None):
379389
380390 self ._bounds_points = list (itertools .product (* bounds ))
381391 self ._stack .update ({p : np .inf for p in self ._bounds_points })
382- self .function = function
392+ self .function = function # type: ignore
383393 self ._ip = self ._ip_combined = None
384394
385395 self .stack_size = 10
@@ -388,7 +398,7 @@ def new(self) -> Learner2D:
388398 return Learner2D (self .function , self .bounds , self .loss_per_triangle )
389399
390400 @property
391- def xy_scale (self ):
401+ def xy_scale (self ) -> np . ndarray :
392402 xy_scale = self ._xy_scale
393403 if self .aspect_ratio == 1 :
394404 return xy_scale
@@ -486,21 +496,21 @@ def load_dataframe(
486496 self .function , df , function_prefix
487497 )
488498
489- def _scale (self , points ) :
499+ def _scale (self , points : list [ tuple [ float , float ]] | np . ndarray ) -> np . ndarray :
490500 points = np .asarray (points , dtype = float )
491501 return (points - self .xy_mean ) / self .xy_scale
492502
493- def _unscale (self , points ) :
503+ def _unscale (self , points : np . ndarray ) -> np . ndarray :
494504 points = np .asarray (points , dtype = float )
495505 return points * self .xy_scale + self .xy_mean
496506
497507 @property
498- def npoints (self ):
508+ def npoints (self ) -> int :
499509 """Number of evaluated points."""
500510 return len (self .data )
501511
502512 @property
503- def vdim (self ):
513+ def vdim (self ) -> int :
504514 """Length of the output of ``learner.function``.
505515 If the output is unsized (when it's a scalar)
506516 then `vdim = 1`.
@@ -516,12 +526,14 @@ def vdim(self):
516526 return self ._vdim or 1
517527
518528 @property
519- def bounds_are_done (self ):
529+ def bounds_are_done (self ) -> bool :
520530 return not any (
521531 (p in self .pending_points or p in self ._stack ) for p in self ._bounds_points
522532 )
523533
524- def interpolated_on_grid (self , n = None ):
534+ def interpolated_on_grid (
535+ self , n : int = None
536+ ) -> tuple [np .ndarray , np .ndarray , np .ndarray ]:
525537 """Get the interpolated data on a grid.
526538
527539 Parameters
@@ -553,7 +565,7 @@ def interpolated_on_grid(self, n=None):
553565 xs , ys = self ._unscale (np .vstack ([xs , ys ]).T ).T
554566 return xs , ys , zs
555567
556- def _data_in_bounds (self ):
568+ def _data_in_bounds (self ) -> tuple [ np . ndarray , np . ndarray ] :
557569 if self .data :
558570 points = np .array (list (self .data .keys ()))
559571 values = np .array (list (self .data .values ()), dtype = float )
@@ -562,7 +574,7 @@ def _data_in_bounds(self):
562574 return points [inds ], values [inds ].reshape (- 1 , self .vdim )
563575 return np .zeros ((0 , 2 )), np .zeros ((0 , self .vdim ), dtype = float )
564576
565- def _data_interp (self ):
577+ def _data_interp (self ) -> tuple [ np . ndarray | list [ tuple [ float , float ]], np . ndarray ] :
566578 if self .pending_points :
567579 points = list (self .pending_points )
568580 if self .bounds_are_done :
@@ -575,7 +587,7 @@ def _data_interp(self):
575587 return points , values
576588 return np .zeros ((0 , 2 )), np .zeros ((0 , self .vdim ), dtype = float )
577589
578- def _data_combined (self ):
590+ def _data_combined (self ) -> tuple [ np . ndarray , np . ndarray ] :
579591 points , values = self ._data_in_bounds ()
580592 if not self .pending_points :
581593 return points , values
@@ -584,7 +596,7 @@ def _data_combined(self):
584596 values_combined = np .vstack ([values , values_interp ])
585597 return points_combined , values_combined
586598
587- def ip (self ):
599+ def ip (self ) -> LinearNDInterpolator :
588600 """Deprecated, use `self.interpolator(scaled=True)`"""
589601 warnings .warn (
590602 "`learner.ip()` is deprecated, use `learner.interpolator(scaled=True)`."
@@ -593,7 +605,7 @@ def ip(self):
593605 )
594606 return self .interpolator (scaled = True )
595607
596- def interpolator (self , * , scaled = False ):
608+ def interpolator (self , * , scaled : bool = False ) -> LinearNDInterpolator :
597609 """A `scipy.interpolate.LinearNDInterpolator` instance
598610 containing the learner's data.
599611
@@ -624,7 +636,7 @@ def interpolator(self, *, scaled=False):
624636 points , values = self ._data_in_bounds ()
625637 return interpolate .LinearNDInterpolator (points , values )
626638
627- def _interpolator_combined (self ):
639+ def _interpolator_combined (self ) -> LinearNDInterpolator :
628640 """A `scipy.interpolate.LinearNDInterpolator` instance
629641 containing the learner's data *and* interpolated data of
630642 the `pending_points`."""
@@ -634,12 +646,12 @@ def _interpolator_combined(self):
634646 self ._ip_combined = interpolate .LinearNDInterpolator (points , values )
635647 return self ._ip_combined
636648
637- def inside_bounds (self , xy ) :
649+ def inside_bounds (self , xy : tuple [ float , float ]) -> Bool :
638650 x , y = xy
639651 (xmin , xmax ), (ymin , ymax ) = self .bounds
640652 return xmin <= x <= xmax and ymin <= y <= ymax
641653
642- def tell (self , point , value ) :
654+ def tell (self , point : tuple [ float , float ], value : float | Iterable [ float ]) -> None :
643655 point = tuple (point )
644656 self .data [point ] = value
645657 if not self .inside_bounds (point ):
@@ -648,15 +660,17 @@ def tell(self, point, value):
648660 self ._ip = None
649661 self ._stack .pop (point , None )
650662
651- def tell_pending (self , point ) :
663+ def tell_pending (self , point : tuple [ float , float ]) -> None :
652664 point = tuple (point )
653665 if not self .inside_bounds (point ):
654666 return
655667 self .pending_points .add (point )
656668 self ._ip_combined = None
657669 self ._stack .pop (point , None )
658670
659- def _fill_stack (self , stack_till = 1 ):
671+ def _fill_stack (
672+ self , stack_till : int = 1
673+ ) -> tuple [list [tuple [float , float ]], list [float ]]:
660674 if len (self .data ) + len (self .pending_points ) < self .ndim + 1 :
661675 raise ValueError ("too few points..." )
662676
@@ -695,7 +709,9 @@ def _fill_stack(self, stack_till=1):
695709
696710 return points_new , losses_new
697711
698- def ask (self , n , tell_pending = True ):
712+ def ask (
713+ self , n : int , tell_pending : bool = True
714+ ) -> tuple [list [tuple [float , float ] | np .ndarray ], list [float ]]:
699715 # Even if tell_pending is False we add the point such that _fill_stack
700716 # will return new points, later we remove these points if needed.
701717 points = list (self ._stack .keys ())
@@ -726,14 +742,14 @@ def ask(self, n, tell_pending=True):
726742 return points [:n ], loss_improvements [:n ]
727743
728744 @cache_latest
729- def loss (self , real = True ):
745+ def loss (self , real : bool = True ) -> float :
730746 if not self .bounds_are_done :
731747 return np .inf
732748 ip = self .interpolator (scaled = True ) if real else self ._interpolator_combined ()
733749 losses = self .loss_per_triangle (ip )
734750 return losses .max ()
735751
736- def remove_unfinished (self ):
752+ def remove_unfinished (self ) -> None :
737753 self .pending_points = set ()
738754 for p in self ._bounds_points :
739755 if p not in self .data :
@@ -807,10 +823,10 @@ def plot(self, n=None, tri_alpha=0):
807823
808824 return im .opts (style = im_opts ) * tris .opts (style = tri_opts , ** no_hover )
809825
810- def _get_data (self ):
826+ def _get_data (self ) -> dict [ tuple [ float , float ], Float | np . ndarray ] :
811827 return self .data
812828
813- def _set_data (self , data ) :
829+ def _set_data (self , data : dict [ tuple [ float , float ], Float | np . ndarray ]) -> None :
814830 self .data = data
815831 # Remove points from stack if they already exist
816832 for point in copy (self ._stack ):
0 commit comments