@@ -49,7 +49,7 @@ def isclose(a: float, b: float, rel_tol: float = 0.25, abs_tol: float = 1) -> bo
4949
5050
5151def mock_int_dtype (n : int , dtype : DataType ) -> int :
52- """Returns equivalent of `n` that mocks `dtype` behaviour"""
52+ """Returns equivalent of `n` that mocks `dtype` behaviour. """
5353 nbits = dh .dtype_nbits [dtype ]
5454 mask = (1 << nbits ) - 1
5555 n &= mask
@@ -76,6 +76,7 @@ def unary_assert_against_refimpl(
7676 expr_template : Optional [str ] = None ,
7777 res_stype : Optional [ScalarType ] = None ,
7878 filter_ : Callable [[Scalar ], bool ] = default_filter ,
79+ strict_check : bool = False ,
7980):
8081 if in_ .shape != res .shape :
8182 raise ValueError (f"{ res .shape = } , but should be { in_ .shape = } " )
@@ -101,7 +102,7 @@ def unary_assert_against_refimpl(
101102 f_i = sh .fmt_idx ("x" , idx )
102103 f_o = sh .fmt_idx ("out" , idx )
103104 expr = expr_template .format (f_i , expected )
104- if dh .is_float_dtype (res .dtype ):
105+ if not strict_check and dh .is_float_dtype (res .dtype ):
105106 assert isclose (scalar_o , expected ), (
106107 f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
107108 f"{ f_i } ={ scalar_i } "
@@ -125,6 +126,7 @@ def binary_assert_against_refimpl(
125126 right_sym : str = "x2" ,
126127 res_name : str = "out" ,
127128 filter_ : Callable [[Scalar ], bool ] = default_filter ,
129+ strict_check : bool = False ,
128130):
129131 if expr_template is None :
130132 expr_template = func_name + "({}, {})={}"
@@ -150,7 +152,7 @@ def binary_assert_against_refimpl(
150152 f_r = sh .fmt_idx (right_sym , r_idx )
151153 f_o = sh .fmt_idx (res_name , o_idx )
152154 expr = expr_template .format (f_l , f_r , expected )
153- if dh .is_float_dtype (res .dtype ):
155+ if not strict_check and dh .is_float_dtype (res .dtype ):
154156 assert isclose (scalar_o , expected ), (
155157 f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
156158 f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
@@ -366,6 +368,7 @@ def binary_param_assert_against_refimpl(
366368 refimpl : Callable [[Scalar , Scalar ], Scalar ],
367369 res_stype : Optional [ScalarType ] = None ,
368370 filter_ : Callable [[Scalar ], bool ] = default_filter ,
371+ strict_check : bool = False ,
369372):
370373 expr_template = "({} " + op_sym + " {})={}"
371374 if ctx .right_is_scalar :
@@ -390,7 +393,7 @@ def binary_param_assert_against_refimpl(
390393 f_l = sh .fmt_idx (ctx .left_sym , idx )
391394 f_o = sh .fmt_idx (ctx .res_name , idx )
392395 expr = expr_template .format (f_l , right , expected )
393- if dh .is_float_dtype (left .dtype ):
396+ if not strict_check and dh .is_float_dtype (left .dtype ):
394397 assert isclose (scalar_o , expected ), (
395398 f"{ f_o } ={ scalar_o } , but should be roughly { expr } "
396399 f"[{ ctx .func_name } ()]\n "
@@ -415,6 +418,7 @@ def binary_param_assert_against_refimpl(
415418 refimpl = refimpl ,
416419 expr_template = expr_template ,
417420 filter_ = filter_ ,
421+ strict_check = strict_check ,
418422 )
419423
420424
@@ -670,14 +674,7 @@ def test_ceil(x):
670674 out = xp .ceil (x )
671675 ph .assert_dtype ("ceil" , x .dtype , out .dtype )
672676 ph .assert_shape ("ceil" , out .shape , x .shape )
673- finite = ah .isfinite (x )
674- ah .assert_integral (out [finite ])
675- assert ah .all (ah .less_equal (x [finite ], out [finite ]))
676- assert ah .all (
677- ah .less_equal (out [finite ] - x [finite ], ah .one (x [finite ].shape , x .dtype ))
678- )
679- integers = ah .isintegral (x )
680- ah .assert_exactly_equal (out [integers ], x [integers ])
677+ unary_assert_against_refimpl ("ceil" , x , out , math .ceil , strict_check = True )
681678
682679
683680@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
@@ -759,18 +756,10 @@ def test_expm1(x):
759756
760757@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
761758def test_floor (x ):
762- # This test is almost identical to test_ceil
763759 out = xp .floor (x )
764760 ph .assert_dtype ("floor" , x .dtype , out .dtype )
765761 ph .assert_shape ("floor" , out .shape , x .shape )
766- finite = ah .isfinite (x )
767- ah .assert_integral (out [finite ])
768- assert ah .all (ah .less_equal (out [finite ], x [finite ]))
769- assert ah .all (
770- ah .less_equal (x [finite ] - out [finite ], ah .one (x [finite ].shape , x .dtype ))
771- )
772- integers = ah .isintegral (x )
773- ah .assert_exactly_equal (out [integers ], x [integers ])
762+ unary_assert_against_refimpl ("floor" , x , out , math .floor , strict_check = True )
774763
775764
776765@pytest .mark .parametrize (
@@ -1122,29 +1111,9 @@ def test_remainder(ctx, data):
11221111@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
11231112def test_round (x ):
11241113 out = xp .round (x )
1125-
11261114 ph .assert_dtype ("round" , x .dtype , out .dtype )
1127-
11281115 ph .assert_shape ("round" , out .shape , x .shape )
1129-
1130- # Test that the out is integral
1131- finite = ah .isfinite (x )
1132- ah .assert_integral (out [finite ])
1133-
1134- # round(x) should be the neaoutt integer to x. The case where there is a
1135- # tie (round to even) is already handled by the special cases tests.
1136-
1137- # This is the same strategy used in the mask in the
1138- # test_round_special_cases_one_arg_two_integers_equally_close special
1139- # cases test.
1140- floor = xp .floor (x )
1141- ceil = xp .ceil (x )
1142- over = xp .subtract (x , floor )
1143- under = xp .subtract (ceil , x )
1144- round_down = ah .less (over , under )
1145- round_up = ah .less (under , over )
1146- ah .assert_exactly_equal (out [round_down ], floor [round_down ])
1147- ah .assert_exactly_equal (out [round_up ], ceil [round_up ])
1116+ unary_assert_against_refimpl ("round" , x , out , round , strict_check = True )
11481117
11491118
11501119@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
@@ -1246,8 +1215,4 @@ def test_trunc(x):
12461215 out = xp .trunc (x )
12471216 ph .assert_dtype ("trunc" , x .dtype , out .dtype )
12481217 ph .assert_shape ("trunc" , out .shape , x .shape )
1249- if dh .is_int_dtype (x .dtype ):
1250- ah .assert_exactly_equal (out , x )
1251- else :
1252- finite = ah .isfinite (x )
1253- ah .assert_integral (out [finite ])
1218+ unary_assert_against_refimpl ("trunc" , x , out , math .trunc , strict_check = True )
0 commit comments