@@ -62,7 +62,8 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
6262#
6363# By default, floating-point functions/methods are loosely asserted against. Use
6464# `strict_check=True` when they should be strictly asserted against, i.e.
65- # when a function should return intergrals.
65+ # when a function should return intergrals. Likewise, use `strict_check=False`
66+ # when integer function/methods should be loosely asserted against.
6667
6768
6869def isclose (a : float , b : float , rel_tol : float = 0.25 , abs_tol : float = 1 ) -> bool :
@@ -92,7 +93,7 @@ def unary_assert_against_refimpl(
9293 expr_template : Optional [str ] = None ,
9394 res_stype : Optional [ScalarType ] = None ,
9495 filter_ : Callable [[Scalar ], bool ] = default_filter ,
95- strict_check : bool = False ,
96+ strict_check : Optional [ bool ] = None ,
9697):
9798 if in_ .shape != res .shape :
9899 raise ValueError (f"{ res .shape = } , but should be { in_ .shape = } " )
@@ -108,7 +109,7 @@ def unary_assert_against_refimpl(
108109 continue
109110 try :
110111 expected = refimpl (scalar_i )
111- except OverflowError :
112+ except Exception :
112113 continue
113114 if res .dtype != xp .bool :
114115 assert m is not None and M is not None # for mypy
@@ -118,7 +119,7 @@ def unary_assert_against_refimpl(
118119 f_i = sh .fmt_idx ("x" , idx )
119120 f_o = sh .fmt_idx ("out" , idx )
120121 expr = expr_template .format (f_i , expected )
121- if not strict_check and dh .is_float_dtype (res .dtype ):
122+ if strict_check == False or dh .is_float_dtype (res .dtype ):
122123 assert isclose (scalar_o , expected ), (
123124 f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
124125 f"{ f_i } ={ scalar_i } "
@@ -142,7 +143,7 @@ def binary_assert_against_refimpl(
142143 right_sym : str = "x2" ,
143144 res_name : str = "out" ,
144145 filter_ : Callable [[Scalar ], bool ] = default_filter ,
145- strict_check : bool = False ,
146+ strict_check : Optional [ bool ] = None ,
146147):
147148 if expr_template is None :
148149 expr_template = func_name + "({}, {})={}"
@@ -157,7 +158,7 @@ def binary_assert_against_refimpl(
157158 continue
158159 try :
159160 expected = refimpl (scalar_l , scalar_r )
160- except OverflowError :
161+ except Exception :
161162 continue
162163 if res .dtype != xp .bool :
163164 assert m is not None and M is not None # for mypy
@@ -168,7 +169,7 @@ def binary_assert_against_refimpl(
168169 f_r = sh .fmt_idx (right_sym , r_idx )
169170 f_o = sh .fmt_idx (res_name , o_idx )
170171 expr = expr_template .format (f_l , f_r , expected )
171- if not strict_check and dh .is_float_dtype (res .dtype ):
172+ if strict_check == False or dh .is_float_dtype (res .dtype ):
172173 assert isclose (scalar_o , expected ), (
173174 f"{ f_o } ={ scalar_o } , but should be roughly { expr } [{ func_name } ()]\n "
174175 f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
@@ -384,11 +385,12 @@ def binary_param_assert_against_refimpl(
384385 refimpl : Callable [[Scalar , Scalar ], Scalar ],
385386 res_stype : Optional [ScalarType ] = None ,
386387 filter_ : Callable [[Scalar ], bool ] = default_filter ,
387- strict_check : bool = False ,
388+ strict_check : Optional [ bool ] = None ,
388389):
389390 expr_template = "({} " + op_sym + " {})={}"
390391 if ctx .right_is_scalar :
391- assert filter_ (right ) # sanity check
392+ if filter_ (right ):
393+ return # short-circuit here as there will be nothing to test
392394 in_stype = dh .get_scalar_type (left .dtype )
393395 if res_stype is None :
394396 res_stype = in_stype
@@ -399,7 +401,7 @@ def binary_param_assert_against_refimpl(
399401 continue
400402 try :
401403 expected = refimpl (scalar_l , right )
402- except OverflowError :
404+ except Exception :
403405 continue
404406 if left .dtype != xp .bool :
405407 assert m is not None and M is not None # for mypy
@@ -409,7 +411,7 @@ def binary_param_assert_against_refimpl(
409411 f_l = sh .fmt_idx (ctx .left_sym , idx )
410412 f_o = sh .fmt_idx (ctx .res_name , idx )
411413 expr = expr_template .format (f_l , right , expected )
412- if not strict_check and dh .is_float_dtype (left .dtype ):
414+ if strict_check == False or dh .is_float_dtype (res .dtype ):
413415 assert isclose (scalar_o , expected ), (
414416 f"{ f_o } ={ scalar_o } , but should be roughly { expr } "
415417 f"[{ ctx .func_name } ()]\n "
@@ -704,16 +706,22 @@ def test_cosh(x):
704706def test_divide (ctx , data ):
705707 left = data .draw (ctx .left_strat , label = ctx .left_sym )
706708 right = data .draw (ctx .right_strat , label = ctx .right_sym )
709+ if ctx .right_is_scalar :
710+ assume
707711
708712 res = ctx .func (left , right )
709713
710714 binary_param_assert_dtype (ctx , left , right , res )
711715 binary_param_assert_shape (ctx , left , right , res )
712- # There isn't much we can test here. The spec doesn't require any behavior
713- # beyond the special cases, and indeed, there aren't many mathematical
714- # properties of division that strictly hold for floating-point numbers. We
715- # could test that this does implement IEEE 754 division, but we don't yet
716- # have those sorts in general for this module.
716+ binary_param_assert_against_refimpl (
717+ ctx ,
718+ left ,
719+ right ,
720+ res ,
721+ "/" ,
722+ operator .truediv ,
723+ filter_ = lambda s : math .isfinite (s ) and s != 0 ,
724+ )
717725
718726
719727@pytest .mark .parametrize ("ctx" , make_binary_params ("equal" , xps .scalar_dtypes ()))
@@ -836,17 +844,7 @@ def test_isfinite(x):
836844 out = ah .isfinite (x )
837845 ph .assert_dtype ("isfinite" , x .dtype , out .dtype , xp .bool )
838846 ph .assert_shape ("isfinite" , out .shape , x .shape )
839- if dh .is_int_dtype (x .dtype ):
840- ah .assert_exactly_equal (out , ah .true (x .shape ))
841- # Test that isfinite, isinf, and isnan are self-consistent.
842- inf = ah .logical_or (xp .isinf (x ), ah .isnan (x ))
843- ah .assert_exactly_equal (out , ah .logical_not (inf ))
844-
845- # Test the exact value by comparing to the math version
846- if dh .is_float_dtype (x .dtype ):
847- for idx in sh .ndindex (x .shape ):
848- s = float (x [idx ])
849- assert bool (out [idx ]) == math .isfinite (s )
847+ unary_assert_against_refimpl ("isfinite" , x , out , math .isfinite , res_stype = bool )
850848
851849
852850@given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
@@ -949,9 +947,10 @@ def test_log10(x):
949947def test_logaddexp (x1 , x2 ):
950948 out = xp .logaddexp (x1 , x2 )
951949 ph .assert_dtype ("logaddexp" , [x1 .dtype , x2 .dtype ], out .dtype )
952- # The spec doesn't require any behavior for this function. We could test
953- # that this is indeed an approximation of log(exp(x1) + exp(x2)), but we
954- # don't have tests for this sort of thing for any functions yet.
950+ ph .assert_result_shape ("logaddexp" , [x1 .shape , x2 .shape ], out .shape )
951+ binary_assert_against_refimpl (
952+ "logaddexp" , x1 , x2 , out , lambda l , r : math .log (math .exp (l ) + math .exp (r ))
953+ )
955954
956955
957956@given (* hh .two_mutual_arrays ([xp .bool ]))
@@ -1078,11 +1077,9 @@ def test_pow(ctx, data):
10781077
10791078 binary_param_assert_dtype (ctx , left , right , res )
10801079 binary_param_assert_shape (ctx , left , right , res )
1081- # There isn't much we can test here. The spec doesn't require any behavior
1082- # beyond the special cases, and indeed, there aren't many mathematical
1083- # properties of exponentiation that strictly hold for floating-point
1084- # numbers. We could test that this does implement IEEE 754 pow, but we
1085- # don't yet have those sorts in general for this module.
1080+ binary_param_assert_against_refimpl (
1081+ ctx , left , right , res , "**" , math .pow , strict_check = False
1082+ )
10861083
10871084
10881085@pytest .mark .parametrize ("ctx" , make_binary_params ("remainder" , xps .numeric_dtypes ()))
@@ -1110,28 +1107,14 @@ def test_round(x):
11101107 unary_assert_against_refimpl ("round" , x , out , round , strict_check = True )
11111108
11121109
1113- @given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes ()))
1110+ @given (xps .arrays (dtype = xps .numeric_dtypes (), shape = hh .shapes (), elements = finite_kw ))
11141111def test_sign (x ):
11151112 out = xp .sign (x )
11161113 ph .assert_dtype ("sign" , x .dtype , out .dtype )
11171114 ph .assert_shape ("sign" , out .shape , x .shape )
1118- scalar_type = dh .get_scalar_type (out .dtype )
1119- for idx in sh .ndindex (x .shape ):
1120- scalar_x = scalar_type (x [idx ])
1121- f_x = sh .fmt_idx ("x" , idx )
1122- if math .isnan (scalar_x ):
1123- continue
1124- if scalar_x == 0 :
1125- expected = 0
1126- expr = f"{ f_x } =0"
1127- else :
1128- expected = 1 if scalar_x > 0 else - 1
1129- expr = f"({ f_x } / |{ f_x } |)={ expected } "
1130- scalar_o = scalar_type (out [idx ])
1131- f_o = sh .fmt_idx ("out" , idx )
1132- assert (
1133- scalar_o == expected
1134- ), f"{ f_o } ={ scalar_o } , but should be { expr } [sign()]\n { f_x } ={ scalar_x } "
1115+ unary_assert_against_refimpl (
1116+ "sign" , x , out , lambda s : math .copysign (1 , s ), filter_ = lambda s : s != 0
1117+ )
11351118
11361119
11371120@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
0 commit comments