@@ -60,14 +60,22 @@ def mock_int_dtype(n: int, dtype: DataType) -> int:
6060 return n
6161
6262
63+ def default_filter (s : Scalar ) -> bool :
64+ """Returns False when s is a non-finite or a signed zero.
65+
66+ Used by default as these values are typically special-cased.
67+ """
68+ return math .isfinite (s ) and s is not - 0.0 and s is not + 0.0
69+
70+
6371def unary_assert_against_refimpl (
6472 func_name : str ,
6573 in_ : Array ,
6674 res : Array ,
6775 refimpl : Callable [[Scalar ], Scalar ],
6876 expr_template : str ,
6977 res_stype : Optional [ScalarType ] = None ,
70- filter_ : Callable [[Scalar ], bool ] = math . isfinite ,
78+ filter_ : Callable [[Scalar ], bool ] = default_filter ,
7179):
7280 if in_ .shape != res .shape :
7381 raise ValueError (f"{ res .shape = } , but should be { in_ .shape = } " )
@@ -114,7 +122,7 @@ def binary_assert_against_refimpl(
114122 left_sym : str = "x1" ,
115123 right_sym : str = "x2" ,
116124 res_name : str = "out" ,
117- filter_ : Callable [[Scalar ], bool ] = math . isfinite ,
125+ filter_ : Callable [[Scalar ], bool ] = default_filter ,
118126):
119127 in_stype = dh .get_scalar_type (left .dtype )
120128 if res_stype is None :
@@ -353,7 +361,7 @@ def binary_param_assert_against_refimpl(
353361 refimpl : Callable [[Scalar , Scalar ], Scalar ],
354362 expr_template : str ,
355363 res_stype : Optional [ScalarType ] = None ,
356- filter_ : Callable [[Scalar ], bool ] = math . isfinite ,
364+ filter_ : Callable [[Scalar ], bool ] = default_filter ,
357365):
358366 if ctx .right_is_scalar :
359367 assert filter_ (right ) # sanity check
@@ -429,36 +437,30 @@ def test_abs(ctx, data):
429437 )
430438
431439
432- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
440+ @given (
441+ xps .arrays (
442+ dtype = xps .floating_dtypes (),
443+ shape = hh .shapes (),
444+ elements = {"min_value" : - 1 , "max_value" : 1 },
445+ )
446+ )
433447def test_acos (x ):
434- res = xp .acos (x )
435- ph .assert_dtype ("acos" , x .dtype , res .dtype )
436- ph .assert_shape ("acos" , res .shape , x .shape )
437- ONE = ah .one (x .shape , x .dtype )
438- # Here (and elsewhere), should technically be res.dtype, but this is the
439- # same as x.dtype, as tested by the type_promotion tests.
440- PI = ah .π (x .shape , x .dtype )
441- ZERO = ah .zero (x .shape , x .dtype )
442- domain = ah .inrange (x , - ONE , ONE )
443- codomain = ah .inrange (res , ZERO , PI )
444- # acos maps [-1, 1] to [0, pi]. Values outside this domain are mapped to
445- # nan, which is already tested in the special cases.
446- ah .assert_exactly_equal (domain , codomain )
448+ out = xp .acos (x )
449+ ph .assert_dtype ("acos" , x .dtype , out .dtype )
450+ ph .assert_shape ("acos" , out .shape , x .shape )
451+ unary_assert_against_refimpl ("acos" , x , out , math .acos , "acos({})={}" )
447452
448453
449- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
454+ @given (
455+ xps .arrays (
456+ dtype = xps .floating_dtypes (), shape = hh .shapes (), elements = {"min_value" : 1 }
457+ )
458+ )
450459def test_acosh (x ):
451- res = xp .acosh (x )
452- ph .assert_dtype ("acosh" , x .dtype , res .dtype )
453- ph .assert_shape ("acosh" , res .shape , x .shape )
454- ONE = ah .one (x .shape , x .dtype )
455- INFINITY = ah .infinity (x .shape , x .dtype )
456- ZERO = ah .zero (x .shape , x .dtype )
457- domain = ah .inrange (x , ONE , INFINITY )
458- codomain = ah .inrange (res , ZERO , INFINITY )
459- # acosh maps [-1, inf] to [0, inf]. Values outside this domain are mapped
460- # to nan, which is already tested in the special cases.
461- ah .assert_exactly_equal (domain , codomain )
460+ out = xp .acosh (x )
461+ ph .assert_dtype ("acosh" , x .dtype , out .dtype )
462+ ph .assert_shape ("acosh" , out .shape , x .shape )
463+ unary_assert_against_refimpl ("acosh" , x , out , math .acosh , "acosh({})={}" )
462464
463465
464466@pytest .mark .parametrize ("ctx," , make_binary_params ("add" , xps .numeric_dtypes ()))
@@ -479,101 +481,56 @@ def test_add(ctx, data):
479481 )
480482
481483
482- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
484+ @given (
485+ xps .arrays (
486+ dtype = xps .floating_dtypes (),
487+ shape = hh .shapes (),
488+ elements = {"min_value" : - 1 , "max_value" : 1 },
489+ )
490+ )
483491def test_asin (x ):
484492 out = xp .asin (x )
485493 ph .assert_dtype ("asin" , x .dtype , out .dtype )
486494 ph .assert_shape ("asin" , out .shape , x .shape )
487- ONE = ah .one (x .shape , x .dtype )
488- PI = ah .π (x .shape , x .dtype )
489- domain = ah .inrange (x , - ONE , ONE )
490- codomain = ah .inrange (out , - PI / 2 , PI / 2 )
491- # asin maps [-1, 1] to [-pi/2, pi/2]. Values outside this domain are
492- # mapped to nan, which is already tested in the special cases.
493- ah .assert_exactly_equal (domain , codomain )
495+ unary_assert_against_refimpl ("asin" , x , out , math .asin , "asin({})={}" )
494496
495497
496498@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
497499def test_asinh (x ):
498500 out = xp .asinh (x )
499501 ph .assert_dtype ("asinh" , x .dtype , out .dtype )
500502 ph .assert_shape ("asinh" , out .shape , x .shape )
501- INFINITY = ah .infinity (x .shape , x .dtype )
502- domain = ah .inrange (x , - INFINITY , INFINITY )
503- codomain = ah .inrange (out , - INFINITY , INFINITY )
504- # asinh maps [-inf, inf] to [-inf, inf]. Values outside this domain are
505- # mapped to nan, which is already tested in the special cases.
506- ah .assert_exactly_equal (domain , codomain )
503+ unary_assert_against_refimpl ("asinh" , x , out , math .asinh , "asinh({})={}" )
507504
508505
509506@given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
510507def test_atan (x ):
511508 out = xp .atan (x )
512509 ph .assert_dtype ("atan" , x .dtype , out .dtype )
513510 ph .assert_shape ("atan" , out .shape , x .shape )
514- INFINITY = ah .infinity (x .shape , x .dtype )
515- PI = ah .π (x .shape , x .dtype )
516- domain = ah .inrange (x , - INFINITY , INFINITY )
517- codomain = ah .inrange (out , - PI / 2 , PI / 2 )
518- # atan maps [-inf, inf] to [-pi/2, pi/2]. Values outside this domain are
519- # mapped to nan, which is already tested in the special cases.
520- ah .assert_exactly_equal (domain , codomain )
511+ unary_assert_against_refimpl ("atan" , x , out , math .atan , "atan({})={}" )
521512
522513
523514@given (* hh .two_mutual_arrays (dh .float_dtypes ))
524515def test_atan2 (x1 , x2 ):
525516 out = xp .atan2 (x1 , x2 )
526517 ph .assert_dtype ("atan2" , [x1 .dtype , x2 .dtype ], out .dtype )
527518 ph .assert_result_shape ("atan2" , [x1 .shape , x2 .shape ], out .shape )
528- INFINITY1 = ah .infinity (x1 .shape , x1 .dtype )
529- INFINITY2 = ah .infinity (x2 .shape , x2 .dtype )
530- PI = ah .π (out .shape , out .dtype )
531- domainx1 = ah .inrange (x1 , - INFINITY1 , INFINITY1 )
532- domainx2 = ah .inrange (x2 , - INFINITY2 , INFINITY2 )
533- # codomain = ah.inrange(out, -PI, PI, 1e-5)
534- codomain = ah .inrange (out , - PI , PI )
535- # atan2 maps [-inf, inf] x [-inf, inf] to [-pi, pi]. Values outside
536- # this domain are mapped to nan, which is already tested in the special
537- # cases.
538- ah .assert_exactly_equal (ah .logical_and (domainx1 , domainx2 ), codomain )
539- # From the spec:
540- #
541- # The mathematical signs of `x1_i` and `x2_i` determine the quadrant of
542- # each element-wise out. The quadrant (i.e., branch) is chosen such
543- # that each element-wise out is the signed angle in radians between the
544- # ray ending at the origin and passing through the point `(1,0)` and the
545- # ray ending at the origin and passing through the point `(x2_i, x1_i)`.
546-
547- # This is equivalent to atan2(x1, x2) has the same sign as x1 when x2 is
548- # finite.
549- pos_x1 = ah .positive_mathematical_sign (x1 )
550- neg_x1 = ah .negative_mathematical_sign (x1 )
551- pos_x2 = ah .positive_mathematical_sign (x2 )
552- neg_x2 = ah .negative_mathematical_sign (x2 )
553- pos_out = ah .positive_mathematical_sign (out )
554- neg_out = ah .negative_mathematical_sign (out )
555- ah .assert_exactly_equal (
556- ah .logical_or (ah .logical_and (pos_x1 , pos_x2 ), ah .logical_and (pos_x1 , neg_x2 )),
557- pos_out ,
558- )
559- ah .assert_exactly_equal (
560- ah .logical_or (ah .logical_and (neg_x1 , pos_x2 ), ah .logical_and (neg_x1 , neg_x2 )),
561- neg_out ,
562- )
519+ binary_assert_against_refimpl ("atan2" , x1 , x2 , out , math .atan2 , "atan2({})={}" )
563520
564521
565- @given (xps .arrays (dtype = xps .floating_dtypes (), shape = hh .shapes ()))
522+ @given (
523+ xps .arrays (
524+ dtype = xps .floating_dtypes (),
525+ shape = hh .shapes (),
526+ elements = {"min_value" : - 1 , "max_value" : 1 },
527+ )
528+ )
566529def test_atanh (x ):
567530 out = xp .atanh (x )
568531 ph .assert_dtype ("atanh" , x .dtype , out .dtype )
569532 ph .assert_shape ("atanh" , out .shape , x .shape )
570- ONE = ah .one (x .shape , x .dtype )
571- INFINITY = ah .infinity (x .shape , x .dtype )
572- domain = ah .inrange (x , - ONE , ONE )
573- codomain = ah .inrange (out , - INFINITY , INFINITY )
574- # atanh maps [-1, 1] to [-inf, inf]. Values outside this domain are
575- # mapped to nan, which is already tested in the special cases.
576- ah .assert_exactly_equal (domain , codomain )
533+ unary_assert_against_refimpl ("atanh" , x , out , math .atanh , "atanh({})={}" )
577534
578535
579536@pytest .mark .parametrize (
0 commit comments