Skip to content

Commit 917981e

Browse files
committed
bug fix
1 parent abc7c03 commit 917981e

File tree

2 files changed

+43
-8
lines changed

2 files changed

+43
-8
lines changed

src/sort/gatherby.jl

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,12 @@ end
138138

139139

140140

141-
function gatherby_mapreduce(gds::GatherBy, f, op, col::ColumnIndex, nt, init::T) where T
141+
function gatherby_mapreduce(gds::GatherBy, f, op, col::ColumnIndex, nt, init, ::Val{T}; promotetypes = false) where T
142142
CT = T
143-
T <: Base.SmallSigned ? CT = Int : nothing
144-
T <: Base.SmallUnsigned ? CT = UInt : nothing
145-
T <: Float64 ? CT = Float64 : nothing
143+
if promotetypes
144+
T <: Base.SmallSigned ? CT = Int : nothing
145+
T <: Base.SmallUnsigned ? CT = UInt : nothing
146+
end
146147
res = Tables.allocatecolumn(Union{CT, Missing}, gds.lastvalid)
147148
fill!(res, init)
148149
if Threads.nthreads() > 1 && gds.lastvalid > 100_000
@@ -153,9 +154,9 @@ function gatherby_mapreduce(gds::GatherBy, f, op, col::ColumnIndex, nt, init::T)
153154
res
154155
end
155156

156-
_gatherby_maximum(gds, col; f = identity, nt = Threads.nthreads()) = gatherby_mapreduce(gds, f, _stat_max_fun, col, nt, typemin(nonmissingtype(eltype(gds.parent[!, col]))))
157-
_gatherby_minimum(gds, col; f = identity, nt = Threads.nthreads()) = gatherby_mapreduce(gds, f, _stat_min_fun, col, nt, typemax(nonmissingtype(eltype(gds.parent[!, col]))))
158-
_gatherby_sum(gds, col; f = identity, nt = Threads.nthreads()) = gatherby_mapreduce(gds, f, _stat_add_sum, col, nt, zero(Core.Compiler.return_type(f, (eltype(gds.parent[!, col]), ))))
157+
_gatherby_maximum(gds, col; f = identity, nt = Threads.nthreads()) = gatherby_mapreduce(gds, f, _stat_max_fun, col, nt, missing, Val(nonmissingtype(eltype(gds.parent[!, col]))))
158+
_gatherby_minimum(gds, col; f = identity, nt = Threads.nthreads()) = gatherby_mapreduce(gds, f, _stat_min_fun, col, nt, missing, Val(nonmissingtype(eltype(gds.parent[!, col]))))
159+
_gatherby_sum(gds, col; f = identity, nt = Threads.nthreads()) = gatherby_mapreduce(gds, f, _stat_add_sum, col, nt, missing, Val(typeof(zero(Core.Compiler.return_type(f, (eltype(gds.parent[!, col]), ))))), promotetypes = true)
159160
_gatherby_n(gds, col; nt = Threads.nthreads()) = _gatherby_sum(gds, col, f = _stat_notmissing, nt = nt)
160161
_gatherby_length(gds, col; nt = Threads.nthreads()) = _gatherby_sum(gds, col, f = x->1, nt = nt)
161162
_gatherby_cntnan(gds, col; nt = Threads.nthreads()) = _gatherby_sum(gds, col, f = ISNAN, nt = nt)
@@ -221,7 +222,7 @@ function _gatherby_var(gds, col; dof = true, cal_std = false)
221222
t1 = Threads.@spawn _gatherby_cntnan(gds, col, nt = nt2)
222223
t2 = Threads.@spawn _gatherby_mean(gds, col, nt = nt2)
223224
meanval = fetch(t2)
224-
t3 = Threads.@spawn gatherby_mapreduce(gds, [x->abs2(x - meanval[i]) for i in 1:length(meanval)], _stat_add_sum, col, nt2, 0.0)
225+
t3 = Threads.@spawn gatherby_mapreduce(gds, [x->abs2(x - meanval[i]) for i in 1:length(meanval)], _stat_add_sum, col, nt2, missing, Val(Float64))
225226
t4 = Threads.@spawn _gatherby_n(gds, col, nt = nt2)
226227
countnan = fetch(t1)
227228
ss = fetch(t3)

test/grouping.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,4 +467,38 @@ end
467467

468468
@test IMD.index(view(ds,[2,1,1,3,5,6,7,4,4],[:z,:x,:q,:q2])) == IMD.index(sds2)
469469

