Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 0b25978

Browse files
Merge pull request #116 from SciML/ap/ls
Add Line Search to (L)Broyden
2 parents ade88a2 + 469afbf commit 0b25978

18 files changed

+301
-104
lines changed

Project.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "SimpleNonlinearSolve"
22
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
33
authors = ["SciML"]
4-
version = "1.2.1"
4+
version = "1.3.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
99
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
1010
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
11+
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
1112
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1213
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1314
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -17,17 +18,20 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1718
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1819
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
1920

20-
[extensions]
21-
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
22-
2321
[weakdeps]
2422
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
23+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
24+
25+
[extensions]
26+
SimpleNonlinearSolvePolyesterForwardDiffExt = "PolyesterForwardDiff"
27+
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
2528

2629
[compat]
2730
ADTypes = "0.2.6"
2831
ArrayInterface = "7"
2932
ConcreteStructs = "0.2"
3033
DiffEqBase = "6.126"
34+
FastClosures = "0.3"
3135
FiniteDiff = "2"
3236
ForwardDiff = "0.10.3"
3337
LinearAlgebra = "1.9"
@@ -36,4 +40,5 @@ PrecompileTools = "1"
3640
Reexport = "1"
3741
SciMLBase = "2.7"
3842
StaticArraysCore = "1.4"
43+
StaticArrays = "1"
3944
julia = "1.9"
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
module SimpleNonlinearSolveStaticArraysExt
2+
3+
using SimpleNonlinearSolve
4+
5+
@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true
6+
7+
end

src/SimpleNonlinearSolve.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,16 @@ module SimpleNonlinearSolve
33
import PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidations
44

55
@recompile_invalidations begin
6-
using ADTypes,
7-
ArrayInterface, ConcreteStructs, DiffEqBase, Reexport, LinearAlgebra, SciMLBase
6+
using ADTypes, ArrayInterface, ConcreteStructs, DiffEqBase, FastClosures, FiniteDiff,
7+
ForwardDiff, Reexport, LinearAlgebra, SciMLBase
88

99
import DiffEqBase: AbstractNonlinearTerminationMode,
1010
AbstractSafeNonlinearTerminationMode, AbstractSafeBestNonlinearTerminationMode,
1111
NonlinearSafeTerminationReturnCode, get_termination_mode,
12-
NONLINEARSOLVE_DEFAULT_NORM, _get_tolerance
13-
using FiniteDiff, ForwardDiff
12+
NONLINEARSOLVE_DEFAULT_NORM
1413
import ForwardDiff: Dual
1514
import MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
16-
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace
15+
import SciMLBase: AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val
1716
import StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, MMatrix, Size
1817
end
1918

@@ -26,6 +25,7 @@ abstract type AbstractNewtonAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm e
2625
@inline __is_extension_loaded(::Val) = false
2726

2827
include("utils.jl")
28+
include("linesearch.jl")
2929

3030
## Nonlinear Solvers
3131
include("nlsolve/raphson.jl")

src/ad.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, <:AbstractArray}
77
sol.original)
88
end
99

10-
# Handle Ambiguities
1110
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
1211
@eval begin
1312
function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip,

src/bracketing/bisection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Bisection, args...
2626
left, right = prob.tspan
2727
fl, fr = f(left), f(right)
2828

