Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 9f5788e

Browse files
Merge branch 'main' into ap/termination_broyden
2 parents 2eacfa8 + f250799 commit 9f5788e

File tree

6 files changed

+133
-10
lines changed

6 files changed

+133
-10
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ authors = ["SciML"]
44
version = "0.1.12"
55

66
[deps]
7-
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
7+
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
88
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
99
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1010
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
@@ -22,8 +22,8 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
2222
SimpleBatchedNonlinearSolveExt = "NNlib"
2323

2424
[compat]
25+
ArrayInterface = "6, 7"
2526
ArrayInterfaceCore = "0.1.1"
26-
DiffEqBase = "6.118.1"
2727
FiniteDiff = "2"
2828
ForwardDiff = "0.10.3"
2929
NNlib = "0.8"

ext/SimpleBatchedNonlinearSolveExt.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module SimpleBatchedNonlinearSolveExt
22

3-
using ArrayInterfaceCore, DiffEqBase, LinearAlgebra, SimpleNonlinearSolve, SciMLBase
3+
using ArrayInterface, DiffEqBase, LinearAlgebra, SimpleNonlinearSolve, SciMLBase
4+
45
isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib)
56

67
_batch_transpose(x) = reshape(x, 1, size(x)...)
@@ -20,7 +21,7 @@ function _batched_mul(x::AbstractArray{T1, 3}, y::AbstractArray{T2, 3}) where {T
2021
end
2122

2223
function _init_J_batched(x::AbstractMatrix{T}) where {T}
23-
J = ArrayInterfaceCore.zeromatrix(x[:, 1])
24+
J = ArrayInterface.zeromatrix(x[:, 1])
2425
if ismutable(x)
2526
J[diagind(J)] .= one(eltype(x))
2627
else

src/SimpleNonlinearSolve.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using FiniteDiff, ForwardDiff
55
using ForwardDiff: Dual
66
using StaticArraysCore
77
using LinearAlgebra
8-
import ArrayInterfaceCore
8+
import ArrayInterface
99
using DiffEqBase
1010

1111
@reexport using SciMLBase
@@ -37,12 +37,14 @@ include("ridder.jl")
3737
include("brent.jl")
3838
include("dfsane.jl")
3939
include("ad.jl")
40+
include("halley.jl")
4041

4142
import SnoopPrecompile
4243

4344
SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
4445
prob_no_brack = NonlinearProblem{false}((u, p) -> u .* u .- p, T(0.1), T(2))
45-
for alg in (SimpleNewtonRaphson, Broyden, Klement, SimpleTrustRegion, SimpleDFSane)
46+
for alg in (SimpleNewtonRaphson, Halley, Broyden, Klement, SimpleTrustRegion,
47+
SimpleDFSane)
4648
solve(prob_no_brack, alg(), abstol = T(1e-2))
4749
end
4850

@@ -63,7 +65,7 @@ SnoopPrecompile.@precompile_all_calls begin for T in (Float32, Float64)
6365
end end
6466

6567
# DiffEq styled algorithms
66-
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Klement,
68+
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement,
6769
Ridder, SimpleNewtonRaphson, SimpleTrustRegion
6870

6971
end # module

src/halley.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""
2+
```julia
3+
Halley(; chunk_size = Val{0}(), autodiff = Val{true}(),
4+
diff_type = Val{:forward})
5+
```
6+
7+
A low-overhead implementation of Halley's Method. This method is non-allocating on scalar
8+
and static array problems.
9+
10+
!!! note
11+
12+
As part of the decreased overhead, this method omits some of the higher level error
13+
catching of the other methods. Thus, to see better error messages, use one of the other
14+
methods like `NewtonRaphson`
15+
16+
### Keyword Arguments
17+
18+
- `chunk_size`: the chunk size used by the internal ForwardDiff.jl automatic differentiation
19+
system. This allows for multiple derivative columns to be computed simultaneously,
20+
improving performance. Defaults to `0`, which is equivalent to using ForwardDiff.jl's
21+
default chunk size mechanism. For more details, see the documentation for
22+
[ForwardDiff.jl](https://juliadiff.org/ForwardDiff.jl/stable/).
23+
- `autodiff`: whether to use forward-mode automatic differentiation for the Jacobian.
24+
Note that this argument is ignored if an analytical Jacobian is passed; as that will be
25+
used instead. Defaults to `Val{true}`, which means ForwardDiff.jl is used by default.
26+
If `Val{false}`, then FiniteDiff.jl is used for finite differencing.
27+
- `diff_type`: the type of finite differencing used if `autodiff = false`. Defaults to
28+
`Val{:forward}` for forward finite differences. For more details on the choices, see the
29+
[FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) documentation.
30+
"""
31+
struct Halley{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT}
32+
function Halley(; chunk_size = Val{0}(), autodiff = Val{true}(),
33+
diff_type = Val{:forward})
34+
new{SciMLBase._unwrap_val(chunk_size), SciMLBase._unwrap_val(autodiff),
35+
SciMLBase._unwrap_val(diff_type)}()
36+
end
37+
end
38+
39+
function SciMLBase.__solve(prob::NonlinearProblem,
40+
alg::Halley, args...; abstol = nothing,
41+
reltol = nothing,
42+
maxiters = 1000, kwargs...)
43+
f = Base.Fix2(prob.f, prob.p)
44+
x = float(prob.u0)
45+
fx = f(x)
46+
# fx = float(prob.u0)
47+
if !isa(fx, Number) || !isa(x, Number)
48+
error("Halley currently only supports scalar-valued single-variable functions")
49+
end
50+
T = typeof(x)
51+
52+
if SciMLBase.isinplace(prob)
53+
error("Halley currently only supports out-of-place nonlinear problems")
54+
end
55+
56+
atol = abstol !== nothing ? abstol :
57+
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
58+
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
59+
60+
if typeof(x) <: Number
61+
xo = oftype(one(eltype(x)), Inf)
62+
else
63+
xo = map(x -> oftype(one(eltype(x)), Inf), x)
64+
end
65+
66+
for i in 1:maxiters
67+
if alg_autodiff(alg)
68+
fx = f(x)
69+
dfdx(x) = ForwardDiff.derivative(f, x)
70+
dfx = dfdx(x)
71+
d2fx = ForwardDiff.derivative(dfdx, x)
72+
else
73+
fx = f(x)
74+
dfx = FiniteDiff.finite_difference_derivative(f, x, diff_type(alg), eltype(x),
75+
fx)
76+
d2fx = FiniteDiff.finite_difference_derivative(x -> FiniteDiff.finite_difference_derivative(f,
77+
x),
78+
x, diff_type(alg), eltype(x), fx)
79+
end
80+
iszero(fx) &&
81+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
82+
Δx = (2 * dfx^2 - fx * d2fx) \ (2fx * dfx)
83+
x -= Δx
84+
if isapprox(x, xo, atol = atol, rtol = rtol)
85+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
86+
end
87+
xo = x
88+
end
89+
90+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
91+
end

src/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ value(x::Dual) = ForwardDiff.value(x)
3535
value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
3636

3737
function init_J(x)
38-
J = ArrayInterfaceCore.zeromatrix(x)
38+
J = ArrayInterface.zeromatrix(x)
3939
if ismutable(x)
4040
J[diagind(J)] .= one(eltype(x))
4141
else

test/basictests.jl

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,29 @@ if VERSION >= v"1.7"
4343
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
4444
end
4545

46+
# Halley
47+
function benchmark_scalar(f, u0)
48+
probN = NonlinearProblem{false}(f, u0)
49+
sol = (solve(probN, Halley()))
50+
end
51+
52+
# function ff(u, p)
53+
# u .* u .- 2
54+
# end
55+
# const cu0 = @SVector[1.0, 1.0]
56+
function sf(u, p)
57+
u * u - 2
58+
end
59+
const csu0 = 1.0
60+
61+
sol = benchmark_scalar(sf, csu0)
62+
@test sol.retcode === ReturnCode.Success
63+
@test sol.u * sol.u - 2 < 1e-9
64+
65+
if VERSION >= v"1.7"
66+
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
67+
end
68+
4669
# Broyden
4770
function benchmark_scalar(f, u0, alg)
4871
probN = NonlinearProblem{false}(f, u0)
@@ -115,7 +138,7 @@ end
115138
# Scalar
116139
f, u0 = (u, p) -> u * u - p, 1.0
117140
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
118-
SimpleDFSane(), BROYDEN_SOLVERS...)
141+
SimpleDFSane(), Halley(), BROYDEN_SOLVERS...)
119142
g = function (p)
120143
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
121144
sol = solve(probN, alg)
@@ -181,7 +204,7 @@ for alg in [Bisection(), Falsi(), Ridder(), Brent()]
181204
end
182205

183206
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
184-
SimpleDFSane(), BROYDEN_SOLVERS...)
207+
SimpleDFSane(), Halley(), BROYDEN_SOLVERS...)
185208
global g, p
186209
g = function (p)
187210
probN = NonlinearProblem{false}(f, 0.5, p)
@@ -206,6 +229,12 @@ for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false),
206229
@test sol.u[end] sqrt(2.0)
207230
end
208231

232+
# Separate Error check for Halley; will be included in above error checks for the improved Halley
233+
f, u0 = (u, p) -> u * u - 2.0, 1.0
234+
probN = NonlinearProblem(f, u0)
235+
236+
@test solve(probN, Halley()).u sqrt(2.0)
237+
209238
for u0 in [1.0, [1, 1.0]]
210239
local f, probN, sol
211240
f = (u, p) -> u .* u .- 2.0

0 commit comments

Comments
 (0)