Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions distr_test/tests/cdf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,38 @@ fn poisson() {
}
}

#[test]
fn truncated_normal() {
let parameters = [
(0.0, 1.0, -1.0, 1.0),
(0.0, 1.0, 0.0, 2.0),
(1.0, 2.0, -1.0, 3.0),
(5.0, 0.5, 4.0, 6.0),
(10.0, 1.0, 8.0, 12.0),
];

for (seed, (mu, sigma, lower, upper)) in parameters.into_iter().enumerate() {
let dist = rand_distr::NormalTruncated::new(mu, sigma, lower, upper).unwrap();
let analytic = |x| {
if x < lower {
0.0
} else if x > upper {
1.0
} else {
let standard_lower = (lower - mu) / sigma;
let standard_upper = (upper - mu) / sigma;
let standard_x = (x - mu) / sigma;

let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap();

let z = normal.cdf(standard_upper) - normal.cdf(standard_lower);
(normal.cdf(standard_x) - normal.cdf(standard_lower)) / z
}
};
test_continuous(seed as u64, dist, analytic);
}
}

fn ln_factorial(n: u64) -> f64 {
(n as f64 + 1.0).lgamma().0
}
Expand Down
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
//! - Misc. distributions
//! - [`InverseGaussian`] distribution
//! - [`NormalInverseGaussian`] distribution
//! - [`TruncatedNormal`] distribution

#[cfg(feature = "alloc")]
extern crate alloc;
Expand Down Expand Up @@ -112,6 +113,7 @@ pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal};
pub use self::normal_inverse_gaussian::{
Error as NormalInverseGaussianError, NormalInverseGaussian,
};
pub use self::normal_truncated::{Error as NormalTruncatedError, NormalTruncated};
pub use self::pareto::{Error as ParetoError, Pareto};
pub use self::pert::{Pert, PertBuilder, PertError};
pub use self::poisson::{Error as PoissonError, Poisson};
Expand Down Expand Up @@ -198,6 +200,7 @@ mod hypergeometric;
mod inverse_gaussian;
mod normal;
mod normal_inverse_gaussian;
mod normal_truncated;
mod pareto;
mod pert;
pub(crate) mod poisson;
Expand Down
211 changes: 211 additions & 0 deletions src/normal_truncated.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
#[allow(unused_imports)]
use num_traits::Float;
use rand::{Rng, distr::Distribution};

/// The [truncated normal distribution](https://en.wikipedia.org/wiki/Truncated_normal_distribution).
///
/// # Current Implementation
/// We follow the approach described in
/// Robert, Christian P. (1995). "Simulation of truncated normal variables".
/// Statistics and Computing. 5 (2): 121–125.

#[derive(Debug)]
pub struct NormalTruncated(Method);

#[derive(Debug)]
enum Method {
Rejection(NormalTruncatedRejection),
OneSided(bool, NormalTruncatedOneSided), // bool indicates if lower bound is used
TwoSided(NormalTruncatedTwoSided),
}

#[derive(Debug)]
/// Errors that can occur when constructing a `NormalTruncated` distribution.
pub enum Error {
/// The standard deviation was not positive.
InvalidStdDev,
/// The lower bound was not less than the upper bound.
InvalidBounds,
}

