Skip to content

Commit 81fecbd

Browse files
committed
add source option to outerjoin
1 parent f69b09e commit 81fecbd

File tree

4 files changed

+68
-5
lines changed

4 files changed

+68
-5
lines changed

src/join/join.jl

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -699,8 +699,33 @@ function _fill_oncols_left_table_left_outer!(res, x, notinleft, en, total)
699699
end
700700
end
701701

702+
function _fill_source_for_outer!(res, ranges, notinleft, lval, rval, en, total)
703+
cnt = 0
704+
for i in 1:length(ranges)
705+
if length(ranges[i]) == 0
706+
cnt += 1
707+
res[cnt] = lval
708+
else
709+
cnt += length(ranges[i])
710+
end
711+
end
712+
for i in en[end]+1:total
713+
res[i] = rval
714+
end
715+
end
716+
717+
718+
719+
function _create_source_for_outer(ranges, notinleft, total_length, en)
720+
res = allowmissing(PooledArray(["left", "right", "both"]))
721+
resize!(res.refs, total_length)
722+
fill!(res.refs, get(res.invpool, "both", missing))
723+
_fill_source_for_outer!(res.refs, ranges, notinleft, get(res.invpool, "left", missing), get(res.invpool, "right", missing), en, total_length)
724+
res
725+
end
726+
702727

