@@ -166,7 +166,8 @@ def merge(
166166 validate = validate ,
167167 )
168168 else :
169- op = _MergeOperation (
169+ klass = _MergeOperation if how != "leftsemi" else _SemiMergeOperation
170+ op = klass (
170171 left_df ,
171172 right_df ,
172173 how = how ,
@@ -817,7 +818,6 @@ def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None:
817818 # Overridden by AsOfMerge
818819 pass
819820
820- @final
821821 def _reindex_and_concat (
822822 self ,
823823 join_index : Index ,
@@ -945,7 +945,6 @@ def _indicator_post_merge(self, result: DataFrame) -> DataFrame:
945945 result = result .drop (labels = ["_left_indicator" , "_right_indicator" ], axis = 1 )
946946 return result
947947
948- @final
949948 def _maybe_restore_index_levels (self , result : DataFrame ) -> None :
950949 """
951950 Restore index levels specified as `on` parameters
@@ -989,7 +988,6 @@ def _maybe_restore_index_levels(self, result: DataFrame) -> None:
989988 if names_to_restore :
990989 result .set_index (names_to_restore , inplace = True )
991990
992- @final
993991 def _maybe_add_join_keys (
994992 self ,
995993 result : DataFrame ,
@@ -1740,7 +1738,8 @@ def get_join_indexers(
17401738 right = Index (rkey )
17411739
17421740 if (
1743- left .is_monotonic_increasing
1741+ how != "leftsemi"
1742+ and left .is_monotonic_increasing
17441743 and right .is_monotonic_increasing
17451744 and (left .is_unique or right .is_unique )
17461745 ):
@@ -1883,6 +1882,48 @@ def _convert_to_multiindex(index: Index) -> MultiIndex:
18831882 return tuple (join_levels ), tuple (join_codes ), tuple (join_names )
18841883
18851884
1885+ class _SemiMergeOperation (_MergeOperation ):
1886+ def __init__ (self , * args , ** kwargs ):
1887+ if kwargs .get ("validate" , None ):
1888+ raise NotImplementedError ("validate is not supported for semi-join." )
1889+
1890+ super ().__init__ (* args , ** kwargs )
1891+ if self .left_index or self .right_index :
1892+ raise NotImplementedError (
1893+ "left_index or right_index are not supported for semi-join."
1894+ )
1895+ elif self .indicator :
1896+ raise NotImplementedError ("indicator is not supported for semi-join." )
1897+ elif self .sort :
1898+ raise NotImplementedError (
1899+ "sort is not supported for semi-join. Sort your DataFrame afterwards."
1900+ )
1901+
1902+ def _maybe_add_join_keys (
1903+ self ,
1904+ result : DataFrame ,
1905+ left_indexer : npt .NDArray [np .intp ] | None ,
1906+ right_indexer : npt .NDArray [np .intp ] | None ,
1907+ ) -> None :
1908+ return
1909+
1910+ def _maybe_restore_index_levels (self , result : DataFrame ) -> None :
1911+ return
1912+
1913+ def _reindex_and_concat (
1914+ self ,
1915+ join_index : Index ,
1916+ left_indexer : npt .NDArray [np .intp ] | None ,
1917+ right_indexer : npt .NDArray [np .intp ] | None ,
1918+ ) -> DataFrame :
1919+ left = self .left [:]
1920+
1921+ if left_indexer is not None and not is_range_indexer (left_indexer , len (left )):
1922+ lmgr = left ._mgr .take (left_indexer , axis = 1 , verify = False )
1923+ left = left ._constructor_from_mgr (lmgr , axes = lmgr .axes )
1924+ return left
1925+
1926+
18861927class _OrderedMerge (_MergeOperation ):
18871928 _merge_type = "ordered_merge"
18881929
@@ -2470,7 +2511,7 @@ def _factorize_keys(
24702511 lk = ensure_int64 (lk .codes )
24712512 rk = ensure_int64 (rk .codes )
24722513
2473- elif isinstance (lk , ExtensionArray ) and lk .dtype == rk .dtype :
2514+ elif how != "leftsemi" and isinstance (lk , ExtensionArray ) and lk .dtype == rk .dtype :
24742515 if (isinstance (lk .dtype , ArrowDtype ) and is_string_dtype (lk .dtype )) or (
24752516 isinstance (lk .dtype , StringDtype )
24762517 and lk .dtype .storage in ["pyarrow" , "pyarrow_numpy" ]
@@ -2560,14 +2601,18 @@ def _factorize_keys(
25602601 lk_data , rk_data = lk , rk # type: ignore[assignment]
25612602 lk_mask , rk_mask = None , None
25622603
2563- hash_join_available = how == "inner" and not sort and lk . dtype . kind in "iufb"
2604+ hash_join_available = how == "inner" and not sort
25642605 if hash_join_available :
25652606 rlab = rizer .factorize (rk_data , mask = rk_mask )
25662607 if rizer .get_count () == len (rlab ):
25672608 ridx , lidx = rizer .hash_inner_join (lk_data , lk_mask )
25682609 return lidx , ridx , - 1
25692610 else :
25702611 llab = rizer .factorize (lk_data , mask = lk_mask )
2612+ elif how == "leftsemi" :
2613+ # populate hashtable for right and then do a hash join
2614+ rizer .factorize (rk_data , mask = rk_mask )
2615+ return rizer .hash_inner_join (lk_data , lk_mask )[1 ], None , - 1
25712616 else :
25722617 llab = rizer .factorize (lk_data , mask = lk_mask )
25732618 rlab = rizer .factorize (rk_data , mask = rk_mask )
0 commit comments