Skip to content

Commit 1e7f874

Browse files
committed
Add unit tests for dropping duplicate keys or not
1 parent 4f9f190 commit 1e7f874

File tree

1 file changed

+20
-7
lines changed

1 file changed

+20
-7
lines changed

python/tests/test_dataframe.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -663,25 +663,38 @@ def test_join():
663663
df1 = ctx.create_dataframe([[batch]], "r")
664664

665665
df2 = df.join(df1, on="a", how="inner")
666-
df2.show()
667666
df2 = df2.sort(column("l.a"))
668667
table = pa.Table.from_batches(df2.collect())
669668

670669
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
671670
assert table.to_pydict() == expected
672671

673-
df2 = df.join(df1, left_on="a", right_on="a", how="inner")
674-
df2.show()
672+
# Test the default behavior for dropping duplicate keys
673+
# Since we may have a duplicate column name and pa.Table()
674+
# hides the fact, instead we need to explicitly check the
675+
# resultant arrays.
676+
df2 = df.join(df1, left_on="a", right_on="a", how="inner", drop_duplicate_keys=True)
675677
df2 = df2.sort(column("l.a"))
676-
table = pa.Table.from_batches(df2.collect())
678+
result = df2.collect()[0]
679+
assert result.num_columns == 3
680+
assert result.column(0) == pa.array([1, 2], pa.int64())
681+
assert result.column(1) == pa.array([4, 5], pa.int64())
682+
assert result.column(2) == pa.array([8, 10], pa.int64())
677683

678-
expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]}
679-
assert table.to_pydict() == expected
684+
df2 = df.join(
685+
df1, left_on="a", right_on="a", how="inner", drop_duplicate_keys=False
686+
)
687+
df2 = df2.sort(column("l.a"))
688+
result = df2.collect()[0]
689+
assert result.num_columns == 4
690+
assert result.column(0) == pa.array([1, 2], pa.int64())
691+
assert result.column(1) == pa.array([4, 5], pa.int64())
692+
assert result.column(2) == pa.array([1, 2], pa.int64())
693+
assert result.column(3) == pa.array([8, 10], pa.int64())
680694

681695
# Verify we don't make a breaking change to pre-43.0.0
682696
# where users would pass join_keys as a positional argument
683697
df2 = df.join(df1, (["a"], ["a"]), how="inner")
684-
df2.show()
685698
df2 = df2.sort(column("l.a"))
686699
table = pa.Table.from_batches(df2.collect())
687700

0 commit comments

Comments
 (0)