|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | | -from functools import partial |
4 | 3 | import operator |
5 | 4 | import re |
6 | 5 | from typing import ( |
@@ -209,12 +208,17 @@ def dtype(self) -> StringDtype: # type: ignore[override] |
209 | 208 | return self._dtype |
210 | 209 |
|
211 | 210 | def insert(self, loc: int, item) -> ArrowStringArray: |
| 211 | + if self.dtype.na_value is np.nan and item is np.nan: |
| 212 | + item = libmissing.NA |
212 | 213 | if not isinstance(item, str) and item is not libmissing.NA: |
213 | 214 | raise TypeError("Scalar must be NA or str") |
214 | 215 | return super().insert(loc, item) |
215 | 216 |
|
216 | | - @classmethod |
217 | | - def _result_converter(cls, values, na=None): |
| 217 | + def _result_converter(self, values, na=None): |
| 218 | + if self.dtype.na_value is np.nan: |
| 219 | + if not isna(na): |
| 220 | + values = values.fill_null(bool(na)) |
| 221 | + return ArrowExtensionArray(values).to_numpy(na_value=np.nan) |
218 | 222 | return BooleanDtype().__from_arrow__(values) |
219 | 223 |
|
220 | 224 | def _maybe_convert_setitem_value(self, value): |
@@ -494,11 +498,30 @@ def _str_get_dummies(self, sep: str = "|"): |
494 | 498 | return dummies.astype(np.int64, copy=False), labels |
495 | 499 |
|
496 | 500 | def _convert_int_dtype(self, result): |
| 501 | + if self.dtype.na_value is np.nan: |
| 502 | + if isinstance(result, pa.Array): |
| 503 | + result = result.to_numpy(zero_copy_only=False) |
| 504 | + else: |
| 505 | + result = result.to_numpy() |
| 506 | + if result.dtype == np.int32: |
| 507 | + result = result.astype(np.int64) |
| 508 | + return result |
| 509 | + |
497 | 510 | return Int64Dtype().__from_arrow__(result) |
498 | 511 |
|
499 | 512 | def _reduce( |
500 | 513 | self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs |
501 | 514 | ): |
| 515 | + if self.dtype.na_value is np.nan and name in ["any", "all"]: |
| 516 | + if not skipna: |
| 517 | + nas = pc.is_null(self._pa_array) |
| 518 | + arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, "")) |
| 519 | + else: |
| 520 | + arr = pc.not_equal(self._pa_array, "") |
| 521 | + return ArrowExtensionArray(arr)._reduce( |
| 522 | + name, skipna=skipna, keepdims=keepdims, **kwargs |
| 523 | + ) |
| 524 | + |
502 | 525 | result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs) |
503 | 526 | if name in ("argmin", "argmax") and isinstance(result, pa.Array): |
504 | 527 | return self._convert_int_dtype(result) |
@@ -529,67 +552,31 @@ def _rank( |
529 | 552 | ) |
530 | 553 | ) |
531 | 554 |
|
532 | | - |
533 | | -class ArrowStringArrayNumpySemantics(ArrowStringArray): |
534 | | - _storage = "pyarrow" |
535 | | - _na_value = np.nan |
536 | | - |
537 | | - @classmethod |
538 | | - def _result_converter(cls, values, na=None): |
539 | | - if not isna(na): |
540 | | - values = values.fill_null(bool(na)) |
541 | | - return ArrowExtensionArray(values).to_numpy(na_value=np.nan) |
542 | | - |
543 | | - def __getattribute__(self, item): |
544 | | - # ArrowStringArray and we both inherit from ArrowExtensionArray, which |
545 | | - # creates inheritance problems (Diamond inheritance) |
546 | | - if item in ArrowStringArrayMixin.__dict__ and item not in ( |
547 | | - "_pa_array", |
548 | | - "__dict__", |
549 | | - ): |
550 | | - return partial(getattr(ArrowStringArrayMixin, item), self) |
551 | | - return super().__getattribute__(item) |
552 | | - |
553 | | - def _convert_int_dtype(self, result): |
554 | | - if isinstance(result, pa.Array): |
555 | | - result = result.to_numpy(zero_copy_only=False) |
556 | | - else: |
557 | | - result = result.to_numpy() |
558 | | - if result.dtype == np.int32: |
559 | | - result = result.astype(np.int64) |
| 555 | + def value_counts(self, dropna: bool = True) -> Series: |
| 556 | + result = super().value_counts(dropna=dropna) |
| 557 | + if self.dtype.na_value is np.nan: |
| 558 | + res_values = result._values.to_numpy() |
| 559 | + return result._constructor( |
| 560 | + res_values, index=result.index, name=result.name, copy=False |
| 561 | + ) |
560 | 562 | return result |
561 | 563 |
|
562 | 564 | def _cmp_method(self, other, op): |
563 | 565 | result = super()._cmp_method(other, op) |
564 | | - if op == operator.ne: |
565 | | - return result.to_numpy(np.bool_, na_value=True) |
566 | | - else: |
567 | | - return result.to_numpy(np.bool_, na_value=False) |
568 | | - |
569 | | - def value_counts(self, dropna: bool = True) -> Series: |
570 | | - from pandas import Series |
571 | | - |
572 | | - result = super().value_counts(dropna) |
573 | | - return Series( |
574 | | - result._values.to_numpy(), index=result.index, name=result.name, copy=False |
575 | | - ) |
576 | | - |
577 | | - def _reduce( |
578 | | - self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs |
579 | | - ): |
580 | | - if name in ["any", "all"]: |
581 | | - if not skipna and name == "all": |
582 | | - nas = pc.invert(pc.is_null(self._pa_array)) |
583 | | - arr = pc.and_kleene(nas, pc.not_equal(self._pa_array, "")) |
| 566 | + if self.dtype.na_value is np.nan: |
| 567 | + if op == operator.ne: |
| 568 | + return result.to_numpy(np.bool_, na_value=True) |
584 | 569 | else: |
585 | | - arr = pc.not_equal(self._pa_array, "") |
586 | | - return ArrowExtensionArray(arr)._reduce( |
587 | | - name, skipna=skipna, keepdims=keepdims, **kwargs |
588 | | - ) |
589 | | - else: |
590 | | - return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs) |
| 570 | + return result.to_numpy(np.bool_, na_value=False) |
| 571 | + return result |
591 | 572 |
|
592 | | - def insert(self, loc: int, item) -> ArrowStringArrayNumpySemantics: |
593 | | - if item is np.nan: |
594 | | - item = libmissing.NA |
595 | | - return super().insert(loc, item) # type: ignore[return-value] |
| 573 | + |
| 574 | +class ArrowStringArrayNumpySemantics(ArrowStringArray): |
| 575 | + _na_value = np.nan |
| 576 | + _str_get = ArrowStringArrayMixin._str_get |
| 577 | + _str_removesuffix = ArrowStringArrayMixin._str_removesuffix |
| 578 | + _str_capitalize = ArrowStringArrayMixin._str_capitalize |
| 579 | + _str_pad = ArrowStringArrayMixin._str_pad |
| 580 | + _str_title = ArrowStringArrayMixin._str_title |
| 581 | + _str_swapcase = ArrowStringArrayMixin._str_swapcase |
| 582 | + _str_slice_replace = ArrowStringArrayMixin._str_slice_replace |
0 commit comments