@@ -397,7 +397,7 @@ def assert_scalar_equals(
397397 kw : dict = {},
398398):
399399 """
400- Assert a 0d array, convered to a scalar, is as expected, e.g.
400+ Assert a 0d array, converted to a scalar, is as expected, e.g.
401401
402402 >>> x = xp.ones(5, dtype=xp.uint8)
403403 >>> out = xp.sum(x)
@@ -407,6 +407,8 @@ def assert_scalar_equals(
407407
408408 >>> assert int(out) == 5
409409
410+ NOTE: This function does *exact* comparison, even for floats. For
411+ approximate float comparisons use assert_scalar_isclose
410412 """
411413 __tracebackhide__ = True
412414 repr_name = repr_name if idx == () else f"{ repr_name } [{ idx } ]"
@@ -418,8 +420,40 @@ def assert_scalar_equals(
418420 msg = f"{ repr_name } ={ out } , but should be { expected } [{ f_func } ]"
419421 assert cmath .isnan (out ), msg
420422 else :
421- msg = f"{ repr_name } ={ out } , but should be roughly { expected } [{ f_func } ]"
422- assert cmath .isclose (out , expected , rel_tol = 0.25 , abs_tol = 1 ), msg
423+ msg = f"{ repr_name } ={ out } , but should be { expected } [{ f_func } ]"
424+ assert out == expected , msg
425+
426+
427+ def assert_scalar_isclose (
428+ func_name : str ,
429+ * ,
430+ rel_tol : float = 0.25 ,
431+ abs_tol : float = 1 ,
432+ type_ : ScalarType ,
433+ idx : Shape ,
434+ out : Scalar ,
435+ expected : Scalar ,
436+ repr_name : str = "out" ,
437+ kw : dict = {},
438+ ):
439+ """
440+ Assert a 0d array, converted to a scalar, is close to the expected value, e.g.
441+
442+ >>> x = xp.ones(5., dtype=xp.float64)
443+ >>> out = xp.sum(x)
444+ >>> assert_scalar_isclose('sum', type_int, out=(), out=int(out), expected=5.)
445+
446+ is equivalent to
447+
448+ >>> assert math.isclose(float(out) == 5.)
449+
450+ """
451+ __tracebackhide__ = True
452+ repr_name = repr_name if idx == () else f"{ repr_name } [{ idx } ]"
453+ f_func = f"{ func_name } ({ fmt_kw (kw )} )"
454+ msg = f"{ repr_name } ={ out } , but should be roughly { expected } [{ f_func } ]"
455+ assert type_ in [float , complex ] # Sanity check
456+ assert cmath .isclose (out , expected , rel_tol = 0.25 , abs_tol = 1 ), msg
423457
424458
425459def assert_fill (
0 commit comments