Skip to content

Commit 9f8fefd

Browse files
committed
fix #67
1 parent 18b5b69 commit 9f8fefd

File tree

3 files changed

+126
-77
lines changed

3 files changed

+126
-77
lines changed

src/InMemoryDatasets.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ export
9494
stdze!,
9595
rescale,
9696
topk,
97+
topkperm,
9798
cummax,
9899
cummax!,
99100
cummin,

src/stat/non_hp_stat.jl

Lines changed: 95 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -361,11 +361,38 @@ end
361361
# finding k largest in an array with missing values
362362
swap!(x, i, j) = x[i], x[j] = x[j], x[i]
363363

364-
TOPK_ISLESS(x,y) = isless(x, y)
365-
TOPK_ISLESS(::Missing, y) = true
366-
TOPK_ISLESS(x, ::Missing) = true
367-
TOPK_ISLESS(::Missing, ::Missing) = false
368-
364+
function initiate_topk_res!(res, x)
365+
cnt = 1
366+
idx = 1
367+
for i in 1:length(x)
368+
idx = i
369+
if !ismissing(x[i])
370+
res[cnt] = x[i]
371+
cnt += 1
372+
if cnt > length(res)
373+
break
374+
end
375+
end
376+
end
377+
idx, cnt-1
378+
end
379+
function initiate_topk_res_perm!(perm, res, x)
380+
cnt = 1
381+
idx = 1
382+
for i in 1:length(x)
383+
idx = i
384+
if !ismissing(x[i])
385+
res[cnt] = x[i]
386+
perm[cnt] = i
387+
cnt += 1
388+
if cnt > length(res)
389+
break
390+
end
391+
end
392+
end
393+
idx, cnt-1
394+
end
395+
369396
Base.@propagate_inbounds function insert_fixed_sorted!(x, item, ord)
370397
if ord(item, x[end])
371398
x[end] = item
@@ -405,41 +432,41 @@ end
405432

406433
Base.@propagate_inbounds function k_largest(x::AbstractVector{T}, k::Int) where {T}
407434
k < 1 && throw(ArgumentError("k must be greater than 1"))
408-
k == 1 && return [maximum(identity, x)]
409-
all(ismissing, x) && return [missing]
410-
res = Vector{Union{Missing, T}}(undef, k)
411-
fill!(res, missing)
412-
cnt = 0
413-
for i in 1:length(x)
435+
k == 1 && return Union{Missing, T}[maximum(identity, x)]
436+
all(ismissing, x) && return Union{Missing, T}[missing]
437+
res = Vector{nonmissingtype(T)}(undef, k)
438+
idx, cnt = initiate_topk_res!(res, x)
439+
sort!(view(res,1:cnt), rev = true)
440+
for i in idx+1:length(x)
414441
if !ismissing(x[i])
415-
insert_fixed_sorted!(res, x[i], (y1, y2) -> TOPK_ISLESS(y2, y1))
442+
insert_fixed_sorted!(res, x[i], (y1, y2) -> isless(y2, y1))
416443
cnt += 1
417444
end
418445
end
419446
if cnt < k
420-
res[1:cnt]
447+
allowmissing(res[1:cnt])
421448
else
422-
res
449+
allowmissing(res)
423450
end
424451
end
425452

426453
Base.@propagate_inbounds function k_smallest(x::AbstractVector{T}, k::Int) where {T}
427454
k < 1 && throw(ArgumentError("k must be greater than 1"))
428-
k == 1 && return [minimum(identity, x)]
429-
all(ismissing, x) && return [missing]
430-
res = Vector{Union{Missing, T}}(undef, k)
431-
fill!(res, missing)
432-
cnt = 0
433-
for i in 1:length(x)
455+
k == 1 && return Union{Missing, T}[minimum(identity, x)]
456+
all(ismissing, x) && return Union{Missing, T}[missing]
457+
res = Vector{nonmissingtype(T)}(undef, k)
458+
idx, cnt = initiate_topk_res!(res, x)
459+
sort!(view(res,1:cnt))
460+
for i in idx+1:length(x)
434461
if !ismissing(x[i])
435-
insert_fixed_sorted!(res, x[i], (y1, y2) -> TOPK_ISLESS(y1, y2))
462+
insert_fixed_sorted!(res, x[i], (y1, y2) -> isless(y1, y2))
436463
cnt += 1
437464
end
438465
end
439466
if cnt < k
440-
res[1:cnt]
467+
allowmissing(res[1:cnt])
441468
else
442-
res
469+
allowmissing(res)
443470
end
444471
end
445472

