Skip to content

Commit cd730c8

Browse files
Merge pull request #537 from pydata/fix_tan
Fix numerical stability for tan/tanh
2 parents 2d24f28 + 74149e7 commit cd730c8

File tree

2 files changed

+40
-30
lines changed

2 files changed

+40
-30
lines changed

numexpr/complex_functions.hpp

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -390,42 +390,45 @@ nc_sinh(std::complex<double> *x, std::complex<double> *r)
390390
static void
391391
nc_tan(std::complex<double> *x, std::complex<double> *r)
392392
{
393-
double sr,cr,shi,chi;
394-
double rs,is,rc,ic;
395-
double d;
396-
double xr=x->real(), xi=x->imag();
397-
sr = sin(xr);
398-
cr = cos(xr);
399-
shi = sinh(xi);
400-
chi = cosh(xi);
401-
rs = sr*chi;
402-
is = cr*shi;
403-
rc = cr*chi;
404-
ic = -sr*shi;
405-
d = rc*rc + ic*ic;
406-
r->real((rs*rc+is*ic)/d);
407-
r->imag((is*rc-rs*ic)/d);
393+
double xr = x->real();
394+
double xi = x->imag();
395+
double imag_part;
396+
397+
double denom = cos(2*xr) + cosh(2*xi);
398+
// handle overflows
399+
if (xi > 20) {
400+
imag_part = 1.0 / (1.0 + exp(-4*xi));
401+
} else if (xi < -20) {
402+
imag_part = -1.0 / (1.0 + exp(4*xi));
403+
} else {
404+
imag_part = sinh(2*xi) / denom;
405+
}
406+
double real_part = sin(2*xr) / denom;
407+
408+
r->real(real_part);
409+
r->imag(imag_part);
408410
return;
409411
}
410412

411413
static void
412414
nc_tanh(std::complex<double> *x, std::complex<double> *r)
413415
{
414-
double si,ci,shr,chr;
415-
double rs,is,rc,ic;
416-
double d;
417-
double xr=x->real(), xi=x->imag();
418-
si = sin(xi);
419-
ci = cos(xi);
420-
shr = sinh(xr);
421-
chr = cosh(xr);
422-
rs = ci*shr;
423-
is = si*chr;
424-
rc = ci*chr;
425-
ic = si*shr;
426-
d = rc*rc + ic*ic;
427-
r->real((rs*rc+is*ic)/d);
428-
r->imag((is*rc-rs*ic)/d);
416+
double xr = x->real();
417+
double xi = x->imag();
418+
double real_part;
419+
double denom = cosh(2*xr) + cos(2*xi);
420+
// handle overflows
421+
if (xr > 20) {
422+
real_part = 1.0 / (1.0 + exp(-4*xr));
423+
} else if (xr < -20) {
424+
real_part = -1.0 / (1.0 + exp(4*xr));
425+
} else {
426+
real_part = sinh(2*xr) / denom;
427+
}
428+
double imag_part = sin(2*xi) / denom;
429+
430+
r->real(real_part);
431+
r->imag(imag_part);
429432
return;
430433
}
431434

numexpr/tests/test_numexpr.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,13 @@ def test_bitwise_operators(self):
480480
assert_array_equal(evaluate("x | y"), x | y) # or
481481
assert_array_equal(evaluate("~x"), ~x) # invert
482482

483+
def test_complex_tan(self):
484+
# old version of NumExpr had overflow problems
485+
x = np.arange(1, 400., step=16., dtype=np.complex128)
486+
y = 1j*np.arange(1, 400., step=16., dtype=np.complex128)
487+
assert_array_almost_equal(evaluate("tan(x + y)"), tan(x + y))
488+
assert_array_almost_equal(evaluate("tanh(x + y)"), tanh(x + y))
489+
483490
def test_maximum_minimum(self):
484491
for dtype in [float, double, int, np.int64]:
485492
x = arange(10, dtype=dtype)

0 commit comments

Comments
 (0)