impl NormalTruncated {
/// Constructs a new `NormalTruncated` distribution with the given
/// mean, standard deviation, lower bound, and upper bound.
pub fn new(mean: f64, stddev: f64, lower: f64, upper: f64) -> Result<Self, Error> {
if !(stddev > 0.0) {
return Err(Error::InvalidStdDev);
}
if !(lower < upper) {
return Err(Error::InvalidBounds);
}

let std_lower = (lower - mean) / stddev;
let std_upper = (upper - mean) / stddev;

if upper == f64::INFINITY {
// Threshold can probably be tuned better for performance
if std_lower >= 0.5 {
// One sided truncation, lower bound only
Ok(NormalTruncated(Method::OneSided(
true,
NormalTruncatedOneSided::new(mean, stddev, std_lower),
)))
} else {
// We use naive rejection sampling
// Also catches the case where both bounds are infinite
Ok(NormalTruncated(Method::Rejection(
NormalTruncatedRejection {
normal: crate::Normal::new(mean, stddev).unwrap(),
lower,
upper,
},
)))
}
} else if lower == f64::NEG_INFINITY {
// Threshold can probably be tuned better for performance
if std_upper <= -0.5 {
// One sided truncation, upper bound only
Ok(NormalTruncated(Method::OneSided(
false,
NormalTruncatedOneSided::new(-mean, stddev, -std_upper),
)))
} else {
// We use naive rejection sampling
Ok(NormalTruncated(Method::Rejection(
NormalTruncatedRejection {
normal: crate::Normal::new(mean, stddev).unwrap(),
lower,
upper,
},
)))
}
} else {
let diff = std_upper - std_lower;
// Threshold can probably be tuned better for performance
if diff >= 1.0 && std_lower <= 1.0 && std_upper >= -1.0 {
// Naive rejection sampling
Ok(NormalTruncated(Method::Rejection(
NormalTruncatedRejection {
normal: crate::Normal::new(mean, stddev).unwrap(),
lower,
upper,
},
)))
} else {
// Two sided truncation
Ok(NormalTruncated(Method::TwoSided(
NormalTruncatedTwoSided::new(mean, stddev, std_lower, std_upper),
)))
}
}
}
}

impl Distribution<f64> for NormalTruncated {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
match &self.0 {
Method::Rejection(rej) => rej.sample(rng),
Method::OneSided(true, one_sided) => one_sided.sample(rng),
Method::OneSided(false, one_sided) => -one_sided.sample(rng),
Method::TwoSided(two_sided) => two_sided.sample(rng),
}
}
}

/// A truncated normal distribution using naive rejection sampling.
/// We use this when the acceptance rate is high enough.
#[derive(Debug)]
struct NormalTruncatedRejection {
normal: crate::Normal<f64>,
lower: f64,
upper: f64,
}

impl Distribution<f64> for NormalTruncatedRejection {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
let mut sample;
loop {
sample = self.normal.sample(rng);
if sample >= self.lower && sample <= self.upper {
break;
}
}
sample
}
}

#[derive(Debug)]
struct NormalTruncatedOneSided {
alpha_star: f64,
lower_bound: f64,
exp_distribution: crate::Exp<f64>,
mu: f64,
sigma: f64,
}

impl NormalTruncatedOneSided {
fn new(mu: f64, sigma: f64, standard_lower_bound: f64) -> Self {
let alpha_star = (standard_lower_bound + (standard_lower_bound.powi(2) + 4.0).sqrt()) / 2.0;
let lambda = alpha_star;
NormalTruncatedOneSided {
alpha_star,
lower_bound: standard_lower_bound,
exp_distribution: crate::Exp::new(lambda).unwrap(),
mu,
sigma,
}
}
}

impl Distribution<f64> for NormalTruncatedOneSided {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
loop {
let z = self.exp_distribution.sample(rng) + self.lower_bound;
let u: f64 = rng.random();
let rho = (-0.5 * (z - self.alpha_star).powi(2)).exp();
if u <= rho {
return self.mu + self.sigma * z;
}
}
}
}

#[derive(Debug)]
struct NormalTruncatedTwoSided {
mu: f64,
sigma: f64,
// In standard normal coordinates
standard_lower: f64,
// In standard normal coordinates
standard_upper: f64,
}

impl NormalTruncatedTwoSided {
fn new(mu: f64, sigma: f64, standard_lower: f64, standard_upper: f64) -> Self {
NormalTruncatedTwoSided {
mu,
sigma,
standard_lower,
standard_upper,
}
}
}

impl Distribution<f64> for NormalTruncatedTwoSided {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
loop {
let z = rng.random_range(self.standard_lower..self.standard_upper);
let u: f64 = rng.random();
let rho = if self.standard_lower <= 0.0 && self.standard_upper >= 0.0 {
(-0.5 * z.powi(2)).exp()
} else if self.standard_upper < 0.0 {
(0.5 * (self.standard_upper.powi(2) - z.powi(2))).exp()
} else {
(0.5 * (self.standard_lower.powi(2) - z.powi(2))).exp()
};
if u <= rho {
return self.mu + self.sigma * z;
}
}
}
}
Loading