Skip to content

Commit ff231fa

Browse files
committed
Add RNG support for avx512 (not benchmarked)
1 parent 5c1ad62 commit ff231fa

File tree

5 files changed

+51
-47
lines changed

5 files changed

+51
-47
lines changed

chacha20/Cargo.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ rand_core-compatible RNGs based on those ciphers.
2020

2121
[dependencies]
2222
cfg-if = "1"
23-
cipher = { version = "0.5.0-rc.1", optional = true, features = [
24-
"stream-wrapper",
25-
] }
23+
cipher = { version = "0.5.0-rc.1", optional = true, features = ["stream-wrapper"] }
2624
rand_core = { version = "0.10.0-rc.1", optional = true, default-features = false }
2725

2826
# `zeroize` is an explicit dependency because this crate may be used without the `cipher` crate

chacha20/src/backends.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@ cfg_if! {
77
pub(crate) mod soft;
88
} else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
99
cfg_if! {
10-
if #[cfg(chacha20_force_avx2)] {
10+
if #[cfg(chacha20_force_avx512)] {
11+
pub(crate) mod avx512;
12+
} else if #[cfg(chacha20_force_avx2)] {
1113
pub(crate) mod avx2;
1214
} else if #[cfg(chacha20_force_sse2)] {
1315
pub(crate) mod sse2;

chacha20/src/backends/avx512.rs

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -85,36 +85,46 @@ where
8585
}
8686

8787
#[inline]
88-
#[target_feature(enable = "avx512")]
88+
#[target_feature(enable = "avx512f")]
8989
#[cfg(feature = "rng")]
9090
pub(crate) unsafe fn rng_inner<R, V>(core: &mut ChaChaCore<R, V>, buffer: &mut [u32; 64])
9191
where
9292
R: Rounds,
9393
V: Variant,
9494
{
95+
use core::slice;
96+
97+
use crate::rng::BLOCK_WORDS;
98+
9599
let state_ptr = core.state.as_ptr() as *const __m128i;
96100
let v = [
97-
_mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(0))),
98-
_mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(1))),
99-
_mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))),
101+
_mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(0))),
102+
_mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(1))),
103+
_mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(2))),
100104
];
101-
let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3)));
102-
c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0));
103-
let mut ctr = [c; N];
104-
for i in 0..N {
105+
let mut c = _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(3)));
106+
c = _mm512_add_epi64(c, _mm512_set_epi64(0, 3, 0, 2, 0, 1, 0, 0));
107+
let mut ctr = [c; MAX_N];
108+
for i in 0..MAX_N {
105109
ctr[i] = c;
106-
c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2));
110+
c = _mm512_add_epi64(c, _mm512_set_epi64(0, 4, 0, 4, 0, 4, 0, 4));
107111
}
108112
let mut backend = Backend::<R, V> {
109113
v,
110114
ctr,
111115
_pd: PhantomData,
112116
};
113117

114-
backend.rng_gen_par_ks_blocks(buffer);
118+
let buffer = slice::from_raw_parts_mut(
119+
buffer.as_mut_ptr().cast::<Block>(),
120+
buffer.len() / BLOCK_WORDS as usize,
121+
);
122+
backend.gen_par_ks_blocks_inner::<4, { 4 / BLOCKS_PER_VECTOR }>(buffer.try_into().unwrap());
115123

116-
core.state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32;
117-
core.state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32;
124+
core.state[12] =
125+
_mm256_extract_epi32::<0>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32;
126+
core.state[13] =
127+
_mm256_extract_epi32::<1>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32;
118128
}
119129

120130
struct Backend<R: Rounds, V: Variant> {
@@ -233,31 +243,6 @@ impl<R: Rounds, V: Variant> StreamCipherBackend for Backend<R, V> {
233243
}
234244
}
235245

