@@ -8,6 +8,46 @@ to avoid allocations and does not require libblastrampoline.
88"""
99struct MKLLUFactorization <: AbstractFactorization end
1010
11+ function getrf! (A:: AbstractMatrix{<:ComplexF64} ;
12+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
13+ info = Ref {BlasInt} (),
14+ check = false )
15+ require_one_based_indexing (A)
16+ check && chkfinite (A)
17+ chkstride1 (A)
18+ m, n = size (A)
19+ lda = max (1 , stride (A, 2 ))
20+ if isempty (ipiv)
21+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
22+ end
23+ ccall ((@blasfunc (zgetrf_), MKL_jll. libmkl_rt), Cvoid,
24+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64},
25+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
26+ m, n, A, lda, ipiv, info)
27+ chkargsok (info[])
28+ A, ipiv, info[], info # Error code is stored in LU factorization type
29+ end
30+
31+ function getrf! (A:: AbstractMatrix{<:ComplexF32} ;
32+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
33+ info = Ref {BlasInt} (),
34+ check = false )
35+ require_one_based_indexing (A)
36+ check && chkfinite (A)
37+ chkstride1 (A)
38+ m, n = size (A)
39+ lda = max (1 , stride (A, 2 ))
40+ if isempty (ipiv)
41+ ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 )))
42+ end
43+ ccall ((@blasfunc (cgetrf_), MKL_jll. libmkl_rt), Cvoid,
44+ (Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32},
45+ Ref{BlasInt}, Ptr{BlasInt}, Ptr{BlasInt}),
46+ m, n, A, lda, ipiv, info)
47+ chkargsok (info[])
48+ A, ipiv, info[], info # Error code is stored in LU factorization type
49+ end
50+
1151function getrf! (A:: AbstractMatrix{<:Float64} ;
1252 ipiv = similar (A, BlasInt, min (size (A, 1 ), size (A, 2 ))),
1353 info = Ref {BlasInt} (),
@@ -48,6 +88,56 @@ function getrf!(A::AbstractMatrix{<:Float32};
4888 A, ipiv, info[], info # Error code is stored in LU factorization type
4989end
5090
91+ function getrs! (trans:: AbstractChar ,
92+ A:: AbstractMatrix{<:ComplexF64} ,
93+ ipiv:: AbstractVector{BlasInt} ,
94+ B:: AbstractVecOrMat{<:ComplexF64} ;
95+ info = Ref {BlasInt} ())
96+ require_one_based_indexing (A, ipiv, B)
97+ LinearAlgebra. LAPACK. chktrans (trans)
98+ chkstride1 (A, B, ipiv)
99+ n = LinearAlgebra. checksquare (A)
100+ if n != size (B, 1 )
101+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
102+ end
103+ if n != length (ipiv)
104+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
105+ end
106+ nrhs = size (B, 2 )
107+ ccall ((" zgetrs_" , MKL_jll. libmkl_rt), Cvoid,
108+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt},
109+ Ptr{BlasInt}, Ptr{ComplexF64}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
110+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
111+ 1 )
112+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
113+ B
114+ end
115+
116+ function getrs! (trans:: AbstractChar ,
117+ A:: AbstractMatrix{<:ComplexF32} ,
118+ ipiv:: AbstractVector{BlasInt} ,
119+ B:: AbstractVecOrMat{<:ComplexF32} ;
120+ info = Ref {BlasInt} ())
121+ require_one_based_indexing (A, ipiv, B)
122+ LinearAlgebra. LAPACK. chktrans (trans)
123+ chkstride1 (A, B, ipiv)
124+ n = LinearAlgebra. checksquare (A)
125+ if n != size (B, 1 )
126+ throw (DimensionMismatch (" B has leading dimension $(size (B,1 )) , but needs $n " ))
127+ end
128+ if n != length (ipiv)
129+ throw (DimensionMismatch (" ipiv has length $(length (ipiv)) , but needs to be $n " ))
130+ end
131+ nrhs = size (B, 2 )
132+ ccall ((" cgetrs_" , MKL_jll. libmkl_rt), Cvoid,
133+ (Ref{UInt8}, Ref{BlasInt}, Ref{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt},
134+ Ptr{BlasInt}, Ptr{ComplexF32}, Ref{BlasInt}, Ptr{BlasInt}, Clong),
135+ trans, n, size (B, 2 ), A, max (1 , stride (A, 2 )), ipiv, B, max (1 , stride (B, 2 )), info,
136+ 1 )
137+ LinearAlgebra. LAPACK. chklapackerror (BlasInt (info[]))
138+ B
139+ end
140+
51141function getrs! (trans:: AbstractChar ,
52142 A:: AbstractMatrix{<:Float64} ,
53143 ipiv:: AbstractVector{BlasInt} ,
@@ -106,12 +196,19 @@ const PREALLOCATED_MKL_LU = begin
106196 luinst = ArrayInterface. lu_instance (A), Ref {BlasInt} ()
107197end
108198
109- function init_cacheval (alg:: MKLLUFactorization , A, b, u, Pl, Pr,
199+ function LinearSolve . init_cacheval (alg:: MKLLUFactorization , A, b, u, Pl, Pr,
110200 maxiters:: Int , abstol, reltol, verbose:: Bool ,
111201 assumptions:: OperatorAssumptions )
112202 PREALLOCATED_MKL_LU
113203end
114204
205+ function LinearSolve. init_cacheval (alg:: MKLLUFactorization , A:: AbstractMatrix{<:Union{Float32,ComplexF32,ComplexF64}} , b, u, Pl, Pr,
206+ maxiters:: Int , abstol, reltol, verbose:: Bool ,
207+ assumptions:: OperatorAssumptions )
208+ A = rand (eltype (A), 0 , 0 )
209+ ArrayInterface. lu_instance (A), Ref {BlasInt} ()
210+ end
211+
115212function SciMLBase. solve! (cache:: LinearCache , alg:: MKLLUFactorization ;
116213 kwargs... )
117214 A = cache. A
0 commit comments