Skip to content

Commit 468de3f

Browse files
committed
Added link to issue in test for take method.
1 parent cd2f2aa commit 468de3f

File tree

2 files changed

+65
-10
lines changed

2 files changed

+65
-10
lines changed

pandas/core/arrays/_mixins.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
from pandas.core.dtypes.dtypes import (
4141
DatetimeTZDtype,
4242
ExtensionDtype,
43-
NumpyEADtype,
4443
PeriodDtype,
4544
)
4645
from pandas.core.dtypes.missing import array_equivalent
@@ -190,15 +189,70 @@ def take(
190189
# don't accept a dtype parameter, which I need to pass to set the
191190
# result's dtype to a floating-point type.
192191

193-
if self.dtype in [
194-
NumpyEADtype(np.uint8),
195-
NumpyEADtype(np.uint16),
196-
NumpyEADtype(np.uint32),
197-
NumpyEADtype(np.uint64),
198-
NumpyEADtype(np.int8),
199-
NumpyEADtype(np.int16),
200-
NumpyEADtype(np.int32),
201-
NumpyEADtype(np.int64),
192+
# All tests pass when I create a new extension array object with the
193+
# appropriate dtype (in the integer-dtype source case), however MyPy
194+
# complains about the missing dtype argument in the call to type(self)
195+
# below. By creating a new array object, this call produces an array
196+
# with a floating point dtype, even when the source dtype is integral.
197+
# I think this happens because the new array is created with the newly
198+
# produced data from the underlying take method, which has the
199+
# appropriate underlying dtype.
200+
201+
# Essentially, these extension arrays are wrappers around Numpy arrays
202+
# which have their own dtype and store the data. Thus, the new
203+
# extension array inherits the dtype from the Numpy array used
204+
# to create it.
205+
206+
# Unfortunately, some of the derived constructors of this class have a
207+
# positional dtype argument, while some do not. If I call a constructor
208+
# without specifying this argument, mypy will complain about the
209+
# missing argument in the case of constructors that require it, but
210+
# if I call the constructor with the dtype argument, the constructors
211+
# that don't have it will fail at runtime since they don't recognize
212+
# it.
213+
214+
# How can I get around this issue?
215+
# Ideas:
216+
# Modify the extension array type to allow modification of its dtype
217+
# after construction.
218+
219+
# Add a conditional branch to this method to call derived constructors
220+
# with or without the dtype argument, depending on their class.
221+
# This approach has the disadvantage of hardcoding information about
222+
# derived classes in this base class, which means that if someone
223+
# changes a constructor of a derived class to remove the dtype argument,
224+
# this method will break.
225+
226+
# Classes derived from this class include:
227+
228+
# Categorical
229+
# DatetimeLikeArrayMixin
230+
# DatelikeOps
231+
# PeriodArray
232+
# DatetimeArray
233+
# TimelikeOps
234+
# TimedeltaArray
235+
# NumpyExtensionArray
236+
# StringArray
237+
238+
# The types of extension arrays (within Pandas) derived from this class are:
239+
# Class name Constructor takes dtype argument Dtype argument required
240+
# Categorical yes no
241+
# PeriodArray yes no
242+
# DatetimeArray
243+
# TimedeltaArray
244+
# StringArray yes no
245+
# NumpyExtensionArray no no
246+
247+
if hasattr(self.dtype, "numpy_dtype") and self.dtype.numpy_dtype in [
248+
np.uint8,
249+
np.uint16,
250+
np.uint32,
251+
np.uint64,
252+
np.int8,
253+
np.int16,
254+
np.int32,
255+
np.int64,
202256
]:
203257
return type(self)(new_data)
204258
return self._from_backing_data(new_data)

pandas/tests/arrays/numpy_/test_numpy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def test_factorize_unsigned():
327327

328328
@pytest.mark.parametrize("dtype", [np.uint32, np.uint64, np.int32, np.int64])
329329
def test_take_assigns_correct_dtype(dtype):
330+
# GH#62448.
330331
array = NumpyExtensionArray(np.array([1, 2, 3], dtype=dtype))
331332

332333
result = array.take([-1], allow_fill=True)

0 commit comments

Comments
 (0)