@@ -13,6 +13,7 @@ StaticArrayStyle{M}(::Val{N}) where {M,N} = StaticArrayStyle{N}()
1313BroadcastStyle (:: Type{<:StaticArray{<:Tuple, <:Any, N}} ) where {N} = StaticArrayStyle {N} ()
1414BroadcastStyle (:: Type{<:Transpose{<:Any, <:StaticArray{<:Tuple, <:Any, N}}} ) where {N} = StaticArrayStyle {N} ()
1515BroadcastStyle (:: Type{<:Adjoint{<:Any, <:StaticArray{<:Tuple, <:Any, N}}} ) where {N} = StaticArrayStyle {N} ()
16+ BroadcastStyle (:: Type{<:Diagonal{<:Any, <:StaticArray{<:Tuple, <:Any, 1}}} ) = StaticArrayStyle {2} ()
1617# Precedence rules
1718BroadcastStyle (:: StaticArrayStyle{M} , :: DefaultArrayStyle{N} ) where {M,N} =
1819 DefaultArrayStyle (Val (max (M, N)))
@@ -97,7 +98,7 @@ scalar_getindex(x) = x
9798scalar_getindex (x:: Ref ) = x[]
9899
99100@generated function _broadcast (f, :: Size{newsize} , s:: Tuple{Vararg{Size}} , a... ) where newsize
100- first_staticarray = a[findfirst (ai -> ai <: Union{StaticArray, Transpose{<:Any, <:StaticArray}, Adjoint{<:Any, <:StaticArray}} , a)]
101+ first_staticarray = a[findfirst (ai -> ai <: Union{StaticArray, Transpose{<:Any, <:StaticArray}, Adjoint{<:Any, <:StaticArray}, Diagonal{<:Any, <:StaticArray} } , a)]
101102
102103 if prod (newsize) == 0
103104 # Use inference to get eltype in empty case (see also comments in _map)
0 commit comments