Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion src/maybe_nan/impl_not_none.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::NotNone;
use num_traits::{FromPrimitive, ToPrimitive};
use num_traits::{Euclid, FromPrimitive, ToPrimitive};
use std::cmp;
use std::fmt;
use std::ops::{Add, Deref, DerefMut, Div, Mul, Rem, Sub};
Expand Down Expand Up @@ -101,6 +101,19 @@ impl<T: Rem> Rem for NotNone<T> {
}
}

impl<T: Euclid> Euclid for NotNone<T> {
#[inline]
fn div_euclid(&self, rhs: &Self) -> Self {
let result = self.deref().div_euclid(rhs.deref());
NotNone(Some(result))
}
#[inline]
fn rem_euclid(&self, rhs: &Self) -> Self {
let result = self.deref().rem_euclid(rhs.deref());
NotNone(Some(result))
}
}

impl<T: ToPrimitive> ToPrimitive for NotNone<T> {
#[inline]
fn to_isize(&self) -> Option<isize> {
Expand Down
121 changes: 105 additions & 16 deletions src/quantile/interpolate.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! Interpolation strategies.
use noisy_float::types::N64;
use num_traits::{Float, FromPrimitive, NumOps, ToPrimitive};
use num_traits::{Euclid, Float, FromPrimitive, NumOps, ToPrimitive};

use crate::maybe_nan::NotNone;

fn float_quantile_index(q: N64, len: usize) -> N64 {
q * ((len - 1) as f64)
Expand Down Expand Up @@ -104,25 +106,69 @@ impl<T> Interpolate<T> for Nearest {
private_impl! {}
}

impl<T> Interpolate<T> for Midpoint
where
T: NumOps + Clone + FromPrimitive,
{
fn needs_lower(_q: N64, _len: usize) -> bool {
true
}
fn needs_higher(_q: N64, _len: usize) -> bool {
true
macro_rules! impl_midpoint_interpolate_for_float {
($($t:ty),*) => {
$(
impl Interpolate<$t> for Midpoint {
fn needs_lower(_q: N64, _len: usize) -> bool {
true
}
fn needs_higher(_q: N64, _len: usize) -> bool {
true
}
fn interpolate(lower: Option<$t>, higher: Option<$t>, _q: N64, _len: usize) -> $t {
let lower = lower.unwrap();
let higher = higher.unwrap();
lower + (higher - lower) / 2.0
}
private_impl! {}
}
)*
}
fn interpolate(lower: Option<T>, higher: Option<T>, _q: N64, _len: usize) -> T {
let denom = T::from_u8(2).unwrap();
let lower = lower.unwrap();
let higher = higher.unwrap();
lower.clone() + (higher.clone() - lower.clone()) / denom.clone()
}

impl_midpoint_interpolate_for_float!(f32, f64);

macro_rules! impl_midpoint_interpolate_for_integer {
($($t:ty),*) => {
$(
impl Interpolate<$t> for Midpoint {
fn needs_lower(_q: N64, _len: usize) -> bool {
true
}
fn needs_higher(_q: N64, _len: usize) -> bool {
true
}
fn interpolate(lower: Option<$t>, higher: Option<$t>, _q: N64, _len: usize) -> $t {
let two = <$t>::from_u8(2).unwrap();
let (lower_half, lower_rem) = lower.unwrap().div_rem_euclid(&two);
let (higher_half, higher_rem) = higher.unwrap().div_rem_euclid(&two);
lower_half + higher_half + (lower_rem * higher_rem)
}
private_impl! {}
}
)*
}
private_impl! {}
}

impl_midpoint_interpolate_for_integer!(
i8, i16, i32, i64, i128, isize, u8, u16, u32, u64, u128, usize
);
impl_midpoint_interpolate_for_integer!(
NotNone<i8>,
NotNone<i16>,
NotNone<i32>,
NotNone<i64>,
NotNone<i128>,
NotNone<isize>,
NotNone<u8>,
NotNone<u16>,
NotNone<u32>,
NotNone<u64>,
NotNone<u128>,
NotNone<usize>
);

impl<T> Interpolate<T> for Linear
where
T: NumOps + Clone + FromPrimitive + ToPrimitive,
Expand All @@ -143,3 +189,46 @@ where
}
private_impl! {}
}

#[cfg(test)]
mod tests {
use super::*;
use noisy_float::types::n64;
use quickcheck::TestResult;
use quickcheck_macros::quickcheck;

#[derive(Clone, Copy, Debug)]
struct LowerHigherPair<T>(T, T);

impl quickcheck::Arbitrary for LowerHigherPair<i64> {
fn arbitrary<G: quickcheck::Gen>(g: &mut G) -> Self {
let (l, h) = loop {
let (l, h) = (i64::arbitrary(g), i64::arbitrary(g));
if l > h || h.checked_sub(l).is_none() {
continue;
}
break (l, h);
};
LowerHigherPair(l, h)
}
}

impl From<LowerHigherPair<i64>> for (i64, i64) {
fn from(value: LowerHigherPair<i64>) -> Self {
(value.0, value.1)
}
}

fn naive_midpoint_i64(lower: i64, higher: i64) -> i64 {
// Overflows when higher is very big and lower is very small
lower + (higher - lower) / 2
}

#[quickcheck]
fn test_midpoint_algo_eq_naive_algo_i64(lh: LowerHigherPair<i64>) -> TestResult {
let (lower, higher) = lh.into();
let naive = naive_midpoint_i64(lower, higher);
let midpoint = Midpoint::interpolate(Some(lower), Some(higher), n64(0.0), 0);
TestResult::from_bool(naive == midpoint)
}
}
12 changes: 11 additions & 1 deletion tests/quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ fn test_quantile_axis_skipnan_mut_linear_opt_i32() {
}

#[test]
fn test_midpoint_overflow() {
fn test_midpoint_overflow_unsigned() {
// Regression test
// This triggered an overflow panic with a naive Midpoint implementation: (a+b)/2
let mut a: Array1<u8> = array![129, 130, 130, 131];
Expand All @@ -277,6 +277,16 @@ fn test_midpoint_overflow() {
assert_eq!(median, expected_median);
}

#[test]
fn test_midpoint_overflow_signed() {
// Regression test
// This triggered an overflow panic with a naive Midpoint implementation: b+(a-b)/2
let mut a: Array1<i64> = array![i64::MIN, i64::MAX];
let median = a.quantile_mut(n64(0.5), &Midpoint).unwrap();
let expected_median = -1;
assert_eq!(median, expected_median);
}

#[quickcheck]
fn test_quantiles_mut(xs: Vec<i64>) -> bool {
let v = Array::from(xs.clone());
Expand Down