@@ -2,6 +2,7 @@ module NonlinearSolveMINPACKExt
22
33using NonlinearSolve, DiffEqBase, SciMLBase
44using MINPACK
5+ import FastClosures: @closure
56
67function SciMLBase. __solve (prob:: Union {NonlinearProblem{uType, iip},
78 NonlinearLeastSquaresProblem{uType, iip}}, alg:: CMINPACK , args... ;
@@ -11,80 +12,42 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip},
1112 @assert (termination_condition ===
1213 nothing )|| (termination_condition isa AbsNormTerminationMode) " CMINPACK does not support termination conditions!"
1314
14- if prob. u0 isa Number
15- u0 = [prob. u0]
16- else
17- u0 = NonlinearSolve. __maybe_unaliased (prob. u0, alias_u0)
18- end
15+ f!_, u0 = NonlinearSolve. __construct_f (prob; alias_u0)
16+ f! = @closure (du, u) -> (f!_ (du, u); Cint (0 ))
1917
20- sizeu = size (prob. u0)
21- p = prob. p
22-
23- # unwrapping alg params
24- show_trace = alg. show_trace || ShT
25- tracing = alg. tracing || StT
26-
27- if ! iip && prob. u0 isa Number
28- f! = (du, u) -> (du .= prob. f (first (u), p); Cint (0 ))
29- elseif ! iip && prob. u0 isa AbstractVector
30- f! = (du, u) -> (du .= prob. f (u, p); Cint (0 ))
31- elseif ! iip && prob. u0 isa AbstractArray
32- f! = (du, u) -> (du .= vec (prob. f (reshape (u, sizeu), p)); Cint (0 ))
33- elseif prob. u0 isa AbstractVector
34- f! = (du, u) -> prob. f (du, u, p)
35- else # Then it's an in-place function on an abstract array
36- f! = (du, u) -> (prob. f (reshape (du, sizeu), reshape (u, sizeu), p); du = vec (du); 0 )
37- end
38-
39- u = zero (u0)
40- resid = NonlinearSolve. evaluate_f (prob, u)
18+ resid = NonlinearSolve. evaluate_f (prob, prob. u0)
4119 m = length (resid)
42- size_jac = (length (resid), length (u))
4320
4421 method = ifelse (alg. method === :auto ,
4522 ifelse (prob isa NonlinearLeastSquaresProblem, :lm , :hybr ), alg. method)
4623
47- abstol = NonlinearSolve. DEFAULT_TOLERANCE (abstol, eltype (u))
24+ show_trace = alg. show_trace || ShT
25+ tracing = alg. tracing || StT
26+ tol = NonlinearSolve. DEFAULT_TOLERANCE (abstol, eltype (u0))
27+
28+ jac!_ = NonlinearSolve. __construct_jac (prob, alg, u0)
4829
49- if SciMLBase. has_jac (prob. f)
50- if ! iip && prob. u0 isa Number
51- g! = (du, u) -> (du .= prob. f. jac (first (u), p); Cint (0 ))
52- elseif ! iip && prob. u0 isa AbstractVector
53- g! = (du, u) -> (du .= prob. f. jac (u, p); Cint (0 ))
54- elseif ! iip && prob. u0 isa AbstractArray
55- g! = (du, u) -> (du .= vec (prob. f. jac (reshape (u, sizeu), p)); Cint (0 ))
56- elseif prob. u0 isa AbstractVector
57- g! = (du, u) -> prob. f. jac (du, u, p)
58- else # Then it's an in-place function on an abstract array
59- g! = function (du, u)
60- prob. f. jac (reshape (du, size_jac), reshape (u, sizeu), p)
61- return Cint (0 )
62- end
63- end
64- original = MINPACK. fsolve (f!, g!, vec (u0), m; tol = abstol, show_trace, tracing,
65- method, iterations = maxiters)
30+ if jac!_ === nothing
31+ original = MINPACK. fsolve (f!, u0, m; tol, show_trace, tracing, method,
32+ iterations = maxiters)
6633 else
67- original = MINPACK. fsolve (f!, vec (u0), m; tol = abstol, show_trace, tracing,
68- method, iterations = maxiters)
34+ jac! = @closure ((J, u) -> (jac!_ (J, u); Cint (0 )))
35+ original = MINPACK. fsolve (f!, jac!, u0, m; tol, show_trace, tracing, method,
36+ iterations = maxiters)
6937 end
7038
71- u = reshape (original. x, size (u))
72- resid = original. f
73- # retcode = original.converged ? ReturnCode.Success : ReturnCode.Failure
74- # MINPACK lies about convergence? or maybe uses some other criteria?
75- # We just check for absolute tolerance on the residual
76- objective = maximum (abs, resid)
77- retcode = ifelse (objective ≤ abstol, ReturnCode. Success, ReturnCode. Failure)
39+ u = original. x
40+ resid_ = original. f
41+ objective = maximum (abs, resid_)
42+ retcode = ifelse (objective ≤ tol, ReturnCode. Success, ReturnCode. Failure)
7843
79- # These are only meaningful if `tracing = true`
44+ # These are only meaningful if `store_trace = Val( true) `
8045 stats = SciMLBase. NLStats (original. trace. f_calls, original. trace. g_calls,
8146 original. trace. g_calls, original. trace. g_calls, - 1 )
8247
83- if prob. u0 isa Number
84- return SciMLBase. build_solution (prob, alg, u[1 ], resid[1 ]; stats, retcode, original)
85- else
86- return SciMLBase. build_solution (prob, alg, u, resid; stats, retcode, original)
87- end
48+ u_ = prob. u0 isa Number ? original. x[1 ] : reshape (original. x, size (prob. u0))
49+ resid_ = prob. u0 isa Number ? resid_[1 ] : reshape (resid_, size (resid))
50+ return SciMLBase. build_solution (prob, alg, u_, resid_; retcode, original, stats)
8851end
8952
9053end
0 commit comments