@@ -1127,12 +1127,11 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11271127 return cases
11281128
11291129
1130- category_stub_pairs = [(c , s ) for c , stubs in category_to_funcs .items () for s in stubs ]
11311130unary_params = []
11321131binary_params = []
11331132iop_params = []
11341133func_to_op : Dict [str , str ] = {v : k for k , v in dh .op_to_func .items ()}
1135- for category , stub in category_stub_pairs :
1134+ for stub in category_to_funcs [ "elementwise" ] :
11361135 if stub .__doc__ is None :
11371136 warn (f"{ stub .__name__ } () stub has no docstring" )
11381137 continue
@@ -1153,56 +1152,51 @@ def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
11531152 if len (sig .parameters ) == 0 :
11541153 warn (f"{ func = } has no parameters" )
11551154 continue
1156- if category == "elementwise" :
1157- if param_names [0 ] == "x" :
1158- if cases := parse_unary_case_block (case_block ):
1159- name_to_func = {stub .__name__ : func }
1160- if stub .__name__ in func_to_op .keys ():
1161- op_name = func_to_op [stub .__name__ ]
1162- op = getattr (operator , op_name )
1163- name_to_func [op_name ] = op
1164- for func_name , func in name_to_func .items ():
1165- for case in cases :
1166- id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1167- p = pytest .param (func_name , func , case , id = id_ )
1168- unary_params .append (p )
1169- else :
1170- warn ("TODO" )
1171- continue
1172- if len (sig .parameters ) == 1 :
1173- warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
1174- continue
1175- if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1176- if cases := parse_binary_case_block (case_block ):
1177- name_to_func = {stub .__name__ : func }
1178- if stub .__name__ in func_to_op .keys ():
1179- op_name = func_to_op [stub .__name__ ]
1180- op = getattr (operator , op_name )
1181- name_to_func [op_name ] = op
1182- # We collect inplace operator test cases seperately
1183- iop_name = "__i" + op_name [2 :]
1184- iop = getattr (operator , iop_name )
1185- for case in cases :
1186- id_ = f"{ iop_name } ({ case .cond_expr } ) -> { case .result_expr } "
1187- p = pytest .param (iop_name , iop , case , id = id_ )
1188- iop_params .append (p )
1189- for func_name , func in name_to_func .items ():
1190- for case in cases :
1191- id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1192- p = pytest .param (func_name , func , case , id = id_ )
1193- binary_params .append (p )
1194- else :
1195- warn ("TODO" )
1196- continue
1155+ if param_names [0 ] == "x" :
1156+ if cases := parse_unary_case_block (case_block ):
1157+ name_to_func = {stub .__name__ : func }
1158+ if stub .__name__ in func_to_op .keys ():
1159+ op_name = func_to_op [stub .__name__ ]
1160+ op = getattr (operator , op_name )
1161+ name_to_func [op_name ] = op
1162+ for func_name , func in name_to_func .items ():
1163+ for case in cases :
1164+ id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1165+ p = pytest .param (func_name , func , case , id = id_ )
1166+ unary_params .append (p )
11971167 else :
1198- warn (
1199- f"{ func = } starts with two parameters '{ param_names [0 ]} ' and "
1200- f"'{ param_names [1 ]} ', which are not named 'x1' and 'x2'"
1201- )
1202- elif category == "statistical" :
1203- pass # TODO
1168+ warn ("TODO" )
1169+ continue
1170+ if len (sig .parameters ) == 1 :
1171+ warn (f"{ func = } has one parameter '{ param_names [0 ]} ' which is not named 'x'" )
1172+ continue
1173+ if param_names [0 ] == "x1" and param_names [1 ] == "x2" :
1174+ if cases := parse_binary_case_block (case_block ):
1175+ name_to_func = {stub .__name__ : func }
1176+ if stub .__name__ in func_to_op .keys ():
1177+ op_name = func_to_op [stub .__name__ ]
1178+ op = getattr (operator , op_name )
1179+ name_to_func [op_name ] = op
1180+ # We collect inplace operator test cases seperately
1181+ iop_name = "__i" + op_name [2 :]
1182+ iop = getattr (operator , iop_name )
1183+ for case in cases :
1184+ id_ = f"{ iop_name } ({ case .cond_expr } ) -> { case .result_expr } "
1185+ p = pytest .param (iop_name , iop , case , id = id_ )
1186+ iop_params .append (p )
1187+ for func_name , func in name_to_func .items ():
1188+ for case in cases :
1189+ id_ = f"{ func_name } ({ case .cond_expr } ) -> { case .result_expr } "
1190+ p = pytest .param (func_name , func , case , id = id_ )
1191+ binary_params .append (p )
1192+ else :
1193+ warn ("TODO" )
1194+ continue
12041195 else :
1205- warn ("TODO" )
1196+ warn (
1197+ f"{ func = } starts with two parameters '{ param_names [0 ]} ' and "
1198+ f"'{ param_names [1 ]} ', which are not named 'x1' and 'x2'"
1199+ )
12061200
12071201
12081202# test_unary and test_binary naively generate arrays, i.e. arrays that might not
@@ -1342,3 +1336,24 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):
13421336 )
13431337 break
13441338 assume (good_example )
1339+
1340+
1341+ @pytest .mark .parametrize (
1342+ "func_name" , [f .__name__ for f in category_to_funcs ["statistical" ]]
1343+ )
1344+ @given (
1345+ x = xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes (min_side = 1 )),
1346+ data = st .data (),
1347+ )
1348+ def test_nan_propagation (func_name , x , data ):
1349+ func = getattr (xp , func_name )
1350+ set_idx = data .draw (
1351+ xps .indices (x .shape , max_dims = 0 , allow_ellipsis = False ), label = "set idx"
1352+ )
1353+ x [set_idx ] = float ("nan" )
1354+ note (f"{ x = } " )
1355+
1356+ out = func (x )
1357+
1358+ ph .assert_shape (func_name , out .shape , ()) # sanity check
1359+ assert xp .isnan (out ), f"{ out = !r} , but should be NaN"
0 commit comments