11using LoopVectorization
22using TriangularSolve: ldiv!
33using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
4- LinearAlgebra, Adjoint, Transpose
4+ LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat
55using StrideArraysCore
66using Polyester: @batch
77
@@ -41,16 +41,23 @@ init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
4141
4242if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_cols! )
4343 function LinearAlgebra. _ipiv_cols! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
44- B:: StridedVecOrMat )
44+ B:: StridedVecOrMat )
4545 return B
4646 end
4747end
4848if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_rows! )
49- function LinearAlgebra. _ipiv_rows! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
50- B:: StridedVecOrMat )
49+ function LinearAlgebra. _ipiv_rows! (:: (LU{T, <:AbstractMatrix{T}, NotIPIV} where {T}) ,
50+ :: OrdinalRange ,
51+ B:: StridedVecOrMat )
5152 return B
5253 end
5354end
55+ if CUSTOMIZABLE_PIVOT
56+ function LinearAlgebra. ldiv! (A:: LU{T, <:StridedMatrix, <:NotIPIV} ,
57+ B:: StridedVecOrMat{T} ) where {T <: BlasFloat }
58+ ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), B))
59+ end
60+ end
5461
5562function lu! (A, pivot = Val (true ), thread = Val (true ); check = true , kwargs... )
5663 m, n = size (A)
@@ -80,11 +87,11 @@ recurse(_) = false
8087_ptrarray (ipiv) = PtrArray (ipiv)
8188_ptrarray (ipiv:: NotIPIV ) = ipiv
8289function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
83- pivot = Val (true ), thread = Val (true );
84- check:: Bool = true ,
85- # the performance is not sensitive wrt blocksize, and 8 is a good default
86- blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
87- threshold:: Integer = pick_threshold ()) where {T}
90+ pivot = Val (true ), thread = Val (true );
91+ check:: Bool = true ,
92+ # the performance is not sensitive wrt blocksize, and 8 is a good default
93+ blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
94+ threshold:: Integer = pick_threshold ()) where {T}
8895 pivot = normalize_pivot (pivot)
8996 info = zero (BlasInt)
9097 m, n = size (A)
@@ -94,10 +101,12 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
94101 end
95102 if recurse (A) && mnmin > threshold
96103 if T <: Union{Float32, Float64}
97- GC. @preserve ipiv A begin info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
98- m, n, mnmin,
99- _ptrarray (ipiv), info, blocksize,
100- thread) end
104+ GC. @preserve ipiv A begin
105+ info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
106+ m, n, mnmin,
107+ _ptrarray (ipiv), info, blocksize,
108+ thread)
109+ end
101110 else
102111 info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
103112 end
@@ -109,7 +118,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
109118end
110119
111120@inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
112- :: Val{true} ) where {Pivot}
121+ :: Val{true} ) where {Pivot}
113122 if length (A) * _sizeof (eltype (A)) >
114123 0.92 * LoopVectorization. VectorizationBase. cache_size (Val (2 ))
115124 _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (true ))
@@ -118,11 +127,11 @@ end
118127 end
119128end
120129@inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
121- :: Val{false} ) where {Pivot}
130+ :: Val{false} ) where {Pivot}
122131 _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
123132end
124133@inline function _recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
125- :: Val{Thread} ) where {Pivot, Thread}
134+ :: Val{Thread} ) where {Pivot, Thread}
126135 info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, Val (Thread)):: Int
127136 @inbounds if m < n # fat matrix
128137 # [AL AR]
@@ -166,7 +175,7 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
166175 nothing
167176end
168177function reckernel! (A:: AbstractMatrix{T} , pivot:: Val{Pivot} , m, n, ipiv, info, blocksize,
169- thread):: BlasInt where {T, Pivot}
178+ thread):: BlasInt where {T, Pivot}
170179 @inbounds begin
171180 if n <= max (blocksize, 1 )
172181 info = _generic_lufact! (A, Val (Pivot), ipiv, info)
@@ -262,44 +271,46 @@ end
262271function _generic_lufact! (A, :: Val{Pivot} , ipiv, info) where {Pivot}
263272 m, n = size (A)
264273 minmn = length (ipiv)
265- @inbounds begin for k in 1 : minmn
266- # find index max
267- kp = k
268- if Pivot
269- amax = abs (zero (eltype (A)))
270- for i in k: m
271- absi = abs (A[i, k])
272- if absi > amax
273- kp = i
274- amax = absi
274+ @inbounds begin
275+ for k in 1 : minmn
276+ # find index max
277+ kp = k
278+ if Pivot
279+ amax = abs (zero (eltype (A)))
280+ for i in k: m
281+ absi = abs (A[i, k])
282+ if absi > amax
283+ kp = i
284+ amax = absi
285+ end
275286 end
287+ ipiv[k] = kp
276288 end
277- ipiv[k] = kp
278- end
279- if ! iszero (A[kp, k])
280- if k != kp
281- # Interchange
282- @simd for i in 1 : n
283- tmp = A[k, i]
284- A[k, i] = A[kp, i]
285- A[kp, i] = tmp
289+ if ! iszero (A[kp, k])
290+ if k != kp
291+ # Interchange
292+ @simd for i in 1 : n
293+ tmp = A[k, i]
294+ A[k, i] = A[kp, i]
295+ A[kp, i] = tmp
296+ end
286297 end
298+ # Scale first column
299+ Akkinv = inv (A[k, k])
300+ @turbo check_empty= true warn_check_args= false for i in (k + 1 ): m
301+ A[i, k] *= Akkinv
302+ end
303+ elseif info == 0
304+ info = k
287305 end
288- # Scale first column
289- Akkinv = inv (A[k, k])
290- @turbo check_empty= true warn_check_args= false for i in (k + 1 ): m
291- A[i, k] *= Akkinv
292- end
293- elseif info == 0
294- info = k
295- end
296- k == minmn && break
297- # Update the rest
298- @turbo warn_check_args= false for j in (k + 1 ): n
299- for i in (k + 1 ): m
300- A[i, j] -= A[i, k] * A[k, j]
306+ k == minmn && break
307+ # Update the rest
308+ @turbo warn_check_args= false for j in (k + 1 ): n
309+ for i in (k + 1 ): m
310+ A[i, j] -= A[i, k] * A[k, j]
311+ end
301312 end
302313 end
303- end end
314+ end
304315 return info
305316end
0 commit comments