@@ -488,14 +488,13 @@ end
488488
489489# THE one big BLAS dispatch. This is split into two methods to improve latency
490490Base. @constprop :aggressive function generic_matmatmul_wrapper! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
491- α:: Number , β:: Number , val:: BlasFlag.SyrkHerkGemm ) where {T<: BlasFloat }
491+ α:: Number , β:: Number , val:: BlasFlag.SyrkHerkGemm ) where {T<: Number }
492492 mA, nA = lapack_size (tA, A)
493493 mB, nB = lapack_size (tB, B)
494494 if any (iszero, size (A)) || any (iszero, size (B)) || iszero (α)
495495 matmul_size_check (size (C), (mA, nA), (mB, nB))
496496 return _rmul_or_fill! (C, β)
497497 end
498- matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β) && return C
499498 _syrk_herk_gemm_wrapper! (C, tA, tB, A, B, α, β, val)
500499 return C
501500end
@@ -570,6 +569,70 @@ Base.@constprop :aggressive function _symm_hemm_generic!(C, tA, tB, A, B, alpha,
570569 _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), alpha, beta)
571570end
572571
572+ """
573+ generic_syrk!(C::StridedMatrix{T}, A::StridedVecOrMat{T}, conjugate::Bool, aat::Bool, α, β) where {T<:Number}
574+
575+ Computes syrk/herk for generic number types. If `conjugate` is false computes syrk, i.e.,
576+ ``A transpose(A) α + C β`` if `aat` is true, and ``transpose(A) A α + C β`` otherwise.
577+ If `conjugate` is true computes herk, i.e., ``A A' α + C β`` if `aat` is true, and
578+ ``A' A α + C β`` otherwise. Only the upper triangular is computed.
579+ """
580+ function generic_syrk! (C:: StridedMatrix{T} , A:: StridedVecOrMat{T} , conjugate:: Bool , aat:: Bool , α, β) where {T<: Number }
581+ require_one_based_indexing (C, A)
582+ nC = checksquare (C)
583+ m, n = size (A, 1 ), size (A, 2 )
584+ mA = aat ? m : n
585+ if nC != mA
586+ throw (DimensionMismatch (lazy " output matrix has size: $(size(C)), but should have size $((mA, mA))" ))
587+ end
588+
589+ _rmul_or_fill! (C, β)
590+ @inbounds if ! conjugate
591+ if aat
592+ for k ∈ 1 : n, j ∈ 1 : m
593+ αA_jk = A[j, k] * α
594+ for i ∈ 1 : j
595+ C[i, j] += A[i, k] * αA_jk
596+ end
597+ end
598+ else
599+ for j ∈ 1 : n, i ∈ 1 : j
600+ temp = A[1 , i] * A[1 , j]
601+ for k ∈ 2 : m
602+ temp += A[k, i] * A[k, j]
603+ end
604+ C[i, j] += temp * α
605+ end
606+ end
607+ else
608+ if aat
609+ for k ∈ 1 : n, j ∈ 1 : m
610+ αA_jk_bar = conj (A[j, k]) * α
611+ for i ∈ 1 : j- 1
612+ C[i, j] += A[i, k] * αA_jk_bar
613+ end
614+ C[j, j] += abs2 (A[j, k]) * α
615+ end
616+ else
617+ for j ∈ 1 : n
618+ for i ∈ 1 : j- 1
619+ temp = conj (A[1 , i]) * A[1 , j]
620+ for k ∈ 2 : m
621+ temp += conj (A[k, i]) * A[k, j]
622+ end
623+ C[i, j] += temp * α
624+ end
625+ temp = abs2 (A[1 , j])
626+ for k ∈ 2 : m
627+ temp += abs2 (A[k, j])
628+ end
629+ C[j, j] += temp * α
630+ end
631+ end
632+ end
633+ return C
634+ end
635+
573636# legacy method
574637Base. @constprop :aggressive generic_matmatmul! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
575638 _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat } =
@@ -713,12 +776,27 @@ Base.@constprop :aggressive function syrk_wrapper!(C::StridedMatrix{T}, tA::Abst
713776 if (alpha isa Union{Bool,T} &&
714777 beta isa Union{Bool,T} &&
715778 stride (A, 1 ) == stride (C, 1 ) == 1 &&
716- _fullstride2 (A) && _fullstride2 (C))
717- return copytri! (BLAS. syrk! (' U' , tA, alpha, A, beta, C), ' U' )
779+ _fullstride2 (A) && _fullstride2 (C)) &&
780+ max (nA, mA) ≥ 4
781+ BLAS. syrk! (' U' , tA, alpha, A, beta, C)
782+ else
783+ generic_syrk! (C, A, false , tA_uc == ' N' , alpha, beta)
718784 end
785+ return copytri! (C, ' U' )
719786 end
720787 return gemm_wrapper! (C, tA, tAt, A, A, α, β)
721788end
789+ Base. @constprop :aggressive function syrk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} ,
790+ α:: Number , β:: Number ) where {T<: Number }
791+
792+ tA_uc = uppercase (tA) # potentially strip a WrapperChar
793+ aat = (tA_uc == ' N' )
794+ if T <: Union{Real,Complex} && (iszero (β) || issymmetric (C))
795+ return copytri! (generic_syrk! (C, A, false , aat, α, β), ' U' )
796+ end
797+ tAt = aat ? ' T' : ' N'
798+ return _generic_matmatmul! (C, wrap (A, tA), wrap (A, tAt), α, β)
799+ end
722800# legacy method
723801syrk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} , _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat } =
724802 syrk_wrapper! (C, tA, A, _add. alpha, _add. beta)
@@ -746,12 +824,27 @@ Base.@constprop :aggressive function herk_wrapper!(C::StridedMatrix{TC}, tA::Abs
746824 alpha, beta = promote (α, β, zero (T))
747825 if (alpha isa T && beta isa T &&
748826 stride (A, 1 ) == stride (C, 1 ) == 1 &&
749- _fullstride2 (A) && _fullstride2 (C))
750- return copytri! (BLAS. herk! (' U' , tA, alpha, A, beta, C), ' U' , true )
827+ _fullstride2 (A) && _fullstride2 (C)) &&
828+ max (nA, mA) ≥ 4
829+ BLAS. herk! (' U' , tA, alpha, A, beta, C)
830+ else
831+ generic_syrk! (C, A, true , tA_uc == ' N' , alpha, beta)
751832 end
833+ return copytri! (C, ' U' , true )
752834 end
753835 return gemm_wrapper! (C, tA, tAt, A, A, α, β)
754836end
837+ Base. @constprop :aggressive function herk_wrapper! (C:: StridedMatrix{T} , tA:: AbstractChar , A:: StridedVecOrMat{T} ,
838+ α:: Number , β:: Number ) where {T<: Number }
839+
840+ tA_uc = uppercase (tA) # potentially strip a WrapperChar
841+ aat = (tA_uc == ' N' )
842+ if isreal (α) && isreal (β) && (iszero (β) || ishermitian (C))
843+ return copytri! (generic_syrk! (C, A, true , aat, α, β), ' U' , true )
844+ end
845+ tAt = aat ? ' C' : ' N'
846+ return _generic_matmatmul! (C, wrap (A, tA), wrap (A, tAt), α, β)
847+ end
755848# legacy method
756849herk_wrapper! (C:: Union{StridedMatrix{T}, StridedMatrix{Complex{T}}} , tA:: AbstractChar , A:: Union{StridedVecOrMat{T}, StridedVecOrMat{Complex{T}}} ,
757850 _add:: MulAddMul = MulAddMul ()) where {T<: BlasReal } =
@@ -785,6 +878,7 @@ Base.@constprop :aggressive function gemm_wrapper!(C::StridedVecOrMat{T}, tA::Ab
785878 mB, nB = lapack_size (tB, B)
786879
787880 matmul_size_check (size (C), (mA, nA), (mB, nB))
881+ matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β) && return C
788882
789883 if C === A || B === C
790884 throw (ArgumentError (" output matrix must not be aliased with input matrix" ))
803897gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
804898 A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} , _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat } =
805899 gemm_wrapper! (C, tA, tB, A, B, _add. alpha, _add. beta)
900+ # fallback for generic types
901+ Base. @constprop :aggressive function gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
902+ A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
903+ α:: Number , β:: Number ) where {T<: Number }
904+ matmul2x2or3x3_nonzeroalpha! (C, tA, tB, A, B, α, β) && return C
905+ return _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), α, β)
906+ end
806907
807908# Aggressive constprop helps propagate the values of tA and tB into wrap, which
808909# makes the calls concretely inferred
0 commit comments