@@ -281,40 +281,69 @@ impl<T: Float> Complex<T> {
281281 ///
282282 /// The branch satisfies `-π/2 ≤ arg(sqrt(z)) ≤ π/2`.
283283 #[ inline]
284- pub fn sqrt ( self ) -> Self {
285- if self . im . is_zero ( ) {
286- if self . re . is_sign_positive ( ) {
287- // simple positive real √r, and copy `im` for its sign
288- Self :: new ( self . re . sqrt ( ) , self . im )
289- } else {
290- // √(r e^(iπ)) = √r e^(iπ/2) = i√r
291- // √(r e^(-iπ)) = √r e^(-iπ/2) = -i√r
292- let re = T :: zero ( ) ;
293- let im = ( -self . re ) . sqrt ( ) ;
294- if self . im . is_sign_positive ( ) {
295- Self :: new ( re, im)
296- } else {
297- Self :: new ( re, -im)
298- }
299- }
300- } else if self . re . is_zero ( ) {
301- // √(r e^(iπ/2)) = √r e^(iπ/4) = √(r/2) + i√(r/2)
302- // √(r e^(-iπ/2)) = √r e^(-iπ/4) = √(r/2) - i√(r/2)
303- let one = T :: one ( ) ;
304- let two = one + one;
305- let x = ( self . im . abs ( ) / two) . sqrt ( ) ;
306- if self . im . is_sign_positive ( ) {
307- Self :: new ( x, x)
284+ #[ allow( clippy:: eq_op) ]
285+ pub fn sqrt ( mut self ) -> Self {
286+ // TODO: rounding for very tiny subnormal numbers isn't perfect yet so
287+ // the assert shown fails in the very worst case this leads to about
288+ // 10% accuracy loss (see example below). As the magnitude increase the
289+ // error quickly drops to basically zero.
290+ //
291+ // glibc handles that (but other implementations like musl and numpy do
292+ // not) by upscaling very small values. That upscaling (and particularly
293+ // it's reversal) are weird and hard to understand (and rely on mantissa
294+ // bit size which we can't get out of the trait). In general the glibc
295+ // implementation is ever so subtley different and I wouldn't want to
296+ // introduce bugs by trying to adapt the underflow handling.
297+ //
298+ // assert_eq!(
299+ // Complex64::new(5.212e-324, 5.212e-324).sqrt(),
300+ // Complex64::new(2.4421097261308304e-162, 1.0115549693666347e-162)
301+ // );
302+
303+ if self . re . is_zero ( ) && self . im . is_zero ( ) {
304+ // 0 +/- 0 i
305+ return Self :: new ( T :: zero ( ) , self . im ) ;
306+ }
307+ if self . im . is_infinite ( ) {
308+ // inf +/- inf i
309+ return Self :: new ( T :: infinity ( ) , self . im ) ;
310+ }
311+ if self . re . is_nan ( ) {
312+ // nan + nan i
313+ return Self :: new ( self . re , ( self . im - self . im ) / ( self . im - self . im ) ) ;
314+ }
315+ if self . re . is_infinite ( ) {
316+ // √(inf +/- NaN i) = inf +/- NaN i
317+ // √(inf +/- x i) = inf +/- 0 i
318+ // √(-inf +/- NaN i) = NaN +/- inf i
319+ // √(-inf +/- x i) = 0 +/- inf i
320+
321+ if self . re . is_sign_negative ( ) {
322+ return Self :: new ( ( self . im - self . im ) . abs ( ) , self . re . copysign ( self . im ) ) ;
308323 } else {
309- Self :: new ( x , -x )
324+ return Self :: new ( self . re , ( self . im - self . im ) . copysign ( self . im ) ) ;
310325 }
326+ }
327+ let two = T :: one ( ) + T :: one ( ) ;
328+ let four = two + two;
329+ let overflow = T :: max_value ( ) / ( T :: one ( ) + T :: sqrt ( two) ) ;
330+ let max_magnitude = self . re . abs ( ) . max ( self . im . abs ( ) ) ;
331+ let scale = max_magnitude >= overflow;
332+ if scale {
333+ self = self / four;
334+ }
335+ if self . re . is_sign_negative ( ) {
336+ let tmp = ( ( -self . re + self . norm ( ) ) / two) . sqrt ( ) ;
337+ self . re = self . im . abs ( ) / ( two * tmp) ;
338+ self . im = tmp. copysign ( self . im ) ;
311339 } else {
312- // formula: sqrt(r e^(it)) = sqrt(r) e^(it/2)
313- let one = T :: one ( ) ;
314- let two = one + one;
315- let ( r, theta) = self . to_polar ( ) ;
316- Self :: from_polar ( r. sqrt ( ) , theta / two)
340+ self . re = ( ( self . re + self . norm ( ) ) / two) . sqrt ( ) ;
341+ self . im = self . im / ( two * self . re ) ;
317342 }
343+ if scale {
344+ self = self * two;
345+ }
346+ self
318347 }
319348
320349 /// Computes the principal value of the cube root of `self`.
@@ -2065,6 +2094,34 @@ pub(crate) mod test {
20652094 }
20662095 }
20672096
2097+ #[ test]
2098+ fn test_sqrt_nan ( ) {
2099+ assert ! ( close_naninf(
2100+ Complex64 :: new( f64 :: INFINITY , f64 :: NAN ) . sqrt( ) ,
2101+ Complex64 :: new( f64 :: INFINITY , f64 :: NAN ) ,
2102+ ) ) ;
2103+ assert ! ( close_naninf(
2104+ Complex64 :: new( f64 :: NEG_INFINITY , -f64 :: NAN ) . sqrt( ) ,
2105+ Complex64 :: new( f64 :: NAN , f64 :: NEG_INFINITY ) ,
2106+ ) ) ;
2107+ assert ! ( close_naninf(
2108+ Complex64 :: new( f64 :: NEG_INFINITY , f64 :: NAN ) . sqrt( ) ,
2109+ Complex64 :: new( f64 :: NAN , f64 :: INFINITY ) ,
2110+ ) ) ;
2111+ for x in ( -100 ..100 ) . map ( f64:: from) {
2112+ // √(inf + x i) = inf + 0 i
2113+ assert ! ( close_naninf(
2114+ Complex64 :: new( f64 :: INFINITY , x) . sqrt( ) ,
2115+ Complex64 :: new( f64 :: INFINITY , 0.0 . copysign( x) ) ,
2116+ ) ) ;
2117+ // √(-inf + x i) = 0 + inf i
2118+ assert ! ( close_naninf(
2119+ Complex64 :: new( f64 :: NEG_INFINITY , x) . sqrt( ) ,
2120+ Complex64 :: new( 0.0 , f64 :: INFINITY . copysign( x) ) ,
2121+ ) ) ;
2122+ }
2123+ }
2124+
20682125 #[ test]
20692126 fn test_cbrt ( ) {
20702127 assert ! ( close( _0_0i. cbrt( ) , _0_0i) ) ;
0 commit comments