@@ -24,19 +24,19 @@ Should pair with `parent`.
2424"""
2525struct TSize{S,T}
2626 function TSize {S,T} () where {S,T}
27- new {S::Tuple{Vararg{StaticDimension}},T::Bool } ()
27+ new {S::Tuple{Vararg{StaticDimension}},T::Symbol } ()
2828 end
2929end
30- TSize (A:: Type{<:Transpose{<:Any,<:StaticArray}} ) = TSize {size(A),true} ()
31- TSize (A:: Type{<:Adjoint{<:Real,<:StaticArray}} ) = TSize {size(A),true} () # can't handle complex adjoints yet
32- TSize (A:: Type{<:StaticArray} ) = TSize {size(A),false} ()
30+ TSize (A:: Type{<:StaticArrayLike} ) = TSize {size(A), gen_by_access(identity, A)} ()
3331TSize (A:: StaticArrayLike ) = TSize (typeof (A))
34- TSize (S:: Size{s} , T= false ) where s = TSize {s,T} ()
32+ TSize (S:: Size{s} , T= :any ) where s = TSize {s,T} ()
3533TSize (s:: Number ) = TSize (Size (s))
36- istranpose (:: TSize{<:Any,T} ) where T = T
34+ istranspose (:: TSize{<:Any,T} ) where T = (T === :transpose )
3735size (:: TSize{S} ) where S = S
3836Size (:: TSize{S} ) where S = Size {S} ()
39- Base. transpose (:: TSize{S,T} ) where {S,T} = TSize {reverse(S),!T} ()
37+ access_type (:: TSize{<:Any,T} ) where T = T
38+ Base. transpose (:: TSize{S,:transpose} ) where {S,T} = TSize {reverse(S),:any} ()
39+ Base. transpose (:: TSize{S,:any} ) where {S,T} = TSize {reverse(S),:transpose} ()
4040
4141# Get the parent of transposed arrays, or the array itself if it has no parent
4242# Different from Base.parent because we only want to get rid of Transpose and Adjoint
9797" Obtain an expression for the linear index of var[k,j], taking transposes into account"
9898@inline _lind (A:: Type{<:TSize} , k:: Int , j:: Int ) = _lind (:a , A, k, j)
9999function _lind (var:: Symbol , A:: Type{TSize{sa,tA}} , k:: Int , j:: Int ) where {sa,tA}
100- if tA
101- return :($ var[$ (LinearIndices (reverse (sa))[j, k])])
102- else
103- return :($ var[$ (LinearIndices (sa)[k, j])])
104- end
100+ return uplo_access (sa, var, k, j, tA)
105101end
106102
103+
104+
107105# Matrix-vector multiplication
108106@generated function _mul! (Sc:: TSize{sc} , c:: StaticVecOrMatLike , Sa:: TSize{sa} , Sb:: TSize{sb} ,
109107 a:: StaticMatrix , b:: StaticVector , _add:: MulAddMul ,
@@ -133,14 +131,21 @@ end
133131end
134132
135133# Outer product
136- @generated function _mul! (:: TSize{sc} , c:: StaticMatrix , :: TSize{sa,false } , :: TSize{sb,true } ,
134+ @generated function _mul! (:: TSize{sc} , c:: StaticMatrix , :: TSize{sa,:any } , tsb :: Union{ TSize{sb,:transpose},TSize{sb,:adjoint} } ,
137135 a:: StaticVector , b:: StaticVector , _add:: MulAddMul ) where {sa, sb, sc}
138136 if sc[1 ] != sa[1 ] || sc[2 ] != sb[2 ]
139137 throw (DimensionMismatch (" Tried to multiply arrays of size $sa and $sb and assign to array of size $sc " ))
140138 end
141139
140+ conjugate_b = isa (tsb, TSize{sb,:adjoint })
141+
142142 lhs = [:(c[$ (LinearIndices (sc)[i,j])]) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
143- ab = [:(a[$ i] * b[$ j]) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
143+ if conjugate_b
144+ ab = [:(a[$ i] * adjoint (b[$ j])) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
145+ else
146+ ab = [:(a[$ i] * b[$ j]) for i = 1 : sa[1 ], j = 1 : sb[2 ]]
147+ end
148+
144149 exprs = _muladd_expr (lhs, ab, _add)
145150
146151 return quote
@@ -267,17 +272,18 @@ end
267272@inline _get_raw_data (A:: SizedArray ) = A. data
268273@inline _get_raw_data (A:: StaticArray ) = A
269274
270- function mul_blas! (:: TSize{<:Any,false} , c:: StaticMatrix , :: TSize{<:Any,tA} , :: TSize{<:Any,tB} ,
271- a:: StaticMatrix , b:: StaticMatrix , _add:: MulAddMul ) where {tA,tB}
272- mat_char (tA) = tA ? ' T' : ' N'
275+ function mul_blas! (:: TSize{<:Any,:any} , c:: StaticMatrix ,
276+ Sa:: Union{TSize{<:Any,:any}, TSize{<:Any,:transpose}} , Sb:: Union{TSize{<:Any,:any}, TSize{<:Any,:transpose}} ,
277+ a:: StaticMatrix , b:: StaticMatrix , _add:: MulAddMul )
278+ mat_char (s) = istranspose (s) ? ' T' : ' N'
273279 T = eltype (a)
274280 A = _get_raw_data (a)
275281 B = _get_raw_data (b)
276282 C = _get_raw_data (c)
277- BLAS. gemm! (mat_char (tA ), mat_char (tB ), T (alpha (_add)), A, B, T (beta (_add)), C)
283+ BLAS. gemm! (mat_char (Sa ), mat_char (Sb ), T (alpha (_add)), A, B, T (beta (_add)), C)
278284end
279285
280286# if C is transposed, transpose the entire expression
281- @inline mul_blas! (Sc:: TSize{<:Any,true } , c:: StaticMatrix , Sa:: TSize , Sb:: TSize ,
287+ @inline mul_blas! (Sc:: TSize{<:Any,:transpose } , c:: StaticMatrix , Sa:: TSize , Sb:: TSize ,
282288 a:: StaticMatrix , b:: StaticMatrix , _add:: MulAddMul ) =
283289 mul_blas! (transpose (Sc), c, transpose (Sb), transpose (Sa), b, a, _add)
0 commit comments