1+ """
2+ _InitialValue
3+
4+ A singleton type for representing "universal" initial value (identity element).
5+
6+ The idea is that, given `op` for `mapfoldl`, virtually, we define an "extended"
7+ version of it by
8+
9+ op′(::_InitialValue, x) = x
10+ op′(acc, x) = op(acc, x)
11+
12+ This is just a conceptually useful model to have in mind and we don't actually
13+ define `op′` here (yet?). But see `Base.BottomRF` for how it might work in
14+ action.
15+
16+ (It is related to that you can always turn a semigroup without an identity into
17+ a monoid by "adjoining" an element that acts as the identity.)
18+ """
19+ struct _InitialValue end
20+
121@inline _first (a1, as... ) = a1
222
323# ###############
86106# # mapreduce ##
87107# ##############
88108
89- @inline function mapreduce (f, op, a:: StaticArray , b:: StaticArray... ; dims= :,kw ... )
90- _mapreduce (f, op, dims, kw . data , same_size (a, b... ), a, b... )
109+ @inline function mapreduce (f, op, a:: StaticArray , b:: StaticArray... ; dims= :, init = _InitialValue () )
110+ _mapreduce (f, op, dims, init , same_size (a, b... ), a, b... )
91111end
92112
93- @generated function _mapreduce (f, op, dims:: Colon , nt:: NamedTuple{()} ,
94- :: Size{S} , a:: StaticArray... ) where {S}
113+ @inline _mapreduce (args:: Vararg{Any,N} ) where N = _mapfoldl (args... )
114+
115+ @generated function _mapfoldl (f, op, dims:: Colon , init, :: Size{S} , a:: StaticArray... ) where {S}
95116 tmp = [:(a[$ j][1 ]) for j ∈ 1 : length (a)]
96117 expr = :(f ($ (tmp... )))
97- for i ∈ 2 : prod (S)
98- tmp = [:(a[$ j][$ i]) for j ∈ 1 : length (a)]
99- expr = :(op ($ expr, f ($ (tmp... ))))
100- end
101- return quote
102- @_inline_meta
103- @inbounds return $ expr
118+ if init === _InitialValue
119+ expr = :(Base. reduce_first (op, $ expr))
120+ else
121+ expr = :(op (init, $ expr))
104122 end
105- end
106-
107- @generated function _mapreduce (f, op, dims:: Colon , nt:: NamedTuple{(:init,)} ,
108- :: Size{S} , a:: StaticArray... ) where {S}
109- expr = :(nt. init)
110- for i ∈ 1 : prod (S)
123+ for i ∈ 2 : prod (S)
111124 tmp = [:(a[$ j][$ i]) for j ∈ 1 : length (a)]
112125 expr = :(op ($ expr, f ($ (tmp... ))))
113126 end
@@ -117,24 +130,24 @@ end
117130 end
118131end
119132
120- @inline function _mapreduce (f, op, D:: Int , nt :: NamedTuple , sz:: Size{S} , a:: StaticArray ) where {S}
133+ @inline function _mapreduce (f, op, D:: Int , init , sz:: Size{S} , a:: StaticArray ) where {S}
121134 # Body of this function is split because constant propagation (at least
122135 # as of Julia 1.2) can't always correctly propagate here and
123136 # as a result the function is not type stable and very slow.
124137 # This makes it at least fast for three dimensions but people should use
125138 # for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
126139 if D == 1
127- return _mapreduce (f, op, Val (1 ), nt , sz, a)
140+ return _mapreduce (f, op, Val (1 ), init , sz, a)
128141 elseif D == 2
129- return _mapreduce (f, op, Val (2 ), nt , sz, a)
142+ return _mapreduce (f, op, Val (2 ), init , sz, a)
130143 elseif D == 3
131- return _mapreduce (f, op, Val (3 ), nt , sz, a)
144+ return _mapreduce (f, op, Val (3 ), init , sz, a)
132145 else
133- return _mapreduce (f, op, Val (D), nt , sz, a)
146+ return _mapreduce (f, op, Val (D), init , sz, a)
134147 end
135148end
136149
137- @generated function _mapreduce (f, op, dims:: Val{D} , nt :: NamedTuple{()} ,
150+ @generated function _mapfoldl (f, op, dims:: Val{D} , init ,
138151 :: Size{S} , a:: StaticArray ) where {S,D}
139152 N = length (S)
140153 Snew = ([n== D ? 1 : S[n] for n = 1 : N]. .. ,)
@@ -143,32 +156,12 @@ end
143156 itr = [1 : n for n ∈ Snew]
144157 for i ∈ Base. product (itr... )
145158 expr = :(f (a[$ (i... )]))
146- for k = 2 : S[D]
147- ik = collect (i )
148- ik[D] = k
149- expr = :(op ($ expr, f (a[ $ (ik ... )]) ))
159+ if init === _InitialValue
160+ expr = :(Base . reduce_first (op, $ expr) )
161+ else
162+ expr = :(op (init, $ expr ))
150163 end
151-
152- exprs[i... ] = expr
153- end
154-
155- return quote
156- @_inline_meta
157- @inbounds elements = tuple ($ (exprs... ))
158- @inbounds return similar_type (a, eltype (elements), Size ($ Snew))(elements)
159- end
160- end
161-
162- @generated function _mapreduce (f, op, dims:: Val{D} , nt:: NamedTuple{(:init,)} ,
163- :: Size{S} , a:: StaticArray ) where {S,D}
164- N = length (S)
165- Snew = ([n== D ? 1 : S[n] for n = 1 : N]. .. ,)
166-
167- exprs = Array {Expr} (undef, Snew)
168- itr = [1 : n for n = Snew]
169- for i ∈ Base. product (itr... )
170- expr = :(nt. init)
171- for k = 1 : S[D]
164+ for k = 2 : S[D]
172165 ik = collect (i)
173166 ik[D] = k
174167 expr = :(op ($ expr, f (a[$ (ik... )])))
@@ -188,20 +181,37 @@ end
188181# # reduce ##
189182# ###########
190183
191- @inline reduce (op, a:: StaticArray ; dims= :, kw... ) = _reduce (op, a, dims, kw. data)
184+ @inline reduce (op, a:: StaticArray ; dims = :, init = _InitialValue ()) =
185+ _reduce (op, a, dims, init)
192186
193187# disambiguation
194188reduce (:: typeof (vcat), A:: StaticArray{<:Tuple,<:AbstractVecOrMat} ) =
195189 Base. _typed_vcat (mapreduce (eltype, promote_type, A), A)
196190reduce (:: typeof (vcat), A:: StaticArray{<:Tuple,<:StaticVecOrMatLike} ) =
197- _reduce (vcat, A, :, NamedTuple ())
191+ _reduce (vcat, A, :, _InitialValue ())
198192
199193reduce (:: typeof (hcat), A:: StaticArray{<:Tuple,<:AbstractVecOrMat} ) =
200194 Base. _typed_hcat (mapreduce (eltype, promote_type, A), A)
201195reduce (:: typeof (hcat), A:: StaticArray{<:Tuple,<:StaticVecOrMatLike} ) =
202- _reduce (hcat, A, :, NamedTuple ())
196+ _reduce (hcat, A, :, _InitialValue ())
197+
198+ @inline _reduce (op, a:: StaticArray , dims, init = _InitialValue ()) =
199+ _mapreduce (identity, op, dims, init, Size (a), a)
203200
204- @inline _reduce (op, a:: StaticArray , dims, kw:: NamedTuple = NamedTuple ()) = _mapreduce (identity, op, dims, kw, Size (a), a)
201+ # ###############
202+ # # (map)foldl ##
203+ # ###############
204+
205+ # Using `where {R}` to force specialization. See:
206+ # https://docs.julialang.org/en/v1.5-dev/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing-1
207+ # https://github.com/JuliaLang/julia/pull/33917
208+
209+ @inline mapfoldl (f:: F , op:: R , a:: StaticArray ; init = _InitialValue ()) where {F,R} =
210+ _mapfoldl (f, op, :, init, Size (a), a)
211+ @inline foldl (op:: R , a:: StaticArray ; init = _InitialValue ()) where {R} =
212+ _foldl (op, a, :, init)
213+ @inline _foldl (op:: R , a, dims, init = _InitialValue ()) where {R} =
214+ _mapfoldl (identity, op, dims, init, Size (a), a)
205215
206216# ######################
207217# # related functions ##
@@ -227,37 +237,37 @@ reduce(::typeof(hcat), A::StaticArray{<:Tuple,<:StaticVecOrMatLike}) =
227237@inline iszero (a:: StaticArray{<:Tuple,T} ) where {T} = reduce ((x,y) -> x && iszero (y), a, init= true )
228238
229239@inline sum (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _reduce (+ , a, dims)
230- @inline sum (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, NamedTuple (), Size (a), a)
231- @inline sum (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, NamedTuple (), Size (a), a) # avoid ambiguity
240+ @inline sum (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, _InitialValue (), Size (a), a)
241+ @inline sum (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, _InitialValue (), Size (a), a) # avoid ambiguity
232242
233243@inline prod (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _reduce (* , a, dims)
234- @inline prod (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, NamedTuple (), Size (a), a)
235- @inline prod (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, NamedTuple (), Size (a), a)
244+ @inline prod (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, _InitialValue (), Size (a), a)
245+ @inline prod (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, _InitialValue (), Size (a), a)
236246
237247@inline count (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (+ , a, dims)
238- @inline count (f, a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , + , dims, NamedTuple (), Size (a), a)
248+ @inline count (f, a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , + , dims, _InitialValue (), Size (a), a)
239249
240- @inline all (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (& , a, dims, (init = true ,) ) # non-branching versions
241- @inline all (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , & , dims, (init = true ,) , Size (a), a)
250+ @inline all (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (& , a, dims, true ) # non-branching versions
251+ @inline all (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , & , dims, true , Size (a), a)
242252
243- @inline any (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (| , a, dims, (init = false ,) ) # (benchmarking needed)
244- @inline any (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , | , dims, (init = false ,) , Size (a), a) # (benchmarking needed)
253+ @inline any (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (| , a, dims, false ) # (benchmarking needed)
254+ @inline any (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , | , dims, false , Size (a), a) # (benchmarking needed)
245255
246- @inline Base. in (x, a:: StaticArray ) = _mapreduce (== (x), | , :, (init = false ,) , Size (a), a)
256+ @inline Base. in (x, a:: StaticArray ) = _mapreduce (== (x), | , :, false , Size (a), a)
247257
248258_mean_denom (a, dims:: Colon ) = length (a)
249259_mean_denom (a, dims:: Int ) = size (a, dims)
250260_mean_denom (a, :: Val{D} ) where {D} = size (a, D)
251261_mean_denom (a, :: Type{Val{D}} ) where {D} = size (a, D)
252262
253263@inline mean (a:: StaticArray ; dims= :) = _reduce (+ , a, dims) / _mean_denom (a, dims)
254- @inline mean (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, + , dims, NamedTuple (), Size (a), a) / _mean_denom (a, dims)
264+ @inline mean (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, + , dims, _InitialValue (), Size (a), a) / _mean_denom (a, dims)
255265
256266@inline minimum (a:: StaticArray ; dims= :) = _reduce (min, a, dims) # base has mapreduce(idenity, scalarmin, a)
257- @inline minimum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, min, dims, NamedTuple (), Size (a), a)
267+ @inline minimum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, min, dims, _InitialValue (), Size (a), a)
258268
259269@inline maximum (a:: StaticArray ; dims= :) = _reduce (max, a, dims) # base has mapreduce(idenity, scalarmax, a)
260- @inline maximum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, max, dims, NamedTuple (), Size (a), a)
270+ @inline maximum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, max, dims, _InitialValue (), Size (a), a)
261271
262272# Diff is slightly different
263273@inline diff (a:: StaticArray ; dims) = _diff (Size (a), a, dims)
286296 end
287297end
288298
289- struct _InitialValue end
290-
291299_maybe_val (dims:: Integer ) = Val (Int (dims))
292300_maybe_val (dims) = dims
293301_valof (:: Val{D} ) where D = D
@@ -299,19 +307,18 @@ _valof(::Val{D}) where D = D
299307 _accumulate (op, a, _maybe_val (dims), init)
300308
301309@inline function _accumulate (op:: F , a:: StaticArray , dims:: Union{Val,Colon} , init) where {F}
302- # Adjoin the initial value to `op`:
310+ # Adjoin the initial value to `op` (one-line version of `Base.BottomRF`) :
303311 rf (x, y) = x isa _InitialValue ? Base. reduce_first (op, y) : op (x, y)
304312
305313 if isempty (a)
306314 T = return_type (rf, Tuple{typeof (init), eltype (a)})
307315 return similar_type (a, T)()
308316 end
309317
310- # StaticArrays' `reduce` is `foldl`:
311- results = _reduce (
318+ results = _foldl (
312319 a,
313320 dims,
314- (init = ( similar_type (a, Union{}, Size (0 ))(), init), ),
321+ (similar_type (a, Union{}, Size (0 ))(), init),
315322 ) do (ys, acc), x
316323 y = rf (acc, x)
317324 # Not using `push(ys, y)` here since we need to widen element type as
0 commit comments