|
1 | 1 | module NonlinearSolveMINPACKExt |
2 | 2 |
|
3 | | -using NonlinearSolve, SciMLBase |
| 3 | +using NonlinearSolve, DiffEqBase, SciMLBase |
4 | 4 | using MINPACK |
5 | 5 |
|
6 | 6 | function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, |
7 | 7 | NonlinearLeastSquaresProblem{uType, iip}}, alg::CMINPACK, args...; |
8 | 8 | abstol = 1e-6, maxiters = 100000, alias_u0::Bool = false, |
9 | 9 | termination_condition = nothing, kwargs...) where {uType, iip} |
10 | | - @assert termination_condition===nothing "CMINPACK does not support termination conditions!" |
| 10 | + @assert (termination_condition === |
| 11 | + nothing)||(termination_condition isa AbsNormTerminationMode) "CMINPACK does not support termination conditions!" |
11 | 12 |
|
12 | 13 | if prob.u0 isa Number |
13 | 14 | u0 = [prob.u0] |
@@ -57,22 +58,26 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem{uType, iip}, |
57 | 58 | return Cint(0) |
58 | 59 | end |
59 | 60 | end |
60 | | - original = MINPACK.fsolve(f!, g!, u0, m; tol = abstol, show_trace, tracing, method, |
61 | | - iterations = maxiters, kwargs...) |
| 61 | + original = MINPACK.fsolve(f!, g!, vec(u0), m; tol = abstol, show_trace, tracing, |
| 62 | + method, iterations = maxiters, kwargs...) |
62 | 63 | else |
63 | | - original = MINPACK.fsolve(f!, u0, m; tol = abstol, show_trace, tracing, method, |
64 | | - iterations = maxiters, kwargs...) |
| 64 | + original = MINPACK.fsolve(f!, vec(u0), m; tol = abstol, show_trace, tracing, |
| 65 | + method, iterations = maxiters, kwargs...) |
65 | 66 | end |
66 | 67 |
|
67 | 68 | u = reshape(original.x, size(u)) |
68 | 69 | resid = original.f |
69 | 70 | # retcode = original.converged ? ReturnCode.Success : ReturnCode.Failure |
70 | 71 | # MINPACK lies about convergence? or maybe uses some other criteria? |
71 | 72 | # We just check for absolute tolerance on the residual |
72 | | - objective = NonlinearSolve.DEFAULT_NORM(resid) |
| 73 | + objective = maximum(abs, resid) |
73 | 74 | retcode = ifelse(objective ≤ abstol, ReturnCode.Success, ReturnCode.Failure) |
74 | 75 |
|
75 | | - return SciMLBase.build_solution(prob, alg, u, resid; retcode, original) |
| 76 | + # These are only meaningful if `tracing = true` |
| 77 | + stats = SciMLBase.NLStats(original.trace.f_calls, original.trace.g_calls, |
| 78 | + original.trace.g_calls, original.trace.g_calls, -1) |
| 79 | + |
| 80 | + return SciMLBase.build_solution(prob, alg, u, resid; stats, retcode, original) |
76 | 81 | end |
77 | 82 |
|
78 | 83 | end |
0 commit comments