11use indexmap:: IndexMap ;
22use ndarray:: prelude:: * ;
33use ndarray:: { Data , DataMut , Slice } ;
4- use rand:: prelude:: * ;
5- use rand:: thread_rng;
64
75/// Methods for sorting and partitioning 1-D arrays.
86pub trait Sort1dExt < A , S >
@@ -50,26 +48,21 @@ where
5048 S : DataMut ,
5149 S2 : Data < Elem = usize > ;
5250
53- /// Partitions the array in increasing order based on the value initially
54- /// located at `pivot_index` and returns the new index of the value.
51+ /// Partitions the array in increasing order based on the values initially located at `0` and
52+ /// `n` where `n` is the number of elements in the array and returns the new indexes of the
53+ /// values.
5554 ///
56- /// The elements are rearranged in such a way that the value initially
57- /// located at `pivot_index` is moved to the position it would be in an
58- /// array sorted in increasing order. The return value is the new index of
59- /// the value after rearrangement. All elements smaller than the value are
60- /// moved to its left and all elements equal or greater than the value are
61- /// moved to its right. The ordering of the elements in the two partitions
62- /// is undefined.
55+ /// The elements are rearranged in such a way that the values initially located at `0` and `n`
56+ /// are moved to the position it would be in an array sorted in increasing order. The return
57+ /// values are the new indexes of the values after rearrangement. All elements less than the
58+ /// values are moved to their left and all elements equal or greater than the values are moved
59+ /// to their right. The ordering of the elements in the three partitions is undefined.
6360 ///
64- /// `self` is shuffled **in place** to operate the desired partition:
65- /// no copy of the array is allocated.
61+ /// `self` is shuffled **in place**, no copy of the array is allocated.
6662 ///
67- /// The method uses Hoare's partition algorithm.
68- /// Complexity: O(`n`), where `n` is the number of elements in the array.
69- /// Average number of element swaps: n/6 - 1/3 (see
70- /// [link](https://cs.stackexchange.com/questions/11458/quicksort-partitioning-hoare-vs-lomuto/11550))
63+ /// This method implements the partitioning scheme of [Yaroslavskiy-Bentley-Bloch Quicksort].
7164 ///
72- /// **Panics** if `pivot_index` is greater than or equal to `n`.
65+ /// [Yaroslavskiy-Bentley-Bloch Quicksort]: https://www.wild-inter.net/publications/wild-2016
7366 ///
7467 /// # Example
7568 ///
@@ -78,23 +71,30 @@ where
7871 /// use ndarray_stats::Sort1dExt;
7972 ///
8073 /// let mut data = array![3, 1, 4, 5, 2];
81- /// let pivot_index = 2;
82- /// let pivot_value = data[pivot_index] ;
74+ /// // Sorted pivot values.
75+ /// let (lower_value, upper_value) = ( data[data.len() - 1], data[0]) ;
8376 ///
84- /// // Partition by the value located at `pivot_index`.
85- /// let new_index = data.partition_mut(pivot_index);
86- /// // The pivot value is now located at `new_index`.
87- /// assert_eq!(data[new_index], pivot_value);
88- /// // Elements less than that value are moved to the left.
89- /// for i in 0..new_index {
90- /// assert!(data[i] < pivot_value);
77+ /// // Partitions by the values located at `0` and `data.len() - 1`.
78+ /// let (lower_index, upper_index) = data.partition_mut();
79+ /// // The pivot values are now located at `lower_index` and `upper_index`.
80+ /// assert_eq!(data[lower_index], lower_value);
81+ /// assert_eq!(data[upper_index], upper_value);
82+ /// // Elements lower than the lower pivot value are moved to its left.
83+ /// for i in 0..lower_index {
84+ /// assert!(data[i] < lower_value);
85+ /// }
86+ /// // Elements greater than or equal the lower pivot value and less than or equal the upper
87+ /// // pivot value are moved between the two pivot indexes.
88+ /// for i in lower_index + 1..upper_index {
89+ /// assert!(lower_value <= data[i]);
90+ /// assert!(data[i] <= upper_value);
9191 /// }
92- /// // Elements greater than or equal to that value are moved to the right.
93- /// for i in (new_index + 1) ..data.len() {
94- /// assert!(data[i] >= pivot_value );
92+ /// // Elements greater than or equal the upper pivot value are moved to its right.
93+ /// for i in upper_index + 1..data.len() {
94+ /// assert!(upper_value <= data[i]);
9595 /// }
9696 /// ```
97- fn partition_mut ( & mut self , pivot_index : usize ) -> usize
97+ fn partition_mut ( & mut self ) -> ( usize , usize )
9898 where
9999 A : Ord + Clone ,
100100 S : DataMut ;
@@ -115,17 +115,20 @@ where
115115 if n == 1 {
116116 self [ 0 ] . clone ( )
117117 } else {
118- let mut rng = thread_rng ( ) ;
119- let pivot_index = rng. gen_range ( 0 ..n) ;
120- let partition_index = self . partition_mut ( pivot_index) ;
121- if i < partition_index {
122- self . slice_axis_mut ( Axis ( 0 ) , Slice :: from ( ..partition_index) )
118+ let ( lower_index, upper_index) = self . partition_mut ( ) ;
119+ if i < lower_index {
120+ self . slice_axis_mut ( Axis ( 0 ) , Slice :: from ( ..lower_index) )
123121 . get_from_sorted_mut ( i)
124- } else if i == partition_index {
122+ } else if i == lower_index {
123+ self [ i] . clone ( )
124+ } else if i < upper_index {
125+ self . slice_axis_mut ( Axis ( 0 ) , Slice :: from ( lower_index + 1 ..upper_index) )
126+ . get_from_sorted_mut ( i - ( lower_index + 1 ) )
127+ } else if i == upper_index {
125128 self [ i] . clone ( )
126129 } else {
127- self . slice_axis_mut ( Axis ( 0 ) , Slice :: from ( partition_index + 1 ..) )
128- . get_from_sorted_mut ( i - ( partition_index + 1 ) )
130+ self . slice_axis_mut ( Axis ( 0 ) , Slice :: from ( upper_index + 1 ..) )
131+ . get_from_sorted_mut ( i - ( upper_index + 1 ) )
129132 }
130133 }
131134 }
@@ -143,42 +146,51 @@ where
143146 get_many_from_sorted_mut_unchecked ( self , & deduped_indexes)
144147 }
145148
146- fn partition_mut ( & mut self , pivot_index : usize ) -> usize
149+ fn partition_mut ( & mut self ) -> ( usize , usize )
147150 where
148151 A : Ord + Clone ,
149152 S : DataMut ,
150153 {
151- let pivot_value = self [ pivot_index] . clone ( ) ;
152- self . swap ( pivot_index, 0 ) ;
153- let n = self . len ( ) ;
154- let mut i = 1 ;
155- let mut j = n - 1 ;
156- loop {
157- loop {
158- if i > j {
159- break ;
160- }
161- if self [ i] >= pivot_value {
162- break ;
154+ // Sort `lowermost` and `uppermost` elements and use them as dual pivot.
155+ let lowermost = 0 ;
156+ let uppermost = self . len ( ) - 1 ;
157+ if self [ lowermost] > self [ uppermost] {
158+ self . swap ( lowermost, uppermost) ;
159+ }
160+ // Increasing running and partition index starting after lower pivot.
161+ let mut index = lowermost + 1 ;
162+ let mut lower = lowermost + 1 ;
163+ // Decreasing partition index starting before upper pivot.
164+ let mut upper = uppermost - 1 ;
165+ // Swap elements at `index` into their partitions.
166+ while index <= upper {
167+ if self [ index] < self [ lowermost] {
168+ // Swap elements into lower partition.
169+ self . swap ( index, lower) ;
170+ lower += 1 ;
171+ } else if self [ index] >= self [ uppermost] {
172+ // Search first element of upper partition.
173+ while self [ upper] > self [ uppermost] && index < upper {
174+ upper -= 1 ;
163175 }
164- i += 1 ;
165- }
166- while pivot_value <= self [ j] {
167- if j == 1 {
168- break ;
176+ // Swap elements into upper partition.
177+ self . swap ( index, upper) ;
178+ if self [ index] < self [ lowermost] {
179+ // Swap swapped elements into lower partition.
180+ self . swap ( index, lower) ;
181+ lower += 1 ;
169182 }
170- j -= 1 ;
171- }
172- if i >= j {
173- break ;
174- } else {
175- self . swap ( i, j) ;
176- i += 1 ;
177- j -= 1 ;
183+ upper -= 1 ;
178184 }
185+ index += 1 ;
179186 }
180- self . swap ( 0 , i - 1 ) ;
181- i - 1
187+ lower -= 1 ;
188+ upper += 1 ;
189+ // Swap pivots to their new indexes.
190+ self . swap ( lowermost, lower) ;
191+ self . swap ( uppermost, upper) ;
192+ // Lower and upper pivot index.
193+ ( lower, upper)
182194 }
183195
184196 private_impl ! { }
@@ -249,50 +261,72 @@ fn _get_many_from_sorted_mut_unchecked<A>(
249261 return ;
250262 }
251263
252- // We pick a random pivot index: the corresponding element is the pivot value
253- let mut rng = thread_rng ( ) ;
254- let pivot_index = rng. gen_range ( 0 ..n) ;
264+ // We partition the array with respect to the two pivot values. The pivot values move to
265+ // `lower_index` and `upper_index`.
266+ //
267+ // Elements strictly less than the lower pivot value have indexes < `lower_index`.
268+ //
269+ // Elements greater than or equal the lower pivot value and less than or equal the upper pivot
270+ // value have indexes > `lower_index` and < `upper_index`.
271+ //
272+ // Elements less than or equal the upper pivot value have indexes > `upper_index`.
273+ let ( lower_index, upper_index) = array. partition_mut ( ) ;
255274
256- // We partition the array with respect to the pivot value.
257- // The pivot value moves to `array_partition_index`.
258- // Elements strictly smaller than the pivot value have indexes < `array_partition_index`.
259- // Elements greater or equal to the pivot value have indexes > `array_partition_index`.
260- let array_partition_index = array. partition_mut ( pivot_index) ;
275+ // We use a divide-and-conquer strategy, splitting the indexes we are searching for (`indexes`)
276+ // and the corresponding portions of the output slice (`values`) into partitions with respect to
277+ // `lower_index` and `upper_index`.
278+ let ( found_exact, split_index) = match indexes. binary_search ( & lower_index) {
279+ Ok ( index) => ( true , index) ,
280+ Err ( index) => ( false , index) ,
281+ } ;
282+ let ( lower_indexes, inner_indexes) = indexes. split_at_mut ( split_index) ;
283+ let ( lower_values, inner_values) = values. split_at_mut ( split_index) ;
284+ let ( upper_indexes, upper_values) = if found_exact {
285+ inner_values[ 0 ] = array[ lower_index] . clone ( ) ; // Write exactly found value.
286+ ( & mut inner_indexes[ 1 ..] , & mut inner_values[ 1 ..] )
287+ } else {
288+ ( inner_indexes, inner_values)
289+ } ;
261290
262- // We use a divide-and-conquer strategy, splitting the indexes we are
263- // searching for (`indexes`) and the corresponding portions of the output
264- // slice (`values`) into pieces with respect to `array_partition_index`.
265- let ( found_exact, index_split) = match indexes. binary_search ( & array_partition_index) {
291+ let ( found_exact, split_index) = match upper_indexes. binary_search ( & upper_index) {
266292 Ok ( index) => ( true , index) ,
267293 Err ( index) => ( false , index) ,
268294 } ;
269- let ( smaller_indexes , other_indexes ) = indexes . split_at_mut ( index_split ) ;
270- let ( smaller_values , other_values ) = values . split_at_mut ( index_split ) ;
271- let ( bigger_indexes , bigger_values ) = if found_exact {
272- other_values [ 0 ] = array[ array_partition_index ] . clone ( ) ; // Write exactly found value.
273- ( & mut other_indexes [ 1 ..] , & mut other_values [ 1 ..] )
295+ let ( inner_indexes , upper_indexes ) = upper_indexes . split_at_mut ( split_index ) ;
296+ let ( inner_values , upper_values ) = upper_values . split_at_mut ( split_index ) ;
297+ let ( upper_indexes , upper_values ) = if found_exact {
298+ upper_values [ 0 ] = array[ upper_index ] . clone ( ) ; // Write exactly found value.
299+ ( & mut upper_indexes [ 1 ..] , & mut upper_values [ 1 ..] )
274300 } else {
275- ( other_indexes , other_values )
301+ ( upper_indexes , upper_values )
276302 } ;
277303
278- // We search recursively for the values corresponding to strictly smaller
279- // indexes to the left of `partition_index`.
304+ // We search recursively for the values corresponding to indexes strictly less than
305+ // `lower_index` in the lower partition.
306+ _get_many_from_sorted_mut_unchecked (
307+ array. slice_axis_mut ( Axis ( 0 ) , Slice :: from ( ..lower_index) ) ,
308+ lower_indexes,
309+ lower_values,
310+ ) ;
311+
312+ // We search recursively for the values corresponding to indexes greater than or equal
313+ // `lower_index` in the inner partition, that is between the lower and upper partition. Since
314+ // only the inner partition of the array is passed in, the indexes need to be shifted by length
315+ // of the lower partition.
316+ inner_indexes. iter_mut ( ) . for_each ( |x| * x -= lower_index + 1 ) ;
280317 _get_many_from_sorted_mut_unchecked (
281- array. slice_axis_mut ( Axis ( 0 ) , Slice :: from ( ..array_partition_index ) ) ,
282- smaller_indexes ,
283- smaller_values ,
318+ array. slice_axis_mut ( Axis ( 0 ) , Slice :: from ( lower_index + 1 ..upper_index ) ) ,
319+ inner_indexes ,
320+ inner_values ,
284321 ) ;
285322
286- // We search recursively for the values corresponding to strictly bigger
287- // indexes to the right of `partition_index`. Since only the right portion
288- // of the array is passed in, the indexes need to be shifted by length of
289- // the removed portion.
290- bigger_indexes
291- . iter_mut ( )
292- . for_each ( |x| * x -= array_partition_index + 1 ) ;
323+ // We search recursively for the values corresponding to indexes greater than or equal
324+ // `upper_index` in the upper partition. Since only the upper partition of the array is passed
325+ // in, the indexes need to be shifted by the combined length of the lower and inner partition.
326+ upper_indexes. iter_mut ( ) . for_each ( |x| * x -= upper_index + 1 ) ;
293327 _get_many_from_sorted_mut_unchecked (
294- array. slice_axis_mut ( Axis ( 0 ) , Slice :: from ( array_partition_index + 1 ..) ) ,
295- bigger_indexes ,
296- bigger_values ,
328+ array. slice_axis_mut ( Axis ( 0 ) , Slice :: from ( upper_index + 1 ..) ) ,
329+ upper_indexes ,
330+ upper_values ,
297331 ) ;
298332}
0 commit comments