@@ -14,17 +14,17 @@ if VERSION >= v"1.7.0-DEV.1188"
1414 to_stdlib_pivot (:: Val{false} ) = LinearAlgebra. NoPivot ()
1515end
1616
17- function lu (A:: AbstractMatrix , pivot = Val (true ); kwargs... )
18- return lu! (copy (A), normalize_pivot (pivot); kwargs... )
17+ function lu (A:: AbstractMatrix , pivot = Val (true ), thread = Val ( true ) ; kwargs... )
18+ return lu! (copy (A), normalize_pivot (pivot), thread ; kwargs... )
1919end
2020
21- function lu! (A, pivot = Val (true ); check= true , kwargs... )
21+ function lu! (A, pivot = Val (true ), thread = Val ( true ) ; check= true , kwargs... )
2222 m, n = size (A)
2323 minmn = min (m, n)
2424 F = if minmn < 10 # avx introduces small performance degradation
2525 LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check= check)
2626 else
27- lu! (A, Vector {BlasInt} (undef, minmn), normalize_pivot (pivot); check= check, kwargs... )
27+ lu! (A, Vector {BlasInt} (undef, minmn), normalize_pivot (pivot), thread ; check= check, kwargs... )
2828 end
2929 return F
3030end
@@ -46,7 +46,7 @@ recurse(_) = false
4646
4747function lu! (
4848 A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
49- pivot = Val (true );
49+ pivot = Val (true ), thread = Val ( true ) ;
5050 check:: Bool = true ,
5151 # the performance is not sensitive wrt blocksize, and 8 is a good default
5252 blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
@@ -59,10 +59,10 @@ function lu!(
5959 if recurse (A) && mnmin > threshold
6060 if T <: Union{Float32,Float64}
6161 GC. @preserve ipiv A begin
62- info = recurse! (PtrArray (A), pivot, m, n, mnmin, PtrArray (ipiv), info, blocksize)
62+ info = recurse! (PtrArray (A), pivot, m, n, mnmin, PtrArray (ipiv), info, blocksize, thread )
6363 end
6464 else
65- info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize)
65+ info = recurse! (A, pivot, m, n, mnmin, ipiv, info, blocksize, thread )
6666 end
6767 else # generic fallback
6868 info = _generic_lufact! (A, pivot, ipiv, info)
@@ -71,26 +71,36 @@ function lu!(
7171 LU {T, typeof(A)} (A, ipiv, info)
7272end
7373
74- @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize) where {Pivot}
75- thread = length (A) * _sizeof (eltype (A)) > 0.92 * LoopVectorization. VectorizationBase. cache_size (Val (1 ))
76- info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, thread)
74+ @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize, :: Val{true} ) where {Pivot}
75+ if length (A) * _sizeof (eltype (A)) > 0.92 * LoopVectorization. VectorizationBase. cache_size (Val (1 ))
76+ _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (true ))
77+ else
78+ _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
79+ end
80+ end
81+ @inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize, :: Val{false} ) where {Pivot}
82+ _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
83+ end
84+ @inline function _recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize, :: Val{Thread} ) where {Pivot,Thread}
85+ info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, Val (Thread))
7786 @inbounds if m < n # fat matrix
7887 # [AL AR]
7988 AL = @view A[:, 1 : m]
8089 AR = @view A[:, m+ 1 : n]
81- apply_permutation! (ipiv, AR, thread )
82- ldiv! (UnitLowerTriangular (AL), AR)
90+ apply_permutation! (ipiv, AR, Val (Thread) )
91+ ldiv! (UnitLowerTriangular (AL), AR, Val (Thread) )
8392 end
8493 info
8594end
8695
96+
8797@inline function nsplit (:: Type{T} , n) where T
8898 k = 512 ÷ (isbitstype (T) ? sizeof (T) : 8 )
8999 k_2 = k ÷ 2
90100 return n >= k ? ((n + k_2) ÷ k) * k_2 : n ÷ 2
91101end
92102
93- function apply_permutation_threaded ! (P, A)
103+ function apply_permutation ! (P, A, :: Val{true} )
94104 batchsize = cld (2000 , length (P))
95105 @batch minbatch= batchsize for j in axes (A, 2 )
96106 @inbounds for i in axes (P, 1 )
@@ -103,9 +113,7 @@ function apply_permutation_threaded!(P, A)
103113 nothing
104114end
105115_sizeof (:: Type{T} ) where {T} = Base. isbitstype (T) ? sizeof (T) : sizeof (Int)
106- Base. @propagate_inbounds function apply_permutation! (P, A, thread)
107- thread && return apply_permutation_threaded! (P, A)
108- # length(A) * _sizeof(eltype(A)) > 0.92 * LoopVectorization.VectorizationBase.cache_size(Val(1)) && return apply_permutation_threaded!(P, A)
116+ Base. @propagate_inbounds function apply_permutation! (P, A, :: Val{false} )
109117 for i in axes (P, 1 )
110118 i′ = P[i]
111119 i′ == i && continue
@@ -162,7 +170,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
162170 # [ A22 ] [ 0 ] [ A22 ]
163171 Pivot && apply_permutation! (P1, AR, thread)
164172 # A12 = L11 U12 => U12 = L11 \ A12
165- ldiv! (UnitLowerTriangular (A11), A12)
173+ ldiv! (UnitLowerTriangular (A11), A12, thread )
166174 # Schur complement:
167175 # We have A22 = L21 U12 + A′22, hence
168176 # A′22 = A22 - L21 U12
@@ -176,7 +184,7 @@ function reckernel!(A::AbstractMatrix{T}, pivot::Val{Pivot}, m, n, ipiv, info, b
176184 Pivot && apply_permutation! (P2, A21, thread)
177185
178186 info != previnfo && (info += n1)
179- @avx for i in 1 : n2
187+ @turbo warn_check_args = false for i in 1 : n2
180188 P2[i] += n1
181189 end
182190 return info
@@ -226,15 +234,15 @@ function _generic_lufact!(A, ::Val{Pivot}, ipiv, info) where Pivot
226234 end
227235 # Scale first column
228236 Akkinv = inv (A[k,k])
229- @avx check_empty= true for i = k+ 1 : m
237+ @turbo check_empty= true warn_check_args = false for i = k+ 1 : m
230238 A[i,k] *= Akkinv
231239 end
232240 elseif info == 0
233241 info = k
234242 end
235243 k == minmn && break
236244 # Update the rest
237- @avx for j = k+ 1 : n
245+ @turbo warn_check_args = false for j = k+ 1 : n
238246 for i = k+ 1 : m
239247 A[i,j] -= A[i,k]* A[k,j]
240248 end
0 commit comments