55from collections import defaultdict
66from math import sqrt
77from operator import attrgetter
8+ from typing import TYPE_CHECKING , Callable
89
910import cloudpickle
1011import numpy as np
2526 with_pandas = False
2627
2728
28- def _downdate (c , nans , depth ) :
29+ def _downdate (c : np . ndarray , nans : list [ int ] , depth : int ) -> np . ndarray :
2930 # This is algorithm 5 from the thesis of Pedro Gonnet.
3031 b = coeff .b_def [depth ].copy ()
3132 m = coeff .ns [depth ] - 1
@@ -45,7 +46,7 @@ def _downdate(c, nans, depth):
4546 return c
4647
4748
48- def _zero_nans (fx ) :
49+ def _zero_nans (fx : np . ndarray ) -> list [ int ] :
4950 """Caution: this function modifies fx."""
5051 nans = []
5152 for i in range (len (fx )):
@@ -55,7 +56,7 @@ def _zero_nans(fx):
5556 return nans
5657
5758
58- def _calc_coeffs (fx , depth ) :
59+ def _calc_coeffs (fx : np . ndarray , depth : int ) -> np . ndarray :
5960 """Caution: this function modifies fx."""
6061 nans = _zero_nans (fx )
6162 c_new = coeff .V_inv [depth ] @ fx
@@ -135,27 +136,32 @@ class _Interval:
135136 "removed" ,
136137 ]
137138
138- def __init__ (self , a , b , depth , rdepth ) :
139- self .children = []
140- self .data = {}
139+ def __init__ (self , a : int | float , b : int | float , depth : int , rdepth : int ) -> None :
140+ self .children : list [ _Interval ] = []
141+ self .data : dict [ float , float ] = {}
141142 self .a = a
142143 self .b = b
143144 self .depth = depth
144145 self .rdepth = rdepth
145- self .done_leaves = set ()
146- self .depth_complete = None
146+ self .done_leaves : set [ _Interval ] = set ()
147+ self .depth_complete : int | None = None
147148 self .removed = False
149+ if TYPE_CHECKING :
150+ self .ndiv : int
151+ self .parent : _Interval | None
152+ self .err : float
153+ self .c : np .ndarray
148154
149155 @classmethod
150- def make_first (cls , a , b , depth = 2 ) :
156+ def make_first (cls , a : int , b : int , depth : int = 2 ) -> _Interval :
151157 ival = _Interval (a , b , depth , rdepth = 1 )
152158 ival .ndiv = 0
153159 ival .parent = None
154160 ival .err = sys .float_info .max # needed because inf/2 == inf
155161 return ival
156162
157163 @property
158- def T (self ):
164+ def T (self ) -> np . ndarray :
159165 """Get the correct shift matrix.
160166
161167 Should only be called on children of a split interval.
@@ -166,24 +172,24 @@ def T(self):
166172 assert left != right
167173 return coeff .T_left if left else coeff .T_right
168174
169- def refinement_complete (self , depth ) :
175+ def refinement_complete (self , depth : int ) -> bool :
170176 """The interval has all the y-values to calculate the intergral."""
171177 if len (self .data ) < coeff .ns [depth ]:
172178 return False
173179 return all (p in self .data for p in self .points (depth ))
174180
175- def points (self , depth = None ):
181+ def points (self , depth : int | None = None ) -> np . ndarray :
176182 if depth is None :
177183 depth = self .depth
178184 a = self .a
179185 b = self .b
180186 return (a + b ) / 2 + (b - a ) * coeff .xi [depth ] / 2
181187
182- def refine (self ):
188+ def refine (self ) -> _Interval :
183189 self .depth += 1
184190 return self
185191
186- def split (self ):
192+ def split (self ) -> list [ _Interval ] :
187193 points = self .points ()
188194 m = points [len (points ) // 2 ]
189195 ivals = [
@@ -198,10 +204,10 @@ def split(self):
198204
199205 return ivals
200206
201- def calc_igral (self ):
207+ def calc_igral (self ) -> None :
202208 self .igral = (self .b - self .a ) * self .c [0 ] / sqrt (2 )
203209
204- def update_heuristic_err (self , value ) :
210+ def update_heuristic_err (self , value : float ) -> None :
205211 """Sets the error of an interval using a heuristic (half the error of
206212 the parent) when the actual error cannot be calculated due to its
207213 parents not being finished yet. This error is propagated down to its
@@ -214,7 +220,7 @@ def update_heuristic_err(self, value):
214220 continue
215221 child .update_heuristic_err (value / 2 )
216222
217- def calc_err (self , c_old ) :
223+ def calc_err (self , c_old : np . ndarray ) -> float :
218224 c_new = self .c
219225 c_diff = np .zeros (max (len (c_old ), len (c_new )))
220226 c_diff [: len (c_old )] = c_old
@@ -226,9 +232,9 @@ def calc_err(self, c_old):
226232 child .update_heuristic_err (self .err / 2 )
227233 return c_diff
228234
229- def calc_ndiv (self ):
235+ def calc_ndiv (self ) -> None :
230236 div = self .parent .c00 and self .c00 / self .parent .c00 > 2
231- self .ndiv += div
237+ self .ndiv += int ( div )
232238
233239 if self .ndiv > coeff .ndiv_max and 2 * self .ndiv > self .rdepth :
234240 raise DivergentIntegralError
@@ -237,15 +243,15 @@ def calc_ndiv(self):
237243 for child in self .children :
238244 child .update_ndiv_recursively ()
239245
240- def update_ndiv_recursively (self ):
246+ def update_ndiv_recursively (self ) -> None :
241247 self .ndiv += 1
242248 if self .ndiv > coeff .ndiv_max and 2 * self .ndiv > self .rdepth :
243249 raise DivergentIntegralError
244250
245251 for child in self .children :
246252 child .update_ndiv_recursively ()
247253
248- def complete_process (self , depth ) :
254+ def complete_process (self , depth : int ) -> tuple [ bool , bool ] | tuple [ bool , np . bool_ ] :
249255 """Calculate the integral contribution and error from this interval,
250256 and update the done leaves of all ancestor intervals."""
251257 assert self .depth_complete is None or self .depth_complete == depth - 1
@@ -322,7 +328,7 @@ def complete_process(self, depth):
322328
323329 return force_split , remove
324330
325- def __repr__ (self ):
331+ def __repr__ (self ) -> str :
326332 lst = [
327333 f"(a, b)=({ self .a :.5f} , { self .b :.5f} )" ,
328334 f"depth={ self .depth } " ,
@@ -334,7 +340,7 @@ def __repr__(self):
334340
335341
336342class IntegratorLearner (BaseLearner ):
337- def __init__ (self , function , bounds , tol ) :
343+ def __init__ (self , function : Callable , bounds : tuple [ int , int ], tol : float ) -> None :
338344 """
339345 Parameters
340346 ----------
@@ -368,16 +374,18 @@ def __init__(self, function, bounds, tol):
368374 plot : hv.Scatter
369375 Plots all the points that are evaluated.
370376 """
371- self .function = function
377+ self .function = function # type: ignore
372378 self .bounds = bounds
373379 self .tol = tol
374380 self .max_ivals = 1000
375- self .priority_split = []
381+ self .priority_split : list [ _Interval ] = []
376382 self .data = {}
377383 self .pending_points = set ()
378- self ._stack = []
379- self .x_mapping = defaultdict (lambda : SortedSet ([], key = attrgetter ("rdepth" )))
380- self .ivals = set ()
384+ self ._stack : list [float ] = []
385+ self .x_mapping : dict [float , SortedSet ] = defaultdict (
386+ lambda : SortedSet ([], key = attrgetter ("rdepth" ))
387+ )
388+ self .ivals : set [_Interval ] = set ()
381389 ival = _Interval .make_first (* self .bounds )
382390 self .add_ival (ival )
383391 self .first_ival = ival
@@ -387,10 +395,10 @@ def new(self) -> IntegratorLearner:
387395 return IntegratorLearner (self .function , self .bounds , self .tol )
388396
389397 @property
390- def approximating_intervals (self ):
398+ def approximating_intervals (self ) -> set [ _Interval ] :
391399 return self .first_ival .done_leaves
392400
393- def tell (self , point , value ) :
401+ def tell (self , point : float , value : float ) -> None :
394402 if point not in self .x_mapping :
395403 raise ValueError (f"Point { point } doesn't belong to any interval" )
396404 self .data [point ] = value
@@ -426,7 +434,7 @@ def tell(self, point, value):
426434 def tell_pending (self ):
427435 pass
428436
429- def propagate_removed (self , ival ) :
437+ def propagate_removed (self , ival : _Interval ) -> None :
430438 def _propagate_removed_down (ival ):
431439 ival .removed = True
432440 self .ivals .discard (ival )
@@ -436,7 +444,7 @@ def _propagate_removed_down(ival):
436444
437445 _propagate_removed_down (ival )
438446
439- def add_ival (self , ival ) :
447+ def add_ival (self , ival : _Interval ) -> None :
440448 for x in ival .points ():
441449 # Update the mappings
442450 self .x_mapping [x ].add (ival )
@@ -447,15 +455,15 @@ def add_ival(self, ival):
447455 self ._stack .append (x )
448456 self .ivals .add (ival )
449457
450- def ask (self , n , tell_pending = True ):
458+ def ask (self , n : int , tell_pending : bool = True ) -> tuple [ list [ float ], list [ float ]] :
451459 """Choose points for learners."""
452460 if not tell_pending :
453461 with restore (self ):
454462 return self ._ask_and_tell_pending (n )
455463 else :
456464 return self ._ask_and_tell_pending (n )
457465
458- def _ask_and_tell_pending (self , n ) :
466+ def _ask_and_tell_pending (self , n : int ) -> tuple [ list [ float ], list [ float ]] :
459467 points , loss_improvements = self .pop_from_stack (n )
460468 n_left = n - len (points )
461469 while n_left > 0 :
@@ -471,7 +479,7 @@ def _ask_and_tell_pending(self, n):
471479
472480 return points , loss_improvements
473481
474- def pop_from_stack (self , n ) :
482+ def pop_from_stack (self , n : int ) -> tuple [ list [ float ], list [ float ]] :
475483 points = self ._stack [:n ]
476484 self ._stack = self ._stack [n :]
477485 loss_improvements = [
@@ -482,7 +490,7 @@ def pop_from_stack(self, n):
482490 def remove_unfinished (self ):
483491 pass
484492
485- def _fill_stack (self ):
493+ def _fill_stack (self ) -> list [ float ] :
486494 # XXX: to-do if all the ivals have err=inf, take the interval
487495 # with the lowest rdepth and no children.
488496 force_split = bool (self .priority_split )
@@ -518,16 +526,16 @@ def _fill_stack(self):
518526 return self ._stack
519527
520528 @property
521- def npoints (self ):
529+ def npoints (self ) -> int :
522530 """Number of evaluated points."""
523531 return len (self .data )
524532
525533 @property
526- def igral (self ):
534+ def igral (self ) -> float :
527535 return sum (i .igral for i in self .approximating_intervals )
528536
529537 @property
530- def err (self ):
538+ def err (self ) -> float :
531539 if self .approximating_intervals :
532540 err = sum (i .err for i in self .approximating_intervals )
533541 if err > sys .float_info .max :
0 commit comments