44
55import inspect
66import math
7+ import operator
78import re
89from dataclasses import dataclass , field
910from decimal import ROUND_HALF_EVEN , Decimal
2425from . import xps
2526from ._array_module import mod as xp
2627from .stubs import category_to_funcs
28+ from .test_operators_and_elementwise_functions import (
29+ oneway_broadcastable_shapes ,
30+ oneway_promotable_dtypes ,
31+ )
2732
2833pytestmark = pytest .mark .ci
2934
@@ -1138,6 +1143,8 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11381143
11391144unary_params = []
11401145binary_params = []
1146+ iop_params = []
1147+ func_to_op : Dict [str , str ] = {v : k for k , v in dh .op_to_func .items ()}
11411148for stub in category_to_funcs ["elementwise" ]:
11421149 if stub .__doc__ is None :
11431150 warn (f"{ stub .__name__ } () stub has no docstring" )
@@ -1157,20 +1164,39 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
11571164 continue
11581165 if param_names [0 ] == "x" :
11591166 if cases := parse_unary_docstring (stub .__doc__ ):
1160- for case in cases :
1161- id_ = f"{ stub .__name__ } ({ case .cond_expr } ) -> { case .result_expr } "
1162- p = pytest .param (stub .__name__ , func , case , id = id_ )
1163- unary_params .append (p )
1167+ func_name_to_func = {stub .__name__ : func }
1168+ if stub .__name__ in func_to_op .keys ():
1169+ op_name = func_to_op [stub .__name__ ]
1170+ op = getattr (operator , op_name )
1171+ func_name_to_func [op_name ] = op
1172+ for func_name , func in func_name_to_func .items ():
1173+ for case in cases :
1174+ id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1175+ p = pytest .param (func_name , func , case , id = id_ )
1176+ unary_params .append (p )
11641177 continue
11651178 if len (sig .parameters ) == 1 :
11661179 warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
11671180 continue
11681181 if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
11691182 if cases := parse_binary_docstring (stub .__doc__ ):
1170- for case in cases :
1171- id_ = f"{ stub .__name__ } ({ case .cond_expr } ) -> { case .result_expr } "
1172- p = pytest .param (stub .__name__ , func , case , id = id_ )
1173- binary_params .append (p )
1183+ func_name_to_func = {stub .__name__ : func }
1184+ if stub .__name__ in func_to_op .keys ():
1185+ op_name = func_to_op [stub .__name__ ]
1186+ op = getattr (operator , op_name )
1187+ func_name_to_func [op_name ] = op
1188+ # We collect inplaceoperator test cases seperately
1189+ iop_name = "__i" + op_name [2 :]
1190+ iop = getattr (operator , iop_name )
1191+ for case in cases :
1192+ id_ = f"{ iop_name } ({ case .cond_expr } ) -> { case .result_expr } "
1193+ p = pytest .param (iop_name , iop , case , id = id_ )
1194+ iop_params .append (p )
1195+ for func_name , func in func_name_to_func .items ():
1196+ for case in cases :
1197+ id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1198+ p = pytest .param (func_name , func , case , id = id_ )
1199+ binary_params .append (p )
11741200 continue
11751201 else :
11761202 warn (
@@ -1264,3 +1290,55 @@ def test_binary(func_name, func, case, x1, x2, data):
12641290 )
12651291 break
12661292 assume (good_example )
1293+
1294+
1295+ @pytest .mark .parametrize ("iop_name, iop, case" , iop_params )
1296+ @given (
1297+ oneway_dtypes = oneway_promotable_dtypes (dh .float_dtypes ),
1298+ oneway_shapes = oneway_broadcastable_shapes (),
1299+ data = st .data (),
1300+ )
1301+ def test_iop (iop_name , iop , case , oneway_dtypes , oneway_shapes , data ):
1302+ x1 = data .draw (
1303+ xps .arrays (dtype = oneway_dtypes .result_dtype , shape = oneway_shapes .result_shape ),
1304+ label = "x1" ,
1305+ )
1306+ x2 = data .draw (
1307+ xps .arrays (dtype = oneway_dtypes .input_dtype , shape = oneway_shapes .input_shape ),
1308+ label = "x2" ,
1309+ )
1310+
1311+ all_indices = list (sh .iter_indices (x1 .shape , x2 .shape , x1 .shape ))
1312+
1313+ indices_strat = st .shared (st .sampled_from (all_indices ))
1314+ set_x1_idx = data .draw (indices_strat .map (lambda t : t [0 ]), label = "set x1 idx" )
1315+ set_x1_value = data .draw (case .x1_cond_from_dtype (x1 .dtype ), label = "set x1 value" )
1316+ x1 [set_x1_idx ] = set_x1_value
1317+ note (f"{ x1 = } " )
1318+ set_x2_idx = data .draw (indices_strat .map (lambda t : t [1 ]), label = "set x2 idx" )
1319+ set_x2_value = data .draw (case .x2_cond_from_dtype (x2 .dtype ), label = "set x2 value" )
1320+ x2 [set_x2_idx ] = set_x2_value
1321+ note (f"{ x2 = } " )
1322+
1323+ res = xp .asarray (x1 , copy = True )
1324+ iop (res , x2 )
1325+ # sanity check
1326+ ph .assert_result_shape (iop_name , [x1 .shape , x2 .shape ], res .shape )
1327+
1328+ good_example = False
1329+ for l_idx , r_idx , o_idx in all_indices :
1330+ l = float (x1 [l_idx ])
1331+ r = float (x2 [r_idx ])
1332+ if case .cond (l , r ):
1333+ good_example = True
1334+ o = float (res [o_idx ])
1335+ f_left = f"{ sh .fmt_idx ('x1' , l_idx )} ={ l } "
1336+ f_right = f"{ sh .fmt_idx ('x2' , r_idx )} ={ r } "
1337+ f_out = f"{ sh .fmt_idx ('out' , o_idx )} ={ o } "
1338+ assert case .check_result (l , r , o ), (
1339+ f"{ f_out } , but should be { case .result_expr } [{ iop_name } ()]\n "
1340+ f"condition: { case } \n "
1341+ f"{ f_left } , { f_right } "
1342+ )
1343+ break
1344+ assume (good_example )
0 commit comments