|
| 1 | +#![allow(unsafe_op_in_unsafe_fn)] |
| 2 | +use crate::{Rounds, Variant}; |
| 3 | +use core::marker::PhantomData; |
| 4 | + |
| 5 | +#[cfg(feature = "rng")] |
| 6 | +use crate::ChaChaCore; |
| 7 | + |
| 8 | +#[cfg(feature = "cipher")] |
| 9 | +use crate::{STATE_WORDS, chacha::Block}; |
| 10 | + |
| 11 | +#[cfg(feature = "cipher")] |
| 12 | +use cipher::{ |
| 13 | + BlockSizeUser, ParBlocks, ParBlocksSizeUser, StreamCipherBackend, StreamCipherClosure, |
| 14 | + consts::{U4, U64}, |
| 15 | +}; |
| 16 | + |
| 17 | +#[cfg(target_arch = "x86")] |
| 18 | +use core::arch::x86::*; |
| 19 | +#[cfg(target_arch = "x86_64")] |
| 20 | +use core::arch::x86_64::*; |
| 21 | + |
| 22 | +/// Number of blocks processed in parallel. |
| 23 | +const PAR_BLOCKS: usize = 4; |
| 24 | +/// Number of `__m512i` to store parallel blocks. |
| 25 | +const N: usize = PAR_BLOCKS / 4; |
| 26 | + |
| 27 | +#[inline] |
| 28 | +#[target_feature(enable = "avx512f")] |
| 29 | +#[cfg(feature = "cipher")] |
| 30 | +pub(crate) unsafe fn inner<R, F, V>(state: &mut [u32; STATE_WORDS], f: F) |
| 31 | +where |
| 32 | + R: Rounds, |
| 33 | + F: StreamCipherClosure<BlockSize = U64>, |
| 34 | + V: Variant, |
| 35 | +{ |
| 36 | + let state_ptr = state.as_ptr() as *const __m128i; |
| 37 | + let v = [ |
| 38 | + _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(0))), |
| 39 | + _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(1))), |
| 40 | + _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(2))), |
| 41 | + ]; |
| 42 | + let mut c = _mm512_broadcast_i32x4(_mm_loadu_si128(state_ptr.add(3))); |
| 43 | + c = match size_of::<V::Counter>() { |
| 44 | + 4 => _mm512_add_epi32( |
| 45 | + c, |
| 46 | + _mm512_set_epi32(0, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0, 0), |
| 47 | + ), |
| 48 | + 8 => _mm512_add_epi64(c, _mm512_set_epi64(0, 3, 0, 2, 0, 1, 0, 0)), |
| 49 | + _ => unreachable!(), |
| 50 | + }; |
| 51 | + let mut ctr = [c; N]; |
| 52 | + for i in 0..N { |
| 53 | + ctr[i] = c; |
| 54 | + c = match size_of::<V::Counter>() { |
| 55 | + 4 => _mm512_add_epi32( |
| 56 | + c, |
| 57 | + _mm512_set_epi32(0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4), |
| 58 | + ), |
| 59 | + 8 => _mm512_add_epi64(c, _mm512_set_epi64(0, 4, 0, 4, 0, 4, 0, 4)), |
| 60 | + _ => unreachable!(), |
| 61 | + }; |
| 62 | + } |
| 63 | + let mut backend = Backend::<R, V> { |
| 64 | + v, |
| 65 | + ctr, |
| 66 | + _pd: PhantomData, |
| 67 | + }; |
| 68 | + |
| 69 | + f.call(&mut backend); |
| 70 | + |
| 71 | + state[12] = _mm256_extract_epi32::<0>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32; |
| 72 | + match size_of::<V::Counter>() { |
| 73 | + 4 => {} |
| 74 | + 8 => { |
| 75 | + state[13] = |
| 76 | + _mm256_extract_epi32::<1>(_mm512_extracti32x8_epi32::<0>(backend.ctr[0])) as u32 |
| 77 | + } |
| 78 | + _ => unreachable!(), |
| 79 | + } |
| 80 | +} |
| 81 | + |
| 82 | +#[inline] |
| 83 | +#[target_feature(enable = "avx2")] |
| 84 | +#[cfg(feature = "rng")] |
| 85 | +pub(crate) unsafe fn rng_inner<R, V>(core: &mut ChaChaCore<R, V>, buffer: &mut [u32; 64]) |
| 86 | +where |
| 87 | + R: Rounds, |
| 88 | + V: Variant, |
| 89 | +{ |
| 90 | + let state_ptr = core.state.as_ptr() as *const __m128i; |
| 91 | + let v = [ |
| 92 | + _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(0))), |
| 93 | + _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(1))), |
| 94 | + _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(2))), |
| 95 | + ]; |
| 96 | + let mut c = _mm256_broadcastsi128_si256(_mm_loadu_si128(state_ptr.add(3))); |
| 97 | + c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 1, 0, 0)); |
| 98 | + let mut ctr = [c; N]; |
| 99 | + for i in 0..N { |
| 100 | + ctr[i] = c; |
| 101 | + c = _mm256_add_epi64(c, _mm256_set_epi64x(0, 2, 0, 2)); |
| 102 | + } |
| 103 | + let mut backend = Backend::<R, V> { |
| 104 | + v, |
| 105 | + ctr, |
| 106 | + _pd: PhantomData, |
| 107 | + }; |
| 108 | + |
| 109 | + backend.rng_gen_par_ks_blocks(buffer); |
| 110 | + |
| 111 | + core.state[12] = _mm256_extract_epi32(backend.ctr[0], 0) as u32; |
| 112 | + core.state[13] = _mm256_extract_epi32(backend.ctr[0], 1) as u32; |
| 113 | +} |
| 114 | + |
| 115 | +struct Backend<R: Rounds, V: Variant> { |
| 116 | + v: [__m512i; 3], |
| 117 | + ctr: [__m512i; N], |
| 118 | + _pd: PhantomData<(R, V)>, |
| 119 | +} |
| 120 | + |
| 121 | +#[cfg(feature = "cipher")] |
| 122 | +impl<R: Rounds, V: Variant> BlockSizeUser for Backend<R, V> { |
| 123 | + type BlockSize = U64; |
| 124 | +} |
| 125 | + |
| 126 | +#[cfg(feature = "cipher")] |
| 127 | +impl<R: Rounds, V: Variant> ParBlocksSizeUser for Backend<R, V> { |
| 128 | + type ParBlocksSize = U4; |
| 129 | +} |
| 130 | + |
| 131 | +#[cfg(feature = "cipher")] |
| 132 | +impl<R: Rounds, V: Variant> StreamCipherBackend for Backend<R, V> { |
| 133 | + #[inline(always)] |
| 134 | + fn gen_ks_block(&mut self, block: &mut Block) { |
| 135 | + unsafe { |
| 136 | + let res = rounds::<R>(&self.v, &self.ctr); |
| 137 | + for c in self.ctr.iter_mut() { |
| 138 | + *c = match size_of::<V::Counter>() { |
| 139 | + 4 => _mm512_add_epi32( |
| 140 | + *c, |
| 141 | + _mm512_set_epi32(0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1), |
| 142 | + ), |
| 143 | + 8 => _mm512_add_epi64(*c, _mm512_set_epi64(0, 1, 0, 1, 0, 1, 0, 1)), |
| 144 | + _ => unreachable!(), |
| 145 | + }; |
| 146 | + } |
| 147 | + |
| 148 | + let block_ptr = block.as_mut_ptr() as *mut __m128i; |
| 149 | + |
| 150 | + for i in 0..4 { |
| 151 | + _mm_storeu_si128(block_ptr.add(i), _mm512_extracti32x4_epi32::<0>(res[0][i])); |
| 152 | + } |
| 153 | + } |
| 154 | + } |
| 155 | + |
| 156 | + #[inline(always)] |
| 157 | + fn gen_par_ks_blocks(&mut self, blocks: &mut ParBlocks<Self>) { |
| 158 | + unsafe { |
| 159 | + let vs = rounds::<R>(&self.v, &self.ctr); |
| 160 | + |
| 161 | + let pb = PAR_BLOCKS as i32; |
| 162 | + for c in self.ctr.iter_mut() { |
| 163 | + *c = match size_of::<V::Counter>() { |
| 164 | + 4 => _mm512_add_epi32( |
| 165 | + *c, |
| 166 | + _mm512_set_epi32(0, 0, 0, pb, 0, 0, 0, pb, 0, 0, 0, pb, 0, 0, 0, pb), |
| 167 | + ), |
| 168 | + 8 => _mm512_add_epi64( |
| 169 | + *c, |
| 170 | + _mm512_set_epi64(0, pb as i64, 0, pb as i64, 0, pb as i64, 0, pb as i64), |
| 171 | + ), |
| 172 | + _ => unreachable!(), |
| 173 | + } |
| 174 | + } |
| 175 | + |
| 176 | + let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i; |
| 177 | + for v in vs { |
| 178 | + let t: [__m128i; 16] = core::mem::transmute(v); |
| 179 | + for i in 0..4 { |
| 180 | + _mm_storeu_si128(block_ptr.add(i), t[4 * i]); |
| 181 | + _mm_storeu_si128(block_ptr.add(4 + i), t[4 * i + 1]); |
| 182 | + _mm_storeu_si128(block_ptr.add(8 + i), t[4 * i + 2]); |
| 183 | + _mm_storeu_si128(block_ptr.add(12 + i), t[4 * i + 3]); |
| 184 | + } |
| 185 | + block_ptr = block_ptr.add(16); |
| 186 | + } |
| 187 | + } |
| 188 | + } |
| 189 | +} |
| 190 | + |
| 191 | +#[cfg(feature = "rng")] |
| 192 | +impl<R: Rounds, V: Variant> Backend<R, V> { |
| 193 | + #[inline(always)] |
| 194 | + fn rng_gen_par_ks_blocks(&mut self, blocks: &mut [u32; 64]) { |
| 195 | + unsafe { |
| 196 | + let vs = rounds::<R>(&self.v, &self.ctr); |
| 197 | + |
| 198 | + let pb = PAR_BLOCKS as i32; |
| 199 | + for c in self.ctr.iter_mut() { |
| 200 | + *c = _mm256_add_epi64(*c, _mm256_set_epi64x(0, pb as i64, 0, pb as i64)); |
| 201 | + } |
| 202 | + |
| 203 | + let mut block_ptr = blocks.as_mut_ptr() as *mut __m128i; |
| 204 | + for v in vs { |
| 205 | + let t: [__m128i; 8] = core::mem::transmute(v); |
| 206 | + for i in 0..4 { |
| 207 | + _mm_storeu_si128(block_ptr.add(i), t[2 * i]); |
| 208 | + _mm_storeu_si128(block_ptr.add(4 + i), t[2 * i + 1]); |
| 209 | + } |
| 210 | + block_ptr = block_ptr.add(8); |
| 211 | + } |
| 212 | + } |
| 213 | + } |
| 214 | +} |
| 215 | + |
| 216 | +#[inline] |
| 217 | +#[target_feature(enable = "avx512f")] |
| 218 | +unsafe fn rounds<R: Rounds>(v: &[__m512i; 3], c: &[__m512i; N]) -> [[__m512i; 4]; N] { |
| 219 | + let mut vs: [[__m512i; 4]; N] = [[_mm512_setzero_si512(); 4]; N]; |
| 220 | + for i in 0..N { |
| 221 | + vs[i] = [v[0], v[1], v[2], c[i]]; |
| 222 | + } |
| 223 | + for _ in 0..R::COUNT { |
| 224 | + double_quarter_round(&mut vs); |
| 225 | + } |
| 226 | + |
| 227 | + for i in 0..N { |
| 228 | + for j in 0..3 { |
| 229 | + vs[i][j] = _mm512_add_epi32(vs[i][j], v[j]); |
| 230 | + } |
| 231 | + vs[i][3] = _mm512_add_epi32(vs[i][3], c[i]); |
| 232 | + } |
| 233 | + |
| 234 | + vs |
| 235 | +} |
| 236 | + |
| 237 | +#[inline] |
| 238 | +#[target_feature(enable = "avx2")] |
| 239 | +unsafe fn double_quarter_round(v: &mut [[__m512i; 4]; N]) { |
| 240 | + add_xor_rot(v); |
| 241 | + rows_to_cols(v); |
| 242 | + add_xor_rot(v); |
| 243 | + cols_to_rows(v); |
| 244 | +} |
| 245 | + |
| 246 | +/// The goal of this function is to transform the state words from: |
| 247 | +/// ```text |
| 248 | +/// [a0, a1, a2, a3] [ 0, 1, 2, 3] |
| 249 | +/// [b0, b1, b2, b3] == [ 4, 5, 6, 7] |
| 250 | +/// [c0, c1, c2, c3] [ 8, 9, 10, 11] |
| 251 | +/// [d0, d1, d2, d3] [12, 13, 14, 15] |
| 252 | +/// ``` |
| 253 | +/// |
| 254 | +/// to: |
| 255 | +/// ```text |
| 256 | +/// [a0, a1, a2, a3] [ 0, 1, 2, 3] |
| 257 | +/// [b1, b2, b3, b0] == [ 5, 6, 7, 4] |
| 258 | +/// [c2, c3, c0, c1] [10, 11, 8, 9] |
| 259 | +/// [d3, d0, d1, d2] [15, 12, 13, 14] |
| 260 | +/// ``` |
| 261 | +/// |
| 262 | +/// so that we can apply [`add_xor_rot`] to the resulting columns, and have it compute the |
| 263 | +/// "diagonal rounds" (as defined in RFC 7539) in parallel. In practice, this shuffle is |
| 264 | +/// non-optimal: the last state word to be altered in `add_xor_rot` is `b`, so the shuffle |
| 265 | +/// blocks on the result of `b` being calculated. |
| 266 | +/// |
| 267 | +/// We can optimize this by observing that the four quarter rounds in `add_xor_rot` are |
| 268 | +/// data-independent: they only access a single column of the state, and thus the order of |
| 269 | +/// the columns does not matter. We therefore instead shuffle the other three state words, |
| 270 | +/// to obtain the following equivalent layout: |
| 271 | +/// ```text |
| 272 | +/// [a3, a0, a1, a2] [ 3, 0, 1, 2] |
| 273 | +/// [b0, b1, b2, b3] == [ 4, 5, 6, 7] |
| 274 | +/// [c1, c2, c3, c0] [ 9, 10, 11, 8] |
| 275 | +/// [d2, d3, d0, d1] [14, 15, 12, 13] |
| 276 | +/// ``` |
| 277 | +/// |
| 278 | +/// See https://github.com/sneves/blake2-avx2/pull/4 for additional details. The earliest |
| 279 | +/// known occurrence of this optimization is in floodyberry's SSE4 ChaCha code from 2014: |
| 280 | +/// - https://github.com/floodyberry/chacha-opt/blob/0ab65cb99f5016633b652edebaf3691ceb4ff753/chacha_blocks_ssse3-64.S#L639-L643 |
| 281 | +#[inline] |
| 282 | +#[target_feature(enable = "avx512f")] |
| 283 | +unsafe fn rows_to_cols(vs: &mut [[__m512i; 4]; N]) { |
| 284 | + // c >>>= 32; d >>>= 64; a >>>= 96; |
| 285 | + for [a, _, c, d] in vs { |
| 286 | + *c = _mm512_shuffle_epi32::<0b_00_11_10_01>(*c); // _MM_SHUFFLE(0, 3, 2, 1) |
| 287 | + *d = _mm512_shuffle_epi32::<0b_01_00_11_10>(*d); // _MM_SHUFFLE(1, 0, 3, 2) |
| 288 | + *a = _mm512_shuffle_epi32::<0b_10_01_00_11>(*a); // _MM_SHUFFLE(2, 1, 0, 3) |
| 289 | + } |
| 290 | +} |
| 291 | + |
| 292 | +/// The goal of this function is to transform the state words from: |
| 293 | +/// ```text |
| 294 | +/// [a3, a0, a1, a2] [ 3, 0, 1, 2] |
| 295 | +/// [b0, b1, b2, b3] == [ 4, 5, 6, 7] |
| 296 | +/// [c1, c2, c3, c0] [ 9, 10, 11, 8] |
| 297 | +/// [d2, d3, d0, d1] [14, 15, 12, 13] |
| 298 | +/// ``` |
| 299 | +/// |
| 300 | +/// to: |
| 301 | +/// ```text |
| 302 | +/// [a0, a1, a2, a3] [ 0, 1, 2, 3] |
| 303 | +/// [b0, b1, b2, b3] == [ 4, 5, 6, 7] |
| 304 | +/// [c0, c1, c2, c3] [ 8, 9, 10, 11] |
| 305 | +/// [d0, d1, d2, d3] [12, 13, 14, 15] |
| 306 | +/// ``` |
| 307 | +/// |
| 308 | +/// reversing the transformation of [`rows_to_cols`]. |
| 309 | +#[inline] |
| 310 | +#[target_feature(enable = "avx512f")] |
| 311 | +unsafe fn cols_to_rows(vs: &mut [[__m512i; 4]; N]) { |
| 312 | + // c <<<= 32; d <<<= 64; a <<<= 96; |
| 313 | + for [a, _, c, d] in vs { |
| 314 | + *c = _mm512_shuffle_epi32::<0b_10_01_00_11>(*c); // _MM_SHUFFLE(2, 1, 0, 3) |
| 315 | + *d = _mm512_shuffle_epi32::<0b_01_00_11_10>(*d); // _MM_SHUFFLE(1, 0, 3, 2) |
| 316 | + *a = _mm512_shuffle_epi32::<0b_00_11_10_01>(*a); // _MM_SHUFFLE(0, 3, 2, 1) |
| 317 | + } |
| 318 | +} |
| 319 | + |
| 320 | +#[inline] |
| 321 | +#[target_feature(enable = "avx512f")] |
| 322 | +unsafe fn add_xor_rot(vs: &mut [[__m512i; 4]; N]) { |
| 323 | + // a += b; d ^= a; d <<<= (16, 16, 16, 16); |
| 324 | + for [a, b, _, d] in vs.iter_mut() { |
| 325 | + *a = _mm512_add_epi32(*a, *b); |
| 326 | + *d = _mm512_xor_si512(*d, *a); |
| 327 | + *d = _mm512_rol_epi32::<16>(*d); |
| 328 | + } |
| 329 | + |
| 330 | + // c += d; b ^= c; b <<<= (12, 12, 12, 12); |
| 331 | + for [_, b, c, d] in vs.iter_mut() { |
| 332 | + *c = _mm512_add_epi32(*c, *d); |
| 333 | + *b = _mm512_xor_si512(*b, *c); |
| 334 | + *b = _mm512_rol_epi32::<12>(*b); |
| 335 | + } |
| 336 | + |
| 337 | + // a += b; d ^= a; d <<<= (8, 8, 8, 8); |
| 338 | + for [a, b, _, d] in vs.iter_mut() { |
| 339 | + *a = _mm512_add_epi32(*a, *b); |
| 340 | + *d = _mm512_xor_si512(*d, *a); |
| 341 | + *d = _mm512_rol_epi32::<8>(*d); |
| 342 | + } |
| 343 | + |
| 344 | + // c += d; b ^= c; b <<<= (7, 7, 7, 7); |
| 345 | + for [_, b, c, d] in vs.iter_mut() { |
| 346 | + *c = _mm512_add_epi32(*c, *d); |
| 347 | + *b = _mm512_xor_si512(*b, *c); |
| 348 | + *b = _mm512_rol_epi32::<7>(*b); |
| 349 | + } |
| 350 | +} |
0 commit comments