@@ -58,7 +58,8 @@ module Data.Vector.Generic.Mutable (
5858 ifoldr , ifoldr' , ifoldrM , ifoldrM' ,
5959
6060 -- * Modifying vectors
61- nextPermutation ,
61+ nextPermutation , nextPermutationBy ,
62+ prevPermutation , prevPermutationBy ,
6263
6364 -- ** Filling and copying
6465 set , copy , move , unsafeCopy , unsafeMove ,
@@ -91,9 +92,10 @@ import Data.Vector.Internal.Check
9192import Control.Monad.Primitive ( PrimMonad (.. ), RealWorld , stToPrim )
9293
9394import Prelude
94- ( Ord , Monad , Bool (.. ), Int , Maybe (.. ), Either (.. )
95+ ( Ord , Monad , Bool (.. ), Int , Maybe (.. ), Either (.. ), Ordering ( .. )
9596 , return , otherwise , flip , const , seq , min , max , not , pure
96- , (>>=) , (+) , (-) , (<) , (<=) , (>=) , (==) , (/=) , (.) , ($) , (=<<) , (>>) , (<$>) )
97+ , (>>=) , (+) , (-) , (<) , (<=) , (>) , (>=) , (==) , (/=) , (.) , ($) , (=<<) , (>>) , (<$>) )
98+ import Data.Bits ( Bits (shiftR ) )
9799
98100#include "vector.h"
99101
@@ -1213,6 +1215,47 @@ partitionWithUnknown f s
12131215-- Modifying vectors
12141216-- -----------------
12151217
1218+
1219+ -- | Compute the (lexicographically) next permutation of the given vector in-place.
1220+ -- Returns False when the input is the last item in the enumeration, i.e., if it is in
1221+ -- weakly descending order. In this case the vector will not get updated,
1222+ -- as opposed to the behavior of the C++ function @std::next_permutation@.
1223+ nextPermutation :: (PrimMonad m , Ord e , MVector v e ) => v (PrimState m ) e -> m Bool
1224+ {-# INLINE nextPermutation #-}
1225+ nextPermutation = nextPermutationByLt (<)
1226+
1227+ -- | Compute the (lexicographically) next permutation of the given vector in-place,
1228+ -- using the provided comparison function.
1229+ -- Returns False when the input is the last item in the enumeration, i.e., if it is in
1230+ -- weakly descending order. In this case the vector will not get updated,
1231+ -- as opposed to the behavior of the C++ function @std::next_permutation@.
1232+ --
1233+ -- @since 0.13.2.0
1234+ nextPermutationBy :: (PrimMonad m , MVector v e ) => (e -> e -> Ordering ) -> v (PrimState m ) e -> m Bool
1235+ {-# INLINE nextPermutationBy #-}
1236+ nextPermutationBy cmp = nextPermutationByLt (\ x y -> cmp x y == LT )
1237+
1238+ -- | Compute the (lexicographically) previous permutation of the given vector in-place.
1239+ -- Returns False when the input is the last item in the enumeration, i.e., if it is in
1240+ -- weakly ascending order. In this case the vector will not get updated,
1241+ -- as opposed to the behavior of the C++ function @std::prev_permutation@.
1242+ --
1243+ -- @since 0.13.2.0
1244+ prevPermutation :: (PrimMonad m , Ord e , MVector v e ) => v (PrimState m ) e -> m Bool
1245+ {-# INLINE prevPermutation #-}
1246+ prevPermutation = nextPermutationByLt (>)
1247+
1248+ -- | Compute the (lexicographically) previous permutation of the given vector in-place,
1249+ -- using the provided comparison function.
1250+ -- Returns False when the input is the last item in the enumeration, i.e., if it is in
1251+ -- weakly ascending order. In this case the vector will not get updated,
1252+ -- as opposed to the behavior of the C++ function @std::prev_permutation@.
1253+ --
1254+ -- @since 0.13.2.0
1255+ prevPermutationBy :: (PrimMonad m , MVector v e ) => (e -> e -> Ordering ) -> v (PrimState m ) e -> m Bool
1256+ {-# INLINE prevPermutationBy #-}
1257+ prevPermutationBy cmp = nextPermutationByLt (\ x y -> cmp x y == GT )
1258+
12161259{-
12171260http://en.wikipedia.org/wiki/Permutation#Algorithms_to_generate_permutations
12181261
@@ -1224,30 +1267,51 @@ a given permutation. It changes the given permutation in-place.
122412672. Find the largest index l greater than k such that a[k] < a[l].
122512683. Swap the value of a[k] with that of a[l].
122612694. Reverse the sequence from a[k + 1] up to and including the final element a[n]
1270+
1271+ The algorithm has been updated to look up the k in Step 1 beginning from the
1272+ last of the vector; which renders the algorithm to achieve the average time
1273+ complexity of O(1) each call. The worst case time complexity is still O(n).
1274+ The orginal implementation, which scanned the vector from the left, had the
1275+ time complexity of O(n) on the best case.
12271276-}
12281277
12291278-- | Compute the (lexicographically) next permutation of the given vector in-place.
1230- -- Returns False when the input is the last permutation.
1231- nextPermutation :: (PrimMonad m ,Ord e ,MVector v e ) => v (PrimState m ) e -> m Bool
1232- nextPermutation v
1233- | dim < 2 = return False
1234- | otherwise = do
1235- val <- unsafeRead v 0
1236- (k,l) <- loop val (- 1 ) 0 val 1
1237- if k < 0
1238- then return False
1239- else unsafeSwap v k l >>
1240- reverse (unsafeSlice (k+ 1 ) (dim- k- 1 ) v) >>
1241- return True
1242- where loop ! kval ! k ! l ! prev ! i
1243- | i == dim = return (k,l)
1244- | otherwise = do
1245- cur <- unsafeRead v i
1246- -- TODO: make tuple unboxed
1247- let (kval',k') = if prev < cur then (prev,i- 1 ) else (kval,k)
1248- l' = if kval' < cur then i else l
1249- loop kval' k' l' cur (i+ 1 )
1250- dim = length v
1279+ -- Here, the first argument should be a less-than comparison function.
1280+ -- Returns False when the input is the last permutation; in this case the vector
1281+ -- will not get updated, as opposed to the behavior of the C++ function
1282+ -- @std::next_permutation@.
1283+ nextPermutationByLt :: (PrimMonad m , MVector v e ) => (e -> e -> Bool ) -> v (PrimState m ) e -> m Bool
1284+ {-# INLINE nextPermutationByLt #-}
1285+ nextPermutationByLt lt v
1286+ | dim < 2 = return False
1287+ | otherwise = stToPrim $ do
1288+ ! vlast <- unsafeRead v (dim - 1 )
1289+ decrLoop (dim - 2 ) vlast
1290+ where
1291+ dim = length v
1292+ -- find the largest index k such that a[k] < a[k + 1], and then pass to the rest.
1293+ decrLoop ! i ! vi1 | i >= 0 = do
1294+ ! vi <- unsafeRead v i
1295+ if vi `lt` vi1 then swapLoop i vi (i+ 1 ) vi1 dim else decrLoop (i- 1 ) vi
1296+ decrLoop _ ! _ = return False
1297+ -- find the largest index l greater than k such that a[k] < a[l], and do the rest.
1298+ swapLoop ! k ! vk = go
1299+ where
1300+ -- binary search.
1301+ go ! l ! vl ! r | r - l <= 1 = do
1302+ -- Done; do the rest of the algorithm.
1303+ unsafeWrite v k vl
1304+ unsafeWrite v l vk
1305+ reverse $ unsafeSlice (k + 1 ) (dim - k - 1 ) v
1306+ return True
1307+ go ! l ! vl ! r = do
1308+ ! vmid <- unsafeRead v mid
1309+ if vk `lt` vmid
1310+ then go mid vmid r
1311+ else go l vl mid
1312+ where
1313+ ! mid = l + (r - l) `shiftR` 1
1314+
12511315
12521316-- $setup
12531317-- >>> import Prelude ((*))
0 commit comments