Skip to content

Commit 44529d9

Browse files
committed
Add field to dataframe join to indicate if we should keep duplicate keys
1 parent e97ed57 commit 44529d9

File tree

3 files changed

+56
-13
lines changed

3 files changed

+56
-13
lines changed

python/datafusion/dataframe.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -774,6 +774,7 @@ def join(
774774
left_on: None = None,
775775
right_on: None = None,
776776
join_keys: None = None,
777+
keep_duplicate_keys: bool = False,
777778
) -> DataFrame: ...
778779

779780
@overload
@@ -786,6 +787,7 @@ def join(
786787
left_on: str | Sequence[str],
787788
right_on: str | Sequence[str],
788789
join_keys: tuple[list[str], list[str]] | None = None,
790+
keep_duplicate_keys: bool = False,
789791
) -> DataFrame: ...
790792

791793
@overload
@@ -798,6 +800,7 @@ def join(
798800
join_keys: tuple[list[str], list[str]],
799801
left_on: None = None,
800802
right_on: None = None,
803+
keep_duplicate_keys: bool = False,
801804
) -> DataFrame: ...
802805

803806
def join(
@@ -809,6 +812,7 @@ def join(
809812
left_on: str | Sequence[str] | None = None,
810813
right_on: str | Sequence[str] | None = None,
811814
join_keys: tuple[list[str], list[str]] | None = None,
815+
keep_duplicate_keys: bool = False,
812816
) -> DataFrame:
813817
"""Join this :py:class:`DataFrame` with another :py:class:`DataFrame`.
814818
@@ -821,11 +825,23 @@ def join(
821825
"right", "full", "semi", "anti".
822826
left_on: Join column of the left dataframe.
823827
right_on: Join column of the right dataframe.
828+
keep_duplicate_keys: When False, the columns from the right DataFrame
829+
that have identical names in the ``on`` fields to the left DataFrame
830+
will be dropped.
824831
join_keys: Tuple of two lists of column names to join on. [Deprecated]
825832
826833
Returns:
827834
DataFrame after join.
828835
"""
836+
if join_keys is not None:
837+
warnings.warn(
838+
"`join_keys` is deprecated, use `on` or `left_on` with `right_on`",
839+
category=DeprecationWarning,
840+
stacklevel=2,
841+
)
842+
left_on = join_keys[0]
843+
right_on = join_keys[1]
844+
829845
# This check is to prevent breaking API changes where users prior to
830846
# DF 43.0.0 would pass the join_keys as a positional argument instead
831847
# of a keyword argument.
@@ -836,18 +852,10 @@ def join(
836852
and isinstance(on[1], list)
837853
):
838854
# We know this is safe because we've checked the types
839-
join_keys = on # type: ignore[assignment]
855+
left_on = on[0]
856+
right_on = on[1]
840857
on = None
841858

842-
if join_keys is not None:
843-
warnings.warn(
844-
"`join_keys` is deprecated, use `on` or `left_on` with `right_on`",
845-
category=DeprecationWarning,
846-
stacklevel=2,
847-
)
848-
left_on = join_keys[0]
849-
right_on = join_keys[1]
850-
851859
if on is not None:
852860
if left_on is not None or right_on is not None:
853861
error_msg = "`left_on` or `right_on` should not provided with `on`"
@@ -866,7 +874,9 @@ def join(
866874
if isinstance(right_on, str):
867875
right_on = [right_on]
868876

869-
return DataFrame(self.df.join(right.df, how, left_on, right_on))
877+
return DataFrame(
878+
self.df.join(right.df, how, left_on, right_on, keep_duplicate_keys)
879+
)
870880

871881
def join_on(
872882
self,

python/tests/test_dataframe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -647,7 +647,6 @@ def test_unnest_without_nulls(nested_df):
647647
assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9])
648648

649649

650-
@pytest.mark.filterwarnings("ignore:`join_keys`:DeprecationWarning")
651650
def test_join():
652651
ctx = SessionContext()
653652

src/dataframe.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,7 @@ impl PyDataFrame {
629629
how: &str,
630630
left_on: Vec<PyBackedStr>,
631631
right_on: Vec<PyBackedStr>,
632+
keep_duplicate_keys: bool,
632633
) -> PyDataFusionResult<Self> {
633634
let join_type = match how {
634635
"inner" => JoinType::Inner,
@@ -647,13 +648,46 @@ impl PyDataFrame {
647648
let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
648649
let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
649650

650-
let df = self.df.as_ref().clone().join(
651+
let mut df = self.df.as_ref().clone().join(
651652
right.df.as_ref().clone(),
652653
join_type,
653654
&left_keys,
654655
&right_keys,
655656
None,
656657
)?;
658+
659+
if !keep_duplicate_keys {
660+
let mutual_keys = left_keys
661+
.iter()
662+
.zip(right_keys.iter())
663+
.filter(|(l, r)| l == r)
664+
.map(|(key, _)| *key)
665+
.collect::<Vec<_>>();
666+
667+
let fields_to_drop = mutual_keys
668+
.iter()
669+
.map(|name| {
670+
df.logical_plan()
671+
.schema()
672+
.qualified_fields_with_unqualified_name(name)
673+
})
674+
.filter(|r| r.len() == 2)
675+
.map(|r| r[1])
676+
.collect::<Vec<_>>();
677+
678+
let expr: Vec<Expr> = df
679+
.logical_plan()
680+
.schema()
681+
.fields()
682+
.into_iter()
683+
.enumerate()
684+
.map(|(idx, _)| df.logical_plan().schema().qualified_field(idx))
685+
.filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f)))
686+
.map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field))))
687+
.collect();
688+
df = df.select(expr)?;
689+
}
690+
657691
Ok(Self::new(df))
658692
}
659693

0 commit comments

Comments
 (0)