Skip to content

Commit 7d83381

Browse files
committed
Check the legality of zero-index based initialization.
And add an optimized fallback for `keys::AbstractArray` (almost 2x faster with cheap `f`.)
1 parent a549929 commit 7d83381

File tree

3 files changed

+116
-40
lines changed

3 files changed

+116
-40
lines changed

base/reducedim.jl

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,21 +1025,25 @@ for (fname, op) in [(:sum, :add_sum), (:prod, :mul_prod),
10251025
end
10261026

10271027
##### findmin & findmax #####
1028+
# `iterate` based fallback for compatibility.
10281029
# The initial values of Rval are not used if the corresponding indices in Rind are 0.
1029-
#
1030-
function findminmax!(f, op, Rval, Rind, A::AbstractArray{T,N}) where {T,N}
1030+
function _findminmax!(f, op, Rval, Rind, A::AbstractArray{T,N}, keys, init = true) where {T,N}
10311031
(isempty(Rval) || isempty(A)) && return Rval, Rind
10321032
lsiz = check_reducedims(Rval, A)
10331033
for i = 1:N
10341034
axes(Rval, i) == axes(Rind, i) || throw(DimensionMismatch("Find-reduction: outputs must have the same indices"))
10351035
end
1036+
zi = zero(eltype(keys))
1037+
zi in keys && throw(ArgumentError(LazyString("`keys` containing ", zi, " is not supported!")))
1038+
if init
1039+
fill!(Rind, zi)
1040+
fill!(Rval, f(first(A)))
1041+
end
10361042
# If we're reducing along dimension 1, for efficiency we can make use of a temporary.
10371043
# Otherwise, keep the result in Rval/Rind so that we traverse A in storage order.
10381044
indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(Rval))
10391045
keep, Idefault = Broadcast.shapeindexer(indsRt)
1040-
ks = keys(A)
1041-
y = iterate(ks)
1042-
zi = zero(eltype(ks))
1046+
y = iterate(keys)
10431047
if reducedim1(Rval, A)
10441048
i1 = first(axes1(Rval))
10451049
@inbounds for IA in CartesianIndices(indsAt)
@@ -1053,7 +1057,7 @@ function findminmax!(f, op, Rval, Rind, A::AbstractArray{T,N}) where {T,N}
10531057
tmpRv = tmpAv
10541058
tmpRi = k
10551059
end
1056-
y = iterate(ks, kss)
1060+
y = iterate(keys, kss)
10571061
end
10581062
Rval[i1,IR] = tmpRv
10591063
Rind[i1,IR] = tmpRi
@@ -1070,7 +1074,57 @@ function findminmax!(f, op, Rval, Rind, A::AbstractArray{T,N}) where {T,N}
10701074
Rval[i,IR] = tmpAv
10711075
Rind[i,IR] = k
10721076
end
1073-
y = iterate(ks, kss)
1077+
y = iterate(keys, kss)
1078+
end
1079+
end
1080+
end
1081+
Rval, Rind
1082+
end
1083+
1084+
# Optimized fallback for `keys` with the same axes, e.g. `LinearIndices`, `CartesianIndices`.
1085+
# We initialize `Rval`/`Rind` via `map/copyfirst!` to support non-1 based `A`.
1086+
function _findminmax!(f, op, Rval, Rind, A::AbstractArray{T,N}, keys::AbstractArray{<:Any,N}, init = true) where {T,N}
1087+
axes(keys) == axes(A) || return @invoke _findminmax!(f, op, Rval, Rind, A::AbstractArray, keys::Any, init)
1088+
(isempty(Rval) || isempty(A)) && return Rval, Rind
1089+
lsiz = check_reducedims(Rval, A)
1090+
for i = 1:N
1091+
axes(Rval, i) == axes(Rind, i) || throw(DimensionMismatch("Find-reduction: outputs must have the same indices"))
1092+
end
1093+
if init
1094+
copyfirst!(Rind, keys)
1095+
mapfirst!(f, Rval, A)
1096+
end
1097+
# If we're reducing along dimension 1, for efficiency we can make use of a temporary.
1098+
# Otherwise, keep the result in Rval/Rind so that we traverse A in storage order.
1099+
indsAt, indsRt = safe_tail(axes(A)), safe_tail(axes(Rval))
1100+
keep, Idefault = Broadcast.shapeindexer(indsRt)
1101+
if reducedim1(Rval, A)
1102+
i1 = first(axes1(Rval))
1103+
@inbounds for IA in CartesianIndices(indsAt)
1104+
IR = Broadcast.newindex(IA, keep, Idefault)
1105+
tmpRv = Rval[i1,IR]
1106+
tmpRi = Rind[i1,IR]
1107+
for i in axes(A,1)
1108+
tmpAv = f(A[i,IA])
1109+
if op(tmpRv, tmpAv)
1110+
tmpRv = tmpAv
1111+
tmpRi = keys[i,IA]
1112+
end
1113+
end
1114+
Rval[i1,IR] = tmpRv
1115+
Rind[i1,IR] = tmpRi
1116+
end
1117+
else
1118+
@inbounds for IA in CartesianIndices(indsAt)
1119+
IR = Broadcast.newindex(IA, keep, Idefault)
1120+
for i in axes(A, 1)
1121+
tmpAv = f(A[i,IA])
1122+
tmpRv = Rval[i,IR]
1123+
tmpRi = Rind[i,IR]
1124+
if op(tmpRv, tmpAv)
1125+
Rval[i,IR] = tmpAv
1126+
Rind[i,IR] = keys[i,IA]
1127+
end
10741128
end
10751129
end
10761130
end
@@ -1086,7 +1140,7 @@ dimensions of `rval` and `rind`, and store the results in `rval` and `rind`.
10861140
"""
10871141
function findmin!(rval::AbstractArray, rind::AbstractArray, A::AbstractArray;
10881142
init::Bool=true)
1089-
findminmax!(identity, isgreater, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A)
1143+
_findminmax!(identity, isgreater, rval, rind, A, keys(A), init)
10901144
end
10911145

10921146
"""
@@ -1142,9 +1196,9 @@ function _findmin(f, A, region)
11421196
end
11431197
similar(A, promote_op(f, eltype(A)), ri), zeros(eltype(keys(A)), ri)
11441198
else
1145-
fA = f(first(A))
1146-
findminmax!(f, isgreater, fill!(similar(A, _findminmax_inittype(f, A), ri), fA),
1147-
zeros(eltype(keys(A)), ri), A)
1199+
rval = similar(A, _findminmax_inittype(f, A), ri)
1200+
rind = similar(A, eltype(keys(A)), ri)
1201+
_findminmax!(f, isgreater, rval, rind, A, keys(A))
11481202
end
11491203
end
11501204

