@@ -103,10 +103,23 @@ end
103103 end
104104end
105105
106- @inline _mapreduce (f, op, D:: Int , nt:: NamedTuple , sz:: Size{S} , a:: StaticArray ) where {S} =
107- _mapreduce (f, op, Val (D), nt, sz, a)
106+ @inline function _mapreduce (f, op, D:: Int , nt:: NamedTuple , sz:: Size{S} , a:: StaticArray ) where {S}
107+ # Body of this function is split because constant propagation (at least
108+ # as of Julia 1.2) can't always correctly propagate here and
109+ # as a result the function is not type stable and very slow.
110+ # This makes it at least fast for three dimensions but people should use
111+ # for example any(a; dims=Val(1)) instead of any(a; dims=1) anyway.
112+ if D == 1
113+ return _mapreduce (f, op, Val (1 ), nt, sz, a)
114+ elseif D == 2
115+ return _mapreduce (f, op, Val (2 ), nt, sz, a)
116+ elseif D == 3
117+ return _mapreduce (f, op, Val (3 ), nt, sz, a)
118+ else
119+ return _mapreduce (f, op, Val (D), nt, sz, a)
120+ end
121+ end
108122
109-
110123@generated function _mapreduce (f, op, dims:: Val{D} , nt:: NamedTuple{()} ,
111124 :: Size{S} , a:: StaticArray ) where {S,D}
112125 N = length (S)
161174# # reduce ##
162175# ###########
163176
164- @inline reduce (op, a:: StaticArray ; kw... ) = mapreduce (identity, op, a; kw... )
177+ @inline reduce (op, a:: StaticArray ; dims= :, kw... ) = _reduce (op, a, dims, kw. data)
178+
179+ @inline _reduce (op, a:: StaticArray , dims, kw:: NamedTuple = NamedTuple ()) = _mapreduce (identity, op, dims, kw, Size (a), a)
165180
166181# ######################
167182# # related functions ##
@@ -186,38 +201,38 @@ end
186201# TODO : change to use Base.reduce_empty/Base.reduce_first
187202@inline iszero (a:: StaticArray{<:Tuple,T} ) where {T} = reduce ((x,y) -> x && iszero (y), a, init= true )
188203
189- @inline sum (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = reduce (+ , a; dims = dims)
190- @inline sum (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = mapreduce (f, + , a; dims = dims )
191- @inline sum (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = mapreduce (f, + , a; dims = dims ) # avoid ambiguity
204+ @inline sum (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _reduce (+ , a, dims)
205+ @inline sum (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, NamedTuple (), Size (a), a )
206+ @inline sum (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, + , dims, NamedTuple (), Size (a), a ) # avoid ambiguity
192207
193- @inline prod (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = reduce (* , a; dims = dims)
194- @inline prod (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = mapreduce (f, * , a; dims = dims )
195- @inline prod (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = mapreduce (f, * , a; dims = dims )
208+ @inline prod (a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _reduce (* , a, dims)
209+ @inline prod (f, a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, NamedTuple (), Size (a), a )
210+ @inline prod (f:: Union{Function, Type} , a:: StaticArray{<:Tuple,T} ; dims= :) where {T} = _mapreduce (f, * , dims, NamedTuple (), Size (a), a )
196211
197- @inline count (a:: StaticArray{<:Tuple,Bool} ; dims= :) = reduce (+ , a; dims = dims)
198- @inline count (f, a:: StaticArray ; dims= :) = mapreduce (x-> f (x):: Bool , + , a; dims = dims )
212+ @inline count (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (+ , a, dims)
213+ @inline count (f, a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , + , dims, NamedTuple (), Size (a), a )
199214
200- @inline all (a:: StaticArray{<:Tuple,Bool} ; dims= :) = reduce (& , a; dims= dims, init= true ) # non-branching versions
201- @inline all (f:: Function , a:: StaticArray ; dims= :) = mapreduce (x-> f (x):: Bool , & , a; dims= dims, init= true )
215+ @inline all (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (& , a, dims, ( init= true ,) ) # non-branching versions
216+ @inline all (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , & , dims, ( init= true ,), Size (a), a )
202217
203- @inline any (a:: StaticArray{<:Tuple,Bool} ; dims= :) = reduce (| , a; dims= dims, init= false ) # (benchmarking needed)
204- @inline any (f:: Function , a:: StaticArray ; dims= :) = mapreduce (x-> f (x):: Bool , | , a; dims= dims, init= false ) # (benchmarking needed)
218+ @inline any (a:: StaticArray{<:Tuple,Bool} ; dims= :) = _reduce (| , a, dims, ( init= false ,) ) # (benchmarking needed)
219+ @inline any (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (x-> f (x):: Bool , | , dims, ( init= false ,), Size (a), a ) # (benchmarking needed)
205220
206- @inline Base. in (x, a:: StaticArray ) = mapreduce (== (x), | , a, init= false )
221+ @inline Base. in (x, a:: StaticArray ) = _mapreduce (== (x), | , :, ( init= false ,), Size (a), a )
207222
208223_mean_denom (a, dims:: Colon ) = length (a)
209224_mean_denom (a, dims:: Int ) = size (a, dims)
210225_mean_denom (a, :: Val{D} ) where {D} = size (a, D)
211226_mean_denom (a, :: Type{Val{D}} ) where {D} = size (a, D)
212227
213- @inline mean (a:: StaticArray ; dims= :) = sum (a; dims= dims ) / _mean_denom (a,dims)
214- @inline mean (f:: Function , a:: StaticArray ;dims= :) = sum (f, a; dims= dims) / _mean_denom (a,dims)
228+ @inline mean (a:: StaticArray ; dims= :) = _reduce ( + , a, dims) / _mean_denom (a, dims)
229+ @inline mean (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, + , dims, NamedTuple (), Size (a), a) / _mean_denom (a, dims)
215230
216- @inline minimum (a:: StaticArray ; dims= :) = reduce (min, a; dims = dims) # base has mapreduce(idenity, scalarmin, a)
217- @inline minimum (f:: Function , a:: StaticArray ; dims= :) = mapreduce (f, min, a; dims = dims )
231+ @inline minimum (a:: StaticArray ; dims= :) = _reduce (min, a, dims) # base has mapreduce(idenity, scalarmin, a)
232+ @inline minimum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, min, dims, NamedTuple (), Size (a), a )
218233
219- @inline maximum (a:: StaticArray ; dims= :) = reduce (max, a; dims = dims) # base has mapreduce(idenity, scalarmax, a)
220- @inline maximum (f:: Function , a:: StaticArray ; dims= :) = mapreduce (f, max, a; dims = dims )
234+ @inline maximum (a:: StaticArray ; dims= :) = _reduce (max, a, dims) # base has mapreduce(idenity, scalarmax, a)
235+ @inline maximum (f:: Function , a:: StaticArray ; dims= :) = _mapreduce (f, max, dims, NamedTuple (), Size (a), a )
221236
222237# Diff is slightly different
223238@inline diff (a:: StaticArray ; dims) = _diff (Size (a), a, dims)
0 commit comments