|
1 | | -module LinearSolveCUDAExt |
2 | | - |
3 | | -using CUDA |
4 | | -using LinearSolve |
5 | | -using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface |
6 | | -using SciMLBase: AbstractSciMLOperator |
7 | | - |
8 | | -function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b, |
9 | | - assump::OperatorAssumptions{Bool}) where {Tv, Ti} |
10 | | - if LinearSolve.cudss_loaded(A) |
11 | | - LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization) |
12 | | - else |
13 | | - if !LinearSolve.ALREADY_WARNED_CUDSS[] |
14 | | - @warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov") |
15 | | - LinearSolve.ALREADY_WARNED_CUDSS[] = true |
16 | | - end |
17 | | - LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES) |
18 | | - end |
19 | | -end |
20 | | - |
21 | | -function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR) |
22 | | - if !LinearSolve.CUDSS_LOADED[] |
23 | | - error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.") |
24 | | - end |
25 | | - nothing |
26 | | -end |
27 | | - |
28 | | -function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization; |
29 | | - kwargs...) |
30 | | - if cache.isfresh |
31 | | - fact = qr(CUDA.CuArray(cache.A)) |
32 | | - cache.cacheval = fact |
33 | | - cache.isfresh = false |
34 | | - end |
35 | | - y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b))) |
36 | | - cache.u .= y |
37 | | - SciMLBase.build_linear_solution(alg, y, nothing, cache) |
38 | | -end |
39 | | - |
40 | | -function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr, |
41 | | - maxiters::Int, abstol, reltol, verbose::Bool, |
42 | | - assumptions::OperatorAssumptions) |
43 | | - qr(CUDA.CuArray(A)) |
44 | | -end |
45 | | - |
46 | | -function LinearSolve.init_cacheval(::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, |
47 | | - Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
48 | | - nothing |
49 | | -end |
50 | | - |
51 | | -function LinearSolve.init_cacheval(::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, |
52 | | - Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
53 | | - nothing |
54 | | -end |
55 | | - |
56 | | -function LinearSolve.init_cacheval(::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, |
57 | | - Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
58 | | - nothing |
59 | | -end |
60 | | - |
61 | | -end |
| 1 | +module LinearSolveCUDAExt |
| 2 | + |
| 3 | +using CUDA |
| 4 | +using LinearSolve |
| 5 | +using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface |
| 6 | +using SciMLBase: AbstractSciMLOperator |
| 7 | + |
| 8 | +function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b, |
| 9 | + assump::OperatorAssumptions{Bool}) where {Tv, Ti} |
| 10 | + if LinearSolve.cudss_loaded(A) |
| 11 | + LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization) |
| 12 | + else |
| 13 | + if !LinearSolve.ALREADY_WARNED_CUDSS[] |
| 14 | + @warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov") |
| 15 | + LinearSolve.ALREADY_WARNED_CUDSS[] = true |
| 16 | + end |
| 17 | + LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES) |
| 18 | + end |
| 19 | +end |
| 20 | + |
| 21 | +function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR) |
| 22 | + if !LinearSolve.CUDSS_LOADED[] |
| 23 | + error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.") |
| 24 | + end |
| 25 | + nothing |
| 26 | +end |
| 27 | + |
| 28 | +function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization; |
| 29 | + kwargs...) |
| 30 | + if cache.isfresh |
| 31 | + fact = qr(CUDA.CuArray(cache.A)) |
| 32 | + cache.cacheval = fact |
| 33 | + cache.isfresh = false |
| 34 | + end |
| 35 | + y = Array(ldiv!(CUDA.CuArray(cache.u), cache.cacheval, CUDA.CuArray(cache.b))) |
| 36 | + cache.u .= y |
| 37 | + SciMLBase.build_linear_solution(alg, y, nothing, cache) |
| 38 | +end |
| 39 | + |
| 40 | +function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr, |
| 41 | + maxiters::Int, abstol, reltol, verbose::Bool, |
| 42 | + assumptions::OperatorAssumptions) |
| 43 | + qr(CUDA.CuArray(A)) |
| 44 | +end |
| 45 | + |
| 46 | +function LinearSolve.init_cacheval( |
| 47 | + ::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, |
| 48 | + Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
| 49 | + nothing |
| 50 | +end |
| 51 | + |
| 52 | +function LinearSolve.init_cacheval( |
| 53 | + ::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, |
| 54 | + Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
| 55 | + nothing |
| 56 | +end |
| 57 | + |
| 58 | +function LinearSolve.init_cacheval( |
| 59 | + ::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, |
| 60 | + Pl, Pr, maxiters::Int, abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions) |
| 61 | + nothing |
| 62 | +end |
| 63 | + |
| 64 | +end |
0 commit comments