Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 2eacfa8

Browse files
committed
Add Termination Conditions to Broyden
1 parent ecbe9d1 commit 2eacfa8

File tree

4 files changed

+107
-57
lines changed

4 files changed

+107
-57
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "0.1.11"
4+
version = "0.1.12"
55

66
[deps]
77
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"
@@ -23,7 +23,7 @@ SimpleBatchedNonlinearSolveExt = "NNlib"
2323

2424
[compat]
2525
ArrayInterfaceCore = "0.1.1"
26-
DiffEqBase = "6.114"
26+
DiffEqBase = "6.118.1"
2727
FiniteDiff = "2"
2828
ForwardDiff = "0.10.3"
2929
NNlib = "0.8"
@@ -43,4 +43,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
4343
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4444

4545
[targets]
46-
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib"]
46+
test = ["BenchmarkTools", "SafeTestsets", "Pkg", "Test", "StaticArrays", "NNlib", "DiffEqBase"]

ext/SimpleBatchedNonlinearSolveExt.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module SimpleBatchedNonlinearSolveExt
22

3-
using ArrayInterfaceCore, LinearAlgebra, SimpleNonlinearSolve, SciMLBase
3+
using ArrayInterfaceCore, DiffEqBase, LinearAlgebra, SimpleNonlinearSolve, SciMLBase
44
isdefined(Base, :get_extension) ? (using NNlib) : (using ..NNlib)
55

66
_batch_transpose(x) = reshape(x, 1, size(x)...)
@@ -31,6 +31,8 @@ end
3131

3232
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
3333
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
34+
tc = alg.termination_condition
35+
mode = DiffEqBase.get_termination_mode(tc)
3436
f = Base.Fix2(prob.f, prob.p)
3537
x = float(prob.u0)
3638

@@ -47,8 +49,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
4749
end
4850

4951
atol = abstol !== nothing ? abstol :
50-
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
51-
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
52+
(tc.abstol !== nothing ? tc.abstol :
53+
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5))
54+
rtol = reltol !== nothing ? reltol :
55+
(tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5))
56+
57+
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
58+
error("Broyden currently doesn't support SAFE_BEST termination modes")
59+
end
60+
61+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing
62+
termination_condition = tc(storage)
5263

5364
xₙ = x
5465
xₙ₋₁ = x
@@ -63,14 +74,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{true}, args...;
6374
(_batched_mul(_batch_transpose(Δxₙ), J⁻¹Δfₙ) .+ T(1e-5))),
6475
_batched_mul(_batch_transpose(Δxₙ), J⁻¹))
6576

66-
iszero(fₙ) &&
67-
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
68-
retcode = ReturnCode.Success)
69-
70-
if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
71-
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
72-
retcode = ReturnCode.Success)
77+
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
78+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)
7379
end
80+
7481
xₙ₋₁ = xₙ
7582
fₙ₋₁ = fₙ
7683
end

