Skip to content

Commit 91442a8

Browse files
committed
Moved changes to take method to NumpyExtensionArray class.
1 parent 468de3f commit 91442a8

File tree

3 files changed

+108
-31
lines changed

3 files changed

+108
-31
lines changed

pandas/core/arrays/_mixins.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -155,24 +155,11 @@ def view(self, dtype: Dtype | None = None) -> ArrayLike:
155155
# "ExtensionDtype | dtype[Any]"; expected "dtype[Any] | _HasDType[dtype[Any]]"
156156
return arr.view(dtype=dtype) # type: ignore[arg-type]
157157

158-
def take(
159-
self,
160-
indices: TakeIndexer,
161-
*,
162-
allow_fill: bool = False,
163-
fill_value: Any = None,
164-
axis: AxisInt = 0,
165-
) -> Self:
166-
if allow_fill:
167-
fill_value = self._validate_scalar(fill_value)
168158

169-
new_data = take(
170-
self._ndarray,
171-
indices,
172-
allow_fill=allow_fill,
173-
fill_value=fill_value,
174-
axis=axis,
175-
)
159+
# Notes on take method fix
160+
# Please remove these once this fix is ready to submit.
161+
# =======================================================================
162+
176163
# One of the base classes to this class: ExtensionArray, provides
177164
# the dtype property, but abstractly, so it leaves the implementation
178165
# of dtype storage up to its derived classes. Some of these derived
@@ -244,17 +231,25 @@ def take(
244231
# StringArray yes no
245232
# NumpyExtensionArray no no
246233

247-
if hasattr(self.dtype, "numpy_dtype") and self.dtype.numpy_dtype in [
248-
np.uint8,
249-
np.uint16,
250-
np.uint32,
251-
np.uint64,
252-
np.int8,
253-
np.int16,
254-
np.int32,
255-
np.int64,
256-
]:
257-
return type(self)(new_data)
234+
235+
def take(
236+
self,
237+
indices: TakeIndexer,
238+
*,
239+
allow_fill: bool = False,
240+
fill_value: Any = None,
241+
axis: AxisInt = 0,
242+
) -> Self:
243+
if allow_fill:
244+
fill_value = self._validate_scalar(fill_value)
245+
246+
new_data = take(
247+
self._ndarray,
248+
indices,
249+
allow_fill=allow_fill,
250+
fill_value=fill_value,
251+
axis=axis,
252+
)
258253
return self._from_backing_data(new_data)
259254

260255
# ------------------------------------------------------------------------

pandas/core/arrays/numpy_.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
InterpolateOptions,
4646
NpDtype,
4747
Scalar,
48+
TakeIndexer,
4849
npt,
4950
)
5051

@@ -350,6 +351,73 @@ def interpolate(
350351
return self
351352
return type(self)._simple_new(out_data, dtype=self.dtype)
352353

354+
def take(
355+
self,
356+
indices: TakeIndexer,
357+
*,
358+
allow_fill: bool = False,
359+
fill_value: Any = None,
360+
axis: AxisInt = 0,
361+
) -> Self:
362+
"""
363+
Take entries from this array at each index in a list of indices,
364+
producing an array containing only those entries.
365+
"""
366+
# See GH#62448.
367+
if self.dtype.numpy_dtype in [
368+
np.uint8,
369+
np.uint16,
370+
np.uint32,
371+
np.uint64,
372+
np.int8,
373+
np.int16,
374+
np.int32,
375+
np.int64
376+
]:
377+
# In this case, the resulting extension array should have a floating-point
378+
# dtype to match the result of the underlying take method when
379+
# NaN values need to be incorporated into it.
380+
# This occurs when allow_fill is True and fill_value is None.
381+
# (fill_value may be an arbitrary Python object, in which case
382+
# the result will be an array of objects.)
383+
384+
# Call the take method of NDArrayBackedExtensionArray
385+
386+
# TODO: How is the dtype of a newly constructed NumpyExtensionArray set?
387+
# It's set to match the dtype of its underlying array.
388+
389+
result = super().take(
390+
indices,
391+
allow_fill=allow_fill,
392+
fill_value=fill_value,
393+
axis=axis
394+
)
395+
return type(self)(result, copy=False)
396+
397+
# In this case, the resulting extension array will have a dtype
398+
# that matches that of the underlying Numpy array and we can link
399+
# to the underlying array without manipulating the extension's
400+
# dtype.
401+
402+
return super().take(
403+
indices,
404+
allow_fill=allow_fill,
405+
fill_value=fill_value,
406+
axis=axis
407+
)
408+
# result array dtype = self dtype
409+
410+
# Implementation steps:
411+
# Determine requirements for this method, including:
412+
# Argument types [done]
413+
# Return type [done]
414+
# Return dtype [done]
415+
# Write tests to check whether this method satisfies these requirements. [done]
416+
# Figure out what base class method to call to implement the take functionality. [done]
417+
# Implement the call. [done]
418+
# Check whether this method satisfies its requirements by running the tests.
419+
420+
353421
# ------------------------------------------------------------------------
354422
# Reductions
355423

pandas/tests/arrays/numpy_/test_numpy.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,14 +325,28 @@ def test_factorize_unsigned():
325325
tm.assert_extension_array_equal(res_unique, NumpyExtensionArray(exp_unique))
326326

327327

328-
@pytest.mark.parametrize("dtype", [np.uint32, np.uint64, np.int32, np.int64])
329-
def test_take_assigns_correct_dtype(dtype):
328+
# TODO: Add the smaller width dtypes to the parameter sets of these tests.
329+
@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32, np.int64])
330+
def test_take_assigns_floating_point_dtype(dtype):
330331
# GH#62448.
331332
array = NumpyExtensionArray(np.array([1, 2, 3], dtype=dtype))
332333

333334
result = array.take([-1], allow_fill=True)
334335

335-
assert result.dtype == NumpyEADtype(np.float64)
336+
assert result.dtype.numpy_dtype == np.float64
337+
338+
result = array.take([-1], allow_fill=True, fill_value=5.0)
339+
340+
assert result.dtype.numpy_dtype == np.float64
341+
342+
@pytest.mark.parametrize("dtype", [np.uint8, np.uint16, np.uint32, np.uint64, np.int8, np.int16, np.int32, np.int64])
343+
def test_take_assigns_integer_dtype_when_fill_disallowed(dtype):
344+
# GH#62448.
345+
array = NumpyExtensionArray(np.array([1, 2, 3], dtype=dtype))
346+
347+
result = array.take([-1], allow_fill=False)
348+
349+
assert result.dtype.numpy_dtype == dtype
336350

337351

338352
# ----------------------------------------------------------------------------

0 commit comments

Comments
 (0)