@@ -43,9 +43,8 @@ function SciMLBase.__solve(prob::NonlinearProblem,
4343 f = Base. Fix2 (prob. f, prob. p)
4444 x = float (prob. u0)
4545 fx = f (x)
46- # fx = float(prob.u0)
47- if ! isa (fx, Number) || ! isa (x, Number)
48- error (" Halley currently only supports scalar-valued single-variable functions" )
46+ if isa (x, AbstractArray)
47+ n = length (x)
4948 end
5049 T = typeof (x)
5150
@@ -65,22 +64,45 @@ function SciMLBase.__solve(prob::NonlinearProblem,
6564
6665 for i in 1 : maxiters
6766 if alg_autodiff (alg)
68- fx = f (x)
69- dfdx (x) = ForwardDiff. derivative (f, x)
70- dfx = dfdx (x)
71- d2fx = ForwardDiff. derivative (dfdx, x)
67+ if isa (x, Number)
68+ fx = f (x)
69+ dfx = ForwardDiff. derivative (f, x)
70+ d2fx = ForwardDiff. derivative (x -> ForwardDiff. derivative (f, x), x)
71+ else
72+ fx = f (x)
73+ dfx = ForwardDiff. jacobian (f, x)
74+ d2fx = ForwardDiff. jacobian (x -> ForwardDiff. jacobian (f, x), x)
75+ ai = - (dfx \ fx)
76+ A = reshape (d2fx * ai, (n, n))
77+ bi = (dfx) \ (A * ai)
78+ ci = (ai .* ai) ./ (ai .+ (0.5 .* bi))
79+ end
7280 else
73- fx = f (x)
74- dfx = FiniteDiff. finite_difference_derivative (f, x, diff_type (alg), eltype (x),
75- fx)
76- d2fx = FiniteDiff. finite_difference_derivative (x -> FiniteDiff. finite_difference_derivative (f,
77- x),
78- x, diff_type (alg), eltype (x), fx)
81+ if isa (x, Number)
82+ fx = f (x)
83+ dfx = FiniteDiff. finite_difference_derivative (f, x, diff_type (alg), eltype (x))
84+ d2fx = FiniteDiff. finite_difference_derivative (x -> FiniteDiff. finite_difference_derivative (f, x), x,
85+ diff_type (alg), eltype (x))
86+ else
87+ fx = f (x)
88+ dfx = FiniteDiff. finite_difference_jacobian (f, x, diff_type (alg), eltype (x))
89+ d2fx = FiniteDiff. finite_difference_jacobian (x -> FiniteDiff. finite_difference_jacobian (f, x), x,
90+ diff_type (alg), eltype (x))
91+ ai = - (dfx \ fx)
92+ A = reshape (d2fx * ai, (n, n))
93+ bi = (dfx) \ (A * ai)
94+ ci = (ai .* ai) ./ (ai .+ (0.5 .* bi))
95+ end
7996 end
8097 iszero (fx) &&
8198 return SciMLBase. build_solution (prob, alg, x, fx; retcode = ReturnCode. Success)
82- Δx = (2 * dfx^ 2 - fx * d2fx) \ (2 fx * dfx)
83- x -= Δx
99+ if isa (x, Number)
100+ Δx = (2 * dfx^ 2 - fx * d2fx) \ (2 fx * dfx)
101+ x -= Δx
102+ else
103+ Δx = ci
104+ x += Δx
105+ end
84106 if isapprox (x, xo, atol = atol, rtol = rtol)
85107 return SciMLBase. build_solution (prob, alg, x, fx; retcode = ReturnCode. Success)
86108 end
0 commit comments