@@ -448,70 +475,83 @@ end
448475

449476
Base.@propagate_inbounds function k_largest_perm(x::AbstractVector{T}, k::Int) where {T}
450477
k < 1 && throw(ArgumentError("k must be greater than 1"))
451-
k == 1 && return [maximum(identity, x)], [argmax(x)]
452-
all(ismissing, x) && return [missing], [missing]
453-
res = Vector{Union{Missing, T}}(undef, k)
478+
k == 1 && return Union{Missing, Int}[argmax(x)]
479+
all(ismissing, x) && return Union{Missing, Int}[missing]
480+
res = Vector{nonmissingtype(T)}(undef, k)
454481
perm = zeros(Int, k)
455-
fill!(res, missing)
456-
cnt = 0
457-
for i in 1:length(x)
482+
idx, cnt = initiate_topk_res_perm!(perm, res, x)
483+
sort_perm = sortperm(view(res,1:cnt), rev = true)
484+
permute!(view(res,1:cnt), sort_perm)
485+
permute!(view(perm,1:cnt), sort_perm)
486+
for i in idx+1:length(x)
458487
if !ismissing(x[i])
459-
insert_fixed_sorted_perm!(perm, res, i, x[i], (y1, y2) -> TOPK_ISLESS(y2, y1))
488+
insert_fixed_sorted_perm!(perm, res, i, x[i], (y1, y2) -> isless(y2, y1))
460489
cnt += 1
461490
end
462491
end
463492
if cnt < k
464-
res[1:cnt], perm[1:cnt]
493+
allowmissing(perm[1:cnt])
465494
else
466-
res, perm
495+
allowmissing(perm)
467496
end
468497
end
469498

470499
Base.@propagate_inbounds function k_smallest_perm(x::AbstractVector{T}, k::Int) where {T}
471500
k < 1 && throw(ArgumentError("k must be greater than 1"))
472-
k == 1 && return [minimum(identity, x)], [argmin(x)]
473-
all(ismissing, x) && return [missing], [missing]
474-
res = Vector{Union{Missing, T}}(undef, k)
501+
k == 1 && return Union{Missing, Int}[argmin(x)]
502+
all(ismissing, x) && return Union{Missing, Int}[missing]
503+
res = Vector{nonmissingtype(T)}(undef, k)
475504
perm = zeros(Int, k)
476-
fill!(res, missing)
477-
cnt = 0
478-
for i in 1:length(x)
505+
idx, cnt = initiate_topk_res_perm!(perm, res, x)
506+
sort_perm = sortperm(view(res,1:cnt))
507+
permute!(view(res,1:cnt), sort_perm)
508+
permute!(view(perm,1:cnt), sort_perm)
509+
for i in idx+1:length(x)
479510
if !ismissing(x[i])
480-
insert_fixed_sorted_perm!(perm, res, i, x[i], (y1, y2) -> TOPK_ISLESS(y1, y2))
511+
insert_fixed_sorted_perm!(perm, res, i, x[i], (y1, y2) -> isless(y1, y2))
481512
cnt += 1
482513
end
483514
end
484515
if cnt < k
485-
res[1:cnt], perm[1:cnt]
516+
allowmissing(perm[1:cnt])
486517
else
487-
res, perm
518+
allowmissing(perm)
488519
end
489520
end
490521

491522

