55from collections import defaultdict
66from copy import deepcopy
77from math import hypot
8- from numbers import Integral as Int
9- from numbers import Real
108from typing import Callable , DefaultDict , Iterable , List , Sequence , Tuple
119
1210import numpy as np
1614
1715from adaptive .learner .learner1D import Learner1D , _get_intervals
1816from adaptive .notebook_integration import ensure_holoviews
17+ from adaptive .types import Int , Real
1918from adaptive .utils import assign_defaults , partial_function_from_dataframe
2019
2120try :
@@ -99,7 +98,7 @@ def __init__(
9998 if min_samples > max_samples :
10099 raise ValueError ("max_samples should be larger than min_samples." )
101100
102- super ().__init__ (function , bounds , loss_per_interval )
101+ super ().__init__ (function , bounds , loss_per_interval ) # type: ignore[arg-type]
103102
104103 self .delta = delta
105104 self .alpha = alpha
@@ -110,7 +109,7 @@ def __init__(
110109
111110 # Contains all samples f(x) for each
112111 # point x in the form {x0: {0: f_0(x0), 1: f_1(x0), ...}, ...}
113- self ._data_samples = SortedDict ()
112+ self ._data_samples : SortedDict [ float , dict [ int , Real ]] = SortedDict ()
114113 # Contains the number of samples taken
115114 # at each point x in the form {x0: n0, x1: n1, ...}
116115 self ._number_samples = SortedDict ()
@@ -124,15 +123,14 @@ def __init__(
124123 # form {xi: ((xii-xi)^2 + (yii-yi)^2)^0.5, ...}
125124 self ._distances : dict [Real , float ] = decreasing_dict ()
126125 # {xii: error[xii]/min(_distances[xi], _distances[xii], ...}
127- self .rescaled_error : dict [Real , float ] = decreasing_dict ()
128- self ._check_required_attributes ()
126+ self .rescaled_error : ItemSortedDict [Real , float ] = decreasing_dict ()
129127
130128 def new (self ) -> AverageLearner1D :
131129 """Create a copy of `~adaptive.AverageLearner1D` without the data."""
132130 return AverageLearner1D (
133131 self .function ,
134132 self .bounds ,
135- self .loss_per_interval ,
133+ self .loss_per_interval , # type: ignore[arg-type]
136134 self .delta ,
137135 self .alpha ,
138136 self .neighbor_sampling ,
@@ -164,7 +162,7 @@ def to_numpy(self, mean: bool = False) -> np.ndarray:
164162 ]
165163 )
166164
167- def to_dataframe (
165+ def to_dataframe ( # type: ignore[override]
168166 self ,
169167 mean : bool = False ,
170168 with_default_function_args : bool = True ,
@@ -202,10 +200,10 @@ def to_dataframe(
202200 if not with_pandas :
203201 raise ImportError ("pandas is not installed." )
204202 if mean :
205- data = sorted (self .data .items ())
203+ data : list [ tuple [ Real , Real ]] = sorted (self .data .items ())
206204 columns = [x_name , y_name ]
207205 else :
208- data = [
206+ data : list [ tuple [ int , Real , Real ]] = [ # type: ignore[no-redef]
209207 (seed , x , y )
210208 for x , seed_y in sorted (self ._data_samples .items ())
211209 for seed , y in sorted (seed_y .items ())
@@ -218,7 +216,7 @@ def to_dataframe(
218216 assign_defaults (self .function , df , function_prefix )
219217 return df
220218
221- def load_dataframe (
219+ def load_dataframe ( # type: ignore[override]
222220 self ,
223221 df : pandas .DataFrame ,
224222 with_default_function_args : bool = True ,
@@ -258,7 +256,7 @@ def load_dataframe(
258256 self .function , df , function_prefix
259257 )
260258
261- def ask (self , n : int , tell_pending : bool = True ) -> tuple [Points , list [float ]]:
259+ def ask (self , n : int , tell_pending : bool = True ) -> tuple [Points , list [float ]]: # type: ignore[override]
262260 """Return 'n' points that are expected to maximally reduce the loss."""
263261 # If some point is undersampled, resample it
264262 if len (self ._undersampled_points ):
@@ -311,18 +309,18 @@ def _ask_for_new_point(self, n: int) -> tuple[Points, list[float]]:
311309 new point, since in general n << min_samples and this point will need
312310 to be resampled many more times"""
313311 points , (loss_improvement ,) = self ._ask_points_without_adding (1 )
314- points = [(seed , x ) for seed , x in zip (range (n ), n * points )]
312+ seed_points = [(seed , x ) for seed , x in zip (range (n ), n * points )]
315313 loss_improvements = [loss_improvement / n ] * n
316- return points , loss_improvements
314+ return seed_points , loss_improvements # type: ignore[return-value]
317315
318- def tell_pending (self , seed_x : Point ) -> None :
316+ def tell_pending (self , seed_x : Point ) -> None : # type: ignore[override]
319317 _ , x = seed_x
320318 self .pending_points .add (seed_x )
321319 if x not in self .data :
322320 self ._update_neighbors (x , self .neighbors_combined )
323321 self ._update_losses (x , real = False )
324322
325- def tell (self , seed_x : Point , y : Real ) -> None :
323+ def tell (self , seed_x : Point , y : Real ) -> None : # type: ignore[override]
326324 seed , x = seed_x
327325 if y is None :
328326 raise TypeError (
@@ -493,7 +491,7 @@ def _calc_error_in_mean(self, ys: Iterable[Real], y_avg: Real, n: int) -> float:
493491 t_student = scipy .stats .t .ppf (1 - self .alpha , df = n - 1 )
494492 return t_student * (variance_in_mean / n ) ** 0.5
495493
496- def tell_many (
494+ def tell_many ( # type: ignore[override]
497495 self , xs : Points | np .ndarray , ys : Sequence [Real ] | np .ndarray
498496 ) -> None :
499497 # Check that all x are within the bounds
@@ -578,10 +576,10 @@ def tell_many_at_point(self, x: Real, seed_y_mapping: dict[int, Real]) -> None:
578576 self ._update_interpolated_loss_in_interval (* interval )
579577 self ._oldscale = deepcopy (self ._scale )
580578
581- def _get_data (self ) -> dict [Real , dict [Int , Real ]]:
579+ def _get_data (self ) -> dict [Real , dict [Int , Real ]]: # type: ignore[override]
582580 return self ._data_samples
583581
584- def _set_data (self , data : dict [Real , dict [Int , Real ]]) -> None :
582+ def _set_data (self , data : dict [Real , dict [Int , Real ]]) -> None : # type: ignore[override]
585583 if data :
586584 for x , samples in data .items ():
587585 self .tell_many_at_point (x , samples )
@@ -616,7 +614,7 @@ def plot(self):
616614 return p .redim (x = {"range" : plot_bounds })
617615
618616
619- def decreasing_dict () -> dict :
617+ def decreasing_dict () -> ItemSortedDict :
620618 """This initialization orders the dictionary from large to small values"""
621619
622620 def sorting_rule (key , value ):
0 commit comments