Skip to content

Commit f5c7ed0

Browse files
committed
Add field to dataframe join to indicate if we should keep duplicate keys
1 parent 2e1b713 commit f5c7ed0

File tree

3 files changed

+56
-13
lines changed

3 files changed

+56
-13
lines changed

python/datafusion/dataframe.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,7 @@ def join(
643643
left_on: None = None,
644644
right_on: None = None,
645645
join_keys: None = None,
646+
keep_duplicate_keys: bool = False,
646647
) -> DataFrame: ...
647648

648649
@overload
@@ -655,6 +656,7 @@ def join(
655656
left_on: str | Sequence[str],
656657
right_on: str | Sequence[str],
657658
join_keys: tuple[list[str], list[str]] | None = None,
659+
keep_duplicate_keys: bool = False,
658660
) -> DataFrame: ...
659661

660662
@overload
@@ -667,6 +669,7 @@ def join(
667669
join_keys: tuple[list[str], list[str]],
668670
left_on: None = None,
669671
right_on: None = None,
672+
keep_duplicate_keys: bool = False,
670673
) -> DataFrame: ...
671674

672675
def join(
@@ -678,6 +681,7 @@ def join(
678681
left_on: str | Sequence[str] | None = None,
679682
right_on: str | Sequence[str] | None = None,
680683
join_keys: tuple[list[str], list[str]] | None = None,
684+
keep_duplicate_keys: bool = False,
681685
) -> DataFrame:
682686
"""Join this :py:class:`DataFrame` with another :py:class:`DataFrame`.
683687
@@ -690,11 +694,23 @@ def join(
690694
"right", "full", "semi", "anti".
691695
left_on: Join column of the left dataframe.
692696
right_on: Join column of the right dataframe.
697+
keep_duplicate_keys: When False, the columns from the right DataFrame
698+
that have identical names in the ``on`` fields to the left DataFrame
699+
will be dropped.
693700
join_keys: Tuple of two lists of column names to join on. [Deprecated]
694701
695702
Returns:
696703
DataFrame after join.
697704
"""
705+
if join_keys is not None:
706+
warnings.warn(
707+
"`join_keys` is deprecated, use `on` or `left_on` with `right_on`",
708+
category=DeprecationWarning,
709+
stacklevel=2,
710+
)
711+
left_on = join_keys[0]
712+
right_on = join_keys[1]
713+
698714
# This check is to prevent breaking API changes where users prior to
699715
# DF 43.0.0 would pass the join_keys as a positional argument instead
700716
# of a keyword argument.
@@ -705,18 +721,10 @@ def join(
705721
and isinstance(on[1], list)
706722
):
707723
# We know this is safe because we've checked the types
708-
join_keys = on # type: ignore[assignment]
724+
left_on = on[0]
725+
right_on = on[1]
709726
on = None
710727

711-
if join_keys is not None:
712-
warnings.warn(
713-
"`join_keys` is deprecated, use `on` or `left_on` with `right_on`",
714-
category=DeprecationWarning,
715-
stacklevel=2,
716-
)
717-
left_on = join_keys[0]
718-
right_on = join_keys[1]
719-
720728
if on is not None:
721729
if left_on is not None or right_on is not None:
722730
error_msg = "`left_on` or `right_on` should not provided with `on`"
@@ -735,7 +743,9 @@ def join(
735743
if isinstance(right_on, str):
736744
right_on = [right_on]
737745

738-
return DataFrame(self.df.join(right.df, how, left_on, right_on))
746+
return DataFrame(
747+
self.df.join(right.df, how, left_on, right_on, keep_duplicate_keys)
748+
)
739749

740750
def join_on(
741751
self,

python/tests/test_dataframe.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,6 @@ def test_unnest_without_nulls(nested_df):
400400
assert result.column(1) == pa.array([7, 8, 8, 9, 9, 9])
401401

402402

403-
@pytest.mark.filterwarnings("ignore:`join_keys`:DeprecationWarning")
404403
def test_join():
405404
ctx = SessionContext()
406405

src/dataframe.rs

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,7 @@ impl PyDataFrame {
566566
how: &str,
567567
left_on: Vec<PyBackedStr>,
568568
right_on: Vec<PyBackedStr>,
569+
keep_duplicate_keys: bool,
569570
) -> PyDataFusionResult<Self> {
570571
let join_type = match how {
571572
"inner" => JoinType::Inner,
@@ -584,13 +585,46 @@ impl PyDataFrame {
584585
let left_keys = left_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
585586
let right_keys = right_on.iter().map(|s| s.as_ref()).collect::<Vec<&str>>();
586587

587-
let df = self.df.as_ref().clone().join(
588+
let mut df = self.df.as_ref().clone().join(
588589
right.df.as_ref().clone(),
589590
join_type,
590591
&left_keys,
591592
&right_keys,
592593
None,
593594
)?;
595+
596+
if !keep_duplicate_keys {
597+
let mutual_keys = left_keys
598+
.iter()
599+
.zip(right_keys.iter())
600+
.filter(|(l, r)| l == r)
601+
.map(|(key, _)| *key)
602+
.collect::<Vec<_>>();
603+
604+
let fields_to_drop = mutual_keys
605+
.iter()
606+
.map(|name| {
607+
df.logical_plan()
608+
.schema()
609+
.qualified_fields_with_unqualified_name(name)
610+
})
611+
.filter(|r| r.len() == 2)
612+
.map(|r| r[1])
613+
.collect::<Vec<_>>();
614+
615+
let expr: Vec<Expr> = df
616+
.logical_plan()
617+
.schema()
618+
.fields()
619+
.into_iter()
620+
.enumerate()
621+
.map(|(idx, _)| df.logical_plan().schema().qualified_field(idx))
622+
.filter(|(qualifier, f)| !fields_to_drop.contains(&(*qualifier, f)))
623+
.map(|(qualifier, field)| Expr::Column(Column::from((qualifier, field))))
624+
.collect();
625+
df = df.select(expr)?;
626+
}
627+
594628
Ok(Self::new(df))
595629
}
596630

0 commit comments

Comments
 (0)