From 7346b83d829346d1c15e1252799064994e7077e6 Mon Sep 17 00:00:00 2001 From: caelunshun Date: Sat, 1 Nov 2025 14:52:03 -0600 Subject: [PATCH 1/9] Initial AVX-512 port of ChaCha20. Benchmark results on Zen 4: AVX2 Zen4 test chacha20_bench1_16b ... bench: 23.71 ns/iter (+/- 0.89) = 695 MB/s test chacha20_bench2_256b ... bench: 82.98 ns/iter (+/- 7.64) = 3121 MB/s test chacha20_bench3_1kib ... bench: 302.03 ns/iter (+/- 3.59) = 3390 MB/s test chacha20_bench4_16kib ... bench: 4,677.58 ns/iter (+/- 161.42) = 3503 MB/s AVX512 Zen4 test chacha20_bench1_16b ... bench: 25.07 ns/iter (+/- 0.90) = 640 MB/s test chacha20_bench2_256b ... bench: 79.66 ns/iter (+/- 1.18) = 3240 MB/s test chacha20_bench3_1kib ... bench: 275.32 ns/iter (+/- 4.13) = 3723 MB/s test chacha20_bench4_16kib ... bench: 4,201.84 ns/iter (+/- 24.18) = 3900 MB/s Much greater speedups are achievable for long input sizes if we increase PAR_BLOCKS to 8, but this also causes a 2x slowdown for short inputs (< 512 bytes). The StreamCipherBackend API doesn't seem to have any way to support multiple degrees of parallelism depending on the input size. --- chacha20/Cargo.toml | 7 +- chacha20/src/backends.rs | 1 + chacha20/src/backends/avx512.rs | 350 ++++++++++++++++++++++++++++++++ chacha20/src/lib.rs | 13 +- 4 files changed, 365 insertions(+), 6 deletions(-) create mode 100644 chacha20/src/backends/avx512.rs diff --git a/chacha20/Cargo.toml b/chacha20/Cargo.toml index 6e732524..1bb2eb83 100644 --- a/chacha20/Cargo.toml +++ b/chacha20/Cargo.toml @@ -3,7 +3,7 @@ name = "chacha20" version = "0.10.0-rc.2" authors = ["RustCrypto Developers"] edition = "2024" -rust-version = "1.85" +rust-version = "1.89" documentation = "https://docs.rs/chacha20" readme = "README.md" repository = "https://github.com/RustCrypto/stream-ciphers" @@ -20,7 +20,9 @@ rand_core-compatible RNGs based on those ciphers. [dependencies] cfg-if = "1" -cipher = { version = "0.5.0-rc.1", optional = true, features = ["stream-wrapper"] } +cipher = { version = "0.5.0-rc.1", optional = true, features = [ + "stream-wrapper", +] } rand_core = { version = "0.10.0-rc.1", optional = true, default-features = false } # `zeroize` is an explicit dependency because this crate may be used without the `cipher` crate @@ -51,6 +53,7 @@ check-cfg = [ 'cfg(chacha20_force_soft)', 'cfg(chacha20_force_sse2)', 'cfg(chacha20_force_avx2)', + 'cfg(chacha20_force_avx512)', ] [lints.clippy] diff --git a/chacha20/src/backends.rs b/chacha20/src/backends.rs index 936e0b67..22bf7885 100644 --- a/chacha20/src/backends.rs +++ b/chacha20/src/backends.rs @@ -13,6 +13,7 @@ cfg_if! { pub(crate) mod sse2; } else { pub(crate) mod soft; + pub(crate) mod avx512; pub(crate) mod avx2; pub(crate) mod sse2; } diff --git a/chacha20/src/backends/avx512.rs b/chacha20/src/backends/avx512.rs new file mode 100644 index 00000000..9622606b --- /dev/null +++ b/chacha20/src/backends/avx512.rs @@ -0,0 +1,350 @@ +#![allow(unsafe_op_in_unsafe_fn)] +use crate::{Rounds, Variant}; +use core::marker::PhantomData; + +#[cfg(feature = "rng")] +use crate::ChaChaCore; + +#[cfg(feature = "cipher")] +use crate::{STATE_WORDS, chacha::Block}; + +#[cfg(feature = "cipher")] +use cipher::{ + BlockSizeUser, ParBlocks, ParBlocksSizeUser, StreamCipherBackend, StreamCipherClosure, + consts::{U4, U64}, +}; + +#[cfg(target_arch = "x86")] +use core::arch::x86::*; +#[cfg(target_arch = "x86_64")] +use core::arch::x86_64::*; + +/// Number of blocks processed in parallel. +const PAR_BLOCKS: usize = 4; +/// Number of `__m512i` to store parallel blocks. +const N: usize = PAR_BLOCKS / 4; + +#[inline] +#[target_feature(enable = "avx512f")] +#[cfg(feature = "cipher")] +pub(crate) unsafe fn inner(state: &mut [u32; STATE_WORDS], f: F) +where + R: Rounds, + F: StreamCipherClosure, + V: Variant, +{ + let state_ptr = state.as_ptr() as *const __m128i; + let v = [ + _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(0))), + _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(1))), + _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(2))), + ]; + let mut c = _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(3))); + c = match size_of::() { + 4 => _mm512_add_epi32( + c, + _mm512_set_epi32(0, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0), + ), + 8 => _mm512_add_epi64(c, _mm512_set_epi64(0, 3, 0, 2, 0, 1, 0, 0)), + _ => unreachable!(), + }; + let mut ctr = [c; N]; + for i in 0..N { + ctr[i] = c; + c = match size_of::() { + 4 => _mm512_add_epi32( + c, + _mm512_set_epi32(0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4), + ), + 8 => _mm512_add_epi64(c, _mm512_set_epi64(0, 4, 0, 4, 0, 4, 0, 4)), + _ => unreachable!(), + }; + } + let mut backend = Backend:: { + v, + ctr, + _pd: PhantomData, + }; + + f.call(&mut backend); + + state[12] = _mm256_extract_epi32::<0>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32; + match size_of::() { + 4 => {} + 8 => { + state[13] = + _mm256_extract_epi32::<1>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32 + } + _ => unreachable!(), + } +} + +#[inline] +#[target_feature(enable = "avx2")] +#[cfg(feature = "rng")] +pub(crate) unsafe fn rng_inner(core: &mut ChaChaCore, buffer: &mut [u32; 64]) +where + R: Rounds, + V: Variant, +{ + let state_ptr = core.state.as_ptr() as *const __m128i; + let v = [ + _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(0))), + _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(1))), + _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))), + ]; + let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3))); + c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0)); + let mut ctr = [c; N]; + for i in 0..N { + ctr[i] = c; + c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2)); + } + let mut backend = Backend:: { + v, + ctr, + _pd: PhantomData, + }; + + backend.rng_gen_par_ks_blocks(buffer); + + core.state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32; + core.state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32; +} + +struct Backend { + v: [__m512i; 3], + ctr: [__m512i; N], + _pd: PhantomData<(R, V)>, +} + +#[cfg(feature = "cipher")] +impl BlockSizeUser for Backend { + type BlockSize = U64; +} + +#[cfg(feature = "cipher")] +impl ParBlocksSizeUser for Backend { + type ParBlocksSize = U4; +} + +#[cfg(feature = "cipher")] +impl StreamCipherBackend for Backend { + #[inline(always)] + fn gen_ks_block(&mut self, block: &mut Block) { + unsafe { + let res = rounds::(&self.v, &self.ctr); + for c in self.ctr.iter_mut() { + *c = match size_of::() { + 4 => _mm512_add_epi32( + *c, + _mm512_set_epi32(0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1), + ), + 8 => _mm512_add_epi64(*c, _mm512_set_epi64(0, 1, 0, 1, 0, 1, 0, 1)), + _ => unreachable!(), + }; + } + + let block_ptr = block.as_mut_ptr() as *mut __m128i; + + for i in 0..4 { + _mm_storeu_si128(block_ptr.add(i), _mm512_extracti32x4_epi32::<0>(res[0][i])); + } + } + } + + #[inline(always)] + fn gen_par_ks_blocks(&mut self, blocks: &mut ParBlocks) { + unsafe { + let vs = rounds::(&self.v, &self.ctr); + + let pb = PAR_BLOCKS as i32; + for c in self.ctr.iter_mut() { + *c = match size_of::() { + 4 => _mm512_add_epi32( + *c, + _mm512_set_epi32(0, 0, 0, pb, 0, 0, 0, pb, 0, 0, 0, pb, 0, 0, 0, pb), + ), + 8 => _mm512_add_epi64( + *c, + _mm512_set_epi64(0, pb as i64, 0, pb as i64, 0, pb as i64, 0, pb as i64), + ), + _ => unreachable!(), + } + } + + let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i; + for v in vs { + let t: [__m128i; 16] = core::mem::transmute(v); + for i in 0..4 { + _mm_storeu_si128(block_ptr.add(i), t[4 * i]); + _mm_storeu_si128(block_ptr.add(4 + i), t[4 * i + 1]); + _mm_storeu_si128(block_ptr.add(8 + i), t[4 * i + 2]); + _mm_storeu_si128(block_ptr.add(12 + i), t[4 * i + 3]); + } + block_ptr = block_ptr.add(16); + } + } + } +} + +#[cfg(feature = "rng")] +impl Backend { + #[inline(always)] + fn rng_gen_par_ks_blocks(&mut self, blocks: &mut [u32; 64]) { + unsafe { + let vs = rounds::(&self.v, &self.ctr); + + let pb = PAR_BLOCKS as i32; + for c in self.ctr.iter_mut() { + *c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, pb as i64, 0, pb as i64)); + } + + let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i; + for v in vs { + let t: [__m128i; 8] = core::mem::transmute(v); + for i in 0..4 { + _mm_storeu_si128(block_ptr.add(i), t[2 * i]); + _mm_storeu_si128(block_ptr.add(4 + i), t[2 * i + 1]); + } + block_ptr = block_ptr.add(8); + } + } + } +} + +#[inline] +#[target_feature(enable = "avx512f")] +unsafe fn rounds(v: &[__m512i; 3], c: &[__m512i; N]) -> [[__m512i; 4]; N] { + let mut vs: [[__m512i; 4]; N] = [[_mm512_setzero_si512(); 4]; N]; + for i in 0..N { + vs[i] = [v[0], v[1], v[2], c[i]]; + } + for _ in 0..R::COUNT { + double_quarter_round(&mut vs); + } + + for i in 0..N { + for j in 0..3 { + vs[i][j] = _mm512_add_epi32(vs[i][j], v[j]); + } + vs[i][3] = _mm512_add_epi32(vs[i][3], c[i]); + } + + vs +} + +#[inline] +#[target_feature(enable = "avx2")] +unsafe fn double_quarter_round(v: &mut [[__m512i; 4]; N]) { + add_xor_rot(v); + rows_to_cols(v); + add_xor_rot(v); + cols_to_rows(v); +} + +/// The goal of this function is to transform the state words from: +/// ```text +/// [a0, a1, a2, a3] [ 0, 1, 2, 3] +/// [b0, b1, b2, b3] == [ 4, 5, 6, 7] +/// [c0, c1, c2, c3] [ 8, 9, 10, 11] +/// [d0, d1, d2, d3] [12, 13, 14, 15] +/// ``` +/// +/// to: +/// ```text +/// [a0, a1, a2, a3] [ 0, 1, 2, 3] +/// [b1, b2, b3, b0] == [ 5, 6, 7, 4] +/// [c2, c3, c0, c1] [10, 11, 8, 9] +/// [d3, d0, d1, d2] [15, 12, 13, 14] +/// ``` +/// +/// so that we can apply [`add_xor_rot`] to the resulting columns, and have it compute the +/// "diagonal rounds" (as defined in RFC 7539) in parallel. In practice, this shuffle is +/// non-optimal: the last state word to be altered in `add_xor_rot` is `b`, so the shuffle +/// blocks on the result of `b` being calculated. +/// +/// We can optimize this by observing that the four quarter rounds in `add_xor_rot` are +/// data-independent: they only access a single column of the state, and thus the order of +/// the columns does not matter. We therefore instead shuffle the other three state words, +/// to obtain the following equivalent layout: +/// ```text +/// [a3, a0, a1, a2] [ 3, 0, 1, 2] +/// [b0, b1, b2, b3] == [ 4, 5, 6, 7] +/// [c1, c2, c3, c0] [ 9, 10, 11, 8] +/// [d2, d3, d0, d1] [14, 15, 12, 13] +/// ``` +/// +/// See https://github.com/sneves/blake2-avx2/pull/4 for additional details. The earliest +/// known occurrence of this optimization is in floodyberry's SSE4 ChaCha code from 2014: +/// - https://github.com/floodyberry/chacha-opt/blob/0ab65cb99f5016633b652edebaf3691ceb4ff753/chacha_blocks_ssse3-64.S#L639-L643 +#[inline] +#[target_feature(enable = "avx512f")] +unsafe fn rows_to_cols(vs: &mut [[__m512i; 4]; N]) { + // c >>>= 32; d >>>= 64; a >>>= 96; + for [a, _, c, d] in vs { + *c = _mm512_shuffle_epi32::<0b_00_11_10_01>(*c); // _MM_SHUFFLE(0, 3, 2, 1) + *d = _mm512_shuffle_epi32::<0b_01_00_11_10>(*d); // _MM_SHUFFLE(1, 0, 3, 2) + *a = _mm512_shuffle_epi32::<0b_10_01_00_11>(*a); // _MM_SHUFFLE(2, 1, 0, 3) + } +} + +/// The goal of this function is to transform the state words from: +/// ```text +/// [a3, a0, a1, a2] [ 3, 0, 1, 2] +/// [b0, b1, b2, b3] == [ 4, 5, 6, 7] +/// [c1, c2, c3, c0] [ 9, 10, 11, 8] +/// [d2, d3, d0, d1] [14, 15, 12, 13] +/// ``` +/// +/// to: +/// ```text +/// [a0, a1, a2, a3] [ 0, 1, 2, 3] +/// [b0, b1, b2, b3] == [ 4, 5, 6, 7] +/// [c0, c1, c2, c3] [ 8, 9, 10, 11] +/// [d0, d1, d2, d3] [12, 13, 14, 15] +/// ``` +/// +/// reversing the transformation of [`rows_to_cols`]. +#[inline] +#[target_feature(enable = "avx512f")] +unsafe fn cols_to_rows(vs: &mut [[__m512i; 4]; N]) { + // c <<<= 32; d <<<= 64; a <<<= 96; + for [a, _, c, d] in vs { + *c = _mm512_shuffle_epi32::<0b_10_01_00_11>(*c); // _MM_SHUFFLE(2, 1, 0, 3) + *d = _mm512_shuffle_epi32::<0b_01_00_11_10>(*d); // _MM_SHUFFLE(1, 0, 3, 2) + *a = _mm512_shuffle_epi32::<0b_00_11_10_01>(*a); // _MM_SHUFFLE(0, 3, 2, 1) + } +} + +#[inline] +#[target_feature(enable = "avx512f")] +unsafe fn add_xor_rot(vs: &mut [[__m512i; 4]; N]) { + // a += b; d ^= a; d <<<= (16, 16, 16, 16); + for [a, b, _, d] in vs.iter_mut() { + *a = _mm512_add_epi32(*a, *b); + *d = _mm512_xor_si512(*d, *a); + *d = _mm512_rol_epi32::<16>(*d); + } + + // c += d; b ^= c; b <<<= (12, 12, 12, 12); + for [_, b, c, d] in vs.iter_mut() { + *c = _mm512_add_epi32(*c, *d); + *b = _mm512_xor_si512(*b, *c); + *b = _mm512_rol_epi32::<12>(*b); + } + + // a += b; d ^= a; d <<<= (8, 8, 8, 8); + for [a, b, _, d] in vs.iter_mut() { + *a = _mm512_add_epi32(*a, *b); + *d = _mm512_xor_si512(*d, *a); + *d = _mm512_rol_epi32::<8>(*d); + } + + // c += d; b ^= c; b <<<= (7, 7, 7, 7); + for [_, b, c, d] in vs.iter_mut() { + *c = _mm512_add_epi32(*c, *d); + *b = _mm512_xor_si512(*b, *c); + *b = _mm512_rol_epi32::<7>(*b); + } +} diff --git a/chacha20/src/lib.rs b/chacha20/src/lib.rs index d9846f71..f4c2fa90 100644 --- a/chacha20/src/lib.rs +++ b/chacha20/src/lib.rs @@ -196,9 +196,10 @@ cfg_if! { `chacha20_force_sse2` configuration option"); type Tokens = (); } else { + cpufeatures::new!(avx512_cpuid, "avx512f"); cpufeatures::new!(avx2_cpuid, "avx2"); cpufeatures::new!(sse2_cpuid, "sse2"); - type Tokens = (avx2_cpuid::InitToken, sse2_cpuid::InitToken); + type Tokens = (avx512_cpuid::InitToken, avx2_cpuid::InitToken, sse2_cpuid::InitToken); } } } else { @@ -252,7 +253,7 @@ impl ChaChaCore { } else if #[cfg(chacha20_force_sse2)] { let tokens = (); } else { - let tokens = (avx2_cpuid::init(), sse2_cpuid::init()); + let tokens = (avx512_cpuid::init(), avx2_cpuid::init(), sse2_cpuid::init()); } } } else { @@ -307,8 +308,12 @@ impl StreamCipherCore for ChaChaCore { backends::sse2::inner::(&mut self.state, f); } } else { - let (avx2_token, sse2_token) = self.tokens; - if avx2_token.get() { + let (avx512_token, avx2_token, sse2_token) = self.tokens; + if avx512_token.get() { + unsafe { + backends::avx512::inner::(&mut self.state, f); + } + } else if avx2_token.get() { unsafe { backends::avx2::inner::(&mut self.state, f); } From 20d700af571316b0481ff7e5321329e6c6688c7b Mon Sep 17 00:00:00 2001 From: caelunshun Date: Sat, 1 Nov 2025 15:08:30 -0600 Subject: [PATCH 2/9] Test PAR_BLOCKS=8 New throughput results on Zen 4: test chacha20_bench1_16b ... bench: 25.53 ns/iter (+/- 0.75) = 640 MB/s test chacha20_bench2_256b ... bench: 255.88 ns/iter (+/- 4.16) = 1003 MB/s test chacha20_bench3_1kib ... bench: 192.76 ns/iter (+/- 4.15) = 5333 MB/s test chacha20_bench4_16kib ... bench: 2,873.78 ns/iter (+/- 62.99) = 5702 MB/s 3x regression for 256b case, since minimum 512b is required to use parallel. --- chacha20/src/backends/avx512.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/chacha20/src/backends/avx512.rs b/chacha20/src/backends/avx512.rs index 9622606b..39662e19 100644 --- a/chacha20/src/backends/avx512.rs +++ b/chacha20/src/backends/avx512.rs @@ -11,7 +11,7 @@ use crate::{STATE_WORDS, chacha::Block}; #[cfg(feature = "cipher")] use cipher::{ BlockSizeUser, ParBlocks, ParBlocksSizeUser, StreamCipherBackend, StreamCipherClosure, - consts::{U4, U64}, + consts::{U8, U64}, }; #[cfg(target_arch = "x86")] @@ -20,7 +20,7 @@ use core::arch::x86::*; use core::arch::x86_64::*; /// Number of blocks processed in parallel. -const PAR_BLOCKS: usize = 4; +const PAR_BLOCKS: usize = 8; /// Number of `__m512i` to store parallel blocks. const N: usize = PAR_BLOCKS / 4; @@ -125,7 +125,7 @@ impl BlockSizeUser for Backend { #[cfg(feature = "cipher")] impl ParBlocksSizeUser for Backend { - type ParBlocksSize = U4; + type ParBlocksSize = U8; } #[cfg(feature = "cipher")] From 5c1ad622d5434ddf7172996c652d2fd9dbb3326b Mon Sep 17 00:00:00 2001 From: caelunshun Date: Sat, 1 Nov 2025 16:20:06 -0600 Subject: [PATCH 3/9] Get back 256b performance by specializing gen_tail_blocks for AVX-512 --- chacha20/src/backends/avx512.rs | 136 +++++++++++++++++++++----------- 1 file changed, 92 insertions(+), 44 deletions(-) diff --git a/chacha20/src/backends/avx512.rs b/chacha20/src/backends/avx512.rs index 39662e19..79332dec 100644 --- a/chacha20/src/backends/avx512.rs +++ b/chacha20/src/backends/avx512.rs @@ -11,7 +11,7 @@ use crate::{STATE_WORDS, chacha::Block}; #[cfg(feature = "cipher")] use cipher::{ BlockSizeUser, ParBlocks, ParBlocksSizeUser, StreamCipherBackend, StreamCipherClosure, - consts::{U8, U64}, + consts::{U16, U64}, }; #[cfg(target_arch = "x86")] @@ -19,10 +19,15 @@ use core::arch::x86::*; #[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; -/// Number of blocks processed in parallel. -const PAR_BLOCKS: usize = 8; -/// Number of `__m512i` to store parallel blocks. -const N: usize = PAR_BLOCKS / 4; +/// Maximum number of blocks processed in parallel. +/// We also support 8 and 4 in gen_tail_blocks. +const MAX_PAR_BLOCKS: usize = 16; + +/// Divisor to compute `N`, the number of __m512i needed +/// to represent a number of parallel blocks. +const BLOCKS_PER_VECTOR: usize = 4; + +const MAX_N: usize = MAX_PAR_BLOCKS / BLOCKS_PER_VECTOR; #[inline] #[target_feature(enable = "avx512f")] @@ -48,8 +53,8 @@ where 8 => _mm512_add_epi64(c, _mm512_set_epi64(0, 3, 0, 2, 0, 1, 0, 0)), _ => unreachable!(), }; - let mut ctr = [c; N]; - for i in 0..N { + let mut ctr = [c; MAX_N]; + for i in 0..MAX_N { ctr[i] = c; c = match size_of::() { 4 => _mm512_add_epi32( @@ -80,7 +85,7 @@ where } #[inline] -#[target_feature(enable = "avx2")] +#[target_feature(enable = "avx512")] #[cfg(feature = "rng")] pub(crate) unsafe fn rng_inner(core: &mut ChaChaCore, buffer: &mut [u32; 64]) where @@ -114,7 +119,7 @@ where struct Backend { v: [__m512i; 3], - ctr: [__m512i; N], + ctr: [__m512i; MAX_N], _pd: PhantomData<(R, V)>, } @@ -125,7 +130,53 @@ impl BlockSizeUser for Backend { #[cfg(feature = "cipher")] impl ParBlocksSizeUser for Backend { - type ParBlocksSize = U8; + type ParBlocksSize = U16; +} + +#[cfg(feature = "cipher")] +impl Backend { + fn gen_par_ks_blocks_inner( + &mut self, + blocks: &mut [cipher::Block; PAR_BLOCKS], + ) { + assert!(PAR_BLOCKS.is_multiple_of(BLOCKS_PER_VECTOR)); + + unsafe { + let vs = rounds::(&self.v, &self.ctr[..N].try_into().unwrap()); + + let pb = blocks.len() as i32; + for c in self.ctr.iter_mut() { + *c = match size_of::() { + 4 => _mm512_add_epi32( + *c, + _mm512_set_epi32(0, 0, 0, pb, 0, 0, 0, pb, 0, 0, 0, pb, 0, 0, 0, pb), + ), + 8 => _mm512_add_epi64( + *c, + _mm512_set_epi64(0, pb as i64, 0, pb as i64, 0, pb as i64, 0, pb as i64), + ), + _ => unreachable!(), + } + } + + let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i; + for (vi, v) in vs.into_iter().enumerate() { + let t: [__m128i; 16] = core::mem::transmute(v); + for i in 0..BLOCKS_PER_VECTOR { + _mm_storeu_si128(block_ptr.add(i), t[4 * i]); + _mm_storeu_si128(block_ptr.add(4 + i), t[4 * i + 1]); + _mm_storeu_si128(block_ptr.add(8 + i), t[4 * i + 2]); + _mm_storeu_si128(block_ptr.add(12 + i), t[4 * i + 3]); + } + + if vi == PAR_BLOCKS / BLOCKS_PER_VECTOR - 1 { + break; + } + + block_ptr = block_ptr.add(16); + } + } + } } #[cfg(feature = "cipher")] @@ -133,7 +184,7 @@ impl StreamCipherBackend for Backend { #[inline(always)] fn gen_ks_block(&mut self, block: &mut Block) { unsafe { - let res = rounds::(&self.v, &self.ctr); + let res = rounds::<1, R>(&self.v, self.ctr[..1].try_into().unwrap()); for c in self.ctr.iter_mut() { *c = match size_of::() { 4 => _mm512_add_epi32( @@ -155,35 +206,29 @@ impl StreamCipherBackend for Backend { #[inline(always)] fn gen_par_ks_blocks(&mut self, blocks: &mut ParBlocks) { - unsafe { - let vs = rounds::(&self.v, &self.ctr); + self.gen_par_ks_blocks_inner::( + blocks.as_mut_slice().try_into().unwrap(), + ); + } - let pb = PAR_BLOCKS as i32; - for c in self.ctr.iter_mut() { - *c = match size_of::() { - 4 => _mm512_add_epi32( - *c, - _mm512_set_epi32(0, 0, 0, pb, 0, 0, 0, pb, 0, 0, 0, pb, 0, 0, 0, pb), - ), - 8 => _mm512_add_epi64( - *c, - _mm512_set_epi64(0, pb as i64, 0, pb as i64, 0, pb as i64, 0, pb as i64), - ), - _ => unreachable!(), - } - } + #[inline(always)] + fn gen_tail_blocks(&mut self, mut blocks: &mut [cipher::Block]) { + while blocks.len() >= 8 { + self.gen_par_ks_blocks_inner::<8, { 8 / BLOCKS_PER_VECTOR }>( + (&mut blocks[..8]).try_into().unwrap(), + ); + blocks = &mut blocks[8..]; + } - let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i; - for v in vs { - let t: [__m128i; 16] = core::mem::transmute(v); - for i in 0..4 { - _mm_storeu_si128(block_ptr.add(i), t[4 * i]); - _mm_storeu_si128(block_ptr.add(4 + i), t[4 * i + 1]); - _mm_storeu_si128(block_ptr.add(8 + i), t[4 * i + 2]); - _mm_storeu_si128(block_ptr.add(12 + i), t[4 * i + 3]); - } - block_ptr = block_ptr.add(16); - } + while blocks.len() >= 4 { + self.gen_par_ks_blocks_inner::<4, { 4 / BLOCKS_PER_VECTOR }>( + (&mut blocks[..4]).try_into().unwrap(), + ); + blocks = &mut blocks[4..]; + } + + for block in blocks { + self.gen_ks_block(block); } } } @@ -215,7 +260,10 @@ impl Backend { #[inline] #[target_feature(enable = "avx512f")] -unsafe fn rounds(v: &[__m512i; 3], c: &[__m512i; N]) -> [[__m512i; 4]; N] { +unsafe fn rounds( + v: &[__m512i; 3], + c: &[__m512i; N], +) -> [[__m512i; 4]; N] { let mut vs: [[__m512i; 4]; N] = [[_mm512_setzero_si512(); 4]; N]; for i in 0..N { vs[i] = [v[0], v[1], v[2], c[i]]; @@ -235,8 +283,8 @@ unsafe fn rounds(v: &[__m512i; 3], c: &[__m512i; N]) -> [[__m512i; 4] } #[inline] -#[target_feature(enable = "avx2")] -unsafe fn double_quarter_round(v: &mut [[__m512i; 4]; N]) { +#[target_feature(enable = "avx512f")] +unsafe fn double_quarter_round(v: &mut [[__m512i; 4]; N]) { add_xor_rot(v); rows_to_cols(v); add_xor_rot(v); @@ -280,7 +328,7 @@ unsafe fn double_quarter_round(v: &mut [[__m512i; 4]; N]) { /// - https://github.com/floodyberry/chacha-opt/blob/0ab65cb99f5016633b652edebaf3691ceb4ff753/chacha_blocks_ssse3-64.S#L639-L643 #[inline] #[target_feature(enable = "avx512f")] -unsafe fn rows_to_cols(vs: &mut [[__m512i; 4]; N]) { +unsafe fn rows_to_cols(vs: &mut [[__m512i; 4]; N]) { // c >>>= 32; d >>>= 64; a >>>= 96; for [a, _, c, d] in vs { *c = _mm512_shuffle_epi32::<0b_00_11_10_01>(*c); // _MM_SHUFFLE(0, 3, 2, 1) @@ -308,7 +356,7 @@ unsafe fn rows_to_cols(vs: &mut [[__m512i; 4]; N]) { /// reversing the transformation of [`rows_to_cols`]. #[inline] #[target_feature(enable = "avx512f")] -unsafe fn cols_to_rows(vs: &mut [[__m512i; 4]; N]) { +unsafe fn cols_to_rows(vs: &mut [[__m512i; 4]; N]) { // c <<<= 32; d <<<= 64; a <<<= 96; for [a, _, c, d] in vs { *c = _mm512_shuffle_epi32::<0b_10_01_00_11>(*c); // _MM_SHUFFLE(2, 1, 0, 3) @@ -319,7 +367,7 @@ unsafe fn cols_to_rows(vs: &mut [[__m512i; 4]; N]) { #[inline] #[target_feature(enable = "avx512f")] -unsafe fn add_xor_rot(vs: &mut [[__m512i; 4]; N]) { +unsafe fn add_xor_rot(vs: &mut [[__m512i; 4]; N]) { // a += b; d ^= a; d <<<= (16, 16, 16, 16); for [a, b, _, d] in vs.iter_mut() { *a = _mm512_add_epi32(*a, *b); From ff231fab0ed7b7448930ba05cab5a73269b01101 Mon Sep 17 00:00:00 2001 From: caelunshun Date: Sat, 1 Nov 2025 16:45:31 -0600 Subject: [PATCH 4/9] Add RNG support for avx512 (not benchmarked) --- chacha20/Cargo.toml | 4 +-- chacha20/src/backends.rs | 4 ++- chacha20/src/backends/avx512.rs | 59 ++++++++++++--------------------- chacha20/src/lib.rs | 17 ++++++++-- chacha20/src/rng.rs | 14 ++++++-- 5 files changed, 51 insertions(+), 47 deletions(-) diff --git a/chacha20/Cargo.toml b/chacha20/Cargo.toml index 1bb2eb83..38223e2a 100644 --- a/chacha20/Cargo.toml +++ b/chacha20/Cargo.toml @@ -20,9 +20,7 @@ rand_core-compatible RNGs based on those ciphers. [dependencies] cfg-if = "1" -cipher = { version = "0.5.0-rc.1", optional = true, features = [ - "stream-wrapper", -] } +cipher = { version = "0.5.0-rc.1", optional = true, features = ["stream-wrapper"] } rand_core = { version = "0.10.0-rc.1", optional = true, default-features = false } # `zeroize` is an explicit dependency because this crate may be used without the `cipher` crate diff --git a/chacha20/src/backends.rs b/chacha20/src/backends.rs index 22bf7885..23c8e539 100644 --- a/chacha20/src/backends.rs +++ b/chacha20/src/backends.rs @@ -7,7 +7,9 @@ cfg_if! { pub(crate) mod soft; } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { cfg_if! { - if #[cfg(chacha20_force_avx2)] { + if #[cfg(chacha20_force_avx512)] { + pub(crate) mod avx512; + } else if #[cfg(chacha20_force_avx2)] { pub(crate) mod avx2; } else if #[cfg(chacha20_force_sse2)] { pub(crate) mod sse2; diff --git a/chacha20/src/backends/avx512.rs b/chacha20/src/backends/avx512.rs index 79332dec..1277797b 100644 --- a/chacha20/src/backends/avx512.rs +++ b/chacha20/src/backends/avx512.rs @@ -85,25 +85,29 @@ where } #[inline] -#[target_feature(enable = "avx512")] +#[target_feature(enable = "avx512f")] #[cfg(feature = "rng")] pub(crate) unsafe fn rng_inner(core: &mut ChaChaCore, buffer: &mut [u32; 64]) where R: Rounds, V: Variant, { + use core::slice; + + use crate::rng::BLOCK_WORDS; + let state_ptr = core.state.as_ptr() as *const __m128i; let v = [ - _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(0))), - _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(1))), - _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))), + _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(0))), + _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(1))), + _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(2))), ]; - let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3))); - c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0)); - let mut ctr = [c; N]; - for i in 0..N { + let mut c = _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(3))); + c = _mm512_add_epi64(c, _mm512_set_epi64(0, 3, 0, 2, 0, 1, 0, 0)); + let mut ctr = [c; MAX_N]; + for i in 0..MAX_N { ctr[i] = c; - c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2)); + c = _mm512_add_epi64(c, _mm512_set_epi64(0, 4, 0, 4, 0, 4, 0, 4)); } let mut backend = Backend:: { v, @@ -111,10 +115,16 @@ where _pd: PhantomData, }; - backend.rng_gen_par_ks_blocks(buffer); + let buffer = slice::from_raw_parts_mut( + buffer.as_mut_ptr().cast::(), + buffer.len() / BLOCK_WORDS as usize, + ); + backend.gen_par_ks_blocks_inner::<4, { 4 / BLOCKS_PER_VECTOR }>(buffer.try_into().unwrap()); - core.state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32; - core.state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32; + core.state[12] = + _mm256_extract_epi32::<0>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32; + core.state[13] = + _mm256_extract_epi32::<1>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32; } struct Backend { @@ -233,31 +243,6 @@ impl StreamCipherBackend for Backend { } } -#[cfg(feature = "rng")] -impl Backend { - #[inline(always)] - fn rng_gen_par_ks_blocks(&mut self, blocks: &mut [u32; 64]) { - unsafe { - let vs = rounds::(&self.v, &self.ctr); - - let pb = PAR_BLOCKS as i32; - for c in self.ctr.iter_mut() { - *c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, pb as i64, 0, pb as i64)); - } - - let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i; - for v in vs { - let t: [__m128i; 8] = core::mem::transmute(v); - for i in 0..4 { - _mm_storeu_si128(block_ptr.add(i), t[2 * i]); - _mm_storeu_si128(block_ptr.add(4 + i), t[2 * i + 1]); - } - block_ptr = block_ptr.add(8); - } - } - } -} - #[inline] #[target_feature(enable = "avx512f")] unsafe fn rounds( diff --git a/chacha20/src/lib.rs b/chacha20/src/lib.rs index f4c2fa90..1908104d 100644 --- a/chacha20/src/lib.rs +++ b/chacha20/src/lib.rs @@ -185,7 +185,12 @@ cfg_if! { type Tokens = (); } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { cfg_if! { - if #[cfg(chacha20_force_avx2)] { + if #[cfg(chacha20_force_avx512)] { + #[cfg(not(target_feature = "avx512f"))] + compile_error!("You must enable `avx512f` target feature with \ + `chacha20_force_avx512` configuration option"); + type Tokens = (); + } else if #[cfg(chacha20_force_avx2)] { #[cfg(not(target_feature = "avx2"))] compile_error!("You must enable `avx2` target feature with \ `chacha20_force_avx2` configuration option"); @@ -248,7 +253,9 @@ impl ChaChaCore { let tokens = (); } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { cfg_if! { - if #[cfg(chacha20_force_avx2)] { + if #[cfg(chacha20_force_avx512)] { + let tokens = (); + } else if #[cfg(chacha20_force_avx2)] { let tokens = (); } else if #[cfg(chacha20_force_sse2)] { let tokens = (); @@ -299,7 +306,11 @@ impl StreamCipherCore for ChaChaCore { f.call(&mut backends::soft::Backend(self)); } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { cfg_if! { - if #[cfg(chacha20_force_avx2)] { + if #[cfg(chacha20_force_avx512)] { + unsafe { + backends::avx512::inner::(&mut self.state, f); + } + } else if #[cfg(chacha20_force_avx2)] { unsafe { backends::avx2::inner::(&mut self.state, f); } diff --git a/chacha20/src/rng.rs b/chacha20/src/rng.rs index 4c483442..79954c49 100644 --- a/chacha20/src/rng.rs +++ b/chacha20/src/rng.rs @@ -189,7 +189,11 @@ impl ChaChaCore { backends::soft::Backend(self).gen_ks_blocks(buffer); } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { cfg_if! { - if #[cfg(chacha20_force_avx2)] { + if #[cfg(chacha20_force_avx512)] { + unsafe { + backends::avx512::rng_inner::(self, buffer); + } + } else if #[cfg(chacha20_force_avx2)] { unsafe { backends::avx2::rng_inner::(self, buffer); } @@ -198,8 +202,12 @@ impl ChaChaCore { backends::sse2::rng_inner::(self, buffer); } } else { - let (avx2_token, sse2_token) = self.tokens; - if avx2_token.get() { + let (avx512_token, avx2_token, sse2_token) = self.tokens; + if avx512_token.get() { + unsafe { + backends::avx512::rng_inner::(self, buffer); + } + } else if avx2_token.get() { unsafe { backends::avx2::rng_inner::(self, buffer); } From 03b9b92bb805b0a1442eaa04736acb64a294d2e7 Mon Sep 17 00:00:00 2001 From: caelunshun Date: Sun, 2 Nov 2025 17:00:53 -0700 Subject: [PATCH 5/9] chacha20 avx512: Refactor design to avoid using 512-bit path even for short inputs This makes up about half the performance loss for 16-byte output. I suspect the remaining loss is due to different inlining decisions and probably insignificant. --- chacha20/src/backends/avx512.rs | 522 +++++++++++++++++++++++++------- chacha20/src/lib.rs | 4 +- 2 files changed, 407 insertions(+), 119 deletions(-) diff --git a/chacha20/src/backends/avx512.rs b/chacha20/src/backends/avx512.rs index 1277797b..f46baa82 100644 --- a/chacha20/src/backends/avx512.rs +++ b/chacha20/src/backends/avx512.rs @@ -38,49 +38,24 @@ where F: StreamCipherClosure, V: Variant, { - let state_ptr = state.as_ptr() as *const __m128i; - let v = [ - _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(0))), - _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(1))), - _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(2))), - ]; - let mut c = _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(3))); - c = match size_of::() { - 4 => _mm512_add_epi32( - c, - _mm512_set_epi32(0, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0), - ), - 8 => _mm512_add_epi64(c, _mm512_set_epi64(0, 3, 0, 2, 0, 1, 0, 0)), - _ => unreachable!(), - }; - let mut ctr = [c; MAX_N]; - for i in 0..MAX_N { - ctr[i] = c; - c = match size_of::() { - 4 => _mm512_add_epi32( - c, - _mm512_set_epi32(0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4), - ), - 8 => _mm512_add_epi64(c, _mm512_set_epi64(0, 4, 0, 4, 0, 4, 0, 4)), - _ => unreachable!(), - }; - } + let simd_state = state.as_mut_ptr().cast::(); + let mut backend = Backend:: { - v, - ctr, + state: [ + _mm_loadu_epi32(simd_state), + _mm_loadu_epi32(simd_state.add(4)), + _mm_loadu_epi32(simd_state.add(8)), + ], + ctr: _mm_loadu_epi32(simd_state.add(12)), _pd: PhantomData, }; f.call(&mut backend); - state[12] = _mm256_extract_epi32::<0>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32; - match size_of::() { - 4 => {} - 8 => { - state[13] = - _mm256_extract_epi32::<1>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32 - } - _ => unreachable!(), + // Update counter in the persistent state + state[12] = _mm_extract_epi32::<0>(backend.ctr) as u32; + if size_of::() == 8 { + state[13] = _mm_extract_epi32::<1>(backend.ctr) as u32; } } @@ -128,117 +103,283 @@ where } struct Backend { - v: [__m512i; 3], - ctr: [__m512i; MAX_N], + state: [__m128i; 3], + ctr: __m128i, _pd: PhantomData<(R, V)>, } -#[cfg(feature = "cipher")] -impl BlockSizeUser for Backend { - type BlockSize = U64; -} - -#[cfg(feature = "cipher")] -impl ParBlocksSizeUser for Backend { - type ParBlocksSize = U16; -} - #[cfg(feature = "cipher")] impl Backend { - fn gen_par_ks_blocks_inner( - &mut self, - blocks: &mut [cipher::Block; PAR_BLOCKS], - ) { - assert!(PAR_BLOCKS.is_multiple_of(BLOCKS_PER_VECTOR)); + #[inline] + #[target_feature(enable = "avx512f", enable = "avx512vl")] + unsafe fn increment_ctr(&mut self, amount: usize) { + match size_of::() { + 4 => { + self.ctr = _mm_add_epi32(self.ctr, _mm_set_epi32(0, 0, 0, amount as i32)); + } + 8 => { + self.ctr = _mm_add_epi64(self.ctr, _mm_set_epi64x(0, amount as i64)); + } + _ => unreachable!(), + } + } - unsafe { - let vs = rounds::(&self.v, &self.ctr[..N].try_into().unwrap()); - - let pb = blocks.len() as i32; - for c in self.ctr.iter_mut() { - *c = match size_of::() { - 4 => _mm512_add_epi32( - *c, - _mm512_set_epi32(0, 0, 0, pb, 0, 0, 0, pb, 0, 0, 0, pb, 0, 0, 0, pb), - ), - 8 => _mm512_add_epi64( - *c, - _mm512_set_epi64(0, pb as i64, 0, pb as i64, 0, pb as i64, 0, pb as i64), - ), - _ => unreachable!(), + /// Generates blocks using the 512-bit-wide dispatch + /// with up to `N` vectors processed in parallel, producing + /// `N * BLOCKS_PER_VECTOR` blocks. + #[inline] + #[target_feature(enable = "avx512f", enable = "avx512vl")] + unsafe fn gen_blocks_fullwidth(&mut self, blocks: &mut [Block]) { + let par_blocks = N * BLOCKS_PER_VECTOR; + assert!(blocks.len() <= par_blocks); + + let mut ctrs = [_mm512_broadcast_i32x4(self.ctr); N]; + for i in 0..ctrs.len() { + match size_of::() { + 4 => { + ctrs[i] = _mm512_add_epi32( + ctrs[i], + _mm512_set_epi32( + 0, + 0, + 0, + (i * BLOCKS_PER_VECTOR + 3) as i32, + 0, + 0, + 0, + (i * BLOCKS_PER_VECTOR + 2) as i32, + 0, + 0, + 0, + (i * BLOCKS_PER_VECTOR + 1) as i32, + 0, + 0, + 0, + (i * BLOCKS_PER_VECTOR) as i32, + ), + ); } + 8 => { + ctrs[i] = _mm512_add_epi64( + ctrs[i], + _mm512_set_epi64( + 0, + (i * BLOCKS_PER_VECTOR + 3) as i64, + 0, + (i * BLOCKS_PER_VECTOR + 2) as i64, + 0, + (i * BLOCKS_PER_VECTOR + 1) as i64, + 0, + (i * BLOCKS_PER_VECTOR) as i64, + ), + ); + } + _ => unreachable!(), } + } - let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i; - for (vi, v) in vs.into_iter().enumerate() { - let t: [__m128i; 16] = core::mem::transmute(v); - for i in 0..BLOCKS_PER_VECTOR { - _mm_storeu_si128(block_ptr.add(i), t[4 * i]); - _mm_storeu_si128(block_ptr.add(4 + i), t[4 * i + 1]); - _mm_storeu_si128(block_ptr.add(8 + i), t[4 * i + 2]); - _mm_storeu_si128(block_ptr.add(12 + i), t[4 * i + 3]); - } + self.increment_ctr(blocks.len()); + + let result = rounds::(&self.state.map(|v| _mm512_broadcast_i32x4(v)), &ctrs); + + for i in 0..N { + let result_vectors = result[i]; + + // We have our data in SIMD vectors in the following layout + // (using a, b, c, and d to indicate the resp. 4 rows of each block, + // and Bn to denote the nth block): + // result_vectors[0]: + // B0a0 B0a1 B0a2 B0a3 + // B1a0 B1a1 B1a2 B1a3 + // ... + // B3a0 B3a1 B3a2 B3a2 + // + // result_vectors[1]: + // B0b0 B0b1 B0b2 B0b3 + // B1b0 B1b1 B1b2 B1b3 + // ... + // B3b0 B3b1 B3b2 B3b2 + // + // and so on for result_vectors[2] (storing c values) and result_vectors[3] (storing d values). + // + // To store to memory, we need to transpose to the following format: + // transposed[0]: + // B0a0 B0a1 B0a2 B0a3 + // B0b0 B0b1 B0b2 B0b3 + // B0c0 B0c1 B0c2 B0c3 + // B0d0 B0d1 B0d2 B0d3 + // + // and so on, such that each 512-bit SIMD vector + // contains a single contiguous block. + // + // We achieve this transposition using the following + // sequence of shuffles. + + let temp_abab_block01 = _mm512_permutex2var_epi64( + result_vectors[0], + _mm512_setr_epi64(0, 1, 8, 9, 2, 3, 10, 11), + result_vectors[1], + ); + let temp_abab_block23 = _mm512_permutex2var_epi64( + result_vectors[0], + _mm512_setr_epi64(4, 5, 12, 13, 6, 7, 14, 15), + result_vectors[1], + ); - if vi == PAR_BLOCKS / BLOCKS_PER_VECTOR - 1 { - break; + let temp_cdcd_block01 = _mm512_permutex2var_epi64( + result_vectors[2], + _mm512_setr_epi64(0, 1, 8, 9, 2, 3, 10, 11), + result_vectors[3], + ); + let temp_cdcd_block23 = _mm512_permutex2var_epi64( + result_vectors[2], + _mm512_setr_epi64(4, 5, 12, 13, 6, 7, 14, 15), + result_vectors[3], + ); + + let block0 = + _mm512_shuffle_i32x4::<0b01_00_01_00>(temp_abab_block01, temp_cdcd_block01); + let block1 = + _mm512_shuffle_i32x4::<0b11_10_11_10>(temp_abab_block01, temp_cdcd_block01); + let block2 = + _mm512_shuffle_i32x4::<0b01_00_01_00>(temp_abab_block23, temp_cdcd_block23); + let block3 = + _mm512_shuffle_i32x4::<0b11_10_11_10>(temp_abab_block23, temp_cdcd_block23); + + for (j, src_block) in [block0, block1, block2, block3].into_iter().enumerate() { + let dst_index = i * BLOCKS_PER_VECTOR + j; + if dst_index < blocks.len() { + _mm512_storeu_si512((&raw mut blocks[dst_index]).cast(), src_block); } + } + } + } + + /// Generates up to 2 blocks using 256-bit vectors. + #[inline] + #[target_feature(enable = "avx512f", enable = "avx512vl")] + unsafe fn gen_blocks_halfwidth(&mut self, blocks: &mut [Block]) { + assert!(blocks.len() <= 2); + + let mut ctr = _mm256_broadcast_i32x4(self.ctr); + + match size_of::() { + 4 => { + ctr = _mm256_add_epi32(ctr, _mm256_set_epi32(0, 0, 0, 1, 0, 0, 0, 0)); + } + 8 => { + ctr = _mm256_add_epi64(ctr, _mm256_set_epi64x(0, 1, 0, 0)); + } + _ => unreachable!(), + } + + self.increment_ctr(blocks.len()); + + let block_vectors = rounds_halfwide::([ + _mm256_broadcast_i32x4(self.state[0]), + _mm256_broadcast_i32x4(self.state[1]), + _mm256_broadcast_i32x4(self.state[2]), + ctr, + ]); + + // Similar transpose operation as + // in gen_blocks_fullwidth. + + let block0_ab = _mm256_permutex2var_epi64( + block_vectors[0], + _mm256_setr_epi64x(0, 1, 4, 5), + block_vectors[1], + ); + let block0_cd = _mm256_permutex2var_epi64( + block_vectors[2], + _mm256_setr_epi64x(0, 1, 4, 5), + block_vectors[3], + ); + let block1_ab = _mm256_permutex2var_epi64( + block_vectors[0], + _mm256_setr_epi64x(2, 3, 6, 7), + block_vectors[1], + ); + let block1_cd = _mm256_permutex2var_epi64( + block_vectors[2], + _mm256_setr_epi64x(2, 3, 6, 7), + block_vectors[3], + ); - block_ptr = block_ptr.add(16); + for (i, (block_part_ab, block_part_cd)) in [(block0_ab, block0_cd), (block1_ab, block1_cd)] + .into_iter() + .enumerate() + { + if i < blocks.len() { + let dst = (&raw mut blocks[i]).cast::(); + _mm256_storeu_epi32(dst, block_part_ab); + _mm256_storeu_epi32( + dst.add(size_of::() / 2 / size_of::()), + block_part_cd, + ); } } } } +#[cfg(feature = "cipher")] +impl BlockSizeUser for Backend { + type BlockSize = U64; +} + +#[cfg(feature = "cipher")] +impl ParBlocksSizeUser for Backend { + type ParBlocksSize = U16; +} + #[cfg(feature = "cipher")] impl StreamCipherBackend for Backend { + #[inline] + fn gen_par_ks_blocks(&mut self, blocks: &mut ParBlocks) { + unsafe { self.gen_blocks_fullwidth::(blocks) } + } + #[inline(always)] fn gen_ks_block(&mut self, block: &mut Block) { + // Fallback for generating a single block using quarter-width vectors + // (128). + unsafe { - let res = rounds::<1, R>(&self.v, self.ctr[..1].try_into().unwrap()); - for c in self.ctr.iter_mut() { - *c = match size_of::() { - 4 => _mm512_add_epi32( - *c, - _mm512_set_epi32(0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1), - ), - 8 => _mm512_add_epi64(*c, _mm512_set_epi64(0, 1, 0, 1, 0, 1, 0, 1)), - _ => unreachable!(), - }; - } + let state = [self.state[0], self.state[1], self.state[2], self.ctr]; + + self.increment_ctr(1); - let block_ptr = block.as_mut_ptr() as *mut __m128i; + let result = rounds_quarterwide::(state); - for i in 0..4 { - _mm_storeu_si128(block_ptr.add(i), _mm512_extracti32x4_epi32::<0>(res[0][i])); + for row in 0..4 { + let dst = block.as_mut_ptr().cast::().add(row * 4); + _mm_storeu_epi32(dst, result[row]); } } } - #[inline(always)] - fn gen_par_ks_blocks(&mut self, blocks: &mut ParBlocks) { - self.gen_par_ks_blocks_inner::( - blocks.as_mut_slice().try_into().unwrap(), - ); - } - - #[inline(always)] - fn gen_tail_blocks(&mut self, mut blocks: &mut [cipher::Block]) { - while blocks.len() >= 8 { - self.gen_par_ks_blocks_inner::<8, { 8 / BLOCKS_PER_VECTOR }>( - (&mut blocks[..8]).try_into().unwrap(), - ); - blocks = &mut blocks[8..]; - } + #[inline] + fn gen_tail_blocks(&mut self, blocks: &mut [cipher::Block]) { + assert!(blocks.len() < MAX_PAR_BLOCKS); - while blocks.len() >= 4 { - self.gen_par_ks_blocks_inner::<4, { 4 / BLOCKS_PER_VECTOR }>( - (&mut blocks[..4]).try_into().unwrap(), - ); - blocks = &mut blocks[4..]; + if blocks.is_empty() { + return; } - for block in blocks { - self.gen_ks_block(block); + // Fallback for generating a number of blocks less than + // MAX_PAR_BLOCKS. + unsafe { + if blocks.len() == 1 { + self.gen_ks_block(&mut blocks[0]); + } else if blocks.len() == 2 { + self.gen_blocks_halfwidth(blocks); + } else if blocks.len() <= 4 { + self.gen_blocks_fullwidth::<1>(blocks); + } else if blocks.len() <= 8 { + self.gen_blocks_fullwidth::<2>(blocks); + } else { + self.gen_blocks_fullwidth::(blocks); + } } } } @@ -381,3 +522,150 @@ unsafe fn add_xor_rot(vs: &mut [[__m512i; 4]; N]) { *b = _mm512_rol_epi32::<7>(*b); } } + +// Below is another implementation of the round application +// that uses 256-bit vectors instead of 512-bit (but, unlike +// the avx2 module, can use new AVX-512 instructions like rotates). +// It is used for tail processing of shorter outputs, +// since 256-bit instructions can be faster and lower latency +// than 512-bit instructions on certain microarchitectures (e.g. Zen 4). + +#[inline] +#[target_feature(enable = "avx512f", enable = "avx512vl")] +unsafe fn rounds_halfwide(v_in: [__m256i; 4]) -> [__m256i; 4] { + let mut v = v_in; + + for _ in 0..R::COUNT { + double_quarter_round_halfwide(&mut v); + } + + for (a, b) in v.iter_mut().zip(v_in) { + *a = _mm256_add_epi32(*a, b); + } + + v +} + +#[inline] +#[target_feature(enable = "avx512f", enable = "avx512vl")] +unsafe fn double_quarter_round_halfwide(v: &mut [__m256i; 4]) { + add_xor_rot_halfwide(v); + rows_to_cols_halfwide(v); + add_xor_rot_halfwide(v); + cols_to_rows_halfwide(v); +} + +#[inline] +#[target_feature(enable = "avx512f", enable = "avx512vl")] +unsafe fn rows_to_cols_halfwide(v: &mut [__m256i; 4]) { + // c >>>= 32; d >>>= 64; a >>>= 96; + let [a, _, c, d] = v; + *c = _mm256_shuffle_epi32::<0b_00_11_10_01>(*c); // _MM_SHUFFLE(0, 3, 2, 1) + *d = _mm256_shuffle_epi32::<0b_01_00_11_10>(*d); // _MM_SHUFFLE(1, 0, 3, 2) + *a = _mm256_shuffle_epi32::<0b_10_01_00_11>(*a); // _MM_SHUFFLE(2, 1, 0, 3) +} + +#[inline] +#[target_feature(enable = "avx512f", enable = "avx512vl")] +unsafe fn cols_to_rows_halfwide(v: &mut [__m256i; 4]) { + // c <<<= 32; d <<<= 64; a <<<= 96; + let [a, _, c, d] = v; + *c = _mm256_shuffle_epi32::<0b_10_01_00_11>(*c); // _MM_SHUFFLE(2, 1, 0, 3) + *d = _mm256_shuffle_epi32::<0b_01_00_11_10>(*d); // _MM_SHUFFLE(1, 0, 3, 2) + *a = _mm256_shuffle_epi32::<0b_00_11_10_01>(*a); // _MM_SHUFFLE(0, 3, 2, 1) +} + +#[inline] +#[target_feature(enable = "avx512f", enable = "avx512vl")] +unsafe fn add_xor_rot_halfwide(v: &mut [__m256i; 4]) { + let [a, b, c, d] = v; + + // a += b; d ^= a; d <<<= (16, 16, 16, 16); + *a = _mm256_add_epi32(*a, *b); + *d = _mm256_xor_si256(*d, *a); + *d = _mm256_rol_epi32::<16>(*d); + + // c += d; b ^= c; b <<<= (12, 12, 12, 12); + *c = _mm256_add_epi32(*c, *d); + *b = _mm256_xor_si256(*b, *c); + *b = _mm256_rol_epi32::<12>(*b); + + // a += b; d ^= a; d <<<= (8, 8, 8, 8); + *a = _mm256_add_epi32(*a, *b); + *d = _mm256_xor_si256(*d, *a); + *d = _mm256_rol_epi32::<8>(*d); + + // c += d; b ^= c; b <<<= (7, 7, 7, 7); + *c = _mm256_add_epi32(*c, *d); + *b = _mm256_xor_si256(*b, *c); + *b = _mm256_rol_epi32::<7>(*b); +} + +// Finally, below is an implementation using 128-bit vectors +// for the case of generating a single block. + +#[inline(always)] +unsafe fn rounds_quarterwide(v_in: [__m128i; 4]) -> [__m128i; 4] { + let mut v = v_in; + + for _ in 0..R::COUNT { + double_quarter_round_quarterwide(&mut v); + } + + for (a, b) in v.iter_mut().zip(v_in) { + *a = _mm_add_epi32(*a, b); + } + + v +} + +#[inline(always)] +unsafe fn double_quarter_round_quarterwide(v: &mut [__m128i; 4]) { + add_xor_rot_quarterwide(v); + rows_to_cols_quarterwide(v); + add_xor_rot_quarterwide(v); + cols_to_rows_quarterwide(v); +} + +#[inline(always)] +unsafe fn rows_to_cols_quarterwide(v: &mut [__m128i; 4]) { + // c >>>= 32; d >>>= 64; a >>>= 96; + let [a, _, c, d] = v; + *c = _mm_shuffle_epi32::<0b_00_11_10_01>(*c); // _MM_SHUFFLE(0, 3, 2, 1) + *d = _mm_shuffle_epi32::<0b_01_00_11_10>(*d); // _MM_SHUFFLE(1, 0, 3, 2) + *a = _mm_shuffle_epi32::<0b_10_01_00_11>(*a); // _MM_SHUFFLE(2, 1, 0, 3) +} + +#[inline(always)] +unsafe fn cols_to_rows_quarterwide(v: &mut [__m128i; 4]) { + // c <<<= 32; d <<<= 64; a <<<= 96; + let [a, _, c, d] = v; + *c = _mm_shuffle_epi32::<0b_10_01_00_11>(*c); // _MM_SHUFFLE(2, 1, 0, 3) + *d = _mm_shuffle_epi32::<0b_01_00_11_10>(*d); // _MM_SHUFFLE(1, 0, 3, 2) + *a = _mm_shuffle_epi32::<0b_00_11_10_01>(*a); // _MM_SHUFFLE(0, 3, 2, 1) +} + +#[inline(always)] +unsafe fn add_xor_rot_quarterwide(v: &mut [__m128i; 4]) { + let [a, b, c, d] = v; + + // a += b; d ^= a; d <<<= (16, 16, 16, 16); + *a = _mm_add_epi32(*a, *b); + *d = _mm_xor_si128(*d, *a); + *d = _mm_rol_epi32::<16>(*d); + + // c += d; b ^= c; b <<<= (12, 12, 12, 12); + *c = _mm_add_epi32(*c, *d); + *b = _mm_xor_si128(*b, *c); + *b = _mm_rol_epi32::<12>(*b); + + // a += b; d ^= a; d <<<= (8, 8, 8, 8); + *a = _mm_add_epi32(*a, *b); + *d = _mm_xor_si128(*d, *a); + *d = _mm_rol_epi32::<8>(*d); + + // c += d; b ^= c; b <<<= (7, 7, 7, 7); + *c = _mm_add_epi32(*c, *d); + *b = _mm_xor_si128(*b, *c); + *b = _mm_rol_epi32::<7>(*b); +} diff --git a/chacha20/src/lib.rs b/chacha20/src/lib.rs index 1908104d..102dc193 100644 --- a/chacha20/src/lib.rs +++ b/chacha20/src/lib.rs @@ -186,7 +186,7 @@ cfg_if! { } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { cfg_if! { if #[cfg(chacha20_force_avx512)] { - #[cfg(not(target_feature = "avx512f"))] + #[cfg(not(all(target_feature = "avx512f", target_feature = "avx512vl")))] compile_error!("You must enable `avx512f` target feature with \ `chacha20_force_avx512` configuration option"); type Tokens = (); @@ -201,7 +201,7 @@ cfg_if! { `chacha20_force_sse2` configuration option"); type Tokens = (); } else { - cpufeatures::new!(avx512_cpuid, "avx512f"); + cpufeatures::new!(avx512_cpuid, "avx512f", "avx512vl"); cpufeatures::new!(avx2_cpuid, "avx2"); cpufeatures::new!(sse2_cpuid, "sse2"); type Tokens = (avx512_cpuid::InitToken, avx2_cpuid::InitToken, sse2_cpuid::InitToken); From 0442f2f4b82307c3c86c141f3d2ca9e84c80eadf Mon Sep 17 00:00:00 2001 From: caelunshun Date: Sun, 2 Nov 2025 19:33:26 -0700 Subject: [PATCH 6/9] Add long input test for chacha20 to test AVX-512 full parallel implementation --- chacha20/Cargo.toml | 1 + .../tests/data/chacha20_long_ciphertext.txt | 1 + .../tests/data/chacha20_long_plaintext.txt | 1 + chacha20/tests/kats.rs | 23 +++++++++++++++++++ 4 files changed, 26 insertions(+) create mode 100644 chacha20/tests/data/chacha20_long_ciphertext.txt create mode 100644 chacha20/tests/data/chacha20_long_plaintext.txt diff --git a/chacha20/Cargo.toml b/chacha20/Cargo.toml index 38223e2a..7200276a 100644 --- a/chacha20/Cargo.toml +++ b/chacha20/Cargo.toml @@ -34,6 +34,7 @@ cipher = { version = "0.5.0-rc.1", features = ["dev"] } hex-literal = "1" proptest = "1" rand_chacha = "0.9" +hex = "0.4" [features] default = ["cipher"] diff --git a/chacha20/tests/data/chacha20_long_ciphertext.txt b/chacha20/tests/data/chacha20_long_ciphertext.txt new file mode 100644 index 00000000..e580bfd2 --- /dev/null +++ b/chacha20/tests/data/chacha20_long_ciphertext.txt @@ -0,0 +1 @@ +d15b418e477b13f6c6ca78bcefb019acd32419c9de046236edfc32a8d4a9c8e474d73f20d92f90b9805c6c71c56e29878d8c369841cfe817cf4af403456f1452c76d748c9c5d08cd2e8cf4c5c7ae2f0ef722817112a0d4b6ba0e867b5fe122013bb3c861b776b7106a19a75fcd60eead7f2315b47905fa376383f674bf7bc7c8a55b4ab6ad8f6315bf9d211e74c16bb7606af8f87ae5e3df9c5b8c52d1044b5bc559e1150d034b44e0886170e355e4a29bb0b21f7126a47b8d0f83eef6e6bb483998c4bb01ae1da66e750e4d1d3cd42c37b3649b8f35efdc650a2beec08883d6f5f84ad30d1e58601ebb228a80d6c95c82abf66a78415a2e48a105b08e2dd1ca52c6ece53a8b131950d04197ec39eb5fd0a17731dc2f223ee6d23c289966d3446ff1c4950ceb30f85c2adf8e5e1f876ffde011e15b8673409d4a349e20d2101b26b82efd950511cd92110bdfd19d0e97d6dcf3264f13fd56d1efc322193cff04f5fed90be5b734a4788450abe949e02b0ea7f5765587808ba41ad22a3d22a22626114cdad4aba17fc7cd617a7be04e2ff5f49f50b28ffdfd273d66c79b20f7f08cb8271ebfa8df83cf153f57cd74bb7dae8f8b39c107e29cc75c7d155acf4df04d4419b1a22439dc438a1cfcaa1047a05488f11d4d65eba77eb796a425bc3832798abf7867e1f8924e40875386ab4b38bccf01c3b0269bb4dd439c9d9ee70e024c7f6ccafc9aefbb1a6bfe761b86e42c967ff3a435d7acdbad4adbaef92acb8044abf1fcbd77baef27b84d96f0f63193467a961e2ca12d63c06d713386a354705bb837470b385a2fb2b57a12d6ec81fd432320c674f94f6dc97a0576ebe05c8e47c493e77bcfba7425db1cd7756be05e342c4536eee23e719a6ab4e1af9be1331f328af13988b7c99764e997840844cc89386d196a8ec60bc7c1a929a58963e6f7c7b646c444c16cb8c1a793e97c659639c2f11efa0e5543f5ec8bdbef6e43c03de0468310e36c0eceac4bd96f80cbe4f13c3d204b3b661b9d11c84616fa6c613a0d5c2b8b31827dad90a0307df6d566f4321430f76649c3d28c68c4ab60c3134e07453df1821c5fc90b265cc8d5c2d8c8956dba77a0d9b528f097736c82b172aa195ca4b20ac85e7eec64851631003c6c3ce41110b27385b7ded00178a508e5ed52fcb64938b543a38ea68d22b924e171548ae558e7b3a2b436652e7dda4f7f96e0732ba81dcf91ec3b064cc04ddb55c31847df992aee116fb91581223b6f8588151d9ade8ed5b4796ea8899028e06676695a307924b2caee25790c03a4322f0035f59e4d02bb26d40c75658166d1a39fcbfade6ae37615efe0bd3cce825caf5fe74a0b1f0ddf4ee23fdd5acefd3a88b38439f8b831f6393d8eee93d4d23efa496c93c2088405786e3a5b8b0de02654df35c6910817b065cb2bbebeab4d6b31250ccb885575aa4880250aa805ec1a80b4093499128b406e3e624b2481347c4292d9db13afbf0a75bad9df5a2111f2f2326d2d7543128706310c9f7517a25aa03aca6503f533446732885ff884ccfc66f5c771e3330a5c22c0b43b191b88262d1bc14833011bbf12a44345e7095ceab0fa824ed55b0a5d2cae20f5047e46e882b3e69a9f65775e6b455545b55c24669a0516d23e6fd96d60698ed5bac8c02df05086cd9f4b3398b087a10f1b5eecbf13ba38a626a5e2cadae279875e7efcfc732b353dc803538fa65cf55eafad56647be291ab5fbe6424ea416b419d77d864a72efdb9930e0cea4ec626f281844374595e08766466863e40be5fbd1d9937fb5d61e9d7387ef98540d564a9359eb21f43d1ac36cfbc7c06c702973b8c6d32bda7e2167185e44d8a7abab372073a63ab898ad5ef5068bdb90685fb83bddb7038c5e8273bf8a957c931d9e30528250017f68b02a1de4098d6acd1430826806eb78fbe8d87b7e07a3fb2aec5a49be204b1a74cb034196224bc3505c005a38954a69fc6328e610677ddd2d420a333e7623a5bd52ef38d8acdafe5f8b34f6fa6a088c47492d775f3930ae97ca3c1ec74807d40e8793fab12136fae771923cd285c60948f1e37ba7fd8c74afd96c4f4240eeb84595384f0fb93254b31ef9a1ad6ea46ecfac12bf2e37dd9b1d4998cdc85ffdcbc67da3d7aa7f207cfb0623e6a9b438d25ce \ No newline at end of file diff --git a/chacha20/tests/data/chacha20_long_plaintext.txt b/chacha20/tests/data/chacha20_long_plaintext.txt new file mode 100644 index 00000000..33089d38 --- /dev/null +++ b/chacha20/tests/data/chacha20_long_plaintext.txt @@ -0,0 +1 @@ +45d146d7d3f7fdb60414ff3c037ef4c086ef27096128e027f73f2038fcb9c273fb498b8ec49c7b61e3b91e7b220de438b1d5eb4fda9e31a4ac649e58cd84ea083289923909ac5cba5ff7b0375741bd6b79183c37dd8366622c21cde394e2623d91d887fd53cf865d0913390578c44829703a22e8e2797c8561e89155f7e643058d8b400e1277419cf8403f453a3c98ea9592fa0827597af9f64742a4742ed72eec3b6d5c9a8f592c485904f0af326c8cdff34d1d4b9456e66eaab2ccf6c20f7a908dec956d686dca4727ee374d68a80e23863cdf9ca2a8359c80596b314d8ad4461957a00aa13c2b05500f5a69cb549ded9939177c53796c94acab5caf369697ec8099295f45cb5cb04fed565b7843c3c47e27720ecea281762237d447d690e5c933894e9ca3a9ce1d4e05bb5e7030f0715db6a96c2ac81e336f1ea918a650ecdb1b101a080e40a52733df4d025009f796977673501fe3587409821157487dc4e2ba1e0be6ce6290bb460b35b5ed7417b281c369ef25482172d4440806a22473b137aecfd6fb80160b66523f212cc09afd327089c6d8725a53ebc9e8272492b281988da2695db40f5e53062ab59016f7f00d608a7100be1c5b8583738bd0f373465a253837a6a6c2d1353beb609a4db6f537462e53eedd6634e35ad853ed27240ba4445eed14fe51d7529c67e6892d1470a229565785d39c2a51c5bf0421a9f311a787f50b06d152640a967b1882aec7a33a1102783a8d517d669e0cd24c4dc7146cf9cc367b838f2ed302a9840a0d73d2603feea40006d03604a4998b098dfd0db44fcfcddbe061f81077c4d2a80bb69abde8bc7a643630b6f285acd21d00ff3284de985db3b9c07c67732a10eff0fc3b2c5f5ec8efa6acafb0d8bf876040236c49723ffdb3f59232f60cfd4444dd56c0152c178f8d9034132f51b4eb1728f7194a2b8aa9f9115e8b344175d59bd0c578dc77b4d3073e9686730a02ef9ccd0299911bbace3eebe3a3122064ef9473df50d1eeaba4b52240c072be25915c729f88a2dde680a41d6ce79a3db8570a4c1ee69b65ff64b84fae1141eea02f3b66bbc86f59b85fda66e813027c2e498523023817067b69782ac3c2b33c1a79ca15f091a4354ce022333b8862854905ea892d3d2c26639005c6d855d447b3456f331e17ba6c100f488d0acc0400f7c8342895aedbe070358f50ab5eb7eaf35f9f00b56391329394ea82ca2f72b04c443abf7ac8705a7b1fc9927530dfa3aeec8451fb6b9c7a488f6ef6bbf299b387f89903631f563b51c9936fb7c14c5ee082a0eda14d7a7b7f9c20d4ca646f63ffb5cf6e6bd2b42c3e4202650a08e9d88cdf84ed98c18356681f6659865c4d663e306c25d2ad090188058b0630a26e2d15548f48c352831e880430fe620171378255a23c13c4434276a8ebd1645a5bd67c39c2e69539816b475662d1767fd86d91a561e6970f16ba865d5975fe84769093563de82dad0faf08b1199fdba4d8cb7624ad75e261195dde028d999eac02bf3923aba780cff354893bf527a0e4acfaac9aeff881f52daec509595de66841600cdf0d99ea44de259d1ababa224026ef9f2e7064581b6f305e966c5b29204689855808a53dd6652297c32c0870db7b065838869d3a164db78ab43b65fc616ff5adb94ece0cf5cc5e2bed92a494e650a611d46f8341dc28c71816f26ba18eeb8a4174bb00b12b86af4e5f8224eb19d61f70b34d404fa631cadfbff5be2fc9e4ed77a2f7a914fce22777ab94f759d792091c236a6a18fbbe728380f6038aa2d18cd8a837772a156506ee6e9b836b7976fe544299054810dd08b6c9c9fe9e41115984e770f90dba9dcc2949b6f76787aad02b9119baaebfd183d732bd10ca6eb2fafcafa954911bbdc570bae72409b88fd308099f532623f22dedf51c64e83fd85952f9764b349d18b503a3ab530b84e242064e56297a932eaa0bbf82b4c1d8b825c75326436c4073091dd5e224b207309e8af373269330b6d1e275318c8ee8c7754c6fcafabb4790be220937b98bffcbfce714e781cfb8a5d6c45065128f8d8a66a0cd769b50bc6187ea48f0910a98b58a6f119cc1f8abddc2272af320f88600ac76201ecb0ad91ccbda22aacf4d1679185f8f954644bab2f78da7873beb1d769509c639e43d8e \ No newline at end of file diff --git a/chacha20/tests/kats.rs b/chacha20/tests/kats.rs index 4e4aa33c..c0d02374 100644 --- a/chacha20/tests/kats.rs +++ b/chacha20/tests/kats.rs @@ -96,6 +96,29 @@ mod chacha20test { } } +// Long input test to check the full parallel AVX-512 implementation. +// Test data generated from random byte strings. +mod chacha20test_long { + use chacha20::{ChaCha20, KeyIvInit}; + use cipher::StreamCipher; + use hex_literal::hex; + + const KEY: [u8; 32] = hex!("d387cb6ea45656c892a19d3706d1835d8e3cb11865431fa7133a09d1a1fc78da"); + + const IV: [u8; 12] = hex!("689f6a394fe2048a2400e005"); + + #[test] + fn chacha20_encryption() { + let mut cipher = ChaCha20::new(&KEY.into(), &IV.into()); + let mut buf = hex::decode(include_str!("data/chacha20_long_plaintext.txt")).unwrap(); + + cipher.apply_keystream(&mut buf); + + let ciphertext = hex::decode(include_str!("data/chacha20_long_ciphertext.txt")).unwrap(); + assert_eq!(&buf[..], ciphertext); + } +} + #[rustfmt::skip] #[cfg(feature = "xchacha")] mod xchacha20 { From 9f4b9c9a3fc02e0a5a3d347be9a5a7e0348bfce3 Mon Sep 17 00:00:00 2001 From: caelunshun Date: Sun, 2 Nov 2025 19:35:32 -0700 Subject: [PATCH 7/9] Remove AVX-512 RNG backend, since RNG doesn't expose enough parallelism to make it worth the complexity --- chacha20/src/backends/avx512.rs | 46 --------------------------------- chacha20/src/rng.rs | 14 +++------- chacha20/tests/kats.rs | 1 + 3 files changed, 4 insertions(+), 57 deletions(-) diff --git a/chacha20/src/backends/avx512.rs b/chacha20/src/backends/avx512.rs index f46baa82..d8da59c8 100644 --- a/chacha20/src/backends/avx512.rs +++ b/chacha20/src/backends/avx512.rs @@ -2,9 +2,6 @@ use crate::{Rounds, Variant}; use core::marker::PhantomData; -#[cfg(feature = "rng")] -use crate::ChaChaCore; - #[cfg(feature = "cipher")] use crate::{STATE_WORDS, chacha::Block}; @@ -59,49 +56,6 @@ where } } -#[inline] -#[target_feature(enable = "avx512f")] -#[cfg(feature = "rng")] -pub(crate) unsafe fn rng_inner(core: &mut ChaChaCore, buffer: &mut [u32; 64]) -where - R: Rounds, - V: Variant, -{ - use core::slice; - - use crate::rng::BLOCK_WORDS; - - let state_ptr = core.state.as_ptr() as *const __m128i; - let v = [ - _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(0))), - _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(1))), - _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(2))), - ]; - let mut c = _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(3))); - c = _mm512_add_epi64(c, _mm512_set_epi64(0, 3, 0, 2, 0, 1, 0, 0)); - let mut ctr = [c; MAX_N]; - for i in 0..MAX_N { - ctr[i] = c; - c = _mm512_add_epi64(c, _mm512_set_epi64(0, 4, 0, 4, 0, 4, 0, 4)); - } - let mut backend = Backend:: { - v, - ctr, - _pd: PhantomData, - }; - - let buffer = slice::from_raw_parts_mut( - buffer.as_mut_ptr().cast::(), - buffer.len() / BLOCK_WORDS as usize, - ); - backend.gen_par_ks_blocks_inner::<4, { 4 / BLOCKS_PER_VECTOR }>(buffer.try_into().unwrap()); - - core.state[12] = - _mm256_extract_epi32::<0>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32; - core.state[13] = - _mm256_extract_epi32::<1>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32; -} - struct Backend { state: [__m128i; 3], ctr: __m128i, diff --git a/chacha20/src/rng.rs b/chacha20/src/rng.rs index 79954c49..4d2493d7 100644 --- a/chacha20/src/rng.rs +++ b/chacha20/src/rng.rs @@ -189,11 +189,7 @@ impl ChaChaCore { backends::soft::Backend(self).gen_ks_blocks(buffer); } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { cfg_if! { - if #[cfg(chacha20_force_avx512)] { - unsafe { - backends::avx512::rng_inner::(self, buffer); - } - } else if #[cfg(chacha20_force_avx2)] { + if #[cfg(chacha20_force_avx2)] { unsafe { backends::avx2::rng_inner::(self, buffer); } @@ -202,12 +198,8 @@ impl ChaChaCore { backends::sse2::rng_inner::(self, buffer); } } else { - let (avx512_token, avx2_token, sse2_token) = self.tokens; - if avx512_token.get() { - unsafe { - backends::avx512::rng_inner::(self, buffer); - } - } else if avx2_token.get() { + let (_avx512_token, avx2_token, sse2_token) = self.tokens; + if avx2_token.get() { unsafe { backends::avx2::rng_inner::(self, buffer); } diff --git a/chacha20/tests/kats.rs b/chacha20/tests/kats.rs index c0d02374..7987fe70 100644 --- a/chacha20/tests/kats.rs +++ b/chacha20/tests/kats.rs @@ -98,6 +98,7 @@ mod chacha20test { // Long input test to check the full parallel AVX-512 implementation. // Test data generated from random byte strings. +#[cfg(feature = "cipher")] mod chacha20test_long { use chacha20::{ChaCha20, KeyIvInit}; use cipher::StreamCipher; From 045a79e1426d79979d33d1fef3271a9bf2115b6e Mon Sep 17 00:00:00 2001 From: caelunshun Date: Mon, 3 Nov 2025 08:32:00 -0700 Subject: [PATCH 8/9] Gate AVX-512 behind chacha20_avx512 cfg --- chacha20/Cargo.toml | 3 ++- chacha20/README.md | 1 + chacha20/src/backends.rs | 3 ++- chacha20/src/lib.rs | 23 ++++++++++++++++++----- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/chacha20/Cargo.toml b/chacha20/Cargo.toml index 7200276a..73255596 100644 --- a/chacha20/Cargo.toml +++ b/chacha20/Cargo.toml @@ -3,7 +3,7 @@ name = "chacha20" version = "0.10.0-rc.2" authors = ["RustCrypto Developers"] edition = "2024" -rust-version = "1.89" +rust-version = "1.85" documentation = "https://docs.rs/chacha20" readme = "README.md" repository = "https://github.com/RustCrypto/stream-ciphers" @@ -49,6 +49,7 @@ rustdoc-args = ["--cfg", "docsrs"] [lints.rust.unexpected_cfgs] level = "warn" check-cfg = [ + 'cfg(chacha20_avx512)', 'cfg(chacha20_force_soft)', 'cfg(chacha20_force_sse2)', 'cfg(chacha20_force_avx2)', diff --git a/chacha20/README.md b/chacha20/README.md index 21a82172..72ce44c6 100644 --- a/chacha20/README.md +++ b/chacha20/README.md @@ -32,6 +32,7 @@ work on stable Rust with the following `RUSTFLAGS`: - `x86` / `x86_64` - `avx2`: (~1.4cpb) `-Ctarget-cpu=haswell -Ctarget-feature=+avx2` - `sse2`: (~1.6cpb) `-Ctarget-feature=+sse2` (on by default on x86 CPUs) + - `avx512`: `-Ctarget-feature=+avx512f,+avx512vl --cfg chacha20_avx512` requires Rust 1.89+ - `aarch64` - `neon` (~2-3x faster than `soft`) requires Rust 1.61+ and the `neon` feature enabled - Portable diff --git a/chacha20/src/backends.rs b/chacha20/src/backends.rs index 23c8e539..c82c7a17 100644 --- a/chacha20/src/backends.rs +++ b/chacha20/src/backends.rs @@ -7,7 +7,7 @@ cfg_if! { pub(crate) mod soft; } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { cfg_if! { - if #[cfg(chacha20_force_avx512)] { + if #[cfg(all(chacha20_avx512, chacha20_force_avx512))] { pub(crate) mod avx512; } else if #[cfg(chacha20_force_avx2)] { pub(crate) mod avx2; @@ -15,6 +15,7 @@ cfg_if! { pub(crate) mod sse2; } else { pub(crate) mod soft; + #[cfg(chacha20_avx512)] pub(crate) mod avx512; pub(crate) mod avx2; pub(crate) mod sse2; diff --git a/chacha20/src/lib.rs b/chacha20/src/lib.rs index 102dc193..ca226440 100644 --- a/chacha20/src/lib.rs +++ b/chacha20/src/lib.rs @@ -185,9 +185,9 @@ cfg_if! { type Tokens = (); } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { cfg_if! { - if #[cfg(chacha20_force_avx512)] { + if #[cfg(all(chacha20_avx512, chacha20_force_avx512))] { #[cfg(not(all(target_feature = "avx512f", target_feature = "avx512vl")))] - compile_error!("You must enable `avx512f` target feature with \ + compile_error!("You must enable `avx512f` and `avx512vl` target features with \ `chacha20_force_avx512` configuration option"); type Tokens = (); } else if #[cfg(chacha20_force_avx2)] { @@ -201,10 +201,14 @@ cfg_if! { `chacha20_force_sse2` configuration option"); type Tokens = (); } else { + #[cfg(chacha20_avx512)] cpufeatures::new!(avx512_cpuid, "avx512f", "avx512vl"); cpufeatures::new!(avx2_cpuid, "avx2"); cpufeatures::new!(sse2_cpuid, "sse2"); + #[cfg(chacha20_avx512)] type Tokens = (avx512_cpuid::InitToken, avx2_cpuid::InitToken, sse2_cpuid::InitToken); + #[cfg(not(chacha20_avx512))] + type Tokens = (avx2_cpuid::InitToken, sse2_cpuid::InitToken); } } } else { @@ -259,8 +263,10 @@ impl ChaChaCore { let tokens = (); } else if #[cfg(chacha20_force_sse2)] { let tokens = (); - } else { + } else if #[cfg(chacha20_avx512)] { let tokens = (avx512_cpuid::init(), avx2_cpuid::init(), sse2_cpuid::init()); + } else { + let tokens = (avx2_cpuid::init(), sse2_cpuid::init()); } } } else { @@ -306,7 +312,7 @@ impl StreamCipherCore for ChaChaCore { f.call(&mut backends::soft::Backend(self)); } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { cfg_if! { - if #[cfg(chacha20_force_avx512)] { + if #[cfg(all(chacha20_avx512, chacha20_force_avx512))] { unsafe { backends::avx512::inner::(&mut self.state, f); } @@ -319,12 +325,19 @@ impl StreamCipherCore for ChaChaCore { backends::sse2::inner::(&mut self.state, f); } } else { + #[cfg(chacha20_avx512)] let (avx512_token, avx2_token, sse2_token) = self.tokens; + #[cfg(not(chacha20_avx512))] + let (avx2_token, sse2_token) = self.tokens; + + #[cfg(chacha20_avx512)] if avx512_token.get() { unsafe { backends::avx512::inner::(&mut self.state, f); } - } else if avx2_token.get() { + return; + } + if avx2_token.get() { unsafe { backends::avx2::inner::(&mut self.state, f); } From efff21d7757476afecf17f05398a8733774ac96d Mon Sep 17 00:00:00 2001 From: caelunshun Date: Mon, 3 Nov 2025 08:38:06 -0700 Subject: [PATCH 9/9] Add CI to run AVX-512 backend (from aes CI config) and fix build with RNG feature --- .github/workflows/chacha20.yml | 40 ++++++++++++++++++++++++++++++++++ Cargo.lock | 7 ++++++ chacha20/src/backends.rs | 3 +++ chacha20/src/backends/avx2.rs | 1 + chacha20/src/rng.rs | 7 +++++- 5 files changed, 57 insertions(+), 1 deletion(-) diff --git a/.github/workflows/chacha20.yml b/.github/workflows/chacha20.yml index d79583e7..ac5cf746 100644 --- a/.github/workflows/chacha20.yml +++ b/.github/workflows/chacha20.yml @@ -20,6 +20,9 @@ defaults: env: CARGO_INCREMENTAL: 0 RUSTFLAGS: "-Dwarnings" + # NOTE: The mirror number changes with each version so keep these in sync + SDE_FULL_VERSION_MIRROR: "859732" + SDE_FULL_VERSION: "9.58.0-2025-06-16" jobs: build: @@ -79,6 +82,43 @@ jobs: - run: cargo check --target ${{ matrix.target }} --all-features - run: cargo hack test --feature-powerset --target ${{ matrix.target }} + # Tests for the AVX-512 backend + avx512: + runs-on: ubuntu-latest + strategy: + matrix: + include: + - target: x86_64-unknown-linux-gnu + rust: stable + RUSTFLAGS: "-Dwarnings --cfg chacha20_avx512" + env: + CARGO_INCREMENTAL: 0 + RUSTFLAGS: ${{ matrix.RUSTFLAGS }} + steps: + - uses: actions/checkout@v4 + - name: Install Intel SDE + run: | + curl -JLO "https://downloadmirror.intel.com/${{ env.SDE_FULL_VERSION_MIRROR }}/sde-external-${{ env.SDE_FULL_VERSION }}-lin.tar.xz" + tar xvf sde-external-${{ env.SDE_FULL_VERSION }}-lin.tar.xz -C /opt + echo "/opt/sde-external-${{ env.SDE_FULL_VERSION }}-lin" >> $GITHUB_PATH + - uses: RustCrypto/actions/cargo-cache@master + - uses: dtolnay/rust-toolchain@master + with: + toolchain: ${{ matrix.rust }} + targets: ${{ matrix.target }} + # NOTE: Write a `.cargo/config.toml` to configure the target for AVX-512 + # NOTE: We use intel-sde as the runner since not all GitHub CI hosts support AVX512 + - name: write .cargo/config.toml + shell: bash + run: | + cd ../chacha20/.. + mkdir -p .cargo + echo '[target.${{ matrix.target }}]' > .cargo/config.toml + echo 'runner = "sde64 -future --"' >> .cargo/config.toml + - run: ${{ matrix.deps }} + - run: cargo test --target ${{ matrix.target }} + - run: cargo test --target ${{ matrix.target }} --all-features + # Tests for the AVX2 backend avx2: runs-on: ubuntu-latest diff --git a/Cargo.lock b/Cargo.lock index 584a71db..9ea64bb5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -58,6 +58,7 @@ dependencies = [ "cfg-if", "cipher", "cpufeatures", + "hex", "hex-literal", "proptest", "rand_chacha", @@ -138,6 +139,12 @@ dependencies = [ "hex-literal", ] +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "hex-literal" version = "1.1.0" diff --git a/chacha20/src/backends.rs b/chacha20/src/backends.rs index c82c7a17..44f04c63 100644 --- a/chacha20/src/backends.rs +++ b/chacha20/src/backends.rs @@ -9,6 +9,9 @@ cfg_if! { cfg_if! { if #[cfg(all(chacha20_avx512, chacha20_force_avx512))] { pub(crate) mod avx512; + // AVX-2 backend needed for RNG if enabled + #[cfg(feature = "rng")] + pub(crate) mod avx2; } else if #[cfg(chacha20_force_avx2)] { pub(crate) mod avx2; } else if #[cfg(chacha20_force_sse2)] { diff --git a/chacha20/src/backends/avx2.rs b/chacha20/src/backends/avx2.rs index 7de35100..d5e4418b 100644 --- a/chacha20/src/backends/avx2.rs +++ b/chacha20/src/backends/avx2.rs @@ -27,6 +27,7 @@ const N: usize = PAR_BLOCKS / 2; #[inline] #[target_feature(enable = "avx2")] #[cfg(feature = "cipher")] +#[cfg_attr(chacha20_force_avx512, expect(unused))] pub(crate) unsafe fn inner(state: &mut [u32; STATE_WORDS], f: F) where R: Rounds, diff --git a/chacha20/src/rng.rs b/chacha20/src/rng.rs index 4d2493d7..303ac733 100644 --- a/chacha20/src/rng.rs +++ b/chacha20/src/rng.rs @@ -189,7 +189,8 @@ impl ChaChaCore { backends::soft::Backend(self).gen_ks_blocks(buffer); } else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] { cfg_if! { - if #[cfg(chacha20_force_avx2)] { + // AVX-512 doesn't support RNG, so use AVX-2 instead + if #[cfg(any(chacha20_force_avx2, chacha20_force_avx512))] { unsafe { backends::avx2::rng_inner::(self, buffer); } @@ -198,7 +199,11 @@ impl ChaChaCore { backends::sse2::rng_inner::(self, buffer); } } else { + #[cfg(chacha20_avx512)] let (_avx512_token, avx2_token, sse2_token) = self.tokens; + #[cfg(not(chacha20_avx512))] + let (avx2_token, sse2_token) = self.tokens; + if avx2_token.get() { unsafe { backends::avx2::rng_inner::(self, buffer);