@@ -58,7 +58,7 @@ static_check_broadcast_shape(::Tuple{}, ::Tuple{SOneTo{1},Vararg{SOneTo{1}}}) =
5858static_check_broadcast_shape (:: Tuple{} , :: Tuple{} ) = ()
5959# copy overload
6060@inline function Base. copy (B:: Broadcasted{StaticArrayStyle{M}} ) where M
61- flat = Broadcast . flatten (B); as = flat. args; f = flat. f
61+ flat = broadcast_flatten (B); as = flat. args; f = flat. f
6262 argsizes = broadcast_sizes (as... )
6363 ax = axes (B)
6464 ax isa Tuple{Vararg{SOneTo}} || error (" Dimension is not static. Please file a bug." )
6868@inline Base. copyto! (dest, B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
6969@inline Base. copyto! (dest:: AbstractArray , B:: Broadcasted{<:StaticArrayStyle} ) = _copyto! (dest, B)
7070@inline function _copyto! (dest, B:: Broadcasted{StaticArrayStyle{M}} ) where M
71- flat = Broadcast . flatten (B); as = flat. args; f = flat. f
71+ flat = broadcast_flatten (B); as = flat. args; f = flat. f
7272 argsizes = broadcast_sizes (as... )
7373 ax = axes (B)
7474 if ax isa Tuple{Vararg{SOneTo}}
165165 return dest
166166 end
167167end
168+
169+ # Work around for https://github.com/JuliaLang/julia/issues/27988
170+ # The following code is borrowed from https://github.com/JuliaLang/julia/pull/43322
171+ # with some modification to make it also works on 1.6.
172+ # TODO : make `broadcast_flatten` call `Broadcast.flatten` once julia#43322 is merged.
173+ module StableFlatten
174+
175+ export broadcast_flatten
176+
177+ using Base: tail
178+ using Base. Broadcast: isflat, Broadcasted
179+
180+ maybeconstructor (f) = f
181+ maybeconstructor (:: Type{F} ) where {F} = (args... ; kwargs... ) -> F (args... ; kwargs... )
182+
183+ function broadcast_flatten (bc:: Broadcasted{Style} ) where {Style}
184+ isflat (bc) && return bc
185+ args = cat_nested (bc)
186+ len = Val {length(args)} ()
187+ makeargs = make_makeargs (bc. args, len, ntuple (_-> true , len))
188+ f = maybeconstructor (bc. f)
189+ @inline newf (args... ) = f (prepare_args (makeargs, args)... )
190+ return Broadcasted {Style} (newf, args, bc. axes)
191+ end
192+
193+ cat_nested (bc:: Broadcasted ) = cat_nested_args (bc. args)
194+ cat_nested_args (:: Tuple{} ) = ()
195+ cat_nested_args (t:: Tuple ) = (cat_nested (t[1 ])... , cat_nested_args (tail (t))... )
196+ cat_nested (@nospecialize (a)) = (a,)
197+
198+ function make_makeargs (args:: Tuple , len, flags)
199+ makeargs, r = _make_makeargs (args, len, flags)
200+ r isa Tuple{} || error (" Internal error. Please file a bug" )
201+ return makeargs
202+ end
203+
204+ # We build `makeargs` by traversing the broadcast nodes recursively.
205+ # note: `len` isa `Val` indicates the length of whole flattened argument list.
206+ # `flags` is a tuple of `Bool` with the same length of the rest arguments.
207+ @inline function _make_makeargs (args:: Tuple , len:: Val , flags:: Tuple )
208+ head, flags′ = _make_makeargs1 (args[1 ], len, flags)
209+ rest, flags″ = _make_makeargs (tail (args), len, flags′)
210+ (head, rest... ), flags″
211+ end
212+ _make_makeargs (:: Tuple{} , :: Val , x:: Tuple ) = (), x
213+
214+ # For flat nodes:
215+ # 1. we just consume one argument, and return the "pick" function
216+ @inline function _make_makeargs1 (@nospecialize (a), :: Val{N} , flags:: Tuple ) where {N}
217+ pickargs (:: Val{N} ) where {N} = (@nospecialize (x:: Tuple )) -> x[N]
218+ return pickargs (Val {N-length(flags)+1} ()), tail (flags)
219+ end
220+
221+ # For nested nodes, we form the `makeargs1` based on the child `makeargs` (n += length(cat_nested(bc)))
222+ @inline function _make_makeargs1 (bc:: Broadcasted , len:: Val , flags:: Tuple )
223+ makeargs, flags′ = _make_makeargs (bc. args, len, flags)
224+ f = maybeconstructor (bc. f)
225+ @inline makeargs1 (@nospecialize (args:: Tuple )) = f (prepare_args (makeargs, args)... )
226+ makeargs1, flags′
227+ end
228+
229+ prepare_args (:: Tuple{} , @nospecialize (:: Tuple )) = ()
230+ @inline prepare_args (makeargs:: Tuple , @nospecialize (x:: Tuple )) = (makeargs[1 ](x), prepare_args (tail (makeargs), x)... )
231+ end
232+ using . StableFlatten
0 commit comments