@@ -1791,6 +1791,10 @@ def _check_binary_with_scalars(func_data, x1x2):
17911791 )
17921792
17931793
1794+ def _filter_zero (x ):
1795+ return x != 0 if dh .is_scalar (x ) else (not xp .any (x == 0 ))
1796+
1797+
17941798@pytest .mark .min_version ("2024.12" )
17951799@pytest .mark .parametrize ('func_data' ,
17961800 # xp_func, name, refimpl, kwargs, expected_dtype
@@ -1812,11 +1816,19 @@ def _check_binary_with_scalars(func_data, x1x2):
18121816 (xp .less_equal , "les_equal" , operator .le , {}, xp .bool ),
18131817 (xp .greater , "greater" , operator .gt , {}, xp .bool ),
18141818 (xp .greater_equal , "greater_equal" , operator .ge , {}, xp .bool ),
1819+ (xp .remainder , "remainder" , operator .mod , {}, None ),
1820+ (xp .floor_divide , "floor_divide" , operator .floordiv , {}, None ),
18151821 ],
18161822 ids = lambda func_data : func_data [1 ] # use names for test IDs
18171823)
18181824@given (x1x2 = hh .array_and_py_scalar (dh .real_float_dtypes ))
18191825def test_binary_with_scalars_real (func_data , x1x2 ):
1826+
1827+ if func_data [1 ] == "remainder" :
1828+ assume (_filter_zero (x1x2 [1 ]))
1829+ if func_data [1 ] == "floor_divide" :
1830+ assume (_filter_zero (x1x2 [0 ]) and _filter_zero (x1x2 [1 ]))
1831+
18201832 _check_binary_with_scalars (func_data , x1x2 )
18211833
18221834
0 commit comments