1+ """
2+ _InitialValue
3+
4+ A singleton type for representing "universal" initial value (identity element).
5+
6+ The idea is that, given `op` for `mapfoldl`, virtually, we define an "extended"
7+ version of it by
8+
9+ op′(::_InitialValue, x) = x
10+ op′(acc, x) = op(acc, x)
11+
12+ This is just a conceptually useful model to have in mind and we don't actually
13+ define `op′` here (yet?). But see `Base.BottomRF` for how it might work in
14+ action.
15+
16+ (It is related to that you can always turn a semigroup without an identity into
17+ a monoid by "adjoining" an element that acts as the identity.)
18+ """
19+ struct _InitialValue end
20+
121@inline _first (a1, as... ) = a1
222
323# ###############
86106# # mapreduce ##
87107# ##############
88108
89- @inline function mapreduce (f, op, a:: StaticArray , b:: StaticArray... ; dims= :,kw ... )
90- _mapreduce (f, op, dims, kw . data , same_size (a, b... ), a, b... )
109+ @inline function mapreduce (f, op, a:: StaticArray , b:: StaticArray... ; dims= :, init = _InitialValue () )
110+ _mapreduce (f, op, dims, init , same_size (a, b... ), a, b... )
91111end
92112
93- @generated function _mapreduce (f, op, dims:: Colon , nt:: NamedTuple{()} ,
94- :: Size{S} , a:: StaticArray... ) where {S}
113+ @inline _mapreduce (args:: Vararg{Any,N} ) where N = _mapfoldl (args... )
114+
115+ @generated function _mapfoldl (f, op, dims:: Colon , init, :: Size{S} , a:: StaticArray... ) where {S}
95116 tmp = [:(a[$ j][1 ]) for j ∈ 1 : length (a)]
96117 expr = :(f ($ (tmp... )))
97- for i ∈ 2 : prod (S)
98- tmp = [:(a[$ j][$ i]) for j ∈ 1 : length (a)]
99- expr = :(op ($ expr, f ($ (tmp... ))))
100- end
101- return quote
102- @_inline_meta
103- @inbounds return $ expr
118+ if init === _InitialValue
119+ expr = :(Base. reduce_first (op, $ expr))
120+ else
121+ expr = :(op (init, $ expr))
104122 end
105- end
106-
107- @generated function _mapreduce (f, op, dims:: Colon , nt:: NamedTuple{(:init,)} ,
108- :: Size{S} , a:: StaticArray... ) where {S}
109- expr = :(nt. init)
110- for i ∈ 1 : prod (S)
123+ for i ∈ 2 : prod (S)
111124 tmp = [:(a[$ j][$ i]) for j ∈ 1 : length (a)]
112125 expr = :(op ($ expr, f ($ (tmp... ))))
113126 end
@@ -117,24 +130,24 @@ end
117130 end
118131end
119132
120- @inline function _mapreduce (f, op, D:: Int , nt :: NamedTuple , sz:: Size{S} , a:: StaticArray ) where {S}
133+ @inline function _mapreduce (f, op, D:: Int , init , sz:: Size{S} , a:: StaticArray ) where {S}
121134 # Body of this function is split because constant propagation (at least
122135 # as of Julia 1.2) can't always correctly propagate here and
123136 # as a result the function is not type stable and very slow.
124137 # This makes it at least fast for three dimensions but people should use
125138 # for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
126139 if D == 1
127- return _mapreduce (f, op, Val (1 ), nt , sz, a)
140+ return _mapreduce (f, op, Val (1 ), init , sz, a)
128141 elseif D == 2
129- return _mapreduce (f, op, Val (2 ), nt , sz, a)
142+ return _mapreduce (f, op, Val (2 ), init , sz, a)
130143 elseif D == 3
131- return _mapreduce (f, op, Val (3 ), nt , sz, a)
144+ return _mapreduce (f, op, Val (3 ), init , sz, a)
132145 else
133- return _mapreduce (f, op, Val (D), nt , sz, a)
146+ return _mapreduce (f, op, Val (D), init , sz, a)
134147 end
135148end
136149
137- @generated function _mapreduce (f, op, dims:: Val{D} , nt :: NamedTuple{()} ,
150+ @generated function _mapfoldl (f, op, dims:: Val{D} , init ,
138151 :: Size{S} , a:: StaticArray ) where {S,D}
139152 N = length (S)
140153 Snew = ([n== D ? 1 : S[n] for n = 1 : N]. .. ,)
@@ -143,32 +156,12 @@ end
143156 itr = [1 : n for n ∈ Snew]
144157 for i ∈ Base. product (itr... )
145158 expr = :(f (a[$ (i... )]))
146- for k = 2 : S[D]
147- ik = collect (i )
148- ik[D] = k
149- expr = :(op ($ expr, f (a[ $ (ik ... )]) ))
159+ if init === _InitialValue
160+ expr = :(Base . reduce_first (op, $ expr) )
161+ else
162+ expr = :(op (init, $ expr ))
150163 end
151-
152- exprs[i... ] = expr
153- end
154-
155- return quote
156- @_inline_meta
157- @inbounds elements = tuple ($ (exprs... ))
158- @inbounds return similar_type (a, eltype (elements), Size ($ Snew))(elements)
159- end
160- end
161-
162- @generated function _mapreduce (f, op, dims:: Val{D} , nt:: NamedTuple{(:init,)} ,
163- :: Size{S} , a:: StaticArray ) where {S,D}
164- N = length (S)
165- Snew = ([n== D ? 1 : S[n] for n = 1 : N]. .. ,)
166-
167- exprs = Array {Expr} (undef, Snew)
168- itr = [1 : n for n = Snew]
169- for i ∈ Base. product (itr... )
170- expr = :(nt. init)
171- for k = 1 : S[D]
164+ for k = 2 : S[D]
172165 ik = collect (i)
173166 ik[D] = k
174167 expr = :(op ($ expr, f (a[$ (ik... )])))
@@ -188,20 +181,33 @@ end
188181# # reduce ##
189182# ###########
190183
191- @inline reduce (op, a:: StaticArray ; dims= :, kw... ) = _reduce (op, a, dims, kw. data)
184+ @inline reduce (op, a:: StaticArray ; dims = :, init = _InitialValue ()) =
185+ _reduce (op, a, dims, init)
192186
193187# disambiguation
194188reduce (:: typeof (vcat), A:: StaticArray{<:Tuple,<:AbstractVecOrMat} ) =
195189 Base. _typed_vcat (mapreduce (eltype, promote_type, A), A)
196190reduce (:: typeof (vcat), A:: StaticArray{<:Tuple,<:StaticVecOrMatLike} ) =
197- _reduce (vcat, A, :, NamedTuple ())
191+ _reduce (vcat, A, :, _InitialValue ())
198192
199193reduce (:: typeof (hcat), A:: StaticArray{<:Tuple,<:AbstractVecOrMat} ) =
200194 Base. _typed_hcat (mapreduce (eltype, promote_type, A), A)
201195reduce (:: typeof (hcat), A:: StaticArray{<:Tuple,<:StaticVecOrMatLike} ) =
202- _reduce (hcat, A, :, NamedTuple ())
196+ _reduce (hcat, A, :, _InitialValue ())
203197
204- @inline _reduce (op, a:: StaticArray , dims, kw:: NamedTuple = NamedTuple ()) = _mapreduce (identity, op, dims, kw, Size (a), a)
198+ @inline _reduce (op, a:: StaticArray , dims, init = _InitialValue ()) =
199+ _mapreduce (identity, op, dims, init, Size (a), a)
200+
201+ # ###############
202+ # # (map)foldl ##
203+ # ###############
204+
205+ @inline mapfoldl (f, op:: R , a:: StaticArray ; init = _InitialValue ()) where {R} =
206+ _mapfoldl (f, op, :, init, Size (a), a)
207+ @inline foldl (op:: R , a:: StaticArray ; init = _InitialValue ()) where {R} =
208+ _foldl (op, a, :, init)
209+ @inline _foldl (op:: R , a, dims, init = _InitialValue ()) where {R} =
210+ _mapfoldl (identity, op, dims, init, Size (a), a)
205211
206212# ######################
207213# # related functions ##
@@ -227,37 +233,37 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
227233@inline iszero (a:: StaticArray{<:Tuple,T} ) where {T} = reduce ((x,y) -> x && iszero (y), a, init= true )
228234
229235@inline sum (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _reduce (+ , a, dims)
230- @inline sum (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, NamedTuple (), Size (a), a)
231- @inline sum (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, NamedTuple (), Size (a), a) # avoid ambiguity
236+ @inline sum (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, _InitialValue (), Size (a), a)
237+ @inline sum (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, _InitialValue (), Size (a), a) # avoid ambiguity
232238
233239@inline prod (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _reduce (* , a, dims)
234- @inline prod (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, NamedTuple (), Size (a), a)
235- @inline prod (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, NamedTuple (), Size (a), a)
240+ @inline prod (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, _InitialValue (), Size (a), a)
241+ @inline prod (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, _InitialValue (), Size (a), a)
236242
237243@inline count (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (+ , a, dims)
238- @inline count (f, a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , + , dims, NamedTuple (), Size (a), a)
244+ @inline count (f, a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , + , dims, _InitialValue (), Size (a), a)
239245
240- @inline all (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (& , a, dims, (init = true ,) ) # non-branching versions
241- @inline all (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , & , dims, (init = true ,) , Size (a), a)
246+ @inline all (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (& , a, dims, true ) # non-branching versions
247+ @inline all (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , & , dims, true , Size (a), a)
242248
243- @inline any (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (| , a, dims, (init = false ,) ) # (benchmarking needed)
244- @inline any (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , | , dims, (init = false ,) , Size (a), a) # (benchmarking needed)
249+ @inline any (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (| , a, dims, false ) # (benchmarking needed)
250+ @inline any (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , | , dims, false , Size (a), a) # (benchmarking needed)
245251
246- @inline Base. in (x, a:: StaticArray ) = _mapreduce (== (x), | , :, (init = false ,) , Size (a), a)
252+ @inline Base. in (x, a:: StaticArray ) = _mapreduce (== (x), | , :, false , Size (a), a)
247253
248254_mean_denom (a, dims:: Colon ) = length (a)
249255_mean_denom (a, dims:: Int ) = size (a, dims)
250256_mean_denom (a, :: Val{D} ) where {D} = size (a, D)
251257_mean_denom (a, :: Type{Val{D}} ) where {D} = size (a, D)
252258
253259@inline mean (a:: StaticArray ; dims= :) = _reduce (+ , a, dims) / _mean_denom (a, dims)
254- @inline mean (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, + , dims, NamedTuple (), Size (a), a) / _mean_denom (a, dims)
260+ @inline mean (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, + , dims, _InitialValue (), Size (a), a) / _mean_denom (a, dims)
255261
256262@inline minimum (a:: StaticArray ; dims= :) = _reduce (min, a, dims) # base has mapreduce(idenity, scalarmin, a)
257- @inline minimum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, min, dims, NamedTuple (), Size (a), a)
263+ @inline minimum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, min, dims, _InitialValue (), Size (a), a)
258264
259265@inline maximum (a:: StaticArray ; dims= :) = _reduce (max, a, dims) # base has mapreduce(idenity, scalarmax, a)
260- @inline maximum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, max, dims, NamedTuple (), Size (a), a)
266+ @inline maximum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, max, dims, _InitialValue (), Size (a), a)
261267
262268# Diff is slightly different
263269@inline diff (a:: StaticArray ; dims) = _diff (Size (a), a, dims)
286292 end
287293end
288294
289- struct _InitialValue end
290-
291295_maybe_val (dims:: Integer ) = Val (Int (dims))
292296_maybe_val (dims) = dims
293297_valof (:: Val{D} ) where D = D
@@ -299,19 +303,18 @@ _valof(::Val{D}) where D = D
299303 _accumulate (op, a, _maybe_val (dims), init)
300304
301305@inline function _accumulate (op:: F , a:: StaticArray , dims:: Union{Val,Colon} , init) where {F}
302- # Adjoin the initial value to `op`:
306+ # Adjoin the initial value to `op` (one-line version of `Base.BottomRF`) :
303307 rf (x, y) = x isa _InitialValue ? Base. reduce_first (op, y) : op (x, y)
304308
305309 if isempty (a)
306310 T = return_type (rf, Tuple{typeof (init), eltype (a)})
307311 return similar_type (a, T)()
308312 end
309313
310- # StaticArrays' `reduce` is `foldl`:
311- results = _reduce (
314+ results = _foldl (
312315 a,
313316 dims,
314- (init = ( similar_type (a, Union{}, Size (0 ))(), init), ),
317+ (similar_type (a, Union{}, Size (0 ))(), init),
315318 ) do (ys, acc), x
316319 y = rf (acc, x)
317320 # Not using `push(ys, y)` here since we need to widen element type as
0 commit comments