@@ -303,7 +303,7 @@ def construct_array_type(self) -> type_t[BaseStringArray]:
303303 elif self .storage == "pyarrow" and self ._na_value is libmissing .NA :
304304 return ArrowStringArray
305305 elif self .storage == "python" :
306- return StringArrayNumpySemantics
306+ return StringArray
307307 else :
308308 return ArrowStringArray
309309
@@ -490,8 +490,10 @@ def _str_map_str_or_object(
490490 )
491491 # error: "BaseStringArray" has no attribute "_from_pyarrow_array"
492492 return self ._from_pyarrow_array (result ) # type: ignore[attr-defined]
493- # error: Too many arguments for "BaseStringArray"
494- return type (self )(result ) # type: ignore[call-arg]
493+ else :
494+ # StringArray
495+ # error: Too many arguments for "BaseStringArray"
496+ return type (self )(result , dtype = self .dtype ) # type: ignore[call-arg]
495497
496498 else :
497499 # This is when the result type is object. We reach this when
@@ -581,6 +583,8 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc]
581583 nan-likes(``None``, ``np.nan``) for the ``values`` parameter
582584 in addition to strings and :attr:`pandas.NA`
583585
586+ dtype : StringDtype
587+ Dtype for the array.
584588 copy : bool, default False
585589 Whether to copy the array of data.
586590
@@ -635,36 +639,56 @@ class StringArray(BaseStringArray, NumpyExtensionArray): # type: ignore[misc]
635639
636640 # undo the NumpyExtensionArray hack
637641 _typ = "extension"
638- _storage = "python"
639- _na_value : libmissing .NAType | float = libmissing .NA
640642
641- def __init__ (self , values , copy : bool = False ) -> None :
643+ def __init__ (
644+ self , values , * , dtype : StringDtype | None = None , copy : bool = False
645+ ) -> None :
646+ if dtype is None :
647+ dtype = StringDtype ()
642648 values = extract_array (values )
643649
644650 super ().__init__ (values , copy = copy )
645651 if not isinstance (values , type (self )):
646- self ._validate ()
652+ self ._validate (dtype )
647653 NDArrayBacked .__init__ (
648654 self ,
649655 self ._ndarray ,
650- StringDtype ( storage = self . _storage , na_value = self . _na_value ) ,
656+ dtype ,
651657 )
652658
653- def _validate (self ) -> None :
659+ def _validate (self , dtype : StringDtype ) -> None :
654660 """Validate that we only store NA or strings."""
655- if len (self ._ndarray ) and not lib .is_string_array (self ._ndarray , skipna = True ):
656- raise ValueError ("StringArray requires a sequence of strings or pandas.NA" )
657- if self ._ndarray .dtype != "object" :
658- raise ValueError (
659- "StringArray requires a sequence of strings or pandas.NA. Got "
660- f"'{ self ._ndarray .dtype } ' dtype instead."
661- )
662- # Check to see if need to convert Na values to pd.NA
663- if self ._ndarray .ndim > 2 :
664- # Ravel if ndims > 2 b/c no cythonized version available
665- lib .convert_nans_to_NA (self ._ndarray .ravel ("K" ))
661+
662+ if dtype ._na_value is libmissing .NA :
663+ if len (self ._ndarray ) and not lib .is_string_array (
664+ self ._ndarray , skipna = True
665+ ):
666+ raise ValueError (
667+ "StringArray requires a sequence of strings or pandas.NA"
668+ )
669+ if self ._ndarray .dtype != "object" :
670+ raise ValueError (
671+ "StringArray requires a sequence of strings or pandas.NA. Got "
672+ f"'{ self ._ndarray .dtype } ' dtype instead."
673+ )
674+ # Check to see if need to convert Na values to pd.NA
675+ if self ._ndarray .ndim > 2 :
676+ # Ravel if ndims > 2 b/c no cythonized version available
677+ lib .convert_nans_to_NA (self ._ndarray .ravel ("K" ))
678+ else :
679+ lib .convert_nans_to_NA (self ._ndarray )
666680 else :
667- lib .convert_nans_to_NA (self ._ndarray )
681+ # Validate that we only store NaN or strings.
682+ if len (self ._ndarray ) and not lib .is_string_array (
683+ self ._ndarray , skipna = True
684+ ):
685+ raise ValueError ("StringArray requires a sequence of strings or NaN" )
686+ if self ._ndarray .dtype != "object" :
687+ raise ValueError (
688+ "StringArray requires a sequence of strings "
689+ "or NaN. Got '{self._ndarray.dtype}' dtype instead."
690+ )
691+ # TODO validate or force NA/None to NaN
668692
669693 def _validate_scalar (self , value ):
670694 # used by NDArrayBackedExtensionIndex.insert
@@ -732,8 +756,8 @@ def _cast_pointwise_result(self, values) -> ArrayLike:
732756 @classmethod
733757 def _empty (cls , shape , dtype ) -> StringArray :
734758 values = np .empty (shape , dtype = object )
735- values [:] = libmissing . NA
736- return cls (values ).astype (dtype , copy = False )
759+ values [:] = dtype . na_value
760+ return cls (values , dtype = dtype ).astype (dtype , copy = False )
737761
738762 def __arrow_array__ (self , type = None ):
739763 """
@@ -933,7 +957,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra
933957 if self ._hasna :
934958 na_mask = cast ("npt.NDArray[np.bool_]" , isna (ndarray ))
935959 if np .all (na_mask ):
936- return type (self )(ndarray )
960+ return type (self )(ndarray , dtype = self . dtype )
937961 if skipna :
938962 if name == "cumsum" :
939963 ndarray = np .where (na_mask , "" , ndarray )
@@ -967,7 +991,7 @@ def _accumulate(self, name: str, *, skipna: bool = True, **kwargs) -> StringArra
967991 # Argument 2 to "where" has incompatible type "NAType | float"
968992 np_result = np .where (na_mask , self .dtype .na_value , np_result ) # type: ignore[arg-type]
969993
970- result = type (self )(np_result )
994+ result = type (self )(np_result , dtype = self . dtype )
971995 return result
972996
973997 def _wrap_reduction_result (self , axis : AxisInt | None , result ) -> Any :
@@ -1046,7 +1070,7 @@ def _cmp_method(self, other, op):
10461070 and other .dtype .na_value is libmissing .NA
10471071 ):
10481072 # NA has priority of NaN semantics
1049- return NotImplemented
1073+ return op ( self . astype ( other . dtype , copy = False ), other )
10501074
10511075 if isinstance (other , ArrowExtensionArray ):
10521076 if isinstance (other , BaseStringArray ):
@@ -1099,29 +1123,3 @@ def _cmp_method(self, other, op):
10991123 return res_arr
11001124
11011125 _arith_method = _cmp_method
1102-
1103-
1104- class StringArrayNumpySemantics (StringArray ):
1105- _storage = "python"
1106- _na_value = np .nan
1107-
1108- def _validate (self ) -> None :
1109- """Validate that we only store NaN or strings."""
1110- if len (self ._ndarray ) and not lib .is_string_array (self ._ndarray , skipna = True ):
1111- raise ValueError (
1112- "StringArrayNumpySemantics requires a sequence of strings or NaN"
1113- )
1114- if self ._ndarray .dtype != "object" :
1115- raise ValueError (
1116- "StringArrayNumpySemantics requires a sequence of strings or NaN. Got "
1117- f"'{ self ._ndarray .dtype } ' dtype instead."
1118- )
1119- # TODO validate or force NA/None to NaN
1120-
1121- @classmethod
1122- def _from_sequence (
1123- cls , scalars , * , dtype : Dtype | None = None , copy : bool = False
1124- ) -> Self :
1125- if dtype is None :
1126- dtype = StringDtype (storage = "python" , na_value = np .nan )
1127- return super ()._from_sequence (scalars , dtype = dtype , copy = copy )
0 commit comments