Skip to content

Commit 038587b

Browse files
committed
supporting byrow with tuple of column indices
1 parent 8aed909 commit 038587b

File tree

9 files changed

+212
-12
lines changed

9 files changed

+212
-12
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22

33
## New features
44

5+
* A new functionality has been added to `byrow` for passing a Tuple of column indices. `byrow(ds, fun, cols)` calls `fun.(ds[:, cols[1]], ds[:, cols[2]], ...)` when `cols` is a NTuple of column indices.
6+
7+
# Version 0.7.6
8+
9+
## New features
10+
511
* Two new functions: `delete` and `delete!`. They should be compared to `filter` and `filter!`, respectively - [issue #63](https://github.com/sl-solution/InMemoryDatasets.jl/issues/63)
612
* Add `DLMReader` to `sysimage` in `IMD.create_sysimage`.
713

docs/src/man/byrow.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,9 @@ One special function that can be used as `fun` in the `byrow` function is `mapre
237237
238238
## User defined operations
239239

240-
For user defined functions which return a single value, `byrow` treats each row as a vector of values, thus the user defined function must accept a vector and returns a single value. For instance to calculate `1 * col1 + 2 * col2 + 3 * col3` for each row in `ds` we can define the following function:
240+
For user defined functions which return a single value, `byrow` treats each row as a vector of values, thus the user defined function must accept a vector and returns a single value.
241+
However, when user defines a multivariate function and pass a Tuple of column indices as the `cols` argument of `byrow`, the `byrow` function simply calls `fun.(ds[:, cols[1]], ds[:, cols2], ...)`.
242+
For instance to calculate `1 * col1 + 2 * col2 + 3 * col3` for each row in `ds` we can define the following function:
241243

242244
```jldoctest
243245
julia> avg(x) = 1 * x[1] + 2 * x[2] + 3 * x[3]
@@ -258,6 +260,31 @@ julia> byrow(ds, avg, 1:3)
258260

259261
Note that `avg` is missing if any of the values in `x` is missing.
260262

263+
Below is an example of using `byrow` with a user defined multivariate function
264+
265+
```jldoctest
266+
julia> ds = Dataset(x1 = [1,2,1,2], x2 = [1,-2,-3,10], x3 = 1:4)
267+
4×3 Dataset
268+
Row │ x1 x2 x3
269+
│ identity identity identity
270+
│ Int64? Int64? Int64?
271+
─────┼──────────────────────────────
272+
1 │ 1 1 1
273+
2 │ 2 -2 2
274+
3 │ 1 -3 3
275+
4 │ 2 10 4
276+
277+
julia> fun(x,y,z)::Float64 = x == 1 ? y*z : y/z
278+
fun (generic function with 1 method)
279+
280+
julia> byrow(ds, fun, (:x1, :x2, :x3))
281+
4-element Vector{Real}:
282+
1.0
283+
-1.0
284+
-9.0
285+
2.5
286+
```
287+
261288
## Special operations
262289

263290
`byrow` also supports a few optimised operations which return a vector of values for each row. The `fun` argument for these operations is one of the followings:

src/byrow/byrow.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,13 @@ function byrow(ds::AbstractDataset, f::Function, cols::MultiColumnIndex; threads
200200
length(colsidx) == 1 && return byrow(ds, f, colsidx[1]; threads = threads)
201201
threads ? hp_row_generic(ds, f, cols) : row_generic(ds, f, cols)
202202
end
203+
204+
# TODO do we need to make sure that the result is Union of Missing?
205+
function byrow(ds::AbstractDataset, f::Function, cols::NTuple{N, ColumnIndex}) where N
206+
cols_idx = [index(ds)[cols[i]] for i in 1:length(cols)]
207+
f.(view(_columns(ds), cols_idx)...)
208+
end
209+
203210
function byrow(ds::AbstractDataset, f::Function, col::ColumnIndex; threads = nrow(ds)>1000, allowmissing::Bool = true)
204211
if threads
205212
T = Core.Compiler.return_type(f, Tuple{nonmissingtype(eltype(ds[!, col]))})

src/byrow/doc.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1238,8 +1238,9 @@ Return the result of calling `fun` on each row of `ds` selected by `cols`. The `
12381238
12391239
When user passes a type as `fun` and a single column as `cols`, `byrow` convert the corresponding column to the type specified by `fun`.
12401240
1241-
For generic functions there are two special cases:
1241+
For generic functions there are the below special cases:
12421242
12431243
* When `cols` is a single column, `byrow(ds, fun, cols)` acts like `fun.(ds[:, cols])`
12441244
* When `cols` is referring to exactly two columns and it is possible to pass two vectors as arguments of `fun`, `byrow` returns `fun.(ds[:, col1], ds[:, col2])` when possible.
1245+
* When `cols` is a `Tuple` of column indices, `byrow(ds, fun, cols)` returns `fun.(ds[:, cols[1]], ds[:, cols[2]], ...)`, i.e. `fun` is a multivariate function which will be applied on each row of `cols`.
12451246
"""

src/dataset/combine.jl

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,73 @@ function normalize_combine!(outidx::Index, idx,
2727
return ntuple(i -> _names(idx)[idx[src[i]]], N) => fun => Symbol(dst)
2828
end
2929

30+
# this is add to support byrow for multivariate functions
31+
# (col1, col2) => byrow(fun) => dst, the job is to create (col1, col2) => byrow(fun) => :dst
32+
function normalize_combine!(outidx::Index, idx,
33+
@nospecialize(sel::Pair{<:NTuple{N, ColumnIndex},
34+
<:Pair{<:Vector{Expr},
35+
<:Union{Symbol, AbstractString}}})
36+
) where N
37+
src = sel.first
38+
if sel.second.first[1].head == :BYROW
39+
_check_ind_and_add!(outidx, Symbol(sel.second.second))
40+
return ntuple(i->outidx[src[i]], length(src)) => sel.second.first[1] => Symbol(sel.second.second)
41+
end
42+
throw(ArgumentError("only byrow is accepted when using expressions"))
43+
end
44+
function normalize_combine!(outidx::Index, idx,
45+
@nospecialize(sel::Pair{<:NTuple{N, ColumnIndex},
46+
<:Pair{<:Expr,
47+
<:Union{Symbol, AbstractString}}})
48+
) where N
49+
src = sel.first
50+
if sel.second.first.head == :BYROW
51+
_check_ind_and_add!(outidx, Symbol(sel.second.second))
52+
return ntuple(i->outidx[src[i]], length(src)) => sel.second.first[1] => Symbol(sel.second.second)
53+
end
54+
throw(ArgumentError("only byrow is accepted when using expressions"))
55+
end
56+
function normalize_combine!(outidx::Index, idx,
57+
@nospecialize(sel::Pair{<:NTuple{N, ColumnIndex},
58+
<:Vector{Expr}})
59+
) where N
60+
src = sel.first
61+
N < 2 && throw(ArgumentError("For multivariate functions (Tuple of column names), the number of input columns must be greater than 1"))
62+
col1, col2 = outidx[src[1]], outidx[src[2]]
63+
var1, var2 = _names(outidx)[col1], _names(outidx)[col2]
64+
if sel.second[1].head == :BYROW
65+
if N > 2
66+
nname = Symbol(funname(sel.second[1].args[1]), "_", var1, "_", var2, "_etc")
67+
else
68+
nname = Symbol(funname(sel.second[1].args[1]), "_", var1, "_", var2)
69+
end
70+
_check_ind_and_add!(outidx, nname)
71+
return ntuple(i->outidx[src[i]], length(src)) => sel.second[1] => nname
72+
end
73+
throw(ArgumentError("only byrow is accepted when using expressions"))
74+
end
75+
function normalize_combine!(outidx::Index, idx,
76+
@nospecialize(sel::Pair{<:NTuple{N, ColumnIndex},
77+
<:Expr})
78+
) where N
79+
src = sel.first
80+
N < 2 && throw(ArgumentError("For multivariate functions (Tuple of column names), the number of input columns must be greater than 1"))
81+
col1, col2 = outidx[src[1]], outidx[src[2]]
82+
var1, var2 = _names(outidx)[col1], _names(outidx)[col2]
83+
if sel.second.head == :BYROW
84+
if N > 2
85+
nname = Symbol(funname(sel.second.args[1]), "_", var1, "_", var2, "_etc")
86+
else
87+
nname = Symbol(funname(sel.second.args[1]), "_", var1, "_", var2)
88+
end
89+
_check_ind_and_add!(outidx, nname)
90+
return ntuple(i->outidx[src[i]], length(src)) => sel.second => nname
91+
end
92+
throw(ArgumentError("only byrow is accepted when using expressions"))
93+
end
94+
95+
96+
3097
# col => fun, the job is to create col => fun => :colname
3198
function normalize_combine!(outidx::Index, idx,
3299
@nospecialize(sel::Pair{<:ColumnIndex,
@@ -244,8 +311,12 @@ function _is_byrow_valid(idx, ms)
244311
end
245312
for i in 1:length(ms)
246313
if (ms[i].second.first isa Expr) && ms[i].second.first.head == :BYROW
247-
248-
byrow_vars = idx[ms[i].first]
314+
# if the input vars are supposed to be used in a multivariate function
315+
if ms[i].first isa Tuple
316+
byrow_vars = [idx[ms[i].first[j]] for j in 1:length(ms[i].first)]
317+
else
318+
byrow_vars = idx[ms[i].first]
319+
end
249320
!all(byrow_vars .∈ Ref(righthands)) && return false
250321
end
251322
if haskey(idx, ms[i].second.second)
@@ -258,7 +329,7 @@ end
258329
function _check_mutliple_rows_for_each_group(ds, ms)
259330
for i in 1:length(ms)
260331
# byrow are not checked since they are not going to modify the number of rows
261-
if ms[i].first isa Tuple
332+
if ms[i].first isa Tuple && !(ms[i].second.first isa Expr)
262333
T = return_type(ms[i].second.first, ntuple(j-> ds[!, ms[i].first[j]].val, length(ms[i].first)))
263334
if T <: AbstractVector && T !== Union{}
264335
return i
@@ -670,7 +741,7 @@ function combine(ds::Dataset, @nospecialize(args...); dropgroupcols = false, thr
670741
_combine_f_barrier_special(special_res, ds[!, ms[i].first].val, newds, ms[i].first, ms[i].second.first, ms[i].second.second, newds_lookup, _first_vector_res,ngroups, new_lengths, total_lengths, threads)
671742
end
672743
else
673-
if ms[i].first isa Tuple
744+
if ms[i].first isa Tuple && !(ms[i].second.first isa Expr)
674745
_combine_f_barrier_tuple(ntuple(j->_columns(ds)[index(ds)[ms[i].first[j]]], length(ms[i].first)), newds, ms[i].first, ms[i].second.first, ms[i].second.second, newds_lookup, starts, ngroups, new_lengths, total_lengths, threads)
675746
else
676747
_combine_f_barrier(haskey(index(ds).lookup, ms[i].first) ? _columns(ds)[index(ds)[ms[i].first]] : _columns(ds)[1], newds, ms[i].first, ms[i].second.first, ms[i].second.second, newds_lookup, starts, ngroups, new_lengths, total_lengths, threads)
@@ -753,7 +824,7 @@ function combine_ds(ds::AbstractDataset, @nospecialize(args...); threads = true)
753824
_combine_f_barrier_special(special_res, ds[!, ms[i].first].val, newds, ms[i].first, ms[i].second.first, ms[i].second.second, newds_lookup, _first_vector_res,ngroups, new_lengths, total_lengths, threads)
754825
end
755826
else
756-
if ms[i].first isa Tuple
827+
if ms[i].first isa Tuple && !(ms[i].second.first isa Expr)
757828
_combine_f_barrier_tuple(ntuple(j->_columns(ds)[index(ds)[ms[i].first[j]]], length(ms[i].first)), newds, ms[i].first, ms[i].second.first, ms[i].second.second, newds_lookup, starts, ngroups, new_lengths, total_lengths, threads)
758829
else
759830
_combine_f_barrier(haskey(index(ds), ms[i].first) ? _columns(ds)[index(ds)[ms[i].first]] : _columns(ds)[1], newds, ms[i].first, ms[i].second.first, ms[i].second.second, newds_lookup, starts, ngroups, new_lengths, total_lengths, threads)

src/dataset/modify.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,71 @@ function normalize_modify!(outidx::Index, idx,
5858
return ntuple(i->outidx[src[i]], N) => fun => Symbol(dst)
5959
end
6060

61+
# this is add to support byrow for multivariate functions
62+
# (col1, col2) => byrow(fun) => dst, the job is to create (col1, col2) => byrow(fun) => :dst
63+
function normalize_modify!(outidx::Index, idx,
64+
@nospecialize(sel::Pair{<:NTuple{N, ColumnIndex},
65+
<:Pair{<:Vector{Expr},
66+
<:Union{Symbol, AbstractString}}})
67+
) where N
68+
src = sel.first
69+
if sel.second.first[1].head == :BYROW
70+
_check_ind_and_add!(outidx, Symbol(sel.second.second))
71+
return ntuple(i->outidx[src[i]], length(src)) => sel.second.first[1] => Symbol(sel.second.second)
72+
end
73+
throw(ArgumentError("only byrow is accepted when using expressions"))
74+
end
75+
function normalize_modify!(outidx::Index, idx,
76+
@nospecialize(sel::Pair{<:NTuple{N, ColumnIndex},
77+
<:Pair{<:Expr,
78+
<:Union{Symbol, AbstractString}}})
79+
) where N
80+
src = sel.first
81+
if sel.second.first.head == :BYROW
82+
_check_ind_and_add!(outidx, Symbol(sel.second.second))
83+
return ntuple(i->outidx[src[i]], length(src)) => sel.second.first[1] => Symbol(sel.second.second)
84+
end
85+
throw(ArgumentError("only byrow is accepted when using expressions"))
86+
end
87+
function normalize_modify!(outidx::Index, idx,
88+
@nospecialize(sel::Pair{<:NTuple{N, ColumnIndex},
89+
<:Vector{Expr}})
90+
) where N
91+
src = sel.first
92+
N < 2 && throw(ArgumentError("For multivariate functions (Tuple of column names), the number of input columns must be greater than 1"))
93+
col1, col2 = outidx[src[1]], outidx[src[2]]
94+
var1, var2 = _names(outidx)[col1], _names(outidx)[col2]
95+
if sel.second[1].head == :BYROW
96+
if N > 2
97+
nname = Symbol(funname(sel.second[1].args[1]), "_", var1, "_", var2, "_etc")
98+
else
99+
nname = Symbol(funname(sel.second[1].args[1]), "_", var1, "_", var2)
100+
end
101+
_check_ind_and_add!(outidx, nname)
102+
return ntuple(i->outidx[src[i]], length(src)) => sel.second[1] => nname
103+
end
104+
throw(ArgumentError("only byrow is accepted when using expressions"))
105+
end
106+
function normalize_modify!(outidx::Index, idx,
107+
@nospecialize(sel::Pair{<:NTuple{N, ColumnIndex},
108+
<:Expr})
109+
) where N
110+
src = sel.first
111+
N < 2 && throw(ArgumentError("For multivariate functions (Tuple of column names), the number of input columns must be greater than 1"))
112+
col1, col2 = outidx[src[1]], outidx[src[2]]
113+
var1, var2 = _names(outidx)[col1], _names(outidx)[col2]
114+
if sel.second.head == :BYROW
115+
if N > 2
116+
nname = Symbol(funname(sel.second.args[1]), "_", var1, "_", var2, "_etc")
117+
else
118+
nname = Symbol(funname(sel.second.args[1]), "_", var1, "_", var2)
119+
end
120+
_check_ind_and_add!(outidx, nname)
121+
return ntuple(i->outidx[src[i]], length(src)) => sel.second => nname
122+
end
123+
throw(ArgumentError("only byrow is accepted when using expressions"))
124+
end
125+
61126
# col => fun, the job is to create col => fun => :colname
62127
function normalize_modify!(outidx::Index, idx,
63128
@nospecialize(sel::Pair{<:ColumnIndex,

src/sort/groupby.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ function combine(gds::Union{GroupBy, GatherBy}, @nospecialize(args...); dropgrou
177177
# if this is not the case, throw ArgumentError and ask user to use modify instead
178178
newlookup, new_nm = _create_index_for_newds(gds.parent, ms, gds.groupcols)
179179
!(_is_byrow_valid(Index(newlookup, new_nm, Dict{Int, Function}()), ms)) && throw(ArgumentError("`byrow` must be used for aggregated columns, use `modify` otherwise"))
180-
181180
if _fast_gatherby_reduction(gds, ms)
182181
return _combine_fast_gatherby_reduction(gds, ms, newlookup, new_nm; dropgroupcols = dropgroupcols, threads = threads)
183182
end
@@ -263,13 +262,13 @@ function combine(gds::Union{GroupBy, GatherBy}, @nospecialize(args...); dropgrou
263262
end
264263

265264
if i == _first_vector_res
266-
if ms[i].first isa Tuple
265+
if ms[i].first isa Tuple && !(ms[i].second.first isa Expr)
267266
_combine_f_barrier_special_tuple(special_res, ntuple(j-> view(_columns(gds.parent)[index(gds.parent)[ms[i].first[j]]], a[1]), length(ms[i].first)), newds, ms[i].first, ms[i].second.first, ms[i].second.second, newds_lookup, _first_vector_res,ngroups, new_lengths, total_lengths, threads)
268267
else
269268
_combine_f_barrier_special(special_res, view(_columns(gds.parent)[index(gds.parent)[ms[i].first]], a[1]), newds, ms[i].first, ms[i].second.first, ms[i].second.second, newds_lookup, _first_vector_res,ngroups, new_lengths, total_lengths, threads)
270269
end
271270
else
272-
if ms[i].first isa Tuple
271+
if ms[i].first isa Tuple && !(ms[i].second.first isa Expr)
273272
_combine_f_barrier_tuple(ntuple(j-> _threaded_permute_for_groupby(_columns(gds.parent)[index(gds.parent)[ms[i].first[j]]], a[1], threads = threads), length(ms[i].first)), newds, ms[i].first, ms[i].second.first, ms[i].second.second, newds_lookup, starts, ngroups, new_lengths, total_lengths, threads)
274273
else
275274
_combine_f_barrier(!(ms[i].second.first isa Expr) && haskey(index(gds.parent), ms[i].first) ? curr_x : view(_columns(gds.parent)[1], a[1]), newds, ms[i].first, ms[i].second.first, ms[i].second.second, newds_lookup, starts, ngroups, new_lengths, total_lengths, threads)

test/byrow.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,3 +407,11 @@ end
407407
@test all(byrow(ds2, issorted, :))
408408
@test Matrix(ds2) == sort(Matrix(ds), dims = 2)
409409
end
410+
411+
@testset "byrow with NTuple as cols" begin
412+
ds = Dataset(x1 = [1,2,1,2], x2 = [1,-2,-3,10], x3 = 1:4)
413+
fun123(x,y,z) = x == 1 ? y*z : y/z
414+
@test byrow(ds, fun123, (1,2,3)) == [1,-1.0,-9,2.5]
415+
fun123_2(x,y) = x == 1 && y < 0 ? true : false
416+
@test byrow(ds, fun123_2, (:x1, :x2)) == [false, false, true, false]
417+
end

test/grouping.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -612,5 +612,21 @@ end
612612

613613
end
614614

615-
616-
615+
@testset "byrow with tuple input" begin
616+
ds = Dataset(x1 = [1,2,1,2], x2 = [1,-2,-3,10], x3 = 1:4)
617+
res = modify(ds, (1,2) => byrow((x,y)-> x==1 && y<0 ? true : false))
618+
@test res == Dataset(x1 = [1,2,1,2], x2 = [1,-2,-3,10], x3 = 1:4, function_x1_x2=[false, false, true, false])
619+
res = modify(view(ds, [1,2,3,4], [1,2,3]), (1,2) => byrow((x,y)-> x==1 && y<0 ? true : false))
620+
@test res == Dataset(x1 = [1,2,1,2], x2 = [1,-2,-3,10], x3 = 1:4, function_x1_x2=[false, false, true, false])
621+
res = modify(groupby(ds, 3), (1,2) => byrow((x,y)-> x==1 && y<0 ? true : false))
622+
@test res == Dataset(x1 = [1,2,1,2], x2 = [1,-2,-3,10], x3 = 1:4, function_x1_x2=[false, false, true, false])
623+
res = modify(gatherby(ds, 3), (1,2) => byrow((x,y)-> x==1 && y<0 ? true : false))
624+
@test res == Dataset(x1 = [1,2,1,2], x2 = [1,-2,-3,10], x3 = 1:4, function_x1_x2=[false, false, true, false])
625+
res = modify(gatherby(ds, 3), (1,2) => byrow((x,y)-> x==1 && y<0 ? true : false)=>:newvar)
626+
@test res == Dataset(x1 = [1,2,1,2], x2 = [1,-2,-3,10], x3 = 1:4, newvar=[false, false, true, false])
627+
628+
res = combine(groupby(ds, 1), 2 => IMD.minimum, :x3 => IMD.minimum, (:minimum_x2, :minimum_x3) => byrow((x,y)->x/y))
629+
@test res == Dataset(x1 = [1,2], minimum_x2 = [-3,-2], minimum_x3 = [1,2], function_minimum_x2_minimum_x3 = [-3.0, -1.0])
630+
res = combine(groupby(ds, 1), 2 => IMD.minimum, :x3 => IMD.minimum, (:minimum_x2, :minimum_x3) => byrow((x,y)->x/y) => :newvar)
631+
@test res == Dataset(x1 = [1,2], minimum_x2 = [-3,-2], minimum_x3 = [1,2], newvar = [-3.0, -1.0])
632+
end

0 commit comments

Comments
 (0)