1+ module NonlinearSolveSIAMFANLEquationsExt
2+
3+ using NonlinearSolve, SciMLBase
4+ using SIAMFANLEquations
5+ import ConcreteStructs: @concrete
6+ import UnPack: @unpack
7+
8+ function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SIAMFANLEquationsJL , args... ; abstol = nothing ,
9+ reltol = nothing , alias_u0:: Bool = false , maxiters = 1000 , termination_condition = nothing , kwargs... )
10+ @assert (termination_condition === nothing ) || (termination_condition isa AbsNormTerminationMode) " SIAMFANLEquationsJL does not support termination conditions!"
11+
12+ @unpack method, show_trace, delta, linsolve = alg
13+
14+ iip = SciMLBase. isinplace (prob)
15+ T = eltype (prob. u0)
16+
17+ atol = abstol === nothing ? real (oneunit (T)) * (eps (real (one (T))))^ (4 // 5 ) : abstol
18+ rtol = reltol === nothing ? real (oneunit (T)) * (eps (real (one (T))))^ (4 // 5 ) : reltol
19+
20+ if prob. u0 isa Number
21+ f! = if iip
22+ function (u)
23+ du = similar (u)
24+ prob. f (du, u, prob. p)
25+ return du
26+ end
27+ else
28+ u -> prob. f (u, prob. p)
29+ end
30+
31+ if method == :newton
32+ sol = nsolsc (f!, prob. u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
33+ elseif method == :pseudotransient
34+ sol = ptcsolsc (f!, prob. u0; delta0 = delta, maxit = maxiters, atol = atol, rtol= rtol, printerr = show_trace)
35+ elseif method == :secant
36+ sol = secant (f!, prob. u0; maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
37+ end
38+
39+ if sol. errcode == 0
40+ retcode = ReturnCode. Success
41+ elseif sol. errcode == 10
42+ retcode = ReturnCode. MaxIters
43+ elseif sol. errcode == 1
44+ retcode = ReturnCode. Failure
45+ elseif sol. errcode == - 1
46+ retcode = ReturnCode. Default
47+ end
48+ stats = method == :pseudotransient ? nothing : (SciMLBase. NLStats (sum (sol. stats. ifun), sum (sol. stats. ijac), 0 , 0 , sum (sol. stats. iarm)))
49+ return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats, original = sol)
50+ else
51+ u = NonlinearSolve. __maybe_unaliased (prob. u0, alias_u0)
52+ end
53+
54+ if iip
55+ f! = function (du, u)
56+ prob. f (du, u, prob. p)
57+ return du
58+ end
59+ else
60+ f! = function (du, u)
61+ du .= prob. f (u, prob. p)
62+ return du
63+ end
64+ end
65+
66+ # Allocate ahead for function
67+ N = length (u)
68+ FS = zeros (T, N)
69+
70+ # Jacobian free Newton Krylov
71+ if linsolve != = nothing
72+ # Allocate ahead for Krylov basis
73+ JVS = linsolve == :gmres ? zeros (T, N, 3 ) : zeros (T, N)
74+ # `linsolve` as a Symbol to keep unified interface with other EXTs, SIAMFANLEquations directly use String to choose between different linear solvers
75+ linsolve_alg = String (linsolve)
76+
77+ if method == :newton
78+ sol = nsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
79+ elseif method == :pseudotransient
80+ sol = ptcsoli (f!, u, FS, JVS; lsolver = linsolve_alg, maxit = maxiters, atol = atol, rtol = rtol, printerr = show_trace)
81+ end
82+
83+ if sol. errcode == 0
84+ retcode = ReturnCode. Success
85+ elseif sol. errcode == 10
86+ retcode = ReturnCode. MaxIters
87+ elseif sol. errcode == 1
88+ retcode = ReturnCode. Failure
89+ elseif sol. errcode == - 1
90+ retcode = ReturnCode. Default
91+ end
92+ stats = method == :pseudotransient ? nothing : (SciMLBase. NLStats (sum (sol. stats. ifun), sum (sol. stats. ijac), 0 , 0 , sum (sol. stats. iarm)))
93+ return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats, original = sol)
94+ end
95+
96+ # Allocate ahead for Jacobian
97+ FPS = zeros (T, N, N)
98+ if prob. f. jac === nothing
99+ # Use the built-in Jacobian machinery
100+ if method == :newton
101+ sol = nsol (f!, u, FS, FPS;
102+ sham= 1 , atol = atol, rtol = rtol, maxit = maxiters,
103+ printerr = show_trace)
104+ elseif method == :pseudotransient
105+ sol = ptcsol (f!, u, FS, FPS;
106+ atol = atol, rtol = rtol, maxit = maxiters,
107+ delta0 = delta, printerr = show_trace)
108+ end
109+ else
110+ AJ! (J, u, x) = prob. f. jac (J, x, prob. p)
111+ if method == :newton
112+ sol = nsol (f!, u, FS, FPS, AJ!;
113+ sham= 1 , atol = atol, rtol = rtol, maxit = maxiters,
114+ printerr = show_trace)
115+ elseif method == :pseudotransient
116+ sol = ptcsol (f!, u, FS, FPS, AJ!;
117+ atol = atol, rtol = rtol, maxit = maxiters,
118+ delta0 = delta, printerr = show_trace)
119+ end
120+ end
121+
122+ if sol. errcode == 0
123+ retcode = ReturnCode. Success
124+ elseif sol. errcode == 10
125+ retcode = ReturnCode. MaxIters
126+ elseif sol. errcode == 1
127+ retcode = ReturnCode. Failure
128+ elseif sol. errcode == - 1
129+ retcode = ReturnCode. Default
130+ end
131+
132+ # pseudo transient continuation has a fixed cost per iteration, iteration statistics are not interesting here.
133+ stats = method == :pseudotransient ? nothing : (SciMLBase. NLStats (sum (sol. stats. ifun), sum (sol. stats. ijac), 0 , 0 , sum (sol. stats. iarm)))
134+ return SciMLBase. build_solution (prob, alg, sol. solution, sol. history; retcode, stats, original = sol)
135+ end
136+
137+ end
0 commit comments