@@ -337,7 +337,7 @@ julia> lmul!(F.Q, B)
337337lmul! (A, B)
338338
339339# THE one big BLAS dispatch
340- @inline function generic_matmatmul! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
340+ Base . @constprop :aggressive function generic_matmatmul! (C:: StridedMatrix{T} , tA, tB, A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
341341 _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat }
342342 if all (in ((' N' , ' T' , ' C' )), (tA, tB))
343343 if tA == ' T' && tB == ' N' && A === B
@@ -364,16 +364,16 @@ lmul!(A, B)
364364 return BLAS. hemm! (' R' , tB == ' H' ? ' U' : ' L' , alpha, B, A, beta, C)
365365 end
366366 end
367- return _generic_matmatmul! (C, ' N ' , ' N ' , wrap (A, tA), wrap (B, tB), _add)
367+ return _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
368368end
369369
370370# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
371- @inline function generic_matmatmul! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
371+ Base . @constprop :aggressive function generic_matmatmul! (C:: StridedVecOrMat{Complex{T}} , tA, tB, A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
372372 _add:: MulAddMul = MulAddMul ()) where {T<: BlasReal }
373373 if all (in ((' N' , ' T' , ' C' )), (tA, tB))
374374 gemm_wrapper! (C, tA, tB, A, B, _add)
375375 else
376- _generic_matmatmul! (C, ' N ' , ' N ' , wrap (A, tA), wrap (B, tB), _add)
376+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
377377 end
378378end
379379
@@ -563,11 +563,11 @@ function gemm_wrapper(tA::AbstractChar, tB::AbstractChar,
563563 if all (in ((' N' , ' T' , ' C' )), (tA, tB))
564564 gemm_wrapper! (C, tA, tB, A, B)
565565 else
566- _generic_matmatmul! (C, ' N ' , ' N ' , wrap (A, tA), wrap (B, tB), _add)
566+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
567567 end
568568end
569569
570- function gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
570+ Base . @constprop :aggressive function gemm_wrapper! (C:: StridedVecOrMat{T} , tA:: AbstractChar , tB:: AbstractChar ,
571571 A:: StridedVecOrMat{T} , B:: StridedVecOrMat{T} ,
572572 _add = MulAddMul ()) where {T<: BlasFloat }
573573 mA, nA = lapack_size (tA, A)
@@ -604,10 +604,10 @@ function gemm_wrapper!(C::StridedVecOrMat{T}, tA::AbstractChar, tB::AbstractChar
604604 stride (C, 2 ) >= size (C, 1 ))
605605 return BLAS. gemm! (tA, tB, alpha, A, B, beta, C)
606606 end
607- _generic_matmatmul! (C, tA, tB, A, B , _add)
607+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB) , _add)
608608end
609609
610- function gemm_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA:: AbstractChar , tB:: AbstractChar ,
610+ Base . @constprop :aggressive function gemm_wrapper! (C:: StridedVecOrMat{Complex{T}} , tA:: AbstractChar , tB:: AbstractChar ,
611611 A:: StridedVecOrMat{Complex{T}} , B:: StridedVecOrMat{T} ,
612612 _add = MulAddMul ()) where {T<: BlasReal }
613613 mA, nA = lapack_size (tA, A)
@@ -647,7 +647,7 @@ function gemm_wrapper!(C::StridedVecOrMat{Complex{T}}, tA::AbstractChar, tB::Abs
647647 BLAS. gemm! (tA, tB, alpha, reinterpret (T, A), B, beta, reinterpret (T, C))
648648 return C
649649 end
650- _generic_matmatmul! (C, tA, tB, A, B , _add)
650+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB) , _add)
651651end
652652
653653# blas.jl defines matmul for floats; other integer and mixed precision
@@ -764,197 +764,65 @@ end
764764
765765const tilebufsize = 10800 # Approximately 32k/3
766766
767- function generic_matmatmul! (C:: AbstractVecOrMat , tA, tB, A:: AbstractVecOrMat , B:: AbstractVecOrMat , _add:: MulAddMul )
768- mA, nA = lapack_size (tA, A)
769- mB, nB = lapack_size (tB, B)
770- mC, nC = size (C)
771-
772- if iszero (_add. alpha)
773- return _rmul_or_fill! (C, _add. beta)
774- end
775- if mA == nA == mB == nB == mC == nC == 2
776- return matmul2x2! (C, tA, tB, A, B, _add)
777- end
778- if mA == nA == mB == nB == mC == nC == 3
779- return matmul3x3! (C, tA, tB, A, B, _add)
780- end
781- A, tA = tA in (' H' , ' h' , ' S' , ' s' ) ? (wrap (A, tA), ' N' ) : (A, tA)
782- B, tB = tB in (' H' , ' h' , ' S' , ' s' ) ? (wrap (B, tB), ' N' ) : (B, tB)
783- _generic_matmatmul! (C, tA, tB, A, B, _add)
784- end
767+ Base. @constprop :aggressive generic_matmatmul! (C:: AbstractVecOrMat , tA, tB, A:: AbstractVecOrMat , B:: AbstractVecOrMat , _add:: MulAddMul ) =
768+ _generic_matmatmul! (C, wrap (A, tA), wrap (B, tB), _add)
785769
786- function _generic_matmatmul! (C:: AbstractVecOrMat{R} , tA, tB , A:: AbstractVecOrMat{T} , B:: AbstractVecOrMat{S} ,
770+ @noinline function _generic_matmatmul! (C:: AbstractVecOrMat{R} , A:: AbstractVecOrMat{T} , B:: AbstractVecOrMat{S} ,
787771 _add:: MulAddMul ) where {T,S,R}
788- @assert tA in (' N' , ' T' , ' C' ) && tB in (' N' , ' T' , ' C' )
789- require_one_based_indexing (C, A, B)
790-
791- mA, nA = lapack_size (tA, A)
792- mB, nB = lapack_size (tB, B)
793- if mB != nA
794- throw (DimensionMismatch (lazy " matrix A has dimensions ($mA,$nA), matrix B has dimensions ($mB,$nB)" ))
795- end
796- if size (C,1 ) != mA || size (C,2 ) != nB
797- throw (DimensionMismatch (lazy " result C has dimensions $(size(C)), needs ($mA,$nB)" ))
798- end
799-
800- if iszero (_add. alpha) || isempty (A) || isempty (B)
801- return _rmul_or_fill! (C, _add. beta)
802- end
803-
804- tile_size = 0
805- if isbitstype (R) && isbitstype (T) && isbitstype (S) && (tA == ' N' || tB != ' N' )
806- tile_size = floor (Int, sqrt (tilebufsize / max (sizeof (R), sizeof (S), sizeof (T), 1 )))
807- end
808- @inbounds begin
809- if tile_size > 0
810- sz = (tile_size, tile_size)
811- Atile = Array {T} (undef, sz)
812- Btile = Array {S} (undef, sz)
813-
814- z1 = zero (A[1 , 1 ]* B[1 , 1 ] + A[1 , 1 ]* B[1 , 1 ])
815- z = convert (promote_type (typeof (z1), R), z1)
816-
817- if mA < tile_size && nA < tile_size && nB < tile_size
818- copy_transpose! (Atile, 1 : nA, 1 : mA, tA, A, 1 : mA, 1 : nA)
819- copyto! (Btile, 1 : mB, 1 : nB, tB, B, 1 : mB, 1 : nB)
820- for j = 1 : nB
821- boff = (j- 1 )* tile_size
822- for i = 1 : mA
823- aoff = (i- 1 )* tile_size
824- s = z
825- for k = 1 : nA
826- s += Atile[aoff+ k] * Btile[boff+ k]
827- end
828- _modify! (_add, s, C, (i,j))
829- end
830- end
831- else
832- Ctile = Array {R} (undef, sz)
833- for jb = 1 : tile_size: nB
834- jlim = min (jb+ tile_size- 1 ,nB)
835- jlen = jlim- jb+ 1
836- for ib = 1 : tile_size: mA
837- ilim = min (ib+ tile_size- 1 ,mA)
838- ilen = ilim- ib+ 1
839- fill! (Ctile, z)
840- for kb = 1 : tile_size: nA
841- klim = min (kb+ tile_size- 1 ,mB)
842- klen = klim- kb+ 1
843- copy_transpose! (Atile, 1 : klen, 1 : ilen, tA, A, ib: ilim, kb: klim)
844- copyto! (Btile, 1 : klen, 1 : jlen, tB, B, kb: klim, jb: jlim)
845- for j= 1 : jlen
846- bcoff = (j- 1 )* tile_size
847- for i = 1 : ilen
848- aoff = (i- 1 )* tile_size
849- s = z
850- for k = 1 : klen
851- s += Atile[aoff+ k] * Btile[bcoff+ k]
852- end
853- Ctile[bcoff+ i] += s
854- end
855- end
856- end
857- if isone (_add. alpha) && iszero (_add. beta)
858- copyto! (C, ib: ilim, jb: jlim, Ctile, 1 : ilen, 1 : jlen)
859- else
860- C[ib: ilim, jb: jlim] .= @views _add .(Ctile[1 : ilen, 1 : jlen], C[ib: ilim, jb: jlim])
861- end
862- end
772+ AxM = axes (A, 1 )
773+ AxK = axes (A, 2 ) # we use two `axes` calls in case of `AbstractVector`
774+ BxK = axes (B, 1 )
775+ BxN = axes (B, 2 )
776+ CxM = axes (C, 1 )
777+ CxN = axes (C, 2 )
778+ if AxM != CxM
779+ throw (DimensionMismatch (lazy " matrix A has axes ($AxM,$AxK), matrix C has axes ($CxM,$CxN)" ))
780+ end
781+ if AxK != BxK
782+ throw (DimensionMismatch (lazy " matrix A has axes ($AxM,$AxK), matrix B has axes ($BxK,$CxN)" ))
783+ end
784+ if BxN != CxN
785+ throw (DimensionMismatch (lazy " matrix B has axes ($BxK,$BxN), matrix C has axes ($CxM,$CxN)" ))
786+ end
787+ if isbitstype (R) && sizeof (R) ≤ 16 && ! (A isa Adjoint || A isa Transpose)
788+ _rmul_or_fill! (C, _add. beta)
789+ (iszero (_add. alpha) || isempty (A) || isempty (B)) && return C
790+ @inbounds for n in BxN, k in BxK
791+ Balpha = B[k,n]* _add. alpha
792+ @simd for m in AxM
793+ C[m,n] = muladd (A[m,k], Balpha, C[m,n])
863794 end
864795 end
796+ elseif isbitstype (R) && sizeof (R) ≤ 16 && ((A isa Adjoint && B isa Adjoint) || (A isa Transpose && B isa Transpose))
797+ _rmul_or_fill! (C, _add. beta)
798+ (iszero (_add. alpha) || isempty (A) || isempty (B)) && return C
799+ t = wrapperop (A)
800+ pB = parent (B)
801+ pA = parent (A)
802+ tmp = similar (C, CxN)
803+ ci = first (CxM)
804+ ta = t (_add. alpha)
805+ for i in AxM
806+ mul! (tmp, pB, view (pA, :, i))
807+ C[ci,:] .+ = t .(ta .* tmp)
808+ ci += 1
809+ end
865810 else
866- # Multiplication for non-plain-data uses the naive algorithm
867- if tA == ' N'
868- if tB == ' N'
869- for i = 1 : mA, j = 1 : nB
870- z2 = zero (A[i, 1 ]* B[1 , j] + A[i, 1 ]* B[1 , j])
871- Ctmp = convert (promote_type (R, typeof (z2)), z2)
872- for k = 1 : nA
873- Ctmp += A[i, k]* B[k, j]
874- end
875- _modify! (_add, Ctmp, C, (i,j))
876- end
877- elseif tB == ' T'
878- for i = 1 : mA, j = 1 : nB
879- z2 = zero (A[i, 1 ]* transpose (B[j, 1 ]) + A[i, 1 ]* transpose (B[j, 1 ]))
880- Ctmp = convert (promote_type (R, typeof (z2)), z2)
881- for k = 1 : nA
882- Ctmp += A[i, k] * transpose (B[j, k])
883- end
884- _modify! (_add, Ctmp, C, (i,j))
885- end
886- else
887- for i = 1 : mA, j = 1 : nB
888- z2 = zero (A[i, 1 ]* B[j, 1 ]' + A[i, 1 ]* B[j, 1 ]' )
889- Ctmp = convert (promote_type (R, typeof (z2)), z2)
890- for k = 1 : nA
891- Ctmp += A[i, k]* B[j, k]'
892- end
893- _modify! (_add, Ctmp, C, (i,j))
894- end
895- end
896- elseif tA == ' T'
897- if tB == ' N'
898- for i = 1 : mA, j = 1 : nB
899- z2 = zero (transpose (A[1 , i])* B[1 , j] + transpose (A[1 , i])* B[1 , j])
900- Ctmp = convert (promote_type (R, typeof (z2)), z2)
901- for k = 1 : nA
902- Ctmp += transpose (A[k, i]) * B[k, j]
903- end
904- _modify! (_add, Ctmp, C, (i,j))
905- end
906- elseif tB == ' T'
907- for i = 1 : mA, j = 1 : nB
908- z2 = zero (transpose (A[1 , i])* transpose (B[j, 1 ]) + transpose (A[1 , i])* transpose (B[j, 1 ]))
909- Ctmp = convert (promote_type (R, typeof (z2)), z2)
910- for k = 1 : nA
911- Ctmp += transpose (A[k, i]) * transpose (B[j, k])
912- end
913- _modify! (_add, Ctmp, C, (i,j))
914- end
915- else
916- for i = 1 : mA, j = 1 : nB
917- z2 = zero (transpose (A[1 , i])* B[j, 1 ]' + transpose (A[1 , i])* B[j, 1 ]' )
918- Ctmp = convert (promote_type (R, typeof (z2)), z2)
919- for k = 1 : nA
920- Ctmp += transpose (A[k, i]) * adjoint (B[j, k])
921- end
922- _modify! (_add, Ctmp, C, (i,j))
923- end
924- end
925- else
926- if tB == ' N'
927- for i = 1 : mA, j = 1 : nB
928- z2 = zero (A[1 , i]' * B[1 , j] + A[1 , i]' * B[1 , j])
929- Ctmp = convert (promote_type (R, typeof (z2)), z2)
930- for k = 1 : nA
931- Ctmp += A[k, i]' B[k, j]
932- end
933- _modify! (_add, Ctmp, C, (i,j))
934- end
935- elseif tB == ' T'
936- for i = 1 : mA, j = 1 : nB
937- z2 = zero (A[1 , i]' * transpose (B[j, 1 ]) + A[1 , i]' * transpose (B[j, 1 ]))
938- Ctmp = convert (promote_type (R, typeof (z2)), z2)
939- for k = 1 : nA
940- Ctmp += adjoint (A[k, i]) * transpose (B[j, k])
941- end
942- _modify! (_add, Ctmp, C, (i,j))
943- end
944- else
945- for i = 1 : mA, j = 1 : nB
946- z2 = zero (A[1 , i]' * B[j, 1 ]' + A[1 , i]' * B[j, 1 ]' )
947- Ctmp = convert (promote_type (R, typeof (z2)), z2)
948- for k = 1 : nA
949- Ctmp += A[k, i]' B[j, k]'
950- end
951- _modify! (_add, Ctmp, C, (i,j))
952- end
811+ if iszero (_add. alpha) || isempty (A) || isempty (B)
812+ return _rmul_or_fill! (C, _add. beta)
813+ end
814+ a1 = first (AxK)
815+ b1 = first (BxK)
816+ @inbounds for i in AxM, j in BxN
817+ z2 = zero (A[i, a1]* B[b1, j] + A[i, a1]* B[b1, j])
818+ Ctmp = convert (promote_type (R, typeof (z2)), z2)
819+ @simd for k in AxK
820+ Ctmp = muladd (A[i, k], B[k, j], Ctmp)
953821 end
822+ _modify! (_add, Ctmp, C, (i,j))
954823 end
955824 end
956- end # @inbounds
957- C
825+ return C
958826end
959827
960828
@@ -963,7 +831,7 @@ function matmul2x2(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
963831 matmul2x2! (similar (B, promote_op (matprod, T, S), 2 , 2 ), tA, tB, A, B)
964832end
965833
966- function matmul2x2! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
834+ Base . @constprop :aggressive function matmul2x2! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
967835 _add:: MulAddMul = MulAddMul ())
968836 require_one_based_indexing (C, A, B)
969837 if ! (size (A) == size (B) == size (C) == (2 ,2 ))
@@ -1030,7 +898,7 @@ function matmul3x3(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S}) where {T,
1030898 matmul3x3! (similar (B, promote_op (matprod, T, S), 3 , 3 ), tA, tB, A, B)
1031899end
1032900
1033- function matmul3x3! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
901+ Base . @constprop :aggressive function matmul3x3! (C:: AbstractMatrix , tA, tB, A:: AbstractMatrix , B:: AbstractMatrix ,
1034902 _add:: MulAddMul = MulAddMul ())
1035903 require_one_based_indexing (C, A, B)
1036904 if ! (size (A) == size (B) == size (C) == (3 ,3 ))
0 commit comments