Skip to content

Commit 5071e11

Browse files
author
nli307
committed
When finding the common type, return the original type if there is only one common type. Return objects for date32 and date64 types when converting those types to numpy types.
1 parent 54c26ec commit 5071e11

File tree

3 files changed

+39
-0
lines changed

3 files changed

+39
-0
lines changed

pandas/core/dtypes/common.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,6 +1872,11 @@ def pandas_dtype(dtype) -> DtypeObj:
18721872
result = result()
18731873
return result
18741874

1875+
# try a pyarrow dtype
1876+
from pandas.core.dtypes.dtypes import ArrowDtype
1877+
if isinstance(dtype, ArrowDtype):
1878+
return ArrowDtype(dtype)
1879+
18751880
# try a numpy dtype
18761881
# raise a consistent TypeError if failed
18771882
try:

pandas/core/dtypes/dtypes.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2277,6 +2277,12 @@ def name(self) -> str: # type: ignore[override]
22772277
@cache_readonly
22782278
def numpy_dtype(self) -> np.dtype:
22792279
"""Return an instance of the related numpy dtype"""
2280+
if pa.types.is_date32(self.pyarrow_dtype) or pa.types.is_date64(
2281+
self.pyarrow_dtype
2282+
):
2283+
# date32 and date64 are pyarrow timestamps but do not have a
2284+
# corresponding numpy dtype.
2285+
return np.dtype(object)
22802286
if pa.types.is_timestamp(self.pyarrow_dtype):
22812287
# pa.timestamp(unit).to_pandas_dtype() returns ns units
22822288
# regardless of the pyarrow timestamp units.
@@ -2453,6 +2459,18 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
24532459

24542460
null_dtype = type(self)(pa.null())
24552461

2462+
# Cover cases where numpy does not have a corresponding dtype, but
2463+
# only one non-null dtype is received, or all dtypes are null.
2464+
single_dtype = {
2465+
dtype
2466+
for dtype in dtypes
2467+
if dtype != null_dtype
2468+
}
2469+
if len(single_dtype) == 0:
2470+
return null_dtype
2471+
if len(single_dtype) == 1:
2472+
return single_dtype.pop()
2473+
24562474
new_dtype = find_common_type(
24572475
[
24582476
dtype.numpy_dtype if isinstance(dtype, ArrowDtype) else dtype

pandas/tests/dtypes/cast/test_find_common_type.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,19 @@ def test_interval_dtype_with_categorical(dtype):
173173

174174
result = find_common_type([dtype, cat.dtype])
175175
assert result == dtype
176+
177+
178+
@pytest.mark.parametrize(
179+
"dtypes,expected",
180+
[
181+
(
182+
["date32[pyarrow]", "null[pyarrow]"],
183+
"date32[day][pyarrow]",
184+
),
185+
],
186+
)
187+
def test_pyarrow_dtypes(dtypes, expected):
188+
"""Test finding common types with pyarrow dtypes not in numpy."""
189+
source_dtypes = [pandas_dtype(dtype) for dtype in dtypes]
190+
result = find_common_type(source_dtypes)
191+
assert result == pandas_dtype(expected)

0 commit comments

Comments
 (0)