@@ -6,14 +6,19 @@ import ConcreteStructs: @concrete
66import UnPack: @unpack
77import FiniteDiff, ForwardDiff
88
9- function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SIAMFANLEquationsJL , args... ; abstol = 1e-8 ,
10- reltol = 1e-8 , alias_u0:: Bool = false , maxiters = 1000 , termination_condition = nothing , kwargs... )
9+ function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SIAMFANLEquationsJL , args... ; abstol = nothing ,
10+ reltol = nothing , alias_u0:: Bool = false , maxiters = 1000 , termination_condition = nothing , kwargs... )
1111 @assert (termination_condition === nothing ) || (termination_condition isa AbsNormTerminationMode) " SIAMFANLEquationsJL does not support termination conditions!"
1212
1313 @unpack method, autodiff, show_trace, delta, linsolve = alg
1414
1515 iip = SciMLBase. isinplace (prob)
16- if typeof (prob. u0) <: Number
16+ T = eltype (u0)
17+
18+ atol = abstol === nothing ? real (oneunit (T)) * (eps (real (one (T))))^ (4 // 5 ) : abstol
19+ rtol = reltol === nothing ? real (oneunit (T)) * (eps (real (one (T))))^ (4 // 5 ) : reltol
20+
21+ if prob. u0 isa Number
1722 f! = if iip
1823 function (u)
1924 du = similar (u)
@@ -25,11 +30,11 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
2530 end
2631
2732 if method == :newton
28- sol = nsolsc (f!, prob. u0; maxit = maxiters, atol = abstol , rtol = reltol , printerr = show_trace)
33+ sol = nsolsc (f!, prob. u0; maxit = maxiters, atol = atol , rtol = rtol , printerr = show_trace)
2934 elseif method == :pseudotransient
30- sol = ptcsolsc (f!, prob. u0; delta0 = delta, maxit = maxiters, atol = abstol , rtol= reltol , printerr = show_trace)
35+ sol = ptcsolsc (f!, prob. u0; delta0 = delta, maxit = maxiters, atol = atol , rtol= rtol , printerr = show_trace)
3136 elseif method == :secant
32- sol = secant (f!, prob. u0; maxit = maxiters, atol = abstol , rtol = reltol , printerr = show_trace)
37+ sol = secant (f!, prob. u0; maxit = maxiters, atol = atol , rtol = rtol , printerr = show_trace)
3338 end
3439
3540 if sol. errcode == 0
@@ -61,22 +66,21 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
6166 end
6267 end
6368
64- # Allocate ahead for function and Jacobian
69+ # Allocate ahead for function
6570 N = length (u)
66- FS = zeros (eltype (u), N)
67- FPS = zeros (eltype (u), N, N)
68- # Allocate ahead for Krylov basis
71+ FS = zeros (T, N)
6972
7073 # Jacobian free Newton Krylov
7174 if linsolve != = nothing
72- JVS = linsolve == :gmres ? zeros (eltype (u), N, 3 ) : zeros (eltype (u), N)
75+ # Allocate ahead for Krylov basis
76+ JVS = linsolve == :gmres ? zeros (T, N, 3 ) : zeros (T, N)
7377 # `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers
7478 linsolve_alg = String (linsolve)
7579
7680 if method == :newton
77- sol = nsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol , rtol = reltol , printerr = show_trace)
81+ sol = nsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol , rtol = rtol , printerr = show_trace)
7882 elseif method == :pseudotransient
79- sol = ptcsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = abstol , rtol = reltol , printerr = show_trace)
83+ sol = ptcsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol , rtol = rtol , printerr = show_trace)
8084 end
8185
8286 if sol. errcode == 0
@@ -92,64 +96,30 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
9296 return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats, original = sol)
9397 end
9498
99+ # Allocate ahead for Jacobian
100+ FPS = zeros (T, N, N)
95101 if prob. f. jac === nothing
96- use_forward_diff = if alg. autodiff === nothing
97- ForwardDiff. can_dual (eltype (u))
98- else
99- alg. autodiff isa AutoForwardDiff
100- end
101- uf = SciMLBase. JacobianWrapper {iip} (prob. f, prob. p)
102- if use_forward_diff
103- cache = iip ? ForwardDiff. JacobianConfig (uf, fu, u) :
104- ForwardDiff. JacobianConfig (uf, u)
105- else
106- cache = FiniteDiff. JacobianCache (u, fu)
107- end
108- J! = if iip
109- if use_forward_diff
110- fu_cache = similar (fu)
111- function (J, x, p)
112- uf. p = p
113- ForwardDiff. jacobian! (J, uf, fu_cache, x, cache)
114- return J
115- end
116- else
117- function (J, x, p)
118- uf. p = p
119- FiniteDiff. finite_difference_jacobian! (J, uf, x, cache)
120- return J
121- end
122- end
123- else
124- if use_forward_diff
125- function (J, x, p)
126- uf. p = p
127- ForwardDiff. jacobian! (J, uf, x, cache)
128- return J
129- end
130- else
131- function (J, x, p)
132- uf. p = p
133- J_ = FiniteDiff. finite_difference_jacobian (uf, x, cache)
134- copyto! (J, J_)
135- return J
136- end
137- end
102+ # Use the built-in Jacobian machinery
103+ if method == :newton
104+ sol = nsol (f!, u, FS, FPS;
105+ sham= 1 , atol = atol, rtol = rtol, maxit = maxiters,
106+ printerr = show_trace)
107+ elseif method == :pseudotransient
108+ sol = ptcsol (f!, u, FS, FPS;
109+ atol = atol, rtol = rtol, maxit = maxiters,
110+ delta0 = delta, printerr = show_trace)
138111 end
139112 else
140- J! = prob. f. jac
141- end
142-
143- AJ! (J, u, x) = J! (J, x, prob. p)
144-
145- if method == :newton
146- sol = nsol (f!, u, FS, FPS, AJ!;
147- sham= 1 , rtol = reltol, atol = abstol, maxit = maxiters,
148- printerr = show_trace)
149- elseif method == :pseudotransient
150- sol = ptcsol (f!, u, FS, FPS, AJ!;
151- rtol = reltol, atol = abstol, maxit = maxiters,
152- delta0 = delta, printerr = show_trace)
113+ AJ! (J, u, x) = prob. f. jac (J, x, prob. p)
114+ if method == :newton
115+ sol = nsol (f!, u, FS, FPS, AJ!;
116+ sham= 1 , atol = atol, rtol = rtol, maxit = maxiters,
117+ printerr = show_trace)
118+ elseif method == :pseudotransient
119+ sol = ptcsol (f!, u, FS, FPS, AJ!;
120+ atol = atol, rtol = rtol, maxit = maxiters,
121+ delta0 = delta, printerr = show_trace)
122+ end
153123 end
154124
155125 if sol. errcode == 0
0 commit comments