@@ -30,12 +30,8 @@ def __init__(self, stiffness, mass, alpha_f=0.4, alpha_m=0.2) -> None:
3030 self .stiffness = stiffness
3131 self .mass = mass
3232
33- def rhs_eval_points (self , dt ) -> List [float ]:
34- return [(1 - self .alpha_f ) * dt ]
35-
36- def do_step (self , u , v , a , f , dt ) -> Tuple [float , float , float ]:
37- if isinstance (f , list ): # if f is list, turn it into a number
38- f = f [0 ]
33+ def do_step (self , u , v , a , rhs , dt ) -> Tuple [float , float , float ]:
34+ f = rhs ((1 - self .alpha_f ) * dt )
3935
4036 m = 3 * [None ]
4137 m [0 ] = (1 - self .alpha_m ) / (self .beta * dt ** 2 )
@@ -71,31 +67,25 @@ def __init__(self, ode_system) -> None:
7167 self .ode_system = ode_system
7268 pass
7369
74- def rhs_eval_points (self , dt ) -> List [float ]:
75- return [self .c [0 ] * dt , self .c [1 ] * dt , self .c [2 ] * dt , self .c [3 ] * dt ]
76-
77- def do_step (self , u , v , a , f , dt ) -> Tuple [float , float , float ]:
70+ def do_step (self , u , v , a , rhs , dt ) -> Tuple [float , float , float ]:
7871 assert (isinstance (u , type (v )))
7972
8073 n_stages = 4
8174
82- if isinstance (f , numbers .Number ): # if f is number, assume constant f
83- f = n_stages * [f ]
84-
8575 if isinstance (u , np .ndarray ):
8676 x = np .concatenate ([u , v ])
87- rhs = [ np .concatenate ([np .array ([0 , 0 ]), f [ i ]]) for i in range ( n_stages )]
77+ def f ( t ): return np .concatenate ([np .array ([0 , 0 ]), rhs ( t )])
8878 elif isinstance (u , numbers .Number ):
8979 x = np .array ([u , v ])
90- rhs = [ np .array ([0 , f [ i ]]) for i in range ( n_stages )]
80+ def f ( t ): return np .array ([0 , rhs ( t )])
9181 else :
9282 raise Exception (f"Cannot handle input type { type (u )} of u and v" )
9383
9484 s = n_stages * [None ]
95- s [0 ] = self .ode_system .dot (x ) + rhs [0 ]
96- s [1 ] = self .ode_system .dot (x + self .a [1 , 0 ] * s [0 ] * dt ) + rhs [1 ]
97- s [2 ] = self .ode_system .dot (x + self .a [2 , 1 ] * s [1 ] * dt ) + rhs [2 ]
98- s [3 ] = self .ode_system .dot (x + self .a [3 , 2 ] * s [2 ] * dt ) + rhs [3 ]
85+ s [0 ] = self .ode_system .dot (x ) + f ( self . c [0 ] * dt )
86+ s [1 ] = self .ode_system .dot (x + self .a [1 , 0 ] * s [0 ] * dt ) + f ( self . c [1 ] * dt )
87+ s [2 ] = self .ode_system .dot (x + self .a [2 , 1 ] * s [1 ] * dt ) + f ( self . c [2 ] * dt )
88+ s [3 ] = self .ode_system .dot (x + self .a [3 , 2 ] * s [2 ] * dt ) + f ( self . c [3 ] * dt )
9989
10090 x_new = x
10191
@@ -119,14 +109,7 @@ def __init__(self, ode_system) -> None:
119109 self .ode_system = ode_system
120110 pass
121111
122- def rhs_eval_points (self , dt ) -> List [float ]:
123- return np .linspace (0 , dt , 5 ) # will create an interpolant from this later
124-
125- def do_step (self , u , v , a , f , dt ) -> Tuple [float , float , float ]:
126- from brot .interpolation import do_lagrange_interpolation
127-
128- ts = self .rhs_eval_points (dt )
129-
112+ def do_step (self , u , v , a , rhs , dt ) -> Tuple [float , float , float ]:
130113 t0 = 0
131114
132115 assert (isinstance (u , type (v )))
@@ -135,25 +118,24 @@ def do_step(self, u, v, a, f, dt) -> Tuple[float, float, float]:
135118 x0 = np .concatenate ([u , v ])
136119 f = np .array (f )
137120 assert (u .shape [0 ] == f .shape [1 ])
138- def rhs_fun (t , x ): return np .concatenate ([np .array ([np .zeros_like (t ), np .zeros_like (t )]), [
139- do_lagrange_interpolation (t , ts , f [:, i ]) for i in range (u .shape [0 ])]])
121+ def rhs_fun (t ): return np .concatenate ([np .array ([np .zeros_like (t ), np .zeros_like (t )]), rhs (t )])
140122 elif isinstance (u , numbers .Number ):
141123 x0 = np .array ([u , v ])
142- def rhs_fun (t , x ): return np .array ([np .zeros_like (t ), do_lagrange_interpolation ( t , ts , f )])
124+ def rhs_fun (t ): return np .array ([np .zeros_like (t ), rhs ( t )])
143125 else :
144126 raise Exception (f"Cannot handle input type { type (u )} of u and v" )
145127
146128 def fun (t , x ):
147- return self .ode_system .dot (x ) + rhs_fun (t , x )
129+ return self .ode_system .dot (x ) + rhs_fun (t )
148130
149- # use large rtol and atol to circumvent error control.
131+ # use adaptive time stepping; dense_output=True allows us to sample from continuous function later
150132 ret = sp .integrate .solve_ivp (fun , [t0 , t0 + dt ], x0 , method = "Radau" ,
151- first_step = dt , max_step = dt , rtol = 10e10 , atol = 10e10 )
133+ dense_output = True , rtol = 10e-5 , atol = 10e-9 )
152134
153135 a_new = None
154136 if isinstance (u , np .ndarray ):
155137 u_new , v_new = ret .y [0 :2 , - 1 ], ret .y [2 :4 , - 1 ]
156138 elif isinstance (u , numbers .Number ):
157139 u_new , v_new = ret .y [:, - 1 ]
158140
159- return u_new , v_new , a_new
141+ return u_new , v_new , a_new , ret . sol
0 commit comments