From 15307eb80333442dcb2421f36f85bbdae2d65b81 Mon Sep 17 00:00:00 2001 From: Miikka Salminen Date: Sat, 25 Oct 2025 15:10:51 +0300 Subject: [PATCH] Fix signed overflow with midpoint interpolation The simple algorithm is replaced with an Euclid division and remainder based slightly more complex one when the inputs are integers. Moreover, the implementation is separated for integers and floating point numbers, relying on macros instead of generics with trait bounds. A regression test that failed before the changes is added along with property-based testing ensuring that the results match those that the previous version output (when within the non-overflowing limits). --- src/maybe_nan/impl_not_none.rs | 15 +++- src/quantile/interpolate.rs | 121 ++++++++++++++++++++++++++++----- tests/quantile.rs | 12 +++- 3 files changed, 130 insertions(+), 18 deletions(-) diff --git a/src/maybe_nan/impl_not_none.rs b/src/maybe_nan/impl_not_none.rs index 2ab4f075..e437de4e 100644 --- a/src/maybe_nan/impl_not_none.rs +++ b/src/maybe_nan/impl_not_none.rs @@ -1,5 +1,5 @@ use super::NotNone; -use num_traits::{FromPrimitive, ToPrimitive}; +use num_traits::{Euclid, FromPrimitive, ToPrimitive}; use std::cmp; use std::fmt; use std::ops::{Add, Deref, DerefMut, Div, Mul, Rem, Sub}; @@ -101,6 +101,19 @@ impl Rem for NotNone { } } +impl Euclid for NotNone { + #[inline] + fn div_euclid(&self, rhs: &Self) -> Self { + let result = self.deref().div_euclid(rhs.deref()); + NotNone(Some(result)) + } + #[inline] + fn rem_euclid(&self, rhs: &Self) -> Self { + let result = self.deref().rem_euclid(rhs.deref()); + NotNone(Some(result)) + } +} + impl ToPrimitive for NotNone { #[inline] fn to_isize(&self) -> Option { diff --git a/src/quantile/interpolate.rs b/src/quantile/interpolate.rs index 0b7871d9..f0069949 100644 --- a/src/quantile/interpolate.rs +++ b/src/quantile/interpolate.rs @@ -1,6 +1,8 @@ //! Interpolation strategies. use noisy_float::types::N64; -use num_traits::{Float, FromPrimitive, NumOps, ToPrimitive}; +use num_traits::{Euclid, Float, FromPrimitive, NumOps, ToPrimitive}; + +use crate::maybe_nan::NotNone; fn float_quantile_index(q: N64, len: usize) -> N64 { q * ((len - 1) as f64) @@ -104,25 +106,69 @@ impl Interpolate for Nearest { private_impl! {} } -impl Interpolate for Midpoint -where - T: NumOps + Clone + FromPrimitive, -{ - fn needs_lower(_q: N64, _len: usize) -> bool { - true - } - fn needs_higher(_q: N64, _len: usize) -> bool { - true +macro_rules! impl_midpoint_interpolate_for_float { + ($($t:ty),*) => { + $( + impl Interpolate<$t> for Midpoint { + fn needs_lower(_q: N64, _len: usize) -> bool { + true + } + fn needs_higher(_q: N64, _len: usize) -> bool { + true + } + fn interpolate(lower: Option<$t>, higher: Option<$t>, _q: N64, _len: usize) -> $t { + let lower = lower.unwrap(); + let higher = higher.unwrap(); + lower + (higher - lower) / 2.0 + } + private_impl! {} + } + )* } - fn interpolate(lower: Option, higher: Option, _q: N64, _len: usize) -> T { - let denom = T::from_u8(2).unwrap(); - let lower = lower.unwrap(); - let higher = higher.unwrap(); - lower.clone() + (higher.clone() - lower.clone()) / denom.clone() +} + +impl_midpoint_interpolate_for_float!(f32, f64); + +macro_rules! impl_midpoint_interpolate_for_integer { + ($($t:ty),*) => { + $( + impl Interpolate<$t> for Midpoint { + fn needs_lower(_q: N64, _len: usize) -> bool { + true + } + fn needs_higher(_q: N64, _len: usize) -> bool { + true + } + fn interpolate(lower: Option<$t>, higher: Option<$t>, _q: N64, _len: usize) -> $t { + let two = <$t>::from_u8(2).unwrap(); + let (lower_half, lower_rem) = lower.unwrap().div_rem_euclid(&two); + let (higher_half, higher_rem) = higher.unwrap().div_rem_euclid(&two); + lower_half + higher_half + (lower_rem * higher_rem) + } + private_impl! {} + } + )* } - private_impl! {} } +impl_midpoint_interpolate_for_integer!( + i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize +); +impl_midpoint_interpolate_for_integer!( + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone, + NotNone +); + impl Interpolate for Linear where T: NumOps + Clone + FromPrimitive + ToPrimitive, @@ -143,3 +189,46 @@ where } private_impl! {} } + +#[cfg(test)] +mod tests { + use super::*; + use noisy_float::types::n64; + use quickcheck::TestResult; + use quickcheck_macros::quickcheck; + + #[derive(Clone, Copy, Debug)] + struct LowerHigherPair(T, T); + + impl quickcheck::Arbitrary for LowerHigherPair { + fn arbitrary(g: &mut G) -> Self { + let (l, h) = loop { + let (l, h) = (i64::arbitrary(g), i64::arbitrary(g)); + if l > h || h.checked_sub(l).is_none() { + continue; + } + break (l, h); + }; + LowerHigherPair(l, h) + } + } + + impl From> for (i64, i64) { + fn from(value: LowerHigherPair) -> Self { + (value.0, value.1) + } + } + + fn naive_midpoint_i64(lower: i64, higher: i64) -> i64 { + // Overflows when higher is very big and lower is very small + lower + (higher - lower) / 2 + } + + #[quickcheck] + fn test_midpoint_algo_eq_naive_algo_i64(lh: LowerHigherPair) -> TestResult { + let (lower, higher) = lh.into(); + let naive = naive_midpoint_i64(lower, higher); + let midpoint = Midpoint::interpolate(Some(lower), Some(higher), n64(0.0), 0); + TestResult::from_bool(naive == midpoint) + } +} diff --git a/tests/quantile.rs b/tests/quantile.rs index 9d58071f..5aa3e684 100644 --- a/tests/quantile.rs +++ b/tests/quantile.rs @@ -268,7 +268,7 @@ fn test_quantile_axis_skipnan_mut_linear_opt_i32() { } #[test] -fn test_midpoint_overflow() { +fn test_midpoint_overflow_unsigned() { // Regression test // This triggered an overflow panic with a naive Midpoint implementation: (a+b)/2 let mut a: Array1 = array![129, 130, 130, 131]; @@ -277,6 +277,16 @@ fn test_midpoint_overflow() { assert_eq!(median, expected_median); } +#[test] +fn test_midpoint_overflow_signed() { + // Regression test + // This triggered an overflow panic with a naive Midpoint implementation: b+(a-b)/2 + let mut a: Array1 = array![i64::MIN, i64::MAX]; + let median = a.quantile_mut(n64(0.5), &Midpoint).unwrap(); + let expected_median = -1; + assert_eq!(median, expected_median); +} + #[quickcheck] fn test_quantiles_mut(xs: Vec) -> bool { let v = Array::from(xs.clone());