703-
function _join_outer(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeunique = false, mapformats = [true, true], stable = false, alg = HeapSort, check = true, accelerate = false, method = :sort, threads = true) where T
728+
function _join_outer(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeunique = false, mapformats = [true, true], stable = false, alg = HeapSort, check = true, accelerate = false, method = :sort, threads = true, source::Bool = false, source_col_name = :source) where T
704729
isempty(dsl) || isempty(dsr) && throw(ArgumentError("in `outerjoin` both left and right tables must be non-empty"))
705730
oncols_left = onleft
706731
oncols_right = onright
@@ -713,7 +738,7 @@ function _join_outer(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeu
713738
end
714739
ranges = Vector{UnitRange{T}}(undef, nrow(dsl))
715740
if length(oncols_left) == 1 && nrow(dsr)>1
716-
success, result = _join_outer_dict(dsl, dsr, ranges, oncols_left, oncols_right, oncols_left, oncols_right, right_cols, Val(T); makeunique = makeunique, mapformats = mapformats, check = check, threads = threads)
741+
success, result = _join_outer_dict(dsl, dsr, ranges, oncols_left, oncols_right, oncols_left, oncols_right, right_cols, Val(T); makeunique = makeunique, mapformats = mapformats, check = check, threads = threads, source = source, source_col_name = source_col_name)
717742
if success
718743
return result
719744
end
@@ -727,6 +752,9 @@ function _join_outer(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeu
727752
notinleft = _find_right_not_in_left(ranges, nrow(dsr), idx)
728753
cumsum!(new_ends, new_ends)
729754
total_length = new_ends[end] + length(notinleft)
755+
if source
756+
source_col = _create_source_for_outer(ranges, notinleft, total_length, new_ends)
757+
end
730758
if check
731759
@assert total_length < 10*nrow(dsl) "the output data set will be very large ($(total_length)×$(ncol(dsl)+length(right_cols))) compared to the left data set size ($(nrow(dsl))×$(ncol(dsl))), make sure that the `on` keyword is selected properly, alternatively, pass `check = false` to ignore this error."
732760
end
@@ -765,6 +793,9 @@ function _join_outer(dsl, dsr::AbstractDataset, ::Val{T}; onleft, onright, makeu
765793
push!(index(newds), new_var_name)
766794
setformat!(newds, index(newds)[new_var_name], getformat(dsr, _names(dsr)[right_cols[j]]))
767795
end
796+
if source
797+
insertcols!(newds, source_col_name => source_col)
798+
end
768799
newds
769800

770801
end

src/join/join_dict.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ function _join_inner_dict(dsl, dsr, ranges, onleft, onright, right_cols, ::Val{T
353353

354354
end
355355

356-
function _join_outer_dict(dsl, dsr, ranges, onleft, onright, oncols_left, oncols_right, right_cols, ::Val{T}; makeunique = makeunique, mapformats = mapformats, check = check, threads = true) where T
356+
function _join_outer_dict(dsl, dsr, ranges, onleft, onright, oncols_left, oncols_right, right_cols, ::Val{T}; makeunique = makeunique, mapformats = mapformats, check = check, threads = true, source::Bool = false, source_col_name = :source) where T
357357
_fl = _date_valueidentity
358358
_fr = _date_valueidentity
359359
if mapformats[1]
@@ -372,6 +372,10 @@ function _join_outer_dict(dsl, dsr, ranges, onleft, onright, oncols_left, oncols
372372
notinleft = _find_right_not_in_left(ranges, nrow(dsr), 1:nrow(dsr))
373373
cumsum!(new_ends, new_ends)
374374
total_length = new_ends[end] + length(notinleft)
375+
if source
376+
source_col = _create_source_for_outer(ranges, notinleft, total_length, new_ends)
377+
end
378+
375379
if check
376380
@assert total_length < 10*nrow(dsl) "the output data set will be very large ($(total_length)×$(ncol(dsl)+length(right_cols))) compared to the left data set size ($(nrow(dsl))×$(ncol(dsl))), make sure that the `on` keyword is selected properly, alternatively, pass `check = false` to ignore this error."
377381
end
@@ -410,6 +414,9 @@ function _join_outer_dict(dsl, dsr, ranges, onleft, onright, oncols_left, oncols
410414
push!(index(newds), new_var_name)
411415
setformat!(newds, index(newds)[new_var_name], getformat(dsr, _names(dsr)[right_cols[j]]))
412416
end
417+
if source
418+
insertcols!(newds, source_col_name => source_col)
419+
end
413420
true, newds
414421

415422
end

src/join/main.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ julia> outerjoin(dsl, dsr, on = :year, mapformats = true) # Use formats for data
461461
4 │ 2012 true missing
462462
```
463463
"""
464-
function DataAPI.outerjoin(dsl::AbstractDataset, dsr::AbstractDataset; on = nothing, makeunique = false, mapformats::Union{Bool, Vector{Bool}} = true, stable = false, alg = HeapSort, check = true, accelerate = false, method = :sort, threads::Bool = true)
464+
function DataAPI.outerjoin(dsl::AbstractDataset, dsr::AbstractDataset; on = nothing, makeunique = false, mapformats::Union{Bool, Vector{Bool}} = true, stable = false, alg = HeapSort, check = true, accelerate = false, method = :sort, threads::Bool = true, source::Bool = false, source_name = :source)
465465
!(method in (:hash, :sort)) && throw(ArgumentError("method must be :hash or :sort"))
466466
on === nothing && throw(ArgumentError("`on` keyword must be specified"))
467467
if !(on isa AbstractVector)
@@ -485,7 +485,7 @@ function DataAPI.outerjoin(dsl::AbstractDataset, dsr::AbstractDataset; on = noth
485485
else
486486
throw(ArgumentError("`on` keyword must be a vector of column names or a vector of pairs of column names"))
487487
end
488-
_join_outer(dsl, dsr, nrow(dsr) < typemax(Int32) ? Val(Int32) : Val(Int64), onleft = onleft, onright = onright, makeunique = makeunique, mapformats = mapformats, stable = stable, alg = alg, check = check, accelerate = accelerate, method = method, threads = threads)
488+
_join_outer(dsl, dsr, nrow(dsr) < typemax(Int32) ? Val(Int32) : Val(Int64), onleft = onleft, onright = onright, makeunique = makeunique, mapformats = mapformats, stable = stable, alg = alg, check = check, accelerate = accelerate, method = method, threads = threads, source = source, source_col_name = source_name)
489489
end
490490

491491
"""

test/join.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,6 +1303,31 @@ closefinance_tol10ms_noexact = Dataset([Union{Missing, DateTime}[DateTime("2016-
13031303
@test closejoin(dsl, dsr, on = :x1, direction = :nearest, border = :none, threads = false) == Dataset(x1=[.3,.74,.53,.30, .65,1], y = [missing,3,3,missing,3, missing])
13041304
@test closejoin(dsl, dsr, on = :x1, direction = :nearest, border = :none, method = :hash, threads = false) == Dataset(x1=[.3,.74,.53,.30, .65,1], y = [missing,3,3,missing,3, missing])
13051305

1306+
dsl = Dataset(x1 = [1,2,3,4], y = [100,200,300,400])
1307+
dsr = Dataset(x1 = [2,1,5,6], y1 = [-100,-200,-300,-400])
1308+
out1 = outerjoin(dsl, dsr, on = :x1, source = true)
1309+
out2 = outerjoin(dsr, dsl, on = :x1, source = true)
1310+
out1_t = Dataset(AbstractVector[Union{Missing, Int64}[1, 2, 3, 4, 5, 6], Union{Missing, Int64}[100, 200, 300, 400, missing, missing], Union{Missing, Int64}[-200, -100, missing, missing, -300, -400], Union{Missing, String}["both", "both", "left", "left", "right", "right"]], ["x1", "y", "y1", "source"])
1311+
out2_t = Dataset(AbstractVector[Union{Missing, Int64}[2, 1, 5, 6, 3, 4], Union{Missing, Int64}[-100, -200, -300, -400, missing, missing], Union{Missing, Int64}[200, 100, missing, missing, 300, 400], Union{Missing, String}["both", "both", "left", "left", "right", "right"]], ["x1", "y1", "y", "source"])
1312+
@test out1 == out1_t
1313+
@test out2 == out2_t
1314+
dsl = Dataset(x1 = [1,2,3,4], y = [100,200,300,400])
1315+
dsr = Dataset(x1 = [2,1], y1 = [-100,missing])
1316+
out1 = outerjoin(dsl, view(dsr,[2,1],:), on = :x1, source = true)
1317+
out2 = outerjoin(view(dsr,[2,1],:), dsl, on = :x1, source = true)
1318+
out1_t = Dataset(AbstractVector[Union{Missing, Int64}[1, 2, 3, 4], Union{Missing, Int64}[100, 200, 300, 400], Union{Missing, Int64}[missing, -100, missing, missing], Union{Missing, String}["both", "both", "left", "left"]], ["x1", "y", "y1", "source"])
1319+
out2_t = Dataset(AbstractVector[Union{Missing, Int64}[1, 2, 3, 4], Union{Missing, Int64}[missing, -100, missing, missing], Union{Missing, Int64}[100, 200, 300, 400], Union{Missing, String}["both", "both", "right", "right"]], ["x1", "y1", "y", "source"])
1320+
@test out1 == out1_t
1321+
@test out2 == out2_t
1322+
dsl = Dataset(x1 = [1,2,3,4],x2=[1,1,1,1], y = [100,200,300,400])
1323+
dsr = Dataset(x1 = [2,1,5,6],x2= [1,1,1,1], y1 = [-100,-200,-300,-400])
1324+
out1 = outerjoin(dsl, dsr, on = [:x1, :x2], source = true)
1325+
out2 = outerjoin(dsr, dsl, on = [:x1, :x2], source = true)
1326+
out1_t = Dataset(AbstractVector[Union{Missing, Int64}[1, 2, 3, 4, 5, 6], Union{Missing, Int64}[1, 1, 1, 1, 1, 1], Union{Missing, Int64}[100, 200, 300, 400, missing, missing], Union{Missing, Int64}[-200, -100, missing, missing, -300, -400], Union{Missing, String}["both", "both", "left", "left", "right", "right"]], ["x1", "x2", "y", "y1", "source"])
1327+
out2_t = Dataset(AbstractVector[Union{Missing, Int64}[2, 1, 5, 6, 3, 4], Union{Missing, Int64}[1, 1, 1, 1, 1, 1], Union{Missing, Int64}[-100, -200, -300, -400, missing, missing], Union{Missing, Int64}[200, 100, missing, missing, 300, 400], Union{Missing, String}["both", "both", "left", "left", "right", "right"]], ["x1", "x2", "y1", "y", "source"])
1328+
@test out1 == out1_t
1329+
@test out2 == out2_t
1330+
13061331
end
13071332

13081333
@testset "Test empty inputs 1" begin

0 commit comments

Comments
 (0)