22# # broadcast! ##
33# ###############
44
5- import Base. Broadcast:
6- BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, materialize!
5+ using Base. Broadcast: AbstractArrayStyle, DefaultArrayStyle, Style, Broadcasted
6+ using Base. Broadcast: broadcast_shape, _broadcast_getindex, combine_axes
7+ import Base. Broadcast: BroadcastStyle, materialize!, instantiate
78import Base. Broadcast: _bcs1 # for SOneTo axis information
89using Base. Broadcast: _bcsm
910# Add a new BroadcastStyle for StaticArrays, derived from AbstractArrayStyle
1011# A constructor that changes the style parameter N (array dimension) is also required
1112struct StaticArrayStyle{N} <: AbstractArrayStyle{N} end
1213StaticArrayStyle {M} (:: Val{N} ) where {M,N} = StaticArrayStyle {N} ()
1314BroadcastStyle (:: Type{<:StaticArray{<:Tuple, <:Any, N}} ) where {N} = StaticArrayStyle {N} ()
14- BroadcastStyle (:: Type{<:Transpose{<:Any, <:StaticArray{<:Tuple, <:Any, N}}} ) where {N} = StaticArrayStyle {N } ()
15- BroadcastStyle (:: Type{<:Adjoint{<:Any, <:StaticArray{<:Tuple, <:Any, N}}} ) where {N} = StaticArrayStyle {N } ()
15+ BroadcastStyle (:: Type{<:Transpose{<:Any, <:StaticArray}} ) = StaticArrayStyle {2 } ()
16+ BroadcastStyle (:: Type{<:Adjoint{<:Any, <:StaticArray}} ) = StaticArrayStyle {2 } ()
1617BroadcastStyle (:: Type{<:Diagonal{<:Any, <:StaticArray{<:Tuple, <:Any, 1}}} ) = StaticArrayStyle {2} ()
1718# Precedence rules
1819BroadcastStyle (:: StaticArrayStyle{M} , :: DefaultArrayStyle{N} ) where {M,N} =
1920 DefaultArrayStyle (Val (max (M, N)))
2021BroadcastStyle (:: StaticArrayStyle{M} , :: DefaultArrayStyle{0} ) where {M} =
2122 StaticArrayStyle {M} ()
23+
24+ # combine_axes overload (for Tuple)
25+ @inline static_combine_axes (A, B... ) = broadcast_shape (static_axes (A), static_combine_axes (B... ))
26+ static_combine_axes (A) = static_axes (A)
27+ static_axes (A) = axes (A)
28+ static_axes (x:: Tuple ) = (SOneTo {length(x)} (),)
29+ static_axes (bc:: Broadcasted{Style{Tuple}} ) = static_combine_axes (bc. args... )
30+ Broadcast. _axes (bc:: Broadcasted{<:StaticArrayStyle} , :: Nothing ) = static_combine_axes (bc. args... )
31+
32+ # instantiate overload
33+ @inline function instantiate (B:: Broadcasted{StaticArrayStyle{M}} ) where M
34+ if B. axes isa Tuple{Vararg{SOneTo}} || B. axes isa Tuple && length (B. axes) > M
35+ return invoke (instantiate, Tuple{Broadcasted}, B)
36+ elseif B. axes isa Nothing
37+ ax = static_combine_axes (B. args... )
38+ return Broadcasted {StaticArrayStyle{M}} (B. f, B. args, ax)
39+ else
40+ # We need to update B.axes for `broadcast!` if it's not static and `ndims(dest) < M`.
41+ ax = static_check_broadcast_shape (B. axes, static_combine_axes (B. args... ))
42+ return Broadcasted {StaticArrayStyle{M}} (B. f, B. args, ax)
43+ end
44+ end
45+ @inline function static_check_broadcast_shape (shp:: Tuple , Ashp:: Tuple{Vararg{SOneTo}} )
46+ ax1 = if length (Ashp[1 ]) == 1
47+ shp[1 ]
48+ elseif Ashp[1 ] == shp[1 ]
49+ Ashp[1 ]
50+ else
51+ throw (DimensionMismatch (" array could not be broadcast to match destination" ))
52+ end
53+ return (ax1, static_check_broadcast_shape (Base. tail (shp), Base. tail (Ashp))... )
54+ end
55+ static_check_broadcast_shape (:: Tuple{} , :: Tuple{SOneTo,Vararg{SOneTo}} ) =
56+ throw (DimensionMismatch (" cannot broadcast array to have fewer non-singleton dimensions" ))
57+ static_check_broadcast_shape (:: Tuple{} , :: Tuple{SOneTo{1},Vararg{SOneTo{1}}} ) = ()
58+ static_check_broadcast_shape (:: Tuple{} , :: Tuple{} ) = ()
2259# copy overload
2360@inline function Base. copy (B:: Broadcasted{StaticArrayStyle{M}} ) where M
2461 flat = Broadcast. flatten (B); as = flat. args; f = flat. f
2562 argsizes = broadcast_sizes (as... )
26- destsize = combine_sizes (argsizes)
27- _broadcast (f, destsize, argsizes, as... )
63+ ax = axes (B)
64+ ax isa Tuple{Vararg{SOneTo}} || error (" Dimension is not static. Please file a bug." )
65+ return _broadcast (f, Size (map (length, ax)), argsizes, as... )
2866end
2967# copyto! overloads
3068@inline Base. copyto! (dest, B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
3169@inline Base. copyto! (dest:: AbstractArray , B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
3270@inline function _copyto! (dest, B:: Broadcasted{StaticArrayStyle{M}} ) where M
3371 flat = Broadcast. flatten (B); as = flat. args; f = flat. f
3472 argsizes = broadcast_sizes (as... )
35- destsize = combine_sizes (( Size (dest), argsizes ... ) )
36- if Length (destsize) === Length {Dynamic()} ()
37- # destination dimension cannot be determined statically; fall back to generic broadcast!
38- return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B) )
73+ ax = axes (B )
74+ if ax isa Tuple{Vararg{SOneTo}}
75+ @boundscheck axes (dest) == ax || Broadcast . throwdm ( axes (dest), ax)
76+ return _broadcast! (f, Size ( map (length, ax)), dest, argsizes, as ... )
3977 end
40- _broadcast! (f, destsize, dest, argsizes, as... )
78+ # destination dimension cannot be determined statically; fall back to generic broadcast!
79+ return copyto! (dest, convert (Broadcasted{DefaultArrayStyle{M}}, B))
4180end
4281
4382# Resolving priority between dynamic and static axes
4483_bcs1 (a:: SOneTo , b:: SOneTo ) = _bcsm (b, a) ? b : (_bcsm (a, b) ? a : throw (DimensionMismatch (" arrays could not be broadcast to a common size" )))
45- _bcs1 (a:: SOneTo , b:: Base.OneTo ) = _bcs1 (Base. OneTo (a), b)
46- _bcs1 (a:: Base.OneTo , b:: SOneTo ) = _bcs1 (a, Base. OneTo (b))
84+ function _bcs1 (a:: SOneTo , b:: Base.OneTo )
85+ length (a) == 1 && return b
86+ if length (b) != length (a) && length (b) != 1
87+ throw (DimensionMismatch (" arrays could not be broadcast to a common size" ))
88+ end
89+ return a
90+ end
91+ _bcs1 (a:: Base.OneTo , b:: SOneTo ) = _bcs1 (b, a)
4792
4893# ##################################################
4994# # Internal broadcast machinery for StaticArrays ##
@@ -56,45 +101,13 @@ _bcs1(a::Base.OneTo, b::SOneTo) = _bcs1(a, Base.OneTo(b))
56101@inline broadcast_size (a:: AbstractArray ) = Size (a)
57102@inline broadcast_size (a:: Tuple ) = Size (length (a))
58103
59- function broadcasted_index (oldsize, newindex)
60- index = ones (Int, length (oldsize))
61- for i = 1 : length (oldsize)
62- if oldsize[i] != 1
63- index[i] = newindex[i]
64- end
65- end
66- return LinearIndices (oldsize)[index... ]
67- end
68-
69- # similar to Base.Broadcast.combine_indices:
70- @generated function combine_sizes (s:: Tuple{Vararg{Size}} )
71- sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
72- ndims = 0
73- for i = 1 : length (sizes)
74- ndims = max (ndims, length (sizes[i]))
75- end
76- newsize = StaticDimension[Dynamic () for _ = 1 : ndims]
77- for i = 1 : length (sizes)
78- s = sizes[i]
79- for j = 1 : length (s)
80- if s[j] isa Dynamic
81- continue
82- elseif newsize[j] isa Dynamic || newsize[j] == 1
83- newsize[j] = s[j]
84- elseif newsize[j] ≠ s[j] && s[j] ≠ 1
85- throw (DimensionMismatch (" Tried to broadcast on inputs sized $sizes " ))
86- end
87- end
88- end
89- quote
90- @_inline_meta
91- Size ($ (tuple (newsize... )))
92- end
104+ broadcast_getindex (:: Tuple{} , i:: Int , I:: CartesianIndex ) = return :(_broadcast_getindex (a[$ i], $ I))
105+ function broadcast_getindex (oldsize:: Tuple , i:: Int , newindex:: CartesianIndex )
106+ li = LinearIndices (oldsize)
107+ ind = _broadcast_getindex (li, newindex)
108+ return :(a[$ i][$ ind])
93109end
94110
95- scalar_getindex (x) = x
96- scalar_getindex (x:: Ref ) = x[]
97-
98111isstatic (:: StaticArray ) = true
99112isstatic (:: Transpose{<:Any, <:StaticArray} ) = true
100113isstatic (:: Adjoint{<:Any, <:StaticArray} ) = true
@@ -118,13 +131,11 @@ end
118131
119132@generated function __broadcast (f, :: Size{newsize} , s:: Tuple{Vararg{Size}} , a... ) where newsize
120133 sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
134+
121135 indices = CartesianIndices (newsize)
122136 exprs = similar (indices, Expr)
123137 for (j, current_ind) ∈ enumerate (indices)
124- exprs_vals = [
125- (! (a[i] <: AbstractArray || a[i] <: Tuple ) ? :(scalar_getindex (a[$ i])) : :(a[$ i][$ (broadcasted_index (sizes[i], current_ind))]))
126- for i = 1 : length (sizes)
127- ]
138+ exprs_vals = (broadcast_getindex (sz, i, current_ind) for (i, sz) in enumerate (sizes))
128139 exprs[j] = :(f ($ (exprs_vals... )))
129140 end
130141
@@ -138,27 +149,18 @@ end
138149# # Internal broadcast! machinery for StaticArrays ##
139150# ###################################################
140151
141- @generated function _broadcast! (f, :: Size{newsize} , dest:: AbstractArray , s:: Tuple{Vararg{Size}} , as... ) where {newsize}
142- sizes = [sz. parameters[1 ] for sz ∈ s. parameters]
143- sizes = tuple (sizes... )
144-
145- # TODO : this could also be done outside the generated function:
146- sizematch (Size {newsize} (), Size (dest)) ||
147- throw (DimensionMismatch (" Tried to broadcast to destination sized $newsize from inputs sized $sizes " ))
152+ @generated function _broadcast! (f, :: Size{newsize} , dest:: AbstractArray , s:: Tuple{Vararg{Size}} , a... ) where {newsize}
153+ sizes = [sz. parameters[1 ] for sz in s. parameters]
148154
149155 indices = CartesianIndices (newsize)
150156 exprs = similar (indices, Expr)
151157 for (j, current_ind) ∈ enumerate (indices)
152- exprs_vals = [
153- (! (as[i] <: AbstractArray || as[i] <: Tuple ) ? :(as[$ i][]) : :(as[$ i][$ (broadcasted_index (sizes[i], current_ind))]))
154- for i = 1 : length (sizes)
155- ]
158+ exprs_vals = (broadcast_getindex (sz, i, current_ind) for (i, sz) in enumerate (sizes))
156159 exprs[j] = :(dest[$ j] = f ($ (exprs_vals... )))
157160 end
158161
159162 return quote
160- @_propagate_inbounds_meta
161- @boundscheck sizematch ($ (Size {newsize} ()), dest) || throw (DimensionMismatch (" array could not be broadcast to match destination" ))
163+ @_inline_meta
162164 @inbounds $ (Expr (:block , exprs... ))
163165 return dest
164166 end
0 commit comments