@@ -1157,7 +1211,7 @@ dimensions of `rval` and `rind`, and store the results in `rval` and `rind`.
11571211
"""
11581212
function findmax!(rval::AbstractArray, rind::AbstractArray, A::AbstractArray;
11591213
init::Bool=true)
1160-
findminmax!(identity, isless, init && !isempty(A) ? fill!(rval, first(A)) : rval, fill!(rind,zero(eltype(keys(A)))), A)
1214+
_findminmax!(identity, isless, rval, rind, A, keys(A), init)
11611215
end
11621216

11631217
"""
@@ -1213,9 +1267,9 @@ function _findmax(f, A, region)
12131267
end
12141268
similar(A, promote_op(f, eltype(A)), ri), zeros(eltype(keys(A)), ri)
12151269
else
1216-
fA = f(first(A))
1217-
findminmax!(f, isless, fill!(similar(A, _findminmax_inittype(f, A), ri), fA),
1218-
zeros(eltype(keys(A)), ri), A)
1270+
rval = similar(A, _findminmax_inittype(f, A), ri)
1271+
rind = similar(A, eltype(keys(A)), ri)
1272+
_findminmax!(f, isless, rval, rind, A, keys(A))
12191273
end
12201274
end
12211275

test/offsetarray.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,3 +822,11 @@ end
822822
# this is fixed in #40038, so the evaluation of its CartesianIndices should work
823823
@test CartesianIndices(A) == CartesianIndices(B)
824824
end
825+
826+
# issue #38660
827+
@testset "`findmin/max` for OffsetArray" begin
828+
ov = OffsetVector([-1, 1], 0:1)
829+
@test @inferred(findmin(ov; dims = 1)) .|> first == (-1, 0)
830+
ov = OffsetVector([-1, 1], -1:0)
831+
@test @inferred(findmax(ov; dims = 1)) .|> first == (1, 0)
832+
end

test/reducedim.jl

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -273,33 +273,47 @@ end
273273
end
274274
end
275275

