Skip to content

Commit 7346b83

Browse files
committed
Initial AVX-512 port of ChaCha20.
Benchmark results on Zen 4: AVX2 Zen4 test chacha20_bench1_16b ... bench: 23.71 ns/iter (+/- 0.89) = 695 MB/s test chacha20_bench2_256b ... bench: 82.98 ns/iter (+/- 7.64) = 3121 MB/s test chacha20_bench3_1kib ... bench: 302.03 ns/iter (+/- 3.59) = 3390 MB/s test chacha20_bench4_16kib ... bench: 4,677.58 ns/iter (+/- 161.42) = 3503 MB/s AVX512 Zen4 test chacha20_bench1_16b ... bench: 25.07 ns/iter (+/- 0.90) = 640 MB/s test chacha20_bench2_256b ... bench: 79.66 ns/iter (+/- 1.18) = 3240 MB/s test chacha20_bench3_1kib ... bench: 275.32 ns/iter (+/- 4.13) = 3723 MB/s test chacha20_bench4_16kib ... bench: 4,201.84 ns/iter (+/- 24.18) = 3900 MB/s Much greater speedups are achievable for long input sizes if we increase PAR_BLOCKS to 8, but this also causes a 2x slowdown for short inputs (< 512 bytes). The StreamCipherBackend API doesn't seem to have any way to support multiple degrees of parallelism depending on the input size.
1 parent 0e3296d commit 7346b83

File tree

4 files changed

+365
-6
lines changed

4 files changed

+365
-6
lines changed

chacha20/Cargo.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ name = "chacha20"
33
version = "0.10.0-rc.2"
44
authors = ["RustCrypto Developers"]
55
edition = "2024"
6-
rust-version = "1.85"
6+
rust-version = "1.89"
77
documentation = "https://docs.rs/chacha20"
88
readme = "README.md"
99
repository = "https://github.com/RustCrypto/stream-ciphers"
@@ -20,7 +20,9 @@ rand_core-compatible RNGs based on those ciphers.
2020

2121
[dependencies]
2222
cfg-if = "1"
23-
cipher = { version = "0.5.0-rc.1", optional = true, features = ["stream-wrapper"] }
23+
cipher = { version = "0.5.0-rc.1", optional = true, features = [
24+
"stream-wrapper",
25+
] }
2426
rand_core = { version = "0.10.0-rc.1", optional = true, default-features = false }
2527

2628
# `zeroize` is an explicit dependency because this crate may be used without the `cipher` crate
@@ -51,6 +53,7 @@ check-cfg = [
5153
'cfg(chacha20_force_soft)',
5254
'cfg(chacha20_force_sse2)',
5355
'cfg(chacha20_force_avx2)',
56+
'cfg(chacha20_force_avx512)',
5457
]
5558

5659
[lints.clippy]

chacha20/src/backends.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ cfg_if! {
1313
pub(crate) mod sse2;
1414
} else {
1515
pub(crate) mod soft;
16+
pub(crate) mod avx512;
1617
pub(crate) mod avx2;
1718
pub(crate) mod sse2;
1819
}

