11"""
22Test element-wise functions/operators against reference implementations.
33"""
4+ import cmath
45import math
56import operator
67from copy import copy
@@ -48,7 +49,7 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
4849def isclose (
4950 a : float ,
5051 b : float ,
51- M : float ,
52+ maximum : float ,
5253 * ,
5354 rel_tol : float = 0.25 ,
5455 abs_tol : float = 1 ,
@@ -61,12 +62,30 @@ def isclose(
6162 if math .isnan (a ) or math .isnan (b ):
6263 raise ValueError (f"{ a = } and { b = } , but input must be non-NaN" )
6364 if math .isinf (a ):
64- return math .isinf (b ) or abs (b ) > math .log (M )
65+ return math .isinf (b ) or abs (b ) > math .log (maximum )
6566 elif math .isinf (b ):
66- return math .isinf (a ) or abs (a ) > math .log (M )
67+ return math .isinf (a ) or abs (a ) > math .log (maximum )
6768 return math .isclose (a , b , rel_tol = rel_tol , abs_tol = abs_tol )
6869
6970
71+ def isclose_complex (
72+ a : complex ,
73+ b : complex ,
74+ maximum : float ,
75+ * ,
76+ rel_tol : float = 0.25 ,
77+ abs_tol : float = 1 ,
78+ ) -> bool :
79+ """Like isclose() but specifically for complex values."""
80+ if cmath .isnan (a ) or cmath .isnan (b ):
81+ raise ValueError (f"{ a = } and { b = } , but input must be non-NaN" )
82+ if cmath .isinf (a ):
83+ return cmath .isinf (b ) or abs (b ) > cmath .log (maximum )
84+ elif cmath .isinf (b ):
85+ return cmath .isinf (a ) or abs (a ) > cmath .log (maximum )
86+ return cmath .isclose (a , b , rel_tol = rel_tol , abs_tol = abs_tol )
87+
88+
7089def default_filter (s : Scalar ) -> bool :
7190 """Returns False when s is a non-finite or a signed zero.
7291
@@ -254,8 +273,7 @@ def unary_assert_against_refimpl(
254273 f"{ f_i } ={ scalar_i } "
255274 )
256275 if res .dtype in dh .complex_dtypes :
257- assert isclose (scalar_o .real , expected .real , M ), msg
258- assert isclose (scalar_o .imag , expected .imag , M ), msg
276+ assert isclose_complex (scalar_o , expected , M ), msg
259277 else :
260278 assert isclose (scalar_o , expected , M ), msg
261279 else :
@@ -330,8 +348,7 @@ def binary_assert_against_refimpl(
330348 f"{ f_l } ={ scalar_l } , { f_r } ={ scalar_r } "
331349 )
332350 if res .dtype in dh .complex_dtypes :
333- assert isclose (scalar_o .real , expected .real , M ), msg
334- assert isclose (scalar_o .imag , expected .imag , M ), msg
351+ assert isclose_complex (scalar_o , expected , M ), msg
335352 else :
336353 assert isclose (scalar_o , expected , M ), msg
337354 else :
@@ -403,8 +420,7 @@ def right_scalar_assert_against_refimpl(
403420 f"{ f_l } ={ scalar_l } "
404421 )
405422 if res .dtype in dh .complex_dtypes :
406- assert isclose (scalar_o .real , expected .real , M ), msg
407- assert isclose (scalar_o .imag , expected .imag , M ), msg
423+ assert isclose_complex (scalar_o , expected , M ), msg
408424 else :
409425 assert isclose (scalar_o , expected , M ), msg
410426 else :
@@ -1394,7 +1410,7 @@ def test_square(x):
13941410 ph .assert_dtype ("square" , in_dtype = x .dtype , out_dtype = out .dtype )
13951411 ph .assert_shape ("square" , out_shape = out .shape , expected = x .shape )
13961412 unary_assert_against_refimpl (
1397- "square" , x , out , lambda s : s ** 2 , expr_template = "{}²={}"
1413+ "square" , x , out , lambda s : s * s , expr_template = "{}²={}"
13981414 )
13991415
14001416
0 commit comments