@@ -663,25 +663,38 @@ def test_join():
663663 df1 = ctx .create_dataframe ([[batch ]], "r" )
664664
665665 df2 = df .join (df1 , on = "a" , how = "inner" )
666- df2 .show ()
667666 df2 = df2 .sort (column ("l.a" ))
668667 table = pa .Table .from_batches (df2 .collect ())
669668
670669 expected = {"a" : [1 , 2 ], "c" : [8 , 10 ], "b" : [4 , 5 ]}
671670 assert table .to_pydict () == expected
672671
673- df2 = df .join (df1 , left_on = "a" , right_on = "a" , how = "inner" )
674- df2 .show ()
672+ # Test the default behavior for dropping duplicate keys
673+ # Since we may have a duplicate column name and pa.Table()
674+ # hides the fact, instead we need to explicitly check the
675+ # resultant arrays.
676+ df2 = df .join (df1 , left_on = "a" , right_on = "a" , how = "inner" , drop_duplicate_keys = True )
675677 df2 = df2 .sort (column ("l.a" ))
676- table = pa .Table .from_batches (df2 .collect ())
678+ result = df2 .collect ()[0 ]
679+ assert result .num_columns == 3
680+ assert result .column (0 ) == pa .array ([1 , 2 ], pa .int64 ())
681+ assert result .column (1 ) == pa .array ([4 , 5 ], pa .int64 ())
682+ assert result .column (2 ) == pa .array ([8 , 10 ], pa .int64 ())
677683
678- expected = {"a" : [1 , 2 ], "c" : [8 , 10 ], "b" : [4 , 5 ]}
679- assert table .to_pydict () == expected
684+ df2 = df .join (
685+ df1 , left_on = "a" , right_on = "a" , how = "inner" , drop_duplicate_keys = False
686+ )
687+ df2 = df2 .sort (column ("l.a" ))
688+ result = df2 .collect ()[0 ]
689+ assert result .num_columns == 4
690+ assert result .column (0 ) == pa .array ([1 , 2 ], pa .int64 ())
691+ assert result .column (1 ) == pa .array ([4 , 5 ], pa .int64 ())
692+ assert result .column (2 ) == pa .array ([1 , 2 ], pa .int64 ())
693+ assert result .column (3 ) == pa .array ([8 , 10 ], pa .int64 ())
680694
681695 # Verify we don't make a breaking change to pre-43.0.0
682696 # where users would pass join_keys as a positional argument
683697 df2 = df .join (df1 , (["a" ], ["a" ]), how = "inner" )
684- df2 .show ()
685698 df2 = df2 .sort (column ("l.a" ))
686699 table = pa .Table .from_batches (df2 .collect ())
687700
0 commit comments