@@ -2,51 +2,58 @@ module NonlinearSolveSIAMFANLEquationsExt
22
33using NonlinearSolve, SciMLBase
44using SIAMFANLEquations
5- import ConcreteStructs: @concrete
65import UnPack: @unpack
76
8- function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SIAMFANLEquationsJL , args... ; abstol = nothing ,
9- reltol = nothing , alias_u0:: Bool = false , maxiters = 1000 , termination_condition = nothing , kwargs... )
10- @assert (termination_condition === nothing ) || (termination_condition isa AbsNormTerminationMode) " SIAMFANLEquationsJL does not support termination conditions!"
7+ @inline function __siam_fanl_equations_retcode_mapping (sol)
8+ if sol. errcode == 0
9+ return ReturnCode. Success
10+ elseif sol. errcode == 10
11+ return ReturnCode. MaxIters
12+ elseif sol. errcode == 1
13+ return ReturnCode. Failure
14+ elseif sol. errcode == - 1
15+ return ReturnCode. Default
16+ end
17+ end
18+
19+ # pseudo transient continuation has a fixed cost per iteration, iteration statistics are
20+ # not interesting here.
21+ @inline function __siam_fanl_equations_stats_mapping (method, sol)
22+ method === :pseudotransient && return nothing
23+ return SciMLBase. NLStats (sum (sol. stats. ifun), sum (sol. stats. ijac), 0 , 0 ,
24+ sum (sol. stats. iarm))
25+ end
1126
12- @unpack method, show_trace, delta, linsolve = alg
27+ function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SIAMFANLEquationsJL , args... ;
28+ abstol = nothing , reltol = nothing , alias_u0:: Bool = false , maxiters = 1000 ,
29+ termination_condition = nothing , show_trace:: Val{ShT} = Val (false ),
30+ kwargs... ) where {ShT}
31+ @assert (termination_condition ===
32+ nothing )|| (termination_condition isa AbsNormTerminationMode) " SIAMFANLEquationsJL does not support termination conditions!"
33+
34+ @unpack method, delta, linsolve = alg
1335
1436 iip = SciMLBase. isinplace (prob)
15- T = eltype (prob. u0)
1637
17- atol = abstol === nothing ? real ( oneunit (T)) * ( eps ( real ( one (T)))) ^ ( 4 // 5 ) : abstol
18- rtol = reltol === nothing ? real ( oneunit (T)) * ( eps ( real ( one (T)))) ^ ( 4 // 5 ) : reltol
38+ atol = NonlinearSolve . DEFAULT_TOLERANCE ( abstol, eltype (prob . u0))
39+ rtol = NonlinearSolve . DEFAULT_TOLERANCE ( reltol, eltype (prob . u0))
1940
2041 if prob. u0 isa Number
21- f! = if iip
22- function (u)
23- du = similar (u)
24- prob. f (du, u, prob. p)
25- return du
26- end
27- else
28- u -> prob. f (u, prob. p)
29- end
42+ f = (u) -> prob. f (u, prob. p)
3043
3144 if method == :newton
32- sol = nsolsc (f! , prob. u0; maxit = maxiters, atol = atol , rtol = rtol , printerr = show_trace )
45+ sol = nsolsc (f, prob. u0; maxit = maxiters, atol, rtol, printerr = ShT )
3346 elseif method == :pseudotransient
34- sol = ptcsolsc (f!, prob. u0; delta0 = delta, maxit = maxiters, atol = atol, rtol= rtol, printerr = show_trace)
47+ sol = ptcsolsc (f, prob. u0; delta0 = delta, maxit = maxiters, atol, rtol,
48+ printerr = ShT)
3549 elseif method == :secant
36- sol = secant (f! , prob. u0; maxit = maxiters, atol = atol , rtol = rtol , printerr = show_trace )
50+ sol = secant (f, prob. u0; maxit = maxiters, atol, rtol, printerr = ShT )
3751 end
3852
39- if sol. errcode == 0
40- retcode = ReturnCode. Success
41- elseif sol. errcode == 10
42- retcode = ReturnCode. MaxIters
43- elseif sol. errcode == 1
44- retcode = ReturnCode. Failure
45- elseif sol. errcode == - 1
46- retcode = ReturnCode. Default
47- end
48- stats = method == :pseudotransient ? nothing : (SciMLBase. NLStats (sum (sol. stats. ifun), sum (sol. stats. ijac), 0 , 0 , sum (sol. stats. iarm)))
49- return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats, original = sol)
53+ retcode = __siam_fanl_equations_retcode_mapping (sol)
54+ stats = __siam_fanl_equations_stats_mapping (method, sol)
55+ return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode,
56+ stats, original = sol)
5057 else
5158 u = NonlinearSolve. __maybe_unaliased (prob. u0, alias_u0)
5259 end
@@ -71,67 +78,50 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SIAMFANLEquationsJL, arg
7178 if linsolve != = nothing
7279 # Allocate ahead for Krylov basis
7380 JVS = linsolve == :gmres ? zeros (T, N, 3 ) : zeros (T, N)
74- # `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers
81+ # `linsolve` as a Symbol to keep unified interface with other EXTs,
82+ # SIAMFANLEquations directly use String to choose between different linear solvers
7583 linsolve_alg = String (linsolve)
7684
7785 if method == :newton
78- sol = nsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
86+ sol = nsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol,
87+ rtol, printerr = ShT)
7988 elseif method == :pseudotransient
80- sol = ptcsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
81- end
82-
83- if sol. errcode == 0
84- retcode = ReturnCode. Success
85- elseif sol. errcode == 10
86- retcode = ReturnCode. MaxIters
87- elseif sol. errcode == 1
88- retcode = ReturnCode. Failure
89- elseif sol. errcode == - 1
90- retcode = ReturnCode. Default
89+ sol = ptcsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol,
90+ rtol, printerr = ShT)
9191 end
92- stats = method == :pseudotransient ? nothing : (SciMLBase. NLStats (sum (sol. stats. ifun), sum (sol. stats. ijac), 0 , 0 , sum (sol. stats. iarm)))
93- return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats, original = sol)
92+
93+ retcode = __siam_fanl_equations_retcode_mapping (sol)
94+ stats = __siam_fanl_equations_stats_mapping (method, sol)
95+ return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode,
96+ stats, original = sol)
9497 end
9598
9699 # Allocate ahead for Jacobian
97100 FPS = zeros (T, N, N)
98101 if prob. f. jac === nothing
99102 # Use the built-in Jacobian machinery
100103 if method == :newton
101- sol = nsol (f!, u, FS, FPS;
102- sham= 1 , atol = atol, rtol = rtol, maxit = maxiters,
103- printerr = show_trace)
104+ sol = nsol (f!, u, FS, FPS; sham = 1 , atol, rtol, maxit = maxiters,
105+ printerr = ShT)
104106 elseif method == :pseudotransient
105- sol = ptcsol (f!, u, FS, FPS;
106- atol = atol, rtol = rtol, maxit = maxiters,
107- delta0 = delta, printerr = show_trace)
107+ sol = ptcsol (f!, u, FS, FPS; atol, rtol, maxit = maxiters,
108+ delta0 = delta, printerr = ShT)
108109 end
109110 else
110111 AJ! (J, u, x) = prob. f. jac (J, x, prob. p)
111112 if method == :newton
112- sol = nsol (f!, u, FS, FPS, AJ!;
113- sham= 1 , atol = atol, rtol = rtol, maxit = maxiters,
114- printerr = show_trace)
113+ sol = nsol (f!, u, FS, FPS, AJ!; sham = 1 , atol, rtol, maxit = maxiters,
114+ printerr = ShT)
115115 elseif method == :pseudotransient
116- sol = ptcsol (f!, u, FS, FPS, AJ!;
117- atol = atol, rtol = rtol, maxit = maxiters,
118- delta0 = delta, printerr = show_trace)
116+ sol = ptcsol (f!, u, FS, FPS, AJ!; atol, rtol, maxit = maxiters,
117+ delta0 = delta, printerr = ShT)
119118 end
120119 end
121120
122- if sol. errcode == 0
123- retcode = ReturnCode. Success
124- elseif sol. errcode == 10
125- retcode = ReturnCode. MaxIters
126- elseif sol. errcode == 1
127- retcode = ReturnCode. Failure
128- elseif sol. errcode == - 1
129- retcode = ReturnCode. Default
130- end
131-
132- # pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here.
133- stats = method == :pseudotransient ? nothing : (SciMLBase. NLStats (sum (sol. stats. ifun), sum (sol. stats. ijac), 0 , 0 , sum (sol. stats. iarm)))
134- return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats, original = sol)
121+ retcode = __siam_fanl_equations_retcode_mapping (sol)
122+ stats = __siam_fanl_equations_stats_mapping (method, sol)
123+ return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats,
124+ original = sol)
135125end
136126
137- end
127+ end
0 commit comments