276+
struct WithIteratorsKeys{T,N} <: AbstractArray{T,N}
277+
data::Array{T,N}
278+
end
279+
Base.size(a::WithIteratorsKeys) = size(a.data)
280+
Base.getindex(a::WithIteratorsKeys, inds...) = a.data[inds...]
281+
struct IteratorsKeys{T,A<:AbstractArray{T}}
282+
iter::A
283+
IteratorsKeys(iter) = new{eltype(iter),typeof(iter)}(iter)
284+
end
285+
Base.iterate(a::IteratorsKeys, state...) = Base.iterate(a.data, state...)
286+
Base.keys(a::WithIteratorsKeys) = IteratorsKeys(keys(a.data))
287+
276288
# findmin/findmax function arguments: output type inference
277289
@testset "findmin/findmax output type inference" begin
278-
A = ["1" "22"; "333" "4444"]
279-
for (tup, rval, rind) in [((1,), [1 2], [CartesianIndex(1, 1) CartesianIndex(1, 2)]),
280-
((2,), reshape([1, 3], 2, 1), reshape([CartesianIndex(1, 1), CartesianIndex(2, 1)], 2, 1)),
281-
((1,2), fill(1,1,1), fill(CartesianIndex(1,1),1,1))]
282-
rval′, rind′ = findmin(length, A, dims=tup)
283-
@test (rval, rind) == (rval′, rind′)
284-
@test typeof(rval′) == Matrix{Int}
285-
end
286-
for (tup, rval, rind) in [((1,), [3 4], [CartesianIndex(2, 1) CartesianIndex(2, 2)]),
287-
((2,), reshape([2, 4], 2, 1), reshape([CartesianIndex(1, 2), CartesianIndex(2, 2)], 2, 1)),
288-
((1,2), fill(4,1,1), fill(CartesianIndex(2,2),1,1))]
289-
rval′, rind′ = findmax(length, A, dims=tup)
290-
@test (rval, rind) == (rval′, rind′)
291-
@test typeof(rval) == Matrix{Int}
292-
end
293-
B = [1.5 1.0; 5.5 6.0]
294-
for (tup, rval, rind) in [((1,), [3//2 1//1], [CartesianIndex(1, 1) CartesianIndex(1, 2)]),
295-
((2,), reshape([1//1, 11//2], 2, 1), reshape([CartesianIndex(1, 2), CartesianIndex(2, 1)], 2, 1)),
296-
((1,2), fill(1//1,1,1), fill(CartesianIndex(1,2),1,1))]
297-
rval′, rind′ = findmin(Rational, B, dims=tup)
298-
@test (rval, rind) == (rval′, rind′)
299-
@test typeof(rval) == Matrix{Rational{Int}}
300-
rval′, rind′ = findmin(Rational abs complex, B, dims=tup)
301-
@test (rval, rind) == (rval′, rind′)
302-
@test typeof(rval) == Matrix{Rational{Int}}
290+
for wrapper in (identity, WithIteratorsKeys)
291+
A = wrapper(["1" "22"; "333" "4444"])
292+
for (tup, rval, rind) in [((1,), [1 2], [CartesianIndex(1, 1) CartesianIndex(1, 2)]),
293+
((2,), reshape([1, 3], 2, 1), reshape([CartesianIndex(1, 1), CartesianIndex(2, 1)], 2, 1)),
294+
((1,2), fill(1,1,1), fill(CartesianIndex(1,1),1,1))]
295+
rval′, rind′ = @inferred findmin(length, A, dims=tup)
296+
@test (rval, rind) == (rval′, rind′)
297+
@test typeof(rval′) == Matrix{Int}
298+
end
299+
for (tup, rval, rind) in [((1,), [3 4], [CartesianIndex(2, 1) CartesianIndex(2, 2)]),
300+
((2,), reshape([2, 4], 2, 1), reshape([CartesianIndex(1, 2), CartesianIndex(2, 2)], 2, 1)),
301+
((1,2), fill(4,1,1), fill(CartesianIndex(2,2),1,1))]
302+
rval′, rind′ = @inferred findmax(length, A, dims=tup)
303+
@test (rval, rind) == (rval′, rind′)
304+
@test typeof(rval) == Matrix{Int}
305+
end
306+
B = wrapper([1.5 1.0; 5.5 6.0])
307+
for (tup, rval, rind) in [((1,), [3//2 1//1], [CartesianIndex(1, 1) CartesianIndex(1, 2)]),
308+
((2,), reshape([1//1, 11//2], 2, 1), reshape([CartesianIndex(1, 2), CartesianIndex(2, 1)], 2, 1)),
309+
((1,2), fill(1//1,1,1), fill(CartesianIndex(1,2),1,1))]
310+
rval′, rind′ = @inferred findmin(Rational, B, dims=tup)
311+
@test (rval, rind) == (rval′, rind′)
312+
@test typeof(rval) == Matrix{Rational{Int}}
313+
rval′, rind′ = @inferred findmin(Rational abs complex, B, dims=tup)
314+
@test (rval, rind) == (rval′, rind′)
315+
@test typeof(rval) == Matrix{Rational{Int}}
316+
end
303317
end
304318
end
305319

0 commit comments

Comments
 (0)