From ee06809a4bc80917472b0a556707c7990856fbdc Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Thu, 22 May 2025 17:14:02 +0200 Subject: [PATCH 01/10] first draft --- src/lib.rs | 1 + src/normal_truncated.rs | 45 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 src/normal_truncated.rs diff --git a/src/lib.rs b/src/lib.rs index e1a892d..b319912 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -198,6 +198,7 @@ mod hypergeometric; mod inverse_gaussian; mod normal; mod normal_inverse_gaussian; +mod normal_truncated; mod pareto; mod pert; pub(crate) mod poisson; diff --git a/src/normal_truncated.rs b/src/normal_truncated.rs new file mode 100644 index 0000000..161020f --- /dev/null +++ b/src/normal_truncated.rs @@ -0,0 +1,45 @@ +use rand::{distr::Distribution, Rng}; + +pub struct NormalTruncated { + mean: f64, + stddev: f64, + lower: f64, + upper: f64, +} + +pub enum Error { + NonPosStdDev, +} + +impl NormalTruncated { + pub fn new(mean: f64, stddev: f64, lower: f64, upper: f64) -> Result { + if stddev <= 0.0 { + return Err(Error::NonPosStdDev); + } + Ok(NormalTruncated { + mean, + stddev, + lower, + upper, + }) + } +} + +struct NormalTruncatedRejection { + normal: crate::Normal, + lower: f64, + upper: f64, +} + +impl Distribution for NormalTruncatedRejection { + fn sample(&self, rng: &mut R) -> f64 { + let mut sample; + loop { + sample = self.normal.sample(rng); + if sample >= self.lower && sample <= self.upper { + break; + } + } + sample + } +} \ No newline at end of file From fd06f66afcb2ad41f20f9d979b534ec32fdb0c88 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Thu, 22 May 2025 17:30:52 +0200 Subject: [PATCH 02/10] second draft --- src/normal_truncated.rs | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/src/normal_truncated.rs b/src/normal_truncated.rs index 161020f..7cb81f4 100644 --- a/src/normal_truncated.rs +++ b/src/normal_truncated.rs @@ -1,10 +1,9 @@ use rand::{distr::Distribution, Rng}; -pub struct NormalTruncated { - mean: f64, - stddev: f64, - lower: f64, - upper: f64, +pub struct NormalTruncated(Method); +pub enum Method { + Rejection(NormalTruncatedRejection), + OneSided(), } pub enum Error { @@ -16,15 +15,21 @@ impl NormalTruncated { if stddev <= 0.0 { return Err(Error::NonPosStdDev); } - Ok(NormalTruncated { - mean, - stddev, - lower, - upper, - }) + + // When the lower bound is smaller than the mean, naive rejection sampling will have at least + if lower < mean { + return Ok(NormalTruncated(Method::Rejection(NormalTruncatedRejection { + normal: crate::Normal::new(mean, stddev).unwrap(), + lower, + upper, + }))); + } + todo!() } } +/// A truncated normal distribution using naive rejection sampling. +/// We use this when the acceptance rate is high enough. struct NormalTruncatedRejection { normal: crate::Normal, lower: f64, From f0b29b35872c4f2efb5b45b2491ce8157377e482 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Mon, 27 Oct 2025 15:10:49 +0100 Subject: [PATCH 03/10] one sided case --- src/normal_truncated.rs | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/normal_truncated.rs b/src/normal_truncated.rs index 7cb81f4..3a860c2 100644 --- a/src/normal_truncated.rs +++ b/src/normal_truncated.rs @@ -47,4 +47,40 @@ impl Distribution for NormalTruncatedRejection { } sample } +} + +struct NormalTruncatedOneSided { + alpha_star: f64, + lower_bound: f64, + exp_distribution: crate::Exp, + mu: f64, + sigma: f64, +} + +impl NormalTruncatedOneSided { + fn new(mu: f64, sigma: f64, lower: f64) -> Self { + let alpha = (lower - mu) / sigma; + let alpha_star = (alpha + (alpha.powi(2) + 4.0).sqrt()) / 2.0; + let lambda = alpha_star; + NormalTruncatedOneSided { + alpha_star, + lower_bound: lower - mu, + exp_distribution: crate::Exp::new(lambda).unwrap(), + mu, + sigma, + } + } +} + +impl Distribution for NormalTruncatedOneSided { + fn sample(&self, rng: &mut R) -> f64 { + loop { + let z = self.exp_distribution.sample(rng) + self.lower_bound; + let u: f64 = rng.random(); + let rho = (-0.5 * (z - self.alpha_star).powi(2)).exp(); + if u <= rho { + return self.mu + self.sigma * z; + } + } + } } \ No newline at end of file From a8a4c6104aa4013e8ea141703c976fa48e5e9c27 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Mon, 27 Oct 2025 16:12:15 +0100 Subject: [PATCH 04/10] two sided case --- src/normal_truncated.rs | 43 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 41 insertions(+), 2 deletions(-) diff --git a/src/normal_truncated.rs b/src/normal_truncated.rs index 3a860c2..03f9cca 100644 --- a/src/normal_truncated.rs +++ b/src/normal_truncated.rs @@ -59,8 +59,8 @@ struct NormalTruncatedOneSided { impl NormalTruncatedOneSided { fn new(mu: f64, sigma: f64, lower: f64) -> Self { - let alpha = (lower - mu) / sigma; - let alpha_star = (alpha + (alpha.powi(2) + 4.0).sqrt()) / 2.0; + let standart_lower_bound = (lower - mu) / sigma; + let alpha_star = (standart_lower_bound + (standart_lower_bound.powi(2) + 4.0).sqrt()) / 2.0; let lambda = alpha_star; NormalTruncatedOneSided { alpha_star, @@ -83,4 +83,43 @@ impl Distribution for NormalTruncatedOneSided { } } } +} + +struct NormalTruncatedTwoSided { + mu: f64, + sigma: f64, + // In standard normal coordinates + lower: f64, + // In standard normal coordinates + upper: f64, +} + +impl NormalTruncatedTwoSided { + fn new(mu: f64, sigma: f64, lower: f64, upper: f64) -> Self { + NormalTruncatedTwoSided { + mu, + sigma, + lower: (lower - mu) / sigma, + upper: (upper - mu) / sigma, + } + } +} + +impl Distribution for NormalTruncatedTwoSided { + fn sample(&self, rng: &mut R) -> f64 { + loop { + let z = rng.random_range(self.lower..self.upper); + let u: f64 = rng.random(); + let rho = if self.lower <= 0.0 && self.upper >= 0.0 { + (-0.5 * z.powi(2)).exp() + } else if self.upper < 0.0 { + (0.5 * (self.upper.powi(2) - z.powi(2))).exp() + } else { + (0.5 * (self.lower.powi(2) - z.powi(2))).exp() + }; + if u <= rho { + return self.mu + self.sigma * z; + } + } + } } \ No newline at end of file From 89dd17d95e06793a1fc3c05d347068b31eccec16 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Mon, 27 Oct 2025 16:40:20 +0100 Subject: [PATCH 05/10] first prototype --- src/lib.rs | 2 + src/normal_truncated.rs | 146 +++++++++++++++++++++++++++++++--------- 2 files changed, 117 insertions(+), 31 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index b319912..8339926 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -80,6 +80,7 @@ //! - Misc. distributions //! - [`InverseGaussian`] distribution //! - [`NormalInverseGaussian`] distribution +//! - [`TruncatedNormal`] distribution #[cfg(feature = "alloc")] extern crate alloc; @@ -112,6 +113,7 @@ pub use self::normal::{Error as NormalError, LogNormal, Normal, StandardNormal}; pub use self::normal_inverse_gaussian::{ Error as NormalInverseGaussianError, NormalInverseGaussian, }; +pub use self::normal_truncated::{Error as NormalTruncatedError, NormalTruncated}; pub use self::pareto::{Error as ParetoError, Pareto}; pub use self::pert::{Pert, PertBuilder, PertError}; pub use self::poisson::{Error as PoissonError, Poisson}; diff --git a/src/normal_truncated.rs b/src/normal_truncated.rs index 03f9cca..758d3f6 100644 --- a/src/normal_truncated.rs +++ b/src/normal_truncated.rs @@ -1,35 +1,118 @@ -use rand::{distr::Distribution, Rng}; +use rand::{Rng, distr::Distribution}; +/// The [truncated normal distribution](https://en.wikipedia.org/wiki/Truncated_normal_distribution). +/// +/// # Current Implementation +/// We follow the approach described in +/// Robert, Christian P. (1995). "Simulation of truncated normal variables". +/// Statistics and Computing. 5 (2): 121–125. + +#[derive(Debug)] pub struct NormalTruncated(Method); -pub enum Method { + +#[derive(Debug)] +enum Method { Rejection(NormalTruncatedRejection), - OneSided(), + OneSided(bool, NormalTruncatedOneSided), // bool indicates if lower bound is used + TwoSided(NormalTruncatedTwoSided), } +#[derive(Debug)] +/// Errors that can occur when constructing a `NormalTruncated` distribution. pub enum Error { - NonPosStdDev, + /// The standard deviation was not positive. + InvalidStdDev, + /// The lower bound was not less than the upper bound. + InvalidBounds, } impl NormalTruncated { + /// Constructs a new `NormalTruncated` distribution with the given + /// mean, standard deviation, lower bound, and upper bound. pub fn new(mean: f64, stddev: f64, lower: f64, upper: f64) -> Result { - if stddev <= 0.0 { - return Err(Error::NonPosStdDev); + if !(stddev > 0.0) { + return Err(Error::InvalidStdDev); + } + if !(lower < upper) { + return Err(Error::InvalidBounds); + } + + let std_lower = (lower - mean) / stddev; + let std_upper = (upper - mean) / stddev; + + if upper == f64::INFINITY { + // Threshold can probably be tuned better for performance + if std_lower >= 0.5 { + // One sided truncation, lower bound only + return Ok(NormalTruncated(Method::OneSided( + true, + NormalTruncatedOneSided::new(mean, stddev, std_lower), + ))); + } else { + // We use naive rejection sampling + // Also catches the case where both bounds are infinite + return Ok(NormalTruncated(Method::Rejection( + NormalTruncatedRejection { + normal: crate::Normal::new(mean, stddev).unwrap(), + lower, + upper, + }, + ))); + } + } else if lower == f64::NEG_INFINITY { + // Threshold can probably be tuned better for performance + if std_upper <= -0.5 { + // One sided truncation, upper bound only + return Ok(NormalTruncated(Method::OneSided( + false, + NormalTruncatedOneSided::new(-mean, stddev, -std_upper), + ))); + } else { + // We use naive rejection sampling + return Ok(NormalTruncated(Method::Rejection( + NormalTruncatedRejection { + normal: crate::Normal::new(mean, stddev).unwrap(), + lower, + upper, + }, + ))); + } + } else { + let diff = std_upper - std_lower; + // Threshold can probably be tuned better for performance + if diff >= 1.0 && std_lower <= 1.0 && std_upper >= -1.0 { + // Naive rejection sampling + return Ok(NormalTruncated(Method::Rejection( + NormalTruncatedRejection { + normal: crate::Normal::new(mean, stddev).unwrap(), + lower, + upper, + }, + ))); + } else { + // Two sided truncation + return Ok(NormalTruncated(Method::TwoSided( + NormalTruncatedTwoSided::new(mean, stddev, std_lower, std_upper), + ))); + } + } + } +} + +impl Distribution for NormalTruncated { + fn sample(&self, rng: &mut R) -> f64 { + match &self.0 { + Method::Rejection(rej) => rej.sample(rng), + Method::OneSided(true, one_sided) => one_sided.sample(rng), + Method::OneSided(false, one_sided) => -one_sided.sample(rng), + Method::TwoSided(two_sided) => two_sided.sample(rng), } - - // When the lower bound is smaller than the mean, naive rejection sampling will have at least - if lower < mean { - return Ok(NormalTruncated(Method::Rejection(NormalTruncatedRejection { - normal: crate::Normal::new(mean, stddev).unwrap(), - lower, - upper, - }))); - } - todo!() } } /// A truncated normal distribution using naive rejection sampling. /// We use this when the acceptance rate is high enough. +#[derive(Debug)] struct NormalTruncatedRejection { normal: crate::Normal, lower: f64, @@ -49,6 +132,7 @@ impl Distribution for NormalTruncatedRejection { } } +#[derive(Debug)] struct NormalTruncatedOneSided { alpha_star: f64, lower_bound: f64, @@ -58,13 +142,12 @@ struct NormalTruncatedOneSided { } impl NormalTruncatedOneSided { - fn new(mu: f64, sigma: f64, lower: f64) -> Self { - let standart_lower_bound = (lower - mu) / sigma; - let alpha_star = (standart_lower_bound + (standart_lower_bound.powi(2) + 4.0).sqrt()) / 2.0; + fn new(mu: f64, sigma: f64, standard_lower_bound: f64) -> Self { + let alpha_star = (standard_lower_bound + (standard_lower_bound.powi(2) + 4.0).sqrt()) / 2.0; let lambda = alpha_star; NormalTruncatedOneSided { alpha_star, - lower_bound: lower - mu, + lower_bound: standard_lower_bound, exp_distribution: crate::Exp::new(lambda).unwrap(), mu, sigma, @@ -85,22 +168,23 @@ impl Distribution for NormalTruncatedOneSided { } } +#[derive(Debug)] struct NormalTruncatedTwoSided { mu: f64, sigma: f64, // In standard normal coordinates - lower: f64, + standard_lower: f64, // In standard normal coordinates - upper: f64, + standard_upper: f64, } impl NormalTruncatedTwoSided { - fn new(mu: f64, sigma: f64, lower: f64, upper: f64) -> Self { + fn new(mu: f64, sigma: f64, standard_lower: f64, standard_upper: f64) -> Self { NormalTruncatedTwoSided { mu, sigma, - lower: (lower - mu) / sigma, - upper: (upper - mu) / sigma, + standard_lower, + standard_upper, } } } @@ -108,18 +192,18 @@ impl NormalTruncatedTwoSided { impl Distribution for NormalTruncatedTwoSided { fn sample(&self, rng: &mut R) -> f64 { loop { - let z = rng.random_range(self.lower..self.upper); + let z = rng.random_range(self.standard_lower..self.standard_upper); let u: f64 = rng.random(); - let rho = if self.lower <= 0.0 && self.upper >= 0.0 { + let rho = if self.standard_lower <= 0.0 && self.standard_upper >= 0.0 { (-0.5 * z.powi(2)).exp() - } else if self.upper < 0.0 { - (0.5 * (self.upper.powi(2) - z.powi(2))).exp() + } else if self.standard_upper < 0.0 { + (0.5 * (self.standard_upper.powi(2) - z.powi(2))).exp() } else { - (0.5 * (self.lower.powi(2) - z.powi(2))).exp() + (0.5 * (self.standard_lower.powi(2) - z.powi(2))).exp() }; if u <= rho { return self.mu + self.sigma * z; } } } -} \ No newline at end of file +} From 7378ba2051ea9885fa72d54f84b439a9ef51f620 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Mon, 27 Oct 2025 16:50:29 +0100 Subject: [PATCH 06/10] added ks test --- distr_test/tests/cdf.rs | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/distr_test/tests/cdf.rs b/distr_test/tests/cdf.rs index 2dd6639..34bb60b 100644 --- a/distr_test/tests/cdf.rs +++ b/distr_test/tests/cdf.rs @@ -437,6 +437,42 @@ fn poisson() { } } +#[test] +fn truncated_normal() { + let parameters = [ + (0.0, 1.0, -1.0, 1.0), + (0.0, 1.0, 0.0, 2.0), + (1.0, 2.0, -1.0, 3.0), + (5.0, 0.5, 4.0, 6.0), + (10.0, 1.0, 8.0, 12.0), + ]; + + for (seed, (mu, sigma, lower, upper)) in parameters.into_iter().enumerate() { + let dist = rand_distr::NormalTruncated::new(mu, sigma, lower, upper).unwrap(); + let analytic = |x| { + if x < lower { + 0.0 + } else if x > upper { + 1.0 + } else { + let standard_lower = (lower - mu) / sigma; + let standard_upper = (upper - mu) / sigma; + let standard_x = (x - mu) / sigma; + + let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap(); + + let z = normal + .cdf(standard_upper) + - normal.cdf(standard_lower); + (normal.cdf(standard_x) + - normal.cdf(standard_lower)) + / Z + } + }; + test_continuous(seed as u64, dist, analytic); + } +} + fn ln_factorial(n: u64) -> f64 { (n as f64 + 1.0).lgamma().0 } From a2afcec666f521001509dc52d80d74e3d0df5c76 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Mon, 27 Oct 2025 16:56:01 +0100 Subject: [PATCH 07/10] fmt and num_traits Float --- distr_test/tests/cdf.rs | 8 ++------ src/normal_truncated.rs | 2 ++ 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/distr_test/tests/cdf.rs b/distr_test/tests/cdf.rs index 34bb60b..546c9ef 100644 --- a/distr_test/tests/cdf.rs +++ b/distr_test/tests/cdf.rs @@ -461,12 +461,8 @@ fn truncated_normal() { let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap(); - let z = normal - .cdf(standard_upper) - - normal.cdf(standard_lower); - (normal.cdf(standard_x) - - normal.cdf(standard_lower)) - / Z + let z = normal.cdf(standard_upper) - normal.cdf(standard_lower); + (normal.cdf(standard_x) - normal.cdf(standard_lower)) / Z } }; test_continuous(seed as u64, dist, analytic); diff --git a/src/normal_truncated.rs b/src/normal_truncated.rs index 758d3f6..64a0e1e 100644 --- a/src/normal_truncated.rs +++ b/src/normal_truncated.rs @@ -1,4 +1,6 @@ use rand::{Rng, distr::Distribution}; +#[allow(unused_imports)] +use num_traits::Float; /// The [truncated normal distribution](https://en.wikipedia.org/wiki/Truncated_normal_distribution). /// From 3a5f549d077c93007bada273e00372bbf288e483 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Mon, 27 Oct 2025 16:58:20 +0100 Subject: [PATCH 08/10] clippy --- src/normal_truncated.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/normal_truncated.rs b/src/normal_truncated.rs index 64a0e1e..724e683 100644 --- a/src/normal_truncated.rs +++ b/src/normal_truncated.rs @@ -46,56 +46,56 @@ impl NormalTruncated { // Threshold can probably be tuned better for performance if std_lower >= 0.5 { // One sided truncation, lower bound only - return Ok(NormalTruncated(Method::OneSided( + Ok(NormalTruncated(Method::OneSided( true, NormalTruncatedOneSided::new(mean, stddev, std_lower), - ))); + ))) } else { // We use naive rejection sampling // Also catches the case where both bounds are infinite - return Ok(NormalTruncated(Method::Rejection( + Ok(NormalTruncated(Method::Rejection( NormalTruncatedRejection { normal: crate::Normal::new(mean, stddev).unwrap(), lower, upper, }, - ))); + ))) } } else if lower == f64::NEG_INFINITY { // Threshold can probably be tuned better for performance if std_upper <= -0.5 { // One sided truncation, upper bound only - return Ok(NormalTruncated(Method::OneSided( + Ok(NormalTruncated(Method::OneSided( false, NormalTruncatedOneSided::new(-mean, stddev, -std_upper), - ))); + ))) } else { // We use naive rejection sampling - return Ok(NormalTruncated(Method::Rejection( + Ok(NormalTruncated(Method::Rejection( NormalTruncatedRejection { normal: crate::Normal::new(mean, stddev).unwrap(), lower, upper, }, - ))); + ))) } } else { let diff = std_upper - std_lower; // Threshold can probably be tuned better for performance if diff >= 1.0 && std_lower <= 1.0 && std_upper >= -1.0 { // Naive rejection sampling - return Ok(NormalTruncated(Method::Rejection( + Ok(NormalTruncated(Method::Rejection( NormalTruncatedRejection { normal: crate::Normal::new(mean, stddev).unwrap(), lower, upper, }, - ))); + ))) } else { // Two sided truncation - return Ok(NormalTruncated(Method::TwoSided( + Ok(NormalTruncated(Method::TwoSided( NormalTruncatedTwoSided::new(mean, stddev, std_lower, std_upper), - ))); + ))) } } } From c4f74e12e1710e8b7f617957ab3873e6e0e78f94 Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Mon, 27 Oct 2025 16:59:54 +0100 Subject: [PATCH 09/10] typo --- distr_test/tests/cdf.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/distr_test/tests/cdf.rs b/distr_test/tests/cdf.rs index 546c9ef..48c5d93 100644 --- a/distr_test/tests/cdf.rs +++ b/distr_test/tests/cdf.rs @@ -462,7 +462,7 @@ fn truncated_normal() { let normal = statrs::distribution::Normal::new(0.0, 1.0).unwrap(); let z = normal.cdf(standard_upper) - normal.cdf(standard_lower); - (normal.cdf(standard_x) - normal.cdf(standard_lower)) / Z + (normal.cdf(standard_x) - normal.cdf(standard_lower)) / z } }; test_continuous(seed as u64, dist, analytic); From e2e2f3dff4a25d0dd3b97cb410d5c2dd5064fbef Mon Sep 17 00:00:00 2001 From: Benjamin Lieser Date: Mon, 27 Oct 2025 17:02:50 +0100 Subject: [PATCH 10/10] fmt --- src/normal_truncated.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/normal_truncated.rs b/src/normal_truncated.rs index 724e683..7e60618 100644 --- a/src/normal_truncated.rs +++ b/src/normal_truncated.rs @@ -1,6 +1,6 @@ -use rand::{Rng, distr::Distribution}; #[allow(unused_imports)] use num_traits::Float; +use rand::{Rng, distr::Distribution}; /// The [truncated normal distribution](https://en.wikipedia.org/wiki/Truncated_normal_distribution). ///