@@ -31,19 +31,6 @@ Base.mapreduce(f, op, A::Broadcast.Broadcasted{<:AbstractGPUArrayStyle}, As::Abs
3131 dims= :, init= nothing ) = _mapreduce (f, op, A, As... ; dims= dims, init= init)
3232
3333function _mapreduce (f:: F , op:: OP , As:: Vararg{Any,N} ; dims:: D , init) where {F,OP,N,D}
34- # mapreduce should apply `f` like `map` does, consuming elements like iterators
35- bc = if allequal (size .(As)... )
36- Broadcast. instantiate (Broadcast. broadcasted (f, As... ))
37- else
38- # TODO : can we avoid the reshape + view?
39- indices = LinearIndices .(As)
40- common_length = minimum (length .(indices))
41- Bs = map (As) do A
42- view (reshape (A, length (A)), 1 : common_length)
43- end
44- Broadcast. instantiate (Broadcast. broadcasted (f, Bs... ))
45- end
46-
4734 # figure out the destination container type by looking at the initializer element,
4835 # or by relying on inference to reason through the map and reduce functions
4936 if init === nothing
@@ -57,16 +44,39 @@ function _mapreduce(f::F, op::OP, As::Vararg{Any,N}; dims::D, init) where {F,OP,
5744 ET = typeof (init)
5845 end
5946
60- sz = size (bc)
47+ # apply the mapping function to the input arrays
48+ if N == 1
49+ # ... with only a single input, we can defer this to the reduce step
50+ A = only (As)
51+ else
52+ # mapreduce should apply `f` like `map` does, consuming elements like iterators
53+ A = if allequal (size .(As)... )
54+ Broadcast. instantiate (Broadcast. broadcasted (f, As... ))
55+ else
56+ # TODO : can we avoid the reshape + view?
57+ indices = LinearIndices .(As)
58+ common_length = minimum (length .(indices))
59+ Bs = map (As) do A
60+ view (reshape (A, length (A)), 1 : common_length)
61+ end
62+ Broadcast. instantiate (Broadcast. broadcasted (f, Bs... ))
63+ end
64+ f = identity
65+ end
66+
67+ # allocate an output container
68+ sz = size (A)
6169 red = ntuple (i-> (dims== Colon () || i in dims) ? 1 : sz[i], length (sz))
62- R = similar (bc , ET, red)
70+ R = similar (A , ET, red)
6371
72+ # perform the reduction
6473 if prod (sz) == 0
6574 fill! (R, init)
6675 else
67- mapreducedim! (identity , op, R, bc; init = init)
76+ mapreducedim! (f , op, R, A; init)
6877 end
6978
79+ # return the result
7080 if dims === Colon ()
7181 @allowscalar R[]
7282 else
0 commit comments