src/broyden.jl

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""
2-
Broyden(; batched = false)
2+
Broyden(; batched = false,
3+
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
4+
abstol = nothing, reltol = nothing))
35
46
A low-overhead implementation of Broyden. This method is non-allocating on scalar
57
and static array problems.
@@ -9,12 +11,22 @@ and static array problems.
911
To use the `batched` version, remember to load `NNlib`, i.e., `using NNlib` or
1012
`import NNlib` must be present in your code.
1113
"""
12-
struct Broyden{batched} <: AbstractSimpleNonlinearSolveAlgorithm
13-
Broyden(; batched = false) = new{batched}()
14+
struct Broyden{batched, TC <: NLSolveTerminationCondition} <:
15+
AbstractSimpleNonlinearSolveAlgorithm
16+
termination_condition::TC
17+
18+
function Broyden(; batched = false,
19+
termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
20+
abstol = nothing,
21+
reltol = nothing))
22+
return new{batched, typeof(termination_condition)}(termination_condition)
23+
end
1424
end
1525

1626
function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
1727
abstol = nothing, reltol = nothing, maxiters = 1000, kwargs...)
28+
tc = alg.termination_condition
29+
mode = DiffEqBase.get_termination_mode(tc)
1830
f = Base.Fix2(prob.f, prob.p)
1931
x = float(prob.u0)
2032

@@ -27,8 +39,17 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
2739
end
2840

2941
atol = abstol !== nothing ? abstol :
30-
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5)
31-
rtol = reltol !== nothing ? reltol : eps(real(one(eltype(T))))^(4 // 5)
42+
(tc.abstol !== nothing ? tc.abstol :
43+
real(oneunit(eltype(T))) * (eps(real(one(eltype(T)))))^(4 // 5))
44+
rtol = reltol !== nothing ? reltol :
45+
(tc.reltol !== nothing ? tc.reltol : eps(real(one(eltype(T))))^(4 // 5))
46+
47+
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
48+
error("Broyden currently doesn't support SAFE_BEST termination modes")
49+
end
50+
51+
storage = mode DiffEqBase.SAFE_TERMINATION_MODES ? Dict() : nothing
52+
termination_condition = tc(storage)
3253

3354
xₙ = x
3455
xₙ₋₁ = x
@@ -41,14 +62,10 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::Broyden{false}, args...;
4162
J⁻¹Δfₙ = J⁻¹ * Δfₙ
4263
J⁻¹ += ((Δxₙ .- J⁻¹Δfₙ) ./ (Δxₙ' * J⁻¹Δfₙ)) * (Δxₙ' * J⁻¹)
4364

44-
iszero(fₙ) &&
45-
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
46-
retcode = ReturnCode.Success)
47-
48-
if isapprox(xₙ, xₙ₋₁, atol = atol, rtol = rtol)
49-
return SciMLBase.build_solution(prob, alg, xₙ, fₙ;
50-
retcode = ReturnCode.Success)
65+
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
66+
return SciMLBase.build_solution(prob, alg, xₙ, fₙ; retcode = ReturnCode.Success)
5167
end
68+
5269
xₙ₋₁ = xₙ
5370
fₙ₋₁ = fₙ
5471
end

test/basictests.jl

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,25 @@
11
using SimpleNonlinearSolve
22
using StaticArrays
33
using BenchmarkTools
4+
using DiffEqBase
45
using Test
56

7+
const BATCHED_BROYDEN_SOLVERS = Broyden[]
8+
const BROYDEN_SOLVERS = Broyden[]
9+
10+
for mode in instances(NLSolveTerminationMode.T)
11+
if mode
12+
(NLSolveTerminationMode.SteadyStateDefault, NLSolveTerminationMode.RelSafeBest,
13+
NLSolveTerminationMode.AbsSafeBest)
14+
continue
15+
end
16+
17+
termination_condition = NLSolveTerminationCondition(mode; abstol = nothing,
18+
reltol = nothing)
19+
push!(BROYDEN_SOLVERS, Broyden(; batched = false, termination_condition))
20+
push!(BATCHED_BROYDEN_SOLVERS, Broyden(; batched = true, termination_condition))
21+
end
22+
623
# SimpleNewtonRaphson
724
function benchmark_scalar(f, u0)
825
probN = NonlinearProblem{false}(f, u0)
@@ -27,16 +44,19 @@ if VERSION >= v"1.7"
2744
end
2845

2946
# Broyden
30-
function benchmark_scalar(f, u0)
47+
function benchmark_scalar(f, u0, alg)
3148
probN = NonlinearProblem{false}(f, u0)
32-
sol = (solve(probN, Broyden()))
49+
sol = (solve(probN, alg))
3350
end
3451

35-
sol = benchmark_scalar(sf, csu0)
36-
@test sol.retcode === ReturnCode.Success
37-
@test sol.u * sol.u - 2 < 1e-9
38-
if VERSION >= v"1.7"
39-
@test (@ballocated benchmark_scalar(sf, csu0)) == 0
52+
for alg in BROYDEN_SOLVERS
53+
sol = benchmark_scalar(sf, csu0, alg)
54+
@test sol.retcode === ReturnCode.Success
55+
@test sol.u * sol.u - 2 < 1e-9
56+
# FIXME: Termination Condition Implementation is allocating. Not sure how to fix it.
57+
# if VERSION >= v"1.7"
58+
# @test (@ballocated benchmark_scalar($sf, $csu0, $termination_condition)) == 0
59+
# end
4060
end
4161

4262
# Klement
@@ -78,8 +98,8 @@ using ForwardDiff
7898
# Immutable
7999
f, u0 = (u, p) -> u .* u .- p, @SVector[1.0, 1.0]
80100

81-
for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(),
82-
SimpleDFSane())
101+
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
102+
SimpleDFSane(), BROYDEN_SOLVERS...)
83103
g = function (p)
84104
probN = NonlinearProblem{false}(f, csu0, p)
85105
sol = solve(probN, alg, abstol = 1e-9)
@@ -94,8 +114,8 @@ end
94114

95115
# Scalar
96116
f, u0 = (u, p) -> u * u - p, 1.0
97-
for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(),
98-
SimpleDFSane())
117+
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
118+
SimpleDFSane(), BROYDEN_SOLVERS...)
99119
g = function (p)
100120
probN = NonlinearProblem{false}(f, oftype(p, u0), p)
101121
sol = solve(probN, alg)
@@ -160,8 +180,8 @@ for alg in [Bisection(), Falsi(), Ridder(), Brent()]
160180
@test ForwardDiff.jacobian(g, p) ForwardDiff.jacobian(t, p)
161181
end
162182

163-
for alg in (SimpleNewtonRaphson(), Broyden(), LBroyden(), Klement(), SimpleTrustRegion(),
164-
SimpleDFSane())
183+
for alg in (SimpleNewtonRaphson(), LBroyden(), Klement(), SimpleTrustRegion(),
184+
SimpleDFSane(), BROYDEN_SOLVERS...)
165185
global g, p
166186
g = function (p)
167187
probN = NonlinearProblem{false}(f, 0.5, p)
@@ -176,33 +196,32 @@ end
176196
f, u0 = (u, p) -> u .* u .- 2.0, @SVector[1.0, 1.0]
177197
probN = NonlinearProblem(f, u0)
178198

179-
@test solve(probN, SimpleNewtonRaphson()).u[end] sqrt(2.0)
180-
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u[end] sqrt(2.0)
181-
@test solve(probN, SimpleTrustRegion()).u[end] sqrt(2.0)
182-
@test solve(probN, SimpleTrustRegion(; autodiff = false)).u[end] sqrt(2.0)
183-
@test solve(probN, Broyden()).u[end] sqrt(2.0)
184-
@test solve(probN, LBroyden()).u[end] sqrt(2.0)
185-
@test solve(probN, Klement()).u[end] sqrt(2.0)
186-
@test solve(probN, SimpleDFSane()).u[end] sqrt(2.0)
199+
for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false),
200+
SimpleTrustRegion(),
201+
SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(), SimpleDFSane(),
202+
BROYDEN_SOLVERS...)
203+
sol = solve(probN, alg)
204+
205+
@test sol.retcode == ReturnCode.Success
206+
@test sol.u[end] sqrt(2.0)
207+
end
187208

188209
for u0 in [1.0, [1, 1.0]]
189210
local f, probN, sol
190211
f = (u, p) -> u .* u .- 2.0
191212
probN = NonlinearProblem(f, u0)
192213
sol = sqrt(2) * u0
193214

194-
@test solve(probN, SimpleNewtonRaphson()).u sol
195-
@test solve(probN, SimpleNewtonRaphson()).u sol
196-
@test solve(probN, SimpleNewtonRaphson(; autodiff = false)).u sol
197-
198-
@test solve(probN, SimpleTrustRegion()).u sol
199-
@test solve(probN, SimpleTrustRegion()).u sol
200-
@test solve(probN, SimpleTrustRegion(; autodiff = false)).u sol
215+
for alg in (SimpleNewtonRaphson(), SimpleNewtonRaphson(; autodiff = false),
216+
SimpleTrustRegion(),
217+
SimpleTrustRegion(; autodiff = false), LBroyden(), Klement(),
218+
SimpleDFSane(),
219+
BROYDEN_SOLVERS...)
220+
sol2 = solve(probN, alg)
201221

202-
@test solve(probN, Broyden()).u sol
203-
@test solve(probN, LBroyden()).u sol
204-
@test solve(probN, Klement()).u sol
205-
@test solve(probN, SimpleDFSane()).u sol
222+
@test sol2.retcode == ReturnCode.Success
223+
@test sol2.u sol
224+
end
206225
end
207226

208227
# Bisection Tests
@@ -382,3 +401,10 @@ probN = NonlinearProblem{false}(f, u0, p);
382401
sol = solve(probN, Broyden(batched = true))
383402

384403
@test abs.(sol.u) sqrt.(p)
404+
405+
for alg in BATCHED_BROYDEN_SOLVERS
406+
sol = solve(probN, alg)
407+
408+
@test sol.retcode == ReturnCode.Success
409+
@test abs.(sol.u) sqrt.(p)
410+
end

0 commit comments

Comments
 (0)