Skip to content

Commit 8228c67

Browse files
Merge pull request #470 from SciML/ap/overhead
[Do Not Merge] Preallocate more caches
2 parents e37a60a + 085ba9d commit 8228c67

File tree

5 files changed

+43
-10
lines changed

5 files changed

+43
-10
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LinearSolve"
22
uuid = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
33
authors = ["SciML"]
4-
version = "2.23.2"
4+
version = "2.23.3"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"

src/common.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ function Base.setproperty!(cache::LinearCache, name::Symbol, x)
9090
update_cacheval!(cache, :b, x)
9191
elseif name === :cacheval && cache.alg isa DefaultLinearSolver
9292
@assert cache.cacheval isa DefaultLinearSolverInit
93-
return setfield!(cache.cacheval, Symbol(cache.alg.alg), x)
93+
return __setfield!(cache.cacheval, cache.alg, x)
94+
# return setfield!(cache.cacheval, Symbol(cache.alg.alg), x)
9495
end
9596
setfield!(cache, name, x)
9697
end

src/default.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,23 @@ mutable struct DefaultLinearSolverInit{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
2424
KrylovJL_LSMR::T21
2525
end
2626

27+
@generated function __setfield!(cache::DefaultLinearSolverInit, alg::DefaultLinearSolver, v)
28+
ex = :()
29+
for alg in first.(EnumX.symbol_map(DefaultAlgorithmChoice.T))
30+
newex = quote
31+
setfield!(cache, $(Meta.quot(alg)), v)
32+
end
33+
alg_enum = getproperty(LinearSolve.DefaultAlgorithmChoice, alg)
34+
ex = if ex == :()
35+
Expr(:elseif, :(alg.alg == $(alg_enum)), newex,
36+
:(error("Algorithm Choice not Allowed")))
37+
else
38+
Expr(:elseif, :(alg.alg == $(alg_enum)), newex, ex)
39+
end
40+
end
41+
ex = Expr(:if, ex.args...)
42+
end
43+
2744
# Legacy fallback
2845
# For SciML algorithms already using `defaultalg`, all assume square matrix.
2946
defaultalg(A, b) = defaultalg(A, b, OperatorAssumptions(true))
@@ -159,7 +176,7 @@ function defaultalg(A, b, assump::OperatorAssumptions{Bool})
159176
(__conditioning(assump) === OperatorCondition.IllConditioned ||
160177
__conditioning(assump) === OperatorCondition.WellConditioned)
161178
if length(b) <= 10
162-
DefaultAlgorithmChoice.GenericLUFactorization
179+
DefaultAlgorithmChoice.RFLUFactorization
163180
elseif appleaccelerate_isavailable()
164181
DefaultAlgorithmChoice.AppleAccelerateLUFactorization
165182
elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500) ||
@@ -345,11 +362,12 @@ end
345362
retcode = sol.retcode,
346363
iters = sol.iters, stats = sol.stats)
347364
end
365+
alg_enum = getproperty(LinearSolve.DefaultAlgorithmChoice, alg)
348366
ex = if ex == :()
349-
Expr(:elseif, :(Symbol(alg.alg) === $(Meta.quot(alg))), newex,
367+
Expr(:elseif, :(alg.alg == $(alg_enum)), newex,
350368
:(error("Algorithm Choice not Allowed")))
351369
else
352-
Expr(:elseif, :(Symbol(alg.alg) === $(Meta.quot(alg))), newex, ex)
370+
Expr(:elseif, :(alg.alg == $(alg_enum)), newex, ex)
353371
end
354372
end
355373
ex = Expr(:if, ex.args...)

src/factorization.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -173,18 +173,24 @@ function init_cacheval(alg::QRFactorization, A, b, u, Pl, Pr,
173173
ArrayInterface.qr_instance(convert(AbstractMatrix, A), alg.pivot)
174174
end
175175

176+
const PREALLOCATED_QR_ColumnNorm = ArrayInterface.qr_instance(rand(1, 1), ColumnNorm())
177+
178+
function init_cacheval(alg::QRFactorization{ColumnNorm}, A::Matrix{Float64}, b, u, Pl, Pr,
179+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
180+
return PREALLOCATED_QR_ColumnNorm
181+
end
182+
176183
function init_cacheval(alg::QRFactorization, A::Union{<:Adjoint, <:Transpose}, b, u, Pl, Pr,
177184
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
178185
A isa GPUArraysCore.AnyGPUArray && return qr(A)
179186
return qr(A, alg.pivot)
180187
end
181188

182-
const PREALLOCATED_QR = ArrayInterface.qr_instance(rand(1, 1))
189+
const PREALLOCATED_QR_NoPivot = ArrayInterface.qr_instance(rand(1, 1))
183190

184191
function init_cacheval(alg::QRFactorization{NoPivot}, A::Matrix{Float64}, b, u, Pl, Pr,
185-
maxiters::Int, abstol, reltol, verbose::Bool,
186-
assumptions::OperatorAssumptions)
187-
PREALLOCATED_QR
192+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
193+
return PREALLOCATED_QR_NoPivot
188194
end
189195

190196
function init_cacheval(alg::QRFactorization, A::AbstractSciMLOperator, b, u, Pl, Pr,
@@ -1013,6 +1019,14 @@ function init_cacheval(alg::NormalCholeskyFactorization, A, b, u, Pl, Pr,
10131019
return ArrayInterface.cholesky_instance(Symmetric(Matrix{eltype(A)}(undef,0,0)), alg.pivot)
10141020
end
10151021

1022+
const PREALLOCATED_NORMALCHOLESKY_SYMMETRIC = ArrayInterface.cholesky_instance(
1023+
Symmetric(rand(1, 1)), NoPivot())
1024+
1025+
function init_cacheval(alg::NormalCholeskyFactorization, A::Matrix{Float64}, b, u, Pl, Pr,
1026+
maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
1027+
return PREALLOCATED_NORMALCHOLESKY_SYMMETRIC
1028+
end
1029+
10161030
function init_cacheval(alg::NormalCholeskyFactorization,
10171031
A::Union{Diagonal, AbstractSciMLOperator}, b, u, Pl, Pr,
10181032
maxiters::Int, abstol, reltol, verbose::Bool,

test/default_algs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using LinearSolve, LinearAlgebra, SparseArrays, Test, JET
22
@test LinearSolve.defaultalg(nothing, zeros(3)).alg ===
3-
LinearSolve.DefaultAlgorithmChoice.GenericLUFactorization
3+
LinearSolve.DefaultAlgorithmChoice.RFLUFactorization
44
prob = LinearProblem(rand(3, 3), rand(3))
55
solve(prob)
66

0 commit comments

Comments
 (0)