236-
#[cfg(feature = "rng")]
237-
impl<R: Rounds, V: Variant> Backend<R, V> {
238-
#[inline(always)]
239-
fn rng_gen_par_ks_blocks(&mut self, blocks: &mut [u32; 64]) {
240-
unsafe {
241-
let vs = rounds::<R>(&self.v, &self.ctr);
242-
243-
let pb = PAR_BLOCKS as i32;
244-
for c in self.ctr.iter_mut() {
245-
*c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, pb as i64, 0, pb as i64));
246-
}
247-
248-
let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i;
249-
for v in vs {
250-
let t: [__m128i; 8] = core::mem::transmute(v);
251-
for i in 0..4 {
252-
_mm_storeu_si128(block_ptr.add(i), t[2 * i]);
253-
_mm_storeu_si128(block_ptr.add(4 + i), t[2 * i + 1]);
254-
}
255-
block_ptr = block_ptr.add(8);
256-
}
257-
}
258-
}
259-
}
260-
261246
#[inline]
262247
#[target_feature(enable = "avx512f")]
263248
unsafe fn rounds<const N: usize, R: Rounds>(

chacha20/src/lib.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,12 @@ cfg_if! {
185185
type Tokens = ();
186186
} else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
187187
cfg_if! {
188-
if #[cfg(chacha20_force_avx2)] {
188+
if #[cfg(chacha20_force_avx512)] {
189+
#[cfg(not(target_feature = "avx512f"))]
190+
compile_error!("You must enable `avx512f` target feature with \
191+
`chacha20_force_avx512` configuration option");
192+
type Tokens = ();
193+
} else if #[cfg(chacha20_force_avx2)] {
189194
#[cfg(not(target_feature = "avx2"))]
190195
compile_error!("You must enable `avx2` target feature with \
191196
`chacha20_force_avx2` configuration option");
@@ -248,7 +253,9 @@ impl<R: Rounds, V: Variant> ChaChaCore<R, V> {
248253
let tokens = ();
249254
} else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
250255
cfg_if! {
251-
if #[cfg(chacha20_force_avx2)] {
256+
if #[cfg(chacha20_force_avx512)] {
257+
let tokens = ();
258+
} else if #[cfg(chacha20_force_avx2)] {
252259
let tokens = ();
253260
} else if #[cfg(chacha20_force_sse2)] {
254261
let tokens = ();
@@ -299,7 +306,11 @@ impl<R: Rounds, V: Variant> StreamCipherCore for ChaChaCore<R, V> {
299306
f.call(&mut backends::soft::Backend(self));
300307
} else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
301308
cfg_if! {
302-
if #[cfg(chacha20_force_avx2)] {
309+
if #[cfg(chacha20_force_avx512)] {
310+
unsafe {
311+
backends::avx512::inner::<R, _, V>(&mut self.state, f);
312+
}
313+
} else if #[cfg(chacha20_force_avx2)] {
303314
unsafe {
304315
backends::avx2::inner::<R, _, V>(&mut self.state, f);
305316
}

chacha20/src/rng.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,11 @@ impl<R: Rounds, V: Variant> ChaChaCore<R, V> {
189189
backends::soft::Backend(self).gen_ks_blocks(buffer);
190190
} else if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
191191
cfg_if! {
192-
if #[cfg(chacha20_force_avx2)] {
192+
if #[cfg(chacha20_force_avx512)] {
193+
unsafe {
194+
backends::avx512::rng_inner::<R, V>(self, buffer);
195+
}
196+
} else if #[cfg(chacha20_force_avx2)] {
193197
unsafe {
194198
backends::avx2::rng_inner::<R, V>(self, buffer);
195199
}
@@ -198,8 +202,12 @@ impl<R: Rounds, V: Variant> ChaChaCore<R, V> {
198202
backends::sse2::rng_inner::<R, V>(self, buffer);
199203
}
200204
} else {
201-
let (avx2_token, sse2_token) = self.tokens;
202-
if avx2_token.get() {
205+
let (avx512_token, avx2_token, sse2_token) = self.tokens;
206+
if avx512_token.get() {
207+
unsafe {
208+
backends::avx512::rng_inner::<R, V>(self, buffer);
209+
}
210+
} else if avx2_token.get() {
203211
unsafe {
204212
backends::avx2::rng_inner::<R, V>(self, buffer);
205213
}

0 commit comments

Comments
 (0)