@@ -4,7 +4,7 @@ using ArrayInterface, NonlinearSolve, SciMLBase
44import ConcreteStructs: @concrete
55import FastClosures: @closure
66import FastLevenbergMarquardt as FastLM
7- import StaticArraysCore: StaticArray
7+ import StaticArraysCore: SArray
88
99@inline function _fast_lm_solver (:: FastLevenbergMarquardtJL{linsolve} , x) where {linsolve}
1010 if linsolve === :cholesky
@@ -15,53 +15,54 @@ import StaticArraysCore: StaticArray
1515 throw (ArgumentError (" Unknown FastLevenbergMarquardt Linear Solver: $linsolve " ))
1616 end
1717end
18+ @inline _fast_lm_solver (:: FastLevenbergMarquardtJL{linsolve} , :: SArray ) where {linsolve} = linsolve
1819
19- # TODO : Implement reinit
20- @concrete struct FastLevenbergMarquardtJLCache
21- f!
22- J!
23- prob
24- alg
25- lmworkspace
26- solver
27- kwargs
28- end
29-
30- function SciMLBase. __init (prob:: NonlinearLeastSquaresProblem ,
20+ function SciMLBase. __solve (prob:: NonlinearLeastSquaresProblem ,
3121 alg:: FastLevenbergMarquardtJL , args... ; alias_u0 = false , abstol = nothing ,
3222 reltol = nothing , maxiters = 1000 , termination_condition = nothing , kwargs... )
3323 NonlinearSolve. __test_termination_condition (termination_condition,
3424 :FastLevenbergMarquardt )
35- if prob. u0 isa StaticArray # FIXME
36- error (" FastLevenbergMarquardtJL does not support StaticArrays yet." )
37- end
3825
39- _f!, u, resid = NonlinearSolve. __construct_extension_f (prob; alias_u0)
40- f! = @closure (du, u, p) -> _f! (du, u)
26+ fn, u, resid = NonlinearSolve. __construct_extension_f (prob; alias_u0,
27+ can_handle_oop = Val (prob. u0 isa SArray))
28+ f = if prob. u0 isa SArray
29+ @closure (u, p) -> fn (u)
30+ else
31+ @closure (du, u, p) -> fn (du, u)
32+ end
4133 abstol = NonlinearSolve. DEFAULT_TOLERANCE (abstol, eltype (u))
4234 reltol = NonlinearSolve. DEFAULT_TOLERANCE (reltol, eltype (u))
4335
44- _J! = NonlinearSolve. __construct_extension_jac (prob, alg, u, resid; alg. autodiff)
45- J! = @closure (J, u, p) -> _J! (J, u)
46- J = prob. f. jac_prototype === nothing ? similar (u, length (resid), length (u)) :
47- zero (prob. f. jac_prototype)
36+ _jac_fn = NonlinearSolve. __construct_extension_jac (prob, alg, u, resid; alg. autodiff,
37+ can_handle_oop = Val (prob. u0 isa SArray))
38+ jac_fn = if prob. u0 isa SArray
39+ @closure (u, p) -> _jac_fn (u)
40+ else
41+ @closure (J, u, p) -> _jac_fn (J, u)
42+ end
4843
49- solver = _fast_lm_solver (alg, u)
50- LM = FastLM. LMWorkspace (u, resid, J)
44+ solver_kwargs = (; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters,
45+ alg. factor, alg. factoraccept, alg. factorreject, alg. minscale, alg. maxscale,
46+ alg. factorupdate, alg. minfactor, alg. maxfactor)
5147
52- return FastLevenbergMarquardtJLCache (f!, J!, prob, alg, LM, solver,
53- (; xtol = reltol, ftol = reltol, gtol = abstol, maxit = maxiters, alg. factor,
54- alg. factoraccept, alg. factorreject, alg. minscale, alg. maxscale,
55- alg. factorupdate, alg. minfactor, alg. maxfactor))
56- end
48+ if prob. u0 isa SArray
49+ res, fx, info, iter, nfev, njev = FastLM. lmsolve (f, jac_fn, prob. u0;
50+ solver_kwargs... )
51+ LM, solver = nothing , nothing
52+ else
53+ J = prob. f. jac_prototype === nothing ? similar (u, length (resid), length (u)) :
54+ zero (prob. f. jac_prototype)
55+ solver = _fast_lm_solver (alg, u)
56+ LM = FastLM. LMWorkspace (u, resid, J)
57+
58+ res, fx, info, iter, nfev, njev, LM, solver = FastLM. lmsolve! (f, jac_fn, LM;
59+ solver, solver_kwargs... )
60+ end
5761
58- function SciMLBase. solve! (cache:: FastLevenbergMarquardtJLCache )
59- res, fx, info, iter, nfev, njev, LM, solver = FastLM. lmsolve! (cache. f!, cache. J!,
60- cache. lmworkspace; cache. solver, cache. kwargs... )
6162 stats = SciMLBase. NLStats (nfev, njev, - 1 , - 1 , iter)
6263 retcode = info == - 1 ? ReturnCode. MaxIters : ReturnCode. Success
63- return SciMLBase. build_solution (cache . prob, cache . alg, res, fx;
64- retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
64+ return SciMLBase. build_solution (prob, alg, res, fx; retcode,
65+ original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
6566end
6667
6768end
0 commit comments