Skip to content

Commit 41a5d23

Browse files
committed
bug fix - generalise topk
1 parent 7318eda commit 41a5d23

File tree

2 files changed

+63
-76
lines changed

2 files changed

+63
-76
lines changed

src/stat/non_hp_stat.jl

Lines changed: 27 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,16 @@ end
361361
# finding k largest in an array with missing values
362362
swap!(x, i, j) = x[i], x[j] = x[j], x[i]
363363

364+
TOPK_ISLESS(x,y) = isless(x, y)
365+
TOPK_ISLESS(::Missing, y) = true
366+
TOPK_ISLESS(x, ::Missing) = true
367+
TOPK_ISLESS(::Missing, ::Missing) = false
368+
364369
function insert_fixed_sorted!(x, item, ord)
365-
if ord((x[end]), (item))
366-
return
367-
else
370+
if ord(item, x[end])
368371
x[end] = item
372+
else
373+
return
369374
end
370375
i = length(x) - 1
371376
while i > 0
@@ -379,11 +384,11 @@ function insert_fixed_sorted!(x, item, ord)
379384
end
380385
# TODO we do not need x, this is just easier to implement, later we may fix this
381386
function insert_fixed_sorted_perm!(perm, x, idx, item, ord)
382-
if ord((x[end]), (item))
383-
return
384-
else
387+
if ord(item, x[end])
385388
x[end] = item
386389
perm[end] = idx
390+
else
391+
return
387392
end
388393
i = length(x) - 1
389394
while i > 0
@@ -397,29 +402,17 @@ function insert_fixed_sorted_perm!(perm, x, idx, item, ord)
397402
end
398403
end
399404

405+
400406
function k_largest(x::AbstractVector{T}, k::Int) where {T}
401-
k < 1 && throw(ArgumentError("k must be greater than 1"))
402-
k == 1 && return [maximum(identity, x)]
403-
if k > length(x)
404-
k = length(x)
405-
end
406-
res = Vector{T}(undef, k)
407-
fill!(res, typemin(T))
408-
for i in 1:length(x)
409-
insert_fixed_sorted!(res, x[i], (y1, y2) -> y1 > y2)
410-
end
411-
res
412-
end
413-
function k_largest(x::AbstractVector{Union{T,Missing}}, k::Int) where {T}
414407
k < 1 && throw(ArgumentError("k must be greater than 1"))
415408
k == 1 && return [maximum(identity, x)]
416409
all(ismissing, x) && return [missing]
417-
res = Vector{T}(undef, k)
418-
fill!(res, typemin(T))
410+
res = Vector{Union{Missing, T}}(undef, k)
411+
fill!(res, missing)
419412
cnt = 0
420413
for i in 1:length(x)
421414
if !ismissing(x[i])
422-
insert_fixed_sorted!(res, x[i], (y1, y2) -> y1 > y2)
415+
insert_fixed_sorted!(res, x[i], (y1, y2) -> TOPK_ISLESS(y2, y1))
423416
cnt += 1
424417
end
425418
end
@@ -431,28 +424,15 @@ function k_largest(x::AbstractVector{Union{T,Missing}}, k::Int) where {T}
431424
end
432425

433426
function k_smallest(x::AbstractVector{T}, k::Int) where {T}
434-
k < 1 && throw(ArgumentError("k must be greater than 1"))
435-
k == 1 && return [minimum(identity, x)]
436-
if k > length(x)
437-
k = length(x)
438-
end
439-
res = Vector{T}(undef, k)
440-
fill!(res, typemax(T))
441-
for i in 1:length(x)
442-
insert_fixed_sorted!(res, x[i], (y1, y2) -> y1 < y2)
443-
end
444-
res
445-
end
446-
function k_smallest(x::AbstractVector{Union{T,Missing}}, k::Int) where {T}
447427
k < 1 && throw(ArgumentError("k must be greater than 1"))
448428
k == 1 && return [minimum(identity, x)]
449429
all(ismissing, x) && return [missing]
450-
res = Vector{T}(undef, k)
451-
fill!(res, typemax(T))
430+
res = Vector{Union{Missing, T}}(undef, k)
431+
fill!(res, missing)
452432
cnt = 0
453433
for i in 1:length(x)
454434
if !ismissing(x[i])
455-
insert_fixed_sorted!(res, x[i], (y1, y2) -> y1 < y2)
435+
insert_fixed_sorted!(res, x[i], (y1, y2) -> TOPK_ISLESS(y1, y2))
456436
cnt += 1
457437
end
458438
end
@@ -465,31 +445,18 @@ end
465445

