22# Copyright 2017 Christoph Groth
33
44from collections import defaultdict
5- from fractions import Fraction as Frac
5+ from fractions import Fraction
6+ from typing import Callable , List , Tuple , Union
67
78import numpy as np
89from numpy .testing import assert_allclose
1112eps = np .spacing (1 )
1213
1314
14- def legendre (n ) :
15+ def legendre (n : int ) -> List [ List [ Fraction ]] :
1516 """Return the first n Legendre polynomials.
1617
1718 The polynomials have *standard* normalization, i.e.
1819 int_{-1}^1 dx L_n(x) L_m(x) = delta(m, n) * 2 / (2 * n + 1).
1920
2021 The return value is a list of list of fraction.Fraction instances.
2122 """
22- result = [[Frac (1 )], [Frac (0 ), Frac (1 )]]
23+ result = [[Fraction (1 )], [Fraction (0 ), Fraction (1 )]]
2324 if n <= 2 :
2425 return result [:n ]
2526 for i in range (2 , n ):
2627 # Use Bonnet's recursion formula.
27- new = (i + 1 ) * [Frac (0 )]
28+ new = (i + 1 ) * [Fraction (0 )]
2829 new [1 :] = (r * (2 * i - 1 ) for r in result [- 1 ])
2930 new [:- 2 ] = (n - r * (i - 1 ) for n , r in zip (new [:- 2 ], result [- 2 ]))
3031 new [:] = (n / i for n in new )
3132 result .append (new )
3233 return result
3334
3435
35- def newton (n ) :
36+ def newton (n : int ) -> np . ndarray :
3637 """Compute the monomial coefficients of the Newton polynomial over the
3738 nodes of the n-point Clenshaw-Curtis quadrature rule.
3839 """
@@ -89,7 +90,7 @@ def newton(n):
8990 return cf
9091
9192
92- def scalar_product (a , b ) :
93+ def scalar_product (a : List [ Fraction ] , b : List [ Fraction ]) -> Fraction :
9394 """Compute the polynomial scalar product int_-1^1 dx a(x) b(x).
9495
9596 The args must be sequences of polynomial coefficients. This
@@ -110,7 +111,7 @@ def scalar_product(a, b):
110111 return 2 * sum (c [i ] / (i + 1 ) for i in range (0 , lc , 2 ))
111112
112113
113- def calc_bdef (ns ) :
114+ def calc_bdef (ns : Tuple [ int , int , int , int ]) -> List [ np . ndarray ] :
114115 """Calculate the decompositions of Newton polynomials (over the nodes
115116 of the n-point Clenshaw-Curtis quadrature rule) in terms of
116117 Legandre polynomials.
@@ -123,7 +124,7 @@ def calc_bdef(ns):
123124 result = []
124125 for n in ns :
125126 poly = []
126- a = list (map (Frac , newton (n )))
127+ a = list (map (Fraction , newton (n )))
127128 for b in legs [: n + 1 ]:
128129 igral = scalar_product (a , b )
129130
@@ -145,7 +146,7 @@ def calc_bdef(ns):
145146b_def = calc_bdef (n )
146147
147148
148- def calc_V (xi , n ) :
149+ def calc_V (xi : np . ndarray , n : int ) -> np . ndarray :
149150 V = [np .ones (xi .shape ), xi .copy ()]
150151 for i in range (2 , n ):
151152 V .append ((2 * i - 1 ) / i * xi * V [- 1 ] - (i - 1 ) / i * V [- 2 ])
@@ -183,7 +184,7 @@ def calc_V(xi, n):
183184gamma = np .concatenate ([[0 , 0 ], np .sqrt (k [2 :] ** 2 / (4 * k [2 :] ** 2 - 1 ))])
184185
185186
186- def _downdate (c , nans , depth ) :
187+ def _downdate (c : np . ndarray , nans : List [ int ] , depth : int ) -> None :
187188 # This is algorithm 5 from the thesis of Pedro Gonnet.
188189 b = b_def [depth ].copy ()
189190 m = n [depth ] - 1
@@ -200,7 +201,7 @@ def _downdate(c, nans, depth):
200201 m -= 1
201202
202203
203- def _zero_nans (fx ) :
204+ def _zero_nans (fx : np . ndarray ) -> List [ int ] :
204205 nans = []
205206 for i in range (len (fx )):
206207 if not np .isfinite (fx [i ]):
@@ -209,7 +210,7 @@ def _zero_nans(fx):
209210 return nans
210211
211212
212- def _calc_coeffs (fx , depth ) :
213+ def _calc_coeffs (fx : np . ndarray , depth : int ) -> np . ndarray :
213214 """Caution: this function modifies fx."""
214215 nans = _zero_nans (fx )
215216 c_new = V_inv [depth ] @ fx
@@ -220,7 +221,7 @@ def _calc_coeffs(fx, depth):
220221
221222
222223class DivergentIntegralError (ValueError ):
223- def __init__ (self , msg , igral , err , nr_points ) :
224+ def __init__ (self , msg : str , igral : float , err : None , nr_points : int ) -> None :
224225 self .igral = igral
225226 self .err = err
226227 self .nr_points = nr_points
@@ -230,19 +231,23 @@ def __init__(self, msg, igral, err, nr_points):
230231class _Interval :
231232 __slots__ = ["a" , "b" , "c" , "fx" , "igral" , "err" , "depth" , "rdepth" , "ndiv" , "c00" ]
232233
233- def __init__ (self , a , b , depth , rdepth ):
234+ def __init__ (
235+ self , a : Union [int , float ], b : Union [int , float ], depth : int , rdepth : int
236+ ) -> None :
234237 self .a = a
235238 self .b = b
236239 self .depth = depth
237240 self .rdepth = rdepth
238241
239- def points (self ):
242+ def points (self ) -> np . ndarray :
240243 a = self .a
241244 b = self .b
242245 return (a + b ) / 2 + (b - a ) * xi [self .depth ] / 2
243246
244247 @classmethod
245- def make_first (cls , f , a , b , depth = 2 ):
248+ def make_first (
249+ cls , f : Callable , a : int , b : int , depth : int = 2
250+ ) -> Tuple ["_Interval" , int ]:
246251 ival = _Interval (a , b , depth , 1 )
247252 fx = f (ival .points ())
248253 ival .c = _calc_coeffs (fx , depth )
@@ -251,7 +256,7 @@ def make_first(cls, f, a, b, depth=2):
251256 ival .ndiv = 0
252257 return ival , n [depth ]
253258
254- def calc_igral_and_err (self , c_old ) :
259+ def calc_igral_and_err (self , c_old : np . ndarray ) -> float :
255260 self .c = c_new = _calc_coeffs (self .fx , self .depth )
256261 c_diff = np .zeros (max (len (c_old ), len (c_new )))
257262 c_diff [: len (c_old )] = c_old
@@ -262,7 +267,9 @@ def calc_igral_and_err(self, c_old):
262267 self .err = w * c_diff
263268 return c_diff
264269
265- def split (self , f ):
270+ def split (
271+ self , f : Callable
272+ ) -> Union [Tuple [Tuple [float , float , float ], int ], Tuple [List ["_Interval" ], int ]]:
266273 m = (self .a + self .b ) / 2
267274 f_center = self .fx [(len (self .fx ) - 1 ) // 2 ]
268275
@@ -287,7 +294,7 @@ def split(self, f):
287294
288295 return ivals , nr_points
289296
290- def refine (self , f ) :
297+ def refine (self , f : Callable ) -> Tuple [ np . ndarray , bool , int ] :
291298 """Increase degree of interval."""
292299 self .depth = depth = self .depth + 1
293300 points = self .points ()
@@ -299,7 +306,9 @@ def refine(self, f):
299306 return points , split , n [depth ] - n [depth - 1 ]
300307
301308
302- def algorithm_4 (f , a , b , tol , N_loops = int (1e9 )):
309+ def algorithm_4 (
310+ f : Callable , a : int , b : int , tol : float , N_loops : int = int (1e9 )
311+ ) -> Tuple [float , float , int , List ["_Interval" ]]:
303312 """ALGORITHM_4 evaluates an integral using adaptive quadrature. The
304313 algorithm uses Clenshaw-Curtis quadrature rules of increasing
305314 degree in each interval and bisects the interval if either the
@@ -403,37 +412,39 @@ def algorithm_4(f, a, b, tol, N_loops=int(1e9)):
403412 return igral , err , nr_points , ivals
404413
405414
406- ################ Tests ################
415+ # ############### Tests ################
407416
408417
409- def f0 (x ) :
418+ def f0 (x : Union [ float , np . ndarray ]) -> Union [ float , np . ndarray ] :
410419 return x * np .sin (1 / x ) * np .sqrt (abs (1 - x ))
411420
412421
413422def f7 (x ):
414423 return x ** - 0.5
415424
416425
417- def f24 (x ) :
426+ def f24 (x : Union [ float , np . ndarray ]) -> Union [ float , np . ndarray ] :
418427 return np .floor (np .exp (x ))
419428
420429
421- def f21 (x ) :
430+ def f21 (x : Union [ float , np . ndarray ]) -> Union [ float , np . ndarray ] :
422431 y = 0
423432 for i in range (1 , 4 ):
424433 y += 1 / np .cosh (20 ** i * (x - 2 * i / 10 ))
425434 return y
426435
427436
428- def f63 (x , alpha , beta ):
437+ def f63 (
438+ x : Union [float , np .ndarray ], alpha : float , beta : float
439+ ) -> Union [float , np .ndarray ]:
429440 return abs (x - beta ) ** alpha
430441
431442
432443def F63 (x , alpha , beta ):
433444 return (x - beta ) * abs (x - beta ) ** alpha / (alpha + 1 )
434445
435446
436- def fdiv (x ) :
447+ def fdiv (x : Union [ float , np . ndarray ]) -> Union [ float , np . ndarray ] :
437448 return abs (x - 0.987654321 ) ** - 1.1
438449
439450
@@ -461,7 +472,9 @@ def test_scalar_product(n=33):
461472 selection = [0 , 5 , 7 , n - 1 ]
462473 for i in selection :
463474 for j in selection :
464- assert scalar_product (legs [i ], legs [j ]) == ((i == j ) and Frac (2 , 2 * i + 1 ))
475+ assert scalar_product (legs [i ], legs [j ]) == (
476+ (i == j ) and Fraction (2 , 2 * i + 1 )
477+ )
465478
466479
467480def simple_newton (n ):
0 commit comments