22
33import sys
44from collections import defaultdict
5- from functools import partial
65from math import sqrt
76from operator import attrgetter
8- from typing import Any , Callable , List , Optional , Set , Tuple , Union
7+ from typing import Callable , List , Optional , Set , Tuple , Union
98
109import numpy as np
11- from numpy import ufunc
1210from scipy .linalg import norm
1311from sortedcontainers import SortedSet
1412
@@ -142,11 +140,7 @@ class _Interval:
142140 ]
143141
144142 def __init__ (
145- self ,
146- a : Union [int , np .float64 ],
147- b : Union [int , np .float64 ],
148- depth : int ,
149- rdepth : int ,
143+ self , a : Union [int , float ], b : Union [int , float ], depth : int , rdepth : int ,
150144 ) -> None :
151145 self .children = []
152146 self .data = {}
@@ -213,7 +207,7 @@ def split(self) -> List["_Interval"]:
213207 def calc_igral (self ) -> None :
214208 self .igral = (self .b - self .a ) * self .c [0 ] / sqrt (2 )
215209
216- def update_heuristic_err (self , value : Union [ np . float64 , float ] ) -> None :
210+ def update_heuristic_err (self , value : float ) -> None :
217211 """Sets the error of an interval using a heuristic (half the error of
218212 the parent) when the actual error cannot be calculated due to its
219213 parents not being finished yet. This error is propagated down to its
@@ -347,10 +341,7 @@ def __repr__(self) -> str:
347341
348342class IntegratorLearner (BaseLearner ):
349343 def __init__ (
350- self ,
351- function : Union [partial , ufunc , Callable ],
352- bounds : Tuple [int , int ],
353- tol : float ,
344+ self , function : Callable , bounds : Tuple [int , int ], tol : float ,
354345 ) -> None :
355346 """
356347 Parameters
@@ -403,7 +394,7 @@ def __init__(
403394 def approximating_intervals (self ) -> Set ["_Interval" ]:
404395 return self .first_ival .done_leaves
405396
406- def tell (self , point : np . float64 , value : np . float64 ) -> None :
397+ def tell (self , point : float , value : float ) -> None :
407398 if point not in self .x_mapping :
408399 raise ValueError (f"Point { point } doesn't belong to any interval" )
409400 self .data [point ] = value
@@ -460,23 +451,15 @@ def add_ival(self, ival: "_Interval") -> None:
460451 self ._stack .append (x )
461452 self .ivals .add (ival )
462453
463- def ask (
464- self , n : int , tell_pending : bool = True
465- ) -> Union [
466- Tuple [List [np .float64 ], List [np .float64 ]], Tuple [List [np .float64 ], List [float ]]
467- ]:
454+ def ask (self , n : int , tell_pending : bool = True ) -> Tuple [List [float ], List [float ]]:
468455 """Choose points for learners."""
469456 if not tell_pending :
470457 with restore (self ):
471458 return self ._ask_and_tell_pending (n )
472459 else :
473460 return self ._ask_and_tell_pending (n )
474461
475- def _ask_and_tell_pending (
476- self , n : int
477- ) -> Union [
478- Tuple [List [np .float64 ], List [np .float64 ]], Tuple [List [np .float64 ], List [float ]]
479- ]:
462+ def _ask_and_tell_pending (self , n : int ) -> Tuple [List [float ], List [float ]]:
480463 points , loss_improvements = self .pop_from_stack (n )
481464 n_left = n - len (points )
482465 while n_left > 0 :
@@ -492,13 +475,7 @@ def _ask_and_tell_pending(
492475
493476 return points , loss_improvements
494477
495- def pop_from_stack (
496- self , n : int
497- ) -> Union [
498- Tuple [List [np .float64 ], List [np .float64 ]],
499- Tuple [List [Any ], List [Any ]],
500- Tuple [List [np .float64 ], List [float ]],
501- ]:
478+ def pop_from_stack (self , n : int ) -> Tuple [List [float ], List [float ]]:
502479 points = self ._stack [:n ]
503480 self ._stack = self ._stack [n :]
504481 loss_improvements = [
@@ -509,7 +486,7 @@ def pop_from_stack(
509486 def remove_unfinished (self ):
510487 pass
511488
512- def _fill_stack (self ) -> List [np . float64 ]:
489+ def _fill_stack (self ) -> List [float ]:
513490 # XXX: to-do if all the ivals have err=inf, take the interval
514491 # with the lowest rdepth and no children.
515492 force_split = bool (self .priority_split )
@@ -550,11 +527,11 @@ def npoints(self) -> int:
550527 return len (self .data )
551528
552529 @property
553- def igral (self ) -> np . float64 :
530+ def igral (self ) -> float :
554531 return sum (i .igral for i in self .approximating_intervals )
555532
556533 @property
557- def err (self ) -> np . float64 :
534+ def err (self ) -> float :
558535 if self .approximating_intervals :
559536 err = sum (i .err for i in self .approximating_intervals )
560537 if err > sys .float_info .max :
0 commit comments