Skip to content

Commit 0baaaf0

Browse files
committed
argmin and argmax return missing if all values are missing
1 parent 755501b commit 0baaaf0

File tree

5 files changed

+43
-48
lines changed

5 files changed

+43
-48
lines changed

src/byrow/byrow.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,10 @@ byrow(ds::AbstractDataset, ::typeof(maximum), col::ColumnIndex; by = identity, t
6161
byrow(ds::AbstractDataset, ::typeof(minimum), cols::MultiColumnIndex = names(ds, Union{Missing, Number}); by = identity, threads = nrow(ds)>1000) = threads ? hp_row_minimum(ds, by, cols) : row_minimum(ds, by, cols)
6262
byrow(ds::AbstractDataset, ::typeof(minimum), col::ColumnIndex; by = identity, threads = nrow(ds)>1000) = byrow(ds, minimum, [col]; by = by, threads = threads)
6363

64-
byrow(ds::AbstractDataset, ::typeof(argmin), cols::MultiColumnIndex = names(ds, Union{Missing, Number}); by = identity, threads = nrow(ds)>1000) = threads ? hp_row_argmin(ds, by, cols) : row_argmin(ds, by, cols)
64+
byrow(ds::AbstractDataset, ::typeof(argmin), cols::MultiColumnIndex = names(ds, Union{Missing, Number}); by = identity, threads = nrow(ds)>1000) = row_argmin(ds, by, cols, threads = threads)
6565
byrow(ds::AbstractDataset, ::typeof(argmin), col::ColumnIndex; by = identity, threads = nrow(ds)>1000) = byrow(ds, argmin, [col]; by = by, threads = threads)
6666

67-
byrow(ds::AbstractDataset, ::typeof(argmax), cols::MultiColumnIndex = names(ds, Union{Missing, Number}); by = identity, threads = nrow(ds)>1000) = threads ? hp_row_argmax(ds, by, cols) : row_argmax(ds, by, cols)
67+
byrow(ds::AbstractDataset, ::typeof(argmax), cols::MultiColumnIndex = names(ds, Union{Missing, Number}); by = identity, threads = nrow(ds)>1000) = row_argmax(ds, by, cols, threads = threads)
6868
byrow(ds::AbstractDataset, ::typeof(argmax), col::ColumnIndex; by = identity, threads = nrow(ds)>1000) = byrow(ds, argmax, [col]; by = by, threads = threads)
6969

7070
byrow(ds::AbstractDataset, ::typeof(var), cols::MultiColumnIndex = names(ds, Union{Missing, Number}); by = identity, dof = true, threads = nrow(ds)>1000) = threads ? hp_row_var(ds, by, cols; dof = dof) : row_var(ds, by, cols; dof = dof)

src/byrow/hp_row_functions.jl

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -154,40 +154,6 @@ end
154154
hp_row_maximum(ds::AbstractDataset, cols = names(ds, Union{Missing, Number})) = hp_row_maximum(ds, identity, cols)
155155

156156

157-
function hp_op_for_argminmax!(x, y, f, vals, idx)
158-
idx[] += 1
159-
Threads.@threads for i in 1:length(x)
160-
if isequal(vals[i], f(y[i])) && ismissing(x[i])
161-
x[i] = idx[]
162-
end
163-
end
164-
x
165-
end
166-
167-
function hp_row_argmin(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}))
168-
colsidx = index(ds)[cols]
169-
minvals = hp_row_minimum(ds, f, cols)
170-
colnames_pa = PooledArray(names(ds, colsidx))
171-
idx = Ref{Int}(0)
172-
res = mapreduce(identity, (x,y)->hp_op_for_argminmax!(x,y,f, minvals, idx), view(_columns(ds),colsidx), init = missings(eltype(colnames_pa.refs), nrow(ds)))
173-
colnames_pa.refs = res
174-
colnames_pa
175-
end
176-
hp_row_argmin(ds::AbstractDataset, cols = names(ds, Union{Missing, Number})) = hp_row_argmin(ds, identity, cols)
177-
178-
function hp_row_argmax(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}))
179-
colsidx = index(ds)[cols]
180-
maxvals = hp_row_maximum(ds, f, cols)
181-
colnames_pa = PooledArray(names(ds, colsidx))
182-
idx = Ref{Int}(0)
183-
res = mapreduce(identity, (x,y)->hp_op_for_argminmax!(x,y,f, maxvals, idx), view(_columns(ds),colsidx), init = missings(eltype(colnames_pa.refs), nrow(ds)))
184-
colnames_pa.refs = res
185-
colnames_pa
186-
end
187-
hp_row_argmax(ds::AbstractDataset, cols = names(ds, Union{Missing, Number})) = hp_row_argmax(ds, identity, cols)
188-
189-
190-
191157
function hp_row_var(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}); dof = true)
192158
colsidx = index(ds)[cols]
193159
CT = mapreduce(eltype, promote_type, view(_columns(ds),colsidx))

