|
1 | 1 | module LinearSolve |
2 | 2 |
|
3 | | -using Base: cache_dependencies, Bool |
4 | | -using SciMLBase: AbstractLinearAlgorithm, AbstractDiffEqOperator |
5 | 3 | using ArrayInterface: lu_instance |
6 | | -using UnPack |
7 | | -using Reexport |
| 4 | +using Base: cache_dependencies, Bool |
| 5 | +using Krylov |
8 | 6 | using LinearAlgebra |
| 7 | +using Reexport |
| 8 | +using SciMLBase: AbstractDiffEqOperator, AbstractLinearAlgorithm |
9 | 9 | using Setfield |
10 | | -@reexport using SciMLBase |
11 | | - |
12 | | -export LUFactorization, QRFactorization, SVDFactorization |
13 | | - |
14 | | -#mutable?# |
15 | | -struct LinearCache{TA,Tb,Tp,Talg,Tc,Tr,Tl} |
16 | | - A::TA |
17 | | - b::Tb |
18 | | - p::Tp |
19 | | - alg::Talg |
20 | | - cacheval::Tc |
21 | | - isfresh::Bool |
22 | | - Pr::Tr |
23 | | - Pl::Tl |
24 | | -end |
25 | | - |
26 | | -function set_A(cache, A) |
27 | | - @set! cache.A = A |
28 | | - @set! cache.isfresh = true |
29 | | -end |
30 | | - |
31 | | -function set_b(cache, b) |
32 | | - @set! cache.b = b |
33 | | -end |
34 | | - |
35 | | -function set_p(cache, p) |
36 | | - @set! cache.p = p |
37 | | - # @set! cache.isfresh = true |
38 | | -end |
39 | | - |
40 | | -function set_cacheval(cache::LinearCache,alg) |
41 | | - if cache.isfresh |
42 | | - @set! cache.cacheval = alg |
43 | | - @set! cache.isfresh = false |
44 | | - end |
45 | | - return cache |
46 | | -end |
47 | | - |
48 | | -function SciMLBase.init(prob::LinearProblem, alg; |
49 | | - alias_A = false, alias_b = false, |
50 | | - kwargs...) |
51 | | - @unpack A, b, p = prob |
52 | | - if alg isa LUFactorization |
53 | | - fact = lu_instance(A) |
54 | | - Tfact = typeof(fact) |
55 | | - else |
56 | | - fact = nothing |
57 | | - Tfact = Any |
58 | | - end |
59 | | - Pr = nothing |
60 | | - Pl = nothing |
61 | | - |
62 | | - A = alias_A ? A : copy(A) |
63 | | - b = alias_b ? b : copy(b) |
64 | | - |
65 | | - cache = LinearCache{typeof(A),typeof(b),typeof(p),typeof(alg),Tfact,typeof(Pr),typeof(Pl)}( |
66 | | - A, b, p, alg, fact, true, Pr, Pl |
67 | | - ) |
68 | | - return cache |
69 | | -end |
70 | | - |
71 | | -SciMLBase.solve(prob::LinearProblem, alg; kwargs...) = solve(init(prob, alg; kwargs...)) |
72 | | -SciMLBase.solve(cache) = solve(cache, cache.alg) |
73 | | - |
74 | | -struct LUFactorization{P} <: AbstractLinearAlgorithm |
75 | | - pivot::P |
76 | | -end |
77 | | -function LUFactorization() |
78 | | - pivot = @static if VERSION < v"1.7beta" |
79 | | - Val(true) |
80 | | - else |
81 | | - RowMaximum() |
82 | | - end |
83 | | - LUFactorization(pivot) |
84 | | -end |
85 | | - |
86 | | -function SciMLBase.solve(cache::LinearCache, alg::LUFactorization) |
87 | | - cache.A isa Union{AbstractMatrix, AbstractDiffEqOperator} || error("LU is not defined for $(typeof(prob.A))") |
88 | | - cache = set_cacheval(cache,lu!(cache.A, alg.pivot)) |
89 | | - ldiv!(cache.cacheval, cache.b) |
90 | | -end |
| 10 | +using UnPack |
91 | 11 |
|
92 | | -struct QRFactorization{P} <: AbstractLinearAlgorithm |
93 | | - pivot::P |
94 | | - blocksize::Int |
95 | | -end |
96 | | -function QRFactorization() |
97 | | - pivot = @static if VERSION < v"1.7beta" |
98 | | - Val(false) |
99 | | - else |
100 | | - NoPivot() |
101 | | - end |
102 | | - QRFactorization(pivot, 16) |
103 | | -end |
| 12 | +@reexport using SciMLBase |
104 | 13 |
|
105 | | -function SciMLBase.solve(cache::LinearCache, alg::QRFactorization) |
106 | | - cache.A isa Union{AbstractMatrix, AbstractDiffEqOperator} || error("QR is not defined for $(typeof(prob.A))") |
107 | | - cache = set_cacheval(cache,qr!(cache.A.A, alg.pivot; blocksize=alg.blocksize)) |
108 | | - ldiv!(cache.cacheval, cache.b) |
109 | | -end |
| 14 | +abstract type SciMLLinearSolveAlgorithm end |
110 | 15 |
|
111 | | -struct SVDFactorization{A} <: AbstractLinearAlgorithm |
112 | | - full::Bool |
113 | | - alg::A |
114 | | -end |
115 | | -SVDFactorization() = SVDFactorization(false, LinearAlgebra.DivideAndConquer()) |
| 16 | +include("common.jl") |
| 17 | +include("factorization.jl") |
| 18 | +include("krylov.jl") |
116 | 19 |
|
117 | | -function SciMLBase.solve(cache::LinearCache, alg::SVDFactorization) |
118 | | - cache.A isa Union{AbstractMatrix, AbstractDiffEqOperator} || error("SVD is not defined for $(typeof(prob.A))") |
119 | | - cache = set_cacheval(cache,svd!(cache.A; full=alg.full, alg=alg.alg)) |
120 | | - ldiv!(cache.cacheval, cache.b) |
121 | | -end |
| 20 | +export LUFactorization, SVDFactorization, QRFactorization |
| 21 | +export KrylovJL |
122 | 22 |
|
123 | 23 | end |
0 commit comments