Skip to content

Commit 8f39de2

Browse files
committed
improve performance of flatten + add flatten!
1 parent 57503bc commit 8f39de2

File tree

4 files changed

+246
-0
lines changed

4 files changed

+246
-0
lines changed

src/InMemoryDatasets.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ export
5959
dropmissing,
6060
dropmissing!,
6161
flatten,
62+
flatten!,
6263
repeat!,
6364
select,
6465
select!,

src/dataset/transpose.jl

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,3 +574,192 @@ end
574574

575575
Base.transpose(ds::Union{GroupBy, GatherBy}, cols::Tuple; id = nothing, renamecolid = nothing, renamerowid = _default_renamerowid_function, variable_name = nothing, default = missing, threads = true, mapformats = true) =
576576
ds_transpose(ds, cols, _groupcols(ds); id = id, renamecolid = renamecolid, renamerowid = renamerowid, variable_name = variable_name, threads = threads, default_fill = default, mapformats = mapformats)
577+
578+
579+
#### flatten
580+
581+
582+
"""
583+
flatten(ds::AbstractDataset, cols)
584+
585+
When columns `cols` of data set `ds` have iterable elements that define
586+
`length` (for example a `Vector` of `Vector`s), return a `Dataset` where each
587+
element of each `col` in `cols` is flattened, meaning the column corresponding
588+
to `col` becomes a longer vector where the original entries are concatenated.
589+
Elements of row `i` of `ds` in columns other than `cols` will be repeated
590+
according to the length of `ds[i, col]`. These lengths must therefore be the
591+
same for each `col` in `cols`, or else an error is raised. Note that these
592+
elements are not copied, and thus if they are mutable changing them in the
593+
returned `Dataset` will affect `ds`.
594+
595+
`cols` can be any column selector ($COLUMNINDEX_STR; $MULTICOLUMNINDEX_STR).
596+
597+
# Examples
598+
599+
```jldoctest
600+
julia> ds1 = Dataset(a = [1, 2], b = [[1, 2], [3, 4]], c = [[5, 6], [7, 8]])
601+
2×3 Dataset
602+
Row │ a b c
603+
│ identity identity identity
604+
│ Int64? Array…? Array…?
605+
─────┼──────────────────────────────
606+
1 │ 1 [1, 2] [5, 6]
607+
2 │ 2 [3, 4] [7, 8]
608+
609+
julia> flatten(ds1, :b)
610+
4×3 Dataset
611+
Row │ a b c
612+
│ identity identity identity
613+
│ Int64? Int64? Array…?
614+
─────┼──────────────────────────────
615+
1 │ 1 1 [5, 6]
616+
2 │ 1 2 [5, 6]
617+
3 │ 2 3 [7, 8]
618+
4 │ 2 4 [7, 8]
619+
620+
julia> flatten(ds1, [:b, :c])
621+
4×3 Dataset
622+
Row │ a b c
623+
│ identity identity identity
624+
│ Int64? Int64? Int64?
625+
─────┼──────────────────────────────
626+
1 │ 1 1 5
627+
2 │ 1 2 6
628+
3 │ 2 3 7
629+
4 │ 2 4 8
630+
631+
julia> ds2 = Dataset(a = [1, 2], b = [("p", "q"), ("r", "s")])
632+
2×2 Dataset
633+
Row │ a b
634+
│ identity identity
635+
│ Int64? Tuple…?
636+
─────┼──────────────────────
637+
1 │ 1 ("p", "q")
638+
2 │ 2 ("r", "s")
639+
640+
julia> flatten(ds2, :b)
641+
4×2 Dataset
642+
Row │ a b
643+
│ identity identity
644+
│ Int64? String?
645+
─────┼────────────────────
646+
1 │ 1 p
647+
2 │ 1 q
648+
3 │ 2 r
649+
4 │ 2 s
650+
651+
julia> ds3 = Dataset(a = [1, 2], b = [[1, 2], [3, 4]], c = [[5, 6], [7]])
652+
2×3 Dataset
653+
Row │ a b c
654+
│ identity identity identity
655+
│ Int64? Array…? Array…?
656+
─────┼──────────────────────────────
657+
1 │ 1 [1, 2] [5, 6]
658+
2 │ 2 [3, 4] [7]
659+
660+
julia> flatten(ds3, [:b, :c])
661+
ERROR: ArgumentError: Lengths of iterables stored in columns :b and :c are not the same in row 2
662+
```
663+
"""
664+
flatten(ds, cols)
665+
666+
667+
_ELTYPE(x) = eltype(x)
668+
_ELTYPE(::Missing) = Missing
669+
_LENGTH(x) = length(x)
670+
_LENGTH(::Missing) = 1
671+
672+
function flatten!(ds::Dataset,
673+
cols::Union{ColumnIndex, MultiColumnIndex})
674+
_check_consistency(ds)
675+
676+
idxcols = index(ds)[cols]
677+
isempty(idxcols) && return copy(ds)
678+
col1 = first(idxcols)
679+
lengths = _LENGTH.(_columns(ds)[col1])
680+
for col in idxcols
681+
v = _columns(ds)[col]
682+
if any(x -> _LENGTH(x[1]) != x[2], zip(v, lengths))
683+
r = findfirst(x -> x != 0, _LENGTH.(v) .- lengths)
684+
colnames = _names(ds)
685+
throw(ArgumentError("Lengths of iterables stored in columns :$(colnames[col1]) " *
686+
"and :$(colnames[col]) are not the same in row $r"))
687+
end
688+
end
689+
r_index = _create_index_for_repeat(lengths, nrow(ds) < typemax(Int32) ? Val(Int32) : Val(Int64))
690+
_permute_ds_after_sort!(ds, r_index, check = false, cols = Not(cols))
691+
new_total = sum(lengths)
692+
length(idxcols) > 1 && sort!(idxcols)
693+
for col in idxcols
694+
col_to_flatten = _columns(ds)[col]
695+
T = mapreduce(_ELTYPE, promote_type, col_to_flatten)
696+
_res = allocatecol(T, new_total)
697+
_fill_flatten!(_res, col_to_flatten, lengths)
698+
if length(idxcols) == ncol(ds)
699+
_columns(ds)[col] = _res
700+
else
701+
ds[!, col] = _res
702+
end
703+
end
704+
_reset_grouping_info!(ds)
705+
_modified(_attributes(ds))
706+
ds
707+
end
708+
709+
710+
function flatten(ds::AbstractDataset,
711+
cols::Union{ColumnIndex, MultiColumnIndex})
712+
_check_consistency(ds)
713+
714+
idxcols = index(ds)[cols]
715+
isempty(idxcols) && return copy(ds)
716+
col1 = first(idxcols)
717+
lengths = _LENGTH.(_columns(ds)[col1])
718+
for col in idxcols
719+
v = _columns(ds)[col]
720+
if any(x -> _LENGTH(x[1]) != x[2], zip(v, lengths))
721+
r = findfirst(x -> x != 0, _LENGTH.(v) .- lengths)
722+
colnames = _names(ds)
723+
throw(ArgumentError("Lengths of iterables stored in columns :$(colnames[col1]) " *
724+
"and :$(colnames[col]) are not the same in row $r"))
725+
end
726+
end
727+
new_total = sum(lengths)
728+
new_ds = similar(ds[!, Not(cols)], new_total)
729+
for name in _names(new_ds)
730+
repeat_lengths_v2!(new_ds[!, name].val, ds[!, name].val, lengths)
731+
end
732+
length(idxcols) > 1 && sort!(idxcols)
733+
for col in idxcols
734+
col_to_flatten = _columns(ds)[col]
735+
T = mapreduce(_ELTYPE, promote_type, col_to_flatten)
736+
_res = allocatecol(T, new_total)
737+
_fill_flatten!(_res, col_to_flatten, lengths)
738+
insertcols!(new_ds, col, _names(ds)[col] => _res)
739+
end
740+
setformat!(new_ds, copy(index(ds).format))
741+
setinfo!(new_ds, _attributes(ds).meta.info[])
742+
_reset_grouping_info!(new_ds)
743+
new_ds
744+
end
745+
746+
747+
function _fill_flatten!_barrier(_res, val, counter)
748+
for j in val
749+
_res[counter] = j
750+
counter += 1
751+
end
752+
counter
753+
end
754+
755+
function _fill_flatten!(_res, col_to_flatten, lengths)
756+
counter = 1
757+
for i in 1:length(col_to_flatten)
758+
if ismissing(col_to_flatten[i])
759+
_res[counter] = missing
760+
counter += 1
761+
else
762+
counter = _fill_flatten!_barrier(_res, col_to_flatten[i], counter)
763+
end
764+
end
765+
end