chacha20/src/backends/avx512.rs

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,350 @@
1+
#![allow(unsafe_op_in_unsafe_fn)]
2+
use crate::{Rounds, Variant};
3+
use core::marker::PhantomData;
4+
5+
#[cfg(feature = "rng")]
6+
use crate::ChaChaCore;
7+
8+
#[cfg(feature = "cipher")]
9+
use crate::{STATE_WORDS, chacha::Block};
10+
11+
#[cfg(feature = "cipher")]
12+
use cipher::{
13+
BlockSizeUser, ParBlocks, ParBlocksSizeUser, StreamCipherBackend, StreamCipherClosure,
14+
consts::{U4, U64},
15+
};
16+
17+
#[cfg(target_arch = "x86")]
18+
use core::arch::x86::*;
19+
#[cfg(target_arch = "x86_64")]
20+
use core::arch::x86_64::*;
21+
22+
/// Number of blocks processed in parallel.
23+
const PAR_BLOCKS: usize = 4;
24+
/// Number of `__m512i` to store parallel blocks.
25+
const N: usize = PAR_BLOCKS / 4;
26+
27+
#[inline]
28+
#[target_feature(enable = "avx512f")]
29+
#[cfg(feature = "cipher")]
30+
pub(crate) unsafe fn inner<R, F, V>(state: &mut [u32; STATE_WORDS], f: F)
31+
where
32+
R: Rounds,
33+
F: StreamCipherClosure<BlockSize = U64>,
34+
V: Variant,
35+
{
36+
let state_ptr = state.as_ptr() as *const __m128i;
37+
let v = [
38+
_mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(0))),
39+
_mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(1))),
40+
_mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(2))),
41+
];
42+
let mut c = _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(3)));
43+
c = match size_of::<V::Counter>() {
44+
4 => _mm512_add_epi32(
45+
c,
46+
_mm512_set_epi32(0, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0),
47+
),
48+
8 => _mm512_add_epi64(c, _mm512_set_epi64(0, 3, 0, 2, 0, 1, 0, 0)),
49+
_ => unreachable!(),
50+
};
51+
let mut ctr = [c; N];
52+
for i in 0..N {
53+
ctr[i] = c;
54+
c = match size_of::<V::Counter>() {
55+
4 => _mm512_add_epi32(
56+
c,
57+
_mm512_set_epi32(0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4),
58+
),
59+
8 => _mm512_add_epi64(c, _mm512_set_epi64(0, 4, 0, 4, 0, 4, 0, 4)),
60+
_ => unreachable!(),
61+
};
62+
}
63+
let mut backend = Backend::<R, V> {
64+
v,
65+
ctr,
66+
_pd: PhantomData,
67+
};
68+
69+
f.call(&mut backend);
70+
71+
state[12] = _mm256_extract_epi32::<0>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32;
72+
match size_of::<V::Counter>() {
73+
4 => {}
74+
8 => {
75+
state[13] =
76+
_mm256_extract_epi32::<1>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32
77+
}
78+
_ => unreachable!(),
79+
}
80+
}
81+
82+
#[inline]
83+
#[target_feature(enable = "avx2")]
84+
#[cfg(feature = "rng")]
85+
pub(crate) unsafe fn rng_inner<R, V>(core: &mut ChaChaCore<R, V>, buffer: &mut [u32; 64])
86+
where
87+
R: Rounds,
88+
V: Variant,
89+
{
90+
let state_ptr = core.state.as_ptr() as *const __m128i;
91+
let v = [
92+
_mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(0))),
93+
_mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(1))),
94+
_mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))),
95+
];
96+
let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3)));
97+
c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0));
98+
let mut ctr = [c; N];
99+
for i in 0..N {
100+
ctr[i] = c;
101+
c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2));
102+
}
103+
let mut backend = Backend::<R, V> {
104+
v,
105+
ctr,
106+
_pd: PhantomData,
107+
};
108+
109+
backend.rng_gen_par_ks_blocks(buffer);
110+
111+
core.state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32;
112+
core.state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32;
113+
}
114+
115+
struct Backend<R: Rounds, V: Variant> {
116+
v: [__m512i; 3],
117+
ctr: [__m512i; N],
118+
_pd: PhantomData<(R, V)>,
119+
}
120+
121+
#[cfg(feature = "cipher")]
122+
impl<R: Rounds, V: Variant> BlockSizeUser for Backend<R, V> {
123+
type BlockSize = U64;
124+
}
125+
126+
#[cfg(feature = "cipher")]
127+
impl<R: Rounds, V: Variant> ParBlocksSizeUser for Backend<R, V> {
128+
type ParBlocksSize = U4;
129+
}
130+
131+
#[cfg(feature = "cipher")]
132+
impl<R: Rounds, V: Variant> StreamCipherBackend for Backend<R, V> {
133+
#[inline(always)]
134+
fn gen_ks_block(&mut self, block: &mut Block) {
135+
unsafe {
136+
let res = rounds::<R>(&self.v, &self.ctr);
137+
for c in self.ctr.iter_mut() {
138+
*c = match size_of::<V::Counter>() {
139+
4 => _mm512_add_epi32(
140+
*c,
141+
_mm512_set_epi32(0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1),
142+
),
143+
8 => _mm512_add_epi64(*c, _mm512_set_epi64(0, 1, 0, 1, 0, 1, 0, 1)),
144+
_ => unreachable!(),
145+
};
146+
}
147+
148+
let block_ptr = block.as_mut_ptr() as *mut __m128i;
149+
150+
for i in 0..4 {
151+
_mm_storeu_si128(block_ptr.add(i), _mm512_extracti32x4_epi32::<0>(res[0][i]));
152+
}
153+
}
154+
}
155+
156+
#[inline(always)]
157+
fn gen_par_ks_blocks(&mut self, blocks: &mut ParBlocks<Self>) {
158+
unsafe {
159+
let vs = rounds::<R>(&self.v, &self.ctr);
160+
161+
let pb = PAR_BLOCKS as i32;
162+
for c in self.ctr.iter_mut() {
163+
*c = match size_of::<V::Counter>() {
164+
4 => _mm512_add_epi32(
165+
*c,
166+
_mm512_set_epi32(0, 0, 0, pb, 0, 0, 0, pb, 0, 0, 0, pb, 0, 0, 0, pb),
167+
),
168+
8 => _mm512_add_epi64(
169+
*c,
170+
_mm512_set_epi64(0, pb as i64, 0, pb as i64, 0, pb as i64, 0, pb as i64),
171+
),
172+
_ => unreachable!(),
173+
}
174+
}
175+
176+
let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i;
177+
for v in vs {
178+
let t: [__m128i; 16] = core::mem::transmute(v);
179+
for i in 0..4 {
180+
_mm_storeu_si128(block_ptr.add(i), t[4 * i]);
181+
_mm_storeu_si128(block_ptr.add(4 + i), t[4 * i + 1]);
182+
_mm_storeu_si128(block_ptr.add(8 + i), t[4 * i + 2]);
183+
_mm_storeu_si128(block_ptr.add(12 + i), t[4 * i + 3]);
184+
}
185+
block_ptr = block_ptr.add(16);
186+
}
187+
}
188+
}
189+
}
190+
191+
#[cfg(feature = "rng")]
192+
impl<R: Rounds, V: Variant> Backend<R, V> {
193+
#[inline(always)]
194+
fn rng_gen_par_ks_blocks(&mut self, blocks: &mut [u32; 64]) {
195+
unsafe {
196+
let vs = rounds::<R>(&self.v, &self.ctr);
197+
198+
let pb = PAR_BLOCKS as i32;
199+
for c in self.ctr.iter_mut() {
200+
*c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, pb as i64, 0, pb as i64));
201+
}
202+
203+
let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i;
204+
for v in vs {
205+
let t: [__m128i; 8] = core::mem::transmute(v);
206+
for i in 0..4 {
207+
_mm_storeu_si128(block_ptr.add(i), t[2 * i]);
208+
_mm_storeu_si128(block_ptr.add(4 + i), t[2 * i + 1]);
209+
}
210+
block_ptr = block_ptr.add(8);
211+
}
212+
}
213+
}
214+
}
215+
216+
#[inline]
217+
#[target_feature(enable = "avx512f")]
218+
unsafe fn rounds<R: Rounds>(v: &[__m512i; 3], c: &[__m512i; N]) -> [[__m512i; 4]; N] {
219+
let mut vs: [[__m512i; 4]; N] = [[_mm512_setzero_si512(); 4]; N];
220+
for i in 0..N {
221+
vs[i] = [v[0], v[1], v[2], c[i]];
222+
}
223+
for _ in 0..R::COUNT {
224+
double_quarter_round(&mut vs);
225+
}
226+
227+
for i in 0..N {
228+
for j in 0..3 {
229+
vs[i][j] = _mm512_add_epi32(vs[i][j], v[j]);
230+
}
231+
vs[i][3] = _mm512_add_epi32(vs[i][3], c[i]);
232+
}
233+
234+
vs
235+
}
236+
237+
#[inline]
238+
#[target_feature(enable = "avx2")]
239+
unsafe fn double_quarter_round(v: &mut [[__m512i; 4]; N]) {
240+
add_xor_rot(v);
241+
rows_to_cols(v);
242+
add_xor_rot(v);
243+
cols_to_rows(v);
244+
}
245+
246+
/// The goal of this function is to transform the state words from:
247+
/// ```text
248+
/// [a0, a1, a2, a3] [ 0, 1, 2, 3]
249+
/// [b0, b1, b2, b3] == [ 4, 5, 6, 7]
250+
/// [c0, c1, c2, c3] [ 8, 9, 10, 11]
251+
/// [d0, d1, d2, d3] [12, 13, 14, 15]
252+
/// ```
253+
///
254+
/// to:
255+
/// ```text
256+
/// [a0, a1, a2, a3] [ 0, 1, 2, 3]
257+
/// [b1, b2, b3, b0] == [ 5, 6, 7, 4]
258+
/// [c2, c3, c0, c1] [10, 11, 8, 9]
259+
/// [d3, d0, d1, d2] [15, 12, 13, 14]
260+
/// ```
261+
///
262+
/// so that we can apply [`add_xor_rot`] to the resulting columns, and have it compute the
263+
/// "diagonal rounds" (as defined in RFC 7539) in parallel. In practice, this shuffle is
264+
/// non-optimal: the last state word to be altered in `add_xor_rot` is `b`, so the shuffle
265+
/// blocks on the result of `b` being calculated.
266+
///
267+
/// We can optimize this by observing that the four quarter rounds in `add_xor_rot` are
268+
/// data-independent: they only access a single column of the state, and thus the order of
269+
/// the columns does not matter. We therefore instead shuffle the other three state words,
270+
/// to obtain the following equivalent layout:
271+
/// ```text
272+
/// [a3, a0, a1, a2] [ 3, 0, 1, 2]
273+
/// [b0, b1, b2, b3] == [ 4, 5, 6, 7]
274+
/// [c1, c2, c3, c0] [ 9, 10, 11, 8]
275+
/// [d2, d3, d0, d1] [14, 15, 12, 13]
276+
/// ```
277+
///
278+
/// See https://github.com/sneves/blake2-avx2/pull/4 for additional details. The earliest
279+
/// known occurrence of this optimization is in floodyberry's SSE4 ChaCha code from 2014:
280+
/// - https://github.com/floodyberry/chacha-opt/blob/0ab65cb99f5016633b652edebaf3691ceb4ff753/chacha_blocks_ssse3-64.S#L639-L643
281+
#[inline]
282+
#[target_feature(enable = "avx512f")]
283+
unsafe fn rows_to_cols(vs: &mut [[__m512i; 4]; N]) {
284+
// c >>>= 32; d >>>= 64; a >>>= 96;
285+
for [a, _, c, d] in vs {
286+
*c = _mm512_shuffle_epi32::<0b_00_11_10_01>(*c); // _MM_SHUFFLE(0, 3, 2, 1)
287+
*d = _mm512_shuffle_epi32::<0b_01_00_11_10>(*d); // _MM_SHUFFLE(1, 0, 3, 2)
288+
*a = _mm512_shuffle_epi32::<0b_10_01_00_11>(*a); // _MM_SHUFFLE(2, 1, 0, 3)
289+
}
290+
}
291+
292+
/// The goal of this function is to transform the state words from:
293+
/// ```text
294+
/// [a3, a0, a1, a2] [ 3, 0, 1, 2]
295+
/// [b0, b1, b2, b3] == [ 4, 5, 6, 7]
296+
/// [c1, c2, c3, c0] [ 9, 10, 11, 8]
297+
/// [d2, d3, d0, d1] [14, 15, 12, 13]
298+
/// ```
299+
///
300+
/// to:
301+
/// ```text
302+
/// [a0, a1, a2, a3] [ 0, 1, 2, 3]
303+
/// [b0, b1, b2, b3] == [ 4, 5, 6, 7]
304+
/// [c0, c1, c2, c3] [ 8, 9, 10, 11]
305+
/// [d0, d1, d2, d3] [12, 13, 14, 15]
306+
/// ```
307+
///
308+
/// reversing the transformation of [`rows_to_cols`].
309+
#[inline]
310+
#[target_feature(enable = "avx512f")]
311+
unsafe fn cols_to_rows(vs: &mut [[__m512i; 4]; N]) {
312+
// c <<<= 32; d <<<= 64; a <<<= 96;
313+
for [a, _, c, d] in vs {
314+
*c = _mm512_shuffle_epi32::<0b_10_01_00_11>(*c); // _MM_SHUFFLE(2, 1, 0, 3)
315+
*d = _mm512_shuffle_epi32::<0b_01_00_11_10>(*d); // _MM_SHUFFLE(1, 0, 3, 2)
316+
*a = _mm512_shuffle_epi32::<0b_00_11_10_01>(*a); // _MM_SHUFFLE(0, 3, 2, 1)
317+
}
318+
}
319+
320+
#[inline]
321+
#[target_feature(enable = "avx512f")]
322+
unsafe fn add_xor_rot(vs: &mut [[__m512i; 4]; N]) {
323+
// a += b; d ^= a; d <<<= (16, 16, 16, 16);
324+
for [a, b, _, d] in vs.iter_mut() {
325+
*a = _mm512_add_epi32(*a, *b);
326+
*d = _mm512_xor_si512(*d, *a);
327+
*d = _mm512_rol_epi32::<16>(*d);
328+
}
329+
330+
// c += d; b ^= c; b <<<= (12, 12, 12, 12);
331+
for [_, b, c, d] in vs.iter_mut() {
332+
*c = _mm512_add_epi32(*c, *d);
333+
*b = _mm512_xor_si512(*b, *c);
334+
*b = _mm512_rol_epi32::<12>(*b);
335+
}
336+
337+
// a += b; d ^= a; d <<<= (8, 8, 8, 8);
338+
for [a, b, _, d] in vs.iter_mut() {
339+
*a = _mm512_add_epi32(*a, *b);
340+
*d = _mm512_xor_si512(*d, *a);
341+
*d = _mm512_rol_epi32::<8>(*d);
342+
}
343+
344+
// c += d; b ^= c; b <<<= (7, 7, 7, 7);
345+
for [_, b, c, d] in vs.iter_mut() {
346+
*c = _mm512_add_epi32(*c, *d);
347+
*b = _mm512_xor_si512(*b, *c);
348+
*b = _mm512_rol_epi32::<7>(*b);
349+
}
350+
}

0 commit comments

Comments
 (0)