@@ -281,40 +281,87 @@ impl<T: Float> Complex<T> {
281281 ///
282282 /// The branch satisfies `-π/2 ≤ arg(sqrt(z)) ≤ π/2`.
283283 #[ inline]
284- pub fn sqrt ( self ) -> Self {
285- if self . im . is_zero ( ) {
286- if self . re . is_sign_positive ( ) {
287- // simple positive real √r, and copy `im` for its sign
288- Self :: new ( self . re . sqrt ( ) , self . im )
284+ pub fn sqrt ( mut self ) -> Self {
285+ // complex sqrt algorithm based on the algorithm from
286+ // dl.acm.org/doi/abs/10.1145/363717.363780 with additional tweaks
287+ // to increase accuracy. Compared to a naive implementationt that
288+ // reuses the complex exp/ln implementations this algorithm has better
289+ // accuarcy since both (real) sqrt and (real) hypot are garunteed to
290+ // round perfectly. It's also faster since this implementation requires
291+ // less transcendental functions and those it does use (sqrt/hypto) are
292+ // faster comparted to exp/sin/cos.
293+ //
294+ // The musl libc implementation was referenced while implementing the
295+ // algorithm here:
296+ // https://git.musl-libc.org/cgit/musl/tree/src/complex/csqrt.c
297+
298+ // TODO: rounding for very tiny subnormal numbers isn't perfect yet so
299+ // the assert shown fails in the very worst case this leads to about
300+ // 10% accuracy loss (see example below). As the magnitude increase the
301+ // error quickly drops to basically zero.
302+ //
303+ // glibc handles that (but other implementations like musl and numpy do
304+ // not) by upscaling very small values. That upscaling (and particularly
305+ // it's reversal) are weird and hard to understand (and rely on mantissa
306+ // bit size which we can't get out of the trait). In general the glibc
307+ // implementation is ever so subtley different and I wouldn't want to
308+ // introduce bugs by trying to adapt the underflow handling.
309+ //
310+ // assert_eq!(
311+ // Complex64::new(5.212e-324, 5.212e-324).sqrt(),
312+ // Complex64::new(2.4421097261308304e-162, 1.0115549693666347e-162)
313+ // );
314+
315+ // specical cases for correct nan/inf handling
316+ // see https://en.cppreference.com/w/c/numeric/complex/csqrt
317+
318+ if self . re . is_zero ( ) && self . im . is_zero ( ) {
319+ // 0 +/- 0 i
320+ return Self :: new ( T :: zero ( ) , self . im ) ;
321+ }
322+ if self . im . is_infinite ( ) {
323+ // inf +/- inf i
324+ return Self :: new ( T :: infinity ( ) , self . im ) ;
325+ }
326+ if self . re . is_nan ( ) {
327+ // nan + nan i
328+ return Self :: new ( self . re , T :: nan ( ) ) ;
329+ }
330+ if self . re . is_infinite ( ) {
331+ // √(inf +/- NaN i) = inf +/- NaN i
332+ // √(inf +/- x i) = inf +/- 0 i
333+ // √(-inf +/- NaN i) = NaN +/- inf i
334+ // √(-inf +/- x i) = 0 +/- inf i
335+
336+ // if im is inf (or nan) this is nan, otherwise it's zero
337+ #[ allow( clippy:: eq_op) ]
338+ let zero_or_nan = self . im - self . im ;
339+ if self . re . is_sign_negative ( ) {
340+ return Self :: new ( zero_or_nan. abs ( ) , self . re . copysign ( self . im ) ) ;
289341 } else {
290- // √(r e^(iπ)) = √r e^(iπ/2) = i√r
291- // √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r
292- let re = T :: zero ( ) ;
293- let im = ( -self . re ) . sqrt ( ) ;
294- if self . im . is_sign_positive ( ) {
295- Self :: new ( re, im)
296- } else {
297- Self :: new ( re, -im)
298- }
299- }
300- } else if self . re . is_zero ( ) {
301- // √(r e^(iπ/2)) = √r e^(iπ/4) = √(r/2) + i√(r/2)
302- // √(r e^(-iπ/2)) = √r e^(-iπ/4) = √(r/2) - i√(r/2)
303- let one = T :: one ( ) ;
304- let two = one + one;
305- let x = ( self . im . abs ( ) / two) . sqrt ( ) ;
306- if self . im . is_sign_positive ( ) {
307- Self :: new ( x, x)
308- } else {
309- Self :: new ( x, -x)
342+ return Self :: new ( self . re , zero_or_nan. copysign ( self . im ) ) ;
310343 }
344+ }
345+ let two = T :: one ( ) + T :: one ( ) ;
346+ let four = two + two;
347+ let overflow = T :: max_value ( ) / ( T :: one ( ) + T :: sqrt ( two) ) ;
348+ let max_magnitude = self . re . abs ( ) . max ( self . im . abs ( ) ) ;
349+ let scale = max_magnitude >= overflow;
350+ if scale {
351+ self = self / four;
352+ }
353+ if self . re . is_sign_negative ( ) {
354+ let tmp = ( ( -self . re + self . norm ( ) ) / two) . sqrt ( ) ;
355+ self . re = self . im . abs ( ) / ( two * tmp) ;
356+ self . im = tmp. copysign ( self . im ) ;
311357 } else {
312- // formula: sqrt(r e^(it )) = sqrt(r) e^(it/2)
313- let one = T :: one ( ) ;
314- let two = one + one ;
315- let ( r , theta ) = self . to_polar ( ) ;
316- Self :: from_polar ( r . sqrt ( ) , theta / two)
358+ self . re = ( ( self . re + self . norm ( ) ) / two ) . sqrt ( ) ;
359+ self . im = self . im / ( two * self . re ) ;
360+ }
361+ if scale {
362+ self = self * two;
317363 }
364+ self
318365 }
319366
320367 /// Computes the principal value of the cube root of `self`.
@@ -2065,6 +2112,50 @@ pub(crate) mod test {
20652112 }
20662113 }
20672114
2115+ #[ test]
2116+ fn test_sqrt_nan ( ) {
2117+ assert ! ( close_naninf(
2118+ Complex64 :: new( f64 :: INFINITY , f64 :: NAN ) . sqrt( ) ,
2119+ Complex64 :: new( f64 :: INFINITY , f64 :: NAN ) ,
2120+ ) ) ;
2121+ assert ! ( close_naninf(
2122+ Complex64 :: new( f64 :: NAN , f64 :: INFINITY ) . sqrt( ) ,
2123+ Complex64 :: new( f64 :: INFINITY , f64 :: INFINITY ) ,
2124+ ) ) ;
2125+ assert ! ( close_naninf(
2126+ Complex64 :: new( f64 :: NEG_INFINITY , -f64 :: NAN ) . sqrt( ) ,
2127+ Complex64 :: new( f64 :: NAN , f64 :: NEG_INFINITY ) ,
2128+ ) ) ;
2129+ assert ! ( close_naninf(
2130+ Complex64 :: new( f64 :: NEG_INFINITY , f64 :: NAN ) . sqrt( ) ,
2131+ Complex64 :: new( f64 :: NAN , f64 :: INFINITY ) ,
2132+ ) ) ;
2133+ assert ! ( close_naninf(
2134+ Complex64 :: new( -0.0 , 0.0 ) . sqrt( ) ,
2135+ Complex64 :: new( 0.0 , 0.0 ) ,
2136+ ) ) ;
2137+ for x in ( -100 ..100 ) . map ( f64:: from) {
2138+ assert ! ( close_naninf(
2139+ Complex64 :: new( x, f64 :: INFINITY ) . sqrt( ) ,
2140+ Complex64 :: new( f64 :: INFINITY , f64 :: INFINITY ) ,
2141+ ) ) ;
2142+ assert ! ( close_naninf(
2143+ Complex64 :: new( f64 :: NAN , x) . sqrt( ) ,
2144+ Complex64 :: new( f64 :: NAN , f64 :: NAN ) ,
2145+ ) ) ;
2146+ // √(inf + x i) = inf + 0 i
2147+ assert ! ( close_naninf(
2148+ Complex64 :: new( f64 :: INFINITY , x) . sqrt( ) ,
2149+ Complex64 :: new( f64 :: INFINITY , 0.0 . copysign( x) ) ,
2150+ ) ) ;
2151+ // √(-inf + x i) = 0 + inf i
2152+ assert ! ( close_naninf(
2153+ Complex64 :: new( f64 :: NEG_INFINITY , x) . sqrt( ) ,
2154+ Complex64 :: new( 0.0 , f64 :: INFINITY . copysign( x) ) ,
2155+ ) ) ;
2156+ }
2157+ }
2158+
20682159 #[ test]
20692160 fn test_cbrt ( ) {
20702161 assert ! ( close( _0_0i. cbrt( ) , _0_0i) ) ;
0 commit comments