11using LoopVectorization
22using TriangularSolve: ldiv!
33using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
4- LinearAlgebra, Adjoint, Transpose
4+ LinearAlgebra, Adjoint, Transpose, UpperTriangular
55using StrideArraysCore
66using Polyester: @batch
77
@@ -41,16 +41,22 @@ init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
4141
4242if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_cols! )
4343 function LinearAlgebra. _ipiv_cols! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
44- B:: StridedVecOrMat )
44+ B:: StridedVecOrMat )
4545 return B
4646 end
4747end
4848if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_rows! )
4949 function LinearAlgebra. _ipiv_rows! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
50- B:: StridedVecOrMat )
50+ B:: StridedVecOrMat )
5151 return B
5252 end
5353end
54+ if CUSTOMIZABLE_PIVOT
55+ function LinearAlgebra. ldiv! (A:: LU{T, <:StridedMatrix, <:NotIPIV} ,
56+ B:: StridedVecOrMat{T} ) where {T <: BlasFloat }
57+ ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), B))
58+ end
59+ end
5460
5561function lu! (A, pivot = Val (true ), thread = Val (true ); check = true , kwargs... )
5662 m, n = size (A)
@@ -80,11 +86,11 @@ recurse(_) = false
8086_ptrarray (ipiv) = PtrArray (ipiv)
8187_ptrarray (ipiv:: NotIPIV ) = ipiv
8288function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
83- pivot = Val (true ), thread = Val (true );
84- check:: Bool = true ,
85- # the performance is not sensitive wrt blocksize, and 8 is a good default
86- blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
87- threshold:: Integer = pick_threshold ()) where {T}
89+ pivot = Val (true ), thread = Val (true );
90+ check:: Bool = true ,
91+ # the performance is not sensitive wrt blocksize, and 8 is a good default
92+ blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
93+ threshold:: Integer = pick_threshold ()) where {T}
8894 pivot = normalize_pivot (pivot)
8995 info = zero (BlasInt)
9096 m, n = size (A)
@@ -94,10 +100,12 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
94100 end
95101 if recurse (A) && mnmin > threshold
96102 if T <: Union{Float32, Float64}
97- GC. @preserve ipiv A begin info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
98- m, n, mnmin,
99- _ptrarray (ipiv), info, blocksize,
100- thread) end
103+ GC. @preserve ipiv A begin
104+ info = recurse! (view (PtrArray (A), axes (A)... ), pivot,
105+ m, n, mnmin,
106+ _ptrarray (ipiv), info, blocksize,
107+ thread)
108+ end
101109 else
102110 info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
103111 end
@@ -109,7 +117,7 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
109117end
110118
111119@inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
112- :: Val{true} ) where {Pivot}
120+ :: Val{true} ) where {Pivot}
113121 if length (A) * _sizeof (eltype (A)) >
114122 0.92 * LoopVectorization. VectorizationBase. cache_size (Val (2 ))
115123 _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (true ))
@@ -118,11 +126,11 @@ end
118126 end
119127end
120128@inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
121- :: Val{false} ) where {Pivot}
129+ :: Val{false} ) where {Pivot}
122130 _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
123131end
124132@inline function _recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
125- :: Val{Thread} ) where {Pivot, Thread}
133+ :: Val{Thread} ) where {Pivot, Thread}
126134 info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, Val (Thread)):: Int
127135 @inbounds if m < n # fat matrix
128136 # [AL AR]
@@ -166,7 +174,7 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
166174 nothing
167175end
168176function reckernel! (A:: AbstractMatrix{T} , pivot:: Val{Pivot} , m, n, ipiv, info, blocksize,
169- thread):: BlasInt where {T, Pivot}
177+ thread):: BlasInt where {T, Pivot}
170178 @inbounds begin
171179 if n <= max (blocksize, 1 )
172180 info = _generic_lufact! (A, Val (Pivot), ipiv, info)
@@ -262,44 +270,46 @@ end
262270function _generic_lufact! (A, :: Val{Pivot} , ipiv, info) where {Pivot}
263271 m, n = size (A)
264272 minmn = length (ipiv)
265- @inbounds begin for k in 1 : minmn
266- # find index max
267- kp = k
268- if Pivot
269- amax = abs (zero (eltype (A)))
270- for i in k: m
271- absi = abs (A[i, k])
272- if absi > amax
273- kp = i
274- amax = absi
273+ @inbounds begin
274+ for k in 1 : minmn
275+ # find index max
276+ kp = k
277+ if Pivot
278+ amax = abs (zero (eltype (A)))
279+ for i in k: m
280+ absi = abs (A[i, k])
281+ if absi > amax
282+ kp = i
283+ amax = absi
284+ end
275285 end
286+ ipiv[k] = kp
276287 end
277- ipiv[k] = kp
278- end
279- if ! iszero (A[kp, k])
280- if k != kp
281- # Interchange
282- @simd for i in 1 : n
283- tmp = A[k, i]
284- A[k, i] = A[kp, i]
285- A[kp, i] = tmp
288+ if ! iszero (A[kp, k])
289+ if k != kp
290+ # Interchange
291+ @simd for i in 1 : n
292+ tmp = A[k, i]
293+ A[k, i] = A[kp, i]
294+ A[kp, i] = tmp
295+ end
286296 end
297+ # Scale first column
298+ Akkinv = inv (A[k, k])
299+ @turbo check_empty= true warn_check_args= false for i in (k + 1 ): m
300+ A[i, k] *= Akkinv
301+ end
302+ elseif info == 0
303+ info = k
287304 end
288- # Scale first column
289- Akkinv = inv (A[k, k])
290- @turbo check_empty= true warn_check_args= false for i in (k + 1 ): m
291- A[i, k] *= Akkinv
292- end
293- elseif info == 0
294- info = k
295- end
296- k == minmn && break
297- # Update the rest
298- @turbo warn_check_args= false for j in (k + 1 ): n
299- for i in (k + 1 ): m
300- A[i, j] -= A[i, k] * A[k, j]
305+ k == minmn && break
306+ # Update the rest
307+ @turbo warn_check_args= false for j in (k + 1 ): n
308+ for i in (k + 1 ): m
309+ A[i, j] -= A[i, k] * A[k, j]
310+ end
301311 end
302312 end
303- end end
313+ end
304314 return info
305315end
0 commit comments