Skip to content

Commit 437c9e1

Browse files
committed
resolve cut
1 parent 12d8d84 commit 437c9e1

File tree

5 files changed

+52
-59
lines changed

5 files changed

+52
-59
lines changed

pandas-stubs/core/algorithms.pyi

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ from pandas._typing import (
2727
# with extension types return the same type while standard type return ndarray
2828

2929
@overload
30-
def unique(values: PeriodIndex) -> PeriodIndex: ...
30+
def unique( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
31+
values: PeriodIndex,
32+
) -> PeriodIndex: ...
3133
@overload
3234
def unique(
3335
values: CategoricalIndex[S1, GenericT_co],

pandas-stubs/core/reshape/tile.pyi

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
from collections.abc import Sequence
22
from typing import (
3+
Any,
34
Literal,
45
overload,
56
)
67

78
import numpy as np
89
from pandas import (
910
Categorical,
10-
CategoricalDtype,
1111
DatetimeIndex,
1212
Index,
13-
Interval,
1413
IntervalIndex,
1514
Timestamp,
1615
)
1716
from pandas.core.series import Series
1817

18+
from pandas._libs.interval import Interval
1919
from pandas._typing import (
2020
IntervalT,
2121
Label,
@@ -49,7 +49,7 @@ def cut(
4949
ordered: bool = ...,
5050
) -> tuple[npt.NDArray[np.intp], IntervalIndex[IntervalT]]: ...
5151
@overload
52-
def cut( # pyright: ignore[reportOverlappingOverload]
52+
def cut(
5353
x: Series[Timestamp],
5454
bins: (
5555
int
@@ -66,7 +66,7 @@ def cut( # pyright: ignore[reportOverlappingOverload]
6666
include_lowest: bool = ...,
6767
duplicates: Literal["raise", "drop"] = ...,
6868
ordered: bool = ...,
69-
) -> tuple[Series, DatetimeIndex]: ...
69+
) -> tuple[Series[Any, Categorical], DatetimeIndex]: ...
7070
@overload
7171
def cut(
7272
x: Series[Timestamp],
@@ -79,10 +79,10 @@ def cut(
7979
include_lowest: bool = ...,
8080
duplicates: Literal["raise", "drop"] = ...,
8181
ordered: bool = ...,
82-
) -> tuple[Series, DatetimeIndex]: ...
82+
) -> tuple[Series[Any, Categorical], DatetimeIndex]: ...
8383
@overload
8484
def cut(
85-
x: Series,
85+
x: Series[int] | Series[float],
8686
bins: int | Series | Index[int] | Index[float] | Sequence[int] | Sequence[float],
8787
right: bool = ...,
8888
labels: Literal[False] | Sequence[Label] | None = ...,
@@ -92,10 +92,10 @@ def cut(
9292
include_lowest: bool = ...,
9393
duplicates: Literal["raise", "drop"] = ...,
9494
ordered: bool = ...,
95-
) -> tuple[Series, npt.NDArray]: ...
95+
) -> tuple[Series[Any, Categorical], npt.NDArray]: ...
9696
@overload
9797
def cut(
98-
x: Series,
98+
x: Series[int] | Series[float],
9999
bins: IntervalIndex[Interval[int]] | IntervalIndex[Interval[float]],
100100
right: bool = ...,
101101
labels: Sequence[Label] | None = ...,
@@ -105,7 +105,7 @@ def cut(
105105
include_lowest: bool = ...,
106106
duplicates: Literal["raise", "drop"] = ...,
107107
ordered: bool = ...,
108-
) -> tuple[Series, IntervalIndex]: ...
108+
) -> tuple[Series[Any, Categorical], IntervalIndex]: ...
109109
@overload
110110
def cut(
111111
x: Index | npt.NDArray | Sequence[int] | Sequence[float],
@@ -158,11 +158,10 @@ def cut(
158158
x: Series[Timestamp],
159159
bins: (
160160
int
161-
| Series[Timestamp]
161+
| Sequence[np.datetime64 | Timestamp]
162162
| DatetimeIndex
163-
| Sequence[Timestamp]
164-
| Sequence[np.datetime64]
165163
| IntervalIndex[Interval[Timestamp]]
164+
| Series[Timestamp]
166165
),
167166
right: bool = ...,
168167
labels: Literal[False] | Sequence[Label] | None = ...,
@@ -171,27 +170,19 @@ def cut(
171170
include_lowest: bool = ...,
172171
duplicates: Literal["raise", "drop"] = ...,
173172
ordered: bool = ...,
174-
) -> Series[CategoricalDtype]: ...
173+
) -> Series[Any, Categorical]: ...
175174
@overload
176175
def cut(
177-
x: Series,
178-
bins: (
179-
int
180-
| Series
181-
| Index[int]
182-
| Index[float]
183-
| Sequence[int]
184-
| Sequence[float]
185-
| IntervalIndex
186-
),
176+
x: Series[int] | Series[float],
177+
bins: int | Sequence[float] | Index[int] | Index[float] | IntervalIndex | Series,
187178
right: bool = ...,
188179
labels: Literal[False] | Sequence[Label] | None = ...,
189180
retbins: Literal[False] = False,
190181
precision: int = ...,
191182
include_lowest: bool = ...,
192183
duplicates: Literal["raise", "drop"] = ...,
193184
ordered: bool = ...,
194-
) -> Series: ...
185+
) -> Series[Any, Categorical]: ...
195186
@overload
196187
def cut(
197188
x: Index | npt.NDArray | Sequence[int] | Sequence[float],

pandas-stubs/core/series.pyi

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,17 @@ from matplotlib.axes import (
4747
SubplotBase,
4848
)
4949
import numpy as np
50-
from pandas import (
51-
Index,
52-
Period,
53-
PeriodDtype,
54-
Timedelta,
55-
Timestamp,
56-
)
5750
from pandas.core.api import (
5851
Int8Dtype as Int8Dtype,
5952
Int16Dtype as Int16Dtype,
6053
Int32Dtype as Int32Dtype,
6154
Int64Dtype as Int64Dtype,
6255
)
6356
from pandas.core.arrays.boolean import BooleanDtype
64-
from pandas.core.arrays.categorical import CategoricalAccessor
57+
from pandas.core.arrays.categorical import (
58+
Categorical,
59+
CategoricalAccessor,
60+
)
6561
from pandas.core.arrays.datetimes import DatetimeArray
6662
from pandas.core.arrays.timedeltas import TimedeltaArray
6763
from pandas.core.base import (
@@ -81,6 +77,7 @@ from pandas.core.groupby.generic import SeriesGroupBy
8177
from pandas.core.groupby.groupby import BaseGroupBy
8278
from pandas.core.indexers import BaseIndexer
8379
from pandas.core.indexes.accessors import DtDescriptor
80+
from pandas.core.indexes.base import Index
8481
from pandas.core.indexes.category import CategoricalIndex
8582
from pandas.core.indexes.datetimes import DatetimeIndex
8683
from pandas.core.indexes.interval import IntervalIndex
@@ -117,6 +114,9 @@ from pandas._libs.lib import _NoDefaultDoNotUse
117114
from pandas._libs.missing import NAType
118115
from pandas._libs.tslibs import BaseOffset
119116
from pandas._libs.tslibs.nattype import NaTType
117+
from pandas._libs.tslibs.period import Period
118+
from pandas._libs.tslibs.timedeltas import Timedelta
119+
from pandas._libs.tslibs.timestamps import Timestamp
120120
from pandas._typing import (
121121
S1,
122122
S2,
@@ -216,7 +216,7 @@ from pandas._typing import (
216216
)
217217

218218
from pandas.core.dtypes.base import ExtensionDtype
219-
from pandas.core.dtypes.dtypes import CategoricalDtype
219+
from pandas.core.dtypes.dtypes import PeriodDtype
220220

221221
from pandas.plotting import PlotAccessor
222222

@@ -387,7 +387,7 @@ class Series(IndexOpsMixin[S1, A1_co, GenericT_co], ElementOpsMixin[S1], NDFrame
387387
dtype: CategoryDtypeArg,
388388
name: Hashable = ...,
389389
copy: bool = ...,
390-
) -> Series[CategoricalDtype]: ...
390+
) -> Series[Any, Categorical]: ...
391391
@overload
392392
def __new__(
393393
cls,
@@ -863,9 +863,9 @@ class Series(IndexOpsMixin[S1, A1_co, GenericT_co], ElementOpsMixin[S1], NDFrame
863863
@overload
864864
def unique(self: Series[Never]) -> np.ndarray: ...
865865
@overload
866-
def unique(self: Series[Timestamp]) -> DatetimeArray: ... # type: ignore[overload-overlap]
866+
def unique(self: Series[Timestamp]) -> DatetimeArray: ...
867867
@overload
868-
def unique(self: Series[Timedelta]) -> TimedeltaArray: ... # type: ignore[overload-overlap]
868+
def unique(self: Series[Timedelta]) -> TimedeltaArray: ...
869869
@overload
870870
def unique(self) -> np.ndarray: ...
871871
@overload
@@ -1449,7 +1449,7 @@ class Series(IndexOpsMixin[S1, A1_co, GenericT_co], ElementOpsMixin[S1], NDFrame
14491449
dtype: CategoryDtypeArg,
14501450
copy: _bool = ...,
14511451
errors: IgnoreRaise = ...,
1452-
) -> Series[CategoricalDtype]: ...
1452+
) -> Series[S1, Categorical]: ...
14531453
@overload
14541454
def astype(
14551455
self,
@@ -3625,7 +3625,7 @@ class Series(IndexOpsMixin[S1, A1_co, GenericT_co], ElementOpsMixin[S1], NDFrame
36253625
axis: int = 0,
36263626
) -> Series[BaseOffset]: ...
36273627
@overload
3628-
def __truediv__( # type: ignore[overload-overlap]
3628+
def __truediv__(
36293629
self: Series[Never], other: complex | NumListLike | Index | Series
36303630
) -> Series: ...
36313631
@overload

tests/series/test_series.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,6 @@
5050
Scalar,
5151
)
5252

53-
from pandas.core.dtypes.dtypes import CategoricalDtype # noqa F401
54-
5553
from tests import (
5654
PD_LTE_23,
5755
TYPE_CHECKING_INVALID_USAGE,
@@ -1827,7 +1825,7 @@ def test_categorical_codes() -> None:
18271825

18281826
# GH1383
18291827
sr = pd.Series([1], dtype="category")
1830-
check(assert_type(sr, "pd.Series[CategoricalDtype]"), pd.Series, np.integer)
1828+
check(assert_type(sr, "pd.Series[Any, pd.Categorical]"), pd.Series, np.integer)
18311829

18321830

18331831
def test_relops() -> None:
@@ -2915,8 +2913,8 @@ def test_astype_categorical(cast_arg: CategoryDtypeArg, target_type: type) -> No
29152913

29162914
if TYPE_CHECKING:
29172915
# pandas category
2918-
assert_type(s.astype(pd.CategoricalDtype()), "pd.Series[pd.CategoricalDtype]")
2919-
assert_type(s.astype(cast_arg), "pd.Series[pd.CategoricalDtype]")
2916+
assert_type(s.astype(pd.CategoricalDtype()), "pd.Series[str, pd.Categorical]")
2917+
assert_type(s.astype(cast_arg), "pd.Series[str, pd.Categorical]")
29202918

29212919

29222920
@pytest.mark.parametrize("cast_arg, target_type", ASTYPE_OBJECT_ARGS, ids=repr)

tests/test_pandas.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,9 +1030,13 @@ def test_cut() -> None:
10301030
g = pd.cut(pd.Series([1, 2, 3, 4, 5, 6, 7, 8]), 4, precision=1, duplicates="drop")
10311031
h = pd.cut(pd.Series([1, 2, 3, 4, 5, 6, 7, 8]), 4, labels=False, duplicates="raise")
10321032
i = pd.cut(pd.Series([1, 2, 3, 4, 5, 6, 7, 8]), 4, labels=["1", "2", "3", "4"])
1033-
check(assert_type(g, pd.Series), pd.Series)
1034-
check(assert_type(h, pd.Series), pd.Series)
1035-
check(assert_type(i, pd.Series), pd.Series)
1033+
check(assert_type(g, pd.Series[Any, pd.Categorical]), pd.Series, pd.Interval[float])
1034+
check(assert_type(h, pd.Series[Any, pd.Categorical]), pd.Series, pd.Interval[float])
1035+
check(
1036+
assert_type(i, pd.Series[Any, pd.Categorical]),
1037+
pd.Series,
1038+
pd.Interval[float],
1039+
)
10361040

10371041
j0, j1 = pd.cut(
10381042
pd.Series([1, 2, 3, 4, 5, 6, 7, 8]),
@@ -1059,13 +1063,13 @@ def test_cut() -> None:
10591063
intval_idx,
10601064
retbins=True,
10611065
)
1062-
check(assert_type(j0, pd.Series), pd.Series)
1066+
check(assert_type(j0, pd.Series[Any, pd.Categorical]), pd.Series)
10631067
check(assert_type(j1, npt.NDArray), np.ndarray)
1064-
check(assert_type(k0, pd.Series), pd.Series)
1068+
check(assert_type(k0, pd.Series[Any, pd.Categorical]), pd.Series)
10651069
check(assert_type(k1, npt.NDArray), np.ndarray)
1066-
check(assert_type(l0, pd.Series), pd.Series)
1070+
check(assert_type(l0, pd.Series[Any, pd.Categorical]), pd.Series)
10671071
check(assert_type(l1, npt.NDArray), np.ndarray)
1068-
check(assert_type(m0, pd.Series), pd.Series)
1072+
check(assert_type(m0, pd.Series[Any, pd.Categorical]), pd.Series)
10691073
check(assert_type(m1, pd.IntervalIndex), pd.IntervalIndex)
10701074

10711075
n0, n1 = pd.cut([1, 2, 3, 4, 5, 6, 7, 8], intval_idx, retbins=True)
@@ -1076,26 +1080,24 @@ def test_cut() -> None:
10761080
check(
10771081
assert_type(
10781082
pd.cut(s1, bins=[np.datetime64("2020-01-03"), np.datetime64("2020-09-01")]),
1079-
"pd.Series[pd.CategoricalDtype]",
1083+
"pd.Series[Any, pd.Categorical]",
10801084
),
10811085
pd.Series,
1086+
pd.Interval[pd.Timestamp],
10821087
)
10831088
check(
1084-
assert_type(
1085-
pd.cut(s1, bins=10),
1086-
"pd.Series[pd.CategoricalDtype]",
1087-
),
1089+
assert_type(pd.cut(s1, bins=10), pd.Series[Any, pd.Categorical]),
10881090
pd.Series,
10891091
pd.Interval,
10901092
)
10911093
s0r, s1r = pd.cut(s1, bins=10, retbins=True)
1092-
check(assert_type(s0r, pd.Series), pd.Series, pd.Interval)
1094+
check(assert_type(s0r, pd.Series[Any, pd.Categorical]), pd.Series, pd.Interval)
10931095
check(assert_type(s1r, pd.DatetimeIndex), pd.DatetimeIndex, pd.Timestamp)
10941096
s0rlf, s1rlf = pd.cut(s1, bins=10, labels=False, retbins=True)
1095-
check(assert_type(s0rlf, pd.Series), pd.Series, np.integer)
1097+
check(assert_type(s0rlf, pd.Series[Any, pd.Categorical]), pd.Series, np.integer)
10961098
check(assert_type(s1rlf, pd.DatetimeIndex), pd.DatetimeIndex, pd.Timestamp)
10971099
s0rls, s1rls = pd.cut(s1, bins=4, labels=["1", "2", "3", "4"], retbins=True)
1098-
check(assert_type(s0rls, pd.Series), pd.Series, str)
1100+
check(assert_type(s0rls, pd.Series[Any, pd.Categorical]), pd.Series, str)
10991101
check(assert_type(s1rls, pd.DatetimeIndex), pd.DatetimeIndex, pd.Timestamp)
11001102

11011103

0 commit comments

Comments
 (0)