@@ -703,9 +703,9 @@ def test_abs(ctx, data):
703703 abs , # type: ignore
704704 res_stype = float if x .dtype in dh .complex_dtypes else None ,
705705 expr_template = "abs({})={}" ,
706- filter_ = lambda s : (
707- s == float ("infinity" ) or (math .isfinite (s ) and not ph .is_neg_zero (s ))
708- ),
706+ # filter_=lambda s: (
707+ # s == float("infinity") or (cmath .isfinite(s) and not ph.is_neg_zero(s))
708+ # ),
709709 )
710710
711711
@@ -714,8 +714,10 @@ def test_acos(x):
714714 out = xp .acos (x )
715715 ph .assert_dtype ("acos" , in_dtype = x .dtype , out_dtype = out .dtype )
716716 ph .assert_shape ("acos" , out_shape = out .shape , expected = x .shape )
717+ refimpl = cmath .acos if x .dtype in dh .complex_dtypes else math .acos
718+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and - 1 <= s <= 1
717719 unary_assert_against_refimpl (
718- "acos" , x , out , math . acos , filter_ = lambda s : default_filter ( s ) and - 1 <= s <= 1
720+ "acos" , x , out , refimpl , filter_ = filter_
719721 )
720722
721723
@@ -724,8 +726,10 @@ def test_acosh(x):
724726 out = xp .acosh (x )
725727 ph .assert_dtype ("acosh" , in_dtype = x .dtype , out_dtype = out .dtype )
726728 ph .assert_shape ("acosh" , out_shape = out .shape , expected = x .shape )
729+ refimpl = cmath .acosh if x .dtype in dh .complex_dtypes else math .acosh
730+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and s >= 1
727731 unary_assert_against_refimpl (
728- "acosh" , x , out , math . acosh , filter_ = lambda s : default_filter ( s ) and s >= 1
732+ "acosh" , x , out , refimpl , filter_ = filter_
729733 )
730734
731735
@@ -748,8 +752,10 @@ def test_asin(x):
748752 out = xp .asin (x )
749753 ph .assert_dtype ("asin" , in_dtype = x .dtype , out_dtype = out .dtype )
750754 ph .assert_shape ("asin" , out_shape = out .shape , expected = x .shape )
755+ refimpl = cmath .asin if x .dtype in dh .complex_dtypes else math .asin
756+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and - 1 <= s <= 1
751757 unary_assert_against_refimpl (
752- "asin" , x , out , math . asin , filter_ = lambda s : default_filter ( s ) and - 1 <= s <= 1
758+ "asin" , x , out , refimpl , filter_ = filter_
753759 )
754760
755761
@@ -758,36 +764,41 @@ def test_asinh(x):
758764 out = xp .asinh (x )
759765 ph .assert_dtype ("asinh" , in_dtype = x .dtype , out_dtype = out .dtype )
760766 ph .assert_shape ("asinh" , out_shape = out .shape , expected = x .shape )
761- unary_assert_against_refimpl ("asinh" , x , out , math .asinh )
767+ refimpl = cmath .asinh if x .dtype in dh .complex_dtypes else math .asinh
768+ unary_assert_against_refimpl ("asinh" , x , out , refimpl )
762769
763770
764771@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
765772def test_atan (x ):
766773 out = xp .atan (x )
767774 ph .assert_dtype ("atan" , in_dtype = x .dtype , out_dtype = out .dtype )
768775 ph .assert_shape ("atan" , out_shape = out .shape , expected = x .shape )
769- unary_assert_against_refimpl ("atan" , x , out , math .atan )
776+ refimpl = cmath .atan if x .dtype in dh .complex_dtypes else math .atan
777+ unary_assert_against_refimpl ("atan" , x , out , refimpl )
770778
771779
772780@given (* hh .two_mutual_arrays (dh .real_float_dtypes ))
773781def test_atan2 (x1 , x2 ):
774782 out = xp .atan2 (x1 , x2 )
775783 ph .assert_dtype ("atan2" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
776784 ph .assert_result_shape ("atan2" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
777- binary_assert_against_refimpl ("atan2" , x1 , x2 , out , math .atan2 )
785+ refimpl = cmath .atan2 if x1 .dtype in dh .complex_dtypes else math .atan2
786+ binary_assert_against_refimpl ("atan2" , x1 , x2 , out , refimpl )
778787
779788
780789@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
781790def test_atanh (x ):
782791 out = xp .atanh (x )
783792 ph .assert_dtype ("atanh" , in_dtype = x .dtype , out_dtype = out .dtype )
784793 ph .assert_shape ("atanh" , out_shape = out .shape , expected = x .shape )
794+ refimpl = cmath .atanh if x .dtype in dh .complex_dtypes else math .atanh
795+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and - 1 < s < 1
785796 unary_assert_against_refimpl (
786797 "atanh" ,
787798 x ,
788799 out ,
789- math . atanh ,
790- filter_ = lambda s : default_filter ( s ) and - 1 <= s <= 1 ,
800+ refimpl ,
801+ filter_ = filter_ ,
791802 )
792803
793804
@@ -1065,15 +1076,17 @@ def test_cos(x):
10651076 out = xp .cos (x )
10661077 ph .assert_dtype ("cos" , in_dtype = x .dtype , out_dtype = out .dtype )
10671078 ph .assert_shape ("cos" , out_shape = out .shape , expected = x .shape )
1068- unary_assert_against_refimpl ("cos" , x , out , math .cos )
1079+ refimpl = cmath .cos if x .dtype in dh .complex_dtypes else math .cos
1080+ unary_assert_against_refimpl ("cos" , x , out , refimpl )
10691081
10701082
10711083@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
10721084def test_cosh (x ):
10731085 out = xp .cosh (x )
10741086 ph .assert_dtype ("cosh" , in_dtype = x .dtype , out_dtype = out .dtype )
10751087 ph .assert_shape ("cosh" , out_shape = out .shape , expected = x .shape )
1076- unary_assert_against_refimpl ("cosh" , x , out , math .cosh )
1088+ refimpl = cmath .cosh if x .dtype in dh .complex_dtypes else math .cosh
1089+ unary_assert_against_refimpl ("cosh" , x , out , refimpl )
10771090
10781091
10791092@pytest .mark .parametrize ("ctx" , make_binary_params ("divide" , dh .all_float_dtypes ))
@@ -1097,7 +1110,7 @@ def test_divide(ctx, data):
10971110 res ,
10981111 "/" ,
10991112 operator .truediv ,
1100- filter_ = lambda s : math .isfinite (s ) and s != 0 ,
1113+ filter_ = lambda s : cmath .isfinite (s ) and s != 0 ,
11011114 )
11021115
11031116
@@ -1134,23 +1147,45 @@ def test_exp(x):
11341147 out = xp .exp (x )
11351148 ph .assert_dtype ("exp" , in_dtype = x .dtype , out_dtype = out .dtype )
11361149 ph .assert_shape ("exp" , out_shape = out .shape , expected = x .shape )
1137- unary_assert_against_refimpl ("exp" , x , out , math .exp )
1150+ refimpl = cmath .exp if x .dtype in dh .complex_dtypes else math .exp
1151+ unary_assert_against_refimpl ("exp" , x , out , refimpl )
11381152
11391153
11401154@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
11411155def test_expm1 (x ):
11421156 out = xp .expm1 (x )
11431157 ph .assert_dtype ("expm1" , in_dtype = x .dtype , out_dtype = out .dtype )
11441158 ph .assert_shape ("expm1" , out_shape = out .shape , expected = x .shape )
1145- unary_assert_against_refimpl ("expm1" , x , out , math .expm1 )
1159+ if x .dtype in dh .complex_dtypes :
1160+ def refimpl (z ):
1161+ # There's no cmath.expm1. Use
1162+ #
1163+ # exp(x+yi) - 1
1164+ # = exp(x)exp(yi) - 1
1165+ # = exp(x)(cos(y) + sin(y)i) - 1
1166+ # = (exp(x) - 1)cos(y) + (cos(y) - 1) + exp(x)sin(y)i
1167+ # = expm1(x)cos(y) - 2sin(y/2)^2 + exp(x)sin(y)i
1168+ #
1169+ # where 1 - cos(y) = 2sin(y/2)^2 is used to avoid loss of
1170+ # significance near y = 0.
1171+ re , im = z .real , z .imag
1172+ return math .expm1 (re )* math .cos (im ) - 2 * math .sin (im / 2 )** 2 + 1j * math .exp (re )* math .sin (im )
1173+ else :
1174+ refimpl = math .expm1
1175+ unary_assert_against_refimpl ("expm1" , x , out , refimpl )
11461176
11471177
11481178@given (hh .arrays (dtype = hh .real_dtypes , shape = hh .shapes ()))
11491179def test_floor (x ):
11501180 out = xp .floor (x )
11511181 ph .assert_dtype ("floor" , in_dtype = x .dtype , out_dtype = out .dtype )
11521182 ph .assert_shape ("floor" , out_shape = out .shape , expected = x .shape )
1153- unary_assert_against_refimpl ("floor" , x , out , math .floor , strict_check = True )
1183+ if x .dtype in dh .complex_dtypes :
1184+ def refimpl (z ):
1185+ return complex (math .floor (z .real ), math .floor (z .imag ))
1186+ else :
1187+ refimpl = math .floor
1188+ unary_assert_against_refimpl ("floor" , x , out , refimpl , strict_check = True )
11541189
11551190
11561191@pytest .mark .parametrize ("ctx" , make_binary_params ("floor_divide" , dh .real_dtypes ))
@@ -1236,23 +1271,26 @@ def test_isfinite(x):
12361271 out = xp .isfinite (x )
12371272 ph .assert_dtype ("isfinite" , in_dtype = x .dtype , out_dtype = out .dtype , expected = xp .bool )
12381273 ph .assert_shape ("isfinite" , out_shape = out .shape , expected = x .shape )
1239- unary_assert_against_refimpl ("isfinite" , x , out , math .isfinite , res_stype = bool )
1274+ refimpl = cmath .isfinite if x .dtype in dh .complex_dtypes else math .isfinite
1275+ unary_assert_against_refimpl ("isfinite" , x , out , refimpl , res_stype = bool )
12401276
12411277
12421278@given (hh .arrays (dtype = hh .numeric_dtypes , shape = hh .shapes ()))
12431279def test_isinf (x ):
12441280 out = xp .isinf (x )
12451281 ph .assert_dtype ("isfinite" , in_dtype = x .dtype , out_dtype = out .dtype , expected = xp .bool )
12461282 ph .assert_shape ("isinf" , out_shape = out .shape , expected = x .shape )
1247- unary_assert_against_refimpl ("isinf" , x , out , math .isinf , res_stype = bool )
1283+ refimpl = cmath .isinf if x .dtype in dh .complex_dtypes else math .isinf
1284+ unary_assert_against_refimpl ("isinf" , x , out , refimpl , res_stype = bool )
12481285
12491286
12501287@given (hh .arrays (dtype = hh .numeric_dtypes , shape = hh .shapes ()))
12511288def test_isnan (x ):
12521289 out = xp .isnan (x )
12531290 ph .assert_dtype ("isnan" , in_dtype = x .dtype , out_dtype = out .dtype , expected = xp .bool )
12541291 ph .assert_shape ("isnan" , out_shape = out .shape , expected = x .shape )
1255- unary_assert_against_refimpl ("isnan" , x , out , math .isnan , res_stype = bool )
1292+ refimpl = cmath .isnan if x .dtype in dh .complex_dtypes else math .isnan
1293+ unary_assert_against_refimpl ("isnan" , x , out , refimpl , res_stype = bool )
12561294
12571295
12581296@pytest .mark .parametrize ("ctx" , make_binary_params ("less" , dh .real_dtypes ))
@@ -1300,8 +1338,10 @@ def test_log(x):
13001338 out = xp .log (x )
13011339 ph .assert_dtype ("log" , in_dtype = x .dtype , out_dtype = out .dtype )
13021340 ph .assert_shape ("log" , out_shape = out .shape , expected = x .shape )
1341+ refimpl = cmath .log if x .dtype in dh .complex_dtypes else math .log
1342+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and s > 0
13031343 unary_assert_against_refimpl (
1304- "log" , x , out , math . log , filter_ = lambda s : default_filter ( s ) and s >= 1
1344+ "log" , x , out , refimpl , filter_ = filter_
13051345 )
13061346
13071347
@@ -1310,8 +1350,19 @@ def test_log1p(x):
13101350 out = xp .log1p (x )
13111351 ph .assert_dtype ("log1p" , in_dtype = x .dtype , out_dtype = out .dtype )
13121352 ph .assert_shape ("log1p" , out_shape = out .shape , expected = x .shape )
1353+ # There isn't a cmath.log1p, and implementing one isn't straightforward
1354+ # (see
1355+ # https://stackoverflow.com/questions/78318212/unexpected-behaviour-of-log1p-numpy).
1356+ # For now, just use log(1+p) for complex inputs, which should hopefully be
1357+ # fine given the very loose numerical tolerances we use. If it isn't, we
1358+ # can try using something like a series expansion for small p.
1359+ if x .dtype in dh .complex_dtypes :
1360+ refimpl = lambda z : cmath .log (1 + z )
1361+ else :
1362+ refimpl = math .log1p
1363+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and s > - 1
13131364 unary_assert_against_refimpl (
1314- "log1p" , x , out , math . log1p , filter_ = lambda s : default_filter ( s ) and s >= 1
1365+ "log1p" , x , out , refimpl , filter_ = filter_
13151366 )
13161367
13171368
@@ -1320,8 +1371,13 @@ def test_log2(x):
13201371 out = xp .log2 (x )
13211372 ph .assert_dtype ("log2" , in_dtype = x .dtype , out_dtype = out .dtype )
13221373 ph .assert_shape ("log2" , out_shape = out .shape , expected = x .shape )
1374+ if x .dtype in dh .complex_dtypes :
1375+ refimpl = lambda z : cmath .log (z )/ math .log (2 )
1376+ else :
1377+ refimpl = math .log2
1378+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and s > 0
13231379 unary_assert_against_refimpl (
1324- "log2" , x , out , math . log2 , filter_ = lambda s : default_filter ( s ) and s > 1
1380+ "log2" , x , out , refimpl , filter_ = filter_
13251381 )
13261382
13271383
@@ -1330,12 +1386,17 @@ def test_log10(x):
13301386 out = xp .log10 (x )
13311387 ph .assert_dtype ("log10" , in_dtype = x .dtype , out_dtype = out .dtype )
13321388 ph .assert_shape ("log10" , out_shape = out .shape , expected = x .shape )
1389+ if x .dtype in dh .complex_dtypes :
1390+ refimpl = lambda z : cmath .log (z )/ math .log (10 )
1391+ else :
1392+ refimpl = math .log10
1393+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and s > 0
13331394 unary_assert_against_refimpl (
1334- "log10" , x , out , math . log10 , filter_ = lambda s : default_filter ( s ) and s > 0
1395+ "log10" , x , out , refimpl , filter_ = filter_
13351396 )
13361397
13371398
1338- def logaddexp (l : float , r : float ) -> float :
1399+ def logaddexp_refimpl (l : float , r : float ) -> float :
13391400 return math .log (math .exp (l ) + math .exp (r ))
13401401
13411402
@@ -1344,7 +1405,7 @@ def test_logaddexp(x1, x2):
13441405 out = xp .logaddexp (x1 , x2 )
13451406 ph .assert_dtype ("logaddexp" , in_dtype = [x1 .dtype , x2 .dtype ], out_dtype = out .dtype )
13461407 ph .assert_result_shape ("logaddexp" , in_shapes = [x1 .shape , x2 .shape ], out_shape = out .shape )
1347- binary_assert_against_refimpl ("logaddexp" , x1 , x2 , out , logaddexp )
1408+ binary_assert_against_refimpl ("logaddexp" , x1 , x2 , out , logaddexp_refimpl )
13481409
13491410
13501411@given (* hh .two_mutual_arrays ([xp .bool ]))
@@ -1521,7 +1582,11 @@ def test_round(x):
15211582 out = xp .round (x )
15221583 ph .assert_dtype ("round" , in_dtype = x .dtype , out_dtype = out .dtype )
15231584 ph .assert_shape ("round" , out_shape = out .shape , expected = x .shape )
1524- unary_assert_against_refimpl ("round" , x , out , round , strict_check = True )
1585+ if x .dtype in dh .complex_dtypes :
1586+ refimpl = lambda z : complex (round (z .real ), round (z .imag ))
1587+ else :
1588+ refimpl = round
1589+ unary_assert_against_refimpl ("round" , x , out , refimpl , strict_check = True )
15251590
15261591
15271592@pytest .mark .min_version ("2023.12" )
@@ -1539,13 +1604,12 @@ def test_sign(x):
15391604 out = xp .sign (x )
15401605 ph .assert_dtype ("sign" , in_dtype = x .dtype , out_dtype = out .dtype )
15411606 ph .assert_shape ("sign" , out_shape = out .shape , expected = x .shape )
1542- refimpl = lambda x : x / math . abs (x ) if x != 0 else 0
1607+ refimpl = lambda x : x / abs (x ) if x != 0 else 0
15431608 unary_assert_against_refimpl (
15441609 "sign" ,
15451610 x ,
15461611 out ,
15471612 refimpl ,
1548- filter_ = lambda s : s != 0 ,
15491613 strict_check = True ,
15501614 )
15511615
@@ -1555,15 +1619,17 @@ def test_sin(x):
15551619 out = xp .sin (x )
15561620 ph .assert_dtype ("sin" , in_dtype = x .dtype , out_dtype = out .dtype )
15571621 ph .assert_shape ("sin" , out_shape = out .shape , expected = x .shape )
1558- unary_assert_against_refimpl ("sin" , x , out , math .sin )
1622+ refimpl = cmath .sin if x .dtype in dh .complex_dtypes else math .sin
1623+ unary_assert_against_refimpl ("sin" , x , out , refimpl )
15591624
15601625
15611626@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
15621627def test_sinh (x ):
15631628 out = xp .sinh (x )
15641629 ph .assert_dtype ("sinh" , in_dtype = x .dtype , out_dtype = out .dtype )
15651630 ph .assert_shape ("sinh" , out_shape = out .shape , expected = x .shape )
1566- unary_assert_against_refimpl ("sinh" , x , out , math .sinh )
1631+ refimpl = cmath .sinh if x .dtype in dh .complex_dtypes else math .sinh
1632+ unary_assert_against_refimpl ("sinh" , x , out , refimpl )
15671633
15681634
15691635@given (hh .arrays (dtype = hh .numeric_dtypes , shape = hh .shapes ()))
@@ -1581,8 +1647,10 @@ def test_sqrt(x):
15811647 out = xp .sqrt (x )
15821648 ph .assert_dtype ("sqrt" , in_dtype = x .dtype , out_dtype = out .dtype )
15831649 ph .assert_shape ("sqrt" , out_shape = out .shape , expected = x .shape )
1650+ refimpl = cmath .sqrt if x .dtype in dh .complex_dtypes else math .sqrt
1651+ filter_ = default_filter if x .dtype in dh .complex_dtypes else lambda s : default_filter (s ) and s >= 0
15841652 unary_assert_against_refimpl (
1585- "sqrt" , x , out , math . sqrt , filter_ = lambda s : default_filter ( s ) and s >= 0
1653+ "sqrt" , x , out , refimpl , filter_ = filter_
15861654 )
15871655
15881656
@@ -1605,15 +1673,17 @@ def test_tan(x):
16051673 out = xp .tan (x )
16061674 ph .assert_dtype ("tan" , in_dtype = x .dtype , out_dtype = out .dtype )
16071675 ph .assert_shape ("tan" , out_shape = out .shape , expected = x .shape )
1608- unary_assert_against_refimpl ("tan" , x , out , math .tan )
1676+ refimpl = cmath .tan if x .dtype in dh .complex_dtypes else math .tan
1677+ unary_assert_against_refimpl ("tan" , x , out , refimpl )
16091678
16101679
16111680@given (hh .arrays (dtype = hh .all_floating_dtypes (), shape = hh .shapes ()))
16121681def test_tanh (x ):
16131682 out = xp .tanh (x )
16141683 ph .assert_dtype ("tanh" , in_dtype = x .dtype , out_dtype = out .dtype )
16151684 ph .assert_shape ("tanh" , out_shape = out .shape , expected = x .shape )
1616- unary_assert_against_refimpl ("tanh" , x , out , math .tanh )
1685+ refimpl = cmath .tanh if x .dtype in dh .complex_dtypes else math .tanh
1686+ unary_assert_against_refimpl ("tanh" , x , out , refimpl )
16171687
16181688
16191689@given (hh .arrays (dtype = hh .real_dtypes , shape = xps .array_shapes ()))
0 commit comments