@@ -734,14 +734,23 @@ pub const fn swap<T>(x: &mut T, y: &mut T) {
734734 // a backend can choose to implement using the block optimization, or not.
735735 #[ cfg( not( any( target_arch = "spirv" ) ) ) ]
736736 {
737+ // Types with alignment bigger than usize are almost always used for
738+ // hand-tuned SIMD optimizations so we don't get into way.
739+ //
737740 // For types that are larger multiples of their alignment, the simple way
738741 // tends to copy the whole thing to stack rather than doing it one part
739- // at a time, so instead treat them as one-element slices and piggy-back
740- // the slice optimizations that will split up the swaps.
741- if size_of :: < T > ( ) / align_of :: < T > ( ) > 4 {
742- // SAFETY: exclusive references always point to one non-overlapping
743- // element and are non-null and properly aligned.
744- return unsafe { ptr:: swap_nonoverlapping ( x, y, 1 ) } ;
742+ // at a time, so instead try to split them into chunks that fit into registers
743+ // and swap chunks.
744+ if const {
745+ let size = size_of :: < T > ( ) ;
746+ let align = align_of :: < T > ( ) ;
747+ // It is weird that LLVM sometimes optimizes `4*align` fine while failing `3*align`.
748+ // Proof: https://godbolt.org/z/MhnqvjjPz
749+ ( align < align_of :: < usize > ( ) && ( size == 3 * align || size > 4 * align) )
750+ // And usize fails even for `4*align`.
751+ || ( align == align_of :: < usize > ( ) && size > 2 * align)
752+ } {
753+ return swap_chunked ( x, y) ;
745754 }
746755 }
747756
@@ -784,6 +793,236 @@ pub(crate) const fn swap_simple<T>(x: &mut T, y: &mut T) {
784793 }
785794}
786795
796+ // This version swaps 2 values by chunks of usize and smaller
797+ // using unaligned reads and writes because they are cheap
798+ // on modern x86_64 processors for at least 10 years now (at 2022-07-04).
799+ // https://lemire.me/blog/2012/05/31/data-alignment-for-speed-myth-or-reality/
800+ // It generates less instructions and memory accesses as well: https://godbolt.org/z/Mr4rWfoad
801+ // Feel free to add another targets as well if they have fast unaligned accesses.
802+ //
803+ // This should be done by backend optimizer but it currently fails to do so.
804+ #[ cfg( target_arch = "x86_64" ) ]
805+ #[ rustc_const_unstable( feature = "const_swap" , issue = "83163" ) ]
806+ #[ inline]
807+ const fn swap_chunked < T : Sized > ( x : & mut T , y : & mut T ) {
808+ // Algorithm:
809+ // 1. Swap first `n*ZMM_BYTES` using `u64`, which reliably autovectorized by LLVM.
810+ // 2. Force backend to generate use SIMD registers YMM or XMM using `force_swap_simd`.
811+ // 3. Swap remaining bytes using integers each of them twice smaller than previous.
812+
813+ // Note: Current version of this function optimized for x86_64.
814+ // If you allow use of it for another architecture, check generated code first.
815+ const XMM_BYTES : usize = 128 / 8 ;
816+ const YMM_BYTES : usize = 256 / 8 ;
817+ const ZMM_BYTES : usize = 512 / 8 ;
818+
819+ /// This function successfully autovectorizes if `size_bytes` is divisible
820+ /// by biggest available SIMD register size. If it is not, it would generate
821+ /// SIMD operations for first divisible bytes, then would emit just integer reads and writes.
822+ /// # Safety
823+ /// 1. Must not overlap. 2. Must have at least `size_bytes` valid bytes.
824+ /// 3. `size_bytes` must be exactly divisible by `size_of::<u64>`.
825+ #[ inline]
826+ const unsafe fn swap_simple_u64_chunks (
827+ size_bytes : usize ,
828+ x : * mut MaybeUninit < u8 > ,
829+ y : * mut MaybeUninit < u8 > ,
830+ ) {
831+ let mut byte_offset = 0 ;
832+ while byte_offset < size_bytes {
833+ // SAFETY:
834+ // Caller must ensure pointers validity and range.
835+ // We use unaligned reads/writes.
836+ unsafe {
837+ let x = x. add ( byte_offset) . cast ( ) ;
838+ let y = y. add ( byte_offset) . cast ( ) ;
839+ // Same rationale for 2 reads and 2 writes
840+ // as in `swap_simple`.
841+ let tmp_x: MaybeUninit < u64 > = ptr:: read_unaligned ( x) ;
842+ let tmp_y: MaybeUninit < u64 > = ptr:: read_unaligned ( y) ;
843+ ptr:: write_unaligned ( x, tmp_y) ;
844+ ptr:: write_unaligned ( y, tmp_x) ;
845+ }
846+ byte_offset += const { size_of :: < u64 > ( ) } ;
847+ }
848+ }
849+
850+ /// This function would generate swap of next `SIZE_BYTES` using SIMD registers.
851+ /// It has drawback: it would generate swap using `SIZE_BYTES` only.
852+ /// E.g. if type has ZMM_BYTES size, but swap implemented using only `force_swap_simd<XMM_BYTES>`,
853+ /// it would use only XMM registers.
854+ /// So this function should be used only once because if it called twice,
855+ /// it is better to use bigger register.
856+ ///
857+ /// It is OK to call it with YMM_BYTES even if only SSE enabled because
858+ /// It would just use 4 XMM registers.
859+ /// # Safety
860+ /// 1. Must not overlap. 2. Must have at least `SIZE_BYTES` valid bytes.
861+ #[ inline]
862+ const unsafe fn force_swap_simd < const SIZE_BYTES : usize > (
863+ x : * mut MaybeUninit < u8 > ,
864+ y : * mut MaybeUninit < u8 > ,
865+ ) {
866+ const {
867+ assert ! (
868+ SIZE_BYTES == XMM_BYTES || SIZE_BYTES == YMM_BYTES ,
869+ "Must have valid SIMD register size" ,
870+ ) ;
871+ }
872+ // SAFETY: We require valid non-overlapping pointers with SIZE_BYTES.
873+ // We checked that they SIMD register sized.
874+ unsafe {
875+ // Don't use an array for temporary here because it ends up being on a stack.
876+ // E.g. it would copy from memory to register,
877+ // from register to stack, from stack to register
878+ // and from register to destination.
879+
880+ const {
881+ assert ! ( XMM_BYTES == size_of:: <u64 >( ) * 2 , "Check number of temporaries below." ) ;
882+ assert ! ( YMM_BYTES == size_of:: <u64 >( ) * 4 , "Check number of temporaries below." ) ;
883+ }
884+ let x: * mut MaybeUninit < u64 > = x. cast ( ) ;
885+ let y: * mut MaybeUninit < u64 > = y. cast ( ) ;
886+ // Use separate variables instead.
887+ // They are successfully folded into YMM register when compiled with AVX,
888+ // or pair of XMM registers in `swap_simd::<YMM_BYTES>` without AVX,
889+ // or single XMM register in `swap_simd::<XMM_BYTES>`.
890+ let x0: MaybeUninit < u64 > = ptr:: read_unaligned ( x) ;
891+ let x1: MaybeUninit < u64 > = ptr:: read_unaligned ( x. add ( 1 ) ) ;
892+ let x2: MaybeUninit < u64 > ;
893+ let x3: MaybeUninit < u64 > ;
894+ if const { SIZE_BYTES == YMM_BYTES } {
895+ x2 = ptr:: read_unaligned ( x. add ( 2 ) ) ;
896+ x3 = ptr:: read_unaligned ( x. add ( 3 ) ) ;
897+ } else {
898+ x2 = MaybeUninit :: uninit ( ) ;
899+ x3 = MaybeUninit :: uninit ( ) ;
900+ }
901+
902+ // Unlike simple swap, we need to use direct move here
903+ // instead of using temporary value for `y` like in `swap_simple`
904+ // because it causes temporaries for `x` to be copied to stack.
905+
906+ // Cast to `MaybeUninit<u8>` because `copy_nonoverlapping` requires correct alignment.
907+ ptr:: copy_nonoverlapping :: < MaybeUninit < u8 > > ( y. cast ( ) , x. cast ( ) , SIZE_BYTES ) ;
908+
909+ ptr:: write_unaligned ( y, x0) ;
910+ ptr:: write_unaligned ( y. add ( 1 ) , x1) ;
911+ if const { SIZE_BYTES == YMM_BYTES } {
912+ ptr:: write_unaligned ( y. add ( 2 ) , x2) ;
913+ ptr:: write_unaligned ( y. add ( 3 ) , x3) ;
914+ }
915+ }
916+ }
917+
918+ /// Would swap first `size_of::<ChunkTy>` bytes of tail.
919+ /// SAFETY:
920+ /// `x` and `y` must not overlap.
921+ /// `x` and `y` must have at least `size_of::<ChunkTy>` bytes.
922+ #[ inline]
923+ const unsafe fn swap_tail < ChunkTy : Copy > ( x : * mut MaybeUninit < u8 > , y : * mut MaybeUninit < u8 > ) {
924+ // SAFETY: Caller must ensure pointers validity.
925+ // We use unaligned reads/writes.
926+ unsafe {
927+ // Same rationale for 2 reads and 2 writes
928+ // as in `swap_simple`.
929+ let tmp_x: MaybeUninit < ChunkTy > = ptr:: read_unaligned ( x. cast ( ) ) ;
930+ let tmp_y: MaybeUninit < ChunkTy > = ptr:: read_unaligned ( y. cast ( ) ) ;
931+ ptr:: write_unaligned ( x. cast ( ) , tmp_y) ;
932+ ptr:: write_unaligned ( y. cast ( ) , tmp_x) ;
933+ }
934+ }
935+
936+ const {
937+ assert ! ( size_of:: <T >( ) <= usize :: MAX / 4 , "We assume that overflows cannot happen." ) ;
938+ }
939+
940+ let x: * mut MaybeUninit < u8 > = ( x as * mut T ) . cast ( ) ;
941+ let y: * mut MaybeUninit < u8 > = ( y as * mut T ) . cast ( ) ;
942+
943+ // I would like to add detection for available SIMD here
944+ // but since standard library is distributed precompiled,
945+ // `cfg!(target_feature="xxx")` evaluates to false here
946+ // even if final binary is built with those features.
947+
948+ let size = const { size_of :: < T > ( ) } ;
949+ let mut byte_offset = 0 ;
950+
951+ // This part would autovectorize to use biggest SIMD register available.
952+ // SAFETY: pointers are valid because they are from references,
953+ // `limit` <= `size`, we removed remainder.
954+ // Whole function doesn't contain places which can panic
955+ // so function wouldn't interrepted before swapping ends.
956+ unsafe {
957+ let exactly_divisible = const {
958+ let size = size_of :: < T > ( ) ;
959+ size - size % ZMM_BYTES
960+ } ;
961+ swap_simple_u64_chunks ( exactly_divisible, x, y) ;
962+ byte_offset += exactly_divisible;
963+ }
964+ if byte_offset + YMM_BYTES <= size {
965+ // SAFETY: Pointers don't overlap because mutable references don't overlap.
966+ // We just checked range.
967+ // Whole function doesn't contain places which can panic
968+ // so function wouldn't interrepted before swapping ends.
969+ unsafe {
970+ // We need to do this only once because tail after ZMM_BYTES chunks cannot contain more than 1.
971+ // It is OK to do this even if AVX not enabled because it would just use 4 XMM registers.
972+ force_swap_simd :: < YMM_BYTES > ( x. add ( byte_offset) , y. add ( byte_offset) ) ;
973+ byte_offset += YMM_BYTES ;
974+ }
975+ }
976+ if byte_offset + XMM_BYTES <= size {
977+ // SAFETY: Pointers don't overlap because mutable references don't overlap.
978+ // We just checked range.
979+ // Whole function doesn't contain places which can panic
980+ // so function wouldn't interrepted before swapping ends.
981+ unsafe {
982+ // We need to do this only once because tail after YMM_BYTES chunks cannot contain more than 1.
983+ force_swap_simd :: < XMM_BYTES > ( x. add ( byte_offset) , y. add ( byte_offset) ) ;
984+ byte_offset += XMM_BYTES ;
985+ }
986+ }
987+
988+ macro_rules! swap_tail_by {
989+ ( $t: ty) => {
990+ // SAFETY: Pointers don't overlap because mutable references don't overlap.
991+ // We never access pointers in a way that require alignment,
992+ // and whole function doesn't contain places which can panic
993+ // so it is drop safe.
994+ // We swapped first `size_of::<T>() / 16 * 16` bytes already.
995+ // We try `swap_tail` functions in order from bigger to smaller.
996+ unsafe {
997+ if byte_offset + const { size_of:: <$t>( ) } <= size {
998+ swap_tail:: <$t>( x. add( byte_offset) , y. add( byte_offset) ) ;
999+ byte_offset += const { size_of:: <$t>( ) } ;
1000+ }
1001+ }
1002+ } ;
1003+ }
1004+ swap_tail_by ! ( u64 ) ;
1005+ swap_tail_by ! ( u32 ) ;
1006+ swap_tail_by ! ( u16 ) ;
1007+ // Swapping by `u8` is guaranteed to finish all bytes
1008+ // because it has `size_of == 1`.
1009+ swap_tail_by ! ( u8 ) ;
1010+ let _ = byte_offset;
1011+ }
1012+
1013+ // For platforms wit slow unaligned accesses
1014+ // we can just delegate to pointer swaps.
1015+ #[ cfg( not( target_arch = "x86_64" ) ) ]
1016+ #[ rustc_const_unstable( feature = "const_swap" , issue = "83163" ) ]
1017+ #[ inline]
1018+ pub ( crate ) const fn swap_chunked < T : Sized > ( x : & mut T , y : & mut T ) {
1019+ // SAFETY: exclusive references always point to one non-overlapping
1020+ // element and are non-null and properly aligned.
1021+ unsafe {
1022+ ptr:: swap_nonoverlapping ( x, y, 1 ) ;
1023+ }
1024+ }
1025+
7871026/// Replaces `dest` with the default value of `T`, returning the previous `dest` value.
7881027///
7891028/// * If you want to replace the values of two variables, see [`swap`].
0 commit comments