11"""
2- ```julia
3- SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
4- M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
5- nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 / k^2)
6- ```
2+ SimpleDFSane(; σ_min::Real = 1e-10, σ_max::Real = 1e10, σ_1::Real = 1.0,
3+ M::Int = 10, γ::Real = 1e-4, τ_min::Real = 0.1, τ_max::Real = 0.5,
4+ nexp::Int = 2, η_strategy::Function = (f_1, k, x, F) -> f_1 ./ k^2,
5+ termination_condition = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
6+ abstol = nothing,
7+ reltol = nothing),
8+ batched::Bool = false,
9+ max_inner_iterations::Int = 1000)
710
811A low-overhead implementation of the df-sane method for solving large-scale nonlinear
912systems of equations. For in depth information about all the parameters and the algorithm,
@@ -39,8 +42,16 @@ Computation, 75, 1429-1448.](https://www.researchgate.net/publication/220576479_
3942 ``f_1=||F(x_1)||^{nexp}``, `k` is the iteration number, `x` is the current `x`-value and
4043 `F` the current residual. Should satisfy ``η_k > 0`` and ``∑ₖ ηₖ < ∞``. Defaults to
4144 ``||F||^2 / k^2``.
45+ - `termination_condition`: a `NLSolveTerminationCondition` that determines when the solver
46+ should terminate. Defaults to `NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
47+ abstol = nothing, reltol = nothing)`.
48+ - `batched`: if `true`, the algorithm will use a batched version of the algorithm that treats each
49+ column of `x` as a separate problem. This can be useful nonlinear problems involing neural
50+ networks. Defaults to `false`.
51+ - `max_inner_iterations`: the maximum number of iterations allowed for the inner loop of the
52+ algorithm. Used exclusively in `batched` mode. Defaults to `1000`.
4253"""
43- struct SimpleDFSane{T } <: AbstractSimpleNonlinearSolveAlgorithm
54+ struct SimpleDFSane{batched, T, TC } <: AbstractSimpleNonlinearSolveAlgorithm
4455 σ_min:: T
4556 σ_max:: T
4657 σ_1:: T
@@ -50,106 +61,187 @@ struct SimpleDFSane{T} <: AbstractSimpleNonlinearSolveAlgorithm
5061 τ_max:: T
5162 nexp:: Int
5263 η_strategy:: Function
64+ termination_condition:: TC
65+ max_inner_iterations:: Int
5366
5467 function SimpleDFSane (; σ_min:: Real = 1e-10 , σ_max:: Real = 1e10 , σ_1:: Real = 1.0 ,
5568 M:: Int = 10 , γ:: Real = 1e-4 , τ_min:: Real = 0.1 , τ_max:: Real = 0.5 ,
56- nexp:: Int = 2 , η_strategy:: Function = (f_1, k, x, F) -> f_1 / k^ 2 )
57- new {typeof(σ_min)} (σ_min, σ_max, σ_1, M, γ, τ_min, τ_max, nexp, η_strategy)
69+ nexp:: Int = 2 , η_strategy:: Function = (f_1, k, x, F) -> f_1 ./ k^ 2 ,
70+ termination_condition = NLSolveTerminationCondition (NLSolveTerminationMode. NLSolveDefault;
71+ abstol = nothing ,
72+ reltol = nothing ),
73+ batched:: Bool = false ,
74+ max_inner_iterations = 1000 )
75+ return new {batched, typeof(σ_min), typeof(termination_condition)} (σ_min,
76+ σ_max,
77+ σ_1,
78+ M,
79+ γ,
80+ τ_min,
81+ τ_max,
82+ nexp,
83+ η_strategy,
84+ termination_condition,
85+ max_inner_iterations)
5886 end
5987end
6088
61- function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SimpleDFSane ,
89+ function SciMLBase. __solve (prob:: NonlinearProblem , alg:: SimpleDFSane{batched} ,
6290 args... ; abstol = nothing , reltol = nothing , maxiters = 1000 ,
63- kwargs... )
91+ kwargs... ) where {batched}
92+ tc = alg. termination_condition
93+ mode = DiffEqBase. get_termination_mode (tc)
94+
6495 f = Base. Fix2 (prob. f, prob. p)
6596 x = float (prob. u0)
97+
98+ if batched
99+ batch_size = size (x, 2 )
100+ end
101+
66102 T = eltype (x)
67103 σ_min = float (alg. σ_min)
68104 σ_max = float (alg. σ_max)
69- σ_k = float (alg. σ_1)
105+ σ_k = batched ? fill (float (alg. σ_1), 1 , batch_size) : float (alg. σ_1)
106+
70107 M = alg. M
71108 γ = float (alg. γ)
72109 τ_min = float (alg. τ_min)
73110 τ_max = float (alg. τ_max)
74111 nexp = alg. nexp
75112 η_strategy = alg. η_strategy
76113
114+ batched && @assert ndims (x)== 2 " Batched SimpleDFSane only supports 2D arrays"
115+
77116 if SciMLBase. isinplace (prob)
78117 error (" SimpleDFSane currently only supports out-of-place nonlinear problems" )
79118 end
80119
81120 atol = abstol != = nothing ? abstol :
82- real (oneunit (eltype (T))) * (eps (real (one (eltype (T)))))^ (4 // 5 )
83- rtol = reltol != = nothing ? reltol : eps (real (one (eltype (T))))^ (4 // 5 )
121+ (tc. abstol != = nothing ? tc. abstol :
122+ real (oneunit (eltype (T))) * (eps (real (one (eltype (T)))))^ (4 // 5 ))
123+ rtol = reltol != = nothing ? reltol :
124+ (tc. reltol != = nothing ? tc. reltol : eps (real (one (eltype (T))))^ (4 // 5 ))
125+
126+ if mode ∈ DiffEqBase. SAFE_BEST_TERMINATION_MODES
127+ error (" SimpleDFSane currently doesn't support SAFE_BEST termination modes" )
128+ end
129+
130+ storage = mode ∈ DiffEqBase. SAFE_TERMINATION_MODES ? NLSolveSafeTerminationResult () :
131+ nothing
132+ termination_condition = tc (storage)
84133
85134 function ff (x)
86135 F = f (x)
87- f_k = norm (F)^ nexp
136+ f_k = if batched
137+ sum (abs2, F; dims = 1 ) .^ (nexp / 2 )
138+ else
139+ norm (F)^ nexp
140+ end
88141 return f_k, F
89142 end
90143
144+ function generate_history (f_k, M)
145+ if batched
146+ history = similar (f_k, (M, length (f_k)))
147+ history .= reshape (f_k, 1 , :)
148+ return history
149+ else
150+ return fill (f_k, M)
151+ end
152+ end
153+
91154 f_k, F_k = ff (x)
92155 α_1 = convert (T, 1.0 )
93156 f_1 = f_k
94- history_f_k = fill (f_k, M)
157+ history_f_k = generate_history (f_k, M)
95158
96159 for k in 1 : maxiters
97- iszero (F_k) &&
98- return SciMLBase. build_solution (prob, alg, x, F_k;
99- retcode = ReturnCode. Success)
100-
101160 # Spectral parameter range check
102- if abs (σ_k) > σ_max
103- σ_k = sign (σ_k) * σ_max
104- elseif abs (σ_k) < σ_min
105- σ_k = sign (σ_k) * σ_min
161+ if batched
162+ @. σ_k = sign (σ_k) * clamp ( abs (σ_k), σ_min, σ_max)
163+ else
164+ σ_k = sign (σ_k) * clamp ( abs (σ_k), σ_min, σ_max)
106165 end
107166
108167 # Line search direction
109- d = - σ_k * F_k
168+ d = - σ_k . * F_k
110169
111170 η = η_strategy (f_1, k, x, F_k)
112- f̄ = maximum (history_f_k)
171+ f̄ = batched ? maximum (history_f_k; dims = 1 ) : maximum (history_f_k)
113172 α_p = α_1
114173 α_m = α_1
115- x_new = x + α_p * d
174+ x_new = @. x + α_p * d
175+
116176 f_new, F_new = ff (x_new)
177+
178+ inner_iterations = 0
117179 while true
118- if f_new ≤ f̄ + η - γ * α_p^ 2 * f_k
119- break
180+ inner_iterations += 1
181+
182+ if batched
183+ criteria = @. f̄ + η - γ * α_p^ 2 * f_k
184+ # NOTE: This is simply a heuristic, ideally we check using `all` but that is
185+ # typically very expensive for large problems
186+ (sum (f_new .≤ criteria) ≥ batch_size ÷ 2 ) && break
187+ else
188+ criteria = f̄ + η - γ * α_p^ 2 * f_k
189+ f_new ≤ criteria && break
120190 end
121191
122- α_tp = α_p^ 2 * f_k / (f_new + (2 * α_p - 1 ) * f_k)
123- x_new = x - α_m * d
192+ α_tp = @. α_p^ 2 * f_k / (f_new + (2 * α_p - 1 ) * f_k)
193+ x_new = @. x - α_m * d
124194 f_new, F_new = ff (x_new)
125195
126- if f_new ≤ f̄ + η - γ * α_m^ 2 * f_k
127- break
196+ if batched
197+ # NOTE: This is simply a heuristic, ideally we check using `all` but that is
198+ # typically very expensive for large problems
199+ (sum (f_new .≤ criteria) ≥ batch_size ÷ 2 ) && break
200+ else
201+ f_new ≤ criteria && break
128202 end
129203
130- α_tm = α_m^ 2 * f_k / (f_new + (2 * α_m - 1 ) * f_k)
131- α_p = min (τ_max * α_p, max (α_tp, τ_min * α_p) )
132- α_m = min (τ_max * α_m, max (α_tm, τ_min * α_m) )
133- x_new = x + α_p * d
204+ α_tm = @. α_m^ 2 * f_k / (f_new + (2 * α_m - 1 ) * f_k)
205+ α_p = @. clamp (α_tp, τ_min * α_p, τ_max * α_p)
206+ α_m = @. clamp (α_tm, τ_min * α_m, τ_max * α_m)
207+ x_new = @. x + α_p * d
134208 f_new, F_new = ff (x_new)
209+
210+ # NOTE: The original algorithm runs till either condition is satisfied, however,
211+ # for most batched problems like neural networks we only care about
212+ # approximate convergence
213+ batched && (inner_iterations ≥ alg. max_inner_iterations) && break
135214 end
136215
137- if isapprox (x_new, x, atol = atol, rtol = rtol)
138- return SciMLBase. build_solution (prob, alg, x_new, F_new;
216+ if termination_condition (F_new, x_new, x, atol, rtol)
217+ return SciMLBase. build_solution (prob,
218+ alg,
219+ x_new,
220+ F_new;
139221 retcode = ReturnCode. Success)
140222 end
223+
141224 # Update spectral parameter
142- s_k = x_new - x
143- y_k = F_new - F_k
144- σ_k = (s_k' * s_k) / (s_k' * y_k)
225+ s_k = @. x_new - x
226+ y_k = @. F_new - F_k
227+
228+ if batched
229+ σ_k = sum (abs2, s_k; dims = 1 ) ./ (sum (s_k .* y_k; dims = 1 ) .+ T (1e-5 ))
230+ else
231+ σ_k = (s_k' * s_k) / (s_k' * y_k)
232+ end
145233
146234 # Take step
147235 x = x_new
148236 F_k = F_new
149237 f_k = f_new
150238
151239 # Store function value
152- history_f_k[k % M + 1 ] = f_new
240+ if batched
241+ history_f_k[k % M + 1 , :] .= vec (f_new)
242+ else
243+ history_f_k[k % M + 1 ] = f_new
244+ end
153245 end
154246 return SciMLBase. build_solution (prob, alg, x, F_k; retcode = ReturnCode. MaxIters)
155247end
0 commit comments