Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 70 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ repository = "https://github.com/rust-ndarray/ndarray-stats"
documentation = "https://docs.rs/ndarray-stats/"
readme = "README.md"

description = "Statistical routines for ArrayBase, the n-dimensional array data structure provided by ndarray."
description = "Statistical routines for the n-dimensional array data structures provided by ndarray."

keywords = ["array", "multidimensional", "statistics", "matrix", "ndarray"]
categories = ["data-structures", "science"]

[dependencies]
ndarray = "0.16.0"
ndarray = "0.17.1"
noisy_float = "0.2.0"
num-integer = "0.1"
num-traits = "0.2"
Expand All @@ -29,10 +29,10 @@ itertools = { version = "0.13", default-features = false }
indexmap = "2.4"

[dev-dependencies]
ndarray = { version = "0.16.1", features = ["approx"] }
ndarray = { version = "0.17.1", features = ["approx"] }
criterion = "0.5.1"
quickcheck = { version = "0.9.2", default-features = false }
ndarray-rand = "0.15.0"
ndarray-rand = "0.16.0"
approx = "0.5"
quickcheck_macros = "1.0.0"
num-bigint = "0.4.0"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
[![Crate](https://img.shields.io/crates/v/ndarray-stats.svg)](https://crates.io/crates/ndarray-stats)
[![Documentation](https://docs.rs/ndarray-stats/badge.svg)](https://docs.rs/ndarray-stats)

This crate provides statistical methods for [`ndarray`]'s `ArrayBase` type.
This crate provides statistical methods for [`ndarray`]'s `ArrayRef` type.

Currently available routines include:
- order statistics (minimum, maximum, median, quantiles, etc.);
Expand Down
4 changes: 2 additions & 2 deletions benches/deviation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ fn sq_l2_dist(c: &mut Criterion) {
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
for len in &lens {
group.bench_with_input(format!("{}", len), len, |b, &len| {
let data = Array::random(len, Uniform::new(0.0, 1.0));
let data2 = Array::random(len, Uniform::new(0.0, 1.0));
let data = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
let data2 = Array::random(len, Uniform::new(0.0, 1.0).unwrap());

b.iter(|| black_box(data.sq_l2_dist(&data2).unwrap()))
});
Expand Down
4 changes: 2 additions & 2 deletions benches/summary_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ fn weighted_std(c: &mut Criterion) {
group.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic));
for len in &lens {
group.bench_with_input(format!("{}", len), len, |b, &len| {
let data = Array::random(len, Uniform::new(0.0, 1.0));
let mut weights = Array::random(len, Uniform::new(0.0, 1.0));
let data = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
let mut weights = Array::random(len, Uniform::new(0.0, 1.0).unwrap());
weights /= weights.sum();
b.iter_batched(
|| data.clone(),
Expand Down
24 changes: 10 additions & 14 deletions src/correlation.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
use crate::errors::EmptyInput;
use ndarray::prelude::*;
use ndarray::Data;
use num_traits::{Float, FromPrimitive};

/// Extension trait for `ArrayBase` providing functions
/// Extension trait for `ndarray` providing functions
/// to compute different correlation measures.
pub trait CorrelationExt<A, S>
where
S: Data<Elem = A>,
{
pub trait CorrelationExt<A> {
/// Return the covariance matrix `C` for a 2-dimensional
/// array of observations `M`.
///
Expand Down Expand Up @@ -125,10 +121,7 @@ where
private_decl! {}
}

impl<A: 'static, S> CorrelationExt<A, S> for ArrayBase<S, Ix2>
where
S: Data<Elem = A>,
{
impl<A: 'static> CorrelationExt<A> for ArrayRef2<A> {
fn cov(&self, ddof: A) -> Result<Array2<A>, EmptyInput>
where
A: Float + FromPrimitive,
Expand All @@ -147,7 +140,7 @@ where
let mean = self.mean_axis(observation_axis);
match mean {
Some(mean) => {
let denoised = self - &mean.insert_axis(observation_axis);
let denoised = self - mean.insert_axis(observation_axis);
let covariance = denoised.dot(&denoised.t());
Ok(covariance.mapv_into(|x| x / dof))
}
Expand Down Expand Up @@ -208,7 +201,7 @@ mod cov_tests {
let n_observations = 4;
let a = Array::random(
(n_random_variables, n_observations),
Uniform::new(-bound.abs(), bound.abs()),
Uniform::new(-bound.abs(), bound.abs()).unwrap(),
);
let covariance = a.cov(1.).unwrap();
abs_diff_eq!(covariance, &covariance.t(), epsilon = 1e-8)
Expand All @@ -219,7 +212,10 @@ mod cov_tests {
fn test_invalid_ddof() {
let n_random_variables = 3;
let n_observations = 4;
let a = Array::random((n_random_variables, n_observations), Uniform::new(0., 10.));
let a = Array::random(
(n_random_variables, n_observations),
Uniform::new(0., 10.).unwrap(),
);
let invalid_ddof = (n_observations as f64) + rand::random::<f64>().abs();
let _ = a.cov(invalid_ddof);
}
Expand Down Expand Up @@ -299,7 +295,7 @@ mod pearson_correlation_tests {
let n_observations = 4;
let a = Array::random(
(n_random_variables, n_observations),
Uniform::new(-bound.abs(), bound.abs()),
Uniform::new(-bound.abs(), bound.abs()).unwrap(),
);
let pearson_correlation = a.pearson_correlation().unwrap();
abs_diff_eq!(
Expand Down
Loading
Loading