470+
ds = Dataset(x = [3,1,2,2,missing,3,3], y = [1.1, missing, -1.0, -3.0, missing, 4.0, 5.0], z = [11,15,7,-11,12,0,0])
471+
sds2 = view(ds, [2,1,1,3,5,6,7,4,4], 1:2)
472+
@test combine(gatherby(sds2, 1), :y=>sum) == Dataset(x=[1,3,2, missing], sum_y=[missing, 11.2,-7.0,missing])
473+
@test combine(gatherby(sds2, 1), :y=>(x->sum(x))=>:sum_y) == Dataset(x=[1,3,2, missing], sum_y=[missing, 11.2,-7.0,missing])
474+
@test combine(groupby(sds2, 1), :y=>sum) == Dataset(x=[1,2,3, missing], sum_y=[missing, -7.0, 11.2,missing])
475+
@test combine(gatherby(sds2, 1, isgathered = true), :y=>sum) == Dataset(x = [1,3,2,missing,3,2], sum_y=[missing, 2.2, -1,missing, 9,-6])
476+
477+
ds = Dataset(x = [1,2,1,2,3], y1 = Union{Int8, Missing}[1,2,missing,4,missing], y2 = Union{Int32, Missing}[1,2,3,4,missing], y3=Union{Int16, Missing}[100,20,3000,4,missing], y4=Float16.(rand(5)), y5=rand(BigFloat, 5))
478+
sds = view(ds, [1,2,3,4,5], [1,2,3,4,5,6])
479+
480+
@test combine(gatherby(sds, 1), 2:4 .=>Ref([sum, mean, maximum, minimum, IMD.n, IMD.nmissing])) == Dataset([Union{Missing, Int64}[1, 2, 3], Union{Missing, Int64}[1, 6, missing], Union{Missing, Float64}[1.0, 3.0, missing], Union{Missing, Int8}[1, 4, missing], Union{Missing, Int8}[1, 2, missing], Union{Missing, Int64}[1, 2, 0], Union{Missing, Int64}[1, 0, 1], Union{Missing, Int64}[4, 6, missing], Union{Missing, Float64}[2.0, 3.0, missing], Union{Missing, Int32}[3, 4, missing], Union{Missing, Int32}[1, 2, missing], Union{Missing, Int64}[2, 2, 0], Union{Missing, Int64}[0, 0, 1], Union{Missing, Int64}[3100, 24, missing], Union{Missing, Float64}[1550.0, 12.0, missing], Union{Missing, Int16}[3000, 20, missing], Union{Missing, Int16}[100, 4, missing], Union{Missing, Int64}[2, 2, 0], Union{Missing, Int64}[0, 0, 1]], ["x", "sum_y1", "mean_y1", "maximum_y1", "minimum_y1", "n_y1", "nmissing_y1", "sum_y2", "mean_y2", "maximum_y2", "minimum_y2", "n_y2", "nmissing_y2", "sum_y3", "mean_y3", "maximum_y3", "minimum_y3", "n_y3", "nmissing_y3"])
481+
482+
var1(x) = var(x)
483+
std1(x) = std(x)
484+
median1(x) = median(x)
485+
c1 =combine(gatherby(sds, 1), 2:4 .=>Ref([var, std, median]))
486+
c2 = combine(gatherby(copy(sds), 1), 2:4 .=> Ref([var1, std1, median1]))
487+
@test byrow(compare(c1, c2, on = names(c1) .=> names(c2)) , all)|>all
488+
489+
c3 = combine(gatherby(sds,1), :y4=>sum)
490+
@test eltype(c3.sum_y4) == Union{Missing, Float16}
491+
c3 = combine(gatherby(sds,1), :y5=>sum)
492+
@test eltype(c3.sum_y5) == Union{Missing, BigFloat}
493+
494+
ds = Dataset(rand(1:10, 300_000,3), :auto)
495+
insertcols!(ds, 1, :g => repeat(1:150_000, inner = 2))
496+
map!(ds, x->rand()<.7 ? missing : x, r"x")
497+
sds = view(ds, :, :)
498+
499+
@test combine(gatherby(sds, 1), r"x"=>sum) == combine(groupby(sds, 1), r"x"=>sum)
500+
@test combine(gatherby(sds, 1), r"x"=>maximum) == combine(groupby(sds, 1), r"x"=>maximum)
501+
@test combine(gatherby(sds, 1), r"x"=>minimum) == combine(groupby(sds, 1), r"x"=>minimum)
502+
@test combine(gatherby(sds, 1), r"x"=>var) == combine(groupby(sds, 1), r"x"=>var)
503+
470504
end

0 commit comments

Comments
 (0)