test/transpose.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,36 +482,54 @@ end
482482
ref = Dataset(a = [1, 1, 2, 2], b = [1, 2, 3, 4])
483483
@test flatten(ds_vec, :b) == flatten(ds_tup, :b) == ref
484484
@test flatten(ds_vec, "b") == flatten(ds_tup, "b") == ref
485+
@test flatten!(copy(ds_vec), :b) == flatten(ds_tup, :b) == ref
486+
@test flatten!(copy(ds_vec), "b") == flatten(ds_tup, "b") == ref
485487
ds_mixed_types = Dataset(a = [1, 2], b = [[1, 2], ["x", "y"]])
486488
ref_mixed_types = Dataset(a = [1, 1, 2, 2], b = [1, 2, "x", "y"])
487489
@test flatten(ds_mixed_types, :b) == ref_mixed_types
490+
@test flatten!(copy(ds_mixed_types), :b) == ref_mixed_types
491+
488492
ds_three = Dataset(a = [1, 2, 3], b = [[1, 2], [10, 20], [100, 200, 300]])
489493
ref_three = Dataset(a = [1, 1, 2, 2, 3, 3, 3], b = [1, 2, 10, 20, 100, 200, 300])
490494
@test flatten(ds_three, :b) == ref_three
491495
@test flatten(ds_three, "b") == ref_three
496+
@test flatten!(copy(ds_three), :b) == ref_three
497+
@test flatten!(copy(ds_three), "b") == ref_three
492498
ds_gen = Dataset(a = [1, 2], b = [(i for i in 1:5), (i for i in 6:10)])
493499
ref_gen = Dataset(a = [fill(1, 5); fill(2, 5)], b = collect(1:10))
494500
@test flatten(ds_gen, :b) == ref_gen
495501
@test flatten(ds_gen, "b") == ref_gen
502+
@test flatten!(copy(ds_gen), :b) == ref_gen
503+
@test flatten!(copy(ds_gen), "b") == ref_gen
496504
ds_miss = Dataset(a = [1, 2], b = [Union{Missing, Int}[1, 2], Union{Missing, Int}[3, 4]])
497505
ref = Dataset(a = [1, 1, 2, 2], b = [1, 2, 3, 4])
498506
@test flatten(ds_miss, :b) == ref
499507
@test flatten(ds_miss, "b") == ref
508+
@test flatten!(copy(ds_miss), :b) == ref
509+
@test flatten!(copy(ds_miss), "b") == ref
500510
v1 = [[1, 2], [3, 4]]
501511
v2 = [[5, 6], [7, 8]]
502512
v = [v1, v2]
503513
ds_vec_vec = Dataset(a = [1, 2], b = v)
504514
ref_vec_vec = Dataset(a = [1, 1, 2, 2], b = [v1 ; v2])
505515
@test flatten(ds_vec_vec, :b) == ref_vec_vec
506516
@test flatten(ds_vec_vec, "b") == ref_vec_vec
517+
@test flatten!(copy(ds_vec_vec), :b) == ref_vec_vec
518+
@test flatten!(copy(ds_vec_vec), "b") == ref_vec_vec
507519
ds_cat = Dataset(a = [1, 2], b = [CategoricalArray([1, 2]), CategoricalArray([1, 2])])
508520
ds_flat_cat = flatten(ds_cat, :b)
509521
ref_cat = Dataset(a = [1, 1, 2, 2], b = [1, 2, 1, 2])
510522
@test ds_flat_cat == ref_cat
511523
@test ds_flat_cat.b.val isa CategoricalArray
524+
flatten!(ds_cat, :b)
525+
ref_cat = Dataset(a = [1, 1, 2, 2], b = [1, 2, 1, 2])
526+
@test ds_cat == ref_cat
527+
@test ds_cat.b.val isa CategoricalArray
512528

