@@ -593,7 +593,7 @@ Base.transpose(ds::Union{GroupBy, GatherBy}, cols::Tuple; id = nothing, renameco
593593
594594
595595"""
596- flatten(ds::AbstractDataset, cols; mapformats = false)
596+ flatten(ds::AbstractDataset, cols; mapformats = false, threads = true )
597597
598598When columns `cols` of data set `ds` have iterable elements that define
599599`length` (for example a `Vector` of `Vector`s), return a `Dataset` where each
@@ -609,6 +609,8 @@ When `mapformats = true`, the function uses the formatted values of `cols`.
609609
610610`cols` can be any column selector ($COLUMNINDEX_STR ; $MULTICOLUMNINDEX_STR ).
611611
612+ To turn off multithreaded computations pass `threads = false`.
613+
612614See [`flatten!`](@ref)
613615
614616# Examples
@@ -743,96 +745,87 @@ julia> flatten(ds, 2:3, mapformats = true)
743745flatten (ds, cols)
744746
745747"""
746- flatten!(ds, cols; mapformats = false)
748+ flatten!(ds, cols; mapformats = false, threads = true )
747749
748750Variant of `flatten` that does flatten `ds` in-place.
749751"""
750752flatten!
751753
752- function _ELTYPE (x; fmt = identity)
753- if fmt == identity
754- eltype (x)
755- else
756- eltype (fmt (x))
757- end
754+ function _ELTYPE (x)
755+ eltype (x)
758756end
759- function _ELTYPE (x:: Missing ; fmt = identity)
760- if fmt == identity
761- Missing
762- elseif ismissing (fmt (x))
763- Missing
764- else
765- eltype (fmt (x))
766- end
757+ function _ELTYPE (x:: Missing )
758+ Missing
767759end
768760
769761
770- function _LENGTH (x; fmt = identity)
771- if fmt == identity
772- res = length (x)
773- else
774- res = length (fmt (x))
775- end
776- res
762+ function _LENGTH (x)
763+ length (x)
777764end
778765
779- function _LENGTH (x:: Missing ; fmt = identity)
780- if fmt == identity
781- res = 1
782- elseif ismissing (fmt (x))
783- res = 1
784- else
785- res = length (fmt (x))
786- end
787- res
766+ function _LENGTH (x:: Missing )
767+ 1
788768end
789769
790770
791771function flatten! (ds:: Dataset ,
792- cols:: Union{ColumnIndex, MultiColumnIndex} ; mapformats = false )
772+ cols:: Union{ColumnIndex, MultiColumnIndex} ; mapformats = false , threads = true )
793773 _check_consistency (ds)
794774
795775 idxcols = index (ds)[cols]
796776 isempty (idxcols) && return ds
797777 col1 = first (idxcols)
778+ all_idxcols = Any[]
798779 if mapformats
799780 f_fmt = getformat (ds, col1)
800- lengths = _LENGTH .( _columns (ds)[ col1], fmt = f_fmt )
781+ push! (all_idxcols, byrow (ds, f_fmt, col1, threads = threads) )
801782 else
802- lengths = _LENGTH .(_columns (ds)[col1])
803- end
804- for col in idxcols
805- v = _columns (ds)[col]
806- if mapformats
807- f_fmt = getformat (ds, col)
808- else
809- f_fmt = identity
810- end
811- if any (x -> _LENGTH (x[1 ], fmt = f_fmt) != x[2 ], zip (v, lengths))
812- r = findfirst (x -> x != 0 , _LENGTH .(v, fmt = f_fmt) .- lengths)
813- colnames = _names (ds)
814- throw (ArgumentError (" Lengths of iterables stored in columns :$(colnames[col1]) " *
815- " and :$(colnames[col]) are not the same in row $r " ))
783+ push! (all_idxcols, _columns (ds)[col1])
784+ end
785+ lengths = byrow (Dataset (all_idxcols, [:x ], copycols = false ), _LENGTH, 1 , threads = threads, forcemissing = false )
786+ if length (idxcols) > 1
787+ for col in 2 : length (idxcols)
788+ if mapformats
789+ f_fmt = getformat (ds, idxcols[col])
790+ push! (all_idxcols, byrow (ds, f_fmt, idxcols[col]), threads = threads)
791+ else
792+ push! (all_idxcols, _columns (ds)[idxcols[col]])
793+ end
794+ v = all_idxcols[col]
795+ if any (x -> _LENGTH (x[1 ]) != x[2 ], zip (v, lengths))
796+ r = findfirst (x -> x != 0 , _LENGTH .(v) .- lengths)
797+ colnames = _names (ds)
798+ throw (ArgumentError (" Lengths of iterables stored in columns :$(colnames[col1]) " *
799+ " and :$(colnames[idxcols[col]]) are not the same in row $r " ))
800+ end
816801 end
817802 end
818803 r_index = _create_index_for_repeat (lengths, nrow (ds) < typemax (Int32) ? Val (Int32) : Val (Int64))
819- _permute_ds_after_sort! (ds, r_index, check = false , cols = Not (cols))
820- new_total = sum (lengths)
821- length (idxcols) > 1 && sort! (idxcols)
822- for col in idxcols
823- col_to_flatten = _columns (ds)[col]
824- if mapformats
825- f_fmt = getformat (ds, col)
826- else
827- f_fmt = identity
828- end
829- T = mapreduce (x-> _ELTYPE (x, fmt = f_fmt), promote_type, col_to_flatten)
804+ _permute_ds_after_sort! (ds, r_index, check = false , cols = Not (cols), threads = threads)
805+ if threads
806+ new_total = hp_sum (lengths)
807+ else
808+ new_total = sum (lengths)
809+ end
810+ if length (idxcols) > 1
811+ sort_permute_idxcols = sortperm (idxcols)
812+ idxcols_sorted = idxcols[sort_permute_idxcols]
813+ else
814+ sort_permute_idxcols = [1 ]
815+ idxcols_sorted = idxcols
816+ end
817+ cumsum! (lengths, lengths)
818+ for col in 1 : length (idxcols_sorted)
819+ col_to_flatten = all_idxcols[sort_permute_idxcols[col]]
820+
821+ T = mapreduce (_ELTYPE, promote_type, col_to_flatten)
830822 _res = allocatecol (T, new_total)
831- _fill_flatten! (_res, col_to_flatten, lengths; fmt = f_fmt )
823+ _fill_flatten! (_res, col_to_flatten, lengths, threads = threads )
832824 if length (idxcols) == ncol (ds)
833- _columns (ds)[col] = _res
825+ _columns (ds)[idxcols_sorted[ col] ] = _res
834826 else
835- ds[! , col] = _res
827+ deleteat! (_columns (ds), idxcols_sorted[col])
828+ insert! (_columns (ds), idxcols_sorted[col], _res)
836829 end
837830 end
838831 _reset_grouping_info! (ds)
@@ -842,49 +835,62 @@ end
842835
843836
844837function flatten (ds:: AbstractDataset ,
845- cols:: Union{ColumnIndex, MultiColumnIndex} ; mapformats = false )
838+ cols:: Union{ColumnIndex, MultiColumnIndex} ; mapformats = false , threads = true )
846839 _check_consistency (ds)
847840
848841 idxcols = index (ds)[cols]
849842 isempty (idxcols) && return copy (ds)
850843 col1 = first (idxcols)
844+ all_idxcols = Any[]
851845 if mapformats
852846 f_fmt = getformat (ds, col1)
853- lengths = _LENGTH .( _columns (ds)[ col1], fmt = f_fmt )
847+ push! (all_idxcols, byrow (ds, f_fmt, col1, threads = threads) )
854848 else
855- lengths = _LENGTH .(_columns (ds)[col1])
856- end
857- for col in idxcols
858- v = _columns (ds)[col]
859- if mapformats
860- f_fmt = getformat (ds, col)
861- else
862- f_fmt = identity
863- end
864- if any (x -> _LENGTH (x[1 ], fmt = f_fmt) != x[2 ], zip (v, lengths))
865- r = findfirst (x -> x != 0 , _LENGTH .(v, fmt = f_fmt) .- lengths)
866- colnames = _names (ds)
867- throw (ArgumentError (" Lengths of iterables stored in columns :$(colnames[col1]) " *
868- " and :$(colnames[col]) are not the same in row $r " ))
849+ push! (all_idxcols, _columns (ds)[col1])
850+ end
851+ lengths = byrow (Dataset (all_idxcols, [:x ], copycols = false ), _LENGTH, 1 , threads = threads, forcemissing = false )
852+ if length (idxcols) > 1
853+ for col in 2 : length (idxcols)
854+ if mapformats
855+ f_fmt = getformat (ds, idxcols[col])
856+ push! (all_idxcols, byrow (ds, f_fmt, idxcols[col]), threads = threads)
857+ else
858+ push! (all_idxcols, _columns (ds)[idxcols[col]])
859+ end
860+ v = all_idxcols[col]
861+ if any (x -> _LENGTH (x[1 ]) != x[2 ], zip (v, lengths))
862+ r = findfirst (x -> x != 0 , _LENGTH .(v) .- lengths)
863+ colnames = _names (ds)
864+ throw (ArgumentError (" Lengths of iterables stored in columns :$(colnames[col1]) " *
865+ " and :$(colnames[idxcols[col]]) are not the same in row $r " ))
866+ end
869867 end
870868 end
871- new_total = sum (lengths)
869+ if threads
870+ new_total = hp_sum (lengths)
871+ else
872+ new_total = sum (lengths)
873+ end
872874 new_ds = similar (ds[! , Not (cols)], new_total)
873875 for name in _names (new_ds)
874- repeat_lengths_v2! (new_ds[! , name]. val, ds[! , name]. val, lengths)
876+ col_name = index (ds)[name]
877+ repeat_lengths_v2! (new_ds[! , name]. val, _columns (ds)[col_name], lengths)
875878 end
876- length (idxcols) > 1 && sort! (idxcols)
877- for col in idxcols
878- col_to_flatten = _columns (ds)[col]
879- if mapformats
880- f_fmt = getformat (ds, col)
881- else
882- f_fmt = identity
883- end
884- T = mapreduce (x-> _ELTYPE (x, fmt = f_fmt), promote_type, col_to_flatten)
879+ if length (idxcols) > 1
880+ sort_permute_idxcols = sortperm (idxcols)
881+ idxcols_sorted = idxcols[sort_permute_idxcols]
882+ else
883+ sort_permute_idxcols = [1 ]
884+ idxcols_sorted = idxcols
885+ end
886+ cumsum! (lengths, lengths)
887+ for col in 1 : length (idxcols_sorted)
888+ col_to_flatten = all_idxcols[sort_permute_idxcols[col]]
889+
890+ T = mapreduce (_ELTYPE, promote_type, col_to_flatten)
885891 _res = allocatecol (T, new_total)
886- _fill_flatten! (_res, col_to_flatten, lengths; fmt = f_fmt )
887- insertcols! (new_ds, col, _names (ds)[col] => _res, unsupported_copy_cols = false )
892+ _fill_flatten! (_res, col_to_flatten, lengths, threads = threads )
893+ insertcols! (new_ds, idxcols_sorted[ col] , _names (ds)[idxcols_sorted[ col] ] => _res, unsupported_copy_cols = false )
888894 end
889895 for j in setdiff (1 : ncol (ds), idxcols)
890896 setformat! (new_ds, j=> getformat (ds, j))
@@ -895,22 +901,22 @@ function flatten(ds::AbstractDataset,
895901end
896902
897903
898- function _fill_flatten!_barrier (_res, val, counter; fmt = identity)
899- for j in fmt (val)
900- _res[counter] = j
901- counter += 1
904+ function _fill_flatten!_barrier (_res, val, lo)
905+ if ismissing (val)
906+ _res[lo] = missing
907+ else
908+
909+ cnt = 0
910+ for j in val
911+ _res[lo+ cnt] = j
912+ cnt += 1
913+ end
902914 end
903- counter
904915end
905916
906- function _fill_flatten! (_res, col_to_flatten, lengths; fmt = identity)
907- counter = 1
908- for i in 1 : length (col_to_flatten)
909- if ismissing (fmt (col_to_flatten[i]))
910- _res[counter] = missing
911- counter += 1
912- else
913- counter = _fill_flatten!_barrier (_res, col_to_flatten[i], counter; fmt = fmt)
914- end
917+ function _fill_flatten! (_res, col_to_flatten, lengths; threads = false )
918+ @_threadsfor threads for i in 1 : length (col_to_flatten)
919+ i == 1 ? lo = 1 : lo = lengths[i- 1 ]+ 1
920+ _fill_flatten!_barrier (_res, col_to_flatten[i], lo)
915921 end
916922end
0 commit comments