Skip to content

Commit 18626a9

Browse files
authored
type isin, remove redundant Iterable | Series (#1507)
* type `isin`, remove redundant `Iterable | Series` * Iterable -> Iterable[Any], extra test cases * add overloads for multiindex.isin * fix pyrefly * fix pyright * fix pyrefly * extra test case, use Collection * remove extra test case * extra test cases * mypy * ty
1 parent cf7478f commit 18626a9

File tree

7 files changed

+54
-6
lines changed

7 files changed

+54
-6
lines changed

pandas-stubs/core/frame.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1783,7 +1783,9 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
17831783
axis: Axis = 0,
17841784
copy: _bool = True,
17851785
) -> Self: ...
1786-
def isin(self, values: Iterable | Series | DataFrame | dict) -> Self: ...
1786+
def isin(
1787+
self, values: Iterable[Any] | Mapping[Hashable, Iterable[Any]] | DataFrame
1788+
) -> Self: ...
17871789
@property
17881790
def plot(self) -> PlotAccessor: ...
17891791
def hist(

pandas-stubs/core/indexes/base.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,9 @@ class Index(IndexOpsMixin[S1], ElementOpsMixin[S1]):
576576
def map(
577577
self, mapper: Renamer, na_action: Literal["ignore"] | None = None
578578
) -> Index: ...
579-
def isin(self, values, level=...) -> np_1darray_bool: ...
579+
def isin(
580+
self, values: Iterable[Any], level: Level | None = None
581+
) -> np_1darray_bool: ...
580582
def slice_indexer(
581583
self,
582584
start: Label | None = None,

pandas-stubs/core/indexes/multi.pyi

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from collections.abc import (
22
Callable,
3+
Collection,
34
Hashable,
45
Iterable,
56
Mapping,
@@ -165,7 +166,14 @@ class MultiIndex(Index):
165166
def equal_levels(self, other): ...
166167
def insert(self, loc, item): ...
167168
def delete(self, loc): ...
168-
def isin(self, values, level=...) -> np_1darray_bool: ...
169+
@overload # type: ignore[override]
170+
def isin( # pyrefly: ignore[bad-override]
171+
self, values: Iterable[Any], level: Level
172+
) -> np_1darray_bool: ...
173+
@overload
174+
def isin( # ty: ignore[invalid-method-override] # pyright: ignore[reportIncompatibleMethodOverride]
175+
self, values: Collection[Iterable[Any]], level: None = None
176+
) -> np_1darray_bool: ...
169177
def set_names(
170178
self,
171179
names: Hashable | Sequence[Hashable] | Mapping[Any, Hashable],

pandas-stubs/core/series.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1305,7 +1305,7 @@ class Series(IndexOpsMixin[S1], ElementOpsMixin[S1], NDFrame):
13051305
show_counts: bool | None = ...,
13061306
) -> None: ...
13071307
def memory_usage(self, index: _bool = True, deep: _bool = False) -> int: ...
1308-
def isin(self, values: Iterable | Series[S1] | dict) -> Series[_bool]: ...
1308+
def isin(self, values: Iterable[Any]) -> Series[_bool]: ...
13091309
def between(
13101310
self,
13111311
left: Scalar | ListLikeU,

tests/indexes/test_indexes.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,18 @@ def test_index_duplicated() -> None:
5757

5858
def test_index_isin() -> None:
5959
ind = pd.Index([1, 2, 3, 4, 5])
60-
isin = ind.isin([2, 4])
61-
check(assert_type(isin, np_1darray_bool), np_1darray_bool)
60+
check(assert_type(ind.isin([2, 4]), np_1darray_bool), np_1darray_bool)
61+
check(assert_type(ind.isin({2, 4}), np_1darray_bool), np_1darray_bool)
62+
check(assert_type(ind.isin(pd.Series([2, 4])), np_1darray_bool), np_1darray_bool)
63+
check(assert_type(ind.isin(pd.Index([2, 4])), np_1darray_bool), np_1darray_bool)
64+
check(assert_type(ind.isin(iter([2, "4"])), np_1darray_bool), np_1darray_bool)
65+
66+
mi = pd.MultiIndex.from_arrays([[1, 2, 3]])
67+
check(assert_type(mi.isin([[3]]), np_1darray_bool), np_1darray_bool)
68+
check(assert_type(mi.isin({iter([3])}), np_1darray_bool), np_1darray_bool)
69+
if TYPE_CHECKING_INVALID_USAGE:
70+
mi.isin({3}) # type: ignore[arg-type] # pyright: ignore[reportArgumentType]
71+
mi.isin(iter([[3]])) # type: ignore[call-overload] # pyright: ignore[reportArgumentType]
6272

6373

6474
def test_index_astype() -> None:

tests/series/test_series.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,6 +1606,17 @@ def test_series_min_max_sub_axis() -> None:
16061606
check(assert_type(df.max(axis=1), pd.Series), pd.Series)
16071607

16081608

1609+
def test_series_isin() -> None:
1610+
s = pd.Series([1, 2, 3, 4, 5])
1611+
check(assert_type(s.isin([3, 4]), "pd.Series[bool]"), pd.Series, np.bool_)
1612+
check(assert_type(s.isin({3, 4}), "pd.Series[bool]"), pd.Series, np.bool_)
1613+
check(
1614+
assert_type(s.isin(pd.Series([3, 4])), "pd.Series[bool]"), pd.Series, np.bool_
1615+
)
1616+
check(assert_type(s.isin(pd.Index([3, 4])), "pd.Series[bool]"), pd.Series, np.bool_)
1617+
check(assert_type(s.isin(iter([3, "4"])), "pd.Series[bool]"), pd.Series, np.bool_)
1618+
1619+
16091620
def test_series_index_isin() -> None:
16101621
s = pd.Series([1, 2, 3, 4, 5], index=[1, 2, 2, 3, 3])
16111622
t1 = s.loc[s.index.isin([1, 3])]

tests/test_frame.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections import (
44
OrderedDict,
5+
UserDict,
56
UserList,
67
defaultdict,
78
deque,
@@ -2936,6 +2937,20 @@ def test_getmultiindex_columns() -> None:
29362937
check(assert_type(df[li[0]], pd.Series), pd.Series)
29372938

29382939

2940+
def test_frame_isin() -> None:
2941+
df = pd.DataFrame({"x": [1, 2, 3, 4, 5]}, index=[1, 2, 3, 4, 5])
2942+
check(assert_type(df.isin([1, 3, 5]), pd.DataFrame), pd.DataFrame)
2943+
check(assert_type(df.isin({1, 3, 5}), pd.DataFrame), pd.DataFrame)
2944+
check(assert_type(df.isin(pd.Series([1, 3, 5])), pd.DataFrame), pd.DataFrame)
2945+
check(assert_type(df.isin(pd.Index([1, 3, 5])), pd.DataFrame), pd.DataFrame)
2946+
check(assert_type(df.isin(df), pd.DataFrame), pd.DataFrame)
2947+
check(assert_type(df.isin({"x": [1, 2]}), pd.DataFrame), pd.DataFrame)
2948+
check(
2949+
assert_type(df.isin(UserDict({"x": iter([1, "2"])})), pd.DataFrame),
2950+
pd.DataFrame,
2951+
)
2952+
2953+
29392954
def test_frame_getitem_isin() -> None:
29402955
df = pd.DataFrame({"x": [1, 2, 3, 4, 5]}, index=[1, 2, 3, 4, 5])
29412956
check(assert_type(df[df.index.isin([1, 3, 5])], pd.DataFrame), pd.DataFrame)

0 commit comments

Comments
 (0)