513529
ds = Dataset(a = [1, 2], b = [[1, 2], [3, 4]], c = [[5, 6], [7, 8]])
514530
@test flatten(ds, []) == ds
531+
@test flatten!(copy(ds), []) == ds
532+
515533
ref = Dataset(a = [1, 1, 2, 2], b = [1, 2, 3, 4], c = [5, 6, 7, 8])
516534
@test flatten(ds, [:b, :c]) == ref
517535
@test flatten(ds, [:c, :b]) == ref
@@ -521,16 +539,46 @@ end
521539
@test flatten(ds, r"[bc]") == ref
522540
@test flatten(ds, Not(:a)) == ref
523541
@test flatten(ds, Between(:b, :c)) == ref
542+
543+
@test flatten!(copy(ds), [:b, :c]) == ref
544+
@test flatten!(copy(ds), [:c, :b]) == ref
545+
@test flatten!(copy(ds), ["b", "c"]) == ref
546+
@test flatten!(copy(ds), ["c", "b"]) == ref
547+
@test flatten!(copy(ds), 2:3) == ref
548+
@test flatten!(copy(ds), r"[bc]") == ref
549+
@test flatten!(copy(ds), Not(:a)) == ref
550+
@test flatten!(copy(ds), Between(:b, :c)) == ref
551+
524552
ds_allcols = Dataset(b = [[1, 2], [3, 4]], c = [[5, 6], [7, 8]])
525553
ref_allcols = Dataset(b = [1, 2, 3, 4], c = [5, 6, 7, 8])
526554
@test flatten(ds_allcols, :) == ref_allcols
555+
@test flatten!(copy(ds_allcols), :) == ref_allcols
527556
ds_bad = Dataset(a = [1, 2], b = [[1, 2], [3, 4]], c = [[5, 6], [7]])
528557
@test_throws ArgumentError flatten(ds_bad, [:b, :c])
558+
@test_throws ArgumentError flatten!(copy(ds_bad), [:b, :c])
529559
ds_vec = Dataset(a = [1, missing], b = [[1, missing], [3, 4]])
530560
ds_tup = Dataset(a = [1, missing], b = [(1, missing), (3, 4)])
531561
ref = Dataset(a = [1, 1, missing, missing], b = [1, missing, 3, 4])
532562
@test flatten(ds_vec, :b) == flatten(ds_tup, :b) == ref
533563
@test flatten(ds_vec, "b") == flatten(ds_tup, "b") == ref
564+
@test flatten!(copy(ds_vec), :b) == flatten(ds_tup, :b) == ref
565+
@test flatten!(copy(ds_vec), "b") == flatten(ds_tup, "b") == ref
566+
567+
ds_cat = Dataset(a = [1, 2], b = [PooledArray([1, 2]), PooledArray([1, 2])])
568+
repeat!(ds_cat, 1000)
569+
ds_flat_cat = flatten(ds_cat, :b)
570+
ref_cat = Dataset(a = repeat([1, 1, 2, 2],1000), b = repeat([1, 2, 1, 2],1000))
571+
@test ds_flat_cat == ref_cat
572+
flatten!(ds_cat, :b)
573+
@test ds_cat == ref_cat
574+
575+
ds_cat = Dataset(a = [1, 2], b = [CategoricalArray([1, 2]), CategoricalArray([1, 2])])
576+
repeat!(ds_cat, 1000)
577+
ds_flat_cat = flatten(ds_cat, :b)
578+
ref_cat = Dataset(a = repeat([1, 1, 2, 2],1000), b = repeat([1, 2, 1, 2],1000))
579+
@test ds_flat_cat == ref_cat
580+
flatten!(ds_cat, :b)
581+
@test ds_cat == ref_cat
534582
end
535583

