Skip to content

Commit 7d4d469

Browse files
committed
flatten/!: bug fix - improve performance
1 parent 662f59f commit 7d4d469

File tree

5 files changed

+146
-109
lines changed

5 files changed

+146
-109
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
* Add a new function `eachgroup`. It allows iteration over each group of a grouped data set.
77
* `op` is a new keyword argument for the `update/!` functions which allows passing a user defined function to control how the value of the main data set should be updated by the values from the transaction data set. ([issue #55](https://github.com/sl-solution/InMemoryDatasets.jl/issues/55))
88
* Supporting of the `mapformats` keyword argument in `flatten/!`. Now users can flatten a data set based on the formatted values. ([issue #57](https://github.com/sl-solution/InMemoryDatasets.jl/issues/57))
9+
* Support of the `threads` keyword argument in `flatten/!`.
910

1011
## Fixes
1112

@@ -14,6 +15,12 @@
1415
* `update` and `update!` have the same `mode` option by default.
1516
* Fix the problem with preserving format of `SubDataset` in `flatten/!`
1617
* Fix the problem that caused `flatten!` to produce a copy of data when an empty data set were passed to it.
18+
* Fix the bug in `flatten!` related to flatten the first column.
19+
* Fix the bug in `flatten` that caused Segmentation fault for view of data sets.
20+
21+
## Performance
22+
23+
* Faster `flatten/!`
1724

1825
# Version 0.7.0
1926

src/abstractdataset/abstractdataset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1450,7 +1450,7 @@ function Base.hash(ds::AbstractDataset, h::UInt)
14501450
h += hashds_seed
14511451
h += hash(size(ds))
14521452
for i in 1:size(ds, 2)
1453-
h = hash(ds[!, i], h)
1453+
h = hash(_columns(ds)[i], h)
14541454
end
14551455
return h
14561456
end

src/byrow/byrow.jl

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,14 +203,22 @@ function byrow(ds::AbstractDataset, f::Function, cols::MultiColumnIndex; threads
203203
length(colsidx) == 1 && return byrow(ds, f, colsidx[1]; threads = threads)
204204
threads ? hp_row_generic(ds, f, cols) : row_generic(ds, f, cols)
205205
end
206-
function byrow(ds::AbstractDataset, f::Function, col::ColumnIndex; threads = nrow(ds)>1000)
206+
function byrow(ds::AbstractDataset, f::Function, col::ColumnIndex; threads = nrow(ds)>1000, forcemissing::Bool = true)
207207
if threads
208208
T = Core.Compiler.return_type(f, Tuple{nonmissingtype(eltype(ds[!, col]))})
209-
res = Vector{Union{Missing, T}}(undef, nrow(ds))
209+
if forcemissing
210+
res = Vector{Union{Missing, T}}(undef, nrow(ds))
211+
else
212+
res = Vector{T}(undef, nrow(ds))
213+
end
210214
_hp_map_a_function!(res, f, _columns(ds)[index(ds)[col]])
211215
else
212216
T = Core.Compiler.return_type(f, Tuple{nonmissingtype(eltype(ds[!, col]))})
213-
res = Vector{Union{Missing, T}}(undef, nrow(ds))
217+
if forcemissing
218+
res = Vector{Union{Missing, T}}(undef, nrow(ds))
219+
else
220+
res = Vector{T}(undef, nrow(ds))
221+
end
214222
map!(f, res, _columns(ds)[index(ds)[col]])
215223
end
216224
res

src/dataset/transpose.jl

Lines changed: 111 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ Base.transpose(ds::Union{GroupBy, GatherBy}, cols::Tuple; id = nothing, renameco
593593

594594

595595
"""
596-
flatten(ds::AbstractDataset, cols; mapformats = false)
596+
flatten(ds::AbstractDataset, cols; mapformats = false, threads = true)
597597
598598
When columns `cols` of data set `ds` have iterable elements that define
599599
`length` (for example a `Vector` of `Vector`s), return a `Dataset` where each
@@ -609,6 +609,8 @@ When `mapformats = true`, the function uses the formatted values of `cols`.
609609
610610
`cols` can be any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR).
611611
612+
To turn off multithreaded computations pass `threads = false`.
613+
612614
See [`flatten!`](@ref)
613615
614616
# Examples
@@ -743,96 +745,87 @@ julia> flatten(ds, 2:3, mapformats = true)
743745
flatten(ds, cols)
744746

745747
"""
746-
flatten!(ds, cols; mapformats = false)
748+
flatten!(ds, cols; mapformats = false, threads = true)
747749
748750
Variant of `flatten` that does flatten `ds` in-place.
749751
"""
750752
flatten!
751753

752-
function _ELTYPE(x; fmt = identity)
753-
if fmt == identity
754-
eltype(x)
755-
else
756-
eltype(fmt(x))
757-
end
754+
function _ELTYPE(x)
755+
eltype(x)
758756
end
759-
function _ELTYPE(x::Missing; fmt = identity)
760-
if fmt == identity
761-
Missing
762-
elseif ismissing(fmt(x))
763-
Missing
764-
else
765-
eltype(fmt(x))
766-
end
757+
function _ELTYPE(x::Missing)
758+
Missing
767759
end
768760

769761

770-
function _LENGTH(x; fmt = identity)
771-
if fmt == identity
772-
res = length(x)
773-
else
774-
res = length(fmt(x))
775-
end
776-
res
762+
function _LENGTH(x)
763+
length(x)
777764
end
778765

779-
function _LENGTH(x::Missing; fmt = identity)
780-
if fmt == identity
781-
res = 1
782-
elseif ismissing(fmt(x))
783-
res = 1
784-
else
785-
res = length(fmt(x))
786-
end
787-
res
766+
function _LENGTH(x::Missing)
767+
1
788768
end
789769

790770

791771
function flatten!(ds::Dataset,
792-
cols::Union{ColumnIndex, MultiColumnIndex}; mapformats = false)
772+
cols::Union{ColumnIndex, MultiColumnIndex}; mapformats = false, threads = true)
793773
_check_consistency(ds)
794774

795775
idxcols = index(ds)[cols]
796776
isempty(idxcols) && return ds
797777
col1 = first(idxcols)
778+
all_idxcols = Any[]
798779
if mapformats
799780
f_fmt = getformat(ds, col1)
800-
lengths = _LENGTH.(_columns(ds)[col1], fmt = f_fmt)
781+
push!(all_idxcols, byrow(ds, f_fmt, col1, threads = threads))
801782
else
802-
lengths = _LENGTH.(_columns(ds)[col1])
803-
end
804-
for col in idxcols
805-
v = _columns(ds)[col]
806-
if mapformats
807-
f_fmt = getformat(ds, col)
808-
else
809-
f_fmt = identity
810-
end
811-
if any(x -> _LENGTH(x[1], fmt = f_fmt) != x[2], zip(v, lengths))
812-
r = findfirst(x -> x != 0, _LENGTH.(v, fmt = f_fmt) .- lengths)
813-
colnames = _names(ds)
814-
throw(ArgumentError("Lengths of iterables stored in columns :$(colnames[col1]) " *
815-
"and :$(colnames[col]) are not the same in row $r"))
783+
push!(all_idxcols, _columns(ds)[col1])
784+
end
785+
lengths = byrow(Dataset(all_idxcols, [:x], copycols = false), _LENGTH, 1, threads = threads, forcemissing = false)
786+
if length(idxcols) > 1
787+
for col in 2:length(idxcols)
788+
if mapformats
789+
f_fmt = getformat(ds, idxcols[col])
790+
push!(all_idxcols, byrow(ds, f_fmt, idxcols[col]), threads = threads)
791+
else
792+
push!(all_idxcols, _columns(ds)[idxcols[col]])
793+
end
794+
v = all_idxcols[col]
795+
if any(x -> _LENGTH(x[1]) != x[2], zip(v, lengths))
796+
r = findfirst(x -> x != 0, _LENGTH.(v) .- lengths)
797+
colnames = _names(ds)
798+
throw(ArgumentError("Lengths of iterables stored in columns :$(colnames[col1]) " *
799+
"and :$(colnames[idxcols[col]]) are not the same in row $r"))
800+
end
816801
end
817802
end
818803
r_index = _create_index_for_repeat(lengths, nrow(ds) < typemax(Int32) ? Val(Int32) : Val(Int64))
819-
_permute_ds_after_sort!(ds, r_index, check = false, cols = Not(cols))
820-
new_total = sum(lengths)
821-
length(idxcols) > 1 && sort!(idxcols)
822-
for col in idxcols
823-
col_to_flatten = _columns(ds)[col]
824-
if mapformats
825-
f_fmt = getformat(ds, col)
826-
else
827-
f_fmt = identity
828-
end
829-
T = mapreduce(x->_ELTYPE(x, fmt = f_fmt), promote_type, col_to_flatten)
804+
_permute_ds_after_sort!(ds, r_index, check = false, cols = Not(cols), threads = threads)
805+
if threads
806+
new_total = hp_sum(lengths)
807+
else
808+
new_total = sum(lengths)
809+
end
810+
if length(idxcols) > 1
811+
sort_permute_idxcols = sortperm(idxcols)
812+
idxcols_sorted = idxcols[sort_permute_idxcols]
813+
else
814+
sort_permute_idxcols = [1]
815+
idxcols_sorted = idxcols
816+
end
817+
cumsum!(lengths, lengths)
818+
for col in 1:length(idxcols_sorted)
819+
col_to_flatten = all_idxcols[sort_permute_idxcols[col]]
820+
821+
T = mapreduce(_ELTYPE, promote_type, col_to_flatten)
830822
_res = allocatecol(T, new_total)
831-
_fill_flatten!(_res, col_to_flatten, lengths; fmt = f_fmt)
823+
_fill_flatten!(_res, col_to_flatten, lengths, threads = threads)
832824
if length(idxcols) == ncol(ds)
833-
_columns(ds)[col] = _res
825+
_columns(ds)[idxcols_sorted[col]] = _res
834826
else
835-
ds[!, col] = _res
827+
deleteat!(_columns(ds), idxcols_sorted[col])
828+
insert!(_columns(ds), idxcols_sorted[col], _res)
836829
end
837830
end
838831
_reset_grouping_info!(ds)
@@ -842,49 +835,62 @@ end
842835

843836

844837
function flatten(ds::AbstractDataset,
845-
cols::Union{ColumnIndex, MultiColumnIndex}; mapformats = false)
838+
cols::Union{ColumnIndex, MultiColumnIndex}; mapformats = false, threads = true)
846839
_check_consistency(ds)
847840

848841
idxcols = index(ds)[cols]
849842
isempty(idxcols) && return copy(ds)
850843
col1 = first(idxcols)
844+
all_idxcols = Any[]
851845
if mapformats
852846
f_fmt = getformat(ds, col1)
853-
lengths = _LENGTH.(_columns(ds)[col1], fmt = f_fmt)
847+
push!(all_idxcols, byrow(ds, f_fmt, col1, threads = threads))
854848
else
855-
lengths = _LENGTH.(_columns(ds)[col1])
856-
end
857-
for col in idxcols
858-
v = _columns(ds)[col]
859-
if mapformats
860-
f_fmt = getformat(ds, col)
861-
else
862-
f_fmt = identity
863-
end
864-
if any(x -> _LENGTH(x[1], fmt = f_fmt) != x[2], zip(v, lengths))
865-
r = findfirst(x -> x != 0, _LENGTH.(v, fmt = f_fmt) .- lengths)
866-
colnames = _names(ds)
867-
throw(ArgumentError("Lengths of iterables stored in columns :$(colnames[col1]) " *
868-
"and :$(colnames[col]) are not the same in row $r"))
849+
push!(all_idxcols, _columns(ds)[col1])
850+
end
851+
lengths = byrow(Dataset(all_idxcols, [:x], copycols = false), _LENGTH, 1, threads = threads, forcemissing = false)
852+
if length(idxcols) > 1
853+
for col in 2:length(idxcols)
854+
if mapformats
855+
f_fmt = getformat(ds, idxcols[col])
856+
push!(all_idxcols, byrow(ds, f_fmt, idxcols[col]), threads = threads)
857+
else
858+
push!(all_idxcols, _columns(ds)[idxcols[col]])
859+
end
860+
v = all_idxcols[col]
861+
if any(x -> _LENGTH(x[1]) != x[2], zip(v, lengths))
862+
r = findfirst(x -> x != 0, _LENGTH.(v) .- lengths)
863+
colnames = _names(ds)
864+
throw(ArgumentError("Lengths of iterables stored in columns :$(colnames[col1]) " *
865+
"and :$(colnames[idxcols[col]]) are not the same in row $r"))
866+
end
869867
end
870868
end
871-
new_total = sum(lengths)
869+
if threads
870+
new_total = hp_sum(lengths)
871+
else
872+
new_total = sum(lengths)
873+
end
872874
new_ds = similar(ds[!, Not(cols)], new_total)
873875
for name in _names(new_ds)
874-
repeat_lengths_v2!(new_ds[!, name].val, ds[!, name].val, lengths)
876+
col_name = index(ds)[name]
877+
repeat_lengths_v2!(new_ds[!, name].val, _columns(ds)[col_name], lengths)
875878
end
876-
length(idxcols) > 1 && sort!(idxcols)
877-
for col in idxcols
878-
col_to_flatten = _columns(ds)[col]
879-
if mapformats
880-
f_fmt = getformat(ds, col)
881-
else
882-
f_fmt = identity
883-
end
884-
T = mapreduce(x->_ELTYPE(x, fmt = f_fmt), promote_type, col_to_flatten)
879+
if length(idxcols) > 1
880+
sort_permute_idxcols = sortperm(idxcols)
881+
idxcols_sorted = idxcols[sort_permute_idxcols]
882+
else
883+
sort_permute_idxcols = [1]
884+
idxcols_sorted = idxcols
885+
end
886+
cumsum!(lengths, lengths)
887+
for col in 1:length(idxcols_sorted)
888+
col_to_flatten = all_idxcols[sort_permute_idxcols[col]]
889+
890+
T = mapreduce(_ELTYPE, promote_type, col_to_flatten)
885891
_res = allocatecol(T, new_total)
886-
_fill_flatten!(_res, col_to_flatten, lengths; fmt = f_fmt)
887-
insertcols!(new_ds, col, _names(ds)[col] => _res, unsupported_copy_cols = false)
892+
_fill_flatten!(_res, col_to_flatten, lengths, threads = threads)
893+
insertcols!(new_ds, idxcols_sorted[col], _names(ds)[idxcols_sorted[col]] => _res, unsupported_copy_cols = false)
888894
end
889895
for j in setdiff(1:ncol(ds), idxcols)
890896
setformat!(new_ds, j=>getformat(ds, j))
@@ -895,22 +901,22 @@ function flatten(ds::AbstractDataset,
895901
end
896902

897903

898-
function _fill_flatten!_barrier(_res, val, counter; fmt = identity)
899-
for j in fmt(val)
900-
_res[counter] = j
901-
counter += 1
904+
function _fill_flatten!_barrier(_res, val, lo)
905+
if ismissing(val)
906+
_res[lo] = missing
907+
else
908+
909+
cnt = 0
910+
for j in val
911+
_res[lo+cnt] = j
912+
cnt += 1
913+
end
902914
end
903-
counter
904915
end
905916

906-
function _fill_flatten!(_res, col_to_flatten, lengths; fmt = identity)
907-
counter = 1
908-
for i in 1:length(col_to_flatten)
909-
if ismissing(fmt(col_to_flatten[i]))
910-
_res[counter] = missing
911-
counter += 1
912-
else
913-
counter = _fill_flatten!_barrier(_res, col_to_flatten[i], counter; fmt = fmt)
914-
end
917+
function _fill_flatten!(_res, col_to_flatten, lengths; threads = false)
918+
@_threadsfor threads for i in 1:length(col_to_flatten)
919+
i == 1 ? lo = 1 : lo = lengths[i-1]+1
920+
_fill_flatten!_barrier(_res, col_to_flatten[i], lo)
915921
end
916922
end

test/transpose.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,22 @@ end
589589
@test flatten(view(ds, :, [2,1]), :y, mapformats = true) == Dataset(reverse([Union{Missing, Int64}[1, 1, 2, 3, 3, 4], Union{Missing, SubString{String}}["ab", "bc", "d", "ef", "gh", missing]]), reverse([:x, :y]))
590590
flatten!(ds, :y, mapformats = true)
591591
@test ds == Dataset([Union{Missing, Int64}[1, 1, 2, 3, 3, 4], Union{Missing, SubString{String}}["ab", "bc", "d", "ef", "gh", missing]], [:x, :y])
592+
593+
for i in 1:10
594+
ds = Dataset(x=rand(10000),y=[rand(1:100, rand(1:5)) for _ in 1:10000],z=rand(["12,34", "2312,343","32423,,343", missing], 10000))
595+
fmt__(x) = split(x, ",")
596+
fmt__(::Missing) = missing
597+
setformat!(ds, :z => fmt__)
598+
@test flatten(ds, :y) == flatten!(copy(ds), :y)
599+
@test flatten(ds, :z, mapformats = true) == flatten!(copy(ds), :z, mapformats = true)
600+
@test flatten(view(ds, [1,2,5,10], [2,3,1]), :y) == flatten!(ds[[1,2,5,10], [2,3,1]], :y)
601+
@test flatten(view(ds, [1,2,5,10], [2,3,1]), :z, mapformats = true) == flatten!(ds[[1,2,5,10], [2,3,1]], :z, mapformats = true)
602+
@test flatten(ds, :y) == flatten!(copy(ds), :y, threads = false)
603+
@test flatten(ds, :z, mapformats = true, threads = false) == flatten!(copy(ds), :z, mapformats = true)
604+
@test flatten(view(ds, [1,2,5,10], [2,3,1]), :y, threads = true) == flatten!(ds[[1,2,5,10], [2,3,1]], :y, threads = false)
605+
@test flatten(view(ds, [1,2,5,10], [2,3,1]), :z, mapformats = true, threads = false) == flatten!(ds[[1,2,5,10], [2,3,1]], :z, mapformats = true, threads = true)
606+
end
607+
592608
end
593609

594610
@testset "transpose - views" begin

0 commit comments

Comments
 (0)