@@ -78,6 +78,8 @@ function do_factorization(alg::LUFactorization, A, b, u)
7878 if A isa AbstractSparseMatrixCSC
7979 return lu (SparseMatrixCSC (size (A)... , getcolptr (A), rowvals (A), nonzeros (A)),
8080 check = false )
81+ elseif A isa GPUArraysCore. AnyGPUArray
82+ fact = lu (A; check = false )
8183 elseif ! ArrayInterface. can_setindex (typeof (A))
8284 fact = lu (A, alg. pivot, check = false )
8385 else
@@ -98,6 +100,17 @@ function init_cacheval(alg::Union{LUFactorization, GenericLUFactorization}, A, b
98100 ArrayInterface. lu_instance (convert (AbstractMatrix, A))
99101end
100102
103+ function init_cacheval (alg:: Union{LUFactorization, GenericLUFactorization} ,
104+ A:: Union{<:Adjoint, <:Transpose} , b, u, Pl, Pr, maxiters:: Int , abstol, reltol,
105+ verbose:: Bool , assumptions:: OperatorAssumptions )
106+ if alg isa LUFactorization
107+ return lu (A; check= false )
108+ else
109+ A isa GPUArraysCore. AnyGPUArray && return nothing
110+ return LinearAlgebra. generic_lufact! (copy (A), alg. pivot; check= false )
111+ end
112+ end
113+
101114const PREALLOCATED_LU = ArrayInterface. lu_instance (rand (1 , 1 ))
102115
103116function init_cacheval (alg:: Union{LUFactorization, GenericLUFactorization} ,
143156function do_factorization (alg:: QRFactorization , A, b, u)
144157 A = convert (AbstractMatrix, A)
145158 if ArrayInterface. can_setindex (typeof (A))
146- if alg. inplace && ! (A isa SparseMatrixCSC) && ! (A isa GPUArraysCore. AbstractGPUArray )
159+ if alg. inplace && ! (A isa SparseMatrixCSC) && ! (A isa GPUArraysCore. AnyGPUArray )
147160 fact = qr! (A, alg. pivot)
148161 else
149162 fact = qr (A) # CUDA.jl does not allow other args!
@@ -160,6 +173,12 @@ function init_cacheval(alg::QRFactorization, A, b, u, Pl, Pr,
160173 ArrayInterface. qr_instance (convert (AbstractMatrix, A), alg. pivot)
161174end
162175
176+ function init_cacheval (alg:: QRFactorization , A:: Union{<:Adjoint, <:Transpose} , b, u, Pl, Pr,
177+ maxiters:: Int , abstol, reltol, verbose:: Bool , assumptions:: OperatorAssumptions )
178+ A isa GPUArraysCore. AnyGPUArray && return qr (A)
179+ return qr (A, alg. pivot)
180+ end
181+
163182const PREALLOCATED_QR = ArrayInterface. qr_instance (rand (1 , 1 ))
164183
165184function init_cacheval (alg:: QRFactorization{NoPivot} , A:: Matrix{Float64} , b, u, Pl, Pr,
@@ -204,6 +223,8 @@ function do_factorization(alg::CholeskyFactorization, A, b, u)
204223 A = convert (AbstractMatrix, A)
205224 if A isa SparseMatrixCSC
206225 fact = cholesky (A; shift = alg. shift, check = false , perm = alg. perm)
226+ elseif A isa GPUArraysCore. AnyGPUArray
227+ fact = cholesky (A; check = false )
207228 elseif alg. pivot === Val (false ) || alg. pivot === NoPivot ()
208229 fact = cholesky! (A, alg. pivot; check = false )
209230 else
@@ -218,9 +239,13 @@ function init_cacheval(alg::CholeskyFactorization, A::SMatrix{S1, S2}, b, u, Pl,
218239 cholesky (A)
219240end
220241
242+ function init_cacheval (alg:: CholeskyFactorization , A:: GPUArraysCore.AnyGPUArray , b, u, Pl,
243+ Pr, maxiters:: Int , abstol, reltol, verbose:: Bool , assumptions:: OperatorAssumptions )
244+ cholesky (A; check= false )
245+ end
246+
221247function init_cacheval (alg:: CholeskyFactorization , A, b, u, Pl, Pr,
222- maxiters:: Int , abstol, reltol, verbose:: Bool ,
223- assumptions:: OperatorAssumptions )
248+ maxiters:: Int , abstol, reltol, verbose:: Bool , assumptions:: OperatorAssumptions )
224249 ArrayInterface. cholesky_instance (convert (AbstractMatrix, A), alg. pivot)
225250end
226251
@@ -968,7 +993,7 @@ default_alias_b(::NormalCholeskyFactorization, ::Any, ::Any) = true
968993const PREALLOCATED_NORMALCHOLESKY = ArrayInterface. cholesky_instance (rand (1 , 1 ), NoPivot ())
969994
970995function init_cacheval (alg:: NormalCholeskyFactorization ,
971- A:: Union {AbstractSparseArray, GPUArraysCore. AbstractGPUArray ,
996+ A:: Union {AbstractSparseArray, GPUArraysCore. AnyGPUArray ,
972997 Symmetric{<: Number , <: AbstractSparseArray }}, b, u, Pl, Pr,
973998 maxiters:: Int , abstol, reltol, verbose:: Bool ,
974999 assumptions:: OperatorAssumptions )
@@ -999,7 +1024,7 @@ function SciMLBase.solve!(cache::LinearCache, alg::NormalCholeskyFactorization;
9991024 A = cache. A
10001025 A = convert (AbstractMatrix, A)
10011026 if cache. isfresh
1002- if A isa SparseMatrixCSC || A isa GPUArraysCore. AbstractGPUArray || A isa SMatrix
1027+ if A isa SparseMatrixCSC || A isa GPUArraysCore. AnyGPUArray || A isa SMatrix
10031028 fact = cholesky (Symmetric ((A)' * A); check = false )
10041029 else
10051030 fact = cholesky (Symmetric ((A)' * A), alg. pivot; check = false )
0 commit comments