11using LoopVectorization
22using TriangularSolve: ldiv!
3- using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS, LinearAlgebra, Adjoint, Transpose
3+ using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
4+ LinearAlgebra, Adjoint, Transpose
45using StrideArraysCore
56using Polyester: @batch
67
8+ @generated function _unit_lower_triangular (B:: A ) where {T, A <: AbstractMatrix{T} }
9+ Expr (:new , UnitLowerTriangular{T, A}, :B )
10+ end
711# 1.7 compat
8- normalize_pivot (t:: Val{T} ) where T = t
12+ normalize_pivot (t:: Val{T} ) where {T} = t
913to_stdlib_pivot (t) = t
1014if VERSION >= v " 1.7.0-DEV.1188"
1115 normalize_pivot (:: LinearAlgebra.RowMaximum ) = Val (true )
@@ -18,19 +22,20 @@ function lu(A::AbstractMatrix, pivot = Val(true), thread = Val(true); kwargs...)
1822 return lu! (copy (A), normalize_pivot (pivot), thread; kwargs... )
1923end
2024
21- function lu! (A, pivot = Val (true ), thread = Val (true ); check= true , kwargs... )
22- m, n = size (A)
25+ function lu! (A, pivot = Val (true ), thread = Val (true ); check = true , kwargs... )
26+ m, n = size (A)
2327 minmn = min (m, n)
2428 F = if minmn < 10 # avx introduces small performance degradation
25- LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check= check)
29+ LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check = check)
2630 else
27- lu! (A, Vector {BlasInt} (undef, minmn), normalize_pivot (pivot), thread; check= check, kwargs... )
31+ lu! (A, Vector {BlasInt} (undef, minmn), normalize_pivot (pivot), thread; check = check,
32+ kwargs... )
2833 end
2934 return F
3035end
3136
3237for (f, T) in [(:adjoint , :Adjoint ), (:transpose , :Transpose )], lu in (:lu , :lu! )
33- @eval $ lu (A:: $T , args... ; kwargs... ) = $ f ($ lu (parent (A), args... ; kwargs... ))
38+ @eval $ lu (A:: $T , args... ; kwargs... ) = $ f ($ lu (parent (A), args... ; kwargs... ))
3439end
3540
3641const RECURSION_THRESHOLD = Ref (- 1 )
4449recurse (:: StridedArray ) = true
4550recurse (_) = false
4651
47- function lu! (
48- A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
49- pivot = Val (true ), thread = Val (true );
50- check:: Bool = true ,
51- # the performance is not sensitive wrt blocksize, and 8 is a good default
52- blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
53- threshold:: Integer = pick_threshold ()
54- ) where T
52+ function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
53+ pivot = Val (true ), thread = Val (true );
54+ check:: Bool = true ,
55+ # the performance is not sensitive wrt blocksize, and 8 is a good default
56+ blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
57+ threshold:: Integer = pick_threshold ()) where {T}
5558 pivot = normalize_pivot (pivot)
5659 info = zero (BlasInt)
5760 m, n = size (A)
5861 mnmin = min (m, n)
5962 if recurse (A) && mnmin > threshold
60- if T <: Union{Float32,Float64}
61- GC. @preserve ipiv A begin
62- info = recurse! ( PtrArray (A), pivot, m, n, mnmin, PtrArray (ipiv), info, blocksize, thread)
63- end
63+ if T <: Union{Float32, Float64}
64+ GC. @preserve ipiv A begin info = recurse! ( PtrArray (A), pivot, m, n, mnmin,
65+ PtrArray (ipiv), info, blocksize,
66+ thread) end
6467 else
6568 info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread)
6669 end
@@ -71,30 +74,33 @@ function lu!(
7174 LU {T, typeof(A)} (A, ipiv, info)
7275end
7376
74- @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize, :: Val{true} ) where {Pivot}
75- if length (A) * _sizeof (eltype (A)) > 0.92 * LoopVectorization. VectorizationBase. cache_size (Val (1 ))
76- _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (true ))
77- else
78- _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
79- end
77+ @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
78+ :: Val{true} ) where {Pivot}
79+ if length (A) * _sizeof (eltype (A)) >
80+ 0.92 * LoopVectorization. VectorizationBase. cache_size (Val (1 ))
81+ _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (true ))
82+ else
83+ _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
84+ end
8085end
81- @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize, :: Val{false} ) where {Pivot}
82- _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
86+ @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
87+ :: Val{false} ) where {Pivot}
88+ _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
8389end
84- @inline function _recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize, :: Val{Thread} ) where {Pivot,Thread}
85- info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, Val (Thread))
86- @inbounds if m < n # fat matrix
87- # [AL AR]
88- AL = @view A[:, 1 : m]
89- AR = @view A[:, m+ 1 : n]
90- apply_permutation! (ipiv, AR, Val (Thread))
91- ldiv! (UnitLowerTriangular (AL), AR, Val (Thread))
92- end
93- info
90+ @inline function _recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
91+ :: Val{Thread} ) where {Pivot, Thread}
92+ info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, Val (Thread)):: Int
93+ @inbounds if m < n # fat matrix
94+ # [AL AR]
95+ AL = @view A[:, 1 : m]
96+ AR = @view A[:, (m + 1 ): n]
97+ apply_permutation! (ipiv, AR, Val (Thread))
98+ ldiv! (_unit_lower_triangular (AL), AR, Val (Thread))
99+ end
100+ info
94101end
95102
96-
97- @inline function nsplit (:: Type{T} , n) where T
103+ @inline function nsplit (:: Type{T} , n) where {T}
98104 k = 512 ÷ (isbitstype (T) ? sizeof (T) : 8 )
99105 k_2 = k ÷ 2
100106 return n >= k ? ((n + k_2) ÷ k) * k_2 : n ÷ 2
@@ -125,8 +131,8 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
125131 end
126132 nothing
127133end
128-
129- function reckernel! (A :: AbstractMatrix{T} , pivot :: Val{Pivot} , m, n, ipiv, info, blocksize, thread):: BlasInt where {T,Pivot}
134+ function reckernel! (A :: AbstractMatrix{T} , pivot :: Val{Pivot} , m, n, ipiv, info, blocksize,
135+ thread):: BlasInt where {T, Pivot}
130136 @inbounds begin
131137 if n <= max (blocksize, 1 )
132138 info = _generic_lufact! (A, Val (Pivot), ipiv, info)
@@ -147,18 +153,18 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
147153 # Partition the matrix A
148154 # [AL AR]
149155 AL = @view A[:, 1 : n1]
150- AR = @view A[:, n1 + 1 : n]
156+ AR = @view A[:, (n1 + 1 ) : n]
151157 # AL AR
152158 # [A11 A12]
153159 # [A21 A22]
154160 A11 = @view A[1 : n1, 1 : n1]
155- A12 = @view A[1 : n1, n1 + 1 : n]
156- A21 = @view A[n1 + 1 : m, 1 : n1]
157- A22 = @view A[n1 + 1 : m, n1 + 1 : n]
161+ A12 = @view A[1 : n1, (n1 + 1 ) : n]
162+ A21 = @view A[(n1 + 1 ) : m, 1 : n1]
163+ A22 = @view A[(n1 + 1 ) : m, (n1 + 1 ) : n]
158164 # [P1]
159165 # [P2]
160166 P1 = @view ipiv[1 : n1]
161- P2 = @view ipiv[n1 + 1 : n]
167+ P2 = @view ipiv[(n1 + 1 ) : n]
162168 # ========================================
163169
164170 # [ A11 ] [ L11 ]
@@ -170,7 +176,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
170176 # [ A22 ] [ 0 ] [ A22 ]
171177 Pivot && apply_permutation! (P1, AR, thread)
172178 # A12 = L11 U12 => U12 = L11 \ A12
173- ldiv! (UnitLowerTriangular (A11), A12, thread)
179+ ldiv! (_unit_lower_triangular (A11), A12, thread)
174180 # Schur complement:
175181 # We have A22 = L21 U12 + A′22, hence
176182 # A′22 = A22 - L21 U12
@@ -191,23 +197,23 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
191197 end # inbounds
192198end
193199
194- function schur_complement! (𝐂, 𝐀, 𝐁,:: Val{THREAD} = Val (true )) where {THREAD}
200+ function schur_complement! (𝐂, 𝐀, 𝐁, :: Val{THREAD} = Val (true )) where {THREAD}
195201 # mul!(𝐂,𝐀,𝐁,-1,1)
196202 if THREAD
197- @tturbo warn_check_args= false for m ∈ 1 : size (𝐀,1 ), n ∈ 1 : size (𝐁,2 )
203+ @tturbo warn_check_args= false for m in 1 : size (𝐀, 1 ), n in 1 : size (𝐁, 2 )
198204 𝐂ₘₙ = zero (eltype (𝐂))
199- for k ∈ 1 : size (𝐀,2 )
200- 𝐂ₘₙ -= 𝐀[m,k] * 𝐁[k,n]
205+ for k in 1 : size (𝐀, 2 )
206+ 𝐂ₘₙ -= 𝐀[m, k] * 𝐁[k, n]
201207 end
202- 𝐂[m,n] = 𝐂ₘₙ + 𝐂[m,n]
208+ 𝐂[m, n] = 𝐂ₘₙ + 𝐂[m, n]
203209 end
204210 else
205- @turbo warn_check_args= false for m ∈ 1 : size (𝐀,1 ), n ∈ 1 : size (𝐁,2 )
211+ @turbo warn_check_args= false for m in 1 : size (𝐀, 1 ), n in 1 : size (𝐁, 2 )
206212 𝐂ₘₙ = zero (eltype (𝐂))
207- for k ∈ 1 : size (𝐀,2 )
208- 𝐂ₘₙ -= 𝐀[m,k] * 𝐁[k,n]
213+ for k in 1 : size (𝐀, 2 )
214+ 𝐂ₘₙ -= 𝐀[m, k] * 𝐁[k, n]
209215 end
210- 𝐂[m,n] = 𝐂ₘₙ + 𝐂[m,n]
216+ 𝐂[m, n] = 𝐂ₘₙ + 𝐂[m, n]
211217 end
212218 end
213219end
@@ -216,49 +222,47 @@ end
216222 Modified from https://github.com/JuliaLang/julia/blob/b56a9f07948255dfbe804eef25bdbada06ec2a57/stdlib/LinearAlgebra/src/lu.jl
217223 License is MIT: https://julialang.org/license
218224=#
219- function _generic_lufact! (A, :: Val{Pivot} , ipiv, info) where Pivot
225+ function _generic_lufact! (A, :: Val{Pivot} , ipiv, info) where { Pivot}
220226 m, n = size (A)
221227 minmn = length (ipiv)
222- @inbounds begin
223- for k = 1 : minmn
224- # find index max
225- kp = k
226- if Pivot
227- amax = abs (zero (eltype (A)))
228- for i = k: m
229- absi = abs (A[i,k])
230- if absi > amax
231- kp = i
232- amax = absi
233- end
234- end
235- end
236- ipiv[k] = kp
237- if ! iszero (A[kp,k])
238- if k != kp
239- # Interchange
240- @simd for i = 1 : n
241- tmp = A[k,i]
242- A[k,i] = A[kp,i]
243- A[kp,i] = tmp
244- end
228+ @inbounds begin for k in 1 : minmn
229+ # find index max
230+ kp = k
231+ if Pivot
232+ amax = abs (zero (eltype (A)))
233+ for i in k: m
234+ absi = abs (A[i, k])
235+ if absi > amax
236+ kp = i
237+ amax = absi
245238 end
246- # Scale first column
247- Akkinv = inv (A[k,k])
248- @turbo check_empty= true warn_check_args= false for i = k+ 1 : m
249- A[i,k] *= Akkinv
250- end
251- elseif info == 0
252- info = k
253239 end
254- k == minmn && break
255- # Update the rest
256- @turbo warn_check_args= false for j = k+ 1 : n
257- for i = k+ 1 : m
258- A[i,j] -= A[i,k]* A[k,j]
240+ end
241+ ipiv[k] = kp
242+ if ! iszero (A[kp, k])
243+ if k != kp
244+ # Interchange
245+ @simd for i in 1 : n
246+ tmp = A[k, i]
247+ A[k, i] = A[kp, i]
248+ A[kp, i] = tmp
259249 end
260250 end
251+ # Scale first column
252+ Akkinv = inv (A[k, k])
253+ @turbo check_empty= true warn_check_args= false for i in (k + 1 ): m
254+ A[i, k] *= Akkinv
255+ end
256+ elseif info == 0
257+ info = k
261258 end
262- end
259+ k == minmn && break
260+ # Update the rest
261+ @turbo warn_check_args= false for j in (k + 1 ): n
262+ for i in (k + 1 ): m
263+ A[i, j] -= A[i, k] * A[k, j]
264+ end
265+ end
266+ end end
263267 return info
264268end
0 commit comments