29-
abstol = _get_tolerance(abstol,
29+
abstol = __get_tolerance(nothing, abstol,
3030
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
3131

3232
if iszero(fl)

src/bracketing/brent.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Brent, args...;
1313
fl, fr = f(left), f(right)
1414
ϵ = eps(convert(typeof(fl), 1))
1515

16-
abstol = _get_tolerance(abstol,
16+
abstol = __get_tolerance(nothing, abstol,
1717
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1818

1919
if iszero(fl)

src/bracketing/falsi.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Falsi, args...;
1212
left, right = prob.tspan
1313
fl, fr = f(left), f(right)
1414

15-
abstol = _get_tolerance(abstol,
15+
abstol = __get_tolerance(nothing, abstol,
1616
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1717

1818
if iszero(fl)

src/bracketing/itp.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::ITP, args...;
5858
left, right = prob.tspan
5959
fl, fr = f(left), f(right)
6060

61-
abstol = _get_tolerance(abstol,
61+
abstol = __get_tolerance(nothing, abstol,
6262
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
6363

6464
if iszero(fl)

src/bracketing/ridder.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::Ridder, args...;
1212
left, right = prob.tspan
1313
fl, fr = f(left), f(right)
1414

15-
abstol = _get_tolerance(abstol,
15+
abstol = __get_tolerance(nothing, abstol,
1616
promote_type(eltype(first(prob.tspan)), eltype(last(prob.tspan))))
1717

1818
if iszero(fl)

src/linesearch.jl

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# This is a copy of the version in NonlinearSolve.jl. Temporarily kept here till we move
2+
# line searches into a dedicated package.
3+
@kwdef @concrete struct LiFukushimaLineSearch
4+
lambda_0 = 1
5+
beta = 0.5
6+
sigma_1 = 0.001
7+
sigma_2 = 0.001
8+
eta = 0.1
9+
rho = 0.1
10+
nan_maxiters = missing
11+
maxiters::Int = 100
12+
end
13+
14+
@concrete mutable struct LiFukushimaLineSearchCache{T <: Union{Nothing, Int}}
15+
ϕ
16+
λ₀
17+
β
18+
σ₁
19+
σ₂
20+
η
21+
ρ
22+
α
23+
nan_maxiters::T
24+
maxiters::Int
25+
end
26+
27+
@concrete struct StaticLiFukushimaLineSearchCache
28+
f
29+
p
30+
λ₀
31+
β
32+
σ₁
33+
σ₂
34+
η
35+
ρ
36+
maxiters::Int
37+
end
38+
39+
(alg::LiFukushimaLineSearch)(prob, fu, u) = __generic_init(alg, prob, fu, u)
40+
function (alg::LiFukushimaLineSearch)(prob, fu::Union{Number, SArray},
41+
u::Union{Number, SArray})
42+
(alg.nan_maxiters === missing || alg.nan_maxiters === nothing) &&
43+
return __static_init(alg, prob, fu, u)
44+
@warn "`LiFukushimaLineSearch` with NaN checking is not non-allocating" maxlog=1
45+
return __generic_init(alg, prob, fu, u)
46+
end
47+
48+
function __generic_init(alg::LiFukushimaLineSearch, prob, fu, u)
49+
@bb u_cache = similar(u)
50+
@bb fu_cache = similar(fu)
51+
T = promote_type(eltype(fu), eltype(u))
52+
53+
ϕ = @closure (u, δu, α) -> begin
54+
@bb @. u_cache = u + α * δu
55+
return NONLINEARSOLVE_DEFAULT_NORM(__eval_f(prob, fu_cache, u_cache))
56+
end
57+
58+
nan_maxiters = ifelse(alg.nan_maxiters === missing, 5, alg.nan_maxiters)
59+
60+
return LiFukushimaLineSearchCache(ϕ, T(alg.lambda_0), T(alg.beta), T(alg.sigma_1),
61+
T(alg.sigma_2), T(alg.eta), T(alg.rho), T(true), nan_maxiters, alg.maxiters)
62+
end
63+
64+
function __static_init(alg::LiFukushimaLineSearch, prob, fu, u)
65+
T = promote_type(eltype(fu), eltype(u))
66+
return StaticLiFukushimaLineSearchCache(prob.f, prob.p, T(alg.lambda_0), T(alg.beta),
67+
T(alg.sigma_1), T(alg.sigma_2), T(alg.eta), T(alg.rho), alg.maxiters)
68+
end
69+
70+
function (cache::LiFukushimaLineSearchCache)(u, δu)
71+
T = promote_type(eltype(u), eltype(δu))
72+
ϕ = @closure α -> cache.ϕ(u, δu, α)
73+
fx_norm = ϕ(T(0))
74+
75+
# Non-Blocking exit if the norm is NaN or Inf
76+
DiffEqBase.NAN_CHECK(fx_norm) && return cache.α
77+
78+
# Early Terminate based on Eq. 2.7
79+
du_norm = NONLINEARSOLVE_DEFAULT_NORM(δu)
80+
fxλ_norm = ϕ(cache.α)
81+
fxλ_norm cache.ρ * fx_norm - cache.σ₂ * du_norm^2 && return cache.α
82+
83+
λ₂, λ₁ = cache.λ₀, cache.λ₀
84+
fxλp_norm = ϕ(λ₂)
85+
86+
if cache.nan_maxiters !== nothing
87+
if DiffEqBase.NAN_CHECK(fxλp_norm)
88+
nan_converged = false
89+
for _ in 1:(cache.nan_maxiters)
90+
λ₁, λ₂ = λ₂, cache.β * λ₂
91+
fxλp_norm = ϕ(λ₂)
92+
nan_converged = DiffEqBase.NAN_CHECK(fxλp_norm)::Bool
93+
nan_converged && break
94+
end
95+
nan_converged || return cache.α
96+
end
97+
end
98+
99+
for i in 1:(cache.maxiters)
100+
fxλp_norm = ϕ(λ₂)
101+
converged = fxλp_norm (1 + cache.η) * fx_norm - cache.σ₁ * λ₂^2 * du_norm^2
102+
converged && return λ₂
103+
λ₁, λ₂ = λ₂, cache.β * λ₂
104+
end
105+
106+
return cache.α
107+
end
108+
109+
function (cache::StaticLiFukushimaLineSearchCache)(u, δu)
110+
T = promote_type(eltype(u), eltype(δu))
111+
112+
# Early Terminate based on Eq. 2.7
113+
fx_norm = NONLINEARSOLVE_DEFAULT_NORM(cache.f(u, cache.p))
114+
du_norm = NONLINEARSOLVE_DEFAULT_NORM(δu)
115+
fxλ_norm = NONLINEARSOLVE_DEFAULT_NORM(cache.f(u .+ δu, cache.p))
116+
fxλ_norm cache.ρ * fx_norm - cache.σ₂ * du_norm^2 && return T(true)
117+
118+
λ₂, λ₁ = cache.λ₀, cache.λ₀
119+
120+
for i in 1:(cache.maxiters)
121+
fxλp_norm = NONLINEARSOLVE_DEFAULT_NORM(cache.f(u .+ λ₂ .* δu, cache.p))
122+
converged = fxλp_norm (1 + cache.η) * fx_norm - cache.σ₁ * λ₂^2 * du_norm^2
123+
converged && return λ₂
124+
λ₁, λ₂ = λ₂, cache.β * λ₂
125+
end
126+
127+
return T(true)
128+
end

0 commit comments

Comments
 (0)