466446

467447
# ktop permutation
448+
468449
function k_largest_perm(x::AbstractVector{T}, k::Int) where {T}
469-
k < 1 && throw(ArgumentError("k must be greater than 1"))
470-
k == 1 && return [maximum(identity, x)], [argmax(x)]
471-
if k > length(x)
472-
k = length(x)
473-
end
474-
res = Vector{T}(undef, k)
475-
perm = zeros(Int, k)
476-
fill!(res, typemin(T))
477-
for i in 1:length(x)
478-
insert_fixed_sorted_perm!(perm, res, i, x[i], (y1, y2) -> y1 > y2)
479-
end
480-
res, perm
481-
end
482-
function k_largest_perm(x::AbstractVector{Union{T,Missing}}, k::Int) where {T}
483450
k < 1 && throw(ArgumentError("k must be greater than 1"))
484451
k == 1 && return [maximum(identity, x)], [argmax(x)]
485452
all(ismissing, x) && return [missing], [missing]
486-
res = Vector{T}(undef, k)
453+
res = Vector{Union{Missing, T}}(undef, k)
487454
perm = zeros(Int, k)
488-
fill!(res, typemin(T))
455+
fill!(res, missing)
489456
cnt = 0
490457
for i in 1:length(x)
491458
if !ismissing(x[i])
492-
insert_fixed_sorted_perm!(perm, res, i, x[i], (y1, y2) -> y1 > y2)
459+
insert_fixed_sorted_perm!(perm, res, i, x[i], (y1, y2) -> TOPK_ISLESS(y2, y1))
493460
cnt += 1
494461
end
495462
end
@@ -499,31 +466,18 @@ function k_largest_perm(x::AbstractVector{Union{T,Missing}}, k::Int) where {T}
499466
res, perm
500467
end
501468
end
469+
502470
function k_smallest_perm(x::AbstractVector{T}, k::Int) where {T}
503-
k < 1 && throw(ArgumentError("k must be greater than 1"))
504-
k == 1 && return [minimum(identity, x)], [argmin(x)]
505-
if k > length(x)
506-
k = length(x)
507-
end
508-
res = Vector{T}(undef, k)
509-
perm = zeros(Int, k)
510-
fill!(res, typemax(T))
511-
for i in 1:length(x)
512-
insert_fixed_sorted_perm!(perm, res, i, x[i], (y1, y2) -> y1 < y2)
513-
end
514-
res, perm
515-
end
516-
function k_smallest_perm(x::AbstractVector{Union{T,Missing}}, k::Int) where {T}
517471
k < 1 && throw(ArgumentError("k must be greater than 1"))
518472
k == 1 && return [minimum(identity, x)], [argmin(x)]
519473
all(ismissing, x) && return [missing], [missing]
520-
res = Vector{T}(undef, k)
474+
res = Vector{Union{Missing, T}}(undef, k)
521475
perm = zeros(Int, k)
522-
fill!(res, typemax(T))
476+
fill!(res, missing)
523477
cnt = 0
524478
for i in 1:length(x)
525479
if !ismissing(x[i])
526-
insert_fixed_sorted_perm!(perm, res, i, x[i], (y1, y2) -> y1 < y2)
480+
insert_fixed_sorted_perm!(perm, res, i, x[i], (y1, y2) -> TOPK_ISLESS(y1, y2))
527481
cnt += 1
528482
end
529483
end
@@ -542,7 +496,7 @@ Return upto `k` largest nonmissing elements of `x`. When `rev = true` it returns
542496
543497
If `output_indices = true`, the function returns the values and their indices.
544498
545-
> The `topk` function uses `<` for comparing values
499+
> The `topk` function uses `isless` for comparing values
546500
"""
547501
function topk(x::AbstractVector, k::Int; rev::Bool=false, output_indices::Bool = false)
548502
@assert firstindex(x) == 1 "topk only supports 1-based indexing"

test/stats.jl

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
for j in 1:11
66
@test partialsort(x, 1:j) == topk(x, j, rev=true)
77
@test partialsort(x, 1:j, rev=true) == topk(x, j)
8-
@test partialsortperm(x, 1:j) == topk(x, j, rev=true, output_indices = true)[2]
9-
@test partialsortperm(x, 1:j, rev=true) == topk(x, j, output_indices = true)[2]
8+
@test partialsortperm(x, 1:j) == topk(x, j, rev=true, output_indices=true)[2]
9+
@test partialsortperm(x, 1:j, rev=true) == topk(x, j, output_indices=true)[2]
1010
end
1111
x = rand(11)
1212
for j in 1:11
1313
@test partialsort(x, 1:j) == topk(x, j, rev=true)
1414
@test partialsort(x, 1:j, rev=true) == topk(x, j)
15-
@test partialsortperm(x, 1:j) == topk(x, j, rev=true, output_indices = true)[2]
15+
@test partialsortperm(x, 1:j) == topk(x, j, rev=true, output_indices=true)[2]
1616
@test partialsortperm(x, 1:j, rev=true) == topk(x, j, output_indices=true)[2]
1717
end
1818
x = randn(11)
@@ -22,6 +22,27 @@
2222
@test partialsortperm(x, 1:j) == topk(x, j, rev=true, output_indices=true)[2]
2323
@test partialsortperm(x, 1:j, rev=true) == topk(x, j, output_indices=true)[2]
2424
end
25+
x = rand(Int8, 10000)
26+
for j in 1:15
27+
@test partialsort(x, 1:j) == topk(x, j, rev=true)
28+
@test partialsort(x, 1:j, rev=true) == topk(x, j)
29+
@test partialsortperm(x, 1:j) == topk(x, j, rev=true, output_indices=true)[2]
30+
@test partialsortperm(x, 1:j, rev=true) == topk(x, j, output_indices=true)[2]
31+
end
32+
x = zeros(Bool, 11)
33+
for j in 1:15
34+
@test partialsort(x, 1:min(11, j)) == topk(x, j, rev=true)
35+
@test partialsort(x, 1:min(j, 11), rev=true) == topk(x, j)
36+
@test partialsortperm(x, 1:min(11, j)) == topk(x, j, rev=true, output_indices=true)[2]
37+
@test partialsortperm(x, 1:min(11, j), rev=true) == topk(x, j, output_indices=true)[2]
38+
end
39+
x = ones(Bool, 11)
40+
for j in 1:15
41+
@test partialsort(x, 1:min(11, j)) == topk(x, j, rev=true)
42+
@test partialsort(x, 1:min(j, 11), rev=true) == topk(x, j)
43+
@test partialsortperm(x, 1:min(11, j)) == topk(x, j, rev=true, output_indices=true)[2]
44+
@test partialsortperm(x, 1:min(11, j), rev=true) == topk(x, j, output_indices=true)[2]
45+
end
2546
end
2647
x = [1, 10, missing, 100, -1000, 32, 54, 0, missing, missing, -1]
2748
@test topk(x, 2) == [100, 54]
@@ -36,4 +57,16 @@
3657
@test isequal(topk([missing, missing], 2, rev = true), [missing])
3758
@test isequal(topk([missing, missing], 2, output_indices=true)[2], [missing])
3859
@test isequal(topk([missing, missing], 2, rev=true, output_indices=true)[2], [missing])
60+
x = Int8[-128, -128, -128]
61+
y = Union{Int8, Missing}[-128, -128, missing, missing, -128]
62+
63+
@test topk(x, 2) == [-128, -128]
64+
@test topk(x, 2, rev = true) == [-128, -128]
65+
@test topk(x, 2, rev = true, output_indices = true) == ([-128, -128], [1,2])
66+
@test topk(x, 2, output_indices = true) == ([-128, -128], [1,2])
67+
68+
@test topk(y, 3) == [-128, -128, -128]
69+
@test topk(y, 3, rev = true) == [-128, -128, -128]
70+
@test topk(y, 3, rev = true, output_indices = true) == ([-128, -128, -128], [1, 2, 5])
71+
@test topk(y, 3, output_indices = true) == ([-128, -128, -128], [1, 2, 5])
3972
end

0 commit comments

Comments
 (0)