Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit c071dd3

Browse files
committed
Non allocating version of LBroyden for StaticArrays
1 parent 5161d6f commit c071dd3

File tree

2 files changed

+156
-18
lines changed

2 files changed

+156
-18
lines changed

src/nlsolve/lbroyden.jl

Lines changed: 156 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ end
3636

3737
fx = _get_fx(prob, x)
3838

39-
U, Vᵀ = __init_low_rank_jacobian(x, fx, threshold)
39+
U, Vᵀ = __init_low_rank_jacobian(x, fx, x isa StaticArray ? threshold : Val(η))
4040

4141
abstol, reltol, tc_cache = init_termination_cache(abstol, reltol, fx, x,
4242
termination_condition)
@@ -48,7 +48,7 @@ end
4848
@bb δf = copy(fx)
4949

5050
@bb vᵀ_cache = copy(x)
51-
Tcache = __lbroyden_threshold_cache(x, threshold)
51+
Tcache = __lbroyden_threshold_cache(x, x isa StaticArray ? threshold : Val(η))
5252
@bb mat_cache = copy(x)
5353

5454
for i in 1:maxiters
@@ -83,6 +83,105 @@ end
8383
return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
8484
end
8585

86+
# Non-allocating StaticArrays version of SimpleLimitedMemoryBroyden is actually quite
87+
# finicky, so we'll implement it separately from the generic version
88+
# We make an exception here and don't support termination conditions
89+
@views function SciMLBase.__solve(prob::NonlinearProblem{<:SArray},
90+
alg::SimpleLimitedMemoryBroyden, args...; abstol = nothing,
91+
termination_condition = nothing,
92+
maxiters = 1000, kwargs...)
93+
if termination_condition !== nothing &&
94+
!(termination_condition isa AbsNormTerminationMode)
95+
error("SimpleLimitedMemoryBroyden with StaticArrays does not support termination \
96+
conditions!")
97+
end
98+
99+
x = prob.u0
100+
fx = _get_fx(prob, x)
101+
threshold = __get_threshold(alg)
102+
103+
U, Vᵀ = __init_low_rank_jacobian(x, fx, threshold)
104+
105+
abstol = DiffEqBase._get_tolerance(abstol, eltype(x))
106+
107+
xo, δx, fo, δf = x, -fx, fx, fx
108+
109+
converged, res = __unrolled_lbroyden_initial_iterations(prob, xo, fo, δx, abstol, U, Vᵀ,
110+
threshold)
111+
112+
converged &&
113+
return build_solution(prob, alg, res.x, res.fx; retcode = ReturnCode.Success)
114+
115+
xo, fo, δx = res.x, res.fx, res.δx
116+
117+
for i in 1:(maxiters - SciMLBase._unwrap_val(threshold))
118+
x = xo .+ δx
119+
fx = prob.f(x, prob.p)
120+
δf = fx - fo
121+
122+
maximum(abs, fx) abstol &&
123+
return build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
124+
125+
vᵀ = _restructure(x, _rmatvec!!(U, Vᵀ, vec(δx)))
126+
mvec = _restructure(x, _matvec!!(U, Vᵀ, vec(δf)))
127+
128+
d = dot(vᵀ, δf)
129+
δx = @. (δx - mvec) / d
130+
131+
U = Base.setindex(U, vec(δx), mod1(i, SciMLBase._unwrap_val(threshold)))
132+
Vᵀ = Base.setindex(Vᵀ, vec(vᵀ), mod1(i, SciMLBase._unwrap_val(threshold)))
133+
134+
δx = -_restructure(fx, _matvec!!(U, Vᵀ, vec(fx)))
135+
136+
xo = x
137+
fo = fx
138+
end
139+
140+
return build_solution(prob, alg, xo, fo; retcode = ReturnCode.MaxIters)
141+
end
142+
143+
@generated function __unrolled_lbroyden_initial_iterations(prob, xo, fo, δx, abstol, U,
144+
Vᵀ, ::Val{threshold}) where {threshold}
145+
calls = []
146+
for i in 1:threshold
147+
static_idx, static_idx_p1 = Val(i - 1), Val(i)
148+
push!(calls,
149+
quote
150+
x = xo .+ δx
151+
fx = prob.f(x, prob.p)
152+
δf = fx - fo
153+
154+
maximum(abs, fx) abstol && return true, (; x, fx, δx)
155+
156+
_U = __first_n_getindex(U, $(static_idx))
157+
_Vᵀ = __first_n_getindex(Vᵀ, $(static_idx))
158+
159+
vᵀ = _restructure(x, _rmatvec!!(_U, _Vᵀ, vec(δx)))
160+
mvec = _restructure(x, _matvec!!(_U, _Vᵀ, vec(δf)))
161+
162+
d = dot(vᵀ, δf)
163+
δx = @. (δx - mvec) / d
164+
165+
U = Base.setindex(U, vec(δx), $(i))
166+
Vᵀ = Base.setindex(Vᵀ, vec(vᵀ), $(i))
167+
168+
_U = __first_n_getindex(U, $(static_idx_p1))
169+
_Vᵀ = __first_n_getindex(Vᵀ, $(static_idx_p1))
170+
δx = -_restructure(fx, _matvec!!(_U, _Vᵀ, vec(fx)))
171+
172+
xo = x
173+
fo = fx
174+
end)
175+
end
176+
push!(calls, quote
177+
# Termination Check
178+
maximum(abs, fx) abstol && return true, (; x, fx, δx)
179+
180+
return false, (; x, fx, δx)
181+
end)
182+
return Expr(:block, calls...)
183+
end
184+
86185
function _rmatvec!!(y, xᵀU, U, Vᵀ, x)
87186
# xᵀ × (-I + UVᵀ)
88187
η = size(U, 2)
@@ -98,6 +197,9 @@ function _rmatvec!!(y, xᵀU, U, Vᵀ, x)
98197
return y
99198
end
100199

200+
@inline _rmatvec!!(::Nothing, Vᵀ, x) = -x
201+
@inline _rmatvec!!(U, Vᵀ, x) = __mapTdot(__mapdot(x, U), Vᵀ) .- x
202+
101203
function _matvec!!(y, Vᵀx, U, Vᵀ, x)
102204
# (-I + UVᵀ) × x
103205
η = size(U, 2)
@@ -113,7 +215,57 @@ function _matvec!!(y, Vᵀx, U, Vᵀ, x)
113215
return y
114216
end
115217

218+
@inline _matvec!!(::Nothing, Vᵀ, x) = -x
219+
@inline _matvec!!(U, Vᵀ, x) = __mapTdot(__mapdot(x, Vᵀ), U) .- x
220+
221+
function __mapdot(x::SVector{S1}, Y::SVector{S2, <:SVector{S1}}) where {S1, S2}
222+
return map(Base.Fix1(dot, x), Y)
223+
end
224+
@generated function __mapTdot(x::SVector{S1}, Y::SVector{S1, <:SVector{S2}}) where {S1, S2}
225+
calls = []
226+
syms = [gensym("m$(i)") for i in 1:length(Y)]
227+
for i in 1:length(Y)
228+
push!(calls, :($(syms[i]) = x[$(i)] .* Y[$i]))
229+
end
230+
push!(calls, :(return .+($(syms...))))
231+
return Expr(:block, calls...)
232+
end
233+
234+
@generated function __first_n_getindex(x::SVector{L, T}, ::Val{N}) where {L, T, N}
235+
@assert N L
236+
getcalls = ntuple(i -> :(x[$i]), N)
237+
N == 0 && return :(return nothing)
238+
return :(return SVector{$N, $T}(($(getcalls...))))
239+
end
240+
116241
__lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = similar(x, threshold)
117-
function __lbroyden_threshold_cache(x::SArray, ::Val{threshold}) where {threshold}
118-
return SArray{Tuple{threshold}, eltype(x)}(ntuple(_ -> zero(eltype(x)), threshold))
242+
function __lbroyden_threshold_cache(x::StaticArray, ::Val{threshold}) where {threshold}
243+
return zeros(MArray{Tuple{threshold}, eltype(x)})
244+
end
245+
__lbroyden_threshold_cache(x::SArray, ::Val{threshold}) where {threshold} = nothing
246+
247+
function __init_low_rank_jacobian(u::StaticArray{S1, T1}, fu::StaticArray{S2, T2},
248+
::Val{threshold}) where {S1, S2, T1, T2, threshold}
249+
T = promote_type(T1, T2)
250+
fuSize, uSize = Size(fu), Size(u)
251+
Vᵀ = MArray{Tuple{threshold, prod(uSize)}, T}(undef)
252+
U = MArray{Tuple{prod(fuSize), threshold}, T}(undef)
253+
return U, Vᵀ
254+
end
255+
@generated function __init_low_rank_jacobian(u::SArray{S1, T1}, fu::SArray{S2, T2},
256+
::Val{threshold}) where {S1, S2, T1, T2, threshold}
257+
T = promote_type(T1, T2)
258+
Lfu, Lu = prod(Size(fu)), prod(Size(u))
259+
inner_inits_Vᵀ = [zeros(SVector{Lu, T}) for i in 1:threshold]
260+
inner_inits_U = [zeros(SVector{Lfu, T}) for i in 1:threshold]
261+
return quote
262+
Vᵀ = SVector($(inner_inits_Vᵀ...))
263+
U = SVector($(inner_inits_U...))
264+
return U, Vᵀ
265+
end
266+
end
267+
function __init_low_rank_jacobian(u, fu, ::Val{threshold}) where {threshold}
268+
Vᵀ = similar(u, threshold, length(u))
269+
U = similar(u, length(fu), threshold)
270+
return U, Vᵀ
119271
end

src/utils.jl

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -243,20 +243,6 @@ function __init_identity_jacobian!!(J::SVector{S1}) where {S1}
243243
return ones(SVector{S1, eltype(J)})
244244
end
245245

246-
function __init_low_rank_jacobian(u::StaticArray{S1, T1}, fu::StaticArray{S2, T2},
247-
::Val{threshold}) where {S1, S2, T1, T2, threshold}
248-
T = promote_type(T1, T2)
249-
fuSize, uSize = Size(fu), Size(u)
250-
Vᵀ = MArray{Tuple{threshold, prod(uSize)}, T}(undef)
251-
U = MArray{Tuple{prod(fuSize), threshold}, T}(undef)
252-
return U, Vᵀ
253-
end
254-
function __init_low_rank_jacobian(u, fu, ::Val{threshold}) where {threshold}
255-
Vᵀ = similar(u, threshold, length(u))
256-
U = similar(u, length(fu), threshold)
257-
return U, Vᵀ
258-
end
259-
260246
@inline _vec(v) = vec(v)
261247
@inline _vec(v::Number) = v
262248
@inline _vec(v::AbstractVector) = v

0 commit comments

Comments
 (0)