src/byrow/row_functions.jl

Lines changed: 35 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -205,37 +205,62 @@ function row_maximum(ds::AbstractDataset, f::Function, cols = names(ds, Union{Mi
205205
end
206206
row_maximum(ds::AbstractDataset, cols = names(ds, Union{Missing, Number})) = row_maximum(ds, identity, cols)
207207

208-
function _op_for_argminmax!(x, y, f, vals, idx)
208+
function _op_for_argminmax!(x, y, f, vals, idx, missref)
209209
idx[] += 1
210210
for i in 1:length(x)
211-
if isequal(vals[i], f(y[i])) && ismissing(x[i])
211+
if !ismissing(vals[i]) && isequal(vals[i], f(y[i])) && isequal(x[i], missref)
212212
x[i] = idx[]
213213
end
214214
end
215215
x
216216
end
217+
function hp_op_for_argminmax!(x, y, f, vals, idx, missref)
218+
idx[] += 1
219+
Threads.@threads for i in 1:length(x)
220+
if !ismissing(vals[i]) && isequal(vals[i], f(y[i])) && isequal(x[i], missref)
221+
x[i] = idx[]
222+
end
223+
end
224+
x
225+
end
226+
217227

218-
function row_argmin(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}))
228+
function row_argmin(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}); threads = true)
219229
colsidx = index(ds)[cols]
220230
minvals = row_minimum(ds, f, cols)
221-
colnames_pa = PooledArray(names(ds, colsidx))
231+
colnames_pa = allowmissing(PooledArray(names(ds, colsidx)))
232+
push!(colnames_pa, missing)
233+
missref = get(colnames_pa.invpool, missing, missing)
234+
init0 = fill(missref, nrow(ds))
222235
idx = Ref{Int}(0)
223-
res = mapreduce(identity, (x,y)->_op_for_argminmax!(x,y, f, minvals, idx), view(_columns(ds),colsidx), init = missings(eltype(colnames_pa.refs), nrow(ds)))
236+
if threads
237+
res = mapreduce(identity, (x,y)->hp_op_for_argminmax!(x,y, f, minvals, idx, missref), view(_columns(ds),colsidx), init = init0)
238+
239+
else
240+
res = mapreduce(identity, (x,y)->_op_for_argminmax!(x,y, f, minvals, idx, missref), view(_columns(ds),colsidx), init = init0)
241+
end
224242
colnames_pa.refs = res
225243
colnames_pa
226244
end
227-
row_argmin(ds::AbstractDataset, cols = names(ds, Union{Missing, Number})) = row_argmin(ds, identity, cols)
245+
row_argmin(ds::AbstractDataset, cols = names(ds, Union{Missing, Number}); threads = true) = row_argmin(ds, identity, cols, threads = threads)
228246

