@@ -1118,16 +1118,16 @@ end
11181118function _mapreduce_kernel_commutative (f, op, A, init, inds, leading= (), trailing= ())
11191119 i1, iN = firstindex (inds), lastindex (inds)
11201120 n = length (inds)
1121- @nexprs 4 N-> a_N = @inbounds A[leading... , inds[i1+ (N- 1 )], trailing... ]
1122- @nexprs 4 N-> v_N = _mapreduce_start (f, op, A, init, a_N)
1123- for batch in 1 : (n>> 2 )- 1
1124- i = i1 + batch* 4
1125- @nexprs 4 N-> a_N = @inbounds A[leading... , inds[i+ (N- 1 )], trailing... ]
1126- @nexprs 4 N-> fa_N = f (a_N)
1127- @nexprs 4 N-> v_N = op (v_N, fa_N)
1128- end
1129- v = op (op (v_1, v_2), op (v_3, v_4))
1130- i = i1 + (n>> 2 ) * 4 - 1
1121+ @nexprs 8 N-> a_N = @inbounds A[leading... , inds[i1+ (N- 1 )], trailing... ]
1122+ @nexprs 8 N-> v_N = _mapreduce_start (f, op, A, init, a_N)
1123+ for batch in 1 : (n>> 3 )- 1
1124+ i = i1 + batch* 8
1125+ @nexprs 8 N-> a_N = @inbounds A[leading... , inds[i+ (N- 1 )], trailing... ]
1126+ @nexprs 8 N-> fa_N = f (a_N)
1127+ @nexprs 8 N-> v_N = op (v_N, fa_N)
1128+ end
1129+ v = op (op (op ( v_1, v_2), op (v_3, v_4)), op ( op (v_5, v_6), op (v_7, v_8) ))
1130+ i = i1 + (n>> 3 ) * 8 - 1
11311131 i == iN && return v
11321132 for i in i+ 1 : iN
11331133 ai = @inbounds A[leading... , inds[i], trailing... ]
@@ -1180,23 +1180,23 @@ function mapreduce_kernel_commutative(f, op, itr, init, ::Union{HasLength, HasSh
11801180 end
11811181 return v_1, s
11821182 end
1183- @nexprs 3 n-> begin
1183+ @nexprs 7 n-> begin
11841184 it = iterate (itr, s)
11851185 it === nothing && _throw_iterator_assertion_error ()
11861186 a, s = it
11871187 v_{n+ 1 } = _mapreduce_start (f, op, itr, init, a)
11881188 end
1189- i = 4
1190- for outer i in 8 : 4 : n
1191- @nexprs 4 n-> begin
1189+ i = 8
1190+ for outer i in 16 : 8 : n
1191+ @nexprs 8 n-> begin
11921192 it = iterate (itr, s)
11931193 it === nothing && _throw_iterator_assertion_error ()
11941194 a_n, s = it
11951195 end
1196- @nexprs 4 n-> fa_n = f (a_n)
1197- @nexprs 4 n-> v_n = op (v_n, fa_n)
1196+ @nexprs 8 n-> fa_n = f (a_n)
1197+ @nexprs 8 n-> v_n = op (v_n, fa_n)
11981198 end
1199- v = op (op (v_1, v_2), op (v_3, v_4))
1199+ v = op (op (op ( v_1, v_2), op (v_3, v_4)), op ( op (v_5, v_6), op (v_7, v_8) ))
12001200 for _ in i+ 1 : n
12011201 it = iterate (itr, s)
12021202 it === nothing && _throw_iterator_assertion_error ()
@@ -1222,23 +1222,43 @@ function mapreduce_kernel_commutative(f, op, itr, init, ::IteratorSize, n, state
12221222 it === nothing && return Some (op (op (v_1, v_2), v_3))
12231223 a, s = it
12241224 v_4 = _mapreduce_start (f, op, itr, init, a)
1225- for _ in 2 : n>> 2
1226- @nexprs 4 N-> begin
1225+ it = iterate (itr, s)
1226+ it === nothing && return Some (op (op (v_1, v_2), op (v_3, v_4)))
1227+ a, s = it
1228+ v_5 = _mapreduce_start (f, op, itr, init, a)
1229+ it = iterate (itr, s)
1230+ it === nothing && return Some (op (op (op (v_1, v_2), op (v_3, v_4)), v_5))
1231+ a, s = it
1232+ v_6 = _mapreduce_start (f, op, itr, init, a)
1233+ it = iterate (itr, s)
1234+ it === nothing && return Some (op (op (op (v_1, v_2), op (v_3, v_4)), op (v_5, v_6)))
1235+ a, s = it
1236+ v_7 = _mapreduce_start (f, op, itr, init, a)
1237+ it = iterate (itr, s)
1238+ it === nothing && return Some (op (op (op (v_1, v_2), op (v_3, v_4)), op (op (v_5, v_6), v_7)))
1239+ a, s = it
1240+ v_8 = _mapreduce_start (f, op, itr, init, a)
1241+ for _ in 3 : n>> 3
1242+ @nexprs 8 N-> begin
12271243 it = iterate (itr, s)
12281244 if it === nothing
1245+ N > 7 && (v_7 = op (v_7, f (a_7)))
1246+ N > 6 && (v_6 = op (v_6, f (a_6)))
1247+ N > 5 && (v_5 = op (v_5, f (a_5)))
1248+ N > 4 && (v_4 = op (v_4, f (a_4)))
12291249 N > 3 && (v_3 = op (v_3, f (a_3)))
12301250 N > 2 && (v_2 = op (v_2, f (a_2)))
12311251 N > 1 && (v_1 = op (v_1, f (a_1)))
1232- return Some (op (op (v_1, v_2), op (v_3, v_4)))
1252+ return Some (op (op (op ( v_1, v_2), op (v_3, v_4)), op ( op (v_5, v_6), op (v_7, v_8) )))
12331253 end
12341254 a_N, s = it
12351255 end
1236- @nexprs 4 N-> fa_N = f (a_N)
1237- @nexprs 4 N-> v_N = op (v_N, fa_N)
1256+ @nexprs 8 N-> fa_N = f (a_N)
1257+ @nexprs 8 N-> v_N = op (v_N, fa_N)
12381258 end
1239- v = op (op (v_1, v_2), op (v_3, v_4))
1240- i = (n>> 2 ) * 4
1241- @nexprs 4 N-> begin
1259+ v = op (op (op ( v_1, v_2), op (v_3, v_4)), op ( op (v_5, v_6), op (v_7, v_8) ))
1260+ i = (n>> 3 ) * 8
1261+ @nexprs 8 N-> begin
12421262 it = iterate (itr, s)
12431263 if it === nothing
12441264 return Some (v)
0 commit comments