88from typing import Any , Callable , List , Optional , Set , Tuple , Union
99
1010import numpy as np
11- from numpy import bool_ , float64 , ndarray , ufunc
11+ from numpy import ufunc
1212from scipy .linalg import norm
1313from sortedcontainers import SortedSet
1414
3333)
3434
3535
36- def _downdate (c : ndarray , nans : List [int ], depth : int ) -> ndarray :
36+ def _downdate (c : np . ndarray , nans : List [int ], depth : int ) -> np . ndarray :
3737 # This is algorithm 5 from the thesis of Pedro Gonnet.
3838 b = b_def [depth ].copy ()
3939 m = ns [depth ] - 1
@@ -51,7 +51,7 @@ def _downdate(c: ndarray, nans: List[int], depth: int) -> ndarray:
5151 return c
5252
5353
54- def _zero_nans (fx : ndarray ) -> List [int ]:
54+ def _zero_nans (fx : np . ndarray ) -> List [int ]:
5555 """Caution: this function modifies fx."""
5656 nans = []
5757 for i in range (len (fx )):
@@ -61,7 +61,7 @@ def _zero_nans(fx: ndarray) -> List[int]:
6161 return nans
6262
6363
64- def _calc_coeffs (fx : ndarray , depth : int ) -> ndarray :
64+ def _calc_coeffs (fx : np . ndarray , depth : int ) -> np . ndarray :
6565 """Caution: this function modifies fx."""
6666 nans = _zero_nans (fx )
6767 c_new = V_inv [depth ] @ fx
@@ -142,7 +142,11 @@ class _Interval:
142142 ]
143143
144144 def __init__ (
145- self , a : Union [int , float64 ], b : Union [int , float64 ], depth : int , rdepth : int
145+ self ,
146+ a : Union [int , np .float64 ],
147+ b : Union [int , np .float64 ],
148+ depth : int ,
149+ rdepth : int ,
146150 ) -> None :
147151 self .children = []
148152 self .data = {}
@@ -163,7 +167,7 @@ def make_first(cls, a: int, b: int, depth: int = 2) -> "_Interval":
163167 return ival
164168
165169 @property
166- def T (self ) -> ndarray :
170+ def T (self ) -> np . ndarray :
167171 """Get the correct shift matrix.
168172
169173 Should only be called on children of a split interval.
@@ -180,7 +184,7 @@ def refinement_complete(self, depth: int) -> bool:
180184 return False
181185 return all (p in self .data for p in self .points (depth ))
182186
183- def points (self , depth : Optional [int ] = None ) -> ndarray :
187+ def points (self , depth : Optional [int ] = None ) -> np . ndarray :
184188 if depth is None :
185189 depth = self .depth
186190 a = self .a
@@ -209,7 +213,7 @@ def split(self) -> List["_Interval"]:
209213 def calc_igral (self ) -> None :
210214 self .igral = (self .b - self .a ) * self .c [0 ] / sqrt (2 )
211215
212- def update_heuristic_err (self , value : Union [float64 , float ]) -> None :
216+ def update_heuristic_err (self , value : Union [np . float64 , float ]) -> None :
213217 """Sets the error of an interval using a heuristic (half the error of
214218 the parent) when the actual error cannot be calculated due to its
215219 parents not being finished yet. This error is propagated down to its
@@ -222,7 +226,7 @@ def update_heuristic_err(self, value: Union[float64, float]) -> None:
222226 continue
223227 child .update_heuristic_err (value / 2 )
224228
225- def calc_err (self , c_old : ndarray ) -> float :
229+ def calc_err (self , c_old : np . ndarray ) -> float :
226230 c_new = self .c
227231 c_diff = np .zeros (max (len (c_old ), len (c_new )))
228232 c_diff [: len (c_old )] = c_old
@@ -255,7 +259,7 @@ def update_ndiv_recursively(self) -> None:
255259
256260 def complete_process (
257261 self , depth : int
258- ) -> Union [Tuple [bool , bool ], Tuple [bool , bool_ ]]:
262+ ) -> Union [Tuple [bool , bool ], Tuple [bool , np . bool_ ]]:
259263 """Calculate the integral contribution and error from this interval,
260264 and update the done leaves of all ancestor intervals."""
261265 assert self .depth_complete is None or self .depth_complete == depth - 1
@@ -399,7 +403,7 @@ def __init__(
399403 def approximating_intervals (self ) -> Set ["_Interval" ]:
400404 return self .first_ival .done_leaves
401405
402- def tell (self , point : float64 , value : float64 ) -> None :
406+ def tell (self , point : np . float64 , value : np . float64 ) -> None :
403407 if point not in self .x_mapping :
404408 raise ValueError (f"Point { point } doesn't belong to any interval" )
405409 self .data [point ] = value
@@ -458,7 +462,9 @@ def add_ival(self, ival: "_Interval") -> None:
458462
459463 def ask (
460464 self , n : int , tell_pending : bool = True
461- ) -> Union [Tuple [List [float64 ], List [float64 ]], Tuple [List [float64 ], List [float ]]]:
465+ ) -> Union [
466+ Tuple [List [np .float64 ], List [np .float64 ]], Tuple [List [np .float64 ], List [float ]]
467+ ]:
462468 """Choose points for learners."""
463469 if not tell_pending :
464470 with restore (self ):
@@ -468,7 +474,9 @@ def ask(
468474
469475 def _ask_and_tell_pending (
470476 self , n : int
471- ) -> Union [Tuple [List [float64 ], List [float64 ]], Tuple [List [float64 ], List [float ]]]:
477+ ) -> Union [
478+ Tuple [List [np .float64 ], List [np .float64 ]], Tuple [List [np .float64 ], List [float ]]
479+ ]:
472480 points , loss_improvements = self .pop_from_stack (n )
473481 n_left = n - len (points )
474482 while n_left > 0 :
@@ -487,9 +495,9 @@ def _ask_and_tell_pending(
487495 def pop_from_stack (
488496 self , n : int
489497 ) -> Union [
490- Tuple [List [float64 ], List [float64 ]],
498+ Tuple [List [np . float64 ], List [np . float64 ]],
491499 Tuple [List [Any ], List [Any ]],
492- Tuple [List [float64 ], List [float ]],
500+ Tuple [List [np . float64 ], List [float ]],
493501 ]:
494502 points = self ._stack [:n ]
495503 self ._stack = self ._stack [n :]
@@ -501,7 +509,7 @@ def pop_from_stack(
501509 def remove_unfinished (self ):
502510 pass
503511
504- def _fill_stack (self ) -> List [float64 ]:
512+ def _fill_stack (self ) -> List [np . float64 ]:
505513 # XXX: to-do if all the ivals have err=inf, take the interval
506514 # with the lowest rdepth and no children.
507515 force_split = bool (self .priority_split )
@@ -542,11 +550,11 @@ def npoints(self) -> int:
542550 return len (self .data )
543551
544552 @property
545- def igral (self ) -> float64 :
553+ def igral (self ) -> np . float64 :
546554 return sum (i .igral for i in self .approximating_intervals )
547555
548556 @property
549- def err (self ) -> float64 :
557+ def err (self ) -> np . float64 :
550558 if self .approximating_intervals :
551559 err = sum (i .err for i in self .approximating_intervals )
552560 if err > sys .float_info .max :
0 commit comments