Skip to content

Commit ee1da4b

Browse files
committed
add a new keyword argument to update/!
the `op` keyword argument is for passing a user defined function for updating rows in the main data set
1 parent 3e1fe76 commit ee1da4b

File tree

5 files changed

+245
-165
lines changed

5 files changed

+245
-165
lines changed

docs/src/man/joins.md

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,9 @@ julia> contains(dsl, dsr, on = [1=>1, 2=>(2,3)], strict_inequality = true)
447447

448448
## Update a data set by values from another data set
449449

450-
`update!` updates a data set values by using values from a transaction data set. The function uses the given keys (`on = ...`) to select rows for updating. By default, the missing values in transaction data set wouldn't replace the values in the main data set, however, using `allowmissing = true` changes this behaviour. If there are multiple rows in the main data set which match the key(s), using `mode = :all` causes all of them to be updated, `mode = :missing` causes only the ones which are missing in the main data set to be updated, and `mode = fun` updates the values which calling `fun` on them returns `true`. If there are multiple rows in the transaction data set which match the key, only the last one (given `stable = true` is passed) will be used to update the main data set.
450+
`update!` updates a data set values by using values from a transaction data set. The function uses the given keys (`on = ...`) to select rows for updating. By default, the missing values in transaction data set wouldn't replace the values in the main data set, however, using `allowmissing = true` changes this behaviour. If there are multiple rows in the main data set which match the key(s), using `mode = :all` causes all of them to be updated, `mode = :missings` causes only the ones which are missing in the main data set to be updated, and `mode = fun` updates the values which calling `fun` on them returns `true`. If there are multiple rows in the transaction data set which match the key, only the last one (given `stable = true` is passed) will be used to update the main data set.
451+
452+
By default, `update!` updates the old values by the new values from the transaction data set, however, user may pass any function via the `op` keyword argument to update the values in the main data set by the result of calling `op` on values on both data sets. In this case, `update!` updates values in the main data set by `op(old, new)`, where `old` is the value from the main data set and `new` is the value from the transaction data set.
451453

452454
The `update!` functions replace the main data set with the updated version, however, if a copy of the updated data set is required, the `update` function can be used instead.
453455

