11from __future__ import annotations
22
3+ from functools import partial
34import re
45from typing import (
56 TYPE_CHECKING ,
2728)
2829from pandas .core .dtypes .missing import isna
2930
31+ from pandas .core .arrays ._arrow_string_mixins import ArrowStringArrayMixin
3032from pandas .core .arrays .arrow import ArrowExtensionArray
3133from pandas .core .arrays .boolean import BooleanDtype
3234from pandas .core .arrays .integer import Int64Dtype
@@ -113,10 +115,11 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr
113115 # error: Incompatible types in assignment (expression has type "StringDtype",
114116 # base class "ArrowExtensionArray" defined the type as "ArrowDtype")
115117 _dtype : StringDtype # type: ignore[assignment]
118+ _storage = "pyarrow"
116119
117120 def __init__ (self , values ) -> None :
118121 super ().__init__ (values )
119- self ._dtype = StringDtype (storage = "pyarrow" )
122+ self ._dtype = StringDtype (storage = self . _storage )
120123
121124 if not pa .types .is_string (self ._pa_array .type ) and not (
122125 pa .types .is_dictionary (self ._pa_array .type )
@@ -144,7 +147,10 @@ def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False)
144147
145148 if dtype and not (isinstance (dtype , str ) and dtype == "string" ):
146149 dtype = pandas_dtype (dtype )
147- assert isinstance (dtype , StringDtype ) and dtype .storage == "pyarrow"
150+ assert isinstance (dtype , StringDtype ) and dtype .storage in (
151+ "pyarrow" ,
152+ "pyarrow_numpy" ,
153+ )
148154
149155 if isinstance (scalars , BaseMaskedArray ):
150156 # avoid costly conversion to object dtype in ensure_string_array and
@@ -178,6 +184,10 @@ def insert(self, loc: int, item) -> ArrowStringArray:
178184 raise TypeError ("Scalar must be NA or str" )
179185 return super ().insert (loc , item )
180186
187+ @classmethod
188+ def _result_converter (cls , values , na = None ):
189+ return BooleanDtype ().__from_arrow__ (values )
190+
181191 def _maybe_convert_setitem_value (self , value ):
182192 """Maybe convert value to be pyarrow compatible."""
183193 if is_scalar (value ):
@@ -313,7 +323,7 @@ def _str_contains(
313323 result = pc .match_substring_regex (self ._pa_array , pat , ignore_case = not case )
314324 else :
315325 result = pc .match_substring (self ._pa_array , pat , ignore_case = not case )
316- result = BooleanDtype (). __from_arrow__ (result )
326+ result = self . _result_converter (result , na = na )
317327 if not isna (na ):
318328 result [isna (result )] = bool (na )
319329 return result
@@ -322,7 +332,7 @@ def _str_startswith(self, pat: str, na=None):
322332 result = pc .starts_with (self ._pa_array , pattern = pat )
323333 if not isna (na ):
324334 result = result .fill_null (na )
325- result = BooleanDtype (). __from_arrow__ (result )
335+ result = self . _result_converter (result )
326336 if not isna (na ):
327337 result [isna (result )] = bool (na )
328338 return result
@@ -331,7 +341,7 @@ def _str_endswith(self, pat: str, na=None):
331341 result = pc .ends_with (self ._pa_array , pattern = pat )
332342 if not isna (na ):
333343 result = result .fill_null (na )
334- result = BooleanDtype (). __from_arrow__ (result )
344+ result = self . _result_converter (result )
335345 if not isna (na ):
336346 result [isna (result )] = bool (na )
337347 return result
@@ -369,39 +379,39 @@ def _str_fullmatch(
369379
370380 def _str_isalnum (self ):
371381 result = pc .utf8_is_alnum (self ._pa_array )
372- return BooleanDtype (). __from_arrow__ (result )
382+ return self . _result_converter (result )
373383
374384 def _str_isalpha (self ):
375385 result = pc .utf8_is_alpha (self ._pa_array )
376- return BooleanDtype (). __from_arrow__ (result )
386+ return self . _result_converter (result )
377387
378388 def _str_isdecimal (self ):
379389 result = pc .utf8_is_decimal (self ._pa_array )
380- return BooleanDtype (). __from_arrow__ (result )
390+ return self . _result_converter (result )
381391
382392 def _str_isdigit (self ):
383393 result = pc .utf8_is_digit (self ._pa_array )
384- return BooleanDtype (). __from_arrow__ (result )
394+ return self . _result_converter (result )
385395
386396 def _str_islower (self ):
387397 result = pc .utf8_is_lower (self ._pa_array )
388- return BooleanDtype (). __from_arrow__ (result )
398+ return self . _result_converter (result )
389399
390400 def _str_isnumeric (self ):
391401 result = pc .utf8_is_numeric (self ._pa_array )
392- return BooleanDtype (). __from_arrow__ (result )
402+ return self . _result_converter (result )
393403
394404 def _str_isspace (self ):
395405 result = pc .utf8_is_space (self ._pa_array )
396- return BooleanDtype (). __from_arrow__ (result )
406+ return self . _result_converter (result )
397407
398408 def _str_istitle (self ):
399409 result = pc .utf8_is_title (self ._pa_array )
400- return BooleanDtype (). __from_arrow__ (result )
410+ return self . _result_converter (result )
401411
402412 def _str_isupper (self ):
403413 result = pc .utf8_is_upper (self ._pa_array )
404- return BooleanDtype (). __from_arrow__ (result )
414+ return self . _result_converter (result )
405415
406416 def _str_len (self ):
407417 result = pc .utf8_length (self ._pa_array )
@@ -433,3 +443,114 @@ def _str_rstrip(self, to_strip=None):
433443 else :
434444 result = pc .utf8_rtrim (self ._pa_array , characters = to_strip )
435445 return type (self )(result )
446+
447+
448+ class ArrowStringArrayNumpySemantics (ArrowStringArray ):
449+ _storage = "pyarrow_numpy"
450+
451+ @classmethod
452+ def _result_converter (cls , values , na = None ):
453+ if not isna (na ):
454+ values = values .fill_null (bool (na ))
455+ return ArrowExtensionArray (values ).to_numpy (na_value = np .nan )
456+
457+ def __getattribute__ (self , item ):
458+ # ArrowStringArray and we both inherit from ArrowExtensionArray, which
459+ # creates inheritance problems (Diamond inheritance)
460+ if item in ArrowStringArrayMixin .__dict__ and item != "_pa_array" :
461+ return partial (getattr (ArrowStringArrayMixin , item ), self )
462+ return super ().__getattribute__ (item )
463+
464+ def _str_map (
465+ self , f , na_value = None , dtype : Dtype | None = None , convert : bool = True
466+ ):
467+ if dtype is None :
468+ dtype = self .dtype
469+ if na_value is None :
470+ na_value = self .dtype .na_value
471+
472+ mask = isna (self )
473+ arr = np .asarray (self )
474+
475+ if is_integer_dtype (dtype ) or is_bool_dtype (dtype ):
476+ if is_integer_dtype (dtype ):
477+ na_value = np .nan
478+ else :
479+ na_value = False
480+ try :
481+ result = lib .map_infer_mask (
482+ arr ,
483+ f ,
484+ mask .view ("uint8" ),
485+ convert = False ,
486+ na_value = na_value ,
487+ dtype = np .dtype (dtype ), # type: ignore[arg-type]
488+ )
489+ return result
490+
491+ except ValueError :
492+ result = lib .map_infer_mask (
493+ arr ,
494+ f ,
495+ mask .view ("uint8" ),
496+ convert = False ,
497+ na_value = na_value ,
498+ )
499+ if convert and result .dtype == object :
500+ result = lib .maybe_convert_objects (result )
501+ return result
502+
503+ elif is_string_dtype (dtype ) and not is_object_dtype (dtype ):
504+ # i.e. StringDtype
505+ result = lib .map_infer_mask (
506+ arr , f , mask .view ("uint8" ), convert = False , na_value = na_value
507+ )
508+ result = pa .array (result , mask = mask , type = pa .string (), from_pandas = True )
509+ return type (self )(result )
510+ else :
511+ # This is when the result type is object. We reach this when
512+ # -> We know the result type is truly object (e.g. .encode returns bytes
513+ # or .findall returns a list).
514+ # -> We don't know the result type. E.g. `.get` can return anything.
515+ return lib .map_infer_mask (arr , f , mask .view ("uint8" ))
516+
517+ def _convert_int_dtype (self , result ):
518+ if result .dtype == np .int32 :
519+ result = result .astype (np .int64 )
520+ return result
521+
522+ def _str_count (self , pat : str , flags : int = 0 ):
523+ if flags :
524+ return super ()._str_count (pat , flags )
525+ result = pc .count_substring_regex (self ._pa_array , pat ).to_numpy ()
526+ return self ._convert_int_dtype (result )
527+
528+ def _str_len (self ):
529+ result = pc .utf8_length (self ._pa_array ).to_numpy ()
530+ return self ._convert_int_dtype (result )
531+
532+ def _str_find (self , sub : str , start : int = 0 , end : int | None = None ):
533+ if start != 0 and end is not None :
534+ slices = pc .utf8_slice_codeunits (self ._pa_array , start , stop = end )
535+ result = pc .find_substring (slices , sub )
536+ not_found = pc .equal (result , - 1 )
537+ offset_result = pc .add (result , end - start )
538+ result = pc .if_else (not_found , result , offset_result )
539+ elif start == 0 and end is None :
540+ slices = self ._pa_array
541+ result = pc .find_substring (slices , sub )
542+ else :
543+ return super ()._str_find (sub , start , end )
544+ return self ._convert_int_dtype (result .to_numpy ())
545+
546+ def _cmp_method (self , other , op ):
547+ result = super ()._cmp_method (other , op )
548+ return result .to_numpy (np .bool_ , na_value = False )
549+
550+ def value_counts (self , dropna : bool = True ):
551+ from pandas import Series
552+
553+ result = super ().value_counts (dropna )
554+ return Series (
555+ result ._values .to_numpy (), index = result .index , name = result .name , copy = False
556+ )
0 commit comments