diff --git a/src/maybe_nan/impl_not_none.rs b/src/maybe_nan/impl_not_none.rs index 2ab4f07..e437de4 100644 --- a/src/maybe_nan/impl_not_none.rs +++ b/src/maybe_nan/impl_not_none.rs @@ -1,5 +1,5 @@ use super::NotNone; -use num_traits::{FromPrimitive, ToPrimitive}; +use num_traits::{Euclid, FromPrimitive, ToPrimitive}; use std::cmp; use std::fmt; use std::ops::{Add, Deref, DerefMut, Div, Mul, Rem, Sub}; @@ -101,6 +101,19 @@ impl Rem for NotNone { } } +impl Euclid for NotNone { + #[inline] + fn div_euclid(&self, rhs: &Self) -> Self { + let result = self.deref().div_euclid(rhs.deref()); + NotNone(Some(result)) + } + #[inline] + fn rem_euclid(&self, rhs: &Self) -> Self { + let result = self.deref().rem_euclid(rhs.deref()); + NotNone(Some(result)) + } +} + impl ToPrimitive for NotNone { #[inline] fn to_isize(&self) -> Option { diff --git a/src/quantile/interpolate.rs b/src/quantile/interpolate.rs index 0b7871d..f006994 100644 --- a/src/quantile/interpolate.rs +++ b/src/quantile/interpolate.rs @@ -1,6 +1,8 @@ //! Interpolation strategies. use noisy_float::types::N64; -use num_traits::{Float, FromPrimitive, NumOps, ToPrimitive}; +use num_traits::{Euclid, Float, FromPrimitive, NumOps, ToPrimitive}; + +use crate::maybe_nan::NotNone; fn float_quantile_index(q: N64, len: usize) -> N64 { q * ((len - 1) as f64) @@ -104,25 +106,69 @@ impl Interpolate for Nearest { private_impl! {} } -impl Interpolate for Midpoint -where - T: NumOps + Clone + FromPrimitive, -{ - fn needs_lower(_q: N64, _len: usize) -> bool { - true - } - fn needs_higher(_q: N64, _len: usize) -> bool { - true +macro_rules! impl_midpoint_interpolate_for_float { + ($($t:ty),*) => { + $( + impl Interpolate<$t> for Midpoint { + fn needs_lower(_q: N64, _len: usize) -> bool { + true + } + fn needs_higher(_q: N64, _len: usize) -> bool { + true + } + fn interpolate(lower: Option<$t>, higher: Option<$t>, _q: N64, _len: usize) -> $t { + let lower = lower.unwrap(); + let higher = higher.unwrap(); + lower + (higher - lower) / 2.0 + } + private_impl! {} + } + )* } - fn interpolate(lower: Option, higher: Option, _q: N64, _len: usize) -> T { - let denom = T::from_u8(2).unwrap(); - let lower = lower.unwrap(); - let higher = higher.unwrap(); - lower.clone() + (higher.clone() - lower.clone()) / denom.clone() +} + +impl_midpoint_interpolate_for_float!(f32, f64); + +macro_rules! impl_midpoint_interpolate_for_integer { + ($($t:ty),*) => { + $( + impl Interpolate<$t> for Midpoint { + fn needs_lower(_q: N64, _len: usize) -> bool { + true + } + fn needs_higher(_q: N64, _len: usize) -> bool { + true + } + fn interpolate(lower: Option<$t>, higher: Option<$t>, _q: N64, _len: usize) -> $t { + let two = <$t>::from_u8(2).unwrap(); + let (lower_half, lower_rem) = lower.unwrap().div_rem_euclid(&two); + let (higher_half, higher_rem) = higher.unwrap().div_rem_euclid(&two); + lower_half + higher_half + (lower_rem * higher_rem) + } + private_impl! {} + } + )* } - private_impl! {} } +impl_midpoint_interpolate_for_integer!( + i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize +); +impl_midpoint_interpolate_for_integer!( + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone +); + impl Interpolate for Linear where T: NumOps + Clone + FromPrimitive + ToPrimitive, @@ -143,3 +189,46 @@ where } private_impl! {} } + +#[cfg(test)] +mod tests { + use super::*; + use noisy_float::types::n64; + use quickcheck::TestResult; + use quickcheck_macros::quickcheck; + + #[derive(Clone, Copy, Debug)] + struct LowerHigherPair(T, T); + + impl quickcheck::Arbitrary for LowerHigherPair { + fn arbitrary(g: &mut G) -> Self { + let (l, h) = loop { + let (l, h) = (i64::arbitrary(g), i64::arbitrary(g)); + if l > h || h.checked_sub(l).is_none() { + continue; + } + break (l, h); + }; + LowerHigherPair(l, h) + } + } + + impl From> for (i64, i64) { + fn from(value: LowerHigherPair) -> Self { + (value.0, value.1) + } + } + + fn naive_midpoint_i64(lower: i64, higher: i64) -> i64 { + // Overflows when higher is very big and lower is very small + lower + (higher - lower) / 2 + } + + #[quickcheck] + fn test_midpoint_algo_eq_naive_algo_i64(lh: LowerHigherPair) -> TestResult { + let (lower, higher) = lh.into(); + let naive = naive_midpoint_i64(lower, higher); + let midpoint = Midpoint::interpolate(Some(lower), Some(higher), n64(0.0), 0); + TestResult::from_bool(naive == midpoint) + } +} diff --git a/tests/quantile.rs b/tests/quantile.rs index 9d58071..5aa3e68 100644 --- a/tests/quantile.rs +++ b/tests/quantile.rs @@ -268,7 +268,7 @@ fn test_quantile_axis_skipnan_mut_linear_opt_i32() { } #[test] -fn test_midpoint_overflow() { +fn test_midpoint_overflow_unsigned() { // Regression test // This triggered an overflow panic with a naive Midpoint implementation: (a+b)/2 let mut a: Array1 = array![129, 130, 130, 131]; @@ -277,6 +277,16 @@ fn test_midpoint_overflow() { assert_eq!(median, expected_median); } +#[test] +fn test_midpoint_overflow_signed() { + // Regression test + // This triggered an overflow panic with a naive Midpoint implementation: b+(a-b)/2 + let mut a: Array1 = array![i64::MIN, i64::MAX]; + let median = a.quantile_mut(n64(0.5), &Midpoint).unwrap(); + let expected_median = -1; + assert_eq!(median, expected_median); +} + #[quickcheck] fn test_quantiles_mut(xs: Vec) -> bool { let v = Array::from(xs.clone());