@@ -486,7 +488,7 @@ julia> transaction = Dataset(group = ["G1", "G2"], id = [2, 1],
486488
487489
488490
julia> update(main, transaction, on = [:group, :id],
489-
allowmissing = false, mode = :missing)
491+
allowmissing = false, mode = :missings)
490492
7×4 Dataset
491493
Row │ group id x1 x2
492494
│ identity identity identity identity
@@ -530,6 +532,20 @@ julia> update(main, transaction, on = [:group, :id],
530532
5 │ G2 1 1.3 1
531533
6 │ G2 1 2.1 missing
532534
7 │ G2 2 0.0 2
535+
536+
julia> update(main, transaction, on = [:group, :id], op = +) # add values of transaction to main, when op is set mode = :all is default
537+
7×4 Dataset
538+
Row │ group id x1 x2
539+
│ identity identity identity identity
540+
│ String? Int64? Float64? Int64?
541+
─────┼─────────────────────────────────────────
542+
1 │ G1 1 1.2 5
543+
2 │ G1 1 2.3 4
544+
3 │ G1 2 missing 4
545+
4 │ G1 2 4.8 2
546+
5 │ G2 1 1.3 4
547+
6 │ G2 1 2.1 missing
548+
7 │ G2 2 0.0 2
533549
```
534550

535551
## `compare`

src/join/join_dict.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ function _join_outer_dict(dsl, dsr, ranges, onleft, onright, oncols_left, oncols
503503

504504
end
505505

506-
function _update!_dict(dsl, dsr, ranges, onleft, onright, right_cols, ::Val{T}; allowmissing = true, mode = :all, mapformats = [true, true], stable = false, alg = HeapSort, threads = threads) where T
506+
function _update!_dict(dsl, dsr, ranges, onleft, onright, right_cols, ::Val{T}; allowmissing = true, mode = :all, mapformats = [true, true], stable = false, alg = HeapSort, threads = threads, op = nothing) where T
507507
_fl = _date_valueidentity
508508
_fr = _date_valueidentity
509509
if mapformats[1]
@@ -534,7 +534,7 @@ function _update!_dict(dsl, dsr, ranges, onleft, onright, right_cols, ::Val{T};
534534
TL = nonmissingtype(eltype(_columns(dsl)[left_cols_idx]))
535535
TR = nonmissingtype(eltype(_columns(dsr)[right_cols[j]]))
536536
if promote_type(TR, TL) <: TL
537-
_update_left_with_right!(_columns(dsl)[left_cols_idx], _columns(dsr)[right_cols[j]], ranges, allowmissing, f_mode, threads = threads)
537+
_update_left_with_right!(_columns(dsl)[left_cols_idx], _columns(dsr)[right_cols[j]], ranges, allowmissing, f_mode, threads = threads, op = op)
538538
end
539539
end
540540
end

src/join/main.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1482,7 +1482,7 @@ function closejoin!(dsl::Dataset, dsr::AbstractDataset; on = nothing, direction
14821482
end
14831483

14841484
"""
1485-
update!(dsmain::Dataset, dsupdate::AbstractDataset; on=nothing, allowmissing=false, mode=:missings, mapformats=true, alg=HeapSort, stable=true, accelerate = false, method = :sort, threads = true)
1485+
update!(dsmain::Dataset, dsupdate::AbstractDataset; on=nothing, allowmissing=false, mode=:missings, op = nothing, mapformats=true, alg=HeapSort, stable=true, accelerate = false, method = :sort, threads = true)
14861486
14871487
Update a `Dataset` `dsmain` with another `Dataset` `dsupdate` based `on` given keys for matching rows,
14881488
and change the left `Dataset` after updating.
@@ -1498,8 +1498,9 @@ the order of selected observation from the right table.
14981498
- `on`: can be a single column name, a vector of column names or a vector of pairs of column names, known as keys that the update function will based on.
14991499
- `allowmissing`: is set to `false` by default, so `missing` values in `dsupdate` will not replace the values in `dsmain`;
15001500
change this to `true` can update `dsmain` using `missing` values in `dsupdate`.
1501-
- `mode`: by default is set to `:missings`, means that only rows in `dsmain` with `missing` values will be updated.
1501+
- `mode`: by default is set to `:missings` and when `op` is passed the default is set to `:all`, it means when `op` is not set only rows in `dsmain` with `missing` values will be updated.
15021502
changing it to `:all` means all matching rows based `on` keys will be updated. Otherwise a function can be passed as `mode` to update only observations which return true when `mode` call on them.
1503+
- `op`: by default, `update!` replace the values in `dsmain` by the values from `dsupdate`, however, user can pass any binary function to `op` to replace the value in `dsmain` by `op(left_value, right_value)`, i.e. replace it by calling `op` on the old value and the new value.
15031504
$_JOINMAPFORMATSDOC
15041505
$_JOINTHREADSDOC
15051506
$_JOINMETHODDOCSORT
@@ -1552,12 +1553,33 @@ julia> update!(dsmain, dsupdate, on = [:group, :id], mode = :missings) # Only mi
15521553
5 │ G2 1 1.3 1
15531554
6 │ G2 1 2.1 3
15541555
7 │ G2 2 0.0 2
1556+
1557+
julia> dsmain = Dataset(group = ["G1", "G1", "G1", "G1", "G2", "G2", "G2"],
1558+
id = [ 1 , 1 , 2 , 2 , 1 , 1 , 2 ],
1559+
x1 = [1.2, 2.3,missing, 2.3, 1.3, 2.1 , 0.0 ],
1560+
x2 = [ 5 , 4 , 4 , 2 , 1 ,missing, 2 ]);
1561+
julia> dsupdate = Dataset(group = ["G1", "G2"], id = [2, 1],
1562+
x1 = [2.5, missing], x2 = [missing, 3]);
1563+
1564+
julia> update!(dsmain, dsupdate, on = [:group, :id], op = +, mode = :all)
1565+
7×4 Dataset
1566+
Row │ group id x1 x2
1567+
│ identity identity identity identity
1568+
│ String? Int64? Float64? Int64?
1569+
─────┼─────────────────────────────────────────
1570+
1 │ G1 1 1.2 5
1571+
2 │ G1 1 2.3 4
1572+
3 │ G1 2 missing 4
1573+
4 │ G1 2 4.8 2
1574+
5 │ G2 1 1.3 4
1575+
6 │ G2 1 2.1 missing
1576+
7 │ G2 2 0.0 2
15551577
```
15561578
"""
1557-
function update!(dsmain::Dataset, dsupdate::AbstractDataset; on = nothing, allowmissing = false, mode::Union{Symbol, Function} = :missings, mapformats::Union{Bool, Vector{Bool}} = true, stable = true, alg = HeapSort, accelerate = false, method = :sort, threads::Bool = true)
1579+
function update!(dsmain::Dataset, dsupdate::AbstractDataset; on = nothing, allowmissing = false, op = nothing, mode::Union{Symbol, Function} = op === nothing ? :missings : :all, mapformats::Union{Bool, Vector{Bool}} = true, stable = true, alg = HeapSort, accelerate = false, method = :sort, threads::Bool = true)
15581580
!(method in (:hash, :sort)) && throw(ArgumentError("method must be :hash or :sort"))
15591581
on === nothing && throw(ArgumentError("`on` keyword must be specified"))
1560-
mode isa Symbol && !(mode (:all, :missing, :missings)) && throw(ArgumentError("`mode` can be either :all or :missing"))
1582+
mode isa Symbol && !(mode (:all, :missing, :missings)) && throw(ArgumentError("`mode` can be either :all, :missing, or a function"))
15611583
if !(on isa AbstractVector)
15621584
on = [on]
15631585
else
@@ -1577,7 +1599,7 @@ function update!(dsmain::Dataset, dsupdate::AbstractDataset; on = nothing, allow
15771599
else
15781600
throw(ArgumentError("`on` keyword must be a vector of column names or a vector of pairs of column names"))
15791601
end
1580-
_update!(dsmain, dsupdate, nrow(dsupdate) < typemax(Int32) ? Val(Int32) : Val(Int64), onleft = onleft, onright = onright, allowmissing = allowmissing, mode = mode, mapformats = mapformats, stable = stable, alg = alg, accelerate = accelerate, method = method, threads = threads)
1602+
_update!(dsmain, dsupdate, nrow(dsupdate) < typemax(Int32) ? Val(Int32) : Val(Int64), onleft = onleft, onright = onright, allowmissing = allowmissing, mode = mode, mapformats = mapformats, stable = stable, alg = alg, accelerate = accelerate, method = method, threads = threads, op = op)
15811603

15821604
dsmain
15831605
end
@@ -1586,8 +1608,7 @@ end
15861608
15871609
Variant of `update!` that returns an updated copy of `dsmain` leaving `dsmain` itself unmodified.
15881610
"""
1589-
update(dsmain::AbstractDataset, dsupdate::AbstractDataset; on = nothing, allowmissing = false, mode = :all, mapformats::Union{Bool, Vector{Bool}} = true, stable = true, alg = HeapSort, accelerate = false, method = :sort, threads = true) = update!(copy(dsmain), dsupdate; on = on, allowmissing = allowmissing, mode = mode, mapformats = mapformats, stable = stable, alg = alg, accelerate = accelerate, method = method, threads = threads)
1590-
1611+
update(dsmain::AbstractDataset, dsupdate::AbstractDataset; on = nothing, allowmissing = false, op = nothing, mode = op === nothing ? :missings : :all, mapformats::Union{Bool, Vector{Bool}} = true, stable = true, alg = HeapSort, accelerate = false, method = :sort, threads = true,) = update!(copy(dsmain), dsupdate; on = on, allowmissing = allowmissing, mode = mode, mapformats = mapformats, stable = stable, alg = alg, accelerate = accelerate, method = method, threads = threads, op = op)
15911612

15921613

15931614
# TODO the docstring is very limited, we need a more comprehensive docs here / and more examples

src/join/update.jl

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,32 @@
1-
function _update_left_with_right!(x, y, ranges, allowmissing, mode::F; threads = true) where F
2-
@_threadsfor threads for i in 1:length(x)
3-
if length(ranges[i]) > 0
4-
if mode(x[i])
5-
if !allowmissing && !ismissing(y[ranges[i].stop])
6-
x[i] = y[ranges[i].stop]
7-
elseif allowmissing
8-
x[i] = y[ranges[i].stop]
1+
function _update_left_with_right!(x, y, ranges, allowmissing, mode::F; threads = true, op = nothing) where F
2+
if op === nothing
3+
@_threadsfor threads for i in 1:length(x)
4+
if length(ranges[i]) > 0
5+
if mode(x[i])
6+
if !allowmissing && !ismissing(y[ranges[i].stop])
7+
x[i] = y[ranges[i].stop]
8+
elseif allowmissing
9+
x[i] = y[ranges[i].stop]
10+
end
11+
end
12+
end
13+
end
14+
else
15+
@_threadsfor threads for i in 1:length(x)
16+
if length(ranges[i]) > 0
17+
if mode(x[i])
18+
if !allowmissing && !ismissing(y[ranges[i].stop])
19+
x[i] = op(x[i], y[ranges[i].stop])
20+
elseif allowmissing
21+
x[i] = op(x[i], y[ranges[i].stop])
22+
end
923
end
1024
end
1125
end
1226
end
1327
end
1428

15-
function _update!(dsl::Dataset, dsr::AbstractDataset, ::Val{T}; onleft, onright, check = true, allowmissing = true, mode = :all, mapformats = [true, true], stable = false, alg = HeapSort, accelerate = false, usehash = true, method = :sort, threads = true) where T
29+
function _update!(dsl::Dataset, dsr::AbstractDataset, ::Val{T}; onleft, onright, check = true, allowmissing = true, mode = :all, mapformats = [true, true], stable = false, alg = HeapSort, accelerate = false, usehash = true, method = :sort, threads = true, op = nothing) where T
1630
isempty(dsl) && return dsl
1731
if method == :hash
1832
ranges, a, idx, minval, reps, sz, right_cols = _find_ranges_for_join_using_hash(dsl, dsr, onleft, onright, mapformats, true, Val(T); threads = threads)
@@ -23,7 +37,7 @@ function _update!(dsl::Dataset, dsr::AbstractDataset, ::Val{T}; onleft, onright,
2337

2438
ranges = Vector{UnitRange{T}}(undef, nrow(dsl))
2539
if usehash && length(oncols_left) == 1 && nrow(dsr)>1
26-
success, result = _update!_dict(dsl, dsr, ranges, oncols_left, oncols_right, right_cols, Val(T); mapformats = mapformats, allowmissing = allowmissing, mode = mode, threads = threads)
40+
success, result = _update!_dict(dsl, dsr, ranges, oncols_left, oncols_right, right_cols, Val(T); mapformats = mapformats, allowmissing = allowmissing, mode = mode, threads = threads, op = op)
2741
if success
2842
return result
2943
end
@@ -48,7 +62,7 @@ function _update!(dsl::Dataset, dsr::AbstractDataset, ::Val{T}; onleft, onright,
4862
TL = nonmissingtype(eltype(_columns(dsl)[left_cols_idx]))
4963
TR = nonmissingtype(eltype(_columns(dsr)[right_cols[j]]))
5064
if promote_type(TR, TL) <: TL
51-
_update_left_with_right!(_columns(dsl)[left_cols_idx], view(_columns(dsr)[right_cols[j]], idx), ranges, allowmissing, f_mode, threads = threads)
65+
_update_left_with_right!(_columns(dsl)[left_cols_idx], view(_columns(dsr)[right_cols[j]], idx), ranges, allowmissing, f_mode, threads = threads, op = op)
5266
end
5367
end
5468
end

0 commit comments

Comments
 (0)