229-
function row_argmax(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}))
247+
function row_argmax(ds::AbstractDataset, f::Function, cols = names(ds, Union{Missing, Number}); threads = true)
230248
colsidx = index(ds)[cols]
231249
maxvals = row_maximum(ds, f, cols)
232-
colnames_pa = PooledArray(names(ds, colsidx))
250+
colnames_pa = allowmissing(PooledArray(names(ds, colsidx)))
251+
push!(colnames_pa, missing)
252+
missref = get(colnames_pa.invpool, missing, missing)
253+
init0 = fill(missref, nrow(ds))
233254
idx = Ref{Int}(0)
234-
res = mapreduce(identity, (x,y)->_op_for_argminmax!(x,y,f, maxvals, idx), view(_columns(ds),colsidx), init = missings(eltype(colnames_pa.refs), nrow(ds)))
255+
if threads
256+
res = mapreduce(identity, (x,y)->hp_op_for_argminmax!(x,y,f, maxvals, idx, missref), view(_columns(ds),colsidx), init = init0)
257+
else
258+
res = mapreduce(identity, (x,y)->_op_for_argminmax!(x,y,f, maxvals, idx, missref), view(_columns(ds),colsidx), init = init0)
259+
end
235260
colnames_pa.refs = res
236261
colnames_pa
237262
end
238-
row_argmax(ds::AbstractDataset, cols = names(ds, Union{Missing, Number})) = row_argmax(ds, identity, cols)
263+
row_argmax(ds::AbstractDataset, cols = names(ds, Union{Missing, Number}); threads = true) = row_argmax(ds, identity, cols, threads = threads)
239264

240265

241266
# TODO better function for the first component of operator

src/stat/non_hp_stat.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ end
6767
function stat_findmax(f, x::AbstractArray{T,1}) where T
6868
isempty(x) && throw(ArgumentError("input vector cannot be empty"))
6969
maxval = stat_maximum(f, x)
70+
ismissing(maxval) && return (missing, missing)
7071
(maxval, _arg_minmax_barrier(x, maxval, f))
7172
end
7273
stat_findmax(x::AbstractArray{T,1}) where T = stat_findmax(identity, x)

test/byrow.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
@test isequal(byrow(ds, minimum, r"float", threads = true) , [1.2, missing, -1.4,2.3,-100.0])
2121

2222
@test byrow(ds, argmax, :) == ["x2_int", "x2_int", "x2_float", "x2_int", "x1_float"]
23-
@test byrow(ds, argmin, r"float") == ["x1_float", "x1_float", "x3_float", "x1_float", "x3_float"]
24-
@test byrow(ds, argmin, r"float", by = abs) == ["x1_float", "x1_float", "x1_float", "x1_float", "x1_float"]
23+
@test isequal(byrow(ds, argmin, r"float"), ["x1_float", missing, "x3_float", "x1_float", "x3_float"])
24+
@test isequal(byrow(ds, argmin, r"float", by = abs) , ["x1_float", missing, "x1_float", "x1_float", "x1_float"])
2525
@test byrow(ds, coalesce, ["x2_float", "x1_float", "x1_int"]) == [1.2,0,3.0,2.3,10]
2626
@test isequal(byrow(ds, var, r"float"), [missing, missing, 5.92, 0.24499999999999922, 6050.0])
2727
@test isequal(byrow(ds, var, r"float", dof = false), [0.0, missing, 3.9466666666666663, 0.12249999999999961, 3025.0])
@@ -31,6 +31,9 @@
3131
@test byrow(sds, sum, [:g, :x2_float]) == [2,2.0,2,4,4,4,1,1,1,2,2,2,2,2]
3232
@test byrow(sds, argmax, [3,2,4,1]) == ["x1_float","x1_float","x1_float","x2_float","x2_float","x2_float","g", "g", "g", "x1_float","x1_float","x1_float","x1_float","x1_float"]
3333
@test byrow(sds, argmax, [1,2,4,3], by = ismissing) == ["x2_float","x2_float","x2_float", "x1_float","x1_float","x1_float", "x1_float","x1_float","x1_float", "x2_float", "x2_float", "x2_float", "x2_float", "x2_float"]
34+
@test isequal(byrow(sds, argmax, [3,4]), ["x1_int","x1_int","x1_int","x2_float","x2_float","x2_float","x1_int","x1_int","x1_int",missing, missing, missing, missing, missing])
35+
@test isequal(byrow(sds, argmin, [3,4], threads = true), ["x1_int","x1_int","x1_int","x1_int","x1_int","x1_int","x1_int","x1_int","x1_int",missing, missing, missing, missing, missing])
36+
3437
@test byrow(sds, any, :, by = ismissing) == [true, true, true, false, false, false, true, true, true, true, true, true, true, true]
3538

3639
ds = Dataset(x1 = [1,2,3,4,missing], x2 = [3,2,4,5, missing])

0 commit comments

Comments
 (0)