@@ -354,47 +354,32 @@ def check_accumulate(self, ser, op_name, skipna):
354354 expected = getattr (ser .astype ("Float64" ), op_name )(skipna = skipna )
355355 tm .assert_series_equal (result , expected , check_dtype = False )
356356
357- @pytest .mark .parametrize ("skipna" , [True , False ])
358- def test_accumulate_series_raises (self , data , all_numeric_accumulations , skipna ):
359- pa_type = data .dtype .pyarrow_dtype
360- if (
361- (
362- pa .types .is_integer (pa_type )
363- or pa .types .is_floating (pa_type )
364- or pa .types .is_duration (pa_type )
365- )
366- and all_numeric_accumulations == "cumsum"
367- and not pa_version_under9p0
368- ):
369- pytest .skip ("These work, are tested by test_accumulate_series." )
357+ def _supports_accumulation (self , ser : pd .Series , op_name : str ) -> bool :
358+ # error: Item "dtype[Any]" of "dtype[Any] | ExtensionDtype" has no
359+ # attribute "pyarrow_dtype"
360+ pa_type = ser .dtype .pyarrow_dtype # type: ignore[union-attr]
370361
371- op_name = all_numeric_accumulations
372- ser = pd .Series (data )
373-
374- with pytest .raises (NotImplementedError ):
375- getattr (ser , op_name )(skipna = skipna )
376-
377- @pytest .mark .parametrize ("skipna" , [True , False ])
378- def test_accumulate_series (self , data , all_numeric_accumulations , skipna , request ):
379- pa_type = data .dtype .pyarrow_dtype
380- op_name = all_numeric_accumulations
381- ser = pd .Series (data )
382-
383- do_skip = False
384362 if pa .types .is_string (pa_type ) or pa .types .is_binary (pa_type ):
385363 if op_name in ["cumsum" , "cumprod" ]:
386- do_skip = True
364+ return False
387365 elif pa .types .is_temporal (pa_type ) and not pa .types .is_duration (pa_type ):
388366 if op_name in ["cumsum" , "cumprod" ]:
389- do_skip = True
367+ return False
390368 elif pa .types .is_duration (pa_type ):
391369 if op_name == "cumprod" :
392- do_skip = True
370+ return False
371+ return True
393372
394- if do_skip :
395- pytest .skip (
396- f"{ op_name } should *not* work, we test in "
397- "test_accumulate_series_raises that these correctly raise."
373+ @pytest .mark .parametrize ("skipna" , [True , False ])
374+ def test_accumulate_series (self , data , all_numeric_accumulations , skipna , request ):
375+ pa_type = data .dtype .pyarrow_dtype
376+ op_name = all_numeric_accumulations
377+ ser = pd .Series (data )
378+
379+ if not self ._supports_accumulation (ser , op_name ):
380+ # The base class test will check that we raise
381+ return super ().test_accumulate_series (
382+ data , all_numeric_accumulations , skipna
398383 )
399384
400385 if all_numeric_accumulations != "cumsum" or pa_version_under9p0 :
0 commit comments