|
1 | 1 | """ |
2 | 2 | Test element-wise functions/operators against reference implementations. |
3 | 3 | """ |
| 4 | +import cmath |
4 | 5 | import math |
5 | 6 | import operator |
6 | 7 | from copy import copy |
@@ -67,6 +68,29 @@ def isclose( |
67 | 68 | return math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) |
68 | 69 |
|
69 | 70 |
|
| 71 | +def isclose_complex( |
| 72 | + a: complex, |
| 73 | + b: complex, |
| 74 | + M: float, |
| 75 | + *, |
| 76 | + rel_tol: float = 0.25, |
| 77 | + abs_tol: float = 1, |
| 78 | +) -> bool: |
| 79 | + """Wraps math.isclose with very generous defaults. |
| 80 | +
|
| 81 | + This is useful for many floating-point operations where the spec does not |
| 82 | + make accuracy requirements. |
| 83 | + """ |
| 84 | + if cmath.isnan(a) or cmath.isnan(b): |
| 85 | + raise ValueError(f"{a=} and {b=}, but input must be non-NaN") |
| 86 | + if cmath.isinf(a): |
| 87 | + return cmath.isinf(b) or abs(b) > cmath.log(M) |
| 88 | + elif cmath.isinf(b): |
| 89 | + return cmath.isinf(a) or abs(a) > cmath.log(M) |
| 90 | + return cmath.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol) |
| 91 | + |
| 92 | + |
| 93 | + |
70 | 94 | def default_filter(s: Scalar) -> bool: |
71 | 95 | """Returns False when s is a non-finite or a signed zero. |
72 | 96 |
|
@@ -254,8 +278,7 @@ def unary_assert_against_refimpl( |
254 | 278 | f"{f_i}={scalar_i}" |
255 | 279 | ) |
256 | 280 | if res.dtype in dh.complex_dtypes: |
257 | | - assert isclose(scalar_o.real, expected.real, M), msg |
258 | | - assert isclose(scalar_o.imag, expected.imag, M), msg |
| 281 | + assert isclose_complex(scalar_o, expected, M), msg |
259 | 282 | else: |
260 | 283 | assert isclose(scalar_o, expected, M), msg |
261 | 284 | else: |
@@ -330,8 +353,7 @@ def binary_assert_against_refimpl( |
330 | 353 | f"{f_l}={scalar_l}, {f_r}={scalar_r}" |
331 | 354 | ) |
332 | 355 | if res.dtype in dh.complex_dtypes: |
333 | | - assert isclose(scalar_o.real, expected.real, M), msg |
334 | | - assert isclose(scalar_o.imag, expected.imag, M), msg |
| 356 | + assert isclose_complex(scalar_o, expected, M), msg |
335 | 357 | else: |
336 | 358 | assert isclose(scalar_o, expected, M), msg |
337 | 359 | else: |
@@ -403,8 +425,7 @@ def right_scalar_assert_against_refimpl( |
403 | 425 | f"{f_l}={scalar_l}" |
404 | 426 | ) |
405 | 427 | if res.dtype in dh.complex_dtypes: |
406 | | - assert isclose(scalar_o.real, expected.real, M), msg |
407 | | - assert isclose(scalar_o.imag, expected.imag, M), msg |
| 428 | + assert isclose_complex(scalar_o, expected, M), msg |
408 | 429 | else: |
409 | 430 | assert isclose(scalar_o, expected, M), msg |
410 | 431 | else: |
@@ -1394,7 +1415,7 @@ def test_square(x): |
1394 | 1415 | ph.assert_dtype("square", in_dtype=x.dtype, out_dtype=out.dtype) |
1395 | 1416 | ph.assert_shape("square", out_shape=out.shape, expected=x.shape) |
1396 | 1417 | unary_assert_against_refimpl( |
1397 | | - "square", x, out, lambda s: s**2, expr_template="{}²={}" |
| 1418 | + "square", x, out, lambda s: s*s, expr_template="{}²={}" |
1398 | 1419 | ) |
1399 | 1420 |
|
1400 | 1421 |
|
|
0 commit comments