@@ -863,17 +863,22 @@ def get_op_from_name(self, op_name):
863863 short_opname = op_name .strip ("_" )
864864 if short_opname == "rtruediv" :
865865 # use the numpy version that won't raise on division by zero
866- return lambda x , y : np .divide (y , x )
866+
867+ def rtruediv (x , y ):
868+ return np .divide (y , x )
869+
870+ return rtruediv
867871 elif short_opname == "rfloordiv" :
868872 return lambda x , y : np .floor_divide (y , x )
869873
870874 return tm .get_op_from_name (op_name )
871875
872- def _patch_combine (self , obj , other , op ):
876+ def _combine (self , obj , other , op ):
873877 # BaseOpsUtil._combine can upcast expected dtype
874878 # (because it generates expected on python scalars)
875879 # while ArrowExtensionArray maintains original type
876880 expected = base .BaseArithmeticOpsTests ._combine (self , obj , other , op )
881+
877882 was_frame = False
878883 if isinstance (expected , pd .DataFrame ):
879884 was_frame = True
@@ -883,10 +888,37 @@ def _patch_combine(self, obj, other, op):
883888 expected_data = expected
884889 original_dtype = obj .dtype
885890
891+ orig_pa_type = original_dtype .pyarrow_dtype
892+ if not was_frame and isinstance (other , pd .Series ):
893+ # i.e. test_arith_series_with_array
894+ if not (
895+ pa .types .is_floating (orig_pa_type )
896+ or (
897+ pa .types .is_integer (orig_pa_type )
898+ and op .__name__ not in ["truediv" , "rtruediv" ]
899+ )
900+ or pa .types .is_duration (orig_pa_type )
901+ or pa .types .is_timestamp (orig_pa_type )
902+ or pa .types .is_date (orig_pa_type )
903+ or pa .types .is_decimal (orig_pa_type )
904+ ):
905+ # base class _combine always returns int64, while
906+ # ArrowExtensionArray does not upcast
907+ return expected
908+ elif not (
909+ (op is operator .floordiv and pa .types .is_integer (orig_pa_type ))
910+ or pa .types .is_duration (orig_pa_type )
911+ or pa .types .is_timestamp (orig_pa_type )
912+ or pa .types .is_date (orig_pa_type )
913+ or pa .types .is_decimal (orig_pa_type )
914+ ):
915+ # base class _combine always returns int64, while
916+ # ArrowExtensionArray does not upcast
917+ return expected
918+
886919 pa_expected = pa .array (expected_data ._values )
887920
888921 if pa .types .is_duration (pa_expected .type ):
889- orig_pa_type = original_dtype .pyarrow_dtype
890922 if pa .types .is_date (orig_pa_type ):
891923 if pa .types .is_date64 (orig_pa_type ):
892924 # TODO: why is this different vs date32?
@@ -907,7 +939,7 @@ def _patch_combine(self, obj, other, op):
907939 pa_expected = pa_expected .cast (f"duration[{ unit } ]" )
908940
909941 elif pa .types .is_decimal (pa_expected .type ) and pa .types .is_decimal (
910- original_dtype . pyarrow_dtype
942+ orig_pa_type
911943 ):
912944 # decimal precision can resize in the result type depending on data
913945 # just compare the float values
@@ -929,7 +961,7 @@ def _patch_combine(self, obj, other, op):
929961 return expected .astype (alt_dtype )
930962
931963 else :
932- pa_expected = pa_expected .cast (original_dtype . pyarrow_dtype )
964+ pa_expected = pa_expected .cast (orig_pa_type )
933965
934966 pd_expected = type (expected_data ._values )(pa_expected )
935967 if was_frame :
@@ -1043,9 +1075,7 @@ def _get_arith_xfail_marker(self, opname, pa_dtype):
10431075
10441076 return mark
10451077
1046- def test_arith_series_with_scalar (
1047- self , data , all_arithmetic_operators , request , monkeypatch
1048- ):
1078+ def test_arith_series_with_scalar (self , data , all_arithmetic_operators , request ):
10491079 pa_dtype = data .dtype .pyarrow_dtype
10501080
10511081 if all_arithmetic_operators == "__rmod__" and (
@@ -1061,24 +1091,9 @@ def test_arith_series_with_scalar(
10611091 if mark is not None :
10621092 request .node .add_marker (mark )
10631093
1064- if (
1065- (
1066- all_arithmetic_operators == "__floordiv__"
1067- and pa .types .is_integer (pa_dtype )
1068- )
1069- or pa .types .is_duration (pa_dtype )
1070- or pa .types .is_timestamp (pa_dtype )
1071- or pa .types .is_date (pa_dtype )
1072- or pa .types .is_decimal (pa_dtype )
1073- ):
1074- # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
1075- # not upcast
1076- monkeypatch .setattr (TestBaseArithmeticOps , "_combine" , self ._patch_combine )
10771094 super ().test_arith_series_with_scalar (data , all_arithmetic_operators )
10781095
1079- def test_arith_frame_with_scalar (
1080- self , data , all_arithmetic_operators , request , monkeypatch
1081- ):
1096+ def test_arith_frame_with_scalar (self , data , all_arithmetic_operators , request ):
10821097 pa_dtype = data .dtype .pyarrow_dtype
10831098
10841099 if all_arithmetic_operators == "__rmod__" and (
@@ -1094,24 +1109,9 @@ def test_arith_frame_with_scalar(
10941109 if mark is not None :
10951110 request .node .add_marker (mark )
10961111
1097- if (
1098- (
1099- all_arithmetic_operators == "__floordiv__"
1100- and pa .types .is_integer (pa_dtype )
1101- )
1102- or pa .types .is_duration (pa_dtype )
1103- or pa .types .is_timestamp (pa_dtype )
1104- or pa .types .is_date (pa_dtype )
1105- or pa .types .is_decimal (pa_dtype )
1106- ):
1107- # BaseOpsUtil._combine always returns int64, while ArrowExtensionArray does
1108- # not upcast
1109- monkeypatch .setattr (TestBaseArithmeticOps , "_combine" , self ._patch_combine )
11101112 super ().test_arith_frame_with_scalar (data , all_arithmetic_operators )
11111113
1112- def test_arith_series_with_array (
1113- self , data , all_arithmetic_operators , request , monkeypatch
1114- ):
1114+ def test_arith_series_with_array (self , data , all_arithmetic_operators , request ):
11151115 pa_dtype = data .dtype .pyarrow_dtype
11161116
11171117 self .series_array_exc = self ._get_scalar_exception (
@@ -1147,18 +1147,6 @@ def test_arith_series_with_array(
11471147 # since ser.iloc[0] is a python scalar
11481148 other = pd .Series (pd .array ([ser .iloc [0 ]] * len (ser ), dtype = data .dtype ))
11491149
1150- if (
1151- pa .types .is_floating (pa_dtype )
1152- or (
1153- pa .types .is_integer (pa_dtype )
1154- and all_arithmetic_operators not in ["__truediv__" , "__rtruediv__" ]
1155- )
1156- or pa .types .is_duration (pa_dtype )
1157- or pa .types .is_timestamp (pa_dtype )
1158- or pa .types .is_date (pa_dtype )
1159- or pa .types .is_decimal (pa_dtype )
1160- ):
1161- monkeypatch .setattr (TestBaseArithmeticOps , "_combine" , self ._patch_combine )
11621150 self .check_opname (ser , op_name , other , exc = self .series_array_exc )
11631151
11641152 def test_add_series_with_extension_array (self , data , request ):
0 commit comments