@@ -5,7 +5,7 @@ import ConcreteStructs: @concrete
55import FastLevenbergMarquardt as FastLM
66import FiniteDiff, ForwardDiff
77
8- function _fast_lm_solver (:: FastLevenbergMarquardtJL{linsolve} , x) where {linsolve}
8+ @inline function _fast_lm_solver (:: FastLevenbergMarquardtJL{linsolve} , x) where {linsolve}
99 if linsolve === :cholesky
1010 return FastLM. CholeskySolver (ArrayInterface. undefmatrix (x))
1111 elseif linsolve === :qr
@@ -15,6 +15,7 @@ function _fast_lm_solver(::FastLevenbergMarquardtJL{linsolve}, x) where {linsolv
1515 end
1616end
1717
18+ # TODO : Implement reinit
1819@concrete struct FastLevenbergMarquardtJLCache
1920 f!
2021 J!
2526 kwargs
2627end
2728
28- @concrete struct InplaceFunction{iip} <: Function
29- f
30- end
31-
32- (f:: InplaceFunction{true} )(fx, x, p) = f. f (fx, x, p)
33- (f:: InplaceFunction{false} )(fx, x, p) = (fx .= f. f (x, p))
34-
3529function SciMLBase. __init (prob:: NonlinearLeastSquaresProblem ,
36- alg:: FastLevenbergMarquardtJL , args... ; alias_u0 = false , abstol = 1e-8 ,
37- reltol = 1e-8 , maxiters = 1000 , kwargs... )
30+ alg:: FastLevenbergMarquardtJL , args... ; alias_u0 = false , abstol = nothing ,
31+ reltol = nothing , maxiters = 1000 , kwargs... )
32+ # FIXME : Support scalar u0
33+ prob. u0 isa Number &&
34+ throw (ArgumentError (" FastLevenbergMarquardtJL does not support scalar `u0`" ))
3835 iip = SciMLBase. isinplace (prob)
3936 u = NonlinearSolve. __maybe_unaliased (prob. u0, alias_u0)
4037 fu = NonlinearSolve. evaluate_f (prob, u)
4138
42- f! = InplaceFunction {iip} (prob. f)
39+ f! = NonlinearSolve. __make_inplace {iip} (prob. f, nothing )
40+
41+ abstol = NonlinearSolve. DEFAULT_TOLERANCE (abstol, eltype (u))
42+ reltol = NonlinearSolve. DEFAULT_TOLERANCE (reltol, eltype (u))
4343
4444 if prob. f. jac === nothing
45- use_forward_diff = if alg. autodiff === nothing
46- ForwardDiff. can_dual (eltype (u))
47- else
48- alg. autodiff isa AutoForwardDiff
49- end
50- uf = SciMLBase. JacobianWrapper {iip} (prob. f, prob. p)
51- if use_forward_diff
52- cache = iip ? ForwardDiff. JacobianConfig (uf, fu, u) :
53- ForwardDiff. JacobianConfig (uf, u)
54- else
55- cache = FiniteDiff. JacobianCache (u, fu)
56- end
57- J! = if iip
58- if use_forward_diff
59- fu_cache = similar (fu)
60- function (J, x, p)
61- uf. p = p
62- ForwardDiff. jacobian! (J, uf, fu_cache, x, cache)
63- return J
64- end
65- else
66- function (J, x, p)
67- uf. p = p
68- FiniteDiff. finite_difference_jacobian! (J, uf, x, cache)
69- return J
70- end
71- end
72- else
73- if use_forward_diff
74- function (J, x, p)
75- uf. p = p
76- ForwardDiff. jacobian! (J, uf, x, cache)
77- return J
78- end
79- else
80- function (J, x, p)
81- uf. p = p
82- J_ = FiniteDiff. finite_difference_jacobian (uf, x, cache)
83- copyto! (J, J_)
84- return J
85- end
86- end
87- end
45+ alg = NonlinearSolve. get_concrete_algorithm (alg, prob)
46+ J! = NonlinearSolve. __construct_jac (prob, alg, u;
47+ can_handle_arbitrary_dims = Val (true ))
8848 else
89- J! = InplaceFunction {iip} (prob. f. jac)
49+ J! = NonlinearSolve . __make_inplace {iip} (prob. f. jac, nothing )
9050 end
9151
9252 J = similar (u, length (fu), length (u))
@@ -95,17 +55,16 @@ function SciMLBase.__init(prob::NonlinearLeastSquaresProblem,
9555 LM = FastLM. LMWorkspace (u, fu, J)
9656
9757 return FastLevenbergMarquardtJLCache (f!, J!, prob, alg, LM, solver,
98- (; xtol = abstol , ftol = reltol, maxit = maxiters, alg. factor, alg . factoraccept ,
99- alg. factorreject , alg. minscale , alg. maxscale , alg. factorupdate, alg . minfactor ,
100- alg. maxfactor, kwargs ... ))
58+ (; xtol = reltol , ftol = reltol, gtol = abstol, maxit = maxiters, alg. factor,
59+ alg. factoraccept , alg. factorreject , alg. minscale , alg. maxscale ,
60+ alg. factorupdate, alg . minfactor, alg . maxfactor ))
10161end
10262
10363function SciMLBase. solve! (cache:: FastLevenbergMarquardtJLCache )
10464 res, fx, info, iter, nfev, njev, LM, solver = FastLM. lmsolve! (cache. f!, cache. J!,
10565 cache. lmworkspace, cache. prob. p; cache. solver, cache. kwargs... )
10666 stats = SciMLBase. NLStats (nfev, njev, - 1 , - 1 , iter)
107- retcode = info == 1 ? ReturnCode. Success :
108- (info == - 1 ? ReturnCode. MaxIters : ReturnCode. Default)
67+ retcode = info == - 1 ? ReturnCode. MaxIters : ReturnCode. Success
10968 return SciMLBase. build_solution (cache. prob, cache. alg, res, fx;
11069 retcode, original = (res, fx, info, iter, nfev, njev, LM, solver), stats)
11170end
0 commit comments