1212
1313
1414class BaseOpsUtil (BaseExtensionTests ):
15+ series_scalar_exc : type [Exception ] | None = TypeError
16+ frame_scalar_exc : type [Exception ] | None = TypeError
17+ series_array_exc : type [Exception ] | None = TypeError
18+ divmod_exc : type [Exception ] | None = TypeError
19+
20+ def _get_expected_exception (
21+ self , op_name : str , obj , other
22+ ) -> type [Exception ] | None :
23+ # Find the Exception, if any we expect to raise calling
24+ # obj.__op_name__(other)
25+
26+ # The self.obj_bar_exc pattern isn't great in part because it can depend
27+ # on op_name or dtypes, but we use it here for backward-compatibility.
28+ if op_name in ["__divmod__" , "__rdivmod__" ]:
29+ return self .divmod_exc
30+ if isinstance (obj , pd .Series ) and isinstance (other , pd .Series ):
31+ return self .series_array_exc
32+ elif isinstance (obj , pd .Series ):
33+ return self .series_scalar_exc
34+ else :
35+ return self .frame_scalar_exc
36+
1537 def _cast_pointwise_result (self , op_name : str , obj , other , pointwise_result ):
1638 # In _check_op we check that the result of a pointwise operation
1739 # (found via _combine) matches the result of the vectorized
@@ -24,17 +46,21 @@ def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
2446 def get_op_from_name (self , op_name : str ):
2547 return tm .get_op_from_name (op_name )
2648
27- def check_opname (self , ser : pd .Series , op_name : str , other , exc = Exception ):
28- op = self .get_op_from_name (op_name )
29-
30- self ._check_op (ser , op , other , op_name , exc )
31-
32- # Subclasses are not expected to need to override _check_op or _combine.
49+ # Subclasses are not expected to need to override check_opname, _check_op,
50+ # _check_divmod_op, or _combine.
3351 # Ideally any relevant overriding can be done in _cast_pointwise_result,
3452 # get_op_from_name, and the specification of `exc`. If you find a use
3553 # case that still requires overriding _check_op or _combine, please let
3654 # us know at github.com/pandas-dev/pandas/issues
3755 @final
56+ def check_opname (self , ser : pd .Series , op_name : str , other ):
57+ exc = self ._get_expected_exception (op_name , ser , other )
58+ op = self .get_op_from_name (op_name )
59+
60+ self ._check_op (ser , op , other , op_name , exc )
61+
62+ # see comment on check_opname
63+ @final
3864 def _combine (self , obj , other , op ):
3965 if isinstance (obj , pd .DataFrame ):
4066 if len (obj .columns ) != 1 :
@@ -44,11 +70,14 @@ def _combine(self, obj, other, op):
4470 expected = obj .combine (other , op )
4571 return expected
4672
47- # see comment on _combine
73+ # see comment on check_opname
4874 @final
4975 def _check_op (
5076 self , ser : pd .Series , op , other , op_name : str , exc = NotImplementedError
5177 ):
78+ # Check that the Series/DataFrame arithmetic/comparison method matches
79+ # the pointwise result from _combine.
80+
5281 if exc is None :
5382 result = op (ser , other )
5483 expected = self ._combine (ser , other , op )
@@ -59,8 +88,14 @@ def _check_op(
5988 with pytest .raises (exc ):
6089 op (ser , other )
6190
62- def _check_divmod_op (self , ser : pd .Series , op , other , exc = Exception ):
63- # divmod has multiple return values, so check separately
91+ # see comment on check_opname
92+ @final
93+ def _check_divmod_op (self , ser : pd .Series , op , other ):
94+ # check that divmod behavior matches behavior of floordiv+mod
95+ if op is divmod :
96+ exc = self ._get_expected_exception ("__divmod__" , ser , other )
97+ else :
98+ exc = self ._get_expected_exception ("__rdivmod__" , ser , other )
6499 if exc is None :
65100 result_div , result_mod = op (ser , other )
66101 if op is divmod :
@@ -96,26 +131,24 @@ def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
96131 # series & scalar
97132 op_name = all_arithmetic_operators
98133 ser = pd .Series (data )
99- self .check_opname (ser , op_name , ser .iloc [0 ], exc = self . series_scalar_exc )
134+ self .check_opname (ser , op_name , ser .iloc [0 ])
100135
101136 def test_arith_frame_with_scalar (self , data , all_arithmetic_operators ):
102137 # frame & scalar
103138 op_name = all_arithmetic_operators
104139 df = pd .DataFrame ({"A" : data })
105- self .check_opname (df , op_name , data [0 ], exc = self . frame_scalar_exc )
140+ self .check_opname (df , op_name , data [0 ])
106141
107142 def test_arith_series_with_array (self , data , all_arithmetic_operators ):
108143 # ndarray & other series
109144 op_name = all_arithmetic_operators
110145 ser = pd .Series (data )
111- self .check_opname (
112- ser , op_name , pd .Series ([ser .iloc [0 ]] * len (ser )), exc = self .series_array_exc
113- )
146+ self .check_opname (ser , op_name , pd .Series ([ser .iloc [0 ]] * len (ser )))
114147
115148 def test_divmod (self , data ):
116149 ser = pd .Series (data )
117- self ._check_divmod_op (ser , divmod , 1 , exc = self . divmod_exc )
118- self ._check_divmod_op (1 , ops .rdivmod , ser , exc = self . divmod_exc )
150+ self ._check_divmod_op (ser , divmod , 1 )
151+ self ._check_divmod_op (1 , ops .rdivmod , ser )
119152
120153 def test_divmod_series_array (self , data , data_for_twos ):
121154 ser = pd .Series (data )
0 commit comments