diff --git a/Cargo.lock b/Cargo.lock index 8b99b1e2..eeb26ec4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -209,6 +209,18 @@ dependencies = [ "wasi 0.11.1+wasi-snapshot-preview1", ] +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + [[package]] name = "half" version = "2.2.1" @@ -315,9 +327,9 @@ checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" [[package]] name = "ndarray" -version = "0.16.1" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +checksum = "0c7c9125e8f6f10c9da3aad044cc918cf8784fa34de857b1aa68038eb05a50a9" dependencies = [ "approx", "matrixmultiply", @@ -331,12 +343,12 @@ dependencies = [ [[package]] name = "ndarray-rand" -version = "0.15.0" +version = "0.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f093b3db6fd194718dcdeea6bd8c829417deae904e3fcc7732dabcd4416d25d8" +checksum = "180f724d496e84764e8ecf28fbe1da74ef231ec4ba15be65a9100be8445d73e3" dependencies = [ "ndarray", - "rand 0.8.5", + "rand 0.9.2", "rand_distr", ] @@ -509,6 +521,12 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + [[package]] name = "rand" version = "0.7.3" @@ -533,6 +551,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha 0.9.0", + "rand_core 0.9.3", +] + [[package]] name = "rand_chacha" version = "0.2.2" @@ -553,6 +581,16 @@ dependencies = [ "rand_core 0.6.4", ] +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core 0.9.3", +] + [[package]] name = "rand_core" version = "0.5.1" @@ -571,14 +609,23 @@ dependencies = [ "getrandom 0.2.16", ] +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom 0.3.4", +] + [[package]] name = "rand_distr" -version = "0.4.3" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" dependencies = [ "num-traits", - "rand 0.8.5", + "rand 0.9.2", ] [[package]] @@ -758,6 +805,15 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasm-bindgen" version = "0.2.105" @@ -902,6 +958,12 @@ version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + [[package]] name = "zerocopy" version = "0.8.27" diff --git a/Cargo.toml b/Cargo.toml index 9377d81e..b0e3bdbb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,13 +14,13 @@ repository = "https://github.com/rust-ndarray/ndarray-stats" documentation = "https://docs.rs/ndarray-stats/" readme = "README.md" -description = "Statistical routines for ArrayBase, the n-dimensional array data structure provided by ndarray." +description = "Statistical routines for the n-dimensional array data structures provided by ndarray." keywords = ["array", "multidimensional", "statistics", "matrix", "ndarray"] categories = ["data-structures", "science"] [dependencies] -ndarray = "0.16.0" +ndarray = "0.17.1" noisy_float = "0.2.0" num-integer = "0.1" num-traits = "0.2" @@ -29,10 +29,10 @@ itertools = { version = "0.13", default-features = false } indexmap = "2.4" [dev-dependencies] -ndarray = { version = "0.16.1", features = ["approx"] } +ndarray = { version = "0.17.1", features = ["approx"] } criterion = "0.5.1" quickcheck = { version = "0.9.2", default-features = false } -ndarray-rand = "0.15.0" +ndarray-rand = "0.16.0" approx = "0.5" quickcheck_macros = "1.0.0" num-bigint = "0.4.0" diff --git a/README.md b/README.md index 3a565da8..77aa65c4 100644 --- a/README.md +++ b/README.md @@ -5,7 +5,7 @@ [![Crate](https://img.shields.io/crates/v/ndarray-stats.svg)](https://crates.io/crates/ndarray-stats) [![Documentation](https://docs.rs/ndarray-stats/badge.svg)](https://docs.rs/ndarray-stats) -This crate provides statistical methods for [`ndarray`]'s `ArrayBase` type. +This crate provides statistical methods for [`ndarray`]'s `ArrayRef` type. Currently available routines include: - order statistics (minimum, maximum, median, quantiles, etc.); diff --git a/benches/deviation.rs b/benches/deviation.rs index c0ceecb5..2cd9b917 100644 --- a/benches/deviation.rs +++ b/benches/deviation.rs @@ -12,8 +12,8 @@ fn sq_l2_dist(c: &mut Criterion) { group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); for len in &lens { group.bench_with_input(format!("{}", len), len, |b, &len| { - let data = Array::random(len, Uniform::new(0.0, 1.0)); - let data2 = Array::random(len, Uniform::new(0.0, 1.0)); + let data = Array::random(len, Uniform::new(0.0, 1.0).unwrap()); + let data2 = Array::random(len, Uniform::new(0.0, 1.0).unwrap()); b.iter(|| black_box(data.sq_l2_dist(&data2).unwrap())) }); diff --git a/benches/summary_statistics.rs b/benches/summary_statistics.rs index 5796fc02..64f22be8 100644 --- a/benches/summary_statistics.rs +++ b/benches/summary_statistics.rs @@ -12,8 +12,8 @@ fn weighted_std(c: &mut Criterion) { group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); for len in &lens { group.bench_with_input(format!("{}", len), len, |b, &len| { - let data = Array::random(len, Uniform::new(0.0, 1.0)); - let mut weights = Array::random(len, Uniform::new(0.0, 1.0)); + let data = Array::random(len, Uniform::new(0.0, 1.0).unwrap()); + let mut weights = Array::random(len, Uniform::new(0.0, 1.0).unwrap()); weights /= weights.sum(); b.iter_batched( || data.clone(), diff --git a/src/correlation.rs b/src/correlation.rs index 5ae194ba..613b5f25 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -1,14 +1,10 @@ use crate::errors::EmptyInput; use ndarray::prelude::*; -use ndarray::Data; use num_traits::{Float, FromPrimitive}; -/// Extension trait for `ArrayBase` providing functions +/// Extension trait for `ndarray` providing functions /// to compute different correlation measures. -pub trait CorrelationExt -where - S: Data, -{ +pub trait CorrelationExt { /// Return the covariance matrix `C` for a 2-dimensional /// array of observations `M`. /// @@ -125,10 +121,7 @@ where private_decl! {} } -impl CorrelationExt for ArrayBase -where - S: Data, -{ +impl CorrelationExt for ArrayRef2 { fn cov(&self, ddof: A) -> Result, EmptyInput> where A: Float + FromPrimitive, @@ -147,7 +140,7 @@ where let mean = self.mean_axis(observation_axis); match mean { Some(mean) => { - let denoised = self - &mean.insert_axis(observation_axis); + let denoised = self - mean.insert_axis(observation_axis); let covariance = denoised.dot(&denoised.t()); Ok(covariance.mapv_into(|x| x / dof)) } @@ -208,7 +201,7 @@ mod cov_tests { let n_observations = 4; let a = Array::random( (n_random_variables, n_observations), - Uniform::new(-bound.abs(), bound.abs()), + Uniform::new(-bound.abs(), bound.abs()).unwrap(), ); let covariance = a.cov(1.).unwrap(); abs_diff_eq!(covariance, &covariance.t(), epsilon = 1e-8) @@ -219,7 +212,10 @@ mod cov_tests { fn test_invalid_ddof() { let n_random_variables = 3; let n_observations = 4; - let a = Array::random((n_random_variables, n_observations), Uniform::new(0., 10.)); + let a = Array::random( + (n_random_variables, n_observations), + Uniform::new(0., 10.).unwrap(), + ); let invalid_ddof = (n_observations as f64) + rand::random::().abs(); let _ = a.cov(invalid_ddof); } @@ -299,7 +295,7 @@ mod pearson_correlation_tests { let n_observations = 4; let a = Array::random( (n_random_variables, n_observations), - Uniform::new(-bound.abs(), bound.abs()), + Uniform::new(-bound.abs(), bound.abs()).unwrap(), ); let pearson_correlation = a.pearson_correlation().unwrap(); abs_diff_eq!( diff --git a/src/deviation.rs b/src/deviation.rs index de85885f..3c357466 100644 --- a/src/deviation.rs +++ b/src/deviation.rs @@ -1,15 +1,14 @@ -use ndarray::{ArrayBase, Data, Dimension, Zip}; +use ndarray::{ArrayRef, Dimension, Zip}; use num_traits::{Signed, ToPrimitive}; use std::convert::Into; use std::ops::AddAssign; use crate::errors::MultiInputError; -/// An extension trait for `ArrayBase` providing functions +/// An extension trait for `ndarray` providing functions /// to compute different deviation measures. -pub trait DeviationExt +pub trait DeviationExt where - S: Data, D: Dimension, { /// Counts the number of indices at which the elements of the arrays `self` @@ -19,10 +18,9 @@ where /// /// * `MultiInputError::EmptyInput` if `self` is empty /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape - fn count_eq(&self, other: &ArrayBase) -> Result + fn count_eq(&self, other: &ArrayRef) -> Result where - A: PartialEq, - T: Data; + A: PartialEq; /// Counts the number of indices at which the elements of the arrays `self` /// and `other` are not equal. @@ -31,10 +29,9 @@ where /// /// * `MultiInputError::EmptyInput` if `self` is empty /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape - fn count_neq(&self, other: &ArrayBase) -> Result + fn count_neq(&self, other: &ArrayRef) -> Result where - A: PartialEq, - T: Data; + A: PartialEq; /// Computes the [squared L2 distance] between `self` and `other`. /// @@ -52,10 +49,9 @@ where /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape /// /// [squared L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance#Squared_Euclidean_distance - fn sq_l2_dist(&self, other: &ArrayBase) -> Result + fn sq_l2_dist(&self, other: &ArrayRef) -> Result where - A: AddAssign + Clone + Signed, - T: Data; + A: AddAssign + Clone + Signed; /// Computes the [L2 distance] between `self` and `other`. /// @@ -75,10 +71,9 @@ where /// **Panics** if the type cast from `A` to `f64` fails. /// /// [L2 distance]: https://en.wikipedia.org/wiki/Euclidean_distance - fn l2_dist(&self, other: &ArrayBase) -> Result + fn l2_dist(&self, other: &ArrayRef) -> Result where - A: AddAssign + Clone + Signed + ToPrimitive, - T: Data; + A: AddAssign + Clone + Signed + ToPrimitive; /// Computes the [L1 distance] between `self` and `other`. /// @@ -96,10 +91,9 @@ where /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape /// /// [L1 distance]: https://en.wikipedia.org/wiki/Taxicab_geometry - fn l1_dist(&self, other: &ArrayBase) -> Result + fn l1_dist(&self, other: &ArrayRef) -> Result where - A: AddAssign + Clone + Signed, - T: Data; + A: AddAssign + Clone + Signed; /// Computes the [L∞ distance] between `self` and `other`. /// @@ -116,10 +110,9 @@ where /// * `MultiInputError::ShapeMismatch` if `self` and `other` don't have the same shape /// /// [L∞ distance]: https://en.wikipedia.org/wiki/Chebyshev_distance - fn linf_dist(&self, other: &ArrayBase) -> Result + fn linf_dist(&self, other: &ArrayRef) -> Result where - A: Clone + PartialOrd + Signed, - T: Data; + A: Clone + PartialOrd + Signed; /// Computes the [mean absolute error] between `self` and `other`. /// @@ -139,10 +132,9 @@ where /// **Panics** if the type cast from `A` to `f64` fails. /// /// [mean absolute error]: https://en.wikipedia.org/wiki/Mean_absolute_error - fn mean_abs_err(&self, other: &ArrayBase) -> Result + fn mean_abs_err(&self, other: &ArrayRef) -> Result where - A: AddAssign + Clone + Signed + ToPrimitive, - T: Data; + A: AddAssign + Clone + Signed + ToPrimitive; /// Computes the [mean squared error] between `self` and `other`. /// @@ -162,10 +154,9 @@ where /// **Panics** if the type cast from `A` to `f64` fails. /// /// [mean squared error]: https://en.wikipedia.org/wiki/Mean_squared_error - fn mean_sq_err(&self, other: &ArrayBase) -> Result + fn mean_sq_err(&self, other: &ArrayRef) -> Result where - A: AddAssign + Clone + Signed + ToPrimitive, - T: Data; + A: AddAssign + Clone + Signed + ToPrimitive; /// Computes the unnormalized [root-mean-square error] between `self` and `other`. /// @@ -183,10 +174,9 @@ where /// **Panics** if the type cast from `A` to `f64` fails. /// /// [root-mean-square error]: https://en.wikipedia.org/wiki/Root-mean-square_deviation - fn root_mean_sq_err(&self, other: &ArrayBase) -> Result + fn root_mean_sq_err(&self, other: &ArrayRef) -> Result where - A: AddAssign + Clone + Signed + ToPrimitive, - T: Data; + A: AddAssign + Clone + Signed + ToPrimitive; /// Computes the [peak signal-to-noise ratio] between `self` and `other`. /// @@ -205,27 +195,24 @@ where /// **Panics** if the type cast from `A` to `f64` fails. /// /// [peak signal-to-noise ratio]: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio - fn peak_signal_to_noise_ratio( + fn peak_signal_to_noise_ratio( &self, - other: &ArrayBase, + other: &ArrayRef, maxv: A, ) -> Result where - A: AddAssign + Clone + Signed + ToPrimitive, - T: Data; + A: AddAssign + Clone + Signed + ToPrimitive; private_decl! {} } -impl DeviationExt for ArrayBase +impl DeviationExt for ArrayRef where - S: Data, D: Dimension, { - fn count_eq(&self, other: &ArrayBase) -> Result + fn count_eq(&self, other: &ArrayRef) -> Result where A: PartialEq, - T: Data, { return_err_if_empty!(self); return_err_unless_same_shape!(self, other); @@ -241,18 +228,16 @@ where Ok(count) } - fn count_neq(&self, other: &ArrayBase) -> Result + fn count_neq(&self, other: &ArrayRef) -> Result where A: PartialEq, - T: Data, { self.count_eq(other).map(|n_eq| self.len() - n_eq) } - fn sq_l2_dist(&self, other: &ArrayBase) -> Result + fn sq_l2_dist(&self, other: &ArrayRef) -> Result where A: AddAssign + Clone + Signed, - T: Data, { return_err_if_empty!(self); return_err_unless_same_shape!(self, other); @@ -268,10 +253,9 @@ where Ok(result) } - fn l2_dist(&self, other: &ArrayBase) -> Result + fn l2_dist(&self, other: &ArrayRef) -> Result where A: AddAssign + Clone + Signed + ToPrimitive, - T: Data, { let sq_l2_dist = self .sq_l2_dist(other)? @@ -281,10 +265,9 @@ where Ok(sq_l2_dist.sqrt()) } - fn l1_dist(&self, other: &ArrayBase) -> Result + fn l1_dist(&self, other: &ArrayRef) -> Result where A: AddAssign + Clone + Signed, - T: Data, { return_err_if_empty!(self); return_err_unless_same_shape!(self, other); @@ -299,10 +282,9 @@ where Ok(result) } - fn linf_dist(&self, other: &ArrayBase) -> Result + fn linf_dist(&self, other: &ArrayRef) -> Result where A: Clone + PartialOrd + Signed, - T: Data, { return_err_if_empty!(self); return_err_unless_same_shape!(self, other); @@ -320,10 +302,9 @@ where Ok(max) } - fn mean_abs_err(&self, other: &ArrayBase) -> Result + fn mean_abs_err(&self, other: &ArrayRef) -> Result where A: AddAssign + Clone + Signed + ToPrimitive, - T: Data, { let l1_dist = self .l1_dist(other)? @@ -334,10 +315,9 @@ where Ok(l1_dist / n) } - fn mean_sq_err(&self, other: &ArrayBase) -> Result + fn mean_sq_err(&self, other: &ArrayRef) -> Result where A: AddAssign + Clone + Signed + ToPrimitive, - T: Data, { let sq_l2_dist = self .sq_l2_dist(other)? @@ -348,23 +328,21 @@ where Ok(sq_l2_dist / n) } - fn root_mean_sq_err(&self, other: &ArrayBase) -> Result + fn root_mean_sq_err(&self, other: &ArrayRef) -> Result where A: AddAssign + Clone + Signed + ToPrimitive, - T: Data, { let msd = self.mean_sq_err(other)?; Ok(msd.sqrt()) } - fn peak_signal_to_noise_ratio( + fn peak_signal_to_noise_ratio( &self, - other: &ArrayBase, + other: &ArrayRef, maxv: A, ) -> Result where A: AddAssign + Clone + Signed + ToPrimitive, - T: Data, { let maxv_f = maxv.to_f64().expect("failed cast from type A to f64"); let msd = self.mean_sq_err(&other)?; diff --git a/src/entropy.rs b/src/entropy.rs index e029729b..4ba9972c 100644 --- a/src/entropy.rs +++ b/src/entropy.rs @@ -1,14 +1,13 @@ //! Information theory (e.g. entropy, KL divergence, etc.). use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch}; -use ndarray::{Array, ArrayBase, Data, Dimension, Zip}; +use ndarray::{Array, ArrayRef, Dimension, Zip}; use num_traits::Float; -/// Extension trait for `ArrayBase` providing methods +/// Extension trait for `ndarray` providing methods /// to compute information theory quantities /// (e.g. entropy, Kullback–Leibler divergence, etc.). -pub trait EntropyExt +pub trait EntropyExt where - S: Data, D: Dimension, { /// Computes the [entropy] *S* of the array values, defined as @@ -74,9 +73,8 @@ where /// /// [Kullback-Leibler divergence]: https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory - fn kl_divergence(&self, q: &ArrayBase) -> Result + fn kl_divergence(&self, q: &ArrayRef) -> Result where - S2: Data, A: Float; /// Computes the [cross entropy] *H(p,q)* between two arrays, @@ -116,17 +114,15 @@ where /// [Information Theory]: https://en.wikipedia.org/wiki/Information_theory /// [optimization problems]: https://en.wikipedia.org/wiki/Cross-entropy_method /// [machine learning]: https://en.wikipedia.org/wiki/Cross_entropy#Cross-entropy_error_function_and_logistic_regression - fn cross_entropy(&self, q: &ArrayBase) -> Result + fn cross_entropy(&self, q: &ArrayRef) -> Result where - S2: Data, A: Float; private_decl! {} } -impl EntropyExt for ArrayBase +impl EntropyExt for ArrayRef where - S: Data, D: Dimension, { fn entropy(&self) -> Result @@ -149,10 +145,9 @@ where } } - fn kl_divergence(&self, q: &ArrayBase) -> Result + fn kl_divergence(&self, q: &ArrayRef) -> Result where A: Float, - S2: Data, { if self.is_empty() { return Err(MultiInputError::EmptyInput); @@ -182,9 +177,8 @@ where Ok(kl_divergence) } - fn cross_entropy(&self, q: &ArrayBase) -> Result + fn cross_entropy(&self, q: &ArrayRef) -> Result where - S2: Data, A: Float, { if self.is_empty() { diff --git a/src/histogram/grid.rs b/src/histogram/grid.rs index 57e85061..1423fcf4 100644 --- a/src/histogram/grid.rs +++ b/src/histogram/grid.rs @@ -2,7 +2,7 @@ use super::{bins::Bins, errors::BinsBuildError, strategies::BinsBuildingStrategy}; use itertools::izip; -use ndarray::{ArrayBase, Axis, Data, Ix1, Ix2}; +use ndarray::{ArrayRef, Axis, Ix1, Ix2}; use std::ops::Range; /// An orthogonal partition of a rectangular region in an *n*-dimensional space, e.g. @@ -200,10 +200,7 @@ impl Grid { /// Some(vec![1, 0, 1]), /// ); /// ``` - pub fn index_of(&self, point: &ArrayBase) -> Option> - where - S: Data, - { + pub fn index_of(&self, point: &ArrayRef) -> Option> { assert_eq!( point.len(), self.ndim(), @@ -337,10 +334,7 @@ where /// [`strategy`]: strategies/index.html /// [`BinsBuildError`]: errors/enum.BinsBuildError.html /// [Trait-level examples]: struct.GridBuilder.html#examples - pub fn from_array(array: &ArrayBase) -> Result - where - S: Data, - { + pub fn from_array(array: &ArrayRef) -> Result { let bin_builders = array .axis_iter(Axis(1)) .map(|data| B::from_array(&data)) diff --git a/src/histogram/histograms.rs b/src/histogram/histograms.rs index 603a5019..831b23dc 100644 --- a/src/histogram/histograms.rs +++ b/src/histogram/histograms.rs @@ -1,7 +1,6 @@ use super::errors::BinNotFound; use super::grid::Grid; use ndarray::prelude::*; -use ndarray::Data; /// Histogram data structure. pub struct Histogram { @@ -45,10 +44,7 @@ impl Histogram { /// assert_eq!(histogram_matrix, expected.into_dyn()); /// # Ok::<(), Box>(()) /// ``` - pub fn add_observation(&mut self, observation: &ArrayBase) -> Result<(), BinNotFound> - where - S: Data, - { + pub fn add_observation(&mut self, observation: &ArrayRef) -> Result<(), BinNotFound> { match self.grid.index_of(observation) { Some(bin_index) => { self.counts[&*bin_index] += 1; @@ -75,11 +71,8 @@ impl Histogram { } } -/// Extension trait for `ArrayBase` providing methods to compute histograms. -pub trait HistogramExt -where - S: Data, -{ +/// Extension trait for `ArrayRef` providing methods to compute histograms. +pub trait HistogramExt { /// Returns the [histogram](https://en.wikipedia.org/wiki/Histogram) /// for a 2-dimensional array of points `M`. /// @@ -141,9 +134,8 @@ where private_decl! {} } -impl HistogramExt for ArrayBase +impl HistogramExt for ArrayRef where - S: Data, A: Ord, { fn histogram(&self, grid: Grid) -> Histogram { diff --git a/src/histogram/strategies.rs b/src/histogram/strategies.rs index a1522109..44255981 100644 --- a/src/histogram/strategies.rs +++ b/src/histogram/strategies.rs @@ -50,7 +50,7 @@ use crate::{ histogram::{errors::BinsBuildError, Bins, Edges}, quantile::{interpolate::Nearest, Quantile1dExt, QuantileExt}, }; -use ndarray::{prelude::*, Data}; +use ndarray::prelude::*; use noisy_float::types::n64; use num_traits::{FromPrimitive, NumOps, Zero}; @@ -75,9 +75,8 @@ pub trait BinsBuildingStrategy { /// See each of the struct-level documentation for details on errors an implementor may return. /// /// [`Bins`]: ../struct.Bins.html - fn from_array(array: &ArrayBase) -> Result + fn from_array(array: &ArrayRef) -> Result where - S: Data, Self: std::marker::Sized; /// Returns a [`Bins`] instance, according to parameters inferred from observations. @@ -263,10 +262,7 @@ where /// Returns `Err(BinsBuildError::Strategy)` if the array is constant. /// Returns `Err(BinsBuildError::EmptyInput)` if `a.len()==0`. /// Returns `Ok(Self)` otherwise. - fn from_array(a: &ArrayBase) -> Result - where - S: Data, - { + fn from_array(a: &ArrayRef) -> Result { let n_elems = a.len(); // casting `n_elems: usize` to `f64` may casus off-by-one error here if `n_elems` > 2 ^ 53, // but it's not relevant here @@ -309,10 +305,7 @@ where /// Returns `Err(BinsBuildError::Strategy)` if the array is constant. /// Returns `Err(BinsBuildError::EmptyInput)` if `a.len()==0`. /// Returns `Ok(Self)` otherwise. - fn from_array(a: &ArrayBase) -> Result - where - S: Data, - { + fn from_array(a: &ArrayRef) -> Result { let n_elems = a.len(); // casting `n_elems: usize` to `f64` may casus off-by-one error here if `n_elems` > 2 ^ 53, // but it's not relevant here @@ -355,10 +348,7 @@ where /// Returns `Err(BinsBuildError::Strategy)` if the array is constant. /// Returns `Err(BinsBuildError::EmptyInput)` if `a.len()==0`. /// Returns `Ok(Self)` otherwise. - fn from_array(a: &ArrayBase) -> Result - where - S: Data, - { + fn from_array(a: &ArrayRef) -> Result { let n_elems = a.len(); // casting `n_elems: usize` to `f64` may casus off-by-one error here if `n_elems` > 2 ^ 53, // but it's not relevant here @@ -401,10 +391,7 @@ where /// Returns `Err(BinsBuildError::Strategy)` if `IQR==0`. /// Returns `Err(BinsBuildError::EmptyInput)` if `a.len()==0`. /// Returns `Ok(Self)` otherwise. - fn from_array(a: &ArrayBase) -> Result - where - S: Data, - { + fn from_array(a: &ArrayRef) -> Result { let n_points = a.len(); if n_points == 0 { return Err(BinsBuildError::EmptyInput); @@ -458,10 +445,7 @@ where /// Returns `Err(BinsBuildError::Strategy)` if `IQR==0`. /// Returns `Err(BinsBuildError::EmptyInput)` if `a.len()==0`. /// Returns `Ok(Self)` otherwise. - fn from_array(a: &ArrayBase) -> Result - where - S: Data, - { + fn from_array(a: &ArrayRef) -> Result { let fd_builder = FreedmanDiaconis::from_array(&a); let sturges_builder = Sturges::from_array(&a); match (fd_builder, sturges_builder) { diff --git a/src/lib.rs b/src/lib.rs index 4ae11004..ef4c5c4b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -//! The [`ndarray-stats`] crate exposes statistical routines for `ArrayBase`, +//! The [`ndarray-stats`] crate exposes statistical routines for `ArrayRef`, //! the *n*-dimensional array data structure provided by [`ndarray`]. //! //! Currently available routines include: diff --git a/src/maybe_nan/mod.rs b/src/maybe_nan/mod.rs index 02cce16d..b0da02c8 100644 --- a/src/maybe_nan/mod.rs +++ b/src/maybe_nan/mod.rs @@ -1,5 +1,5 @@ use ndarray::prelude::*; -use ndarray::{s, Data, DataMut, RemoveAxis}; +use ndarray::{s, RemoveAxis}; use noisy_float::types::{N32, N64}; use std::mem; @@ -260,11 +260,10 @@ impl NotNone { } } -/// Extension trait for `ArrayBase` providing NaN-related functionality. -pub trait MaybeNanExt +/// Extension trait for `ArrayRef` providing NaN-related functionality. +pub trait MaybeNanExt where A: MaybeNan, - S: Data, D: Dimension, { /// Traverse the non-NaN array elements and apply a fold, returning the @@ -321,17 +320,15 @@ where fn map_axis_skipnan_mut<'a, B, F>(&'a mut self, axis: Axis, mapping: F) -> Array where A: 'a, - S: DataMut, D: RemoveAxis, F: FnMut(ArrayViewMut1<'a, A::NotNan>) -> B; private_decl! {} } -impl MaybeNanExt for ArrayBase +impl MaybeNanExt for ArrayRef where A: MaybeNan, - S: Data, D: Dimension, { fn fold_skipnan<'a, F, B>(&'a self, init: B, mut f: F) -> B @@ -396,7 +393,6 @@ where ) -> Array where A: 'a, - S: DataMut, D: RemoveAxis, F: FnMut(ArrayViewMut1<'a, A::NotNan>) -> B, { diff --git a/src/quantile/mod.rs b/src/quantile/mod.rs index 3fea4a65..94b41602 100644 --- a/src/quantile/mod.rs +++ b/src/quantile/mod.rs @@ -4,14 +4,13 @@ use crate::errors::QuantileError; use crate::errors::{EmptyInput, MinMaxError, MinMaxError::UndefinedOrder}; use crate::{MaybeNan, MaybeNanExt}; use ndarray::prelude::*; -use ndarray::{Data, DataMut, RemoveAxis, Zip}; +use ndarray::{RemoveAxis, Zip}; use noisy_float::types::N64; use std::cmp; -/// Quantile methods for `ArrayBase`. -pub trait QuantileExt +/// Quantile methods for `ArrayRef`. +pub trait QuantileExt where - S: Data, D: Dimension, { /// Finds the index of the minimum value of the array. @@ -214,7 +213,6 @@ where where D: RemoveAxis, A: Ord + Clone, - S: DataMut, I: Interpolate; /// A bulk version of [`quantile_axis_mut`], optimized to retrieve multiple @@ -249,17 +247,15 @@ where /// assert_eq!(quantile, data.quantile_axis_mut(axis, q, &Nearest).unwrap()); /// } /// ``` - fn quantiles_axis_mut( + fn quantiles_axis_mut( &mut self, axis: Axis, - qs: &ArrayBase, + qs: &ArrayRef, interpolate: &I, ) -> Result, QuantileError> where D: RemoveAxis, A: Ord + Clone, - S: DataMut, - S2: Data, I: Interpolate; /// Return the `q`th quantile of the data along the specified axis, skipping NaN values. @@ -275,15 +271,13 @@ where D: RemoveAxis, A: MaybeNan, A::NotNan: Clone + Ord, - S: DataMut, I: Interpolate; private_decl! {} } -impl QuantileExt for ArrayBase +impl QuantileExt for ArrayRef where - S: Data, D: Dimension, { fn argmin(&self) -> Result @@ -420,17 +414,15 @@ where })) } - fn quantiles_axis_mut( + fn quantiles_axis_mut( &mut self, axis: Axis, - qs: &ArrayBase, + qs: &ArrayRef, interpolate: &I, ) -> Result, QuantileError> where D: RemoveAxis, A: Ord + Clone, - S: DataMut, - S2: Data, I: Interpolate, { // Minimize number of type parameters to avoid monomorphization bloat. @@ -509,7 +501,6 @@ where where D: RemoveAxis, A: Ord + Clone, - S: DataMut, I: Interpolate, { self.quantiles_axis_mut(axis, &aview1(&[q]), interpolate) @@ -526,7 +517,6 @@ where D: RemoveAxis, A: MaybeNan, A::NotNan: Clone + Ord, - S: DataMut, I: Interpolate, { if !((q >= 0.) && (q <= 1.)) { @@ -557,10 +547,7 @@ where } /// Quantile methods for 1-D arrays. -pub trait Quantile1dExt -where - S: Data, -{ +pub trait Quantile1dExt { /// Return the qth quantile of the data. /// /// `q` needs to be a float between 0 and 1, bounds included. @@ -593,7 +580,6 @@ where fn quantile_mut(&mut self, q: N64, interpolate: &I) -> Result where A: Ord + Clone, - S: DataMut, I: Interpolate; /// A bulk version of [`quantile_mut`], optimized to retrieve multiple @@ -611,28 +597,22 @@ where /// used to retrieve them. /// /// [`quantile_mut`]: #tymethod.quantile_mut - fn quantiles_mut( + fn quantiles_mut( &mut self, - qs: &ArrayBase, + qs: &ArrayRef, interpolate: &I, ) -> Result, QuantileError> where A: Ord + Clone, - S: DataMut, - S2: Data, I: Interpolate; private_decl! {} } -impl Quantile1dExt for ArrayBase -where - S: Data, -{ +impl Quantile1dExt for ArrayRef { fn quantile_mut(&mut self, q: N64, interpolate: &I) -> Result where A: Ord + Clone, - S: DataMut, I: Interpolate, { Ok(self @@ -640,15 +620,13 @@ where .into_scalar()) } - fn quantiles_mut( + fn quantiles_mut( &mut self, - qs: &ArrayBase, + qs: &ArrayRef, interpolate: &I, ) -> Result, QuantileError> where A: Ord + Clone, - S: DataMut, - S2: Data, I: Interpolate, { self.quantiles_axis_mut(Axis(0), qs, interpolate) diff --git a/src/sort.rs b/src/sort.rs index f43a95b1..e38e205b 100644 --- a/src/sort.rs +++ b/src/sort.rs @@ -1,14 +1,11 @@ use indexmap::IndexMap; use ndarray::prelude::*; -use ndarray::{Data, DataMut, Slice}; +use ndarray::Slice; use rand::prelude::*; use rand::thread_rng; /// Methods for sorting and partitioning 1-D arrays. -pub trait Sort1dExt -where - S: Data, -{ +pub trait Sort1dExt { /// Return the element that would occupy the `i`-th position if /// the array were sorted in increasing order. /// @@ -30,8 +27,7 @@ where /// **Panics** if `i` is greater than or equal to `n`. fn get_from_sorted_mut(&mut self, i: usize) -> A where - A: Ord + Clone, - S: DataMut; + A: Ord + Clone; /// A bulk version of [`get_from_sorted_mut`], optimized to retrieve multiple /// indexes at once. @@ -44,11 +40,9 @@ where /// where `n` is the length of the array.. /// /// [`get_from_sorted_mut`]: #tymethod.get_from_sorted_mut - fn get_many_from_sorted_mut(&mut self, indexes: &ArrayBase) -> IndexMap + fn get_many_from_sorted_mut(&mut self, indexes: &ArrayRef1) -> IndexMap where - A: Ord + Clone, - S: DataMut, - S2: Data; + A: Ord + Clone; /// Partitions the array in increasing order based on the value initially /// located at `pivot_index` and returns the new index of the value. @@ -96,20 +90,15 @@ where /// ``` fn partition_mut(&mut self, pivot_index: usize) -> usize where - A: Ord + Clone, - S: DataMut; + A: Ord + Clone; private_decl! {} } -impl Sort1dExt for ArrayBase -where - S: Data, -{ +impl Sort1dExt for ArrayRef { fn get_from_sorted_mut(&mut self, i: usize) -> A where A: Ord + Clone, - S: DataMut, { let n = self.len(); if n == 1 { @@ -130,11 +119,9 @@ where } } - fn get_many_from_sorted_mut(&mut self, indexes: &ArrayBase) -> IndexMap + fn get_many_from_sorted_mut(&mut self, indexes: &ArrayRef1) -> IndexMap where A: Ord + Clone, - S: DataMut, - S2: Data, { let mut deduped_indexes: Vec = indexes.to_vec(); deduped_indexes.sort_unstable(); @@ -146,7 +133,6 @@ where fn partition_mut(&mut self, pivot_index: usize) -> usize where A: Ord + Clone, - S: DataMut, { let pivot_value = self[pivot_index].clone(); self.swap(pivot_index, 0); @@ -195,13 +181,12 @@ where /// using the same indexes. /// /// [get_many_from_sorted_mut]: ../trait.Sort1dExt.html#tymethod.get_many_from_sorted_mut -pub(crate) fn get_many_from_sorted_mut_unchecked( - array: &mut ArrayBase, +pub(crate) fn get_many_from_sorted_mut_unchecked( + array: &mut ArrayRef1, indexes: &[usize], ) -> IndexMap where A: Ord + Clone, - S: DataMut, { if indexes.is_empty() { return IndexMap::new(); diff --git a/src/summary_statistics/means.rs b/src/summary_statistics/means.rs index d5226263..92cd1014 100644 --- a/src/summary_statistics/means.rs +++ b/src/summary_statistics/means.rs @@ -1,13 +1,12 @@ use super::SummaryStatisticsExt; use crate::errors::{EmptyInput, MultiInputError, ShapeMismatch}; -use ndarray::{Array, ArrayBase, Axis, Data, Dimension, Ix1, RemoveAxis}; +use ndarray::{Array, ArrayBase, ArrayRef, Axis, Data, Dimension, Ix1, RemoveAxis}; use num_integer::IterBinomial; use num_traits::{Float, FromPrimitive, Zero}; use std::ops::{Add, AddAssign, Div, Mul}; -impl SummaryStatisticsExt for ArrayBase +impl SummaryStatisticsExt for ArrayRef where - S: Data, D: Dimension, { fn mean(&self) -> Result @@ -33,7 +32,7 @@ where Ok(weighted_sum / weights.sum()) } - fn weighted_sum(&self, weights: &ArrayBase) -> Result + fn weighted_sum(&self, weights: &ArrayRef) -> Result where A: Copy + Mul + Zero, { @@ -47,7 +46,7 @@ where fn weighted_mean_axis( &self, axis: Axis, - weights: &ArrayBase, + weights: &ArrayRef, ) -> Result, MultiInputError> where A: Copy + Div + Mul + Zero, @@ -63,7 +62,7 @@ where fn weighted_sum_axis( &self, axis: Axis, - weights: &ArrayBase, + weights: &ArrayRef, ) -> Result, MultiInputError> where A: Copy + Mul + Zero, @@ -130,7 +129,7 @@ where fn weighted_var_axis( &self, axis: Axis, - weights: &ArrayBase, + weights: &ArrayRef, ddof: A, ) -> Result, MultiInputError> where @@ -161,7 +160,7 @@ where fn weighted_std_axis( &self, axis: Axis, - weights: &ArrayBase, + weights: &ArrayRef, ddof: A, ) -> Result, MultiInputError> where @@ -245,14 +244,13 @@ where } /// Private function for `weighted_var` without conditions and asserts. -fn inner_weighted_var( - arr: &ArrayBase, - weights: &ArrayBase, +fn inner_weighted_var( + arr: &ArrayRef, + weights: &ArrayRef, ddof: A, zero: A, ) -> Result where - S: Data, A: AddAssign + Float + FromPrimitive, D: Dimension, { diff --git a/src/summary_statistics/mod.rs b/src/summary_statistics/mod.rs index 1f8fe000..e31dde1d 100644 --- a/src/summary_statistics/mod.rs +++ b/src/summary_statistics/mod.rs @@ -1,14 +1,13 @@ //! Summary statistics (e.g. mean, variance, etc.). use crate::errors::{EmptyInput, MultiInputError}; -use ndarray::{Array, ArrayBase, Axis, Data, Dimension, Ix1, RemoveAxis}; +use ndarray::{Array, ArrayRef, Axis, Dimension, Ix1, RemoveAxis}; use num_traits::{Float, FromPrimitive, Zero}; use std::ops::{Add, AddAssign, Div, Mul}; -/// Extension trait for `ArrayBase` providing methods +/// Extension trait for `ArrayRef` providing methods /// to compute several summary statistics (e.g. mean, variance, etc.). -pub trait SummaryStatisticsExt +pub trait SummaryStatisticsExt where - S: Data, D: Dimension, { /// Returns the [`arithmetic mean`] x̅ of all elements in the array: @@ -93,7 +92,7 @@ where fn weighted_mean_axis( &self, axis: Axis, - weights: &ArrayBase, + weights: &ArrayRef, ) -> Result, MultiInputError> where A: Copy + Div + Mul + Zero, @@ -116,7 +115,7 @@ where fn weighted_sum_axis( &self, axis: Axis, - weights: &ArrayBase, + weights: &ArrayRef, ) -> Result, MultiInputError> where A: Copy + Mul + Zero, @@ -203,7 +202,7 @@ where fn weighted_var_axis( &self, axis: Axis, - weights: &ArrayBase, + weights: &ArrayRef, ddof: A, ) -> Result, MultiInputError> where @@ -225,7 +224,7 @@ where fn weighted_std_axis( &self, axis: Axis, - weights: &ArrayBase, + weights: &ArrayRef, ddof: A, ) -> Result, MultiInputError> where diff --git a/tests/summary_statistics.rs b/tests/summary_statistics.rs index 5269e332..1062a09a 100644 --- a/tests/summary_statistics.rs +++ b/tests/summary_statistics.rs @@ -297,7 +297,7 @@ fn weighted_var_algo_eq_simple_algo() { for axis in 0..3 { let axis = Axis(axis); - let weights = Array::random(a.len_of(axis), Uniform::new(0.0, 1.0)); + let weights = Array::random(a.len_of(axis), Uniform::new(0.0, 1.0).unwrap()); let mean = a .weighted_mean_axis(axis, &weights) .unwrap() @@ -327,7 +327,7 @@ fn test_central_moment_with_empty_array_of_floats() { fn test_zeroth_central_moment_is_one() { let n = 50; let bound: f64 = 200.; - let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); + let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()).unwrap()); assert_eq!(a.central_moment(0).unwrap(), 1.); } @@ -335,7 +335,7 @@ fn test_zeroth_central_moment_is_one() { fn test_first_central_moment_is_zero() { let n = 50; let bound: f64 = 200.; - let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); + let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()).unwrap()); assert_eq!(a.central_moment(1).unwrap(), 0.); } @@ -374,7 +374,7 @@ fn test_bulk_central_moments() { // Test that the bulk method is coherent with the non-bulk method let n = 50; let bound: f64 = 200.; - let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs())); + let a = Array::random(n, Uniform::new(-bound.abs(), bound.abs()).unwrap()); let order = 10; let central_moments = a.central_moments(order).unwrap(); for i in 0..=order {