2929//! that the items are not compatible (e.g. that a type doesn't implement a
3030//! necessary trait).
3131
32- use crate :: rand:: distributions:: Distribution ;
32+ use crate :: rand:: distributions:: { Distribution , Uniform } ;
3333use crate :: rand:: rngs:: SmallRng ;
34+ use crate :: rand:: seq:: index;
3435use crate :: rand:: { thread_rng, Rng , SeedableRng } ;
3536
36- use ndarray:: ShapeBuilder ;
37+ use ndarray:: { Array , Axis , RemoveAxis , ShapeBuilder } ;
3738use ndarray:: { ArrayBase , DataOwned , Dimension } ;
39+ #[ cfg( feature = "quickcheck" ) ]
40+ use quickcheck:: { Arbitrary , Gen } ;
3841
3942/// [`rand`](https://docs.rs/rand/0.7), re-exported for convenience and version-compatibility.
4043pub mod rand {
@@ -59,9 +62,9 @@ pub mod rand_distr {
5962/// low-quality random numbers, and reproducibility is not guaranteed. See its
6063/// documentation for information. You can select a different RNG with
6164/// [`.random_using()`](#tymethod.random_using).
62- pub trait RandomExt < S , D >
65+ pub trait RandomExt < S , A , D >
6366where
64- S : DataOwned ,
67+ S : DataOwned < Elem = A > ,
6568 D : Dimension ,
6669{
6770 /// Create an array with shape `dim` with elements drawn from
@@ -116,21 +119,125 @@ where
116119 IdS : Distribution < S :: Elem > ,
117120 R : Rng + ?Sized ,
118121 Sh : ShapeBuilder < Dim = D > ;
122+
123+ /// Sample `n_samples` lanes slicing along `axis` using the default RNG.
124+ ///
125+ /// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
126+ /// If `strategy==SamplingStrategy::WithReplacement`, each lane can be sampled multiple times.
127+ ///
128+ /// ***Panics*** when:
129+ /// - creation of the RNG fails;
130+ /// - `n_samples` is greater than the length of `axis` (if sampling without replacement);
131+ /// - length of `axis` is 0.
132+ ///
133+ /// ```
134+ /// use ndarray::{array, Axis};
135+ /// use ndarray_rand::{RandomExt, SamplingStrategy};
136+ ///
137+ /// # fn main() {
138+ /// let a = array![
139+ /// [1., 2., 3.],
140+ /// [4., 5., 6.],
141+ /// [7., 8., 9.],
142+ /// [10., 11., 12.],
143+ /// ];
144+ /// // Sample 2 rows, without replacement
145+ /// let sample_rows = a.sample_axis(Axis(0), 2, SamplingStrategy::WithoutReplacement);
146+ /// println!("{:?}", sample_rows);
147+ /// // Example Output: (1st and 3rd rows)
148+ /// // [
149+ /// // [1., 2., 3.],
150+ /// // [7., 8., 9.]
151+ /// // ]
152+ /// // Sample 2 columns, with replacement
153+ /// let sample_columns = a.sample_axis(Axis(1), 1, SamplingStrategy::WithReplacement);
154+ /// println!("{:?}", sample_columns);
155+ /// // Example Output: (2nd column, sampled twice)
156+ /// // [
157+ /// // [2., 2.],
158+ /// // [5., 5.],
159+ /// // [8., 8.],
160+ /// // [11., 11.]
161+ /// // ]
162+ /// # }
163+ /// ```
164+ fn sample_axis ( & self , axis : Axis , n_samples : usize , strategy : SamplingStrategy ) -> Array < A , D >
165+ where
166+ A : Copy ,
167+ D : RemoveAxis ;
168+
169+ /// Sample `n_samples` lanes slicing along `axis` using the specified RNG `rng`.
170+ ///
171+ /// If `strategy==SamplingStrategy::WithoutReplacement`, each lane can only be sampled once.
172+ /// If `strategy==SamplingStrategy::WithReplacement`, each lane can be sampled multiple times.
173+ ///
174+ /// ***Panics*** when:
175+ /// - creation of the RNG fails;
176+ /// - `n_samples` is greater than the length of `axis` (if sampling without replacement);
177+ /// - length of `axis` is 0.
178+ ///
179+ /// ```
180+ /// use ndarray::{array, Axis};
181+ /// use ndarray_rand::{RandomExt, SamplingStrategy};
182+ /// use ndarray_rand::rand::SeedableRng;
183+ /// use rand_isaac::isaac64::Isaac64Rng;
184+ ///
185+ /// # fn main() {
186+ /// // Get a seeded random number generator for reproducibility (Isaac64 algorithm)
187+ /// let seed = 42;
188+ /// let mut rng = Isaac64Rng::seed_from_u64(seed);
189+ ///
190+ /// let a = array![
191+ /// [1., 2., 3.],
192+ /// [4., 5., 6.],
193+ /// [7., 8., 9.],
194+ /// [10., 11., 12.],
195+ /// ];
196+ /// // Sample 2 rows, without replacement
197+ /// let sample_rows = a.sample_axis_using(Axis(0), 2, SamplingStrategy::WithoutReplacement, &mut rng);
198+ /// println!("{:?}", sample_rows);
199+ /// // Example Output: (1st and 3rd rows)
200+ /// // [
201+ /// // [1., 2., 3.],
202+ /// // [7., 8., 9.]
203+ /// // ]
204+ ///
205+ /// // Sample 2 columns, with replacement
206+ /// let sample_columns = a.sample_axis_using(Axis(1), 1, SamplingStrategy::WithReplacement, &mut rng);
207+ /// println!("{:?}", sample_columns);
208+ /// // Example Output: (2nd column, sampled twice)
209+ /// // [
210+ /// // [2., 2.],
211+ /// // [5., 5.],
212+ /// // [8., 8.],
213+ /// // [11., 11.]
214+ /// // ]
215+ /// # }
216+ /// ```
217+ fn sample_axis_using < R > (
218+ & self ,
219+ axis : Axis ,
220+ n_samples : usize ,
221+ strategy : SamplingStrategy ,
222+ rng : & mut R ,
223+ ) -> Array < A , D >
224+ where
225+ R : Rng + ?Sized ,
226+ A : Copy ,
227+ D : RemoveAxis ;
119228}
120229
121- impl < S , D > RandomExt < S , D > for ArrayBase < S , D >
230+ impl < S , A , D > RandomExt < S , A , D > for ArrayBase < S , D >
122231where
123- S : DataOwned ,
232+ S : DataOwned < Elem = A > ,
124233 D : Dimension ,
125234{
126235 fn random < Sh , IdS > ( shape : Sh , dist : IdS ) -> ArrayBase < S , D >
127236 where
128237 IdS : Distribution < S :: Elem > ,
129238 Sh : ShapeBuilder < Dim = D > ,
130239 {
131- let mut rng =
132- SmallRng :: from_rng ( thread_rng ( ) ) . expect ( "create SmallRng from thread_rng failed" ) ;
133- Self :: random_using ( shape, dist, & mut rng)
240+ Self :: random_using ( shape, dist, & mut get_rng ( ) )
134241 }
135242
136243 fn random_using < Sh , IdS , R > ( shape : Sh , dist : IdS , rng : & mut R ) -> ArrayBase < S , D >
@@ -141,6 +248,66 @@ where
141248 {
142249 Self :: from_shape_fn ( shape, |_| dist. sample ( rng) )
143250 }
251+
252+ fn sample_axis ( & self , axis : Axis , n_samples : usize , strategy : SamplingStrategy ) -> Array < A , D >
253+ where
254+ A : Copy ,
255+ D : RemoveAxis ,
256+ {
257+ self . sample_axis_using ( axis, n_samples, strategy, & mut get_rng ( ) )
258+ }
259+
260+ fn sample_axis_using < R > (
261+ & self ,
262+ axis : Axis ,
263+ n_samples : usize ,
264+ strategy : SamplingStrategy ,
265+ rng : & mut R ,
266+ ) -> Array < A , D >
267+ where
268+ R : Rng + ?Sized ,
269+ A : Copy ,
270+ D : RemoveAxis ,
271+ {
272+ let indices: Vec < _ > = match strategy {
273+ SamplingStrategy :: WithReplacement => {
274+ let distribution = Uniform :: from ( 0 ..self . len_of ( axis) ) ;
275+ ( 0 ..n_samples) . map ( |_| distribution. sample ( rng) ) . collect ( )
276+ }
277+ SamplingStrategy :: WithoutReplacement => {
278+ index:: sample ( rng, self . len_of ( axis) , n_samples) . into_vec ( )
279+ }
280+ } ;
281+ self . select ( axis, & indices)
282+ }
283+ }
284+
285+ /// Used as parameter in [`sample_axis`] and [`sample_axis_using`] to determine
286+ /// if lanes from the original array should only be sampled once (*without replacement*) or
287+ /// multiple times (*with replacement*).
288+ ///
289+ /// [`sample_axis`]: trait.RandomExt.html#tymethod.sample_axis
290+ /// [`sample_axis_using`]: trait.RandomExt.html#tymethod.sample_axis_using
291+ #[ derive( Debug , Clone ) ]
292+ pub enum SamplingStrategy {
293+ WithReplacement ,
294+ WithoutReplacement ,
295+ }
296+
297+ // `Arbitrary` enables `quickcheck` to generate random `SamplingStrategy` values for testing.
298+ #[ cfg( feature = "quickcheck" ) ]
299+ impl Arbitrary for SamplingStrategy {
300+ fn arbitrary < G : Gen > ( g : & mut G ) -> Self {
301+ if g. gen_bool ( 0.5 ) {
302+ SamplingStrategy :: WithReplacement
303+ } else {
304+ SamplingStrategy :: WithoutReplacement
305+ }
306+ }
307+ }
308+
309+ fn get_rng ( ) -> SmallRng {
310+ SmallRng :: from_rng ( thread_rng ( ) ) . expect ( "create SmallRng from thread_rng failed" )
144311}
145312
146313/// A wrapper type that allows casting f64 distributions to f32
0 commit comments