492523
"""
493-
topk(x, k; rev = false, output_indices = false)
524+
topk(x, k; rev = false)
494525
495526
Return upto `k` largest nonmissing elements of `x`. When `rev = true` it returns upto `k` smallest nonmissing elements of `x`. When all elements are missing, the function returns `[missing]`.
496527
497-
If `output_indices = true`, the function returns the values and their indices.
498-
499528
> The `topk` function uses `isless` for comparing values
529+
530+
Also see [`topkperm`](@ref)
500531
"""
501-
function topk(x::AbstractVector, k::Int; rev::Bool=false, output_indices::Bool = false)
532+
function topk(x::AbstractVector, k::Int; rev::Bool=false)
502533
@assert firstindex(x) == 1 "topk only supports 1-based indexing"
503534
if rev
504-
if output_indices
505-
k_smallest_perm(x, k)
506-
else
507-
k_smallest(x, k)
508-
end
535+
k_smallest(x, k)
509536
else
510-
if output_indices
511-
k_largest_perm(x, k)
512-
else
513-
k_largest(x, k)
514-
end
537+
k_largest(x, k)
538+
end
539+
end
540+
"""
541+
topkperm(x, k; rev = false)
542+
543+
Return the indices of upto `k` largest nonmissing elements of `x`. When `rev = true` it returns the indices of upto `k` smallest nonmissing elements of `x`. When all elements are missing, the function returns `[missing]`.
544+
545+
> The `topkperm` function uses `isless` for comparing values
546+
547+
Also see [`topk`](@ref)
548+
"""
549+
function topkperm(x::AbstractVector, k::Int; rev::Bool=false)
550+
@assert firstindex(x) == 1 "topkperm only supports 1-based indexing"
551+
if rev
552+
k_smallest_perm(x, k)
553+
else
554+
k_largest_perm(x, k)
515555
end
516556
end
517557

