|
1 | | -## Default algorithm |
| 1 | +defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions) = defaultalg(A.A, b, assumptions) |
2 | 2 |
|
3 | | -# Allows A === nothing as a stand-in for dense matrix |
4 | | -function defaultalg(A, b) |
5 | | - if A isa DiffEqArrayOperator |
6 | | - A = A.A |
7 | | - end |
8 | | - |
9 | | - # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when |
10 | | - # it makes sense according to the benchmarks, which is dependent on |
11 | | - # whether MKL or OpenBLAS is being used |
12 | | - if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix |
13 | | - if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) && |
14 | | - ArrayInterfaceCore.can_setindex(b) |
15 | | - if length(b) <= 10 |
16 | | - alg = GenericLUFactorization() |
17 | | - elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) && |
18 | | - eltype(A) <: Union{Float32, Float64} |
19 | | - alg = RFLUFactorization() |
20 | | - #elseif A === nothing || A isa Matrix |
21 | | - # alg = FastLUFactorization() |
22 | | - else |
23 | | - alg = LUFactorization() |
24 | | - end |
25 | | - else |
26 | | - alg = LUFactorization() |
27 | | - end |
| 3 | +# Ambiguity handling |
| 4 | +defaultalg(A::DiffEqArrayOperator, b, assumptions::OperatorAssumptions{nothing}) = defaultalg(A.A, b, assumptions) |
28 | 5 |
|
29 | | - # These few cases ensure the choice is optimal without the |
30 | | - # dynamic dispatching of factorize |
31 | | - elseif A isa Tridiagonal |
32 | | - alg = GenericFactorization(; fact_alg = lu!) |
33 | | - elseif A isa SymTridiagonal |
34 | | - alg = GenericFactorization(; fact_alg = ldlt!) |
35 | | - elseif A isa SparseMatrixCSC |
36 | | - if length(b) <= 10_000 |
37 | | - alg = KLUFactorization() |
38 | | - else |
39 | | - alg = UMFPACKFactorization() |
40 | | - end |
| 6 | +function defaultalg(A, b, ::OperatorAssumptions{nothing}) |
| 7 | + issquare = size(A,1) == size(A,2) |
| 8 | + defaultalg(A, b, OperatorAssumptions(Val(issquare))) |
| 9 | +end |
41 | 10 |
|
42 | | - # This catches the cases where a factorization overload could exist |
43 | | - # For example, BlockBandedMatrix |
44 | | - elseif A !== nothing && ArrayInterfaceCore.isstructured(A) |
45 | | - alg = GenericFactorization() |
| 11 | +defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{true}) = GenericFactorization(; fact_alg = lu!) |
| 12 | +defaultalg(A::Tridiagonal, b, ::OperatorAssumptions{false}) = GenericFactorization(; fact_alg = qr!) |
| 13 | +defaultalg(A::SymTridiagonal, b, ::OperatorAssumptions{true}) = GenericFactorization(; fact_alg = ldlt!) |
46 | 14 |
|
47 | | - # This catches the case where A is a CuMatrix |
48 | | - # Which does not have LU fully defined |
49 | | - elseif A isa GPUArraysCore.AbstractGPUArray || b isa GPUArraysCore.AbstractGPUArray |
50 | | - if VERSION >= v"1.8-" |
51 | | - alg = LUFactorization() |
52 | | - else |
53 | | - alg = QRFactorization() |
54 | | - end |
| 15 | +function defaultalg(A::SparseMatrixCSC, b, ::OperatorAssumptions{true}) |
| 16 | + if length(b) <= 10_000 |
| 17 | + KLUFactorization() |
| 18 | + else |
| 19 | + UMFPACKFactorization() |
| 20 | + end |
| 21 | +end |
55 | 22 |
|
56 | | - # Not factorizable operator, default to only using A*x |
| 23 | +function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, ::OperatorAssumptions{true}) |
| 24 | + if VERSION >= v"1.8-" |
| 25 | + LUFactorization() |
57 | 26 | else |
58 | | - alg = KrylovJL_GMRES() |
| 27 | + QRFactorization() |
59 | 28 | end |
60 | | - alg |
61 | 29 | end |
62 | 30 |
|
63 | | -## Other dispatches are to decrease the dispatch cost |
| 31 | +function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{true}) |
| 32 | + if VERSION >= v"1.8-" |
| 33 | + LUFactorization() |
| 34 | + else |
| 35 | + QRFactorization() |
| 36 | + end |
| 37 | +end |
64 | 38 |
|
65 | | -function SciMLBase.solve(cache::LinearCache, alg::Nothing, |
66 | | - args...; kwargs...) |
67 | | - @unpack A = cache |
68 | | - if A isa DiffEqArrayOperator |
69 | | - A = A.A |
| 39 | +# Handle ambiguity |
| 40 | +function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{true}) |
| 41 | + if VERSION >= v"1.8-" |
| 42 | + LUFactorization() |
| 43 | + else |
| 44 | + QRFactorization() |
70 | 45 | end |
| 46 | +end |
| 47 | + |
| 48 | +function defaultalg(A::GPUArraysCore.AbstractGPUArray, b, ::OperatorAssumptions{false}) |
| 49 | + QRFactorization() |
| 50 | +end |
| 51 | + |
| 52 | +function defaultalg(A, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{false}) |
| 53 | + QRFactorization() |
| 54 | +end |
71 | 55 |
|
| 56 | +# Handle ambiguity |
| 57 | +function defaultalg(A::GPUArraysCore.AbstractGPUArray, b::GPUArraysCore.AbstractGPUArray, ::OperatorAssumptions{false}) |
| 58 | + QRFactorization() |
| 59 | +end |
| 60 | + |
| 61 | +# Allows A === nothing as a stand-in for dense matrix |
| 62 | +function defaultalg(A, b, ::Assumptions{true}) |
72 | 63 | # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when |
73 | 64 | # it makes sense according to the benchmarks, which is dependent on |
74 | 65 | # whether MKL or OpenBLAS is being used |
75 | | - if A isa Matrix |
76 | | - b = cache.b |
| 66 | + if (A === nothing && !(b isa GPUArraysCore.AbstractGPUArray)) || A isa Matrix |
77 | 67 | if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) && |
78 | 68 | ArrayInterfaceCore.can_setindex(b) |
79 | 69 | if length(b) <= 10 |
80 | 70 | alg = GenericLUFactorization() |
81 | | - SciMLBase.solve(cache, alg, args...; kwargs...) |
82 | 71 | elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) && |
83 | 72 | eltype(A) <: Union{Float32, Float64} |
84 | 73 | alg = RFLUFactorization() |
85 | | - SciMLBase.solve(cache, alg, args...; kwargs...) |
86 | | - #elseif A isa Matrix |
| 74 | + #elseif A === nothing || A isa Matrix |
87 | 75 | # alg = FastLUFactorization() |
88 | | - # SciMLBase.solve(cache, alg, args...; kwargs...) |
89 | 76 | else |
90 | 77 | alg = LUFactorization() |
91 | | - SciMLBase.solve(cache, alg, args...; kwargs...) |
92 | 78 | end |
93 | 79 | else |
94 | 80 | alg = LUFactorization() |
95 | | - SciMLBase.solve(cache, alg, args...; kwargs...) |
96 | | - end |
97 | | - |
98 | | - # These few cases ensure the choice is optimal without the |
99 | | - # dynamic dispatching of factorize |
100 | | - elseif A isa Tridiagonal |
101 | | - alg = GenericFactorization(; fact_alg = lu!) |
102 | | - SciMLBase.solve(cache, alg, args...; kwargs...) |
103 | | - elseif A isa SymTridiagonal |
104 | | - alg = GenericFactorization(; fact_alg = ldlt!) |
105 | | - SciMLBase.solve(cache, alg, args...; kwargs...) |
106 | | - elseif A isa SparseMatrixCSC |
107 | | - b = cache.b |
108 | | - if length(b) <= 10_000 |
109 | | - alg = KLUFactorization() |
110 | | - SciMLBase.solve(cache, alg, args...; kwargs...) |
111 | | - else |
112 | | - alg = UMFPACKFactorization() |
113 | | - SciMLBase.solve(cache, alg, args...; kwargs...) |
114 | 81 | end |
115 | 82 |
|
116 | 83 | # This catches the cases where a factorization overload could exist |
117 | 84 | # For example, BlockBandedMatrix |
118 | | - elseif ArrayInterfaceCore.isstructured(A) |
| 85 | + elseif A !== nothing && ArrayInterfaceCore.isstructured(A) |
119 | 86 | alg = GenericFactorization() |
120 | | - SciMLBase.solve(cache, alg, args...; kwargs...) |
121 | 87 |
|
122 | | - # This catches the case where A is a CuMatrix |
123 | | - # Which does not have LU fully defined |
124 | | - elseif A isa GPUArraysCore.AbstractGPUArray |
125 | | - if VERSION >= v"1.8-" |
126 | | - alg = LUFactorization() |
127 | | - SciMLBase.solve(cache, alg, args...; kwargs...) |
128 | | - else |
129 | | - alg = QRFactorization() |
130 | | - SciMLBase.solve(cache, alg, args...; kwargs...) |
131 | | - end |
132 | 88 | # Not factorizable operator, default to only using A*x |
133 | | - # IterativeSolvers is faster on CPU but not GPU-compatible |
134 | 89 | else |
135 | 90 | alg = KrylovJL_GMRES() |
136 | | - SciMLBase.solve(cache, alg, args...; kwargs...) |
137 | 91 | end |
| 92 | + alg |
138 | 93 | end |
139 | 94 |
|
140 | | -function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
141 | | - if A isa DiffEqArrayOperator |
142 | | - A = A.A |
143 | | - end |
144 | | - |
145 | | - # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when |
146 | | - # it makes sense according to the benchmarks, which is dependent on |
147 | | - # whether MKL or OpenBLAS is being used |
148 | | - if A isa Matrix |
149 | | - if (A === nothing || eltype(A) <: Union{Float32, Float64, ComplexF32, ComplexF64}) && |
150 | | - ArrayInterfaceCore.can_setindex(b) |
151 | | - if length(b) <= 10 |
152 | | - alg = GenericLUFactorization() |
153 | | - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
154 | | - elseif (length(b) <= 100 || (isopenblas() && length(b) <= 500)) && |
155 | | - eltype(A) <: Union{Float32, Float64} |
156 | | - alg = RFLUFactorization() |
157 | | - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
158 | | - #elseif A isa Matrix |
159 | | - # alg = FastLUFactorization() |
160 | | - # init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
161 | | - else |
162 | | - alg = LUFactorization() |
163 | | - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
164 | | - end |
165 | | - else |
166 | | - alg = LUFactorization() |
167 | | - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
168 | | - end |
| 95 | +function defaultalg(A, b, ::Val{false}) |
| 96 | + QRFactorization() |
| 97 | +end |
169 | 98 |
|
170 | | - # These few cases ensure the choice is optimal without the |
171 | | - # dynamic dispatching of factorize |
172 | | - elseif A isa Tridiagonal |
173 | | - alg = GenericFactorization(; fact_alg = lu!) |
174 | | - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
175 | | - elseif A isa SymTridiagonal |
176 | | - alg = GenericFactorization(; fact_alg = ldlt!) |
177 | | - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
178 | | - elseif A isa SparseMatrixCSC |
179 | | - if length(b) <= 10_000 |
180 | | - alg = KLUFactorization() |
181 | | - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
182 | | - else |
183 | | - alg = UMFPACKFactorization() |
184 | | - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
185 | | - end |
| 99 | +## Catch high level interface |
186 | 100 |
|
187 | | - # This catches the cases where a factorization overload could exist |
188 | | - # For example, BlockBandedMatrix |
189 | | - elseif ArrayInterfaceCore.isstructured(A) |
190 | | - alg = GenericFactorization() |
191 | | - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
| 101 | +function SciMLBase.solve(cache::LinearCache, alg::Nothing, |
| 102 | + args...; assumptions::OperatorAssumptions = OperatorAssumptions(), kwargs...) |
| 103 | + @unpack A, b = cache |
| 104 | + SciMLBase.solve(cache, default_alg(A,b,assumptions), args...; kwargs...) |
| 105 | +end |
192 | 106 |
|
193 | | - # This catches the case where A is a CuMatrix |
194 | | - # Which does not have LU fully defined |
195 | | - elseif A isa GPUArraysCore.AbstractGPUArray |
196 | | - if VERSION >= v"1.8-" |
197 | | - alg = LUFactorization() |
198 | | - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
199 | | - else |
200 | | - alg = QRFactorization() |
201 | | - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
202 | | - end |
203 | | - # Not factorizable operator, default to only using A*x |
204 | | - # IterativeSolvers is faster on CPU but not GPU-compatible |
205 | | - else |
206 | | - alg = KrylovJL_GMRES() |
207 | | - init_cacheval(alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose) |
208 | | - end |
| 107 | +function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
| 108 | + init_cacheval(default_alg(A,b), A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose, assumptions) |
209 | 109 | end |
0 commit comments