11using LoopVectorization
22using TriangularSolve: ldiv!
33using LinearAlgebra: BlasInt, BlasFloat, LU, UnitLowerTriangular, checknonsingular, BLAS,
4- LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat
4+ LinearAlgebra, Adjoint, Transpose, UpperTriangular, AbstractVecOrMat
55using StrideArraysCore
66using Polyester: @batch
77
@@ -41,32 +41,35 @@ init_pivot(::Val{true}, minmn) = Vector{BlasInt}(undef, minmn)
4141
4242if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_cols! )
4343 function LinearAlgebra. _ipiv_cols! (:: LU{<:Any, <:Any, NotIPIV} , :: OrdinalRange ,
44- B:: StridedVecOrMat )
44+ B:: StridedVecOrMat )
4545 return B
4646 end
4747end
4848if CUSTOMIZABLE_PIVOT && isdefined (LinearAlgebra, :_ipiv_rows! )
4949 function LinearAlgebra. _ipiv_rows! (:: (LU{T, <:AbstractMatrix{T}, NotIPIV} where {T}) ,
50- :: OrdinalRange ,
51- B:: StridedVecOrMat )
50+ :: OrdinalRange ,
51+ B:: StridedVecOrMat )
5252 return B
5353 end
5454end
5555if CUSTOMIZABLE_PIVOT
5656 function LinearAlgebra. ldiv! (A:: LU{T, <:StridedMatrix, <:NotIPIV} ,
57- B:: StridedVecOrMat{T} ) where {T <: BlasFloat }
57+ B:: StridedVecOrMat{T} ) where {T <: BlasFloat }
5858 ldiv! (UpperTriangular (A. factors), ldiv! (UnitLowerTriangular (A. factors), B))
5959 end
6060end
6161
62- function lu! (A, pivot = Val (true ), thread = Val (false ); check = true , kwargs... )
62+ function lu! (A, pivot = Val (true ), thread = Val (false );
63+ check:: Union{Bool, Val{true}, Val{false}} = Val (true ), kwargs... )
6364 m, n = size (A)
6465 minmn = min (m, n)
6566 npivot = normalize_pivot (pivot)
6667 # we want the type on both branches to match. When pivot = Val(false), we construct
6768 # a `NotIPIV`, which `LinearAlgebra.generic_lufact!` does not.
6869 F = if pivot === Val (true ) && minmn < 10 # avx introduces small performance degradation
69- LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot); check = check)
70+ LinearAlgebra. generic_lufact! (A, to_stdlib_pivot (pivot);
71+ check = ((check isa Bool && check) || (check === Val (true )))
72+ )
7073 else
7174 lu! (A, init_pivot (npivot, minmn), npivot, thread; check = check,
7275 kwargs... )
@@ -87,11 +90,11 @@ recurse(_) = false
8790_ptrarray (ipiv) = PtrArray (ipiv)
8891_ptrarray (ipiv:: NotIPIV ) = ipiv
8992function lu! (A:: AbstractMatrix{T} , ipiv:: AbstractVector{<:Integer} ,
90- pivot = Val (true ), thread = Val (false );
91- check:: Bool = true ,
92- # the performance is not sensitive wrt blocksize, and 8 is a good default
93- blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
94- threshold:: Integer = pick_threshold ()) where {T}
93+ pivot = Val (true ), thread = Val (false );
94+ check:: Union{ Bool, Val{true}, Val{false}} = Val ( true ) ,
95+ # the performance is not sensitive wrt blocksize, and 8 is a good default
96+ blocksize:: Integer = length (A) ≥ 40_000 ? 8 : 16 ,
97+ threshold:: Integer = pick_threshold ()) where {T}
9598 pivot = normalize_pivot (pivot)
9699 info = zero (BlasInt)
97100 m, n = size (A)
@@ -113,12 +116,12 @@ function lu!(A::AbstractMatrix{T}, ipiv::AbstractVector{<:Integer},
113116 else # generic fallback
114117 info = _generic_lufact! (A, pivot, ipiv, info)
115118 end
116- check && checknonsingular (info)
119+ (( check isa Bool && check) || (check === Val ( true ))) && checknonsingular (info)
117120 LU (A, ipiv, info)
118121end
119122
120123@inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
121- :: Val{true} ) where {Pivot}
124+ :: Val{true} ) where {Pivot}
122125 if length (A) * _sizeof (eltype (A)) >
123126 0.92 * LoopVectorization. VectorizationBase. cache_size (Val (2 ))
124127 _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (true ))
@@ -127,11 +130,11 @@ end
127130 end
128131end
129132@inline function recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
130- :: Val{false} ) where {Pivot}
133+ :: Val{false} ) where {Pivot}
131134 _recurse! (A, Val {Pivot} (), m, n, mnmin, ipiv, info, blocksize, Val (false ))
132135end
133136@inline function _recurse! (A, :: Val{Pivot} , m, n, mnmin, ipiv, info, blocksize,
134- :: Val{Thread} ) where {Pivot, Thread}
137+ :: Val{Thread} ) where {Pivot, Thread}
135138 info = reckernel! (A, Val (Pivot), m, mnmin, ipiv, info, blocksize, Val (Thread)):: Int
136139 @inbounds if m < n # fat matrix
137140 # [AL AR]
@@ -175,7 +178,7 @@ Base.@propagate_inbounds function apply_permutation!(P, A, ::Val{false})
175178 nothing
176179end
177180function reckernel! (A:: AbstractMatrix{T} , pivot:: Val{Pivot} , m, n, ipiv, info, blocksize,
178- thread):: BlasInt where {T, Pivot}
181+ thread):: BlasInt where {T, Pivot}
179182 @inbounds begin
180183 if n <= max (blocksize, 1 )
181184 info = _generic_lufact! (A, Val (Pivot), ipiv, info)
0 commit comments