@@ -9,7 +9,7 @@ function defaultalg(A,b)
99 # Special case on Arrays: avoid BLAS for RecursiveFactorization.jl when
1010 # it makes sense according to the benchmarks, which is dependent on
1111 # whether MKL or OpenBLAS is being used
12- if A === nothing || A isa Matrix
12+ if ( A === nothing && ! isgpu (b)) || A isa Matrix
1313 if (A === nothing || eltype (A) <: Union{Float32,Float64,ComplexF32,ComplexF64} ) &&
1414 ArrayInterface. can_setindex (b) && (length (b) <= 100 ||
1515 (isopenblas () && length (b) <= 500 )
@@ -30,18 +30,15 @@ function defaultalg(A,b)
3030
3131 # This catches the cases where a factorization overload could exist
3232 # For example, BlockBandedMatrix
33- elseif ArrayInterface. isstructured (A)
33+ elseif A != = nothing && ArrayInterface. isstructured (A)
3434 alg = GenericFactorization ()
3535
3636 # This catches the case where A is a CuMatrix
3737 # Which does not have LU fully defined
38- elseif ! (A isa AbstractDiffEqOperator )
38+ elseif isgpu (A) || isgpu (b )
3939 alg = QRFactorization (false )
4040
4141 # Not factorizable operator, default to only using A*x
42- # IterativeSolvers is faster on CPU but not GPU-compatible
43- elseif cache. u isa Array
44- alg = IterativeSolversJL_GMRES ()
4542 else
4643 alg = KrylovJL_GMRES ()
4744 end
@@ -92,15 +89,12 @@ function SciMLBase.solve(cache::LinearCache, alg::Nothing,
9289
9390 # This catches the case where A is a CuMatrix
9491 # Which does not have LU fully defined
95- elseif ! (A isa AbstractDiffEqOperator )
92+ elseif isgpu (A )
9693 alg = QRFactorization (false )
9794 SciMLBase. solve (cache, alg, args... ; kwargs... )
9895
9996 # Not factorizable operator, default to only using A*x
10097 # IterativeSolvers is faster on CPU but not GPU-compatible
101- elseif cache. u isa Array
102- alg = IterativeSolversJL_GMRES ()
103- SciMLBase. solve (cache, alg, args... ; kwargs... )
10498 else
10599 alg = KrylovJL_GMRES ()
106100 SciMLBase. solve (cache, alg, args... ; kwargs... )
@@ -147,15 +141,12 @@ function init_cacheval(alg::Nothing, A, b, u, Pl, Pr, maxiters, abstol, reltol,
147141
148142 # This catches the case where A is a CuMatrix
149143 # Which does not have LU fully defined
150- elseif ! (A isa AbstractDiffEqOperator )
144+ elseif isgpu (A )
151145 alg = QRFactorization (false )
152146 init_cacheval (alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
153147
154148 # Not factorizable operator, default to only using A*x
155149 # IterativeSolvers is faster on CPU but not GPU-compatible
156- elseif u isa Array
157- alg = IterativeSolversJL_GMRES ()
158- init_cacheval (alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
159150 else
160151 alg = KrylovJL_GMRES ()
161152 init_cacheval (alg, A, b, u, Pl, Pr, maxiters, abstol, reltol, verbose)
0 commit comments