test/stats.jl

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,80 @@
1+
using Random
12
@testset "topk" begin
23
# general usage
34
for i in 1:100
45
x = rand(Int, 11)
56
for j in 1:11
67
@test partialsort(x, 1:j) == topk(x, j, rev=true)
78
@test partialsort(x, 1:j, rev=true) == topk(x, j)
8-
@test partialsortperm(x, 1:j) == topk(x, j, rev=true, output_indices=true)[2]
9-
@test partialsortperm(x, 1:j, rev=true) == topk(x, j, output_indices=true)[2]
9+
@test partialsortperm(x, 1:j) == topkperm(x, j, rev = true)
10+
@test partialsortperm(x, 1:j, rev=true) == topkperm(x, j)
1011
end
1112
x = rand(11)
1213
for j in 1:11
1314
@test partialsort(x, 1:j) == topk(x, j, rev=true)
1415
@test partialsort(x, 1:j, rev=true) == topk(x, j)
15-
@test partialsortperm(x, 1:j) == topk(x, j, rev=true, output_indices=true)[2]
16-
@test partialsortperm(x, 1:j, rev=true) == topk(x, j, output_indices=true)[2]
16+
@test partialsortperm(x, 1:j) == topkperm(x, j, rev=true)
17+
@test partialsortperm(x, 1:j, rev=true) == topkperm(x, j)
1718
end
1819
x = randn(11)
1920
for j in 1:11
2021
@test partialsort(x, 1:j) == topk(x, j, rev=true)
2122
@test partialsort(x, 1:j, rev=true) == topk(x, j)
22-
@test partialsortperm(x, 1:j) == topk(x, j, rev=true, output_indices=true)[2]
23-
@test partialsortperm(x, 1:j, rev=true) == topk(x, j, output_indices=true)[2]
23+
@test partialsortperm(x, 1:j) == topkperm(x, j, rev=true)
24+
@test partialsortperm(x, 1:j, rev=true) == topkperm(x, j)
2425
end
2526
x = rand(Int8, 10000)
2627
for j in 1:15
2728
@test partialsort(x, 1:j) == topk(x, j, rev=true)
2829
@test partialsort(x, 1:j, rev=true) == topk(x, j)
29-
@test partialsortperm(x, 1:j) == topk(x, j, rev=true, output_indices=true)[2]
30-
@test partialsortperm(x, 1:j, rev=true) == topk(x, j, output_indices=true)[2]
30+
@test partialsortperm(x, 1:j) == topkperm(x, j, rev=true)
31+
@test partialsortperm(x, 1:j, rev=true) == topkperm(x, j)
3132
end
3233
x = zeros(Bool, 11)
3334
for j in 1:15
3435
@test partialsort(x, 1:min(11, j)) == topk(x, j, rev=true)
3536
@test partialsort(x, 1:min(j, 11), rev=true) == topk(x, j)
36-
@test partialsortperm(x, 1:min(11, j)) == topk(x, j, rev=true, output_indices=true)[2]
37-
@test partialsortperm(x, 1:min(11, j), rev=true) == topk(x, j, output_indices=true)[2]
37+
@test partialsortperm(x, 1:min(11, j)) == topkperm(x, j, rev=true)
38+
@test partialsortperm(x, 1:min(11, j), rev=true) == topkperm(x, j)
3839
end
3940
x = ones(Bool, 11)
4041
for j in 1:15
4142
@test partialsort(x, 1:min(11, j)) == topk(x, j, rev=true)
4243
@test partialsort(x, 1:min(j, 11), rev=true) == topk(x, j)
43-
@test partialsortperm(x, 1:min(11, j)) == topk(x, j, rev=true, output_indices=true)[2]
44-
@test partialsortperm(x, 1:min(11, j), rev=true) == topk(x, j, output_indices=true)[2]
44+
@test partialsortperm(x, 1:min(11, j)) == topkperm(x, j, rev=true)
45+
@test partialsortperm(x, 1:min(11, j), rev=true) == topkperm(x, j)
46+
end
47+
x = [randstring() for _ in 1:101]
48+
for j in 1:15
49+
@test partialsort(x, 1:j) == topk(x, j, rev=true)
50+
@test partialsort(x, 1:j, rev=true) == topk(x, j)
51+
@test partialsortperm(x, 1:j) == topkperm(x, j, rev=true)
52+
@test partialsortperm(x, 1:j, rev=true) == topkperm(x, j)
4553
end
4654
end
4755
x = [1, 10, missing, 100, -1000, 32, 54, 0, missing, missing, -1]
4856
@test topk(x, 2) == [100, 54]
4957
@test topk(x, 2, rev=true) == [-1000, -1]
50-
@test topk(x, 2, output_indices=true)[2] == [4, 7]
51-
@test topk(x, 2, rev=true, output_indices=true)[2] == [5, 11]
58+
@test topkperm(x, 2) == [4, 7]
59+
@test topkperm(x, 2, rev=true) == [5, 11]
5260
@test topk(x, 10) == [100, 54, 32, 10, 1, 0, -1, -1000]
5361
@test topk(x, 10, rev=true) == [-1000, -1, 0, 1, 10, 32, 54, 100]
54-
@test topk(x, 10, output_indices=true)[2] == [4, 7, 6, 2, 1, 8, 11, 5]
55-
@test topk(x, 10, rev=true, output_indices=true)[2] == [5, 11, 8, 1, 2, 6, 7, 4]
62+
@test topkperm(x, 10) == [4, 7, 6, 2, 1, 8, 11, 5]
63+
@test topkperm(x, 10, rev=true) == [5, 11, 8, 1, 2, 6, 7, 4]
5664
@test isequal(topk([missing, missing], 2), [missing])
5765
@test isequal(topk([missing, missing], 2, rev = true), [missing])
58-
@test isequal(topk([missing, missing], 2, output_indices=true)[2], [missing])
59-
@test isequal(topk([missing, missing], 2, rev=true, output_indices=true)[2], [missing])
66+
@test isequal(topkperm([missing, missing], 2), [missing])
67+
@test isequal(topkperm([missing, missing], 2, rev=true), [missing])
6068
x = Int8[-128, -128, -128]
6169
y = Union{Int8, Missing}[-128, -128, missing, missing, -128]
6270

6371
@test topk(x, 2) == [-128, -128]
6472
@test topk(x, 2, rev = true) == [-128, -128]
65-
@test topk(x, 2, rev = true, output_indices = true) == ([-128, -128], [1,2])
66-
@test topk(x, 2, output_indices = true) == ([-128, -128], [1,2])
73+
@test topkperm(x, 2, rev = true) == [1,2]
74+
@test topkperm(x, 2) == [1,2]
6775

6876
@test topk(y, 3) == [-128, -128, -128]
6977
@test topk(y, 3, rev = true) == [-128, -128, -128]
70-
@test topk(y, 3, rev = true, output_indices = true) == ([-128, -128, -128], [1, 2, 5])
71-
@test topk(y, 3, output_indices = true) == ([-128, -128, -128], [1, 2, 5])
78+
@test topkperm(y, 3, rev = true) == [1, 2, 5]
79+
@test topkperm(y, 3) == [1, 2, 5]
7280
end

0 commit comments

Comments
 (0)