536584
@testset "transpose - views" begin

test/utils.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,12 @@ end
105105
@test repeat(ds, freq = :x2) == ds[repeat([1,2,2,2,3,4,4], 1000), :]
106106
@test repeat(ds, freq = :x4) == ds[repeat([1,1,2,3,4],1000), :]
107107
@test repeat(ds, freq = ds[!, :x4]) == ds[repeat([1,1,2,3,4],1000), :]
108+
109+
ds = Dataset(a = 0:2, b = 2:4)
110+
@test repeat(ds, freq = :a) == Dataset(a = [1,2,2], b = [3,4,4])
111+
@test repeat(ds, freq = [2,0,2]) == Dataset(a = [0,0,2,2], b=[2,2,4,4])
112+
113+
@test repeat(ds, freq = [1000, 1000, 0]) == Dataset(a = [fill(0,1000);fill(1, 1000)], b=[fill(2,1000);fill(3, 1000)])
114+
@test repeat(view(ds,[1,2,3], [1,2]), freq = [1000, 1000, 0]) == Dataset(a = [fill(0,1000);fill(1, 1000)], b=[fill(2,1000);fill(3, 1000)])
115+
@test repeat(view(ds,[1,2,3], [1,2]), freq = [1000, 1000, 0], view = true) == Dataset(a = [fill(0,1000);fill(1, 1000)], b=[fill(2,1000);fill(3, 1000)])
108116
end

0 commit comments

Comments
 (0)